tpu-inference 0.11.1.dev202511150811__py3-none-any.whl → 0.11.1.dev202512030818__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (54) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/lora/test_layers.py +0 -6
  3. tests/lora/utils.py +0 -8
  4. tests/test_envs.py +32 -11
  5. tests/test_utils.py +1 -2
  6. tpu_inference/__init__.py +22 -3
  7. tpu_inference/core/disagg_utils.py +6 -8
  8. tpu_inference/distributed/tpu_connector.py +3 -4
  9. tpu_inference/distributed/utils.py +3 -2
  10. tpu_inference/envs.py +61 -8
  11. tpu_inference/executors/ray_distributed_executor.py +31 -11
  12. tpu_inference/kernels/fused_moe/v1/kernel.py +641 -110
  13. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
  14. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +213 -126
  15. tpu_inference/layers/common/attention_interface.py +7 -1
  16. tpu_inference/layers/common/sharding.py +5 -5
  17. tpu_inference/layers/vllm/fused_moe.py +74 -25
  18. tpu_inference/layers/vllm/quantization/common.py +6 -1
  19. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -62
  20. tpu_inference/layers/vllm/quantization/unquantized.py +107 -113
  21. tpu_inference/layers/vllm/sharding.py +2 -2
  22. tpu_inference/lora/torch_punica_tpu.py +1 -2
  23. tpu_inference/models/common/model_loader.py +45 -11
  24. tpu_inference/models/jax/llama3.py +2 -1
  25. tpu_inference/models/jax/llama_eagle3.py +8 -5
  26. tpu_inference/models/jax/llama_guard_4.py +361 -0
  27. tpu_inference/models/jax/qwen2.py +2 -1
  28. tpu_inference/models/jax/qwen2_5_vl.py +163 -48
  29. tpu_inference/models/jax/qwen3.py +2 -1
  30. tpu_inference/models/jax/utils/quantization/quantization_utils.py +3 -6
  31. tpu_inference/models/jax/utils/weight_utils.py +198 -143
  32. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -7
  33. tpu_inference/platforms/tpu_platform.py +28 -22
  34. tpu_inference/runner/compilation_manager.py +144 -59
  35. tpu_inference/runner/kv_cache_manager.py +17 -18
  36. tpu_inference/runner/persistent_batch_manager.py +40 -2
  37. tpu_inference/runner/structured_decoding_manager.py +2 -3
  38. tpu_inference/runner/tpu_runner.py +271 -147
  39. tpu_inference/runner/utils.py +2 -2
  40. tpu_inference/spec_decode/jax/eagle3.py +71 -21
  41. tpu_inference/tpu_info.py +4 -3
  42. tpu_inference/utils.py +36 -13
  43. tpu_inference/worker/tpu_worker.py +162 -25
  44. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/METADATA +3 -2
  45. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/RECORD +48 -53
  46. tpu_inference/mock/__init__.py +0 -0
  47. tpu_inference/mock/vllm_config_utils.py +0 -28
  48. tpu_inference/mock/vllm_envs.py +0 -1219
  49. tpu_inference/mock/vllm_logger.py +0 -212
  50. tpu_inference/mock/vllm_logging_utils.py +0 -15
  51. tpu_inference/models/jax/phi3.py +0 -376
  52. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/WHEEL +0 -0
  53. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/licenses/LICENSE +0 -0
  54. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,13 @@
1
- import os
2
1
  import time
3
- from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
2
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
4
3
 
5
4
  import jax
6
5
  import jax.numpy as jnp
7
6
  import numpy as np
8
- import vllm.envs as envs
7
+ import vllm.envs as vllm_envs
9
8
  from jax.sharding import NamedSharding, PartitionSpec
10
9
 
10
+ import tpu_inference.envs as envs
11
11
  from tpu_inference.core.disagg_utils import is_disagg_enabled
12
12
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
13
13
  from tpu_inference.layers.common.sharding import ShardingAxisName
@@ -15,6 +15,8 @@ from tpu_inference.layers.jax.sample.sampling import sample
15
15
  from tpu_inference.layers.jax.sample.sampling_metadata import \
16
16
  TPUSupportedSamplingMetadata
17
17
  from tpu_inference.logger import init_logger
18
+ from tpu_inference.models.jax.jax_intermediate_tensor import \
19
+ JaxIntermediateTensors
18
20
  from tpu_inference.utils import device_array
19
21
 
