tpu-inference 0.12.0.dev20251222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +67 -0
- tests/core/test_dp_scheduler.py +724 -0
- tests/core/test_init.py +63 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +393 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +291 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +388 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +498 -0
- tests/kernels/quantized_matmul_kernel_test.py +159 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/layers/jax/test_qwix.py +969 -0
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +297 -0
- tests/layers/vllm/test_unquantized.py +621 -0
- tests/layers/vllm/utils.py +72 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +46 -0
- tests/lora/test_bgmv.py +57 -0
- tests/lora/test_layers.py +666 -0
- tests/lora/test_lora.py +147 -0
- tests/lora/test_lora_perf.py +67 -0
- tests/lora/utils.py +88 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +606 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +202 -0
- tests/runner/test_tpu_runner_dp.py +1033 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +215 -0
- tests/test_envs.py +280 -0
- tests/test_tpu_info.py +134 -0
- tests/test_utils.py +193 -0
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +67 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +49 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +814 -0
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +81 -0
- tpu_inference/distributed/tpu_connector.py +732 -0
- tpu_inference/distributed/utils.py +112 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +191 -0
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +399 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +272 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +1340 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +403 -0
- tpu_inference/layers/common/attention_metadata.py +48 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +23 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +600 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +268 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
- tpu_inference/layers/jax/base.py +165 -0
- tpu_inference/layers/jax/constants.py +101 -0
- tpu_inference/layers/jax/layers.py +315 -0
- tpu_inference/layers/jax/misc.py +30 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
- tpu_inference/layers/jax/moe/moe.py +249 -0
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +294 -0
- tpu_inference/layers/jax/rope_interface.py +228 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
- tpu_inference/layers/jax/sample/sampling.py +110 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
- tpu_inference/layers/jax/transformer_block.py +121 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +502 -0
- tpu_inference/layers/vllm/linear_common.py +221 -0
- tpu_inference/layers/vllm/quantization/__init__.py +55 -0
- tpu_inference/layers/vllm/quantization/awq.py +221 -0
- tpu_inference/layers/vllm/quantization/common.py +124 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
- tpu_inference/layers/vllm/sharding.py +244 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +98 -0
- tpu_inference/lora/torch_punica_tpu.py +310 -0
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +520 -0
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +978 -0
- tpu_inference/models/jax/gpt_oss.py +508 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
- tpu_inference/models/jax/llama3.py +436 -0
- tpu_inference/models/jax/llama4.py +643 -0
- tpu_inference/models/jax/llama_eagle3.py +350 -0
- tpu_inference/models/jax/llama_guard_4.py +375 -0
- tpu_inference/models/jax/qwen2.py +390 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
- tpu_inference/models/jax/qwen3.py +318 -0
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +110 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
- tpu_inference/models/jax/utils/weight_utils.py +621 -0
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
- tpu_inference/platforms/__init__.py +16 -0
- tpu_inference/platforms/tpu_platform.py +258 -0
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +890 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +166 -0
- tpu_inference/runner/kv_cache_manager.py +508 -0
- tpu_inference/runner/lora_utils.py +106 -0
- tpu_inference/runner/multimodal_manager.py +231 -0
- tpu_inference/runner/persistent_batch_manager.py +296 -0
- tpu_inference/runner/speculative_decoding_manager.py +262 -0
- tpu_inference/runner/structured_decoding_manager.py +101 -0
- tpu_inference/runner/tpu_runner.py +1768 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +430 -0
- tpu_inference/tpu_info.py +92 -0
- tpu_inference/utils.py +345 -0
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +468 -0
- tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
- tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
- tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
- tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,430 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Implements the Eagle3 proposer for speculative decoding on JAX/TPU."""
|
|
15
|
+
import functools
|
|
16
|
+
from dataclasses import replace
|
|
17
|
+
from typing import Any, Optional
|
|
18
|
+
|
|
19
|
+
import jax
|
|
20
|
+
import jax.numpy as jnp
|
|
21
|
+
import numpy as np
|
|
22
|
+
from flax import nnx
|
|
23
|
+
from jax import lax
|
|
24
|
+
from jax.sharding import NamedSharding, PartitionSpec
|
|
25
|
+
from vllm.config import VllmConfig
|
|
26
|
+
|
|
27
|
+
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
28
|
+
from tpu_inference.logger import init_logger
|
|
29
|
+
from tpu_inference.models.common.model_loader import get_model
|
|
30
|
+
from tpu_inference.runner import utils as runner_utils
|
|
31
|
+
from tpu_inference.utils import device_array
|
|
32
|
+
|
|
33
|
+
logger = init_logger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class Eagle3Proposer:
|
|
37
|
+
"""A proposer for speculative decoding using the Eagle3 method.
|
|
38
|
+
|
|
39
|
+
This class is responsible for loading the draft model and generating draft
|
|
40
|
+
tokens based on the target model's outputs.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
vllm_config: VllmConfig,
|
|
46
|
+
runner: Any, # TPUModelRunner
|
|
47
|
+
):
|
|
48
|
+
"""Initializes the Eagle3Proposer.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
vllm_config: The vLLM configuration.
|
|
52
|
+
runner: The TPUModelRunner instance.
|
|
53
|
+
"""
|
|
54
|
+
self.vllm_config = vllm_config
|
|
55
|
+
self.speculative_config = vllm_config.speculative_config
|
|
56
|
+
assert self.speculative_config is not None
|
|
57
|
+
self.draft_model_config = self.speculative_config.draft_model_config
|
|
58
|
+
self.method = self.speculative_config.method
|
|
59
|
+
|
|
60
|
+
self.runner = runner
|
|
61
|
+
self.mesh = runner.mesh
|
|
62
|
+
self.num_speculative_tokens = (
|
|
63
|
+
self.speculative_config.num_speculative_tokens)
|
|
64
|
+
self.block_size = vllm_config.cache_config.block_size
|
|
65
|
+
self.rng_key = jax.random.key(self.vllm_config.model_config.seed)
|
|
66
|
+
self.max_num_tokens = runner.max_num_tokens
|
|
67
|
+
self.token_arange = jnp.arange(self.max_num_tokens)
|
|
68
|
+
|
|
69
|
+
def load_model(self, target_model: Any) -> None:
|
|
70
|
+
"""Loads the draft model."""
|
|
71
|
+
self.model_fn, self.compute_logits_fn, self.combine_hidden_states_fn, _, self.state, _, _ = get_model(
|
|
72
|
+
self.vllm_config, self.rng_key, self.mesh, is_draft_model=True)
|
|
73
|
+
|
|
74
|
+
draft_embed_tokens = getattr(self.state.model, 'embed_tokens', None)
|
|
75
|
+
if draft_embed_tokens is None or ~jnp.any(
|
|
76
|
+
draft_embed_tokens.embedding):
|
|
77
|
+
logger.info(
|
|
78
|
+
"Draft model does not have embedding. Setting draft model's embed_tokens to target model's embed"
|
|
79
|
+
)
|
|
80
|
+
self.state.model.embed_tokens = target_model.model.embed
|
|
81
|
+
elif jnp.array_equal(draft_embed_tokens.embedding,
|
|
82
|
+
target_model.model.embed.embedding):
|
|
83
|
+
logger.info(
|
|
84
|
+
"Draft model's embed_tokens is identical to target model's embed. Sharing the embedding."
|
|
85
|
+
)
|
|
86
|
+
self.state.model.embed_tokens = target_model.model.embed
|
|
87
|
+
else:
|
|
88
|
+
logger.info("Draft model has its own embed_tokens.")
|
|
89
|
+
|
|
90
|
+
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
91
|
+
def _prepare_input_ids(
|
|
92
|
+
self, query_start_loc: jax.Array, target_token_ids: jax.Array,
|
|
93
|
+
next_token_ids: jax.Array,
|
|
94
|
+
num_reqs: jax.Array) -> tuple[jnp.ndarray, jnp.ndarray]:
|
|
95
|
+
"""JIT-compiled helper for preparing the input IDs for the draft model."""
|
|
96
|
+
|
|
97
|
+
last_token_indices = query_start_loc[1:] - 1
|
|
98
|
+
# Shift the input ids by one token.
|
|
99
|
+
rolled_input_ids = jnp.roll(target_token_ids, -1, axis=0)
|
|
100
|
+
|
|
101
|
+
# To make the update JIT-compatible with a dynamic `num_reqs`, we perform a
|
|
102
|
+
# scatter update of a static size, using a mask to handle the dynamic part.
|
|
103
|
+
max_num_reqs = last_token_indices.shape[0]
|
|
104
|
+
mask = jnp.arange(max_num_reqs) < num_reqs
|
|
105
|
+
|
|
106
|
+
# For padded requests (where mask is False), we use the original value from
|
|
107
|
+
# the rolled array, making the update a no-op for them.
|
|
108
|
+
original_values_at_indices = rolled_input_ids[last_token_indices]
|
|
109
|
+
values_to_set = jnp.where(mask, next_token_ids,
|
|
110
|
+
original_values_at_indices)
|
|
111
|
+
|
|
112
|
+
input_ids = rolled_input_ids.at[last_token_indices].set(values_to_set)
|
|
113
|
+
|
|
114
|
+
return input_ids, last_token_indices
|
|
115
|
+
|
|
116
|
+
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
117
|
+
def _update_inputs_for_loop_speculation(
|
|
118
|
+
self, positions: jax.Array, seq_lens: jax.Array,
|
|
119
|
+
block_tables: jax.Array
|
|
120
|
+
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]:
|
|
121
|
+
"""JIT-compiled helper for preparing inputs in the loop of prediction."""
|
|
122
|
+
|
|
123
|
+
positions += 1
|
|
124
|
+
exceeds_max_model_len = positions >= self.runner.max_model_len
|
|
125
|
+
clamped_positions = jnp.where(exceeds_max_model_len, 0, positions)
|
|
126
|
+
|
|
127
|
+
new_seq_lens = seq_lens + 1
|
|
128
|
+
new_seq_lens = jnp.minimum(new_seq_lens, self.runner.max_model_len)
|
|
129
|
+
new_seq_lens = jnp.where(exceeds_max_model_len, 1, new_seq_lens)
|
|
130
|
+
|
|
131
|
+
num_reqs = seq_lens.shape[0]
|
|
132
|
+
query_start_loc = jnp.arange(num_reqs + 1)
|
|
133
|
+
|
|
134
|
+
# Compute the slot mapping.
|
|
135
|
+
# NOTE(woosuk): We should handle the case where the draft model
|
|
136
|
+
# generates tokens beyond the max model length. Since it is complex
|
|
137
|
+
# to remove such requests from the batch, we keep them in the batch
|
|
138
|
+
# but adjust the position ids and slot mappings to avoid the
|
|
139
|
+
# out-of-range access during the model execution. The draft tokens
|
|
140
|
+
# generated with this adjustment should be ignored.
|
|
141
|
+
max_num_blocks_per_req = block_tables.shape[0] // num_reqs
|
|
142
|
+
expanded_exceeds_mask = jnp.repeat(exceeds_max_model_len,
|
|
143
|
+
max_num_blocks_per_req)
|
|
144
|
+
new_block_tables = jnp.where(expanded_exceeds_mask, -1, block_tables)
|
|
145
|
+
|
|
146
|
+
positions = lax.with_sharding_constraint(
|
|
147
|
+
positions, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
148
|
+
clamped_positions = lax.with_sharding_constraint(
|
|
149
|
+
clamped_positions, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
150
|
+
new_seq_lens = lax.with_sharding_constraint(
|
|
151
|
+
new_seq_lens, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
152
|
+
query_start_loc = lax.with_sharding_constraint(
|
|
153
|
+
query_start_loc, NamedSharding(self.mesh, PartitionSpec()))
|
|
154
|
+
new_block_tables = lax.with_sharding_constraint(
|
|
155
|
+
new_block_tables, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
156
|
+
|
|
157
|
+
return positions, clamped_positions, new_seq_lens, query_start_loc, new_block_tables
|
|
158
|
+
|
|
159
|
+
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
160
|
+
def _stack_draft_token_ids(
|
|
161
|
+
self, draft_token_ids_list: list[jax.Array]) -> jnp.ndarray:
|
|
162
|
+
"""JIT-compiled helper for stacking draft token IDs."""
|
|
163
|
+
return jnp.stack(draft_token_ids_list, axis=1)
|
|
164
|
+
|
|
165
|
+
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
166
|
+
def _prepare_hidden_states_and_input_ids(
|
|
167
|
+
self,
|
|
168
|
+
state: nnx.State,
|
|
169
|
+
aux_hidden_states: tuple[jax.Array, ...],
|
|
170
|
+
query_start_loc: jax.Array,
|
|
171
|
+
target_token_ids: jax.Array,
|
|
172
|
+
next_token_ids: jax.Array,
|
|
173
|
+
num_reqs: jax.Array,
|
|
174
|
+
) -> tuple[jax.Array, jax.Array, jax.Array]:
|
|
175
|
+
target_hidden_states = jnp.concatenate(aux_hidden_states, axis=-1)
|
|
176
|
+
target_hidden_states = self.combine_hidden_states_fn(
|
|
177
|
+
state, target_hidden_states)
|
|
178
|
+
|
|
179
|
+
input_ids, last_token_indices = self._prepare_input_ids(
|
|
180
|
+
query_start_loc, target_token_ids, next_token_ids, num_reqs)
|
|
181
|
+
# NOTE(pooyam): For now, we don't support multimodal.
|
|
182
|
+
|
|
183
|
+
return target_hidden_states, input_ids, last_token_indices
|
|
184
|
+
|
|
185
|
+
def prepare_inputs(
|
|
186
|
+
self,
|
|
187
|
+
attn_metadata: AttentionMetadata,
|
|
188
|
+
input_ids: jax.Array,
|
|
189
|
+
aux_hidden_states: tuple[jax.Array, ...],
|
|
190
|
+
next_token_ids: jax.Array,
|
|
191
|
+
num_rejected_tokens: Optional[jax.Array] = None,
|
|
192
|
+
) -> tuple[jax.Array, jax.Array, jax.Array, AttentionMetadata]:
|
|
193
|
+
"""Prepare drafter inputs based on target forward outputs.
|
|
194
|
+
|
|
195
|
+
Mirrors the GPU reference logic but adapted to TPU/JAX types:
|
|
196
|
+
- When no rejection happened, select the first N scheduled tokens.
|
|
197
|
+
- When rejections happened, trim the per-request tail tokens and
|
|
198
|
+
update attention metadata accordingly.
|
|
199
|
+
- Build the EAGLE3 hidden input by concatenating auxiliary hidden
|
|
200
|
+
states along the last dimension.
|
|
201
|
+
|
|
202
|
+
Returns updated AttentionMetadata (positions, query_start_loc, seq_lens)
|
|
203
|
+
and the selected `target_token_ids` and `target_hidden_states`.
|
|
204
|
+
"""
|
|
205
|
+
assert aux_hidden_states is not None and len(aux_hidden_states) > 0, (
|
|
206
|
+
"EAGLE3 requires auxiliary hidden states from the target model.")
|
|
207
|
+
|
|
208
|
+
# The last KV cache group is for the draft model.
|
|
209
|
+
num_kv_cache_groups = len(self.runner.kv_cache_config.kv_cache_groups)
|
|
210
|
+
draft_kv_cache_group_id = num_kv_cache_groups - 1
|
|
211
|
+
block_tables = self.runner.input_batch.block_table[
|
|
212
|
+
draft_kv_cache_group_id].get_cpu_tensor().reshape(-1)
|
|
213
|
+
# Number of active requests in this step (un-padded count).
|
|
214
|
+
num_reqs = self.runner.input_batch.num_reqs
|
|
215
|
+
|
|
216
|
+
if num_rejected_tokens is None:
|
|
217
|
+
num_reqs = device_array(self.mesh,
|
|
218
|
+
np.asarray([num_reqs], dtype=jnp.int32))
|
|
219
|
+
# block_tables = device_array(self.mesh, block_tables)
|
|
220
|
+
attn_metadata = replace(attn_metadata,
|
|
221
|
+
block_tables=device_array(
|
|
222
|
+
self.mesh, block_tables))
|
|
223
|
+
target_hidden_states, input_ids, last_token_indices = self._prepare_hidden_states_and_input_ids(
|
|
224
|
+
self.state, aux_hidden_states, attn_metadata.query_start_loc,
|
|
225
|
+
input_ids, next_token_ids, num_reqs)
|
|
226
|
+
return target_hidden_states, input_ids, last_token_indices, attn_metadata
|
|
227
|
+
|
|
228
|
+
# Host copies from the metadata prepared by the runner.
|
|
229
|
+
query_start_loc_cpu = attn_metadata.query_start_loc_cpu
|
|
230
|
+
seq_lens_cpu = attn_metadata.seq_lens_cpu
|
|
231
|
+
assert query_start_loc_cpu is not None and seq_lens_cpu is not None
|
|
232
|
+
|
|
233
|
+
# Rejection-aware path: compute new per-request lengths and token indices.
|
|
234
|
+
# Convert to host numpy for efficient prefix-sum and repeat ops.
|
|
235
|
+
nrt_cpu = jax.device_get(num_rejected_tokens).astype("int32")
|
|
236
|
+
|
|
237
|
+
# query_len_per_req = [q1, q2, ...]
|
|
238
|
+
query_len_per_req = (query_start_loc_cpu[1:] -
|
|
239
|
+
query_start_loc_cpu[:-1])
|
|
240
|
+
|
|
241
|
+
# query_start_loc_cpu and consequentaly query_len_per_req are padded
|
|
242
|
+
# For padded requests, the query length should be 0.
|
|
243
|
+
query_len_per_req[num_reqs:] = 1
|
|
244
|
+
# num_tokens_per_req = [q1 - n1, q2 - n2, ...]
|
|
245
|
+
num_tokens_per_req = (query_len_per_req - nrt_cpu)
|
|
246
|
+
assert (num_tokens_per_req
|
|
247
|
+
>= 0).all(), ("num_tokens_per_req must be non-negative")
|
|
248
|
+
|
|
249
|
+
# new_query_start_loc = [0, q1-n1, q1+q2-n1-n2, ...]
|
|
250
|
+
# Use numpy for cumsum and then convert back.
|
|
251
|
+
new_query_start_loc_cpu = np.zeros_like(query_start_loc_cpu)
|
|
252
|
+
np.cumsum(num_tokens_per_req, out=new_query_start_loc_cpu[1:])
|
|
253
|
+
|
|
254
|
+
# Build token indices selecting the kept tokens from each request.
|
|
255
|
+
total_num_tokens = int(new_query_start_loc_cpu[-1])
|
|
256
|
+
|
|
257
|
+
# Pad to total_num_tokens.
|
|
258
|
+
padded_total_num_tokens = runner_utils.get_padded_token_len(
|
|
259
|
+
self.runner.num_tokens_paddings, total_num_tokens)
|
|
260
|
+
pad_width = padded_total_num_tokens - total_num_tokens
|
|
261
|
+
assert pad_width >= 0, (
|
|
262
|
+
f"total_num_tokens {total_num_tokens} exceeds "
|
|
263
|
+
f"num_tokens_paddings {self.runner.num_tokens_paddings}")
|
|
264
|
+
|
|
265
|
+
# Expand request starts: [0, 0, q1-n1, ...,]
|
|
266
|
+
expanded_new_query_start_loc = np.repeat(new_query_start_loc_cpu[:-1],
|
|
267
|
+
num_tokens_per_req)
|
|
268
|
+
# Offsets within each request window: [0,1,2, 0,1,2,3, ...]
|
|
269
|
+
token_offsets = np.arange(total_num_tokens, dtype=np.int32)
|
|
270
|
+
token_offsets -= expanded_new_query_start_loc
|
|
271
|
+
# Map into old flat indices by adding original request starts.
|
|
272
|
+
old_query_start_loc_expanded = np.repeat(query_start_loc_cpu[:-1],
|
|
273
|
+
num_tokens_per_req)
|
|
274
|
+
|
|
275
|
+
token_indices_cpu = token_offsets + old_query_start_loc_expanded
|
|
276
|
+
token_indices_cpu = np.pad(token_indices_cpu, (0, pad_width),
|
|
277
|
+
"constant",
|
|
278
|
+
constant_values=0)
|
|
279
|
+
# Update seq_lens for active requests only: new_seq_lens = s - n.
|
|
280
|
+
new_seq_lens_cpu = seq_lens_cpu - nrt_cpu
|
|
281
|
+
|
|
282
|
+
query_start_loc, seq_lens, token_indices, num_reqs, block_tables = device_array(
|
|
283
|
+
self.mesh,
|
|
284
|
+
(new_query_start_loc_cpu, new_seq_lens_cpu, token_indices_cpu,
|
|
285
|
+
np.asarray([num_reqs], dtype=jnp.int32), block_tables))
|
|
286
|
+
|
|
287
|
+
attn_metadata = replace(attn_metadata, block_tables=block_tables)
|
|
288
|
+
return self._filter_token_and_prepare_initial_inputs(
|
|
289
|
+
self.state, token_indices, query_start_loc, seq_lens, input_ids,
|
|
290
|
+
aux_hidden_states, attn_metadata, next_token_ids, num_reqs)
|
|
291
|
+
|
|
292
|
+
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
293
|
+
def _filter_token_and_prepare_initial_inputs(
|
|
294
|
+
self,
|
|
295
|
+
state: nnx.State,
|
|
296
|
+
token_indices: jax.Array,
|
|
297
|
+
query_start_loc: jax.Array,
|
|
298
|
+
seq_lens: jax.Array,
|
|
299
|
+
input_ids: jax.Array,
|
|
300
|
+
aux_hidden_states: tuple[jax.Array, ...],
|
|
301
|
+
attn_metadata: AttentionMetadata,
|
|
302
|
+
next_token_ids: jax.Array,
|
|
303
|
+
num_reqs: jax.Array,
|
|
304
|
+
) -> tuple[jax.Array, jax.Array, jax.Array, AttentionMetadata]:
|
|
305
|
+
|
|
306
|
+
# Select tokens and hidden states.
|
|
307
|
+
target_token_ids = input_ids[token_indices]
|
|
308
|
+
# Update positions to match the selected tokens.
|
|
309
|
+
if attn_metadata.input_positions.ndim == 2:
|
|
310
|
+
input_positions = attn_metadata.input_positions[:, token_indices]
|
|
311
|
+
else:
|
|
312
|
+
input_positions = attn_metadata.input_positions[token_indices]
|
|
313
|
+
|
|
314
|
+
attn_metadata = AttentionMetadata(
|
|
315
|
+
input_positions=input_positions,
|
|
316
|
+
block_tables=attn_metadata.block_tables,
|
|
317
|
+
seq_lens=seq_lens,
|
|
318
|
+
query_start_loc=query_start_loc,
|
|
319
|
+
request_distribution=attn_metadata.request_distribution,
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
target_hidden_states, input_ids, last_token_indices = self._prepare_hidden_states_and_input_ids(
|
|
323
|
+
state, [h[token_indices] for h in aux_hidden_states],
|
|
324
|
+
query_start_loc, target_token_ids, next_token_ids, num_reqs)
|
|
325
|
+
|
|
326
|
+
return target_hidden_states, input_ids, last_token_indices, attn_metadata
|
|
327
|
+
|
|
328
|
+
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
329
|
+
def _select_draft_token_ids(
|
|
330
|
+
self,
|
|
331
|
+
state: nnx.State,
|
|
332
|
+
hidden_states: jax.Array,
|
|
333
|
+
last_token_indices: jax.Array,
|
|
334
|
+
) -> jax.Array:
|
|
335
|
+
sample_hidden_states = hidden_states[last_token_indices]
|
|
336
|
+
sample_hidden_states = lax.with_sharding_constraint(
|
|
337
|
+
sample_hidden_states,
|
|
338
|
+
NamedSharding(self.mesh, PartitionSpec(None, None)))
|
|
339
|
+
return self._get_draft_token_ids(state, sample_hidden_states)
|
|
340
|
+
|
|
341
|
+
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
342
|
+
def _get_draft_token_ids(self, state: nnx.State,
|
|
343
|
+
hidden_states: jax.Array) -> jax.Array:
|
|
344
|
+
lora_metadata = None
|
|
345
|
+
logits = self.compute_logits_fn(state, hidden_states, lora_metadata)
|
|
346
|
+
draft_token_ids = jnp.argmax(logits, axis=-1)
|
|
347
|
+
return lax.with_sharding_constraint(
|
|
348
|
+
draft_token_ids, NamedSharding(self.mesh, PartitionSpec()))
|
|
349
|
+
|
|
350
|
+
@functools.partial(jax.jit, static_argnums=(0, ))
|
|
351
|
+
def _select_inputs_for_loop_speculation(
|
|
352
|
+
self, state: nnx.State, positions: jax.Array, residual: jax.Array,
|
|
353
|
+
hidden_states: jax.Array,
|
|
354
|
+
last_token_indices: jax.Array) -> tuple[jax.Array, jax.Array]:
|
|
355
|
+
positions = positions[last_token_indices]
|
|
356
|
+
residual = residual[last_token_indices]
|
|
357
|
+
draft_token_ids = self._select_draft_token_ids(state, hidden_states,
|
|
358
|
+
last_token_indices)
|
|
359
|
+
|
|
360
|
+
positions = lax.with_sharding_constraint(
|
|
361
|
+
positions, NamedSharding(self.mesh, PartitionSpec(None, )))
|
|
362
|
+
residual = lax.with_sharding_constraint(
|
|
363
|
+
residual, NamedSharding(self.mesh, PartitionSpec(None, None)))
|
|
364
|
+
draft_token_ids = lax.with_sharding_constraint(
|
|
365
|
+
draft_token_ids, NamedSharding(self.mesh, PartitionSpec()))
|
|
366
|
+
|
|
367
|
+
return positions, residual, draft_token_ids
|
|
368
|
+
|
|
369
|
+
def propose(
|
|
370
|
+
self,
|
|
371
|
+
kv_caches: list[jax.Array],
|
|
372
|
+
input_ids: jax.Array,
|
|
373
|
+
attn_metadata: AttentionMetadata,
|
|
374
|
+
last_token_indices,
|
|
375
|
+
target_hidden_states,
|
|
376
|
+
) -> tuple[list[jax.Array], jnp.ndarray]:
|
|
377
|
+
"""Proposes draft tokens using the draft model.
|
|
378
|
+
Returns:
|
|
379
|
+
A tuple containing the updated KV caches and a tensor of proposed
|
|
380
|
+
draft token IDs.
|
|
381
|
+
"""
|
|
382
|
+
|
|
383
|
+
# input_ids and target_hidden_states for the first speculation have been prepared in prepare_inputs() to improve performance.
|
|
384
|
+
kv_caches, hidden_states, residual = self.model_fn(
|
|
385
|
+
self.state,
|
|
386
|
+
kv_caches,
|
|
387
|
+
input_ids,
|
|
388
|
+
target_hidden_states,
|
|
389
|
+
attn_metadata,
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
if self.num_speculative_tokens == 1:
|
|
393
|
+
return kv_caches, self._select_draft_token_ids(
|
|
394
|
+
self.state, hidden_states, last_token_indices)
|
|
395
|
+
|
|
396
|
+
positions, hidden_states, draft_token_ids = self._select_inputs_for_loop_speculation(
|
|
397
|
+
self.state, attn_metadata.input_positions, residual[0],
|
|
398
|
+
hidden_states, last_token_indices)
|
|
399
|
+
|
|
400
|
+
draft_token_ids_list = [draft_token_ids]
|
|
401
|
+
|
|
402
|
+
for _ in range(self.num_speculative_tokens - 1):
|
|
403
|
+
input_ids_loop = draft_token_ids_list[-1]
|
|
404
|
+
|
|
405
|
+
positions, clamped_positions, new_seq_lens, query_start_loc, new_block_tables = self._update_inputs_for_loop_speculation(
|
|
406
|
+
positions, attn_metadata.seq_lens, attn_metadata.block_tables)
|
|
407
|
+
|
|
408
|
+
attn_metadata = replace(
|
|
409
|
+
attn_metadata,
|
|
410
|
+
input_positions=clamped_positions,
|
|
411
|
+
seq_lens=new_seq_lens,
|
|
412
|
+
query_start_loc=query_start_loc,
|
|
413
|
+
block_tables=new_block_tables,
|
|
414
|
+
)
|
|
415
|
+
kv_caches, new_hidden_states, residual = self.model_fn(
|
|
416
|
+
self.state,
|
|
417
|
+
kv_caches,
|
|
418
|
+
input_ids_loop,
|
|
419
|
+
hidden_states, # This should be the hidden_states from previous step
|
|
420
|
+
attn_metadata,
|
|
421
|
+
)
|
|
422
|
+
hidden_states = residual[0]
|
|
423
|
+
draft_token_ids = self._get_draft_token_ids(
|
|
424
|
+
self.state, new_hidden_states)
|
|
425
|
+
draft_token_ids_list.append(draft_token_ids)
|
|
426
|
+
|
|
427
|
+
# [batch_size, num_speculative_tokens]
|
|
428
|
+
draft_token_ids = self._stack_draft_token_ids(draft_token_ids_list)
|
|
429
|
+
|
|
430
|
+
return kv_caches, draft_token_ids
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import glob
|
|
16
|
+
import os
|
|
17
|
+
|
|
18
|
+
import requests
|
|
19
|
+
|
|
20
|
+
from tpu_inference import envs
|
|
21
|
+
from tpu_inference.logger import init_logger
|
|
22
|
+
|
|
23
|
+
logger = init_logger(__name__)
|
|
24
|
+
|
|
25
|
+
GCE_TPU_ACCELERATOR_ENDPOINT = (
|
|
26
|
+
"http://metadata.google.internal/computeMetadata/v1/instance/attributes/")
|
|
27
|
+
GCE_TPU_HEADERS = {"Metadata-Flavor": "Google"}
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def get_tpu_metadata(key: str = "") -> str:
|
|
31
|
+
try:
|
|
32
|
+
accelerator_type_request = requests.get(
|
|
33
|
+
os.path.join(GCE_TPU_ACCELERATOR_ENDPOINT, key),
|
|
34
|
+
headers=GCE_TPU_HEADERS,
|
|
35
|
+
)
|
|
36
|
+
if (accelerator_type_request.status_code == 200
|
|
37
|
+
and accelerator_type_request.text):
|
|
38
|
+
return accelerator_type_request.text
|
|
39
|
+
else:
|
|
40
|
+
logger.error(
|
|
41
|
+
"Unable to poll TPU GCE Metadata. Got "
|
|
42
|
+
f"status code: {accelerator_type_request.status_code} and "
|
|
43
|
+
f"content: {accelerator_type_request.text}")
|
|
44
|
+
except requests.RequestException as e:
|
|
45
|
+
logger.error("Unable to poll the TPU GCE Metadata: %s", e)
|
|
46
|
+
return None
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def get_tpu_type() -> str:
|
|
50
|
+
tpu_type = envs.TPU_ACCELERATOR_TYPE
|
|
51
|
+
if tpu_type is None:
|
|
52
|
+
tpu_type = get_tpu_metadata(key="accelerator-type")
|
|
53
|
+
return tpu_type
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def get_node_name() -> str:
|
|
57
|
+
tpu_name = envs.TPU_NAME
|
|
58
|
+
if not tpu_name:
|
|
59
|
+
tpu_name = get_tpu_metadata(key="instance-id")
|
|
60
|
+
return tpu_name
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def get_node_worker_id() -> int:
|
|
64
|
+
"""For multi-host TPU VM, this returns the worker id for the current node."""
|
|
65
|
+
worker_id = envs.TPU_WORKER_ID
|
|
66
|
+
if worker_id is None:
|
|
67
|
+
worker_id = get_tpu_metadata(key="agent-worker-number")
|
|
68
|
+
if worker_id is None:
|
|
69
|
+
return 0
|
|
70
|
+
return int(worker_id)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def get_num_cores_per_chip() -> int:
|
|
74
|
+
tpu_type = get_tpu_type()
|
|
75
|
+
if tpu_type.startswith(("v5litepod", "v6e")):
|
|
76
|
+
return 1
|
|
77
|
+
return 2
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def get_num_chips() -> int:
|
|
81
|
+
accel_files = glob.glob("/dev/accel*")
|
|
82
|
+
if accel_files:
|
|
83
|
+
return len(accel_files)
|
|
84
|
+
try:
|
|
85
|
+
vfio_entries = os.listdir("/dev/vfio")
|
|
86
|
+
numeric_entries = [
|
|
87
|
+
int(entry) for entry in vfio_entries if entry.isdigit()
|
|
88
|
+
]
|
|
89
|
+
return len(numeric_entries)
|
|
90
|
+
except FileNotFoundError as e:
|
|
91
|
+
logger.error("Failed to detect number of TPUs: %s", e)
|
|
92
|
+
return 0
|