tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.11.1.dev202512030818__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (25) hide show
  1. tests/test_envs.py +32 -11
  2. tests/test_utils.py +1 -2
  3. tpu_inference/distributed/tpu_connector.py +1 -1
  4. tpu_inference/envs.py +60 -7
  5. tpu_inference/executors/ray_distributed_executor.py +5 -1
  6. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +72 -19
  7. tpu_inference/layers/common/sharding.py +3 -4
  8. tpu_inference/layers/vllm/quantization/mxfp4.py +2 -1
  9. tpu_inference/models/common/model_loader.py +3 -1
  10. tpu_inference/models/jax/utils/quantization/quantization_utils.py +3 -6
  11. tpu_inference/models/vllm/vllm_model_wrapper.py +1 -2
  12. tpu_inference/platforms/tpu_platform.py +13 -20
  13. tpu_inference/runner/compilation_manager.py +87 -27
  14. tpu_inference/runner/kv_cache_manager.py +8 -15
  15. tpu_inference/runner/persistent_batch_manager.py +40 -2
  16. tpu_inference/runner/tpu_runner.py +68 -45
  17. tpu_inference/runner/utils.py +2 -2
  18. tpu_inference/spec_decode/jax/eagle3.py +52 -19
  19. tpu_inference/utils.py +31 -9
  20. tpu_inference/worker/tpu_worker.py +2 -2
  21. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/METADATA +1 -1
  22. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/RECORD +25 -25
  23. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/WHEEL +0 -0
  24. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/licenses/LICENSE +0 -0
  25. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,13 @@
1
- import os
2
1
  import time
3
2
  from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
4
3
 
5
4
  import jax
6
5
  import jax.numpy as jnp
7
6
  import numpy as np
8
- import vllm.envs as envs
7
+ import vllm.envs as vllm_envs
9
8
  from jax.sharding import NamedSharding, PartitionSpec
10
9
 
10
+ import tpu_inference.envs as envs
11
11
  from tpu_inference.core.disagg_utils import is_disagg_enabled
12
12
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
13
13
  from tpu_inference.layers.common.sharding import ShardingAxisName
@@ -15,6 +15,8 @@ from tpu_inference.layers.jax.sample.sampling import sample
15
15
  from tpu_inference.layers.jax.sample.sampling_metadata import \
16
16
  TPUSupportedSamplingMetadata
17
17
  from tpu_inference.logger import init_logger
18
+ from tpu_inference.models.jax.jax_intermediate_tensor import \
19
+ JaxIntermediateTensors
18
20
  from tpu_inference.utils import device_array
19
21
 
20
22
  if TYPE_CHECKING:
@@ -30,10 +32,10 @@ class CompilationManager:
30
32
 
31
33
  def __init__(self, runner: "TPUModelRunner"):
32
34
  self.runner = runner
33
- if not envs.VLLM_DISABLE_COMPILE_CACHE:
35
+ if not vllm_envs.VLLM_DISABLE_COMPILE_CACHE:
34
36
  logger.info("Enabling JAX compile cache.")
35
37
  jax.config.update("jax_compilation_cache_dir",
36
- envs.VLLM_XLA_CACHE_PATH)
38
+ vllm_envs.VLLM_XLA_CACHE_PATH)
37
39
 
38
40
  def _create_dummy_tensor(self,
39
41
  shape: Tuple[int, ...],
@@ -67,8 +69,7 @@ class CompilationManager:
67
69
  logger.info("Compilation finished in %.2f [secs].", end - start)
68
70
 
69
71
  def capture_model(self) -> None:
70
- if os.getenv("SKIP_JAX_PRECOMPILE",
71
- False) or self.runner.model_config.enforce_eager:
72
+ if envs.SKIP_JAX_PRECOMPILE or self.runner.model_config.enforce_eager:
72
73
  return
73
74
  logger.info("Precompile all the subgraphs with possible input shapes.")
74
75
 
@@ -81,6 +82,8 @@ class CompilationManager:
81
82
  self._precompile_backbone_with_inputs_embeds()
82
83
  if self.runner.scheduler_config.async_scheduling:
83
84
  self._precompile_substitute_placeholder_token()
85
+ if not self.runner.is_last_rank:
86
+ return
84
87
  self._precompile_select_from_array()
85
88
  self._precompile_compute_logits()
86
89
  self._precompile_disagg_utils()
@@ -120,8 +123,15 @@ class CompilationManager:
120
123
  num_tokens=num_tokens,
121
124
  )
