tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (76) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/kernels/mla_v1_test.py +129 -41
  3. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  4. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
  5. tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
  6. tests/lora/test_layers.py +4 -7
  7. tests/lora/test_lora_perf.py +53 -0
  8. tests/lora/utils.py +0 -8
  9. tests/test_envs.py +110 -12
  10. tests/test_quantization.py +3 -0
  11. tests/test_utils.py +1 -2
  12. tpu_inference/__init__.py +22 -3
  13. tpu_inference/core/disagg_utils.py +6 -8
  14. tpu_inference/distributed/tpu_connector.py +3 -4
  15. tpu_inference/distributed/utils.py +3 -2
  16. tpu_inference/envs.py +93 -9
  17. tpu_inference/executors/ray_distributed_executor.py +9 -2
  18. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  19. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  20. tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
  21. tpu_inference/kernels/mla/v1/kernel.py +98 -120
  22. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  23. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  24. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  25. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +140 -67
  26. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +204 -120
  27. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
  28. tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
  29. tpu_inference/layers/common/attention_interface.py +7 -1
  30. tpu_inference/layers/common/sharding.py +11 -7
  31. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
  32. tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
  33. tpu_inference/layers/vllm/fused_moe.py +170 -208
  34. tpu_inference/layers/vllm/linear_common.py +43 -21
  35. tpu_inference/layers/vllm/quantization/common.py +11 -6
  36. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
  37. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
  38. tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
  39. tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
  40. tpu_inference/layers/vllm/sharding.py +2 -2
  41. tpu_inference/lora/torch_punica_tpu.py +1 -2
  42. tpu_inference/models/common/model_loader.py +84 -28
  43. tpu_inference/models/jax/deepseek_v3.py +185 -64
  44. tpu_inference/models/jax/gpt_oss.py +3 -3
  45. tpu_inference/models/jax/llama3.py +2 -1
  46. tpu_inference/models/jax/llama_eagle3.py +8 -5
  47. tpu_inference/models/jax/llama_guard_4.py +361 -0
  48. tpu_inference/models/jax/qwen2.py +2 -1
  49. tpu_inference/models/jax/qwen2_5_vl.py +163 -48
  50. tpu_inference/models/jax/qwen3.py +2 -1
  51. tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
  52. tpu_inference/models/jax/utils/weight_utils.py +205 -144
  53. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -8
  54. tpu_inference/platforms/tpu_platform.py +34 -50
  55. tpu_inference/runner/compilation_manager.py +144 -60
  56. tpu_inference/runner/kv_cache.py +40 -20
  57. tpu_inference/runner/kv_cache_manager.py +48 -33
  58. tpu_inference/runner/persistent_batch_manager.py +40 -2
  59. tpu_inference/runner/structured_decoding_manager.py +2 -3
  60. tpu_inference/runner/tpu_runner.py +280 -149
  61. tpu_inference/runner/utils.py +2 -2
  62. tpu_inference/spec_decode/jax/eagle3.py +71 -21
  63. tpu_inference/tpu_info.py +4 -3
  64. tpu_inference/utils.py +46 -18
  65. tpu_inference/worker/tpu_worker.py +197 -63
  66. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +9 -10
  67. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +70 -74
  68. tpu_inference/mock/__init__.py +0 -0
  69. tpu_inference/mock/vllm_config_utils.py +0 -28
  70. tpu_inference/mock/vllm_envs.py +0 -1219
  71. tpu_inference/mock/vllm_logger.py +0 -212
  72. tpu_inference/mock/vllm_logging_utils.py +0 -15
  73. tpu_inference/models/jax/phi3.py +0 -376
  74. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
  75. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
  76. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,13 @@
1
- import os
2
1
  import time
3
- from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
2
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
4
3
 
5
4
  import jax
6
5
  import jax.numpy as jnp
7
6
  import numpy as np
8
- import vllm.envs as envs
7
+ import vllm.envs as vllm_envs
9
8
  from jax.sharding import NamedSharding, PartitionSpec
10
9
 
10
+ import tpu_inference.envs as envs
11
11
  from tpu_inference.core.disagg_utils import is_disagg_enabled
