tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511180814__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +34 -303
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
- tests/lora/test_layers.py +6 -0
- tests/lora/utils.py +8 -0
- tests/test_envs.py +11 -32
- tests/test_utils.py +2 -1
- tpu_inference/__init__.py +3 -22
- tpu_inference/core/disagg_utils.py +8 -6
- tpu_inference/distributed/tpu_connector.py +4 -3
- tpu_inference/distributed/utils.py +2 -3
- tpu_inference/envs.py +8 -61
- tpu_inference/executors/ray_distributed_executor.py +2 -9
- tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +145 -266
- tpu_inference/layers/common/attention_interface.py +1 -7
- tpu_inference/layers/common/sharding.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +208 -170
- tpu_inference/layers/vllm/quantization/common.py +1 -6
- tpu_inference/layers/vllm/quantization/mxfp4.py +73 -138
- tpu_inference/layers/vllm/quantization/unquantized.py +64 -58
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +2 -1
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/common/model_loader.py +10 -43
- tpu_inference/models/jax/llama3.py +1 -2
- tpu_inference/models/jax/llama_eagle3.py +5 -8
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +1 -2
- tpu_inference/models/jax/qwen2_5_vl.py +48 -163
- tpu_inference/models/jax/qwen3.py +1 -2
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
- tpu_inference/models/jax/utils/weight_utils.py +143 -198
- tpu_inference/models/vllm/vllm_model_wrapper.py +8 -14
- tpu_inference/platforms/tpu_platform.py +31 -37
- tpu_inference/runner/compilation_manager.py +58 -141
- tpu_inference/runner/kv_cache.py +1 -1
- tpu_inference/runner/kv_cache_manager.py +18 -17
- tpu_inference/runner/persistent_batch_manager.py +2 -40
- tpu_inference/runner/structured_decoding_manager.py +3 -2
- tpu_inference/runner/tpu_runner.py +147 -271
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +21 -71
- tpu_inference/tpu_info.py +3 -4
- tpu_inference/utils.py +13 -36
- tpu_inference/worker/tpu_worker.py +25 -162
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/METADATA +3 -4
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/RECORD +55 -50
- tpu_inference/models/jax/llama_guard_4.py +0 -361
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/WHEEL +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/top_level.txt +0 -0
|
@@ -1,13 +1,13 @@
|
|
|
1
|
+
import os
|
|
1
2
|
import time
|
|
2
|
-
from typing import TYPE_CHECKING, Any, Callable,
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
|
|
3
4
|
|
|
4
5
|
import jax
|
|
5
6
|
import jax.numpy as jnp
|
|
6
7
|
import numpy as np
|
|
7
|
-
import vllm.envs as
|
|
8
|
+
import vllm.envs as envs
|
|
8
9
|
from jax.sharding import NamedSharding, PartitionSpec
|
|
9
10
|
|
|
10
|
-
import tpu_inference.envs as envs
|
|
11
11
|
from tpu_inference.core.disagg_utils import is_disagg_enabled
|
|
12
12
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
13
13
|
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
@@ -15,8 +15,6 @@ from tpu_inference.layers.jax.sample.sampling import sample
|
|
|
15
15
|
from tpu_inference.layers.jax.sample.sampling_metadata import \
|
|
16
16
|
TPUSupportedSamplingMetadata
|
|
17
17
|
from tpu_inference.logger import init_logger
|
|
18
|
-
from tpu_inference.models.jax.jax_intermediate_tensor import \
|
|
19
|
-
JaxIntermediateTensors
|
|
20
18
|
from tpu_inference.utils import device_array
|
|
21
19
|
|
|
22
20
|
if TYPE_CHECKING:
|
|
@@ -32,10 +30,10 @@ class CompilationManager:
|
|
|
32
30
|
|
|
33
31
|
def __init__(self, runner: "TPUModelRunner"):
|
|
34
32
|
self.runner = runner
|
|
35
|
-
if not
|
|
33
|
+
if not envs.VLLM_DISABLE_COMPILE_CACHE:
|
|
36
34
|
logger.info("Enabling JAX compile cache.")
|
|
37
35
|
jax.config.update("jax_compilation_cache_dir",
|
|
38
|
-
|
|
36
|
+
envs.VLLM_XLA_CACHE_PATH)
|
|
39
37
|
|
|
40
38
|
def _create_dummy_tensor(self,
|
|
41
39
|
shape: Tuple[int, ...],
|
|
@@ -69,7 +67,8 @@ class CompilationManager:
|
|
|
69
67
|
logger.info("Compilation finished in %.2f [secs].", end - start)
|
|
70
68
|
|
|
71
69
|
def capture_model(self) -> None:
|
|
72
|
-
if
|
|
70
|
+
if os.getenv("SKIP_JAX_PRECOMPILE",
|
|
71
|
+
False) or self.runner.model_config.enforce_eager:
|
|
73
72
|
return
|
|
74
73
|
logger.info("Precompile all the subgraphs with possible input shapes.")
|
|
75
74
|
|
|
@@ -82,8 +81,6 @@ class CompilationManager:
|
|
|
82
81
|
self._precompile_backbone_with_inputs_embeds()
|
|
83
82
|
if self.runner.scheduler_config.async_scheduling:
|
|
84
83
|
self._precompile_substitute_placeholder_token()
|
|
85
|
-
if not self.runner.is_last_rank:
|
|
86
|
-
return
|
|
87
84
|
self._precompile_select_from_array()
|
|
88
85
|
self._precompile_compute_logits()
|
|
89
86
|
self._precompile_disagg_utils()
|
|
@@ -123,15 +120,8 @@ class CompilationManager:
|
|
|
123
120
|
num_tokens=num_tokens,
|
|
124
121
|
)
|
|
125
122
|
|
|
126
|
-
def _precompile_backbone_helper(self,
|
|
127
|
-
|
|
128
|
-
*,
|
|
129
|
-
input_ids,
|
|
130
|
-
positions,
|
|
131
|
-
inputs_embeds,
|
|
132
|
-
intermediate_tensors=None,
|
|
133
|
-
is_first_rank=True,
|
|
134
|
-
is_last_rank=True) -> None:
|
|
123
|
+
def _precompile_backbone_helper(self, name, *, input_ids, positions,
|
|
124
|
+
inputs_embeds) -> None:
|
|
135
125
|
num_tokens = None
|
|
136
126
|
if input_ids is not None:
|
|
137
127
|
num_tokens = input_ids.shape[0]
|
|
@@ -145,6 +135,12 @@ class CompilationManager:
|
|
|
145
135
|
ShardingAxisName.ATTN_DATA, )) if dp_size > 1 else None
|
|
146
136
|
|
|
147
137
|
# Keep existing pattern for complex array operations
|
|
138
|
+
block_tables = self.runner.block_table_cpu[:self.runner.max_num_reqs]
|
|
139
|
+
block_tables = block_tables.reshape(-1)
|
|
140
|
+
block_tables = device_array(self.runner.mesh,
|
|
141
|
+
block_tables,
|
|
142
|
+
sharding=dp_sharding)
|
|
143
|
+
|
|
148
144
|
seq_lens = self._create_dummy_tensor((self.runner.max_num_reqs, ),
|
|
149
145
|
jnp.int32, dp_sharding)
|
|
150
146
|
query_start_loc = self._create_dummy_tensor(
|
|
@@ -156,49 +152,26 @@ class CompilationManager:
|
|
|
156
152
|
request_distribution,
|
|
157
153
|
sharding=dp_sharding)
|
|
158
154
|
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
block_tables = device_array(self.runner.mesh,
|
|
167
|
-
block_tables,
|
|
168
|
-
sharding=dp_sharding)
|
|
169
|
-
|
|
170
|
-
attention_metadata_gid = AttentionMetadata(
|
|
171
|
-
input_positions=positions,
|
|
172
|
-
block_tables=block_tables,
|
|
173
|
-
seq_lens=seq_lens,
|
|
174
|
-
query_start_loc=query_start_loc,
|
|
175
|
-
request_distribution=request_distribution,
|
|
176
|
-
)
|
|
177
|
-
if not self.runner.use_hybrid_kvcache:
|
|
178
|
-
# all layers share the same attention metadata
|
|
179
|
-
uniform_attention_metadata = attention_metadata_gid
|
|
180
|
-
else:
|
|
181
|
-
for layer_name in kv_cache_group.layer_names:
|
|
182
|
-
attention_metadata_per_layer[
|
|
183
|
-
layer_name] = attention_metadata_gid
|
|
155
|
+
attention_metadata = AttentionMetadata(
|
|
156
|
+
input_positions=positions,
|
|
157
|
+
block_tables=block_tables,
|
|
158
|
+
seq_lens=seq_lens,
|
|
159
|
+
query_start_loc=query_start_loc,
|
|
160
|
+
request_distribution=request_distribution,
|
|
161
|
+
)
|
|
184
162
|
|
|
185
163
|
def model_fn_wrapper(
|
|
186
164
|
state,
|
|
187
165
|
kv_caches,
|
|
188
166
|
input_ids,
|
|
189
167
|
attention_metadata,
|
|
190
|
-
positions,
|
|
191
168
|
inputs_embeds,
|
|
192
169
|
layer_name_to_kvcache_index,
|
|
193
170
|
lora_metadata,
|
|
194
|
-
intermediate_tensors,
|
|
195
|
-
is_first_rank,
|
|
196
|
-
is_last_rank,
|
|
197
171
|
):
|
|
198
172
|
kv_caches, hidden_states, _ = self.runner.model_fn(
|
|
199
173
|
state, kv_caches, input_ids, attention_metadata, inputs_embeds,
|
|
200
|
-
|
|
201
|
-
intermediate_tensors, is_first_rank, is_last_rank)
|
|
174
|
+
layer_name_to_kvcache_index, lora_metadata)
|
|
202
175
|
self.runner.kv_caches = kv_caches
|
|
203
176
|
return hidden_states
|
|
204
177
|
|
|
@@ -206,10 +179,6 @@ class CompilationManager:
|
|
|
206
179
|
self.runner.lora_config, np.array([num_tokens],
|
|
207
180
|
dtype=np.int32)):
|
|
208
181
|
lora_metadata = self.runner.lora_utils.extract_lora_metadata()
|
|
209
|
-
if self.runner.use_hybrid_kvcache:
|
|
210
|
-
attention_metadata = attention_metadata_per_layer
|
|
211
|
-
else:
|
|
212
|
-
attention_metadata = uniform_attention_metadata
|
|
213
182
|
self._run_compilation(
|
|
214
183
|
name,
|
|
215
184
|
model_fn_wrapper,
|
|
@@ -217,13 +186,9 @@ class CompilationManager:
|
|
|
217
186
|
self.runner.kv_caches,
|
|
218
187
|
input_ids,
|
|
219
188
|
attention_metadata,
|
|
220
|
-
positions,
|
|
221
189
|
inputs_embeds,
|
|
222
190
|
tuple(self.runner.layer_name_to_kvcache_index.items()),
|
|
223
191
|
lora_metadata,
|
|
224
|
-
intermediate_tensors,
|
|
225
|
-
is_first_rank,
|
|
226
|
-
is_last_rank,
|
|
227
192
|
num_tokens=num_tokens,
|
|
228
193
|
)
|
|
229
194
|
|
|
@@ -274,7 +239,6 @@ class CompilationManager:
|
|
|
274
239
|
)
|
|
275
240
|
|
|
276
241
|
def _precompile_backbone_text_only(self) -> None:
|
|
277
|
-
hidden_size = self.runner.model_config.get_hidden_size()
|
|
278
242
|
for num_tokens in self.runner.num_tokens_paddings:
|
|
279
243
|
dp_sharding = NamedSharding(
|
|
280
244
|
self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, )
|
|
@@ -284,28 +248,10 @@ class CompilationManager:
|
|
|
284
248
|
dp_sharding)
|
|
285
249
|
positions = self._create_dummy_tensor((num_tokens, ), jnp.int32,
|
|
286
250
|
dp_sharding)
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
else:
|
|
292
|
-
hidden_states = self._create_dummy_tensor(
|
|
293
|
-
(num_tokens, hidden_size), jnp.bfloat16)
|
|
294
|
-
residual = self._create_dummy_tensor((num_tokens, hidden_size),
|
|
295
|
-
jnp.bfloat16)
|
|
296
|
-
intermediate_tensors = JaxIntermediateTensors(
|
|
297
|
-
tensors={
|
|
298
|
-
"hidden_states": hidden_states,
|
|
299
|
-
"residual": residual
|
|
300
|
-
})
|
|
301
|
-
self._precompile_backbone_helper(
|
|
302
|
-
f"worker{self.runner.rank} backbone",
|
|
303
|
-
input_ids=input_ids,
|
|
304
|
-
positions=positions,
|
|
305
|
-
inputs_embeds=None,
|
|
306
|
-
intermediate_tensors=intermediate_tensors,
|
|
307
|
-
is_first_rank=is_first_rank,
|
|
308
|
-
is_last_rank=is_last_rank)
|
|
251
|
+
self._precompile_backbone_helper("backbone",
|
|
252
|
+
input_ids=input_ids,
|
|
253
|
+
positions=positions,
|
|
254
|
+
inputs_embeds=None)
|
|
309
255
|
|
|
310
256
|
def _precompile_backbone_with_inputs_embeds(self) -> None:
|
|
311
257
|
hidden_size = self.runner.model_config.get_hidden_size()
|
|
@@ -319,28 +265,10 @@ class CompilationManager:
|
|
|
319
265
|
else:
|
|
320
266
|
positions = self._create_dummy_tensor((num_tokens, ),
|
|
321
267
|
jnp.int32)
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
(num_tokens, hidden_size), jnp.bfloat16)
|
|
327
|
-
residual = self._create_dummy_tensor((num_tokens, hidden_size),
|
|
328
|
-
jnp.bfloat16)
|
|
329
|
-
intermediate_tensors = JaxIntermediateTensors(
|
|
330
|
-
tensors={
|
|
331
|
-
"hidden_states": hidden_states,
|
|
332
|
-
"residual": residual
|
|
333
|
-
})
|
|
334
|
-
else:
|
|
335
|
-
intermediate_tensors = None
|
|
336
|
-
self._precompile_backbone_helper(
|
|
337
|
-
f"worker{self.runner.rank} backbone with embeds",
|
|
338
|
-
input_ids=None,
|
|
339
|
-
positions=positions,
|
|
340
|
-
inputs_embeds=inputs_embeds,
|
|
341
|
-
intermediate_tensors=intermediate_tensors,
|
|
342
|
-
is_first_rank=is_first_rank,
|
|
343
|
-
is_last_rank=is_last_rank)
|
|
268
|
+
self._precompile_backbone_helper("backbone with embeds",
|
|
269
|
+
input_ids=None,
|
|
270
|
+
positions=positions,
|
|
271
|
+
inputs_embeds=inputs_embeds)
|
|
344
272
|
|
|
345
273
|
def _precompile_select_from_array_helper(
|
|
346
274
|
self,
|
|
@@ -408,7 +336,7 @@ class CompilationManager:
|
|
|
408
336
|
self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None))
|
|
409
337
|
dp_size = self.runner.vllm_config.sharding_config.total_dp_size
|
|
410
338
|
self._precompile_select_from_array_helper(
|
|
411
|
-
name=
|
|
339
|
+
name="select all logits",
|
|
412
340
|
source_paddings=self.runner.num_tokens_paddings,
|
|
413
341
|
indices_paddings=index_paddings,
|
|
414
342
|
hidden_dim=hsize,
|
|
@@ -419,8 +347,7 @@ class CompilationManager:
|
|
|
419
347
|
if self.runner.speculative_config:
|
|
420
348
|
vocab_size = self.runner.model_config.get_vocab_size()
|
|
421
349
|
self._precompile_select_from_array_helper(
|
|
422
|
-
name=
|
|
423
|
-
f"worker{self.runner.rank} select bonus tokens for spec decoding",
|
|
350
|
+
name="select bonus tokens for spec decoding",
|
|
424
351
|
source_paddings=self.runner.num_logits_paddings,
|
|
425
352
|
indices_paddings=self.runner.num_reqs_paddings,
|
|
426
353
|
hidden_dim=vocab_size,
|
|
@@ -428,8 +355,7 @@ class CompilationManager:
|
|
|
428
355
|
PartitionSpec(None, "model")),
|
|
429
356
|
)
|
|
430
357
|
self._precompile_select_from_array_helper(
|
|
431
|
-
name=
|
|
432
|
-
f"worker{self.runner.rank} select target tokens for spec decoding",
|
|
358
|
+
name="select target tokens for spec decoding",
|
|
433
359
|
source_paddings=self.runner.num_logits_paddings,
|
|
434
360
|
indices_paddings=self.runner.num_logits_paddings,
|
|
435
361
|
hidden_dim=vocab_size,
|
|
@@ -452,7 +378,7 @@ class CompilationManager:
|
|
|
452
378
|
np.array([num_reqs], dtype=np.int32)):
|
|
453
379
|
lora_metadata = self.runner.lora_utils.extract_lora_metadata()
|
|
454
380
|
self._run_compilation(
|
|
455
|
-
|
|
381
|
+
"compute_logits",
|
|
456
382
|
self.runner.compute_logits_fn,
|
|
457
383
|
self.runner.state,
|
|
458
384
|
hidden_states,
|
|
@@ -494,7 +420,7 @@ class CompilationManager:
|
|
|
494
420
|
do_sampling=do_sampling,
|
|
495
421
|
)
|
|
496
422
|
self._run_compilation(
|
|
497
|
-
|
|
423
|
+
"sample",
|
|
498
424
|
sample,
|
|
499
425
|
self.runner.rng_params_for_sampling,
|
|
500
426
|
self.runner.mesh,
|
|
@@ -535,7 +461,7 @@ class CompilationManager:
|
|
|
535
461
|
logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16)
|
|
536
462
|
token_ids = self._create_dummy_tensor((num_reqs, ), jnp.int32)
|
|
537
463
|
self._run_compilation(
|
|
538
|
-
|
|
464
|
+
"gather_logprobs",
|
|
539
465
|
self.runner._compute_and_gather_logprobs,
|
|
540
466
|
logits,
|
|
541
467
|
token_ids,
|
|
@@ -587,7 +513,7 @@ class CompilationManager:
|
|
|
587
513
|
do_sampling=do_sampling)
|
|
588
514
|
|
|
589
515
|
self._run_compilation(
|
|
590
|
-
|
|
516
|
+
compilation_name,
|
|
591
517
|
self.runner.rejection_sampler,
|
|
592
518
|
draft_token_ids,
|
|
593
519
|
num_draft_tokens,
|
|
@@ -604,9 +530,7 @@ class CompilationManager:
|
|
|
604
530
|
def _precompile_eagle3_helpers(self) -> None:
|
|
605
531
|
logger.info(
|
|
606
532
|
"Compiling eagle3 jitted helpers with different input shapes.")
|
|
607
|
-
|
|
608
|
-
draft_hidden_size = self.runner.speculative_config.draft_model_config.get_hidden_size(
|
|
609
|
-
)
|
|
533
|
+
hidden_size = self.runner.model_config.get_hidden_size()
|
|
610
534
|
dtype = self.runner.model_config.dtype
|
|
611
535
|
|
|
612
536
|
num_kv_cache_groups = len(self.runner.kv_cache_config.kv_cache_groups)
|
|
@@ -653,11 +577,10 @@ class CompilationManager:
|
|
|
653
577
|
|
|
654
578
|
for num_logits in self.runner.num_logits_paddings:
|
|
655
579
|
hidden_states = self._create_dummy_tensor(
|
|
656
|
-
(num_logits,
|
|
580
|
+
(num_logits, hidden_size), jnp.bfloat16)
|
|
657
581
|
self._run_compilation(
|
|
658
582
|
"eagle3_get_draft_token_ids",
|
|
659
583
|
self.runner.drafter._get_draft_token_ids,
|
|
660
|
-
self.runner.drafter.state,
|
|
661
584
|
hidden_states,
|
|
662
585
|
num_logits=num_logits,
|
|
663
586
|
)
|
|
@@ -665,8 +588,8 @@ class CompilationManager:
|
|
|
665
588
|
input_ids_loop = self._create_dummy_tensor(
|
|
666
589
|
(self.runner.max_num_reqs, ), jnp.int32,
|
|
667
590
|
NamedSharding(self.runner.mesh, PartitionSpec()))
|
|
668
|
-
|
|
669
|
-
(self.runner.max_num_reqs,
|
|
591
|
+
target_hidden_state_loop = self._create_dummy_tensor(
|
|
592
|
+
(self.runner.max_num_reqs, hidden_size), dtype,
|
|
670
593
|
NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
|
|
671
594
|
next_token_ids = self._create_dummy_tensor(
|
|
672
595
|
(self.runner.max_num_reqs, ), jnp.int32)
|
|
@@ -674,12 +597,9 @@ class CompilationManager:
|
|
|
674
597
|
(self.runner.max_num_reqs, ), jnp.int32)
|
|
675
598
|
for num_tokens in self.runner.num_tokens_paddings:
|
|
676
599
|
aux_hidden_states = [
|
|
677
|
-
self._create_dummy_tensor((num_tokens,
|
|
678
|
-
|
|
679
|
-
self._create_dummy_tensor((num_tokens,
|
|
680
|
-
dtype),
|
|
681
|
-
self._create_dummy_tensor((num_tokens, target_hidden_size),
|
|
682
|
-
dtype),
|
|
600
|
+
self._create_dummy_tensor((num_tokens, hidden_size), dtype),
|
|
601
|
+
self._create_dummy_tensor((num_tokens, hidden_size), dtype),
|
|
602
|
+
self._create_dummy_tensor((num_tokens, hidden_size), dtype),
|
|
683
603
|
]
|
|
684
604
|
|
|
685
605
|
positions = self._create_dummy_tensor((num_tokens, ), jnp.int32)
|
|
@@ -702,23 +622,23 @@ class CompilationManager:
|
|
|
702
622
|
num_reqs,
|
|
703
623
|
):
|
|
704
624
|
target_hidden_states, input_ids, last_token_indices, _ = self.runner.drafter._filter_token_and_prepare_initial_inputs(
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
625
|
+
token_indices, query_start_loc, seq_lens, input_ids,
|
|
626
|
+
aux_hidden_states, attention_metadata, next_token_ids,
|
|
627
|
+
num_reqs)
|
|
708
628
|
return target_hidden_states, input_ids, last_token_indices
|
|
709
629
|
|
|
710
630
|
input_ids = self._create_dummy_tensor((num_tokens, ), jnp.int32)
|
|
711
631
|
aux_hidden_states = [
|
|
712
632
|
self._create_dummy_tensor(
|
|
713
|
-
(num_tokens,
|
|
633
|
+
(num_tokens, hidden_size), jnp.bfloat16,
|
|
714
634
|
NamedSharding(self.runner.mesh, PartitionSpec(None,
|
|
715
635
|
None))),
|
|
716
636
|
self._create_dummy_tensor(
|
|
717
|
-
(num_tokens,
|
|
637
|
+
(num_tokens, hidden_size), jnp.bfloat16,
|
|
718
638
|
NamedSharding(self.runner.mesh, PartitionSpec(None,
|
|
719
639
|
None))),
|
|
720
640
|
self._create_dummy_tensor(
|
|
721
|
-
(num_tokens,
|
|
641
|
+
(num_tokens, hidden_size), jnp.bfloat16,
|
|
722
642
|
NamedSharding(self.runner.mesh, PartitionSpec(None,
|
|
723
643
|
None))),
|
|
724
644
|
]
|
|
@@ -750,17 +670,17 @@ class CompilationManager:
|
|
|
750
670
|
state,
|
|
751
671
|
kv_caches,
|
|
752
672
|
input_ids,
|
|
753
|
-
|
|
673
|
+
target_hidden_states,
|
|
754
674
|
attention_metadata,
|
|
755
675
|
):
|
|
756
676
|
kv_caches, hidden_states, _ = self.runner.drafter.model_fn(
|
|
757
|
-
state, kv_caches, input_ids,
|
|
677
|
+
state, kv_caches, input_ids, target_hidden_states,
|
|
758
678
|
attention_metadata)
|
|
759
679
|
self.runner.kv_caches = kv_caches
|
|
760
680
|
return hidden_states
|
|
761
681
|
|
|
762
|
-
|
|
763
|
-
(num_tokens,
|
|
682
|
+
target_hidden_states = self._create_dummy_tensor(
|
|
683
|
+
(num_tokens, hidden_size), dtype,
|
|
764
684
|
NamedSharding(self.runner.mesh, PartitionSpec(None, "model")))
|
|
765
685
|
input_ids = self._create_dummy_tensor(
|
|
766
686
|
(num_tokens, ), jnp.int32,
|
|
@@ -771,7 +691,7 @@ class CompilationManager:
|
|
|
771
691
|
self.runner.drafter.state,
|
|
772
692
|
self.runner.kv_caches,
|
|
773
693
|
input_ids,
|
|
774
|
-
|
|
694
|
+
target_hidden_states,
|
|
775
695
|
attention_metadata,
|
|
776
696
|
num_tokens=num_tokens,
|
|
777
697
|
)
|
|
@@ -781,7 +701,6 @@ class CompilationManager:
|
|
|
781
701
|
self._run_compilation(
|
|
782
702
|
"eagle3_prepare_hidden_states_and_input_ids",
|
|
783
703
|
self.runner.drafter._prepare_hidden_states_and_input_ids,
|
|
784
|
-
self.runner.drafter.state,
|
|
785
704
|
aux_hidden_states,
|
|
786
705
|
query_start_loc,
|
|
787
706
|
target_token_ids,
|
|
@@ -804,19 +723,18 @@ class CompilationManager:
|
|
|
804
723
|
self.runner.drafter.state,
|
|
805
724
|
self.runner.kv_caches,
|
|
806
725
|
input_ids_loop,
|
|
807
|
-
|
|
726
|
+
target_hidden_state_loop,
|
|
808
727
|
attention_metadata,
|
|
809
728
|
num_tokens=num_tokens,
|
|
810
729
|
)
|
|
811
730
|
|
|
812
731
|
hidden_states = self._create_dummy_tensor(
|
|
813
|
-
(num_tokens,
|
|
732
|
+
(num_tokens, hidden_size), jnp.bfloat16,
|
|
814
733
|
NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
|
|
815
734
|
|
|
816
735
|
self._run_compilation(
|
|
817
736
|
"eagle3_select_inputs_for_loop_speculation",
|
|
818
737
|
self.runner.drafter._select_inputs_for_loop_speculation,
|
|
819
|
-
self.runner.drafter.state,
|
|
820
738
|
positions,
|
|
821
739
|
hidden_states,
|
|
822
740
|
hidden_states,
|
|
@@ -827,7 +745,6 @@ class CompilationManager:
|
|
|
827
745
|
self._run_compilation(
|
|
828
746
|
"eagle3_select_draft_token_ids",
|
|
829
747
|
self.runner.drafter._select_draft_token_ids,
|
|
830
|
-
self.runner.drafter.state,
|
|
831
748
|
hidden_states,
|
|
832
749
|
last_token_indices,
|
|
833
750
|
num_tokens=num_tokens,
|
tpu_inference/runner/kv_cache.py
CHANGED
|
@@ -1,16 +1,15 @@
|
|
|
1
1
|
import functools
|
|
2
|
+
import math
|
|
2
3
|
from typing import TYPE_CHECKING, Dict, List
|
|
3
4
|
|
|
4
5
|
import jax
|
|
5
6
|
import jax.numpy as jnp
|
|
6
|
-
import numpy as np
|
|
7
7
|
import vllm.envs as envs
|
|
8
8
|
from jax.sharding import NamedSharding, PartitionSpec
|
|
9
9
|
from torchax.ops.mappings import t2j_dtype
|
|
10
|
+
from vllm.attention import Attention
|
|
10
11
|
from vllm.attention.backends.abstract import AttentionType
|
|
11
|
-
from vllm.attention.layer import Attention
|
|
12
12
|
from vllm.config import get_layers_from_vllm_config
|
|
13
|
-
from vllm.utils.math_utils import cdiv
|
|
14
13
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
|
15
14
|
KVCacheSpec, MLAAttentionSpec,
|
|
16
15
|
SlidingWindowSpec)
|
|
@@ -176,11 +175,6 @@ class KVCacheManager:
|
|
|
176
175
|
)
|
|
177
176
|
self.runner.input_batch = new_input_batch
|
|
178
177
|
self.runner.persistent_batch_manager.input_batch = new_input_batch
|
|
179
|
-
self.runner.block_tables_cpu = [
|
|
180
|
-
np.zeros((self.runner.max_num_reqs,
|
|
181
|
-
cdiv(self.runner.max_model_len, block_size)),
|
|
182
|
-
dtype=np.int32) for block_size in block_sizes
|
|
183
|
-
]
|
|
184
178
|
|
|
185
179
|
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
|
186
180
|
self.maybe_reinitialize_input_batch(kv_cache_config)
|
|
@@ -196,7 +190,7 @@ class KVCacheManager:
|
|
|
196
190
|
num_blocks = kv_cache_tensor.size // page_size_bytes
|
|
197
191
|
dp_size = self.runner.vllm_config.sharding_config.total_dp_size
|
|
198
192
|
# num_blocks must be a multiple of dp_size
|
|
199
|
-
num_blocks = (num_blocks
|
|
193
|
+
num_blocks = math.ceil(num_blocks / dp_size) * dp_size
|
|
200
194
|
# NOTE: we'll multiply the num_kv_heads by 2 in the function
|
|
201
195
|
kv_cache = create_kv_caches(
|
|
202
196
|
num_blocks=num_blocks,
|
|
@@ -289,8 +283,13 @@ class KVCacheManager:
|
|
|
289
283
|
|
|
290
284
|
def _update_layer(cache, slices):
|
|
291
285
|
"""The function to apply to each layer's cache and slices."""
|
|
292
|
-
reshaped_slices = slices.reshape(-1, block_size,
|
|
293
|
-
|
|
286
|
+
reshaped_slices = slices.reshape(-1, 1, block_size,
|
|
287
|
+
*slices.shape[1:])
|
|
288
|
+
for (i, block_idx) in enumerate(block_numbers):
|
|
289
|
+
cache = jax.lax.dynamic_update_slice_in_dim(cache,
|
|
290
|
+
reshaped_slices[i],
|
|
291
|
+
block_idx,
|
|
292
|
+
axis=0)
|
|
294
293
|
return cache
|
|
295
294
|
|
|
296
295
|
return jax.tree.map(_update_layer, kv_caches, kv_cache_slices)
|
|
@@ -343,12 +342,16 @@ class KVCacheManager:
|
|
|
343
342
|
"""
|
|
344
343
|
if block_ids == list(range(block_ids[0],
|
|
345
344
|
block_ids[0] + len(block_ids))):
|
|
346
|
-
|
|
347
|
-
|
|
345
|
+
with runner_utils.LatencyTracker(
|
|
346
|
+
"BatchedGatherKVSlices-for-blocks"):
|
|
347
|
+
batched_kv_cache_per_layer = self._jitted_gather_continuous_kv_cache(
|
|
348
|
+
self.runner.kv_caches, block_ids[0], len(block_ids))
|
|
348
349
|
|
|
349
350
|
else:
|
|
350
|
-
|
|
351
|
-
|
|
351
|
+
with runner_utils.LatencyTracker(
|
|
352
|
+
"BatchedGatherKVSlices-for-blocks"):
|
|
353
|
+
batched_kv_cache_per_layer = self._jitted_gather_kv_cache(
|
|
354
|
+
self.runner.kv_caches, jnp.array(block_ids))
|
|
352
355
|
return batched_kv_cache_per_layer
|
|
353
356
|
|
|
354
357
|
def transfer_kv_cache(self,
|
|
@@ -437,7 +440,6 @@ class KVCacheManager:
|
|
|
437
440
|
kv_cache_slices,
|
|
438
441
|
start_block,
|
|
439
442
|
)
|
|
440
|
-
jax.block_until_ready(self.runner.kv_caches)
|
|
441
443
|
else:
|
|
442
444
|
with runner_utils.LatencyTracker(
|
|
443
445
|
f"JittedInsertKVCache-b{len(block_numbers)}"):
|
|
@@ -449,7 +451,6 @@ class KVCacheManager:
|
|
|
449
451
|
kv_cache_slices,
|
|
450
452
|
jnp.array(block_numbers),
|
|
451
453
|
)
|
|
452
|
-
jax.block_until_ready(self.runner.kv_caches)
|
|
453
454
|
|
|
454
455
|
logger.debug(
|
|
455
456
|
f"Updated kv cache entries cnt={len(self.runner.kv_caches)}")
|
|
@@ -14,13 +14,12 @@ class PersistentBatchManager:
|
|
|
14
14
|
def __init__(self, requests: Dict[str, CachedRequestState],
|
|
15
15
|
input_batch: InputBatch, encoder_cache: Dict[str,
|
|
16
16
|
'jax.Array'],
|
|
17
|
-
uses_mrope: bool, model_config
|
|
17
|
+
uses_mrope: bool, model_config):
|
|
18
18
|
self.requests = requests
|
|
19
19
|
self.input_batch = input_batch
|
|
20
20
|
self.encoder_cache = encoder_cache
|
|
21
21
|
self.uses_mrope = uses_mrope
|
|
22
22
|
self.model_config = model_config
|
|
23
|
-
self.is_last_rank = is_last_rank
|
|
24
23
|
|
|
25
24
|
def _reorder_batch(self, scheduler_output: "VllmSchedulerOutput") -> int:
|
|
26
25
|
""" Reorder the sheduled requests to RPA kernel friendly distribution
|
|
@@ -180,35 +179,9 @@ class PersistentBatchManager:
|
|
|
180
179
|
num_computed_tokens = req_data.num_computed_tokens[i]
|
|
181
180
|
new_block_ids = req_data.new_block_ids[i]
|
|
182
181
|
resumed_from_preemption = req_data.resumed_from_preemption[i]
|
|
183
|
-
num_output_tokens = req_data.num_output_tokens[i]
|
|
184
182
|
|
|
185
183
|
# Update the cached states.
|
|
186
184
|
req_state.num_computed_tokens = num_computed_tokens
|
|
187
|
-
req_index = self.input_batch.req_id_to_index.get(req_id)
|
|
188
|
-
|
|
189
|
-
if not self.is_last_rank:
|
|
190
|
-
# When using PP, the scheduler sends the sampled tokens back,
|
|
191
|
-
# because there's no direct communication between the first-
|
|
192
|
-
# stage worker and the last-stage worker.
|
|
193
|
-
new_token_ids = req_data.new_token_ids[i]
|
|
194
|
-
# Add the sampled token(s) from the previous step (if any).
|
|
195
|
-
# This doesn't include "unverified" tokens like spec tokens.
|
|
196
|
-
num_new_tokens = (num_computed_tokens + len(new_token_ids) -
|
|
197
|
-
req_state.num_tokens)
|
|
198
|
-
if num_new_tokens == 1:
|
|
199
|
-
req_state.output_token_ids.append(new_token_ids[-1])
|
|
200
|
-
elif num_new_tokens > 0:
|
|
201
|
-
req_state.output_token_ids.extend(
|
|
202
|
-
new_token_ids[-num_new_tokens:])
|
|
203
|
-
elif num_output_tokens < len(req_state.output_token_ids):
|
|
204
|
-
del req_state.output_token_ids[num_output_tokens:]
|
|
205
|
-
if req_index is not None:
|
|
206
|
-
end_idx = (self.input_batch.num_prompt_tokens[req_index] +
|
|
207
|
-
num_output_tokens)
|
|
208
|
-
self.input_batch.num_tokens[req_index] = end_idx
|
|
209
|
-
self.input_batch.num_tokens_no_spec[req_index] = end_idx
|
|
210
|
-
|
|
211
|
-
# Update the block IDs.
|
|
212
185
|
if not resumed_from_preemption:
|
|
213
186
|
if new_block_ids is not None:
|
|
214
187
|
# Append the new blocks to the existing block IDs.
|
|
@@ -221,6 +194,7 @@ class PersistentBatchManager:
|
|
|
221
194
|
# Replace the existing block IDs with the new ones.
|
|
222
195
|
req_state.block_ids = new_block_ids
|
|
223
196
|
|
|
197
|
+
req_index = self.input_batch.req_id_to_index.get(req_id)
|
|
224
198
|
if req_index is None:
|
|
225
199
|
# The request is not in the persistent batch.
|
|
226
200
|
# The request was either preempted and resumed later, or was not
|
|
@@ -235,18 +209,6 @@ class PersistentBatchManager:
|
|
|
235
209
|
self.input_batch.block_table.append_row(
|
|
236
210
|
new_block_ids, req_index)
|
|
237
211
|
|
|
238
|
-
# For the last rank, we don't need to update the token_ids_cpu
|
|
239
|
-
# because the sampled tokens are already cached.
|
|
240
|
-
if not self.is_last_rank:
|
|
241
|
-
start_token_index = num_computed_tokens
|
|
242
|
-
end_token_index = num_computed_tokens + len(new_token_ids)
|
|
243
|
-
self.input_batch.token_ids_cpu[
|
|
244
|
-
req_index,
|
|
245
|
-
start_token_index:end_token_index] = new_token_ids
|
|
246
|
-
self.input_batch.num_tokens_no_spec[
|
|
247
|
-
req_index] = end_token_index
|
|
248
|
-
self.input_batch.num_tokens[req_index] = end_token_index
|
|
249
|
-
|
|
250
212
|
# Add spec_token_ids to token_ids_cpu.
|
|
251
213
|
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
|
|
252
214
|
req_id, ())
|
|
@@ -61,10 +61,11 @@ class StructuredDecodingManager:
|
|
|
61
61
|
self.runner.require_structured_out_cpu.fill(0)
|
|
62
62
|
|
|
63
63
|
sorted_struct_requests = sorted(
|
|
64
|
-
grammar_output.structured_output_request_ids)
|
|
64
|
+
grammar_output.structured_output_request_ids.items(),
|
|
65
|
+
key=lambda item: item[1])
|
|
65
66
|
|
|
66
67
|
cumulative_mask_idx = 0
|
|
67
|
-
for req_id in sorted_struct_requests:
|
|
68
|
+
for req_id, _ in sorted_struct_requests:
|
|
68
69
|
if req_id not in self.runner.input_batch.req_id_to_index:
|
|
69
70
|
continue
|
|
70
71
|
batch_index = self.runner.input_batch.req_id_to_index[req_id]
|