tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/kernels/mla_v1_test.py +129 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
- tests/lora/test_layers.py +4 -7
- tests/lora/test_lora_perf.py +53 -0
- tests/lora/utils.py +0 -8
- tests/test_envs.py +110 -12
- tests/test_quantization.py +3 -0
- tests/test_utils.py +1 -2
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +3 -4
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +93 -9
- tpu_inference/executors/ray_distributed_executor.py +9 -2
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
- tpu_inference/kernels/mla/v1/kernel.py +98 -120
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +140 -67
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +204 -120
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +11 -7
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +170 -208
- tpu_inference/layers/vllm/linear_common.py +43 -21
- tpu_inference/layers/vllm/quantization/common.py +11 -6
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
- tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
- tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +84 -28
- tpu_inference/models/jax/deepseek_v3.py +185 -64
- tpu_inference/models/jax/gpt_oss.py +3 -3
- tpu_inference/models/jax/llama3.py +2 -1
- tpu_inference/models/jax/llama_eagle3.py +8 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +2 -1
- tpu_inference/models/jax/qwen2_5_vl.py +163 -48
- tpu_inference/models/jax/qwen3.py +2 -1
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
- tpu_inference/models/jax/utils/weight_utils.py +205 -144
- tpu_inference/models/vllm/vllm_model_wrapper.py +14 -8
- tpu_inference/platforms/tpu_platform.py +34 -50
- tpu_inference/runner/compilation_manager.py +144 -60
- tpu_inference/runner/kv_cache.py +40 -20
- tpu_inference/runner/kv_cache_manager.py +48 -33
- tpu_inference/runner/persistent_batch_manager.py +40 -2
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +280 -149
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +71 -21
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +46 -18
- tpu_inference/worker/tpu_worker.py +197 -63
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +9 -10
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +70 -74
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
tpu_inference/runner/utils.py
CHANGED
|
@@ -15,6 +15,7 @@ import jax
|
|
|
15
15
|
from jax._src.interpreters import pxla
|
|
16
16
|
from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
|
|
17
17
|
|
|
18
|
+
from tpu_inference import envs
|
|
18
19
|
from tpu_inference.logger import init_logger
|
|
19
20
|
from tpu_inference.runner.input_batch import InputBatch
|
|
20
21
|
|
|
@@ -306,8 +307,7 @@ class PhasedBasedProfiler:
|
|
|
306
307
|
InferencePhase.BALANCED: False
|
|
307
308
|
}
|
|
308
309
|
self.default_profiling_options = jax.profiler.ProfileOptions()
|
|
309
|
-
self.default_profiling_options.python_tracer_level =
|
|
310
|
-
"PYTHON_TRACER_LEVEL", 0)
|
|
310
|
+
self.default_profiling_options.python_tracer_level = envs.PYTHON_TRACER_LEVEL
|
|
311
311
|
|
|
312
312
|
self.current_phase: str = ""
|
|
313
313
|
|
|
@@ -6,13 +6,19 @@ from typing import Any, Optional
|
|
|
6
6
|
import jax
|
|
7
7
|
import jax.numpy as jnp
|
|
8
8
|
import numpy as np
|
|
9
|
+
from flax import nnx
|
|
10
|
+
from jax import lax
|
|
11
|
+
from jax.sharding import NamedSharding, PartitionSpec
|
|
9
12
|
from vllm.config import VllmConfig
|
|
10
13
|
|
|
11
14
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
15
|
+
from tpu_inference.logger import init_logger
|
|
12
16
|
from tpu_inference.models.common.model_loader import get_model
|
|
13
17
|
from tpu_inference.runner import utils as runner_utils
|
|
14
18
|
from tpu_inference.utils import device_array
|
|
15
19
|
|
|
20
|
+
logger = init_logger(__name__)
|
|
21
|
+
|
|
16
22
|
|
|
17
23
|
class Eagle3Proposer:
|
|
18
24
|
"""A proposer for speculative decoding using the Eagle3 method.
|
|
@@ -51,8 +57,22 @@ class Eagle3Proposer:
|
|
|
51
57
|
"""Loads the draft model."""
|
|
52
58
|
self.model_fn, self.compute_logits_fn, self.combine_hidden_states_fn, _, self.state, _, _ = get_model(
|
|
53
59
|
self.vllm_config, self.rng_key, self.mesh, is_draft_model=True)
|
|
54
|
-
|
|
55
|
-
self.state.model
|
|
60
|
+
|
|
61
|
+
draft_embed_tokens = getattr(self.state.model, 'embed_tokens', None)
|
|
62
|
+
if draft_embed_tokens is None or ~jnp.any(
|
|
63
|
+
draft_embed_tokens.embedding):
|
|
64
|
+
logger.info(
|
|
65
|
+
"Draft model does not have embedding. Setting draft model's embed_tokens to target model's embed"
|
|
66
|
+
)
|
|
67
|
+
self.state.model.embed_tokens = target_model.model.embed
|
|
68
|
+
elif jnp.array_equal(draft_embed_tokens.embedding,
|
|
69
|
+
target_model.model.embed.embedding):
|
|
70
|
+
logger.info(
|
|
71
|
+
"Draft model's embed_tokens is identical to target model's embed. Sharing the embedding."
|
|
72
|
+
)
|
|
73
|
+
self.state.model.embed_tokens = target_model.model.embed
|
|
74
|
+
else:
|
|
75
|
+
logger.info("Draft model has its own embed_tokens.")
|
|
56
76
|
|
|
57
77
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
58
78
|
def _prepare_input_ids(
|
|
@@ -110,6 +130,17 @@ class Eagle3Proposer:
|
|
|
110
130
|
max_num_blocks_per_req)
|
|
111
131
|
new_block_tables = jnp.where(expanded_exceeds_mask, -1, block_tables)
|
|
112
132
|
|
|
133
|
+
positions = lax.with_sharding_constraint(
|
|
134
|
+
positions, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
135
|
+
clamped_positions = lax.with_sharding_constraint(
|
|
136
|
+
clamped_positions, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
137
|
+
new_seq_lens = lax.with_sharding_constraint(
|
|
138
|
+
new_seq_lens, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
139
|
+
query_start_loc = lax.with_sharding_constraint(
|
|
140
|
+
query_start_loc, NamedSharding(self.mesh, PartitionSpec()))
|
|
141
|
+
new_block_tables = lax.with_sharding_constraint(
|
|
142
|
+
new_block_tables, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
143
|
+
|
|
113
144
|
return positions, clamped_positions, new_seq_lens, query_start_loc, new_block_tables
|
|
114
145
|
|
|
115
146
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
@@ -121,6 +152,7 @@ class Eagle3Proposer:
|
|
|
121
152
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
122
153
|
def _prepare_hidden_states_and_input_ids(
|
|
123
154
|
self,
|
|
155
|
+
state: nnx.State,
|
|
124
156
|
aux_hidden_states: tuple[jax.Array, ...],
|
|
125
157
|
query_start_loc: jax.Array,
|
|
126
158
|
target_token_ids: jax.Array,
|
|
@@ -129,7 +161,7 @@ class Eagle3Proposer:
|
|
|
129
161
|
) -> tuple[jax.Array, jax.Array, jax.Array]:
|
|
130
162
|
target_hidden_states = jnp.concatenate(aux_hidden_states, axis=-1)
|
|
131
163
|
target_hidden_states = self.combine_hidden_states_fn(
|
|
132
|
-
|
|
164
|
+
state, target_hidden_states)
|
|
133
165
|
|
|
134
166
|
input_ids, last_token_indices = self._prepare_input_ids(
|
|
135
167
|
query_start_loc, target_token_ids, next_token_ids, num_reqs)
|
|
@@ -176,8 +208,8 @@ class Eagle3Proposer:
|
|
|
176
208
|
block_tables=device_array(
|
|
177
209
|
self.mesh, block_tables))
|
|
178
210
|
target_hidden_states, input_ids, last_token_indices = self._prepare_hidden_states_and_input_ids(
|
|
179
|
-
aux_hidden_states, attn_metadata.query_start_loc,
|
|
180
|
-
next_token_ids, num_reqs)
|
|
211
|
+
self.state, aux_hidden_states, attn_metadata.query_start_loc,
|
|
212
|
+
input_ids, next_token_ids, num_reqs)
|
|
181
213
|
return target_hidden_states, input_ids, last_token_indices, attn_metadata
|
|
182
214
|
|
|
183
215
|
# Host copies from the metadata prepared by the runner.
|
|
@@ -241,12 +273,13 @@ class Eagle3Proposer:
|
|
|
241
273
|
|
|
242
274
|
attn_metadata = replace(attn_metadata, block_tables=block_tables)
|
|
243
275
|
return self._filter_token_and_prepare_initial_inputs(
|
|
244
|
-
token_indices, query_start_loc, seq_lens, input_ids,
|
|
276
|
+
self.state, token_indices, query_start_loc, seq_lens, input_ids,
|
|
245
277
|
aux_hidden_states, attn_metadata, next_token_ids, num_reqs)
|
|
246
278
|
|
|
247
279
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
248
280
|
def _filter_token_and_prepare_initial_inputs(
|
|
249
281
|
self,
|
|
282
|
+
state: nnx.State,
|
|
250
283
|
token_indices: jax.Array,
|
|
251
284
|
query_start_loc: jax.Array,
|
|
252
285
|
seq_lens: jax.Array,
|
|
@@ -274,35 +307,51 @@ class Eagle3Proposer:
|
|
|
274
307
|
)
|
|
275
308
|
|
|
276
309
|
target_hidden_states, input_ids, last_token_indices = self._prepare_hidden_states_and_input_ids(
|
|
277
|
-
[h[token_indices] for h in aux_hidden_states],
|
|
278
|
-
target_token_ids, next_token_ids, num_reqs)
|
|
310
|
+
state, [h[token_indices] for h in aux_hidden_states],
|
|
311
|
+
query_start_loc, target_token_ids, next_token_ids, num_reqs)
|
|
279
312
|
|
|
280
313
|
return target_hidden_states, input_ids, last_token_indices, attn_metadata
|
|
281
314
|
|
|
282
315
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
283
316
|
def _select_draft_token_ids(
|
|
284
317
|
self,
|
|
318
|
+
state: nnx.State,
|
|
285
319
|
hidden_states: jax.Array,
|
|
286
320
|
last_token_indices: jax.Array,
|
|
287
321
|
) -> jax.Array:
|
|
288
322
|
sample_hidden_states = hidden_states[last_token_indices]
|
|
289
|
-
|
|
323
|
+
sample_hidden_states = lax.with_sharding_constraint(
|
|
324
|
+
sample_hidden_states,
|
|
325
|
+
NamedSharding(self.mesh, PartitionSpec(None, None)))
|
|
326
|
+
return self._get_draft_token_ids(state, sample_hidden_states)
|
|
290
327
|
|
|
291
328
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
292
|
-
def _get_draft_token_ids(self,
|
|
329
|
+
def _get_draft_token_ids(self, state: nnx.State,
|
|
330
|
+
hidden_states: jax.Array) -> jax.Array:
|
|
293
331
|
lora_metadata = None
|
|
294
|
-
logits = self.compute_logits_fn(
|
|
295
|
-
|
|
296
|
-
return
|
|
332
|
+
logits = self.compute_logits_fn(state, hidden_states, lora_metadata)
|
|
333
|
+
draft_token_ids = jnp.argmax(logits, axis=-1)
|
|
334
|
+
return lax.with_sharding_constraint(
|
|
335
|
+
draft_token_ids, NamedSharding(self.mesh, PartitionSpec()))
|
|
297
336
|
|
|
298
337
|
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
299
338
|
def _select_inputs_for_loop_speculation(
|
|
300
|
-
self, positions: jax.Array, residual: jax.Array,
|
|
339
|
+
self, state: nnx.State, positions: jax.Array, residual: jax.Array,
|
|
301
340
|
hidden_states: jax.Array,
|
|
302
341
|
last_token_indices: jax.Array) -> tuple[jax.Array, jax.Array]:
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
342
|
+
positions = positions[last_token_indices]
|
|
343
|
+
residual = residual[last_token_indices]
|
|
344
|
+
draft_token_ids = self._select_draft_token_ids(state, hidden_states,
|
|
345
|
+
last_token_indices)
|
|
346
|
+
|
|
347
|
+
positions = lax.with_sharding_constraint(
|
|
348
|
+
positions, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
349
|
+
residual = lax.with_sharding_constraint(
|
|
350
|
+
residual, NamedSharding(self.mesh, PartitionSpec(None, None)))
|
|
351
|
+
draft_token_ids = lax.with_sharding_constraint(
|
|
352
|
+
draft_token_ids, NamedSharding(self.mesh, PartitionSpec()))
|
|
353
|
+
|
|
354
|
+
return positions, residual, draft_token_ids
|
|
306
355
|
|
|
307
356
|
def propose(
|
|
308
357
|
self,
|
|
@@ -329,11 +378,11 @@ class Eagle3Proposer:
|
|
|
329
378
|
|
|
330
379
|
if self.num_speculative_tokens == 1:
|
|
331
380
|
return kv_caches, self._select_draft_token_ids(
|
|
332
|
-
hidden_states, last_token_indices)
|
|
381
|
+
self.state, hidden_states, last_token_indices)
|
|
333
382
|
|
|
334
383
|
positions, hidden_states, draft_token_ids = self._select_inputs_for_loop_speculation(
|
|
335
|
-
attn_metadata.input_positions, residual[0],
|
|
336
|
-
last_token_indices)
|
|
384
|
+
self.state, attn_metadata.input_positions, residual[0],
|
|
385
|
+
hidden_states, last_token_indices)
|
|
337
386
|
|
|
338
387
|
draft_token_ids_list = [draft_token_ids]
|
|
339
388
|
|
|
@@ -358,7 +407,8 @@ class Eagle3Proposer:
|
|
|
358
407
|
attn_metadata,
|
|
359
408
|
)
|
|
360
409
|
hidden_states = residual[0]
|
|
361
|
-
draft_token_ids = self._get_draft_token_ids(
|
|
410
|
+
draft_token_ids = self._get_draft_token_ids(
|
|
411
|
+
self.state, new_hidden_states)
|
|
362
412
|
draft_token_ids_list.append(draft_token_ids)
|
|
363
413
|
|
|
364
414
|
# [batch_size, num_speculative_tokens]
|
tpu_inference/tpu_info.py
CHANGED
|
@@ -3,6 +3,7 @@ import os
|
|
|
3
3
|
|
|
4
4
|
import requests
|
|
5
5
|
|
|
6
|
+
from tpu_inference import envs
|
|
6
7
|
from tpu_inference.logger import init_logger
|
|
7
8
|
|
|
8
9
|
logger = init_logger(__name__)
|
|
@@ -32,14 +33,14 @@ def get_tpu_metadata(key: str = "") -> str:
|
|
|
32
33
|
|
|
33
34
|
|
|
34
35
|
def get_tpu_type() -> str:
|
|
35
|
-
tpu_type =
|
|
36
|
+
tpu_type = envs.TPU_ACCELERATOR_TYPE
|
|
36
37
|
if tpu_type is None:
|
|
37
38
|
tpu_type = get_tpu_metadata(key="accelerator-type")
|
|
38
39
|
return tpu_type
|
|
39
40
|
|
|
40
41
|
|
|
41
42
|
def get_node_name() -> str:
|
|
42
|
-
tpu_name =
|
|
43
|
+
tpu_name = envs.TPU_NAME
|
|
43
44
|
if not tpu_name:
|
|
44
45
|
tpu_name = get_tpu_metadata(key="instance-id")
|
|
45
46
|
return tpu_name
|
|
@@ -47,7 +48,7 @@ def get_node_name() -> str:
|
|
|
47
48
|
|
|
48
49
|
def get_node_worker_id() -> int:
|
|
49
50
|
"""For multi-host TPU VM, this returns the worker id for the current node."""
|
|
50
|
-
worker_id =
|
|
51
|
+
worker_id = envs.TPU_WORKER_ID
|
|
51
52
|
if worker_id is None:
|
|
52
53
|
worker_id = get_tpu_metadata(key="agent-worker-number")
|
|
53
54
|
if worker_id is None:
|
tpu_inference/utils.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
-
import os
|
|
3
2
|
import time
|
|
4
3
|
from collections import defaultdict
|
|
5
4
|
from collections.abc import Sequence
|
|
@@ -9,34 +8,62 @@ from typing import Any, Callable, List, Tuple
|
|
|
9
8
|
import jax
|
|
10
9
|
import jax.numpy as jnp
|
|
11
10
|
import numpy as np
|
|
11
|
+
import torch
|
|
12
12
|
from jax._src import dtypes
|
|
13
13
|
from jax._src import mesh as mesh_lib
|
|
14
14
|
from jax._src import xla_bridge as xb
|
|
15
15
|
from jax._src.lib import xla_client as xc
|
|
16
|
+
from jax._src.numpy.scalar_types import _ScalarMeta
|
|
16
17
|
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
17
|
-
from
|
|
18
|
+
from torchax.ops.mappings import j2t_dtype, t2j_dtype
|
|
19
|
+
from vllm import envs as vllm_envs
|
|
20
|
+
from vllm import utils
|
|
18
21
|
|
|
22
|
+
from tpu_inference import envs
|
|
19
23
|
from tpu_inference.logger import init_logger
|
|
20
24
|
|
|
21
25
|
GBYTES = 1024 * 1024 * 1024
|
|
22
26
|
TPU_HEAD_SIZE_ALIGNMENT = 128
|
|
23
27
|
TPU_SECOND_LAST_MINOR = 8
|
|
24
28
|
|
|
25
|
-
#
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
"
|
|
30
|
-
"fp8": jnp.float8_e4m3fn,
|
|
31
|
-
"fp8_e4m3": jnp.float8_e4m3,
|
|
32
|
-
"fp8_e5m2": jnp.float8_e5m2,
|
|
33
|
-
"int8": jnp.int8,
|
|
29
|
+
# Map vllm dtype string that doesn't exactly match jax dtype string name.
|
|
30
|
+
_VLLM_DTYPE_STR_TO_JAX_DTYPE = {
|
|
31
|
+
"fp8": jnp.float8_e4m3fn.dtype,
|
|
32
|
+
"fp8_e4m3": jnp.float8_e4m3fn.dtype,
|
|
33
|
+
"fp8_e5m2": jnp.float8_e5m2.dtype,
|
|
34
34
|
}
|
|
35
35
|
|
|
36
|
+
|
|
37
|
+
def to_jax_dtype(dtype: str | jnp.dtype | torch.dtype) -> jnp.dtype:
|
|
38
|
+
if isinstance(dtype, str):
|
|
39
|
+
if dict_dtype := _VLLM_DTYPE_STR_TO_JAX_DTYPE.get(dtype, None):
|
|
40
|
+
return dict_dtype
|
|
41
|
+
return jnp.dtype(dtype)
|
|
42
|
+
elif isinstance(dtype, torch.dtype):
|
|
43
|
+
return t2j_dtype(dtype)
|
|
44
|
+
elif isinstance(dtype, jnp.dtype):
|
|
45
|
+
return dtype
|
|
46
|
+
elif isinstance(dtype, _ScalarMeta):
|
|
47
|
+
return dtype.dtype
|
|
48
|
+
else:
|
|
49
|
+
raise ValueError(f"Argument is unsupported data type {type(dtype)}")
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def to_torch_dtype(dtype: str | jnp.dtype | torch.dtype) -> torch.dtype:
|
|
53
|
+
# Use jax dtype as an intermediate dtype which we'll be used to convert it
|
|
54
|
+
# into torch dtype.
|
|
55
|
+
dtype = to_jax_dtype(dtype)
|
|
56
|
+
return j2t_dtype(dtype)
|
|
57
|
+
|
|
58
|
+
|
|
36
59
|
_megacore = False
|
|
37
60
|
logger = init_logger(__name__)
|
|
38
61
|
|
|
39
62
|
|
|
63
|
+
def align_to(unpadded_dim, pad_multiple):
|
|
64
|
+
return (unpadded_dim + pad_multiple - 1) // pad_multiple * pad_multiple
|
|
65
|
+
|
|
66
|
+
|
|
40
67
|
def enable_megacore() -> None:
|
|
41
68
|
global _megacore
|
|
42
69
|
_megacore = True
|
|
@@ -57,10 +84,10 @@ def get_num_kv_heads_by_tp(num_kv_heads: int, tp_size: int) -> int:
|
|
|
57
84
|
|
|
58
85
|
def hbm_usage_bytes(devices: Any) -> List[Tuple[int, int]]:
|
|
59
86
|
usage = []
|
|
60
|
-
if
|
|
87
|
+
if vllm_envs.VLLM_TPU_USING_PATHWAYS:
|
|
61
88
|
return pathways_hbm_usage_gb(devices)
|
|
62
89
|
|
|
63
|
-
multihost_backend =
|
|
90
|
+
multihost_backend = envs.TPU_MULTIHOST_BACKEND
|
|
64
91
|
if multihost_backend == "ray":
|
|
65
92
|
# MemoryStats is only supported for addressable PjRt devices.
|
|
66
93
|
# Assume all the devices have similar memory usage for now.
|
|
@@ -163,7 +190,8 @@ def get_padded_num_heads(num_heads: int, sharding_size: int) -> int:
|
|
|
163
190
|
|
|
164
191
|
|
|
165
192
|
def get_dtype_packing(dtype):
|
|
166
|
-
bits = dtypes.bit_width(dtype)
|
|
193
|
+
bits = (dtypes.bit_width(dtype)
|
|
194
|
+
if hasattr(dtypes, "bit_width") else dtypes.itemsize_bits(dtype))
|
|
167
195
|
return 32 // bits
|
|
168
196
|
|
|
169
197
|
|
|
@@ -248,11 +276,11 @@ def device_array(mesh: Mesh, *args, sharding=None, **kwargs) -> jax.Array:
|
|
|
248
276
|
|
|
249
277
|
def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
|
|
250
278
|
"""
|
|
251
|
-
A wrapper function of vllm.utils.get_hash_fn_by_name to support builtin
|
|
279
|
+
A wrapper function of vllm.utils.hashing.get_hash_fn_by_name to support builtin
|
|
252
280
|
"""
|
|
253
281
|
if hash_fn_name == "builtin":
|
|
254
282
|
return hash
|
|
255
|
-
return utils.get_hash_fn_by_name(hash_fn_name)
|
|
283
|
+
return utils.hashing.get_hash_fn_by_name(hash_fn_name)
|
|
256
284
|
|
|
257
285
|
|
|
258
286
|
def quantize_kv(key: jax.Array, value: jax.Array,
|
|
@@ -294,8 +322,8 @@ def get_jax_dtype_from_str_dtype(str_dtype: str) -> jnp.dtype:
|
|
|
294
322
|
Returns:
|
|
295
323
|
jnp.dtype: The JAX dtype.
|
|
296
324
|
"""
|
|
297
|
-
|
|
298
|
-
return
|
|
325
|
+
# TODO(kyuyeunk): Replace all reference of this function into TpuDtype.
|
|
326
|
+
return to_jax_dtype(str_dtype)
|
|
299
327
|
|
|
300
328
|
|
|
301
329
|
def time_function(func):
|