12
12
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
13
13
  from tpu_inference.layers.common.sharding import ShardingAxisName
@@ -15,6 +15,8 @@ from tpu_inference.layers.jax.sample.sampling import sample
15
15
  from tpu_inference.layers.jax.sample.sampling_metadata import \
16
16
  TPUSupportedSamplingMetadata
17
17
  from tpu_inference.logger import init_logger
18
+ from tpu_inference.models.jax.jax_intermediate_tensor import \
19
+ JaxIntermediateTensors
18
20
  from tpu_inference.utils import device_array
19
21
 
20
22
  if TYPE_CHECKING:
@@ -30,10 +32,10 @@ class CompilationManager:
30
32
 
31
33
  def __init__(self, runner: "TPUModelRunner"):
32
34
  self.runner = runner
33
- if not envs.VLLM_DISABLE_COMPILE_CACHE:
35
+ if not vllm_envs.VLLM_DISABLE_COMPILE_CACHE:
34
36
  logger.info("Enabling JAX compile cache.")
35
37
  jax.config.update("jax_compilation_cache_dir",
36
- envs.VLLM_XLA_CACHE_PATH)
38
+ vllm_envs.VLLM_XLA_CACHE_PATH)
37
39
 
38
40
  def _create_dummy_tensor(self,
39
41
  shape: Tuple[int, ...],
@@ -67,8 +69,7 @@ class CompilationManager:
67
69
  logger.info("Compilation finished in %.2f [secs].", end - start)
68
70
 
69
71
  def capture_model(self) -> None:
70
- if os.getenv("SKIP_JAX_PRECOMPILE",
71
- False) or self.runner.model_config.enforce_eager:
72
+ if envs.SKIP_JAX_PRECOMPILE or self.runner.model_config.enforce_eager:
72
73
  return
73
74
  logger.info("Precompile all the subgraphs with possible input shapes.")
74
75
 
@@ -81,6 +82,8 @@ class CompilationManager:
81
82
  self._precompile_backbone_with_inputs_embeds()
82
83
  if self.runner.scheduler_config.async_scheduling:
83
84
  self._precompile_substitute_placeholder_token()
85
+ if not self.runner.is_last_rank:
86
+ return
84
87
  self._precompile_select_from_array()
85
88
  self._precompile_compute_logits()
86
89
  self._precompile_disagg_utils()
@@ -120,8 +123,15 @@ class CompilationManager:
120
123
  num_tokens=num_tokens,
121
124
  )
122
125
 
123
- def _precompile_backbone_helper(self, name, *, input_ids, positions,
124
- inputs_embeds) -> None:
126
+ def _precompile_backbone_helper(self,
127
+ name,
128
+ *,
129
+ input_ids,
130
+ positions,
131
+ inputs_embeds,
132
+ intermediate_tensors=None,
133
+ is_first_rank=True,
134
+ is_last_rank=True) -> None:
125
135
  num_tokens = None
126
136
  if input_ids is not None:
127
137
  num_tokens = input_ids.shape[0]
@@ -135,12 +145,6 @@ class CompilationManager:
135
145
  ShardingAxisName.ATTN_DATA, )) if dp_size > 1 else None
136
146
 
137
147
  # Keep existing pattern for complex array operations
138
- block_tables = self.runner.block_table_cpu[:self.runner.max_num_reqs]
139
- block_tables = block_tables.reshape(-1)
140
- block_tables = device_array(self.runner.mesh,
141
- block_tables,
142
- sharding=dp_sharding)
143
-
144
148
  seq_lens = self._create_dummy_tensor((self.runner.max_num_reqs, ),
145
149
  jnp.int32, dp_sharding)
146
150
  query_start_loc = self._create_dummy_tensor(
@@ -152,26 +156,49 @@ class CompilationManager:
152
156
  request_distribution,
153
157
  sharding=dp_sharding)
154
158
 
