tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511130813__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (67) hide show
  1. tests/kernels/fused_moe_v1_test.py +34 -303
  2. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
  3. tests/lora/test_layers.py +6 -0
  4. tests/lora/utils.py +8 -0
  5. tests/test_utils.py +16 -24
  6. tpu_inference/__init__.py +3 -22
  7. tpu_inference/core/core_tpu.py +9 -17
  8. tpu_inference/core/disagg_utils.py +8 -6
  9. tpu_inference/distributed/tpu_connector.py +4 -3
  10. tpu_inference/distributed/utils.py +2 -3
  11. tpu_inference/envs.py +8 -61
  12. tpu_inference/executors/ray_distributed_executor.py +11 -31
  13. tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
  14. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
  15. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +143 -287
  16. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +0 -7
  17. tpu_inference/layers/jax/attention/attention.py +1 -1
  18. tpu_inference/layers/{common → jax}/attention_interface.py +2 -8
  19. tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
  20. tpu_inference/layers/jax/sample/sampling.py +2 -2
  21. tpu_inference/layers/{common → jax}/sharding.py +5 -5
  22. tpu_inference/layers/vllm/attention.py +1 -1
  23. tpu_inference/layers/vllm/fused_moe.py +208 -170
  24. tpu_inference/layers/vllm/quantization/__init__.py +3 -7
  25. tpu_inference/layers/vllm/quantization/awq.py +3 -4
  26. tpu_inference/layers/vllm/quantization/common.py +1 -6
  27. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +2 -4
  28. tpu_inference/layers/vllm/quantization/unquantized.py +67 -62
  29. tpu_inference/layers/vllm/sharding.py +2 -2
  30. tpu_inference/lora/torch_punica_tpu.py +2 -1
  31. tpu_inference/mock/__init__.py +0 -0
  32. tpu_inference/mock/vllm_config_utils.py +28 -0
  33. tpu_inference/mock/vllm_envs.py +1219 -0
  34. tpu_inference/mock/vllm_logger.py +212 -0
  35. tpu_inference/mock/vllm_logging_utils.py +15 -0
  36. tpu_inference/models/common/model_loader.py +12 -46
  37. tpu_inference/models/jax/llama3.py +3 -4
  38. tpu_inference/models/jax/llama_eagle3.py +5 -8
  39. tpu_inference/models/jax/phi3.py +376 -0
  40. tpu_inference/models/jax/qwen2.py +2 -3
  41. tpu_inference/models/jax/qwen2_5_vl.py +50 -165
  42. tpu_inference/models/jax/qwen3.py +2 -3
  43. tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
  44. tpu_inference/models/jax/utils/weight_utils.py +143 -198
  45. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -32
  46. tpu_inference/platforms/tpu_platform.py +34 -47
  47. tpu_inference/runner/compilation_manager.py +60 -145
  48. tpu_inference/runner/kv_cache.py +2 -2
  49. tpu_inference/runner/kv_cache_manager.py +18 -17
  50. tpu_inference/runner/persistent_batch_manager.py +2 -40
  51. tpu_inference/runner/structured_decoding_manager.py +3 -2
  52. tpu_inference/runner/tpu_runner.py +135 -283
  53. tpu_inference/runner/utils.py +2 -2
  54. tpu_inference/spec_decode/jax/eagle3.py +21 -71
  55. tpu_inference/tpu_info.py +3 -4
  56. tpu_inference/utils.py +15 -38
  57. tpu_inference/worker/tpu_worker.py +26 -163
  58. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/METADATA +3 -4
  59. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/RECORD +63 -61
  60. tests/test_envs.py +0 -203
  61. tpu_inference/layers/common/quant_methods.py +0 -8
  62. tpu_inference/layers/vllm/quantization/mxfp4.py +0 -331
  63. tpu_inference/models/jax/llama_guard_4.py +0 -361
  64. /tpu_inference/layers/{common → jax}/binary_search.py +0 -0
  65. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/WHEEL +0 -0
  66. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/licenses/LICENSE +0 -0
  67. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/top_level.txt +0 -0
@@ -15,7 +15,6 @@ import jax
15
15
  from jax._src.interpreters import pxla
16
16
  from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
17
17
 
18
- from tpu_inference import envs
19
18
  from tpu_inference.logger import init_logger
20
19
  from tpu_inference.runner.input_batch import InputBatch
21
20
 
@@ -307,7 +306,8 @@ class PhasedBasedProfiler:
307
306
  InferencePhase.BALANCED: False
308
307
  }
309
308
  self.default_profiling_options = jax.profiler.ProfileOptions()
