tpu-inference 0.11.1.dev202511150811__py3-none-any.whl → 0.11.1.dev202512030818__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/lora/test_layers.py +0 -6
- tests/lora/utils.py +0 -8
- tests/test_envs.py +32 -11
- tests/test_utils.py +1 -2
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +3 -4
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +61 -8
- tpu_inference/executors/ray_distributed_executor.py +31 -11
- tpu_inference/kernels/fused_moe/v1/kernel.py +641 -110
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +213 -126
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +74 -25
- tpu_inference/layers/vllm/quantization/common.py +6 -1
- tpu_inference/layers/vllm/quantization/mxfp4.py +137 -62
- tpu_inference/layers/vllm/quantization/unquantized.py +107 -113
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +45 -11
- tpu_inference/models/jax/llama3.py +2 -1
- tpu_inference/models/jax/llama_eagle3.py +8 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +2 -1
- tpu_inference/models/jax/qwen2_5_vl.py +163 -48
- tpu_inference/models/jax/qwen3.py +2 -1
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +3 -6
- tpu_inference/models/jax/utils/weight_utils.py +198 -143
- tpu_inference/models/vllm/vllm_model_wrapper.py +14 -7
- tpu_inference/platforms/tpu_platform.py +28 -22
- tpu_inference/runner/compilation_manager.py +144 -59
- tpu_inference/runner/kv_cache_manager.py +17 -18
- tpu_inference/runner/persistent_batch_manager.py +40 -2
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +271 -147
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +71 -21
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +36 -13
- tpu_inference/worker/tpu_worker.py +162 -25
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/METADATA +3 -2
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/RECORD +48 -53
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/top_level.txt +0 -0
tpu_inference/runner/utils.py
CHANGED
|
@@ -15,6 +15,7 @@ import jax
|
|
|
15
15
|
from jax._src.interpreters import pxla
|
|
16
16
|
from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
|
|
17
17
|
|
|
18
|
+
from tpu_inference import envs
|
|
18
19
|
from tpu_inference.logger import init_logger
|
|
19
20
|
from tpu_inference.runner.input_batch import InputBatch
|
|
20
21
|
|
|
@@ -306,8 +307,7 @@ class PhasedBasedProfiler:
|
|
|
306
307
|
InferencePhase.BALANCED: False
|
|
307
308
|
}
|
|
308
309
|
self.default_profiling_options = jax.profiler.ProfileOptions()
|
|
309
|
-
self.default_profiling_options.python_tracer_level =
|
|
310
|
-
"PYTHON_TRACER_LEVEL", 0)
|
|
310
|
+
self.default_profiling_options.python_tracer_level = envs.PYTHON_TRACER_LEVEL
|
|
311
311
|
|
|
312
312
|
self.current_phase: str = ""
|
|
313
313
|
|
|
@@ -6,13 +6,19 @@ from typing import Any, Optional
|
|
|
6
6
|
import jax
|
|
7
7
|
import jax.numpy as jnp
|
|
8
8
|
import numpy as np
|
|
9
|
+
from flax import nnx
|
|
10
|
+
from jax import lax
|
|
11
|
+
from jax.sharding import NamedSharding, PartitionSpec
|
|
9
12
|
from vllm.config import VllmConfig
|
|
10
13
|
|
|
11
14
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
15
|
+
from tpu_inference.logger import init_logger
|
|
12
16
|
from tpu_inference.models.common.model_loader import get_model
|
|
13
17
|
from tpu_inference.runner import utils as runner_utils
|
|
14
18
|
from tpu_inference.utils import device_array
|
|
15
19
|
|
|
20
|
+
logger = init_logger(__name__)
|
|
21
|
+
|
|
16
22
|
|
|
17
23
|
class Eagle3Proposer:
|
|
18
24
|
"""A proposer for speculative decoding using the Eagle3 method.
|
|
@@ -51,8 +57,22 @@ class Eagle3Proposer:
|
|
|
51
57
|
"""Loads the draft model."""
|
|
52
58
|
self.model_fn, self.compute_logits_fn, self.combine_hidden_states_fn, _, self.state, _, _ = get_model(
|
|
53
59
|
self.vllm_config, self.rng_key, self.mesh, is_draft_model=True)
|
|
54
|
-
|
|
55
|
-
self.state.model
|
|
60
|
+
|
|
61
|
+
draft_embed_tokens = getattr(self.state.model, 'embed_tokens', None)
|
|
62
|
+
if draft_embed_tokens is None or ~jnp.any(
|
|
63
|
+
draft_embed_tokens.embedding):
|
|
64
|
+
logger.info(
|
|
65
|
+
"Draft model does not have embedding. Setting draft model's embed_tokens to target model's embed"
|
|
66
|
+
)
|
|
67
|
+
self.state.model.embed_tokens = target_model.model.embed
|
|
68
|
+
elif jnp.array_equal(draft_embed_tokens.embedding,
|
|
69
|
+
target_model.model.embed.embedding):
|
|
70
|
+
logger.info(
|
|
71
|
+
"Draft model's embed_tokens is identical to target model's embed. Sharing the embedding."
|
|
72
|
+
)
|
|
73
|
+
self.state.model.embed_tokens = target_model.model.embed
|
|
74
|
+
else:
|
|
75
|
+
logger.info("Draft model has its own embed_tokens.")
|
|
56
76
|
|
|
57
77
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
58
78
|
def _prepare_input_ids(
|
|
@@ -110,6 +130,17 @@ class Eagle3Proposer:
|
|
|
110
130
|
max_num_blocks_per_req)
|
|
111
131
|
new_block_tables = jnp.where(expanded_exceeds_mask, -1, block_tables)
|
|
112
132
|
|
|
133
|
+
positions = lax.with_sharding_constraint(
|
|
134
|
+
positions, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
135
|
+
clamped_positions = lax.with_sharding_constraint(
|
|
136
|
+
clamped_positions, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
137
|
+
new_seq_lens = lax.with_sharding_constraint(
|
|
138
|
+
new_seq_lens, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
139
|
+
query_start_loc = lax.with_sharding_constraint(
|
|
140
|
+
query_start_loc, NamedSharding(self.mesh, PartitionSpec()))
|
|
141
|
+
new_block_tables = lax.with_sharding_constraint(
|
|
142
|
+
new_block_tables, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
143
|
+
|
|
113
144
|
return positions, clamped_positions, new_seq_lens, query_start_loc, new_block_tables
|
|
114
145
|
|
|
115
146
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
@@ -121,6 +152,7 @@ class Eagle3Proposer:
|
|
|
121
152
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
122
153
|
def _prepare_hidden_states_and_input_ids(
|
|
123
154
|
self,
|
|
155
|
+
state: nnx.State,
|
|
124
156
|
aux_hidden_states: tuple[jax.Array, ...],
|
|
125
157
|
query_start_loc: jax.Array,
|
|
126
158
|
target_token_ids: jax.Array,
|
|
@@ -129,7 +161,7 @@ class Eagle3Proposer:
|
|
|
129
161
|
) -> tuple[jax.Array, jax.Array, jax.Array]:
|
|
130
162
|
target_hidden_states = jnp.concatenate(aux_hidden_states, axis=-1)
|
|
131
163
|
target_hidden_states = self.combine_hidden_states_fn(
|
|
132
|
-
|
|
164
|
+
state, target_hidden_states)
|
|
133
165
|
|
|
134
166
|
input_ids, last_token_indices = self._prepare_input_ids(
|
|
135
167
|
query_start_loc, target_token_ids, next_token_ids, num_reqs)
|
|
@@ -176,8 +208,8 @@ class Eagle3Proposer:
|
|
|
176
208
|
block_tables=device_array(
|
|
177
209
|
self.mesh, block_tables))
|
|
178
210
|
target_hidden_states, input_ids, last_token_indices = self._prepare_hidden_states_and_input_ids(
|
|
179
|
-
aux_hidden_states, attn_metadata.query_start_loc,
|
|
180
|
-
next_token_ids, num_reqs)
|
|
211
|
+
self.state, aux_hidden_states, attn_metadata.query_start_loc,
|
|
212
|
+
input_ids, next_token_ids, num_reqs)
|
|
181
213
|
return target_hidden_states, input_ids, last_token_indices, attn_metadata
|
|
182
214
|
|
|
183
215
|
# Host copies from the metadata prepared by the runner.
|
|
@@ -241,12 +273,13 @@ class Eagle3Proposer:
|
|
|
241
273
|
|
|
242
274
|
attn_metadata = replace(attn_metadata, block_tables=block_tables)
|
|
243
275
|
return self._filter_token_and_prepare_initial_inputs(
|
|
244
|
-
token_indices, query_start_loc, seq_lens, input_ids,
|
|
276
|
+
self.state, token_indices, query_start_loc, seq_lens, input_ids,
|
|
245
277
|
aux_hidden_states, attn_metadata, next_token_ids, num_reqs)
|
|
246
278
|
|
|
247
279
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
248
280
|
def _filter_token_and_prepare_initial_inputs(
|
|
249
281
|
self,
|
|
282
|
+
state: nnx.State,
|
|
250
283
|
token_indices: jax.Array,
|
|
251
284
|
query_start_loc: jax.Array,
|
|
252
285
|
seq_lens: jax.Array,
|
|
@@ -274,35 +307,51 @@ class Eagle3Proposer:
|
|
|
274
307
|
)
|
|
275
308
|
|
|
276
309
|
target_hidden_states, input_ids, last_token_indices = self._prepare_hidden_states_and_input_ids(
|
|
277
|
-
[h[token_indices] for h in aux_hidden_states],
|
|
278
|
-
target_token_ids, next_token_ids, num_reqs)
|
|
310
|
+
state, [h[token_indices] for h in aux_hidden_states],
|
|
311
|
+
query_start_loc, target_token_ids, next_token_ids, num_reqs)
|
|
279
312
|
|
|
280
313
|
return target_hidden_states, input_ids, last_token_indices, attn_metadata
|
|
281
314
|
|
|
282
315
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
283
316
|
def _select_draft_token_ids(
|
|
284
317
|
self,
|
|
318
|
+
state: nnx.State,
|
|
285
319
|
hidden_states: jax.Array,
|
|
286
320
|
last_token_indices: jax.Array,
|
|
287
321
|
) -> jax.Array:
|
|
288
322
|
sample_hidden_states = hidden_states[last_token_indices]
|
|
289
|
-
|
|
323
|
+
sample_hidden_states = lax.with_sharding_constraint(
|
|
324
|
+
sample_hidden_states,
|
|
325
|
+
NamedSharding(self.mesh, PartitionSpec(None, None)))
|
|
326
|
+
return self._get_draft_token_ids(state, sample_hidden_states)
|
|
290
327
|
|
|
291
328
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
292
|
-
def _get_draft_token_ids(self,
|
|
329
|
+
def _get_draft_token_ids(self, state: nnx.State,
|
|
330
|
+
hidden_states: jax.Array) -> jax.Array:
|
|
293
331
|
lora_metadata = None
|
|
294
|
-
logits = self.compute_logits_fn(
|
|
295
|
-
|
|
296
|
-
return
|
|
332
|
+
logits = self.compute_logits_fn(state, hidden_states, lora_metadata)
|
|
333
|
+
draft_token_ids = jnp.argmax(logits, axis=-1)
|
|
334
|
+
return lax.with_sharding_constraint(
|
|
335
|
+
draft_token_ids, NamedSharding(self.mesh, PartitionSpec()))
|
|
297
336
|
|
|
298
337
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
299
338
|
def _select_inputs_for_loop_speculation(
|
|
300
|
-
self, positions: jax.Array, residual: jax.Array,
|
|
339
|
+
self, state: nnx.State, positions: jax.Array, residual: jax.Array,
|
|
301
340
|
hidden_states: jax.Array,
|
|
302
341
|
last_token_indices: jax.Array) -> tuple[jax.Array, jax.Array]:
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
342
|
+
positions = positions[last_token_indices]
|
|
343
|
+
residual = residual[last_token_indices]
|
|
344
|
+
draft_token_ids = self._select_draft_token_ids(state, hidden_states,
|
|
345
|
+
last_token_indices)
|
|
346
|
+
|
|
347
|
+
positions = lax.with_sharding_constraint(
|
|
348
|
+
positions, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
349
|
+
residual = lax.with_sharding_constraint(
|
|
350
|
+
residual, NamedSharding(self.mesh, PartitionSpec(None, None)))
|
|
351
|
+
draft_token_ids = lax.with_sharding_constraint(
|
|
352
|
+
draft_token_ids, NamedSharding(self.mesh, PartitionSpec()))
|
|
353
|
+
|
|
354
|
+
return positions, residual, draft_token_ids
|
|
306
355
|
|
|
307
356
|
def propose(
|
|
308
357
|
self,
|
|
@@ -329,11 +378,11 @@ class Eagle3Proposer:
|
|
|
329
378
|
|
|
330
379
|
if self.num_speculative_tokens == 1:
|
|
331
380
|
return kv_caches, self._select_draft_token_ids(
|
|
332
|
-
hidden_states, last_token_indices)
|
|
381
|
+
self.state, hidden_states, last_token_indices)
|
|
333
382
|
|
|
334
383
|
positions, hidden_states, draft_token_ids = self._select_inputs_for_loop_speculation(
|
|
335
|
-
attn_metadata.input_positions, residual[0],
|
|
336
|
-
last_token_indices)
|
|
384
|
+
self.state, attn_metadata.input_positions, residual[0],
|
|
385
|
+
hidden_states, last_token_indices)
|
|
337
386
|
|
|
338
387
|
draft_token_ids_list = [draft_token_ids]
|
|
339
388
|
|
|
@@ -358,7 +407,8 @@ class Eagle3Proposer:
|
|
|
358
407
|
attn_metadata,
|
|
359
408
|
)
|
|
360
409
|
hidden_states = residual[0]
|
|
361
|
-
draft_token_ids = self._get_draft_token_ids(
|
|
410
|
+
draft_token_ids = self._get_draft_token_ids(
|
|
411
|
+
self.state, new_hidden_states)
|
|
362
412
|
draft_token_ids_list.append(draft_token_ids)
|
|
363
413
|
|
|
364
414
|
# [batch_size, num_speculative_tokens]
|
tpu_inference/tpu_info.py
CHANGED
|
@@ -3,6 +3,7 @@ import os
|
|
|
3
3
|
|
|
4
4
|
import requests
|
|
5
5
|
|
|
6
|
+
from tpu_inference import envs
|
|
6
7
|
from tpu_inference.logger import init_logger
|
|
7
8
|
|
|
8
9
|
logger = init_logger(__name__)
|
|
@@ -32,14 +33,14 @@ def get_tpu_metadata(key: str = "") -> str:
|
|
|
32
33
|
|
|
33
34
|
|
|
34
35
|
def get_tpu_type() -> str:
|
|
35
|
-
tpu_type =
|
|
36
|
+
tpu_type = envs.TPU_ACCELERATOR_TYPE
|
|
36
37
|
if tpu_type is None:
|
|
37
38
|
tpu_type = get_tpu_metadata(key="accelerator-type")
|
|
38
39
|
return tpu_type
|
|
39
40
|
|
|
40
41
|
|
|
41
42
|
def get_node_name() -> str:
|
|
42
|
-
tpu_name =
|
|
43
|
+
tpu_name = envs.TPU_NAME
|
|
43
44
|
if not tpu_name:
|
|
44
45
|
tpu_name = get_tpu_metadata(key="instance-id")
|
|
45
46
|
return tpu_name
|
|
@@ -47,7 +48,7 @@ def get_node_name() -> str:
|
|
|
47
48
|
|
|
48
49
|
def get_node_worker_id() -> int:
|
|
49
50
|
"""For multi-host TPU VM, this returns the worker id for the current node."""
|
|
50
|
-
worker_id =
|
|
51
|
+
worker_id = envs.TPU_WORKER_ID
|
|
51
52
|
if worker_id is None:
|
|
52
53
|
worker_id = get_tpu_metadata(key="agent-worker-number")
|
|
53
54
|
if worker_id is None:
|
tpu_inference/utils.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
-
import os
|
|
3
2
|
import time
|
|
4
3
|
from collections import defaultdict
|
|
5
4
|
from collections.abc import Sequence
|
|
@@ -9,30 +8,54 @@ from typing import Any, Callable, List, Tuple
|
|
|
9
8
|
import jax
|
|
10
9
|
import jax.numpy as jnp
|
|
11
10
|
import numpy as np
|
|
11
|
+
import torch
|
|
12
12
|
from jax._src import dtypes
|
|
13
13
|
from jax._src import mesh as mesh_lib
|
|
14
14
|
from jax._src import xla_bridge as xb
|
|
15
15
|
from jax._src.lib import xla_client as xc
|
|
16
|
+
from jax._src.numpy.scalar_types import _ScalarMeta
|
|
16
17
|
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
17
|
-
from
|
|
18
|
+
from torchax.ops.mappings import j2t_dtype, t2j_dtype
|
|
19
|
+
from vllm import envs as vllm_envs
|
|
20
|
+
from vllm import utils
|
|
18
21
|
|
|
22
|
+
from tpu_inference import envs
|
|
19
23
|
from tpu_inference.logger import init_logger
|
|
20
24
|
|
|
21
25
|
GBYTES = 1024 * 1024 * 1024
|
|
22
26
|
TPU_HEAD_SIZE_ALIGNMENT = 128
|
|
23
27
|
TPU_SECOND_LAST_MINOR = 8
|
|
24
28
|
|
|
25
|
-
#
|
|
26
|
-
|
|
27
|
-
# converting the `--kv_cache_dtype` flag to a dtype.
|
|
28
|
-
TPU_STR_DTYPE_TO_JAX_DTYPE = {
|
|
29
|
-
"bfloat16": jnp.bfloat16,
|
|
29
|
+
# Map vllm dtype string that doesn't exactly match jax dtype string name.
|
|
30
|
+
_VLLM_DTYPE_STR_TO_JAX_DTYPE = {
|
|
30
31
|
"fp8": jnp.float8_e4m3fn,
|
|
31
|
-
"fp8_e4m3": jnp.
|
|
32
|
+
"fp8_e4m3": jnp.float8_e4m3fn,
|
|
32
33
|
"fp8_e5m2": jnp.float8_e5m2,
|
|
33
|
-
"int8": jnp.int8,
|
|
34
34
|
}
|
|
35
35
|
|
|
36
|
+
|
|
37
|
+
def to_jax_dtype(dtype: str | jnp.dtype | torch.dtype) -> jnp.dtype:
|
|
38
|
+
if isinstance(dtype, str):
|
|
39
|
+
if dict_dtype := _VLLM_DTYPE_STR_TO_JAX_DTYPE.get(dtype, None):
|
|
40
|
+
return dict_dtype
|
|
41
|
+
return jnp.dtype(dtype)
|
|
42
|
+
elif isinstance(dtype, torch.dtype):
|
|
43
|
+
return t2j_dtype(dtype)
|
|
44
|
+
elif isinstance(dtype, jnp.dtype):
|
|
45
|
+
return dtype
|
|
46
|
+
elif isinstance(dtype, _ScalarMeta):
|
|
47
|
+
return dtype.dtype
|
|
48
|
+
else:
|
|
49
|
+
raise ValueError(f"Argument is unsupported data type {type(dtype)}")
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def to_torch_dtype(dtype: str | jnp.dtype | torch.dtype) -> torch.dtype:
|
|
53
|
+
# Use jax dtype as an intermediate dtype which we'll be used to convert it
|
|
54
|
+
# into torch dtype.
|
|
55
|
+
dtype = to_jax_dtype(dtype)
|
|
56
|
+
return j2t_dtype(dtype)
|
|
57
|
+
|
|
58
|
+
|
|
36
59
|
_megacore = False
|
|
37
60
|
logger = init_logger(__name__)
|
|
38
61
|
|
|
@@ -57,10 +80,10 @@ def get_num_kv_heads_by_tp(num_kv_heads: int, tp_size: int) -> int:
|
|
|
57
80
|
|
|
58
81
|
def hbm_usage_bytes(devices: Any) -> List[Tuple[int, int]]:
|
|
59
82
|
usage = []
|
|
60
|
-
if
|
|
83
|
+
if vllm_envs.VLLM_TPU_USING_PATHWAYS:
|
|
61
84
|
return pathways_hbm_usage_gb(devices)
|
|
62
85
|
|
|
63
|
-
multihost_backend =
|
|
86
|
+
multihost_backend = envs.TPU_MULTIHOST_BACKEND
|
|
64
87
|
if multihost_backend == "ray":
|
|
65
88
|
# MemoryStats is only supported for addressable PjRt devices.
|
|
66
89
|
# Assume all the devices have similar memory usage for now.
|
|
@@ -294,8 +317,8 @@ def get_jax_dtype_from_str_dtype(str_dtype: str) -> jnp.dtype:
|
|
|
294
317
|
Returns:
|
|
295
318
|
jnp.dtype: The JAX dtype.
|
|
296
319
|
"""
|
|
297
|
-
|
|
298
|
-
return
|
|
320
|
+
# TODO(kyuyeunk): Replace all reference of this function into TpuDtype.
|
|
321
|
+
return to_jax_dtype(str_dtype)
|
|
299
322
|
|
|
300
323
|
|
|
301
324
|
def time_function(func):
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
import tempfile
|
|
5
|
+
from dataclasses import dataclass, field
|
|
5
6
|
from typing import Callable, Dict, Optional, Tuple
|
|
6
7
|
|
|
7
8
|
import jax
|
|
@@ -10,6 +11,7 @@ import jaxlib
|
|
|
10
11
|
import jaxtyping
|
|
11
12
|
import vllm.envs as vllm_envs
|
|
12
13
|
from vllm.config import VllmConfig, set_current_vllm_config
|
|
14
|
+
from vllm.distributed import get_pp_group
|
|
13
15
|
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
|
|
14
16
|
has_kv_transfer_group)
|
|
15
17
|
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
|
@@ -23,10 +25,13 @@ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
|
|
23
25
|
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
|
24
26
|
|
|
25
27
|
from tpu_inference import envs, utils
|
|
28
|
+
from tpu_inference.distributed import jax_parallel_state
|
|
26
29
|
from tpu_inference.distributed.utils import (get_host_ip, get_kv_transfer_port,
|
|
27
30
|
get_node_id)
|
|
28
31
|
from tpu_inference.layers.common.sharding import ShardingConfigManager
|
|
29
32
|
from tpu_inference.logger import init_logger
|
|
33
|
+
from tpu_inference.models.jax.jax_intermediate_tensor import \
|
|
34
|
+
JaxIntermediateTensors
|
|
30
35
|
from tpu_inference.runner.kv_cache import get_rpa_page_size_bytes
|
|
31
36
|
from tpu_inference.runner.tpu_runner import TPUModelRunner
|
|
32
37
|
|
|
@@ -39,15 +44,39 @@ _DTYPE: dict[str, jnp.dtype] = {
|
|
|
39
44
|
}
|
|
40
45
|
|
|
41
46
|
|
|
47
|
+
@dataclass
|
|
48
|
+
class PPConfig:
|
|
49
|
+
rank: int
|
|
50
|
+
ip: str
|
|
51
|
+
prev_worker_ip: str
|
|
52
|
+
pp_world_size: int
|
|
53
|
+
|
|
54
|
+
# default env vars for
|
|
55
|
+
# TPU_PROCESS_BOUNDS, TPU_CHIPS_PER_PROCESS_BOUNDS, TPU_VISIBLE_CHIPS
|
|
56
|
+
# if PP is used in single host.
|
|
57
|
+
default_tpu_process_bounds: str = field(init=False)
|
|
58
|
+
default_tpu_chips_per_process_bounds: str = field(init=False)
|
|
59
|
+
default_tpu_visible_chips: str = field(init=False)
|
|
60
|
+
|
|
61
|
+
def __post_init__(self):
|
|
62
|
+
self.default_tpu_process_bounds = f"1,{self.pp_world_size},1"
|
|
63
|
+
self.default_tpu_chips_per_process_bounds = "1,1,1"
|
|
64
|
+
self.default_tpu_visible_chips = f"{self.rank}"
|
|
65
|
+
|
|
66
|
+
|
|
42
67
|
class TPUWorker:
|
|
43
68
|
|
|
44
|
-
def __init__(
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
vllm_config: VllmConfig,
|
|
72
|
+
local_rank: int,
|
|
73
|
+
rank: int,
|
|
74
|
+
distributed_init_method: str,
|
|
75
|
+
is_driver_worker: bool = False,
|
|
76
|
+
devices=None,
|
|
77
|
+
ip: str = "localhost",
|
|
78
|
+
prev_worker_ip: str = "localhost",
|
|
79
|
+
):
|
|
51
80
|
# If we use vLLM's model implementation in PyTorch, we should set it
|
|
52
81
|
# with torch version of the dtype.
|
|
53
82
|
impl = envs.MODEL_IMPL_TYPE
|
|
@@ -74,10 +103,12 @@ class TPUWorker:
|
|
|
74
103
|
self.devices = devices if devices is not None else []
|
|
75
104
|
self.device_ranks = set(device.id for device in self.devices
|
|
76
105
|
if isinstance(device, jaxlib._jax.Device))
|
|
106
|
+
self.pp_config = PPConfig(rank, ip, prev_worker_ip,
|
|
107
|
+
self.parallel_config.pipeline_parallel_size)
|
|
77
108
|
|
|
78
109
|
if self.model_config.trust_remote_code:
|
|
79
110
|
# note: lazy import to avoid importing torch before initializing
|
|
80
|
-
from vllm.utils import init_cached_hf_modules
|
|
111
|
+
from vllm.utils.import_utils import init_cached_hf_modules
|
|
81
112
|
|
|
82
113
|
init_cached_hf_modules()
|
|
83
114
|
|
|
@@ -86,7 +117,7 @@ class TPUWorker:
|
|
|
86
117
|
# TPU Worker is initialized. The profiler server needs to start after
|
|
87
118
|
# MP runtime is initialized.
|
|
88
119
|
self.profile_dir = None
|
|
89
|
-
if vllm_envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
|
|
120
|
+
if vllm_envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1 and self.pp_config.pp_world_size == 1:
|
|
90
121
|
if not self.devices or 0 in self.device_ranks:
|
|
91
122
|
# For TPU, we can only have 1 active profiler session for 1 profiler
|
|
92
123
|
# server. So we only profile on rank0.
|
|
@@ -94,6 +125,14 @@ class TPUWorker:
|
|
|
94
125
|
logger.info("Profiling enabled. Traces will be saved to: %s",
|
|
95
126
|
self.profile_dir)
|
|
96
127
|
|
|
128
|
+
# For PP, we use MPMD so we want to profile every worker.
|
|
129
|
+
if self.pp_config.pp_world_size > 1 and vllm_envs.VLLM_TORCH_PROFILER_DIR:
|
|
130
|
+
self.profile_dir = os.path.join(
|
|
131
|
+
vllm_envs.VLLM_TORCH_PROFILER_DIR,
|
|
132
|
+
f"pprank_{self.rank}_ppworldsize_{self.pp_config.pp_world_size}"
|
|
133
|
+
)
|
|
134
|
+
os.makedirs(self.profile_dir, exist_ok=True)
|
|
135
|
+
|
|
97
136
|
use_jax_profiler_server = os.getenv("USE_JAX_PROFILER_SERVER", False)
|
|
98
137
|
# Only one instance of profiler is allowed
|
|
99
138
|
if use_jax_profiler_server and self.rank < 1:
|
|
@@ -105,31 +144,87 @@ class TPUWorker:
|
|
|
105
144
|
)
|
|
106
145
|
jax.profiler.start_server(jax_profiler_server_port)
|
|
107
146
|
|
|
147
|
+
# step_counter is used to calculate uuid to transfer intermediate tensors.
|
|
148
|
+
self.step_counter = 0
|
|
149
|
+
|
|
108
150
|
def initialize_cache(self, num_gpu_blocks: int,
|
|
109
151
|
num_cpu_blocks: int) -> None:
|
|
110
152
|
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
|
111
153
|
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
|
112
154
|
|
|
113
|
-
def init_device(self
|
|
155
|
+
def init_device(self,
|
|
156
|
+
tpu_process_bounds="",
|
|
157
|
+
tpu_chips_per_process_bounds="",
|
|
158
|
+
tpu_visible_chips=""):
|
|
159
|
+
# set tpu visible devices for Jax runtime in single host PP.
|
|
160
|
+
multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
|
|
161
|
+
if multihost_backend != "ray" and self.parallel_config.pipeline_parallel_size > 1:
|
|
162
|
+
tpu_ports = [
|
|
163
|
+
jax_parallel_state.BASE_JAX_PORT + i
|
|
164
|
+
for i in range(self.pp_config.pp_world_size)
|
|
165
|
+
]
|
|
166
|
+
os.environ["TPU_PROCESS_ADDRESSES"] = ",".join(
|
|
167
|
+
[f"localhost:{port}" for port in tpu_ports])
|
|
168
|
+
os.environ["TPU_PROCESS_PORT"] = f"{tpu_ports[self.rank]}"
|
|
169
|
+
os.environ["CLOUD_TPU_TASK_ID"] = f"{self.rank}"
|
|
170
|
+
|
|
171
|
+
# Note: Below is the setting for v6e8 host (8 chips of v6e)
|
|
172
|
+
# Replace with your own topology.
|
|
173
|
+
# There are 2 ways of subslicing a v6e
|
|
174
|
+
# 1) 2 slices with 4 TPU chips each, we can do PP=2, TP=1/2/3/4
|
|
175
|
+
# TPU_PROCESS_BOUNDS = "1,1,1"
|
|
176
|
+
# TPU_CHIPS_PER_PROCESS_BOUNDS = "1,4,1"
|
|
177
|
+
# TPU_VISIBLE_CHIPS = "0,1,2,3" or "4,5,6,7"
|
|
178
|
+
# 2) 1 chip for each subslice, with at most 8 subslices,
|
|
179
|
+
# we can do TP=1, PP=1/2/3/4/5/6/7/8
|
|
180
|
+
os.environ[
|
|
181
|
+
"TPU_PROCESS_BOUNDS"] = tpu_process_bounds \
|
|
182
|
+
if tpu_process_bounds \
|
|
183
|
+
else self.pp_config.default_tpu_process_bounds
|
|
184
|
+
os.environ[
|
|
185
|
+
"TPU_CHIPS_PER_PROCESS_BOUNDS"] = tpu_chips_per_process_bounds \
|
|
186
|
+
if tpu_chips_per_process_bounds \
|
|
187
|
+
else self.pp_config.default_tpu_chips_per_process_bounds
|
|
188
|
+
os.environ[
|
|
189
|
+
"TPU_VISIBLE_CHIPS"] = tpu_visible_chips \
|
|
190
|
+
if tpu_visible_chips \
|
|
191
|
+
else self.pp_config.default_tpu_visible_chips
|
|
192
|
+
|
|
114
193
|
if not self.devices:
|
|
115
194
|
sharding_config: ShardingConfigManager = self.vllm_config.sharding_config
|
|
116
195
|
device_indexes = sharding_config.device_indexes
|
|
117
196
|
if device_indexes is not None and len(device_indexes) > 0:
|
|
118
197
|
# Enforcing the devices sequence to be consistent with the specified device indexes
|
|
119
|
-
|
|
120
|
-
device_dict = {
|
|
198
|
+
all_local_devices = jax.local_devices()
|
|
199
|
+
device_dict = {
|
|
200
|
+
device.id: device
|
|
201
|
+
for device in all_local_devices
|
|
202
|
+
}
|
|
121
203
|
self.devices = []
|
|
122
204
|
for device_index in device_indexes:
|
|
123
205
|
device = device_dict[device_index]
|
|
124
206
|
if device is None:
|
|
125
207
|
raise KeyError(
|
|
126
208
|
f"Device index {device_index} not found in "
|
|
127
|
-
f"jax.
|
|
209
|
+
f"jax.local_devices() with IDs {list(device_dict.keys())}!"
|
|
128
210
|
)
|
|
129
211
|
self.devices.append(device)
|
|
212
|
+
assert len(self.devices) >= sharding_config.total_devices
|
|
130
213
|
self.devices = self.devices[:sharding_config.total_devices]
|
|
131
214
|
else:
|
|
132
|
-
self.
|
|
215
|
+
if self.pp_config.pp_world_size > 1:
|
|
216
|
+
# We only support a mixed tp + pp scenario that tp size is
|
|
217
|
+
# smaller or equals the total TPUs in one node
|
|
218
|
+
# say: we have 4 nodes with 4 TPUs each, we can only do pp:4, tp:4, but not pp:2, tp:8
|
|
219
|
+
assert jax.local_device_count(
|
|
220
|
+
) >= sharding_config.total_devices
|
|
221
|
+
self.devices = jax.local_devices()[:sharding_config.
|
|
222
|
+
total_devices]
|
|
223
|
+
else:
|
|
224
|
+
# In a multi-host distributed env, say: Ray, local_device count may smaller
|
|
225
|
+
# than the total devices, we just choose the smaller set here.
|
|
226
|
+
self.devices = jax.devices()[:sharding_config.
|
|
227
|
+
total_devices]
|
|
133
228
|
|
|
134
229
|
# Initialize the vLLM distribution layer as a single chip environment,
|
|
135
230
|
# we'll swap the model's parallel modules with TPU SPMD equivalents.
|
|
@@ -146,8 +241,18 @@ class TPUWorker:
|
|
|
146
241
|
tensor_model_parallel_size=1,
|
|
147
242
|
pipeline_model_parallel_size=1,
|
|
148
243
|
)
|
|
244
|
+
|
|
245
|
+
jax_parallel_state.init_pp_distributed_environment(
|
|
246
|
+
self.pp_config.ip,
|
|
247
|
+
self.rank,
|
|
248
|
+
self.parallel_config.pipeline_parallel_size,
|
|
249
|
+
self.devices[0],
|
|
250
|
+
need_pp=self.parallel_config.pipeline_parallel_size > 1)
|
|
251
|
+
|
|
149
252
|
ensure_kv_transfer_initialized(self.vllm_config)
|
|
150
|
-
self.model_runner = TPUModelRunner(
|
|
253
|
+
self.model_runner = TPUModelRunner(
|
|
254
|
+
self.vllm_config, self.devices, self.rank, self.rank == 0,
|
|
255
|
+
self.rank == self.pp_config.pp_world_size - 1)
|
|
151
256
|
logger.info(f"Init worker | "
|
|
152
257
|
f"rank={self.rank} | "
|
|
153
258
|
f"node_id={get_node_id()} | "
|
|
@@ -155,6 +260,12 @@ class TPUWorker:
|
|
|
155
260
|
f"hbm={utils.hbm_usage_gb(self.devices)}GiB")
|
|
156
261
|
vllm_utils.report_usage_stats(self.vllm_config)
|
|
157
262
|
|
|
263
|
+
def initialize_pp_transfer_connect(self):
|
|
264
|
+
if self.rank == 0:
|
|
265
|
+
return
|
|
266
|
+
jax_parallel_state.connect(self.pp_config.prev_worker_ip,
|
|
267
|
+
self.rank - 1)
|
|
268
|
+
|
|
158
269
|
def determine_available_memory(self) -> int:
|
|
159
270
|
gpu_memory_utilization = self.cache_config.gpu_memory_utilization
|
|
160
271
|
hbm_usage = utils.hbm_usage_bytes(self.devices)
|
|
@@ -194,14 +305,39 @@ class TPUWorker:
|
|
|
194
305
|
# deliberate, temporary compromise for the same reasons outlined in
|
|
195
306
|
# the `get_kv_cache_spec` method.
|
|
196
307
|
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
308
|
+
if self.parallel_config.pipeline_parallel_size == 1 or self.rank == 0:
|
|
309
|
+
intermediate_tensors = None
|
|
310
|
+
else:
|
|
311
|
+
# receive intermediate tensors
|
|
312
|
+
uuid = self.model_runner.get_uuid_for_jax_transfer(
|
|
313
|
+
scheduler_output, self.rank - 1, self.step_counter)
|
|
314
|
+
# TODO: this method might only works for vllm model, not sure about jax models.
|
|
315
|
+
tensor_spec = self.model_runner.get_intermediate_tensor_spec(
|
|
316
|
+
scheduler_output.total_num_scheduled_tokens)
|
|
317
|
+
intermediate_tensors_dict = get_pp_group().recv_tensor_dict(
|
|
318
|
+
uuid, tensor_spec)
|
|
319
|
+
intermediate_tensors = JaxIntermediateTensors(
|
|
320
|
+
intermediate_tensors_dict)
|
|
321
|
+
|
|
322
|
+
output = self.model_runner.execute_model(scheduler_output,
|
|
323
|
+
intermediate_tensors)
|
|
324
|
+
|
|
325
|
+
if isinstance(output, JaxIntermediateTensors):
|
|
326
|
+
assert self.parallel_config.pipeline_parallel_size > 1
|
|
327
|
+
assert not get_pp_group().is_last_rank
|
|
328
|
+
# send intermediate tensors
|
|
329
|
+
uuid = self.model_runner.get_uuid_for_jax_transfer(
|
|
330
|
+
scheduler_output, self.rank, self.step_counter)
|
|
331
|
+
get_pp_group().send_tensor_dict(uuid, output.tensors)
|
|
332
|
+
self.step_counter += 1
|
|
333
|
+
return None
|
|
334
|
+
else:
|
|
335
|
+
self.step_counter += 1
|
|
336
|
+
# With a connector, the scheduler expects output from all workers
|
|
337
|
+
# TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866
|
|
338
|
+
if has_kv_transfer_group():
|
|
339
|
+
return output
|
|
340
|
+
return output if self.is_driver_worker else None
|
|
205
341
|
|
|
206
342
|
def sample_tokens(self,
|
|
207
343
|
grammar_output: GrammarOutput) -> ModelRunnerOutput:
|
|
@@ -221,7 +357,7 @@ class TPUWorker:
|
|
|
221
357
|
if is_start:
|
|
222
358
|
options = jax.profiler.ProfileOptions()
|
|
223
359
|
# default: https://docs.jax.dev/en/latest/profiling.html#general-options
|
|
224
|
-
options.python_tracer_level =
|
|
360
|
+
options.python_tracer_level = envs.PYTHON_TRACER_LEVEL
|
|
225
361
|
options.host_tracer_level = os.getenv("HOST_TRACER_LEVEL", 1)
|
|
226
362
|
jax.profiler.start_trace(self.profile_dir,
|
|
227
363
|
profiler_options=options)
|
|
@@ -266,7 +402,8 @@ class TPUWorker:
|
|
|
266
402
|
|
|
267
403
|
# TODO(kyuyeunk): Instead of checking page_size_bytes here, introduce
|
|
268
404
|
# feature that allows overriding page_size_bytes of KVCacheSpec.
|
|
269
|
-
vllm_page_size_bytes = get_uniform_page_size(
|
|
405
|
+
vllm_page_size_bytes = get_uniform_page_size(
|
|
406
|
+
list(kv_cache_specs.values()))
|
|
270
407
|
rpa_page_size_bytes = get_rpa_page_size_bytes(self.model_runner.mesh,
|
|
271
408
|
kv_cache_specs)
|
|
272
409
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: tpu_inference
|
|
3
|
-
Version: 0.11.1.
|
|
3
|
+
Version: 0.11.1.dev202512030818
|
|
4
4
|
Author: tpu_inference Contributors
|
|
5
5
|
Classifier: Development Status :: 3 - Alpha
|
|
6
6
|
Classifier: Intended Audience :: Developers
|
|
@@ -27,10 +27,11 @@ Requires-Dist: jaxtyping
|
|
|
27
27
|
Requires-Dist: flax==0.11.1
|
|
28
28
|
Requires-Dist: torchax==0.0.7
|
|
29
29
|
Requires-Dist: qwix==0.1.1
|
|
30
|
-
Requires-Dist: torchvision==0.
|
|
30
|
+
Requires-Dist: torchvision==0.24.0
|
|
31
31
|
Requires-Dist: pathwaysutils
|
|
32
32
|
Requires-Dist: parameterized
|
|
33
33
|
Requires-Dist: numba==0.62.1
|
|
34
|
+
Requires-Dist: runai-model-streamer[gcs,s3]==0.15.0
|
|
34
35
|
Dynamic: author
|
|
35
36
|
Dynamic: classifier
|
|
36
37
|
Dynamic: description
|