tpu-inference 0.0.1rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (174) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +374 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +648 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +88 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +203 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +235 -0
  27. tpu_inference/__init__.py +53 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +49 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +727 -0
  37. tpu_inference/distributed/utils.py +60 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +160 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +382 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1566 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1501 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1603 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +396 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +469 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +110 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +331 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +368 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +310 -0
  120. tpu_inference/models/__init__.py +0 -0
  121. tpu_inference/models/common/__init__.py +0 -0
  122. tpu_inference/models/common/model_loader.py +478 -0
  123. tpu_inference/models/jax/__init__.py +0 -0
  124. tpu_inference/models/jax/deepseek_v3.py +868 -0
  125. tpu_inference/models/jax/gpt_oss.py +492 -0
  126. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  127. tpu_inference/models/jax/llama3.py +376 -0
  128. tpu_inference/models/jax/llama4.py +629 -0
  129. tpu_inference/models/jax/llama_eagle3.py +336 -0
  130. tpu_inference/models/jax/llama_guard_4.py +361 -0
  131. tpu_inference/models/jax/qwen2.py +376 -0
  132. tpu_inference/models/jax/qwen2_5_vl.py +1218 -0
  133. tpu_inference/models/jax/qwen3.py +303 -0
  134. tpu_inference/models/jax/utils/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/file_utils.py +96 -0
  136. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  137. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  138. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  139. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  140. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  141. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  142. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  143. tpu_inference/models/jax/utils/quantization/quantization_utils.py +650 -0
  144. tpu_inference/models/jax/utils/weight_utils.py +584 -0
  145. tpu_inference/models/vllm/__init__.py +0 -0
  146. tpu_inference/models/vllm/vllm_model_wrapper.py +293 -0
  147. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  148. tpu_inference/platforms/__init__.py +2 -0
  149. tpu_inference/platforms/tpu_platform.py +275 -0
  150. tpu_inference/runner/__init__.py +0 -0
  151. tpu_inference/runner/block_table.py +122 -0
  152. tpu_inference/runner/compilation_manager.py +865 -0
  153. tpu_inference/runner/input_batch.py +435 -0
  154. tpu_inference/runner/kv_cache.py +132 -0
  155. tpu_inference/runner/kv_cache_manager.py +478 -0
  156. tpu_inference/runner/lora_utils.py +92 -0
  157. tpu_inference/runner/multimodal_manager.py +217 -0
  158. tpu_inference/runner/persistent_batch_manager.py +282 -0
  159. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  160. tpu_inference/runner/structured_decoding_manager.py +87 -0
  161. tpu_inference/runner/tpu_runner.py +1744 -0
  162. tpu_inference/runner/utils.py +426 -0
  163. tpu_inference/spec_decode/__init__.py +0 -0
  164. tpu_inference/spec_decode/jax/__init__.py +0 -0
  165. tpu_inference/spec_decode/jax/eagle3.py +417 -0
  166. tpu_inference/tpu_info.py +78 -0
  167. tpu_inference/utils.py +340 -0
  168. tpu_inference/worker/__init__.py +0 -0
  169. tpu_inference/worker/tpu_worker.py +458 -0
  170. tpu_inference-0.0.1rc1.dist-info/METADATA +108 -0
  171. tpu_inference-0.0.1rc1.dist-info/RECORD +174 -0
  172. tpu_inference-0.0.1rc1.dist-info/WHEEL +5 -0
  173. tpu_inference-0.0.1rc1.dist-info/licenses/LICENSE +201 -0
  174. tpu_inference-0.0.1rc1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,648 @@