155
- attention_metadata = AttentionMetadata(
156
- input_positions=positions,
157
- block_tables=block_tables,
158
- seq_lens=seq_lens,
159
- query_start_loc=query_start_loc,
160
- request_distribution=request_distribution,
161
- )
159
+ attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
160
+ uniform_attention_metadata: AttentionMetadata = None
161
+ for kv_cache_gid, kv_cache_group in enumerate(
162
+ self.runner.kv_cache_config.kv_cache_groups):
163
+ block_tables = self.runner.block_tables_cpu[
164
+ kv_cache_gid][:self.runner.max_num_reqs]
165
+ block_tables = block_tables.reshape(-1)
166
+ block_tables = device_array(self.runner.mesh,
167
+ block_tables,
168
+ sharding=dp_sharding)
169
+
170
+ attention_metadata_gid = AttentionMetadata(
171
+ input_positions=positions,
172
+ block_tables=block_tables,
173
+ seq_lens=seq_lens,
174
+ query_start_loc=query_start_loc,
175
+ request_distribution=request_distribution,
176
+ )
177
+ if not self.runner.use_hybrid_kvcache:
178
+ # all layers share the same attention metadata
179
+ uniform_attention_metadata = attention_metadata_gid
180
+ else:
181
+ for layer_name in kv_cache_group.layer_names:
182
+ attention_metadata_per_layer[
183
+ layer_name] = attention_metadata_gid
162
184
 
163
185
  def model_fn_wrapper(
164
186
  state,
165
187
  kv_caches,
166
188
  input_ids,
167
189
  attention_metadata,
190
+ positions,
168
191
  inputs_embeds,
169
192
  layer_name_to_kvcache_index,
170
193
  lora_metadata,
194
+ intermediate_tensors,
195
+ is_first_rank,
196
+ is_last_rank,
171
197
  ):
172
198
  kv_caches, hidden_states, _ = self.runner.model_fn(
173
199
  state, kv_caches, input_ids, attention_metadata, inputs_embeds,
174
- layer_name_to_kvcache_index, lora_metadata)
200
+ positions, layer_name_to_kvcache_index, lora_metadata,
201
+ intermediate_tensors, is_first_rank, is_last_rank)
175
202
  self.runner.kv_caches = kv_caches
176
203
  return hidden_states
177
204
 
@@ -179,6 +206,10 @@ class CompilationManager:
179
206
  self.runner.lora_config, np.array([num_tokens],
180
207
  dtype=np.int32)):
181
208
  lora_metadata = self.runner.lora_utils.extract_lora_metadata()
209
+ if self.runner.use_hybrid_kvcache:
210
+ attention_metadata = attention_metadata_per_layer
211
+ else:
212
+ attention_metadata = uniform_attention_metadata
182
213
  self._run_compilation(
183
214
  name,
184
215
  model_fn_wrapper,
@@ -186,9 +217,13 @@ class CompilationManager:
186
217
  self.runner.kv_caches,
187
218
  input_ids,
188
219
  attention_metadata,
220
+ positions,
189
221
  inputs_embeds,
190
222
  tuple(self.runner.layer_name_to_kvcache_index.items()),
191
223
  lora_metadata,
224
+ intermediate_tensors,
225
+ is_first_rank,
226
+ is_last_rank,
192
227
  num_tokens=num_tokens,
193
228
  )
194
229
 
@@ -239,6 +274,7 @@ class CompilationManager:
239
274
  )
240
275
 
241
276
  def _precompile_backbone_text_only(self) -> None:
277
+ hidden_size = self.runner.model_config.get_hidden_size()
242
278
  for num_tokens in self.runner.num_tokens_paddings:
243
279
  dp_sharding = NamedSharding(
244
280
  self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, )
@@ -248,10 +284,28 @@ class CompilationManager:
248
284
  dp_sharding)
249
285
  positions = self._create_dummy_tensor((num_tokens, ), jnp.int32,
250
286
  dp_sharding)
