tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511180814__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (56) hide show
  1. tests/kernels/fused_moe_v1_test.py +34 -303
  2. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
  3. tests/lora/test_layers.py +6 -0
  4. tests/lora/utils.py +8 -0
  5. tests/test_envs.py +11 -32
  6. tests/test_utils.py +2 -1
  7. tpu_inference/__init__.py +3 -22
  8. tpu_inference/core/disagg_utils.py +8 -6
  9. tpu_inference/distributed/tpu_connector.py +4 -3
  10. tpu_inference/distributed/utils.py +2 -3
  11. tpu_inference/envs.py +8 -61
  12. tpu_inference/executors/ray_distributed_executor.py +2 -9
  13. tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
  14. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
  15. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +145 -266
  16. tpu_inference/layers/common/attention_interface.py +1 -7
  17. tpu_inference/layers/common/sharding.py +5 -5
  18. tpu_inference/layers/vllm/fused_moe.py +208 -170
  19. tpu_inference/layers/vllm/quantization/common.py +1 -6
  20. tpu_inference/layers/vllm/quantization/mxfp4.py +73 -138
  21. tpu_inference/layers/vllm/quantization/unquantized.py +64 -58
  22. tpu_inference/layers/vllm/sharding.py +2 -2
  23. tpu_inference/lora/torch_punica_tpu.py +2 -1
  24. tpu_inference/mock/__init__.py +0 -0
  25. tpu_inference/mock/vllm_config_utils.py +28 -0
  26. tpu_inference/mock/vllm_envs.py +1219 -0
  27. tpu_inference/mock/vllm_logger.py +212 -0
  28. tpu_inference/mock/vllm_logging_utils.py +15 -0
  29. tpu_inference/models/common/model_loader.py +10 -43
  30. tpu_inference/models/jax/llama3.py +1 -2
  31. tpu_inference/models/jax/llama_eagle3.py +5 -8
  32. tpu_inference/models/jax/phi3.py +376 -0
  33. tpu_inference/models/jax/qwen2.py +1 -2
  34. tpu_inference/models/jax/qwen2_5_vl.py +48 -163
  35. tpu_inference/models/jax/qwen3.py +1 -2
  36. tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
  37. tpu_inference/models/jax/utils/weight_utils.py +143 -198
  38. tpu_inference/models/vllm/vllm_model_wrapper.py +8 -14
  39. tpu_inference/platforms/tpu_platform.py +31 -37
  40. tpu_inference/runner/compilation_manager.py +58 -141
  41. tpu_inference/runner/kv_cache.py +1 -1
  42. tpu_inference/runner/kv_cache_manager.py +18 -17
  43. tpu_inference/runner/persistent_batch_manager.py +2 -40
  44. tpu_inference/runner/structured_decoding_manager.py +3 -2
  45. tpu_inference/runner/tpu_runner.py +147 -271
  46. tpu_inference/runner/utils.py +2 -2
  47. tpu_inference/spec_decode/jax/eagle3.py +21 -71
  48. tpu_inference/tpu_info.py +3 -4
  49. tpu_inference/utils.py +13 -36
  50. tpu_inference/worker/tpu_worker.py +25 -162
  51. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/METADATA +3 -4
  52. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/RECORD +55 -50
  53. tpu_inference/models/jax/llama_guard_4.py +0 -361
  54. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/WHEEL +0 -0
  55. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/licenses/LICENSE +0 -0
  56. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,13 @@
1
+ import os
1
2
  import time
2
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
3
+ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
3
4
 
4
5
  import jax
5
6
  import jax.numpy as jnp
6
7
  import numpy as np
7
- import vllm.envs as vllm_envs
8
+ import vllm.envs as envs
8
9
  from jax.sharding import NamedSharding, PartitionSpec
9
10
 
10
- import tpu_inference.envs as envs
11
11
  from tpu_inference.core.disagg_utils import is_disagg_enabled
12
12
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
13
13
  from tpu_inference.layers.common.sharding import ShardingAxisName
@@ -15,8 +15,6 @@ from tpu_inference.layers.jax.sample.sampling import sample
15
15
  from tpu_inference.layers.jax.sample.sampling_metadata import \
16
16
  TPUSupportedSamplingMetadata
17
17
  from tpu_inference.logger import init_logger
18
- from tpu_inference.models.jax.jax_intermediate_tensor import \
19
- JaxIntermediateTensors
20
18
  from tpu_inference.utils import device_array
