tpu-inference 0.11.1.dev202511150811__py3-none-any.whl → 0.11.1.dev202512030818__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (54) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/lora/test_layers.py +0 -6
  3. tests/lora/utils.py +0 -8
  4. tests/test_envs.py +32 -11
  5. tests/test_utils.py +1 -2
  6. tpu_inference/__init__.py +22 -3
  7. tpu_inference/core/disagg_utils.py +6 -8
  8. tpu_inference/distributed/tpu_connector.py +3 -4
  9. tpu_inference/distributed/utils.py +3 -2
  10. tpu_inference/envs.py +61 -8
  11. tpu_inference/executors/ray_distributed_executor.py +31 -11
  12. tpu_inference/kernels/fused_moe/v1/kernel.py +641 -110
  13. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
  14. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +213 -126
  15. tpu_inference/layers/common/attention_interface.py +7 -1
  16. tpu_inference/layers/common/sharding.py +5 -5
  17. tpu_inference/layers/vllm/fused_moe.py +74 -25
  18. tpu_inference/layers/vllm/quantization/common.py +6 -1
  19. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -62
  20. tpu_inference/layers/vllm/quantization/unquantized.py +107 -113
  21. tpu_inference/layers/vllm/sharding.py +2 -2
  22. tpu_inference/lora/torch_punica_tpu.py +1 -2
  23. tpu_inference/models/common/model_loader.py +45 -11
  24. tpu_inference/models/jax/llama3.py +2 -1
  25. tpu_inference/models/jax/llama_eagle3.py +8 -5
  26. tpu_inference/models/jax/llama_guard_4.py +361 -0
  27. tpu_inference/models/jax/qwen2.py +2 -1
  28. tpu_inference/models/jax/qwen2_5_vl.py +163 -48
  29. tpu_inference/models/jax/qwen3.py +2 -1
  30. tpu_inference/models/jax/utils/quantization/quantization_utils.py +3 -6
  31. tpu_inference/models/jax/utils/weight_utils.py +198 -143
  32. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -7
  33. tpu_inference/platforms/tpu_platform.py +28 -22
  34. tpu_inference/runner/compilation_manager.py +144 -59
  35. tpu_inference/runner/kv_cache_manager.py +17 -18
  36. tpu_inference/runner/persistent_batch_manager.py +40 -2
  37. tpu_inference/runner/structured_decoding_manager.py +2 -3
  38. tpu_inference/runner/tpu_runner.py +271 -147
  39. tpu_inference/runner/utils.py +2 -2
  40. tpu_inference/spec_decode/jax/eagle3.py +71 -21
  41. tpu_inference/tpu_info.py +4 -3
  42. tpu_inference/utils.py +36 -13
  43. tpu_inference/worker/tpu_worker.py +162 -25
  44. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/METADATA +3 -2
  45. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/RECORD +48 -53
  46. tpu_inference/mock/__init__.py +0 -0
  47. tpu_inference/mock/vllm_config_utils.py +0 -28
  48. tpu_inference/mock/vllm_envs.py +0 -1219
  49. tpu_inference/mock/vllm_logger.py +0 -212
  50. tpu_inference/mock/vllm_logging_utils.py +0 -15
  51. tpu_inference/models/jax/phi3.py +0 -376
  52. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/WHEEL +0 -0
  53. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/licenses/LICENSE +0 -0
  54. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,7 @@ import jax
15
15
  from jax._src.interpreters import pxla
16
16
  from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
17
17
 
18
+ from tpu_inference import envs
18
19
  from tpu_inference.logger import init_logger
19
20
  from tpu_inference.runner.input_batch import InputBatch
20
21
 
@@ -306,8 +307,7 @@ class PhasedBasedProfiler:
306
307
  InferencePhase.BALANCED: False
307
308
  }
308
309
  self.default_profiling_options = jax.profiler.ProfileOptions()
309
- self.default_profiling_options.python_tracer_level = os.getenv(
310
- "PYTHON_TRACER_LEVEL", 0)
310
+ self.default_profiling_options.python_tracer_level = envs.PYTHON_TRACER_LEVEL
311
311
 
312
312
  self.current_phase: str = ""
313
313
 
@@ -6,13 +6,19 @@ from typing import Any, Optional
6
6
  import jax
7
7
  import jax.numpy as jnp
8
8
  import numpy as np
9
+ from flax import nnx
10
+ from jax import lax
11
+ from jax.sharding import NamedSharding, PartitionSpec
9
12
  from vllm.config import VllmConfig
10
13
 
11
14
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
15
+ from tpu_inference.logger import init_logger
12
16
  from tpu_inference.models.common.model_loader import get_model
13
17
  from tpu_inference.runner import utils as runner_utils