251
- self._precompile_backbone_helper("backbone",
252
- input_ids=input_ids,
253
- positions=positions,
254
- inputs_embeds=None)
287
+ is_first_rank = self.runner.is_first_rank
288
+ is_last_rank = self.runner.is_last_rank
289
+ if is_first_rank:
290
+ intermediate_tensors = None
291
+ else:
292
+ hidden_states = self._create_dummy_tensor(
293
+ (num_tokens, hidden_size), jnp.bfloat16)
294
+ residual = self._create_dummy_tensor((num_tokens, hidden_size),
295
+ jnp.bfloat16)
296
+ intermediate_tensors = JaxIntermediateTensors(
297
+ tensors={
298
+ "hidden_states": hidden_states,
299
+ "residual": residual
300
+ })
301
+ self._precompile_backbone_helper(
302
+ f"worker{self.runner.rank} backbone",
303
+ input_ids=input_ids,
304
+ positions=positions,
305
+ inputs_embeds=None,
306
+ intermediate_tensors=intermediate_tensors,
307
+ is_first_rank=is_first_rank,
308
+ is_last_rank=is_last_rank)
255
309
 
256
310
  def _precompile_backbone_with_inputs_embeds(self) -> None:
257
311
  hidden_size = self.runner.model_config.get_hidden_size()
@@ -265,10 +319,28 @@ class CompilationManager:
265
319
  else:
266
320
  positions = self._create_dummy_tensor((num_tokens, ),
267
321
  jnp.int32)
268
- self._precompile_backbone_helper("backbone with embeds",
269
- input_ids=None,
270
- positions=positions,
271
- inputs_embeds=inputs_embeds)
322
+ is_first_rank = self.runner.is_first_rank
323
+ is_last_rank = self.runner.is_last_rank
324
+ if not is_first_rank:
325
+ hidden_states = self._create_dummy_tensor(
326
+ (num_tokens, hidden_size), jnp.bfloat16)
327
+ residual = self._create_dummy_tensor((num_tokens, hidden_size),
328
+ jnp.bfloat16)
329
+ intermediate_tensors = JaxIntermediateTensors(
330
+ tensors={
331
+ "hidden_states": hidden_states,
332
+ "residual": residual
333
+ })
334
+ else:
335
+ intermediate_tensors = None
336
+ self._precompile_backbone_helper(
337
+ f"worker{self.runner.rank} backbone with embeds",
338
+ input_ids=None,
339
+ positions=positions,
340
+ inputs_embeds=inputs_embeds,
341
+ intermediate_tensors=intermediate_tensors,
342
+ is_first_rank=is_first_rank,
343
+ is_last_rank=is_last_rank)
272
344
 
273
345
  def _precompile_select_from_array_helper(
274
346
  self,
@@ -336,7 +408,7 @@ class CompilationManager:
336
408
  self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None))
337
409
  dp_size = self.runner.vllm_config.sharding_config.total_dp_size
338
410
  self._precompile_select_from_array_helper(
339
- name="select all logits",
411
+ name=f"worker{self.runner.rank} select all logits",
340
412
  source_paddings=self.runner.num_tokens_paddings,
341
413
  indices_paddings=index_paddings,
342
414
  hidden_dim=hsize,
@@ -347,7 +419,8 @@ class CompilationManager:
347
419
  if self.runner.speculative_config:
348
420
  vocab_size = self.runner.model_config.get_vocab_size()
349
421
  self._precompile_select_from_array_helper(
350
- name="select bonus tokens for spec decoding",
422
+ name=
423
+ f"worker{self.runner.rank} select bonus tokens for spec decoding",
351
424
  source_paddings=self.runner.num_logits_paddings,
352
425
  indices_paddings=self.runner.num_reqs_paddings,
353
426
  hidden_dim=vocab_size,
@@ -355,7 +428,8 @@ class CompilationManager:
355
428
  PartitionSpec(None, "model")),
356
429
  )
357
430
  self._precompile_select_from_array_helper(
358
- name="select target tokens for spec decoding",
431
+ name=
432
+ f"worker{self.runner.rank} select target tokens for spec decoding",
359
433
  source_paddings=self.runner.num_logits_paddings,
360
434
  indices_paddings=self.runner.num_logits_paddings,
361
435
  hidden_dim=vocab_size,
@@ -378,7 +452,7 @@ class CompilationManager:
378
452
  np.array([num_reqs], dtype=np.int32)):
379
453
  lora_metadata = self.runner.lora_utils.extract_lora_metadata()