21
19
 
22
20
  if TYPE_CHECKING:
@@ -32,10 +30,10 @@ class CompilationManager:
32
30
 
33
31
  def __init__(self, runner: "TPUModelRunner"):
34
32
  self.runner = runner
35
- if not vllm_envs.VLLM_DISABLE_COMPILE_CACHE:
33
+ if not envs.VLLM_DISABLE_COMPILE_CACHE:
36
34
  logger.info("Enabling JAX compile cache.")
37
35
  jax.config.update("jax_compilation_cache_dir",
38
- vllm_envs.VLLM_XLA_CACHE_PATH)
36
+ envs.VLLM_XLA_CACHE_PATH)
39
37
 
40
38
  def _create_dummy_tensor(self,
41
39
  shape: Tuple[int, ...],
@@ -69,7 +67,8 @@ class CompilationManager:
69
67
  logger.info("Compilation finished in %.2f [secs].", end - start)
70
68
 
71
69
  def capture_model(self) -> None:
72
- if envs.SKIP_JAX_PRECOMPILE or self.runner.model_config.enforce_eager:
70
+ if os.getenv("SKIP_JAX_PRECOMPILE",
71
+ False) or self.runner.model_config.enforce_eager:
73
72
  return
74
73
  logger.info("Precompile all the subgraphs with possible input shapes.")
75
74
 
@@ -82,8 +81,6 @@ class CompilationManager:
82
81
  self._precompile_backbone_with_inputs_embeds()
83
82
  if self.runner.scheduler_config.async_scheduling:
84
83
  self._precompile_substitute_placeholder_token()
85
- if not self.runner.is_last_rank:
86
- return
87
84
  self._precompile_select_from_array()
88
85
  self._precompile_compute_logits()
89
86
  self._precompile_disagg_utils()
@@ -123,15 +120,8 @@ class CompilationManager:
123
120
  num_tokens=num_tokens,
124
121
  )
125
122
 
126
- def _precompile_backbone_helper(self,
127
- name,
128
- *,
129
- input_ids,
130
- positions,
131
- inputs_embeds,
132
- intermediate_tensors=None,
133
- is_first_rank=True,
134
- is_last_rank=True) -> None:
123
+ def _precompile_backbone_helper(self, name, *, input_ids, positions,
124
+ inputs_embeds) -> None:
135
125
  num_tokens = None
136
126
  if input_ids is not None:
137
127
  num_tokens = input_ids.shape[0]
@@ -145,6 +135,12 @@ class CompilationManager:
145
135
  ShardingAxisName.ATTN_DATA, )) if dp_size > 1 else None
146
136
 
147
137
  # Keep existing pattern for complex array operations
138
+ block_tables = self.runner.block_table_cpu[:self.runner.max_num_reqs]
139
+ block_tables = block_tables.reshape(-1)
140
+ block_tables = device_array(self.runner.mesh,
141
+ block_tables,
142
+ sharding=dp_sharding)
143
+
148
144
  seq_lens = self._create_dummy_tensor((self.runner.max_num_reqs, ),
149
145
  jnp.int32, dp_sharding)
150
146
  query_start_loc = self._create_dummy_tensor(
@@ -156,49 +152,26 @@ class CompilationManager:
156
152
  request_distribution,
157
153
  sharding=dp_sharding)
158
154
 
159
- attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
160
- uniform_attention_metadata: AttentionMetadata = None
161
- for kv_cache_gid, kv_cache_group in enumerate(
162
- self.runner.kv_cache_config.kv_cache_groups):
163
- block_tables = self.runner.block_tables_cpu[
164
- kv_cache_gid][:self.runner.max_num_reqs]
165
- block_tables = block_tables.reshape(-1)
166
- block_tables = device_array(self.runner.mesh,
167
- block_tables,
168
- sharding=dp_sharding)
169
-
170
- attention_metadata_gid = AttentionMetadata(
171
- input_positions=positions,
172
- block_tables=block_tables,
173
- seq_lens=seq_lens,
174
- query_start_loc=query_start_loc,
175
- request_distribution=request_distribution,
176
- )
177
- if not self.runner.use_hybrid_kvcache:
178
- # all layers share the same attention metadata
179
- uniform_attention_metadata = attention_metadata_gid
180
- else:
181
- for layer_name in kv_cache_group.layer_names:
182
- attention_metadata_per_layer[
183
- layer_name] = attention_metadata_gid
155
+ attention_metadata = AttentionMetadata(
156
+ input_positions=positions,
157
+ block_tables=block_tables,
158
+ seq_lens=seq_lens,
159
+ query_start_loc=query_start_loc,
160
+ request_distribution=request_distribution,
161
+ )
184
162
 