14
18
  from tpu_inference.utils import device_array
15
19
 
20
+ logger = init_logger(__name__)
21
+
16
22
 
17
23
  class Eagle3Proposer:
18
24
  """A proposer for speculative decoding using the Eagle3 method.
@@ -51,8 +57,22 @@ class Eagle3Proposer:
51
57
  """Loads the draft model."""
52
58
  self.model_fn, self.compute_logits_fn, self.combine_hidden_states_fn, _, self.state, _, _ = get_model(
53
59
  self.vllm_config, self.rng_key, self.mesh, is_draft_model=True)
54
- del self.state.model['embed_tokens']
55
- self.state.model.embed_tokens = target_model.model.embed
60
+
61
+ draft_embed_tokens = getattr(self.state.model, 'embed_tokens', None)
62
+ if draft_embed_tokens is None or ~jnp.any(
63
+ draft_embed_tokens.embedding):
64
+ logger.info(
65
+ "Draft model does not have embedding. Setting draft model's embed_tokens to target model's embed"
66
+ )
67
+ self.state.model.embed_tokens = target_model.model.embed
68
+ elif jnp.array_equal(draft_embed_tokens.embedding,
69
+ target_model.model.embed.embedding):
70
+ logger.info(
71
+ "Draft model's embed_tokens is identical to target model's embed. Sharing the embedding."
72
+ )
73
+ self.state.model.embed_tokens = target_model.model.embed
74
+ else:
75
+ logger.info("Draft model has its own embed_tokens.")
56
76
 
57
77
  @functools.partial(jax.jit, static_argnums=(0, ))
58
78
  def _prepare_input_ids(
@@ -110,6 +130,17 @@ class Eagle3Proposer:
110
130
  max_num_blocks_per_req)
111
131
  new_block_tables = jnp.where(expanded_exceeds_mask, -1, block_tables)
112
132
 
133
+ positions = lax.with_sharding_constraint(
134
+ positions, NamedSharding(self.mesh, PartitionSpec(None, )))
135
+ clamped_positions = lax.with_sharding_constraint(
136
+ clamped_positions, NamedSharding(self.mesh, PartitionSpec(None, )))
137
+ new_seq_lens = lax.with_sharding_constraint(
138
+ new_seq_lens, NamedSharding(self.mesh, PartitionSpec(None, )))
139
+ query_start_loc = lax.with_sharding_constraint(
140
+ query_start_loc, NamedSharding(self.mesh, PartitionSpec()))
141
+ new_block_tables = lax.with_sharding_constraint(
142
+ new_block_tables, NamedSharding(self.mesh, PartitionSpec(None, )))
143
+
113
144
  return positions, clamped_positions, new_seq_lens, query_start_loc, new_block_tables
114
145
 
115
146
  @functools.partial(jax.jit, static_argnums=(0, ))
@@ -121,6 +152,7 @@ class Eagle3Proposer:
121
152
  @functools.partial(jax.jit, static_argnums=(0, ))
122
153
  def _prepare_hidden_states_and_input_ids(
123
154
  self,
155
+ state: nnx.State,
124
156
  aux_hidden_states: tuple[jax.Array, ...],
125
157
  query_start_loc: jax.Array,
126
158
  target_token_ids: jax.Array,
@@ -129,7 +161,7 @@ class Eagle3Proposer:
129
161
  ) -> tuple[jax.Array, jax.Array, jax.Array]:
130
162
  target_hidden_states = jnp.concatenate(aux_hidden_states, axis=-1)
131
163
  target_hidden_states = self.combine_hidden_states_fn(
132
- self.state, target_hidden_states)
164
+ state, target_hidden_states)
133
165
 
134
166
  input_ids, last_token_indices = self._prepare_input_ids(
135
167
  query_start_loc, target_token_ids, next_token_ids, num_reqs)
@@ -176,8 +208,8 @@ class Eagle3Proposer:
176
208
  block_tables=device_array(
177
209
  self.mesh, block_tables))
178
210
  target_hidden_states, input_ids, last_token_indices = self._prepare_hidden_states_and_input_ids(
179
- aux_hidden_states, attn_metadata.query_start_loc, input_ids,
180
- next_token_ids, num_reqs)
211
+ self.state, aux_hidden_states, attn_metadata.query_start_loc,
212
+ input_ids, next_token_ids, num_reqs)
181
213
  return target_hidden_states, input_ids, last_token_indices, attn_metadata
182
214
 
183
215
  # Host copies from the metadata prepared by the runner.
@@ -241,12 +273,13 @@ class Eagle3Proposer:
241
273
 
242
274
  attn_metadata = replace(attn_metadata, block_tables=block_tables)
243
275
  return self._filter_token_and_prepare_initial_inputs(
244
- token_indices, query_start_loc, seq_lens, input_ids,
276
+ self.state, token_indices, query_start_loc, seq_lens, input_ids,
245
277
  aux_hidden_states, attn_metadata, next_token_ids, num_reqs)
246
278
 
247
279
  @functools.partial(jax.jit, static_argnums=(0, ))
248
280
  def _filter_token_and_prepare_initial_inputs(
249
281
  self,
282
+ state: nnx.State,
250
283
  token_indices: jax.Array,
251
284
  query_start_loc: jax.Array,
252
285
  seq_lens: jax.Array,
@@ -274,35 +307,51 @@ class Eagle3Proposer:
274
307
  )
275
308
 
276
309
  target_hidden_states, input_ids, last_token_indices = self._prepare_hidden_states_and_input_ids(
277
- [h[token_indices] for h in aux_hidden_states], query_start_loc,
278
- target_token_ids, next_token_ids, num_reqs)
310
+ state, [h[token_indices] for h in aux_hidden_states],
311
+ query_start_loc, target_token_ids, next_token_ids, num_reqs)
279
312
 
280
313
  return target_hidden_states, input_ids, last_token_indices, attn_metadata
281
314
 
282
315
  @functools.partial(jax.jit, static_argnums=(0, ))
283
316
  def _select_draft_token_ids(
284
317
  self,
318
+ state: nnx.State,
285
319
  hidden_states: jax.Array,
286
320
  last_token_indices: jax.Array,
287
321
  ) -> jax.Array:
288
322
  sample_hidden_states = hidden_states[last_token_indices]
289
- return self._get_draft_token_ids(sample_hidden_states)
323
+ sample_hidden_states = lax.with_sharding_constraint(
324
+ sample_hidden_states,
325
+ NamedSharding(self.mesh, PartitionSpec(None, None)))
326
+ return self._get_draft_token_ids(state, sample_hidden_states)
290
327
 
291
328
  @functools.partial(jax.jit, static_argnums=(0, ))
292
- def _get_draft_token_ids(self, hidden_states: jax.Array) -> jax.Array:
329
+ def _get_draft_token_ids(self, state: nnx.State,
330
+ hidden_states: jax.Array) -> jax.Array:
293
331
  lora_metadata = None
294
- logits = self.compute_logits_fn(self.state, hidden_states,
295
- lora_metadata)
296
- return jnp.argmax(logits, axis=-1)
332
+ logits = self.compute_logits_fn(state, hidden_states, lora_metadata)
333
+ draft_token_ids = jnp.argmax(logits, axis=-1)
334
+ return lax.with_sharding_constraint(
335
+ draft_token_ids, NamedSharding(self.mesh, PartitionSpec()))
297
336
 
298
337
  @functools.partial(jax.jit, static_argnums=(0, ))
299
338
  def _select_inputs_for_loop_speculation(
300
- self, positions: jax.Array, residual: jax.Array,
339
+ self, state: nnx.State, positions: jax.Array, residual: jax.Array,
301
340
  hidden_states: jax.Array,
302
341
  last_token_indices: jax.Array) -> tuple[jax.Array, jax.Array]:
303
- return positions[last_token_indices], residual[
304
- last_token_indices], self._select_draft_token_ids(
305
- hidden_states, last_token_indices)
342
+ positions = positions[last_token_indices]
343
+ residual = residual[last_token_indices]
344
+ draft_token_ids = self._select_draft_token_ids(state, hidden_states,
345
+ last_token_indices)
346
+
347
+ positions = lax.with_sharding_constraint(
348
+ positions, NamedSharding(self.mesh, PartitionSpec(None, )))
349
+ residual = lax.with_sharding_constraint(
350
+ residual, NamedSharding(self.mesh, PartitionSpec(None, None)))
351
+ draft_token_ids = lax.with_sharding_constraint(
352
+ draft_token_ids, NamedSharding(self.mesh, PartitionSpec()))
353
+
354
+ return positions, residual, draft_token_ids
306
355
 
307
356
  def propose(
308
357
  self,
@@ -329,11 +378,11 @@ class Eagle3Proposer:
329
378
 
330
379
  if self.num_speculative_tokens == 1:
331
380
  return kv_caches, self._select_draft_token_ids(
332
- hidden_states, last_token_indices)
381
+ self.state, hidden_states, last_token_indices)
333
382
 
334
383
  positions, hidden_states, draft_token_ids = self._select_inputs_for_loop_speculation(
335
- attn_metadata.input_positions, residual[0], hidden_states,
336
- last_token_indices)
384
+ self.state, attn_metadata.input_positions, residual[0],
385
+ hidden_states, last_token_indices)
337
386
 
338
387
  draft_token_ids_list = [draft_token_ids]
339
388
 
@@ -358,7 +407,8 @@ class Eagle3Proposer:
358
407
  attn_metadata,
359
408
  )
360
409
  hidden_states = residual[0]
361
- draft_token_ids = self._get_draft_token_ids(new_hidden_states)
410
+ draft_token_ids = self._get_draft_token_ids(
411
+ self.state, new_hidden_states)
362
412
  draft_token_ids_list.append(draft_token_ids)
363
413
 
364
414
  # [batch_size, num_speculative_tokens]
tpu_inference/tpu_info.py CHANGED
@@ -3,6 +3,7 @@ import os
3
3
 
4
4
  import requests
5
5
 
6
+ from tpu_inference import envs
6
7
  from tpu_inference.logger import init_logger
7
8
 
8
9
  logger = init_logger(__name__)
@@ -32,14 +33,14 @@ def get_tpu_metadata(key: str = "") -> str:
32
33
 
33
34
 
34
35
  def get_tpu_type() -> str:
35
- tpu_type = os.getenv("TPU_ACCELERATOR_TYPE", None)
36
+ tpu_type = envs.TPU_ACCELERATOR_TYPE
36
37
  if tpu_type is None:
37
38
  tpu_type = get_tpu_metadata(key="accelerator-type")
38
39
  return tpu_type
39
40
 
40
41
 
41
42
  def get_node_name() -> str:
42
- tpu_name = os.getenv("TPU_NAME", None)
43
+ tpu_name = envs.TPU_NAME
43
44
  if not tpu_name:
44
45
  tpu_name = get_tpu_metadata(key="instance-id")
45
46
  return tpu_name
@@ -47,7 +48,7 @@ def get_node_name() -> str:
47
48
 
48
49
  def get_node_worker_id() -> int:
49
50
  """For multi-host TPU VM, this returns the worker id for the current node."""
50
- worker_id = os.getenv("TPU_WORKER_ID", None)
51
+ worker_id = envs.TPU_WORKER_ID
51
52
  if worker_id is None:
52
53
  worker_id = get_tpu_metadata(key="agent-worker-number")
53
54
  if worker_id is None:
tpu_inference/utils.py CHANGED
@@ -1,5 +1,4 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
- import os
3
2
  import time
4
3
  from collections import defaultdict
5
4
  from collections.abc import Sequence
@@ -9,30 +8,54 @@ from typing import Any, Callable, List, Tuple
9
8
  import jax
10
9
  import jax.numpy as jnp
11
10
  import numpy as np
11
+ import torch
12
12
  from jax._src import dtypes
13
13
  from jax._src import mesh as mesh_lib
14
14
  from jax._src import xla_bridge as xb
15
15
  from jax._src.lib import xla_client as xc
16
+ from jax._src.numpy.scalar_types import _ScalarMeta
16
17
  from jax.sharding import Mesh, NamedSharding, PartitionSpec
17
- from vllm import envs, utils
18
+ from torchax.ops.mappings import j2t_dtype, t2j_dtype
19
+ from vllm import envs as vllm_envs
20
+ from vllm import utils
18
21
 
22
+ from tpu_inference import envs
19
23
  from tpu_inference.logger import init_logger
20
24
 
21
25
  GBYTES = 1024 * 1024 * 1024
22
26
  TPU_HEAD_SIZE_ALIGNMENT = 128
23
27
  TPU_SECOND_LAST_MINOR = 8
24
28
 
25
- # This is used to translate from a string name for a dtype
26
- # to formal jax.numpy DType. One use case for this is
27
- # converting the `--kv_cache_dtype` flag to a dtype.
28
- TPU_STR_DTYPE_TO_JAX_DTYPE = {
29
- "bfloat16": jnp.bfloat16,
29
+ # Map vllm dtype string that doesn't exactly match jax dtype string name.
30
+ _VLLM_DTYPE_STR_TO_JAX_DTYPE = {
30
31
  "fp8": jnp.float8_e4m3fn,
31
- "fp8_e4m3": jnp.float8_e4m3,
32
+ "fp8_e4m3": jnp.float8_e4m3fn,
32
33
  "fp8_e5m2": jnp.float8_e5m2,
33
- "int8": jnp.int8,
34
34
  }
35
35
 
36
+
37
+ def to_jax_dtype(dtype: str | jnp.dtype | torch.dtype) -> jnp.dtype:
38
+ if isinstance(dtype, str):
39
+ if dict_dtype := _VLLM_DTYPE_STR_TO_JAX_DTYPE.get(dtype, None):
40
+ return dict_dtype
41
+ return jnp.dtype(dtype)
42
+ elif isinstance(dtype, torch.dtype):
43
+ return t2j_dtype(dtype)
44
+ elif isinstance(dtype, jnp.dtype):
45
+ return dtype
46
+ elif isinstance(dtype, _ScalarMeta):
47
+ return dtype.dtype
48
+ else:
49
+ raise ValueError(f"Argument is unsupported data type {type(dtype)}")
50
+
51
+
52
+ def to_torch_dtype(dtype: str | jnp.dtype | torch.dtype) -> torch.dtype:
53
+ # Use jax dtype as an intermediate dtype which we'll be used to convert it
54
+ # into torch dtype.
55
+ dtype = to_jax_dtype(dtype)
56
+ return j2t_dtype(dtype)
57
+
58
+
36
59
  _megacore = False
37
60
  logger = init_logger(__name__)
38
61
 
@@ -57,10 +80,10 @@ def get_num_kv_heads_by_tp(num_kv_heads: int, tp_size: int) -> int:
57
80
 
58
81
  def hbm_usage_bytes(devices: Any) -> List[Tuple[int, int]]:
59
82
  usage = []
60
- if envs.VLLM_TPU_USING_PATHWAYS:
83
+ if vllm_envs.VLLM_TPU_USING_PATHWAYS:
61
84
  return pathways_hbm_usage_gb(devices)
62
85
 
63
- multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
86
+ multihost_backend = envs.TPU_MULTIHOST_BACKEND
64
87
  if multihost_backend == "ray":
65
88
  # MemoryStats is only supported for addressable PjRt devices.
66
89
  # Assume all the devices have similar memory usage for now.
@@ -294,8 +317,8 @@ def get_jax_dtype_from_str_dtype(str_dtype: str) -> jnp.dtype:
294
317
  Returns:
295
318
  jnp.dtype: The JAX dtype.
296
319
  """