380
454
  self._run_compilation(
381
- "compute_logits",
455
+ f"worker{self.runner.rank} compute_logits",
382
456
  self.runner.compute_logits_fn,
383
457
  self.runner.state,
384
458
  hidden_states,
@@ -392,11 +466,12 @@ class CompilationManager:
392
466
  for num_reqs in self.runner.num_reqs_paddings:
393
467
  logits_sharding = NamedSharding(
394
468
  self.runner.mesh,
395
- PartitionSpec(ShardingAxisName.ATTN_DATA, "model"))
469
+ PartitionSpec(ShardingAxisName.MLP_DATA,
470
+ ShardingAxisName.MLP_TENSOR))
396
471
  dp_size = self.runner.vllm_config.sharding_config.total_dp_size
397
472
  sampling_metadata_sharding = NamedSharding(
398
473
  self.runner.mesh, PartitionSpec(
399
- ShardingAxisName.ATTN_DATA)) if dp_size > 1 else None
474
+ ShardingAxisName.MLP_DATA)) if dp_size > 1 else None
400
475
  logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16,
401
476
  logits_sharding)
402
477
  for do_sampling in (True, False):
@@ -420,7 +495,7 @@ class CompilationManager:
420
495
  do_sampling=do_sampling,
421
496
  )
422
497
  self._run_compilation(
423
- "sample",
498
+ f"worker{self.runner.rank} sample",
424
499
  sample,
425
500
  self.runner.rng_params_for_sampling,
426
501
  self.runner.mesh,
@@ -461,7 +536,7 @@ class CompilationManager:
461
536
  logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16)
462
537
  token_ids = self._create_dummy_tensor((num_reqs, ), jnp.int32)
463
538
  self._run_compilation(
464
- "gather_logprobs",
539
+ f"worker{self.runner.rank} gather_logprobs",
465
540
  self.runner._compute_and_gather_logprobs,
466
541
  logits,
467
542
  token_ids,
@@ -513,7 +588,7 @@ class CompilationManager:
513
588
  do_sampling=do_sampling)
514
589
 
515
590
  self._run_compilation(
516
- compilation_name,
591
+ f"worker{self.runner.rank} {compilation_name}",
517
592
  self.runner.rejection_sampler,
518
593
  draft_token_ids,
519
594
  num_draft_tokens,
@@ -530,7 +605,9 @@ class CompilationManager:
530
605
  def _precompile_eagle3_helpers(self) -> None:
531
606
  logger.info(
532
607
  "Compiling eagle3 jitted helpers with different input shapes.")
533
- hidden_size = self.runner.model_config.get_hidden_size()
608
+ target_hidden_size = self.runner.model_config.get_hidden_size()
609
+ draft_hidden_size = self.runner.speculative_config.draft_model_config.get_hidden_size(
610
+ )
534
611
  dtype = self.runner.model_config.dtype
535
612
 
536
613
  num_kv_cache_groups = len(self.runner.kv_cache_config.kv_cache_groups)
@@ -577,10 +654,11 @@ class CompilationManager:
577
654
 
578
655
  for num_logits in self.runner.num_logits_paddings:
579
656
  hidden_states = self._create_dummy_tensor(
580
- (num_logits, hidden_size), jnp.bfloat16)
657
+ (num_logits, draft_hidden_size), jnp.bfloat16)
581
658
  self._run_compilation(
582
659
  "eagle3_get_draft_token_ids",
583
660
  self.runner.drafter._get_draft_token_ids,
661
+ self.runner.drafter.state,
584
662
  hidden_states,
585
663
  num_logits=num_logits,
586
664
  )
@@ -588,8 +666,8 @@ class CompilationManager:
588
666
  input_ids_loop = self._create_dummy_tensor(
589
667
  (self.runner.max_num_reqs, ), jnp.int32,
590
668
  NamedSharding(self.runner.mesh, PartitionSpec()))