20
22
  if TYPE_CHECKING:
@@ -30,10 +32,10 @@ class CompilationManager:
30
32
 
31
33
  def __init__(self, runner: "TPUModelRunner"):
32
34
  self.runner = runner
33
- if not envs.VLLM_DISABLE_COMPILE_CACHE:
35
+ if not vllm_envs.VLLM_DISABLE_COMPILE_CACHE:
34
36
  logger.info("Enabling JAX compile cache.")
35
37
  jax.config.update("jax_compilation_cache_dir",
36
- envs.VLLM_XLA_CACHE_PATH)
38
+ vllm_envs.VLLM_XLA_CACHE_PATH)
37
39
 
38
40
  def _create_dummy_tensor(self,
39
41
  shape: Tuple[int, ...],
@@ -67,8 +69,7 @@ class CompilationManager:
67
69
  logger.info("Compilation finished in %.2f [secs].", end - start)
68
70
 
69
71
  def capture_model(self) -> None:
70
- if os.getenv("SKIP_JAX_PRECOMPILE",
71
- False) or self.runner.model_config.enforce_eager:
72
+ if envs.SKIP_JAX_PRECOMPILE or self.runner.model_config.enforce_eager:
72
73
  return
73
74
  logger.info("Precompile all the subgraphs with possible input shapes.")
74
75
 
@@ -81,6 +82,8 @@ class CompilationManager:
81
82
  self._precompile_backbone_with_inputs_embeds()
82
83
  if self.runner.scheduler_config.async_scheduling:
83
84
  self._precompile_substitute_placeholder_token()
85
+ if not self.runner.is_last_rank:
86
+ return
84
87
  self._precompile_select_from_array()
85
88
  self._precompile_compute_logits()
86
89
  self._precompile_disagg_utils()
@@ -120,8 +123,15 @@ class CompilationManager:
120
123
  num_tokens=num_tokens,
121
124
  )
122
125
 
123
- def _precompile_backbone_helper(self, name, *, input_ids, positions,
124
- inputs_embeds) -> None:
126
+ def _precompile_backbone_helper(self,
127
+ name,
128
+ *,
129
+ input_ids,
130
+ positions,
131
+ inputs_embeds,
132
+ intermediate_tensors=None,
133
+ is_first_rank=True,
134
+ is_last_rank=True) -> None:
125
135
  num_tokens = None
126
136
  if input_ids is not None:
127
137
  num_tokens = input_ids.shape[0]
@@ -135,12 +145,6 @@ class CompilationManager:
135
145
  ShardingAxisName.ATTN_DATA, )) if dp_size > 1 else None
136
146
 
137
147
  # Keep existing pattern for complex array operations
138
- block_tables = self.runner.block_table_cpu[:self.runner.max_num_reqs]
139
- block_tables = block_tables.reshape(-1)
140
- block_tables = device_array(self.runner.mesh,
141
- block_tables,
142
- sharding=dp_sharding)
143
-
144
148
  seq_lens = self._create_dummy_tensor((self.runner.max_num_reqs, ),
145
149
  jnp.int32, dp_sharding)
146
150
  query_start_loc = self._create_dummy_tensor(
@@ -152,26 +156,49 @@ class CompilationManager:
152
156
  request_distribution,
153
157
  sharding=dp_sharding)
154
158
 
155
- attention_metadata = AttentionMetadata(
156
- input_positions=positions,
157
- block_tables=block_tables,
158
- seq_lens=seq_lens,
159
- query_start_loc=query_start_loc,
160
- request_distribution=request_distribution,
161
- )
159
+ attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
160
+ uniform_attention_metadata: AttentionMetadata = None
161
+ for kv_cache_gid, kv_cache_group in enumerate(
162
+ self.runner.kv_cache_config.kv_cache_groups):
163
+ block_tables = self.runner.block_tables_cpu[
164
+ kv_cache_gid][:self.runner.max_num_reqs]
165
+ block_tables = block_tables.reshape(-1)
166
+ block_tables = device_array(self.runner.mesh,
167
+ block_tables,
168
+ sharding=dp_sharding)
169
+
170
+ attention_metadata_gid = AttentionMetadata(
171
+ input_positions=positions,
172
+ block_tables=block_tables,
173
+ seq_lens=seq_lens,
174
+ query_start_loc=query_start_loc,
175
+ request_distribution=request_distribution,
176
+ )
177
+ if not self.runner.use_hybrid_kvcache:
178
+ # all layers share the same attention metadata
179
+ uniform_attention_metadata = attention_metadata_gid
180
+ else:
181
+ for layer_name in kv_cache_group.layer_names:
182
+ attention_metadata_per_layer[
183
+ layer_name] = attention_metadata_gid
162
184
 