122
125
 
123
- def _precompile_backbone_helper(self, name, *, input_ids, positions,
124
- inputs_embeds) -> None:
126
+ def _precompile_backbone_helper(self,
127
+ name,
128
+ *,
129
+ input_ids,
130
+ positions,
131
+ inputs_embeds,
132
+ intermediate_tensors=None,
133
+ is_first_rank=True,
134
+ is_last_rank=True) -> None:
125
135
  num_tokens = None
126
136
  if input_ids is not None:
127
137
  num_tokens = input_ids.shape[0]
@@ -181,10 +191,14 @@ class CompilationManager:
181
191
  inputs_embeds,
182
192
  layer_name_to_kvcache_index,
183
193
  lora_metadata,
194
+ intermediate_tensors,
195
+ is_first_rank,
196
+ is_last_rank,
184
197
  ):
185
198
  kv_caches, hidden_states, _ = self.runner.model_fn(
186
199
  state, kv_caches, input_ids, attention_metadata, inputs_embeds,
187
- positions, layer_name_to_kvcache_index, lora_metadata)
200
+ positions, layer_name_to_kvcache_index, lora_metadata,
201
+ intermediate_tensors, is_first_rank, is_last_rank)
188
202
  self.runner.kv_caches = kv_caches
189
203
  return hidden_states
190
204
 
@@ -207,6 +221,9 @@ class CompilationManager:
207
221
  inputs_embeds,
208
222
  tuple(self.runner.layer_name_to_kvcache_index.items()),
209
223
  lora_metadata,
224
+ intermediate_tensors,
225
+ is_first_rank,
226
+ is_last_rank,
210
227
  num_tokens=num_tokens,
211
228
  )
212
229
 
@@ -257,6 +274,7 @@ class CompilationManager:
257
274
  )
258
275
 
259
276
  def _precompile_backbone_text_only(self) -> None:
277
+ hidden_size = self.runner.model_config.get_hidden_size()
260
278
  for num_tokens in self.runner.num_tokens_paddings:
261
279
  dp_sharding = NamedSharding(
262
280
  self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, )
@@ -266,10 +284,28 @@ class CompilationManager:
266
284
  dp_sharding)
267
285
  positions = self._create_dummy_tensor((num_tokens, ), jnp.int32,
268
286
  dp_sharding)
269
- self._precompile_backbone_helper("backbone",
270
- input_ids=input_ids,
271
- positions=positions,
272
- inputs_embeds=None)
287
+ is_first_rank = self.runner.is_first_rank
288
+ is_last_rank = self.runner.is_last_rank
289
+ if is_first_rank:
290
+ intermediate_tensors = None
291
+ else:
292
+ hidden_states = self._create_dummy_tensor(
293
+ (num_tokens, hidden_size), jnp.bfloat16)
294
+ residual = self._create_dummy_tensor((num_tokens, hidden_size),
295
+ jnp.bfloat16)
296
+ intermediate_tensors = JaxIntermediateTensors(
297
+ tensors={
298
+ "hidden_states": hidden_states,
299
+ "residual": residual
300
+ })
301
+ self._precompile_backbone_helper(
302
+ f"worker{self.runner.rank} backbone",
303
+ input_ids=input_ids,
304
+ positions=positions,
305
+ inputs_embeds=None,
306
+ intermediate_tensors=intermediate_tensors,
307
+ is_first_rank=is_first_rank,
308
+ is_last_rank=is_last_rank)
273
309
 
274
310
  def _precompile_backbone_with_inputs_embeds(self) -> None:
275
311
  hidden_size = self.runner.model_config.get_hidden_size()
@@ -283,10 +319,28 @@ class CompilationManager:
283
319
  else:
284
320
  positions = self._create_dummy_tensor((num_tokens, ),
285
321
  jnp.int32)
