tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511180814__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +34 -303
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
- tests/lora/test_layers.py +6 -0
- tests/lora/utils.py +8 -0
- tests/test_envs.py +11 -32
- tests/test_utils.py +2 -1
- tpu_inference/__init__.py +3 -22
- tpu_inference/core/disagg_utils.py +8 -6
- tpu_inference/distributed/tpu_connector.py +4 -3
- tpu_inference/distributed/utils.py +2 -3
- tpu_inference/envs.py +8 -61
- tpu_inference/executors/ray_distributed_executor.py +2 -9
- tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +145 -266
- tpu_inference/layers/common/attention_interface.py +1 -7
- tpu_inference/layers/common/sharding.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +208 -170
- tpu_inference/layers/vllm/quantization/common.py +1 -6
- tpu_inference/layers/vllm/quantization/mxfp4.py +73 -138
- tpu_inference/layers/vllm/quantization/unquantized.py +64 -58
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +2 -1
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/common/model_loader.py +10 -43
- tpu_inference/models/jax/llama3.py +1 -2
- tpu_inference/models/jax/llama_eagle3.py +5 -8
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +1 -2
- tpu_inference/models/jax/qwen2_5_vl.py +48 -163
- tpu_inference/models/jax/qwen3.py +1 -2
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
- tpu_inference/models/jax/utils/weight_utils.py +143 -198
- tpu_inference/models/vllm/vllm_model_wrapper.py +8 -14
- tpu_inference/platforms/tpu_platform.py +31 -37
- tpu_inference/runner/compilation_manager.py +58 -141
- tpu_inference/runner/kv_cache.py +1 -1
- tpu_inference/runner/kv_cache_manager.py +18 -17
- tpu_inference/runner/persistent_batch_manager.py +2 -40
- tpu_inference/runner/structured_decoding_manager.py +3 -2
- tpu_inference/runner/tpu_runner.py +147 -271
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +21 -71
- tpu_inference/tpu_info.py +3 -4
- tpu_inference/utils.py +13 -36
- tpu_inference/worker/tpu_worker.py +25 -162
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/METADATA +3 -4
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/RECORD +55 -50
- tpu_inference/models/jax/llama_guard_4.py +0 -361
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/WHEEL +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/top_level.txt +0 -0
|
@@ -10,16 +10,17 @@ import jax
|
|
|
10
10
|
import jax.numpy as jnp
|
|
11
11
|
import jaxtyping
|
|
12
12
|
import numpy as np
|
|
13
|
-
import
|
|
13
|
+
import torch
|
|
14
|
+
import vllm.envs as envs
|
|
14
15
|
from flax import nnx
|
|
15
16
|
from jax.experimental import mesh_utils
|
|
16
17
|
from jax.sharding import NamedSharding, PartitionSpec
|
|
17
|
-
from torchax.ops.mappings import
|
|
18
|
+
from torchax.ops.mappings import j2t_dtype
|
|
18
19
|
from vllm.config import VllmConfig
|
|
19
|
-
from vllm.distributed import get_pp_group
|
|
20
20
|
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
|
21
21
|
has_kv_transfer_group)
|
|
22
22
|
from vllm.forward_context import set_forward_context
|
|
23
|
+
from vllm.sequence import IntermediateTensors
|
|
23
24
|
from vllm.tasks import SupportedTask
|
|
24
25
|
from vllm.utils.math_utils import cdiv
|
|
25
26
|
from vllm.v1.core.sched.output import GrammarOutput
|
|
@@ -34,7 +35,6 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import \
|
|
|
34
35
|
KVConnectorModelRunnerMixin
|
|
35
36
|
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
|
36
37
|
|
|
37
|
-
import tpu_inference.envs as envs
|
|
38
38
|
from tpu_inference import utils as common_utils
|
|
39
39
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
40
40
|
from tpu_inference.layers.common.sharding import (MESH_AXIS_NAMES,
|
|
@@ -48,8 +48,6 @@ from tpu_inference.layers.jax.sample.sampling_metadata import \
|
|
|
48
48
|
TPUSupportedSamplingMetadata
|
|
49
49
|
from tpu_inference.logger import init_logger
|
|
50
50
|
from tpu_inference.models.common.model_loader import get_model
|
|
51
|
-
from tpu_inference.models.jax.jax_intermediate_tensor import \
|
|
52
|
-
JaxIntermediateTensors
|
|
53
51
|
from tpu_inference.models.jax.utils.weight_utils import (
|
|
54
52
|
shard_put, transfer_state_with_mappings)
|
|
55
53
|
from tpu_inference.runner import utils as runner_utils
|
|
@@ -66,7 +64,7 @@ from tpu_inference.runner.structured_decoding_manager import \
|
|
|
66
64
|
StructuredDecodingManager
|
|
67
65
|
from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
|
|
68
66
|
from tpu_inference.utils import (device_array, make_optimized_mesh,
|
|
69
|
-
time_function
|
|
67
|
+
time_function)
|
|
70
68
|
|
|
71
69
|
logger = init_logger(__name__)
|
|
72
70
|
|
|
@@ -80,6 +78,17 @@ DUMMY_METADATA = AttentionMetadata(
|
|
|
80
78
|
request_distribution=[0, 0, 0],
|
|
81
79
|
)
|
|
82
80
|
|
|
81
|
+
TPU_STR_DTYPE_TO_TORCH_DTYPE = {
|
|
82
|
+
"half": torch.half,
|
|
83
|
+
"bfloat16": torch.bfloat16,
|
|
84
|
+
"float": torch.float,
|
|
85
|
+
"fp8": torch.float8_e4m3fn,
|
|
86
|
+
"fp8_e4m3": torch.float8_e4m3fn,
|
|
87
|
+
"fp8_e5m2": torch.float8_e5m2,
|
|
88
|
+
"int8": torch.int8,
|
|
89
|
+
"uint8": torch.uint8,
|
|
90
|
+
}
|
|
91
|
+
|
|
83
92
|
|
|
84
93
|
class AsyncTPUModelRunnerOutput(AsyncModelRunnerOutput):
|
|
85
94
|
"""Holds asynchronous model output specifically from a TPU runner.
|
|
@@ -144,7 +153,6 @@ class ExecuteModelState:
|
|
|
144
153
|
spec_decode_metadata: Optional[SpecDecodeMetadata]
|
|
145
154
|
kv_connector_output: Optional[KVConnectorOutput]
|
|
146
155
|
logits_indices_selector: Optional[List[int]] = None
|
|
147
|
-
padded_num_reqs: Optional[int] = None
|
|
148
156
|
|
|
149
157
|
|
|
150
158
|
@functools.partial(jax.jit, donate_argnums=(0, 1, 2))
|
|
@@ -182,28 +190,18 @@ def _substitute_placeholder_token(
|
|
|
182
190
|
return input_ids.at[token_in_tpu_cur_input_indices].set(update_values)
|
|
183
191
|
|
|
184
192
|
|
|
185
|
-
def
|
|
186
|
-
logits_indices_selector=None,
|
|
187
|
-
cu_num_generated_tokens=None):
|
|
188
|
-
"""Convert JAX LogprobsTensors to LogprobsLists by converting JAX arrays to numpy."""
|
|
189
|
-
log_token_ids_list = logprobs_tensors.logprob_token_ids.tolist()
|
|
190
|
-
logprobs_list = logprobs_tensors.logprobs.tolist()
|
|
191
|
-
selected_token_ranks_list = logprobs_tensors.selected_token_ranks.tolist()
|
|
192
|
-
|
|
193
|
-
if logits_indices_selector is not None:
|
|
194
|
-
log_token_ids_list = [
|
|
195
|
-
log_token_ids_list[i] for i in logits_indices_selector
|
|
196
|
-
]
|
|
197
|
-
logprobs_list = [logprobs_list[i] for i in logits_indices_selector]
|
|
198
|
-
selected_token_ranks_list = [
|
|
199
|
-
selected_token_ranks_list[i] for i in logits_indices_selector
|
|
200
|
-
]
|
|
201
|
-
|
|
193
|
+
def _reorder_logits_indices(logprobs_lists, logits_indices_selector):
|
|
202
194
|
return LogprobsLists(
|
|
203
|
-
logprob_token_ids=
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
195
|
+
logprob_token_ids=[
|
|
196
|
+
logprobs_lists.logprob_token_ids[i]
|
|
197
|
+
for i in logits_indices_selector
|
|
198
|
+
],
|
|
199
|
+
logprobs=[logprobs_lists.logprobs[i] for i in logits_indices_selector],
|
|
200
|
+
sampled_token_ranks=[
|
|
201
|
+
logprobs_lists.sampled_token_ranks[i]
|
|
202
|
+
for i in logits_indices_selector
|
|
203
|
+
],
|
|
204
|
+
cu_num_generated_tokens=logprobs_lists.cu_num_generated_tokens,
|
|
207
205
|
)
|
|
208
206
|
|
|
209
207
|
|
|
@@ -213,9 +211,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
213
211
|
self,
|
|
214
212
|
vllm_config: VllmConfig,
|
|
215
213
|
devices: List[Any],
|
|
216
|
-
rank: int = 0,
|
|
217
|
-
is_first_rank: bool = True,
|
|
218
|
-
is_last_rank: bool = True,
|
|
219
214
|
):
|
|
220
215
|
self.vllm_config = vllm_config
|
|
221
216
|
self.model_config = vllm_config.model_config
|
|
@@ -234,9 +229,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
234
229
|
self.maybe_forbid_compile = runner_utils.ForbidCompile(
|
|
235
230
|
) if envs.VLLM_XLA_CHECK_RECOMPILATION else nullcontext()
|
|
236
231
|
self.dp_size = self.vllm_config.sharding_config.total_dp_size
|
|
237
|
-
self.rank = rank
|
|
238
|
-
self.is_first_rank = is_first_rank
|
|
239
|
-
self.is_last_rank = is_last_rank
|
|
240
232
|
|
|
241
233
|
self._init_random()
|
|
242
234
|
self._init_mesh()
|
|
@@ -247,21 +239,31 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
247
239
|
|
|
248
240
|
# Delegate functions to specific manager classes.
|
|
249
241
|
self.compilation_manager = CompilationManager(self)
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
self)
|
|
253
|
-
self.structured_decoding_manager = StructuredDecodingManager(self)
|
|
242
|
+
self.speculative_decoding_manager = SpeculativeDecodingManager(self)
|
|
243
|
+
self.structured_decoding_manager = StructuredDecodingManager(self)
|
|
254
244
|
self.kv_cache_manager = KVCacheManager(self)
|
|
255
245
|
self.mm_manager = MultiModalManager(self)
|
|
256
246
|
self.persistent_batch_manager = PersistentBatchManager(
|
|
257
247
|
self.requests, self.input_batch, self.encoder_cache,
|
|
258
|
-
self.uses_mrope, self.model_config
|
|
248
|
+
self.uses_mrope, self.model_config)
|
|
259
249
|
self.lora_utils = LoraUtils(self)
|
|
260
250
|
|
|
261
|
-
|
|
262
|
-
if cache_dtype == "auto":
|
|
263
|
-
|
|
264
|
-
|
|
251
|
+
cache_config = self.cache_config
|
|
252
|
+
if cache_config.cache_dtype == "auto":
|
|
253
|
+
model_dtype = self.dtype
|
|
254
|
+
if isinstance(model_dtype, str):
|
|
255
|
+
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
|
|
256
|
+
elif isinstance(getattr(model_dtype, 'dtype', None), jnp.dtype):
|
|
257
|
+
self.kv_cache_dtype = j2t_dtype(model_dtype.dtype)
|
|
258
|
+
elif isinstance(model_dtype, torch.dtype):
|
|
259
|
+
self.kv_cache_dtype = model_dtype
|
|
260
|
+
else:
|
|
261
|
+
raise ValueError(
|
|
262
|
+
"KV cache is unsupported for model_dtype of %s",
|
|
263
|
+
model_dtype)
|
|
264
|
+
else:
|
|
265
|
+
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[
|
|
266
|
+
cache_config.cache_dtype]
|
|
265
267
|
|
|
266
268
|
self._pre_async_results: AsyncPreResults | None = None
|
|
267
269
|
self._substitute_placeholder_token_fn = _substitute_placeholder_token
|
|
@@ -275,7 +277,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
275
277
|
self.rng_key = jax.random.key(self.model_config.seed)
|
|
276
278
|
|
|
277
279
|
def _init_mesh(self) -> None:
|
|
278
|
-
if
|
|
280
|
+
if os.getenv("NEW_MODEL_DESIGN", False):
|
|
279
281
|
self.mesh = self._create_new_model_mesh()
|
|
280
282
|
else:
|
|
281
283
|
# NOTE(wenxindongwork): The new MoE kernel expects a 2D mesh, so we need
|
|
@@ -286,7 +288,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
286
288
|
logger.info(f"Init mesh | mesh={self.mesh}")
|
|
287
289
|
|
|
288
290
|
def _create_new_model_mesh(self) -> jax.sharding.Mesh:
|
|
289
|
-
num_slices =
|
|
291
|
+
num_slices = int(os.environ.get('NUM_SLICES', 1))
|
|
290
292
|
|
|
291
293
|
logger.info(f"Creating new model mesh | devices={len(self.devices)}, "
|
|
292
294
|
f"num_slices={num_slices}")
|
|
@@ -355,7 +357,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
355
357
|
devices=self.devices)
|
|
356
358
|
|
|
357
359
|
def _init_phased_profiling(self) -> None:
|
|
358
|
-
self.phased_profiling_dir =
|
|
360
|
+
self.phased_profiling_dir = os.getenv("PHASED_PROFILING_DIR", "")
|
|
359
361
|
self.phase_based_profiler = None
|
|
360
362
|
if self.phased_profiling_dir:
|
|
361
363
|
self.phase_based_profiler = runner_utils.PhasedBasedProfiler(
|
|
@@ -397,7 +399,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
397
399
|
min_token_size=max(16, self.dp_size),
|
|
398
400
|
max_token_size=scheduler_config.max_num_batched_tokens *
|
|
399
401
|
self.dp_size,
|
|
400
|
-
padding_gap=
|
|
402
|
+
padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
|
|
401
403
|
self.num_tokens_paddings_per_dp = [
|
|
402
404
|
padding // self.dp_size for padding in self.num_tokens_paddings
|
|
403
405
|
]
|
|
@@ -421,14 +423,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
421
423
|
|
|
422
424
|
self.input_ids_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
|
|
423
425
|
self.positions_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
# in kv_cache_manager's maybe_reinitialize_input_batch.
|
|
427
|
-
self.block_tables_cpu = [
|
|
428
|
-
np.zeros((self.max_num_reqs, self.max_num_blocks_per_req),
|
|
429
|
-
dtype=np.int32)
|
|
430
|
-
]
|
|
431
|
-
|
|
426
|
+
self.block_table_cpu = np.zeros(
|
|
427
|
+
(self.max_num_reqs, self.max_num_blocks_per_req), dtype=np.int32)
|
|
432
428
|
self.query_start_loc_cpu = np.zeros(self.max_num_reqs + self.dp_size,
|
|
433
429
|
dtype=np.int32)
|
|
434
430
|
self.seq_lens_cpu = np.zeros(self.max_num_reqs, dtype=np.int32)
|
|
@@ -462,6 +458,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
462
458
|
|
|
463
459
|
# tensors for structured decoding
|
|
464
460
|
self.vocab_size = self.model_config.get_vocab_size()
|
|
461
|
+
if self.lora_config is not None:
|
|
462
|
+
# lora_config.lora_extra_vocab_size is the "Maximum size of extra vocabulary that can be present in a LoRA adapter" per https://github.com/vanbasten23/vllm/blob/7f4a8b6705622fde952a2e633e86716f902d6e1b/vllm/config.py#L3040
|
|
463
|
+
self.vocab_size += self.lora_config.lora_extra_vocab_size
|
|
465
464
|
self.grammar_bitmask_cpu = np.zeros(
|
|
466
465
|
(self.max_num_reqs, cdiv(self.vocab_size, 32)),
|
|
467
466
|
dtype=np.int32,
|
|
@@ -506,14 +505,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
506
505
|
|
|
507
506
|
self.rng_params_for_sampling = nnx.Rngs(
|
|
508
507
|
jax.random.key(self.model_config.seed)).params()
|
|
509
|
-
self.is_multimodal_model = (
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
self.model_config.hf_config, "architectures"
|
|
513
|
-
) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
|
|
514
|
-
and len(self.model_config.hf_config.architectures) >= 1
|
|
515
|
-
and self.model_config.hf_config.architectures[0]
|
|
516
|
-
!= "Llama4ForConditionalGeneration")
|
|
508
|
+
self.is_multimodal_model = (self.model_config.is_multimodal_model
|
|
509
|
+
and self.get_multimodal_embeddings_fn
|
|
510
|
+
is not None)
|
|
517
511
|
|
|
518
512
|
logger.info(f"Init model | "
|
|
519
513
|
f"hbm={common_utils.hbm_usage_gb(self.devices)}GiB")
|
|
@@ -526,7 +520,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
526
520
|
|
|
527
521
|
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
|
528
522
|
self.kv_cache_config = kv_cache_config
|
|
529
|
-
self.use_hybrid_kvcache = len(kv_cache_config.kv_cache_groups) > 1
|
|
530
523
|
self.kv_caches = []
|
|
531
524
|
self.kv_cache_manager.initialize_kv_cache(kv_cache_config)
|
|
532
525
|
if has_kv_transfer_group():
|
|
@@ -539,12 +532,12 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
539
532
|
def execute_model(
|
|
540
533
|
self,
|
|
541
534
|
scheduler_output: "VllmSchedulerOutput",
|
|
542
|
-
intermediate_tensors: Optional[
|
|
543
|
-
) -> ModelRunnerOutput |
|
|
535
|
+
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
536
|
+
) -> ModelRunnerOutput | None:
|
|
544
537
|
if self.execute_model_state is not None:
|
|
545
538
|
raise RuntimeError("State error: sample_tokens() must be called "
|
|
546
539
|
"after execute_model() returns None.")
|
|
547
|
-
_, output = self._execute_model(scheduler_output
|
|
540
|
+
_, output = self._execute_model(scheduler_output)
|
|
548
541
|
return output
|
|
549
542
|
|
|
550
543
|
def sample_tokens(
|
|
@@ -557,17 +550,16 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
557
550
|
|
|
558
551
|
(scheduler_output, attn_metadata, input_ids, hidden_states, logits,
|
|
559
552
|
aux_hidden_states, spec_decode_metadata, kv_connector_output,
|
|
560
|
-
logits_indices_selector
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
self.execute_model_state.padded_num_reqs)
|
|
553
|
+
logits_indices_selector) = (
|
|
554
|
+
self.execute_model_state.scheduler_output,
|
|
555
|
+
self.execute_model_state.attn_metadata,
|
|
556
|
+
self.execute_model_state.input_ids,
|
|
557
|
+
self.execute_model_state.hidden_states,
|
|
558
|
+
self.execute_model_state.logits,
|
|
559
|
+
self.execute_model_state.aux_hidden_states,
|
|
560
|
+
self.execute_model_state.spec_decode_metadata,
|
|
561
|
+
self.execute_model_state.kv_connector_output,
|
|
562
|
+
self.execute_model_state.logits_indices_selector)
|
|
571
563
|
self.execute_model_state = None
|
|
572
564
|
|
|
573
565
|
if grammar_output is not None:
|
|
@@ -581,10 +573,12 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
581
573
|
logits,
|
|
582
574
|
arange,
|
|
583
575
|
)
|
|
584
|
-
return self._sample_from_logits(
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
576
|
+
return self._sample_from_logits(scheduler_output, attn_metadata,
|
|
577
|
+
input_ids, hidden_states, logits,
|
|
578
|
+
aux_hidden_states,
|
|
579
|
+
spec_decode_metadata,
|
|
580
|
+
kv_connector_output,
|
|
581
|
+
logits_indices_selector)
|
|
588
582
|
|
|
589
583
|
def _modify_prev_results(self):
|
|
590
584
|
# If copy to host has not been done, we just wait.
|
|
@@ -670,9 +664,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
670
664
|
def _execute_model(
|
|
671
665
|
self,
|
|
672
666
|
scheduler_output: "VllmSchedulerOutput",
|
|
673
|
-
|
|
674
|
-
) -> tuple[AttentionMetadata, JaxIntermediateTensors | ModelRunnerOutput
|
|
675
|
-
| None]:
|
|
667
|
+
) -> tuple[AttentionMetadata, ModelRunnerOutput | None]:
|
|
676
668
|
self.persistent_batch_manager.update_states(
|
|
677
669
|
scheduler_output, self.get_mrope_input_positions_fn)
|
|
678
670
|
if not scheduler_output.total_num_scheduled_tokens:
|
|
@@ -695,23 +687,13 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
695
687
|
# TODO(pooyam): I guess we can remove returning sampling_metadata in `_prepare_inputs` after https://github.com/njhill/vllm/commit/b7433ca1a47732394b1bdea4099d98389515954b
|
|
696
688
|
(
|
|
697
689
|
input_ids,
|
|
698
|
-
input_positions,
|
|
699
690
|
attn_metadata,
|
|
700
691
|
_,
|
|
701
692
|
logits_indices,
|
|
702
693
|
spec_decode_metadata,
|
|
703
694
|
logits_indices_selector,
|
|
704
|
-
padded_num_reqs,
|
|
705
695
|
) = self._prepare_inputs(scheduler_output)
|
|
706
696
|
|
|
707
|
-
is_llama_guard_4 = (
|
|
708
|
-
hasattr(
|
|
709
|
-
self.model_config.hf_config, "architectures"
|
|
710
|
-
) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
|
|
711
|
-
and len(self.model_config.hf_config.architectures) >= 1
|
|
712
|
-
and self.model_config.hf_config.architectures[0]
|
|
713
|
-
== "Llama4ForConditionalGeneration")
|
|
714
|
-
|
|
715
697
|
# multi-modal support
|
|
716
698
|
if self.is_multimodal_model:
|
|
717
699
|
# Run the multimodal encoder if any.
|
|
@@ -719,13 +701,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
719
701
|
self.mm_manager.execute_mm_encoder(scheduler_output)
|
|
720
702
|
mm_embeds = self.mm_manager.gather_mm_embeddings(
|
|
721
703
|
scheduler_output, input_ids.shape[0])
|
|
722
|
-
#TODO: Remove the follow elif statement once Llama Guard 4 Vision portion has been implemented
|
|
723
|
-
elif is_llama_guard_4 and any(
|
|
724
|
-
self.mm_manager.runner.requests[req_id].mm_features
|
|
725
|
-
for req_id in self.mm_manager.runner.input_batch.req_ids):
|
|
726
|
-
raise NotImplementedError(
|
|
727
|
-
"Llama Guard 4 (JAX) currently supports only text inputs. "
|
|
728
|
-
"Multimodal processing not yet implemented.")
|
|
729
704
|
else:
|
|
730
705
|
mm_embeds = []
|
|
731
706
|
|
|
@@ -750,6 +725,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
750
725
|
scheduler_output) as kv_connector_output:
|
|
751
726
|
# NOTE(Wenlong): It takes both `input_ids` and `inputs_embeds`,
|
|
752
727
|
# but one of them would be `None`
|
|
728
|
+
|
|
753
729
|
(self.kv_caches, hidden_states,
|
|
754
730
|
aux_hidden_states) = self.model_fn(
|
|
755
731
|
self.state,
|
|
@@ -757,17 +733,10 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
757
733
|
input_ids,
|
|
758
734
|
attn_metadata,
|
|
759
735
|
inputs_embeds,
|
|
760
|
-
input_positions,
|
|
761
736
|
tuple(self.layer_name_to_kvcache_index.items()),
|
|
762
737
|
lora_metadata,
|
|
763
|
-
intermediate_tensors,
|
|
764
|
-
self.is_first_rank,
|
|
765
|
-
self.is_last_rank,
|
|
766
738
|
)
|
|
767
|
-
|
|
768
|
-
assert isinstance(hidden_states, JaxIntermediateTensors)
|
|
769
|
-
hidden_states.kv_connector_output = kv_connector_output
|
|
770
|
-
return attn_metadata, hidden_states
|
|
739
|
+
|
|
771
740
|
hidden_states = self._select_from_array_fn(hidden_states,
|
|
772
741
|
logits_indices)
|
|
773
742
|
logits = self.compute_logits_fn(
|
|
@@ -785,8 +754,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
785
754
|
aux_hidden_states=aux_hidden_states,
|
|
786
755
|
spec_decode_metadata=spec_decode_metadata,
|
|
787
756
|
kv_connector_output=kv_connector_output,
|
|
788
|
-
logits_indices_selector=logits_indices_selector
|
|
789
|
-
padded_num_reqs=padded_num_reqs)
|
|
757
|
+
logits_indices_selector=logits_indices_selector)
|
|
790
758
|
return attn_metadata, None
|
|
791
759
|
|
|
792
760
|
def _sample_from_logits(
|
|
@@ -800,44 +768,23 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
800
768
|
spec_decode_metadata: Optional[SpecDecodeMetadata],
|
|
801
769
|
kv_connector_output: Optional[KVConnectorOutput],
|
|
802
770
|
logits_indices_selector: Optional[List[int]] = None,
|
|
803
|
-
padded_num_reqs: Optional[int] = None,
|
|
804
771
|
) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput:
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
self.input_batch.num_reqs, self.max_num_reqs)
|
|
808
|
-
|
|
809
|
-
sharding = None
|
|
810
|
-
if self.dp_size > 1:
|
|
811
|
-
sharding = NamedSharding(self.mesh,
|
|
812
|
-
PartitionSpec(ShardingAxisName.ATTN_DATA))
|
|
813
|
-
|
|
772
|
+
padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit(
|
|
773
|
+
self.input_batch.num_reqs, self.max_num_reqs)
|
|
814
774
|
tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
|
|
815
|
-
self.mesh, self.input_batch, padded_num_reqs
|
|
816
|
-
|
|
817
|
-
# TODO(pooyam): Should we move this to `_prepare_inputs`?
|
|
818
|
-
if tpu_sampling_metadata.do_sampling:
|
|
819
|
-
self.rng_params_for_sampling, step_rng = jax.random.split(
|
|
820
|
-
self.rng_params_for_sampling)
|
|
821
|
-
else:
|
|
822
|
-
step_rng = self.rng_params_for_sampling
|
|
823
|
-
|
|
775
|
+
self.mesh, self.input_batch, padded_num_reqs)
|
|
824
776
|
if spec_decode_metadata is None:
|
|
825
777
|
next_tokens = sample(
|
|
826
|
-
|
|
778
|
+
self.rng_params_for_sampling,
|
|
827
779
|
self.mesh,
|
|
828
780
|
logits,
|
|
829
781
|
tpu_sampling_metadata,
|
|
830
782
|
)
|
|
831
783
|
else:
|
|
832
|
-
if tpu_sampling_metadata.do_sampling:
|
|
833
|
-
bonus_rng, rejection_rng = jax.random.split(step_rng)
|
|
834
|
-
else:
|
|
835
|
-
bonus_rng = step_rng
|
|
836
|
-
rejection_rng = step_rng
|
|
837
784
|
bonus_logits = self._select_from_array_fn(
|
|
838
785
|
logits, spec_decode_metadata.bonus_logits_indices)
|
|
839
786
|
bonus_token_ids = sample(
|
|
840
|
-
|
|
787
|
+
self.rng_params_for_sampling,
|
|
841
788
|
self.mesh,
|
|
842
789
|
bonus_logits,
|
|
843
790
|
tpu_sampling_metadata,
|
|
@@ -851,7 +798,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
851
798
|
target_logits=target_logits,
|
|
852
799
|
bonus_token_ids=bonus_token_ids,
|
|
853
800
|
sampling_metadata=tpu_sampling_metadata,
|
|
854
|
-
key=
|
|
801
|
+
key=self.rng_params_for_sampling,
|
|
855
802
|
)
|
|
856
803
|
|
|
857
804
|
if tpu_sampling_metadata.logprobs:
|
|
@@ -909,8 +856,10 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
909
856
|
|
|
910
857
|
if logprobs is not None:
|
|
911
858
|
# Map logprobs back to the pre-dp shuffling order
|
|
912
|
-
logprobs_lists =
|
|
913
|
-
|
|
859
|
+
logprobs_lists = logprobs.tolists()
|
|
860
|
+
if logits_indices_selector is not None:
|
|
861
|
+
logprobs_lists = _reorder_logits_indices(
|
|
862
|
+
logprobs_lists, logits_indices_selector)
|
|
914
863
|
|
|
915
864
|
else:
|
|
916
865
|
logprobs_lists = None
|
|
@@ -980,8 +929,10 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
980
929
|
|
|
981
930
|
if logprobs is not None:
|
|
982
931
|
# Map logprobs back to the pre-dp shuffling order
|
|
983
|
-
logprobs_lists =
|
|
984
|
-
|
|
932
|
+
logprobs_lists = logprobs.tolists()
|
|
933
|
+
if logits_indices_selector is not None:
|
|
934
|
+
logprobs_lists = _reorder_logits_indices(
|
|
935
|
+
logprobs_lists, logits_indices_selector)
|
|
985
936
|
else:
|
|
986
937
|
logprobs_lists = None
|
|
987
938
|
|
|
@@ -1329,6 +1280,16 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1329
1280
|
mrope_positions = self.mrope_positions_cpu[:, :
|
|
1330
1281
|
padded_total_num_scheduled_tokens]
|
|
1331
1282
|
|
|
1283
|
+
block_tables = self.block_table_cpu[:self.max_num_reqs]
|
|
1284
|
+
for dp_rank in range(dp_size):
|
|
1285
|
+
req_offset = dp_rank * max_num_reqs_per_dp_rank
|
|
1286
|
+
_num_reqs = num_req_per_dp_rank[dp_rank]
|
|
1287
|
+
|
|
1288
|
+
block_tables[
|
|
1289
|
+
req_offset:req_offset + _num_reqs, :self.
|
|
1290
|
+
max_num_blocks_per_req] = self.input_batch.block_table[
|
|
1291
|
+
0].get_cpu_tensor()[req_indices_dp[dp_rank]]
|
|
1292
|
+
|
|
1332
1293
|
query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs +
|
|
1333
1294
|
dp_size]
|
|
1334
1295
|
seq_lens = self.seq_lens_cpu[:self.max_num_reqs]
|
|
@@ -1370,59 +1331,20 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1370
1331
|
if self.uses_mrope:
|
|
1371
1332
|
positions = mrope_positions
|
|
1372
1333
|
|
|
1334
|
+
# Convert block_tables to 1D on cpu.
|
|
1335
|
+
block_tables = block_tables.reshape(-1)
|
|
1336
|
+
|
|
1373
1337
|
query_start_loc_cpu = query_start_loc
|
|
1374
1338
|
logits_indices_cpu = logits_indices
|
|
1375
1339
|
seq_lens_cpu = seq_lens
|
|
1376
1340
|
|
|
1377
|
-
(input_ids, positions, query_start_loc, seq_lens,
|
|
1378
|
-
request_distribution) = device_array(
|
|
1341
|
+
(input_ids, positions, block_tables, query_start_loc, seq_lens,
|
|
1342
|
+
logits_indices, request_distribution) = device_array(
|
|
1379
1343
|
self.mesh,
|
|
1380
|
-
(input_ids, positions, query_start_loc, seq_lens,
|
|
1381
|
-
request_distribution),
|
|
1344
|
+
(input_ids, positions, block_tables, query_start_loc, seq_lens,
|
|
1345
|
+
logits_indices, request_distribution),
|
|
1382
1346
|
sharding=data_parallel_attn_sharding,
|
|
1383
1347
|
)
|
|
1384
|
-
|
|
1385
|
-
attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
|
|
1386
|
-
uniform_attention_metadata: AttentionMetadata = None
|
|
1387
|
-
for kv_cache_gid, kv_cache_group in enumerate(
|
|
1388
|
-
self.kv_cache_config.kv_cache_groups):
|
|
1389
|
-
block_tables = self.block_tables_cpu[kv_cache_gid][:self.
|
|
1390
|
-
max_num_reqs]
|
|
1391
|
-
for dp_rank in range(dp_size):
|
|
1392
|
-
req_offset = dp_rank * max_num_reqs_per_dp_rank
|
|
1393
|
-
_num_reqs = num_req_per_dp_rank[dp_rank]
|
|
1394
|
-
|
|
1395
|
-
block_tables[
|
|
1396
|
-
req_offset:req_offset + _num_reqs, :self.
|
|
1397
|
-
max_num_blocks_per_req] = self.input_batch.block_table[
|
|
1398
|
-
0].get_cpu_tensor()[req_indices_dp[dp_rank]]
|
|
1399
|
-
# Convert block_tables to 1D on cpu.
|
|
1400
|
-
block_tables = block_tables.reshape(-1)
|
|
1401
|
-
block_tables = device_array(
|
|
1402
|
-
self.mesh,
|
|
1403
|
-
(block_tables),
|
|
1404
|
-
sharding=data_parallel_attn_sharding,
|
|
1405
|
-
)
|
|
1406
|
-
|
|
1407
|
-
attention_metadata_gid = AttentionMetadata(
|
|
1408
|
-
input_positions=positions,
|
|
1409
|
-
block_tables=block_tables,
|
|
1410
|
-
seq_lens=seq_lens,
|
|
1411
|
-
query_start_loc=query_start_loc,
|
|
1412
|
-
request_distribution=request_distribution,
|
|
1413
|
-
)
|
|
1414
|
-
|
|
1415
|
-
# This is for making these cpu buffers hidden during tracing
|
|
1416
|
-
attention_metadata_gid.query_start_loc_cpu = query_start_loc_cpu
|
|
1417
|
-
attention_metadata_gid.seq_lens_cpu = seq_lens_cpu
|
|
1418
|
-
|
|
1419
|
-
if not self.use_hybrid_kvcache:
|
|
1420
|
-
uniform_attention_metadata = attention_metadata_gid
|
|
1421
|
-
else:
|
|
1422
|
-
for layer_name in kv_cache_group.layer_names:
|
|
1423
|
-
attention_metadata_per_layer[
|
|
1424
|
-
layer_name] = attention_metadata_gid
|
|
1425
|
-
|
|
1426
1348
|
# Async scheduling: substitute placeholder tokens for DP
|
|
1427
1349
|
if self.scheduler_config.async_scheduling and self._pre_async_results is not None:
|
|
1428
1350
|
# Collect all token indices that need substitution across all DP ranks
|
|
@@ -1451,19 +1373,25 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1451
1373
|
padded_total_num_scheduled_tokens,
|
|
1452
1374
|
)
|
|
1453
1375
|
|
|
1454
|
-
|
|
1455
|
-
|
|
1456
|
-
|
|
1457
|
-
|
|
1376
|
+
attention_metadata = AttentionMetadata(
|
|
1377
|
+
input_positions=positions,
|
|
1378
|
+
block_tables=block_tables,
|
|
1379
|
+
seq_lens=seq_lens,
|
|
1380
|
+
query_start_loc=query_start_loc,
|
|
1381
|
+
request_distribution=request_distribution,
|
|
1382
|
+
)
|
|
1383
|
+
|
|
1384
|
+
# This is for making these cpu buffers hidden during tracing
|
|
1385
|
+
attention_metadata.query_start_loc_cpu = query_start_loc_cpu
|
|
1386
|
+
attention_metadata.seq_lens_cpu = seq_lens_cpu
|
|
1387
|
+
|
|
1458
1388
|
return (
|
|
1459
1389
|
input_ids,
|
|
1460
|
-
positions,
|
|
1461
1390
|
attention_metadata,
|
|
1462
1391
|
sampling_metadata,
|
|
1463
1392
|
logits_indices,
|
|
1464
1393
|
spec_decode_metadata,
|
|
1465
1394
|
logits_indices_selector,
|
|
1466
|
-
padded_num_reqs,
|
|
1467
1395
|
)
|
|
1468
1396
|
|
|
1469
1397
|
def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"):
|
|
@@ -1564,6 +1492,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1564
1492
|
positions = self.positions_cpu[:padded_total_num_scheduled_tokens]
|
|
1565
1493
|
mrope_positions = self.mrope_positions_cpu[:, :
|
|
1566
1494
|
padded_total_num_scheduled_tokens]
|
|
1495
|
+
block_tables = self.block_table_cpu[:self.max_num_reqs]
|
|
1496
|
+
block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
|
|
1497
|
+
self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs])
|
|
1567
1498
|
|
|
1568
1499
|
# TODO(pooyam): Some paddings are up to `num_reqs_paddings` (spec decoding, select hidden states, etc) and some other are to `max_num_reqs` (block table, seq_lens). We should stick to one of them maybe?
|
|
1569
1500
|
query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1]
|
|
@@ -1592,44 +1523,16 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1592
1523
|
self.mesh, self.input_batch, padded_num_reqs)
|
|
1593
1524
|
if self.uses_mrope:
|
|
1594
1525
|
positions = mrope_positions
|
|
1526
|
+
|
|
1527
|
+
# Convert block_tables to 1D on cpu.
|
|
1528
|
+
block_tables = block_tables.reshape(-1)
|
|
1529
|
+
|
|
1595
1530
|
query_start_loc_cpu = query_start_loc
|
|
1596
1531
|
seq_lens_cpu = seq_lens
|
|
1597
|
-
|
|
1598
|
-
(input_ids, positions, query_start_loc, seq_lens,
|
|
1532
|
+
(input_ids, positions, block_tables, query_start_loc, seq_lens,
|
|
1599
1533
|
logits_indices, request_distribution) = device_array(
|
|
1600
|
-
self.mesh, (input_ids, positions,
|
|
1601
|
-
logits_indices, request_distribution))
|
|
1602
|
-
|
|
1603
|
-
attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
|
|
1604
|
-
uniform_attention_metadata: AttentionMetadata = None
|
|
1605
|
-
for kv_cache_gid, kv_cache_group in enumerate(
|
|
1606
|
-
self.kv_cache_config.kv_cache_groups):
|
|
1607
|
-
block_tables = self.block_tables_cpu[kv_cache_gid][:self.
|
|
1608
|
-
max_num_reqs]
|
|
1609
|
-
block_tables[:num_reqs] = (
|
|
1610
|
-
self.input_batch.block_table[kv_cache_gid].get_cpu_tensor()
|
|
1611
|
-
[:num_reqs])
|
|
1612
|
-
# Convert block_tables to 1D on cpu.
|
|
1613
|
-
block_tables = block_tables.reshape(-1)
|
|
1614
|
-
block_tables = device_array(self.mesh, (block_tables))
|
|
1615
|
-
|
|
1616
|
-
attention_metadata_gid = AttentionMetadata(
|
|
1617
|
-
input_positions=positions,
|
|
1618
|
-
block_tables=block_tables,
|
|
1619
|
-
seq_lens=seq_lens,
|
|
1620
|
-
query_start_loc=query_start_loc,
|
|
1621
|
-
request_distribution=request_distribution)
|
|
1622
|
-
# This is for making these cpu buffers hidden during tracing
|
|
1623
|
-
attention_metadata_gid.query_start_loc_cpu = query_start_loc_cpu
|
|
1624
|
-
attention_metadata_gid.seq_lens_cpu = seq_lens_cpu
|
|
1625
|
-
|
|
1626
|
-
if not self.use_hybrid_kvcache:
|
|
1627
|
-
# all layers share the same attention metadata
|
|
1628
|
-
uniform_attention_metadata = attention_metadata_gid
|
|
1629
|
-
else:
|
|
1630
|
-
for layer_name in kv_cache_group.layer_names:
|
|
1631
|
-
attention_metadata_per_layer[
|
|
1632
|
-
layer_name] = attention_metadata_gid
|
|
1534
|
+
self.mesh, (input_ids, positions, block_tables, query_start_loc,
|
|
1535
|
+
seq_lens, logits_indices, request_distribution))
|
|
1633
1536
|
|
|
1634
1537
|
if self.scheduler_config.async_scheduling and len(
|
|
1635
1538
|
token_in_tpu_cur_input_indices) > 0:
|
|
@@ -1642,15 +1545,20 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1642
1545
|
self.lora_utils.set_active_loras(
|
|
1643
1546
|
num_scheduled_tokens_per_req, total_num_scheduled_tokens,
|
|
1644
1547
|
padded_total_num_scheduled_tokens)
|
|
1645
|
-
logits_indices_selector = None
|
|
1646
1548
|
|
|
1647
|
-
|
|
1648
|
-
|
|
1649
|
-
|
|
1650
|
-
|
|
1651
|
-
|
|
1652
|
-
|
|
1653
|
-
|
|
1549
|
+
attention_metadata = AttentionMetadata(
|
|
1550
|
+
input_positions=positions,
|
|
1551
|
+
block_tables=block_tables,
|
|
1552
|
+
seq_lens=seq_lens,
|
|
1553
|
+
query_start_loc=query_start_loc,
|
|
1554
|
+
request_distribution=request_distribution)
|
|
1555
|
+
|
|
1556
|
+
# This is for making these cpu buffers hidden during tracing
|
|
1557
|
+
attention_metadata.query_start_loc_cpu = query_start_loc_cpu
|
|
1558
|
+
attention_metadata.seq_lens_cpu = seq_lens_cpu
|
|
1559
|
+
logits_indices_selector = None
|
|
1560
|
+
return (input_ids, attention_metadata, sampling_metadata,
|
|
1561
|
+
logits_indices, spec_decode_metadata, logits_indices_selector)
|
|
1654
1562
|
|
|
1655
1563
|
def _get_input_ids_embeds(self, input_ids: jax.Array,
|
|
1656
1564
|
mm_embeds: list[jax.Array]):
|
|
@@ -1710,35 +1618,3 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1710
1618
|
mappings=mappings,
|
|
1711
1619
|
transpose_keys=transpose_keys,
|
|
1712
1620
|
shard=shard)
|
|
1713
|
-
|
|
1714
|
-
def get_intermediate_tensor_spec(self, num_tokens: int):
|
|
1715
|
-
impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower()
|
|
1716
|
-
jax_dtype = t2j_dtype(self.dtype) if impl == "vllm" else self.dtype
|
|
1717
|
-
num_padded_tokens = runner_utils.get_padded_token_len(
|
|
1718
|
-
self.num_tokens_paddings, num_tokens)
|
|
1719
|
-
sharding = NamedSharding(self.mesh, PartitionSpec())
|
|
1720
|
-
hidden_size = self.model_config.get_hidden_size()
|
|
1721
|
-
spec = jax.ShapeDtypeStruct(shape=(num_padded_tokens, hidden_size),
|
|
1722
|
-
dtype=jax_dtype,
|
|
1723
|
-
sharding=sharding)
|
|
1724
|
-
tensor_spec = {"hidden_states": spec, "residual": spec}
|
|
1725
|
-
return tensor_spec
|
|
1726
|
-
|
|
1727
|
-
def get_uuid_for_jax_transfer(self,
|
|
1728
|
-
scheduler_output: "VllmSchedulerOutput",
|
|
1729
|
-
rank: int, step: int) -> int:
|
|
1730
|
-
'''
|
|
1731
|
-
Get a uuid for jax.transfer, here we use the hash of
|
|
1732
|
-
scheduler_output + counter_step + sender's rank
|
|
1733
|
-
'''
|
|
1734
|
-
scheduler_output_str = ""
|
|
1735
|
-
if not scheduler_output.num_scheduled_tokens:
|
|
1736
|
-
scheduler_output_str = "empty_batch"
|
|
1737
|
-
else:
|
|
1738
|
-
scheduler_output_str = str(
|
|
1739
|
-
sorted(scheduler_output.num_scheduled_tokens.items()))
|
|
1740
|
-
unique_str = f'{scheduler_output_str} {step} {rank}'
|
|
1741
|
-
import hashlib
|
|
1742
|
-
hasher = hashlib.sha1()
|
|
1743
|
-
hasher.update(unique_str.encode('utf-8'))
|
|
1744
|
-
return int.from_bytes(hasher.digest()[:8], 'big')
|