tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +14 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +25 -8
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +14 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +20 -3
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +20 -26
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +22 -3
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +100 -455
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
- tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +37 -16
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +113 -124
- tpu_inference/models/jax/gpt_oss.py +23 -7
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
- tpu_inference/models/jax/utils/weight_utils.py +32 -1
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +27 -29
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +69 -35
- tpu_inference/runner/kv_cache.py +14 -0
- tpu_inference/runner/kv_cache_manager.py +15 -2
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +30 -10
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +31 -30
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +23 -7
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -208
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import jax
|
|
16
|
+
from jax.sharding import Mesh, NamedSharding
|
|
17
|
+
from jax.sharding import PartitionSpec as P
|
|
18
|
+
|
|
19
|
+
from tpu_inference import envs
|
|
20
|
+
from tpu_inference.kernels.quantized_matmul.kernel import (
|
|
21
|
+
quantized_matmul_kernel, xla_quantized_matmul)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def sharded_quantized_matmul(x: jax.Array, w_q: jax.Array, w_s: jax.Array,
|
|
25
|
+
mesh: Mesh, weight_sharding: P) -> jax.Array:
|
|
26
|
+
"""
|
|
27
|
+
Wrapper around the quantized matmul kernel.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
x: Activation.
|
|
31
|
+
w_q: Weight quantized array. [n_output_features, n_input_features]
|
|
32
|
+
w_s: Weight quantization scale. [n_output_features]
|
|
33
|
+
mesh: Mesh to shard on.
|
|
34
|
+
weight_sharding: PartitionSpec for the weight tensor.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
Output of the quantized matmul.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
# NOTE (jacobplatin/kyuyeunk) there have been numeric issues (concerning) NaNs
|
|
41
|
+
# with the kernel and thus we disable it for now.
|
|
42
|
+
if envs.ENABLE_QUANTIZED_MATMUL_KERNEL:
|
|
43
|
+
out_axis, in_axis = weight_sharding
|
|
44
|
+
x_sharding = P(None, in_axis)
|
|
45
|
+
scale_sharding = P(out_axis, )
|
|
46
|
+
out_sharding = P(None, out_axis)
|
|
47
|
+
|
|
48
|
+
x = jax.lax.with_sharding_constraint(x,
|
|
49
|
+
NamedSharding(mesh, x_sharding))
|
|
50
|
+
|
|
51
|
+
def wrapper(x, w_q, w_s):
|
|
52
|
+
output = quantized_matmul_kernel(x, w_q, w_s, x_q_dtype=w_q.dtype)
|
|
53
|
+
if in_axis:
|
|
54
|
+
output = jax.lax.psum(output, axis_name=in_axis)
|
|
55
|
+
return output
|
|
56
|
+
|
|
57
|
+
return jax.shard_map(wrapper,
|
|
58
|
+
mesh=mesh,
|
|
59
|
+
in_specs=(x_sharding, weight_sharding,
|
|
60
|
+
scale_sharding),
|
|
61
|
+
out_specs=(out_sharding),
|
|
62
|
+
check_vma=False)(x, w_q, w_s)
|
|
63
|
+
else:
|
|
64
|
+
return xla_quantized_matmul(x, w_q, w_s)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -1,7 +1,18 @@
|
|
|
1
|
-
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
2
14
|
|
|
3
15
|
import jax
|
|
4
|
-
import jax.numpy as jnp
|
|
5
16
|
import torch
|
|
6
17
|
import torchax
|
|
7
18
|
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
@@ -9,6 +20,7 @@ from torch.nn import Parameter
|
|
|
9
20
|
from torch.utils import _pytree as pytree
|
|
10
21
|
from torchax.interop import jax_view, torch_view
|
|
11
22
|
from torchax.ops.mappings import t2j
|
|
23
|
+
from vllm import envs as vllm_envs
|
|
12
24
|
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
|
|
13
25
|
MergedColumnParallelLinearWithLoRA,
|
|
14
26
|
MergedQKVParallelLinearWithLoRA,
|
|
@@ -20,18 +32,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
|
20
32
|
ParallelLMHead, VocabParallelEmbedding)
|
|
21
33
|
|
|
22
34
|
from tpu_inference import envs
|
|
35
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
23
36
|
from tpu_inference.logger import init_logger
|
|
37
|
+
from tpu_inference.utils import to_jax_dtype
|
|
24
38
|
|
|
25
39
|
P = PartitionSpec
|
|
26
40
|
|
|
27
41
|
logger = init_logger(__name__)
|
|
28
42
|
|
|
29
|
-
TORCH_TO_JAX_DTYPE_MAP = {
|
|
30
|
-
torch.float32: jnp.float32,
|
|
31
|
-
torch.float16: jnp.float16,
|
|
32
|
-
torch.bfloat16: jnp.bfloat16,
|
|
33
|
-
}
|
|
34
|
-
|
|
35
43
|
|
|
36
44
|
def shard_model_to_tpu(model: torch.nn.Module,
|
|
37
45
|
mesh: Mesh) -> dict[str, torchax.torch.Tensor]:
|
|
@@ -88,10 +96,9 @@ def _tensor_is_in_cpu(tensor: torch.tensor) -> bool:
|
|
|
88
96
|
|
|
89
97
|
def _convert_to_torchax_and_shard(tensor: torch.Tensor,
|
|
90
98
|
sharding: NamedSharding) -> torch.Tensor:
|
|
91
|
-
if
|
|
92
|
-
tensor, torch.Tensor):
|
|
99
|
+
if vllm_envs.VLLM_TPU_USING_PATHWAYS and isinstance(tensor, torch.Tensor):
|
|
93
100
|
np_tensor = tensor.detach().cpu().to(torch.float32).numpy()
|
|
94
|
-
dtype =
|
|
101
|
+
dtype = to_jax_dtype(tensor.dtype)
|
|
95
102
|
return torch_view(jax.device_put(np_tensor, sharding).astype(dtype))
|
|
96
103
|
else:
|
|
97
104
|
if isinstance(tensor, torchax.tensor.Tensor):
|
|
@@ -109,7 +116,8 @@ def _shard_tensor_to_tpu_replicated(tensor: torch.Tensor,
|
|
|
109
116
|
def _shard_vocab_parallel_embedding(layer: VocabParallelEmbedding,
|
|
110
117
|
mesh: Mesh) -> None:
|
|
111
118
|
weight = _convert_to_torchax_and_shard(
|
|
112
|
-
layer.weight, NamedSharding(mesh, P(
|
|
119
|
+
layer.weight, NamedSharding(mesh, P(ShardingAxisName.MLP_TENSOR,
|
|
120
|
+
None)))
|
|
113
121
|
layer.weight = Parameter(weight, requires_grad=False)
|
|
114
122
|
|
|
115
123
|
|
|
@@ -118,11 +126,12 @@ def _shard_lm_head(layer: ParallelLMHead, mesh: Mesh):
|
|
|
118
126
|
# if that config is set, then we should not create new weights but reuse the
|
|
119
127
|
# weight from VocabParallelEmbedding
|
|
120
128
|
weight = _convert_to_torchax_and_shard(
|
|
121
|
-
layer.weight, NamedSharding(mesh, P(
|
|
129
|
+
layer.weight, NamedSharding(mesh, P(ShardingAxisName.MLP_TENSOR,
|
|
130
|
+
None)))
|
|
122
131
|
layer.weight = Parameter(weight, requires_grad=False)
|
|
123
132
|
if layer.bias is not None:
|
|
124
|
-
bias = _convert_to_torchax_and_shard(
|
|
125
|
-
|
|
133
|
+
bias = _convert_to_torchax_and_shard(
|
|
134
|
+
layer.bias, NamedSharding(mesh, P(ShardingAxisName.MLP_TENSOR)))
|
|
126
135
|
layer.bias = Parameter(bias, requires_grad=False)
|
|
127
136
|
|
|
128
137
|
|
|
@@ -0,0 +1,369 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from dataclasses import dataclass, fields
|
|
16
|
+
|
|
17
|
+
import jax
|
|
18
|
+
import jax.numpy as jnp
|
|
19
|
+
from jax.experimental.layout import Format, Layout, with_layout_constraint
|
|
20
|
+
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
21
|
+
from torchax.tensor import Tensor
|
|
22
|
+
|
|
23
|
+
from tpu_inference.layers.common.quantization import quantize_tensor
|
|
24
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
25
|
+
from tpu_inference.layers.common.utils import \
|
|
26
|
+
reorder_concatenated_tensor_for_sharding
|
|
27
|
+
from tpu_inference.layers.vllm.fused_moe import FusedMoEBackend
|
|
28
|
+
from tpu_inference.utils import align_to
|
|
29
|
+
|
|
30
|
+
P = PartitionSpec
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@jax.tree_util.register_dataclass
|
|
34
|
+
@dataclass
|
|
35
|
+
class FusedMoEWeights:
|
|
36
|
+
"""Fused moe weights. weights can be either jax or torchax array."""
|
|
37
|
+
w13_weight: jax.Array | Tensor
|
|
38
|
+
w13_weight_scale: jax.Array | Tensor | None
|
|
39
|
+
w13_bias: jax.Array | Tensor | None
|
|
40
|
+
w2_weight: jax.Array | Tensor
|
|
41
|
+
w2_weight_scale: jax.Array | Tensor | None
|
|
42
|
+
w2_bias: jax.Array | Tensor | None
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def quantize_moe_weights(
|
|
46
|
+
weights: FusedMoEWeights,
|
|
47
|
+
dtype: jnp.dtype,
|
|
48
|
+
block_size: int | None,
|
|
49
|
+
) -> FusedMoEWeights:
|
|
50
|
+
"""Quantize fused moe weights into a given dtype and block size.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
weights: fused moe weights.
|
|
54
|
+
dtype: dtype to perform quantization.
|
|
55
|
+
block_size: Specify block quantization size. If non, use per-channel
|
|
56
|
+
quantization. If contracting dim is not divisible by block size,
|
|
57
|
+
the dim will be automatically padded and corresponding dim on bias
|
|
58
|
+
and the other weight (w13_weight <-> w2_weight) is also padded.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
Quantized fused moe weights that may have also been padded.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
# If scale is present, it means the weights are already quantized.
|
|
65
|
+
# Ensure that weights are not quantized by checking if scales are None.
|
|
66
|
+
assert weights.w13_weight_scale is None
|
|
67
|
+
assert weights.w2_weight_scale is None
|
|
68
|
+
|
|
69
|
+
w13_weight = weights.w13_weight
|
|
70
|
+
w2_weight = weights.w2_weight
|
|
71
|
+
|
|
72
|
+
if block_size is None:
|
|
73
|
+
# Use per-channel quantizaiton.
|
|
74
|
+
w13_block_size = w13_weight.shape[-1]
|
|
75
|
+
w2_block_size = w2_weight.shape[-1]
|
|
76
|
+
else:
|
|
77
|
+
w13_block_size = w2_block_size = block_size
|
|
78
|
+
|
|
79
|
+
_, orig_hidden_size, orig_intermediate_size = w2_weight.shape
|
|
80
|
+
|
|
81
|
+
w13_weight, w13_weight_scale = quantize_tensor(dtype, w13_weight, 2,
|
|
82
|
+
w13_block_size, True)
|
|
83
|
+
w2_weight, w2_weight_scale = quantize_tensor(dtype, w2_weight, 2,
|
|
84
|
+
w2_block_size, True)
|
|
85
|
+
|
|
86
|
+
intermediate_size = w2_weight.shape[-1]
|
|
87
|
+
hidden_size = w13_weight.shape[-1]
|
|
88
|
+
|
|
89
|
+
# Dims may have been padded to align with subchannel size during
|
|
90
|
+
# quantization. We pad the corresponding dim on other weight.
|
|
91
|
+
# NOTE: We perform padding after quantization as padding value can
|
|
92
|
+
# affect quantization numerics.
|
|
93
|
+
w13_pad_widths = [[0, 0] for _ in range(3)]
|
|
94
|
+
w13_pad_widths[1][1] = 2 * (intermediate_size - orig_intermediate_size)
|
|
95
|
+
w2_pad_widths = [[0, 0] for _ in range(3)]
|
|
96
|
+
w2_pad_widths[1][1] = hidden_size - orig_hidden_size
|
|
97
|
+
|
|
98
|
+
weights.w13_weight = jnp.pad(w13_weight, w13_pad_widths)
|
|
99
|
+
weights.w13_weight_scale = jnp.pad(w13_weight_scale, w13_pad_widths)
|
|
100
|
+
weights.w2_weight = jnp.pad(w2_weight, w2_pad_widths)
|
|
101
|
+
weights.w2_weight_scale = jnp.pad(w2_weight_scale, w2_pad_widths)
|
|
102
|
+
|
|
103
|
+
if (w13_bias := weights.w13_bias) is not None:
|
|
104
|
+
weights.w13_bias = jnp.pad(w13_bias, w13_pad_widths[:2])
|
|
105
|
+
if (w2_bias := weights.w2_bias) is not None:
|
|
106
|
+
weights.w2_bias = jnp.pad(w2_bias, w2_pad_widths[:2])
|
|
107
|
+
|
|
108
|
+
return weights
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def process_moe_weights(
|
|
112
|
+
weights: FusedMoEWeights,
|
|
113
|
+
moe_backend: FusedMoEBackend,
|
|
114
|
+
w13_reorder_size: int | None = None,
|
|
115
|
+
w13_interleave: bool = False,
|
|
116
|
+
) -> FusedMoEWeights:
|
|
117
|
+
"""Process fused moe weights to a layout that moe backend expects.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
weights: fused moe weights.
|
|
121
|
+
moe_backend: backend type the weights should be processed for.
|
|
122
|
+
w13_reorder_size: only used when backend type is GMM_TP. in order to
|
|
123
|
+
eliminate collective operations when using tensor parallelism,
|
|
124
|
+
group w13_weight into w13_reorder_size number of chuncks where each
|
|
125
|
+
chunk stores both w1 and w3 weights.
|
|
126
|
+
w13_interleave: used when loaded w13_weight is stored in interleaved
|
|
127
|
+
pattern where even index element is w1 and odd index element is w3.
|
|
128
|
+
we uninterleave so that first half is w1 and second half is w3.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
MoE weights that are processed for specified backend.
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
w13_weight = weights.w13_weight
|
|
135
|
+
w13_weight_scale = weights.w13_weight_scale
|
|
136
|
+
w13_bias = weights.w13_bias
|
|
137
|
+
w2_weight = weights.w2_weight
|
|
138
|
+
w2_weight_scale = weights.w2_weight_scale
|
|
139
|
+
w2_bias = weights.w2_bias
|
|
140
|
+
|
|
141
|
+
num_experts, hidden_size, intermediate_size = w2_weight.shape
|
|
142
|
+
|
|
143
|
+
if w13_interleave:
|
|
144
|
+
w1_weight = w13_weight[:, ::2, :]
|
|
145
|
+
w3_weight = w13_weight[:, 1::2, :]
|
|
146
|
+
w13_weight = jnp.concat([w1_weight, w3_weight], axis=1)
|
|
147
|
+
|
|
148
|
+
if w13_weight_scale is not None:
|
|
149
|
+
w1_weight_scale = w13_weight_scale[:, ::2, :]
|
|
150
|
+
w3_weight_scale = w13_weight_scale[:, 1::2, :]
|
|
151
|
+
w13_weight_scale = jnp.concat([w1_weight_scale, w3_weight_scale],
|
|
152
|
+
axis=1)
|
|
153
|
+
|
|
154
|
+
if w13_bias is not None:
|
|
155
|
+
w1_bias = w13_bias[:, ::2]
|
|
156
|
+
w3_bias = w13_bias[:, 1::2]
|
|
157
|
+
w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
|
|
158
|
+
|
|
159
|
+
if w13_weight_scale is not None:
|
|
160
|
+
w13_weight_scale = w13_weight_scale.astype(jnp.float32)
|
|
161
|
+
if w2_weight_scale is not None:
|
|
162
|
+
w2_weight_scale = w2_weight_scale.astype(jnp.float32)
|
|
163
|
+
if w13_bias is not None:
|
|
164
|
+
w13_bias = w13_bias.astype(jnp.float32)
|
|
165
|
+
if w2_bias is not None:
|
|
166
|
+
w2_bias = w2_bias.astype(jnp.float32)
|
|
167
|
+
|
|
168
|
+
match moe_backend:
|
|
169
|
+
case FusedMoEBackend.FUSED_MOE:
|
|
170
|
+
# Kernel expects:
|
|
171
|
+
# w13: (num_experts, 2, hidden_size, intermediate_size)
|
|
172
|
+
# w2: (num_experts, intermediate_size, hidden_size)
|
|
173
|
+
# Current format:
|
|
174
|
+
# w13_weight: (num_experts, 2*intermediate_size, hidden_size)
|
|
175
|
+
# w2_weight: (num_experts, hidden_size, intermediate_size)
|
|
176
|
+
|
|
177
|
+
# Fused moe kernel expects dims to be multiple of 256.
|
|
178
|
+
pad_width_intermediate_size = align_to(intermediate_size,
|
|
179
|
+
256) - intermediate_size
|
|
180
|
+
pad_width_hidden_size = align_to(hidden_size, 256) - hidden_size
|
|
181
|
+
|
|
182
|
+
w13_weight = w13_weight.reshape(
|
|
183
|
+
num_experts,
|
|
184
|
+
2,
|
|
185
|
+
intermediate_size,
|
|
186
|
+
hidden_size,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
# Transpose non-constracting dim to right most dim
|
|
190
|
+
w13_weight = jnp.swapaxes(w13_weight, 2, 3)
|
|
191
|
+
w2_weight = jnp.swapaxes(w2_weight, 1, 2)
|
|
192
|
+
|
|
193
|
+
# Workaround for JAX error "must have valid byte strides"
|
|
194
|
+
w13_weight = with_layout_constraint(w13_weight, Layout(
|
|
195
|
+
(0, 1, 2, 3)))
|
|
196
|
+
w2_weight = with_layout_constraint(w2_weight, Layout((0, 1, 2)))
|
|
197
|
+
|
|
198
|
+
w13_weight = jnp.pad(
|
|
199
|
+
w13_weight,
|
|
200
|
+
((0, 0), (0, 0), (0, pad_width_hidden_size),
|
|
201
|
+
(0, pad_width_intermediate_size)),
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
w2_weight = jnp.pad(
|
|
205
|
+
w2_weight,
|
|
206
|
+
((0, 0), (0, pad_width_intermediate_size),
|
|
207
|
+
(0, pad_width_hidden_size)),
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
if w13_weight_scale is not None:
|
|
211
|
+
w13_weight_scale = w13_weight_scale.reshape(
|
|
212
|
+
num_experts, 2, intermediate_size, 1, -1)
|
|
213
|
+
w13_weight_scale = jnp.swapaxes(w13_weight_scale, 2, 4)
|
|
214
|
+
w13_weight_scale = jnp.pad(
|
|
215
|
+
w13_weight_scale,
|
|
216
|
+
((0, 0), (0, 0), (0, pad_width_hidden_size), (0, 0),
|
|
217
|
+
(0, pad_width_intermediate_size)),
|
|
218
|
+
)
|
|
219
|
+
if w2_weight_scale is not None:
|
|
220
|
+
w2_weight_scale = w2_weight_scale.reshape(
|
|
221
|
+
num_experts, hidden_size, 1, -1)
|
|
222
|
+
w2_weight_scale = jnp.swapaxes(w2_weight_scale, 1, 3)
|
|
223
|
+
w2_weight_scale = jnp.pad(
|
|
224
|
+
w2_weight_scale,
|
|
225
|
+
((0, 0), (0, pad_width_intermediate_size), (0, 0),
|
|
226
|
+
(0, pad_width_hidden_size)),
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
if w13_bias is not None:
|
|
230
|
+
w13_bias = w13_bias.reshape(num_experts, 2, 1,
|
|
231
|
+
intermediate_size)
|
|
232
|
+
w13_bias = jnp.pad(
|
|
233
|
+
w13_bias,
|
|
234
|
+
((0, 0), (0, 0), (0, 0), (0, pad_width_intermediate_size)),
|
|
235
|
+
)
|
|
236
|
+
if w2_bias is not None:
|
|
237
|
+
w2_bias = w2_bias.reshape(num_experts, 1, hidden_size)
|
|
238
|
+
w2_bias = jnp.pad(
|
|
239
|
+
w2_bias,
|
|
240
|
+
((0, 0), (0, 0), (0, pad_width_hidden_size)),
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
case FusedMoEBackend.GMM_EP | FusedMoEBackend.GMM_TP:
|
|
244
|
+
if w13_weight_scale is not None:
|
|
245
|
+
w13_weight_scale = jnp.swapaxes(w13_weight_scale, 1, 2)
|
|
246
|
+
w13_weight_scale = jnp.expand_dims(w13_weight_scale, 2)
|
|
247
|
+
if w2_weight_scale is not None:
|
|
248
|
+
w2_weight_scale = jnp.swapaxes(w2_weight_scale, 1, 2)
|
|
249
|
+
w2_weight_scale = jnp.expand_dims(w2_weight_scale, 2)
|
|
250
|
+
if w13_bias is not None:
|
|
251
|
+
w13_bias = jnp.expand_dims(w13_bias, 1)
|
|
252
|
+
if w2_bias is not None:
|
|
253
|
+
w2_bias = jnp.expand_dims(w2_bias, 1)
|
|
254
|
+
|
|
255
|
+
if moe_backend == FusedMoEBackend.GMM_TP:
|
|
256
|
+
assert w13_reorder_size is not None
|
|
257
|
+
assert intermediate_size % w13_reorder_size == 0
|
|
258
|
+
output_sizes = [intermediate_size, intermediate_size]
|
|
259
|
+
w13_weight = reorder_concatenated_tensor_for_sharding(
|
|
260
|
+
w13_weight,
|
|
261
|
+
output_sizes,
|
|
262
|
+
w13_reorder_size,
|
|
263
|
+
dim=1,
|
|
264
|
+
)
|
|
265
|
+
if w13_weight_scale is not None:
|
|
266
|
+
w13_weight_scale = reorder_concatenated_tensor_for_sharding(
|
|
267
|
+
w13_weight_scale,
|
|
268
|
+
output_sizes,
|
|
269
|
+
w13_reorder_size,
|
|
270
|
+
dim=3,
|
|
271
|
+
)
|
|
272
|
+
if w13_bias is not None:
|
|
273
|
+
w13_bias = reorder_concatenated_tensor_for_sharding(
|
|
274
|
+
w13_bias,
|
|
275
|
+
output_sizes,
|
|
276
|
+
w13_reorder_size,
|
|
277
|
+
dim=2,
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
return FusedMoEWeights(
|
|
281
|
+
w13_weight=w13_weight,
|
|
282
|
+
w13_weight_scale=w13_weight_scale,
|
|
283
|
+
w13_bias=w13_bias,
|
|
284
|
+
w2_weight=w2_weight,
|
|
285
|
+
w2_weight_scale=w2_weight_scale,
|
|
286
|
+
w2_bias=w2_bias,
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def shard_moe_weights(
|
|
291
|
+
weights: FusedMoEWeights,
|
|
292
|
+
moe_backend: FusedMoEBackend,
|
|
293
|
+
mesh: Mesh,
|
|
294
|
+
) -> FusedMoEWeights:
|
|
295
|
+
|
|
296
|
+
match moe_backend:
|
|
297
|
+
case FusedMoEBackend.FUSED_MOE | FusedMoEBackend.GMM_EP:
|
|
298
|
+
ep_sharding = NamedSharding(mesh, P(ShardingAxisName.EXPERT))
|
|
299
|
+
weight_shardings = FusedMoEWeights(
|
|
300
|
+
w13_weight=ep_sharding,
|
|
301
|
+
w13_weight_scale=ep_sharding,
|
|
302
|
+
w13_bias=ep_sharding,
|
|
303
|
+
w2_weight=ep_sharding,
|
|
304
|
+
w2_weight_scale=ep_sharding,
|
|
305
|
+
w2_bias=ep_sharding,
|
|
306
|
+
)
|
|
307
|
+
case FusedMoEBackend.GMM_TP:
|
|
308
|
+
# When using per-channel, in_dim // block_size == 1. This means we
|
|
309
|
+
# are unable to shard w2_weight_scale along 1st dim. Therefore, we
|
|
310
|
+
# fully replicate it instead.
|
|
311
|
+
if (weights.w2_weight_scale is not None
|
|
312
|
+
and weights.w2_weight_scale.shape[1] == 1):
|
|
313
|
+
w2_weight_scale_p_spec = P()
|
|
314
|
+
else:
|
|
315
|
+
w2_weight_scale_p_spec = P(None, ShardingAxisName.MLP_TENSOR)
|
|
316
|
+
weight_shardings = FusedMoEWeights(
|
|
317
|
+
w13_weight=NamedSharding(
|
|
318
|
+
mesh,
|
|
319
|
+
P(None, ShardingAxisName.MLP_TENSOR, None),
|
|
320
|
+
), # (num_experts, out_dim, in_dim)
|
|
321
|
+
w13_weight_scale=NamedSharding(
|
|
322
|
+
mesh,
|
|
323
|
+
P(None, None, None, ShardingAxisName.MLP_TENSOR),
|
|
324
|
+
), # (num_experts, in_dim // block_size, 1, out_dim)
|
|
325
|
+
w13_bias=NamedSharding(
|
|
326
|
+
mesh,
|
|
327
|
+
P(None, None, ShardingAxisName.MLP_TENSOR),
|
|
328
|
+
), # (num_experts, 1, out_dim)
|
|
329
|
+
w2_weight=NamedSharding(
|
|
330
|
+
mesh,
|
|
331
|
+
P(None, None, ShardingAxisName.MLP_TENSOR),
|
|
332
|
+
), # (num_experts, out_dim, in_dim)
|
|
333
|
+
w2_weight_scale=NamedSharding(
|
|
334
|
+
mesh, w2_weight_scale_p_spec
|
|
335
|
+
), # (num_experts, in_dim // block_size, 1, out_dim)
|
|
336
|
+
w2_bias=NamedSharding(
|
|
337
|
+
mesh,
|
|
338
|
+
P(None, None, None),
|
|
339
|
+
), # (num_experts, 1, out_dim)
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
match moe_backend:
|
|
343
|
+
case FusedMoEBackend.FUSED_MOE:
|
|
344
|
+
weight_layouts = FusedMoEWeights(
|
|
345
|
+
w13_weight=Layout((0, 1, 2, 3)),
|
|
346
|
+
w13_weight_scale=Layout((0, 1, 2, 3, 4)),
|
|
347
|
+
w13_bias=Layout((0, 1, 2, 3)),
|
|
348
|
+
w2_weight=Layout((0, 1, 2)),
|
|
349
|
+
w2_weight_scale=Layout((0, 1, 2, 3)),
|
|
350
|
+
w2_bias=Layout((0, 1, 2)),
|
|
351
|
+
)
|
|
352
|
+
case FusedMoEBackend.GMM_TP | FusedMoEBackend.GMM_EP:
|
|
353
|
+
weight_layouts = FusedMoEWeights(
|
|
354
|
+
w13_weight=Layout((0, 1, 2)),
|
|
355
|
+
w13_weight_scale=Layout((0, 1, 2, 3)),
|
|
356
|
+
w13_bias=Layout((0, 1, 2)),
|
|
357
|
+
w2_weight=Layout((0, 1, 2)),
|
|
358
|
+
w2_weight_scale=Layout((0, 1, 2, 3)),
|
|
359
|
+
w2_bias=Layout((0, 1, 2)),
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
for field in fields(FusedMoEWeights):
|
|
363
|
+
key = field.name
|
|
364
|
+
if (weight := getattr(weights, key, None)) is not None:
|
|
365
|
+
layout = getattr(weight_layouts, key)
|
|
366
|
+
sharding = getattr(weight_shardings, key)
|
|
367
|
+
weight = jax.device_put(weight, Format(layout, sharding))
|
|
368
|
+
setattr(weights, key, weight)
|
|
369
|
+
return weights
|