tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.2rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +78 -1
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +38 -7
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +17 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +28 -5
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +74 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +89 -26
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -64
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +72 -37
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +46 -17
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +44 -17
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +42 -36
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +63 -50
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/METADATA +7 -9
- tpu_inference-0.13.2rc3.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
# This file contains end-to-end tests for structured decoding.
|
|
16
|
+
#
|
|
17
|
+
# Structured decoding allows constraining the model's output to follow a
|
|
18
|
+
# specific format, such as choosing from a predefined set of options or
|
|
19
|
+
# following a JSON schema. This is useful for classification tasks,
|
|
20
|
+
# structured data extraction, and ensuring outputs conform to expected formats.
|
|
21
|
+
|
|
22
|
+
# The tests in this file verify that:
|
|
23
|
+
# 1. Choice-based structured decoding correctly constrains output to valid options
|
|
24
|
+
# 2. The model produces deterministic results when given structured constraints
|
|
25
|
+
|
|
26
|
+
from __future__ import annotations
|
|
27
|
+
|
|
28
|
+
from vllm import LLM, SamplingParams
|
|
29
|
+
from vllm.sampling_params import StructuredOutputsParams
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def test_structured_decoding():
|
|
33
|
+
llm = LLM(model='meta-llama/Llama-3.2-1B-Instruct',
|
|
34
|
+
max_model_len=1024,
|
|
35
|
+
max_num_seqs=1,
|
|
36
|
+
enable_prefix_caching=False)
|
|
37
|
+
|
|
38
|
+
choices = ['Positive', 'Negative']
|
|
39
|
+
structured_outputs_params = StructuredOutputsParams(choice=choices)
|
|
40
|
+
sampling_params = SamplingParams(
|
|
41
|
+
structured_outputs=structured_outputs_params)
|
|
42
|
+
outputs = llm.generate(
|
|
43
|
+
prompts="Classify this sentiment: tpu-inference is wonderful!",
|
|
44
|
+
sampling_params=sampling_params,
|
|
45
|
+
)
|
|
46
|
+
assert outputs[0].outputs[0].text in choices
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import unittest
|
|
16
|
+
from unittest.mock import MagicMock, patch
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# Mock VllmConfig and its nested configs to avoid dependencies on the actual
|
|
20
|
+
# classes, which can be complex to instantiate for testing.
|
|
21
|
+
class MockVllmConfig:
|
|
22
|
+
|
|
23
|
+
def __init__(self):
|
|
24
|
+
self.parallel_config = MagicMock()
|
|
25
|
+
self.parallel_config.world_size = 4
|
|
26
|
+
self.parallel_config.tensor_parallel_size = 2
|
|
27
|
+
self.parallel_config.pipeline_parallel_size = 1
|
|
28
|
+
self.parallel_config.ray_workers_use_nsight = False
|
|
29
|
+
self.parallel_config.placement_group = None
|
|
30
|
+
self.parallel_config.max_parallel_loading_workers = None
|
|
31
|
+
|
|
32
|
+
self.sharding_config = MagicMock()
|
|
33
|
+
self.sharding_config.total_devices = 2
|
|
34
|
+
|
|
35
|
+
self.model_config = MagicMock()
|
|
36
|
+
self.cache_config = MagicMock()
|
|
37
|
+
self.lora_config = MagicMock()
|
|
38
|
+
self.load_config = MagicMock()
|
|
39
|
+
self.scheduler_config = MagicMock()
|
|
40
|
+
self.speculative_config = MagicMock()
|
|
41
|
+
self.prompt_adapter_config = MagicMock()
|
|
42
|
+
self.observability_config = MagicMock()
|
|
43
|
+
self.device_config = MagicMock()
|
|
44
|
+
self.ec_transfer_config = MagicMock()
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@patch(
|
|
48
|
+
"vllm.v1.executor.ray_distributed_executor.RayDistributedExecutor.__init__",
|
|
49
|
+
lambda x, y: None)
|
|
50
|
+
@patch("tpu_inference.executors.ray_distributed_executor.envs")
|
|
51
|
+
@patch("tpu_inference.executors.ray_distributed_executor.ray")
|
|
52
|
+
@patch("tpu_inference.executors.ray_distributed_executor.current_platform")
|
|
53
|
+
@patch("tpu_inference.executors.ray_distributed_executor.get_ip",
|
|
54
|
+
return_value="127.0.0.1")
|
|
55
|
+
@patch("tpu_inference.executors.ray_distributed_executor.get_open_port",
|
|
56
|
+
return_value=12345)
|
|
57
|
+
@patch(
|
|
58
|
+
"tpu_inference.executors.ray_distributed_executor.available_resources_per_node"
|
|
59
|
+
)
|
|
60
|
+
@patch("tpu_inference.executors.ray_distributed_executor._wait_until_pg_ready")
|
|
61
|
+
class TestTpuRayDistributedExecutor(unittest.TestCase):
|
|
62
|
+
|
|
63
|
+
def setUp(self):
|
|
64
|
+
# Import the class under test inside the test method to ensure
|
|
65
|
+
# patches are applied.
|
|
66
|
+
from tpu_inference.executors.ray_distributed_executor import \
|
|
67
|
+
RayDistributedExecutor
|
|
68
|
+
self.RayDistributedExecutor = RayDistributedExecutor
|
|
69
|
+
|
|
70
|
+
self.vllm_config = MockVllmConfig()
|
|
71
|
+
# Reset placement group for each test as it might be modified.
|
|
72
|
+
self.vllm_config.parallel_config.placement_group = None
|
|
73
|
+
self.vllm_config.kv_transfer_config = None
|
|
74
|
+
|
|
75
|
+
def test_init_executor_basic_flow(self, mock_wait_until_pg_ready,
|
|
76
|
+
mock_avail_resources, mock_get_port,
|
|
77
|
+
mock_get_ip, mock_platform, mock_ray,
|
|
78
|
+
mock_envs):
|
|
79
|
+
# --- Setup mocks ---
|
|
80
|
+
mock_envs.VLLM_USE_RAY_COMPILED_DAG = True
|
|
81
|
+
mock_envs.VLLM_USE_RAY_SPMD_WORKER = True
|
|
82
|
+
mock_envs.VLLM_RAY_BUNDLE_INDICES = ""
|
|
83
|
+
|
|
84
|
+
mock_platform.ray_device_key = "TPU"
|
|
85
|
+
mock_platform.device_name = "tpu"
|
|
86
|
+
mock_platform.device_control_env_var = "TPU_VISIBLE_CHIPS"
|
|
87
|
+
mock_platform.additional_env_vars = []
|
|
88
|
+
|
|
89
|
+
mock_ray.is_initialized.return_value = False
|
|
90
|
+
mock_ray.nodes.return_value = [{"Resources": {"TPU": 4}}]
|
|
91
|
+
mock_ray.get_runtime_context.return_value.get_node_id.return_value = "node_1"
|
|
92
|
+
mock_avail_resources.return_value = {"node_1": {"TPU": 4}}
|
|
93
|
+
|
|
94
|
+
mock_wait_until_pg_ready.return_value = None
|
|
95
|
+
|
|
96
|
+
mock_placement_group = MagicMock()
|
|
97
|
+
mock_placement_group.bundle_specs = [{"TPU": 1}] * 4
|
|
98
|
+
mock_ray.util.placement_group.return_value = mock_placement_group
|
|
99
|
+
|
|
100
|
+
mock_worker = MagicMock()
|
|
101
|
+
mock_worker.get_node_and_gpu_ids.remote.return_value = [("node_1",
|
|
102
|
+
[0, 1, 2, 3])]
|
|
103
|
+
mock_ray.remote.return_value.remote.return_value = mock_worker
|
|
104
|
+
|
|
105
|
+
# Simulate remote calls on the worker
|
|
106
|
+
mock_ray.get.side_effect = [
|
|
107
|
+
["127.0.0.1"] * 4, # worker_ips
|
|
108
|
+
*[("node_1", [i]) for i in range(4)] # worker_node_and_tpu_ids
|
|
109
|
+
]
|
|
110
|
+
|
|
111
|
+
executor = self.RayDistributedExecutor(self.vllm_config)
|
|
112
|
+
# Members of the parent class
|
|
113
|
+
executor.uses_ray = True
|
|
114
|
+
executor.vllm_config = self.vllm_config
|
|
115
|
+
executor.parallel_config = self.vllm_config.parallel_config
|
|
116
|
+
executor.collective_rpc = MagicMock()
|
|
117
|
+
executor.collective_rpc.return_value = None
|
|
118
|
+
|
|
119
|
+
# --- Initialization ---
|
|
120
|
+
executor._init_executor()
|
|
121
|
+
|
|
122
|
+
# --- Assertions ---
|
|
123
|
+
mock_ray.init.assert_called_once()
|
|
124
|
+
self.assertIsNotNone(executor.parallel_config.placement_group)
|
|
125
|
+
self.assertEqual(len(executor.workers), 4)
|
|
126
|
+
|
|
127
|
+
def test_initialize_ray_cluster_no_tpu_on_driver_raises_error(
|
|
128
|
+
self, mock_wait_until_pg_ready, mock_avail_resources,
|
|
129
|
+
mock_get_port, mock_get_ip, mock_platform, mock_ray, mock_envs):
|
|
130
|
+
# --- Setup Mocks ---
|
|
131
|
+
mock_platform.ray_device_key = "TPU"
|
|
132
|
+
mock_platform.device_name = "tpu"
|
|
133
|
+
|
|
134
|
+
mock_ray.is_initialized.return_value = False
|
|
135
|
+
mock_ray.nodes.return_value = [{"Resources": {"TPU": 4}}]
|
|
136
|
+
mock_ray.get_runtime_context.return_value.get_node_id.return_value = "driver_node"
|
|
137
|
+
# Simulate no TPUs on the driver node
|
|
138
|
+
mock_avail_resources.return_value = {
|
|
139
|
+
"driver_node": {
|
|
140
|
+
"CPU": 8
|
|
141
|
+
},
|
|
142
|
+
"worker_node": {
|
|
143
|
+
"TPU": 4
|
|
144
|
+
}
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
executor = self.RayDistributedExecutor(self.vllm_config)
|
|
148
|
+
executor.vllm_config = self.vllm_config
|
|
149
|
+
executor.parallel_config = self.vllm_config.parallel_config
|
|
150
|
+
|
|
151
|
+
# --- Test and Assert ---
|
|
152
|
+
with self.assertRaisesRegex(ValueError,
|
|
153
|
+
"Current node has no TPU available"):
|
|
154
|
+
executor._initialize_ray_cluster()
|
|
155
|
+
|
|
156
|
+
def test_init_workers_ray_sorts_correctly(self, mock_wait_until_pg_ready,
|
|
157
|
+
mock_avail_resources,
|
|
158
|
+
mock_get_port, mock_get_ip,
|
|
159
|
+
mock_platform, mock_ray,
|
|
160
|
+
mock_envs):
|
|
161
|
+
# --- Setup Mocks ---
|
|
162
|
+
mock_envs.VLLM_RAY_BUNDLE_INDICES = ""
|
|
163
|
+
mock_platform.ray_device_key = "TPU"
|
|
164
|
+
mock_get_ip.return_value = "10.0.0.1" # Driver IP
|
|
165
|
+
|
|
166
|
+
mock_pg = MagicMock()
|
|
167
|
+
mock_pg.bundle_specs = [{"TPU": 1}] * 4
|
|
168
|
+
|
|
169
|
+
mock_workers = [MagicMock() for _ in range(4)]
|
|
170
|
+
mock_ray.remote.return_value.return_value.remote.side_effect = mock_workers
|
|
171
|
+
|
|
172
|
+
# Simulate IPs for workers created with ranks 0, 1, 2, 3
|
|
173
|
+
worker_ips = ["10.0.0.2", "10.0.0.3", "10.0.0.1", "10.0.0.4"]
|
|
174
|
+
mock_ray.get.side_effect = [
|
|
175
|
+
worker_ips, # worker_ips
|
|
176
|
+
*[('node_1', ['0', '1', '2', '3']),
|
|
177
|
+
('node_2', ['4', '5', '6', '7']),
|
|
178
|
+
('node_3', ['8', '9', '10', '11']),
|
|
179
|
+
('node_4', ['12', '13', '14', '15'])] # worker_node_and_tpu_ids
|
|
180
|
+
]
|
|
181
|
+
|
|
182
|
+
executor = self.RayDistributedExecutor(self.vllm_config)
|
|
183
|
+
executor.use_ray_spmd_worker = True
|
|
184
|
+
executor.parallel_config = self.vllm_config.parallel_config
|
|
185
|
+
executor.vllm_config = self.vllm_config
|
|
186
|
+
executor.parallel_config.ray_workers_use_nsight = False
|
|
187
|
+
executor.collective_rpc = MagicMock()
|
|
188
|
+
executor.collective_rpc.return_value = None
|
|
189
|
+
|
|
190
|
+
# --- Call method under test ---
|
|
191
|
+
executor._init_workers_ray(mock_pg)
|
|
192
|
+
|
|
193
|
+
# --- Assertions ---
|
|
194
|
+
# Expected sorted order of workers: driver, then by IP
|
|
195
|
+
# Original workers: 0 (10.0.0.2), 1 (10.0.0.3), 2 (10.0.0.1), 3 (10.0.0.2)
|
|
196
|
+
# Sorted workers: 2 (driver), 0, 3 (same IP), 1
|
|
197
|
+
self.assertEqual(executor.workers, [
|
|
198
|
+
mock_workers[2], mock_workers[0], mock_workers[1], mock_workers[3]
|
|
199
|
+
])
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from types import SimpleNamespace
|
|
16
|
+
from unittest.mock import MagicMock, patch
|
|
17
|
+
|
|
18
|
+
import jax
|
|
19
|
+
import jax.numpy as jnp
|
|
20
|
+
import numpy as np
|
|
21
|
+
import pytest
|
|
22
|
+
from flax import nnx
|
|
23
|
+
from flax.typing import PRNGKey
|
|
24
|
+
from jax.sharding import Mesh
|
|
25
|
+
|
|
26
|
+
from tpu_inference.experimental.llama3_jax_stashed import (Llama3WeightLoader,
|
|
27
|
+
LlamaForCausalLM)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class MockParam:
|
|
31
|
+
"""A mock for a parameter used in the Llama model."""
|
|
32
|
+
|
|
33
|
+
def __init__(self, shape=(32, 128)):
|
|
34
|
+
self.value = SimpleNamespace(shape=shape)
|
|
35
|
+
# The sharding spec is accessed during weight loading
|
|
36
|
+
self.sharding = SimpleNamespace(spec=None)
|
|
37
|
+
|
|
38
|
+
# Allow the mock parameter's value to be updated
|
|
39
|
+
def __setattr__(self, name, value):
|
|
40
|
+
if name == "value":
|
|
41
|
+
self.__dict__[name] = value
|
|
42
|
+
else:
|
|
43
|
+
super().__setattr__(name, value)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class MockVllmConfig:
|
|
47
|
+
"""A mock VllmConfig sufficient for testing the Llama3 model."""
|
|
48
|
+
|
|
49
|
+
def __init__(self,
|
|
50
|
+
model_name: str,
|
|
51
|
+
random_weights: bool = False,
|
|
52
|
+
tensor_parallelism: int = 1):
|
|
53
|
+
self.model_config = SimpleNamespace(model=model_name,
|
|
54
|
+
dtype="bfloat16",
|
|
55
|
+
hf_overrides={},
|
|
56
|
+
override_generation_config={})
|
|
57
|
+
self.load_config = MagicMock()
|
|
58
|
+
self.additional_config = {
|
|
59
|
+
"random_weights": random_weights,
|
|
60
|
+
"sharding": {
|
|
61
|
+
"sharding_strategy": {
|
|
62
|
+
"tensor_parallelism": tensor_parallelism
|
|
63
|
+
}
|
|
64
|
+
}
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
# NOTE (jacobplatin): we could add a quantized KV cache test, but
|
|
68
|
+
# we'll skip it for now.
|
|
69
|
+
self.cache_config = MagicMock(cache_dtype="auto")
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@pytest.fixture(scope="module")
|
|
73
|
+
def mesh():
|
|
74
|
+
"""
|
|
75
|
+
Creates a mesh with all required axes for testing.
|
|
76
|
+
FIX: The sharding logic expects 'data', 'model', and 'expert' axes.
|
|
77
|
+
This creates a 3D mesh to satisfy the sharding rules, even on a single device.
|
|
78
|
+
"""
|
|
79
|
+
if not jax.devices():
|
|
80
|
+
pytest.skip("No JAX devices available for mesh creation.")
|
|
81
|
+
|
|
82
|
+
devices = np.array(jax.local_devices())
|
|
83
|
+
# Reshape devices into a 3D array to name 3 axes: data, model, and expert.
|
|
84
|
+
# The 'model' and 'expert' axes will have a size of 1.
|
|
85
|
+
num_devices = len(devices)
|
|
86
|
+
device_mesh = devices.reshape((num_devices, 1, 1))
|
|
87
|
+
|
|
88
|
+
with Mesh(device_mesh, axis_names=('data', 'model', 'expert')) as m:
|
|
89
|
+
yield m
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@pytest.fixture
|
|
93
|
+
def rng() -> PRNGKey:
|
|
94
|
+
"""Provides a reusable JAX PRNGKey."""
|
|
95
|
+
return jax.random.PRNGKey(42)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@pytest.fixture
|
|
99
|
+
def mock_vllm_config_8b() -> MockVllmConfig:
|
|
100
|
+
return MockVllmConfig(model_name="meta-llama/Llama-3-8B")
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@pytest.fixture
|
|
104
|
+
def mock_vllm_config_70b() -> MockVllmConfig:
|
|
105
|
+
return MockVllmConfig(model_name="meta-llama/Llama-3-70B-Instruct")
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
@pytest.fixture
|
|
109
|
+
def mock_vllm_config_unknown() -> MockVllmConfig:
|
|
110
|
+
return MockVllmConfig(model_name="some-other-model")
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
# --- Test Cases ---
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class TestLlamaForCausalLM:
|
|
117
|
+
"""Tests for the main LlamaForCausalLM model class."""
|
|
118
|
+
|
|
119
|
+
def test_init_8b_variant(self, mock_vllm_config_8b, rng, mesh):
|
|
120
|
+
"""Tests correct parameter detection for the 8B model variant."""
|
|
121
|
+
model = LlamaForCausalLM(mock_vllm_config_8b, rng, mesh)
|
|
122
|
+
assert model.hidden_size == 4096
|
|
123
|
+
assert "8b" in model.vllm_config.model_config.model.lower()
|
|
124
|
+
|
|
125
|
+
def test_init_70b_variant(self, mock_vllm_config_70b, rng, mesh):
|
|
126
|
+
"""Tests correct parameter detection for the 70B model variant."""
|
|
127
|
+
model = nnx.eval_shape(
|
|
128
|
+
lambda: LlamaForCausalLM(mock_vllm_config_70b, rng, mesh))
|
|
129
|
+
assert model.hidden_size == 8192
|
|
130
|
+
assert "70b" in model.vllm_config.model_config.model.lower()
|
|
131
|
+
|
|
132
|
+
def test_init_unknown_variant_raises_error(self, mock_vllm_config_unknown,
|
|
133
|
+
rng, mesh):
|
|
134
|
+
"""Tests that an unknown model variant raises a ValueError."""
|
|
135
|
+
with pytest.raises(ValueError,
|
|
136
|
+
match="Could not determine Llama3 variant"):
|
|
137
|
+
LlamaForCausalLM(mock_vllm_config_unknown, rng, mesh)
|
|
138
|
+
|
|
139
|
+
def test_create_model_with_random_weights(self, mock_vllm_config_8b, rng,
|
|
140
|
+
mesh):
|
|
141
|
+
"""
|
|
142
|
+
Tests that random weight initialization creates concrete, non-zero-variance arrays.
|
|
143
|
+
"""
|
|
144
|
+
with jax.set_mesh(mesh):
|
|
145
|
+
model = LlamaForCausalLM(vllm_config=mock_vllm_config_8b,
|
|
146
|
+
rng=rng,
|
|
147
|
+
mesh=mesh,
|
|
148
|
+
force_random_weights=True)
|
|
149
|
+
|
|
150
|
+
embedding_weight = model.embedder.input_embedding_table_VD.value
|
|
151
|
+
attention_q_kernel = model.layers[0].attn.kernel_q_proj_DNH.value
|
|
152
|
+
final_norm_scale = model.final_norm.scale.value
|
|
153
|
+
|
|
154
|
+
assert isinstance(embedding_weight, jax.Array)
|
|
155
|
+
assert isinstance(attention_q_kernel, jax.Array)
|
|
156
|
+
assert isinstance(final_norm_scale, jax.Array)
|
|
157
|
+
|
|
158
|
+
assert jnp.std(embedding_weight) > 0
|
|
159
|
+
assert jnp.std(attention_q_kernel) > 0
|
|
160
|
+
|
|
161
|
+
assert jnp.all(final_norm_scale == 1.0)
|
|
162
|
+
|
|
163
|
+
@patch("tpu_inference.experimental.llama3_jax_stashed.Llama3WeightLoader")
|
|
164
|
+
def test_load_weights_called_correctly(self, mock_loader_cls, rng, mesh):
|
|
165
|
+
"""Tests that the weight loader is called correctly for checkpoint loading."""
|
|
166
|
+
vllm_config = MockVllmConfig(model_name="llama3-8b",
|
|
167
|
+
random_weights=False)
|
|
168
|
+
model = LlamaForCausalLM(vllm_config, rng, mesh)
|
|
169
|
+
|
|
170
|
+
mock_loader_instance = MagicMock()
|
|
171
|
+
mock_loader_cls.return_value = mock_loader_instance
|
|
172
|
+
model.load_weights(rng, cache_dir="/tmp/cache")
|
|
173
|
+
mock_loader_cls.assert_called_once_with(vllm_config=vllm_config,
|
|
174
|
+
hidden_size=4096,
|
|
175
|
+
attn_heads=32,
|
|
176
|
+
num_key_value_heads=8,
|
|
177
|
+
attn_head_dim=128)
|
|
178
|
+
mock_loader_instance.load_weights.assert_called_once_with(model)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
class TestLlama3WeightLoader:
|
|
182
|
+
"""Tests for the Llama3WeightLoader class."""
|
|
183
|
+
|
|
184
|
+
@pytest.fixture
|
|
185
|
+
def weight_loader(self):
|
|
186
|
+
# Patch the superclass's setup to isolate the Llama3 loader's logic
|
|
187
|
+
return Llama3WeightLoader(vllm_config=MockVllmConfig("test-model"),
|
|
188
|
+
hidden_size=32,
|
|
189
|
+
attn_heads=4,
|
|
190
|
+
num_key_value_heads=2,
|
|
191
|
+
attn_head_dim=8)
|
|
192
|
+
|
|
193
|
+
def test_load_weights_transformation(self, weight_loader, rng, mesh):
|
|
194
|
+
"""Tests that weights are correctly reshaped, transposed, and loaded."""
|
|
195
|
+
vllm_config = MockVllmConfig("llama3-8b-small-test",
|
|
196
|
+
random_weights=False)
|
|
197
|
+
|
|
198
|
+
# Create a model instance but override its config for the test.
|
|
199
|
+
model = LlamaForCausalLM(vllm_config, rng, mesh)
|
|
200
|
+
|
|
201
|
+
with patch(
|
|
202
|
+
"tpu_inference.experimental.llama3_jax_stashed.load_hf_weights"
|
|
203
|
+
) as mock_load:
|
|
204
|
+
# This will now pass after the code fix
|
|
205
|
+
weight_loader.load_weights(model)
|
|
206
|
+
|
|
207
|
+
# Assert that shard_put was called with the correctly transposed weight
|
|
208
|
+
mock_load.assert_called_once()
|
tests/kernels/__init__.py
CHANGED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
import jax
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
from absl.testing import absltest, parameterized
|
|
8
|
+
from jax._src import test_util as jtu
|
|
9
|
+
|
|
10
|
+
from tpu_inference import utils
|
|
11
|
+
from tpu_inference.kernels.collectives import all_gather_matmul
|
|
12
|
+
|
|
13
|
+
jax.config.parse_flags_with_absl()
|
|
14
|
+
|
|
15
|
+
P = jax.sharding.PartitionSpec
|
|
16
|
+
|
|
17
|
+
SpongeDir: str | None = os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', None)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@jtu.with_config(jax_numpy_dtype_promotion='standard')
|
|
21
|
+
class AllGatherMatmulTest(jtu.JaxTestCase):
|
|
22
|
+
|
|
23
|
+
@parameterized.product(
|
|
24
|
+
grid_k=[1, 2, 3],
|
|
25
|
+
grid_n=[1, 2, 3],
|
|
26
|
+
rhs_transpose=[True, False],
|
|
27
|
+
)
|
|
28
|
+
def test_all_gather_matmul(self, grid_k, grid_n, rhs_transpose):
|
|
29
|
+
if jax.device_count() != 8:
|
|
30
|
+
self.skipTest('Not enough devices for test')
|
|
31
|
+
|
|
32
|
+
axis_name = 'x'
|
|
33
|
+
num_devices = jax.device_count()
|
|
34
|
+
mesh = utils.make_optimized_mesh((num_devices, ), (axis_name, ))
|
|
35
|
+
bk, bn = 1024, 1024
|
|
36
|
+
m, k, n = 1024, bk * grid_k, bn * grid_n * num_devices
|
|
37
|
+
|
|
38
|
+
# Run the test 10 times to expose race conditions as much as possible.
|
|
39
|
+
for i in range(10):
|
|
40
|
+
# Create input data
|
|
41
|
+
prng_key = jax.random.key(1234 + i)
|
|
42
|
+
k0, k1 = jax.random.split(prng_key, 2)
|
|
43
|
+
x = jax.random.normal(k0, (m, k), dtype=jnp.bfloat16)
|
|
44
|
+
y_shape = (n, k) if rhs_transpose else (k, n)
|
|
45
|
+
y_sharding = P(axis_name, None) if rhs_transpose else P(
|
|
46
|
+
None, axis_name)
|
|
47
|
+
y = jax.random.normal(k1, y_shape, dtype=jnp.bfloat16)
|
|
48
|
+
sharded_x = jax.device_put(
|
|
49
|
+
x, jax.sharding.NamedSharding(mesh, P(axis_name, None)))
|
|
50
|
+
sharded_y = jax.device_put(
|
|
51
|
+
y, jax.sharding.NamedSharding(mesh, y_sharding))
|
|
52
|
+
|
|
53
|
+
# Run the all_gather_matmul function
|
|
54
|
+
output = all_gather_matmul.all_gather_matmul(
|
|
55
|
+
sharded_x,
|
|
56
|
+
sharded_y,
|
|
57
|
+
mesh,
|
|
58
|
+
axis_name,
|
|
59
|
+
bk=bk,
|
|
60
|
+
bn=bn,
|
|
61
|
+
rhs_transpose=rhs_transpose,
|
|
62
|
+
)
|
|
63
|
+
y_for_dot = sharded_y.T if rhs_transpose else sharded_y
|
|
64
|
+
expected_output = jnp.dot(sharded_x, y_for_dot)
|
|
65
|
+
self.assertAllClose(output, expected_output, atol=1e-2, rtol=1e-2)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
if __name__ == "__main__":
|
|
69
|
+
absltest.main(testLoader=jtu.JaxTestLoader())
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import jax
|
|
2
16
|
import jax.numpy as jnp
|
|
3
17
|
import numpy as np
|