tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (59) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/kernels/mla_v1_test.py +129 -41
  3. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  4. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
  5. tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
  6. tests/lora/test_layers.py +4 -1
  7. tests/lora/test_lora_perf.py +53 -0
  8. tests/test_envs.py +110 -12
  9. tests/test_quantization.py +3 -0
  10. tests/test_utils.py +1 -2
  11. tpu_inference/distributed/tpu_connector.py +1 -1
  12. tpu_inference/envs.py +92 -8
  13. tpu_inference/executors/ray_distributed_executor.py +5 -1
  14. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  15. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  16. tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
  17. tpu_inference/kernels/mla/v1/kernel.py +98 -120
  18. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  19. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  20. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  21. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +82 -32
  22. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +146 -85
  23. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
  24. tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
  25. tpu_inference/layers/common/attention_interface.py +7 -1
  26. tpu_inference/layers/common/sharding.py +11 -7
  27. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
  28. tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
  29. tpu_inference/layers/vllm/fused_moe.py +170 -208
  30. tpu_inference/layers/vllm/linear_common.py +43 -21
  31. tpu_inference/layers/vllm/quantization/common.py +11 -6
  32. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
  33. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
  34. tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
  35. tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
  36. tpu_inference/models/common/model_loader.py +78 -22
  37. tpu_inference/models/jax/deepseek_v3.py +185 -64
  38. tpu_inference/models/jax/gpt_oss.py +3 -3
  39. tpu_inference/models/jax/llama_eagle3.py +4 -5
  40. tpu_inference/models/jax/qwen2_5_vl.py +161 -47
  41. tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
  42. tpu_inference/models/jax/utils/weight_utils.py +203 -155
  43. tpu_inference/models/vllm/vllm_model_wrapper.py +11 -5
  44. tpu_inference/platforms/tpu_platform.py +29 -48
  45. tpu_inference/runner/compilation_manager.py +112 -46
  46. tpu_inference/runner/kv_cache.py +40 -20
  47. tpu_inference/runner/kv_cache_manager.py +40 -31
  48. tpu_inference/runner/persistent_batch_manager.py +40 -2
  49. tpu_inference/runner/structured_decoding_manager.py +2 -3
  50. tpu_inference/runner/tpu_runner.py +94 -51
  51. tpu_inference/runner/utils.py +2 -2
  52. tpu_inference/spec_decode/jax/eagle3.py +71 -22
  53. tpu_inference/utils.py +41 -14
  54. tpu_inference/worker/tpu_worker.py +43 -45
  55. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +8 -9
  56. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +59 -58
  57. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
  58. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
  59. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,13 @@
1
- import os
2
1
  import time
3
2
  from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
4
3
 
5
4
  import jax
6
5
  import jax.numpy as jnp
7
6
  import numpy as np
8
- import vllm.envs as envs
7
+ import vllm.envs as vllm_envs
9
8
  from jax.sharding import NamedSharding, PartitionSpec
10
9
 
10
+ import tpu_inference.envs as envs
11
11
  from tpu_inference.core.disagg_utils import is_disagg_enabled
12
12
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
13
13
  from tpu_inference.layers.common.sharding import ShardingAxisName
@@ -15,6 +15,8 @@ from tpu_inference.layers.jax.sample.sampling import sample
15
15
  from tpu_inference.layers.jax.sample.sampling_metadata import \
16
16
  TPUSupportedSamplingMetadata
17
17
  from tpu_inference.logger import init_logger
18
+ from tpu_inference.models.jax.jax_intermediate_tensor import \
19
+ JaxIntermediateTensors
18
20
  from tpu_inference.utils import device_array
19
21
 
20
22
  if TYPE_CHECKING:
@@ -30,10 +32,10 @@ class CompilationManager:
30
32
 
31
33
  def __init__(self, runner: "TPUModelRunner"):
32
34
  self.runner = runner
33
- if not envs.VLLM_DISABLE_COMPILE_CACHE:
35
+ if not vllm_envs.VLLM_DISABLE_COMPILE_CACHE:
34
36
  logger.info("Enabling JAX compile cache.")