310
- self.default_profiling_options.python_tracer_level = envs.PYTHON_TRACER_LEVEL
309
+ self.default_profiling_options.python_tracer_level = os.getenv(
310
+ "PYTHON_TRACER_LEVEL", 0)
311
311
 
312
312
  self.current_phase: str = ""
313
313
 
@@ -6,19 +6,13 @@ from typing import Any, Optional
6
6
  import jax
7
7
  import jax.numpy as jnp
8
8
  import numpy as np
9
- from flax import nnx
10
- from jax import lax
11
- from jax.sharding import NamedSharding, PartitionSpec
12
9
  from vllm.config import VllmConfig
13
10
 
14
11
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
15
- from tpu_inference.logger import init_logger
16
12
  from tpu_inference.models.common.model_loader import get_model
17
13
  from tpu_inference.runner import utils as runner_utils
18
14
  from tpu_inference.utils import device_array
19
15
 
20
- logger = init_logger(__name__)
21
-
22
16
 
23
17
  class Eagle3Proposer:
24
18
  """A proposer for speculative decoding using the Eagle3 method.
@@ -57,22 +51,8 @@ class Eagle3Proposer:
57
51
  """Loads the draft model."""
58
52
  self.model_fn, self.compute_logits_fn, self.combine_hidden_states_fn, _, self.state, _, _ = get_model(
59
53
  self.vllm_config, self.rng_key, self.mesh, is_draft_model=True)
60
-
61
- draft_embed_tokens = getattr(self.state.model, 'embed_tokens', None)
62
- if draft_embed_tokens is None or ~jnp.any(
63
- draft_embed_tokens.embedding):
64
- logger.info(
65
- "Draft model does not have embedding. Setting draft model's embed_tokens to target model's embed"
66
- )
67
- self.state.model.embed_tokens = target_model.model.embed
68
- elif jnp.array_equal(draft_embed_tokens.embedding,
69
- target_model.model.embed.embedding):
70
- logger.info(
71
- "Draft model's embed_tokens is identical to target model's embed. Sharing the embedding."
72
- )
73
- self.state.model.embed_tokens = target_model.model.embed
74
- else:
75
- logger.info("Draft model has its own embed_tokens.")
54
+ del self.state.model['embed_tokens']
55
+ self.state.model.embed_tokens = target_model.model.embed
76
56
 
77
57
  @functools.partial(jax.jit, static_argnums=(0, ))
78
58
  def _prepare_input_ids(
@@ -130,17 +110,6 @@ class Eagle3Proposer:
130
110
  max_num_blocks_per_req)
131
111
  new_block_tables = jnp.where(expanded_exceeds_mask, -1, block_tables)
132
112
 
133
- positions = lax.with_sharding_constraint(
134
- positions, NamedSharding(self.mesh, PartitionSpec(None, )))
135
- clamped_positions = lax.with_sharding_constraint(
136
- clamped_positions, NamedSharding(self.mesh, PartitionSpec(None, )))
137
- new_seq_lens = lax.with_sharding_constraint(
138
- new_seq_lens, NamedSharding(self.mesh, PartitionSpec(None, )))
139
- query_start_loc = lax.with_sharding_constraint(
140
- query_start_loc, NamedSharding(self.mesh, PartitionSpec()))
141
- new_block_tables = lax.with_sharding_constraint(
142
- new_block_tables, NamedSharding(self.mesh, PartitionSpec(None, )))
143
-
144
113
  return positions, clamped_positions, new_seq_lens, query_start_loc, new_block_tables
145
114
 
146
115
  @functools.partial(jax.jit, static_argnums=(0, ))
@@ -152,7 +121,6 @@ class Eagle3Proposer:
152
121
  @functools.partial(jax.jit, static_argnums=(0, ))
153
122
  def _prepare_hidden_states_and_input_ids(
154
123
  self,
155
- state: nnx.State,
156
124
  aux_hidden_states: tuple[jax.Array, ...],
157
125
  query_start_loc: jax.Array,
158
126
  target_token_ids: jax.Array,
@@ -161,7 +129,7 @@ class Eagle3Proposer:
161
129
  ) -> tuple[jax.Array, jax.Array, jax.Array]:
162
130
  target_hidden_states = jnp.concatenate(aux_hidden_states, axis=-1)
163
131
  target_hidden_states = self.combine_hidden_states_fn(
164
- state, target_hidden_states)
132
+ self.state, target_hidden_states)
165
133
 
166
134
  input_ids, last_token_indices = self._prepare_input_ids(
167
135
  query_start_loc, target_token_ids, next_token_ids, num_reqs)
@@ -208,8 +176,8 @@ class Eagle3Proposer:
208
176
  block_tables=device_array(
209
177
  self.mesh, block_tables))
210
178
  target_hidden_states, input_ids, last_token_indices = self._prepare_hidden_states_and_input_ids(
211
- self.state, aux_hidden_states, attn_metadata.query_start_loc,
212
- input_ids, next_token_ids, num_reqs)
179
+ aux_hidden_states, attn_metadata.query_start_loc, input_ids,
180
+ next_token_ids, num_reqs)
213
181
  return target_hidden_states, input_ids, last_token_indices, attn_metadata
214
182
 
215
183
  # Host copies from the metadata prepared by the runner.
@@ -273,13 +241,12 @@ class Eagle3Proposer:
273
241
 
274
242
  attn_metadata = replace(attn_metadata, block_tables=block_tables)
275
243
  return self._filter_token_and_prepare_initial_inputs(
276
- self.state, token_indices, query_start_loc, seq_lens, input_ids,
244
+ token_indices, query_start_loc, seq_lens, input_ids,
277
245
  aux_hidden_states, attn_metadata, next_token_ids, num_reqs)
278
246
 
279
247
  @functools.partial(jax.jit, static_argnums=(0, ))
280
248
  def _filter_token_and_prepare_initial_inputs(
281
249
  self,
282
- state: nnx.State,
283
250
  token_indices: jax.Array,
284
251
  query_start_loc: jax.Array,
285
252
  seq_lens: jax.Array,
@@ -307,51 +274,35 @@ class Eagle3Proposer:
307
274
  )
308
275
 
309
276
  target_hidden_states, input_ids, last_token_indices = self._prepare_hidden_states_and_input_ids(
310
- state, [h[token_indices] for h in aux_hidden_states],
311
- query_start_loc, target_token_ids, next_token_ids, num_reqs)
277
+ [h[token_indices] for h in aux_hidden_states], query_start_loc,
278
+ target_token_ids, next_token_ids, num_reqs)
312
279
 
313
280
  return target_hidden_states, input_ids, last_token_indices, attn_metadata
314
281
 
315
282
  @functools.partial(jax.jit, static_argnums=(0, ))
316
283
  def _select_draft_token_ids(
317
284
  self,
318
- state: nnx.State,
319
285
  hidden_states: jax.Array,
320
286
  last_token_indices: jax.Array,
321
287
  ) -> jax.Array:
322
288
  sample_hidden_states = hidden_states[last_token_indices]
323
- sample_hidden_states = lax.with_sharding_constraint(
324
- sample_hidden_states,
325
- NamedSharding(self.mesh, PartitionSpec(None, None)))
326
- return self._get_draft_token_ids(state, sample_hidden_states)
289
+ return self._get_draft_token_ids(sample_hidden_states)
327
290
 
328
291
  @functools.partial(jax.jit, static_argnums=(0, ))
329
- def _get_draft_token_ids(self, state: nnx.State,
330
- hidden_states: jax.Array) -> jax.Array:
292
+ def _get_draft_token_ids(self, hidden_states: jax.Array) -> jax.Array:
331
293
  lora_metadata = None
332
- logits = self.compute_logits_fn(state, hidden_states, lora_metadata)
333
- draft_token_ids = jnp.argmax(logits, axis=-1)
334
- return lax.with_sharding_constraint(
335
- draft_token_ids, NamedSharding(self.mesh, PartitionSpec()))
294
+ logits = self.compute_logits_fn(self.state, hidden_states,
295
+ lora_metadata)
296
+ return jnp.argmax(logits, axis=-1)
336
297
 
337
298
  @functools.partial(jax.jit, static_argnums=(0, ))
338
299
  def _select_inputs_for_loop_speculation(
339
- self, state: nnx.State, positions: jax.Array, residual: jax.Array,
300
+ self, positions: jax.Array, residual: jax.Array,
340
301
  hidden_states: jax.Array,
341
302
  last_token_indices: jax.Array) -> tuple[jax.Array, jax.Array]:
342
- positions = positions[last_token_indices]
343
- residual = residual[last_token_indices]
344
- draft_token_ids = self._select_draft_token_ids(state, hidden_states,
345
- last_token_indices)
346
-
347
- positions = lax.with_sharding_constraint(
348
- positions, NamedSharding(self.mesh, PartitionSpec(None, )))
349
- residual = lax.with_sharding_constraint(
350
- residual, NamedSharding(self.mesh, PartitionSpec(None, None)))
351
- draft_token_ids = lax.with_sharding_constraint(
352
- draft_token_ids, NamedSharding(self.mesh, PartitionSpec()))
353
-
354
- return positions, residual, draft_token_ids
303
+ return positions[last_token_indices], residual[
304
+ last_token_indices], self._select_draft_token_ids(
305
+ hidden_states, last_token_indices)
355
306
 
356
307
  def propose(
357
308
  self,
@@ -378,11 +329,11 @@ class Eagle3Proposer:
378
329
 
379
330
  if self.num_speculative_tokens == 1:
380
331
  return kv_caches, self._select_draft_token_ids(
381
- self.state, hidden_states, last_token_indices)
332
+ hidden_states, last_token_indices)
382
333
 
383
334
  positions, hidden_states, draft_token_ids = self._select_inputs_for_loop_speculation(
384
- self.state, attn_metadata.input_positions, residual[0],
385
- hidden_states, last_token_indices)
335
+ attn_metadata.input_positions, residual[0], hidden_states,
336
+ last_token_indices)
386
337
 
387
338
  draft_token_ids_list = [draft_token_ids]
388
339
 
@@ -407,8 +358,7 @@ class Eagle3Proposer:
407
358
  attn_metadata,
408
359
  )
409
360
  hidden_states = residual[0]
410
- draft_token_ids = self._get_draft_token_ids(
411
- self.state, new_hidden_states)
361
+ draft_token_ids = self._get_draft_token_ids(new_hidden_states)
412
362
  draft_token_ids_list.append(draft_token_ids)
413
363
 
414
364
  # [batch_size, num_speculative_tokens]
tpu_inference/tpu_info.py CHANGED
@@ -3,7 +3,6 @@ import os
3
3
 
4
4
  import requests
5
5
 
6
- from tpu_inference import envs
7
6
  from tpu_inference.logger import init_logger
8
7
 
9
8
  logger = init_logger(__name__)
@@ -33,14 +32,14 @@ def get_tpu_metadata(key: str = "") -> str:
33
32
 
34
33
 
35
34
  def get_tpu_type() -> str:
36
- tpu_type = envs.TPU_ACCELERATOR_TYPE
35
+ tpu_type = os.getenv("TPU_ACCELERATOR_TYPE", None)
37
36
  if tpu_type is None:
38
37
  tpu_type = get_tpu_metadata(key="accelerator-type")
39
38
  return tpu_type
40
39
 
41
40
 
42
41
  def get_node_name() -> str:
43
- tpu_name = envs.TPU_NAME
42
+ tpu_name = os.getenv("TPU_NAME", None)
44
43
  if not tpu_name:
45
44
  tpu_name = get_tpu_metadata(key="instance-id")
46
45
  return tpu_name
@@ -48,7 +47,7 @@ def get_node_name() -> str:
48
47
 
49
48
  def get_node_worker_id() -> int:
50
49
  """For multi-host TPU VM, this returns the worker id for the current node."""
51
- worker_id = envs.TPU_WORKER_ID
50
+ worker_id = os.getenv("TPU_WORKER_ID", None)
52
51
  if worker_id is None:
53
52
  worker_id = get_tpu_metadata(key="agent-worker-number")
54
53
  if worker_id is None:
tpu_inference/utils.py CHANGED
@@ -1,4 +1,5 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
+ import os
2
3
  import time
3
4
  from collections import defaultdict
4
5
  from collections.abc import Sequence
@@ -8,54 +9,30 @@ from typing import Any, Callable, List, Tuple
8
9
  import jax
9
10
  import jax.numpy as jnp
10
11
  import numpy as np
11
- import torch
12
12
  from jax._src import dtypes
13
13
  from jax._src import mesh as mesh_lib
14
14
  from jax._src import xla_bridge as xb
15
15
  from jax._src.lib import xla_client as xc
16
- from jax._src.numpy.scalar_types import _ScalarMeta
17
16
  from jax.sharding import Mesh, NamedSharding, PartitionSpec
18
- from torchax.ops.mappings import j2t_dtype, t2j_dtype
19
- from vllm import envs as vllm_envs
20
- from vllm import utils
17
+ from vllm import envs, utils
21
18
 
22
- from tpu_inference import envs
23
19
  from tpu_inference.logger import init_logger
24
20
 
25
21
  GBYTES = 1024 * 1024 * 1024
26
22
  TPU_HEAD_SIZE_ALIGNMENT = 128
27
23
  TPU_SECOND_LAST_MINOR = 8
28
24
 
29
- # Map vllm dtype string that doesn't exactly match jax dtype string name.
30
- _VLLM_DTYPE_STR_TO_JAX_DTYPE = {
25
+ # This is used to translate from a string name for a dtype
26
+ # to formal jax.numpy DType. One use case for this is
27
+ # converting the `--kv_cache_dtype` flag to a dtype.
28
+ TPU_STR_DTYPE_TO_JAX_DTYPE = {
29
+ "bfloat16": jnp.bfloat16,
31
30
  "fp8": jnp.float8_e4m3fn,
32
- "fp8_e4m3": jnp.float8_e4m3fn,
31
+ "fp8_e4m3": jnp.float8_e4m3,
33
32
  "fp8_e5m2": jnp.float8_e5m2,
33
+ "int8": jnp.int8,
34
34
  }
35
35
 
36
-
37
- def to_jax_dtype(dtype: str | jnp.dtype | torch.dtype) -> jnp.dtype:
38
- if isinstance(dtype, str):
39
- if dict_dtype := _VLLM_DTYPE_STR_TO_JAX_DTYPE.get(dtype, None):
40
- return dict_dtype
41
- return jnp.dtype(dtype)
42
- elif isinstance(dtype, torch.dtype):
43
- return t2j_dtype(dtype)
44
- elif isinstance(dtype, jnp.dtype):
45
- return dtype
46
- elif isinstance(dtype, _ScalarMeta):
47
- return dtype.dtype
48
- else:
49
- raise ValueError(f"Argument is unsupported data type {type(dtype)}")
50
-
51
-
52
- def to_torch_dtype(dtype: str | jnp.dtype | torch.dtype) -> torch.dtype:
53
- # Use jax dtype as an intermediate dtype which we'll be used to convert it
54
- # into torch dtype.
55
- dtype = to_jax_dtype(dtype)
56
- return j2t_dtype(dtype)
57
-
58
-
59
36
  _megacore = False
60
37
  logger = init_logger(__name__)
61
38
 
@@ -80,10 +57,10 @@ def get_num_kv_heads_by_tp(num_kv_heads: int, tp_size: int) -> int:
80
57
 
81
58
  def hbm_usage_bytes(devices: Any) -> List[Tuple[int, int]]:
82
59
  usage = []
83
- if vllm_envs.VLLM_TPU_USING_PATHWAYS:
60
+ if envs.VLLM_TPU_USING_PATHWAYS:
84
61
  return pathways_hbm_usage_gb(devices)
85
62
 
86
- multihost_backend = envs.TPU_MULTIHOST_BACKEND
63
+ multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
87
64
  if multihost_backend == "ray":
88
65
  # MemoryStats is only supported for addressable PjRt devices.
89
66
  # Assume all the devices have similar memory usage for now.
@@ -155,8 +132,8 @@ def pathways_hbm_usage_gb(devices: Any) -> List[Tuple[float, float]]:
155
132
  hbm_used = defaultdict(int)
156
133
  hbm_limit = get_device_hbm_limit()
157
134
  for array in live_arrays:
158
- for buffer in array.addressable_shards:
159
- hbm_used[buffer.data.device] += buffer.data.nbytes
135
+ for buffer in array.device_buffers:
136
+ hbm_used[buffer.device] += buffer.nbytes
160
137
  return [(hbm_used[device], hbm_limit) for device in devices]
161
138
 
162
139
 
@@ -317,8 +294,8 @@ def get_jax_dtype_from_str_dtype(str_dtype: str) -> jnp.dtype:
317
294
  Returns:
318
295
  jnp.dtype: The JAX dtype.
319
296
  """
