tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/kernels/mla_v1_test.py +129 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
- tests/lora/test_layers.py +4 -7
- tests/lora/test_lora_perf.py +53 -0
- tests/lora/utils.py +0 -8
- tests/test_envs.py +110 -12
- tests/test_quantization.py +3 -0
- tests/test_utils.py +1 -2
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +3 -4
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +93 -9
- tpu_inference/executors/ray_distributed_executor.py +9 -2
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
- tpu_inference/kernels/mla/v1/kernel.py +98 -120
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +140 -67
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +204 -120
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +11 -7
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +170 -208
- tpu_inference/layers/vllm/linear_common.py +43 -21
- tpu_inference/layers/vllm/quantization/common.py +11 -6
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
- tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
- tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +84 -28
- tpu_inference/models/jax/deepseek_v3.py +185 -64
- tpu_inference/models/jax/gpt_oss.py +3 -3
- tpu_inference/models/jax/llama3.py +2 -1
- tpu_inference/models/jax/llama_eagle3.py +8 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +2 -1
- tpu_inference/models/jax/qwen2_5_vl.py +163 -48
- tpu_inference/models/jax/qwen3.py +2 -1
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
- tpu_inference/models/jax/utils/weight_utils.py +205 -144
- tpu_inference/models/vllm/vllm_model_wrapper.py +14 -8
- tpu_inference/platforms/tpu_platform.py +34 -50
- tpu_inference/runner/compilation_manager.py +144 -60
- tpu_inference/runner/kv_cache.py +40 -20
- tpu_inference/runner/kv_cache_manager.py +48 -33
- tpu_inference/runner/persistent_batch_manager.py +40 -2
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +280 -149
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +71 -21
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +46 -18
- tpu_inference/worker/tpu_worker.py +197 -63
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +9 -10
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +70 -74
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
|
@@ -1,13 +1,13 @@
|
|
|
1
|
-
import os
|
|
2
1
|
import time
|
|
3
|
-
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
|
4
3
|
|
|
5
4
|
import jax
|
|
6
5
|
import jax.numpy as jnp
|
|
7
6
|
import numpy as np
|
|
8
|
-
import vllm.envs as
|
|
7
|
+
import vllm.envs as vllm_envs
|
|
9
8
|
from jax.sharding import NamedSharding, PartitionSpec
|
|
10
9
|
|
|
10
|
+
import tpu_inference.envs as envs
|
|
11
11
|
from tpu_inference.core.disagg_utils import is_disagg_enabled
|
|
12
12
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
13
13
|
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
@@ -15,6 +15,8 @@ from tpu_inference.layers.jax.sample.sampling import sample
|
|
|
15
15
|
from tpu_inference.layers.jax.sample.sampling_metadata import \
|
|
16
16
|
TPUSupportedSamplingMetadata
|
|
17
17
|
from tpu_inference.logger import init_logger
|
|
18
|
+
from tpu_inference.models.jax.jax_intermediate_tensor import \
|
|
19
|
+
JaxIntermediateTensors
|
|
18
20
|
from tpu_inference.utils import device_array
|
|
19
21
|
|
|
20
22
|
if TYPE_CHECKING:
|
|
@@ -30,10 +32,10 @@ class CompilationManager:
|
|
|
30
32
|
|
|
31
33
|
def __init__(self, runner: "TPUModelRunner"):
|
|
32
34
|
self.runner = runner
|
|
33
|
-
if not
|
|
35
|
+
if not vllm_envs.VLLM_DISABLE_COMPILE_CACHE:
|
|
34
36
|
logger.info("Enabling JAX compile cache.")
|
|
35
37
|
jax.config.update("jax_compilation_cache_dir",
|
|
36
|
-
|
|
38
|
+
vllm_envs.VLLM_XLA_CACHE_PATH)
|
|
37
39
|
|
|
38
40
|
def _create_dummy_tensor(self,
|
|
39
41
|
shape: Tuple[int, ...],
|
|
@@ -67,8 +69,7 @@ class CompilationManager:
|
|
|
67
69
|
logger.info("Compilation finished in %.2f [secs].", end - start)
|
|
68
70
|
|
|
69
71
|
def capture_model(self) -> None:
|
|
70
|
-
if
|
|
71
|
-
False) or self.runner.model_config.enforce_eager:
|
|
72
|
+
if envs.SKIP_JAX_PRECOMPILE or self.runner.model_config.enforce_eager:
|
|
72
73
|
return
|
|
73
74
|
logger.info("Precompile all the subgraphs with possible input shapes.")
|
|
74
75
|
|
|
@@ -81,6 +82,8 @@ class CompilationManager:
|
|
|
81
82
|
self._precompile_backbone_with_inputs_embeds()
|
|
82
83
|
if self.runner.scheduler_config.async_scheduling:
|
|
83
84
|
self._precompile_substitute_placeholder_token()
|
|
85
|
+
if not self.runner.is_last_rank:
|
|
86
|
+
return
|
|
84
87
|
self._precompile_select_from_array()
|
|
85
88
|
self._precompile_compute_logits()
|
|
86
89
|
self._precompile_disagg_utils()
|
|
@@ -120,8 +123,15 @@ class CompilationManager:
|
|
|
120
123
|
num_tokens=num_tokens,
|
|
121
124
|
)
|
|
122
125
|
|
|
123
|
-
def _precompile_backbone_helper(self,
|
|
124
|
-
|
|
126
|
+
def _precompile_backbone_helper(self,
|
|
127
|
+
name,
|
|
128
|
+
*,
|
|
129
|
+
input_ids,
|
|
130
|
+
positions,
|
|
131
|
+
inputs_embeds,
|
|
132
|
+
intermediate_tensors=None,
|
|
133
|
+
is_first_rank=True,
|
|
134
|
+
is_last_rank=True) -> None:
|
|
125
135
|
num_tokens = None
|
|
126
136
|
if input_ids is not None:
|
|
127
137
|
num_tokens = input_ids.shape[0]
|
|
@@ -135,12 +145,6 @@ class CompilationManager:
|
|
|
135
145
|
ShardingAxisName.ATTN_DATA, )) if dp_size > 1 else None
|
|
136
146
|
|
|
137
147
|
# Keep existing pattern for complex array operations
|
|
138
|
-
block_tables = self.runner.block_table_cpu[:self.runner.max_num_reqs]
|
|
139
|
-
block_tables = block_tables.reshape(-1)
|
|
140
|
-
block_tables = device_array(self.runner.mesh,
|
|
141
|
-
block_tables,
|
|
142
|
-
sharding=dp_sharding)
|
|
143
|
-
|
|
144
148
|
seq_lens = self._create_dummy_tensor((self.runner.max_num_reqs, ),
|
|
145
149
|
jnp.int32, dp_sharding)
|
|
146
150
|
query_start_loc = self._create_dummy_tensor(
|
|
@@ -152,26 +156,49 @@ class CompilationManager:
|
|
|
152
156
|
request_distribution,
|
|
153
157
|
sharding=dp_sharding)
|
|
154
158
|
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
159
|
+
attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
|
|
160
|
+
uniform_attention_metadata: AttentionMetadata = None
|
|
161
|
+
for kv_cache_gid, kv_cache_group in enumerate(
|
|
162
|
+
self.runner.kv_cache_config.kv_cache_groups):
|
|
163
|
+
block_tables = self.runner.block_tables_cpu[
|
|
164
|
+
kv_cache_gid][:self.runner.max_num_reqs]
|
|
165
|
+
block_tables = block_tables.reshape(-1)
|
|
166
|
+
block_tables = device_array(self.runner.mesh,
|
|
167
|
+
block_tables,
|
|
168
|
+
sharding=dp_sharding)
|
|
169
|
+
|
|
170
|
+
attention_metadata_gid = AttentionMetadata(
|
|
171
|
+
input_positions=positions,
|
|
172
|
+
block_tables=block_tables,
|
|
173
|
+
seq_lens=seq_lens,
|
|
174
|
+
query_start_loc=query_start_loc,
|
|
175
|
+
request_distribution=request_distribution,
|
|
176
|
+
)
|
|
177
|
+
if not self.runner.use_hybrid_kvcache:
|
|
178
|
+
# all layers share the same attention metadata
|
|
179
|
+
uniform_attention_metadata = attention_metadata_gid
|
|
180
|
+
else:
|
|
181
|
+
for layer_name in kv_cache_group.layer_names:
|
|
182
|
+
attention_metadata_per_layer[
|
|
183
|
+
layer_name] = attention_metadata_gid
|
|
162
184
|
|
|
163
185
|
def model_fn_wrapper(
|
|
164
186
|
state,
|
|
165
187
|
kv_caches,
|
|
166
188
|
input_ids,
|
|
167
189
|
attention_metadata,
|
|
190
|
+
positions,
|
|
168
191
|
inputs_embeds,
|
|
169
192
|
layer_name_to_kvcache_index,
|
|
170
193
|
lora_metadata,
|
|
194
|
+
intermediate_tensors,
|
|
195
|
+
is_first_rank,
|
|
196
|
+
is_last_rank,
|
|
171
197
|
):
|
|
172
198
|
kv_caches, hidden_states, _ = self.runner.model_fn(
|
|
173
199
|
state, kv_caches, input_ids, attention_metadata, inputs_embeds,
|
|
174
|
-
layer_name_to_kvcache_index, lora_metadata
|
|
200
|
+
positions, layer_name_to_kvcache_index, lora_metadata,
|
|
201
|
+
intermediate_tensors, is_first_rank, is_last_rank)
|
|
175
202
|
self.runner.kv_caches = kv_caches
|
|
176
203
|
return hidden_states
|
|
177
204
|
|
|
@@ -179,6 +206,10 @@ class CompilationManager:
|
|
|
179
206
|
self.runner.lora_config, np.array([num_tokens],
|
|
180
207
|
dtype=np.int32)):
|
|
181
208
|
lora_metadata = self.runner.lora_utils.extract_lora_metadata()
|
|
209
|
+
if self.runner.use_hybrid_kvcache:
|
|
210
|
+
attention_metadata = attention_metadata_per_layer
|
|
211
|
+
else:
|
|
212
|
+
attention_metadata = uniform_attention_metadata
|
|
182
213
|
self._run_compilation(
|
|
183
214
|
name,
|
|
184
215
|
model_fn_wrapper,
|
|
@@ -186,9 +217,13 @@ class CompilationManager:
|
|
|
186
217
|
self.runner.kv_caches,
|
|
187
218
|
input_ids,
|
|
188
219
|
attention_metadata,
|
|
220
|
+
positions,
|
|
189
221
|
inputs_embeds,
|
|
190
222
|
tuple(self.runner.layer_name_to_kvcache_index.items()),
|
|
191
223
|
lora_metadata,
|
|
224
|
+
intermediate_tensors,
|
|
225
|
+
is_first_rank,
|
|
226
|
+
is_last_rank,
|
|
192
227
|
num_tokens=num_tokens,
|
|
193
228
|
)
|
|
194
229
|
|
|
@@ -239,6 +274,7 @@ class CompilationManager:
|
|
|
239
274
|
)
|
|
240
275
|
|
|
241
276
|
def _precompile_backbone_text_only(self) -> None:
|
|
277
|
+
hidden_size = self.runner.model_config.get_hidden_size()
|
|
242
278
|
for num_tokens in self.runner.num_tokens_paddings:
|
|
243
279
|
dp_sharding = NamedSharding(
|
|
244
280
|
self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, )
|
|
@@ -248,10 +284,28 @@ class CompilationManager:
|
|
|
248
284
|
dp_sharding)
|
|
249
285
|
positions = self._create_dummy_tensor((num_tokens, ), jnp.int32,
|
|
250
286
|
dp_sharding)
|
|
251
|
-
self.
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
287
|
+
is_first_rank = self.runner.is_first_rank
|
|
288
|
+
is_last_rank = self.runner.is_last_rank
|
|
289
|
+
if is_first_rank:
|
|
290
|
+
intermediate_tensors = None
|
|
291
|
+
else:
|
|
292
|
+
hidden_states = self._create_dummy_tensor(
|
|
293
|
+
(num_tokens, hidden_size), jnp.bfloat16)
|
|
294
|
+
residual = self._create_dummy_tensor((num_tokens, hidden_size),
|
|
295
|
+
jnp.bfloat16)
|
|
296
|
+
intermediate_tensors = JaxIntermediateTensors(
|
|
297
|
+
tensors={
|
|
298
|
+
"hidden_states": hidden_states,
|
|
299
|
+
"residual": residual
|
|
300
|
+
})
|
|
301
|
+
self._precompile_backbone_helper(
|
|
302
|
+
f"worker{self.runner.rank} backbone",
|
|
303
|
+
input_ids=input_ids,
|
|
304
|
+
positions=positions,
|
|
305
|
+
inputs_embeds=None,
|
|
306
|
+
intermediate_tensors=intermediate_tensors,
|
|
307
|
+
is_first_rank=is_first_rank,
|
|
308
|
+
is_last_rank=is_last_rank)
|
|
255
309
|
|
|
256
310
|
def _precompile_backbone_with_inputs_embeds(self) -> None:
|
|
257
311
|
hidden_size = self.runner.model_config.get_hidden_size()
|
|
@@ -265,10 +319,28 @@ class CompilationManager:
|
|
|
265
319
|
else:
|
|
266
320
|
positions = self._create_dummy_tensor((num_tokens, ),
|
|
267
321
|
jnp.int32)
|
|
268
|
-
self.
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
322
|
+
is_first_rank = self.runner.is_first_rank
|
|
323
|
+
is_last_rank = self.runner.is_last_rank
|
|
324
|
+
if not is_first_rank:
|
|
325
|
+
hidden_states = self._create_dummy_tensor(
|
|
326
|
+
(num_tokens, hidden_size), jnp.bfloat16)
|
|
327
|
+
residual = self._create_dummy_tensor((num_tokens, hidden_size),
|
|
328
|
+
jnp.bfloat16)
|
|
329
|
+
intermediate_tensors = JaxIntermediateTensors(
|
|
330
|
+
tensors={
|
|
331
|
+
"hidden_states": hidden_states,
|
|
332
|
+
"residual": residual
|
|
333
|
+
})
|
|
334
|
+
else:
|
|
335
|
+
intermediate_tensors = None
|
|
336
|
+
self._precompile_backbone_helper(
|
|
337
|
+
f"worker{self.runner.rank} backbone with embeds",
|
|
338
|
+
input_ids=None,
|
|
339
|
+
positions=positions,
|
|
340
|
+
inputs_embeds=inputs_embeds,
|
|
341
|
+
intermediate_tensors=intermediate_tensors,
|
|
342
|
+
is_first_rank=is_first_rank,
|
|
343
|
+
is_last_rank=is_last_rank)
|
|
272
344
|
|
|
273
345
|
def _precompile_select_from_array_helper(
|
|
274
346
|
self,
|
|
@@ -336,7 +408,7 @@ class CompilationManager:
|
|
|
336
408
|
self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None))
|
|
337
409
|
dp_size = self.runner.vllm_config.sharding_config.total_dp_size
|
|
338
410
|
self._precompile_select_from_array_helper(
|
|
339
|
-
name="select all logits",
|
|
411
|
+
name=f"worker{self.runner.rank} select all logits",
|
|
340
412
|
source_paddings=self.runner.num_tokens_paddings,
|
|
341
413
|
indices_paddings=index_paddings,
|
|
342
414
|
hidden_dim=hsize,
|
|
@@ -347,7 +419,8 @@ class CompilationManager:
|
|
|
347
419
|
if self.runner.speculative_config:
|
|
348
420
|
vocab_size = self.runner.model_config.get_vocab_size()
|
|
349
421
|
self._precompile_select_from_array_helper(
|
|
350
|
-
name=
|
|
422
|
+
name=
|
|
423
|
+
f"worker{self.runner.rank} select bonus tokens for spec decoding",
|
|
351
424
|
source_paddings=self.runner.num_logits_paddings,
|
|
352
425
|
indices_paddings=self.runner.num_reqs_paddings,
|
|
353
426
|
hidden_dim=vocab_size,
|
|
@@ -355,7 +428,8 @@ class CompilationManager:
|
|
|
355
428
|
PartitionSpec(None, "model")),
|
|
356
429
|
)
|
|
357
430
|
self._precompile_select_from_array_helper(
|
|
358
|
-
name=
|
|
431
|
+
name=
|
|
432
|
+
f"worker{self.runner.rank} select target tokens for spec decoding",
|
|
359
433
|
source_paddings=self.runner.num_logits_paddings,
|
|
360
434
|
indices_paddings=self.runner.num_logits_paddings,
|
|
361
435
|
hidden_dim=vocab_size,
|
|
@@ -378,7 +452,7 @@ class CompilationManager:
|
|
|
378
452
|
np.array([num_reqs], dtype=np.int32)):
|
|
379
453
|
lora_metadata = self.runner.lora_utils.extract_lora_metadata()
|
|
380
454
|
self._run_compilation(
|
|
381
|
-
"compute_logits",
|
|
455
|
+
f"worker{self.runner.rank} compute_logits",
|
|
382
456
|
self.runner.compute_logits_fn,
|
|
383
457
|
self.runner.state,
|
|
384
458
|
hidden_states,
|
|
@@ -392,11 +466,12 @@ class CompilationManager:
|
|
|
392
466
|
for num_reqs in self.runner.num_reqs_paddings:
|
|
393
467
|
logits_sharding = NamedSharding(
|
|
394
468
|
self.runner.mesh,
|
|
395
|
-
PartitionSpec(ShardingAxisName.
|
|
469
|
+
PartitionSpec(ShardingAxisName.MLP_DATA,
|
|
470
|
+
ShardingAxisName.MLP_TENSOR))
|
|
396
471
|
dp_size = self.runner.vllm_config.sharding_config.total_dp_size
|
|
397
472
|
sampling_metadata_sharding = NamedSharding(
|
|
398
473
|
self.runner.mesh, PartitionSpec(
|
|
399
|
-
ShardingAxisName.
|
|
474
|
+
ShardingAxisName.MLP_DATA)) if dp_size > 1 else None
|
|
400
475
|
logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16,
|
|
401
476
|
logits_sharding)
|
|
402
477
|
for do_sampling in (True, False):
|
|
@@ -420,7 +495,7 @@ class CompilationManager:
|
|
|
420
495
|
do_sampling=do_sampling,
|
|
421
496
|
)
|
|
422
497
|
self._run_compilation(
|
|
423
|
-
"sample",
|
|
498
|
+
f"worker{self.runner.rank} sample",
|
|
424
499
|
sample,
|
|
425
500
|
self.runner.rng_params_for_sampling,
|
|
426
501
|
self.runner.mesh,
|
|
@@ -461,7 +536,7 @@ class CompilationManager:
|
|
|
461
536
|
logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16)
|
|
462
537
|
token_ids = self._create_dummy_tensor((num_reqs, ), jnp.int32)
|
|
463
538
|
self._run_compilation(
|
|
464
|
-
"gather_logprobs",
|
|
539
|
+
f"worker{self.runner.rank} gather_logprobs",
|
|
465
540
|
self.runner._compute_and_gather_logprobs,
|
|
466
541
|
logits,
|
|
467
542
|
token_ids,
|
|
@@ -513,7 +588,7 @@ class CompilationManager:
|
|
|
513
588
|
do_sampling=do_sampling)
|
|
514
589
|
|
|
515
590
|
self._run_compilation(
|
|
516
|
-
compilation_name,
|
|
591
|
+
f"worker{self.runner.rank} {compilation_name}",
|
|
517
592
|
self.runner.rejection_sampler,
|
|
518
593
|
draft_token_ids,
|
|
519
594
|
num_draft_tokens,
|
|
@@ -530,7 +605,9 @@ class CompilationManager:
|
|
|
530
605
|
def _precompile_eagle3_helpers(self) -> None:
|
|
531
606
|
logger.info(
|
|
532
607
|
"Compiling eagle3 jitted helpers with different input shapes.")
|
|
533
|
-
|
|
608
|
+
target_hidden_size = self.runner.model_config.get_hidden_size()
|
|
609
|
+
draft_hidden_size = self.runner.speculative_config.draft_model_config.get_hidden_size(
|
|
610
|
+
)
|
|
534
611
|
dtype = self.runner.model_config.dtype
|
|
535
612
|
|
|
536
613
|
num_kv_cache_groups = len(self.runner.kv_cache_config.kv_cache_groups)
|
|
@@ -577,10 +654,11 @@ class CompilationManager:
|
|
|
577
654
|
|
|
578
655
|
for num_logits in self.runner.num_logits_paddings:
|
|
579
656
|
hidden_states = self._create_dummy_tensor(
|
|
580
|
-
(num_logits,
|
|
657
|
+
(num_logits, draft_hidden_size), jnp.bfloat16)
|
|
581
658
|
self._run_compilation(
|
|
582
659
|
"eagle3_get_draft_token_ids",
|
|
583
660
|
self.runner.drafter._get_draft_token_ids,
|
|
661
|
+
self.runner.drafter.state,
|
|
584
662
|
hidden_states,
|
|
585
663
|
num_logits=num_logits,
|
|
586
664
|
)
|
|
@@ -588,8 +666,8 @@ class CompilationManager:
|
|
|
588
666
|
input_ids_loop = self._create_dummy_tensor(
|
|
589
667
|
(self.runner.max_num_reqs, ), jnp.int32,
|
|
590
668
|
NamedSharding(self.runner.mesh, PartitionSpec()))
|
|
591
|
-
|
|
592
|
-
(self.runner.max_num_reqs,
|
|
669
|
+
draft_hidden_state_loop = self._create_dummy_tensor(
|
|
670
|
+
(self.runner.max_num_reqs, draft_hidden_size), dtype,
|
|
593
671
|
NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
|
|
594
672
|
next_token_ids = self._create_dummy_tensor(
|
|
595
673
|
(self.runner.max_num_reqs, ), jnp.int32)
|
|
@@ -597,9 +675,12 @@ class CompilationManager:
|
|
|
597
675
|
(self.runner.max_num_reqs, ), jnp.int32)
|
|
598
676
|
for num_tokens in self.runner.num_tokens_paddings:
|
|
599
677
|
aux_hidden_states = [
|
|
600
|
-
self._create_dummy_tensor((num_tokens,
|
|
601
|
-
|
|
602
|
-
self._create_dummy_tensor((num_tokens,
|
|
678
|
+
self._create_dummy_tensor((num_tokens, target_hidden_size),
|
|
679
|
+
dtype),
|
|
680
|
+
self._create_dummy_tensor((num_tokens, target_hidden_size),
|
|
681
|
+
dtype),
|
|
682
|
+
self._create_dummy_tensor((num_tokens, target_hidden_size),
|
|
683
|
+
dtype),
|
|
603
684
|
]
|
|
604
685
|
|
|
605
686
|
positions = self._create_dummy_tensor((num_tokens, ), jnp.int32)
|
|
@@ -622,23 +703,23 @@ class CompilationManager:
|
|
|
622
703
|
num_reqs,
|
|
623
704
|
):
|
|
624
705
|
target_hidden_states, input_ids, last_token_indices, _ = self.runner.drafter._filter_token_and_prepare_initial_inputs(
|
|
625
|
-
token_indices, query_start_loc,
|
|
626
|
-
aux_hidden_states, attention_metadata,
|
|
627
|
-
num_reqs)
|
|
706
|
+
self.runner.drafter.state, token_indices, query_start_loc,
|
|
707
|
+
seq_lens, input_ids, aux_hidden_states, attention_metadata,
|
|
708
|
+
next_token_ids, num_reqs)
|
|
628
709
|
return target_hidden_states, input_ids, last_token_indices
|
|
629
710
|
|
|
630
711
|
input_ids = self._create_dummy_tensor((num_tokens, ), jnp.int32)
|
|
631
712
|
aux_hidden_states = [
|
|
632
713
|
self._create_dummy_tensor(
|
|
633
|
-
(num_tokens,
|
|
714
|
+
(num_tokens, target_hidden_size), jnp.bfloat16,
|
|
634
715
|
NamedSharding(self.runner.mesh, PartitionSpec(None,
|
|
635
716
|
None))),
|
|
636
717
|
self._create_dummy_tensor(
|
|
637
|
-
(num_tokens,
|
|
718
|
+
(num_tokens, target_hidden_size), jnp.bfloat16,
|
|
638
719
|
NamedSharding(self.runner.mesh, PartitionSpec(None,
|
|
639
720
|
None))),
|
|
640
721
|
self._create_dummy_tensor(
|
|
641
|
-
(num_tokens,
|
|
722
|
+
(num_tokens, target_hidden_size), jnp.bfloat16,
|
|
642
723
|
NamedSharding(self.runner.mesh, PartitionSpec(None,
|
|
643
724
|
None))),
|
|
644
725
|
]
|
|
@@ -670,17 +751,17 @@ class CompilationManager:
|
|
|
670
751
|
state,
|
|
671
752
|
kv_caches,
|
|
672
753
|
input_ids,
|
|
673
|
-
|
|
754
|
+
draft_hidden_states,
|
|
674
755
|
attention_metadata,
|
|
675
756
|
):
|
|
676
757
|
kv_caches, hidden_states, _ = self.runner.drafter.model_fn(
|
|
677
|
-
state, kv_caches, input_ids,
|
|
758
|
+
state, kv_caches, input_ids, draft_hidden_states,
|
|
678
759
|
attention_metadata)
|
|
679
760
|
self.runner.kv_caches = kv_caches
|
|
680
761
|
return hidden_states
|
|
681
762
|
|
|
682
|
-
|
|
683
|
-
(num_tokens,
|
|
763
|
+
draft_hidden_states = self._create_dummy_tensor(
|
|
764
|
+
(num_tokens, draft_hidden_size), dtype,
|
|
684
765
|
NamedSharding(self.runner.mesh, PartitionSpec(None, "model")))
|
|
685
766
|
input_ids = self._create_dummy_tensor(
|
|
686
767
|
(num_tokens, ), jnp.int32,
|
|
@@ -691,7 +772,7 @@ class CompilationManager:
|
|
|
691
772
|
self.runner.drafter.state,
|
|
692
773
|
self.runner.kv_caches,
|
|
693
774
|
input_ids,
|
|
694
|
-
|
|
775
|
+
draft_hidden_states,
|
|
695
776
|
attention_metadata,
|
|
696
777
|
num_tokens=num_tokens,
|
|
697
778
|
)
|
|
@@ -701,6 +782,7 @@ class CompilationManager:
|
|
|
701
782
|
self._run_compilation(
|
|
702
783
|
"eagle3_prepare_hidden_states_and_input_ids",
|
|
703
784
|
self.runner.drafter._prepare_hidden_states_and_input_ids,
|
|
785
|
+
self.runner.drafter.state,
|
|
704
786
|
aux_hidden_states,
|
|
705
787
|
query_start_loc,
|
|
706
788
|
target_token_ids,
|
|
@@ -723,18 +805,19 @@ class CompilationManager:
|
|
|
723
805
|
self.runner.drafter.state,
|
|
724
806
|
self.runner.kv_caches,
|
|
725
807
|
input_ids_loop,
|
|
726
|
-
|
|
808
|
+
draft_hidden_state_loop,
|
|
727
809
|
attention_metadata,
|
|
728
810
|
num_tokens=num_tokens,
|
|
729
811
|
)
|
|
730
812
|
|
|
731
813
|
hidden_states = self._create_dummy_tensor(
|
|
732
|
-
(num_tokens,
|
|
814
|
+
(num_tokens, draft_hidden_size), jnp.bfloat16,
|
|
733
815
|
NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
|
|
734
816
|
|
|
735
817
|
self._run_compilation(
|
|
736
818
|
"eagle3_select_inputs_for_loop_speculation",
|
|
737
819
|
self.runner.drafter._select_inputs_for_loop_speculation,
|
|
820
|
+
self.runner.drafter.state,
|
|
738
821
|
positions,
|
|
739
822
|
hidden_states,
|
|
740
823
|
hidden_states,
|
|
@@ -745,6 +828,7 @@ class CompilationManager:
|
|
|
745
828
|
self._run_compilation(
|
|
746
829
|
"eagle3_select_draft_token_ids",
|
|
747
830
|
self.runner.drafter._select_draft_token_ids,
|
|
831
|
+
self.runner.drafter.state,
|
|
748
832
|
hidden_states,
|
|
749
833
|
last_token_indices,
|
|
750
834
|
num_tokens=num_tokens,
|
tpu_inference/runner/kv_cache.py
CHANGED
|
@@ -7,6 +7,7 @@ from jax._src import dtypes
|
|
|
7
7
|
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
8
8
|
from torchax.ops.mappings import t2j_dtype
|
|
9
9
|
|
|
10
|
+
import tpu_inference.kernels.mla.v1.kernel as mla
|
|
10
11
|
import tpu_inference.kernels.ragged_paged_attention.v3.kernel as rpa
|
|
11
12
|
import tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 as rpa_hd64
|
|
12
13
|
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
@@ -17,9 +18,13 @@ logger = init_logger(__name__)
|
|
|
17
18
|
DEFAULT_KV_CACHE_DTYPE = jnp.bfloat16
|
|
18
19
|
|
|
19
20
|
|
|
20
|
-
def get_kv_cache_shape_with_mesh(mesh: Mesh,
|
|
21
|
-
|
|
22
|
-
|
|
21
|
+
def get_kv_cache_shape_with_mesh(mesh: Mesh,
|
|
22
|
+
total_num_pages: int,
|
|
23
|
+
page_size: int,
|
|
24
|
+
actual_num_kv_heads: int,
|
|
25
|
+
actual_head_dim: int,
|
|
26
|
+
kv_dtype: any,
|
|
27
|
+
use_mla: bool = False):
|
|
23
28
|
"""Gets the KV cache shape based on the mesh configuration."""
|
|
24
29
|
|
|
25
30
|
model_cnt = mesh.shape["model"]
|
|
@@ -28,15 +33,21 @@ def get_kv_cache_shape_with_mesh(mesh: Mesh, total_num_pages: int,
|
|
|
28
33
|
# specific model, rather than being determined by the head_dim. If new
|
|
29
34
|
# models are introduced with a head_dim of 64, this will require additional
|
|
30
35
|
# model-specific adjustments.
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
36
|
+
if use_mla:
|
|
37
|
+
get_kv_cache_shape_fn = mla.get_kv_cache_shape
|
|
38
|
+
shape = list(
|
|
39
|
+
get_kv_cache_shape_fn(total_num_pages, page_size, actual_head_dim,
|
|
40
|
+
kv_dtype))
|
|
41
|
+
else:
|
|
42
|
+
get_kv_cache_shape_fn = (
|
|
43
|
+
rpa_hd64.get_kv_cache_shape if actual_head_dim == 64 \
|
|
44
|
+
else rpa.get_kv_cache_shape
|
|
45
|
+
)
|
|
46
|
+
shape = list(
|
|
47
|
+
get_kv_cache_shape_fn(total_num_pages, page_size,
|
|
48
|
+
actual_num_kv_heads // model_cnt,
|
|
49
|
+
actual_head_dim, kv_dtype))
|
|
50
|
+
shape[2] *= model_cnt
|
|
40
51
|
return tuple(shape)
|
|
41
52
|
|
|
42
53
|
|
|
@@ -48,6 +59,7 @@ def create_kv_caches(
|
|
|
48
59
|
mesh: Mesh,
|
|
49
60
|
layer_names: List[str],
|
|
50
61
|
cache_dtype: jnp.dtype = DEFAULT_KV_CACHE_DTYPE,
|
|
62
|
+
use_mla: bool = False,
|
|
51
63
|
) -> List[jax.Array]:
|
|
52
64
|
"""
|
|
53
65
|
Creates a list of KV cache where each array mapps to single attention layer.
|
|
@@ -74,12 +86,16 @@ def create_kv_caches(
|
|
|
74
86
|
|
|
75
87
|
cache_shape = get_kv_cache_shape_with_mesh(mesh, num_blocks, block_size,
|
|
76
88
|
num_kv_heads, head_size,
|
|
77
|
-
cache_dtype)
|
|
89
|
+
cache_dtype, use_mla)
|
|
78
90
|
|
|
79
|
-
|
|
80
|
-
mesh,
|
|
81
|
-
|
|
82
|
-
|
|
91
|
+
if use_mla:
|
|
92
|
+
sharding = NamedSharding(mesh,
|
|
93
|
+
PartitionSpec(ShardingAxisName.MLP_TENSOR))
|
|
94
|
+
else:
|
|
95
|
+
sharding = NamedSharding(
|
|
96
|
+
mesh,
|
|
97
|
+
PartitionSpec(ShardingAxisName.ATTN_DATA, None,
|
|
98
|
+
ShardingAxisName.ATTN_HEAD))
|
|
83
99
|
|
|
84
100
|
def _allocate() -> jax.Array:
|
|
85
101
|
return jnp.empty(
|
|
@@ -94,7 +110,8 @@ def create_kv_caches(
|
|
|
94
110
|
return kv_caches
|
|
95
111
|
|
|
96
112
|
|
|
97
|
-
def
|
|
113
|
+
def get_attention_page_size_bytes(mesh: Mesh,
|
|
114
|
+
kv_cache_specs: dict[str, Any]) -> int:
|
|
98
115
|
"""
|
|
99
116
|
Calculate KV cache page size of RPA kernel.
|
|
100
117
|
|
|
@@ -107,14 +124,16 @@ def get_rpa_page_size_bytes(mesh: Mesh, kv_cache_specs: dict[str, Any]) -> int:
|
|
|
107
124
|
"""
|
|
108
125
|
|
|
109
126
|
# Import it here to avoid circular import.
|
|
110
|
-
from vllm.v1.kv_cache_interface import AttentionSpec
|
|
127
|
+
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
|
|
111
128
|
|
|
112
129
|
page_size_bytes_set = set()
|
|
113
130
|
for kv_cache_spec in kv_cache_specs.values():
|
|
114
131
|
assert isinstance(kv_cache_spec, AttentionSpec)
|
|
115
132
|
|
|
116
133
|
dtype = t2j_dtype(kv_cache_spec.dtype)
|
|
117
|
-
bits = dtypes.bit_width(dtype)
|
|
134
|
+
bits = (dtypes.bit_width(dtype) if hasattr(dtypes, "bit_width") else
|
|
135
|
+
dtypes.itemsize_bits(dtype))
|
|
136
|
+
use_mla = isinstance(kv_cache_spec, MLAAttentionSpec)
|
|
118
137
|
|
|
119
138
|
kv_cache_shape = get_kv_cache_shape_with_mesh(
|
|
120
139
|
mesh=mesh,
|
|
@@ -123,6 +142,7 @@ def get_rpa_page_size_bytes(mesh: Mesh, kv_cache_specs: dict[str, Any]) -> int:
|
|
|
123
142
|
actual_num_kv_heads=kv_cache_spec.num_kv_heads,
|
|
124
143
|
actual_head_dim=kv_cache_spec.head_size,
|
|
125
144
|
kv_dtype=dtype,
|
|
145
|
+
use_mla=use_mla,
|
|
126
146
|
)
|
|
127
147
|
page_size_bytes = (bits * np.prod(kv_cache_shape)) // 8
|
|
128
148
|
page_size_bytes_set.add(page_size_bytes)
|