185
163
  def model_fn_wrapper(
186
164
  state,
187
165
  kv_caches,
188
166
  input_ids,
189
167
  attention_metadata,
190
- positions,
191
168
  inputs_embeds,
192
169
  layer_name_to_kvcache_index,
193
170
  lora_metadata,
194
- intermediate_tensors,
195
- is_first_rank,
196
- is_last_rank,
197
171
  ):
198
172
  kv_caches, hidden_states, _ = self.runner.model_fn(
199
173
  state, kv_caches, input_ids, attention_metadata, inputs_embeds,
200
- positions, layer_name_to_kvcache_index, lora_metadata,
201
- intermediate_tensors, is_first_rank, is_last_rank)
174
+ layer_name_to_kvcache_index, lora_metadata)
202
175
  self.runner.kv_caches = kv_caches
203
176
  return hidden_states
204
177
 
@@ -206,10 +179,6 @@ class CompilationManager:
206
179
  self.runner.lora_config, np.array([num_tokens],
207
180
  dtype=np.int32)):
208
181
  lora_metadata = self.runner.lora_utils.extract_lora_metadata()
209
- if self.runner.use_hybrid_kvcache:
210
- attention_metadata = attention_metadata_per_layer
211
- else:
212
- attention_metadata = uniform_attention_metadata
213
182
  self._run_compilation(
214
183
  name,
215
184
  model_fn_wrapper,
@@ -217,13 +186,9 @@ class CompilationManager:
217
186
  self.runner.kv_caches,
218
187
  input_ids,
219
188
  attention_metadata,
220
- positions,
221
189
  inputs_embeds,
222
190
  tuple(self.runner.layer_name_to_kvcache_index.items()),
223
191
  lora_metadata,
224
- intermediate_tensors,
225
- is_first_rank,
226
- is_last_rank,
227
192
  num_tokens=num_tokens,
228
193
  )
229
194
 
@@ -274,7 +239,6 @@ class CompilationManager:
274
239
  )
275
240
 
276
241
  def _precompile_backbone_text_only(self) -> None:
277
- hidden_size = self.runner.model_config.get_hidden_size()
278
242
  for num_tokens in self.runner.num_tokens_paddings:
279
243
  dp_sharding = NamedSharding(
280
244
  self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, )
@@ -284,28 +248,10 @@ class CompilationManager:
284
248
  dp_sharding)
285
249
  positions = self._create_dummy_tensor((num_tokens, ), jnp.int32,
286
250
  dp_sharding)
287
- is_first_rank = self.runner.is_first_rank
288
- is_last_rank = self.runner.is_last_rank
289
- if is_first_rank:
290
- intermediate_tensors = None
291
- else:
292
- hidden_states = self._create_dummy_tensor(
293
- (num_tokens, hidden_size), jnp.bfloat16)
294
- residual = self._create_dummy_tensor((num_tokens, hidden_size),
295
- jnp.bfloat16)
296
- intermediate_tensors = JaxIntermediateTensors(
297
- tensors={
298
- "hidden_states": hidden_states,
299
- "residual": residual
300
- })
301
- self._precompile_backbone_helper(
302
- f"worker{self.runner.rank} backbone",
303
- input_ids=input_ids,
304
- positions=positions,
305
- inputs_embeds=None,
306
- intermediate_tensors=intermediate_tensors,
307
- is_first_rank=is_first_rank,
308
- is_last_rank=is_last_rank)
251
+ self._precompile_backbone_helper("backbone",
252
+ input_ids=input_ids,
253
+ positions=positions,
254
+ inputs_embeds=None)
309
255
 
310
256
  def _precompile_backbone_with_inputs_embeds(self) -> None:
311
257
  hidden_size = self.runner.model_config.get_hidden_size()
@@ -319,28 +265,10 @@ class CompilationManager:
319
265
  else:
320
266
  positions = self._create_dummy_tensor((num_tokens, ),
321
267
  jnp.int32)