320
- # TODO(kyuyeunk): Replace all reference of this function into TpuDtype.
321
- return to_jax_dtype(str_dtype)
297
+ str_dtype = str_dtype.lower().strip()
298
+ return TPU_STR_DTYPE_TO_JAX_DTYPE.get(str_dtype)
322
299
 
323
300
 
324
301
  def time_function(func):
@@ -2,7 +2,6 @@
2
2
 
3
3
  import os
4
4
  import tempfile
5
- from dataclasses import dataclass, field
6
5
  from typing import Callable, Dict, Optional, Tuple
7
6
 
8
7
  import jax
@@ -11,7 +10,6 @@ import jaxlib
11
10
  import jaxtyping
12
11
  import vllm.envs as vllm_envs
13
12
  from vllm.config import VllmConfig, set_current_vllm_config
14
- from vllm.distributed import get_pp_group
15
13
  from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
16
14
  has_kv_transfer_group)
17
15
  from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
@@ -25,13 +23,10 @@ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
25
23
  from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
26
24
 
27
25
  from tpu_inference import envs, utils
28
- from tpu_inference.distributed import jax_parallel_state
29
26
  from tpu_inference.distributed.utils import (get_host_ip, get_kv_transfer_port,
30
27
  get_node_id)
31
- from tpu_inference.layers.common.sharding import ShardingConfigManager
28
+ from tpu_inference.layers.jax.sharding import ShardingConfigManager
32
29
  from tpu_inference.logger import init_logger
