tpu-inference 0.11.1.dev202511150811__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_dp_scheduler.py +899 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/fused_moe_v1_test.py +105 -0
- tests/kernels/mla_v1_test.py +396 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/conftest.py +32 -0
- tests/lora/test_bgmv.py +43 -0
- tests/lora/test_layers.py +654 -0
- tests/lora/test_lora.py +133 -0
- tests/lora/utils.py +96 -0
- tests/test_base.py +201 -0
- tests/test_envs.py +182 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +236 -0
- tpu_inference/__init__.py +34 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/core/sched/__init__.py +0 -0
- tpu_inference/core/sched/dp_scheduler.py +523 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/jax_parallel_state.py +67 -0
- tpu_inference/distributed/tpu_connector.py +728 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +107 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +362 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
- tpu_inference/kernels/mla/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/kernel.py +1349 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_interface.py +390 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/common/sharding.py +582 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +255 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +280 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +96 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
- tpu_inference/layers/jax/transformer_block.py +107 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +507 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +39 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
- tpu_inference/layers/vllm/sharding.py +230 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +311 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +444 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/gpt_oss.py +492 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
- tpu_inference/models/jax/llama3.py +375 -0
- tpu_inference/models/jax/llama4.py +629 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
- tpu_inference/models/jax/utils/weight_utils.py +529 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_platform.py +269 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +780 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +132 -0
- tpu_inference/runner/kv_cache_manager.py +479 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +217 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +248 -0
- tpu_inference/runner/structured_decoding_manager.py +88 -0
- tpu_inference/runner/tpu_runner.py +1620 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +367 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +317 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/tpu_worker.py +321 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,386 @@
|
|
|
1
|
+
from typing import Any, Callable, Optional, Union
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
import torch
|
|
6
|
+
from jax.experimental.layout import Format, Layout
|
|
7
|
+
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
8
|
+
from torch.nn.parameter import Parameter
|
|
9
|
+
from torchax.interop import jax_view, torch_view
|
|
10
|
+
from torchax.ops.mappings import t2j
|
|
11
|
+
from vllm.attention.layer import Attention
|
|
12
|
+
from vllm.logger import init_logger
|
|
13
|
+
from vllm.model_executor.layers.fused_moe.layer import (
|
|
14
|
+
FusedMoE, FusedMoEConfig, UnquantizedFusedMoEMethod)
|
|
15
|
+
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
|
16
|
+
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize)
|
|
17
|
+
from vllm.model_executor.layers.linear import (LinearBase,
|
|
18
|
+
UnquantizedLinearMethod)
|
|
19
|
+
from vllm.model_executor.layers.quantization import \
|
|
20
|
+
register_quantization_config
|
|
21
|
+
from vllm.model_executor.layers.quantization.base_config import (
|
|
22
|
+
QuantizationConfig, QuantizeMethodBase)
|
|
23
|
+
|
|
24
|
+
from tpu_inference import envs
|
|
25
|
+
from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
|
|
26
|
+
from tpu_inference.layers.common.quant_methods import (UNQUANTIZED,
|
|
27
|
+
get_tpu_quant_method)
|
|
28
|
+
from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded
|
|
29
|
+
from tpu_inference.layers.vllm.linear_common import (
|
|
30
|
+
reorder_concatenated_tensor_for_sharding,
|
|
31
|
+
slice_sharded_tensor_for_concatenation, torch_to_jax_param)
|
|
32
|
+
from tpu_inference.layers.vllm.quantization.common import (
|
|
33
|
+
JaxCommonConfig, JaxCommonLinearConfig)
|
|
34
|
+
|
|
35
|
+
P = PartitionSpec
|
|
36
|
+
logger = init_logger(__name__)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@register_quantization_config(get_tpu_quant_method(UNQUANTIZED))
|
|
40
|
+
class VllmUnquantizedConfig(QuantizationConfig, JaxCommonConfig):
|
|
41
|
+
|
|
42
|
+
@classmethod
|
|
43
|
+
def get_name(cls) -> str:
|
|
44
|
+
return UNQUANTIZED
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
|
48
|
+
return [torch.float32, torch.float16, torch.bfloat16]
|
|
49
|
+
|
|
50
|
+
@classmethod
|
|
51
|
+
def get_min_capability(cls) -> int:
|
|
52
|
+
return 0 # Always supported
|
|
53
|
+
|
|
54
|
+
@classmethod
|
|
55
|
+
def get_config_filenames(cls) -> list[str]:
|
|
56
|
+
return [] # No extra configs required.
|
|
57
|
+
|
|
58
|
+
@classmethod
|
|
59
|
+
def from_config(cls, _: dict[str, Any]) -> "VllmUnquantizedConfig":
|
|
60
|
+
return cls()
|
|
61
|
+
|
|
62
|
+
def get_quant_method(self, layer: torch.nn.Module,
|
|
63
|
+
prefix: str) -> Optional[QuantizeMethodBase]:
|
|
64
|
+
if isinstance(layer, LinearBase):
|
|
65
|
+
linear_config = self.get_linear_config(layer)
|
|
66
|
+
return VllmUnquantizedLinearMethod(linear_config)
|
|
67
|
+
if isinstance(layer, FusedMoE):
|
|
68
|
+
moe_config = self.get_moe_config(layer)
|
|
69
|
+
return VllmUnquantizedFusedMoEMethod(moe_config, self.mesh)
|
|
70
|
+
if isinstance(layer, Attention):
|
|
71
|
+
return None
|
|
72
|
+
return None
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class VllmUnquantizedLinearMethod(UnquantizedLinearMethod):
|
|
76
|
+
|
|
77
|
+
def __init__(self, jax_config: JaxCommonLinearConfig):
|
|
78
|
+
self.jax_config = jax_config
|
|
79
|
+
|
|
80
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
81
|
+
weight = torch_to_jax_param(
|
|
82
|
+
layer.weight,
|
|
83
|
+
NamedSharding(self.jax_config.mesh,
|
|
84
|
+
self.jax_config.weight_sharding),
|
|
85
|
+
self.jax_config.output_sizes,
|
|
86
|
+
self.jax_config.n_shards,
|
|
87
|
+
self.jax_config.fuse_matmuls,
|
|
88
|
+
)
|
|
89
|
+
delattr(layer, "weight")
|
|
90
|
+
layer.weight = weight
|
|
91
|
+
|
|
92
|
+
if layer.bias is not None and not layer.skip_bias_add:
|
|
93
|
+
if layer.return_bias:
|
|
94
|
+
logger.warning_once("Bias might return incorrect value.")
|
|
95
|
+
|
|
96
|
+
bias = torch_to_jax_param(
|
|
97
|
+
layer.bias,
|
|
98
|
+
NamedSharding(self.jax_config.mesh,
|
|
99
|
+
self.jax_config.bias_sharding),
|
|
100
|
+
self.jax_config.output_sizes,
|
|
101
|
+
self.jax_config.n_shards,
|
|
102
|
+
self.jax_config.fuse_matmuls,
|
|
103
|
+
)
|
|
104
|
+
delattr(layer, "bias")
|
|
105
|
+
layer.bias = bias
|
|
106
|
+
|
|
107
|
+
def apply(self,
|
|
108
|
+
layer: torch.nn.Module,
|
|
109
|
+
x: torch.Tensor,
|
|
110
|
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
111
|
+
with jax.named_scope(layer._get_name()):
|
|
112
|
+
if in_sharding := self.jax_config.get_input_sharding(x):
|
|
113
|
+
x.shard_(NamedSharding(self.jax_config.mesh, in_sharding))
|
|
114
|
+
|
|
115
|
+
if self.jax_config.fuse_matmuls:
|
|
116
|
+
out = self._apply_fused(layer, x, bias)
|
|
117
|
+
else:
|
|
118
|
+
out = self._apply_split(layer, x, bias)
|
|
119
|
+
|
|
120
|
+
if out_sharding := self.jax_config.get_output_sharding(out):
|
|
121
|
+
out.shard_(NamedSharding(self.jax_config.mesh, out_sharding))
|
|
122
|
+
|
|
123
|
+
return out
|
|
124
|
+
|
|
125
|
+
def _apply_fused(self,
|
|
126
|
+
layer: torch.nn.Module,
|
|
127
|
+
x: torch.Tensor,
|
|
128
|
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
129
|
+
x_jax = jax_view(x)
|
|
130
|
+
weight_jax = jax_view(layer.weight)
|
|
131
|
+
|
|
132
|
+
outs = jnp.einsum("mn,pn->mp", x_jax, weight_jax)
|
|
133
|
+
if bias is not None and not layer.skip_bias_add:
|
|
134
|
+
outs += bias.jax()
|
|
135
|
+
|
|
136
|
+
outs = slice_sharded_tensor_for_concatenation(
|
|
137
|
+
outs, self.jax_config.output_sizes, self.jax_config.n_shards)
|
|
138
|
+
out = jnp.concatenate(outs, axis=-1)
|
|
139
|
+
return torch_view(out)
|
|
140
|
+
|
|
141
|
+
def _apply_split(self,
|
|
142
|
+
layer: torch.nn.Module,
|
|
143
|
+
x: torch.Tensor,
|
|
144
|
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
145
|
+
assert isinstance(layer.weight, torch.nn.ParameterList)
|
|
146
|
+
|
|
147
|
+
x_jax = x.jax()
|
|
148
|
+
outs = []
|
|
149
|
+
for i, weight in enumerate(layer.weight):
|
|
150
|
+
weight_jax = jax_view(weight)
|
|
151
|
+
|
|
152
|
+
out = jnp.einsum("mn,pn->mp", x_jax, weight_jax)
|
|
153
|
+
if bias is not None and not layer.skip_bias_add:
|
|
154
|
+
out += jax_view(bias[i])
|
|
155
|
+
|
|
156
|
+
outs.append(out)
|
|
157
|
+
out = jnp.concatenate(outs, axis=-1)
|
|
158
|
+
return torch_view(out)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|
162
|
+
|
|
163
|
+
def __init__(self,
|
|
164
|
+
moe: FusedMoEConfig,
|
|
165
|
+
mesh: Mesh,
|
|
166
|
+
ep_axis_name: str = 'model'):
|
|
167
|
+
super().__init__(moe)
|
|
168
|
+
self.mesh = mesh
|
|
169
|
+
self.use_kernel = envs.USE_MOE_EP_KERNEL
|
|
170
|
+
self.ep_axis_name = ep_axis_name
|
|
171
|
+
# TODO: Use autotune table once we have it.
|
|
172
|
+
self.block_size = {
|
|
173
|
+
"bt": 16,
|
|
174
|
+
"bf": 384,
|
|
175
|
+
"bd1": 512,
|
|
176
|
+
"bd2": 512,
|
|
177
|
+
"btc": 16,
|
|
178
|
+
"bfc": 384,
|
|
179
|
+
"bd1c": 256,
|
|
180
|
+
"bd2c": 256,
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
def select_gemm_impl(
|
|
184
|
+
self,
|
|
185
|
+
prepare_finalize: FusedMoEPrepareAndFinalize,
|
|
186
|
+
moe: FusedMoEConfig,
|
|
187
|
+
layer: torch.nn.Module,
|
|
188
|
+
) -> FusedMoEPermuteExpertsUnpermute:
|
|
189
|
+
raise NotImplementedError(
|
|
190
|
+
"Selecting gemm implementation is currently not supported.")
|
|
191
|
+
|
|
192
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
193
|
+
assert isinstance(layer, FusedMoE)
|
|
194
|
+
available_devices = self.mesh.devices.flatten()
|
|
195
|
+
with jax.default_device(available_devices[0]):
|
|
196
|
+
w13_weight = t2j(layer.w13_weight, use_dlpack=False)
|
|
197
|
+
w2_weight = t2j(layer.w2_weight, use_dlpack=False)
|
|
198
|
+
|
|
199
|
+
if self.moe.has_bias:
|
|
200
|
+
w13_bias = t2j(layer.w13_bias, use_dlpack=False)
|
|
201
|
+
w2_bias = t2j(layer.w2_bias, use_dlpack=False)
|
|
202
|
+
|
|
203
|
+
if layer.activation == "swigluoai":
|
|
204
|
+
# When using swigluoai, vLLM splits gmm output in a interleaved way.
|
|
205
|
+
# However, interleaved split is not performant on TPU. Therefore,
|
|
206
|
+
# we preprocess the weight so that splitting gmm output by middle
|
|
207
|
+
# can still get the same result.
|
|
208
|
+
w1_weight = w13_weight[:, ::2, :]
|
|
209
|
+
w3_weight = w13_weight[:, 1::2, :]
|
|
210
|
+
w13_weight = jnp.concat([w1_weight, w3_weight], axis=1)
|
|
211
|
+
|
|
212
|
+
if self.moe.has_bias:
|
|
213
|
+
w1_bias = w13_bias[:, ::2]
|
|
214
|
+
w3_bias = w13_bias[:, 1::2]
|
|
215
|
+
w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
|
|
216
|
+
|
|
217
|
+
if self.use_kernel and layer.use_ep:
|
|
218
|
+
# Kernel expects:
|
|
219
|
+
# w13: (num_experts, 2, hidden_size, intermediate_size)
|
|
220
|
+
# w2: (num_experts, intermediate_size, hidden_size)
|
|
221
|
+
# Current format:
|
|
222
|
+
# w13_weight: (num_experts, 2*intermediate_size, hidden_size)
|
|
223
|
+
# w2_weight: (num_experts, hidden_size, intermediate_size)
|
|
224
|
+
num_experts = w13_weight.shape[0]
|
|
225
|
+
intermediate_size = w13_weight.shape[1] // 2
|
|
226
|
+
hidden_size = w13_weight.shape[2]
|
|
227
|
+
|
|
228
|
+
# Reshape and transpose w13_weight to (num_experts, 2, hidden_size, intermediate_size)
|
|
229
|
+
w13_reshaped = w13_weight.reshape(num_experts, 2,
|
|
230
|
+
intermediate_size,
|
|
231
|
+
hidden_size)
|
|
232
|
+
w13_weight_transposed = jnp.transpose(w13_reshaped,
|
|
233
|
+
(0, 1, 3, 2))
|
|
234
|
+
|
|
235
|
+
# Transpose w2_weight to (num_experts, intermediate_size, hidden_size)
|
|
236
|
+
w2_weight_transposed = jnp.transpose(w2_weight, (0, 2, 1))
|
|
237
|
+
|
|
238
|
+
# Apply EP sharding
|
|
239
|
+
w13_weight = jax.device_put(
|
|
240
|
+
w13_weight_transposed,
|
|
241
|
+
Format(
|
|
242
|
+
Layout((0, 1, 2, 3)),
|
|
243
|
+
NamedSharding(self.mesh, P("model", None, None,
|
|
244
|
+
None))))
|
|
245
|
+
w2_weight = jax.device_put(
|
|
246
|
+
w2_weight_transposed,
|
|
247
|
+
Format(Layout((0, 1, 2)),
|
|
248
|
+
NamedSharding(self.mesh, P("model", None, None))))
|
|
249
|
+
|
|
250
|
+
if self.moe.has_bias:
|
|
251
|
+
w13_bias = w13_bias.reshape(num_experts, 2,
|
|
252
|
+
intermediate_size)
|
|
253
|
+
|
|
254
|
+
# Apply EP sharding
|
|
255
|
+
w13_bias = jax.device_put(
|
|
256
|
+
w13_bias,
|
|
257
|
+
Format(
|
|
258
|
+
Layout((0, 1, 2)),
|
|
259
|
+
NamedSharding(self.mesh, P("model", None, None))))
|
|
260
|
+
w2_bias = jax.device_put(
|
|
261
|
+
w2_bias,
|
|
262
|
+
Format(Layout((0, 1)),
|
|
263
|
+
NamedSharding(self.mesh, P("model", None))))
|
|
264
|
+
|
|
265
|
+
else:
|
|
266
|
+
# Original logic for non-kernel path
|
|
267
|
+
if layer.use_ep:
|
|
268
|
+
w13_weight = jax.device_put(
|
|
269
|
+
w13_weight,
|
|
270
|
+
Format(
|
|
271
|
+
Layout((0, 1, 2)),
|
|
272
|
+
NamedSharding(self.mesh, P("model", None, None))))
|
|
273
|
+
w2_weight = jax.device_put(
|
|
274
|
+
w2_weight,
|
|
275
|
+
Format(
|
|
276
|
+
Layout((0, 1, 2)),
|
|
277
|
+
NamedSharding(self.mesh, P("model", None, None))))
|
|
278
|
+
|
|
279
|
+
if self.moe.has_bias:
|
|
280
|
+
w13_bias = jax.device_put(
|
|
281
|
+
w13_bias,
|
|
282
|
+
Format(Layout((0, 1)),
|
|
283
|
+
NamedSharding(self.mesh, P("model", None))))
|
|
284
|
+
w2_bias = jax.device_put(
|
|
285
|
+
w2_bias,
|
|
286
|
+
Format(Layout((0, 1)),
|
|
287
|
+
NamedSharding(self.mesh, P("model", None))))
|
|
288
|
+
|
|
289
|
+
else:
|
|
290
|
+
intermediate_size = w13_weight.shape[1] // 2
|
|
291
|
+
assert intermediate_size == w2_weight.shape[-1]
|
|
292
|
+
output_sizes = [intermediate_size, intermediate_size]
|
|
293
|
+
n_shards = self.mesh.shape["model"]
|
|
294
|
+
assert intermediate_size % n_shards == 0
|
|
295
|
+
w13_weight = reorder_concatenated_tensor_for_sharding(
|
|
296
|
+
w13_weight, output_sizes, n_shards, dim=1)
|
|
297
|
+
w13_weight = jax.device_put(
|
|
298
|
+
w13_weight,
|
|
299
|
+
Format(
|
|
300
|
+
Layout((0, 1, 2)),
|
|
301
|
+
NamedSharding(self.mesh, P(None, "model", None))))
|
|
302
|
+
w2_weight = jax.device_put(
|
|
303
|
+
w2_weight,
|
|
304
|
+
Format(
|
|
305
|
+
Layout((0, 1, 2)),
|
|
306
|
+
NamedSharding(self.mesh, P(None, None, "model"))))
|
|
307
|
+
|
|
308
|
+
if self.moe.has_bias:
|
|
309
|
+
w13_bias = reorder_concatenated_tensor_for_sharding(
|
|
310
|
+
w13_bias, output_sizes, n_shards, dim=1)
|
|
311
|
+
w13_bias = jax.device_put(
|
|
312
|
+
w13_bias,
|
|
313
|
+
Format(Layout((0, 1)),
|
|
314
|
+
NamedSharding(self.mesh, P(None, "model"))))
|
|
315
|
+
w2_bias = jax.device_put(
|
|
316
|
+
w2_bias,
|
|
317
|
+
Format(Layout((0, 1)),
|
|
318
|
+
NamedSharding(self.mesh, P(None, None))))
|
|
319
|
+
|
|
320
|
+
layer.w13_weight = Parameter(torch_view(w13_weight),
|
|
321
|
+
requires_grad=False)
|
|
322
|
+
layer.w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
|
|
323
|
+
|
|
324
|
+
if self.moe.has_bias:
|
|
325
|
+
layer.w13_bias = Parameter(torch_view(w13_bias),
|
|
326
|
+
requires_grad=False)
|
|
327
|
+
layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False)
|
|
328
|
+
|
|
329
|
+
def apply(
|
|
330
|
+
self,
|
|
331
|
+
layer: torch.nn.Module,
|
|
332
|
+
x: torch.Tensor,
|
|
333
|
+
router_logits: torch.Tensor,
|
|
334
|
+
top_k: int,
|
|
335
|
+
renormalize: bool,
|
|
336
|
+
use_grouped_topk: bool = False,
|
|
337
|
+
topk_group: Optional[int] = None,
|
|
338
|
+
num_expert_group: Optional[int] = None,
|
|
339
|
+
global_num_experts: int = -1,
|
|
340
|
+
expert_map: Optional[torch.Tensor] = None,
|
|
341
|
+
custom_routing_function: Optional[Callable] = None,
|
|
342
|
+
scoring_func: str = "softmax",
|
|
343
|
+
routed_scaling_factor: float = 1.0,
|
|
344
|
+
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
345
|
+
apply_router_weight_on_input: bool = False,
|
|
346
|
+
activation: str = "silu",
|
|
347
|
+
enable_eplb: bool = False,
|
|
348
|
+
expert_load_view: Optional[torch.Tensor] = None,
|
|
349
|
+
logical_to_physical_map: Optional[torch.Tensor] = None,
|
|
350
|
+
logical_replica_count: Optional[torch.Tensor] = None,
|
|
351
|
+
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
|
352
|
+
assert isinstance(layer, FusedMoE)
|
|
353
|
+
if scoring_func != "softmax":
|
|
354
|
+
raise NotImplementedError(
|
|
355
|
+
"Only softmax is supported for scoring_func")
|
|
356
|
+
|
|
357
|
+
if self.use_kernel and layer.use_ep:
|
|
358
|
+
output = fused_ep_moe(
|
|
359
|
+
mesh=self.mesh,
|
|
360
|
+
tokens=jax_view(x),
|
|
361
|
+
w1=jax_view(layer.w13_weight),
|
|
362
|
+
w2=jax_view(layer.w2_weight),
|
|
363
|
+
gating_output=jax_view(router_logits),
|
|
364
|
+
top_k=top_k,
|
|
365
|
+
ep_axis_name=self.ep_axis_name,
|
|
366
|
+
**self.block_size,
|
|
367
|
+
)
|
|
368
|
+
else:
|
|
369
|
+
# Use the original implementation
|
|
370
|
+
output = fused_moe_func_padded(
|
|
371
|
+
jax_view(x),
|
|
372
|
+
jax_view(layer.w13_weight),
|
|
373
|
+
jax_view(layer.w2_weight),
|
|
374
|
+
jax_view(layer.w13_bias) if self.moe.has_bias else None,
|
|
375
|
+
jax_view(layer.w2_bias) if self.moe.has_bias else None,
|
|
376
|
+
jax_view(router_logits),
|
|
377
|
+
topk=top_k,
|
|
378
|
+
global_num_experts=global_num_experts,
|
|
379
|
+
renormalize=renormalize,
|
|
380
|
+
reduce_results=layer.reduce_results,
|
|
381
|
+
mesh=self.mesh,
|
|
382
|
+
use_ep=layer.use_ep,
|
|
383
|
+
activation=activation,
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
return torch_view(output)
|
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
import torch
|
|
6
|
+
import torchax
|
|
7
|
+
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
8
|
+
from torch.nn import Parameter
|
|
9
|
+
from torch.utils import _pytree as pytree
|
|
10
|
+
from torchax.interop import jax_view, torch_view
|
|
11
|
+
from torchax.ops.mappings import t2j
|
|
12
|
+
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
|
|
13
|
+
MergedColumnParallelLinearWithLoRA,
|
|
14
|
+
MergedQKVParallelLinearWithLoRA,
|
|
15
|
+
QKVParallelLinearWithLoRA,
|
|
16
|
+
ReplicatedLinearWithLoRA,
|
|
17
|
+
RowParallelLinearWithLoRA)
|
|
18
|
+
from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA
|
|
19
|
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
20
|
+
ParallelLMHead, VocabParallelEmbedding)
|
|
21
|
+
|
|
22
|
+
from tpu_inference.logger import init_logger
|
|
23
|
+
|
|
24
|
+
P = PartitionSpec
|
|
25
|
+
|
|
26
|
+
logger = init_logger(__name__)
|
|
27
|
+
|
|
28
|
+
TORCH_TO_JAX_DTYPE_MAP = {
|
|
29
|
+
torch.float32: jnp.float32,
|
|
30
|
+
torch.float16: jnp.float16,
|
|
31
|
+
torch.bfloat16: jnp.bfloat16,
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def shard_model_to_tpu(model: torch.nn.Module,
|
|
36
|
+
mesh: Mesh) -> dict[str, torchax.torch.Tensor]:
|
|
37
|
+
"""
|
|
38
|
+
Shard the model weights and move them to TPU.
|
|
39
|
+
At the same time, also turn the weight tensors into torchax tensors so that
|
|
40
|
+
jax code can interop with it and the overall program can be traced and
|
|
41
|
+
compiled in XLA.
|
|
42
|
+
Args:
|
|
43
|
+
model: A PyTorch model whose weights are on CPU main memory.
|
|
44
|
+
mesh: JAX mesh object for sharding.
|
|
45
|
+
Returns:
|
|
46
|
+
Dictionary of parameters and buffers that will be used as arguments of
|
|
47
|
+
torch.func.functional_call
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
with jax.default_device(jax.devices("cpu")[0]):
|
|
51
|
+
_shard_module_to_tpu(model, mesh)
|
|
52
|
+
|
|
53
|
+
params, buffers = _extract_all_params_buffers(model)
|
|
54
|
+
|
|
55
|
+
# For other weight tensors, repliate them on all the TPU chips.
|
|
56
|
+
params, buffers = pytree.tree_map_only(
|
|
57
|
+
_tensor_is_in_cpu,
|
|
58
|
+
lambda tensor: _shard_tensor_to_tpu_replicated(tensor, mesh),
|
|
59
|
+
(params, buffers))
|
|
60
|
+
|
|
61
|
+
return {**params, **buffers}
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def update_lora(model: torch.nn.Module,
|
|
65
|
+
initial_params_buffers) -> dict[str, torchax.torch.Tensor]:
|
|
66
|
+
params, buffers = _extract_all_params_buffers(model)
|
|
67
|
+
params_buffers = {**params, **buffers}
|
|
68
|
+
for k, v in params_buffers.items():
|
|
69
|
+
if 'lora_a_stacked' in k or 'lora_b_stacked' in k:
|
|
70
|
+
assert k in initial_params_buffers, f"{k} not in initial_params_buffers"
|
|
71
|
+
initial_params_buffers[k] = v
|
|
72
|
+
|
|
73
|
+
return initial_params_buffers
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _extract_all_params_buffers(model: torch.nn.Module):
|
|
77
|
+
return dict(model.named_parameters()), dict(model.named_buffers())
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _tensor_is_in_cpu(tensor: torch.tensor) -> bool:
|
|
81
|
+
# Check if a tensor haven't been converted to torchax tensor.
|
|
82
|
+
if not isinstance(tensor, torchax.tensor.Tensor):
|
|
83
|
+
return True
|
|
84
|
+
# Check if torchax tensor is still in CPU.
|
|
85
|
+
return tensor.jax_device == jax.devices('cpu')[0]
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _convert_to_torchax_and_shard(tensor: torch.Tensor,
|
|
89
|
+
sharding: NamedSharding) -> torch.Tensor:
|
|
90
|
+
if os.getenv("VLLM_TPU_USING_PATHWAYS", False) and isinstance(
|
|
91
|
+
tensor, torch.Tensor):
|
|
92
|
+
np_tensor = tensor.detach().cpu().to(torch.float32).numpy()
|
|
93
|
+
dtype = TORCH_TO_JAX_DTYPE_MAP.get(tensor.dtype, jnp.float32)
|
|
94
|
+
return torch_view(jax.device_put(np_tensor, sharding).astype(dtype))
|
|
95
|
+
else:
|
|
96
|
+
if isinstance(tensor, torchax.tensor.Tensor):
|
|
97
|
+
tensor = jax_view(tensor)
|
|
98
|
+
else:
|
|
99
|
+
tensor = t2j(tensor)
|
|
100
|
+
return torch_view(_sharded_device_put(tensor, sharding))
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _shard_tensor_to_tpu_replicated(tensor: torch.Tensor,
|
|
104
|
+
mesh: Mesh) -> torchax.tensor.Tensor:
|
|
105
|
+
return _convert_to_torchax_and_shard(tensor, NamedSharding(mesh, P()))
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def _shard_vocab_parallel_embedding(layer: VocabParallelEmbedding,
|
|
109
|
+
mesh: Mesh) -> None:
|
|
110
|
+
weight = _convert_to_torchax_and_shard(
|
|
111
|
+
layer.weight, NamedSharding(mesh, P('model', None)))
|
|
112
|
+
layer.weight = Parameter(weight, requires_grad=False)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _shard_lm_head(layer: ParallelLMHead, mesh: Mesh):
|
|
116
|
+
# TODO(qihqi): currently this is not handling case of tie_word_weights=True.
|
|
117
|
+
# if that config is set, then we should not create new weights but reuse the
|
|
118
|
+
# weight from VocabParallelEmbedding
|
|
119
|
+
weight = _convert_to_torchax_and_shard(
|
|
120
|
+
layer.weight, NamedSharding(mesh, P('model', None)))
|
|
121
|
+
layer.weight = Parameter(weight, requires_grad=False)
|
|
122
|
+
if layer.bias is not None:
|
|
123
|
+
bias = _convert_to_torchax_and_shard(layer.bias,
|
|
124
|
+
NamedSharding(mesh, P('model')))
|
|
125
|
+
layer.bias = Parameter(bias, requires_grad=False)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _shard_base_linear_lora_replicated(layer: BaseLinearLayerWithLoRA,
|
|
129
|
+
mesh: Mesh) -> None:
|
|
130
|
+
# NOTE: lora_a_stacked[i] has shape [max_loras, 1, num_out, num_in]
|
|
131
|
+
sharded_lora_a_tpu = torch.nn.ParameterList()
|
|
132
|
+
sharded_lora_b_tpu = torch.nn.ParameterList()
|
|
133
|
+
|
|
134
|
+
for i in range(layer.n_slices):
|
|
135
|
+
sharded_lora_a_tpu.append(
|
|
136
|
+
_shard_tensor_to_tpu_replicated(layer.lora_a_stacked[i], mesh))
|
|
137
|
+
sharded_lora_b_tpu.append(
|
|
138
|
+
_shard_tensor_to_tpu_replicated(layer.lora_b_stacked[i], mesh))
|
|
139
|
+
|
|
140
|
+
layer.lora_a_stacked = sharded_lora_a_tpu
|
|
141
|
+
layer.lora_b_stacked = sharded_lora_b_tpu
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def _shard_column_linear_lora(layer: ColumnParallelLinearWithLoRA,
|
|
145
|
+
mesh: Mesh) -> None:
|
|
146
|
+
assert layer.n_slices > 0, "layer.n_slices should be greater than 0"
|
|
147
|
+
# lora_a_stacked[i] has shape [max_loras, 1, max_lora_rank, in_features]
|
|
148
|
+
sharded_lora_a_tpu = torch.nn.ParameterList()
|
|
149
|
+
sharded_lora_b_tpu = torch.nn.ParameterList()
|
|
150
|
+
|
|
151
|
+
# lora_b_stacked[i] has shape [max_loras, 1, out_features, max_lora_rank]
|
|
152
|
+
lora_b_partition_spec = P(None, None, 'model', None)
|
|
153
|
+
lora_b_sharding = NamedSharding(mesh, lora_b_partition_spec)
|
|
154
|
+
for i in range(layer.n_slices):
|
|
155
|
+
sharded_lora_a_tpu.append(
|
|
156
|
+
_shard_tensor_to_tpu_replicated(layer.lora_a_stacked[i], mesh))
|
|
157
|
+
|
|
158
|
+
sharded_lora_b_tpu.append(
|
|
159
|
+
_convert_to_torchax_and_shard(layer.lora_b_stacked[i],
|
|
160
|
+
lora_b_sharding))
|
|
161
|
+
|
|
162
|
+
layer.lora_a_stacked = sharded_lora_a_tpu
|
|
163
|
+
layer.lora_b_stacked = sharded_lora_b_tpu
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _shard_qkv_linear_lora(layer: ColumnParallelLinearWithLoRA,
|
|
167
|
+
mesh: Mesh) -> None:
|
|
168
|
+
_shard_column_linear_lora(layer, mesh)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def _shard_merged_column_parallel_linear_lora(
|
|
172
|
+
layer: MergedColumnParallelLinearWithLoRA, mesh: Mesh) -> None:
|
|
173
|
+
_shard_column_linear_lora(layer, mesh)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def _shard_merged_qkv_parallel_linear_lora(
|
|
177
|
+
layer: MergedQKVParallelLinearWithLoRA, mesh: Mesh) -> None:
|
|
178
|
+
_shard_column_linear_lora(layer, mesh)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def _shard_row_parallel_linear_lora(layer: RowParallelLinearWithLoRA,
|
|
182
|
+
mesh: Mesh) -> None:
|
|
183
|
+
_shard_base_linear_lora_replicated(layer, mesh)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
# NOTE: Ordering is important as it calls first matched type of a given module
|
|
187
|
+
MODULE_TYPE_TO_SHARDING_FUNC = [
|
|
188
|
+
# Shard embedding layers
|
|
189
|
+
(ParallelLMHead, _shard_lm_head),
|
|
190
|
+
(VocabParallelEmbedding, _shard_vocab_parallel_embedding),
|
|
191
|
+
# Shard LoRA layers
|
|
192
|
+
(ColumnParallelLinearWithLoRA, _shard_column_linear_lora),
|
|
193
|
+
(QKVParallelLinearWithLoRA, _shard_qkv_linear_lora),
|
|
194
|
+
(MergedColumnParallelLinearWithLoRA,
|
|
195
|
+
_shard_merged_column_parallel_linear_lora),
|
|
196
|
+
(MergedQKVParallelLinearWithLoRA, _shard_merged_qkv_parallel_linear_lora),
|
|
197
|
+
(RowParallelLinearWithLoRA, _shard_row_parallel_linear_lora),
|
|
198
|
+
(ReplicatedLinearWithLoRA, _shard_base_linear_lora_replicated),
|
|
199
|
+
]
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def _shard_module_to_tpu(model: torch.nn.Module, mesh: Mesh) -> None:
|
|
203
|
+
for path, module in model.named_modules():
|
|
204
|
+
for module_type, sharding_func in MODULE_TYPE_TO_SHARDING_FUNC:
|
|
205
|
+
if type(module) is module_type:
|
|
206
|
+
logger.debug("shard %s with %s", path, sharding_func)
|
|
207
|
+
sharding_func(module, mesh)
|
|
208
|
+
break
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def _sharded_device_put(tensor: jax.Array, sharding) -> jax.Array:
|
|
212
|
+
if isinstance(tensor, tuple):
|
|
213
|
+
return tuple(_sharded_device_put(t, sharding) for t in tensor)
|
|
214
|
+
import os
|
|
215
|
+
multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
|
|
216
|
+
if multihost_backend != "ray":
|
|
217
|
+
return jax.device_put(tensor, sharding)
|
|
218
|
+
|
|
219
|
+
# NOTE: at here, num_global_devices != num_local_devices
|
|
220
|
+
# meaning we are in multi-host setup. Each host will run the same process
|
|
221
|
+
# and each process only need to handle the devices accessible to this host.
|
|
222
|
+
shape = tensor.shape
|
|
223
|
+
x_split = [
|
|
224
|
+
jax.device_put(tensor[i], device) for device, i in
|
|
225
|
+
sharding.addressable_devices_indices_map(shape).items()
|
|
226
|
+
]
|
|
227
|
+
return jax.make_array_from_single_device_arrays(shape,
|
|
228
|
+
sharding,
|
|
229
|
+
x_split,
|
|
230
|
+
dtype=tensor.dtype)
|
tpu_inference/logger.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
|
|
3
|
+
from vllm.logger import _VllmLogger
|
|
4
|
+
from vllm.logger import init_logger as init_vllm_logger
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def init_logger(name: str) -> _VllmLogger:
|
|
8
|
+
# Prepend the root "vllm" to the module path to use vllm's configured logger.
|
|
9
|
+
patched_name = "vllm." + name
|
|
10
|
+
return init_vllm_logger(patched_name)
|
|
File without changes
|