297
- str_dtype = str_dtype.lower().strip()
298
- return TPU_STR_DTYPE_TO_JAX_DTYPE.get(str_dtype)
320
+ # TODO(kyuyeunk): Replace all reference of this function into TpuDtype.
321
+ return to_jax_dtype(str_dtype)
299
322
 
300
323
 
301
324
  def time_function(func):
@@ -2,6 +2,7 @@
2
2
 
3
3
  import os
4
4
  import tempfile
5
+ from dataclasses import dataclass, field
5
6
  from typing import Callable, Dict, Optional, Tuple
6
7
 
7
8
  import jax
@@ -10,6 +11,7 @@ import jaxlib
10
11
  import jaxtyping
11
12
  import vllm.envs as vllm_envs
12
13
  from vllm.config import VllmConfig, set_current_vllm_config
14
+ from vllm.distributed import get_pp_group
13
15
  from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
14
16
  has_kv_transfer_group)
15
17
  from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
@@ -23,10 +25,13 @@ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
23
25
  from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
24
26
 
25
27
  from tpu_inference import envs, utils
28
+ from tpu_inference.distributed import jax_parallel_state
26
29
  from tpu_inference.distributed.utils import (get_host_ip, get_kv_transfer_port,
27
30
  get_node_id)
28
31
  from tpu_inference.layers.common.sharding import ShardingConfigManager