35
37
  jax.config.update("jax_compilation_cache_dir",
36
- envs.VLLM_XLA_CACHE_PATH)
38
+ vllm_envs.VLLM_XLA_CACHE_PATH)
37
39
 
38
40
  def _create_dummy_tensor(self,
39
41
  shape: Tuple[int, ...],
@@ -67,8 +69,7 @@ class CompilationManager:
67
69
  logger.info("Compilation finished in %.2f [secs].", end - start)
68
70
 
69
71
  def capture_model(self) -> None:
70
- if os.getenv("SKIP_JAX_PRECOMPILE",
71
- False) or self.runner.model_config.enforce_eager:
72
+ if envs.SKIP_JAX_PRECOMPILE or self.runner.model_config.enforce_eager:
72
73
  return
73
74
  logger.info("Precompile all the subgraphs with possible input shapes.")
74
75
 
@@ -81,6 +82,8 @@ class CompilationManager:
81
82
  self._precompile_backbone_with_inputs_embeds()
82
83
  if self.runner.scheduler_config.async_scheduling:
83
84
  self._precompile_substitute_placeholder_token()
85
+ if not self.runner.is_last_rank:
86
+ return
84
87
  self._precompile_select_from_array()
85
88
  self._precompile_compute_logits()
86
89
  self._precompile_disagg_utils()
@@ -120,8 +123,15 @@ class CompilationManager:
120
123
  num_tokens=num_tokens,
121
124
  )
122
125
 
123
- def _precompile_backbone_helper(self, name, *, input_ids, positions,
124
- inputs_embeds) -> None:
126
+ def _precompile_backbone_helper(self,
127
+ name,
128
+ *,
129
+ input_ids,
130
+ positions,
131
+ inputs_embeds,
132
+ intermediate_tensors=None,
133
+ is_first_rank=True,
134
+ is_last_rank=True) -> None:
125
135
  num_tokens = None
126
136
  if input_ids is not None:
127
137
  num_tokens = input_ids.shape[0]
@@ -181,10 +191,14 @@ class CompilationManager:
181
191
  inputs_embeds,
182
192
  layer_name_to_kvcache_index,
183
193
  lora_metadata,
194
+ intermediate_tensors,
195
+ is_first_rank,
196
+ is_last_rank,
184
197
  ):
185
198
  kv_caches, hidden_states, _ = self.runner.model_fn(
186
199
  state, kv_caches, input_ids, attention_metadata, inputs_embeds,
187
- positions, layer_name_to_kvcache_index, lora_metadata)
200
+ positions, layer_name_to_kvcache_index, lora_metadata,
201
+ intermediate_tensors, is_first_rank, is_last_rank)
188
202
  self.runner.kv_caches = kv_caches
189
203
  return hidden_states
190
204
 
@@ -207,6 +221,9 @@ class CompilationManager:
207
221
  inputs_embeds,
208
222
  tuple(self.runner.layer_name_to_kvcache_index.items()),
209
223
  lora_metadata,
224
+ intermediate_tensors,
225
+ is_first_rank,
226
+ is_last_rank,
210
227
  num_tokens=num_tokens,
211
228
  )
212
229
 
@@ -257,6 +274,7 @@ class CompilationManager:
257
274
  )
258
275
 
259
276
  def _precompile_backbone_text_only(self) -> None:
277
+ hidden_size = self.runner.model_config.get_hidden_size()
260
278
  for num_tokens in self.runner.num_tokens_paddings:
261
279
  dp_sharding = NamedSharding(
262
280
  self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, )
@@ -266,10 +284,28 @@ class CompilationManager:
266
284
  dp_sharding)
267
285
  positions = self._create_dummy_tensor((num_tokens, ), jnp.int32,
268
286
  dp_sharding)