322
- is_first_rank = self.runner.is_first_rank
323
- is_last_rank = self.runner.is_last_rank
324
- if not is_first_rank:
325
- hidden_states = self._create_dummy_tensor(
326
- (num_tokens, hidden_size), jnp.bfloat16)
327
- residual = self._create_dummy_tensor((num_tokens, hidden_size),
328
- jnp.bfloat16)
329
- intermediate_tensors = JaxIntermediateTensors(
330
- tensors={
331
- "hidden_states": hidden_states,
332
- "residual": residual
333
- })
334
- else:
335
- intermediate_tensors = None
336
- self._precompile_backbone_helper(
337
- f"worker{self.runner.rank} backbone with embeds",
338
- input_ids=None,
339
- positions=positions,
340
- inputs_embeds=inputs_embeds,
341
- intermediate_tensors=intermediate_tensors,
342
- is_first_rank=is_first_rank,
343
- is_last_rank=is_last_rank)
268
+ self._precompile_backbone_helper("backbone with embeds",
269
+ input_ids=None,
270
+ positions=positions,
271
+ inputs_embeds=inputs_embeds)
344
272
 
345
273
  def _precompile_select_from_array_helper(
346
274
  self,
@@ -408,7 +336,7 @@ class CompilationManager:
408
336
  self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None))
409
337
  dp_size = self.runner.vllm_config.sharding_config.total_dp_size
410
338
  self._precompile_select_from_array_helper(
411
- name=f"worker{self.runner.rank} select all logits",
339
+ name="select all logits",
412
340
  source_paddings=self.runner.num_tokens_paddings,
413
341
  indices_paddings=index_paddings,
414
342
  hidden_dim=hsize,
@@ -419,8 +347,7 @@ class CompilationManager:
419
347
  if self.runner.speculative_config:
420
348
  vocab_size = self.runner.model_config.get_vocab_size()
421
349
  self._precompile_select_from_array_helper(
422
- name=
423
- f"worker{self.runner.rank} select bonus tokens for spec decoding",
350
+ name="select bonus tokens for spec decoding",
424
351
  source_paddings=self.runner.num_logits_paddings,
425
352
  indices_paddings=self.runner.num_reqs_paddings,
426
353
  hidden_dim=vocab_size,
@@ -428,8 +355,7 @@ class CompilationManager:
428
355
  PartitionSpec(None, "model")),
429
356
  )
430
357
  self._precompile_select_from_array_helper(
431
- name=
432
- f"worker{self.runner.rank} select target tokens for spec decoding",
358
+ name="select target tokens for spec decoding",
433
359
  source_paddings=self.runner.num_logits_paddings,
434
360
  indices_paddings=self.runner.num_logits_paddings,
435
361
  hidden_dim=vocab_size,
@@ -452,7 +378,7 @@ class CompilationManager:
452
378
  np.array([num_reqs], dtype=np.int32)):
453
379
  lora_metadata = self.runner.lora_utils.extract_lora_metadata()
454
380
  self._run_compilation(
455
- f"worker{self.runner.rank} compute_logits",
381
+ "compute_logits",
456
382
  self.runner.compute_logits_fn,
457
383
  self.runner.state,
458
384
  hidden_states,
@@ -494,7 +420,7 @@ class CompilationManager:
494
420
  do_sampling=do_sampling,
495
421
  )
496
422
  self._run_compilation(
497
- f"worker{self.runner.rank} sample",
423
+ "sample",
498
424
  sample,
499
425
  self.runner.rng_params_for_sampling,
500
426
  self.runner.mesh,
@@ -535,7 +461,7 @@ class CompilationManager:
535
461
  logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16)
536
462
  token_ids = self._create_dummy_tensor((num_reqs, ), jnp.int32)
537
463
  self._run_compilation(
538
- f"worker{self.runner.rank} gather_logprobs",
464
+ "gather_logprobs",
539
465
  self.runner._compute_and_gather_logprobs,
540
466
  logits,
541
467
  token_ids,
@@ -587,7 +513,7 @@ class CompilationManager:
587
513
  do_sampling=do_sampling)
588
514
 