29
32
  from tpu_inference.logger import init_logger
33
+ from tpu_inference.models.jax.jax_intermediate_tensor import \
34
+ JaxIntermediateTensors
30
35
  from tpu_inference.runner.kv_cache import get_rpa_page_size_bytes
31
36
  from tpu_inference.runner.tpu_runner import TPUModelRunner
32
37
 
@@ -39,15 +44,39 @@ _DTYPE: dict[str, jnp.dtype] = {
39
44
  }
40
45
 
41
46
 
47
+ @dataclass
48
+ class PPConfig:
49
+ rank: int
50
+ ip: str
51
+ prev_worker_ip: str
52
+ pp_world_size: int
53
+
54
+ # default env vars for
55
+ # TPU_PROCESS_BOUNDS, TPU_CHIPS_PER_PROCESS_BOUNDS, TPU_VISIBLE_CHIPS
56
+ # if PP is used in single host.
57
+ default_tpu_process_bounds: str = field(init=False)
58
+ default_tpu_chips_per_process_bounds: str = field(init=False)
59
+ default_tpu_visible_chips: str = field(init=False)
60
+
61
+ def __post_init__(self):
62
+ self.default_tpu_process_bounds = f"1,{self.pp_world_size},1"
63
+ self.default_tpu_chips_per_process_bounds = "1,1,1"
64
+ self.default_tpu_visible_chips = f"{self.rank}"
65
+
66
+
42
67
  class TPUWorker:
43
68
 
44
- def __init__(self,
45
- vllm_config: VllmConfig,
46
- local_rank: int,
47
- rank: int,
48
- distributed_init_method: str,
49
- is_driver_worker: bool = False,
50
- devices=None):
69
+ def __init__(
70
+ self,
71
+ vllm_config: VllmConfig,
72
+ local_rank: int,
73
+ rank: int,
74
+ distributed_init_method: str,
75
+ is_driver_worker: bool = False,
76
+ devices=None,
77
+ ip: str = "localhost",
78
+ prev_worker_ip: str = "localhost",
79
+ ):
51
80
  # If we use vLLM's model implementation in PyTorch, we should set it
52
81
  # with torch version of the dtype.
53
82
  impl = envs.MODEL_IMPL_TYPE
@@ -74,10 +103,12 @@ class TPUWorker:
74
103
  self.devices = devices if devices is not None else []
75
104
  self.device_ranks = set(device.id for device in self.devices
76
105
  if isinstance(device, jaxlib._jax.Device))
106
+ self.pp_config = PPConfig(rank, ip, prev_worker_ip,
107
+ self.parallel_config.pipeline_parallel_size)
77
108
 
78
109
  if self.model_config.trust_remote_code:
79
110
  # note: lazy import to avoid importing torch before initializing