1
+ import random
2
+ from typing import Optional
3
+
4
+ import jax
5
+ import pytest
6
+ import torch
7
+ import torchax
8
+ from jax.sharding import NamedSharding, PartitionSpec
9
+ from torchax.interop import jax_view, torch_view
10
+ from torchax.ops.mappings import t2j
11
+ from vllm.config import LoRAConfig
12
+ # yapf conflicts with isort for this block
13
+ # yapf: disable
14
+ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
15
+ LoRAMapping, MergedColumnParallelLinearWithLoRA,
16
+ MergedQKVParallelLinearWithLoRA,
17
+ QKVParallelLinearWithLoRA,
18
+ ReplicatedLinearWithLoRA,
19
+ RowParallelLinearWithLoRA)
20
+ # yapf: enable
21
+ from vllm.lora.models import LoRALayerWeights, PackedLoRALayerWeights
22
+ from vllm.lora.punica_wrapper import get_punica_wrapper
23
+ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
24
+ MergedColumnParallelLinear,
25
+ QKVParallelLinear,
26
+ ReplicatedLinear,
27
+ RowParallelLinear)
28
+ from vllm.model_executor.utils import set_random_seed
29
+ from vllm.platforms import current_platform
30
+
31
+ from tpu_inference.layers.vllm.quantization.common import JaxCommonLinearConfig
32
+ from tpu_inference.layers.vllm.quantization.unquantized import \
33
+ VllmUnquantizedLinearMethod
34
+ from tpu_inference.layers.vllm.sharding import _shard_module_to_tpu
35
+
36
+ from .utils import DummyLoRAManager
37
+
38
+ P = PartitionSpec
39
+
40
+ TOLERANCES = {
41
+ torch.float16: (5e-3, 5e-3),
42
+ torch.float32: (5e-3, 5e-3),
43
+ torch.bfloat16: (3e-2, 2e-2),
44
+ }
45
+
46
+ pytestmark = pytest.mark.skipif(not current_platform.is_tpu(),
47
+ reason="This test is only for TPU platform.")
48
+
49
+ # prefill stage(True) or decode stage(False)
50
+ STAGES = [True, False]
51
+
52
+
53
+ def check_punica_wrapper(punica_wrapper) -> bool:
54
+ from tpu_inference.lora.torch_punica_tpu import PunicaWrapperTPU
55
+ return type(punica_wrapper) is PunicaWrapperTPU
56
+
57
+
58
+ def get_random_index_to_id(num_loras: int,
59
+ num_slots: int,
60
+ log: bool = True) -> list[Optional[int]]:
61
+ """Creates a random index_to_lora_id mapping: slot[index] = lora_id.
62
+
63
+ Args:
64
+ num_loras: The number of active loras in the mapping.
65
+ num_slots: The number of slots in the mapping. Must be larger
66
+ than num_loras.
67
+ log: Whether to log the output.
68
+
69
+ returns:
70
+ index_to_lora_id: a random index_to_lora_id mapping.
71
+ """
72
+
73
+ if num_loras > num_slots:
74
+ raise ValueError(
75
+ f"num_loras is higher than num_slots: {num_loras} > {num_slots}. "
76
+ "num_loras must be less than or equal to num_slots.")
77
+
78
+ slots: list[Optional[int]] = [None] * num_slots
79
+ random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist()
80
+ for lora_id, slot_idx in enumerate(random_slot_selections, start=1):
81
+ # The slot_idx start at 1.
82
+ slots[slot_idx] = lora_id
83
+
84
+ if log:
85
+ print(f"Created lora_id_to_index mapping: {slots}.")
86
+
87
+ return slots
88
+
89
+
90
+ def populate_loras(
91
+ index_to_id: list[Optional[int]],
92
+ lora_layer: BaseLayerWithLoRA,
93
+ baselayer_weights: torch.Tensor,
94
+ repeats: int = 1,
95
+ ) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]:
96
+ """This method populates the lora weights (lora_a and lora_b) in the lora layers (BaseLayerWithLoRA).
97
+
98
+ Args:
99
+ index_to_id: a list of lora ids. The index of the lora id
100
+ represents which memory slot the lora matrices are
101
+ stored in. A None value indicates a free slot.
102
+ lora_layer: the LoRAlayer to populate.
103
+ baselayer_weights: the PyTorch tensor containing the layer's
104
+ weights.
105
+ repeats: must only be set for column parallel packed
106
+ layers. Indicates the number of loras to compose
107
+ together to create a single lora layer.
108
+
109
+ returns:
110
+ lora_dict: a dictionary dict[int, LoRALayerWeights] that maps the lora ID to the corresponding lora weights.
111
+ sublora_dict: a dictionary dict[int, list[LoRALayerWeights]] that maps the lora ID to the corresponding lora weights.
112
+ """
113
+
114
+ # Dictionary that maps the lora ID to the
115
+ # corresponding lora weights.
116
+ lora_dict: dict[int, LoRALayerWeights] = dict()
117
+
118
+ # Dictionary that maps the lora ID to the
119
+ # corresponding subloras.
120
+ sublora_dict: dict[int, list[LoRALayerWeights]] = dict()
121
+
122
+ for slot_idx, lora_id in enumerate(index_to_id):
123
+ if lora_id is not None:
124
+ subloras: list[LoRALayerWeights] = []
125
+ sublora_len = baselayer_weights.shape[0] // repeats
126
+ for i in range(repeats):
127
+ sublora = DummyLoRAManager(
128
+ baselayer_weights.device).init_random_lora(
129
+ module_name=f"fake_{i}",
130
+ weight=baselayer_weights,
131
+ )
132
+ sublora.lora_b = sublora.lora_b[(sublora_len *
133
+ i):(sublora_len * (i + 1)), :]
134
+ sublora.optimize()
135
+ subloras.append(sublora)
136
+
137
+ lora = PackedLoRALayerWeights.pack(
138
+ subloras) if repeats > 1 else subloras[0]
139
+
140
+ # Some of the layer.lora is torchax tensor so it can only do math (slice op) in the torchax env.
141
+ with torchax.default_env():
142
+ lora_layer.set_lora(
143
+ slot_idx,
144
+ lora_a=lora.lora_a,
145
+ lora_b=lora.lora_b,
146
+ )
147
+
148
+ lora_dict[lora_id] = lora
149
+ sublora_dict[lora_id] = subloras
150
+
151
+ return lora_dict, sublora_dict
152
+
153
+
154
+ def create_random_inputs(
155
+ active_lora_ids: list[int],
156
+ num_inputs: int,
157
+ input_size: tuple[int, ...],
158
+ input_range: tuple[float, float],
159
+ input_type: torch.dtype = torch.int,
160
+ device: torch.device = "cpu",
161
+ ) -> tuple[list[torch.Tensor], list[int], list[int]]:
162
+ """Creates random inputs.
163
+
164
+ Args:
165
+ active_lora_ids: lora IDs of active lora weights.
166
+ num_inputs: the number of inputs to create. Or the number of requests.
167
+ input_size: the size of each individual input. Or the number of tokens.
168
+ input_range: the range of values to include in the input.
169
+ input_range[0] <= possible input values < input_range[1]
170
+ input_type: the type of values in the input.
171
+
172
+ returns:
173
+ inputs: a list of torch tensors of size num_inputs. Each input has shape `input_size`.
174
+ index_mapping: maps each input token to a lora ID.
175
+ prompt_mapping: maps each request to a lora ID.
176
+ """
177
+
178
+ low, high = input_range
179
+
180
+ inputs: list[torch.Tensor] = []
181
+ index_mapping: list[int] = []
182
+ prompt_mapping: list[int] = []
183
+
184
+ for _ in range(num_inputs):
185
+ if input_type == torch.int:
186
+ inputs.append(
187
+ torch.randint(low=int(low),
188
+ high=int(high),
189
+ size=input_size,
190
+ device=device))
191
+ else:
192
+ inputs.append(
193
+ torch.rand(size=input_size, dtype=input_type, device=device) *
194
+ high + low)
195
+
196
+ lora_id = random.choice(active_lora_ids)
197
+ index_mapping += [lora_id] * input_size[0]
198
+ prompt_mapping += [lora_id]
199
+
200
+ return inputs, index_mapping, prompt_mapping
201
+
202
+
203
+ @torch.inference_mode()
204
+ @pytest.mark.parametrize("num_loras", [1, 4, 9])
205
+ @pytest.mark.parametrize("repeats", [1, 2, 3])
206
+ @pytest.mark.parametrize("stage", [True, False])
207
+ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
208
+ set_random_seed(6)
209
+
210
+ max_loras = 9
211
+ max_lora_rank = 8
212
+ lora_config = LoRAConfig(
213
+ max_loras=max_loras,
214
+ max_lora_rank=max_lora_rank,
215
+ fully_sharded_loras=False,
216
+ lora_dtype=torch.bfloat16,
217
+ )
218
+ vllm_config = dist_init
219
+ vllm_config.lora_config = lora_config
220
+
221
+ mesh = _create_mesh()
222
+ linear, lora_linear = _create_column_parallel_packed_layer(
223
+ repeats, vllm_config, mesh)
224
+ _verify_lora_linear_layer(linear, lora_linear)
225
+
226
+ # After we create the lora_config, the linear layer and the lora layer,
227
+ # here are the steps to do next:
228
+ # - create a punica wrapper.
229
+ # - associate the punica wrapper with the lora layer.
230
+ # - populate the lora matrices in the lora layer: use non-zero values for testing lora and zero values for testing the case where the layer doesn't have lora.
231
+ # - create inputs and lora_mapping.
232
+ # - update the metadata of the punica wrapper.
233
+ # - convert the inputs to be torchax tensors.
234
+ # - then run a forward on the lora layer to get the actual output.
235
+ # - then run a reference implementation as the expected output.
236
+
237
+ # Create a punica wrapper and associate it with the lora linear layer.
238
+ max_num_batched_tokens = 8192
239
+ max_batches = 256
240
+ with torchax.default_env():
241
+ punica_wrapper = get_punica_wrapper(max_num_batched_tokens,
242
+ max_batches,
243
+ 'jax',
244
+ max_loras=max_loras)
245
+ assert check_punica_wrapper(punica_wrapper)
246
+ lora_linear.set_mapping(punica_wrapper)
247
+
248
+ # Populate lora matrices (lora_a and lora_b) in the lora layer.
249
+ index_to_id = get_random_index_to_id(num_loras, max_loras)
250
+ # lora_dict: lora_id -> LoRALayerWeights|PackedLoRALayerWeights
251
+ lora_dict, sublora_dict = populate_loras(
252
+ index_to_id,
253
+ lora_layer=lora_linear,
254
+ baselayer_weights=linear.weight,
255
+ repeats=repeats,
256
+ )
257
+
258
+ # Create inputs and lora mappings.
259
+ # inputs: list[torch.Tensor] of size num_inputs. inputs[i] corresponds to a request which has several token of shape=[num_tokens, 64].
260
+ # index_mapping: list[int]
261
+ # prompt_mapping: list[int]
262
+ inputs, index_mapping, prompt_mapping = create_random_inputs(
263
+ active_lora_ids=list(lora_dict.keys()),
264
+ num_inputs=32,
265
+ input_size=(1, 64),
266
+ input_range=(0, 1),
267
+ input_type=torch.bfloat16,
268
+ device='cpu')
269
+
270
+ _update_punica_wrapper_metadata(punica_wrapper, index_mapping,
271
+ prompt_mapping, stage, index_to_id,
272
+ lora_config)
273
+
274
+ with torchax.default_env():
275
+ torchax_inputs = _shard_and_move_inputs_to_tpu(inputs, mesh)
276
+ actual_result = lora_linear(torchax_inputs)[0]
277
+
278
+ expected_results: list[torch.Tensor] = []
279
+ for input_, lora_id in zip(inputs, prompt_mapping):
280
+ # linear(input_) returns (output, output_bias) so we only need the first one.
281
+ result = linear(input_)[0]
282
+ subloras = sublora_dict[lora_id]
283
+ for i, sublora in enumerate(subloras):
284
+ result[:, sublora.lora_b.shape[0] * i:sublora.lora_b.shape[0] *
285
+ (i + 1)] += (input_ @ sublora.lora_a.T @ sublora.lora_b.T *
286
+ sublora.scaling)
287
+ expected_results.append(result)
288
+ expected_result = torch.cat(expected_results)
289
+
290
+ rtol, atol = TOLERANCES[actual_result.dtype]
291
+ with torchax.default_env():
292
+ actual_result_cpu = actual_result.to('cpu')
293
+ torch.testing.assert_close(actual_result_cpu,
294
+ expected_result,
295
+ rtol=rtol,
296
+ atol=atol)
297
+ # print(
298
+ # f'Output max diff: {torch.max(torch.abs(expected_result - actual_result_cpu))}'
299
+ # )
300
+ # print(
301
+ # f'Output mean diff: {torch.mean(torch.abs(expected_result - actual_result_cpu))}'
302
+ # )
303
+
304
+ # Check that resetting the lora weights succeeds
305
+ # Here we set all lora weight to be empty.
306
+ for slot_idx in range(max_loras):
307
+ lora_linear.reset_lora(slot_idx)
308
+
309
+ inputs, index_mapping, prompt_mapping = create_random_inputs(
310
+ active_lora_ids=[0], # different from the above create_random_inputs
311
+ num_inputs=32,
312
+ input_size=(1, 64),
313
+ input_range=(0, 1),
314
+ input_type=torch.bfloat16,
315
+ device='cpu')
316
+
317
+ _update_punica_wrapper_metadata(punica_wrapper, index_mapping,
318
+ prompt_mapping, stage, index_to_id,
319
+ lora_config)
320
+
321
+ with torchax.default_env():
322
+ torchax_inputs = _shard_and_move_inputs_to_tpu(inputs, mesh)
323
+ actual_result = lora_linear(torchax_inputs)[0]
324
+ expected_result = linear(torch.cat(inputs))[0]
325
+
326
+ rtol, atol = TOLERANCES[actual_result.dtype]
327
+ with torchax.default_env():
328
+ actual_result_cpu = actual_result.to('cpu')
329
+ torch.testing.assert_close(actual_result_cpu,
330
+ expected_result,
331
+ rtol=rtol,
332
+ atol=atol)
333
+
334
+
335
+ @torch.inference_mode()
336
+ @pytest.mark.parametrize("num_loras", [1, 4, 9])
337
+ @pytest.mark.parametrize("layer_type", ["row", "column", "replicated"])
338
+ @pytest.mark.parametrize("stage", [True, False])
339
+ def test_linear_parallel(dist_init, num_loras, layer_type, stage) -> None:
340
+ set_random_seed(6)
341
+
342
+ max_loras = 9
343
+ max_lora_rank = 8
344
+ lora_config = LoRAConfig(
345
+ max_loras=max_loras,
346
+ max_lora_rank=max_lora_rank,
347
+ fully_sharded_loras=False,
348
+ lora_dtype=torch.bfloat16,
349
+ )
350
+ vllm_config = dist_init
351
+ vllm_config.lora_config = lora_config
352
+
353
+ mesh = _create_mesh()
354
+ linear, lora_linear = _create_random_linear_parallel_layer(
355
+ layer_type, vllm_config, mesh)
356
+ _verify_lora_linear_layer(linear, lora_linear)
357
+
358
+ max_num_batched_tokens = 8192
359
+ max_batches = 256
360
+ with torchax.default_env():
361
+ punica_wrapper = get_punica_wrapper(max_num_batched_tokens,
362
+ max_batches,
363
+ 'jax',
364
+ max_loras=max_loras)
365
+ assert check_punica_wrapper(punica_wrapper)
366
+ lora_linear.set_mapping(punica_wrapper)
367
+
368
+ # Populate lora matrices (lora_a and lora_b) in the lora layer.
369
+ index_to_id = get_random_index_to_id(num_loras, max_loras)
370
+ # lora_dict: lora_id -> LoRALayerWeights|PackedLoRALayerWeights
371
+ lora_dict, sublora_dict = populate_loras(
372
+ index_to_id,
373
+ lora_layer=lora_linear,
374
+ baselayer_weights=linear.weight,
375
+ )
376
+
377
+ inputs, index_mapping, prompt_mapping = create_random_inputs(
378
+ active_lora_ids=list(lora_dict.keys()),
379
+ num_inputs=32,
380
+ input_size=(1, 64),
381
+ input_range=(0, 1),
382
+ input_type=torch.bfloat16,
383
+ device='cpu')
384
+
385
+ _update_punica_wrapper_metadata(punica_wrapper, index_mapping,
386
+ prompt_mapping, stage, index_to_id,
387
+ lora_config)
388
+
389
+ with torchax.default_env():
390
+ torchax_inputs = _shard_and_move_inputs_to_tpu(inputs, mesh)
391
+ actual_result = lora_linear(torchax_inputs)[0]
392
+
393
+ expected_results: list[torch.Tensor] = []
394
+ for input_, lora_id in zip(inputs, prompt_mapping):
395
+ result = linear(input_)[0]
396
+ lora = lora_dict[lora_id]
397
+ lora_result = input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
398
+ result += lora_result
399
+ expected_results.append(result)
400
+ expected_result = torch.cat(expected_results)
401
+
402
+ rtol, atol = TOLERANCES[actual_result.dtype]
403
+ with torchax.default_env():
404
+ actual_result_cpu = actual_result.to('cpu')
405
+ torch.testing.assert_close(actual_result_cpu,
406
+ expected_result,
407
+ rtol=rtol,
408
+ atol=atol)
409
+
410
+ # Check that resetting the lora weights succeeds
411
+ # Here we set all lora weight to be empty.
412
+ for slot_idx in range(max_loras):
413
+ lora_linear.reset_lora(slot_idx)
414
+
415
+ inputs, index_mapping, prompt_mapping = create_random_inputs(
416
+ active_lora_ids=[0], # different from the above create_random_inputs
417
+ num_inputs=32,
418
+ input_size=(1, 64),
419
+ input_range=(0, 1),
420
+ input_type=torch.bfloat16,
421
+ device='cpu')
422
+ _update_punica_wrapper_metadata(punica_wrapper, index_mapping,
423
+ prompt_mapping, stage, index_to_id,
424
+ lora_config)
425
+
426
+ with torchax.default_env():
427
+ torchax_inputs = _shard_and_move_inputs_to_tpu(inputs, mesh)
428
+ actual_result = lora_linear(torchax_inputs)[0]
429
+ expected_result = linear(torch.cat(inputs))[0]
430
+
431
+ rtol, atol = TOLERANCES[actual_result.dtype]
432
+ with torchax.default_env():
433
+ actual_result_cpu = actual_result.to('cpu')
434
+ torch.testing.assert_close(actual_result_cpu,
435
+ expected_result,
436
+ rtol=rtol,
437
+ atol=atol)
438
+
439
+
440
+ def _create_random_linear_parallel_layer(layer_type, vllm_config, mesh):
441
+ # We first create a base linear layer, then a lora layer to wrap it.
442
+ if layer_type == "row":
443
+
444
+ def _create_row_linear():
445
+ return RowParallelLinear(
446
+ 64, # input_size
447
+ 64, # output_size
448
+ bias=False,
449
+ params_dtype=torch.bfloat16)
450
+
451
+ linear = _create_row_linear()
452
+ linear.weight.data = torch.rand_like(linear.weight.data)
453
+
454
+ base_linear = _create_row_linear()
455
+ lora_linear = _create_lora_wrapper(linear,
456
+ base_linear,
457
+ RowParallelLinearWithLoRA,
458
+ vllm_config=vllm_config,
459
+ mesh=mesh)
460
+ elif layer_type == "column":
461
+
462
+ def _create_column_linear():
463
+ return ColumnParallelLinear(64,
464
+ 64,
465
+ bias=False,
466
+ params_dtype=torch.bfloat16)
467
+
468
+ linear = _create_column_linear()
469
+ linear.weight.data = torch.rand_like(linear.weight.data)
470
+
471
+ base_linear = _create_column_linear()
472
+ lora_linear = _create_lora_wrapper(linear,
473
+ base_linear,
474
+ ColumnParallelLinearWithLoRA,
475
+ vllm_config=vllm_config,
476
+ mesh=mesh)
477
+
478
+ elif layer_type == "replicated":
479
+
480
+ def _create_replicated_linear():
481
+ return ReplicatedLinear(64,
482
+ 64,
483
+ bias=False,
484
+ params_dtype=torch.bfloat16)
485
+
486
+ linear = _create_replicated_linear()
487
+ linear.weight.data = torch.rand_like(linear.weight.data)
488
+
489
+ base_linear = _create_replicated_linear()
490
+ lora_linear = _create_lora_wrapper(linear,
491
+ base_linear,
492
+ ReplicatedLinearWithLoRA,
493
+ vllm_config=vllm_config,
494
+ mesh=mesh)
495
+
496
+ else:
497
+ raise NotImplementedError("Unknown layer type: {}".format(layer_type))
498
+
499
+ return linear, lora_linear
500
+
501
+
502
+ def _create_mesh():
503
+ axis_names = ("data", "model")
504
+ devices = jax.devices()
505
+ mesh_shape = (1, len(devices))
506
+ mesh = jax.make_mesh(mesh_shape, axis_names, devices=devices)
507
+ return mesh
508
+
509
+
510
+ def _verify_lora_linear_layer(linear, lora_linear):
511
+ with torchax.default_env():
512
+ # lora_linear.weight has type torchax.tensor.Tensor
513
+ # BaseLinearLayerWithLoRA.weight property guarantees this.
514
+ # if len(devices) != 1, `reorder_concatenated_tensor_for_sharding` function may reorder the out_features dimension of the weight matrix.
515
+ # So the below check will fail.
516
+ if len(jax.devices()) == 1:
517
+ assert torch.equal(linear.weight.data,
518
+ lora_linear.weight.to('cpu'))
519
+
520
+
521
+ def _shard_and_move_inputs_to_tpu(inputs, mesh):
522
+ processed_inputs = []
523
+ for input in inputs:
524
+ # without `torch_view`, you get an error `AttributeError: 'jaxlib._jax.ArrayImpl' object has no attribute 'apply_jax_'`
525
+ # without `t2j`, you get an error `AttributeError: 'Tensor' object has no attribute 'apply_jax_'`
526
+ jax_input = torch_view(t2j(input))
527
+ jax_input.apply_jax_(jax.device_put,
528
+ NamedSharding(mesh, P(None, None)))
529
+ processed_inputs.append(jax_input)
530
+ return torch.cat(processed_inputs)
531
+
532
+
533
+ def _update_punica_wrapper_metadata(punica_wrapper, index_mapping,
534
+ prompt_mapping, stage, index_to_id,
535
+ lora_config):
536
+ lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
537
+ with torchax.default_env():
538
+ # Here we move the metadata from cpu to tpu.
539
+ punica_wrapper.update_metadata(
540
+ lora_mapping,
541
+ index_to_id,
542
+ lora_config.max_loras,
543
+ vocab_size=512,
544
+ )
545
+ assert jax_view(punica_wrapper._lora_indices_per_batch).platform(
546
+ ) == 'tpu', 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
547
+ assert isinstance(
548
+ jax_view(punica_wrapper._lora_indices_per_batch).sharding,
549
+ jax.sharding.SingleDeviceSharding
550
+ ), 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
551
+
552
+
553
+ def _create_column_parallel_packed_layer(repeats, vllm_config, mesh):
554
+ # We first create a base linear layer, then a lora layer to wrap it.
555
+ if repeats == 2:
556
+ # In e2e, MergedColumnParallelLinear is created when we load the model. The base_layer weights are sharded and moved to TPU in VllmUnquantizedLinearMethod.process_weights_after_loading.
557
+ def _create_merged_column_linear():
558
+ return MergedColumnParallelLinear(
559
+ 64, # input_size
560
+ [64] * repeats, # output_size
561
+ bias=False,
562
+ params_dtype=torch.bfloat16)
563
+
564
+ linear = _create_merged_column_linear()
565
+ linear.weight.data = torch.rand_like(linear.weight.data)
566
+
567
+ base_linear = _create_merged_column_linear()
568
+ lora_linear = _create_lora_wrapper(linear, base_linear,
569
+ MergedColumnParallelLinearWithLoRA,
570
+ vllm_config, mesh, repeats)
571
+ elif repeats == 3:
572
+
573
+ def _create_qkv_linear():
574
+ return QKVParallelLinear(64,
575
+ 64,
576
+ 32,
577
+ bias=False,
578
+ params_dtype=torch.bfloat16)
579
+
580
+ linear = _create_qkv_linear()
581
+ linear.weight.data = torch.rand_like(linear.weight.data)
582
+
583
+ base_linear = _create_qkv_linear()
584
+ lora_linear = _create_lora_wrapper(linear, base_linear,
585
+ MergedQKVParallelLinearWithLoRA,
586
+ vllm_config, mesh, repeats)
587
+ else:
588
+
589
+ def _create_qkv_linear():
590
+ return QKVParallelLinear(64,
591
+ 64,
592
+ 32,
593
+ bias=False,
594
+ params_dtype=torch.bfloat16)
595
+
596
+ linear = _create_qkv_linear()
597
+ linear.weight.data = torch.rand_like(linear.weight.data)
598
+
599
+ base_linear = _create_qkv_linear()
600
+ lora_linear = _create_lora_wrapper(linear, base_linear,
601
+ QKVParallelLinearWithLoRA,
602
+ vllm_config, mesh, repeats)
603
+
604
+ return linear, lora_linear
605
+
606
+
607
+ def _create_lora_wrapper(linear,
608
+ base_linear,
609
+ lora_cls,
610
+ vllm_config,
611
+ mesh,
612
+ repeats=1):
613
+ base_linear.weight.data = linear.weight.data
614
+ jax_config = JaxCommonLinearConfig(vllm_config, mesh, base_linear)
615
+ linear_method = VllmUnquantizedLinearMethod(jax_config)
616
+ base_linear.quant_method = linear_method
617
+ linear_method.process_weights_after_loading(
618
+ base_linear) # here base_linear.weight is moved to TPU and sharded.
619
+ assert jax_view(base_linear.weight).platform(
620
+ ) == 'tpu', 'base_linear.weight should have been moved to TPU.'
621
+ assert not isinstance(
622
+ jax_view(base_linear.weight).sharding, jax.sharding.
623
+ SingleDeviceSharding), 'base_linear.weight should have been sharded.'
624
+
625
+ lora_linear = lora_cls(base_linear)
626
+
627
+ lora_config = vllm_config.lora_config
628
+ max_loras = lora_config.max_loras
629
+ with torchax.default_env():
630
+ lora_linear.create_lora_weights(max_loras, lora_config)
631
+ # In the e2e, the lora_layer's weight is moved to TPU in _shard_module_to_tpu.
632
+ _shard_module_to_tpu(lora_linear, mesh)
633
+
634
+ assert jax_view(lora_linear.lora_a_stacked[0]).platform(
635
+ ) == 'tpu', 'lora_a_stacked should have been moved to TPU.'
636
+ assert not isinstance(
637
+ jax_view(lora_linear.lora_a_stacked[0]).sharding, jax.sharding.
638
+ SingleDeviceSharding), 'lora_a_stacked should have been sharded.'
639
+ assert jax_view(lora_linear.lora_b_stacked[0]).platform(
640
+ ) == 'tpu', 'lora_b_stacked should have been moved to TPU.'
641
+ assert not isinstance(
642
+ jax_view(lora_linear.lora_b_stacked[0]).sharding, jax.sharding.
643
+ SingleDeviceSharding), 'lora_b_stacked should have been sharded.'
644
+ n_slices = repeats
645
+ assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
646
+ lora_linear.lora_b_stacked) == n_slices)
647
+
648
+ return lora_linear