tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/kernels/mla_v1_test.py +129 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
- tests/lora/test_layers.py +4 -1
- tests/lora/test_lora_perf.py +53 -0
- tests/test_envs.py +110 -12
- tests/test_quantization.py +3 -0
- tests/test_utils.py +1 -2
- tpu_inference/distributed/tpu_connector.py +1 -1
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/ray_distributed_executor.py +5 -1
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
- tpu_inference/kernels/mla/v1/kernel.py +98 -120
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +82 -32
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +146 -85
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +11 -7
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +170 -208
- tpu_inference/layers/vllm/linear_common.py +43 -21
- tpu_inference/layers/vllm/quantization/common.py +11 -6
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
- tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
- tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
- tpu_inference/models/common/model_loader.py +78 -22
- tpu_inference/models/jax/deepseek_v3.py +185 -64
- tpu_inference/models/jax/gpt_oss.py +3 -3
- tpu_inference/models/jax/llama_eagle3.py +4 -5
- tpu_inference/models/jax/qwen2_5_vl.py +161 -47
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
- tpu_inference/models/jax/utils/weight_utils.py +203 -155
- tpu_inference/models/vllm/vllm_model_wrapper.py +11 -5
- tpu_inference/platforms/tpu_platform.py +29 -48
- tpu_inference/runner/compilation_manager.py +112 -46
- tpu_inference/runner/kv_cache.py +40 -20
- tpu_inference/runner/kv_cache_manager.py +40 -31
- tpu_inference/runner/persistent_batch_manager.py +40 -2
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +94 -51
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +71 -22
- tpu_inference/utils.py +41 -14
- tpu_inference/worker/tpu_worker.py +43 -45
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +8 -9
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +59 -58
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
|
@@ -6,13 +6,19 @@ from typing import Any, Optional
|
|
|
6
6
|
import jax
|
|
7
7
|
import jax.numpy as jnp
|
|
8
8
|
import numpy as np
|
|
9
|
+
from flax import nnx
|
|
10
|
+
from jax import lax
|
|
11
|
+
from jax.sharding import NamedSharding, PartitionSpec
|
|
9
12
|
from vllm.config import VllmConfig
|
|
10
13
|
|
|
11
14
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
15
|
+
from tpu_inference.logger import init_logger
|
|
12
16
|
from tpu_inference.models.common.model_loader import get_model
|
|
13
17
|
from tpu_inference.runner import utils as runner_utils
|
|
14
18
|
from tpu_inference.utils import device_array
|
|
15
19
|
|
|
20
|
+
logger = init_logger(__name__)
|
|
21
|
+
|
|
16
22
|
|
|
17
23
|
class Eagle3Proposer:
|
|
18
24
|
"""A proposer for speculative decoding using the Eagle3 method.
|
|
@@ -51,9 +57,22 @@ class Eagle3Proposer:
|
|
|
51
57
|
"""Loads the draft model."""
|
|
52
58
|
self.model_fn, self.compute_logits_fn, self.combine_hidden_states_fn, _, self.state, _, _ = get_model(
|
|
53
59
|
self.vllm_config, self.rng_key, self.mesh, is_draft_model=True)
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
60
|
+
|
|
61
|
+
draft_embed_tokens = getattr(self.state.model, 'embed_tokens', None)
|
|
62
|
+
if draft_embed_tokens is None or ~jnp.any(
|
|
63
|
+
draft_embed_tokens.embedding):
|
|
64
|
+
logger.info(
|
|
65
|
+
"Draft model does not have embedding. Setting draft model's embed_tokens to target model's embed"
|
|
66
|
+
)
|
|
67
|
+
self.state.model.embed_tokens = target_model.model.embed
|
|
68
|
+
elif jnp.array_equal(draft_embed_tokens.embedding,
|
|
69
|
+
target_model.model.embed.embedding):
|
|
70
|
+
logger.info(
|
|
71
|
+
"Draft model's embed_tokens is identical to target model's embed. Sharing the embedding."
|
|
72
|
+
)
|
|
73
|
+
self.state.model.embed_tokens = target_model.model.embed
|
|
74
|
+
else:
|
|
75
|
+
logger.info("Draft model has its own embed_tokens.")
|
|
57
76
|
|
|
58
77
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
59
78
|
def _prepare_input_ids(
|
|
@@ -111,6 +130,17 @@ class Eagle3Proposer:
|
|
|
111
130
|
max_num_blocks_per_req)
|
|
112
131
|
new_block_tables = jnp.where(expanded_exceeds_mask, -1, block_tables)
|
|
113
132
|
|
|
133
|
+
positions = lax.with_sharding_constraint(
|
|
134
|
+
positions, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
135
|
+
clamped_positions = lax.with_sharding_constraint(
|
|
136
|
+
clamped_positions, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
137
|
+
new_seq_lens = lax.with_sharding_constraint(
|
|
138
|
+
new_seq_lens, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
139
|
+
query_start_loc = lax.with_sharding_constraint(
|
|
140
|
+
query_start_loc, NamedSharding(self.mesh, PartitionSpec()))
|
|
141
|
+
new_block_tables = lax.with_sharding_constraint(
|
|
142
|
+
new_block_tables, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
143
|
+
|
|
114
144
|
return positions, clamped_positions, new_seq_lens, query_start_loc, new_block_tables
|
|
115
145
|
|
|
116
146
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
@@ -122,6 +152,7 @@ class Eagle3Proposer:
|
|
|
122
152
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
123
153
|
def _prepare_hidden_states_and_input_ids(
|
|
124
154
|
self,
|
|
155
|
+
state: nnx.State,
|
|
125
156
|
aux_hidden_states: tuple[jax.Array, ...],
|
|
126
157
|
query_start_loc: jax.Array,
|
|
127
158
|
target_token_ids: jax.Array,
|
|
@@ -130,7 +161,7 @@ class Eagle3Proposer:
|
|
|
130
161
|
) -> tuple[jax.Array, jax.Array, jax.Array]:
|
|
131
162
|
target_hidden_states = jnp.concatenate(aux_hidden_states, axis=-1)
|
|
132
163
|
target_hidden_states = self.combine_hidden_states_fn(
|
|
133
|
-
|
|
164
|
+
state, target_hidden_states)
|
|
134
165
|
|
|
135
166
|
input_ids, last_token_indices = self._prepare_input_ids(
|
|
136
167
|
query_start_loc, target_token_ids, next_token_ids, num_reqs)
|
|
@@ -177,8 +208,8 @@ class Eagle3Proposer:
|
|
|
177
208
|
block_tables=device_array(
|
|
178
209
|
self.mesh, block_tables))
|
|
179
210
|
target_hidden_states, input_ids, last_token_indices = self._prepare_hidden_states_and_input_ids(
|
|
180
|
-
aux_hidden_states, attn_metadata.query_start_loc,
|
|
181
|
-
next_token_ids, num_reqs)
|
|
211
|
+
self.state, aux_hidden_states, attn_metadata.query_start_loc,
|
|
212
|
+
input_ids, next_token_ids, num_reqs)
|
|
182
213
|
return target_hidden_states, input_ids, last_token_indices, attn_metadata
|
|
183
214
|
|
|
184
215
|
# Host copies from the metadata prepared by the runner.
|
|
@@ -242,12 +273,13 @@ class Eagle3Proposer:
|
|
|
242
273
|
|
|
243
274
|
attn_metadata = replace(attn_metadata, block_tables=block_tables)
|
|
244
275
|
return self._filter_token_and_prepare_initial_inputs(
|
|
245
|
-
token_indices, query_start_loc, seq_lens, input_ids,
|
|
276
|
+
self.state, token_indices, query_start_loc, seq_lens, input_ids,
|
|
246
277
|
aux_hidden_states, attn_metadata, next_token_ids, num_reqs)
|
|
247
278
|
|
|
248
279
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
249
280
|
def _filter_token_and_prepare_initial_inputs(
|
|
250
281
|
self,
|
|
282
|
+
state: nnx.State,
|
|
251
283
|
token_indices: jax.Array,
|
|
252
284
|
query_start_loc: jax.Array,
|
|
253
285
|
seq_lens: jax.Array,
|
|
@@ -275,35 +307,51 @@ class Eagle3Proposer:
|
|
|
275
307
|
)
|
|
276
308
|
|
|
277
309
|
target_hidden_states, input_ids, last_token_indices = self._prepare_hidden_states_and_input_ids(
|
|
278
|
-
[h[token_indices] for h in aux_hidden_states],
|
|
279
|
-
target_token_ids, next_token_ids, num_reqs)
|
|
310
|
+
state, [h[token_indices] for h in aux_hidden_states],
|
|
311
|
+
query_start_loc, target_token_ids, next_token_ids, num_reqs)
|
|
280
312
|
|
|
281
313
|
return target_hidden_states, input_ids, last_token_indices, attn_metadata
|
|
282
314
|
|
|
283
315
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
284
316
|
def _select_draft_token_ids(
|
|
285
317
|
self,
|
|
318
|
+
state: nnx.State,
|
|
286
319
|
hidden_states: jax.Array,
|
|
287
320
|
last_token_indices: jax.Array,
|
|
288
321
|
) -> jax.Array:
|
|
289
322
|
sample_hidden_states = hidden_states[last_token_indices]
|
|
290
|
-
|
|
323
|
+
sample_hidden_states = lax.with_sharding_constraint(
|
|
324
|
+
sample_hidden_states,
|
|
325
|
+
NamedSharding(self.mesh, PartitionSpec(None, None)))
|
|
326
|
+
return self._get_draft_token_ids(state, sample_hidden_states)
|
|
291
327
|
|
|
292
328
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
293
|
-
def _get_draft_token_ids(self,
|
|
329
|
+
def _get_draft_token_ids(self, state: nnx.State,
|
|
330
|
+
hidden_states: jax.Array) -> jax.Array:
|
|
294
331
|
lora_metadata = None
|
|
295
|
-
logits = self.compute_logits_fn(
|
|
296
|
-
|
|
297
|
-
return
|
|
332
|
+
logits = self.compute_logits_fn(state, hidden_states, lora_metadata)
|
|
333
|
+
draft_token_ids = jnp.argmax(logits, axis=-1)
|
|
334
|
+
return lax.with_sharding_constraint(
|
|
335
|
+
draft_token_ids, NamedSharding(self.mesh, PartitionSpec()))
|
|
298
336
|
|
|
299
337
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
300
338
|
def _select_inputs_for_loop_speculation(
|
|
301
|
-
self, positions: jax.Array, residual: jax.Array,
|
|
339
|
+
self, state: nnx.State, positions: jax.Array, residual: jax.Array,
|
|
302
340
|
hidden_states: jax.Array,
|
|
303
341
|
last_token_indices: jax.Array) -> tuple[jax.Array, jax.Array]:
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
342
|
+
positions = positions[last_token_indices]
|
|
343
|
+
residual = residual[last_token_indices]
|
|
344
|
+
draft_token_ids = self._select_draft_token_ids(state, hidden_states,
|
|
345
|
+
last_token_indices)
|
|
346
|
+
|
|
347
|
+
positions = lax.with_sharding_constraint(
|
|
348
|
+
positions, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
349
|
+
residual = lax.with_sharding_constraint(
|
|
350
|
+
residual, NamedSharding(self.mesh, PartitionSpec(None, None)))
|
|
351
|
+
draft_token_ids = lax.with_sharding_constraint(
|
|
352
|
+
draft_token_ids, NamedSharding(self.mesh, PartitionSpec()))
|
|
353
|
+
|
|
354
|
+
return positions, residual, draft_token_ids
|
|
307
355
|
|
|
308
356
|
def propose(
|
|
309
357
|
self,
|
|
@@ -330,11 +378,11 @@ class Eagle3Proposer:
|
|
|
330
378
|
|
|
331
379
|
if self.num_speculative_tokens == 1:
|
|
332
380
|
return kv_caches, self._select_draft_token_ids(
|
|
333
|
-
hidden_states, last_token_indices)
|
|
381
|
+
self.state, hidden_states, last_token_indices)
|
|
334
382
|
|
|
335
383
|
positions, hidden_states, draft_token_ids = self._select_inputs_for_loop_speculation(
|
|
336
|
-
attn_metadata.input_positions, residual[0],
|
|
337
|
-
last_token_indices)
|
|
384
|
+
self.state, attn_metadata.input_positions, residual[0],
|
|
385
|
+
hidden_states, last_token_indices)
|
|
338
386
|
|
|
339
387
|
draft_token_ids_list = [draft_token_ids]
|
|
340
388
|
|
|
@@ -359,7 +407,8 @@ class Eagle3Proposer:
|
|
|
359
407
|
attn_metadata,
|
|
360
408
|
)
|
|
361
409
|
hidden_states = residual[0]
|
|
362
|
-
draft_token_ids = self._get_draft_token_ids(
|
|
410
|
+
draft_token_ids = self._get_draft_token_ids(
|
|
411
|
+
self.state, new_hidden_states)
|
|
363
412
|
draft_token_ids_list.append(draft_token_ids)
|
|
364
413
|
|
|
365
414
|
# [batch_size, num_speculative_tokens]
|
tpu_inference/utils.py
CHANGED
|
@@ -8,11 +8,14 @@ from typing import Any, Callable, List, Tuple
|
|
|
8
8
|
import jax
|
|
9
9
|
import jax.numpy as jnp
|
|
10
10
|
import numpy as np
|
|
11
|
+
import torch
|
|
11
12
|
from jax._src import dtypes
|
|
12
13
|
from jax._src import mesh as mesh_lib
|
|
13
14
|
from jax._src import xla_bridge as xb
|
|
14
15
|
from jax._src.lib import xla_client as xc
|
|
16
|
+
from jax._src.numpy.scalar_types import _ScalarMeta
|
|
15
17
|
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
18
|
+
from torchax.ops.mappings import j2t_dtype, t2j_dtype
|
|
16
19
|
from vllm import envs as vllm_envs
|
|
17
20
|
from vllm import utils
|
|
18
21
|
|
|
@@ -23,21 +26,44 @@ GBYTES = 1024 * 1024 * 1024
|
|
|
23
26
|
TPU_HEAD_SIZE_ALIGNMENT = 128
|
|
24
27
|
TPU_SECOND_LAST_MINOR = 8
|
|
25
28
|
|
|
26
|
-
#
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
"
|
|
31
|
-
"fp8": jnp.float8_e4m3fn,
|
|
32
|
-
"fp8_e4m3": jnp.float8_e4m3,
|
|
33
|
-
"fp8_e5m2": jnp.float8_e5m2,
|
|
34
|
-
"int8": jnp.int8,
|
|
29
|
+
# Map vllm dtype string that doesn't exactly match jax dtype string name.
|
|
30
|
+
_VLLM_DTYPE_STR_TO_JAX_DTYPE = {
|
|
31
|
+
"fp8": jnp.float8_e4m3fn.dtype,
|
|
32
|
+
"fp8_e4m3": jnp.float8_e4m3fn.dtype,
|
|
33
|
+
"fp8_e5m2": jnp.float8_e5m2.dtype,
|
|
35
34
|
}
|
|
36
35
|
|
|
36
|
+
|
|
37
|
+
def to_jax_dtype(dtype: str | jnp.dtype | torch.dtype) -> jnp.dtype:
|
|
38
|
+
if isinstance(dtype, str):
|
|
39
|
+
if dict_dtype := _VLLM_DTYPE_STR_TO_JAX_DTYPE.get(dtype, None):
|
|
40
|
+
return dict_dtype
|
|
41
|
+
return jnp.dtype(dtype)
|
|
42
|
+
elif isinstance(dtype, torch.dtype):
|
|
43
|
+
return t2j_dtype(dtype)
|
|
44
|
+
elif isinstance(dtype, jnp.dtype):
|
|
45
|
+
return dtype
|
|
46
|
+
elif isinstance(dtype, _ScalarMeta):
|
|
47
|
+
return dtype.dtype
|
|
48
|
+
else:
|
|
49
|
+
raise ValueError(f"Argument is unsupported data type {type(dtype)}")
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def to_torch_dtype(dtype: str | jnp.dtype | torch.dtype) -> torch.dtype:
|
|
53
|
+
# Use jax dtype as an intermediate dtype which we'll be used to convert it
|
|
54
|
+
# into torch dtype.
|
|
55
|
+
dtype = to_jax_dtype(dtype)
|
|
56
|
+
return j2t_dtype(dtype)
|
|
57
|
+
|
|
58
|
+
|
|
37
59
|
_megacore = False
|
|
38
60
|
logger = init_logger(__name__)
|
|
39
61
|
|
|
40
62
|
|
|
63
|
+
def align_to(unpadded_dim, pad_multiple):
|
|
64
|
+
return (unpadded_dim + pad_multiple - 1) // pad_multiple * pad_multiple
|
|
65
|
+
|
|
66
|
+
|
|
41
67
|
def enable_megacore() -> None:
|
|
42
68
|
global _megacore
|
|
43
69
|
_megacore = True
|
|
@@ -164,7 +190,8 @@ def get_padded_num_heads(num_heads: int, sharding_size: int) -> int:
|
|
|
164
190
|
|
|
165
191
|
|
|
166
192
|
def get_dtype_packing(dtype):
|
|
167
|
-
bits = dtypes.bit_width(dtype)
|
|
193
|
+
bits = (dtypes.bit_width(dtype)
|
|
194
|
+
if hasattr(dtypes, "bit_width") else dtypes.itemsize_bits(dtype))
|
|
168
195
|
return 32 // bits
|
|
169
196
|
|
|
170
197
|
|
|
@@ -249,11 +276,11 @@ def device_array(mesh: Mesh, *args, sharding=None, **kwargs) -> jax.Array:
|
|
|
249
276
|
|
|
250
277
|
def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
|
|
251
278
|
"""
|
|
252
|
-
A wrapper function of vllm.utils.get_hash_fn_by_name to support builtin
|
|
279
|
+
A wrapper function of vllm.utils.hashing.get_hash_fn_by_name to support builtin
|
|
253
280
|
"""
|
|
254
281
|
if hash_fn_name == "builtin":
|
|
255
282
|
return hash
|
|
256
|
-
return utils.get_hash_fn_by_name(hash_fn_name)
|
|
283
|
+
return utils.hashing.get_hash_fn_by_name(hash_fn_name)
|
|
257
284
|
|
|
258
285
|
|
|
259
286
|
def quantize_kv(key: jax.Array, value: jax.Array,
|
|
@@ -295,8 +322,8 @@ def get_jax_dtype_from_str_dtype(str_dtype: str) -> jnp.dtype:
|
|
|
295
322
|
Returns:
|
|
296
323
|
jnp.dtype: The JAX dtype.
|
|
297
324
|
"""
|
|
298
|
-
|
|
299
|
-
return
|
|
325
|
+
# TODO(kyuyeunk): Replace all reference of this function into TpuDtype.
|
|
326
|
+
return to_jax_dtype(str_dtype)
|
|
300
327
|
|
|
301
328
|
|
|
302
329
|
def time_function(func):
|
|
@@ -6,7 +6,6 @@ from dataclasses import dataclass, field
|
|
|
6
6
|
from typing import Callable, Dict, Optional, Tuple
|
|
7
7
|
|
|
8
8
|
import jax
|
|
9
|
-
import jax.numpy as jnp
|
|
10
9
|
import jaxlib
|
|
11
10
|
import jaxtyping
|
|
12
11
|
import vllm.envs as vllm_envs
|
|
@@ -19,7 +18,8 @@ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
|
|
19
18
|
from vllm.lora.request import LoRARequest
|
|
20
19
|
from vllm.tasks import SupportedTask
|
|
21
20
|
from vllm.v1 import utils as vllm_utils
|
|
22
|
-
from vllm.v1.core.kv_cache_utils import get_num_blocks,
|
|
21
|
+
from vllm.v1.core.kv_cache_utils import (get_kv_cache_groups, get_num_blocks,
|
|
22
|
+
get_uniform_page_size)
|
|
23
23
|
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
|
24
24
|
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
|
25
25
|
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
|
@@ -32,17 +32,11 @@ from tpu_inference.layers.common.sharding import ShardingConfigManager
|
|
|
32
32
|
from tpu_inference.logger import init_logger
|
|
33
33
|
from tpu_inference.models.jax.jax_intermediate_tensor import \
|
|
34
34
|
JaxIntermediateTensors
|
|
35
|
-
from tpu_inference.runner.kv_cache import
|
|
35
|
+
from tpu_inference.runner.kv_cache import get_attention_page_size_bytes
|
|
36
36
|
from tpu_inference.runner.tpu_runner import TPUModelRunner
|
|
37
37
|
|
|
38
38
|
logger = init_logger(__name__)
|
|
39
39
|
|
|
40
|
-
_DTYPE: dict[str, jnp.dtype] = {
|
|
41
|
-
"bfloat16": jnp.bfloat16,
|
|
42
|
-
"float": jnp.float32,
|
|
43
|
-
"float32": jnp.float32,
|
|
44
|
-
}
|
|
45
|
-
|
|
46
40
|
|
|
47
41
|
@dataclass
|
|
48
42
|
class PPConfig:
|
|
@@ -77,21 +71,6 @@ class TPUWorker:
|
|
|
77
71
|
ip: str = "localhost",
|
|
78
72
|
prev_worker_ip: str = "localhost",
|
|
79
73
|
):
|
|
80
|
-
# If we use vLLM's model implementation in PyTorch, we should set it
|
|
81
|
-
# with torch version of the dtype.
|
|
82
|
-
impl = envs.MODEL_IMPL_TYPE
|
|
83
|
-
if impl != "vllm": # vllm-pytorch implementation does not need this conversion
|
|
84
|
-
|
|
85
|
-
# NOTE(wenlong): because sometimes mm needs to use torch for preprocessing
|
|
86
|
-
if not isinstance(vllm_config.model_config.dtype, str):
|
|
87
|
-
logger.warning(
|
|
88
|
-
"The model dtype is not properly set for JAX backend. "
|
|
89
|
-
"Overwriting it to jnp.bfloat16")
|
|
90
|
-
vllm_config.model_config.dtype = jnp.bfloat16
|
|
91
|
-
else:
|
|
92
|
-
vllm_config.model_config.dtype = _DTYPE.get(
|
|
93
|
-
vllm_config.model_config.dtype, jnp.bfloat16)
|
|
94
|
-
|
|
95
74
|
self.vllm_config = vllm_config
|
|
96
75
|
self.model_config = vllm_config.model_config
|
|
97
76
|
self.parallel_config = vllm_config.parallel_config
|
|
@@ -108,7 +87,7 @@ class TPUWorker:
|
|
|
108
87
|
|
|
109
88
|
if self.model_config.trust_remote_code:
|
|
110
89
|
# note: lazy import to avoid importing torch before initializing
|
|
111
|
-
from vllm.utils import init_cached_hf_modules
|
|
90
|
+
from vllm.utils.import_utils import init_cached_hf_modules
|
|
112
91
|
|
|
113
92
|
init_cached_hf_modules()
|
|
114
93
|
|
|
@@ -250,11 +229,20 @@ class TPUWorker:
|
|
|
250
229
|
need_pp=self.parallel_config.pipeline_parallel_size > 1)
|
|
251
230
|
|
|
252
231
|
ensure_kv_transfer_initialized(self.vllm_config)
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
232
|
+
|
|
233
|
+
is_first_rank = True
|
|
234
|
+
is_last_rank = True
|
|
235
|
+
if self.parallel_config.pipeline_parallel_size > 1:
|
|
236
|
+
is_first_rank = self.rank == 0
|
|
237
|
+
is_last_rank = self.rank == self.pp_config.pp_world_size - 1
|
|
238
|
+
|
|
239
|
+
self.model_runner = TPUModelRunner(self.vllm_config, self.devices,
|
|
240
|
+
self.rank, is_first_rank,
|
|
241
|
+
is_last_rank)
|
|
256
242
|
logger.info(f"Init worker | "
|
|
257
243
|
f"rank={self.rank} | "
|
|
244
|
+
f"is_first_rank={is_first_rank} | "
|
|
245
|
+
f"is_last_rank={is_last_rank} | "
|
|
258
246
|
f"node_id={get_node_id()} | "
|
|
259
247
|
f"is_driver_worker={self.is_driver_worker} | "
|
|
260
248
|
f"hbm={utils.hbm_usage_gb(self.devices)}GiB")
|
|
@@ -357,7 +345,7 @@ class TPUWorker:
|
|
|
357
345
|
if is_start:
|
|
358
346
|
options = jax.profiler.ProfileOptions()
|
|
359
347
|
# default: https://docs.jax.dev/en/latest/profiling.html#general-options
|
|
360
|
-
options.python_tracer_level =
|
|
348
|
+
options.python_tracer_level = envs.PYTHON_TRACER_LEVEL
|
|
361
349
|
options.host_tracer_level = os.getenv("HOST_TRACER_LEVEL", 1)
|
|
362
350
|
jax.profiler.start_trace(self.profile_dir,
|
|
363
351
|
profiler_options=options)
|
|
@@ -395,32 +383,37 @@ class TPUWorker:
|
|
|
395
383
|
# responsible for this translation. When vLLM can be modified, this
|
|
396
384
|
# method should be changed to return `dict[str, AbstractKVCacheSpec]`,
|
|
397
385
|
# and the vLLM side should be updated to handle the translation.
|
|
398
|
-
|
|
386
|
+
kv_cache_spec = self.model_runner.get_kv_cache_spec()
|
|
399
387
|
|
|
400
|
-
if len(
|
|
401
|
-
return
|
|
388
|
+
if len(kv_cache_spec) == 0:
|
|
389
|
+
return kv_cache_spec
|
|
402
390
|
|
|
403
391
|
# TODO(kyuyeunk): Instead of checking page_size_bytes here, introduce
|
|
404
392
|
# feature that allows overriding page_size_bytes of KVCacheSpec.
|
|
405
|
-
vllm_page_size_bytes = get_uniform_page_size(
|
|
406
|
-
|
|
407
|
-
|
|
393
|
+
vllm_page_size_bytes = get_uniform_page_size(
|
|
394
|
+
list(kv_cache_spec.values()))
|
|
395
|
+
attention_page_size_bytes = get_attention_page_size_bytes(
|
|
396
|
+
self.model_runner.mesh, kv_cache_spec)
|
|
408
397
|
|
|
409
|
-
if vllm_page_size_bytes !=
|
|
398
|
+
if vllm_page_size_bytes != attention_page_size_bytes:
|
|
410
399
|
logger.info(
|
|
411
|
-
f"
|
|
412
|
-
f"
|
|
413
|
-
f"
|
|
414
|
-
f"
|
|
415
|
-
|
|
400
|
+
f"Page size calculated by vLLM ({vllm_page_size_bytes} Bytes) "
|
|
401
|
+
f"does not match with actual page size used by the kernel "
|
|
402
|
+
f"({attention_page_size_bytes} Bytes). Recalculating number of "
|
|
403
|
+
f"KV blocks using actual page size.")
|
|
404
|
+
|
|
405
|
+
kv_cache_groups = get_kv_cache_groups(self.vllm_config,
|
|
406
|
+
kv_cache_spec)
|
|
407
|
+
group_size = max(
|
|
408
|
+
len(group.layer_names) for group in kv_cache_groups)
|
|
416
409
|
available_memory = self.determine_available_memory()
|
|
417
|
-
num_blocks = get_num_blocks(self.vllm_config,
|
|
418
|
-
available_memory,
|
|
419
|
-
|
|
410
|
+
num_blocks = get_num_blocks(self.vllm_config, group_size,
|
|
411
|
+
available_memory,
|
|
412
|
+
attention_page_size_bytes)
|
|
420
413
|
cache_config = self.vllm_config.cache_config
|
|
421
414
|
cache_config.num_gpu_blocks_override = num_blocks
|
|
422
415
|
|
|
423
|
-
return
|
|
416
|
+
return kv_cache_spec
|
|
424
417
|
|
|
425
418
|
def initialize_from_config(
|
|
426
419
|
self,
|
|
@@ -455,3 +448,8 @@ class TPUWorker:
|
|
|
455
448
|
|
|
456
449
|
def shutdown(self) -> None:
|
|
457
450
|
return
|
|
451
|
+
|
|
452
|
+
# Ray executor do not need handshake metadata
|
|
453
|
+
# as we pass the kv_parameters through proxy server
|
|
454
|
+
def get_kv_connector_handshake_metadata(self) -> None:
|
|
455
|
+
pass
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: tpu_inference
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.12.0.dev20251213
|
|
4
4
|
Author: tpu_inference Contributors
|
|
5
5
|
Classifier: Development Status :: 3 - Alpha
|
|
6
6
|
Classifier: Intended Audience :: Developers
|
|
@@ -14,7 +14,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
|
14
14
|
Requires-Python: >=3.10
|
|
15
15
|
Description-Content-Type: text/markdown
|
|
16
16
|
License-File: LICENSE
|
|
17
|
-
Requires-Dist: tpu-info==0.
|
|
17
|
+
Requires-Dist: tpu-info==0.7.1
|
|
18
18
|
Requires-Dist: yapf==0.43.0
|
|
19
19
|
Requires-Dist: pytest
|
|
20
20
|
Requires-Dist: pytest-mock
|
|
@@ -25,12 +25,13 @@ Requires-Dist: jax[tpu]==0.8.0
|
|
|
25
25
|
Requires-Dist: jaxlib==0.8.0
|
|
26
26
|
Requires-Dist: jaxtyping
|
|
27
27
|
Requires-Dist: flax==0.11.1
|
|
28
|
-
Requires-Dist: torchax==0.0.
|
|
28
|
+
Requires-Dist: torchax==0.0.10
|
|
29
29
|
Requires-Dist: qwix==0.1.1
|
|
30
30
|
Requires-Dist: torchvision==0.24.0
|
|
31
31
|
Requires-Dist: pathwaysutils
|
|
32
32
|
Requires-Dist: parameterized
|
|
33
33
|
Requires-Dist: numba==0.62.1
|
|
34
|
+
Requires-Dist: runai-model-streamer[gcs,s3]==0.15.0
|
|
34
35
|
Dynamic: author
|
|
35
36
|
Dynamic: classifier
|
|
36
37
|
Dynamic: description
|
|
@@ -52,14 +53,12 @@ Dynamic: requires-python
|
|
|
52
53
|
|
|
53
54
|
---
|
|
54
55
|
|
|
55
|
-
_Upcoming Events_ 🔥
|
|
56
|
-
|
|
57
|
-
- Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundation.org/pytorch-conference/) in San Francisco!
|
|
58
|
-
- Join us at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
|
|
59
|
-
- Join us at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
|
|
60
|
-
|
|
61
56
|
_Latest News_ 🔥
|
|
62
57
|
|
|
58
|
+
- [Pytorch Conference](https://pytorchconference.sched.com/event/27QCh/sponsored-session-everything-everywhere-all-at-once-vllm-hardware-optionality-with-spotify-and-google-brittany-rockwell-google-shireen-kheradpey-spotify) Learn how Spotify uses vLLM with both GPUs and TPUs to drive down costs and improve user experience.
|
|
59
|
+
- Check back soon for a recording of our session at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
|
|
60
|
+
- Check back soon for a recording of our session at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
|
|
61
|
+
|
|
63
62
|
- [2025/10] [vLLM TPU: A New Unified Backend Supporting PyTorch and JAX on TPU](https://blog.vllm.ai/2025/10/16/vllm-tpu.html)
|
|
64
63
|
|
|
65
64
|
<details>
|