591
- target_hidden_state_loop = self._create_dummy_tensor(
592
- (self.runner.max_num_reqs, hidden_size), dtype,
669
+ draft_hidden_state_loop = self._create_dummy_tensor(
670
+ (self.runner.max_num_reqs, draft_hidden_size), dtype,
593
671
  NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
594
672
  next_token_ids = self._create_dummy_tensor(
595
673
  (self.runner.max_num_reqs, ), jnp.int32)
@@ -597,9 +675,12 @@ class CompilationManager:
597
675
  (self.runner.max_num_reqs, ), jnp.int32)
598
676
  for num_tokens in self.runner.num_tokens_paddings:
599
677
  aux_hidden_states = [
600
- self._create_dummy_tensor((num_tokens, hidden_size), dtype),
601
- self._create_dummy_tensor((num_tokens, hidden_size), dtype),
602
- self._create_dummy_tensor((num_tokens, hidden_size), dtype),
678
+ self._create_dummy_tensor((num_tokens, target_hidden_size),
679
+ dtype),
680
+ self._create_dummy_tensor((num_tokens, target_hidden_size),
681
+ dtype),
682
+ self._create_dummy_tensor((num_tokens, target_hidden_size),
683
+ dtype),
603
684
  ]
604
685
 
605
686
  positions = self._create_dummy_tensor((num_tokens, ), jnp.int32)
@@ -622,23 +703,23 @@ class CompilationManager:
622
703
  num_reqs,
623
704
  ):
624
705
  target_hidden_states, input_ids, last_token_indices, _ = self.runner.drafter._filter_token_and_prepare_initial_inputs(
625
- token_indices, query_start_loc, seq_lens, input_ids,
626
- aux_hidden_states, attention_metadata, next_token_ids,
627
- num_reqs)
706
+ self.runner.drafter.state, token_indices, query_start_loc,
707
+ seq_lens, input_ids, aux_hidden_states, attention_metadata,
708
+ next_token_ids, num_reqs)
628
709
  return target_hidden_states, input_ids, last_token_indices
629
710
 
630
711
  input_ids = self._create_dummy_tensor((num_tokens, ), jnp.int32)
631
712
  aux_hidden_states = [
632
713
  self._create_dummy_tensor(
633
- (num_tokens, hidden_size), jnp.bfloat16,
714
+ (num_tokens, target_hidden_size), jnp.bfloat16,
634
715
  NamedSharding(self.runner.mesh, PartitionSpec(None,
635
716
  None))),
636
717
  self._create_dummy_tensor(
637
- (num_tokens, hidden_size), jnp.bfloat16,
718
+ (num_tokens, target_hidden_size), jnp.bfloat16,
638
719
  NamedSharding(self.runner.mesh, PartitionSpec(None,
639
720
  None))),
640
721
  self._create_dummy_tensor(
641
- (num_tokens, hidden_size), jnp.bfloat16,
722
+ (num_tokens, target_hidden_size), jnp.bfloat16,
642
723
  NamedSharding(self.runner.mesh, PartitionSpec(None,
643
724
  None))),
644
725
  ]
@@ -670,17 +751,17 @@ class CompilationManager:
670
751
  state,
671
752
  kv_caches,
672
753
  input_ids,
673
- target_hidden_states,
754
+ draft_hidden_states,
674
755
  attention_metadata,
675
756
  ):
676
757
  kv_caches, hidden_states, _ = self.runner.drafter.model_fn(
677
- state, kv_caches, input_ids, target_hidden_states,
758
+ state, kv_caches, input_ids, draft_hidden_states,
678
759
  attention_metadata)
679
760
  self.runner.kv_caches = kv_caches
680
761
  return hidden_states
681
762
 
