tpu-inference 0.11.1.dev202511150811__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_dp_scheduler.py +899 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/fused_moe_v1_test.py +105 -0
- tests/kernels/mla_v1_test.py +396 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/conftest.py +32 -0
- tests/lora/test_bgmv.py +43 -0
- tests/lora/test_layers.py +654 -0
- tests/lora/test_lora.py +133 -0
- tests/lora/utils.py +96 -0
- tests/test_base.py +201 -0
- tests/test_envs.py +182 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +236 -0
- tpu_inference/__init__.py +34 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/core/sched/__init__.py +0 -0
- tpu_inference/core/sched/dp_scheduler.py +523 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/jax_parallel_state.py +67 -0
- tpu_inference/distributed/tpu_connector.py +728 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +107 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +362 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
- tpu_inference/kernels/mla/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/kernel.py +1349 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_interface.py +390 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/common/sharding.py +582 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +255 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +280 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +96 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
- tpu_inference/layers/jax/transformer_block.py +107 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +507 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +39 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
- tpu_inference/layers/vllm/sharding.py +230 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +311 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +444 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/gpt_oss.py +492 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
- tpu_inference/models/jax/llama3.py +375 -0
- tpu_inference/models/jax/llama4.py +629 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
- tpu_inference/models/jax/utils/weight_utils.py +529 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_platform.py +269 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +780 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +132 -0
- tpu_inference/runner/kv_cache_manager.py +479 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +217 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +248 -0
- tpu_inference/runner/structured_decoding_manager.py +88 -0
- tpu_inference/runner/tpu_runner.py +1620 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +367 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +317 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/tpu_worker.py +321 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,515 @@
|
|
|
1
|
+
"""
|
|
2
|
+
JAX-based rejection sampler for speculative decoding on TPU.
|
|
3
|
+
|
|
4
|
+
This implementation follows the same algorithm as the GPU version but is
|
|
5
|
+
designed for JAX/TPU compatibility. It currently only supports greedy sampling.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import functools
|
|
9
|
+
from typing import Optional
|
|
10
|
+
|
|
11
|
+
import jax
|
|
12
|
+
import jax.numpy as jnp
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
from tpu_inference.layers.common.binary_search import topk_mask, topp_mask
|
|
16
|
+
from tpu_inference.layers.jax.sample.sampling_metadata import \
|
|
17
|
+
TPUSupportedSamplingMetadata
|
|
18
|
+
|
|
19
|
+
# Placeholder token ID for rejected tokens
|
|
20
|
+
PLACEHOLDER_TOKEN_ID = -1
|
|
21
|
+
GREEDY_TEMPERATURE = -1
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class RejectionSampler:
|
|
25
|
+
"""
|
|
26
|
+
JAX-based rejection sampler for speculative decoding.
|
|
27
|
+
|
|
28
|
+
The implementation follows the algorithm described in
|
|
29
|
+
https://arxiv.org/abs/2211.17192.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self):
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
def __call__(
|
|
36
|
+
self,
|
|
37
|
+
# [num_tokens] - flattened format
|
|
38
|
+
draft_token_ids: jnp.ndarray,
|
|
39
|
+
# [batch_size] - number of draft tokens per request
|
|
40
|
+
num_draft_tokens: jnp.ndarray,
|
|
41
|
+
# [num_tokens, vocab_size] - flattened format
|
|
42
|
+
draft_probs: Optional[jnp.ndarray],
|
|
43
|
+
# [num_tokens, vocab_size] - flattened format
|
|
44
|
+
target_logits: jnp.ndarray,
|
|
45
|
+
# [batch_size]
|
|
46
|
+
bonus_token_ids: jnp.ndarray,
|
|
47
|
+
sampling_metadata: TPUSupportedSamplingMetadata,
|
|
48
|
+
key: Optional[jax.random.PRNGKey] = None,
|
|
49
|
+
) -> jnp.ndarray:
|
|
50
|
+
"""
|
|
51
|
+
Perform rejection sampling on draft tokens with flattened inputs.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
draft_token_ids: Draft token IDs in flattened format [num_tokens].
|
|
55
|
+
num_draft_tokens: Number of draft tokens per request [batch_size].
|
|
56
|
+
draft_probs: Draft probabilities in flattened format [num_tokens, vocab_size].
|
|
57
|
+
target_probs: Target probabilities in flattened format [num_tokens, vocab_size].
|
|
58
|
+
bonus_token_ids: Bonus token IDs [batch_size].
|
|
59
|
+
sampling_metadata: Additional metadata needed for sampling.
|
|
60
|
+
key: JAX random key for non-greedy sampling.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
output_token_ids: A tensor containing the final output token IDs.
|
|
64
|
+
"""
|
|
65
|
+
return self.forward(
|
|
66
|
+
draft_token_ids=draft_token_ids,
|
|
67
|
+
num_draft_tokens=num_draft_tokens,
|
|
68
|
+
draft_probs=draft_probs,
|
|
69
|
+
target_logits=target_logits,
|
|
70
|
+
bonus_token_ids=bonus_token_ids,
|
|
71
|
+
sampling_metadata=sampling_metadata,
|
|
72
|
+
key=key,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
76
|
+
def forward(
|
|
77
|
+
self,
|
|
78
|
+
# [num_tokens] - flattened format
|
|
79
|
+
draft_token_ids: jnp.ndarray,
|
|
80
|
+
# [batch_size] - number of draft tokens per request
|
|
81
|
+
num_draft_tokens: jnp.ndarray,
|
|
82
|
+
# [num_tokens, vocab_size] - flattened format
|
|
83
|
+
draft_probs: Optional[jnp.ndarray],
|
|
84
|
+
# [num_tokens, vocab_size] - flattened format
|
|
85
|
+
target_logits: jnp.ndarray,
|
|
86
|
+
# [batch_size]
|
|
87
|
+
bonus_token_ids: jnp.ndarray,
|
|
88
|
+
sampling_metadata: TPUSupportedSamplingMetadata,
|
|
89
|
+
key: Optional[jax.random.PRNGKey] = None,
|
|
90
|
+
) -> jnp.ndarray:
|
|
91
|
+
"""
|
|
92
|
+
Perform rejection sampling on draft tokens with flattened inputs.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
draft_token_ids: Draft token IDs in flattened format [num_tokens].
|
|
96
|
+
num_draft_tokens: Number of draft tokens per request [batch_size].
|
|
97
|
+
draft_probs: Draft probabilities in flattened format [num_tokens, vocab_size].
|
|
98
|
+
target_logits: Target logits in flattened format [num_tokens, vocab_size].
|
|
99
|
+
bonus_token_ids: Bonus token IDs [batch_size].
|
|
100
|
+
sampling_metadata: Additional metadata needed for sampling.
|
|
101
|
+
key: JAX random key for non-greedy sampling.
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
output_token_ids: A tensor containing the final output token IDs.
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
if sampling_metadata.do_sampling:
|
|
108
|
+
target_probs = _compute_probs(target_logits, num_draft_tokens,
|
|
109
|
+
sampling_metadata)
|
|
110
|
+
else:
|
|
111
|
+
target_probs = target_logits
|
|
112
|
+
|
|
113
|
+
output_token_ids = rejection_sample(
|
|
114
|
+
draft_token_ids,
|
|
115
|
+
num_draft_tokens,
|
|
116
|
+
draft_probs,
|
|
117
|
+
target_probs,
|
|
118
|
+
bonus_token_ids,
|
|
119
|
+
sampling_metadata,
|
|
120
|
+
key=key,
|
|
121
|
+
)
|
|
122
|
+
return output_token_ids
|
|
123
|
+
|
|
124
|
+
@staticmethod
|
|
125
|
+
def parse_output(
|
|
126
|
+
output_token_ids: jnp.ndarray,
|
|
127
|
+
vocab_size: int,
|
|
128
|
+
num_draft_tokens_cpu: np.ndarray,
|
|
129
|
+
batch_size: int,
|
|
130
|
+
padded_tokens_length: int,
|
|
131
|
+
) -> list[list[int]]:
|
|
132
|
+
"""Parse the output of the rejection sampler.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
output_token_ids: The sampled token IDs in shape
|
|
136
|
+
[num_tokens + batch_size]. The first num_tokens elements are
|
|
137
|
+
the main tokens, and the last batch_size elements are bonus tokens.
|
|
138
|
+
Rejected tokens are replaced with `PLACEHOLDER_TOKEN_ID`.
|
|
139
|
+
vocab_size: The size of the vocabulary.
|
|
140
|
+
num_draft_tokens_cpu: Number of draft tokens per request [batch_size]
|
|
141
|
+
as a numpy array on CPU.
|
|
142
|
+
batch_size: The number of requests in the batch.
|
|
143
|
+
padded_tokens_length: The padded length of the main tokens in the output.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
A list of lists of token IDs.
|
|
147
|
+
"""
|
|
148
|
+
# Convert JAX array to numpy for easier manipulation
|
|
149
|
+
output_token_ids_np = np.asarray(output_token_ids)
|
|
150
|
+
|
|
151
|
+
# Split main tokens and bonus tokens
|
|
152
|
+
main_tokens = output_token_ids_np[:
|
|
153
|
+
padded_tokens_length] # [num_tokens]
|
|
154
|
+
bonus_tokens = output_token_ids_np[
|
|
155
|
+
padded_tokens_length:] # [batch_size]
|
|
156
|
+
|
|
157
|
+
# Reconstruct per-sequence outputs
|
|
158
|
+
outputs = []
|
|
159
|
+
start_idx = 0
|
|
160
|
+
|
|
161
|
+
for i in range(batch_size):
|
|
162
|
+
seq_length = int(num_draft_tokens_cpu[i])
|
|
163
|
+
end_idx = start_idx + seq_length
|
|
164
|
+
|
|
165
|
+
# Get main tokens for this sequence
|
|
166
|
+
seq_main_tokens = main_tokens[start_idx:end_idx]
|
|
167
|
+
|
|
168
|
+
# Filter out placeholder tokens
|
|
169
|
+
valid_main_tokens = seq_main_tokens[
|
|
170
|
+
(seq_main_tokens != PLACEHOLDER_TOKEN_ID)
|
|
171
|
+
& (seq_main_tokens < vocab_size)]
|
|
172
|
+
|
|
173
|
+
# Add bonus token if it's valid
|
|
174
|
+
bonus_token = bonus_tokens[i]
|
|
175
|
+
if bonus_token != PLACEHOLDER_TOKEN_ID and bonus_token < vocab_size:
|
|
176
|
+
seq_tokens = np.concatenate([valid_main_tokens, [bonus_token]])
|
|
177
|
+
else:
|
|
178
|
+
seq_tokens = valid_main_tokens
|
|
179
|
+
|
|
180
|
+
outputs.append(seq_tokens.tolist())
|
|
181
|
+
start_idx = end_idx
|
|
182
|
+
|
|
183
|
+
return outputs
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _compute_probs(
|
|
187
|
+
logits: jnp.ndarray,
|
|
188
|
+
num_draft_tokens: jnp.ndarray,
|
|
189
|
+
sampling_metadata: TPUSupportedSamplingMetadata,
|
|
190
|
+
) -> jnp.ndarray:
|
|
191
|
+
"""
|
|
192
|
+
Apply top-k, top-p, and temperature to logits and compute probabilities.
|
|
193
|
+
"""
|
|
194
|
+
total_tokens = logits.shape[0]
|
|
195
|
+
segment_ids, _ = _get_segment_info(num_draft_tokens, total_tokens)
|
|
196
|
+
|
|
197
|
+
# Expand sampling params from [batch_size] to [num_tokens]
|
|
198
|
+
top_k = sampling_metadata.top_k[segment_ids]
|
|
199
|
+
top_p = sampling_metadata.top_p[segment_ids]
|
|
200
|
+
temperatures = sampling_metadata.temperature[segment_ids]
|
|
201
|
+
|
|
202
|
+
# Apply top-k and top-p masking
|
|
203
|
+
logits = logits.astype(jnp.float32)
|
|
204
|
+
# Only apply top-k masking if k > 0 for each token
|
|
205
|
+
should_apply_topk = jnp.expand_dims(top_k > 0, axis=-1)
|
|
206
|
+
topk_masked = topk_mask(logits, top_k, replace_val=-jnp.inf)
|
|
207
|
+
logits = jnp.where(should_apply_topk, topk_masked, logits)
|
|
208
|
+
|
|
209
|
+
# Only apply top-p masking if p < 1.0 for each token
|
|
210
|
+
should_apply_topp = jnp.expand_dims(top_p < 1.0, axis=-1)
|
|
211
|
+
topp_masked = topp_mask(logits, top_p, replace_val=-jnp.inf)
|
|
212
|
+
logits = jnp.where(should_apply_topp, topp_masked, logits)
|
|
213
|
+
|
|
214
|
+
# Apply temperature scaling
|
|
215
|
+
temperatures = jnp.expand_dims(temperatures, axis=-1)
|
|
216
|
+
# Add epsilon to avoid division by zero
|
|
217
|
+
logits /= (temperatures + 1e-9)
|
|
218
|
+
|
|
219
|
+
return jax.nn.softmax(logits, axis=-1)
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def _get_segment_info(num_draft_tokens: jax.Array, total_tokens: int):
|
|
223
|
+
"""Helper to create segment IDs and per-segment indices."""
|
|
224
|
+
batch_size = num_draft_tokens.shape[0]
|
|
225
|
+
|
|
226
|
+
# `segment_ids` assigns a unique ID to each token, corresponding to its
|
|
227
|
+
# sequence in the batch. E.g., [0, 0, 0, 1, 1, 2, 2, 2, 2] for sequences [3, 2, 4].
|
|
228
|
+
segment_ids = jnp.repeat(jnp.arange(batch_size),
|
|
229
|
+
num_draft_tokens,
|
|
230
|
+
total_repeat_length=total_tokens)
|
|
231
|
+
|
|
232
|
+
# `group_indices` creates a within-segment index for each token.
|
|
233
|
+
# E.g., [0, 1, 2, 0, 1, 0, 1, 2, 3] for the example above.
|
|
234
|
+
segment_starts = jnp.concatenate(
|
|
235
|
+
[jnp.array([0]), jnp.cumsum(num_draft_tokens)[:-1]])
|
|
236
|
+
broadcast_starts = jnp.repeat(segment_starts,
|
|
237
|
+
num_draft_tokens,
|
|
238
|
+
total_repeat_length=total_tokens)
|
|
239
|
+
group_indices = jnp.arange(total_tokens) - broadcast_starts
|
|
240
|
+
return segment_ids, group_indices
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def _sample_recovered_tokens(
|
|
244
|
+
draft_token_ids: jax.Array,
|
|
245
|
+
draft_probs: Optional[jax.Array],
|
|
246
|
+
target_probs: jax.Array,
|
|
247
|
+
key: jax.random.PRNGKey,
|
|
248
|
+
) -> jax.Array:
|
|
249
|
+
"""
|
|
250
|
+
Sample recovered tokens using the Gumbel-Max trick.
|
|
251
|
+
This is used when a draft token is rejected in random sampling.
|
|
252
|
+
"""
|
|
253
|
+
if draft_probs is not None:
|
|
254
|
+
# The new distribution is p' = max(p_target - p_draft, 0)
|
|
255
|
+
new_dist = jnp.maximum(target_probs - draft_probs, 0)
|
|
256
|
+
else:
|
|
257
|
+
# If no draft probs, the new distribution is the target distribution
|
|
258
|
+
# with the draft token's probability zeroed out.
|
|
259
|
+
vocab_size = target_probs.shape[-1]
|
|
260
|
+
mask = jax.nn.one_hot(draft_token_ids, vocab_size, dtype=jnp.bool)
|
|
261
|
+
new_dist = target_probs * ~mask
|
|
262
|
+
|
|
263
|
+
# Gumbel-Max trick to sample from the new distribution:
|
|
264
|
+
# y = argmax(log(p') + g) where g ~ Gumbel(0,1)
|
|
265
|
+
# This is equivalent to argmax(p' / q) where q ~ Exponential(1)
|
|
266
|
+
q = jax.random.exponential(key, shape=new_dist.shape)
|
|
267
|
+
|
|
268
|
+
# Add a small epsilon to avoid division by zero
|
|
269
|
+
recovered_token_ids = jnp.argmax(new_dist / (q + 1e-9), axis=-1)
|
|
270
|
+
return recovered_token_ids
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def rejection_sample(
|
|
274
|
+
# [num_tokens] - flattened format
|
|
275
|
+
draft_token_ids: jnp.ndarray,
|
|
276
|
+
# [batch_size] - JAX array
|
|
277
|
+
num_draft_tokens: jnp.ndarray,
|
|
278
|
+
# [num_tokens, vocab_size] - flattened format
|
|
279
|
+
draft_probs: Optional[jnp.ndarray],
|
|
280
|
+
# [num_tokens, vocab_size] - flattened format
|
|
281
|
+
target_probs: jnp.ndarray,
|
|
282
|
+
# [batch_size]
|
|
283
|
+
bonus_token_ids: jnp.ndarray,
|
|
284
|
+
sampling_metadata: TPUSupportedSamplingMetadata,
|
|
285
|
+
key: Optional[jax.random.PRNGKey] = None,
|
|
286
|
+
) -> jnp.ndarray:
|
|
287
|
+
"""
|
|
288
|
+
Perform rejection sampling on draft tokens with flattened inputs.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
draft_token_ids: Draft token IDs in flattened format [num_tokens].
|
|
292
|
+
num_draft_tokens: Number of draft tokens per request [batch_size].
|
|
293
|
+
draft_probs: Draft probabilities in flattened format [num_tokens, vocab_size].
|
|
294
|
+
target_probs: Target probabilities in flattened format [num_tokens, vocab_size].
|
|
295
|
+
bonus_token_ids: Bonus token IDs [batch_size].
|
|
296
|
+
sampling_metadata: Sampling metadata.
|
|
297
|
+
key: JAX random key for non-greedy sampling.
|
|
298
|
+
|
|
299
|
+
Returns:
|
|
300
|
+
output_token_ids: Output token IDs [num_tokens + batch_size].
|
|
301
|
+
"""
|
|
302
|
+
if sampling_metadata.do_sampling is False:
|
|
303
|
+
greedy_output = _greedy_rejection_sample_with_segment(
|
|
304
|
+
draft_token_ids, target_probs, num_draft_tokens, bonus_token_ids)
|
|
305
|
+
return greedy_output
|
|
306
|
+
|
|
307
|
+
# Random path
|
|
308
|
+
if key is None:
|
|
309
|
+
raise ValueError(
|
|
310
|
+
"A random key must be provided for non-greedy sampling.")
|
|
311
|
+
|
|
312
|
+
random_output = _random_rejection_sample_with_segment(
|
|
313
|
+
draft_token_ids,
|
|
314
|
+
draft_probs,
|
|
315
|
+
target_probs,
|
|
316
|
+
num_draft_tokens,
|
|
317
|
+
bonus_token_ids,
|
|
318
|
+
key,
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
return random_output
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def _random_rejection_sample_with_segment(
|
|
325
|
+
draft_token_ids: jax.Array,
|
|
326
|
+
draft_probs: Optional[jax.Array],
|
|
327
|
+
target_probs: jax.Array,
|
|
328
|
+
num_draft_tokens: jax.Array,
|
|
329
|
+
bonus_token_ids: jax.Array,
|
|
330
|
+
key: jax.random.PRNGKey,
|
|
331
|
+
) -> jax.Array:
|
|
332
|
+
"""
|
|
333
|
+
Performs random speculative decoding validation in a vectorized, jittable manner.
|
|
334
|
+
"""
|
|
335
|
+
total_tokens = draft_token_ids.shape[0]
|
|
336
|
+
batch_size = num_draft_tokens.shape[0]
|
|
337
|
+
|
|
338
|
+
# Split random key
|
|
339
|
+
uniform_key, recover_key = jax.random.split(key)
|
|
340
|
+
|
|
341
|
+
# --- Step 1: Get Segment Info ---
|
|
342
|
+
segment_ids, group_indices = _get_segment_info(num_draft_tokens,
|
|
343
|
+
total_tokens)
|
|
344
|
+
|
|
345
|
+
# --- Step 2: Acceptance/Rejection Logic ---
|
|
346
|
+
if draft_probs is not None:
|
|
347
|
+
draft_token_probs = jnp.take_along_axis(draft_probs,
|
|
348
|
+
draft_token_ids[:, None],
|
|
349
|
+
axis=-1).squeeze(-1)
|
|
350
|
+
else:
|
|
351
|
+
draft_token_probs = 1.0
|
|
352
|
+
|
|
353
|
+
target_token_probs = jnp.take_along_axis(target_probs,
|
|
354
|
+
draft_token_ids[:, None],
|
|
355
|
+
axis=-1).squeeze(-1)
|
|
356
|
+
|
|
357
|
+
uniform_probs = jax.random.uniform(uniform_key, shape=(total_tokens, ))
|
|
358
|
+
|
|
359
|
+
# Acceptance condition: p_target(d) / p_draft(d) >= u
|
|
360
|
+
ratio = target_token_probs / (draft_token_probs + 1e-9)
|
|
361
|
+
accepted = ratio >= uniform_probs
|
|
362
|
+
|
|
363
|
+
# --- Step 3: Find First Rejection ---
|
|
364
|
+
rejections = ~accepted
|
|
365
|
+
large_value = total_tokens
|
|
366
|
+
rejection_indices = jnp.where(rejections, group_indices, large_value)
|
|
367
|
+
|
|
368
|
+
first_rejection_idx_per_segment = jax.ops.segment_min(
|
|
369
|
+
data=rejection_indices.astype(jnp.int32),
|
|
370
|
+
segment_ids=segment_ids,
|
|
371
|
+
num_segments=batch_size,
|
|
372
|
+
indices_are_sorted=True,
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
max_int = jnp.iinfo(jnp.int32).max
|
|
376
|
+
first_rejection_idx_per_segment = jnp.where(
|
|
377
|
+
first_rejection_idx_per_segment == max_int, large_value,
|
|
378
|
+
first_rejection_idx_per_segment)
|
|
379
|
+
|
|
380
|
+
# --- Step 4: Sample Recovered Tokens ---
|
|
381
|
+
recovered_token_ids = _sample_recovered_tokens(draft_token_ids,
|
|
382
|
+
draft_probs, target_probs,
|
|
383
|
+
recover_key)
|
|
384
|
+
|
|
385
|
+
# --- Step 5: Generate Main Token Output ---
|
|
386
|
+
first_rejection_idx_broadcast = jnp.repeat(
|
|
387
|
+
first_rejection_idx_per_segment,
|
|
388
|
+
num_draft_tokens,
|
|
389
|
+
total_repeat_length=total_tokens)
|
|
390
|
+
|
|
391
|
+
main_tokens = jnp.where(
|
|
392
|
+
group_indices < first_rejection_idx_broadcast, draft_token_ids,
|
|
393
|
+
jnp.where(group_indices == first_rejection_idx_broadcast,
|
|
394
|
+
recovered_token_ids, PLACEHOLDER_TOKEN_ID))
|
|
395
|
+
|
|
396
|
+
# --- Step 6: Handle Bonus Tokens ---
|
|
397
|
+
all_accepted = first_rejection_idx_per_segment == large_value
|
|
398
|
+
no_draft_tokens = num_draft_tokens == 0
|
|
399
|
+
should_get_bonus = all_accepted | no_draft_tokens
|
|
400
|
+
bonus_tokens = jnp.where(should_get_bonus, bonus_token_ids,
|
|
401
|
+
PLACEHOLDER_TOKEN_ID)
|
|
402
|
+
|
|
403
|
+
# --- Step 7: Concatenate ---
|
|
404
|
+
return jnp.concatenate([main_tokens, bonus_tokens])
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
# TODO(pooyam): Optimize/Profile this implementation further. Currently, I just want working e2e. There might be overheads with `parse_output` that can be optimized on TPU.
|
|
408
|
+
# I should Benchmark against the following approaches:
|
|
409
|
+
# - Using `jax.lax.segment_xyz`` to work with flattened inputs instead of batched inputs.
|
|
410
|
+
# - Using vectorized implementation using `cumprod` and other masking tricks.
|
|
411
|
+
# - A pallas kernel similar to the Triton implementation.
|
|
412
|
+
# - Scan based approach.
|
|
413
|
+
# Overall, I expect XLA to optimize the scan-based approach pretty well, but
|
|
414
|
+
# it would be good to compare performance against other methods.
|
|
415
|
+
def _greedy_rejection_sample_with_segment(
|
|
416
|
+
draft_token_ids: jax.Array,
|
|
417
|
+
target_probs: jax.Array,
|
|
418
|
+
num_draft_tokens: jax.Array,
|
|
419
|
+
bonus_token_ids: jax.Array,
|
|
420
|
+
) -> jax.Array:
|
|
421
|
+
"""
|
|
422
|
+
Performs greedy speculative decoding validation in a vectorized, jittable manner.
|
|
423
|
+
|
|
424
|
+
This function compares draft tokens with the target model's outputs. For each
|
|
425
|
+
sequence in the batch, it accepts tokens as long as the draft and target match.
|
|
426
|
+
When a mismatch occurs, it takes the target model's token and invalidates the
|
|
427
|
+
rest of the tokens in that sequence by setting them to -1.
|
|
428
|
+
|
|
429
|
+
Args:
|
|
430
|
+
draft_token_ids: A 1D JAX array (num_tokens,) of integers representing the
|
|
431
|
+
concatenated draft tokens for all sequences in the batch.
|
|
432
|
+
target_probs: A 2D JAX array (num_tokens, vocab_size) of floats representing
|
|
433
|
+
the concatenated target model's probabilities.
|
|
434
|
+
num_draft_tokens: A 1D JAX array (batch_size,) of integers specifying the
|
|
435
|
+
number of draft tokens for each sequence in the batch.
|
|
436
|
+
bonus_token_ids: A 1D JAX array (batch_size,) of integers representing the
|
|
437
|
+
bonus token for each sequence.
|
|
438
|
+
|
|
439
|
+
Returns:
|
|
440
|
+
A 1D JAX array (num_tokens + batch_size,) containing the validated token
|
|
441
|
+
sequence followed by bonus tokens (or -1 if not accepted).
|
|
442
|
+
"""
|
|
443
|
+
# Get target argmax
|
|
444
|
+
target_logits_argmax = jnp.argmax(target_probs, axis=-1)
|
|
445
|
+
|
|
446
|
+
# --- Step 1: Create Segment IDs and Per-Segment Indices ---
|
|
447
|
+
total_tokens = draft_token_ids.shape[0]
|
|
448
|
+
batch_size = num_draft_tokens.shape[0]
|
|
449
|
+
segment_ids, group_indices = _get_segment_info(num_draft_tokens,
|
|
450
|
+
total_tokens)
|
|
451
|
+
|
|
452
|
+
# --- Step 2: Find the First Mismatch in Each Segment ---
|
|
453
|
+
|
|
454
|
+
# Find all mismatches between draft and target tokens.
|
|
455
|
+
mismatches = draft_token_ids != target_logits_argmax
|
|
456
|
+
|
|
457
|
+
# To find the *first* mismatch, we use a trick with segment_min.
|
|
458
|
+
# We create an array where mismatched positions hold their `group_index`
|
|
459
|
+
# and matched positions hold a large value.
|
|
460
|
+
large_value = total_tokens
|
|
461
|
+
mismatch_indices = jnp.where(mismatches, group_indices, large_value)
|
|
462
|
+
|
|
463
|
+
# `segment_min` finds the minimum `mismatch_index` for each segment. This
|
|
464
|
+
# effectively gives us the `group_index` of the first mismatch.
|
|
465
|
+
# For sequences with no mismatches, the result will be `large_value`.
|
|
466
|
+
first_mismatch_idx_per_segment = jax.ops.segment_min(
|
|
467
|
+
data=mismatch_indices.astype(jnp.int32),
|
|
468
|
+
segment_ids=segment_ids,
|
|
469
|
+
num_segments=batch_size,
|
|
470
|
+
indices_are_sorted=True,
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
# Handle empty segments (where num_draft_tokens is 0). `segment_min` returns
|
|
474
|
+
# the dtype's max value for empty segments; we replace it with our large_value
|
|
475
|
+
# for consistency.
|
|
476
|
+
max_int = jnp.iinfo(jnp.int32).max
|
|
477
|
+
first_mismatch_idx_per_segment = jnp.where(
|
|
478
|
+
first_mismatch_idx_per_segment == max_int, large_value,
|
|
479
|
+
first_mismatch_idx_per_segment)
|
|
480
|
+
|
|
481
|
+
# --- Step 3: Broadcast Mismatch Info and Generate Main Token Output ---
|
|
482
|
+
|
|
483
|
+
# Broadcast the first mismatch index back to the original token dimension.
|
|
484
|
+
first_mismatch_idx_broadcast = jnp.repeat(first_mismatch_idx_per_segment,
|
|
485
|
+
num_draft_tokens,
|
|
486
|
+
total_repeat_length=total_tokens)
|
|
487
|
+
|
|
488
|
+
# The final logic for main tokens:
|
|
489
|
+
# A token is valid if its `group_index` is less than or equal to the
|
|
490
|
+
# index of the first mismatch in its segment.
|
|
491
|
+
# - If `group_index < first_mismatch_idx`, the draft was correct.
|
|
492
|
+
# - If `group_index == first_mismatch_idx`, this is the correction token.
|
|
493
|
+
# - If `group_index > first_mismatch_idx`, the token is invalid (-1).
|
|
494
|
+
main_tokens = jnp.where(group_indices <= first_mismatch_idx_broadcast,
|
|
495
|
+
target_logits_argmax, PLACEHOLDER_TOKEN_ID)
|
|
496
|
+
|
|
497
|
+
# --- Step 4: Handle Bonus Tokens ---
|
|
498
|
+
|
|
499
|
+
# A sequence gets its bonus token if there were no mismatches
|
|
500
|
+
# (first_mismatch_idx_per_segment == large_value)
|
|
501
|
+
all_accepted = first_mismatch_idx_per_segment == large_value
|
|
502
|
+
|
|
503
|
+
# For sequences with no draft tokens, we should still give them the bonus token
|
|
504
|
+
# since there's nothing to reject
|
|
505
|
+
no_draft_tokens = num_draft_tokens == 0
|
|
506
|
+
should_get_bonus = all_accepted | no_draft_tokens
|
|
507
|
+
|
|
508
|
+
bonus_tokens = jnp.where(should_get_bonus, bonus_token_ids,
|
|
509
|
+
PLACEHOLDER_TOKEN_ID)
|
|
510
|
+
|
|
511
|
+
# --- Step 5: Concatenate Main Tokens and Bonus Tokens ---
|
|
512
|
+
|
|
513
|
+
output = jnp.concatenate([main_tokens, bonus_tokens])
|
|
514
|
+
|
|
515
|
+
return output
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
from jax.sharding import Mesh, NamedSharding
|
|
6
|
+
from jax.sharding import PartitionSpec as P
|
|
7
|
+
from vllm.v1.outputs import LogprobsTensors
|
|
8
|
+
|
|
9
|
+
from tpu_inference.layers.common.binary_search import topk_mask, topp_mask
|
|
10
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
11
|
+
from tpu_inference.layers.jax.sample.sampling_metadata import \
|
|
12
|
+
TPUSupportedSamplingMetadata
|
|
13
|
+
|
|
14
|
+
_SAMPLING_EPS = 1e-5
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@functools.partial(
|
|
18
|
+
jax.jit,
|
|
19
|
+
static_argnames=["mesh"],
|
|
20
|
+
)
|
|
21
|
+
def sample(
|
|
22
|
+
rng: jax.Array,
|
|
23
|
+
mesh: Mesh,
|
|
24
|
+
logits: jax.Array,
|
|
25
|
+
tpu_sampling_metadata: TPUSupportedSamplingMetadata,
|
|
26
|
+
) -> jax.Array:
|
|
27
|
+
# (B, vocab_size)
|
|
28
|
+
if tpu_sampling_metadata.do_sampling:
|
|
29
|
+
# Unshard the logits explicity to avoid latency increase.
|
|
30
|
+
logits = jax.lax.with_sharding_constraint(
|
|
31
|
+
logits, NamedSharding(mesh, P(ShardingAxisName.ATTN_DATA, None)))
|
|
32
|
+
greedy_sampled = jnp.argmax(logits, axis=-1)
|
|
33
|
+
if not tpu_sampling_metadata.do_sampling:
|
|
34
|
+
return greedy_sampled
|
|
35
|
+
|
|
36
|
+
logits = logits.astype(jnp.float32)
|
|
37
|
+
logits = topk_mask(logits, tpu_sampling_metadata.top_k, replace_val=-1e12)
|
|
38
|
+
logits = topp_mask(logits, tpu_sampling_metadata.top_p, replace_val=-1e12)
|
|
39
|
+
|
|
40
|
+
temperatures = tpu_sampling_metadata.temperature.astype(logits.dtype)
|
|
41
|
+
temperatures = jnp.expand_dims(temperatures, axis=-1)
|
|
42
|
+
logits /= temperatures
|
|
43
|
+
|
|
44
|
+
# (batch_size,)
|
|
45
|
+
next_tokens = jax.random.categorical(rng, logits)
|
|
46
|
+
# Note: avoid using the sample result when temperature < _SAMPLING_EPS
|
|
47
|
+
# If temperature < 0, logits /= temperatures will flip the result, causing error.
|
|
48
|
+
return jnp.where(tpu_sampling_metadata.temperature < _SAMPLING_EPS,
|
|
49
|
+
greedy_sampled, next_tokens)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def compute_logprobs(logits: jax.Array) -> jax.Array:
|
|
53
|
+
return jax.nn.log_softmax(logits, axis=-1)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def gather_logprobs(
|
|
57
|
+
logprobs: jax.Array,
|
|
58
|
+
token_ids: jax.Array,
|
|
59
|
+
num_logprobs: int,
|
|
60
|
+
) -> LogprobsTensors:
|
|
61
|
+
"""
|
|
62
|
+
Gather logprobs for topk and sampled/prompt token.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
logprobs: (num tokens) x (vocab) tensor
|
|
66
|
+
token_ids: prompt tokens (if prompt logprobs)
|
|
67
|
+
or sampled tokens (if sampled
|
|
68
|
+
logprobs); 1D token ID tensor
|
|
69
|
+
with (num tokens) elements
|
|
70
|
+
num_logprobs: minimum number of logprobs to
|
|
71
|
+
retain per token
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
|
|
76
|
+
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
|
|
77
|
+
Sampled token rank tensor, (num tokens)
|
|
78
|
+
"""
|
|
79
|
+
# Find the topK values.
|
|
80
|
+
topk_logprobs, topk_indices = jax.lax.top_k(logprobs, k=num_logprobs)
|
|
81
|
+
|
|
82
|
+
# Get with the logprob of the prompt or sampled token.
|
|
83
|
+
token_ids = jnp.expand_dims(token_ids, axis=-1)
|
|
84
|
+
token_logprobs = jnp.take_along_axis(logprobs, token_ids, axis=-1)
|
|
85
|
+
|
|
86
|
+
# Compute the ranks of the actual token.
|
|
87
|
+
token_ranks = jnp.sum(logprobs >= token_logprobs, axis=-1)
|
|
88
|
+
|
|
89
|
+
# Concatenate together with the topk.
|
|
90
|
+
indices = jnp.concatenate((token_ids, topk_indices), axis=1)
|
|
91
|
+
logprobs = jnp.concatenate((token_logprobs, topk_logprobs), axis=1)
|
|
92
|
+
|
|
93
|
+
# Use int32 to reduce the tensor size.
|
|
94
|
+
indices = jnp.int32(indices)
|
|
95
|
+
|
|
96
|
+
return LogprobsTensors(indices, logprobs, token_ranks)
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import jax
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
import torch
|
|
8
|
+
from jax.sharding import Mesh
|
|
9
|
+
|
|
10
|
+
from tpu_inference.runner.input_batch import InputBatch
|
|
11
|
+
from tpu_inference.utils import device_array
|
|
12
|
+
|
|
13
|
+
DEFAULT_SAMPLING_PARAMS = dict(
|
|
14
|
+
temperature=-1.0,
|
|
15
|
+
top_k=0,
|
|
16
|
+
top_p=1.0,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@functools.partial(
|
|
21
|
+
jax.tree_util.register_dataclass,
|
|
22
|
+
data_fields=[
|
|
23
|
+
"temperature",
|
|
24
|
+
"top_k",
|
|
25
|
+
"top_p",
|
|
26
|
+
],
|
|
27
|
+
meta_fields=["do_sampling", "logprobs"],
|
|
28
|
+
)
|
|
29
|
+
@dataclass
|
|
30
|
+
class TPUSupportedSamplingMetadata:
|
|
31
|
+
temperature: Optional[jnp.ndarray] = None
|
|
32
|
+
top_k: Optional[jnp.ndarray] = None
|
|
33
|
+
top_p: Optional[jnp.ndarray] = None
|
|
34
|
+
do_sampling: bool = False
|
|
35
|
+
logprobs: bool = False
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
def from_input_batch(
|
|
39
|
+
cls,
|
|
40
|
+
mesh: Mesh,
|
|
41
|
+
input_batch: InputBatch,
|
|
42
|
+
padded_num_reqs: int,
|
|
43
|
+
sharding: Optional[jax.sharding.Sharding] = None,
|
|
44
|
+
) -> "TPUSupportedSamplingMetadata":
|
|
45
|
+
needs_logprobs = input_batch.max_num_logprobs > 0 if input_batch.max_num_logprobs else False
|
|
46
|
+
if input_batch.all_greedy:
|
|
47
|
+
return cls(do_sampling=False, logprobs=needs_logprobs)
|
|
48
|
+
num_reqs = input_batch.num_reqs
|
|
49
|
+
|
|
50
|
+
def fill_slice(cpu_torch_tensor: torch.Tensor,
|
|
51
|
+
fill_val: float) -> torch.Tensor:
|
|
52
|
+
# Pad value is the default one.
|
|
53
|
+
cpu_torch_tensor[num_reqs:padded_num_reqs] = fill_val
|
|
54
|
+
return cpu_torch_tensor
|
|
55
|
+
|
|
56
|
+
temp_tensor = fill_slice(input_batch.temperature_cpu,
|
|
57
|
+
DEFAULT_SAMPLING_PARAMS["temperature"])
|
|
58
|
+
top_k_tensor = fill_slice(input_batch.top_k_cpu,
|
|
59
|
+
DEFAULT_SAMPLING_PARAMS["top_k"])
|
|
60
|
+
top_p_tensor = fill_slice(input_batch.top_p_cpu,
|
|
61
|
+
DEFAULT_SAMPLING_PARAMS["top_p"])
|
|
62
|
+
|
|
63
|
+
# Slice persistent device tensors to a fixed pre-compiled padded shape.
|
|
64
|
+
return cls(
|
|
65
|
+
temperature=device_array(mesh,
|
|
66
|
+
temp_tensor[:padded_num_reqs],
|
|
67
|
+
sharding=sharding),
|
|
68
|
+
top_p=device_array(mesh,
|
|
69
|
+
top_p_tensor[:padded_num_reqs],
|
|
70
|
+
sharding=sharding),
|
|
71
|
+
top_k=device_array(mesh,
|
|
72
|
+
top_k_tensor[:padded_num_reqs],
|
|
73
|
+
sharding=sharding),
|
|
74
|
+
do_sampling=not input_batch.all_greedy,
|
|
75
|
+
logprobs=needs_logprobs,
|
|
76
|
+
)
|