tpu-inference 0.11.1.dev202511150811__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_dp_scheduler.py +899 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/fused_moe_v1_test.py +105 -0
- tests/kernels/mla_v1_test.py +396 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/conftest.py +32 -0
- tests/lora/test_bgmv.py +43 -0
- tests/lora/test_layers.py +654 -0
- tests/lora/test_lora.py +133 -0
- tests/lora/utils.py +96 -0
- tests/test_base.py +201 -0
- tests/test_envs.py +182 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +236 -0
- tpu_inference/__init__.py +34 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/core/sched/__init__.py +0 -0
- tpu_inference/core/sched/dp_scheduler.py +523 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/jax_parallel_state.py +67 -0
- tpu_inference/distributed/tpu_connector.py +728 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +107 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +362 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
- tpu_inference/kernels/mla/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/kernel.py +1349 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_interface.py +390 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/common/sharding.py +582 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +255 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +280 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +96 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
- tpu_inference/layers/jax/transformer_block.py +107 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +507 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +39 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
- tpu_inference/layers/vllm/sharding.py +230 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +311 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +444 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/gpt_oss.py +492 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
- tpu_inference/models/jax/llama3.py +375 -0
- tpu_inference/models/jax/llama4.py +629 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
- tpu_inference/models/jax/utils/weight_utils.py +529 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_platform.py +269 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +780 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +132 -0
- tpu_inference/runner/kv_cache_manager.py +479 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +217 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +248 -0
- tpu_inference/runner/structured_decoding_manager.py +88 -0
- tpu_inference/runner/tpu_runner.py +1620 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +367 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +317 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/tpu_worker.py +321 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,654 @@
|
|
|
1
|
+
import random
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
import pytest
|
|
6
|
+
import torch
|
|
7
|
+
import torchax
|
|
8
|
+
from jax.sharding import NamedSharding, PartitionSpec
|
|
9
|
+
from torchax.interop import jax_view, torch_view
|
|
10
|
+
from torchax.ops.mappings import t2j
|
|
11
|
+
from vllm.config import LoRAConfig
|
|
12
|
+
# yapf conflicts with isort for this block
|
|
13
|
+
# yapf: disable
|
|
14
|
+
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
|
|
15
|
+
LoRAMapping, MergedColumnParallelLinearWithLoRA,
|
|
16
|
+
MergedQKVParallelLinearWithLoRA,
|
|
17
|
+
QKVParallelLinearWithLoRA,
|
|
18
|
+
ReplicatedLinearWithLoRA,
|
|
19
|
+
RowParallelLinearWithLoRA)
|
|
20
|
+
# yapf: enable
|
|
21
|
+
from vllm.lora.models import LoRALayerWeights, PackedLoRALayerWeights
|
|
22
|
+
from vllm.lora.punica_wrapper import get_punica_wrapper
|
|
23
|
+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|
24
|
+
MergedColumnParallelLinear,
|
|
25
|
+
QKVParallelLinear,
|
|
26
|
+
ReplicatedLinear,
|
|
27
|
+
RowParallelLinear)
|
|
28
|
+
from vllm.model_executor.utils import set_random_seed
|
|
29
|
+
from vllm.platforms import current_platform
|
|
30
|
+
|
|
31
|
+
from tpu_inference.layers.vllm.quantization.common import JaxCommonLinearConfig
|
|
32
|
+
from tpu_inference.layers.vllm.quantization.unquantized import \
|
|
33
|
+
VllmUnquantizedLinearMethod
|
|
34
|
+
from tpu_inference.layers.vllm.sharding import _shard_module_to_tpu
|
|
35
|
+
|
|
36
|
+
from .utils import DummyLoRAManager
|
|
37
|
+
|
|
38
|
+
P = PartitionSpec
|
|
39
|
+
|
|
40
|
+
TOLERANCES = {
|
|
41
|
+
torch.float16: (5e-3, 5e-3),
|
|
42
|
+
torch.float32: (5e-3, 5e-3),
|
|
43
|
+
torch.bfloat16: (3e-2, 2e-2),
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
pytestmark = pytest.mark.skipif(not current_platform.is_tpu(),
|
|
47
|
+
reason="This test is only for TPU platform.")
|
|
48
|
+
|
|
49
|
+
# prefill stage(True) or decode stage(False)
|
|
50
|
+
STAGES = [True, False]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def check_punica_wrapper(punica_wrapper) -> bool:
|
|
54
|
+
from tpu_inference.lora.torch_punica_tpu import PunicaWrapperTPU
|
|
55
|
+
return type(punica_wrapper) is PunicaWrapperTPU
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def get_random_index_to_id(num_loras: int,
|
|
59
|
+
num_slots: int,
|
|
60
|
+
log: bool = True) -> list[Optional[int]]:
|
|
61
|
+
"""Creates a random index_to_lora_id mapping: slot[index] = lora_id.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
num_loras: The number of active loras in the mapping.
|
|
65
|
+
num_slots: The number of slots in the mapping. Must be larger
|
|
66
|
+
than num_loras.
|
|
67
|
+
log: Whether to log the output.
|
|
68
|
+
|
|
69
|
+
returns:
|
|
70
|
+
index_to_lora_id: a random index_to_lora_id mapping.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
if num_loras > num_slots:
|
|
74
|
+
raise ValueError(
|
|
75
|
+
f"num_loras is higher than num_slots: {num_loras} > {num_slots}. "
|
|
76
|
+
"num_loras must be less than or equal to num_slots.")
|
|
77
|
+
|
|
78
|
+
slots: list[Optional[int]] = [None] * num_slots
|
|
79
|
+
random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist()
|
|
80
|
+
for lora_id, slot_idx in enumerate(random_slot_selections, start=1):
|
|
81
|
+
# The slot_idx start at 1.
|
|
82
|
+
slots[slot_idx] = lora_id
|
|
83
|
+
|
|
84
|
+
if log:
|
|
85
|
+
print(f"Created lora_id_to_index mapping: {slots}.")
|
|
86
|
+
|
|
87
|
+
return slots
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def populate_loras(
|
|
91
|
+
index_to_id: list[Optional[int]],
|
|
92
|
+
lora_layer: BaseLayerWithLoRA,
|
|
93
|
+
baselayer_weights: torch.Tensor,
|
|
94
|
+
generate_embeddings_tensor: int = 0,
|
|
95
|
+
repeats: int = 1,
|
|
96
|
+
) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]:
|
|
97
|
+
"""This method populates the lora weights (lora_a and lora_b) in the lora layers (BaseLayerWithLoRA).
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
index_to_id: a list of lora ids. The index of the lora id
|
|
101
|
+
represents which memory slot the lora matrices are
|
|
102
|
+
stored in. A None value indicates a free slot.
|
|
103
|
+
lora_layer: the LoRAlayer to populate.
|
|
104
|
+
baselayer_weights: the PyTorch tensor containing the layer's
|
|
105
|
+
weights.
|
|
106
|
+
generate_embeddings_tensor: whether to generate an
|
|
107
|
+
embeddings tensor for each LoRA.
|
|
108
|
+
repeats: must only be set for column parallel packed
|
|
109
|
+
layers. Indicates the number of loras to compose
|
|
110
|
+
together to create a single lora layer.
|
|
111
|
+
|
|
112
|
+
returns:
|
|
113
|
+
lora_dict: a dictionary dict[int, LoRALayerWeights] that maps the lora ID to the corresponding lora weights.
|
|
114
|
+
sublora_dict: a dictionary dict[int, list[LoRALayerWeights]] that maps the lora ID to the corresponding lora weights.
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
# Dictionary that maps the lora ID to the
|
|
118
|
+
# corresponding lora weights.
|
|
119
|
+
lora_dict: dict[int, LoRALayerWeights] = dict()
|
|
120
|
+
|
|
121
|
+
# Dictionary that maps the lora ID to the
|
|
122
|
+
# corresponding subloras.
|
|
123
|
+
sublora_dict: dict[int, list[LoRALayerWeights]] = dict()
|
|
124
|
+
|
|
125
|
+
for slot_idx, lora_id in enumerate(index_to_id):
|
|
126
|
+
if lora_id is not None:
|
|
127
|
+
subloras: list[LoRALayerWeights] = []
|
|
128
|
+
sublora_len = baselayer_weights.shape[0] // repeats
|
|
129
|
+
for i in range(repeats):
|
|
130
|
+
sublora = DummyLoRAManager(
|
|
131
|
+
baselayer_weights.device).init_random_lora(
|
|
132
|
+
module_name=f"fake_{i}",
|
|
133
|
+
weight=baselayer_weights,
|
|
134
|
+
generate_embeddings_tensor=generate_embeddings_tensor,
|
|
135
|
+
)
|
|
136
|
+
sublora.lora_b = sublora.lora_b[(sublora_len *
|
|
137
|
+
i):(sublora_len * (i + 1)), :]
|
|
138
|
+
sublora.optimize()
|
|
139
|
+
subloras.append(sublora)
|
|
140
|
+
|
|
141
|
+
lora = PackedLoRALayerWeights.pack(
|
|
142
|
+
subloras) if repeats > 1 else subloras[0]
|
|
143
|
+
|
|
144
|
+
# Some of the layer.lora is torchax tensor so it can only do math (slice op) in the torchax env.
|
|
145
|
+
with torchax.default_env():
|
|
146
|
+
lora_layer.set_lora(
|
|
147
|
+
slot_idx,
|
|
148
|
+
lora_a=lora.lora_a,
|
|
149
|
+
lora_b=lora.lora_b,
|
|
150
|
+
embeddings_tensor=lora.embeddings_tensor,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
lora_dict[lora_id] = lora
|
|
154
|
+
sublora_dict[lora_id] = subloras
|
|
155
|
+
|
|
156
|
+
return lora_dict, sublora_dict
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def create_random_inputs(
|
|
160
|
+
active_lora_ids: list[int],
|
|
161
|
+
num_inputs: int,
|
|
162
|
+
input_size: tuple[int, ...],
|
|
163
|
+
input_range: tuple[float, float],
|
|
164
|
+
input_type: torch.dtype = torch.int,
|
|
165
|
+
device: torch.device = "cpu",
|
|
166
|
+
) -> tuple[list[torch.Tensor], list[int], list[int]]:
|
|
167
|
+
"""Creates random inputs.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
active_lora_ids: lora IDs of active lora weights.
|
|
171
|
+
num_inputs: the number of inputs to create. Or the number of requests.
|
|
172
|
+
input_size: the size of each individual input. Or the number of tokens.
|
|
173
|
+
input_range: the range of values to include in the input.
|
|
174
|
+
input_range[0] <= possible input values < input_range[1]
|
|
175
|
+
input_type: the type of values in the input.
|
|
176
|
+
|
|
177
|
+
returns:
|
|
178
|
+
inputs: a list of torch tensors of size num_inputs. Each input has shape `input_size`.
|
|
179
|
+
index_mapping: maps each input token to a lora ID.
|
|
180
|
+
prompt_mapping: maps each request to a lora ID.
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
low, high = input_range
|
|
184
|
+
|
|
185
|
+
inputs: list[torch.Tensor] = []
|
|
186
|
+
index_mapping: list[int] = []
|
|
187
|
+
prompt_mapping: list[int] = []
|
|
188
|
+
|
|
189
|
+
for _ in range(num_inputs):
|
|
190
|
+
if input_type == torch.int:
|
|
191
|
+
inputs.append(
|
|
192
|
+
torch.randint(low=int(low),
|
|
193
|
+
high=int(high),
|
|
194
|
+
size=input_size,
|
|
195
|
+
device=device))
|
|
196
|
+
else:
|
|
197
|
+
inputs.append(
|
|
198
|
+
torch.rand(size=input_size, dtype=input_type, device=device) *
|
|
199
|
+
high + low)
|
|
200
|
+
|
|
201
|
+
lora_id = random.choice(active_lora_ids)
|
|
202
|
+
index_mapping += [lora_id] * input_size[0]
|
|
203
|
+
prompt_mapping += [lora_id]
|
|
204
|
+
|
|
205
|
+
return inputs, index_mapping, prompt_mapping
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
@torch.inference_mode()
|
|
209
|
+
@pytest.mark.parametrize("num_loras", [1, 4, 9])
|
|
210
|
+
@pytest.mark.parametrize("repeats", [1, 2, 3])
|
|
211
|
+
@pytest.mark.parametrize("stage", [True, False])
|
|
212
|
+
def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
|
|
213
|
+
set_random_seed(6)
|
|
214
|
+
|
|
215
|
+
max_loras = 9
|
|
216
|
+
max_lora_rank = 8
|
|
217
|
+
lora_config = LoRAConfig(
|
|
218
|
+
max_loras=max_loras,
|
|
219
|
+
max_lora_rank=max_lora_rank,
|
|
220
|
+
fully_sharded_loras=False,
|
|
221
|
+
lora_dtype=torch.bfloat16,
|
|
222
|
+
)
|
|
223
|
+
vllm_config = dist_init
|
|
224
|
+
vllm_config.lora_config = lora_config
|
|
225
|
+
|
|
226
|
+
mesh = _create_mesh()
|
|
227
|
+
linear, lora_linear = _create_column_parallel_packed_layer(
|
|
228
|
+
repeats, vllm_config, mesh)
|
|
229
|
+
_verify_lora_linear_layer(linear, lora_linear)
|
|
230
|
+
|
|
231
|
+
# After we create the lora_config, the linear layer and the lora layer,
|
|
232
|
+
# here are the steps to do next:
|
|
233
|
+
# - create a punica wrapper.
|
|
234
|
+
# - associate the punica wrapper with the lora layer.
|
|
235
|
+
# - populate the lora matrices in the lora layer: use non-zero values for testing lora and zero values for testing the case where the layer doesn't have lora.
|
|
236
|
+
# - create inputs and lora_mapping.
|
|
237
|
+
# - update the metadata of the punica wrapper.
|
|
238
|
+
# - convert the inputs to be torchax tensors.
|
|
239
|
+
# - then run a forward on the lora layer to get the actual output.
|
|
240
|
+
# - then run a reference implementation as the expected output.
|
|
241
|
+
|
|
242
|
+
# Create a punica wrapper and associate it with the lora linear layer.
|
|
243
|
+
max_num_batched_tokens = 8192
|
|
244
|
+
max_batches = 256
|
|
245
|
+
with torchax.default_env():
|
|
246
|
+
punica_wrapper = get_punica_wrapper(max_num_batched_tokens,
|
|
247
|
+
max_batches,
|
|
248
|
+
'jax',
|
|
249
|
+
max_loras=max_loras)
|
|
250
|
+
assert check_punica_wrapper(punica_wrapper)
|
|
251
|
+
lora_linear.set_mapping(punica_wrapper)
|
|
252
|
+
|
|
253
|
+
# Populate lora matrices (lora_a and lora_b) in the lora layer.
|
|
254
|
+
index_to_id = get_random_index_to_id(num_loras, max_loras)
|
|
255
|
+
# lora_dict: lora_id -> LoRALayerWeights|PackedLoRALayerWeights
|
|
256
|
+
lora_dict, sublora_dict = populate_loras(
|
|
257
|
+
index_to_id,
|
|
258
|
+
lora_layer=lora_linear,
|
|
259
|
+
baselayer_weights=linear.weight,
|
|
260
|
+
repeats=repeats,
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
# Create inputs and lora mappings.
|
|
264
|
+
# inputs: list[torch.Tensor] of size num_inputs. inputs[i] corresponds to a request which has several token of shape=[num_tokens, 64].
|
|
265
|
+
# index_mapping: list[int]
|
|
266
|
+
# prompt_mapping: list[int]
|
|
267
|
+
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
|
268
|
+
active_lora_ids=list(lora_dict.keys()),
|
|
269
|
+
num_inputs=32,
|
|
270
|
+
input_size=(1, 64),
|
|
271
|
+
input_range=(0, 1),
|
|
272
|
+
input_type=torch.bfloat16,
|
|
273
|
+
device='cpu')
|
|
274
|
+
|
|
275
|
+
_update_punica_wrapper_metadata(punica_wrapper, index_mapping,
|
|
276
|
+
prompt_mapping, stage, index_to_id,
|
|
277
|
+
lora_config)
|
|
278
|
+
|
|
279
|
+
with torchax.default_env():
|
|
280
|
+
torchax_inputs = _shard_and_move_inputs_to_tpu(inputs, mesh)
|
|
281
|
+
actual_result = lora_linear(torchax_inputs)[0]
|
|
282
|
+
|
|
283
|
+
expected_results: list[torch.Tensor] = []
|
|
284
|
+
for input_, lora_id in zip(inputs, prompt_mapping):
|
|
285
|
+
# linear(input_) returns (output, output_bias) so we only need the first one.
|
|
286
|
+
result = linear(input_)[0]
|
|
287
|
+
subloras = sublora_dict[lora_id]
|
|
288
|
+
for i, sublora in enumerate(subloras):
|
|
289
|
+
result[:, sublora.lora_b.shape[0] * i:sublora.lora_b.shape[0] *
|
|
290
|
+
(i + 1)] += (input_ @ sublora.lora_a.T @ sublora.lora_b.T *
|
|
291
|
+
sublora.scaling)
|
|
292
|
+
expected_results.append(result)
|
|
293
|
+
expected_result = torch.cat(expected_results)
|
|
294
|
+
|
|
295
|
+
rtol, atol = TOLERANCES[actual_result.dtype]
|
|
296
|
+
with torchax.default_env():
|
|
297
|
+
actual_result_cpu = actual_result.to('cpu')
|
|
298
|
+
torch.testing.assert_close(actual_result_cpu,
|
|
299
|
+
expected_result,
|
|
300
|
+
rtol=rtol,
|
|
301
|
+
atol=atol)
|
|
302
|
+
# print(
|
|
303
|
+
# f'Output max diff: {torch.max(torch.abs(expected_result - actual_result_cpu))}'
|
|
304
|
+
# )
|
|
305
|
+
# print(
|
|
306
|
+
# f'Output mean diff: {torch.mean(torch.abs(expected_result - actual_result_cpu))}'
|
|
307
|
+
# )
|
|
308
|
+
|
|
309
|
+
# Check that resetting the lora weights succeeds
|
|
310
|
+
# Here we set all lora weight to be empty.
|
|
311
|
+
for slot_idx in range(max_loras):
|
|
312
|
+
lora_linear.reset_lora(slot_idx)
|
|
313
|
+
|
|
314
|
+
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
|
315
|
+
active_lora_ids=[0], # different from the above create_random_inputs
|
|
316
|
+
num_inputs=32,
|
|
317
|
+
input_size=(1, 64),
|
|
318
|
+
input_range=(0, 1),
|
|
319
|
+
input_type=torch.bfloat16,
|
|
320
|
+
device='cpu')
|
|
321
|
+
|
|
322
|
+
_update_punica_wrapper_metadata(punica_wrapper, index_mapping,
|
|
323
|
+
prompt_mapping, stage, index_to_id,
|
|
324
|
+
lora_config)
|
|
325
|
+
|
|
326
|
+
with torchax.default_env():
|
|
327
|
+
torchax_inputs = _shard_and_move_inputs_to_tpu(inputs, mesh)
|
|
328
|
+
actual_result = lora_linear(torchax_inputs)[0]
|
|
329
|
+
expected_result = linear(torch.cat(inputs))[0]
|
|
330
|
+
|
|
331
|
+
rtol, atol = TOLERANCES[actual_result.dtype]
|
|
332
|
+
with torchax.default_env():
|
|
333
|
+
actual_result_cpu = actual_result.to('cpu')
|
|
334
|
+
torch.testing.assert_close(actual_result_cpu,
|
|
335
|
+
expected_result,
|
|
336
|
+
rtol=rtol,
|
|
337
|
+
atol=atol)
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
@torch.inference_mode()
|
|
341
|
+
@pytest.mark.parametrize("num_loras", [1, 4, 9])
|
|
342
|
+
@pytest.mark.parametrize("layer_type", ["row", "column", "replicated"])
|
|
343
|
+
@pytest.mark.parametrize("stage", [True, False])
|
|
344
|
+
def test_linear_parallel(dist_init, num_loras, layer_type, stage) -> None:
|
|
345
|
+
set_random_seed(6)
|
|
346
|
+
|
|
347
|
+
max_loras = 9
|
|
348
|
+
max_lora_rank = 8
|
|
349
|
+
lora_config = LoRAConfig(
|
|
350
|
+
max_loras=max_loras,
|
|
351
|
+
max_lora_rank=max_lora_rank,
|
|
352
|
+
fully_sharded_loras=False,
|
|
353
|
+
lora_dtype=torch.bfloat16,
|
|
354
|
+
)
|
|
355
|
+
vllm_config = dist_init
|
|
356
|
+
vllm_config.lora_config = lora_config
|
|
357
|
+
|
|
358
|
+
mesh = _create_mesh()
|
|
359
|
+
linear, lora_linear = _create_random_linear_parallel_layer(
|
|
360
|
+
layer_type, vllm_config, mesh)
|
|
361
|
+
_verify_lora_linear_layer(linear, lora_linear)
|
|
362
|
+
|
|
363
|
+
max_num_batched_tokens = 8192
|
|
364
|
+
max_batches = 256
|
|
365
|
+
with torchax.default_env():
|
|
366
|
+
punica_wrapper = get_punica_wrapper(max_num_batched_tokens,
|
|
367
|
+
max_batches,
|
|
368
|
+
'jax',
|
|
369
|
+
max_loras=max_loras)
|
|
370
|
+
assert check_punica_wrapper(punica_wrapper)
|
|
371
|
+
lora_linear.set_mapping(punica_wrapper)
|
|
372
|
+
|
|
373
|
+
# Populate lora matrices (lora_a and lora_b) in the lora layer.
|
|
374
|
+
index_to_id = get_random_index_to_id(num_loras, max_loras)
|
|
375
|
+
# lora_dict: lora_id -> LoRALayerWeights|PackedLoRALayerWeights
|
|
376
|
+
lora_dict, sublora_dict = populate_loras(
|
|
377
|
+
index_to_id,
|
|
378
|
+
lora_layer=lora_linear,
|
|
379
|
+
baselayer_weights=linear.weight,
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
|
383
|
+
active_lora_ids=list(lora_dict.keys()),
|
|
384
|
+
num_inputs=32,
|
|
385
|
+
input_size=(1, 64),
|
|
386
|
+
input_range=(0, 1),
|
|
387
|
+
input_type=torch.bfloat16,
|
|
388
|
+
device='cpu')
|
|
389
|
+
|
|
390
|
+
_update_punica_wrapper_metadata(punica_wrapper, index_mapping,
|
|
391
|
+
prompt_mapping, stage, index_to_id,
|
|
392
|
+
lora_config)
|
|
393
|
+
|
|
394
|
+
with torchax.default_env():
|
|
395
|
+
torchax_inputs = _shard_and_move_inputs_to_tpu(inputs, mesh)
|
|
396
|
+
actual_result = lora_linear(torchax_inputs)[0]
|
|
397
|
+
|
|
398
|
+
expected_results: list[torch.Tensor] = []
|
|
399
|
+
for input_, lora_id in zip(inputs, prompt_mapping):
|
|
400
|
+
result = linear(input_)[0]
|
|
401
|
+
lora = lora_dict[lora_id]
|
|
402
|
+
lora_result = input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
|
|
403
|
+
result += lora_result
|
|
404
|
+
expected_results.append(result)
|
|
405
|
+
expected_result = torch.cat(expected_results)
|
|
406
|
+
|
|
407
|
+
rtol, atol = TOLERANCES[actual_result.dtype]
|
|
408
|
+
with torchax.default_env():
|
|
409
|
+
actual_result_cpu = actual_result.to('cpu')
|
|
410
|
+
torch.testing.assert_close(actual_result_cpu,
|
|
411
|
+
expected_result,
|
|
412
|
+
rtol=rtol,
|
|
413
|
+
atol=atol)
|
|
414
|
+
|
|
415
|
+
# Check that resetting the lora weights succeeds
|
|
416
|
+
# Here we set all lora weight to be empty.
|
|
417
|
+
for slot_idx in range(max_loras):
|
|
418
|
+
lora_linear.reset_lora(slot_idx)
|
|
419
|
+
|
|
420
|
+
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
|
421
|
+
active_lora_ids=[0], # different from the above create_random_inputs
|
|
422
|
+
num_inputs=32,
|
|
423
|
+
input_size=(1, 64),
|
|
424
|
+
input_range=(0, 1),
|
|
425
|
+
input_type=torch.bfloat16,
|
|
426
|
+
device='cpu')
|
|
427
|
+
_update_punica_wrapper_metadata(punica_wrapper, index_mapping,
|
|
428
|
+
prompt_mapping, stage, index_to_id,
|
|
429
|
+
lora_config)
|
|
430
|
+
|
|
431
|
+
with torchax.default_env():
|
|
432
|
+
torchax_inputs = _shard_and_move_inputs_to_tpu(inputs, mesh)
|
|
433
|
+
actual_result = lora_linear(torchax_inputs)[0]
|
|
434
|
+
expected_result = linear(torch.cat(inputs))[0]
|
|
435
|
+
|
|
436
|
+
rtol, atol = TOLERANCES[actual_result.dtype]
|
|
437
|
+
with torchax.default_env():
|
|
438
|
+
actual_result_cpu = actual_result.to('cpu')
|
|
439
|
+
torch.testing.assert_close(actual_result_cpu,
|
|
440
|
+
expected_result,
|
|
441
|
+
rtol=rtol,
|
|
442
|
+
atol=atol)
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
def _create_random_linear_parallel_layer(layer_type, vllm_config, mesh):
|
|
446
|
+
# We first create a base linear layer, then a lora layer to wrap it.
|
|
447
|
+
if layer_type == "row":
|
|
448
|
+
|
|
449
|
+
def _create_row_linear():
|
|
450
|
+
return RowParallelLinear(
|
|
451
|
+
64, # input_size
|
|
452
|
+
64, # output_size
|
|
453
|
+
bias=False,
|
|
454
|
+
params_dtype=torch.bfloat16)
|
|
455
|
+
|
|
456
|
+
linear = _create_row_linear()
|
|
457
|
+
linear.weight.data = torch.rand_like(linear.weight.data)
|
|
458
|
+
|
|
459
|
+
base_linear = _create_row_linear()
|
|
460
|
+
lora_linear = _create_lora_wrapper(linear,
|
|
461
|
+
base_linear,
|
|
462
|
+
RowParallelLinearWithLoRA,
|
|
463
|
+
vllm_config=vllm_config,
|
|
464
|
+
mesh=mesh)
|
|
465
|
+
elif layer_type == "column":
|
|
466
|
+
|
|
467
|
+
def _create_column_linear():
|
|
468
|
+
return ColumnParallelLinear(64,
|
|
469
|
+
64,
|
|
470
|
+
bias=False,
|
|
471
|
+
params_dtype=torch.bfloat16)
|
|
472
|
+
|
|
473
|
+
linear = _create_column_linear()
|
|
474
|
+
linear.weight.data = torch.rand_like(linear.weight.data)
|
|
475
|
+
|
|
476
|
+
base_linear = _create_column_linear()
|
|
477
|
+
lora_linear = _create_lora_wrapper(linear,
|
|
478
|
+
base_linear,
|
|
479
|
+
ColumnParallelLinearWithLoRA,
|
|
480
|
+
vllm_config=vllm_config,
|
|
481
|
+
mesh=mesh)
|
|
482
|
+
|
|
483
|
+
elif layer_type == "replicated":
|
|
484
|
+
|
|
485
|
+
def _create_replicated_linear():
|
|
486
|
+
return ReplicatedLinear(64,
|
|
487
|
+
64,
|
|
488
|
+
bias=False,
|
|
489
|
+
params_dtype=torch.bfloat16)
|
|
490
|
+
|
|
491
|
+
linear = _create_replicated_linear()
|
|
492
|
+
linear.weight.data = torch.rand_like(linear.weight.data)
|
|
493
|
+
|
|
494
|
+
base_linear = _create_replicated_linear()
|
|
495
|
+
lora_linear = _create_lora_wrapper(linear,
|
|
496
|
+
base_linear,
|
|
497
|
+
ReplicatedLinearWithLoRA,
|
|
498
|
+
vllm_config=vllm_config,
|
|
499
|
+
mesh=mesh)
|
|
500
|
+
|
|
501
|
+
else:
|
|
502
|
+
raise NotImplementedError("Unknown layer type: {}".format(layer_type))
|
|
503
|
+
|
|
504
|
+
return linear, lora_linear
|
|
505
|
+
|
|
506
|
+
|
|
507
|
+
def _create_mesh():
|
|
508
|
+
axis_names = ("data", "model")
|
|
509
|
+
devices = jax.devices()
|
|
510
|
+
mesh_shape = (1, len(devices))
|
|
511
|
+
mesh = jax.make_mesh(mesh_shape, axis_names, devices=devices)
|
|
512
|
+
return mesh
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
def _verify_lora_linear_layer(linear, lora_linear):
|
|
516
|
+
with torchax.default_env():
|
|
517
|
+
# lora_linear.weight has type torchax.tensor.Tensor
|
|
518
|
+
# BaseLinearLayerWithLoRA.weight property guarantees this.
|
|
519
|
+
# if len(devices) != 1, `reorder_concatenated_tensor_for_sharding` function may reorder the out_features dimension of the weight matrix.
|
|
520
|
+
# So the below check will fail.
|
|
521
|
+
if len(jax.devices()) == 1:
|
|
522
|
+
assert torch.equal(linear.weight.data,
|
|
523
|
+
lora_linear.weight.to('cpu'))
|
|
524
|
+
|
|
525
|
+
|
|
526
|
+
def _shard_and_move_inputs_to_tpu(inputs, mesh):
|
|
527
|
+
processed_inputs = []
|
|
528
|
+
for input in inputs:
|
|
529
|
+
# without `torch_view`, you get an error `AttributeError: 'jaxlib._jax.ArrayImpl' object has no attribute 'apply_jax_'`
|
|
530
|
+
# without `t2j`, you get an error `AttributeError: 'Tensor' object has no attribute 'apply_jax_'`
|
|
531
|
+
jax_input = torch_view(t2j(input))
|
|
532
|
+
jax_input.apply_jax_(jax.device_put,
|
|
533
|
+
NamedSharding(mesh, P(None, None)))
|
|
534
|
+
processed_inputs.append(jax_input)
|
|
535
|
+
return torch.cat(processed_inputs)
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
def _update_punica_wrapper_metadata(punica_wrapper, index_mapping,
|
|
539
|
+
prompt_mapping, stage, index_to_id,
|
|
540
|
+
lora_config):
|
|
541
|
+
lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
|
|
542
|
+
with torchax.default_env():
|
|
543
|
+
# Here we move the metadata from cpu to tpu.
|
|
544
|
+
punica_wrapper.update_metadata(
|
|
545
|
+
lora_mapping,
|
|
546
|
+
index_to_id,
|
|
547
|
+
lora_config.max_loras,
|
|
548
|
+
vocab_size=512,
|
|
549
|
+
extra_vocab_size=lora_config.lora_extra_vocab_size,
|
|
550
|
+
)
|
|
551
|
+
assert jax_view(punica_wrapper._lora_indices_per_batch).platform(
|
|
552
|
+
) == 'tpu', 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
|
|
553
|
+
assert isinstance(
|
|
554
|
+
jax_view(punica_wrapper._lora_indices_per_batch).sharding,
|
|
555
|
+
jax.sharding.SingleDeviceSharding
|
|
556
|
+
), 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
def _create_column_parallel_packed_layer(repeats, vllm_config, mesh):
|
|
560
|
+
# We first create a base linear layer, then a lora layer to wrap it.
|
|
561
|
+
if repeats == 2:
|
|
562
|
+
# In e2e, MergedColumnParallelLinear is created when we load the model. The base_layer weights are sharded and moved to TPU in VllmUnquantizedLinearMethod.process_weights_after_loading.
|
|
563
|
+
def _create_merged_column_linear():
|
|
564
|
+
return MergedColumnParallelLinear(
|
|
565
|
+
64, # input_size
|
|
566
|
+
[64] * repeats, # output_size
|
|
567
|
+
bias=False,
|
|
568
|
+
params_dtype=torch.bfloat16)
|
|
569
|
+
|
|
570
|
+
linear = _create_merged_column_linear()
|
|
571
|
+
linear.weight.data = torch.rand_like(linear.weight.data)
|
|
572
|
+
|
|
573
|
+
base_linear = _create_merged_column_linear()
|
|
574
|
+
lora_linear = _create_lora_wrapper(linear, base_linear,
|
|
575
|
+
MergedColumnParallelLinearWithLoRA,
|
|
576
|
+
vllm_config, mesh, repeats)
|
|
577
|
+
elif repeats == 3:
|
|
578
|
+
|
|
579
|
+
def _create_qkv_linear():
|
|
580
|
+
return QKVParallelLinear(64,
|
|
581
|
+
64,
|
|
582
|
+
32,
|
|
583
|
+
bias=False,
|
|
584
|
+
params_dtype=torch.bfloat16)
|
|
585
|
+
|
|
586
|
+
linear = _create_qkv_linear()
|
|
587
|
+
linear.weight.data = torch.rand_like(linear.weight.data)
|
|
588
|
+
|
|
589
|
+
base_linear = _create_qkv_linear()
|
|
590
|
+
lora_linear = _create_lora_wrapper(linear, base_linear,
|
|
591
|
+
MergedQKVParallelLinearWithLoRA,
|
|
592
|
+
vllm_config, mesh, repeats)
|
|
593
|
+
else:
|
|
594
|
+
|
|
595
|
+
def _create_qkv_linear():
|
|
596
|
+
return QKVParallelLinear(64,
|
|
597
|
+
64,
|
|
598
|
+
32,
|
|
599
|
+
bias=False,
|
|
600
|
+
params_dtype=torch.bfloat16)
|
|
601
|
+
|
|
602
|
+
linear = _create_qkv_linear()
|
|
603
|
+
linear.weight.data = torch.rand_like(linear.weight.data)
|
|
604
|
+
|
|
605
|
+
base_linear = _create_qkv_linear()
|
|
606
|
+
lora_linear = _create_lora_wrapper(linear, base_linear,
|
|
607
|
+
QKVParallelLinearWithLoRA,
|
|
608
|
+
vllm_config, mesh, repeats)
|
|
609
|
+
|
|
610
|
+
return linear, lora_linear
|
|
611
|
+
|
|
612
|
+
|
|
613
|
+
def _create_lora_wrapper(linear,
|
|
614
|
+
base_linear,
|
|
615
|
+
lora_cls,
|
|
616
|
+
vllm_config,
|
|
617
|
+
mesh,
|
|
618
|
+
repeats=1):
|
|
619
|
+
base_linear.weight.data = linear.weight.data
|
|
620
|
+
jax_config = JaxCommonLinearConfig(vllm_config, mesh, base_linear)
|
|
621
|
+
linear_method = VllmUnquantizedLinearMethod(jax_config)
|
|
622
|
+
base_linear.quant_method = linear_method
|
|
623
|
+
linear_method.process_weights_after_loading(
|
|
624
|
+
base_linear) # here base_linear.weight is moved to TPU and sharded.
|
|
625
|
+
assert jax_view(base_linear.weight).platform(
|
|
626
|
+
) == 'tpu', 'base_linear.weight should have been moved to TPU.'
|
|
627
|
+
assert not isinstance(
|
|
628
|
+
jax_view(base_linear.weight).sharding, jax.sharding.
|
|
629
|
+
SingleDeviceSharding), 'base_linear.weight should have been sharded.'
|
|
630
|
+
|
|
631
|
+
lora_linear = lora_cls(base_linear)
|
|
632
|
+
|
|
633
|
+
lora_config = vllm_config.lora_config
|
|
634
|
+
max_loras = lora_config.max_loras
|
|
635
|
+
with torchax.default_env():
|
|
636
|
+
lora_linear.create_lora_weights(max_loras, lora_config)
|
|
637
|
+
# In the e2e, the lora_layer's weight is moved to TPU in _shard_module_to_tpu.
|
|
638
|
+
_shard_module_to_tpu(lora_linear, mesh)
|
|
639
|
+
|
|
640
|
+
assert jax_view(lora_linear.lora_a_stacked[0]).platform(
|
|
641
|
+
) == 'tpu', 'lora_a_stacked should have been moved to TPU.'
|
|
642
|
+
assert not isinstance(
|
|
643
|
+
jax_view(lora_linear.lora_a_stacked[0]).sharding, jax.sharding.
|
|
644
|
+
SingleDeviceSharding), 'lora_a_stacked should have been sharded.'
|
|
645
|
+
assert jax_view(lora_linear.lora_b_stacked[0]).platform(
|
|
646
|
+
) == 'tpu', 'lora_b_stacked should have been moved to TPU.'
|
|
647
|
+
assert not isinstance(
|
|
648
|
+
jax_view(lora_linear.lora_b_stacked[0]).sharding, jax.sharding.
|
|
649
|
+
SingleDeviceSharding), 'lora_b_stacked should have been sharded.'
|
|
650
|
+
n_slices = repeats
|
|
651
|
+
assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
|
|
652
|
+
lora_linear.lora_b_stacked) == n_slices)
|
|
653
|
+
|
|
654
|
+
return lora_linear
|