tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.11.1.dev202512030818__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/test_envs.py +32 -11
- tests/test_utils.py +1 -2
- tpu_inference/distributed/tpu_connector.py +1 -1
- tpu_inference/envs.py +60 -7
- tpu_inference/executors/ray_distributed_executor.py +5 -1
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +72 -19
- tpu_inference/layers/common/sharding.py +3 -4
- tpu_inference/layers/vllm/quantization/mxfp4.py +2 -1
- tpu_inference/models/common/model_loader.py +3 -1
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +3 -6
- tpu_inference/models/vllm/vllm_model_wrapper.py +1 -2
- tpu_inference/platforms/tpu_platform.py +13 -20
- tpu_inference/runner/compilation_manager.py +87 -27
- tpu_inference/runner/kv_cache_manager.py +8 -15
- tpu_inference/runner/persistent_batch_manager.py +40 -2
- tpu_inference/runner/tpu_runner.py +68 -45
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +52 -19
- tpu_inference/utils.py +31 -9
- tpu_inference/worker/tpu_worker.py +2 -2
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/METADATA +1 -1
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/RECORD +25 -25
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/top_level.txt +0 -0
|
@@ -1,13 +1,13 @@
|
|
|
1
|
-
import os
|
|
2
1
|
import time
|
|
3
2
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
|
4
3
|
|
|
5
4
|
import jax
|
|
6
5
|
import jax.numpy as jnp
|
|
7
6
|
import numpy as np
|
|
8
|
-
import vllm.envs as
|
|
7
|
+
import vllm.envs as vllm_envs
|
|
9
8
|
from jax.sharding import NamedSharding, PartitionSpec
|
|
10
9
|
|
|
10
|
+
import tpu_inference.envs as envs
|
|
11
11
|
from tpu_inference.core.disagg_utils import is_disagg_enabled
|
|
12
12
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
13
13
|
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
@@ -15,6 +15,8 @@ from tpu_inference.layers.jax.sample.sampling import sample
|
|
|
15
15
|
from tpu_inference.layers.jax.sample.sampling_metadata import \
|
|
16
16
|
TPUSupportedSamplingMetadata
|
|
17
17
|
from tpu_inference.logger import init_logger
|
|
18
|
+
from tpu_inference.models.jax.jax_intermediate_tensor import \
|
|
19
|
+
JaxIntermediateTensors
|
|
18
20
|
from tpu_inference.utils import device_array
|
|
19
21
|
|
|
20
22
|
if TYPE_CHECKING:
|
|
@@ -30,10 +32,10 @@ class CompilationManager:
|
|
|
30
32
|
|
|
31
33
|
def __init__(self, runner: "TPUModelRunner"):
|
|
32
34
|
self.runner = runner
|
|
33
|
-
if not
|
|
35
|
+
if not vllm_envs.VLLM_DISABLE_COMPILE_CACHE:
|
|
34
36
|
logger.info("Enabling JAX compile cache.")
|
|
35
37
|
jax.config.update("jax_compilation_cache_dir",
|
|
36
|
-
|
|
38
|
+
vllm_envs.VLLM_XLA_CACHE_PATH)
|
|
37
39
|
|
|
38
40
|
def _create_dummy_tensor(self,
|
|
39
41
|
shape: Tuple[int, ...],
|
|
@@ -67,8 +69,7 @@ class CompilationManager:
|
|
|
67
69
|
logger.info("Compilation finished in %.2f [secs].", end - start)
|
|
68
70
|
|
|
69
71
|
def capture_model(self) -> None:
|
|
70
|
-
if
|
|
71
|
-
False) or self.runner.model_config.enforce_eager:
|
|
72
|
+
if envs.SKIP_JAX_PRECOMPILE or self.runner.model_config.enforce_eager:
|
|
72
73
|
return
|
|
73
74
|
logger.info("Precompile all the subgraphs with possible input shapes.")
|
|
74
75
|
|
|
@@ -81,6 +82,8 @@ class CompilationManager:
|
|
|
81
82
|
self._precompile_backbone_with_inputs_embeds()
|
|
82
83
|
if self.runner.scheduler_config.async_scheduling:
|
|
83
84
|
self._precompile_substitute_placeholder_token()
|
|
85
|
+
if not self.runner.is_last_rank:
|
|
86
|
+
return
|
|
84
87
|
self._precompile_select_from_array()
|
|
85
88
|
self._precompile_compute_logits()
|
|
86
89
|
self._precompile_disagg_utils()
|
|
@@ -120,8 +123,15 @@ class CompilationManager:
|
|
|
120
123
|
num_tokens=num_tokens,
|
|
121
124
|
)
|
|
122
125
|
|
|
123
|
-
def _precompile_backbone_helper(self,
|
|
124
|
-
|
|
126
|
+
def _precompile_backbone_helper(self,
|
|
127
|
+
name,
|
|
128
|
+
*,
|
|
129
|
+
input_ids,
|
|
130
|
+
positions,
|
|
131
|
+
inputs_embeds,
|
|
132
|
+
intermediate_tensors=None,
|
|
133
|
+
is_first_rank=True,
|
|
134
|
+
is_last_rank=True) -> None:
|
|
125
135
|
num_tokens = None
|
|
126
136
|
if input_ids is not None:
|
|
127
137
|
num_tokens = input_ids.shape[0]
|
|
@@ -181,10 +191,14 @@ class CompilationManager:
|
|
|
181
191
|
inputs_embeds,
|
|
182
192
|
layer_name_to_kvcache_index,
|
|
183
193
|
lora_metadata,
|
|
194
|
+
intermediate_tensors,
|
|
195
|
+
is_first_rank,
|
|
196
|
+
is_last_rank,
|
|
184
197
|
):
|
|
185
198
|
kv_caches, hidden_states, _ = self.runner.model_fn(
|
|
186
199
|
state, kv_caches, input_ids, attention_metadata, inputs_embeds,
|
|
187
|
-
positions, layer_name_to_kvcache_index, lora_metadata
|
|
200
|
+
positions, layer_name_to_kvcache_index, lora_metadata,
|
|
201
|
+
intermediate_tensors, is_first_rank, is_last_rank)
|
|
188
202
|
self.runner.kv_caches = kv_caches
|
|
189
203
|
return hidden_states
|
|
190
204
|
|
|
@@ -207,6 +221,9 @@ class CompilationManager:
|
|
|
207
221
|
inputs_embeds,
|
|
208
222
|
tuple(self.runner.layer_name_to_kvcache_index.items()),
|
|
209
223
|
lora_metadata,
|
|
224
|
+
intermediate_tensors,
|
|
225
|
+
is_first_rank,
|
|
226
|
+
is_last_rank,
|
|
210
227
|
num_tokens=num_tokens,
|
|
211
228
|
)
|
|
212
229
|
|
|
@@ -257,6 +274,7 @@ class CompilationManager:
|
|
|
257
274
|
)
|
|
258
275
|
|
|
259
276
|
def _precompile_backbone_text_only(self) -> None:
|
|
277
|
+
hidden_size = self.runner.model_config.get_hidden_size()
|
|
260
278
|
for num_tokens in self.runner.num_tokens_paddings:
|
|
261
279
|
dp_sharding = NamedSharding(
|
|
262
280
|
self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, )
|
|
@@ -266,10 +284,28 @@ class CompilationManager:
|
|
|
266
284
|
dp_sharding)
|
|
267
285
|
positions = self._create_dummy_tensor((num_tokens, ), jnp.int32,
|
|
268
286
|
dp_sharding)
|
|
269
|
-
self.
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
287
|
+
is_first_rank = self.runner.is_first_rank
|
|
288
|
+
is_last_rank = self.runner.is_last_rank
|
|
289
|
+
if is_first_rank:
|
|
290
|
+
intermediate_tensors = None
|
|
291
|
+
else:
|
|
292
|
+
hidden_states = self._create_dummy_tensor(
|
|
293
|
+
(num_tokens, hidden_size), jnp.bfloat16)
|
|
294
|
+
residual = self._create_dummy_tensor((num_tokens, hidden_size),
|
|
295
|
+
jnp.bfloat16)
|
|
296
|
+
intermediate_tensors = JaxIntermediateTensors(
|
|
297
|
+
tensors={
|
|
298
|
+
"hidden_states": hidden_states,
|
|
299
|
+
"residual": residual
|
|
300
|
+
})
|
|
301
|
+
self._precompile_backbone_helper(
|
|
302
|
+
f"worker{self.runner.rank} backbone",
|
|
303
|
+
input_ids=input_ids,
|
|
304
|
+
positions=positions,
|
|
305
|
+
inputs_embeds=None,
|
|
306
|
+
intermediate_tensors=intermediate_tensors,
|
|
307
|
+
is_first_rank=is_first_rank,
|
|
308
|
+
is_last_rank=is_last_rank)
|
|
273
309
|
|
|
274
310
|
def _precompile_backbone_with_inputs_embeds(self) -> None:
|
|
275
311
|
hidden_size = self.runner.model_config.get_hidden_size()
|
|
@@ -283,10 +319,28 @@ class CompilationManager:
|
|
|
283
319
|
else:
|
|
284
320
|
positions = self._create_dummy_tensor((num_tokens, ),
|
|
285
321
|
jnp.int32)
|
|
286
|
-
self.
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
322
|
+
is_first_rank = self.runner.is_first_rank
|
|
323
|
+
is_last_rank = self.runner.is_last_rank
|
|
324
|
+
if not is_first_rank:
|
|
325
|
+
hidden_states = self._create_dummy_tensor(
|
|
326
|
+
(num_tokens, hidden_size), jnp.bfloat16)
|
|
327
|
+
residual = self._create_dummy_tensor((num_tokens, hidden_size),
|
|
328
|
+
jnp.bfloat16)
|
|
329
|
+
intermediate_tensors = JaxIntermediateTensors(
|
|
330
|
+
tensors={
|
|
331
|
+
"hidden_states": hidden_states,
|
|
332
|
+
"residual": residual
|
|
333
|
+
})
|
|
334
|
+
else:
|
|
335
|
+
intermediate_tensors = None
|
|
336
|
+
self._precompile_backbone_helper(
|
|
337
|
+
f"worker{self.runner.rank} backbone with embeds",
|
|
338
|
+
input_ids=None,
|
|
339
|
+
positions=positions,
|
|
340
|
+
inputs_embeds=inputs_embeds,
|
|
341
|
+
intermediate_tensors=intermediate_tensors,
|
|
342
|
+
is_first_rank=is_first_rank,
|
|
343
|
+
is_last_rank=is_last_rank)
|
|
290
344
|
|
|
291
345
|
def _precompile_select_from_array_helper(
|
|
292
346
|
self,
|
|
@@ -354,7 +408,7 @@ class CompilationManager:
|
|
|
354
408
|
self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None))
|
|
355
409
|
dp_size = self.runner.vllm_config.sharding_config.total_dp_size
|
|
356
410
|
self._precompile_select_from_array_helper(
|
|
357
|
-
name="select all logits",
|
|
411
|
+
name=f"worker{self.runner.rank} select all logits",
|
|
358
412
|
source_paddings=self.runner.num_tokens_paddings,
|
|
359
413
|
indices_paddings=index_paddings,
|
|
360
414
|
hidden_dim=hsize,
|
|
@@ -365,7 +419,8 @@ class CompilationManager:
|
|
|
365
419
|
if self.runner.speculative_config:
|
|
366
420
|
vocab_size = self.runner.model_config.get_vocab_size()
|
|
367
421
|
self._precompile_select_from_array_helper(
|
|
368
|
-
name=
|
|
422
|
+
name=
|
|
423
|
+
f"worker{self.runner.rank} select bonus tokens for spec decoding",
|
|
369
424
|
source_paddings=self.runner.num_logits_paddings,
|
|
370
425
|
indices_paddings=self.runner.num_reqs_paddings,
|
|
371
426
|
hidden_dim=vocab_size,
|
|
@@ -373,7 +428,8 @@ class CompilationManager:
|
|
|
373
428
|
PartitionSpec(None, "model")),
|
|
374
429
|
)
|
|
375
430
|
self._precompile_select_from_array_helper(
|
|
376
|
-
name=
|
|
431
|
+
name=
|
|
432
|
+
f"worker{self.runner.rank} select target tokens for spec decoding",
|
|
377
433
|
source_paddings=self.runner.num_logits_paddings,
|
|
378
434
|
indices_paddings=self.runner.num_logits_paddings,
|
|
379
435
|
hidden_dim=vocab_size,
|
|
@@ -396,7 +452,7 @@ class CompilationManager:
|
|
|
396
452
|
np.array([num_reqs], dtype=np.int32)):
|
|
397
453
|
lora_metadata = self.runner.lora_utils.extract_lora_metadata()
|
|
398
454
|
self._run_compilation(
|
|
399
|
-
"compute_logits",
|
|
455
|
+
f"worker{self.runner.rank} compute_logits",
|
|
400
456
|
self.runner.compute_logits_fn,
|
|
401
457
|
self.runner.state,
|
|
402
458
|
hidden_states,
|
|
@@ -438,7 +494,7 @@ class CompilationManager:
|
|
|
438
494
|
do_sampling=do_sampling,
|
|
439
495
|
)
|
|
440
496
|
self._run_compilation(
|
|
441
|
-
"sample",
|
|
497
|
+
f"worker{self.runner.rank} sample",
|
|
442
498
|
sample,
|
|
443
499
|
self.runner.rng_params_for_sampling,
|
|
444
500
|
self.runner.mesh,
|
|
@@ -479,7 +535,7 @@ class CompilationManager:
|
|
|
479
535
|
logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16)
|
|
480
536
|
token_ids = self._create_dummy_tensor((num_reqs, ), jnp.int32)
|
|
481
537
|
self._run_compilation(
|
|
482
|
-
"gather_logprobs",
|
|
538
|
+
f"worker{self.runner.rank} gather_logprobs",
|
|
483
539
|
self.runner._compute_and_gather_logprobs,
|
|
484
540
|
logits,
|
|
485
541
|
token_ids,
|
|
@@ -531,7 +587,7 @@ class CompilationManager:
|
|
|
531
587
|
do_sampling=do_sampling)
|
|
532
588
|
|
|
533
589
|
self._run_compilation(
|
|
534
|
-
compilation_name,
|
|
590
|
+
f"worker{self.runner.rank} {compilation_name}",
|
|
535
591
|
self.runner.rejection_sampler,
|
|
536
592
|
draft_token_ids,
|
|
537
593
|
num_draft_tokens,
|
|
@@ -601,6 +657,7 @@ class CompilationManager:
|
|
|
601
657
|
self._run_compilation(
|
|
602
658
|
"eagle3_get_draft_token_ids",
|
|
603
659
|
self.runner.drafter._get_draft_token_ids,
|
|
660
|
+
self.runner.drafter.state,
|
|
604
661
|
hidden_states,
|
|
605
662
|
num_logits=num_logits,
|
|
606
663
|
)
|
|
@@ -645,9 +702,9 @@ class CompilationManager:
|
|
|
645
702
|
num_reqs,
|
|
646
703
|
):
|
|
647
704
|
target_hidden_states, input_ids, last_token_indices, _ = self.runner.drafter._filter_token_and_prepare_initial_inputs(
|
|
648
|
-
token_indices, query_start_loc,
|
|
649
|
-
aux_hidden_states, attention_metadata,
|
|
650
|
-
num_reqs)
|
|
705
|
+
self.runner.drafter.state, token_indices, query_start_loc,
|
|
706
|
+
seq_lens, input_ids, aux_hidden_states, attention_metadata,
|
|
707
|
+
next_token_ids, num_reqs)
|
|
651
708
|
return target_hidden_states, input_ids, last_token_indices
|
|
652
709
|
|
|
653
710
|
input_ids = self._create_dummy_tensor((num_tokens, ), jnp.int32)
|
|
@@ -724,6 +781,7 @@ class CompilationManager:
|
|
|
724
781
|
self._run_compilation(
|
|
725
782
|
"eagle3_prepare_hidden_states_and_input_ids",
|
|
726
783
|
self.runner.drafter._prepare_hidden_states_and_input_ids,
|
|
784
|
+
self.runner.drafter.state,
|
|
727
785
|
aux_hidden_states,
|
|
728
786
|
query_start_loc,
|
|
729
787
|
target_token_ids,
|
|
@@ -758,6 +816,7 @@ class CompilationManager:
|
|
|
758
816
|
self._run_compilation(
|
|
759
817
|
"eagle3_select_inputs_for_loop_speculation",
|
|
760
818
|
self.runner.drafter._select_inputs_for_loop_speculation,
|
|
819
|
+
self.runner.drafter.state,
|
|
761
820
|
positions,
|
|
762
821
|
hidden_states,
|
|
763
822
|
hidden_states,
|
|
@@ -768,6 +827,7 @@ class CompilationManager:
|
|
|
768
827
|
self._run_compilation(
|
|
769
828
|
"eagle3_select_draft_token_ids",
|
|
770
829
|
self.runner.drafter._select_draft_token_ids,
|
|
830
|
+
self.runner.drafter.state,
|
|
771
831
|
hidden_states,
|
|
772
832
|
last_token_indices,
|
|
773
833
|
num_tokens=num_tokens,
|
|
@@ -289,13 +289,8 @@ class KVCacheManager:
|
|
|
289
289
|
|
|
290
290
|
def _update_layer(cache, slices):
|
|
291
291
|
"""The function to apply to each layer's cache and slices."""
|
|
292
|
-
reshaped_slices = slices.reshape(-1,
|
|
293
|
-
|
|
294
|
-
for (i, block_idx) in enumerate(block_numbers):
|
|
295
|
-
cache = jax.lax.dynamic_update_slice_in_dim(cache,
|
|
296
|
-
reshaped_slices[i],
|
|
297
|
-
block_idx,
|
|
298
|
-
axis=0)
|
|
292
|
+
reshaped_slices = slices.reshape(-1, block_size, *slices.shape[1:])
|
|
293
|
+
cache.at[block_numbers].set(reshaped_slices)
|
|
299
294
|
return cache
|
|
300
295
|
|
|
301
296
|
return jax.tree.map(_update_layer, kv_caches, kv_cache_slices)
|
|
@@ -348,16 +343,12 @@ class KVCacheManager:
|
|
|
348
343
|
"""
|
|
349
344
|
if block_ids == list(range(block_ids[0],
|
|
350
345
|
block_ids[0] + len(block_ids))):
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
batched_kv_cache_per_layer = self._jitted_gather_continuous_kv_cache(
|
|
354
|
-
self.runner.kv_caches, block_ids[0], len(block_ids))
|
|
346
|
+
batched_kv_cache_per_layer = self._jitted_gather_continuous_kv_cache(
|
|
347
|
+
self.runner.kv_caches, block_ids[0], len(block_ids))
|
|
355
348
|
|
|
356
349
|
else:
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
batched_kv_cache_per_layer = self._jitted_gather_kv_cache(
|
|
360
|
-
self.runner.kv_caches, jnp.array(block_ids))
|
|
350
|
+
batched_kv_cache_per_layer = self._jitted_gather_kv_cache(
|
|
351
|
+
self.runner.kv_caches, jnp.array(block_ids))
|
|
361
352
|
return batched_kv_cache_per_layer
|
|
362
353
|
|
|
363
354
|
def transfer_kv_cache(self,
|
|
@@ -446,6 +437,7 @@ class KVCacheManager:
|
|
|
446
437
|
kv_cache_slices,
|
|
447
438
|
start_block,
|
|
448
439
|
)
|
|
440
|
+
jax.block_until_ready(self.runner.kv_caches)
|
|
449
441
|
else:
|
|
450
442
|
with runner_utils.LatencyTracker(
|
|
451
443
|
f"JittedInsertKVCache-b{len(block_numbers)}"):
|
|
@@ -457,6 +449,7 @@ class KVCacheManager:
|
|
|
457
449
|
kv_cache_slices,
|
|
458
450
|
jnp.array(block_numbers),
|
|
459
451
|
)
|
|
452
|
+
jax.block_until_ready(self.runner.kv_caches)
|
|
460
453
|
|
|
461
454
|
logger.debug(
|
|
462
455
|
f"Updated kv cache entries cnt={len(self.runner.kv_caches)}")
|
|
@@ -14,12 +14,13 @@ class PersistentBatchManager:
|
|
|
14
14
|
def __init__(self, requests: Dict[str, CachedRequestState],
|
|
15
15
|
input_batch: InputBatch, encoder_cache: Dict[str,
|
|
16
16
|
'jax.Array'],
|
|
17
|
-
uses_mrope: bool, model_config):
|
|
17
|
+
uses_mrope: bool, model_config, is_last_rank: bool):
|
|
18
18
|
self.requests = requests
|
|
19
19
|
self.input_batch = input_batch
|
|
20
20
|
self.encoder_cache = encoder_cache
|
|
21
21
|
self.uses_mrope = uses_mrope
|
|
22
22
|
self.model_config = model_config
|
|
23
|
+
self.is_last_rank = is_last_rank
|
|
23
24
|
|
|
24
25
|
def _reorder_batch(self, scheduler_output: "VllmSchedulerOutput") -> int:
|
|
25
26
|
""" Reorder the sheduled requests to RPA kernel friendly distribution
|
|
@@ -179,9 +180,35 @@ class PersistentBatchManager:
|
|
|
179
180
|
num_computed_tokens = req_data.num_computed_tokens[i]
|
|
180
181
|
new_block_ids = req_data.new_block_ids[i]
|
|
181
182
|
resumed_from_preemption = req_data.resumed_from_preemption[i]
|
|
183
|
+
num_output_tokens = req_data.num_output_tokens[i]
|
|
182
184
|
|
|
183
185
|
# Update the cached states.
|
|
184
186
|
req_state.num_computed_tokens = num_computed_tokens
|
|
187
|
+
req_index = self.input_batch.req_id_to_index.get(req_id)
|
|
188
|
+
|
|
189
|
+
if not self.is_last_rank:
|
|
190
|
+
# When using PP, the scheduler sends the sampled tokens back,
|
|
191
|
+
# because there's no direct communication between the first-
|
|
192
|
+
# stage worker and the last-stage worker.
|
|
193
|
+
new_token_ids = req_data.new_token_ids[i]
|
|
194
|
+
# Add the sampled token(s) from the previous step (if any).
|
|
195
|
+
# This doesn't include "unverified" tokens like spec tokens.
|
|
196
|
+
num_new_tokens = (num_computed_tokens + len(new_token_ids) -
|
|
197
|
+
req_state.num_tokens)
|
|
198
|
+
if num_new_tokens == 1:
|
|
199
|
+
req_state.output_token_ids.append(new_token_ids[-1])
|
|
200
|
+
elif num_new_tokens > 0:
|
|
201
|
+
req_state.output_token_ids.extend(
|
|
202
|
+
new_token_ids[-num_new_tokens:])
|
|
203
|
+
elif num_output_tokens < len(req_state.output_token_ids):
|
|
204
|
+
del req_state.output_token_ids[num_output_tokens:]
|
|
205
|
+
if req_index is not None:
|
|
206
|
+
end_idx = (self.input_batch.num_prompt_tokens[req_index] +
|
|
207
|
+
num_output_tokens)
|
|
208
|
+
self.input_batch.num_tokens[req_index] = end_idx
|
|
209
|
+
self.input_batch.num_tokens_no_spec[req_index] = end_idx
|
|
210
|
+
|
|
211
|
+
# Update the block IDs.
|
|
185
212
|
if not resumed_from_preemption:
|
|
186
213
|
if new_block_ids is not None:
|
|
187
214
|
# Append the new blocks to the existing block IDs.
|
|
@@ -194,7 +221,6 @@ class PersistentBatchManager:
|
|
|
194
221
|
# Replace the existing block IDs with the new ones.
|
|
195
222
|
req_state.block_ids = new_block_ids
|
|
196
223
|
|
|
197
|
-
req_index = self.input_batch.req_id_to_index.get(req_id)
|
|
198
224
|
if req_index is None:
|
|
199
225
|
# The request is not in the persistent batch.
|
|
200
226
|
# The request was either preempted and resumed later, or was not
|
|
@@ -209,6 +235,18 @@ class PersistentBatchManager:
|
|
|
209
235
|
self.input_batch.block_table.append_row(
|
|
210
236
|
new_block_ids, req_index)
|
|
211
237
|
|
|
238
|
+
# For the last rank, we don't need to update the token_ids_cpu
|
|
239
|
+
# because the sampled tokens are already cached.
|
|
240
|
+
if not self.is_last_rank:
|
|
241
|
+
start_token_index = num_computed_tokens
|
|
242
|
+
end_token_index = num_computed_tokens + len(new_token_ids)
|
|
243
|
+
self.input_batch.token_ids_cpu[
|
|
244
|
+
req_index,
|
|
245
|
+
start_token_index:end_token_index] = new_token_ids
|
|
246
|
+
self.input_batch.num_tokens_no_spec[
|
|
247
|
+
req_index] = end_token_index
|
|
248
|
+
self.input_batch.num_tokens[req_index] = end_token_index
|
|
249
|
+
|
|
212
250
|
# Add spec_token_ids to token_ids_cpu.
|
|
213
251
|
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
|
|
214
252
|
req_id, ())
|
|
@@ -10,17 +10,16 @@ import jax
|
|
|
10
10
|
import jax.numpy as jnp
|
|
11
11
|
import jaxtyping
|
|
12
12
|
import numpy as np
|
|
13
|
-
import
|
|
14
|
-
import vllm.envs as envs
|
|
13
|
+
import vllm.envs as vllm_envs
|
|
15
14
|
from flax import nnx
|
|
16
15
|
from jax.experimental import mesh_utils
|
|
17
16
|
from jax.sharding import NamedSharding, PartitionSpec
|
|
18
|
-
from torchax.ops.mappings import
|
|
17
|
+
from torchax.ops.mappings import t2j_dtype
|
|
19
18
|
from vllm.config import VllmConfig
|
|
19
|
+
from vllm.distributed import get_pp_group
|
|
20
20
|
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
|
21
21
|
has_kv_transfer_group)
|
|
22
22
|
from vllm.forward_context import set_forward_context
|
|
23
|
-
from vllm.sequence import IntermediateTensors
|
|
24
23
|
from vllm.tasks import SupportedTask
|
|
25
24
|
from vllm.utils.math_utils import cdiv
|
|
26
25
|
from vllm.v1.core.sched.output import GrammarOutput
|
|
@@ -35,6 +34,7 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import \
|
|
|
35
34
|
KVConnectorModelRunnerMixin
|
|
36
35
|
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
|
37
36
|
|
|
37
|
+
import tpu_inference.envs as envs
|
|
38
38
|
from tpu_inference import utils as common_utils
|
|
39
39
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
40
40
|
from tpu_inference.layers.common.sharding import (MESH_AXIS_NAMES,
|
|
@@ -48,6 +48,8 @@ from tpu_inference.layers.jax.sample.sampling_metadata import \
|
|
|
48
48
|
TPUSupportedSamplingMetadata
|
|
49
49
|
from tpu_inference.logger import init_logger
|
|
50
50
|
from tpu_inference.models.common.model_loader import get_model
|
|
51
|
+
from tpu_inference.models.jax.jax_intermediate_tensor import \
|
|
52
|
+
JaxIntermediateTensors
|
|
51
53
|
from tpu_inference.models.jax.utils.weight_utils import (
|
|
52
54
|
shard_put, transfer_state_with_mappings)
|
|
53
55
|
from tpu_inference.runner import utils as runner_utils
|
|
@@ -64,7 +66,7 @@ from tpu_inference.runner.structured_decoding_manager import \
|
|
|
64
66
|
StructuredDecodingManager
|
|
65
67
|
from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
|
|
66
68
|
from tpu_inference.utils import (device_array, make_optimized_mesh,
|
|
67
|
-
time_function)
|
|
69
|
+
time_function, to_torch_dtype)
|
|
68
70
|
|
|
69
71
|
logger = init_logger(__name__)
|
|
70
72
|
|
|
@@ -78,17 +80,6 @@ DUMMY_METADATA = AttentionMetadata(
|
|
|
78
80
|
request_distribution=[0, 0, 0],
|
|
79
81
|
)
|
|
80
82
|
|
|
81
|
-
TPU_STR_DTYPE_TO_TORCH_DTYPE = {
|
|
82
|
-
"half": torch.half,
|
|
83
|
-
"bfloat16": torch.bfloat16,
|
|
84
|
-
"float": torch.float,
|
|
85
|
-
"fp8": torch.float8_e4m3fn,
|
|
86
|
-
"fp8_e4m3": torch.float8_e4m3fn,
|
|
87
|
-
"fp8_e5m2": torch.float8_e5m2,
|
|
88
|
-
"int8": torch.int8,
|
|
89
|
-
"uint8": torch.uint8,
|
|
90
|
-
}
|
|
91
|
-
|
|
92
83
|
|
|
93
84
|
class AsyncTPUModelRunnerOutput(AsyncModelRunnerOutput):
|
|
94
85
|
"""Holds asynchronous model output specifically from a TPU runner.
|
|
@@ -243,6 +234,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
243
234
|
self.maybe_forbid_compile = runner_utils.ForbidCompile(
|
|
244
235
|
) if envs.VLLM_XLA_CHECK_RECOMPILATION else nullcontext()
|
|
245
236
|
self.dp_size = self.vllm_config.sharding_config.total_dp_size
|
|
237
|
+
self.rank = rank
|
|
238
|
+
self.is_first_rank = is_first_rank
|
|
239
|
+
self.is_last_rank = is_last_rank
|
|
246
240
|
|
|
247
241
|
self._init_random()
|
|
248
242
|
self._init_mesh()
|
|
@@ -253,31 +247,21 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
253
247
|
|
|
254
248
|
# Delegate functions to specific manager classes.
|
|
255
249
|
self.compilation_manager = CompilationManager(self)
|
|
256
|
-
self.
|
|
257
|
-
|
|
250
|
+
if self.is_last_rank:
|
|
251
|
+
self.speculative_decoding_manager = SpeculativeDecodingManager(
|
|
252
|
+
self)
|
|
253
|
+
self.structured_decoding_manager = StructuredDecodingManager(self)
|
|
258
254
|
self.kv_cache_manager = KVCacheManager(self)
|
|
259
255
|
self.mm_manager = MultiModalManager(self)
|
|
260
256
|
self.persistent_batch_manager = PersistentBatchManager(
|
|
261
257
|
self.requests, self.input_batch, self.encoder_cache,
|
|
262
|
-
self.uses_mrope, self.model_config)
|
|
258
|
+
self.uses_mrope, self.model_config, self.is_last_rank)
|
|
263
259
|
self.lora_utils = LoraUtils(self)
|
|
264
260
|
|
|
265
|
-
|
|
266
|
-
if
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
|
|
270
|
-
elif isinstance(getattr(model_dtype, 'dtype', None), jnp.dtype):
|
|
271
|
-
self.kv_cache_dtype = j2t_dtype(model_dtype.dtype)
|
|
272
|
-
elif isinstance(model_dtype, torch.dtype):
|
|
273
|
-
self.kv_cache_dtype = model_dtype
|
|
274
|
-
else:
|
|
275
|
-
raise ValueError(
|
|
276
|
-
"KV cache is unsupported for model_dtype of %s",
|
|
277
|
-
model_dtype)
|
|
278
|
-
else:
|
|
279
|
-
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[
|
|
280
|
-
cache_config.cache_dtype]
|
|
261
|
+
cache_dtype = self.cache_config.cache_dtype
|
|
262
|
+
if cache_dtype == "auto":
|
|
263
|
+
cache_dtype = self.dtype
|
|
264
|
+
self.kv_cache_dtype = to_torch_dtype(cache_dtype)
|
|
281
265
|
|
|
282
266
|
self._pre_async_results: AsyncPreResults | None = None
|
|
283
267
|
self._substitute_placeholder_token_fn = _substitute_placeholder_token
|
|
@@ -291,7 +275,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
291
275
|
self.rng_key = jax.random.key(self.model_config.seed)
|
|
292
276
|
|
|
293
277
|
def _init_mesh(self) -> None:
|
|
294
|
-
if
|
|
278
|
+
if envs.NEW_MODEL_DESIGN:
|
|
295
279
|
self.mesh = self._create_new_model_mesh()
|
|
296
280
|
else:
|
|
297
281
|
# NOTE(wenxindongwork): The new MoE kernel expects a 2D mesh, so we need
|
|
@@ -302,7 +286,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
302
286
|
logger.info(f"Init mesh | mesh={self.mesh}")
|
|
303
287
|
|
|
304
288
|
def _create_new_model_mesh(self) -> jax.sharding.Mesh:
|
|
305
|
-
num_slices =
|
|
289
|
+
num_slices = envs.NUM_SLICES
|
|
306
290
|
|
|
307
291
|
logger.info(f"Creating new model mesh | devices={len(self.devices)}, "
|
|
308
292
|
f"num_slices={num_slices}")
|
|
@@ -371,7 +355,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
371
355
|
devices=self.devices)
|
|
372
356
|
|
|
373
357
|
def _init_phased_profiling(self) -> None:
|
|
374
|
-
self.phased_profiling_dir =
|
|
358
|
+
self.phased_profiling_dir = envs.PHASED_PROFILING_DIR
|
|
375
359
|
self.phase_based_profiler = None
|
|
376
360
|
if self.phased_profiling_dir:
|
|
377
361
|
self.phase_based_profiler = runner_utils.PhasedBasedProfiler(
|
|
@@ -413,7 +397,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
413
397
|
min_token_size=max(16, self.dp_size),
|
|
414
398
|
max_token_size=scheduler_config.max_num_batched_tokens *
|
|
415
399
|
self.dp_size,
|
|
416
|
-
padding_gap=
|
|
400
|
+
padding_gap=vllm_envs.VLLM_TPU_BUCKET_PADDING_GAP)
|
|
417
401
|
self.num_tokens_paddings_per_dp = [
|
|
418
402
|
padding // self.dp_size for padding in self.num_tokens_paddings
|
|
419
403
|
]
|
|
@@ -555,12 +539,12 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
555
539
|
def execute_model(
|
|
556
540
|
self,
|
|
557
541
|
scheduler_output: "VllmSchedulerOutput",
|
|
558
|
-
intermediate_tensors: Optional[
|
|
559
|
-
) -> ModelRunnerOutput | None:
|
|
542
|
+
intermediate_tensors: Optional[JaxIntermediateTensors] = None,
|
|
543
|
+
) -> ModelRunnerOutput | JaxIntermediateTensors | None:
|
|
560
544
|
if self.execute_model_state is not None:
|
|
561
545
|
raise RuntimeError("State error: sample_tokens() must be called "
|
|
562
546
|
"after execute_model() returns None.")
|
|
563
|
-
_, output = self._execute_model(scheduler_output)
|
|
547
|
+
_, output = self._execute_model(scheduler_output, intermediate_tensors)
|
|
564
548
|
return output
|
|
565
549
|
|
|
566
550
|
def sample_tokens(
|
|
@@ -686,7 +670,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
686
670
|
def _execute_model(
|
|
687
671
|
self,
|
|
688
672
|
scheduler_output: "VllmSchedulerOutput",
|
|
689
|
-
|
|
673
|
+
intermediate_tensors: Optional[JaxIntermediateTensors] = None,
|
|
674
|
+
) -> tuple[AttentionMetadata, JaxIntermediateTensors | ModelRunnerOutput
|
|
675
|
+
| None]:
|
|
690
676
|
self.persistent_batch_manager.update_states(
|
|
691
677
|
scheduler_output, self.get_mrope_input_positions_fn)
|
|
692
678
|
if not scheduler_output.total_num_scheduled_tokens:
|
|
@@ -764,7 +750,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
764
750
|
scheduler_output) as kv_connector_output:
|
|
765
751
|
# NOTE(Wenlong): It takes both `input_ids` and `inputs_embeds`,
|
|
766
752
|
# but one of them would be `None`
|
|
767
|
-
|
|
768
753
|
(self.kv_caches, hidden_states,
|
|
769
754
|
aux_hidden_states) = self.model_fn(
|
|
770
755
|
self.state,
|
|
@@ -775,8 +760,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
775
760
|
input_positions,
|
|
776
761
|
tuple(self.layer_name_to_kvcache_index.items()),
|
|
777
762
|
lora_metadata,
|
|
763
|
+
intermediate_tensors,
|
|
764
|
+
self.is_first_rank,
|
|
765
|
+
self.is_last_rank,
|
|
778
766
|
)
|
|
779
|
-
|
|
767
|
+
if not get_pp_group().is_last_rank:
|
|
768
|
+
assert isinstance(hidden_states, JaxIntermediateTensors)
|
|
769
|
+
hidden_states.kv_connector_output = kv_connector_output
|
|
770
|
+
return attn_metadata, hidden_states
|
|
780
771
|
hidden_states = self._select_from_array_fn(hidden_states,
|
|
781
772
|
logits_indices)
|
|
782
773
|
logits = self.compute_logits_fn(
|
|
@@ -1719,3 +1710,35 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1719
1710
|
mappings=mappings,
|
|
1720
1711
|
transpose_keys=transpose_keys,
|
|
1721
1712
|
shard=shard)
|
|
1713
|
+
|
|
1714
|
+
def get_intermediate_tensor_spec(self, num_tokens: int):
|
|
1715
|
+
impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower()
|
|
1716
|
+
jax_dtype = t2j_dtype(self.dtype) if impl == "vllm" else self.dtype
|
|
1717
|
+
num_padded_tokens = runner_utils.get_padded_token_len(
|
|
1718
|
+
self.num_tokens_paddings, num_tokens)
|
|
1719
|
+
sharding = NamedSharding(self.mesh, PartitionSpec())
|
|
1720
|
+
hidden_size = self.model_config.get_hidden_size()
|
|
1721
|
+
spec = jax.ShapeDtypeStruct(shape=(num_padded_tokens, hidden_size),
|
|
1722
|
+
dtype=jax_dtype,
|
|
1723
|
+
sharding=sharding)
|
|
1724
|
+
tensor_spec = {"hidden_states": spec, "residual": spec}
|
|
1725
|
+
return tensor_spec
|
|
1726
|
+
|
|
1727
|
+
def get_uuid_for_jax_transfer(self,
|
|
1728
|
+
scheduler_output: "VllmSchedulerOutput",
|
|
1729
|
+
rank: int, step: int) -> int:
|
|
1730
|
+
'''
|
|
1731
|
+
Get a uuid for jax.transfer, here we use the hash of
|
|
1732
|
+
scheduler_output + counter_step + sender's rank
|
|
1733
|
+
'''
|
|
1734
|
+
scheduler_output_str = ""
|
|
1735
|
+
if not scheduler_output.num_scheduled_tokens:
|
|
1736
|
+
scheduler_output_str = "empty_batch"
|
|
1737
|
+
else:
|
|
1738
|
+
scheduler_output_str = str(
|
|
1739
|
+
sorted(scheduler_output.num_scheduled_tokens.items()))
|
|
1740
|
+
unique_str = f'{scheduler_output_str} {step} {rank}'
|
|
1741
|
+
import hashlib
|
|
1742
|
+
hasher = hashlib.sha1()
|
|
1743
|
+
hasher.update(unique_str.encode('utf-8'))
|
|
1744
|
+
return int.from_bytes(hasher.digest()[:8], 'big')
|
tpu_inference/runner/utils.py
CHANGED
|
@@ -15,6 +15,7 @@ import jax
|
|
|
15
15
|
from jax._src.interpreters import pxla
|
|
16
16
|
from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
|
|
17
17
|
|
|
18
|
+
from tpu_inference import envs
|
|
18
19
|
from tpu_inference.logger import init_logger
|
|
19
20
|
from tpu_inference.runner.input_batch import InputBatch
|
|
20
21
|
|
|
@@ -306,8 +307,7 @@ class PhasedBasedProfiler:
|
|
|
306
307
|
InferencePhase.BALANCED: False
|
|
307
308
|
}
|
|
308
309
|
self.default_profiling_options = jax.profiler.ProfileOptions()
|
|
309
|
-
self.default_profiling_options.python_tracer_level =
|
|
310
|
-
"PYTHON_TRACER_LEVEL", 0)
|
|
310
|
+
self.default_profiling_options.python_tracer_level = envs.PYTHON_TRACER_LEVEL
|
|
311
311
|
|
|
312
312
|
self.current_phase: str = ""
|
|
313
313
|
|