269
- self._precompile_backbone_helper("backbone",
270
- input_ids=input_ids,
271
- positions=positions,
272
- inputs_embeds=None)
287
+ is_first_rank = self.runner.is_first_rank
288
+ is_last_rank = self.runner.is_last_rank
289
+ if is_first_rank:
290
+ intermediate_tensors = None
291
+ else:
292
+ hidden_states = self._create_dummy_tensor(
293
+ (num_tokens, hidden_size), jnp.bfloat16)
294
+ residual = self._create_dummy_tensor((num_tokens, hidden_size),
295
+ jnp.bfloat16)
296
+ intermediate_tensors = JaxIntermediateTensors(
297
+ tensors={
298
+ "hidden_states": hidden_states,
299
+ "residual": residual
300
+ })
301
+ self._precompile_backbone_helper(
302
+ f"worker{self.runner.rank} backbone",
303
+ input_ids=input_ids,
304
+ positions=positions,
305
+ inputs_embeds=None,
306
+ intermediate_tensors=intermediate_tensors,
307
+ is_first_rank=is_first_rank,
308
+ is_last_rank=is_last_rank)
273
309
 
274
310
  def _precompile_backbone_with_inputs_embeds(self) -> None:
275
311
  hidden_size = self.runner.model_config.get_hidden_size()
@@ -283,10 +319,28 @@ class CompilationManager:
283
319
  else:
284
320
  positions = self._create_dummy_tensor((num_tokens, ),
285
321
  jnp.int32)
286
- self._precompile_backbone_helper("backbone with embeds",
287
- input_ids=None,
288
- positions=positions,
289
- inputs_embeds=inputs_embeds)
322
+ is_first_rank = self.runner.is_first_rank
323
+ is_last_rank = self.runner.is_last_rank
324
+ if not is_first_rank:
325
+ hidden_states = self._create_dummy_tensor(
326
+ (num_tokens, hidden_size), jnp.bfloat16)
327
+ residual = self._create_dummy_tensor((num_tokens, hidden_size),
328
+ jnp.bfloat16)
329
+ intermediate_tensors = JaxIntermediateTensors(
330
+ tensors={
331
+ "hidden_states": hidden_states,
332
+ "residual": residual
333
+ })
334
+ else:
335
+ intermediate_tensors = None
336
+ self._precompile_backbone_helper(
337
+ f"worker{self.runner.rank} backbone with embeds",
338
+ input_ids=None,
339
+ positions=positions,
340
+ inputs_embeds=inputs_embeds,
341
+ intermediate_tensors=intermediate_tensors,
342
+ is_first_rank=is_first_rank,
343
+ is_last_rank=is_last_rank)
290
344
 
291
345
  def _precompile_select_from_array_helper(
292
346
  self,
@@ -354,7 +408,7 @@ class CompilationManager:
354
408
  self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None))
355
409
  dp_size = self.runner.vllm_config.sharding_config.total_dp_size
356
410
  self._precompile_select_from_array_helper(
357
- name="select all logits",
411
+ name=f"worker{self.runner.rank} select all logits",
358
412
  source_paddings=self.runner.num_tokens_paddings,
359
413
  indices_paddings=index_paddings,
360
414
  hidden_dim=hsize,
@@ -365,7 +419,8 @@ class CompilationManager:
365
419
  if self.runner.speculative_config:
366
420
  vocab_size = self.runner.model_config.get_vocab_size()
367
421
  self._precompile_select_from_array_helper(
368
- name="select bonus tokens for spec decoding",
422
+ name=
423
+ f"worker{self.runner.rank} select bonus tokens for spec decoding",
369
424
  source_paddings=self.runner.num_logits_paddings,
370
425
  indices_paddings=self.runner.num_reqs_paddings,
371
426
  hidden_dim=vocab_size,
@@ -373,7 +428,8 @@ class CompilationManager:
373
428
  PartitionSpec(None, "model")),
374
429
  )
375
430
  self._precompile_select_from_array_helper(
376
- name="select target tokens for spec decoding",
431
+ name=
432
+ f"worker{self.runner.rank} select target tokens for spec decoding",
377
433
  source_paddings=self.runner.num_logits_paddings,
378
434
  indices_paddings=self.runner.num_logits_paddings,
379
435
  hidden_dim=vocab_size,
@@ -396,7 +452,7 @@ class CompilationManager:
396
452
  np.array([num_reqs], dtype=np.int32)):
397
453
  lora_metadata = self.runner.lora_utils.extract_lora_metadata()
