tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/kernels/mla_v1_test.py +129 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
- tests/lora/test_layers.py +4 -1
- tests/lora/test_lora_perf.py +53 -0
- tests/test_envs.py +110 -12
- tests/test_quantization.py +3 -0
- tests/test_utils.py +1 -2
- tpu_inference/distributed/tpu_connector.py +1 -1
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/ray_distributed_executor.py +5 -1
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
- tpu_inference/kernels/mla/v1/kernel.py +98 -120
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +82 -32
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +146 -85
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +11 -7
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +170 -208
- tpu_inference/layers/vllm/linear_common.py +43 -21
- tpu_inference/layers/vllm/quantization/common.py +11 -6
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
- tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
- tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
- tpu_inference/models/common/model_loader.py +78 -22
- tpu_inference/models/jax/deepseek_v3.py +185 -64
- tpu_inference/models/jax/gpt_oss.py +3 -3
- tpu_inference/models/jax/llama_eagle3.py +4 -5
- tpu_inference/models/jax/qwen2_5_vl.py +161 -47
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
- tpu_inference/models/jax/utils/weight_utils.py +203 -155
- tpu_inference/models/vllm/vllm_model_wrapper.py +11 -5
- tpu_inference/platforms/tpu_platform.py +29 -48
- tpu_inference/runner/compilation_manager.py +112 -46
- tpu_inference/runner/kv_cache.py +40 -20
- tpu_inference/runner/kv_cache_manager.py +40 -31
- tpu_inference/runner/persistent_batch_manager.py +40 -2
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +94 -51
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +71 -22
- tpu_inference/utils.py +41 -14
- tpu_inference/worker/tpu_worker.py +43 -45
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +8 -9
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +59 -58
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
|
@@ -1,13 +1,13 @@
|
|
|
1
|
-
import os
|
|
2
1
|
import time
|
|
3
2
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
|
4
3
|
|
|
5
4
|
import jax
|
|
6
5
|
import jax.numpy as jnp
|
|
7
6
|
import numpy as np
|
|
8
|
-
import vllm.envs as
|
|
7
|
+
import vllm.envs as vllm_envs
|
|
9
8
|
from jax.sharding import NamedSharding, PartitionSpec
|
|
10
9
|
|
|
10
|
+
import tpu_inference.envs as envs
|
|
11
11
|
from tpu_inference.core.disagg_utils import is_disagg_enabled
|
|
12
12
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
13
13
|
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
@@ -15,6 +15,8 @@ from tpu_inference.layers.jax.sample.sampling import sample
|
|
|
15
15
|
from tpu_inference.layers.jax.sample.sampling_metadata import \
|
|
16
16
|
TPUSupportedSamplingMetadata
|
|
17
17
|
from tpu_inference.logger import init_logger
|
|
18
|
+
from tpu_inference.models.jax.jax_intermediate_tensor import \
|
|
19
|
+
JaxIntermediateTensors
|
|
18
20
|
from tpu_inference.utils import device_array
|
|
19
21
|
|
|
20
22
|
if TYPE_CHECKING:
|
|
@@ -30,10 +32,10 @@ class CompilationManager:
|
|
|
30
32
|
|
|
31
33
|
def __init__(self, runner: "TPUModelRunner"):
|
|
32
34
|
self.runner = runner
|
|
33
|
-
if not
|
|
35
|
+
if not vllm_envs.VLLM_DISABLE_COMPILE_CACHE:
|
|
34
36
|
logger.info("Enabling JAX compile cache.")
|
|
35
37
|
jax.config.update("jax_compilation_cache_dir",
|
|
36
|
-
|
|
38
|
+
vllm_envs.VLLM_XLA_CACHE_PATH)
|
|
37
39
|
|
|
38
40
|
def _create_dummy_tensor(self,
|
|
39
41
|
shape: Tuple[int, ...],
|
|
@@ -67,8 +69,7 @@ class CompilationManager:
|
|
|
67
69
|
logger.info("Compilation finished in %.2f [secs].", end - start)
|
|
68
70
|
|
|
69
71
|
def capture_model(self) -> None:
|
|
70
|
-
if
|
|
71
|
-
False) or self.runner.model_config.enforce_eager:
|
|
72
|
+
if envs.SKIP_JAX_PRECOMPILE or self.runner.model_config.enforce_eager:
|
|
72
73
|
return
|
|
73
74
|
logger.info("Precompile all the subgraphs with possible input shapes.")
|
|
74
75
|
|
|
@@ -81,6 +82,8 @@ class CompilationManager:
|
|
|
81
82
|
self._precompile_backbone_with_inputs_embeds()
|
|
82
83
|
if self.runner.scheduler_config.async_scheduling:
|
|
83
84
|
self._precompile_substitute_placeholder_token()
|
|
85
|
+
if not self.runner.is_last_rank:
|
|
86
|
+
return
|
|
84
87
|
self._precompile_select_from_array()
|
|
85
88
|
self._precompile_compute_logits()
|
|
86
89
|
self._precompile_disagg_utils()
|
|
@@ -120,8 +123,15 @@ class CompilationManager:
|
|
|
120
123
|
num_tokens=num_tokens,
|
|
121
124
|
)
|
|
122
125
|
|
|
123
|
-
def _precompile_backbone_helper(self,
|
|
124
|
-
|
|
126
|
+
def _precompile_backbone_helper(self,
|
|
127
|
+
name,
|
|
128
|
+
*,
|
|
129
|
+
input_ids,
|
|
130
|
+
positions,
|
|
131
|
+
inputs_embeds,
|
|
132
|
+
intermediate_tensors=None,
|
|
133
|
+
is_first_rank=True,
|
|
134
|
+
is_last_rank=True) -> None:
|
|
125
135
|
num_tokens = None
|
|
126
136
|
if input_ids is not None:
|
|
127
137
|
num_tokens = input_ids.shape[0]
|
|
@@ -181,10 +191,14 @@ class CompilationManager:
|
|
|
181
191
|
inputs_embeds,
|
|
182
192
|
layer_name_to_kvcache_index,
|
|
183
193
|
lora_metadata,
|
|
194
|
+
intermediate_tensors,
|
|
195
|
+
is_first_rank,
|
|
196
|
+
is_last_rank,
|
|
184
197
|
):
|
|
185
198
|
kv_caches, hidden_states, _ = self.runner.model_fn(
|
|
186
199
|
state, kv_caches, input_ids, attention_metadata, inputs_embeds,
|
|
187
|
-
positions, layer_name_to_kvcache_index, lora_metadata
|
|
200
|
+
positions, layer_name_to_kvcache_index, lora_metadata,
|
|
201
|
+
intermediate_tensors, is_first_rank, is_last_rank)
|
|
188
202
|
self.runner.kv_caches = kv_caches
|
|
189
203
|
return hidden_states
|
|
190
204
|
|
|
@@ -207,6 +221,9 @@ class CompilationManager:
|
|
|
207
221
|
inputs_embeds,
|
|
208
222
|
tuple(self.runner.layer_name_to_kvcache_index.items()),
|
|
209
223
|
lora_metadata,
|
|
224
|
+
intermediate_tensors,
|
|
225
|
+
is_first_rank,
|
|
226
|
+
is_last_rank,
|
|
210
227
|
num_tokens=num_tokens,
|
|
211
228
|
)
|
|
212
229
|
|
|
@@ -257,6 +274,7 @@ class CompilationManager:
|
|
|
257
274
|
)
|
|
258
275
|
|
|
259
276
|
def _precompile_backbone_text_only(self) -> None:
|
|
277
|
+
hidden_size = self.runner.model_config.get_hidden_size()
|
|
260
278
|
for num_tokens in self.runner.num_tokens_paddings:
|
|
261
279
|
dp_sharding = NamedSharding(
|
|
262
280
|
self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, )
|
|
@@ -266,10 +284,28 @@ class CompilationManager:
|
|
|
266
284
|
dp_sharding)
|
|
267
285
|
positions = self._create_dummy_tensor((num_tokens, ), jnp.int32,
|
|
268
286
|
dp_sharding)
|
|
269
|
-
self.
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
287
|
+
is_first_rank = self.runner.is_first_rank
|
|
288
|
+
is_last_rank = self.runner.is_last_rank
|
|
289
|
+
if is_first_rank:
|
|
290
|
+
intermediate_tensors = None
|
|
291
|
+
else:
|
|
292
|
+
hidden_states = self._create_dummy_tensor(
|
|
293
|
+
(num_tokens, hidden_size), jnp.bfloat16)
|
|
294
|
+
residual = self._create_dummy_tensor((num_tokens, hidden_size),
|
|
295
|
+
jnp.bfloat16)
|
|
296
|
+
intermediate_tensors = JaxIntermediateTensors(
|
|
297
|
+
tensors={
|
|
298
|
+
"hidden_states": hidden_states,
|
|
299
|
+
"residual": residual
|
|
300
|
+
})
|
|
301
|
+
self._precompile_backbone_helper(
|
|
302
|
+
f"worker{self.runner.rank} backbone",
|
|
303
|
+
input_ids=input_ids,
|
|
304
|
+
positions=positions,
|
|
305
|
+
inputs_embeds=None,
|
|
306
|
+
intermediate_tensors=intermediate_tensors,
|
|
307
|
+
is_first_rank=is_first_rank,
|
|
308
|
+
is_last_rank=is_last_rank)
|
|
273
309
|
|
|
274
310
|
def _precompile_backbone_with_inputs_embeds(self) -> None:
|
|
275
311
|
hidden_size = self.runner.model_config.get_hidden_size()
|
|
@@ -283,10 +319,28 @@ class CompilationManager:
|
|
|
283
319
|
else:
|
|
284
320
|
positions = self._create_dummy_tensor((num_tokens, ),
|
|
285
321
|
jnp.int32)
|
|
286
|
-
self.
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
322
|
+
is_first_rank = self.runner.is_first_rank
|
|
323
|
+
is_last_rank = self.runner.is_last_rank
|
|
324
|
+
if not is_first_rank:
|
|
325
|
+
hidden_states = self._create_dummy_tensor(
|
|
326
|
+
(num_tokens, hidden_size), jnp.bfloat16)
|
|
327
|
+
residual = self._create_dummy_tensor((num_tokens, hidden_size),
|
|
328
|
+
jnp.bfloat16)
|
|
329
|
+
intermediate_tensors = JaxIntermediateTensors(
|
|
330
|
+
tensors={
|
|
331
|
+
"hidden_states": hidden_states,
|
|
332
|
+
"residual": residual
|
|
333
|
+
})
|
|
334
|
+
else:
|
|
335
|
+
intermediate_tensors = None
|
|
336
|
+
self._precompile_backbone_helper(
|
|
337
|
+
f"worker{self.runner.rank} backbone with embeds",
|
|
338
|
+
input_ids=None,
|
|
339
|
+
positions=positions,
|
|
340
|
+
inputs_embeds=inputs_embeds,
|
|
341
|
+
intermediate_tensors=intermediate_tensors,
|
|
342
|
+
is_first_rank=is_first_rank,
|
|
343
|
+
is_last_rank=is_last_rank)
|
|
290
344
|
|
|
291
345
|
def _precompile_select_from_array_helper(
|
|
292
346
|
self,
|
|
@@ -354,7 +408,7 @@ class CompilationManager:
|
|
|
354
408
|
self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None))
|
|
355
409
|
dp_size = self.runner.vllm_config.sharding_config.total_dp_size
|
|
356
410
|
self._precompile_select_from_array_helper(
|
|
357
|
-
name="select all logits",
|
|
411
|
+
name=f"worker{self.runner.rank} select all logits",
|
|
358
412
|
source_paddings=self.runner.num_tokens_paddings,
|
|
359
413
|
indices_paddings=index_paddings,
|
|
360
414
|
hidden_dim=hsize,
|
|
@@ -365,7 +419,8 @@ class CompilationManager:
|
|
|
365
419
|
if self.runner.speculative_config:
|
|
366
420
|
vocab_size = self.runner.model_config.get_vocab_size()
|
|
367
421
|
self._precompile_select_from_array_helper(
|
|
368
|
-
name=
|
|
422
|
+
name=
|
|
423
|
+
f"worker{self.runner.rank} select bonus tokens for spec decoding",
|
|
369
424
|
source_paddings=self.runner.num_logits_paddings,
|
|
370
425
|
indices_paddings=self.runner.num_reqs_paddings,
|
|
371
426
|
hidden_dim=vocab_size,
|
|
@@ -373,7 +428,8 @@ class CompilationManager:
|
|
|
373
428
|
PartitionSpec(None, "model")),
|
|
374
429
|
)
|
|
375
430
|
self._precompile_select_from_array_helper(
|
|
376
|
-
name=
|
|
431
|
+
name=
|
|
432
|
+
f"worker{self.runner.rank} select target tokens for spec decoding",
|
|
377
433
|
source_paddings=self.runner.num_logits_paddings,
|
|
378
434
|
indices_paddings=self.runner.num_logits_paddings,
|
|
379
435
|
hidden_dim=vocab_size,
|
|
@@ -396,7 +452,7 @@ class CompilationManager:
|
|
|
396
452
|
np.array([num_reqs], dtype=np.int32)):
|
|
397
453
|
lora_metadata = self.runner.lora_utils.extract_lora_metadata()
|
|
398
454
|
self._run_compilation(
|
|
399
|
-
"compute_logits",
|
|
455
|
+
f"worker{self.runner.rank} compute_logits",
|
|
400
456
|
self.runner.compute_logits_fn,
|
|
401
457
|
self.runner.state,
|
|
402
458
|
hidden_states,
|
|
@@ -410,11 +466,12 @@ class CompilationManager:
|
|
|
410
466
|
for num_reqs in self.runner.num_reqs_paddings:
|
|
411
467
|
logits_sharding = NamedSharding(
|
|
412
468
|
self.runner.mesh,
|
|
413
|
-
PartitionSpec(ShardingAxisName.
|
|
469
|
+
PartitionSpec(ShardingAxisName.MLP_DATA,
|
|
470
|
+
ShardingAxisName.MLP_TENSOR))
|
|
414
471
|
dp_size = self.runner.vllm_config.sharding_config.total_dp_size
|
|
415
472
|
sampling_metadata_sharding = NamedSharding(
|
|
416
473
|
self.runner.mesh, PartitionSpec(
|
|
417
|
-
ShardingAxisName.
|
|
474
|
+
ShardingAxisName.MLP_DATA)) if dp_size > 1 else None
|
|
418
475
|
logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16,
|
|
419
476
|
logits_sharding)
|
|
420
477
|
for do_sampling in (True, False):
|
|
@@ -438,7 +495,7 @@ class CompilationManager:
|
|
|
438
495
|
do_sampling=do_sampling,
|
|
439
496
|
)
|
|
440
497
|
self._run_compilation(
|
|
441
|
-
"sample",
|
|
498
|
+
f"worker{self.runner.rank} sample",
|
|
442
499
|
sample,
|
|
443
500
|
self.runner.rng_params_for_sampling,
|
|
444
501
|
self.runner.mesh,
|
|
@@ -479,7 +536,7 @@ class CompilationManager:
|
|
|
479
536
|
logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16)
|
|
480
537
|
token_ids = self._create_dummy_tensor((num_reqs, ), jnp.int32)
|
|
481
538
|
self._run_compilation(
|
|
482
|
-
"gather_logprobs",
|
|
539
|
+
f"worker{self.runner.rank} gather_logprobs",
|
|
483
540
|
self.runner._compute_and_gather_logprobs,
|
|
484
541
|
logits,
|
|
485
542
|
token_ids,
|
|
@@ -531,7 +588,7 @@ class CompilationManager:
|
|
|
531
588
|
do_sampling=do_sampling)
|
|
532
589
|
|
|
533
590
|
self._run_compilation(
|
|
534
|
-
compilation_name,
|
|
591
|
+
f"worker{self.runner.rank} {compilation_name}",
|
|
535
592
|
self.runner.rejection_sampler,
|
|
536
593
|
draft_token_ids,
|
|
537
594
|
num_draft_tokens,
|
|
@@ -548,7 +605,9 @@ class CompilationManager:
|
|
|
548
605
|
def _precompile_eagle3_helpers(self) -> None:
|
|
549
606
|
logger.info(
|
|
550
607
|
"Compiling eagle3 jitted helpers with different input shapes.")
|
|
551
|
-
|
|
608
|
+
target_hidden_size = self.runner.model_config.get_hidden_size()
|
|
609
|
+
draft_hidden_size = self.runner.speculative_config.draft_model_config.get_hidden_size(
|
|
610
|
+
)
|
|
552
611
|
dtype = self.runner.model_config.dtype
|
|
553
612
|
|
|
554
613
|
num_kv_cache_groups = len(self.runner.kv_cache_config.kv_cache_groups)
|
|
@@ -595,10 +654,11 @@ class CompilationManager:
|
|
|
595
654
|
|
|
596
655
|
for num_logits in self.runner.num_logits_paddings:
|
|
597
656
|
hidden_states = self._create_dummy_tensor(
|
|
598
|
-
(num_logits,
|
|
657
|
+
(num_logits, draft_hidden_size), jnp.bfloat16)
|
|
599
658
|
self._run_compilation(
|
|
600
659
|
"eagle3_get_draft_token_ids",
|
|
601
660
|
self.runner.drafter._get_draft_token_ids,
|
|
661
|
+
self.runner.drafter.state,
|
|
602
662
|
hidden_states,
|
|
603
663
|
num_logits=num_logits,
|
|
604
664
|
)
|
|
@@ -606,8 +666,8 @@ class CompilationManager:
|
|
|
606
666
|
input_ids_loop = self._create_dummy_tensor(
|
|
607
667
|
(self.runner.max_num_reqs, ), jnp.int32,
|
|
608
668
|
NamedSharding(self.runner.mesh, PartitionSpec()))
|
|
609
|
-
|
|
610
|
-
(self.runner.max_num_reqs,
|
|
669
|
+
draft_hidden_state_loop = self._create_dummy_tensor(
|
|
670
|
+
(self.runner.max_num_reqs, draft_hidden_size), dtype,
|
|
611
671
|
NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
|
|
612
672
|
next_token_ids = self._create_dummy_tensor(
|
|
613
673
|
(self.runner.max_num_reqs, ), jnp.int32)
|
|
@@ -615,9 +675,12 @@ class CompilationManager:
|
|
|
615
675
|
(self.runner.max_num_reqs, ), jnp.int32)
|
|
616
676
|
for num_tokens in self.runner.num_tokens_paddings:
|
|
617
677
|
aux_hidden_states = [
|
|
618
|
-
self._create_dummy_tensor((num_tokens,
|
|
619
|
-
|
|
620
|
-
self._create_dummy_tensor((num_tokens,
|
|
678
|
+
self._create_dummy_tensor((num_tokens, target_hidden_size),
|
|
679
|
+
dtype),
|
|
680
|
+
self._create_dummy_tensor((num_tokens, target_hidden_size),
|
|
681
|
+
dtype),
|
|
682
|
+
self._create_dummy_tensor((num_tokens, target_hidden_size),
|
|
683
|
+
dtype),
|
|
621
684
|
]
|
|
622
685
|
|
|
623
686
|
positions = self._create_dummy_tensor((num_tokens, ), jnp.int32)
|
|
@@ -640,23 +703,23 @@ class CompilationManager:
|
|
|
640
703
|
num_reqs,
|
|
641
704
|
):
|
|
642
705
|
target_hidden_states, input_ids, last_token_indices, _ = self.runner.drafter._filter_token_and_prepare_initial_inputs(
|
|
643
|
-
token_indices, query_start_loc,
|
|
644
|
-
aux_hidden_states, attention_metadata,
|
|
645
|
-
num_reqs)
|
|
706
|
+
self.runner.drafter.state, token_indices, query_start_loc,
|
|
707
|
+
seq_lens, input_ids, aux_hidden_states, attention_metadata,
|
|
708
|
+
next_token_ids, num_reqs)
|
|
646
709
|
return target_hidden_states, input_ids, last_token_indices
|
|
647
710
|
|
|
648
711
|
input_ids = self._create_dummy_tensor((num_tokens, ), jnp.int32)
|
|
649
712
|
aux_hidden_states = [
|
|
650
713
|
self._create_dummy_tensor(
|
|
651
|
-
(num_tokens,
|
|
714
|
+
(num_tokens, target_hidden_size), jnp.bfloat16,
|
|
652
715
|
NamedSharding(self.runner.mesh, PartitionSpec(None,
|
|
653
716
|
None))),
|
|
654
717
|
self._create_dummy_tensor(
|
|
655
|
-
(num_tokens,
|
|
718
|
+
(num_tokens, target_hidden_size), jnp.bfloat16,
|
|
656
719
|
NamedSharding(self.runner.mesh, PartitionSpec(None,
|
|
657
720
|
None))),
|
|
658
721
|
self._create_dummy_tensor(
|
|
659
|
-
(num_tokens,
|
|
722
|
+
(num_tokens, target_hidden_size), jnp.bfloat16,
|
|
660
723
|
NamedSharding(self.runner.mesh, PartitionSpec(None,
|
|
661
724
|
None))),
|
|
662
725
|
]
|
|
@@ -688,17 +751,17 @@ class CompilationManager:
|
|
|
688
751
|
state,
|
|
689
752
|
kv_caches,
|
|
690
753
|
input_ids,
|
|
691
|
-
|
|
754
|
+
draft_hidden_states,
|
|
692
755
|
attention_metadata,
|
|
693
756
|
):
|
|
694
757
|
kv_caches, hidden_states, _ = self.runner.drafter.model_fn(
|
|
695
|
-
state, kv_caches, input_ids,
|
|
758
|
+
state, kv_caches, input_ids, draft_hidden_states,
|
|
696
759
|
attention_metadata)
|
|
697
760
|
self.runner.kv_caches = kv_caches
|
|
698
761
|
return hidden_states
|
|
699
762
|
|
|
700
|
-
|
|
701
|
-
(num_tokens,
|
|
763
|
+
draft_hidden_states = self._create_dummy_tensor(
|
|
764
|
+
(num_tokens, draft_hidden_size), dtype,
|
|
702
765
|
NamedSharding(self.runner.mesh, PartitionSpec(None, "model")))
|
|
703
766
|
input_ids = self._create_dummy_tensor(
|
|
704
767
|
(num_tokens, ), jnp.int32,
|
|
@@ -709,7 +772,7 @@ class CompilationManager:
|
|
|
709
772
|
self.runner.drafter.state,
|
|
710
773
|
self.runner.kv_caches,
|
|
711
774
|
input_ids,
|
|
712
|
-
|
|
775
|
+
draft_hidden_states,
|
|
713
776
|
attention_metadata,
|
|
714
777
|
num_tokens=num_tokens,
|
|
715
778
|
)
|
|
@@ -719,6 +782,7 @@ class CompilationManager:
|
|
|
719
782
|
self._run_compilation(
|
|
720
783
|
"eagle3_prepare_hidden_states_and_input_ids",
|
|
721
784
|
self.runner.drafter._prepare_hidden_states_and_input_ids,
|
|
785
|
+
self.runner.drafter.state,
|
|
722
786
|
aux_hidden_states,
|
|
723
787
|
query_start_loc,
|
|
724
788
|
target_token_ids,
|
|
@@ -741,18 +805,19 @@ class CompilationManager:
|
|
|
741
805
|
self.runner.drafter.state,
|
|
742
806
|
self.runner.kv_caches,
|
|
743
807
|
input_ids_loop,
|
|
744
|
-
|
|
808
|
+
draft_hidden_state_loop,
|
|
745
809
|
attention_metadata,
|
|
746
810
|
num_tokens=num_tokens,
|
|
747
811
|
)
|
|
748
812
|
|
|
749
813
|
hidden_states = self._create_dummy_tensor(
|
|
750
|
-
(num_tokens,
|
|
814
|
+
(num_tokens, draft_hidden_size), jnp.bfloat16,
|
|
751
815
|
NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
|
|
752
816
|
|
|
753
817
|
self._run_compilation(
|
|
754
818
|
"eagle3_select_inputs_for_loop_speculation",
|
|
755
819
|
self.runner.drafter._select_inputs_for_loop_speculation,
|
|
820
|
+
self.runner.drafter.state,
|
|
756
821
|
positions,
|
|
757
822
|
hidden_states,
|
|
758
823
|
hidden_states,
|
|
@@ -763,6 +828,7 @@ class CompilationManager:
|
|
|
763
828
|
self._run_compilation(
|
|
764
829
|
"eagle3_select_draft_token_ids",
|
|
765
830
|
self.runner.drafter._select_draft_token_ids,
|
|
831
|
+
self.runner.drafter.state,
|
|
766
832
|
hidden_states,
|
|
767
833
|
last_token_indices,
|
|
768
834
|
num_tokens=num_tokens,
|
tpu_inference/runner/kv_cache.py
CHANGED
|
@@ -7,6 +7,7 @@ from jax._src import dtypes
|
|
|
7
7
|
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
8
8
|
from torchax.ops.mappings import t2j_dtype
|
|
9
9
|
|
|
10
|
+
import tpu_inference.kernels.mla.v1.kernel as mla
|
|
10
11
|
import tpu_inference.kernels.ragged_paged_attention.v3.kernel as rpa
|
|
11
12
|
import tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 as rpa_hd64
|
|
12
13
|
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
@@ -17,9 +18,13 @@ logger = init_logger(__name__)
|
|
|
17
18
|
DEFAULT_KV_CACHE_DTYPE = jnp.bfloat16
|
|
18
19
|
|
|
19
20
|
|
|
20
|
-
def get_kv_cache_shape_with_mesh(mesh: Mesh,
|
|
21
|
-
|
|
22
|
-
|
|
21
|
+
def get_kv_cache_shape_with_mesh(mesh: Mesh,
|
|
22
|
+
total_num_pages: int,
|
|
23
|
+
page_size: int,
|
|
24
|
+
actual_num_kv_heads: int,
|
|
25
|
+
actual_head_dim: int,
|
|
26
|
+
kv_dtype: any,
|
|
27
|
+
use_mla: bool = False):
|
|
23
28
|
"""Gets the KV cache shape based on the mesh configuration."""
|
|
24
29
|
|
|
25
30
|
model_cnt = mesh.shape["model"]
|
|
@@ -28,15 +33,21 @@ def get_kv_cache_shape_with_mesh(mesh: Mesh, total_num_pages: int,
|
|
|
28
33
|
# specific model, rather than being determined by the head_dim. If new
|
|
29
34
|
# models are introduced with a head_dim of 64, this will require additional
|
|
30
35
|
# model-specific adjustments.
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
36
|
+
if use_mla:
|
|
37
|
+
get_kv_cache_shape_fn = mla.get_kv_cache_shape
|
|
38
|
+
shape = list(
|
|
39
|
+
get_kv_cache_shape_fn(total_num_pages, page_size, actual_head_dim,
|
|
40
|
+
kv_dtype))
|
|
41
|
+
else:
|
|
42
|
+
get_kv_cache_shape_fn = (
|
|
43
|
+
rpa_hd64.get_kv_cache_shape if actual_head_dim == 64 \
|
|
44
|
+
else rpa.get_kv_cache_shape
|
|
45
|
+
)
|
|
46
|
+
shape = list(
|
|
47
|
+
get_kv_cache_shape_fn(total_num_pages, page_size,
|
|
48
|
+
actual_num_kv_heads // model_cnt,
|
|
49
|
+
actual_head_dim, kv_dtype))
|
|
50
|
+
shape[2] *= model_cnt
|
|
40
51
|
return tuple(shape)
|
|
41
52
|
|
|
42
53
|
|
|
@@ -48,6 +59,7 @@ def create_kv_caches(
|
|
|
48
59
|
mesh: Mesh,
|
|
49
60
|
layer_names: List[str],
|
|
50
61
|
cache_dtype: jnp.dtype = DEFAULT_KV_CACHE_DTYPE,
|
|
62
|
+
use_mla: bool = False,
|
|
51
63
|
) -> List[jax.Array]:
|
|
52
64
|
"""
|
|
53
65
|
Creates a list of KV cache where each array mapps to single attention layer.
|
|
@@ -74,12 +86,16 @@ def create_kv_caches(
|
|
|
74
86
|
|
|
75
87
|
cache_shape = get_kv_cache_shape_with_mesh(mesh, num_blocks, block_size,
|
|
76
88
|
num_kv_heads, head_size,
|
|
77
|
-
cache_dtype)
|
|
89
|
+
cache_dtype, use_mla)
|
|
78
90
|
|
|
79
|
-
|
|
80
|
-
mesh,
|
|
81
|
-
|
|
82
|
-
|
|
91
|
+
if use_mla:
|
|
92
|
+
sharding = NamedSharding(mesh,
|
|
93
|
+
PartitionSpec(ShardingAxisName.MLP_TENSOR))
|
|
94
|
+
else:
|
|
95
|
+
sharding = NamedSharding(
|
|
96
|
+
mesh,
|
|
97
|
+
PartitionSpec(ShardingAxisName.ATTN_DATA, None,
|
|
98
|
+
ShardingAxisName.ATTN_HEAD))
|
|
83
99
|
|
|
84
100
|
def _allocate() -> jax.Array:
|
|
85
101
|
return jnp.empty(
|
|
@@ -94,7 +110,8 @@ def create_kv_caches(
|
|
|
94
110
|
return kv_caches
|
|
95
111
|
|
|
96
112
|
|
|
97
|
-
def
|
|
113
|
+
def get_attention_page_size_bytes(mesh: Mesh,
|
|
114
|
+
kv_cache_specs: dict[str, Any]) -> int:
|
|
98
115
|
"""
|
|
99
116
|
Calculate KV cache page size of RPA kernel.
|
|
100
117
|
|
|
@@ -107,14 +124,16 @@ def get_rpa_page_size_bytes(mesh: Mesh, kv_cache_specs: dict[str, Any]) -> int:
|
|
|
107
124
|
"""
|
|
108
125
|
|
|
109
126
|
# Import it here to avoid circular import.
|
|
110
|
-
from vllm.v1.kv_cache_interface import AttentionSpec
|
|
127
|
+
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
|
|
111
128
|
|
|
112
129
|
page_size_bytes_set = set()
|
|
113
130
|
for kv_cache_spec in kv_cache_specs.values():
|
|
114
131
|
assert isinstance(kv_cache_spec, AttentionSpec)
|
|
115
132
|
|
|
116
133
|
dtype = t2j_dtype(kv_cache_spec.dtype)
|
|
117
|
-
bits = dtypes.bit_width(dtype)
|
|
134
|
+
bits = (dtypes.bit_width(dtype) if hasattr(dtypes, "bit_width") else
|
|
135
|
+
dtypes.itemsize_bits(dtype))
|
|
136
|
+
use_mla = isinstance(kv_cache_spec, MLAAttentionSpec)
|
|
118
137
|
|
|
119
138
|
kv_cache_shape = get_kv_cache_shape_with_mesh(
|
|
120
139
|
mesh=mesh,
|
|
@@ -123,6 +142,7 @@ def get_rpa_page_size_bytes(mesh: Mesh, kv_cache_specs: dict[str, Any]) -> int:
|
|
|
123
142
|
actual_num_kv_heads=kv_cache_spec.num_kv_heads,
|
|
124
143
|
actual_head_dim=kv_cache_spec.head_size,
|
|
125
144
|
kv_dtype=dtype,
|
|
145
|
+
use_mla=use_mla,
|
|
126
146
|
)
|
|
127
147
|
page_size_bytes = (bits * np.prod(kv_cache_shape)) // 8
|
|
128
148
|
page_size_bytes_set.add(page_size_bytes)
|