33
- from tpu_inference.models.jax.jax_intermediate_tensor import \
34
- JaxIntermediateTensors
35
30
  from tpu_inference.runner.kv_cache import get_rpa_page_size_bytes
36
31
  from tpu_inference.runner.tpu_runner import TPUModelRunner
37
32
 
@@ -44,39 +39,15 @@ _DTYPE: dict[str, jnp.dtype] = {
44
39
  }
45
40
 
46
41
 
47
- @dataclass
48
- class PPConfig:
49
- rank: int
50
- ip: str
51
- prev_worker_ip: str
52
- pp_world_size: int
53
-
54
- # default env vars for
55
- # TPU_PROCESS_BOUNDS, TPU_CHIPS_PER_PROCESS_BOUNDS, TPU_VISIBLE_CHIPS
56
- # if PP is used in single host.
57
- default_tpu_process_bounds: str = field(init=False)
58
- default_tpu_chips_per_process_bounds: str = field(init=False)
59
- default_tpu_visible_chips: str = field(init=False)
60
-
61
- def __post_init__(self):
62
- self.default_tpu_process_bounds = f"1,{self.pp_world_size},1"
63
- self.default_tpu_chips_per_process_bounds = "1,1,1"
64
- self.default_tpu_visible_chips = f"{self.rank}"
65
-
66
-
67
42
  class TPUWorker:
