tpu-inference 0.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_adapters.py +83 -0
- tests/core/test_core_tpu.py +523 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/test_lora.py +123 -0
- tests/test_base.py +201 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +218 -0
- tests/tpu_backend_test.py +59 -0
- tpu_inference/__init__.py +30 -0
- tpu_inference/adapters/__init__.py +0 -0
- tpu_inference/adapters/vllm_adapters.py +42 -0
- tpu_inference/adapters/vllm_config_adapters.py +134 -0
- tpu_inference/backend.py +69 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/adapters.py +153 -0
- tpu_inference/core/core_tpu.py +776 -0
- tpu_inference/core/disagg_executor.py +117 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/di/__init__.py +0 -0
- tpu_inference/di/abstracts.py +28 -0
- tpu_inference/di/host.py +76 -0
- tpu_inference/di/interfaces.py +51 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/tpu_connector.py +699 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +346 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/interfaces/__init__.py +0 -0
- tpu_inference/interfaces/cache.py +31 -0
- tpu_inference/interfaces/config.py +47 -0
- tpu_inference/interfaces/config_parts.py +117 -0
- tpu_inference/interfaces/engine.py +51 -0
- tpu_inference/interfaces/outputs.py +22 -0
- tpu_inference/interfaces/params.py +21 -0
- tpu_inference/interfaces/platform.py +74 -0
- tpu_inference/interfaces/request.py +39 -0
- tpu_inference/interfaces/scheduler.py +31 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +254 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/attention_interface.py +356 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/binary_search.py +295 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +172 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +95 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
- tpu_inference/layers/jax/sharding.py +406 -0
- tpu_inference/layers/jax/transformer_block.py +76 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +184 -0
- tpu_inference/layers/vllm/fused_moe.py +399 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +34 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
- tpu_inference/layers/vllm/sharding.py +151 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +308 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1233 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +433 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/llama3.py +366 -0
- tpu_inference/models/jax/llama4.py +473 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +976 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
- tpu_inference/models/jax/utils/weight_utils.py +510 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_jax.py +257 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table_jax.py +122 -0
- tpu_inference/runner/compilation_manager.py +672 -0
- tpu_inference/runner/input_batch_jax.py +435 -0
- tpu_inference/runner/kv_cache.py +119 -0
- tpu_inference/runner/kv_cache_manager.py +460 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +208 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +250 -0
- tpu_inference/runner/structured_decoding_manager.py +89 -0
- tpu_inference/runner/tpu_jax_runner.py +771 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +334 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +294 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/_temporary_vllm_compat.py +129 -0
- tpu_inference/worker/base.py +100 -0
- tpu_inference/worker/tpu_worker_jax.py +321 -0
- tpu_inference-0.11.1.dist-info/METADATA +101 -0
- tpu_inference-0.11.1.dist-info/RECORD +168 -0
- tpu_inference-0.11.1.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,334 @@
|
|
|
1
|
+
"""Implements the Eagle3 proposer for speculative decoding on JAX/TPU."""
|
|
2
|
+
import functools
|
|
3
|
+
from dataclasses import replace
|
|
4
|
+
from typing import Any, Optional
|
|
5
|
+
|
|
6
|
+
import jax
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
import numpy as np
|
|
9
|
+
from vllm.config import VllmConfig
|
|
10
|
+
|
|
11
|
+
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
12
|
+
from tpu_inference.models.common.model_loader import get_model
|
|
13
|
+
from tpu_inference.runner import utils as runner_utils
|
|
14
|
+
from tpu_inference.utils import device_array
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Eagle3Proposer:
|
|
18
|
+
"""A proposer for speculative decoding using the Eagle3 method.
|
|
19
|
+
|
|
20
|
+
This class is responsible for loading the draft model and generating draft
|
|
21
|
+
tokens based on the target model's outputs.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
vllm_config: VllmConfig,
|
|
27
|
+
runner: Any, # TPUModelRunner
|
|
28
|
+
):
|
|
29
|
+
"""Initializes the Eagle3Proposer.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
vllm_config: The vLLM configuration.
|
|
33
|
+
runner: The TPUModelRunner instance.
|
|
34
|
+
"""
|
|
35
|
+
self.vllm_config = vllm_config
|
|
36
|
+
self.speculative_config = vllm_config.speculative_config
|
|
37
|
+
assert self.speculative_config is not None
|
|
38
|
+
self.draft_model_config = self.speculative_config.draft_model_config
|
|
39
|
+
self.method = self.speculative_config.method
|
|
40
|
+
|
|
41
|
+
self.runner = runner
|
|
42
|
+
self.mesh = runner.mesh
|
|
43
|
+
self.num_speculative_tokens = (
|
|
44
|
+
self.speculative_config.num_speculative_tokens)
|
|
45
|
+
self.block_size = vllm_config.cache_config.block_size
|
|
46
|
+
self.rng_key = jax.random.key(self.vllm_config.model_config.seed)
|
|
47
|
+
self.max_num_tokens = runner.max_num_tokens
|
|
48
|
+
self.token_arange = jnp.arange(self.max_num_tokens)
|
|
49
|
+
|
|
50
|
+
def load_model(self, target_model: Any) -> None:
|
|
51
|
+
"""Loads the draft model."""
|
|
52
|
+
self.model_fn, self.compute_logits_fn, self.combine_hidden_states_fn, _, _, _, self.state, _, _ = get_model(
|
|
53
|
+
self.vllm_config, self.rng_key, self.mesh, is_draft_model=True)
|
|
54
|
+
del self.state.model['embed_tokens']
|
|
55
|
+
self.state.model.embed_tokens = target_model.model.embed
|
|
56
|
+
|
|
57
|
+
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
58
|
+
def _concate_hidden_states(self, aux_hidden_states):
|
|
59
|
+
"""JIT-compiled helper for concatenating auxiliary hidden states."""
|
|
60
|
+
# Concat aux hidden states along feature dim.
|
|
61
|
+
return jnp.concatenate(aux_hidden_states, axis=-1)
|
|
62
|
+
|
|
63
|
+
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
64
|
+
def _select_target_hidden_states(self, aux_hidden_states, token_indices):
|
|
65
|
+
"""JIT-compiled helper for selecting target hidden states."""
|
|
66
|
+
return jnp.concatenate([h[token_indices] for h in aux_hidden_states],
|
|
67
|
+
axis=-1)
|
|
68
|
+
|
|
69
|
+
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
70
|
+
def _prepare_input_ids(self, query_start_loc: jax.Array,
|
|
71
|
+
target_token_ids: jax.Array,
|
|
72
|
+
next_token_ids: jax.Array,
|
|
73
|
+
num_reqs: int) -> tuple[jnp.ndarray, jnp.ndarray]:
|
|
74
|
+
"""JIT-compiled helper for preparing the input IDs for the draft model."""
|
|
75
|
+
|
|
76
|
+
last_token_indices = query_start_loc[1:] - 1
|
|
77
|
+
# Shift the input ids by one token.
|
|
78
|
+
rolled_input_ids = jnp.roll(target_token_ids, -1, axis=0)
|
|
79
|
+
|
|
80
|
+
# To make the update JIT-compatible with a dynamic `num_reqs`, we perform a
|
|
81
|
+
# scatter update of a static size, using a mask to handle the dynamic part.
|
|
82
|
+
max_num_reqs = last_token_indices.shape[0]
|
|
83
|
+
mask = jnp.arange(max_num_reqs) < num_reqs
|
|
84
|
+
|
|
85
|
+
# For padded requests (where mask is False), we use the original value from
|
|
86
|
+
# the rolled array, making the update a no-op for them.
|
|
87
|
+
original_values_at_indices = rolled_input_ids[last_token_indices]
|
|
88
|
+
values_to_set = jnp.where(mask, next_token_ids,
|
|
89
|
+
original_values_at_indices)
|
|
90
|
+
|
|
91
|
+
input_ids = rolled_input_ids.at[last_token_indices].set(values_to_set)
|
|
92
|
+
|
|
93
|
+
return input_ids, last_token_indices
|
|
94
|
+
|
|
95
|
+
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
96
|
+
def _prepare_input_loop(self, positions, seq_lens, block_tables):
|
|
97
|
+
"""JIT-compiled helper for preparing inputs in the loop of prediction."""
|
|
98
|
+
|
|
99
|
+
positions += 1
|
|
100
|
+
exceeds_max_model_len = positions >= self.runner.max_model_len
|
|
101
|
+
clamped_positions = jnp.where(exceeds_max_model_len, 0, positions)
|
|
102
|
+
|
|
103
|
+
new_seq_lens = seq_lens + 1
|
|
104
|
+
new_seq_lens = jnp.minimum(new_seq_lens, self.runner.max_model_len)
|
|
105
|
+
new_seq_lens = jnp.where(exceeds_max_model_len, 1, new_seq_lens)
|
|
106
|
+
|
|
107
|
+
num_reqs = seq_lens.shape[0]
|
|
108
|
+
query_start_loc = jnp.arange(num_reqs + 1)
|
|
109
|
+
|
|
110
|
+
# Compute the slot mapping.
|
|
111
|
+
# NOTE(woosuk): We should handle the case where the draft model
|
|
112
|
+
# generates tokens beyond the max model length. Since it is complex
|
|
113
|
+
# to remove such requests from the batch, we keep them in the batch
|
|
114
|
+
# but adjust the position ids and slot mappings to avoid the
|
|
115
|
+
# out-of-range access during the model execution. The draft tokens
|
|
116
|
+
# generated with this adjustment should be ignored.
|
|
117
|
+
max_num_blocks_per_req = block_tables.shape[0] // num_reqs
|
|
118
|
+
expanded_exceeds_mask = jnp.repeat(exceeds_max_model_len,
|
|
119
|
+
max_num_blocks_per_req)
|
|
120
|
+
new_block_tables = jnp.where(expanded_exceeds_mask, -1, block_tables)
|
|
121
|
+
|
|
122
|
+
return positions, clamped_positions, new_seq_lens, query_start_loc, new_block_tables
|
|
123
|
+
|
|
124
|
+
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
125
|
+
def _reshape_block_tables(self, block_tables: jax.Array) -> jax.Array:
|
|
126
|
+
"""JIT-compiled helper for reshaping block tables."""
|
|
127
|
+
return block_tables.reshape(-1)
|
|
128
|
+
|
|
129
|
+
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
130
|
+
def _get_draft_token_ids(self, logits: jax.Array) -> jnp.ndarray:
|
|
131
|
+
"""JIT-compiled helper for getting draft token IDs from logits."""
|
|
132
|
+
return jnp.argmax(logits, axis=-1)
|
|
133
|
+
|
|
134
|
+
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
135
|
+
def _stack_draft_token_ids(
|
|
136
|
+
self, draft_token_ids_list: list[jax.Array]) -> jnp.ndarray:
|
|
137
|
+
"""JIT-compiled helper for stacking draft token IDs."""
|
|
138
|
+
return jnp.stack(draft_token_ids_list, axis=1)
|
|
139
|
+
|
|
140
|
+
def prepare_inputs(
|
|
141
|
+
self,
|
|
142
|
+
attn_metadata: AttentionMetadata,
|
|
143
|
+
input_ids: jax.Array,
|
|
144
|
+
aux_hidden_states: tuple[jax.Array, ...],
|
|
145
|
+
num_rejected_tokens: Optional[jax.Array] = None,
|
|
146
|
+
) -> tuple[AttentionMetadata, jnp.ndarray, jnp.ndarray]:
|
|
147
|
+
"""Prepare drafter inputs based on target forward outputs.
|
|
148
|
+
|
|
149
|
+
Mirrors the GPU reference logic but adapted to TPU/JAX types:
|
|
150
|
+
- When no rejection happened, select the first N scheduled tokens.
|
|
151
|
+
- When rejections happened, trim the per-request tail tokens and
|
|
152
|
+
update attention metadata accordingly.
|
|
153
|
+
- Build the EAGLE3 hidden input by concatenating auxiliary hidden
|
|
154
|
+
states along the last dimension.
|
|
155
|
+
|
|
156
|
+
Returns updated AttentionMetadata (positions, query_start_loc, seq_lens)
|
|
157
|
+
and the selected `target_token_ids` and `target_hidden_states`.
|
|
158
|
+
"""
|
|
159
|
+
assert aux_hidden_states is not None and len(aux_hidden_states) > 0, (
|
|
160
|
+
"EAGLE3 requires auxiliary hidden states from the target model.")
|
|
161
|
+
|
|
162
|
+
if num_rejected_tokens is None:
|
|
163
|
+
return attn_metadata, input_ids, self._concate_hidden_states(
|
|
164
|
+
aux_hidden_states)
|
|
165
|
+
|
|
166
|
+
# Number of active requests in this step (un-padded count).
|
|
167
|
+
num_reqs = self.runner.input_batch.num_reqs
|
|
168
|
+
|
|
169
|
+
# Host copies from the metadata prepared by the runner.
|
|
170
|
+
query_start_loc_cpu = attn_metadata.query_start_loc_cpu
|
|
171
|
+
seq_lens_cpu = attn_metadata.seq_lens_cpu
|
|
172
|
+
assert query_start_loc_cpu is not None and seq_lens_cpu is not None
|
|
173
|
+
|
|
174
|
+
# Rejection-aware path: compute new per-request lengths and token indices.
|
|
175
|
+
# Convert to host numpy for efficient prefix-sum and repeat ops.
|
|
176
|
+
nrt_cpu = jax.device_get(num_rejected_tokens).astype("int32")
|
|
177
|
+
|
|
178
|
+
# query_len_per_req = [q1, q2, ...]
|
|
179
|
+
query_len_per_req = (query_start_loc_cpu[1:] -
|
|
180
|
+
query_start_loc_cpu[:-1])
|
|
181
|
+
|
|
182
|
+
# query_start_loc_cpu and consequentaly query_len_per_req are padded
|
|
183
|
+
# For padded requests, the query length should be 0.
|
|
184
|
+
query_len_per_req[num_reqs:] = 1
|
|
185
|
+
# num_tokens_per_req = [q1 - n1, q2 - n2, ...]
|
|
186
|
+
num_tokens_per_req = (query_len_per_req - nrt_cpu)
|
|
187
|
+
assert (num_tokens_per_req
|
|
188
|
+
>= 0).all(), ("num_tokens_per_req must be non-negative")
|
|
189
|
+
|
|
190
|
+
# new_query_start_loc = [0, q1-n1, q1+q2-n1-n2, ...]
|
|
191
|
+
# Use numpy for cumsum and then convert back.
|
|
192
|
+
new_query_start_loc_cpu = np.zeros_like(query_start_loc_cpu)
|
|
193
|
+
np.cumsum(num_tokens_per_req, out=new_query_start_loc_cpu[1:])
|
|
194
|
+
|
|
195
|
+
# Build token indices selecting the kept tokens from each request.
|
|
196
|
+
total_num_tokens = int(new_query_start_loc_cpu[-1])
|
|
197
|
+
|
|
198
|
+
# Pad to total_num_tokens.
|
|
199
|
+
padded_total_num_tokens = runner_utils.get_padded_token_len(
|
|
200
|
+
self.runner.num_tokens_paddings, total_num_tokens)
|
|
201
|
+
pad_width = padded_total_num_tokens - total_num_tokens
|
|
202
|
+
assert pad_width >= 0, (
|
|
203
|
+
f"total_num_tokens {total_num_tokens} exceeds "
|
|
204
|
+
f"num_tokens_paddings {self.runner.num_tokens_paddings}")
|
|
205
|
+
|
|
206
|
+
# Expand request starts: [0, 0, q1-n1, ...,]
|
|
207
|
+
expanded_new_query_start_loc = np.repeat(new_query_start_loc_cpu[:-1],
|
|
208
|
+
num_tokens_per_req)
|
|
209
|
+
# Offsets within each request window: [0,1,2, 0,1,2,3, ...]
|
|
210
|
+
token_offsets = np.arange(total_num_tokens, dtype=np.int32)
|
|
211
|
+
token_offsets -= expanded_new_query_start_loc
|
|
212
|
+
# Map into old flat indices by adding original request starts.
|
|
213
|
+
old_query_start_loc_expanded = np.repeat(query_start_loc_cpu[:-1],
|
|
214
|
+
num_tokens_per_req)
|
|
215
|
+
|
|
216
|
+
token_indices_cpu = token_offsets + old_query_start_loc_expanded
|
|
217
|
+
token_indices_cpu = np.pad(token_indices_cpu, (0, pad_width),
|
|
218
|
+
"constant",
|
|
219
|
+
constant_values=0)
|
|
220
|
+
token_indices = jnp.asarray(token_indices_cpu, dtype=jnp.int32)
|
|
221
|
+
# Select tokens and hidden states.
|
|
222
|
+
target_token_ids = self.runner._select_from_array_fn(
|
|
223
|
+
input_ids, token_indices)
|
|
224
|
+
target_hidden_states = self._select_target_hidden_states(
|
|
225
|
+
aux_hidden_states, token_indices)
|
|
226
|
+
# Update positions to match the selected tokens.
|
|
227
|
+
if attn_metadata.input_positions.ndim == 2:
|
|
228
|
+
input_positions = attn_metadata.input_positions[:, token_indices]
|
|
229
|
+
else:
|
|
230
|
+
input_positions = self.runner._select_from_array_fn(
|
|
231
|
+
attn_metadata.input_positions, token_indices)
|
|
232
|
+
|
|
233
|
+
# Update seq_lens for active requests only: new_seq_lens = s - n.
|
|
234
|
+
new_seq_lens_cpu = seq_lens_cpu - nrt_cpu
|
|
235
|
+
|
|
236
|
+
query_start_loc, seq_lens = device_array(self.mesh, (
|
|
237
|
+
new_query_start_loc_cpu,
|
|
238
|
+
new_seq_lens_cpu,
|
|
239
|
+
))
|
|
240
|
+
|
|
241
|
+
# Return updated metadata with positions, qsl, and seq_lens.
|
|
242
|
+
updated_attn = AttentionMetadata(
|
|
243
|
+
input_positions=input_positions,
|
|
244
|
+
block_tables=attn_metadata.block_tables,
|
|
245
|
+
seq_lens=seq_lens,
|
|
246
|
+
query_start_loc=query_start_loc,
|
|
247
|
+
request_distribution=attn_metadata.request_distribution,
|
|
248
|
+
)
|
|
249
|
+
return updated_attn, target_token_ids, target_hidden_states
|
|
250
|
+
|
|
251
|
+
def propose(
|
|
252
|
+
self,
|
|
253
|
+
kv_caches: list[jax.Array],
|
|
254
|
+
next_token_ids: jnp.ndarray, # [batch_size]
|
|
255
|
+
attn_metadata: AttentionMetadata,
|
|
256
|
+
target_token_ids,
|
|
257
|
+
target_hidden_states,
|
|
258
|
+
) -> tuple[list[jax.Array], jnp.ndarray]:
|
|
259
|
+
"""Proposes draft tokens using the draft model.
|
|
260
|
+
Returns:
|
|
261
|
+
A tuple containing the updated KV caches and a tensor of proposed
|
|
262
|
+
draft token IDs.
|
|
263
|
+
"""
|
|
264
|
+
|
|
265
|
+
target_hidden_states = self.combine_hidden_states_fn(
|
|
266
|
+
self.state, target_hidden_states)
|
|
267
|
+
|
|
268
|
+
input_ids, last_token_indices = self._prepare_input_ids(
|
|
269
|
+
attn_metadata.query_start_loc, target_token_ids, next_token_ids,
|
|
270
|
+
self.runner.input_batch.num_reqs)
|
|
271
|
+
# NOTE(pooyam): For now, we don't support multimodal.
|
|
272
|
+
|
|
273
|
+
# The last KV cache group is for the draft model.
|
|
274
|
+
num_kv_cache_groups = len(self.runner.kv_cache_config.kv_cache_groups)
|
|
275
|
+
draft_kv_cache_group_id = num_kv_cache_groups - 1
|
|
276
|
+
block_tables = self.runner.input_batch.block_table[
|
|
277
|
+
draft_kv_cache_group_id].get_device_tensor()
|
|
278
|
+
block_tables = self._reshape_block_tables(block_tables)
|
|
279
|
+
attn_metadata = replace(attn_metadata, block_tables=block_tables)
|
|
280
|
+
|
|
281
|
+
kv_caches, hidden_states, residual = self.model_fn(
|
|
282
|
+
self.state,
|
|
283
|
+
kv_caches,
|
|
284
|
+
input_ids,
|
|
285
|
+
target_hidden_states,
|
|
286
|
+
attn_metadata,
|
|
287
|
+
)
|
|
288
|
+
sample_hidden_states = self.runner._select_from_array_fn(
|
|
289
|
+
hidden_states, last_token_indices)
|
|
290
|
+
lora_metadata = None
|
|
291
|
+
logits = self.compute_logits_fn(self.state, sample_hidden_states,
|
|
292
|
+
lora_metadata)
|
|
293
|
+
draft_token_ids = self._get_draft_token_ids(logits)
|
|
294
|
+
|
|
295
|
+
draft_token_ids_list = [draft_token_ids]
|
|
296
|
+
# Early exit if there is only one draft token to be generated.
|
|
297
|
+
if self.num_speculative_tokens == 1:
|
|
298
|
+
return kv_caches, self._stack_draft_token_ids(draft_token_ids_list)
|
|
299
|
+
|
|
300
|
+
positions = self.runner._select_from_array_fn(
|
|
301
|
+
attn_metadata.input_positions, last_token_indices)
|
|
302
|
+
hidden_states = self.runner._select_from_array_fn(
|
|
303
|
+
residual[0], last_token_indices)
|
|
304
|
+
|
|
305
|
+
for _ in range(self.num_speculative_tokens - 1):
|
|
306
|
+
input_ids_loop = draft_token_ids_list[-1]
|
|
307
|
+
|
|
308
|
+
positions, clamped_positions, new_seq_lens, query_start_loc, new_block_tables = self._prepare_input_loop(
|
|
309
|
+
positions, attn_metadata.seq_lens, attn_metadata.block_tables)
|
|
310
|
+
|
|
311
|
+
attn_metadata = replace(
|
|
312
|
+
attn_metadata,
|
|
313
|
+
input_positions=clamped_positions,
|
|
314
|
+
seq_lens=new_seq_lens,
|
|
315
|
+
query_start_loc=query_start_loc,
|
|
316
|
+
block_tables=new_block_tables,
|
|
317
|
+
)
|
|
318
|
+
kv_caches, new_hidden_states, residual = self.model_fn(
|
|
319
|
+
self.state,
|
|
320
|
+
kv_caches,
|
|
321
|
+
input_ids_loop,
|
|
322
|
+
hidden_states, # This should be the hidden_states from previous step
|
|
323
|
+
attn_metadata,
|
|
324
|
+
)
|
|
325
|
+
hidden_states = residual[0]
|
|
326
|
+
logits = self.compute_logits_fn(self.state, new_hidden_states,
|
|
327
|
+
lora_metadata)
|
|
328
|
+
draft_token_ids = self._get_draft_token_ids(logits)
|
|
329
|
+
draft_token_ids_list.append(draft_token_ids)
|
|
330
|
+
|
|
331
|
+
# [batch_size, num_speculative_tokens]
|
|
332
|
+
draft_token_ids = self._stack_draft_token_ids(draft_token_ids_list)
|
|
333
|
+
|
|
334
|
+
return kv_caches, draft_token_ids
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
import glob
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
import requests
|
|
5
|
+
|
|
6
|
+
from tpu_inference.logger import init_logger
|
|
7
|
+
|
|
8
|
+
logger = init_logger(__name__)
|
|
9
|
+
|
|
10
|
+
GCE_TPU_ACCELERATOR_ENDPOINT = (
|
|
11
|
+
"http://metadata.google.internal/computeMetadata/v1/instance/attributes/")
|
|
12
|
+
GCE_TPU_HEADERS = {"Metadata-Flavor": "Google"}
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_tpu_metadata(key: str = "") -> str:
|
|
16
|
+
try:
|
|
17
|
+
accelerator_type_request = requests.get(
|
|
18
|
+
os.path.join(GCE_TPU_ACCELERATOR_ENDPOINT, key),
|
|
19
|
+
headers=GCE_TPU_HEADERS,
|
|
20
|
+
)
|
|
21
|
+
if (accelerator_type_request.status_code == 200
|
|
22
|
+
and accelerator_type_request.text):
|
|
23
|
+
return accelerator_type_request.text
|
|
24
|
+
else:
|
|
25
|
+
logger.error(
|
|
26
|
+
"Unable to poll TPU GCE Metadata. Got "
|
|
27
|
+
f"status code: {accelerator_type_request.status_code} and "
|
|
28
|
+
f"content: {accelerator_type_request.text}")
|
|
29
|
+
except requests.RequestException as e:
|
|
30
|
+
logger.error("Unable to poll the TPU GCE Metadata: %s", e)
|
|
31
|
+
return None
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def get_tpu_type() -> str:
|
|
35
|
+
tpu_type = os.getenv("TPU_ACCELERATOR_TYPE", None)
|
|
36
|
+
if tpu_type is None:
|
|
37
|
+
tpu_type = get_tpu_metadata(key="accelerator-type")
|
|
38
|
+
return tpu_type
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def get_node_name() -> str:
|
|
42
|
+
tpu_name = os.getenv("TPU_NAME", None)
|
|
43
|
+
if not tpu_name:
|
|
44
|
+
tpu_name = get_tpu_metadata(key="instance-id")
|
|
45
|
+
return tpu_name
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def get_node_worker_id() -> int:
|
|
49
|
+
"""For multi-host TPU VM, this returns the worker id for the current node."""
|
|
50
|
+
worker_id = os.getenv("TPU_WORKER_ID", None)
|
|
51
|
+
if worker_id is None:
|
|
52
|
+
worker_id = get_tpu_metadata(key="agent-worker-number")
|
|
53
|
+
if worker_id is None:
|
|
54
|
+
return 0
|
|
55
|
+
return int(worker_id)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def get_num_cores_per_chip() -> int:
|
|
59
|
+
tpu_type = get_tpu_type()
|
|
60
|
+
if tpu_type.startswith(("v5litepod", "v6e")):
|
|
61
|
+
return 1
|
|
62
|
+
return 2
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def get_num_chips() -> int:
|
|
66
|
+
accel_files = glob.glob("/dev/accel*")
|
|
67
|
+
if accel_files:
|
|
68
|
+
return len(accel_files)
|
|
69
|
+
try:
|
|
70
|
+
vfio_entries = os.listdir("/dev/vfio")
|
|
71
|
+
numeric_entries = [
|
|
72
|
+
int(entry) for entry in vfio_entries if entry.isdigit()
|
|
73
|
+
]
|
|
74
|
+
return len(numeric_entries)
|
|
75
|
+
except FileNotFoundError as e:
|
|
76
|
+
logger.error("Failed to detect number of TPUs: %s", e)
|
|
77
|
+
return 0
|