398
454
  self._run_compilation(
399
- "compute_logits",
455
+ f"worker{self.runner.rank} compute_logits",
400
456
  self.runner.compute_logits_fn,
401
457
  self.runner.state,
402
458
  hidden_states,
@@ -410,11 +466,12 @@ class CompilationManager:
410
466
  for num_reqs in self.runner.num_reqs_paddings:
411
467
  logits_sharding = NamedSharding(
412
468
  self.runner.mesh,
413
- PartitionSpec(ShardingAxisName.ATTN_DATA, "model"))
469
+ PartitionSpec(ShardingAxisName.MLP_DATA,
470
+ ShardingAxisName.MLP_TENSOR))
414
471
  dp_size = self.runner.vllm_config.sharding_config.total_dp_size
415
472
  sampling_metadata_sharding = NamedSharding(
416
473
  self.runner.mesh, PartitionSpec(
417
- ShardingAxisName.ATTN_DATA)) if dp_size > 1 else None
474
+ ShardingAxisName.MLP_DATA)) if dp_size > 1 else None
418
475
  logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16,
419
476
  logits_sharding)
420
477
  for do_sampling in (True, False):
@@ -438,7 +495,7 @@ class CompilationManager:
438
495
  do_sampling=do_sampling,
439
496
  )
440
497
  self._run_compilation(
441
- "sample",
498
+ f"worker{self.runner.rank} sample",
442
499
  sample,
443
500
  self.runner.rng_params_for_sampling,
444
501
  self.runner.mesh,
@@ -479,7 +536,7 @@ class CompilationManager:
479
536
  logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16)
480
537
  token_ids = self._create_dummy_tensor((num_reqs, ), jnp.int32)
481
538
  self._run_compilation(
482
- "gather_logprobs",
539
+ f"worker{self.runner.rank} gather_logprobs",
483
540
  self.runner._compute_and_gather_logprobs,
484
541
  logits,
485
542
  token_ids,
@@ -531,7 +588,7 @@ class CompilationManager:
531
588
  do_sampling=do_sampling)
532
589
 
533
590
  self._run_compilation(
534
- compilation_name,
591
+ f"worker{self.runner.rank} {compilation_name}",
535
592
  self.runner.rejection_sampler,
536
593
  draft_token_ids,
537
594
  num_draft_tokens,
@@ -548,7 +605,9 @@ class CompilationManager:
548
605
  def _precompile_eagle3_helpers(self) -> None:
549
606
  logger.info(
550
607
  "Compiling eagle3 jitted helpers with different input shapes.")
551
- hidden_size = self.runner.model_config.get_hidden_size()
608
+ target_hidden_size = self.runner.model_config.get_hidden_size()
609
+ draft_hidden_size = self.runner.speculative_config.draft_model_config.get_hidden_size(
610
+ )
552
611
  dtype = self.runner.model_config.dtype
553
612
 
554
613
  num_kv_cache_groups = len(self.runner.kv_cache_config.kv_cache_groups)
@@ -595,10 +654,11 @@ class CompilationManager:
595
654
 
596
655
  for num_logits in self.runner.num_logits_paddings:
597
656
  hidden_states = self._create_dummy_tensor(
598
- (num_logits, hidden_size), jnp.bfloat16)
657
+ (num_logits, draft_hidden_size), jnp.bfloat16)
599
658
  self._run_compilation(
600
659
  "eagle3_get_draft_token_ids",
601
660
  self.runner.drafter._get_draft_token_ids,
661
+ self.runner.drafter.state,
602
662
  hidden_states,
603
663
  num_logits=num_logits,
604
664
  )
@@ -606,8 +666,8 @@ class CompilationManager:
606
666
  input_ids_loop = self._create_dummy_tensor(
607
667
  (self.runner.max_num_reqs, ), jnp.int32,
608
668
  NamedSharding(self.runner.mesh, PartitionSpec()))