163
185
  def model_fn_wrapper(
164
186
  state,
165
187
  kv_caches,
166
188
  input_ids,
167
189
  attention_metadata,
190
+ positions,
168
191
  inputs_embeds,
169
192
  layer_name_to_kvcache_index,
170
193
  lora_metadata,
194
+ intermediate_tensors,
195
+ is_first_rank,
196
+ is_last_rank,
171
197
  ):
172
198
  kv_caches, hidden_states, _ = self.runner.model_fn(
173
199
  state, kv_caches, input_ids, attention_metadata, inputs_embeds,
174
- layer_name_to_kvcache_index, lora_metadata)
200
+ positions, layer_name_to_kvcache_index, lora_metadata,
201
+ intermediate_tensors, is_first_rank, is_last_rank)
175
202
  self.runner.kv_caches = kv_caches
176
203
  return hidden_states
177
204
 
@@ -179,6 +206,10 @@ class CompilationManager:
179
206
  self.runner.lora_config, np.array([num_tokens],
180
207
  dtype=np.int32)):
181
208
  lora_metadata = self.runner.lora_utils.extract_lora_metadata()
209
+ if self.runner.use_hybrid_kvcache:
210
+ attention_metadata = attention_metadata_per_layer
211
+ else:
212
+ attention_metadata = uniform_attention_metadata
182
213
  self._run_compilation(
183
214
  name,
184
215
  model_fn_wrapper,
@@ -186,9 +217,13 @@ class CompilationManager:
186
217
  self.runner.kv_caches,
187
218
  input_ids,
188
219
  attention_metadata,
220
+ positions,
189
221
  inputs_embeds,
190
222
  tuple(self.runner.layer_name_to_kvcache_index.items()),
191
223
  lora_metadata,
224
+ intermediate_tensors,
225
+ is_first_rank,
226
+ is_last_rank,
192
227
  num_tokens=num_tokens,
193
228
  )
194
229
 
@@ -239,6 +274,7 @@ class CompilationManager:
239
274
  )
240
275
 
241
276
  def _precompile_backbone_text_only(self) -> None:
277
+ hidden_size = self.runner.model_config.get_hidden_size()
242
278
  for num_tokens in self.runner.num_tokens_paddings:
243
279
  dp_sharding = NamedSharding(
244
280
  self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, )
@@ -248,10 +284,28 @@ class CompilationManager:
248
284
  dp_sharding)
249
285
  positions = self._create_dummy_tensor((num_tokens, ), jnp.int32,
250
286
  dp_sharding)
251
- self._precompile_backbone_helper("backbone",
252
- input_ids=input_ids,
253
- positions=positions,
254
- inputs_embeds=None)
287
+ is_first_rank = self.runner.is_first_rank
288
+ is_last_rank = self.runner.is_last_rank
289
+ if is_first_rank:
290
+ intermediate_tensors = None
291
+ else:
292
+ hidden_states = self._create_dummy_tensor(
293
+ (num_tokens, hidden_size), jnp.bfloat16)
294
+ residual = self._create_dummy_tensor((num_tokens, hidden_size),
295
+ jnp.bfloat16)
296
+ intermediate_tensors = JaxIntermediateTensors(
297
+ tensors={
298
+ "hidden_states": hidden_states,
299
+ "residual": residual
300
+ })
301
+ self._precompile_backbone_helper(
302
+ f"worker{self.runner.rank} backbone",
303
+ input_ids=input_ids,
304
+ positions=positions,
305
+ inputs_embeds=None,
306
+ intermediate_tensors=intermediate_tensors,
307
+ is_first_rank=is_first_rank,
308
+ is_last_rank=is_last_rank)
255
309
 
256
310
  def _precompile_backbone_with_inputs_embeds(self) -> None:
257
311
  hidden_size = self.runner.model_config.get_hidden_size()
@@ -265,10 +319,28 @@ class CompilationManager:
265
319
  else:
266
320
  positions = self._create_dummy_tensor((num_tokens, ),
267
321
  jnp.int32)