68
43
 
69
- def __init__(
70
- self,
71
- vllm_config: VllmConfig,
72
- local_rank: int,
73
- rank: int,
74
- distributed_init_method: str,
75
- is_driver_worker: bool = False,
76
- devices=None,
77
- ip: str = "localhost",
78
- prev_worker_ip: str = "localhost",
79
- ):
44
+ def __init__(self,
45
+ vllm_config: VllmConfig,
46
+ local_rank: int,
47
+ rank: int,
48
+ distributed_init_method: str,
49
+ is_driver_worker: bool = False,
50
+ devices=None):
80
51
  # If we use vLLM's model implementation in PyTorch, we should set it
81
52
  # with torch version of the dtype.
82
53
  impl = envs.MODEL_IMPL_TYPE
@@ -103,12 +74,10 @@ class TPUWorker:
103
74
  self.devices = devices if devices is not None else []
104
75
  self.device_ranks = set(device.id for device in self.devices
105
76
  if isinstance(device, jaxlib._jax.Device))
106
- self.pp_config = PPConfig(rank, ip, prev_worker_ip,
107
- self.parallel_config.pipeline_parallel_size)
108
77
 
109
78
  if self.model_config.trust_remote_code:
110
79
  # note: lazy import to avoid importing torch before initializing
111
- from vllm.utils.import_utils import init_cached_hf_modules
80
+ from vllm.utils import init_cached_hf_modules
112
81
 