609
- target_hidden_state_loop = self._create_dummy_tensor(
610
- (self.runner.max_num_reqs, hidden_size), dtype,
669
+ draft_hidden_state_loop = self._create_dummy_tensor(
670
+ (self.runner.max_num_reqs, draft_hidden_size), dtype,
611
671
  NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
612
672
  next_token_ids = self._create_dummy_tensor(
613
673
  (self.runner.max_num_reqs, ), jnp.int32)
@@ -615,9 +675,12 @@ class CompilationManager:
615
675
  (self.runner.max_num_reqs, ), jnp.int32)
616
676
  for num_tokens in self.runner.num_tokens_paddings:
617
677
  aux_hidden_states = [
618
- self._create_dummy_tensor((num_tokens, hidden_size), dtype),
619
- self._create_dummy_tensor((num_tokens, hidden_size), dtype),
620
- self._create_dummy_tensor((num_tokens, hidden_size), dtype),
678
+ self._create_dummy_tensor((num_tokens, target_hidden_size),
679
+ dtype),
680
+ self._create_dummy_tensor((num_tokens, target_hidden_size),
681
+ dtype),
682
+ self._create_dummy_tensor((num_tokens, target_hidden_size),
683
+ dtype),
621
684
  ]
622
685
 
623
686
  positions = self._create_dummy_tensor((num_tokens, ), jnp.int32)
@@ -640,23 +703,23 @@ class CompilationManager:
640
703
  num_reqs,
641
704
  ):
642
705
  target_hidden_states, input_ids, last_token_indices, _ = self.runner.drafter._filter_token_and_prepare_initial_inputs(
643
- token_indices, query_start_loc, seq_lens, input_ids,
644
- aux_hidden_states, attention_metadata, next_token_ids,
645
- num_reqs)
706
+ self.runner.drafter.state, token_indices, query_start_loc,
707
+ seq_lens, input_ids, aux_hidden_states, attention_metadata,
708
+ next_token_ids, num_reqs)
646
709
  return target_hidden_states, input_ids, last_token_indices
647
710
 
648
711
  input_ids = self._create_dummy_tensor((num_tokens, ), jnp.int32)
649
712
  aux_hidden_states = [
650
713
  self._create_dummy_tensor(
651
- (num_tokens, hidden_size), jnp.bfloat16,
714
+ (num_tokens, target_hidden_size), jnp.bfloat16,
652
715
  NamedSharding(self.runner.mesh, PartitionSpec(None,
653
716
  None))),
654
717
  self._create_dummy_tensor(
655
- (num_tokens, hidden_size), jnp.bfloat16,
718
+ (num_tokens, target_hidden_size), jnp.bfloat16,
656
719
  NamedSharding(self.runner.mesh, PartitionSpec(None,
657
720
  None))),
658
721
  self._create_dummy_tensor(
659
- (num_tokens, hidden_size), jnp.bfloat16,
722
+ (num_tokens, target_hidden_size), jnp.bfloat16,
660
723
  NamedSharding(self.runner.mesh, PartitionSpec(None,
661
724
  None))),
662
725
  ]
@@ -688,17 +751,17 @@ class CompilationManager:
688
751
  state,
689
752
  kv_caches,
690
753
  input_ids,
691
- target_hidden_states,
754
+ draft_hidden_states,
692
755
  attention_metadata,
693
756
  ):
694
757
  kv_caches, hidden_states, _ = self.runner.drafter.model_fn(
695
- state, kv_caches, input_ids, target_hidden_states,
758
+ state, kv_caches, input_ids, draft_hidden_states,
696
759
  attention_metadata)
697
760
  self.runner.kv_caches = kv_caches
698
761
  return hidden_states
699
762
 