268
- self._precompile_backbone_helper("backbone with embeds",
269
- input_ids=None,
270
- positions=positions,
271
- inputs_embeds=inputs_embeds)
322
+ is_first_rank = self.runner.is_first_rank
323
+ is_last_rank = self.runner.is_last_rank
324
+ if not is_first_rank:
325
+ hidden_states = self._create_dummy_tensor(
326
+ (num_tokens, hidden_size), jnp.bfloat16)
327
+ residual = self._create_dummy_tensor((num_tokens, hidden_size),
328
+ jnp.bfloat16)
329
+ intermediate_tensors = JaxIntermediateTensors(
330
+ tensors={
331
+ "hidden_states": hidden_states,
332
+ "residual": residual
333
+ })
334
+ else:
335
+ intermediate_tensors = None
336
+ self._precompile_backbone_helper(
337
+ f"worker{self.runner.rank} backbone with embeds",
338
+ input_ids=None,
339
+ positions=positions,
340
+ inputs_embeds=inputs_embeds,
341
+ intermediate_tensors=intermediate_tensors,
342
+ is_first_rank=is_first_rank,
343
+ is_last_rank=is_last_rank)
272
344
 
273
345
  def _precompile_select_from_array_helper(
274
346
  self,
@@ -332,20 +404,23 @@ class CompilationManager:
332
404
  index_paddings = self.runner.num_reqs_paddings
333
405
  dp_sharding = NamedSharding(self.runner.mesh,
334
406
  PartitionSpec(ShardingAxisName.ATTN_DATA))
407
+ hidden_states_sharding = NamedSharding(
408
+ self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None))
335
409
  dp_size = self.runner.vllm_config.sharding_config.total_dp_size
336
410
  self._precompile_select_from_array_helper(
337
- name="select all logits",
411
+ name=f"worker{self.runner.rank} select all logits",
338
412
  source_paddings=self.runner.num_tokens_paddings,
339
413
  indices_paddings=index_paddings,
340
414
  hidden_dim=hsize,
341
- input_sharding=dp_sharding,
415
+ input_sharding=hidden_states_sharding,
342
416
  indices_sharding=dp_sharding if dp_size > 1 else None,
343
417
  )
344
418
 
345
419
  if self.runner.speculative_config:
346
420
  vocab_size = self.runner.model_config.get_vocab_size()
347
421
  self._precompile_select_from_array_helper(
348
- name="select bonus tokens for spec decoding",
422
+ name=
423
+ f"worker{self.runner.rank} select bonus tokens for spec decoding",
349
424
  source_paddings=self.runner.num_logits_paddings,
350
425
  indices_paddings=self.runner.num_reqs_paddings,
351
426
  hidden_dim=vocab_size,
@@ -353,7 +428,8 @@ class CompilationManager:
353
428
  PartitionSpec(None, "model")),
354
429
  )
355
430
  self._precompile_select_from_array_helper(
356
- name="select target tokens for spec decoding",
431
+ name=
432
+ f"worker{self.runner.rank} select target tokens for spec decoding",
357
433
  source_paddings=self.runner.num_logits_paddings,
358
434
  indices_paddings=self.runner.num_logits_paddings,
359
435
  hidden_dim=vocab_size,
@@ -376,7 +452,7 @@ class CompilationManager:
376
452
  np.array([num_reqs], dtype=np.int32)):
377
453
  lora_metadata = self.runner.lora_utils.extract_lora_metadata()
378
454
  self._run_compilation(
379
- "compute_logits",
455
+ f"worker{self.runner.rank} compute_logits",
380
456
  self.runner.compute_logits_fn,
381
457
  self.runner.state,
382
458
  hidden_states,
@@ -418,7 +494,7 @@ class CompilationManager:
418
494
  do_sampling=do_sampling,
419
495
  )
420
496
  self._run_compilation(
421
- "sample",
497
+ f"worker{self.runner.rank} sample",
422
498
  sample,
423
499
  self.runner.rng_params_for_sampling,
424
500
  self.runner.mesh,
@@ -459,7 +535,7 @@ class CompilationManager:
459
535
  logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16)
460
536
  token_ids = self._create_dummy_tensor((num_reqs, ), jnp.int32)
461
537
  self._run_compilation(
462
- "gather_logprobs",
538
+ f"worker{self.runner.rank} gather_logprobs",
463
539
  self.runner._compute_and_gather_logprobs,
464
540
  logits,
465
541
  token_ids,
@@ -511,7 +587,7 @@ class CompilationManager:
511
587
  do_sampling=do_sampling)
