tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +317 -34
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +26 -6
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +25 -4
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +25 -12
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +32 -9
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +101 -494
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
- tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +112 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +18 -5
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +179 -51
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
- tpu_inference/models/jax/utils/weight_utils.py +234 -155
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +51 -72
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +180 -80
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +55 -33
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +16 -3
- tpu_inference/runner/tpu_runner.py +124 -61
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +84 -22
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +66 -52
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -186
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
1
14
|
"""Utilities for downloading model weights from HuggingFace."""
|
|
2
15
|
|
|
3
16
|
import functools
|
|
@@ -13,10 +26,12 @@ from typing import Any, Optional
|
|
|
13
26
|
import jax
|
|
14
27
|
import jax.numpy as jnp
|
|
15
28
|
import torch
|
|
29
|
+
import torchax
|
|
16
30
|
from flax import nnx
|
|
17
31
|
from jax.sharding import Mesh, NamedSharding
|
|
18
32
|
from jax.sharding import PartitionSpec as P
|
|
19
33
|
from safetensors import safe_open
|
|
34
|
+
from vllm.config import VllmConfig
|
|
20
35
|
|
|
21
36
|
from tpu_inference import envs, utils
|
|
22
37
|
from tpu_inference.logger import init_logger
|
|
@@ -65,7 +80,13 @@ def transpose_params(param_key: str, param_tensor: jax.Array, transpose_map):
|
|
|
65
80
|
def reshape_params(param_key: str, param_tensor: jax.Array, shape_map):
|
|
66
81
|
for key, new_shape in shape_map.items():
|
|
67
82
|
if key in param_key:
|
|
68
|
-
|
|
83
|
+
try:
|
|
84
|
+
#TODO:(gpolovets) Add validation on whether reshape preserves data layout.
|
|
85
|
+
return jnp.reshape(param_tensor, new_shape)
|
|
86
|
+
except TypeError:
|
|
87
|
+
raise TypeError(
|
|
88
|
+
f"Cannot reshape for key={key}, new_shape={new_shape}, param_shape={param_tensor.shape}"
|
|
89
|
+
)
|
|
69
90
|
return param_tensor # Base case / no-op
|
|
70
91
|
|
|
71
92
|
|
|
@@ -265,15 +286,16 @@ def get_default_maps(model_config, mesh: Mesh,
|
|
|
265
286
|
bias_pad_map=bias_pad_keys)
|
|
266
287
|
|
|
267
288
|
|
|
268
|
-
def
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
289
|
+
def _load_and_shard_weight(vllm_config,
|
|
290
|
+
params: nnx.State,
|
|
291
|
+
shardings: Any,
|
|
292
|
+
metadata_map: MetadataMap,
|
|
293
|
+
mesh: Mesh,
|
|
294
|
+
hf_key: str,
|
|
295
|
+
hf_weight: jax.Array,
|
|
296
|
+
keep_original_dtype_keys_regex: list[str]
|
|
297
|
+
| None = None,
|
|
298
|
+
pp_missing_layers: list[str] | None = None):
|
|
277
299
|
name_map = metadata_map.name_map
|
|
278
300
|
reshape_keys = metadata_map.reshape_map
|
|
279
301
|
bias_reshape_keys = metadata_map.bias_reshape_map
|
|
@@ -290,6 +312,131 @@ def _load_hf_weights_on_thread(vllm_config,
|
|
|
290
312
|
head_dim = utils.get_padded_head_dim(head_dim_original)
|
|
291
313
|
head_dim_pad = head_dim - head_dim_original
|
|
292
314
|
|
|
315
|
+
# Check if the key should retain its original dtype
|
|
316
|
+
keep_original_dtype = False
|
|
317
|
+
if keep_original_dtype_keys_regex:
|
|
318
|
+
for pattern in keep_original_dtype_keys_regex:
|
|
319
|
+
if re.match(pattern, hf_key):
|
|
320
|
+
keep_original_dtype = True
|
|
321
|
+
break
|
|
322
|
+
|
|
323
|
+
# Converting to config's dtype
|
|
324
|
+
if not keep_original_dtype and hf_weight.dtype != model_config.dtype:
|
|
325
|
+
logger.warning(
|
|
326
|
+
f"Converting dtype for {hf_key} from {hf_weight.dtype} to {model_config.dtype}"
|
|
327
|
+
)
|
|
328
|
+
hf_weight = hf_weight.astype(model_config.dtype)
|
|
329
|
+
|
|
330
|
+
if hf_key.endswith(".weight"):
|
|
331
|
+
hf_key = hf_key.removesuffix(".weight")
|
|
332
|
+
|
|
333
|
+
# Find the corresponding model key using the HF key
|
|
334
|
+
if "layers" in hf_key:
|
|
335
|
+
layer_num = re.search(r"layers\.(\d+)", hf_key).group(1)
|
|
336
|
+
layer_key = re.sub(r"layers\.\d+", "layers.*", hf_key)
|
|
337
|
+
model_key = name_map[layer_key]
|
|
338
|
+
model_key = re.sub(r"layers\.\*", f"layers.{layer_num}", model_key)
|
|
339
|
+
elif "blocks" in hf_key:
|
|
340
|
+
layer_num = re.search(r"blocks\.(\d+)", hf_key).group(1)
|
|
341
|
+
layer_key = re.sub(r"blocks\.\d+", "blocks.*", hf_key)
|
|
342
|
+
model_key = name_map[layer_key]
|
|
343
|
+
model_key = re.sub(r"blocks\.\*", f"blocks.{layer_num}", model_key)
|
|
344
|
+
else:
|
|
345
|
+
if hf_key not in name_map and hf_key == "lm_head":
|
|
346
|
+
logger.warning(f"Skip loading {hf_key} due to tie_word_embeddings")
|
|
347
|
+
return
|
|
348
|
+
if hf_key not in name_map and "t2d" in hf_key:
|
|
349
|
+
logger.warning(
|
|
350
|
+
f"Skip loading {hf_key} as it's not used in eagle-3 for now")
|
|
351
|
+
return
|
|
352
|
+
model_key = name_map.get(hf_key, hf_key)
|
|
353
|
+
|
|
354
|
+
if pp_missing_layers and _is_pp_missing_layer(hf_key, pp_missing_layers):
|
|
355
|
+
logger.warning(
|
|
356
|
+
f"Skip loading {hf_key} as it doesn't belong to this PP stage.")
|
|
357
|
+
return
|
|
358
|
+
model_weight, model_sharding = get_param_and_sharding(
|
|
359
|
+
params, shardings, model_key)
|
|
360
|
+
|
|
361
|
+
logger.debug(
|
|
362
|
+
"before transform | "
|
|
363
|
+
f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
if hf_key.endswith(".bias"):
|
|
367
|
+
for key in bias_reshape_keys:
|
|
368
|
+
if key in hf_key:
|
|
369
|
+
hf_weight = jnp.reshape(hf_weight, bias_reshape_keys[key])
|
|
370
|
+
if head_dim_pad > 0:
|
|
371
|
+
hf_weight = jnp.pad(hf_weight, ((0, 0), (0, head_dim_pad)))
|
|
372
|
+
break
|
|
373
|
+
else:
|
|
374
|
+
for key in reshape_keys:
|
|
375
|
+
if key in hf_key:
|
|
376
|
+
hf_weight = jnp.reshape(hf_weight, reshape_keys[key])
|
|
377
|
+
if head_dim_pad > 0:
|
|
378
|
+
if "o_proj" in key:
|
|
379
|
+
hf_weight = jnp.pad(hf_weight, ((0, 0), (0, 0),
|
|
380
|
+
(0, head_dim_pad)))
|
|
381
|
+
else:
|
|
382
|
+
hf_weight = jnp.pad(hf_weight,
|
|
383
|
+
((0, 0), (0, head_dim_pad),
|
|
384
|
+
(0, 0)))
|
|
385
|
+
break
|
|
386
|
+
for key in transpose_keys:
|
|
387
|
+
if key in hf_key:
|
|
388
|
+
hf_weight = jnp.transpose(hf_weight, transpose_keys[key])
|
|
389
|
+
break
|
|
390
|
+
|
|
391
|
+
# Pad num-kv-heads
|
|
392
|
+
if hf_key.endswith(".bias"):
|
|
393
|
+
for key, value in bias_pad_keys.items():
|
|
394
|
+
dim = value[0]
|
|
395
|
+
dim_size = value[1]
|
|
396
|
+
if key in hf_key and dim_size != 0:
|
|
397
|
+
hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
|
|
398
|
+
break
|
|
399
|
+
else:
|
|
400
|
+
for key, value in pad_keys.items():
|
|
401
|
+
dim = value[0]
|
|
402
|
+
dim_size = value[1]
|
|
403
|
+
if key in hf_key and dim_size != 0:
|
|
404
|
+
hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
|
|
405
|
+
break
|
|
406
|
+
|
|
407
|
+
logger.debug(
|
|
408
|
+
"after transform | "
|
|
409
|
+
f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
if head_dim_pad == 0:
|
|
413
|
+
assert model_weight.value.shape == hf_weight.shape, f"{hf_key}: {model_weight.value.shape} != {hf_weight.shape}"
|
|
414
|
+
|
|
415
|
+
# Update the model weight
|
|
416
|
+
spec = model_weight.sharding.spec if isinstance(
|
|
417
|
+
model_weight.sharding, NamedSharding) else model_weight.sharding
|
|
418
|
+
model_weight.value = shard(hf_weight, spec)
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
def _is_pp_missing_layer(hf_key: str, pp_missing_layers: list[str]) -> bool:
|
|
422
|
+
has_digit = any(char.isdigit() for char in hf_key)
|
|
423
|
+
# add the suffix after digits to avoid it matches "layers.10" with "layers.1"
|
|
424
|
+
suffix = "." if has_digit else ""
|
|
425
|
+
return any(f'{pp_missing_layer}{suffix}' in hf_key
|
|
426
|
+
for pp_missing_layer in pp_missing_layers)
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
def _load_hf_weights_on_thread(
|
|
430
|
+
vllm_config: VllmConfig,
|
|
431
|
+
params: nnx.State,
|
|
432
|
+
metadata_map: "MetadataMap",
|
|
433
|
+
mesh: Mesh,
|
|
434
|
+
weights_file: str,
|
|
435
|
+
filter_regex: Optional[str] = None,
|
|
436
|
+
keep_original_dtype_keys_regex: Optional[list[str]] = None,
|
|
437
|
+
pp_missing_layers: list[str] | None = None,
|
|
438
|
+
):
|
|
439
|
+
"""Loads weights from a single weights file."""
|
|
293
440
|
try:
|
|
294
441
|
shardings = nnx.get_named_sharding(params, mesh)
|
|
295
442
|
except TypeError:
|
|
@@ -297,160 +444,92 @@ def _load_hf_weights_on_thread(vllm_config,
|
|
|
297
444
|
|
|
298
445
|
for hf_key, hf_weight in model_weights_single_file_generator(
|
|
299
446
|
weights_file, framework="flax", filter_regex=filter_regex):
|
|
447
|
+
_load_and_shard_weight(
|
|
448
|
+
vllm_config,
|
|
449
|
+
params,
|
|
450
|
+
shardings,
|
|
451
|
+
metadata_map,
|
|
452
|
+
mesh,
|
|
453
|
+
hf_key,
|
|
454
|
+
hf_weight,
|
|
455
|
+
keep_original_dtype_keys_regex,
|
|
456
|
+
pp_missing_layers,
|
|
457
|
+
)
|
|
300
458
|
|
|
301
|
-
# Check if the key should be excluded
|
|
302
|
-
if exclude_regex:
|
|
303
|
-
should_exclude = False
|
|
304
|
-
for pattern in exclude_regex:
|
|
305
|
-
if re.search(pattern, hf_key):
|
|
306
|
-
logger.info(
|
|
307
|
-
f"Excluding {hf_key} based on pattern {pattern}")
|
|
308
|
-
should_exclude = True
|
|
309
|
-
break
|
|
310
|
-
if should_exclude:
|
|
311
|
-
continue
|
|
312
|
-
|
|
313
|
-
# Check if the key should retain its original dtype
|
|
314
|
-
keep_original_dtype = False
|
|
315
|
-
if keep_original_dtype_keys_regex:
|
|
316
|
-
for pattern in keep_original_dtype_keys_regex:
|
|
317
|
-
if re.match(pattern, hf_key):
|
|
318
|
-
keep_original_dtype = True
|
|
319
|
-
break
|
|
320
459
|
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
f"Skip loading {hf_key} due to tie_word_embeddings")
|
|
346
|
-
continue
|
|
347
|
-
if hf_key not in name_map and "t2d" in hf_key:
|
|
348
|
-
logger.warning(
|
|
349
|
-
f"Skip loading {hf_key} as it's not used in eagle-3 for now"
|
|
350
|
-
)
|
|
460
|
+
def load_hf_weights(
|
|
461
|
+
vllm_config: VllmConfig,
|
|
462
|
+
model: nnx.Module,
|
|
463
|
+
metadata_map: "MetadataMap",
|
|
464
|
+
mesh: Mesh,
|
|
465
|
+
filter_regex: Optional[str] = None,
|
|
466
|
+
is_draft_model: bool = False,
|
|
467
|
+
keep_original_dtype_keys_regex: Optional[list[str]] = None,
|
|
468
|
+
pp_missing_layers: list[str] | None = None,
|
|
469
|
+
):
|
|
470
|
+
"""Load weights into a JAX model from either an iterator or files."""
|
|
471
|
+
params = nnx.state(model)
|
|
472
|
+
try:
|
|
473
|
+
shardings = nnx.get_named_sharding(params, mesh)
|
|
474
|
+
except TypeError:
|
|
475
|
+
shardings = params
|
|
476
|
+
weights_iterator = None
|
|
477
|
+
if hasattr(vllm_config.model_config, "model_weights_iterator"):
|
|
478
|
+
weights_iterator = vllm_config.model_config.model_weights_iterator
|
|
479
|
+
env = torchax.default_env()
|
|
480
|
+
# The weights_iterator is used in RunAI model streamer integration.
|
|
481
|
+
if weights_iterator is not None:
|
|
482
|
+
for hf_key, hf_weight in weights_iterator:
|
|
483
|
+
if filter_regex and not re.match(filter_regex, hf_key):
|
|
351
484
|
continue
|
|
352
|
-
model_key = name_map.get(hf_key, hf_key)
|
|
353
|
-
model_weight, model_sharding = get_param_and_sharding(
|
|
354
|
-
params, shardings, model_key)
|
|
355
|
-
|
|
356
|
-
logger.debug(
|
|
357
|
-
"before transform | "
|
|
358
|
-
f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
|
|
359
|
-
)
|
|
360
485
|
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
hf_weight = jnp.reshape(hf_weight, bias_reshape_keys[key])
|
|
365
|
-
if head_dim_pad > 0:
|
|
366
|
-
hf_weight = jnp.pad(hf_weight,
|
|
367
|
-
((0, 0), (0, head_dim_pad)))
|
|
368
|
-
break
|
|
369
|
-
else:
|
|
370
|
-
for key in reshape_keys:
|
|
371
|
-
if key in hf_key:
|
|
372
|
-
hf_weight = jnp.reshape(hf_weight, reshape_keys[key])
|
|
373
|
-
if head_dim_pad > 0:
|
|
374
|
-
if "o_proj" in key:
|
|
375
|
-
hf_weight = jnp.pad(hf_weight, ((0, 0), (0, 0),
|
|
376
|
-
(0, head_dim_pad)))
|
|
377
|
-
else:
|
|
378
|
-
hf_weight = jnp.pad(hf_weight,
|
|
379
|
-
((0, 0), (0, head_dim_pad),
|
|
380
|
-
(0, 0)))
|
|
381
|
-
break
|
|
382
|
-
for key in transpose_keys:
|
|
383
|
-
if key in hf_key:
|
|
384
|
-
hf_weight = jnp.transpose(hf_weight, transpose_keys[key])
|
|
385
|
-
break
|
|
386
|
-
|
|
387
|
-
# Pad num-kv-heads
|
|
388
|
-
if hf_key.endswith(".bias"):
|
|
389
|
-
for key, value in bias_pad_keys.items():
|
|
390
|
-
dim = value[0]
|
|
391
|
-
dim_size = value[1]
|
|
392
|
-
if key in hf_key and dim_size != 0:
|
|
393
|
-
hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
|
|
394
|
-
break
|
|
395
|
-
else:
|
|
396
|
-
for key, value in pad_keys.items():
|
|
397
|
-
dim = value[0]
|
|
398
|
-
dim_size = value[1]
|
|
399
|
-
if key in hf_key and dim_size != 0:
|
|
400
|
-
hf_weight = jnp.repeat(hf_weight, dim_size, axis=dim)
|
|
401
|
-
break
|
|
402
|
-
|
|
403
|
-
logger.debug(
|
|
404
|
-
"after transform | "
|
|
405
|
-
f"{hf_key}: {hf_weight.shape} --> {model_key}: {model_weight.value.shape} {model_sharding}"
|
|
406
|
-
)
|
|
486
|
+
# Since the weights_iterator yields Pytorch tensors (torch.Tensor),
|
|
487
|
+
# we need to convert them to JAX arrays (jax.Array).
|
|
488
|
+
hf_weight_jax = env.t2j_copy(hf_weight)
|
|
407
489
|
|
|
408
|
-
|
|
409
|
-
assert model_weight.value.shape == hf_weight.shape, f"{hf_key}: {model_weight.value.shape} != {hf_weight.shape}"
|
|
410
|
-
|
|
411
|
-
# Update the model weight
|
|
412
|
-
spec = model_weight.sharding.spec if isinstance(
|
|
413
|
-
model_weight.sharding, NamedSharding) else model_weight.sharding
|
|
414
|
-
model_weight.value = shard(hf_weight, spec)
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
def load_hf_weights(vllm_config,
|
|
418
|
-
model: nnx.Module,
|
|
419
|
-
metadata_map: MetadataMap,
|
|
420
|
-
mesh: Mesh,
|
|
421
|
-
filter_regex: str | None = None,
|
|
422
|
-
is_draft_model: bool = False,
|
|
423
|
-
keep_original_dtype_keys_regex: list[str] | None = None,
|
|
424
|
-
exclude_regex: list[str] | None = None):
|
|
425
|
-
"""Load weights from all model weights files to the model, run in multi threads."""
|
|
426
|
-
if is_draft_model:
|
|
427
|
-
model_path = vllm_config.speculative_config.draft_model_config.model
|
|
428
|
-
else:
|
|
429
|
-
model_path = vllm_config.model_config.model
|
|
430
|
-
weights_files = get_model_weights_files(
|
|
431
|
-
model_path, vllm_config.load_config.download_dir)
|
|
432
|
-
params = nnx.state(model)
|
|
433
|
-
max_workers = min(64, len(weights_files))
|
|
434
|
-
# NOTE(xiang): Disable multi-threading mode if running on multi-host.
|
|
435
|
-
# Because multi-threading would cause different JAX processes to load
|
|
436
|
-
# different weights at the same time.
|
|
437
|
-
if envs.TPU_MULTIHOST_BACKEND == "ray":
|
|
438
|
-
max_workers = 1
|
|
439
|
-
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
440
|
-
futures = [
|
|
441
|
-
executor.submit(
|
|
442
|
-
_load_hf_weights_on_thread,
|
|
490
|
+
_load_and_shard_weight(
|
|
443
491
|
vllm_config,
|
|
444
492
|
params,
|
|
493
|
+
shardings,
|
|
445
494
|
metadata_map,
|
|
446
495
|
mesh,
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
keep_original_dtype_keys_regex
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
496
|
+
hf_key,
|
|
497
|
+
hf_weight_jax,
|
|
498
|
+
keep_original_dtype_keys_regex,
|
|
499
|
+
pp_missing_layers=pp_missing_layers,
|
|
500
|
+
)
|
|
501
|
+
else:
|
|
502
|
+
# File-based path (multi-threaded)
|
|
503
|
+
if is_draft_model:
|
|
504
|
+
model_path = vllm_config.speculative_config.draft_model_config.model
|
|
505
|
+
else:
|
|
506
|
+
model_path = vllm_config.model_config.model
|
|
507
|
+
weights_files = get_model_weights_files(
|
|
508
|
+
model_path, vllm_config.load_config.download_dir)
|
|
509
|
+
max_workers = min(64, len(weights_files))
|
|
510
|
+
# NOTE(xiang): Disable multi-threading mode if running on multi-host.
|
|
511
|
+
# Because multi-threading would cause different JAX processes to load
|
|
512
|
+
# different weights at the same time.
|
|
513
|
+
if envs.TPU_MULTIHOST_BACKEND == "ray":
|
|
514
|
+
max_workers = 1
|
|
515
|
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
516
|
+
futures = [
|
|
517
|
+
executor.submit(
|
|
518
|
+
_load_hf_weights_on_thread,
|
|
519
|
+
vllm_config,
|
|
520
|
+
params,
|
|
521
|
+
metadata_map,
|
|
522
|
+
mesh,
|
|
523
|
+
weights_file,
|
|
524
|
+
filter_regex=filter_regex,
|
|
525
|
+
keep_original_dtype_keys_regex=
|
|
526
|
+
keep_original_dtype_keys_regex,
|
|
527
|
+
pp_missing_layers=pp_missing_layers,
|
|
528
|
+
) for weights_file in weights_files
|
|
529
|
+
]
|
|
530
|
+
for future in futures:
|
|
531
|
+
future.result()
|
|
532
|
+
|
|
454
533
|
check_all_loaded(params)
|
|
455
534
|
nnx.update(model, params)
|
|
456
535
|
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import copy
|
|
2
16
|
import functools
|
|
3
17
|
from collections.abc import Sequence
|
|
@@ -9,6 +23,7 @@ import jax
|
|
|
9
23
|
import torch
|
|
10
24
|
import torch.nn
|
|
11
25
|
import torchax
|
|
26
|
+
import vllm.envs as vllm_envs
|
|
12
27
|
from flax.typing import PRNGKey
|
|
13
28
|
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
14
29
|
from torchax.interop import jax_view, torch_view
|
|
@@ -22,8 +37,10 @@ from vllm.model_executor.models import supports_lora, supports_multimodal
|
|
|
22
37
|
from vllm.sequence import IntermediateTensors
|
|
23
38
|
|
|
24
39
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
40
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
41
|
+
from tpu_inference.layers.vllm.process_weights.cleanup_sharding import \
|
|
42
|
+
shard_model_to_tpu
|
|
25
43
|
from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
|
|
26
|
-
from tpu_inference.layers.vllm.sharding import shard_model_to_tpu
|
|
27
44
|
from tpu_inference.logger import init_logger
|
|
28
45
|
from tpu_inference.models.jax.jax_intermediate_tensor import \
|
|
29
46
|
JaxIntermediateTensors
|
|
@@ -118,9 +135,16 @@ class VllmModelWrapper:
|
|
|
118
135
|
"torch._sync",
|
|
119
136
|
return_value=None) if use_random_weights else nullcontext()
|
|
120
137
|
|
|
138
|
+
# By default load weights to the CPU device first. If we are running
|
|
139
|
+
# under Pathways, this would cause weights to be loaded on a CPU-only
|
|
140
|
+
# node, so we'll need to remove this context.
|
|
141
|
+
jax_context = jax.default_device(
|
|
142
|
+
jax.devices("cpu")
|
|
143
|
+
[0]) if not vllm_envs.VLLM_TPU_USING_PATHWAYS else nullcontext()
|
|
144
|
+
|
|
121
145
|
# Load the vLLM model and wrap it into a new model whose forward
|
|
122
146
|
# function can calculate the hidden_state and logits.
|
|
123
|
-
with load_context,
|
|
147
|
+
with load_context, jax_context:
|
|
124
148
|
vllm_model = vllm_get_model(vllm_config=vllm_config_for_load)
|
|
125
149
|
lora_manager = None
|
|
126
150
|
if vllm_config_for_load.lora_config is not None:
|
|
@@ -189,7 +213,7 @@ class VllmModelWrapper:
|
|
|
189
213
|
kwargs={
|
|
190
214
|
"input_ids": torch_view(input_ids),
|
|
191
215
|
"positions": torch_view(input_positions),
|
|
192
|
-
"intermediate_tensors":
|
|
216
|
+
"intermediate_tensors": intermediate_tensors,
|
|
193
217
|
"inputs_embeds": None,
|
|
194
218
|
},
|
|
195
219
|
tie_weights=False,
|
|
@@ -212,8 +236,10 @@ class VllmModelWrapper:
|
|
|
212
236
|
|
|
213
237
|
@functools.partial(
|
|
214
238
|
jax.jit,
|
|
215
|
-
out_shardings=(NamedSharding(
|
|
216
|
-
|
|
239
|
+
out_shardings=(NamedSharding(
|
|
240
|
+
self.mesh,
|
|
241
|
+
PartitionSpec(ShardingAxisName.MLP_DATA,
|
|
242
|
+
ShardingAxisName.MLP_TENSOR))),
|
|
217
243
|
)
|
|
218
244
|
def compute_logits_func(
|
|
219
245
|
params_and_buffers: Any,
|
|
@@ -255,7 +281,6 @@ def load_lora_model(model: torch.nn.Module, vllm_config: VllmConfig,
|
|
|
255
281
|
vllm_config,
|
|
256
282
|
device,
|
|
257
283
|
model.embedding_modules,
|
|
258
|
-
model.embedding_padding_modules,
|
|
259
284
|
)
|
|
260
285
|
return lora_manager, lora_manager.create_lora_manager(model)
|
|
261
286
|
|
|
@@ -269,10 +294,9 @@ def replace_set_lora(model):
|
|
|
269
294
|
index: int,
|
|
270
295
|
lora_a: torch.Tensor,
|
|
271
296
|
lora_b: torch.Tensor,
|
|
272
|
-
embeddings_tensor: Optional[torch.Tensor],
|
|
273
297
|
):
|
|
274
298
|
with torchax.default_env():
|
|
275
|
-
self._original_set_lora(index, lora_a, lora_b
|
|
299
|
+
self._original_set_lora(index, lora_a, lora_b)
|
|
276
300
|
|
|
277
301
|
def _tpu_reset_lora(self, index: int):
|
|
278
302
|
with torchax.default_env():
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from contextlib import contextmanager
|
|
2
16
|
from dataclasses import dataclass
|
|
3
17
|
from typing import Dict, List, Optional
|
|
@@ -1,2 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
# ruff: noqa
|
|
2
16
|
from tpu_inference.platforms.tpu_platform import TpuPlatform
|