286
- self._precompile_backbone_helper("backbone with embeds",
287
- input_ids=None,
288
- positions=positions,
289
- inputs_embeds=inputs_embeds)
322
+ is_first_rank = self.runner.is_first_rank
323
+ is_last_rank = self.runner.is_last_rank
324
+ if not is_first_rank:
325
+ hidden_states = self._create_dummy_tensor(
326
+ (num_tokens, hidden_size), jnp.bfloat16)
327
+ residual = self._create_dummy_tensor((num_tokens, hidden_size),
328
+ jnp.bfloat16)
329
+ intermediate_tensors = JaxIntermediateTensors(
330
+ tensors={
331
+ "hidden_states": hidden_states,
332
+ "residual": residual
333
+ })
334
+ else:
335
+ intermediate_tensors = None
336
+ self._precompile_backbone_helper(
337
+ f"worker{self.runner.rank} backbone with embeds",
338
+ input_ids=None,
339
+ positions=positions,
340
+ inputs_embeds=inputs_embeds,
341
+ intermediate_tensors=intermediate_tensors,
342
+ is_first_rank=is_first_rank,
343
+ is_last_rank=is_last_rank)
290
344
 
291
345
  def _precompile_select_from_array_helper(
292
346
  self,
@@ -354,7 +408,7 @@ class CompilationManager:
354
408
  self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None))
355
409
  dp_size = self.runner.vllm_config.sharding_config.total_dp_size
356
410
  self._precompile_select_from_array_helper(
357
- name="select all logits",
411
+ name=f"worker{self.runner.rank} select all logits",
358
412
  source_paddings=self.runner.num_tokens_paddings,
359
413
  indices_paddings=index_paddings,
360
414
  hidden_dim=hsize,
@@ -365,7 +419,8 @@ class CompilationManager:
365
419
  if self.runner.speculative_config:
366
420
  vocab_size = self.runner.model_config.get_vocab_size()
367
421
  self._precompile_select_from_array_helper(
368
- name="select bonus tokens for spec decoding",
422
+ name=
423
+ f"worker{self.runner.rank} select bonus tokens for spec decoding",
369
424
  source_paddings=self.runner.num_logits_paddings,
370
425
  indices_paddings=self.runner.num_reqs_paddings,
371
426
  hidden_dim=vocab_size,
@@ -373,7 +428,8 @@ class CompilationManager:
373
428
  PartitionSpec(None, "model")),
374
429
  )
375
430
  self._precompile_select_from_array_helper(
376
- name="select target tokens for spec decoding",
431
+ name=
432
+ f"worker{self.runner.rank} select target tokens for spec decoding",
377
433
  source_paddings=self.runner.num_logits_paddings,
378
434
  indices_paddings=self.runner.num_logits_paddings,
379
435
  hidden_dim=vocab_size,
@@ -396,7 +452,7 @@ class CompilationManager:
396
452
  np.array([num_reqs], dtype=np.int32)):
397
453
  lora_metadata = self.runner.lora_utils.extract_lora_metadata()
398
454
  self._run_compilation(
399
- "compute_logits",
455
+ f"worker{self.runner.rank} compute_logits",
400
456
  self.runner.compute_logits_fn,
401
457
  self.runner.state,
402
458
  hidden_states,
@@ -438,7 +494,7 @@ class CompilationManager:
438
494
  do_sampling=do_sampling,
439
495
  )
440
496
  self._run_compilation(
441
- "sample",
497
+ f"worker{self.runner.rank} sample",
442
498
  sample,
443
499
  self.runner.rng_params_for_sampling,
444
500
  self.runner.mesh,
@@ -479,7 +535,7 @@ class CompilationManager:
479
535
  logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16)
480
536
  token_ids = self._create_dummy_tensor((num_reqs, ), jnp.int32)
481
537
  self._run_compilation(
482
- "gather_logprobs",
538
+ f"worker{self.runner.rank} gather_logprobs",
483
539
  self.runner._compute_and_gather_logprobs,
484
540
  logits,
485
541
  token_ids,
@@ -531,7 +587,7 @@ class CompilationManager:
531
587
  do_sampling=do_sampling)
532
588
 
533
589
  self._run_compilation(
534
- compilation_name,
590
+ f"worker{self.runner.rank} {compilation_name}",
535
591
  self.runner.rejection_sampler,
536
592
  draft_token_ids,
537
593
  num_draft_tokens,
@@ -601,6 +657,7 @@ class CompilationManager:
601
657
  self._run_compilation(
602
658
  "eagle3_get_draft_token_ids",
603
659
  self.runner.drafter._get_draft_token_ids,
660
+ self.runner.drafter.state,
604
661
  hidden_states,
605
662
  num_logits=num_logits,
606
663
  )
@@ -645,9 +702,9 @@ class CompilationManager:
645
702
  num_reqs,
646
703
  ):