113
82
  init_cached_hf_modules()
114
83
 
@@ -117,7 +86,7 @@ class TPUWorker:
117
86
  # TPU Worker is initialized. The profiler server needs to start after
118
87
  # MP runtime is initialized.
119
88
  self.profile_dir = None
120
- if vllm_envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1 and self.pp_config.pp_world_size == 1:
89
+ if vllm_envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
121
90
  if not self.devices or 0 in self.device_ranks:
122
91
  # For TPU, we can only have 1 active profiler session for 1 profiler
123
92
  # server. So we only profile on rank0.
@@ -125,14 +94,6 @@ class TPUWorker:
125
94
  logger.info("Profiling enabled. Traces will be saved to: %s",
126
95
  self.profile_dir)
127
96
 
128
- # For PP, we use MPMD so we want to profile every worker.
129
- if self.pp_config.pp_world_size > 1 and vllm_envs.VLLM_TORCH_PROFILER_DIR:
130
- self.profile_dir = os.path.join(
131
- vllm_envs.VLLM_TORCH_PROFILER_DIR,
132
- f"pprank_{self.rank}_ppworldsize_{self.pp_config.pp_world_size}"
133
- )
134
- os.makedirs(self.profile_dir, exist_ok=True)
135
-
136
97
  use_jax_profiler_server = os.getenv("USE_JAX_PROFILER_SERVER", False)
137
98
  # Only one instance of profiler is allowed
138
99
  if use_jax_profiler_server and self.rank < 1:
@@ -144,87 +105,31 @@ class TPUWorker:
144
105
  )
145
106
  jax.profiler.start_server(jax_profiler_server_port)
146
107
 
147
- # step_counter is used to calculate uuid to transfer intermediate tensors.
148
- self.step_counter = 0
149
-
150
108
  def initialize_cache(self, num_gpu_blocks: int,
151
109
  num_cpu_blocks: int) -> None:
152
110
  self.cache_config.num_gpu_blocks = num_gpu_blocks
153
111
  self.cache_config.num_cpu_blocks = num_cpu_blocks
154
112
 