700
- target_hidden_states = self._create_dummy_tensor(
701
- (num_tokens, hidden_size), dtype,
763
+ draft_hidden_states = self._create_dummy_tensor(
764
+ (num_tokens, draft_hidden_size), dtype,
702
765
  NamedSharding(self.runner.mesh, PartitionSpec(None, "model")))
703
766
  input_ids = self._create_dummy_tensor(
704
767
  (num_tokens, ), jnp.int32,
@@ -709,7 +772,7 @@ class CompilationManager:
709
772
  self.runner.drafter.state,
710
773
  self.runner.kv_caches,
711
774
  input_ids,
712
- target_hidden_states,
775
+ draft_hidden_states,
713
776
  attention_metadata,
714
777
  num_tokens=num_tokens,
715
778
  )
@@ -719,6 +782,7 @@ class CompilationManager:
719
782
  self._run_compilation(
720
783
  "eagle3_prepare_hidden_states_and_input_ids",
721
784
  self.runner.drafter._prepare_hidden_states_and_input_ids,
785
+ self.runner.drafter.state,
722
786
  aux_hidden_states,
723
787
  query_start_loc,
724
788
  target_token_ids,
@@ -741,18 +805,19 @@ class CompilationManager:
741
805
  self.runner.drafter.state,
742
806
  self.runner.kv_caches,
743
807
  input_ids_loop,
744
- target_hidden_state_loop,
808
+ draft_hidden_state_loop,
745
809
  attention_metadata,
746
810
  num_tokens=num_tokens,
747
811
  )
748
812
 
749
813
  hidden_states = self._create_dummy_tensor(
750
- (num_tokens, hidden_size), jnp.bfloat16,
814
+ (num_tokens, draft_hidden_size), jnp.bfloat16,
751
815
  NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
752
816
 
753
817
  self._run_compilation(
754
818
  "eagle3_select_inputs_for_loop_speculation",
755
819
  self.runner.drafter._select_inputs_for_loop_speculation,
820
+ self.runner.drafter.state,
756
821
  positions,
757
822
  hidden_states,
758
823
  hidden_states,
@@ -763,6 +828,7 @@ class CompilationManager:
763
828
  self._run_compilation(
764
829
  "eagle3_select_draft_token_ids",
765
830
  self.runner.drafter._select_draft_token_ids,
831
+ self.runner.drafter.state,
766
832
  hidden_states,
767
833
  last_token_indices,
768
834
  num_tokens=num_tokens,
@@ -7,6 +7,7 @@ from jax._src import dtypes
7
7
  from jax.sharding import Mesh, NamedSharding, PartitionSpec
8
8
  from torchax.ops.mappings import t2j_dtype
9
9
 
10
+ import tpu_inference.kernels.mla.v1.kernel as mla
10
11
  import tpu_inference.kernels.ragged_paged_attention.v3.kernel as rpa
11
12
  import tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 as rpa_hd64
12
13
  from tpu_inference.layers.common.sharding import ShardingAxisName
@@ -17,9 +18,13 @@ logger = init_logger(__name__)
17
18
  DEFAULT_KV_CACHE_DTYPE = jnp.bfloat16
18
19
 
19
20
 
20
- def get_kv_cache_shape_with_mesh(mesh: Mesh, total_num_pages: int,
21
- page_size: int, actual_num_kv_heads: int,
22
- actual_head_dim: int, kv_dtype: any):
21
+ def get_kv_cache_shape_with_mesh(mesh: Mesh,
22
+ total_num_pages: int,
23
+ page_size: int,
24
+ actual_num_kv_heads: int,
25
+ actual_head_dim: int,
26
+ kv_dtype: any,
27
+ use_mla: bool = False):
23
28
  """Gets the KV cache shape based on the mesh configuration."""
24
29
 
25
30
  model_cnt = mesh.shape["model"]
@@ -28,15 +33,21 @@ def get_kv_cache_shape_with_mesh(mesh: Mesh, total_num_pages: int,
28
33
  # specific model, rather than being determined by the head_dim. If new
29
34
  # models are introduced with a head_dim of 64, this will require additional
30
35
  # model-specific adjustments.
31
- get_kv_cache_shape_fn = (
32
- rpa_hd64.get_kv_cache_shape if actual_head_dim == 64 \
33
- else rpa.get_kv_cache_shape
34
- )
35
- shape = list(
36
- get_kv_cache_shape_fn(total_num_pages, page_size,
37
- actual_num_kv_heads // model_cnt,
38
- actual_head_dim, kv_dtype))
39
- shape[2] *= model_cnt
36
+ if use_mla:
37
+ get_kv_cache_shape_fn = mla.get_kv_cache_shape
38
+ shape = list(
39
+ get_kv_cache_shape_fn(total_num_pages, page_size, actual_head_dim,
40
+ kv_dtype))
41
+ else:
42
+ get_kv_cache_shape_fn = (
43
+ rpa_hd64.get_kv_cache_shape if actual_head_dim == 64 \
44
+ else rpa.get_kv_cache_shape
45
+ )
46
+ shape = list(
47
+ get_kv_cache_shape_fn(total_num_pages, page_size,
48
+ actual_num_kv_heads // model_cnt,
49
+ actual_head_dim, kv_dtype))
50
+ shape[2] *= model_cnt
40
51
  return tuple(shape)
41
52
 
42
53
 
@@ -48,6 +59,7 @@ def create_kv_caches(
48
59
  mesh: Mesh,
49
60
  layer_names: List[str],
50
61
  cache_dtype: jnp.dtype = DEFAULT_KV_CACHE_DTYPE,
62
+ use_mla: bool = False,
51
63
  ) -> List[jax.Array]:
52
64
  """
53
65
  Creates a list of KV cache where each array mapps to single attention layer.
@@ -74,12 +86,16 @@ def create_kv_caches(
74
86
 
75
87
  cache_shape = get_kv_cache_shape_with_mesh(mesh, num_blocks, block_size,
76
88
  num_kv_heads, head_size,
77
- cache_dtype)
89
+ cache_dtype, use_mla)
78
90
 
79
- sharding = NamedSharding(
80
- mesh,
81
- PartitionSpec(ShardingAxisName.ATTN_DATA, None,
82
- ShardingAxisName.ATTN_HEAD))
91
+ if use_mla:
92
+ sharding = NamedSharding(mesh,
93
+ PartitionSpec(ShardingAxisName.MLP_TENSOR))
94
+ else:
95
+ sharding = NamedSharding(
96
+ mesh,
97
+ PartitionSpec(ShardingAxisName.ATTN_DATA, None,
98
+ ShardingAxisName.ATTN_HEAD))
83
99
 
84
100
  def _allocate() -> jax.Array:
85
101
  return jnp.empty(
@@ -94,7 +110,8 @@ def create_kv_caches(
94
110
  return kv_caches
95
111
 
96
112
 
97
- def get_rpa_page_size_bytes(mesh: Mesh, kv_cache_specs: dict[str, Any]) -> int:
113
+ def get_attention_page_size_bytes(mesh: Mesh,
114
+ kv_cache_specs: dict[str, Any]) -> int:
98
115
  """
99
116
  Calculate KV cache page size of RPA kernel.
100
117
 
@@ -107,14 +124,16 @@ def get_rpa_page_size_bytes(mesh: Mesh, kv_cache_specs: dict[str, Any]) -> int:
107
124
  """
108
125
 
109
126
  # Import it here to avoid circular import.
110
- from vllm.v1.kv_cache_interface import AttentionSpec
127
+ from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
111
128
 
112
129
  page_size_bytes_set = set()
113
130
  for kv_cache_spec in kv_cache_specs.values():
114
131
  assert isinstance(kv_cache_spec, AttentionSpec)
115
132
 
116
133
  dtype = t2j_dtype(kv_cache_spec.dtype)
117
- bits = dtypes.bit_width(dtype)
134
+ bits = (dtypes.bit_width(dtype) if hasattr(dtypes, "bit_width") else
135
+ dtypes.itemsize_bits(dtype))
136
+ use_mla = isinstance(kv_cache_spec, MLAAttentionSpec)
118
137
 
119
138
  kv_cache_shape = get_kv_cache_shape_with_mesh(
120
139
  mesh=mesh,
@@ -123,6 +142,7 @@ def get_rpa_page_size_bytes(mesh: Mesh, kv_cache_specs: dict[str, Any]) -> int:
123
142
  actual_num_kv_heads=kv_cache_spec.num_kv_heads,
124
143
  actual_head_dim=kv_cache_spec.head_size,
125
144
  kv_dtype=dtype,
145
+ use_mla=use_mla,
126
146
  )
127
147
  page_size_bytes = (bits * np.prod(kv_cache_shape)) // 8
128
148
  page_size_bytes_set.add(page_size_bytes)