647
704
  target_hidden_states, input_ids, last_token_indices, _ = self.runner.drafter._filter_token_and_prepare_initial_inputs(
648
- token_indices, query_start_loc, seq_lens, input_ids,
649
- aux_hidden_states, attention_metadata, next_token_ids,
650
- num_reqs)
705
+ self.runner.drafter.state, token_indices, query_start_loc,
706
+ seq_lens, input_ids, aux_hidden_states, attention_metadata,
707
+ next_token_ids, num_reqs)
651
708
  return target_hidden_states, input_ids, last_token_indices
652
709
 
653
710
  input_ids = self._create_dummy_tensor((num_tokens, ), jnp.int32)
@@ -724,6 +781,7 @@ class CompilationManager:
724
781
  self._run_compilation(
725
782
  "eagle3_prepare_hidden_states_and_input_ids",
726
783
  self.runner.drafter._prepare_hidden_states_and_input_ids,
784
+ self.runner.drafter.state,
727
785
  aux_hidden_states,
728
786
  query_start_loc,
729
787
  target_token_ids,
@@ -758,6 +816,7 @@ class CompilationManager:
758
816
  self._run_compilation(
759
817
  "eagle3_select_inputs_for_loop_speculation",
760
818
  self.runner.drafter._select_inputs_for_loop_speculation,
819
+ self.runner.drafter.state,
761
820
  positions,
762
821
  hidden_states,
763
822
  hidden_states,
@@ -768,6 +827,7 @@ class CompilationManager:
768
827
  self._run_compilation(
769
828
  "eagle3_select_draft_token_ids",
770
829
  self.runner.drafter._select_draft_token_ids,
830
+ self.runner.drafter.state,
771
831
  hidden_states,
772
832
  last_token_indices,
773
833
  num_tokens=num_tokens,
@@ -289,13 +289,8 @@ class KVCacheManager:
289
289
 
290
290
  def _update_layer(cache, slices):
291
291
  """The function to apply to each layer's cache and slices."""
292
- reshaped_slices = slices.reshape(-1, 1, block_size,
293
- *slices.shape[1:])
294
- for (i, block_idx) in enumerate(block_numbers):
295
- cache = jax.lax.dynamic_update_slice_in_dim(cache,
296
- reshaped_slices[i],
297
- block_idx,
298
- axis=0)
292
+ reshaped_slices = slices.reshape(-1, block_size, *slices.shape[1:])
293
+ cache.at[block_numbers].set(reshaped_slices)
299
294
  return cache
300
295
 
301
296
  return jax.tree.map(_update_layer, kv_caches, kv_cache_slices)
@@ -348,16 +343,12 @@ class KVCacheManager:
348
343
  """
349
344
  if block_ids == list(range(block_ids[0],
350
345
  block_ids[0] + len(block_ids))):
351
- with runner_utils.LatencyTracker(
352
- "BatchedGatherKVSlices-for-blocks"):
353
- batched_kv_cache_per_layer = self._jitted_gather_continuous_kv_cache(
354
- self.runner.kv_caches, block_ids[0], len(block_ids))
346
+ batched_kv_cache_per_layer = self._jitted_gather_continuous_kv_cache(
347
+ self.runner.kv_caches, block_ids[0], len(block_ids))
355
348
 
356
349
  else:
357
- with runner_utils.LatencyTracker(
358
- "BatchedGatherKVSlices-for-blocks"):
359
- batched_kv_cache_per_layer = self._jitted_gather_kv_cache(
360
- self.runner.kv_caches, jnp.array(block_ids))
350
+ batched_kv_cache_per_layer = self._jitted_gather_kv_cache(
351
+ self.runner.kv_caches, jnp.array(block_ids))
361
352
  return batched_kv_cache_per_layer
362
353
 
363
354
  def transfer_kv_cache(self,
@@ -446,6 +437,7 @@ class KVCacheManager:
446
437
  kv_cache_slices,
447
438
  start_block,
448
439
  )
440
+ jax.block_until_ready(self.runner.kv_caches)
449
441
  else:
450
442
  with runner_utils.LatencyTracker(
451
443
  f"JittedInsertKVCache-b{len(block_numbers)}"):
@@ -457,6 +449,7 @@ class KVCacheManager:
457
449
  kv_cache_slices,
458
450
  jnp.array(block_numbers),
459
451
  )
452
+ jax.block_until_ready(self.runner.kv_caches)
460
453
 
461
454
  logger.debug(
462
455
  f"Updated kv cache entries cnt={len(self.runner.kv_caches)}")
@@ -14,12 +14,13 @@ class PersistentBatchManager:
14
14
  def __init__(self, requests: Dict[str, CachedRequestState],
15
15
  input_batch: InputBatch, encoder_cache: Dict[str,
16
16
  'jax.Array'],
17
- uses_mrope: bool, model_config):
17
+ uses_mrope: bool, model_config, is_last_rank: bool):
18
18
  self.requests = requests
19
19
  self.input_batch = input_batch
20
20
  self.encoder_cache = encoder_cache
21
21
  self.uses_mrope = uses_mrope
22
22
  self.model_config = model_config
23
+ self.is_last_rank = is_last_rank
23
24
 
24
25
  def _reorder_batch(self, scheduler_output: "VllmSchedulerOutput") -> int:
25
26
  """ Reorder the sheduled requests to RPA kernel friendly distribution
@@ -179,9 +180,35 @@ class PersistentBatchManager:
179
180
  num_computed_tokens = req_data.num_computed_tokens[i]
180
181
  new_block_ids = req_data.new_block_ids[i]
181
182
  resumed_from_preemption = req_data.resumed_from_preemption[i]
183
+ num_output_tokens = req_data.num_output_tokens[i]
182
184
 
183
185
  # Update the cached states.
184
186
  req_state.num_computed_tokens = num_computed_tokens
187
+ req_index = self.input_batch.req_id_to_index.get(req_id)
188
+
189
+ if not self.is_last_rank:
190
+ # When using PP, the scheduler sends the sampled tokens back,
191
+ # because there's no direct communication between the first-
192
+ # stage worker and the last-stage worker.
193
+ new_token_ids = req_data.new_token_ids[i]
194
+ # Add the sampled token(s) from the previous step (if any).
195
+ # This doesn't include "unverified" tokens like spec tokens.
196
+ num_new_tokens = (num_computed_tokens + len(new_token_ids) -
197
+ req_state.num_tokens)
198
+ if num_new_tokens == 1:
199
+ req_state.output_token_ids.append(new_token_ids[-1])
200
+ elif num_new_tokens > 0:
201
+ req_state.output_token_ids.extend(
202
+ new_token_ids[-num_new_tokens:])
203
+ elif num_output_tokens < len(req_state.output_token_ids):
204
+ del req_state.output_token_ids[num_output_tokens:]
205
+ if req_index is not None:
206
+ end_idx = (self.input_batch.num_prompt_tokens[req_index] +
207
+ num_output_tokens)
208
+ self.input_batch.num_tokens[req_index] = end_idx
209
+ self.input_batch.num_tokens_no_spec[req_index] = end_idx
210
+
211
+ # Update the block IDs.
185
212
  if not resumed_from_preemption:
186
213
  if new_block_ids is not None:
187
214
  # Append the new blocks to the existing block IDs.
@@ -194,7 +221,6 @@ class PersistentBatchManager:
194
221
  # Replace the existing block IDs with the new ones.
195
222
  req_state.block_ids = new_block_ids
196
223
 
197
- req_index = self.input_batch.req_id_to_index.get(req_id)
198
224
  if req_index is None:
199
225
  # The request is not in the persistent batch.
200
226
  # The request was either preempted and resumed later, or was not
@@ -209,6 +235,18 @@ class PersistentBatchManager:
209
235
  self.input_batch.block_table.append_row(
210
236
  new_block_ids, req_index)
211
237
 
238
+ # For the last rank, we don't need to update the token_ids_cpu
239
+ # because the sampled tokens are already cached.
240
+ if not self.is_last_rank:
241
+ start_token_index = num_computed_tokens
242
+ end_token_index = num_computed_tokens + len(new_token_ids)
243
+ self.input_batch.token_ids_cpu[
244
+ req_index,
245
+ start_token_index:end_token_index] = new_token_ids
246
+ self.input_batch.num_tokens_no_spec[
247
+ req_index] = end_token_index
248
+ self.input_batch.num_tokens[req_index] = end_token_index
249
+
212
250
  # Add spec_token_ids to token_ids_cpu.
213
251
  spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
214
252
  req_id, ())
@@ -10,17 +10,16 @@ import jax
10
10
  import jax.numpy as jnp
11
11
  import jaxtyping
12
12
  import numpy as np
13
- import torch
14
- import vllm.envs as envs
13
+ import vllm.envs as vllm_envs
15
14
  from flax import nnx
16
15
  from jax.experimental import mesh_utils
17
16
  from jax.sharding import NamedSharding, PartitionSpec
18
- from torchax.ops.mappings import j2t_dtype
17
+ from torchax.ops.mappings import t2j_dtype
19
18
  from vllm.config import VllmConfig
19
+ from vllm.distributed import get_pp_group
20
20
  from vllm.distributed.kv_transfer import (get_kv_transfer_group,
21
21
  has_kv_transfer_group)
22
22
  from vllm.forward_context import set_forward_context
23
- from vllm.sequence import IntermediateTensors
24
23
  from vllm.tasks import SupportedTask
25
24
  from vllm.utils.math_utils import cdiv
26
25
  from vllm.v1.core.sched.output import GrammarOutput
@@ -35,6 +34,7 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import \
35
34
  KVConnectorModelRunnerMixin
36
35
  from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
37
36
 
37
+ import tpu_inference.envs as envs
38
38
  from tpu_inference import utils as common_utils
39
39
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
40
40
  from tpu_inference.layers.common.sharding import (MESH_AXIS_NAMES,
@@ -48,6 +48,8 @@ from tpu_inference.layers.jax.sample.sampling_metadata import \
48
48
  TPUSupportedSamplingMetadata
49
49
  from tpu_inference.logger import init_logger
50
50
  from tpu_inference.models.common.model_loader import get_model
51
+ from tpu_inference.models.jax.jax_intermediate_tensor import \
52
+ JaxIntermediateTensors
51
53
  from tpu_inference.models.jax.utils.weight_utils import (
52
54
  shard_put, transfer_state_with_mappings)
53
55
  from tpu_inference.runner import utils as runner_utils
@@ -64,7 +66,7 @@ from tpu_inference.runner.structured_decoding_manager import \
64
66
  StructuredDecodingManager
65
67
  from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
66
68
  from tpu_inference.utils import (device_array, make_optimized_mesh,
67
- time_function)
69
+ time_function, to_torch_dtype)
68
70
 
69
71
  logger = init_logger(__name__)
70
72
 
@@ -78,17 +80,6 @@ DUMMY_METADATA = AttentionMetadata(
78
80
  request_distribution=[0, 0, 0],
79
81
  )
80
82
 
81
- TPU_STR_DTYPE_TO_TORCH_DTYPE = {
82
- "half": torch.half,
83
- "bfloat16": torch.bfloat16,
84
- "float": torch.float,
85
- "fp8": torch.float8_e4m3fn,
86
- "fp8_e4m3": torch.float8_e4m3fn,
87
- "fp8_e5m2": torch.float8_e5m2,
88
- "int8": torch.int8,
89
- "uint8": torch.uint8,
90
- }
91
-
92
83
 
93
84
  class AsyncTPUModelRunnerOutput(AsyncModelRunnerOutput):
94
85
  """Holds asynchronous model output specifically from a TPU runner.
@@ -243,6 +234,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
243
234
  self.maybe_forbid_compile = runner_utils.ForbidCompile(
244
235
  ) if envs.VLLM_XLA_CHECK_RECOMPILATION else nullcontext()
245
236
  self.dp_size = self.vllm_config.sharding_config.total_dp_size
237
+ self.rank = rank
238
+ self.is_first_rank = is_first_rank
239
+ self.is_last_rank = is_last_rank
246
240
 
247
241
  self._init_random()
248
242
  self._init_mesh()
@@ -253,31 +247,21 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
253
247
 
254
248
  # Delegate functions to specific manager classes.
255
249
  self.compilation_manager = CompilationManager(self)
256
- self.speculative_decoding_manager = SpeculativeDecodingManager(self)
257
- self.structured_decoding_manager = StructuredDecodingManager(self)
250
+ if self.is_last_rank:
251
+ self.speculative_decoding_manager = SpeculativeDecodingManager(
252
+ self)
253
+ self.structured_decoding_manager = StructuredDecodingManager(self)
258
254
  self.kv_cache_manager = KVCacheManager(self)
259
255
  self.mm_manager = MultiModalManager(self)
260
256
  self.persistent_batch_manager = PersistentBatchManager(
261
257
  self.requests, self.input_batch, self.encoder_cache,
262
- self.uses_mrope, self.model_config)
258
+ self.uses_mrope, self.model_config, self.is_last_rank)
263
259
  self.lora_utils = LoraUtils(self)
264
260
 
265
- cache_config = self.cache_config
266
- if cache_config.cache_dtype == "auto":
267
- model_dtype = self.dtype
268
- if isinstance(model_dtype, str):
269
- self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
270
- elif isinstance(getattr(model_dtype, 'dtype', None), jnp.dtype):
271
- self.kv_cache_dtype = j2t_dtype(model_dtype.dtype)
272
- elif isinstance(model_dtype, torch.dtype):
273
- self.kv_cache_dtype = model_dtype
274
- else:
275
- raise ValueError(
276
- "KV cache is unsupported for model_dtype of %s",
277
- model_dtype)
278
- else:
279
- self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[
280
- cache_config.cache_dtype]
261
+ cache_dtype = self.cache_config.cache_dtype
262
+ if cache_dtype == "auto":
263
+ cache_dtype = self.dtype
264
+ self.kv_cache_dtype = to_torch_dtype(cache_dtype)
281
265
 
282
266
  self._pre_async_results: AsyncPreResults | None = None
283
267
  self._substitute_placeholder_token_fn = _substitute_placeholder_token
@@ -291,7 +275,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
291
275
  self.rng_key = jax.random.key(self.model_config.seed)
292
276
 
293
277
  def _init_mesh(self) -> None:
294
- if os.getenv("NEW_MODEL_DESIGN", False):
278
+ if envs.NEW_MODEL_DESIGN:
295
279
  self.mesh = self._create_new_model_mesh()
296
280
  else:
297
281
  # NOTE(wenxindongwork): The new MoE kernel expects a 2D mesh, so we need
@@ -302,7 +286,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
302
286
  logger.info(f"Init mesh | mesh={self.mesh}")
303
287
 
304
288
  def _create_new_model_mesh(self) -> jax.sharding.Mesh:
305
- num_slices = int(os.environ.get('NUM_SLICES', 1))
289
+ num_slices = envs.NUM_SLICES
306
290
 
307
291
  logger.info(f"Creating new model mesh | devices={len(self.devices)}, "
308
292
  f"num_slices={num_slices}")
@@ -371,7 +355,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
371
355
  devices=self.devices)
372
356
 
373
357
  def _init_phased_profiling(self) -> None:
374
- self.phased_profiling_dir = os.getenv("PHASED_PROFILING_DIR", "")
358
+ self.phased_profiling_dir = envs.PHASED_PROFILING_DIR
375
359
  self.phase_based_profiler = None
376
360
  if self.phased_profiling_dir:
377
361
  self.phase_based_profiler = runner_utils.PhasedBasedProfiler(
@@ -413,7 +397,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
413
397
  min_token_size=max(16, self.dp_size),
414
398
  max_token_size=scheduler_config.max_num_batched_tokens *
415
399
  self.dp_size,
416
- padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
400
+ padding_gap=vllm_envs.VLLM_TPU_BUCKET_PADDING_GAP)
417
401
  self.num_tokens_paddings_per_dp = [
418
402
  padding // self.dp_size for padding in self.num_tokens_paddings
419
403
  ]
@@ -555,12 +539,12 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
555
539
  def execute_model(
556
540
  self,
557
541
  scheduler_output: "VllmSchedulerOutput",
558
- intermediate_tensors: Optional[IntermediateTensors] = None,
559
- ) -> ModelRunnerOutput | None:
542
+ intermediate_tensors: Optional[JaxIntermediateTensors] = None,
543
+ ) -> ModelRunnerOutput | JaxIntermediateTensors | None:
560
544
  if self.execute_model_state is not None:
561
545
  raise RuntimeError("State error: sample_tokens() must be called "
562
546
  "after execute_model() returns None.")
563
- _, output = self._execute_model(scheduler_output)
547
+ _, output = self._execute_model(scheduler_output, intermediate_tensors)
564
548
  return output
565
549
 
566
550
  def sample_tokens(
@@ -686,7 +670,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
686
670
  def _execute_model(
687
671
  self,
688
672
  scheduler_output: "VllmSchedulerOutput",
689
- ) -> tuple[AttentionMetadata, ModelRunnerOutput | None]:
673
+ intermediate_tensors: Optional[JaxIntermediateTensors] = None,
674
+ ) -> tuple[AttentionMetadata, JaxIntermediateTensors | ModelRunnerOutput
675
+ | None]:
690
676
  self.persistent_batch_manager.update_states(
691
677
  scheduler_output, self.get_mrope_input_positions_fn)
692
678
  if not scheduler_output.total_num_scheduled_tokens:
@@ -764,7 +750,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
764
750
  scheduler_output) as kv_connector_output:
765
751
  # NOTE(Wenlong): It takes both `input_ids` and `inputs_embeds`,
766
752
  # but one of them would be `None`
767
-
768
753
  (self.kv_caches, hidden_states,
769
754
  aux_hidden_states) = self.model_fn(
770
755
  self.state,
@@ -775,8 +760,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
775
760
  input_positions,
776
761
  tuple(self.layer_name_to_kvcache_index.items()),
777
762
  lora_metadata,
763
+ intermediate_tensors,
764
+ self.is_first_rank,
765
+ self.is_last_rank,
778
766
  )
779
-
767
+ if not get_pp_group().is_last_rank:
768
+ assert isinstance(hidden_states, JaxIntermediateTensors)
769
+ hidden_states.kv_connector_output = kv_connector_output
770
+ return attn_metadata, hidden_states
780
771
  hidden_states = self._select_from_array_fn(hidden_states,
781
772
  logits_indices)
782
773
  logits = self.compute_logits_fn(
@@ -1719,3 +1710,35 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1719
1710
  mappings=mappings,
1720
1711
  transpose_keys=transpose_keys,
1721
1712
  shard=shard)
1713
+
1714
+ def get_intermediate_tensor_spec(self, num_tokens: int):
1715
+ impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower()
1716
+ jax_dtype = t2j_dtype(self.dtype) if impl == "vllm" else self.dtype
1717
+ num_padded_tokens = runner_utils.get_padded_token_len(
1718
+ self.num_tokens_paddings, num_tokens)
1719
+ sharding = NamedSharding(self.mesh, PartitionSpec())
1720
+ hidden_size = self.model_config.get_hidden_size()
1721
+ spec = jax.ShapeDtypeStruct(shape=(num_padded_tokens, hidden_size),
1722
+ dtype=jax_dtype,
1723
+ sharding=sharding)
1724
+ tensor_spec = {"hidden_states": spec, "residual": spec}
1725
+ return tensor_spec
1726
+
1727
+ def get_uuid_for_jax_transfer(self,
1728
+ scheduler_output: "VllmSchedulerOutput",
1729
+ rank: int, step: int) -> int:
1730
+ '''
1731
+ Get a uuid for jax.transfer, here we use the hash of
1732
+ scheduler_output + counter_step + sender's rank
1733
+ '''
1734
+ scheduler_output_str = ""
1735
+ if not scheduler_output.num_scheduled_tokens:
1736
+ scheduler_output_str = "empty_batch"
1737
+ else:
1738
+ scheduler_output_str = str(
1739
+ sorted(scheduler_output.num_scheduled_tokens.items()))
1740
+ unique_str = f'{scheduler_output_str} {step} {rank}'
1741
+ import hashlib
1742
+ hasher = hashlib.sha1()
1743
+ hasher.update(unique_str.encode('utf-8'))
1744
+ return int.from_bytes(hasher.digest()[:8], 'big')
@@ -15,6 +15,7 @@ import jax
15
15
  from jax._src.interpreters import pxla
16
16
  from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
17
17
 
18
+ from tpu_inference import envs
18
19
  from tpu_inference.logger import init_logger
19
20
  from tpu_inference.runner.input_batch import InputBatch
20
21
 
@@ -306,8 +307,7 @@ class PhasedBasedProfiler:
306
307
  InferencePhase.BALANCED: False
307
308
  }
308
309
  self.default_profiling_options = jax.profiler.ProfileOptions()
309
- self.default_profiling_options.python_tracer_level = os.getenv(
310
- "PYTHON_TRACER_LEVEL", 0)
310
+ self.default_profiling_options.python_tracer_level = envs.PYTHON_TRACER_LEVEL
311
311
 
312
312
  self.current_phase: str = ""
313
313