155
- def init_device(self,
156
- tpu_process_bounds="",
157
- tpu_chips_per_process_bounds="",
158
- tpu_visible_chips=""):
159
- # set tpu visible devices for Jax runtime in single host PP.
160
- multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
161
- if multihost_backend != "ray" and self.parallel_config.pipeline_parallel_size > 1:
162
- tpu_ports = [
163
- jax_parallel_state.BASE_JAX_PORT + i
164
- for i in range(self.pp_config.pp_world_size)
165
- ]
166
- os.environ["TPU_PROCESS_ADDRESSES"] = ",".join(
167
- [f"localhost:{port}" for port in tpu_ports])
168
- os.environ["TPU_PROCESS_PORT"] = f"{tpu_ports[self.rank]}"
169
- os.environ["CLOUD_TPU_TASK_ID"] = f"{self.rank}"
170
-
171
- # Note: Below is the setting for v6e8 host (8 chips of v6e)
172
- # Replace with your own topology.
173
- # There are 2 ways of subslicing a v6e
174
- # 1) 2 slices with 4 TPU chips each, we can do PP=2, TP=1/2/3/4
175
- # TPU_PROCESS_BOUNDS = "1,1,1"
176
- # TPU_CHIPS_PER_PROCESS_BOUNDS = "1,4,1"
177
- # TPU_VISIBLE_CHIPS = "0,1,2,3" or "4,5,6,7"
178
- # 2) 1 chip for each subslice, with at most 8 subslices,
179
- # we can do TP=1, PP=1/2/3/4/5/6/7/8
180
- os.environ[
181
- "TPU_PROCESS_BOUNDS"] = tpu_process_bounds \
182
- if tpu_process_bounds \
183
- else self.pp_config.default_tpu_process_bounds
184
- os.environ[
185
- "TPU_CHIPS_PER_PROCESS_BOUNDS"] = tpu_chips_per_process_bounds \
186
- if tpu_chips_per_process_bounds \
187
- else self.pp_config.default_tpu_chips_per_process_bounds
188
- os.environ[
189
- "TPU_VISIBLE_CHIPS"] = tpu_visible_chips \
190
- if tpu_visible_chips \
191
- else self.pp_config.default_tpu_visible_chips
192
-
113
+ def init_device(self):
193
114
  if not self.devices:
194
115
  sharding_config: ShardingConfigManager = self.vllm_config.sharding_config
195
116
  device_indexes = sharding_config.device_indexes
196
117
  if device_indexes is not None and len(device_indexes) > 0:
197
118
  # Enforcing the devices sequence to be consistent with the specified device indexes
198
- all_local_devices = jax.local_devices()
199
- device_dict = {
200
- device.id: device
201
- for device in all_local_devices
202
- }
119
+ all_devices = jax.devices()
120
+ device_dict = {device.id: device for device in all_devices}
203
121
  self.devices = []
204
122
  for device_index in device_indexes:
205
123
  device = device_dict[device_index]
206
124
  if device is None:
207
125
  raise KeyError(
208
126
  f"Device index {device_index} not found in "
209
- f"jax.local_devices() with IDs {list(device_dict.keys())}!"
127
+ f"jax.devices() with IDs {list(device_dict.keys())}!"
210
128
  )
211
129
  self.devices.append(device)
212
- assert len(self.devices) >= sharding_config.total_devices
213
130
  self.devices = self.devices[:sharding_config.total_devices]
214
131
  else:
215
- if self.pp_config.pp_world_size > 1:
216
- # We only support a mixed tp + pp scenario that tp size is
217
- # smaller or equals the total TPUs in one node
218
- # say: we have 4 nodes with 4 TPUs each, we can only do pp:4, tp:4, but not pp:2, tp:8
219
- assert jax.local_device_count(
220
- ) >= sharding_config.total_devices
221
- self.devices = jax.local_devices()[:sharding_config.
222
- total_devices]
223
- else:
224
- # In a multi-host distributed env, say: Ray, local_device count may smaller
225
- # than the total devices, we just choose the smaller set here.
226
- self.devices = jax.devices()[:sharding_config.
227
- total_devices]
132
+ self.devices = jax.devices()[:sharding_config.total_devices]
228
133
 
229
134
  # Initialize the vLLM distribution layer as a single chip environment,
230
135
  # we'll swap the model's parallel modules with TPU SPMD equivalents.
@@ -241,18 +146,8 @@ class TPUWorker:
241
146
  tensor_model_parallel_size=1,
242
147
  pipeline_model_parallel_size=1,
243
148
  )