589
515
  self._run_compilation(
590
- f"worker{self.runner.rank} {compilation_name}",
516
+ compilation_name,
591
517
  self.runner.rejection_sampler,
592
518
  draft_token_ids,
593
519
  num_draft_tokens,
@@ -604,9 +530,7 @@ class CompilationManager:
604
530
  def _precompile_eagle3_helpers(self) -> None:
605
531
  logger.info(
606
532
  "Compiling eagle3 jitted helpers with different input shapes.")
607
- target_hidden_size = self.runner.model_config.get_hidden_size()
608
- draft_hidden_size = self.runner.speculative_config.draft_model_config.get_hidden_size(
609
- )
533
+ hidden_size = self.runner.model_config.get_hidden_size()
610
534
  dtype = self.runner.model_config.dtype
611
535
 
612
536
  num_kv_cache_groups = len(self.runner.kv_cache_config.kv_cache_groups)
@@ -653,11 +577,10 @@ class CompilationManager:
653
577
 
654
578
  for num_logits in self.runner.num_logits_paddings:
655
579
  hidden_states = self._create_dummy_tensor(
656
- (num_logits, draft_hidden_size), jnp.bfloat16)
580
+ (num_logits, hidden_size), jnp.bfloat16)
657
581
  self._run_compilation(
658
582
  "eagle3_get_draft_token_ids",
659
583
  self.runner.drafter._get_draft_token_ids,
660
- self.runner.drafter.state,
661
584
  hidden_states,
662
585
  num_logits=num_logits,
663
586
  )
@@ -665,8 +588,8 @@ class CompilationManager:
665
588
  input_ids_loop = self._create_dummy_tensor(
666
589
  (self.runner.max_num_reqs, ), jnp.int32,
667
590
  NamedSharding(self.runner.mesh, PartitionSpec()))
668
- draft_hidden_state_loop = self._create_dummy_tensor(
669
- (self.runner.max_num_reqs, draft_hidden_size), dtype,
591
+ target_hidden_state_loop = self._create_dummy_tensor(
592
+ (self.runner.max_num_reqs, hidden_size), dtype,
670
593
  NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
671
594
  next_token_ids = self._create_dummy_tensor(
672
595
  (self.runner.max_num_reqs, ), jnp.int32)
@@ -674,12 +597,9 @@ class CompilationManager:
674
597
  (self.runner.max_num_reqs, ), jnp.int32)
675
598
  for num_tokens in self.runner.num_tokens_paddings:
676
599
  aux_hidden_states = [
677
- self._create_dummy_tensor((num_tokens, target_hidden_size),
678
- dtype),
679
- self._create_dummy_tensor((num_tokens, target_hidden_size),
680
- dtype),
681
- self._create_dummy_tensor((num_tokens, target_hidden_size),
682
- dtype),
600
+ self._create_dummy_tensor((num_tokens, hidden_size), dtype),
601
+ self._create_dummy_tensor((num_tokens, hidden_size), dtype),
602
+ self._create_dummy_tensor((num_tokens, hidden_size), dtype),
683
603
  ]
684
604
 
685
605
  positions = self._create_dummy_tensor((num_tokens, ), jnp.int32)
@@ -702,23 +622,23 @@ class CompilationManager:
702
622
  num_reqs,
703
623
  ):
704
624
  target_hidden_states, input_ids, last_token_indices, _ = self.runner.drafter._filter_token_and_prepare_initial_inputs(
705
- self.runner.drafter.state, token_indices, query_start_loc,
706
- seq_lens, input_ids, aux_hidden_states, attention_metadata,
707
- next_token_ids, num_reqs)
625
+ token_indices, query_start_loc, seq_lens, input_ids,
626
+ aux_hidden_states, attention_metadata, next_token_ids,
627
+ num_reqs)
708
628
  return target_hidden_states, input_ids, last_token_indices
709
629
 
710
630
  input_ids = self._create_dummy_tensor((num_tokens, ), jnp.int32)
711
631
  aux_hidden_states = [
712
632
  self._create_dummy_tensor(
713
- (num_tokens, target_hidden_size), jnp.bfloat16,
633
+ (num_tokens, hidden_size), jnp.bfloat16,
714
634
  NamedSharding(self.runner.mesh, PartitionSpec(None,
715
635
  None))),
716
636
  self._create_dummy_tensor(
717
- (num_tokens, target_hidden_size), jnp.bfloat16,
637
+ (num_tokens, hidden_size), jnp.bfloat16,
718
638
  NamedSharding(self.runner.mesh, PartitionSpec(None,
719
639
  None))),
720
640
  self._create_dummy_tensor(
721
- (num_tokens, target_hidden_size), jnp.bfloat16,
641
+ (num_tokens, hidden_size), jnp.bfloat16,
722
642
  NamedSharding(self.runner.mesh, PartitionSpec(None,
723
643
  None))),
724
644
  ]
