tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511180814__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +34 -303
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
- tests/lora/test_layers.py +6 -0
- tests/lora/utils.py +8 -0
- tests/test_envs.py +11 -32
- tests/test_utils.py +2 -1
- tpu_inference/__init__.py +3 -22
- tpu_inference/core/disagg_utils.py +8 -6
- tpu_inference/distributed/tpu_connector.py +4 -3
- tpu_inference/distributed/utils.py +2 -3
- tpu_inference/envs.py +8 -61
- tpu_inference/executors/ray_distributed_executor.py +2 -9
- tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +145 -266
- tpu_inference/layers/common/attention_interface.py +1 -7
- tpu_inference/layers/common/sharding.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +208 -170
- tpu_inference/layers/vllm/quantization/common.py +1 -6
- tpu_inference/layers/vllm/quantization/mxfp4.py +73 -138
- tpu_inference/layers/vllm/quantization/unquantized.py +64 -58
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +2 -1
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/common/model_loader.py +10 -43
- tpu_inference/models/jax/llama3.py +1 -2
- tpu_inference/models/jax/llama_eagle3.py +5 -8
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +1 -2
- tpu_inference/models/jax/qwen2_5_vl.py +48 -163
- tpu_inference/models/jax/qwen3.py +1 -2
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
- tpu_inference/models/jax/utils/weight_utils.py +143 -198
- tpu_inference/models/vllm/vllm_model_wrapper.py +8 -14
- tpu_inference/platforms/tpu_platform.py +31 -37
- tpu_inference/runner/compilation_manager.py +58 -141
- tpu_inference/runner/kv_cache.py +1 -1
- tpu_inference/runner/kv_cache_manager.py +18 -17
- tpu_inference/runner/persistent_batch_manager.py +2 -40
- tpu_inference/runner/structured_decoding_manager.py +3 -2
- tpu_inference/runner/tpu_runner.py +147 -271
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +21 -71
- tpu_inference/tpu_info.py +3 -4
- tpu_inference/utils.py +13 -36
- tpu_inference/worker/tpu_worker.py +25 -162
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/METADATA +3 -4
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/RECORD +55 -50
- tpu_inference/models/jax/llama_guard_4.py +0 -361
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/WHEEL +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/top_level.txt +0 -0
tpu_inference/runner/utils.py
CHANGED
|
@@ -15,7 +15,6 @@ import jax
|
|
|
15
15
|
from jax._src.interpreters import pxla
|
|
16
16
|
from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
|
|
17
17
|
|
|
18
|
-
from tpu_inference import envs
|
|
19
18
|
from tpu_inference.logger import init_logger
|
|
20
19
|
from tpu_inference.runner.input_batch import InputBatch
|
|
21
20
|
|
|
@@ -307,7 +306,8 @@ class PhasedBasedProfiler:
|
|
|
307
306
|
InferencePhase.BALANCED: False
|
|
308
307
|
}
|
|
309
308
|
self.default_profiling_options = jax.profiler.ProfileOptions()
|
|
310
|
-
self.default_profiling_options.python_tracer_level =
|
|
309
|
+
self.default_profiling_options.python_tracer_level = os.getenv(
|
|
310
|
+
"PYTHON_TRACER_LEVEL", 0)
|
|
311
311
|
|
|
312
312
|
self.current_phase: str = ""
|
|
313
313
|
|
|
@@ -6,19 +6,13 @@ from typing import Any, Optional
|
|
|
6
6
|
import jax
|
|
7
7
|
import jax.numpy as jnp
|
|
8
8
|
import numpy as np
|
|
9
|
-
from flax import nnx
|
|
10
|
-
from jax import lax
|
|
11
|
-
from jax.sharding import NamedSharding, PartitionSpec
|
|
12
9
|
from vllm.config import VllmConfig
|
|
13
10
|
|
|
14
11
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
15
|
-
from tpu_inference.logger import init_logger
|
|
16
12
|
from tpu_inference.models.common.model_loader import get_model
|
|
17
13
|
from tpu_inference.runner import utils as runner_utils
|
|
18
14
|
from tpu_inference.utils import device_array
|
|
19
15
|
|
|
20
|
-
logger = init_logger(__name__)
|
|
21
|
-
|
|
22
16
|
|
|
23
17
|
class Eagle3Proposer:
|
|
24
18
|
"""A proposer for speculative decoding using the Eagle3 method.
|
|
@@ -57,22 +51,8 @@ class Eagle3Proposer:
|
|
|
57
51
|
"""Loads the draft model."""
|
|
58
52
|
self.model_fn, self.compute_logits_fn, self.combine_hidden_states_fn, _, self.state, _, _ = get_model(
|
|
59
53
|
self.vllm_config, self.rng_key, self.mesh, is_draft_model=True)
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
if draft_embed_tokens is None or ~jnp.any(
|
|
63
|
-
draft_embed_tokens.embedding):
|
|
64
|
-
logger.info(
|
|
65
|
-
"Draft model does not have embedding. Setting draft model's embed_tokens to target model's embed"
|
|
66
|
-
)
|
|
67
|
-
self.state.model.embed_tokens = target_model.model.embed
|
|
68
|
-
elif jnp.array_equal(draft_embed_tokens.embedding,
|
|
69
|
-
target_model.model.embed.embedding):
|
|
70
|
-
logger.info(
|
|
71
|
-
"Draft model's embed_tokens is identical to target model's embed. Sharing the embedding."
|
|
72
|
-
)
|
|
73
|
-
self.state.model.embed_tokens = target_model.model.embed
|
|
74
|
-
else:
|
|
75
|
-
logger.info("Draft model has its own embed_tokens.")
|
|
54
|
+
del self.state.model['embed_tokens']
|
|
55
|
+
self.state.model.embed_tokens = target_model.model.embed
|
|
76
56
|
|
|
77
57
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
78
58
|
def _prepare_input_ids(
|
|
@@ -130,17 +110,6 @@ class Eagle3Proposer:
|
|
|
130
110
|
max_num_blocks_per_req)
|
|
131
111
|
new_block_tables = jnp.where(expanded_exceeds_mask, -1, block_tables)
|
|
132
112
|
|
|
133
|
-
positions = lax.with_sharding_constraint(
|
|
134
|
-
positions, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
135
|
-
clamped_positions = lax.with_sharding_constraint(
|
|
136
|
-
clamped_positions, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
137
|
-
new_seq_lens = lax.with_sharding_constraint(
|
|
138
|
-
new_seq_lens, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
139
|
-
query_start_loc = lax.with_sharding_constraint(
|
|
140
|
-
query_start_loc, NamedSharding(self.mesh, PartitionSpec()))
|
|
141
|
-
new_block_tables = lax.with_sharding_constraint(
|
|
142
|
-
new_block_tables, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
143
|
-
|
|
144
113
|
return positions, clamped_positions, new_seq_lens, query_start_loc, new_block_tables
|
|
145
114
|
|
|
146
115
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
@@ -152,7 +121,6 @@ class Eagle3Proposer:
|
|
|
152
121
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
153
122
|
def _prepare_hidden_states_and_input_ids(
|
|
154
123
|
self,
|
|
155
|
-
state: nnx.State,
|
|
156
124
|
aux_hidden_states: tuple[jax.Array, ...],
|
|
157
125
|
query_start_loc: jax.Array,
|
|
158
126
|
target_token_ids: jax.Array,
|
|
@@ -161,7 +129,7 @@ class Eagle3Proposer:
|
|
|
161
129
|
) -> tuple[jax.Array, jax.Array, jax.Array]:
|
|
162
130
|
target_hidden_states = jnp.concatenate(aux_hidden_states, axis=-1)
|
|
163
131
|
target_hidden_states = self.combine_hidden_states_fn(
|
|
164
|
-
state, target_hidden_states)
|
|
132
|
+
self.state, target_hidden_states)
|
|
165
133
|
|
|
166
134
|
input_ids, last_token_indices = self._prepare_input_ids(
|
|
167
135
|
query_start_loc, target_token_ids, next_token_ids, num_reqs)
|
|
@@ -208,8 +176,8 @@ class Eagle3Proposer:
|
|
|
208
176
|
block_tables=device_array(
|
|
209
177
|
self.mesh, block_tables))
|
|
210
178
|
target_hidden_states, input_ids, last_token_indices = self._prepare_hidden_states_and_input_ids(
|
|
211
|
-
|
|
212
|
-
|
|
179
|
+
aux_hidden_states, attn_metadata.query_start_loc, input_ids,
|
|
180
|
+
next_token_ids, num_reqs)
|
|
213
181
|
return target_hidden_states, input_ids, last_token_indices, attn_metadata
|
|
214
182
|
|
|
215
183
|
# Host copies from the metadata prepared by the runner.
|
|
@@ -273,13 +241,12 @@ class Eagle3Proposer:
|
|
|
273
241
|
|
|
274
242
|
attn_metadata = replace(attn_metadata, block_tables=block_tables)
|
|
275
243
|
return self._filter_token_and_prepare_initial_inputs(
|
|
276
|
-
|
|
244
|
+
token_indices, query_start_loc, seq_lens, input_ids,
|
|
277
245
|
aux_hidden_states, attn_metadata, next_token_ids, num_reqs)
|
|
278
246
|
|
|
279
247
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
280
248
|
def _filter_token_and_prepare_initial_inputs(
|
|
281
249
|
self,
|
|
282
|
-
state: nnx.State,
|
|
283
250
|
token_indices: jax.Array,
|
|
284
251
|
query_start_loc: jax.Array,
|
|
285
252
|
seq_lens: jax.Array,
|
|
@@ -307,51 +274,35 @@ class Eagle3Proposer:
|
|
|
307
274
|
)
|
|
308
275
|
|
|
309
276
|
target_hidden_states, input_ids, last_token_indices = self._prepare_hidden_states_and_input_ids(
|
|
310
|
-
|
|
311
|
-
|
|
277
|
+
[h[token_indices] for h in aux_hidden_states], query_start_loc,
|
|
278
|
+
target_token_ids, next_token_ids, num_reqs)
|
|
312
279
|
|
|
313
280
|
return target_hidden_states, input_ids, last_token_indices, attn_metadata
|
|
314
281
|
|
|
315
282
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
316
283
|
def _select_draft_token_ids(
|
|
317
284
|
self,
|
|
318
|
-
state: nnx.State,
|
|
319
285
|
hidden_states: jax.Array,
|
|
320
286
|
last_token_indices: jax.Array,
|
|
321
287
|
) -> jax.Array:
|
|
322
288
|
sample_hidden_states = hidden_states[last_token_indices]
|
|
323
|
-
|
|
324
|
-
sample_hidden_states,
|
|
325
|
-
NamedSharding(self.mesh, PartitionSpec(None, None)))
|
|
326
|
-
return self._get_draft_token_ids(state, sample_hidden_states)
|
|
289
|
+
return self._get_draft_token_ids(sample_hidden_states)
|
|
327
290
|
|
|
328
291
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
329
|
-
def _get_draft_token_ids(self,
|
|
330
|
-
hidden_states: jax.Array) -> jax.Array:
|
|
292
|
+
def _get_draft_token_ids(self, hidden_states: jax.Array) -> jax.Array:
|
|
331
293
|
lora_metadata = None
|
|
332
|
-
logits = self.compute_logits_fn(state, hidden_states,
|
|
333
|
-
|
|
334
|
-
return
|
|
335
|
-
draft_token_ids, NamedSharding(self.mesh, PartitionSpec()))
|
|
294
|
+
logits = self.compute_logits_fn(self.state, hidden_states,
|
|
295
|
+
lora_metadata)
|
|
296
|
+
return jnp.argmax(logits, axis=-1)
|
|
336
297
|
|
|
337
298
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
338
299
|
def _select_inputs_for_loop_speculation(
|
|
339
|
-
self,
|
|
300
|
+
self, positions: jax.Array, residual: jax.Array,
|
|
340
301
|
hidden_states: jax.Array,
|
|
341
302
|
last_token_indices: jax.Array) -> tuple[jax.Array, jax.Array]:
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
last_token_indices)
|
|
346
|
-
|
|
347
|
-
positions = lax.with_sharding_constraint(
|
|
348
|
-
positions, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
349
|
-
residual = lax.with_sharding_constraint(
|
|
350
|
-
residual, NamedSharding(self.mesh, PartitionSpec(None, None)))
|
|
351
|
-
draft_token_ids = lax.with_sharding_constraint(
|
|
352
|
-
draft_token_ids, NamedSharding(self.mesh, PartitionSpec()))
|
|
353
|
-
|
|
354
|
-
return positions, residual, draft_token_ids
|
|
303
|
+
return positions[last_token_indices], residual[
|
|
304
|
+
last_token_indices], self._select_draft_token_ids(
|
|
305
|
+
hidden_states, last_token_indices)
|
|
355
306
|
|
|
356
307
|
def propose(
|
|
357
308
|
self,
|
|
@@ -378,11 +329,11 @@ class Eagle3Proposer:
|
|
|
378
329
|
|
|
379
330
|
if self.num_speculative_tokens == 1:
|
|
380
331
|
return kv_caches, self._select_draft_token_ids(
|
|
381
|
-
|
|
332
|
+
hidden_states, last_token_indices)
|
|
382
333
|
|
|
383
334
|
positions, hidden_states, draft_token_ids = self._select_inputs_for_loop_speculation(
|
|
384
|
-
|
|
385
|
-
|
|
335
|
+
attn_metadata.input_positions, residual[0], hidden_states,
|
|
336
|
+
last_token_indices)
|
|
386
337
|
|
|
387
338
|
draft_token_ids_list = [draft_token_ids]
|
|
388
339
|
|
|
@@ -407,8 +358,7 @@ class Eagle3Proposer:
|
|
|
407
358
|
attn_metadata,
|
|
408
359
|
)
|
|
409
360
|
hidden_states = residual[0]
|
|
410
|
-
draft_token_ids = self._get_draft_token_ids(
|
|
411
|
-
self.state, new_hidden_states)
|
|
361
|
+
draft_token_ids = self._get_draft_token_ids(new_hidden_states)
|
|
412
362
|
draft_token_ids_list.append(draft_token_ids)
|
|
413
363
|
|
|
414
364
|
# [batch_size, num_speculative_tokens]
|
tpu_inference/tpu_info.py
CHANGED
|
@@ -3,7 +3,6 @@ import os
|
|
|
3
3
|
|
|
4
4
|
import requests
|
|
5
5
|
|
|
6
|
-
from tpu_inference import envs
|
|
7
6
|
from tpu_inference.logger import init_logger
|
|
8
7
|
|
|
9
8
|
logger = init_logger(__name__)
|
|
@@ -33,14 +32,14 @@ def get_tpu_metadata(key: str = "") -> str:
|
|
|
33
32
|
|
|
34
33
|
|
|
35
34
|
def get_tpu_type() -> str:
|
|
36
|
-
tpu_type =
|
|
35
|
+
tpu_type = os.getenv("TPU_ACCELERATOR_TYPE", None)
|
|
37
36
|
if tpu_type is None:
|
|
38
37
|
tpu_type = get_tpu_metadata(key="accelerator-type")
|
|
39
38
|
return tpu_type
|
|
40
39
|
|
|
41
40
|
|
|
42
41
|
def get_node_name() -> str:
|
|
43
|
-
tpu_name =
|
|
42
|
+
tpu_name = os.getenv("TPU_NAME", None)
|
|
44
43
|
if not tpu_name:
|
|
45
44
|
tpu_name = get_tpu_metadata(key="instance-id")
|
|
46
45
|
return tpu_name
|
|
@@ -48,7 +47,7 @@ def get_node_name() -> str:
|
|
|
48
47
|
|
|
49
48
|
def get_node_worker_id() -> int:
|
|
50
49
|
"""For multi-host TPU VM, this returns the worker id for the current node."""
|
|
51
|
-
worker_id =
|
|
50
|
+
worker_id = os.getenv("TPU_WORKER_ID", None)
|
|
52
51
|
if worker_id is None:
|
|
53
52
|
worker_id = get_tpu_metadata(key="agent-worker-number")
|
|
54
53
|
if worker_id is None:
|
tpu_inference/utils.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
import os
|
|
2
3
|
import time
|
|
3
4
|
from collections import defaultdict
|
|
4
5
|
from collections.abc import Sequence
|
|
@@ -8,54 +9,30 @@ from typing import Any, Callable, List, Tuple
|
|
|
8
9
|
import jax
|
|
9
10
|
import jax.numpy as jnp
|
|
10
11
|
import numpy as np
|
|
11
|
-
import torch
|
|
12
12
|
from jax._src import dtypes
|
|
13
13
|
from jax._src import mesh as mesh_lib
|
|
14
14
|
from jax._src import xla_bridge as xb
|
|
15
15
|
from jax._src.lib import xla_client as xc
|
|
16
|
-
from jax._src.numpy.scalar_types import _ScalarMeta
|
|
17
16
|
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
18
|
-
from
|
|
19
|
-
from vllm import envs as vllm_envs
|
|
20
|
-
from vllm import utils
|
|
17
|
+
from vllm import envs, utils
|
|
21
18
|
|
|
22
|
-
from tpu_inference import envs
|
|
23
19
|
from tpu_inference.logger import init_logger
|
|
24
20
|
|
|
25
21
|
GBYTES = 1024 * 1024 * 1024
|
|
26
22
|
TPU_HEAD_SIZE_ALIGNMENT = 128
|
|
27
23
|
TPU_SECOND_LAST_MINOR = 8
|
|
28
24
|
|
|
29
|
-
#
|
|
30
|
-
|
|
25
|
+
# This is used to translate from a string name for a dtype
|
|
26
|
+
# to formal jax.numpy DType. One use case for this is
|
|
27
|
+
# converting the `--kv_cache_dtype` flag to a dtype.
|
|
28
|
+
TPU_STR_DTYPE_TO_JAX_DTYPE = {
|
|
29
|
+
"bfloat16": jnp.bfloat16,
|
|
31
30
|
"fp8": jnp.float8_e4m3fn,
|
|
32
|
-
"fp8_e4m3": jnp.
|
|
31
|
+
"fp8_e4m3": jnp.float8_e4m3,
|
|
33
32
|
"fp8_e5m2": jnp.float8_e5m2,
|
|
33
|
+
"int8": jnp.int8,
|
|
34
34
|
}
|
|
35
35
|
|
|
36
|
-
|
|
37
|
-
def to_jax_dtype(dtype: str | jnp.dtype | torch.dtype) -> jnp.dtype:
|
|
38
|
-
if isinstance(dtype, str):
|
|
39
|
-
if dict_dtype := _VLLM_DTYPE_STR_TO_JAX_DTYPE.get(dtype, None):
|
|
40
|
-
return dict_dtype
|
|
41
|
-
return jnp.dtype(dtype)
|
|
42
|
-
elif isinstance(dtype, torch.dtype):
|
|
43
|
-
return t2j_dtype(dtype)
|
|
44
|
-
elif isinstance(dtype, jnp.dtype):
|
|
45
|
-
return dtype
|
|
46
|
-
elif isinstance(dtype, _ScalarMeta):
|
|
47
|
-
return dtype.dtype
|
|
48
|
-
else:
|
|
49
|
-
raise ValueError(f"Argument is unsupported data type {type(dtype)}")
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
def to_torch_dtype(dtype: str | jnp.dtype | torch.dtype) -> torch.dtype:
|
|
53
|
-
# Use jax dtype as an intermediate dtype which we'll be used to convert it
|
|
54
|
-
# into torch dtype.
|
|
55
|
-
dtype = to_jax_dtype(dtype)
|
|
56
|
-
return j2t_dtype(dtype)
|
|
57
|
-
|
|
58
|
-
|
|
59
36
|
_megacore = False
|
|
60
37
|
logger = init_logger(__name__)
|
|
61
38
|
|
|
@@ -80,10 +57,10 @@ def get_num_kv_heads_by_tp(num_kv_heads: int, tp_size: int) -> int:
|
|
|
80
57
|
|
|
81
58
|
def hbm_usage_bytes(devices: Any) -> List[Tuple[int, int]]:
|
|
82
59
|
usage = []
|
|
83
|
-
if
|
|
60
|
+
if envs.VLLM_TPU_USING_PATHWAYS:
|
|
84
61
|
return pathways_hbm_usage_gb(devices)
|
|
85
62
|
|
|
86
|
-
multihost_backend =
|
|
63
|
+
multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
|
|
87
64
|
if multihost_backend == "ray":
|
|
88
65
|
# MemoryStats is only supported for addressable PjRt devices.
|
|
89
66
|
# Assume all the devices have similar memory usage for now.
|
|
@@ -317,8 +294,8 @@ def get_jax_dtype_from_str_dtype(str_dtype: str) -> jnp.dtype:
|
|
|
317
294
|
Returns:
|
|
318
295
|
jnp.dtype: The JAX dtype.
|
|
319
296
|
"""
|
|
320
|
-
|
|
321
|
-
return
|
|
297
|
+
str_dtype = str_dtype.lower().strip()
|
|
298
|
+
return TPU_STR_DTYPE_TO_JAX_DTYPE.get(str_dtype)
|
|
322
299
|
|
|
323
300
|
|
|
324
301
|
def time_function(func):
|
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
import tempfile
|
|
5
|
-
from dataclasses import dataclass, field
|
|
6
5
|
from typing import Callable, Dict, Optional, Tuple
|
|
7
6
|
|
|
8
7
|
import jax
|
|
@@ -11,7 +10,6 @@ import jaxlib
|
|
|
11
10
|
import jaxtyping
|
|
12
11
|
import vllm.envs as vllm_envs
|
|
13
12
|
from vllm.config import VllmConfig, set_current_vllm_config
|
|
14
|
-
from vllm.distributed import get_pp_group
|
|
15
13
|
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
|
|
16
14
|
has_kv_transfer_group)
|
|
17
15
|
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
|
@@ -25,13 +23,10 @@ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
|
|
25
23
|
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
|
26
24
|
|
|
27
25
|
from tpu_inference import envs, utils
|
|
28
|
-
from tpu_inference.distributed import jax_parallel_state
|
|
29
26
|
from tpu_inference.distributed.utils import (get_host_ip, get_kv_transfer_port,
|
|
30
27
|
get_node_id)
|
|
31
28
|
from tpu_inference.layers.common.sharding import ShardingConfigManager
|
|
32
29
|
from tpu_inference.logger import init_logger
|
|
33
|
-
from tpu_inference.models.jax.jax_intermediate_tensor import \
|
|
34
|
-
JaxIntermediateTensors
|
|
35
30
|
from tpu_inference.runner.kv_cache import get_rpa_page_size_bytes
|
|
36
31
|
from tpu_inference.runner.tpu_runner import TPUModelRunner
|
|
37
32
|
|
|
@@ -44,39 +39,15 @@ _DTYPE: dict[str, jnp.dtype] = {
|
|
|
44
39
|
}
|
|
45
40
|
|
|
46
41
|
|
|
47
|
-
@dataclass
|
|
48
|
-
class PPConfig:
|
|
49
|
-
rank: int
|
|
50
|
-
ip: str
|
|
51
|
-
prev_worker_ip: str
|
|
52
|
-
pp_world_size: int
|
|
53
|
-
|
|
54
|
-
# default env vars for
|
|
55
|
-
# TPU_PROCESS_BOUNDS, TPU_CHIPS_PER_PROCESS_BOUNDS, TPU_VISIBLE_CHIPS
|
|
56
|
-
# if PP is used in single host.
|
|
57
|
-
default_tpu_process_bounds: str = field(init=False)
|
|
58
|
-
default_tpu_chips_per_process_bounds: str = field(init=False)
|
|
59
|
-
default_tpu_visible_chips: str = field(init=False)
|
|
60
|
-
|
|
61
|
-
def __post_init__(self):
|
|
62
|
-
self.default_tpu_process_bounds = f"1,{self.pp_world_size},1"
|
|
63
|
-
self.default_tpu_chips_per_process_bounds = "1,1,1"
|
|
64
|
-
self.default_tpu_visible_chips = f"{self.rank}"
|
|
65
|
-
|
|
66
|
-
|
|
67
42
|
class TPUWorker:
|
|
68
43
|
|
|
69
|
-
def __init__(
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
devices=None,
|
|
77
|
-
ip: str = "localhost",
|
|
78
|
-
prev_worker_ip: str = "localhost",
|
|
79
|
-
):
|
|
44
|
+
def __init__(self,
|
|
45
|
+
vllm_config: VllmConfig,
|
|
46
|
+
local_rank: int,
|
|
47
|
+
rank: int,
|
|
48
|
+
distributed_init_method: str,
|
|
49
|
+
is_driver_worker: bool = False,
|
|
50
|
+
devices=None):
|
|
80
51
|
# If we use vLLM's model implementation in PyTorch, we should set it
|
|
81
52
|
# with torch version of the dtype.
|
|
82
53
|
impl = envs.MODEL_IMPL_TYPE
|
|
@@ -103,12 +74,10 @@ class TPUWorker:
|
|
|
103
74
|
self.devices = devices if devices is not None else []
|
|
104
75
|
self.device_ranks = set(device.id for device in self.devices
|
|
105
76
|
if isinstance(device, jaxlib._jax.Device))
|
|
106
|
-
self.pp_config = PPConfig(rank, ip, prev_worker_ip,
|
|
107
|
-
self.parallel_config.pipeline_parallel_size)
|
|
108
77
|
|
|
109
78
|
if self.model_config.trust_remote_code:
|
|
110
79
|
# note: lazy import to avoid importing torch before initializing
|
|
111
|
-
from vllm.utils
|
|
80
|
+
from vllm.utils import init_cached_hf_modules
|
|
112
81
|
|
|
113
82
|
init_cached_hf_modules()
|
|
114
83
|
|
|
@@ -117,7 +86,7 @@ class TPUWorker:
|
|
|
117
86
|
# TPU Worker is initialized. The profiler server needs to start after
|
|
118
87
|
# MP runtime is initialized.
|
|
119
88
|
self.profile_dir = None
|
|
120
|
-
if vllm_envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1
|
|
89
|
+
if vllm_envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
|
|
121
90
|
if not self.devices or 0 in self.device_ranks:
|
|
122
91
|
# For TPU, we can only have 1 active profiler session for 1 profiler
|
|
123
92
|
# server. So we only profile on rank0.
|
|
@@ -125,14 +94,6 @@ class TPUWorker:
|
|
|
125
94
|
logger.info("Profiling enabled. Traces will be saved to: %s",
|
|
126
95
|
self.profile_dir)
|
|
127
96
|
|
|
128
|
-
# For PP, we use MPMD so we want to profile every worker.
|
|
129
|
-
if self.pp_config.pp_world_size > 1 and vllm_envs.VLLM_TORCH_PROFILER_DIR:
|
|
130
|
-
self.profile_dir = os.path.join(
|
|
131
|
-
vllm_envs.VLLM_TORCH_PROFILER_DIR,
|
|
132
|
-
f"pprank_{self.rank}_ppworldsize_{self.pp_config.pp_world_size}"
|
|
133
|
-
)
|
|
134
|
-
os.makedirs(self.profile_dir, exist_ok=True)
|
|
135
|
-
|
|
136
97
|
use_jax_profiler_server = os.getenv("USE_JAX_PROFILER_SERVER", False)
|
|
137
98
|
# Only one instance of profiler is allowed
|
|
138
99
|
if use_jax_profiler_server and self.rank < 1:
|
|
@@ -144,87 +105,31 @@ class TPUWorker:
|
|
|
144
105
|
)
|
|
145
106
|
jax.profiler.start_server(jax_profiler_server_port)
|
|
146
107
|
|
|
147
|
-
# step_counter is used to calculate uuid to transfer intermediate tensors.
|
|
148
|
-
self.step_counter = 0
|
|
149
|
-
|
|
150
108
|
def initialize_cache(self, num_gpu_blocks: int,
|
|
151
109
|
num_cpu_blocks: int) -> None:
|
|
152
110
|
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
|
153
111
|
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
|
154
112
|
|
|
155
|
-
def init_device(self
|
|
156
|
-
tpu_process_bounds="",
|
|
157
|
-
tpu_chips_per_process_bounds="",
|
|
158
|
-
tpu_visible_chips=""):
|
|
159
|
-
# set tpu visible devices for Jax runtime in single host PP.
|
|
160
|
-
multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
|
|
161
|
-
if multihost_backend != "ray" and self.parallel_config.pipeline_parallel_size > 1:
|
|
162
|
-
tpu_ports = [
|
|
163
|
-
jax_parallel_state.BASE_JAX_PORT + i
|
|
164
|
-
for i in range(self.pp_config.pp_world_size)
|
|
165
|
-
]
|
|
166
|
-
os.environ["TPU_PROCESS_ADDRESSES"] = ",".join(
|
|
167
|
-
[f"localhost:{port}" for port in tpu_ports])
|
|
168
|
-
os.environ["TPU_PROCESS_PORT"] = f"{tpu_ports[self.rank]}"
|
|
169
|
-
os.environ["CLOUD_TPU_TASK_ID"] = f"{self.rank}"
|
|
170
|
-
|
|
171
|
-
# Note: Below is the setting for v6e8 host (8 chips of v6e)
|
|
172
|
-
# Replace with your own topology.
|
|
173
|
-
# There are 2 ways of subslicing a v6e
|
|
174
|
-
# 1) 2 slices with 4 TPU chips each, we can do PP=2, TP=1/2/3/4
|
|
175
|
-
# TPU_PROCESS_BOUNDS = "1,1,1"
|
|
176
|
-
# TPU_CHIPS_PER_PROCESS_BOUNDS = "1,4,1"
|
|
177
|
-
# TPU_VISIBLE_CHIPS = "0,1,2,3" or "4,5,6,7"
|
|
178
|
-
# 2) 1 chip for each subslice, with at most 8 subslices,
|
|
179
|
-
# we can do TP=1, PP=1/2/3/4/5/6/7/8
|
|
180
|
-
os.environ[
|
|
181
|
-
"TPU_PROCESS_BOUNDS"] = tpu_process_bounds \
|
|
182
|
-
if tpu_process_bounds \
|
|
183
|
-
else self.pp_config.default_tpu_process_bounds
|
|
184
|
-
os.environ[
|
|
185
|
-
"TPU_CHIPS_PER_PROCESS_BOUNDS"] = tpu_chips_per_process_bounds \
|
|
186
|
-
if tpu_chips_per_process_bounds \
|
|
187
|
-
else self.pp_config.default_tpu_chips_per_process_bounds
|
|
188
|
-
os.environ[
|
|
189
|
-
"TPU_VISIBLE_CHIPS"] = tpu_visible_chips \
|
|
190
|
-
if tpu_visible_chips \
|
|
191
|
-
else self.pp_config.default_tpu_visible_chips
|
|
192
|
-
|
|
113
|
+
def init_device(self):
|
|
193
114
|
if not self.devices:
|
|
194
115
|
sharding_config: ShardingConfigManager = self.vllm_config.sharding_config
|
|
195
116
|
device_indexes = sharding_config.device_indexes
|
|
196
117
|
if device_indexes is not None and len(device_indexes) > 0:
|
|
197
118
|
# Enforcing the devices sequence to be consistent with the specified device indexes
|
|
198
|
-
|
|
199
|
-
device_dict = {
|
|
200
|
-
device.id: device
|
|
201
|
-
for device in all_local_devices
|
|
202
|
-
}
|
|
119
|
+
all_devices = jax.devices()
|
|
120
|
+
device_dict = {device.id: device for device in all_devices}
|
|
203
121
|
self.devices = []
|
|
204
122
|
for device_index in device_indexes:
|
|
205
123
|
device = device_dict[device_index]
|
|
206
124
|
if device is None:
|
|
207
125
|
raise KeyError(
|
|
208
126
|
f"Device index {device_index} not found in "
|
|
209
|
-
f"jax.
|
|
127
|
+
f"jax.devices() with IDs {list(device_dict.keys())}!"
|
|
210
128
|
)
|
|
211
129
|
self.devices.append(device)
|
|
212
|
-
assert len(self.devices) >= sharding_config.total_devices
|
|
213
130
|
self.devices = self.devices[:sharding_config.total_devices]
|
|
214
131
|
else:
|
|
215
|
-
|
|
216
|
-
# We only support a mixed tp + pp scenario that tp size is
|
|
217
|
-
# smaller or equals the total TPUs in one node
|
|
218
|
-
# say: we have 4 nodes with 4 TPUs each, we can only do pp:4, tp:4, but not pp:2, tp:8
|
|
219
|
-
assert jax.local_device_count(
|
|
220
|
-
) >= sharding_config.total_devices
|
|
221
|
-
self.devices = jax.local_devices()[:sharding_config.
|
|
222
|
-
total_devices]
|
|
223
|
-
else:
|
|
224
|
-
# In a multi-host distributed env, say: Ray, local_device count may smaller
|
|
225
|
-
# than the total devices, we just choose the smaller set here.
|
|
226
|
-
self.devices = jax.devices()[:sharding_config.
|
|
227
|
-
total_devices]
|
|
132
|
+
self.devices = jax.devices()[:sharding_config.total_devices]
|
|
228
133
|
|
|
229
134
|
# Initialize the vLLM distribution layer as a single chip environment,
|
|
230
135
|
# we'll swap the model's parallel modules with TPU SPMD equivalents.
|
|
@@ -241,18 +146,8 @@ class TPUWorker:
|
|
|
241
146
|
tensor_model_parallel_size=1,
|
|
242
147
|
pipeline_model_parallel_size=1,
|
|
243
148
|
)
|
|
244
|
-
|
|
245
|
-
jax_parallel_state.init_pp_distributed_environment(
|
|
246
|
-
self.pp_config.ip,
|
|
247
|
-
self.rank,
|
|
248
|
-
self.parallel_config.pipeline_parallel_size,
|
|
249
|
-
self.devices[0],
|
|
250
|
-
need_pp=self.parallel_config.pipeline_parallel_size > 1)
|
|
251
|
-
|
|
252
149
|
ensure_kv_transfer_initialized(self.vllm_config)
|
|
253
|
-
self.model_runner = TPUModelRunner(
|
|
254
|
-
self.vllm_config, self.devices, self.rank, self.rank == 0,
|
|
255
|
-
self.rank == self.pp_config.pp_world_size - 1)
|
|
150
|
+
self.model_runner = TPUModelRunner(self.vllm_config, self.devices)
|
|
256
151
|
logger.info(f"Init worker | "
|
|
257
152
|
f"rank={self.rank} | "
|
|
258
153
|
f"node_id={get_node_id()} | "
|
|
@@ -260,12 +155,6 @@ class TPUWorker:
|
|
|
260
155
|
f"hbm={utils.hbm_usage_gb(self.devices)}GiB")
|
|
261
156
|
vllm_utils.report_usage_stats(self.vllm_config)
|
|
262
157
|
|
|
263
|
-
def initialize_pp_transfer_connect(self):
|
|
264
|
-
if self.rank == 0:
|
|
265
|
-
return
|
|
266
|
-
jax_parallel_state.connect(self.pp_config.prev_worker_ip,
|
|
267
|
-
self.rank - 1)
|
|
268
|
-
|
|
269
158
|
def determine_available_memory(self) -> int:
|
|
270
159
|
gpu_memory_utilization = self.cache_config.gpu_memory_utilization
|
|
271
160
|
hbm_usage = utils.hbm_usage_bytes(self.devices)
|
|
@@ -305,39 +194,14 @@ class TPUWorker:
|
|
|
305
194
|
# deliberate, temporary compromise for the same reasons outlined in
|
|
306
195
|
# the `get_kv_cache_spec` method.
|
|
307
196
|
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
scheduler_output.total_num_scheduled_tokens)
|
|
317
|
-
intermediate_tensors_dict = get_pp_group().recv_tensor_dict(
|
|
318
|
-
uuid, tensor_spec)
|
|
319
|
-
intermediate_tensors = JaxIntermediateTensors(
|
|
320
|
-
intermediate_tensors_dict)
|
|
321
|
-
|
|
322
|
-
output = self.model_runner.execute_model(scheduler_output,
|
|
323
|
-
intermediate_tensors)
|
|
324
|
-
|
|
325
|
-
if isinstance(output, JaxIntermediateTensors):
|
|
326
|
-
assert self.parallel_config.pipeline_parallel_size > 1
|
|
327
|
-
assert not get_pp_group().is_last_rank
|
|
328
|
-
# send intermediate tensors
|
|
329
|
-
uuid = self.model_runner.get_uuid_for_jax_transfer(
|
|
330
|
-
scheduler_output, self.rank, self.step_counter)
|
|
331
|
-
get_pp_group().send_tensor_dict(uuid, output.tensors)
|
|
332
|
-
self.step_counter += 1
|
|
333
|
-
return None
|
|
334
|
-
else:
|
|
335
|
-
self.step_counter += 1
|
|
336
|
-
# With a connector, the scheduler expects output from all workers
|
|
337
|
-
# TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866
|
|
338
|
-
if has_kv_transfer_group():
|
|
339
|
-
return output
|
|
340
|
-
return output if self.is_driver_worker else None
|
|
197
|
+
output = self.model_runner.execute_model(scheduler_output)
|
|
198
|
+
|
|
199
|
+
# With a connector, the scheduler expects output from all workers
|
|
200
|
+
# TODO(mrjunwan): Figure out if this is ok after https://github.com/vllm-project/vllm/pull/26866
|
|
201
|
+
if has_kv_transfer_group():
|
|
202
|
+
return output
|
|
203
|
+
|
|
204
|
+
return output if self.is_driver_worker else None
|
|
341
205
|
|
|
342
206
|
def sample_tokens(self,
|
|
343
207
|
grammar_output: GrammarOutput) -> ModelRunnerOutput:
|
|
@@ -357,7 +221,7 @@ class TPUWorker:
|
|
|
357
221
|
if is_start:
|
|
358
222
|
options = jax.profiler.ProfileOptions()
|
|
359
223
|
# default: https://docs.jax.dev/en/latest/profiling.html#general-options
|
|
360
|
-
options.python_tracer_level =
|
|
224
|
+
options.python_tracer_level = os.getenv("PYTHON_TRACER_LEVEL", 0)
|
|
361
225
|
options.host_tracer_level = os.getenv("HOST_TRACER_LEVEL", 1)
|
|
362
226
|
jax.profiler.start_trace(self.profile_dir,
|
|
363
227
|
profiler_options=options)
|
|
@@ -402,8 +266,7 @@ class TPUWorker:
|
|
|
402
266
|
|
|
403
267
|
# TODO(kyuyeunk): Instead of checking page_size_bytes here, introduce
|
|
404
268
|
# feature that allows overriding page_size_bytes of KVCacheSpec.
|
|
405
|
-
vllm_page_size_bytes = get_uniform_page_size(
|
|
406
|
-
list(kv_cache_specs.values()))
|
|
269
|
+
vllm_page_size_bytes = get_uniform_page_size(kv_cache_specs)
|
|
407
270
|
rpa_page_size_bytes = get_rpa_page_size_bytes(self.model_runner.mesh,
|
|
408
271
|
kv_cache_specs)
|
|
409
272
|
|
{tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: tpu_inference
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.11.1.dev202511180814
|
|
4
4
|
Author: tpu_inference Contributors
|
|
5
5
|
Classifier: Development Status :: 3 - Alpha
|
|
6
6
|
Classifier: Intended Audience :: Developers
|
|
@@ -14,7 +14,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
|
14
14
|
Requires-Python: >=3.10
|
|
15
15
|
Description-Content-Type: text/markdown
|
|
16
16
|
License-File: LICENSE
|
|
17
|
-
Requires-Dist: tpu-info==0.
|
|
17
|
+
Requires-Dist: tpu-info==0.4.0
|
|
18
18
|
Requires-Dist: yapf==0.43.0
|
|
19
19
|
Requires-Dist: pytest
|
|
20
20
|
Requires-Dist: pytest-mock
|
|
@@ -27,11 +27,10 @@ Requires-Dist: jaxtyping
|
|
|
27
27
|
Requires-Dist: flax==0.11.1
|
|
28
28
|
Requires-Dist: torchax==0.0.7
|
|
29
29
|
Requires-Dist: qwix==0.1.1
|
|
30
|
-
Requires-Dist: torchvision==0.
|
|
30
|
+
Requires-Dist: torchvision==0.23.0
|
|
31
31
|
Requires-Dist: pathwaysutils
|
|
32
32
|
Requires-Dist: parameterized
|
|
33
33
|
Requires-Dist: numba==0.62.1
|
|
34
|
-
Requires-Dist: runai-model-streamer[gcs,s3]==0.15.0
|
|
35
34
|
Dynamic: author
|
|
36
35
|
Dynamic: classifier
|
|
37
36
|
Dynamic: description
|