244
-
245
- jax_parallel_state.init_pp_distributed_environment(
246
- self.pp_config.ip,
247
- self.rank,
248
- self.parallel_config.pipeline_parallel_size,
249
- self.devices[0],
250
- need_pp=self.parallel_config.pipeline_parallel_size > 1)
251
-
252
149
  ensure_kv_transfer_initialized(self.vllm_config)
253
- self.model_runner = TPUModelRunner(
254
- self.vllm_config, self.devices, self.rank, self.rank == 0,
255
- self.rank == self.pp_config.pp_world_size - 1)
150
+ self.model_runner = TPUModelRunner(self.vllm_config, self.devices)
256
151
  logger.info(f"Init worker | "
257
152
  f"rank={self.rank} | "
258
153
  f"node_id={get_node_id()} | "
@@ -260,12 +155,6 @@ class TPUWorker:
260
155
  f"hbm={utils.hbm_usage_gb(self.devices)}GiB")
261
156
  vllm_utils.report_usage_stats(self.vllm_config)
262
157
 
263
- def initialize_pp_transfer_connect(self):
264
- if self.rank == 0:
265
- return
266
- jax_parallel_state.connect(self.pp_config.prev_worker_ip,
267
- self.rank - 1)
268
-
269
158
  def determine_available_memory(self) -> int:
270
159
  gpu_memory_utilization = self.cache_config.gpu_memory_utilization
271
160
  hbm_usage = utils.hbm_usage_bytes(self.devices)
@@ -305,39 +194,14 @@ class TPUWorker:
305
194
  # deliberate, temporary compromise for the same reasons outlined in
306
195
  # the `get_kv_cache_spec` method.
307
196
 
308
- if self.parallel_config.pipeline_parallel_size == 1 or self.rank == 0:
309
- intermediate_tensors = None
310
- else:
311
- # receive intermediate tensors
312
- uuid = self.model_runner.get_uuid_for_jax_transfer(
313
- scheduler_output, self.rank - 1, self.step_counter)
314
- # TODO: this method might only works for vllm model, not sure about jax models.
315
- tensor_spec = self.model_runner.get_intermediate_tensor_spec(
316
- scheduler_output.total_num_scheduled_tokens)
317
- intermediate_tensors_dict = get_pp_group().recv_tensor_dict(
318
- uuid, tensor_spec)
319
- intermediate_tensors = JaxIntermediateTensors(
320
- intermediate_tensors_dict)
321
-
322
- output = self.model_runner.execute_model(scheduler_output,
323
- intermediate_tensors)
324
-
325
- if isinstance(output, JaxIntermediateTensors):
326
- assert self.parallel_config.pipeline_parallel_size > 1
327
- assert not get_pp_group().is_last_rank
328
- # send intermediate tensors
329
- uuid = self.model_runner.get_uuid_for_jax_transfer(
330
- scheduler_output, self.rank, self.step_counter)
331
- get_pp_group().send_tensor_dict(uuid, output.tensors)
332
- self.step_counter += 1
333
- return None
334
- else:
335
- self.step_counter += 1
336
- # With a connector, the scheduler expects output from all workers
337
- # TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866
338
- if has_kv_transfer_group():
339
- return output
340
- return output if self.is_driver_worker else None
197
+ output = self.model_runner.execute_model(scheduler_output)
198
+
199
+ # With a connector, the scheduler expects output from all workers
200
+ # TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866
201
+ if has_kv_transfer_group():
202
+ return output
203
+
204
+ return output if self.is_driver_worker else None
341
205
 
342
206
  def sample_tokens(self,
343
207
  grammar_output: GrammarOutput) -> ModelRunnerOutput:
@@ -357,7 +221,7 @@ class TPUWorker:
357
221
  if is_start:
358
222
  options = jax.profiler.ProfileOptions()
359
223
  # default: https://docs.jax.dev/en/latest/profiling.html#general-options
360
- options.python_tracer_level = envs.PYTHON_TRACER_LEVEL
224
+ options.python_tracer_level = os.getenv("PYTHON_TRACER_LEVEL", 0)
361
225
  options.host_tracer_level = os.getenv("HOST_TRACER_LEVEL", 1)
362
226
  jax.profiler.start_trace(self.profile_dir,
363
227
  profiler_options=options)
@@ -402,8 +266,7 @@ class TPUWorker:
402
266
 
403
267
  # TODO(kyuyeunk): Instead of checking page_size_bytes here, introduce
404
268
  # feature that allows overriding page_size_bytes of KVCacheSpec.
405
- vllm_page_size_bytes = get_uniform_page_size(
406
- list(kv_cache_specs.values()))
269
+ vllm_page_size_bytes = get_uniform_page_size(kv_cache_specs)
407
270
  rpa_page_size_bytes = get_rpa_page_size_bytes(self.model_runner.mesh,
408
271
  kv_cache_specs)
409
272