80
- from vllm.utils import init_cached_hf_modules
111
+ from vllm.utils.import_utils import init_cached_hf_modules
81
112
 
82
113
  init_cached_hf_modules()
83
114
 
@@ -86,7 +117,7 @@ class TPUWorker:
86
117
  # TPU Worker is initialized. The profiler server needs to start after
87
118
  # MP runtime is initialized.
88
119
  self.profile_dir = None
89
- if vllm_envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
120
+ if vllm_envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1 and self.pp_config.pp_world_size == 1:
90
121
  if not self.devices or 0 in self.device_ranks:
91
122
  # For TPU, we can only have 1 active profiler session for 1 profiler
92
123
  # server. So we only profile on rank0.
@@ -94,6 +125,14 @@ class TPUWorker:
94
125
  logger.info("Profiling enabled. Traces will be saved to: %s",
95
126
  self.profile_dir)
96
127
 
128
+ # For PP, we use MPMD so we want to profile every worker.
129
+ if self.pp_config.pp_world_size > 1 and vllm_envs.VLLM_TORCH_PROFILER_DIR:
130
+ self.profile_dir = os.path.join(
131
+ vllm_envs.VLLM_TORCH_PROFILER_DIR,
132
+ f"pprank_{self.rank}_ppworldsize_{self.pp_config.pp_world_size}"
133
+ )
134
+ os.makedirs(self.profile_dir, exist_ok=True)
135
+
97
136
  use_jax_profiler_server = os.getenv("USE_JAX_PROFILER_SERVER", False)
98
137
  # Only one instance of profiler is allowed
99
138
  if use_jax_profiler_server and self.rank < 1:
@@ -105,31 +144,87 @@ class TPUWorker:
105
144
  )
106
145
  jax.profiler.start_server(jax_profiler_server_port)
107
146
 
147
+ # step_counter is used to calculate uuid to transfer intermediate tensors.
148
+ self.step_counter = 0
149
+
108
150
  def initialize_cache(self, num_gpu_blocks: int,
109
151
  num_cpu_blocks: int) -> None:
110
152
  self.cache_config.num_gpu_blocks = num_gpu_blocks
111
153
  self.cache_config.num_cpu_blocks = num_cpu_blocks
112
154
 
113
- def init_device(self):
155
+ def init_device(self,
156
+ tpu_process_bounds="",
157
+ tpu_chips_per_process_bounds="",
158
+ tpu_visible_chips=""):
159
+ # set tpu visible devices for Jax runtime in single host PP.
160
+ multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
161
+ if multihost_backend != "ray" and self.parallel_config.pipeline_parallel_size > 1:
162
+ tpu_ports = [
163
+ jax_parallel_state.BASE_JAX_PORT + i
164
+ for i in range(self.pp_config.pp_world_size)
165
+ ]
166
+ os.environ["TPU_PROCESS_ADDRESSES"] = ",".join(
167
+ [f"localhost:{port}" for port in tpu_ports])
168
+ os.environ["TPU_PROCESS_PORT"] = f"{tpu_ports[self.rank]}"
169
+ os.environ["CLOUD_TPU_TASK_ID"] = f"{self.rank}"
170
+
171
+ # Note: Below is the setting for v6e8 host (8 chips of v6e)
172
+ # Replace with your own topology.
173
+ # There are 2 ways of subslicing a v6e
174
+ # 1) 2 slices with 4 TPU chips each, we can do PP=2, TP=1/2/3/4
175
+ # TPU_PROCESS_BOUNDS = "1,1,1"
176
+ # TPU_CHIPS_PER_PROCESS_BOUNDS = "1,4,1"
177
+ # TPU_VISIBLE_CHIPS = "0,1,2,3" or "4,5,6,7"
178
+ # 2) 1 chip for each subslice, with at most 8 subslices,
179
+ # we can do TP=1, PP=1/2/3/4/5/6/7/8
180
+ os.environ[
181
+ "TPU_PROCESS_BOUNDS"] = tpu_process_bounds \
182
+ if tpu_process_bounds \
183
+ else self.pp_config.default_tpu_process_bounds
184
+ os.environ[
185
+ "TPU_CHIPS_PER_PROCESS_BOUNDS"] = tpu_chips_per_process_bounds \
186
+ if tpu_chips_per_process_bounds \
187
+ else self.pp_config.default_tpu_chips_per_process_bounds
188
+ os.environ[
189
+ "TPU_VISIBLE_CHIPS"] = tpu_visible_chips \
190
+ if tpu_visible_chips \
191
+ else self.pp_config.default_tpu_visible_chips
192
+
114
193
  if not self.devices:
115
194
  sharding_config: ShardingConfigManager = self.vllm_config.sharding_config
116
195
  device_indexes = sharding_config.device_indexes
117
196
  if device_indexes is not None and len(device_indexes) > 0:
118
197
  # Enforcing the devices sequence to be consistent with the specified device indexes
119
- all_devices = jax.devices()
120
- device_dict = {device.id: device for device in all_devices}
198
+ all_local_devices = jax.local_devices()
199
+ device_dict = {
200
+ device.id: device
201
+ for device in all_local_devices
202
+ }
121
203
  self.devices = []
122
204
  for device_index in device_indexes:
123
205
  device = device_dict[device_index]
124
206
  if device is None:
125
207
  raise KeyError(
126
208
  f"Device index {device_index} not found in "
127
- f"jax.devices() with IDs {list(device_dict.keys())}!"
209
+ f"jax.local_devices() with IDs {list(device_dict.keys())}!"
128
210
  )
129
211
  self.devices.append(device)
212
+ assert len(self.devices) >= sharding_config.total_devices
130
213
  self.devices = self.devices[:sharding_config.total_devices]
131
214
  else:
132
- self.devices = jax.devices()[:sharding_config.total_devices]
215
+ if self.pp_config.pp_world_size > 1:
216
+ # We only support a mixed tp + pp scenario that tp size is
217
+ # smaller or equals the total TPUs in one node
218
+ # say: we have 4 nodes with 4 TPUs each, we can only do pp:4, tp:4, but not pp:2, tp:8
219
+ assert jax.local_device_count(
220
+ ) >= sharding_config.total_devices
221
+ self.devices = jax.local_devices()[:sharding_config.
222
+ total_devices]
223
+ else:
224
+ # In a multi-host distributed env, say: Ray, local_device count may smaller
225
+ # than the total devices, we just choose the smaller set here.
226
+ self.devices = jax.devices()[:sharding_config.
227
+ total_devices]
133
228
 
134
229
  # Initialize the vLLM distribution layer as a single chip environment,
135
230
  # we'll swap the model's parallel modules with TPU SPMD equivalents.
@@ -146,8 +241,18 @@ class TPUWorker:
146
241
  tensor_model_parallel_size=1,
147
242
  pipeline_model_parallel_size=1,
148
243
  )
244
+
245
+ jax_parallel_state.init_pp_distributed_environment(
246
+ self.pp_config.ip,
247
+ self.rank,
248
+ self.parallel_config.pipeline_parallel_size,
249
+ self.devices[0],
250
+ need_pp=self.parallel_config.pipeline_parallel_size > 1)
251
+
149
252
  ensure_kv_transfer_initialized(self.vllm_config)
150
- self.model_runner = TPUModelRunner(self.vllm_config, self.devices)
253
+ self.model_runner = TPUModelRunner(
254
+ self.vllm_config, self.devices, self.rank, self.rank == 0,
255
+ self.rank == self.pp_config.pp_world_size - 1)
151
256
  logger.info(f"Init worker | "
152
257
  f"rank={self.rank} | "
153
258
  f"node_id={get_node_id()} | "
@@ -155,6 +260,12 @@ class TPUWorker:
155
260
  f"hbm={utils.hbm_usage_gb(self.devices)}GiB")
156
261
  vllm_utils.report_usage_stats(self.vllm_config)
157
262
 
263
+ def initialize_pp_transfer_connect(self):
264
+ if self.rank == 0:
265
+ return
266
+ jax_parallel_state.connect(self.pp_config.prev_worker_ip,
267
+ self.rank - 1)
268
+
158
269
  def determine_available_memory(self) -> int:
159
270
  gpu_memory_utilization = self.cache_config.gpu_memory_utilization
160
271
  hbm_usage = utils.hbm_usage_bytes(self.devices)
@@ -194,14 +305,39 @@ class TPUWorker:
194
305
  # deliberate, temporary compromise for the same reasons outlined in
195
306
  # the `get_kv_cache_spec` method.
196
307
 
197
- output = self.model_runner.execute_model(scheduler_output)
198
-
199
- # With a connector, the scheduler expects output from all workers
200
- # TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866
201
- if has_kv_transfer_group():
202
- return output
203
-
204
- return output if self.is_driver_worker else None
308
+ if self.parallel_config.pipeline_parallel_size == 1 or self.rank == 0:
309
+ intermediate_tensors = None
310
+ else:
311
+ # receive intermediate tensors
312
+ uuid = self.model_runner.get_uuid_for_jax_transfer(
313
+ scheduler_output, self.rank - 1, self.step_counter)
314
+ # TODO: this method might only works for vllm model, not sure about jax models.
315
+ tensor_spec = self.model_runner.get_intermediate_tensor_spec(
316
+ scheduler_output.total_num_scheduled_tokens)
317
+ intermediate_tensors_dict = get_pp_group().recv_tensor_dict(
318
+ uuid, tensor_spec)
319
+ intermediate_tensors = JaxIntermediateTensors(
320
+ intermediate_tensors_dict)
321
+
322
+ output = self.model_runner.execute_model(scheduler_output,
323
+ intermediate_tensors)
324
+
325
+ if isinstance(output, JaxIntermediateTensors):
326
+ assert self.parallel_config.pipeline_parallel_size > 1
327
+ assert not get_pp_group().is_last_rank
328
+ # send intermediate tensors
329
+ uuid = self.model_runner.get_uuid_for_jax_transfer(
330
+ scheduler_output, self.rank, self.step_counter)
331
+ get_pp_group().send_tensor_dict(uuid, output.tensors)
332
+ self.step_counter += 1
333
+ return None
334
+ else:
335
+ self.step_counter += 1
336
+ # With a connector, the scheduler expects output from all workers
337
+ # TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866
338
+ if has_kv_transfer_group():
339
+ return output
340
+ return output if self.is_driver_worker else None
205
341
 
206
342
  def sample_tokens(self,
207
343
  grammar_output: GrammarOutput) -> ModelRunnerOutput:
@@ -221,7 +357,7 @@ class TPUWorker:
221
357
  if is_start:
222
358
  options = jax.profiler.ProfileOptions()
223
359
  # default: https://docs.jax.dev/en/latest/profiling.html#general-options
224
- options.python_tracer_level = os.getenv("PYTHON_TRACER_LEVEL", 0)
360
+ options.python_tracer_level = envs.PYTHON_TRACER_LEVEL
225
361
  options.host_tracer_level = os.getenv("HOST_TRACER_LEVEL", 1)
226
362
  jax.profiler.start_trace(self.profile_dir,
227
363
  profiler_options=options)
@@ -266,7 +402,8 @@ class TPUWorker:
266
402
 
267
403
  # TODO(kyuyeunk): Instead of checking page_size_bytes here, introduce
268
404
  # feature that allows overriding page_size_bytes of KVCacheSpec.
269
- vllm_page_size_bytes = get_uniform_page_size(kv_cache_specs)
405
+ vllm_page_size_bytes = get_uniform_page_size(
406
+ list(kv_cache_specs.values()))
270
407
  rpa_page_size_bytes = get_rpa_page_size_bytes(self.model_runner.mesh,
271
408
  kv_cache_specs)
272
409
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tpu_inference
3
- Version: 0.11.1.dev202511150811
3
+ Version: 0.11.1.dev202512030818
4
4
  Author: tpu_inference Contributors
5
5
  Classifier: Development Status :: 3 - Alpha
6
6
  Classifier: Intended Audience :: Developers
@@ -27,10 +27,11 @@ Requires-Dist: jaxtyping
27
27
  Requires-Dist: flax==0.11.1
28
28
  Requires-Dist: torchax==0.0.7
29
29
  Requires-Dist: qwix==0.1.1
30
- Requires-Dist: torchvision==0.23.0
30
+ Requires-Dist: torchvision==0.24.0
31
31
  Requires-Dist: pathwaysutils
32
32
  Requires-Dist: parameterized
33
33
  Requires-Dist: numba==0.62.1
34
+ Requires-Dist: runai-model-streamer[gcs,s3]==0.15.0
34
35
  Dynamic: author
35
36
  Dynamic: classifier
36
37
  Dynamic: description