512
588
 
513
589
  self._run_compilation(
514
- compilation_name,
590
+ f"worker{self.runner.rank} {compilation_name}",
515
591
  self.runner.rejection_sampler,
516
592
  draft_token_ids,
517
593
  num_draft_tokens,
@@ -528,7 +604,9 @@ class CompilationManager:
528
604
  def _precompile_eagle3_helpers(self) -> None:
529
605
  logger.info(
530
606
  "Compiling eagle3 jitted helpers with different input shapes.")
531
- hidden_size = self.runner.model_config.get_hidden_size()
607
+ target_hidden_size = self.runner.model_config.get_hidden_size()
608
+ draft_hidden_size = self.runner.speculative_config.draft_model_config.get_hidden_size(
609
+ )
532
610
  dtype = self.runner.model_config.dtype
533
611
 
534
612
  num_kv_cache_groups = len(self.runner.kv_cache_config.kv_cache_groups)
@@ -575,10 +653,11 @@ class CompilationManager:
575
653
 
576
654
  for num_logits in self.runner.num_logits_paddings:
577
655
  hidden_states = self._create_dummy_tensor(
578
- (num_logits, hidden_size), jnp.bfloat16)
656
+ (num_logits, draft_hidden_size), jnp.bfloat16)
579
657
  self._run_compilation(
580
658
  "eagle3_get_draft_token_ids",
581
659
  self.runner.drafter._get_draft_token_ids,
660
+ self.runner.drafter.state,
582
661
  hidden_states,
583
662
  num_logits=num_logits,
584
663
  )
@@ -586,8 +665,8 @@ class CompilationManager:
586
665
  input_ids_loop = self._create_dummy_tensor(
587
666
  (self.runner.max_num_reqs, ), jnp.int32,
588
667
  NamedSharding(self.runner.mesh, PartitionSpec()))
589
- target_hidden_state_loop = self._create_dummy_tensor(
590
- (self.runner.max_num_reqs, hidden_size), dtype,
668
+ draft_hidden_state_loop = self._create_dummy_tensor(
669
+ (self.runner.max_num_reqs, draft_hidden_size), dtype,
591
670
  NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
592
671
  next_token_ids = self._create_dummy_tensor(
593
672
  (self.runner.max_num_reqs, ), jnp.int32)
@@ -595,9 +674,12 @@ class CompilationManager:
595
674
  (self.runner.max_num_reqs, ), jnp.int32)
596
675
  for num_tokens in self.runner.num_tokens_paddings:
597
676
  aux_hidden_states = [
598
- self._create_dummy_tensor((num_tokens, hidden_size), dtype),
599
- self._create_dummy_tensor((num_tokens, hidden_size), dtype),
600
- self._create_dummy_tensor((num_tokens, hidden_size), dtype),
677
+ self._create_dummy_tensor((num_tokens, target_hidden_size),
678
+ dtype),
679
+ self._create_dummy_tensor((num_tokens, target_hidden_size),
680
+ dtype),
681
+ self._create_dummy_tensor((num_tokens, target_hidden_size),
682
+ dtype),
601
683
  ]
602
684
 
603
685
  positions = self._create_dummy_tensor((num_tokens, ), jnp.int32)
@@ -620,23 +702,23 @@ class CompilationManager:
620
702
  num_reqs,
621
703
  ):
622
704
  target_hidden_states, input_ids, last_token_indices, _ = self.runner.drafter._filter_token_and_prepare_initial_inputs(
623
- token_indices, query_start_loc, seq_lens, input_ids,
624
- aux_hidden_states, attention_metadata, next_token_ids,
625
- num_reqs)
705
+ self.runner.drafter.state, token_indices, query_start_loc,
706
+ seq_lens, input_ids, aux_hidden_states, attention_metadata,
707
+ next_token_ids, num_reqs)
626
708
  return target_hidden_states, input_ids, last_token_indices
627
709
 
628
710
  input_ids = self._create_dummy_tensor((num_tokens, ), jnp.int32)