@@ -750,17 +670,17 @@ class CompilationManager:
750
670
  state,
751
671
  kv_caches,
752
672
  input_ids,
753
- draft_hidden_states,
673
+ target_hidden_states,
754
674
  attention_metadata,
755
675
  ):
756
676
  kv_caches, hidden_states, _ = self.runner.drafter.model_fn(
757
- state, kv_caches, input_ids, draft_hidden_states,
677
+ state, kv_caches, input_ids, target_hidden_states,
758
678
  attention_metadata)
759
679
  self.runner.kv_caches = kv_caches
760
680
  return hidden_states
761
681
 
762
- draft_hidden_states = self._create_dummy_tensor(
763
- (num_tokens, draft_hidden_size), dtype,
682
+ target_hidden_states = self._create_dummy_tensor(
683
+ (num_tokens, hidden_size), dtype,
764
684
  NamedSharding(self.runner.mesh, PartitionSpec(None, "model")))
765
685
  input_ids = self._create_dummy_tensor(
766
686
  (num_tokens, ), jnp.int32,
@@ -771,7 +691,7 @@ class CompilationManager:
771
691
  self.runner.drafter.state,
772
692
  self.runner.kv_caches,
773
693
  input_ids,
774
- draft_hidden_states,
694
+ target_hidden_states,
775
695
  attention_metadata,
776
696
  num_tokens=num_tokens,
777
697
  )
@@ -781,7 +701,6 @@ class CompilationManager:
781
701
  self._run_compilation(
782
702
  "eagle3_prepare_hidden_states_and_input_ids",
783
703
  self.runner.drafter._prepare_hidden_states_and_input_ids,
784
- self.runner.drafter.state,
785
704
  aux_hidden_states,
786
705
  query_start_loc,
787
706
  target_token_ids,
@@ -804,19 +723,18 @@ class CompilationManager:
804
723
  self.runner.drafter.state,
805
724
  self.runner.kv_caches,
806
725
  input_ids_loop,
807
- draft_hidden_state_loop,
726
+ target_hidden_state_loop,
808
727
  attention_metadata,
809
728
  num_tokens=num_tokens,
810
729
  )
811
730
 
812
731
  hidden_states = self._create_dummy_tensor(
813
- (num_tokens, draft_hidden_size), jnp.bfloat16,
732
+ (num_tokens, hidden_size), jnp.bfloat16,
814
733
  NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
815
734
 
816
735
  self._run_compilation(
817
736
  "eagle3_select_inputs_for_loop_speculation",
818
737
  self.runner.drafter._select_inputs_for_loop_speculation,
819
- self.runner.drafter.state,
820
738
  positions,
821
739
  hidden_states,
822
740
  hidden_states,
@@ -827,7 +745,6 @@ class CompilationManager:
827
745
  self._run_compilation(
828
746
  "eagle3_select_draft_token_ids",
829
747
  self.runner.drafter._select_draft_token_ids,
830
- self.runner.drafter.state,
831
748
  hidden_states,
832
749
  last_token_indices,
833
750
  num_tokens=num_tokens,
@@ -82,7 +82,7 @@ def create_kv_caches(
82
82
  ShardingAxisName.ATTN_HEAD))
83
83
 
84
84
  def _allocate() -> jax.Array:
85
- return jnp.zeros(
85
+ return jnp.empty(
86
86
  shape=cache_shape,
87
87
  dtype=cache_dtype,
88
88
  )
@@ -1,16 +1,15 @@
1
1
  import functools
2
+ import math
2
3
  from typing import TYPE_CHECKING, Dict, List
3
4
 
4
5
  import jax
5
6
  import jax.numpy as jnp
6
- import numpy as np
7
7
  import vllm.envs as envs
8
8
  from jax.sharding import NamedSharding, PartitionSpec
9
9
  from torchax.ops.mappings import t2j_dtype
10
+ from vllm.attention import Attention
10
11
  from vllm.attention.backends.abstract import AttentionType
11
- from vllm.attention.layer import Attention
12
12
  from vllm.config import get_layers_from_vllm_config
13
- from vllm.utils.math_utils import cdiv
14
13
  from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
15
14
  KVCacheSpec, MLAAttentionSpec,
16
15
  SlidingWindowSpec)
@@ -176,11 +175,6 @@ class KVCacheManager:
176
175
  )
177
176
  self.runner.input_batch = new_input_batch
178
177
  self.runner.persistent_batch_manager.input_batch = new_input_batch
179
- self.runner.block_tables_cpu = [
180
- np.zeros((self.runner.max_num_reqs,
181
- cdiv(self.runner.max_model_len, block_size)),
182
- dtype=np.int32) for block_size in block_sizes
183
- ]
184
178
 
185
179
  def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
186
180
  self.maybe_reinitialize_input_batch(kv_cache_config)
@@ -196,7 +190,7 @@ class KVCacheManager:
196
190
  num_blocks = kv_cache_tensor.size // page_size_bytes
197
191
  dp_size = self.runner.vllm_config.sharding_config.total_dp_size
198
192
  # num_blocks must be a multiple of dp_size
199
- num_blocks = (num_blocks // dp_size) * dp_size
193
+ num_blocks = math.ceil(num_blocks / dp_size) * dp_size
200
194
  # NOTE: we'll multiply the num_kv_heads by 2 in the function
201
195
  kv_cache = create_kv_caches(
202
196
  num_blocks=num_blocks,
@@ -289,8 +283,13 @@ class KVCacheManager:
289
283
 
290
284
  def _update_layer(cache, slices):
291
285
  """The function to apply to each layer's cache and slices."""
292
- reshaped_slices = slices.reshape(-1, block_size, *slices.shape[1:])
293
- cache.at[block_numbers].set(reshaped_slices)
286
+ reshaped_slices = slices.reshape(-1, 1, block_size,
287
+ *slices.shape[1:])
288
+ for (i, block_idx) in enumerate(block_numbers):
289
+ cache = jax.lax.dynamic_update_slice_in_dim(cache,
290
+ reshaped_slices[i],
291
+ block_idx,
292
+ axis=0)
294
293
  return cache
295
294
 
296
295
  return jax.tree.map(_update_layer, kv_caches, kv_cache_slices)
@@ -343,12 +342,16 @@ class KVCacheManager:
343
342
  """
344
343
  if block_ids == list(range(block_ids[0],
345
344
  block_ids[0] + len(block_ids))):
346
- batched_kv_cache_per_layer = self._jitted_gather_continuous_kv_cache(
347
- self.runner.kv_caches, block_ids[0], len(block_ids))
345
+ with runner_utils.LatencyTracker(
346
+ "BatchedGatherKVSlices-for-blocks"):
347
+ batched_kv_cache_per_layer = self._jitted_gather_continuous_kv_cache(
348
+ self.runner.kv_caches, block_ids[0], len(block_ids))
348
349
 
349
350
  else:
350
- batched_kv_cache_per_layer = self._jitted_gather_kv_cache(
351
- self.runner.kv_caches, jnp.array(block_ids))
351
+ with runner_utils.LatencyTracker(
352
+ "BatchedGatherKVSlices-for-blocks"):
353
+ batched_kv_cache_per_layer = self._jitted_gather_kv_cache(
354
+ self.runner.kv_caches, jnp.array(block_ids))
352
355
  return batched_kv_cache_per_layer
353
356
 
354
357
  def transfer_kv_cache(self,
@@ -437,7 +440,6 @@ class KVCacheManager:
437
440
  kv_cache_slices,
438
441
  start_block,
439
442
  )
440
- jax.block_until_ready(self.runner.kv_caches)
441
443
  else:
442
444
  with runner_utils.LatencyTracker(
443
445
  f"JittedInsertKVCache-b{len(block_numbers)}"):
@@ -449,7 +451,6 @@ class KVCacheManager:
449
451
  kv_cache_slices,
450
452
  jnp.array(block_numbers),
451
453
  )
452
- jax.block_until_ready(self.runner.kv_caches)
453
454
 
454
455
  logger.debug(
455
456
  f"Updated kv cache entries cnt={len(self.runner.kv_caches)}")
@@ -14,13 +14,12 @@ class PersistentBatchManager:
14
14
  def __init__(self, requests: Dict[str, CachedRequestState],
15
15
  input_batch: InputBatch, encoder_cache: Dict[str,
16
16
  'jax.Array'],
17
- uses_mrope: bool, model_config, is_last_rank: bool):
17
+ uses_mrope: bool, model_config):
18
18
  self.requests = requests
19
19
  self.input_batch = input_batch
20
20
  self.encoder_cache = encoder_cache
21
21
  self.uses_mrope = uses_mrope
22
22
  self.model_config = model_config
23
- self.is_last_rank = is_last_rank
24
23
 
25
24
  def _reorder_batch(self, scheduler_output: "VllmSchedulerOutput") -> int:
26
25
  """ Reorder the sheduled requests to RPA kernel friendly distribution
@@ -180,35 +179,9 @@ class PersistentBatchManager:
180
179
  num_computed_tokens = req_data.num_computed_tokens[i]
181
180
  new_block_ids = req_data.new_block_ids[i]
182
181
  resumed_from_preemption = req_data.resumed_from_preemption[i]
183
- num_output_tokens = req_data.num_output_tokens[i]
184
182
 
185
183
  # Update the cached states.
186
184
  req_state.num_computed_tokens = num_computed_tokens
187
- req_index = self.input_batch.req_id_to_index.get(req_id)
188
-
189
- if not self.is_last_rank:
190
- # When using PP, the scheduler sends the sampled tokens back,
191
- # because there's no direct communication between the first-
192
- # stage worker and the last-stage worker.
193
- new_token_ids = req_data.new_token_ids[i]
194
- # Add the sampled token(s) from the previous step (if any).
195
- # This doesn't include "unverified" tokens like spec tokens.
196
- num_new_tokens = (num_computed_tokens + len(new_token_ids) -
197
- req_state.num_tokens)
198
- if num_new_tokens == 1:
199
- req_state.output_token_ids.append(new_token_ids[-1])
200
- elif num_new_tokens > 0:
201
- req_state.output_token_ids.extend(
202
- new_token_ids[-num_new_tokens:])
203
- elif num_output_tokens < len(req_state.output_token_ids):
204
- del req_state.output_token_ids[num_output_tokens:]
205
- if req_index is not None:
206
- end_idx = (self.input_batch.num_prompt_tokens[req_index] +
207
- num_output_tokens)
208
- self.input_batch.num_tokens[req_index] = end_idx
209
- self.input_batch.num_tokens_no_spec[req_index] = end_idx
210
-
211
- # Update the block IDs.
212
185
  if not resumed_from_preemption:
213
186
  if new_block_ids is not None:
214
187
  # Append the new blocks to the existing block IDs.
@@ -221,6 +194,7 @@ class PersistentBatchManager:
221
194
  # Replace the existing block IDs with the new ones.
222
195
  req_state.block_ids = new_block_ids
223
196
 
197
+ req_index = self.input_batch.req_id_to_index.get(req_id)
224
198
  if req_index is None:
225
199
  # The request is not in the persistent batch.
226
200
  # The request was either preempted and resumed later, or was not
@@ -235,18 +209,6 @@ class PersistentBatchManager:
235
209
  self.input_batch.block_table.append_row(
236
210
  new_block_ids, req_index)
237
211
 
238
- # For the last rank, we don't need to update the token_ids_cpu
239
- # because the sampled tokens are already cached.
240
- if not self.is_last_rank:
241
- start_token_index = num_computed_tokens
242
- end_token_index = num_computed_tokens + len(new_token_ids)
243
- self.input_batch.token_ids_cpu[
244
- req_index,
245
- start_token_index:end_token_index] = new_token_ids
246
- self.input_batch.num_tokens_no_spec[
247
- req_index] = end_token_index
248
- self.input_batch.num_tokens[req_index] = end_token_index
249
-
250
212
  # Add spec_token_ids to token_ids_cpu.
251
213
  spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
252
214
  req_id, ())
@@ -61,10 +61,11 @@ class StructuredDecodingManager:
61
61
  self.runner.require_structured_out_cpu.fill(0)
62
62
 
63
63
  sorted_struct_requests = sorted(
64
- grammar_output.structured_output_request_ids)
64
+ grammar_output.structured_output_request_ids.items(),
65
+ key=lambda item: item[1])
65
66
 
66
67
  cumulative_mask_idx = 0
67
- for req_id in sorted_struct_requests:
68
+ for req_id, _ in sorted_struct_requests:
68
69
  if req_id not in self.runner.input_batch.req_id_to_index:
69
70
  continue
70
71
  batch_index = self.runner.input_batch.req_id_to_index[req_id]