tpu-inference 0.11.1.dev202511150811__py3-none-any.whl → 0.11.1.dev202512030818__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/lora/test_layers.py +0 -6
- tests/lora/utils.py +0 -8
- tests/test_envs.py +32 -11
- tests/test_utils.py +1 -2
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +3 -4
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +61 -8
- tpu_inference/executors/ray_distributed_executor.py +31 -11
- tpu_inference/kernels/fused_moe/v1/kernel.py +641 -110
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +213 -126
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +74 -25
- tpu_inference/layers/vllm/quantization/common.py +6 -1
- tpu_inference/layers/vllm/quantization/mxfp4.py +137 -62
- tpu_inference/layers/vllm/quantization/unquantized.py +107 -113
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +45 -11
- tpu_inference/models/jax/llama3.py +2 -1
- tpu_inference/models/jax/llama_eagle3.py +8 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +2 -1
- tpu_inference/models/jax/qwen2_5_vl.py +163 -48
- tpu_inference/models/jax/qwen3.py +2 -1
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +3 -6
- tpu_inference/models/jax/utils/weight_utils.py +198 -143
- tpu_inference/models/vllm/vllm_model_wrapper.py +14 -7
- tpu_inference/platforms/tpu_platform.py +28 -22
- tpu_inference/runner/compilation_manager.py +144 -59
- tpu_inference/runner/kv_cache_manager.py +17 -18
- tpu_inference/runner/persistent_batch_manager.py +40 -2
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +271 -147
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +71 -21
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +36 -13
- tpu_inference/worker/tpu_worker.py +162 -25
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/METADATA +3 -2
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/RECORD +48 -53
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/top_level.txt +0 -0
|
@@ -1,13 +1,13 @@
|
|
|
1
|
-
import os
|
|
2
1
|
import time
|
|
3
|
-
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
|
4
3
|
|
|
5
4
|
import jax
|
|
6
5
|
import jax.numpy as jnp
|
|
7
6
|
import numpy as np
|
|
8
|
-
import vllm.envs as
|
|
7
|
+
import vllm.envs as vllm_envs
|
|
9
8
|
from jax.sharding import NamedSharding, PartitionSpec
|
|
10
9
|
|
|
10
|
+
import tpu_inference.envs as envs
|
|
11
11
|
from tpu_inference.core.disagg_utils import is_disagg_enabled
|
|
12
12
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
13
13
|
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
@@ -15,6 +15,8 @@ from tpu_inference.layers.jax.sample.sampling import sample
|
|
|
15
15
|
from tpu_inference.layers.jax.sample.sampling_metadata import \
|
|
16
16
|
TPUSupportedSamplingMetadata
|
|
17
17
|
from tpu_inference.logger import init_logger
|
|
18
|
+
from tpu_inference.models.jax.jax_intermediate_tensor import \
|
|
19
|
+
JaxIntermediateTensors
|
|
18
20
|
from tpu_inference.utils import device_array
|
|
19
21
|
|
|
20
22
|
if TYPE_CHECKING:
|
|
@@ -30,10 +32,10 @@ class CompilationManager:
|
|
|
30
32
|
|
|
31
33
|
def __init__(self, runner: "TPUModelRunner"):
|
|
32
34
|
self.runner = runner
|
|
33
|
-
if not
|
|
35
|
+
if not vllm_envs.VLLM_DISABLE_COMPILE_CACHE:
|
|
34
36
|
logger.info("Enabling JAX compile cache.")
|
|
35
37
|
jax.config.update("jax_compilation_cache_dir",
|
|
36
|
-
|
|
38
|
+
vllm_envs.VLLM_XLA_CACHE_PATH)
|
|
37
39
|
|
|
38
40
|
def _create_dummy_tensor(self,
|
|
39
41
|
shape: Tuple[int, ...],
|
|
@@ -67,8 +69,7 @@ class CompilationManager:
|
|
|
67
69
|
logger.info("Compilation finished in %.2f [secs].", end - start)
|
|
68
70
|
|
|
69
71
|
def capture_model(self) -> None:
|
|
70
|
-
if
|
|
71
|
-
False) or self.runner.model_config.enforce_eager:
|
|
72
|
+
if envs.SKIP_JAX_PRECOMPILE or self.runner.model_config.enforce_eager:
|
|
72
73
|
return
|
|
73
74
|
logger.info("Precompile all the subgraphs with possible input shapes.")
|
|
74
75
|
|
|
@@ -81,6 +82,8 @@ class CompilationManager:
|
|
|
81
82
|
self._precompile_backbone_with_inputs_embeds()
|
|
82
83
|
if self.runner.scheduler_config.async_scheduling:
|
|
83
84
|
self._precompile_substitute_placeholder_token()
|
|
85
|
+
if not self.runner.is_last_rank:
|
|
86
|
+
return
|
|
84
87
|
self._precompile_select_from_array()
|
|
85
88
|
self._precompile_compute_logits()
|
|
86
89
|
self._precompile_disagg_utils()
|
|
@@ -120,8 +123,15 @@ class CompilationManager:
|
|
|
120
123
|
num_tokens=num_tokens,
|
|
121
124
|
)
|
|
122
125
|
|
|
123
|
-
def _precompile_backbone_helper(self,
|
|
124
|
-
|
|
126
|
+
def _precompile_backbone_helper(self,
|
|
127
|
+
name,
|
|
128
|
+
*,
|
|
129
|
+
input_ids,
|
|
130
|
+
positions,
|
|
131
|
+
inputs_embeds,
|
|
132
|
+
intermediate_tensors=None,
|
|
133
|
+
is_first_rank=True,
|
|
134
|
+
is_last_rank=True) -> None:
|
|
125
135
|
num_tokens = None
|
|
126
136
|
if input_ids is not None:
|
|
127
137
|
num_tokens = input_ids.shape[0]
|
|
@@ -135,12 +145,6 @@ class CompilationManager:
|
|
|
135
145
|
ShardingAxisName.ATTN_DATA, )) if dp_size > 1 else None
|
|
136
146
|
|
|
137
147
|
# Keep existing pattern for complex array operations
|
|
138
|
-
block_tables = self.runner.block_table_cpu[:self.runner.max_num_reqs]
|
|
139
|
-
block_tables = block_tables.reshape(-1)
|
|
140
|
-
block_tables = device_array(self.runner.mesh,
|
|
141
|
-
block_tables,
|
|
142
|
-
sharding=dp_sharding)
|
|
143
|
-
|
|
144
148
|
seq_lens = self._create_dummy_tensor((self.runner.max_num_reqs, ),
|
|
145
149
|
jnp.int32, dp_sharding)
|
|
146
150
|
query_start_loc = self._create_dummy_tensor(
|
|
@@ -152,26 +156,49 @@ class CompilationManager:
|
|
|
152
156
|
request_distribution,
|
|
153
157
|
sharding=dp_sharding)
|
|
154
158
|
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
159
|
+
attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
|
|
160
|
+
uniform_attention_metadata: AttentionMetadata = None
|
|
161
|
+
for kv_cache_gid, kv_cache_group in enumerate(
|
|
162
|
+
self.runner.kv_cache_config.kv_cache_groups):
|
|
163
|
+
block_tables = self.runner.block_tables_cpu[
|
|
164
|
+
kv_cache_gid][:self.runner.max_num_reqs]
|
|
165
|
+
block_tables = block_tables.reshape(-1)
|
|
166
|
+
block_tables = device_array(self.runner.mesh,
|
|
167
|
+
block_tables,
|
|
168
|
+
sharding=dp_sharding)
|
|
169
|
+
|
|
170
|
+
attention_metadata_gid = AttentionMetadata(
|
|
171
|
+
input_positions=positions,
|
|
172
|
+
block_tables=block_tables,
|
|
173
|
+
seq_lens=seq_lens,
|
|
174
|
+
query_start_loc=query_start_loc,
|
|
175
|
+
request_distribution=request_distribution,
|
|
176
|
+
)
|
|
177
|
+
if not self.runner.use_hybrid_kvcache:
|
|
178
|
+
# all layers share the same attention metadata
|
|
179
|
+
uniform_attention_metadata = attention_metadata_gid
|
|
180
|
+
else:
|
|
181
|
+
for layer_name in kv_cache_group.layer_names:
|
|
182
|
+
attention_metadata_per_layer[
|
|
183
|
+
layer_name] = attention_metadata_gid
|
|
162
184
|
|
|
163
185
|
def model_fn_wrapper(
|
|
164
186
|
state,
|
|
165
187
|
kv_caches,
|
|
166
188
|
input_ids,
|
|
167
189
|
attention_metadata,
|
|
190
|
+
positions,
|
|
168
191
|
inputs_embeds,
|
|
169
192
|
layer_name_to_kvcache_index,
|
|
170
193
|
lora_metadata,
|
|
194
|
+
intermediate_tensors,
|
|
195
|
+
is_first_rank,
|
|
196
|
+
is_last_rank,
|
|
171
197
|
):
|
|
172
198
|
kv_caches, hidden_states, _ = self.runner.model_fn(
|
|
173
199
|
state, kv_caches, input_ids, attention_metadata, inputs_embeds,
|
|
174
|
-
layer_name_to_kvcache_index, lora_metadata
|
|
200
|
+
positions, layer_name_to_kvcache_index, lora_metadata,
|
|
201
|
+
intermediate_tensors, is_first_rank, is_last_rank)
|
|
175
202
|
self.runner.kv_caches = kv_caches
|
|
176
203
|
return hidden_states
|
|
177
204
|
|
|
@@ -179,6 +206,10 @@ class CompilationManager:
|
|
|
179
206
|
self.runner.lora_config, np.array([num_tokens],
|
|
180
207
|
dtype=np.int32)):
|
|
181
208
|
lora_metadata = self.runner.lora_utils.extract_lora_metadata()
|
|
209
|
+
if self.runner.use_hybrid_kvcache:
|
|
210
|
+
attention_metadata = attention_metadata_per_layer
|
|
211
|
+
else:
|
|
212
|
+
attention_metadata = uniform_attention_metadata
|
|
182
213
|
self._run_compilation(
|
|
183
214
|
name,
|
|
184
215
|
model_fn_wrapper,
|
|
@@ -186,9 +217,13 @@ class CompilationManager:
|
|
|
186
217
|
self.runner.kv_caches,
|
|
187
218
|
input_ids,
|
|
188
219
|
attention_metadata,
|
|
220
|
+
positions,
|
|
189
221
|
inputs_embeds,
|
|
190
222
|
tuple(self.runner.layer_name_to_kvcache_index.items()),
|
|
191
223
|
lora_metadata,
|
|
224
|
+
intermediate_tensors,
|
|
225
|
+
is_first_rank,
|
|
226
|
+
is_last_rank,
|
|
192
227
|
num_tokens=num_tokens,
|
|
193
228
|
)
|
|
194
229
|
|
|
@@ -239,6 +274,7 @@ class CompilationManager:
|
|
|
239
274
|
)
|
|
240
275
|
|
|
241
276
|
def _precompile_backbone_text_only(self) -> None:
|
|
277
|
+
hidden_size = self.runner.model_config.get_hidden_size()
|
|
242
278
|
for num_tokens in self.runner.num_tokens_paddings:
|
|
243
279
|
dp_sharding = NamedSharding(
|
|
244
280
|
self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, )
|
|
@@ -248,10 +284,28 @@ class CompilationManager:
|
|
|
248
284
|
dp_sharding)
|
|
249
285
|
positions = self._create_dummy_tensor((num_tokens, ), jnp.int32,
|
|
250
286
|
dp_sharding)
|
|
251
|
-
self.
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
287
|
+
is_first_rank = self.runner.is_first_rank
|
|
288
|
+
is_last_rank = self.runner.is_last_rank
|
|
289
|
+
if is_first_rank:
|
|
290
|
+
intermediate_tensors = None
|
|
291
|
+
else:
|
|
292
|
+
hidden_states = self._create_dummy_tensor(
|
|
293
|
+
(num_tokens, hidden_size), jnp.bfloat16)
|
|
294
|
+
residual = self._create_dummy_tensor((num_tokens, hidden_size),
|
|
295
|
+
jnp.bfloat16)
|
|
296
|
+
intermediate_tensors = JaxIntermediateTensors(
|
|
297
|
+
tensors={
|
|
298
|
+
"hidden_states": hidden_states,
|
|
299
|
+
"residual": residual
|
|
300
|
+
})
|
|
301
|
+
self._precompile_backbone_helper(
|
|
302
|
+
f"worker{self.runner.rank} backbone",
|
|
303
|
+
input_ids=input_ids,
|
|
304
|
+
positions=positions,
|
|
305
|
+
inputs_embeds=None,
|
|
306
|
+
intermediate_tensors=intermediate_tensors,
|
|
307
|
+
is_first_rank=is_first_rank,
|
|
308
|
+
is_last_rank=is_last_rank)
|
|
255
309
|
|
|
256
310
|
def _precompile_backbone_with_inputs_embeds(self) -> None:
|
|
257
311
|
hidden_size = self.runner.model_config.get_hidden_size()
|
|
@@ -265,10 +319,28 @@ class CompilationManager:
|
|
|
265
319
|
else:
|
|
266
320
|
positions = self._create_dummy_tensor((num_tokens, ),
|
|
267
321
|
jnp.int32)
|
|
268
|
-
self.
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
322
|
+
is_first_rank = self.runner.is_first_rank
|
|
323
|
+
is_last_rank = self.runner.is_last_rank
|
|
324
|
+
if not is_first_rank:
|
|
325
|
+
hidden_states = self._create_dummy_tensor(
|
|
326
|
+
(num_tokens, hidden_size), jnp.bfloat16)
|
|
327
|
+
residual = self._create_dummy_tensor((num_tokens, hidden_size),
|
|
328
|
+
jnp.bfloat16)
|
|
329
|
+
intermediate_tensors = JaxIntermediateTensors(
|
|
330
|
+
tensors={
|
|
331
|
+
"hidden_states": hidden_states,
|
|
332
|
+
"residual": residual
|
|
333
|
+
})
|
|
334
|
+
else:
|
|
335
|
+
intermediate_tensors = None
|
|
336
|
+
self._precompile_backbone_helper(
|
|
337
|
+
f"worker{self.runner.rank} backbone with embeds",
|
|
338
|
+
input_ids=None,
|
|
339
|
+
positions=positions,
|
|
340
|
+
inputs_embeds=inputs_embeds,
|
|
341
|
+
intermediate_tensors=intermediate_tensors,
|
|
342
|
+
is_first_rank=is_first_rank,
|
|
343
|
+
is_last_rank=is_last_rank)
|
|
272
344
|
|
|
273
345
|
def _precompile_select_from_array_helper(
|
|
274
346
|
self,
|
|
@@ -332,20 +404,23 @@ class CompilationManager:
|
|
|
332
404
|
index_paddings = self.runner.num_reqs_paddings
|
|
333
405
|
dp_sharding = NamedSharding(self.runner.mesh,
|
|
334
406
|
PartitionSpec(ShardingAxisName.ATTN_DATA))
|
|
407
|
+
hidden_states_sharding = NamedSharding(
|
|
408
|
+
self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None))
|
|
335
409
|
dp_size = self.runner.vllm_config.sharding_config.total_dp_size
|
|
336
410
|
self._precompile_select_from_array_helper(
|
|
337
|
-
name="select all logits",
|
|
411
|
+
name=f"worker{self.runner.rank} select all logits",
|
|
338
412
|
source_paddings=self.runner.num_tokens_paddings,
|
|
339
413
|
indices_paddings=index_paddings,
|
|
340
414
|
hidden_dim=hsize,
|
|
341
|
-
input_sharding=
|
|
415
|
+
input_sharding=hidden_states_sharding,
|
|
342
416
|
indices_sharding=dp_sharding if dp_size > 1 else None,
|
|
343
417
|
)
|
|
344
418
|
|
|
345
419
|
if self.runner.speculative_config:
|
|
346
420
|
vocab_size = self.runner.model_config.get_vocab_size()
|
|
347
421
|
self._precompile_select_from_array_helper(
|
|
348
|
-
name=
|
|
422
|
+
name=
|
|
423
|
+
f"worker{self.runner.rank} select bonus tokens for spec decoding",
|
|
349
424
|
source_paddings=self.runner.num_logits_paddings,
|
|
350
425
|
indices_paddings=self.runner.num_reqs_paddings,
|
|
351
426
|
hidden_dim=vocab_size,
|
|
@@ -353,7 +428,8 @@ class CompilationManager:
|
|
|
353
428
|
PartitionSpec(None, "model")),
|
|
354
429
|
)
|
|
355
430
|
self._precompile_select_from_array_helper(
|
|
356
|
-
name=
|
|
431
|
+
name=
|
|
432
|
+
f"worker{self.runner.rank} select target tokens for spec decoding",
|
|
357
433
|
source_paddings=self.runner.num_logits_paddings,
|
|
358
434
|
indices_paddings=self.runner.num_logits_paddings,
|
|
359
435
|
hidden_dim=vocab_size,
|
|
@@ -376,7 +452,7 @@ class CompilationManager:
|
|
|
376
452
|
np.array([num_reqs], dtype=np.int32)):
|
|
377
453
|
lora_metadata = self.runner.lora_utils.extract_lora_metadata()
|
|
378
454
|
self._run_compilation(
|
|
379
|
-
"compute_logits",
|
|
455
|
+
f"worker{self.runner.rank} compute_logits",
|
|
380
456
|
self.runner.compute_logits_fn,
|
|
381
457
|
self.runner.state,
|
|
382
458
|
hidden_states,
|
|
@@ -418,7 +494,7 @@ class CompilationManager:
|
|
|
418
494
|
do_sampling=do_sampling,
|
|
419
495
|
)
|
|
420
496
|
self._run_compilation(
|
|
421
|
-
"sample",
|
|
497
|
+
f"worker{self.runner.rank} sample",
|
|
422
498
|
sample,
|
|
423
499
|
self.runner.rng_params_for_sampling,
|
|
424
500
|
self.runner.mesh,
|
|
@@ -459,7 +535,7 @@ class CompilationManager:
|
|
|
459
535
|
logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16)
|
|
460
536
|
token_ids = self._create_dummy_tensor((num_reqs, ), jnp.int32)
|
|
461
537
|
self._run_compilation(
|
|
462
|
-
"gather_logprobs",
|
|
538
|
+
f"worker{self.runner.rank} gather_logprobs",
|
|
463
539
|
self.runner._compute_and_gather_logprobs,
|
|
464
540
|
logits,
|
|
465
541
|
token_ids,
|
|
@@ -511,7 +587,7 @@ class CompilationManager:
|
|
|
511
587
|
do_sampling=do_sampling)
|
|
512
588
|
|
|
513
589
|
self._run_compilation(
|
|
514
|
-
compilation_name,
|
|
590
|
+
f"worker{self.runner.rank} {compilation_name}",
|
|
515
591
|
self.runner.rejection_sampler,
|
|
516
592
|
draft_token_ids,
|
|
517
593
|
num_draft_tokens,
|
|
@@ -528,7 +604,9 @@ class CompilationManager:
|
|
|
528
604
|
def _precompile_eagle3_helpers(self) -> None:
|
|
529
605
|
logger.info(
|
|
530
606
|
"Compiling eagle3 jitted helpers with different input shapes.")
|
|
531
|
-
|
|
607
|
+
target_hidden_size = self.runner.model_config.get_hidden_size()
|
|
608
|
+
draft_hidden_size = self.runner.speculative_config.draft_model_config.get_hidden_size(
|
|
609
|
+
)
|
|
532
610
|
dtype = self.runner.model_config.dtype
|
|
533
611
|
|
|
534
612
|
num_kv_cache_groups = len(self.runner.kv_cache_config.kv_cache_groups)
|
|
@@ -575,10 +653,11 @@ class CompilationManager:
|
|
|
575
653
|
|
|
576
654
|
for num_logits in self.runner.num_logits_paddings:
|
|
577
655
|
hidden_states = self._create_dummy_tensor(
|
|
578
|
-
(num_logits,
|
|
656
|
+
(num_logits, draft_hidden_size), jnp.bfloat16)
|
|
579
657
|
self._run_compilation(
|
|
580
658
|
"eagle3_get_draft_token_ids",
|
|
581
659
|
self.runner.drafter._get_draft_token_ids,
|
|
660
|
+
self.runner.drafter.state,
|
|
582
661
|
hidden_states,
|
|
583
662
|
num_logits=num_logits,
|
|
584
663
|
)
|
|
@@ -586,8 +665,8 @@ class CompilationManager:
|
|
|
586
665
|
input_ids_loop = self._create_dummy_tensor(
|
|
587
666
|
(self.runner.max_num_reqs, ), jnp.int32,
|
|
588
667
|
NamedSharding(self.runner.mesh, PartitionSpec()))
|
|
589
|
-
|
|
590
|
-
(self.runner.max_num_reqs,
|
|
668
|
+
draft_hidden_state_loop = self._create_dummy_tensor(
|
|
669
|
+
(self.runner.max_num_reqs, draft_hidden_size), dtype,
|
|
591
670
|
NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
|
|
592
671
|
next_token_ids = self._create_dummy_tensor(
|
|
593
672
|
(self.runner.max_num_reqs, ), jnp.int32)
|
|
@@ -595,9 +674,12 @@ class CompilationManager:
|
|
|
595
674
|
(self.runner.max_num_reqs, ), jnp.int32)
|
|
596
675
|
for num_tokens in self.runner.num_tokens_paddings:
|
|
597
676
|
aux_hidden_states = [
|
|
598
|
-
self._create_dummy_tensor((num_tokens,
|
|
599
|
-
|
|
600
|
-
self._create_dummy_tensor((num_tokens,
|
|
677
|
+
self._create_dummy_tensor((num_tokens, target_hidden_size),
|
|
678
|
+
dtype),
|
|
679
|
+
self._create_dummy_tensor((num_tokens, target_hidden_size),
|
|
680
|
+
dtype),
|
|
681
|
+
self._create_dummy_tensor((num_tokens, target_hidden_size),
|
|
682
|
+
dtype),
|
|
601
683
|
]
|
|
602
684
|
|
|
603
685
|
positions = self._create_dummy_tensor((num_tokens, ), jnp.int32)
|
|
@@ -620,23 +702,23 @@ class CompilationManager:
|
|
|
620
702
|
num_reqs,
|
|
621
703
|
):
|
|
622
704
|
target_hidden_states, input_ids, last_token_indices, _ = self.runner.drafter._filter_token_and_prepare_initial_inputs(
|
|
623
|
-
token_indices, query_start_loc,
|
|
624
|
-
aux_hidden_states, attention_metadata,
|
|
625
|
-
num_reqs)
|
|
705
|
+
self.runner.drafter.state, token_indices, query_start_loc,
|
|
706
|
+
seq_lens, input_ids, aux_hidden_states, attention_metadata,
|
|
707
|
+
next_token_ids, num_reqs)
|
|
626
708
|
return target_hidden_states, input_ids, last_token_indices
|
|
627
709
|
|
|
628
710
|
input_ids = self._create_dummy_tensor((num_tokens, ), jnp.int32)
|
|
629
711
|
aux_hidden_states = [
|
|
630
712
|
self._create_dummy_tensor(
|
|
631
|
-
(num_tokens,
|
|
713
|
+
(num_tokens, target_hidden_size), jnp.bfloat16,
|
|
632
714
|
NamedSharding(self.runner.mesh, PartitionSpec(None,
|
|
633
715
|
None))),
|
|
634
716
|
self._create_dummy_tensor(
|
|
635
|
-
(num_tokens,
|
|
717
|
+
(num_tokens, target_hidden_size), jnp.bfloat16,
|
|
636
718
|
NamedSharding(self.runner.mesh, PartitionSpec(None,
|
|
637
719
|
None))),
|
|
638
720
|
self._create_dummy_tensor(
|
|
639
|
-
(num_tokens,
|
|
721
|
+
(num_tokens, target_hidden_size), jnp.bfloat16,
|
|
640
722
|
NamedSharding(self.runner.mesh, PartitionSpec(None,
|
|
641
723
|
None))),
|
|
642
724
|
]
|
|
@@ -668,17 +750,17 @@ class CompilationManager:
|
|
|
668
750
|
state,
|
|
669
751
|
kv_caches,
|
|
670
752
|
input_ids,
|
|
671
|
-
|
|
753
|
+
draft_hidden_states,
|
|
672
754
|
attention_metadata,
|
|
673
755
|
):
|
|
674
756
|
kv_caches, hidden_states, _ = self.runner.drafter.model_fn(
|
|
675
|
-
state, kv_caches, input_ids,
|
|
757
|
+
state, kv_caches, input_ids, draft_hidden_states,
|
|
676
758
|
attention_metadata)
|
|
677
759
|
self.runner.kv_caches = kv_caches
|
|
678
760
|
return hidden_states
|
|
679
761
|
|
|
680
|
-
|
|
681
|
-
(num_tokens,
|
|
762
|
+
draft_hidden_states = self._create_dummy_tensor(
|
|
763
|
+
(num_tokens, draft_hidden_size), dtype,
|
|
682
764
|
NamedSharding(self.runner.mesh, PartitionSpec(None, "model")))
|
|
683
765
|
input_ids = self._create_dummy_tensor(
|
|
684
766
|
(num_tokens, ), jnp.int32,
|
|
@@ -689,7 +771,7 @@ class CompilationManager:
|
|
|
689
771
|
self.runner.drafter.state,
|
|
690
772
|
self.runner.kv_caches,
|
|
691
773
|
input_ids,
|
|
692
|
-
|
|
774
|
+
draft_hidden_states,
|
|
693
775
|
attention_metadata,
|
|
694
776
|
num_tokens=num_tokens,
|
|
695
777
|
)
|
|
@@ -699,6 +781,7 @@ class CompilationManager:
|
|
|
699
781
|
self._run_compilation(
|
|
700
782
|
"eagle3_prepare_hidden_states_and_input_ids",
|
|
701
783
|
self.runner.drafter._prepare_hidden_states_and_input_ids,
|
|
784
|
+
self.runner.drafter.state,
|
|
702
785
|
aux_hidden_states,
|
|
703
786
|
query_start_loc,
|
|
704
787
|
target_token_ids,
|
|
@@ -721,18 +804,19 @@ class CompilationManager:
|
|
|
721
804
|
self.runner.drafter.state,
|
|
722
805
|
self.runner.kv_caches,
|
|
723
806
|
input_ids_loop,
|
|
724
|
-
|
|
807
|
+
draft_hidden_state_loop,
|
|
725
808
|
attention_metadata,
|
|
726
809
|
num_tokens=num_tokens,
|
|
727
810
|
)
|
|
728
811
|
|
|
729
812
|
hidden_states = self._create_dummy_tensor(
|
|
730
|
-
(num_tokens,
|
|
813
|
+
(num_tokens, draft_hidden_size), jnp.bfloat16,
|
|
731
814
|
NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
|
|
732
815
|
|
|
733
816
|
self._run_compilation(
|
|
734
817
|
"eagle3_select_inputs_for_loop_speculation",
|
|
735
818
|
self.runner.drafter._select_inputs_for_loop_speculation,
|
|
819
|
+
self.runner.drafter.state,
|
|
736
820
|
positions,
|
|
737
821
|
hidden_states,
|
|
738
822
|
hidden_states,
|
|
@@ -743,6 +827,7 @@ class CompilationManager:
|
|
|
743
827
|
self._run_compilation(
|
|
744
828
|
"eagle3_select_draft_token_ids",
|
|
745
829
|
self.runner.drafter._select_draft_token_ids,
|
|
830
|
+
self.runner.drafter.state,
|
|
746
831
|
hidden_states,
|
|
747
832
|
last_token_indices,
|
|
748
833
|
num_tokens=num_tokens,
|
|
@@ -1,15 +1,16 @@
|
|
|
1
1
|
import functools
|
|
2
|
-
import math
|
|
3
2
|
from typing import TYPE_CHECKING, Dict, List
|
|
4
3
|
|
|
5
4
|
import jax
|
|
6
5
|
import jax.numpy as jnp
|
|
6
|
+
import numpy as np
|
|
7
7
|
import vllm.envs as envs
|
|
8
8
|
from jax.sharding import NamedSharding, PartitionSpec
|
|
9
9
|
from torchax.ops.mappings import t2j_dtype
|
|
10
|
-
from vllm.attention import Attention
|
|
11
10
|
from vllm.attention.backends.abstract import AttentionType
|
|
11
|
+
from vllm.attention.layer import Attention
|
|
12
12
|
from vllm.config import get_layers_from_vllm_config
|
|
13
|
+
from vllm.utils.math_utils import cdiv
|
|
13
14
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
|
14
15
|
KVCacheSpec, MLAAttentionSpec,
|
|
15
16
|
SlidingWindowSpec)
|
|
@@ -175,6 +176,11 @@ class KVCacheManager:
|
|
|
175
176
|
)
|
|
176
177
|
self.runner.input_batch = new_input_batch
|
|
177
178
|
self.runner.persistent_batch_manager.input_batch = new_input_batch
|
|
179
|
+
self.runner.block_tables_cpu = [
|
|
180
|
+
np.zeros((self.runner.max_num_reqs,
|
|
181
|
+
cdiv(self.runner.max_model_len, block_size)),
|
|
182
|
+
dtype=np.int32) for block_size in block_sizes
|
|
183
|
+
]
|
|
178
184
|
|
|
179
185
|
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
|
180
186
|
self.maybe_reinitialize_input_batch(kv_cache_config)
|
|
@@ -190,7 +196,7 @@ class KVCacheManager:
|
|
|
190
196
|
num_blocks = kv_cache_tensor.size // page_size_bytes
|
|
191
197
|
dp_size = self.runner.vllm_config.sharding_config.total_dp_size
|
|
192
198
|
# num_blocks must be a multiple of dp_size
|
|
193
|
-
num_blocks =
|
|
199
|
+
num_blocks = (num_blocks // dp_size) * dp_size
|
|
194
200
|
# NOTE: we'll multiply the num_kv_heads by 2 in the function
|
|
195
201
|
kv_cache = create_kv_caches(
|
|
196
202
|
num_blocks=num_blocks,
|
|
@@ -283,13 +289,8 @@ class KVCacheManager:
|
|
|
283
289
|
|
|
284
290
|
def _update_layer(cache, slices):
|
|
285
291
|
"""The function to apply to each layer's cache and slices."""
|
|
286
|
-
reshaped_slices = slices.reshape(-1,
|
|
287
|
-
|
|
288
|
-
for (i, block_idx) in enumerate(block_numbers):
|
|
289
|
-
cache = jax.lax.dynamic_update_slice_in_dim(cache,
|
|
290
|
-
reshaped_slices[i],
|
|
291
|
-
block_idx,
|
|
292
|
-
axis=0)
|
|
292
|
+
reshaped_slices = slices.reshape(-1, block_size, *slices.shape[1:])
|
|
293
|
+
cache.at[block_numbers].set(reshaped_slices)
|
|
293
294
|
return cache
|
|
294
295
|
|
|
295
296
|
return jax.tree.map(_update_layer, kv_caches, kv_cache_slices)
|
|
@@ -342,16 +343,12 @@ class KVCacheManager:
|
|
|
342
343
|
"""
|
|
343
344
|
if block_ids == list(range(block_ids[0],
|
|
344
345
|
block_ids[0] + len(block_ids))):
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
batched_kv_cache_per_layer = self._jitted_gather_continuous_kv_cache(
|
|
348
|
-
self.runner.kv_caches, block_ids[0], len(block_ids))
|
|
346
|
+
batched_kv_cache_per_layer = self._jitted_gather_continuous_kv_cache(
|
|
347
|
+
self.runner.kv_caches, block_ids[0], len(block_ids))
|
|
349
348
|
|
|
350
349
|
else:
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
batched_kv_cache_per_layer = self._jitted_gather_kv_cache(
|
|
354
|
-
self.runner.kv_caches, jnp.array(block_ids))
|
|
350
|
+
batched_kv_cache_per_layer = self._jitted_gather_kv_cache(
|
|
351
|
+
self.runner.kv_caches, jnp.array(block_ids))
|
|
355
352
|
return batched_kv_cache_per_layer
|
|
356
353
|
|
|
357
354
|
def transfer_kv_cache(self,
|
|
@@ -440,6 +437,7 @@ class KVCacheManager:
|
|
|
440
437
|
kv_cache_slices,
|
|
441
438
|
start_block,
|
|
442
439
|
)
|
|
440
|
+
jax.block_until_ready(self.runner.kv_caches)
|
|
443
441
|
else:
|
|
444
442
|
with runner_utils.LatencyTracker(
|
|
445
443
|
f"JittedInsertKVCache-b{len(block_numbers)}"):
|
|
@@ -451,6 +449,7 @@ class KVCacheManager:
|
|
|
451
449
|
kv_cache_slices,
|
|
452
450
|
jnp.array(block_numbers),
|
|
453
451
|
)
|
|
452
|
+
jax.block_until_ready(self.runner.kv_caches)
|
|
454
453
|
|
|
455
454
|
logger.debug(
|
|
456
455
|
f"Updated kv cache entries cnt={len(self.runner.kv_caches)}")
|
|
@@ -14,12 +14,13 @@ class PersistentBatchManager:
|
|
|
14
14
|
def __init__(self, requests: Dict[str, CachedRequestState],
|
|
15
15
|
input_batch: InputBatch, encoder_cache: Dict[str,
|
|
16
16
|
'jax.Array'],
|
|
17
|
-
uses_mrope: bool, model_config):
|
|
17
|
+
uses_mrope: bool, model_config, is_last_rank: bool):
|
|
18
18
|
self.requests = requests
|
|
19
19
|
self.input_batch = input_batch
|
|
20
20
|
self.encoder_cache = encoder_cache
|
|
21
21
|
self.uses_mrope = uses_mrope
|
|
22
22
|
self.model_config = model_config
|
|
23
|
+
self.is_last_rank = is_last_rank
|
|
23
24
|
|
|
24
25
|
def _reorder_batch(self, scheduler_output: "VllmSchedulerOutput") -> int:
|
|
25
26
|
""" Reorder the sheduled requests to RPA kernel friendly distribution
|
|
@@ -179,9 +180,35 @@ class PersistentBatchManager:
|
|
|
179
180
|
num_computed_tokens = req_data.num_computed_tokens[i]
|
|
180
181
|
new_block_ids = req_data.new_block_ids[i]
|
|
181
182
|
resumed_from_preemption = req_data.resumed_from_preemption[i]
|
|
183
|
+
num_output_tokens = req_data.num_output_tokens[i]
|
|
182
184
|
|
|
183
185
|
# Update the cached states.
|
|
184
186
|
req_state.num_computed_tokens = num_computed_tokens
|
|
187
|
+
req_index = self.input_batch.req_id_to_index.get(req_id)
|
|
188
|
+
|
|
189
|
+
if not self.is_last_rank:
|
|
190
|
+
# When using PP, the scheduler sends the sampled tokens back,
|
|
191
|
+
# because there's no direct communication between the first-
|
|
192
|
+
# stage worker and the last-stage worker.
|
|
193
|
+
new_token_ids = req_data.new_token_ids[i]
|
|
194
|
+
# Add the sampled token(s) from the previous step (if any).
|
|
195
|
+
# This doesn't include "unverified" tokens like spec tokens.
|
|
196
|
+
num_new_tokens = (num_computed_tokens + len(new_token_ids) -
|
|
197
|
+
req_state.num_tokens)
|
|
198
|
+
if num_new_tokens == 1:
|
|
199
|
+
req_state.output_token_ids.append(new_token_ids[-1])
|
|
200
|
+
elif num_new_tokens > 0:
|
|
201
|
+
req_state.output_token_ids.extend(
|
|
202
|
+
new_token_ids[-num_new_tokens:])
|
|
203
|
+
elif num_output_tokens < len(req_state.output_token_ids):
|
|
204
|
+
del req_state.output_token_ids[num_output_tokens:]
|
|
205
|
+
if req_index is not None:
|
|
206
|
+
end_idx = (self.input_batch.num_prompt_tokens[req_index] +
|
|
207
|
+
num_output_tokens)
|
|
208
|
+
self.input_batch.num_tokens[req_index] = end_idx
|
|
209
|
+
self.input_batch.num_tokens_no_spec[req_index] = end_idx
|
|
210
|
+
|
|
211
|
+
# Update the block IDs.
|
|
185
212
|
if not resumed_from_preemption:
|
|
186
213
|
if new_block_ids is not None:
|
|
187
214
|
# Append the new blocks to the existing block IDs.
|
|
@@ -194,7 +221,6 @@ class PersistentBatchManager:
|
|
|
194
221
|
# Replace the existing block IDs with the new ones.
|
|
195
222
|
req_state.block_ids = new_block_ids
|
|
196
223
|
|
|
197
|
-
req_index = self.input_batch.req_id_to_index.get(req_id)
|
|
198
224
|
if req_index is None:
|
|
199
225
|
# The request is not in the persistent batch.
|
|
200
226
|
# The request was either preempted and resumed later, or was not
|
|
@@ -209,6 +235,18 @@ class PersistentBatchManager:
|
|
|
209
235
|
self.input_batch.block_table.append_row(
|
|
210
236
|
new_block_ids, req_index)
|
|
211
237
|
|
|
238
|
+
# For the last rank, we don't need to update the token_ids_cpu
|
|
239
|
+
# because the sampled tokens are already cached.
|
|
240
|
+
if not self.is_last_rank:
|
|
241
|
+
start_token_index = num_computed_tokens
|
|
242
|
+
end_token_index = num_computed_tokens + len(new_token_ids)
|
|
243
|
+
self.input_batch.token_ids_cpu[
|
|
244
|
+
req_index,
|
|
245
|
+
start_token_index:end_token_index] = new_token_ids
|
|
246
|
+
self.input_batch.num_tokens_no_spec[
|
|
247
|
+
req_index] = end_token_index
|
|
248
|
+
self.input_batch.num_tokens[req_index] = end_token_index
|
|
249
|
+
|
|
212
250
|
# Add spec_token_ids to token_ids_cpu.
|
|
213
251
|
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
|
|
214
252
|
req_id, ())
|
|
@@ -61,11 +61,10 @@ class StructuredDecodingManager:
|
|
|
61
61
|
self.runner.require_structured_out_cpu.fill(0)
|
|
62
62
|
|
|
63
63
|
sorted_struct_requests = sorted(
|
|
64
|
-
grammar_output.structured_output_request_ids
|
|
65
|
-
key=lambda item: item[1])
|
|
64
|
+
grammar_output.structured_output_request_ids)
|
|
66
65
|
|
|
67
66
|
cumulative_mask_idx = 0
|
|
68
|
-
for req_id
|
|
67
|
+
for req_id in sorted_struct_requests:
|
|
69
68
|
if req_id not in self.runner.input_batch.req_id_to_index:
|
|
70
69
|
continue
|
|
71
70
|
batch_index = self.runner.input_batch.req_id_to_index[req_id]
|