629
711
  aux_hidden_states = [
630
712
  self._create_dummy_tensor(
631
- (num_tokens, hidden_size), jnp.bfloat16,
713
+ (num_tokens, target_hidden_size), jnp.bfloat16,
632
714
  NamedSharding(self.runner.mesh, PartitionSpec(None,
633
715
  None))),
634
716
  self._create_dummy_tensor(
635
- (num_tokens, hidden_size), jnp.bfloat16,
717
+ (num_tokens, target_hidden_size), jnp.bfloat16,
636
718
  NamedSharding(self.runner.mesh, PartitionSpec(None,
637
719
  None))),
638
720
  self._create_dummy_tensor(
639
- (num_tokens, hidden_size), jnp.bfloat16,
721
+ (num_tokens, target_hidden_size), jnp.bfloat16,
640
722
  NamedSharding(self.runner.mesh, PartitionSpec(None,
641
723
  None))),
642
724
  ]
@@ -668,17 +750,17 @@ class CompilationManager:
668
750
  state,
669
751
  kv_caches,
670
752
  input_ids,
671
- target_hidden_states,
753
+ draft_hidden_states,
672
754
  attention_metadata,
673
755
  ):
674
756
  kv_caches, hidden_states, _ = self.runner.drafter.model_fn(
675
- state, kv_caches, input_ids, target_hidden_states,
757
+ state, kv_caches, input_ids, draft_hidden_states,
676
758
  attention_metadata)
677
759
  self.runner.kv_caches = kv_caches
678
760
  return hidden_states
679
761
 