682
- target_hidden_states = self._create_dummy_tensor(
683
- (num_tokens, hidden_size), dtype,
763
+ draft_hidden_states = self._create_dummy_tensor(
764
+ (num_tokens, draft_hidden_size), dtype,
684
765
  NamedSharding(self.runner.mesh, PartitionSpec(None, "model")))
685
766
  input_ids = self._create_dummy_tensor(
686
767
  (num_tokens, ), jnp.int32,
@@ -691,7 +772,7 @@ class CompilationManager:
691
772
  self.runner.drafter.state,
692
773
  self.runner.kv_caches,
693
774
  input_ids,
694
- target_hidden_states,
775
+ draft_hidden_states,
695
776
  attention_metadata,
696
777
  num_tokens=num_tokens,
697
778
  )
@@ -701,6 +782,7 @@ class CompilationManager:
701
782
  self._run_compilation(
702
783
  "eagle3_prepare_hidden_states_and_input_ids",
703
784
  self.runner.drafter._prepare_hidden_states_and_input_ids,
785
+ self.runner.drafter.state,
704
786
  aux_hidden_states,
705
787
  query_start_loc,
706
788
  target_token_ids,
@@ -723,18 +805,19 @@ class CompilationManager:
723
805
  self.runner.drafter.state,
724
806
  self.runner.kv_caches,
725
807
  input_ids_loop,
726
- target_hidden_state_loop,
808
+ draft_hidden_state_loop,
727
809
  attention_metadata,
728
810
  num_tokens=num_tokens,
729
811
  )
730
812
 
731
813
  hidden_states = self._create_dummy_tensor(
732
- (num_tokens, hidden_size), jnp.bfloat16,
814
+ (num_tokens, draft_hidden_size), jnp.bfloat16,
733
815
  NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
734
816
 
735
817
  self._run_compilation(
736
818
  "eagle3_select_inputs_for_loop_speculation",
737
819
  self.runner.drafter._select_inputs_for_loop_speculation,
820
+ self.runner.drafter.state,
738
821
  positions,
739
822
  hidden_states,
740
823
  hidden_states,
@@ -745,6 +828,7 @@ class CompilationManager:
745
828
  self._run_compilation(
746
829
  "eagle3_select_draft_token_ids",
747
830
  self.runner.drafter._select_draft_token_ids,
831
+ self.runner.drafter.state,
748
832
  hidden_states,
749
833
  last_token_indices,
750
834
  num_tokens=num_tokens,
@@ -7,6 +7,7 @@ from jax._src import dtypes
7
7
  from jax.sharding import Mesh, NamedSharding, PartitionSpec
8
8
  from torchax.ops.mappings import t2j_dtype
9
9
 
10
+ import tpu_inference.kernels.mla.v1.kernel as mla
10
11
  import tpu_inference.kernels.ragged_paged_attention.v3.kernel as rpa
11
12
  import tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 as rpa_hd64
12
13
  from tpu_inference.layers.common.sharding import ShardingAxisName
@@ -17,9 +18,13 @@ logger = init_logger(__name__)
17
18
  DEFAULT_KV_CACHE_DTYPE = jnp.bfloat16
18
19
 
19
20
 
20
- def get_kv_cache_shape_with_mesh(mesh: Mesh, total_num_pages: int,
21
- page_size: int, actual_num_kv_heads: int,
22
- actual_head_dim: int, kv_dtype: any):
21
+ def get_kv_cache_shape_with_mesh(mesh: Mesh,
22
+ total_num_pages: int,
23
+ page_size: int,
24
+ actual_num_kv_heads: int,
25
+ actual_head_dim: int,
26
+ kv_dtype: any,
27
+ use_mla: bool = False):
23
28
  """Gets the KV cache shape based on the mesh configuration."""
24
29
 
25
30
  model_cnt = mesh.shape["model"]
@@ -28,15 +33,21 @@ def get_kv_cache_shape_with_mesh(mesh: Mesh, total_num_pages: int,
28
33
  # specific model, rather than being determined by the head_dim. If new
29
34
  # models are introduced with a head_dim of 64, this will require additional
30
35
  # model-specific adjustments.
31
- get_kv_cache_shape_fn = (
32
- rpa_hd64.get_kv_cache_shape if actual_head_dim == 64 \
33
- else rpa.get_kv_cache_shape
34
- )
35
- shape = list(
36
- get_kv_cache_shape_fn(total_num_pages, page_size,
37
- actual_num_kv_heads // model_cnt,
38
- actual_head_dim, kv_dtype))
39
- shape[2] *= model_cnt
36
+ if use_mla:
37
+ get_kv_cache_shape_fn = mla.get_kv_cache_shape
38
+ shape = list(
39
+ get_kv_cache_shape_fn(total_num_pages, page_size, actual_head_dim,
40
+ kv_dtype))
41
+ else:
42
+ get_kv_cache_shape_fn = (
43
+ rpa_hd64.get_kv_cache_shape if actual_head_dim == 64 \
44
+ else rpa.get_kv_cache_shape
45
+ )
46
+ shape = list(
47
+ get_kv_cache_shape_fn(total_num_pages, page_size,
48
+ actual_num_kv_heads // model_cnt,
49
+ actual_head_dim, kv_dtype))
50
+ shape[2] *= model_cnt
40
51
  return tuple(shape)
41
52
 
42
53
 
@@ -48,6 +59,7 @@ def create_kv_caches(
48
59
  mesh: Mesh,
49
60
  layer_names: List[str],
50
61
  cache_dtype: jnp.dtype = DEFAULT_KV_CACHE_DTYPE,
62
+ use_mla: bool = False,
51
63
  ) -> List[jax.Array]:
52
64
  """
53
65
  Creates a list of KV cache where each array mapps to single attention layer.
@@ -74,12 +86,16 @@ def create_kv_caches(
74
86
 
75
87
  cache_shape = get_kv_cache_shape_with_mesh(mesh, num_blocks, block_size,
76
88
  num_kv_heads, head_size,
77
- cache_dtype)
89
+ cache_dtype, use_mla)
78
90
 
79
- sharding = NamedSharding(
80
- mesh,
81
- PartitionSpec(ShardingAxisName.ATTN_DATA, None,
82
- ShardingAxisName.ATTN_HEAD))
91
+ if use_mla:
92
+ sharding = NamedSharding(mesh,
93
+ PartitionSpec(ShardingAxisName.MLP_TENSOR))
94
+ else:
95
+ sharding = NamedSharding(
96
+ mesh,
97
+ PartitionSpec(ShardingAxisName.ATTN_DATA, None,
98
+ ShardingAxisName.ATTN_HEAD))
83
99
 
84
100
  def _allocate() -> jax.Array:
85
101
  return jnp.empty(
@@ -94,7 +110,8 @@ def create_kv_caches(
94
110
  return kv_caches
95
111
 
96
112
 
97
- def get_rpa_page_size_bytes(mesh: Mesh, kv_cache_specs: dict[str, Any]) -> int:
113
+ def get_attention_page_size_bytes(mesh: Mesh,
114
+ kv_cache_specs: dict[str, Any]) -> int:
98
115
  """
99
116
  Calculate KV cache page size of RPA kernel.
100
117
 
@@ -107,14 +124,16 @@ def get_rpa_page_size_bytes(mesh: Mesh, kv_cache_specs: dict[str, Any]) -> int:
107
124
  """
108
125
 
109
126
  # Import it here to avoid circular import.
110
- from vllm.v1.kv_cache_interface import AttentionSpec
127
+ from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
111
128
 
112
129
  page_size_bytes_set = set()
113
130
  for kv_cache_spec in kv_cache_specs.values():
114
131
  assert isinstance(kv_cache_spec, AttentionSpec)
115
132
 
116
133
  dtype = t2j_dtype(kv_cache_spec.dtype)
117
- bits = dtypes.bit_width(dtype)
134
+ bits = (dtypes.bit_width(dtype) if hasattr(dtypes, "bit_width") else
135
+ dtypes.itemsize_bits(dtype))
136
+ use_mla = isinstance(kv_cache_spec, MLAAttentionSpec)
118
137
 
119
138
  kv_cache_shape = get_kv_cache_shape_with_mesh(
120
139
  mesh=mesh,
@@ -123,6 +142,7 @@ def get_rpa_page_size_bytes(mesh: Mesh, kv_cache_specs: dict[str, Any]) -> int:
123
142
  actual_num_kv_heads=kv_cache_spec.num_kv_heads,
124
143
  actual_head_dim=kv_cache_spec.head_size,
125
144
  kv_dtype=dtype,
145
+ use_mla=use_mla,
126
146
  )
127
147
  page_size_bytes = (bits * np.prod(kv_cache_shape)) // 8
128
148
  page_size_bytes_set.add(page_size_bytes)