tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,666 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import random
16
+ from typing import Optional
17
+
18
+ import jax
19
+ import pytest
20
+ import torch
21
+ import torchax
22
+ from jax.sharding import NamedSharding, PartitionSpec
23
+ from torchax.interop import jax_view, torch_view
24
+ from torchax.ops.mappings import t2j
25
+ from vllm.config import LoRAConfig
26
+ # yapf conflicts with isort for this block
27
+ # yapf: disable
28
+ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
29
+ LoRAMapping, MergedColumnParallelLinearWithLoRA,
30
+ MergedQKVParallelLinearWithLoRA,
31
+ QKVParallelLinearWithLoRA,
32
+ ReplicatedLinearWithLoRA,
33
+ RowParallelLinearWithLoRA)
34
+ # yapf: enable
35
+ from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
36
+ from vllm.lora.punica_wrapper import get_punica_wrapper
37
+ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
38
+ MergedColumnParallelLinear,
39
+ QKVParallelLinear,
40
+ ReplicatedLinear,
41
+ RowParallelLinear)
42
+ from vllm.model_executor.utils import set_random_seed
43
+ from vllm.platforms import current_platform
44
+
45
+ from tpu_inference.layers.vllm.quantization.common import JaxCommonLinearConfig
46
+ from tpu_inference.layers.vllm.quantization.unquantized import \
47
+ VllmUnquantizedLinearMethod
48
+ from tpu_inference.layers.vllm.sharding import _shard_module_to_tpu
49
+
50
+ from .utils import DummyLoRAManager
51
+
52
+ P = PartitionSpec
53
+
54
+ TOLERANCES = {
55
+ torch.float16: (5e-3, 5e-3),
56
+ torch.float32: (5e-3, 5e-3),
57
+ torch.bfloat16: (3e-2, 2e-2),
58
+ }
59
+
60
+ pytestmark = pytest.mark.skipif(not current_platform.is_tpu(),
61
+ reason="This test is only for TPU platform.")
62
+
63
+ # prefill stage(True) or decode stage(False)
64
+ STAGES = [True, False]
65
+
66
+
67
+ def check_punica_wrapper(punica_wrapper) -> bool:
68
+ from tpu_inference.lora.torch_punica_tpu import PunicaWrapperTPU
69
+ return type(punica_wrapper) is PunicaWrapperTPU
70
+
71
+
72
+ def get_random_index_to_id(num_loras: int,
73
+ num_slots: int,
74
+ log: bool = True) -> list[Optional[int]]:
75
+ """Creates a random index_to_lora_id mapping: slot[index] = lora_id.
76
+
77
+ Args:
78
+ num_loras: The number of active loras in the mapping.
79
+ num_slots: The number of slots in the mapping. Must be larger
80
+ than num_loras.
81
+ log: Whether to log the output.
82
+
83
+ returns:
84
+ index_to_lora_id: a random index_to_lora_id mapping.
85
+ """
86
+
87
+ if num_loras > num_slots:
88
+ raise ValueError(
89
+ f"num_loras is higher than num_slots: {num_loras} > {num_slots}. "
90
+ "num_loras must be less than or equal to num_slots.")
91
+
92
+ slots: list[Optional[int]] = [None] * num_slots
93
+ random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist()
94
+ for lora_id, slot_idx in enumerate(random_slot_selections, start=1):
95
+ # The slot_idx start at 1.
96
+ slots[slot_idx] = lora_id
97
+
98
+ if log:
99
+ print(f"Created lora_id_to_index mapping: {slots}.")
100
+
101
+ return slots
102
+
103
+
104
+ def populate_loras(
105
+ index_to_id: list[Optional[int]],
106
+ lora_layer: BaseLayerWithLoRA,
107
+ baselayer_weights: torch.Tensor,
108
+ repeats: int = 1,
109
+ ) -> tuple[dict[int, LoRALayerWeights], dict[int, list[LoRALayerWeights]]]:
110
+ """This method populates the lora weights (lora_a and lora_b) in the lora layers (BaseLayerWithLoRA).
111
+
112
+ Args:
113
+ index_to_id: a list of lora ids. The index of the lora id
114
+ represents which memory slot the lora matrices are
115
+ stored in. A None value indicates a free slot.
116
+ lora_layer: the LoRAlayer to populate.
117
+ baselayer_weights: the PyTorch tensor containing the layer's
118
+ weights.
119
+ repeats: must only be set for column parallel packed
120
+ layers. Indicates the number of loras to compose
121
+ together to create a single lora layer.
122
+
123
+ returns:
124
+ lora_dict: a dictionary dict[int, LoRALayerWeights] that maps the lora ID to the corresponding lora weights.
125
+ sublora_dict: a dictionary dict[int, list[LoRALayerWeights]] that maps the lora ID to the corresponding lora weights.
126
+ """
127
+
128
+ # Dictionary that maps the lora ID to the
129
+ # corresponding lora weights.
130
+ lora_dict: dict[int, LoRALayerWeights] = dict()
131
+
132
+ # Dictionary that maps the lora ID to the
133
+ # corresponding subloras.
134
+ sublora_dict: dict[int, list[LoRALayerWeights]] = dict()
135
+
136
+ for slot_idx, lora_id in enumerate(index_to_id):
137
+ if lora_id is not None:
138
+ subloras: list[LoRALayerWeights] = []
139
+ sublora_len = baselayer_weights.shape[0] // repeats
140
+ for i in range(repeats):
141
+ sublora = DummyLoRAManager(
142
+ baselayer_weights.device).init_random_lora(
143
+ module_name=f"fake_{i}",
144
+ weight=baselayer_weights,
145
+ )
146
+ sublora.lora_b = sublora.lora_b[(sublora_len *
147
+ i):(sublora_len * (i + 1)), :]
148
+ sublora.optimize()
149
+ subloras.append(sublora)
150
+
151
+ lora = PackedLoRALayerWeights.pack(
152
+ subloras) if repeats > 1 else subloras[0]
153
+
154
+ # Some of the layer.lora is torchax tensor so it can only do math (slice op) in the torchax env.
155
+ with torchax.default_env():
156
+ lora_layer.set_lora(
157
+ slot_idx,
158
+ lora_a=lora.lora_a,
159
+ lora_b=lora.lora_b,
160
+ )
161
+
162
+ lora_dict[lora_id] = lora
163
+ sublora_dict[lora_id] = subloras
164
+
165
+ return lora_dict, sublora_dict
166
+
167
+
168
+ def create_random_inputs(
169
+ active_lora_ids: list[int],
170
+ num_inputs: int,
171
+ input_size: tuple[int, ...],
172
+ input_range: tuple[float, float],
173
+ input_type: torch.dtype = torch.int,
174
+ device: torch.device = "cpu",
175
+ ) -> tuple[list[torch.Tensor], list[int], list[int]]:
176
+ """Creates random inputs.
177
+
178
+ Args:
179
+ active_lora_ids: lora IDs of active lora weights.
180
+ num_inputs: the number of inputs to create. Or the number of requests.
181
+ input_size: the size of each individual input. Or the number of tokens.
182
+ input_range: the range of values to include in the input.
183
+ input_range[0] <= possible input values < input_range[1]
184
+ input_type: the type of values in the input.
185
+
186
+ returns:
187
+ inputs: a list of torch tensors of size num_inputs. Each input has shape `input_size`.
188
+ index_mapping: maps each input token to a lora ID.
189
+ prompt_mapping: maps each request to a lora ID.
190
+ """
191
+
192
+ low, high = input_range
193
+
194
+ inputs: list[torch.Tensor] = []
195
+ index_mapping: list[int] = []
196
+ prompt_mapping: list[int] = []
197
+
198
+ for _ in range(num_inputs):
199
+ if input_type == torch.int:
200
+ inputs.append(
201
+ torch.randint(low=int(low),
202
+ high=int(high),
203
+ size=input_size,
204
+ device=device))
205
+ else:
206
+ inputs.append(
207
+ torch.rand(size=input_size, dtype=input_type, device=device) *
208
+ high + low)
209
+
210
+ lora_id = random.choice(active_lora_ids)
211
+ index_mapping += [lora_id] * input_size[0]
212
+ prompt_mapping += [lora_id]
213
+
214
+ return inputs, index_mapping, prompt_mapping
215
+
216
+
217
+ @torch.inference_mode()
218
+ @pytest.mark.parametrize("num_loras", [1, 4, 9])
219
+ @pytest.mark.parametrize("repeats", [1, 2, 3])
220
+ @pytest.mark.parametrize("stage", [True, False])
221
+ def test_column_parallel_packed(dist_init, num_loras, repeats, stage) -> None:
222
+ set_random_seed(6)
223
+
224
+ max_loras = 9
225
+ max_lora_rank = 8
226
+ lora_config = LoRAConfig(
227
+ max_loras=max_loras,
228
+ max_lora_rank=max_lora_rank,
229
+ fully_sharded_loras=False,
230
+ lora_dtype=torch.bfloat16,
231
+ )
232
+ vllm_config = dist_init
233
+ vllm_config.lora_config = lora_config
234
+
235
+ mesh = _create_mesh()
236
+ linear, lora_linear = _create_column_parallel_packed_layer(
237
+ repeats, vllm_config, mesh)
238
+ _verify_lora_linear_layer(linear, lora_linear)
239
+
240
+ # After we create the lora_config, the linear layer and the lora layer,
241
+ # here are the steps to do next:
242
+ # - create a punica wrapper.
243
+ # - associate the punica wrapper with the lora layer.
244
+ # - populate the lora matrices in the lora layer: use non-zero values for testing lora and zero values for testing the case where the layer doesn't have lora.
245
+ # - create inputs and lora_mapping.
246
+ # - update the metadata of the punica wrapper.
247
+ # - convert the inputs to be torchax tensors.
248
+ # - then run a forward on the lora layer to get the actual output.
249
+ # - then run a reference implementation as the expected output.
250
+
251
+ # Create a punica wrapper and associate it with the lora linear layer.
252
+ max_num_batched_tokens = 8192
253
+ max_batches = 256
254
+ with torchax.default_env():
255
+ punica_wrapper = get_punica_wrapper(max_num_batched_tokens,
256
+ max_batches,
257
+ 'jax',
258
+ max_loras=max_loras)
259
+ assert check_punica_wrapper(punica_wrapper)
260
+ lora_linear.set_mapping(punica_wrapper)
261
+
262
+ # Populate lora matrices (lora_a and lora_b) in the lora layer.
263
+ index_to_id = get_random_index_to_id(num_loras, max_loras)
264
+ # lora_dict: lora_id -> LoRALayerWeights|PackedLoRALayerWeights
265
+ lora_dict, sublora_dict = populate_loras(
266
+ index_to_id,
267
+ lora_layer=lora_linear,
268
+ baselayer_weights=linear.weight,
269
+ repeats=repeats,
270
+ )
271
+
272
+ # Create inputs and lora mappings.
273
+ # inputs: list[torch.Tensor] of size num_inputs. inputs[i] corresponds to a request which has several token of shape=[num_tokens, 64].
274
+ # index_mapping: list[int]
275
+ # prompt_mapping: list[int]
276
+ inputs, index_mapping, prompt_mapping = create_random_inputs(
277
+ active_lora_ids=list(lora_dict.keys()),
278
+ num_inputs=32,
279
+ input_size=(1, 64),
280
+ input_range=(0, 1),
281
+ input_type=torch.bfloat16,
282
+ device='cpu')
283
+
284
+ _update_punica_wrapper_metadata(punica_wrapper, index_mapping,
285
+ prompt_mapping, stage, index_to_id,
286
+ lora_config)
287
+
288
+ with torchax.default_env():
289
+ torchax_inputs = _shard_and_move_inputs_to_tpu(inputs, mesh)
290
+ actual_result = lora_linear(torchax_inputs)[0]
291
+
292
+ expected_results: list[torch.Tensor] = []
293
+ for input_, lora_id in zip(inputs, prompt_mapping):
294
+ # linear(input_) returns (output, output_bias) so we only need the first one.
295
+ result = linear(input_)[0]
296
+ subloras = sublora_dict[lora_id]
297
+ for i, sublora in enumerate(subloras):
298
+ result[:, sublora.lora_b.shape[0] * i:sublora.lora_b.shape[0] *
299
+ (i + 1)] += (input_ @ sublora.lora_a.T @ sublora.lora_b.T *
300
+ sublora.scaling)
301
+ expected_results.append(result)
302
+ expected_result = torch.cat(expected_results)
303
+
304
+ rtol, atol = TOLERANCES[actual_result.dtype]
305
+ with torchax.default_env():
306
+ actual_result_cpu = actual_result.to('cpu')
307
+ torch.testing.assert_close(actual_result_cpu,
308
+ expected_result,
309
+ rtol=rtol,
310
+ atol=atol)
311
+ # print(
312
+ # f'Output max diff: {torch.max(torch.abs(expected_result - actual_result_cpu))}'
313
+ # )
314
+ # print(
315
+ # f'Output mean diff: {torch.mean(torch.abs(expected_result - actual_result_cpu))}'
316
+ # )
317
+
318
+ # Check that resetting the lora weights succeeds
319
+ # Here we set all lora weight to be empty.
320
+ for slot_idx in range(max_loras):
321
+ lora_linear.reset_lora(slot_idx)
322
+
323
+ inputs, index_mapping, prompt_mapping = create_random_inputs(
324
+ active_lora_ids=[0], # different from the above create_random_inputs
325
+ num_inputs=32,
326
+ input_size=(1, 64),
327
+ input_range=(0, 1),
328
+ input_type=torch.bfloat16,
329
+ device='cpu')
330
+
331
+ _update_punica_wrapper_metadata(punica_wrapper, index_mapping,
332
+ prompt_mapping, stage, index_to_id,
333
+ lora_config)
334
+
335
+ with torchax.default_env():
336
+ torchax_inputs = _shard_and_move_inputs_to_tpu(inputs, mesh)
337
+ actual_result = lora_linear(torchax_inputs)[0]
338
+ expected_result = linear(torch.cat(inputs))[0]
339
+
340
+ rtol, atol = TOLERANCES[actual_result.dtype]
341
+ with torchax.default_env():
342
+ actual_result_cpu = actual_result.to('cpu')
343
+ torch.testing.assert_close(actual_result_cpu,
344
+ expected_result,
345
+ rtol=rtol,
346
+ atol=atol)
347
+
348
+
349
+ @torch.inference_mode()
350
+ @pytest.mark.parametrize("num_loras", [1, 4, 9])
351
+ @pytest.mark.parametrize("layer_type", ["row", "column", "replicated"])
352
+ @pytest.mark.parametrize("stage", [True, False])
353
+ def test_linear_parallel(dist_init, num_loras, layer_type, stage) -> None:
354
+ set_random_seed(6)
355
+
356
+ max_loras = 9
357
+ max_lora_rank = 8
358
+ lora_config = LoRAConfig(
359
+ max_loras=max_loras,
360
+ max_lora_rank=max_lora_rank,
361
+ fully_sharded_loras=False,
362
+ lora_dtype=torch.bfloat16,
363
+ )
364
+ vllm_config = dist_init
365
+ vllm_config.lora_config = lora_config
366
+
367
+ mesh = _create_mesh()
368
+ linear, lora_linear = _create_random_linear_parallel_layer(
369
+ layer_type, vllm_config, mesh)
370
+ _verify_lora_linear_layer(linear, lora_linear)
371
+
372
+ max_num_batched_tokens = 8192
373
+ max_batches = 256
374
+ with torchax.default_env():
375
+ punica_wrapper = get_punica_wrapper(max_num_batched_tokens,
376
+ max_batches,
377
+ 'jax',
378
+ max_loras=max_loras)
379
+ assert check_punica_wrapper(punica_wrapper)
380
+ lora_linear.set_mapping(punica_wrapper)
381
+
382
+ # Populate lora matrices (lora_a and lora_b) in the lora layer.
383
+ index_to_id = get_random_index_to_id(num_loras, max_loras)
384
+ # lora_dict: lora_id -> LoRALayerWeights|PackedLoRALayerWeights
385
+ lora_dict, sublora_dict = populate_loras(
386
+ index_to_id,
387
+ lora_layer=lora_linear,
388
+ baselayer_weights=linear.weight,
389
+ )
390
+
391
+ inputs, index_mapping, prompt_mapping = create_random_inputs(
392
+ active_lora_ids=list(lora_dict.keys()),
393
+ num_inputs=32,
394
+ input_size=(1, 64),
395
+ input_range=(0, 1),
396
+ input_type=torch.bfloat16,
397
+ device='cpu')
398
+
399
+ _update_punica_wrapper_metadata(punica_wrapper, index_mapping,
400
+ prompt_mapping, stage, index_to_id,
401
+ lora_config)
402
+
403
+ with torchax.default_env():
404
+ torchax_inputs = _shard_and_move_inputs_to_tpu(inputs, mesh)
405
+ actual_result = lora_linear(torchax_inputs)[0]
406
+
407
+ expected_results: list[torch.Tensor] = []
408
+ for input_, lora_id in zip(inputs, prompt_mapping):
409
+ result = linear(input_)[0]
410
+ lora = lora_dict[lora_id]
411
+ lora_result = input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
412
+ result += lora_result
413
+ expected_results.append(result)
414
+ expected_result = torch.cat(expected_results)
415
+
416
+ rtol, atol = TOLERANCES[actual_result.dtype]
417
+ with torchax.default_env():
418
+ actual_result_cpu = actual_result.to('cpu')
419
+ torch.testing.assert_close(actual_result_cpu,
420
+ expected_result,
421
+ rtol=rtol,
422
+ atol=atol)
423
+
424
+ # Check that resetting the lora weights succeeds
425
+ # Here we set all lora weight to be empty.
426
+ for slot_idx in range(max_loras):
427
+ lora_linear.reset_lora(slot_idx)
428
+
429
+ inputs, index_mapping, prompt_mapping = create_random_inputs(
430
+ active_lora_ids=[0], # different from the above create_random_inputs
431
+ num_inputs=32,
432
+ input_size=(1, 64),
433
+ input_range=(0, 1),
434
+ input_type=torch.bfloat16,
435
+ device='cpu')
436
+ _update_punica_wrapper_metadata(punica_wrapper, index_mapping,
437
+ prompt_mapping, stage, index_to_id,
438
+ lora_config)
439
+
440
+ with torchax.default_env():
441
+ torchax_inputs = _shard_and_move_inputs_to_tpu(inputs, mesh)
442
+ actual_result = lora_linear(torchax_inputs)[0]
443
+ expected_result = linear(torch.cat(inputs))[0]
444
+
445
+ rtol, atol = TOLERANCES[actual_result.dtype]
446
+ with torchax.default_env():
447
+ actual_result_cpu = actual_result.to('cpu')
448
+ torch.testing.assert_close(actual_result_cpu,
449
+ expected_result,
450
+ rtol=rtol,
451
+ atol=atol)
452
+
453
+
454
+ def _create_random_linear_parallel_layer(layer_type, vllm_config, mesh):
455
+ # We first create a base linear layer, then a lora layer to wrap it.
456
+ if layer_type == "row":
457
+
458
+ def _create_row_linear():
459
+ return RowParallelLinear(
460
+ 64, # input_size
461
+ 64, # output_size
462
+ bias=False,
463
+ params_dtype=torch.bfloat16)
464
+
465
+ linear = _create_row_linear()
466
+ linear.weight.data = torch.rand_like(linear.weight.data)
467
+
468
+ base_linear = _create_row_linear()
469
+ lora_linear = _create_lora_wrapper(linear,
470
+ base_linear,
471
+ RowParallelLinearWithLoRA,
472
+ vllm_config=vllm_config,
473
+ mesh=mesh)
474
+ elif layer_type == "column":
475
+
476
+ def _create_column_linear():
477
+ return ColumnParallelLinear(64,
478
+ 64,
479
+ bias=False,
480
+ params_dtype=torch.bfloat16)
481
+
482
+ linear = _create_column_linear()
483
+ linear.weight.data = torch.rand_like(linear.weight.data)
484
+
485
+ base_linear = _create_column_linear()
486
+ lora_linear = _create_lora_wrapper(linear,
487
+ base_linear,
488
+ ColumnParallelLinearWithLoRA,
489
+ vllm_config=vllm_config,
490
+ mesh=mesh)
491
+
492
+ elif layer_type == "replicated":
493
+
494
+ def _create_replicated_linear():
495
+ return ReplicatedLinear(64,
496
+ 64,
497
+ bias=False,
498
+ params_dtype=torch.bfloat16)
499
+
500
+ linear = _create_replicated_linear()
501
+ linear.weight.data = torch.rand_like(linear.weight.data)
502
+
503
+ base_linear = _create_replicated_linear()
504
+ lora_linear = _create_lora_wrapper(linear,
505
+ base_linear,
506
+ ReplicatedLinearWithLoRA,
507
+ vllm_config=vllm_config,
508
+ mesh=mesh)
509
+
510
+ else:
511
+ raise NotImplementedError("Unknown layer type: {}".format(layer_type))
512
+
513
+ return linear, lora_linear
514
+
515
+
516
+ def _get_devices():
517
+ return jax.devices()
518
+
519
+
520
+ def _create_mesh():
521
+ axis_names = ("data", "model")
522
+ devices = _get_devices()
523
+ mesh_shape = (1, len(devices))
524
+ mesh = jax.make_mesh(mesh_shape, axis_names, devices=devices)
525
+ return mesh
526
+
527
+
528
+ def _verify_lora_linear_layer(linear, lora_linear):
529
+ with torchax.default_env():
530
+ # lora_linear.weight has type torchax.tensor.Tensor
531
+ # BaseLinearLayerWithLoRA.weight property guarantees this.
532
+ # if len(devices) != 1, `reorder_concatenated_tensor_for_sharding` function may reorder the out_features dimension of the weight matrix.
533
+ # So the below check will fail.
534
+ if len(_get_devices()) == 1:
535
+ assert torch.equal(linear.weight.data,
536
+ lora_linear.weight.to('cpu'))
537
+
538
+
539
+ def _shard_and_move_inputs_to_tpu(inputs, mesh):
540
+ processed_inputs = []
541
+ for input in inputs:
542
+ # without `torch_view`, you get an error `AttributeError: 'jaxlib._jax.ArrayImpl' object has no attribute 'apply_jax_'`
543
+ # without `t2j`, you get an error `AttributeError: 'Tensor' object has no attribute 'apply_jax_'`
544
+ jax_input = torch_view(t2j(input))
545
+ jax_input.apply_jax_(jax.device_put,
546
+ NamedSharding(mesh, P(None, None)))
547
+ processed_inputs.append(jax_input)
548
+ return torch.cat(processed_inputs)
549
+
550
+
551
+ def _update_punica_wrapper_metadata(punica_wrapper, index_mapping,
552
+ prompt_mapping, stage, index_to_id,
553
+ lora_config):
554
+ lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
555
+ with torchax.default_env():
556
+ # Here we move the metadata from cpu to tpu.
557
+ punica_wrapper.update_metadata(
558
+ lora_mapping,
559
+ index_to_id,
560
+ lora_config.max_loras,
561
+ vocab_size=512,
562
+ )
563
+ assert jax_view(punica_wrapper._lora_indices_per_batch).platform(
564
+ ) == 'tpu', 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
565
+ assert isinstance(
566
+ jax_view(punica_wrapper._lora_indices_per_batch).sharding,
567
+ jax.sharding.SingleDeviceSharding
568
+ ), 'punica_wrapper._lora_indices_per_batch should have been moved to TPU.'
569
+
570
+
571
+ def _create_column_parallel_packed_layer(repeats, vllm_config, mesh):
572
+ # We first create a base linear layer, then a lora layer to wrap it.
573
+ if repeats == 2:
574
+ # In e2e, MergedColumnParallelLinear is created when we load the model. The base_layer weights are sharded and moved to TPU in VllmUnquantizedLinearMethod.process_weights_after_loading.
575
+ def _create_merged_column_linear():
576
+ return MergedColumnParallelLinear(
577
+ 64, # input_size
578
+ [64] * repeats, # output_size
579
+ bias=False,
580
+ params_dtype=torch.bfloat16)
581
+
582
+ linear = _create_merged_column_linear()
583
+ linear.weight.data = torch.rand_like(linear.weight.data)
584
+
585
+ base_linear = _create_merged_column_linear()
586
+ lora_linear = _create_lora_wrapper(linear, base_linear,
587
+ MergedColumnParallelLinearWithLoRA,
588
+ vllm_config, mesh, repeats)
589
+ elif repeats == 3:
590
+
591
+ def _create_qkv_linear():
592
+ return QKVParallelLinear(64,
593
+ 64,
594
+ 32,
595
+ bias=False,
596
+ params_dtype=torch.bfloat16)
597
+
598
+ linear = _create_qkv_linear()
599
+ linear.weight.data = torch.rand_like(linear.weight.data)
600
+
601
+ base_linear = _create_qkv_linear()
602
+ lora_linear = _create_lora_wrapper(linear, base_linear,
603
+ MergedQKVParallelLinearWithLoRA,
604
+ vllm_config, mesh, repeats)
605
+ else:
606
+
607
+ def _create_qkv_linear():
608
+ return QKVParallelLinear(64,
609
+ 64,
610
+ 32,
611
+ bias=False,
612
+ params_dtype=torch.bfloat16)
613
+
614
+ linear = _create_qkv_linear()
615
+ linear.weight.data = torch.rand_like(linear.weight.data)
616
+
617
+ base_linear = _create_qkv_linear()
618
+ lora_linear = _create_lora_wrapper(linear, base_linear,
619
+ QKVParallelLinearWithLoRA,
620
+ vllm_config, mesh, repeats)
621
+
622
+ return linear, lora_linear
623
+
624
+
625
+ def _create_lora_wrapper(linear,
626
+ base_linear,
627
+ lora_cls,
628
+ vllm_config,
629
+ mesh,
630
+ repeats=1):
631
+ base_linear.weight.data = linear.weight.data
632
+ jax_config = JaxCommonLinearConfig(vllm_config, mesh, base_linear)
633
+ linear_method = VllmUnquantizedLinearMethod(jax_config)
634
+ base_linear.quant_method = linear_method
635
+ linear_method.process_weights_after_loading(
636
+ base_linear) # here base_linear.weight is moved to TPU and sharded.
637
+ assert jax_view(base_linear.weight).platform(
638
+ ) == 'tpu', 'base_linear.weight should have been moved to TPU.'
639
+ assert not isinstance(
640
+ jax_view(base_linear.weight).sharding, jax.sharding.
641
+ SingleDeviceSharding), 'base_linear.weight should have been sharded.'
642
+
643
+ lora_linear = lora_cls(base_linear)
644
+
645
+ lora_config = vllm_config.lora_config
646
+ max_loras = lora_config.max_loras
647
+ with torchax.default_env():
648
+ lora_linear.create_lora_weights(max_loras, lora_config)
649
+ # In the e2e, the lora_layer's weight is moved to TPU in _shard_module_to_tpu.
650
+ _shard_module_to_tpu(lora_linear, mesh)
651
+
652
+ assert jax_view(lora_linear.lora_a_stacked[0]).platform(
653
+ ) == 'tpu', 'lora_a_stacked should have been moved to TPU.'
654
+ assert not isinstance(
655
+ jax_view(lora_linear.lora_a_stacked[0]).sharding, jax.sharding.
656
+ SingleDeviceSharding), 'lora_a_stacked should have been sharded.'
657
+ assert jax_view(lora_linear.lora_b_stacked[0]).platform(
658
+ ) == 'tpu', 'lora_b_stacked should have been moved to TPU.'
659
+ assert not isinstance(
660
+ jax_view(lora_linear.lora_b_stacked[0]).sharding, jax.sharding.
661
+ SingleDeviceSharding), 'lora_b_stacked should have been sharded.'
662
+ n_slices = repeats
663
+ assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
664
+ lora_linear.lora_b_stacked) == n_slices)
665
+
666
+ return lora_linear