680
- target_hidden_states = self._create_dummy_tensor(
681
- (num_tokens, hidden_size), dtype,
762
+ draft_hidden_states = self._create_dummy_tensor(
763
+ (num_tokens, draft_hidden_size), dtype,
682
764
  NamedSharding(self.runner.mesh, PartitionSpec(None, "model")))
683
765
  input_ids = self._create_dummy_tensor(
684
766
  (num_tokens, ), jnp.int32,
@@ -689,7 +771,7 @@ class CompilationManager:
689
771
  self.runner.drafter.state,
690
772
  self.runner.kv_caches,
691
773
  input_ids,
692
- target_hidden_states,
774
+ draft_hidden_states,
693
775
  attention_metadata,
694
776
  num_tokens=num_tokens,
695
777
  )
@@ -699,6 +781,7 @@ class CompilationManager:
699
781
  self._run_compilation(
700
782
  "eagle3_prepare_hidden_states_and_input_ids",
701
783
  self.runner.drafter._prepare_hidden_states_and_input_ids,
784
+ self.runner.drafter.state,
702
785
  aux_hidden_states,
703
786
  query_start_loc,
704
787
  target_token_ids,
@@ -721,18 +804,19 @@ class CompilationManager:
721
804
  self.runner.drafter.state,
722
805
  self.runner.kv_caches,
723
806
  input_ids_loop,
724
- target_hidden_state_loop,
807
+ draft_hidden_state_loop,
725
808
  attention_metadata,
726
809
  num_tokens=num_tokens,
727
810
  )
728
811
 
729
812
  hidden_states = self._create_dummy_tensor(
730
- (num_tokens, hidden_size), jnp.bfloat16,
813
+ (num_tokens, draft_hidden_size), jnp.bfloat16,
731
814
  NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
732
815
 
733
816
  self._run_compilation(
734
817
  "eagle3_select_inputs_for_loop_speculation",
735
818
  self.runner.drafter._select_inputs_for_loop_speculation,
819
+ self.runner.drafter.state,
736
820
  positions,
737
821
  hidden_states,
738
822
  hidden_states,
@@ -743,6 +827,7 @@ class CompilationManager:
743
827
  self._run_compilation(
744
828
  "eagle3_select_draft_token_ids",
745
829
  self.runner.drafter._select_draft_token_ids,
830
+ self.runner.drafter.state,
746
831
  hidden_states,
747
832
  last_token_indices,
748
833
  num_tokens=num_tokens,
@@ -1,15 +1,16 @@
1
1
  import functools
2
- import math
3
2
  from typing import TYPE_CHECKING, Dict, List
4
3
 
5
4
  import jax
6
5
  import jax.numpy as jnp
6
+ import numpy as np
7
7
  import vllm.envs as envs
8
8
  from jax.sharding import NamedSharding, PartitionSpec
9
9
  from torchax.ops.mappings import t2j_dtype
10
- from vllm.attention import Attention
11
10
  from vllm.attention.backends.abstract import AttentionType
11
+ from vllm.attention.layer import Attention
12
12
  from vllm.config import get_layers_from_vllm_config
13
+ from vllm.utils.math_utils import cdiv
13
14
  from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
14
15
  KVCacheSpec, MLAAttentionSpec,
15
16
  SlidingWindowSpec)
@@ -175,6 +176,11 @@ class KVCacheManager:
175
176
  )
176
177
  self.runner.input_batch = new_input_batch
177
178
  self.runner.persistent_batch_manager.input_batch = new_input_batch
179
+ self.runner.block_tables_cpu = [
180
+ np.zeros((self.runner.max_num_reqs,
181
+ cdiv(self.runner.max_model_len, block_size)),
182
+ dtype=np.int32) for block_size in block_sizes
183
+ ]
178
184
 
179
185
  def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
180
186
  self.maybe_reinitialize_input_batch(kv_cache_config)
@@ -190,7 +196,7 @@ class KVCacheManager:
190
196
  num_blocks = kv_cache_tensor.size // page_size_bytes
191
197
  dp_size = self.runner.vllm_config.sharding_config.total_dp_size
192
198
  # num_blocks must be a multiple of dp_size
193
- num_blocks = math.ceil(num_blocks / dp_size) * dp_size
199
+ num_blocks = (num_blocks // dp_size) * dp_size
194
200
  # NOTE: we'll multiply the num_kv_heads by 2 in the function
195
201
  kv_cache = create_kv_caches(
196
202
  num_blocks=num_blocks,
@@ -283,13 +289,8 @@ class KVCacheManager:
283
289
 
284
290
  def _update_layer(cache, slices):
285
291
  """The function to apply to each layer's cache and slices."""
286
- reshaped_slices = slices.reshape(-1, 1, block_size,
287
- *slices.shape[1:])
288
- for (i, block_idx) in enumerate(block_numbers):
289
- cache = jax.lax.dynamic_update_slice_in_dim(cache,
290
- reshaped_slices[i],
291
- block_idx,
292
- axis=0)
292
+ reshaped_slices = slices.reshape(-1, block_size, *slices.shape[1:])
293
+ cache.at[block_numbers].set(reshaped_slices)
293
294
  return cache
294
295
 
295
296
  return jax.tree.map(_update_layer, kv_caches, kv_cache_slices)
@@ -342,16 +343,12 @@ class KVCacheManager:
342
343
  """
343
344
  if block_ids == list(range(block_ids[0],
344
345
  block_ids[0] + len(block_ids))):
345
- with runner_utils.LatencyTracker(
346
- "BatchedGatherKVSlices-for-blocks"):
347
- batched_kv_cache_per_layer = self._jitted_gather_continuous_kv_cache(
348
- self.runner.kv_caches, block_ids[0], len(block_ids))
346
+ batched_kv_cache_per_layer = self._jitted_gather_continuous_kv_cache(
347
+ self.runner.kv_caches, block_ids[0], len(block_ids))
349
348
 
350
349
  else:
351
- with runner_utils.LatencyTracker(
352
- "BatchedGatherKVSlices-for-blocks"):
353
- batched_kv_cache_per_layer = self._jitted_gather_kv_cache(
354
- self.runner.kv_caches, jnp.array(block_ids))
350
+ batched_kv_cache_per_layer = self._jitted_gather_kv_cache(
351
+ self.runner.kv_caches, jnp.array(block_ids))
355
352
  return batched_kv_cache_per_layer
356
353
 
357
354
  def transfer_kv_cache(self,
@@ -440,6 +437,7 @@ class KVCacheManager:
440
437
  kv_cache_slices,
441
438
  start_block,
442
439
  )
440
+ jax.block_until_ready(self.runner.kv_caches)
443
441
  else:
444
442
  with runner_utils.LatencyTracker(
445
443
  f"JittedInsertKVCache-b{len(block_numbers)}"):
@@ -451,6 +449,7 @@ class KVCacheManager:
451
449
  kv_cache_slices,
452
450
  jnp.array(block_numbers),
453
451
  )
452
+ jax.block_until_ready(self.runner.kv_caches)
454
453
 
455
454
  logger.debug(
456
455
  f"Updated kv cache entries cnt={len(self.runner.kv_caches)}")
@@ -14,12 +14,13 @@ class PersistentBatchManager:
14
14
  def __init__(self, requests: Dict[str, CachedRequestState],
15
15
  input_batch: InputBatch, encoder_cache: Dict[str,
16
16
  'jax.Array'],
17
- uses_mrope: bool, model_config):
17
+ uses_mrope: bool, model_config, is_last_rank: bool):
18
18
  self.requests = requests
19
19
  self.input_batch = input_batch
20
20
  self.encoder_cache = encoder_cache
21
21
  self.uses_mrope = uses_mrope
22
22
  self.model_config = model_config
23
+ self.is_last_rank = is_last_rank
23
24
 
24
25
  def _reorder_batch(self, scheduler_output: "VllmSchedulerOutput") -> int:
25
26
  """ Reorder the sheduled requests to RPA kernel friendly distribution
@@ -179,9 +180,35 @@ class PersistentBatchManager:
179
180
  num_computed_tokens = req_data.num_computed_tokens[i]
180
181
  new_block_ids = req_data.new_block_ids[i]
181
182
  resumed_from_preemption = req_data.resumed_from_preemption[i]
183
+ num_output_tokens = req_data.num_output_tokens[i]
182
184
 
183
185
  # Update the cached states.
184
186
  req_state.num_computed_tokens = num_computed_tokens
187
+ req_index = self.input_batch.req_id_to_index.get(req_id)
188
+
189
+ if not self.is_last_rank:
190
+ # When using PP, the scheduler sends the sampled tokens back,
191
+ # because there's no direct communication between the first-
192
+ # stage worker and the last-stage worker.
193
+ new_token_ids = req_data.new_token_ids[i]
194
+ # Add the sampled token(s) from the previous step (if any).
195
+ # This doesn't include "unverified" tokens like spec tokens.
196
+ num_new_tokens = (num_computed_tokens + len(new_token_ids) -
197
+ req_state.num_tokens)
198
+ if num_new_tokens == 1:
199
+ req_state.output_token_ids.append(new_token_ids[-1])
200
+ elif num_new_tokens > 0:
201
+ req_state.output_token_ids.extend(
202
+ new_token_ids[-num_new_tokens:])
203
+ elif num_output_tokens < len(req_state.output_token_ids):
204
+ del req_state.output_token_ids[num_output_tokens:]
205
+ if req_index is not None:
206
+ end_idx = (self.input_batch.num_prompt_tokens[req_index] +
207
+ num_output_tokens)
208
+ self.input_batch.num_tokens[req_index] = end_idx
209
+ self.input_batch.num_tokens_no_spec[req_index] = end_idx
210
+
211
+ # Update the block IDs.
185
212
  if not resumed_from_preemption:
186
213
  if new_block_ids is not None:
187
214
  # Append the new blocks to the existing block IDs.
@@ -194,7 +221,6 @@ class PersistentBatchManager:
194
221
  # Replace the existing block IDs with the new ones.
195
222
  req_state.block_ids = new_block_ids
196
223
 
197
- req_index = self.input_batch.req_id_to_index.get(req_id)
198
224
  if req_index is None:
199
225
  # The request is not in the persistent batch.
200
226
  # The request was either preempted and resumed later, or was not
@@ -209,6 +235,18 @@ class PersistentBatchManager:
209
235
  self.input_batch.block_table.append_row(
210
236
  new_block_ids, req_index)
211
237
 
238
+ # For the last rank, we don't need to update the token_ids_cpu
239
+ # because the sampled tokens are already cached.
240
+ if not self.is_last_rank:
241
+ start_token_index = num_computed_tokens
242
+ end_token_index = num_computed_tokens + len(new_token_ids)
243
+ self.input_batch.token_ids_cpu[
244
+ req_index,
245
+ start_token_index:end_token_index] = new_token_ids
246
+ self.input_batch.num_tokens_no_spec[
247
+ req_index] = end_token_index
248
+ self.input_batch.num_tokens[req_index] = end_token_index
249
+
212
250
  # Add spec_token_ids to token_ids_cpu.
213
251
  spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
214
252
  req_id, ())
@@ -61,11 +61,10 @@ class StructuredDecodingManager:
61
61
  self.runner.require_structured_out_cpu.fill(0)
62
62
 
63
63
  sorted_struct_requests = sorted(
64
- grammar_output.structured_output_request_ids.items(),
65
- key=lambda item: item[1])
64
+ grammar_output.structured_output_request_ids)
66
65
 
67
66
  cumulative_mask_idx = 0
68
- for req_id, _ in sorted_struct_requests:
67
+ for req_id in sorted_struct_requests:
69
68
  if req_id not in self.runner.input_batch.req_id_to_index:
70
69
  continue
71
70
  batch_index = self.runner.input_batch.req_id_to_index[req_id]