tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,621 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import tempfile
16
+
17
+ import jax
18
+ import pytest
19
+ import torch
20
+ import torchax
21
+ from jax._src import test_util as jtu
22
+ from jax.sharding import NamedSharding, PartitionSpec
23
+ from torchax.interop import torch_view
24
+ from torchax.ops.mappings import j2t, t2j
25
+ from vllm.config import ParallelConfig, set_current_vllm_config
26
+ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
27
+ init_distributed_environment)
28
+ from vllm.engine.arg_utils import EngineArgs
29
+ from vllm.forward_context import set_forward_context
30
+ from vllm.model_executor.layers.fused_moe import FusedMoE
31
+ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
32
+ LinearBase,
33
+ MergedColumnParallelLinear,
34
+ QKVParallelLinear,
35
+ RowParallelLinear)
36
+ from vllm.model_executor.model_loader import get_model as vllm_get_model
37
+
38
+ from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
39
+ from tpu_inference.layers.vllm.quantization.unquantized import (
40
+ VllmUnquantizedConfig, VllmUnquantizedFusedMoEMethod,
41
+ VllmUnquantizedLinearMethod)
42
+
43
+ from . import utils as test_utils
44
+
45
+ P = PartitionSpec
46
+ MODELS = ["Qwen/Qwen2-1.5B-Instruct"]
47
+
48
+
49
+ @pytest.fixture(autouse=True)
50
+ def setup_environment():
51
+ # This is a fake config used for init dist env.
52
+ # RowParallelLinear needs dist env to be initialized.
53
+ engine_args = EngineArgs(
54
+ model=MODELS[0],
55
+ max_model_len=64,
56
+ max_num_batched_tokens=64,
57
+ max_num_seqs=4,
58
+ )
59
+
60
+ vllm_config = engine_args.create_engine_config()
61
+
62
+ with set_current_vllm_config(vllm_config):
63
+ temp_file = tempfile.mkstemp()[1]
64
+ init_distributed_environment(
65
+ 1,
66
+ 0,
67
+ local_rank=0,
68
+ distributed_init_method=f"file://{temp_file}",
69
+ backend="gloo")
70
+ ensure_model_parallel_initialized(1, 1)
71
+
72
+
73
+ @pytest.mark.parametrize("model", MODELS)
74
+ @pytest.mark.parametrize("mesh", [
75
+ test_utils.get_spmd_mesh(1),
76
+ test_utils.get_spmd_mesh(jax.local_device_count())
77
+ ])
78
+ def test_quant_override(model, mesh):
79
+
80
+ engine_args = EngineArgs(
81
+ model=model,
82
+ max_model_len=64,
83
+ max_num_batched_tokens=64,
84
+ max_num_seqs=4,
85
+ )
86
+ vllm_config = engine_args.create_engine_config()
87
+ vllm_config.model_config.dtype = torch.bfloat16
88
+
89
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
90
+ assert isinstance(quant_config, VllmUnquantizedConfig)
91
+ assert quant_config.vllm_config == vllm_config
92
+ assert quant_config.mesh == mesh
93
+
94
+
95
+ @pytest.mark.parametrize("model", MODELS)
96
+ @pytest.mark.parametrize("mesh", [
97
+ test_utils.get_spmd_mesh(1),
98
+ test_utils.get_spmd_mesh(jax.local_device_count())
99
+ ])
100
+ def test_loading_model(model, mesh):
101
+ engine_args = EngineArgs(
102
+ model=model,
103
+ max_model_len=64,
104
+ max_num_batched_tokens=64,
105
+ max_num_seqs=4,
106
+ )
107
+ vllm_config = engine_args.create_engine_config()
108
+ vllm_config.model_config.dtype = torch.bfloat16
109
+ vllm_config.quant_config = get_tpu_quantization_config(vllm_config, mesh)
110
+ vllm_config.device_config.device = "cpu"
111
+
112
+ vllm_model = vllm_get_model(vllm_config=vllm_config)
113
+ layers = test_utils.find_all_layer_type(vllm_model, LinearBase)
114
+ for layer in layers:
115
+ assert isinstance(layer.quant_config, VllmUnquantizedConfig)
116
+ assert isinstance(layer.quant_method, VllmUnquantizedLinearMethod)
117
+
118
+
119
+ @pytest.mark.parametrize("model", MODELS)
120
+ @pytest.mark.parametrize("bias", [False, True])
121
+ @pytest.mark.parametrize("mesh", [
122
+ test_utils.get_spmd_mesh(1),
123
+ test_utils.get_spmd_mesh(jax.local_device_count())
124
+ ])
125
+ @pytest.mark.parametrize("enable_sp", [False, True])
126
+ def test_row_parallel_linear(model, bias, mesh, enable_sp):
127
+ dtype = torch.bfloat16
128
+
129
+ engine_args = EngineArgs(
130
+ model=model,
131
+ max_model_len=64,
132
+ max_num_batched_tokens=64,
133
+ max_num_seqs=4,
134
+ )
135
+ vllm_config = engine_args.create_engine_config()
136
+ vllm_config.compilation_config.pass_config.enable_sp = enable_sp
137
+
138
+ input_tensor = torch.rand(10, 4096, dtype=dtype) / 10
139
+ input_tensor = input_tensor.to('cpu')
140
+
141
+ with set_current_vllm_config(vllm_config):
142
+ row_linear = RowParallelLinear(
143
+ input_size=4096,
144
+ output_size=8192,
145
+ bias=bias,
146
+ params_dtype=dtype,
147
+ return_bias=False,
148
+ )
149
+
150
+ weight_data = torch.rand_like(row_linear.weight.data) / 10
151
+ if bias:
152
+ bias_data = torch.rand_like(row_linear.bias.data)
153
+
154
+ row_linear.weight.data = weight_data
155
+ if bias:
156
+ row_linear.bias.data = bias_data
157
+ row_linear = row_linear.to('cpu')
158
+ row_linear.quant_method.process_weights_after_loading(row_linear)
159
+ output = row_linear(input_tensor).to(dtype)
160
+
161
+ vllm_config.model_config.dtype = dtype
162
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
163
+ with set_current_vllm_config(vllm_config):
164
+ jax_row_linear = RowParallelLinear(
165
+ input_size=4096,
166
+ output_size=8192,
167
+ bias=bias,
168
+ params_dtype=dtype,
169
+ return_bias=False,
170
+ quant_config=quant_config,
171
+ )
172
+
173
+ jax_row_linear.weight.data = weight_data
174
+ if bias:
175
+ jax_row_linear.bias.data = bias_data
176
+
177
+ jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False))
178
+ jax_input_tensor.apply_jax_(jax.device_put,
179
+ NamedSharding(mesh, P(None, None)))
180
+ with torchax.default_env():
181
+ assert isinstance(jax_row_linear.quant_method,
182
+ VllmUnquantizedLinearMethod)
183
+ jax_row_linear.quant_method.process_weights_after_loading(
184
+ jax_row_linear)
185
+ jax_output = jax_row_linear(jax_input_tensor)
186
+ # j2t() doens't support bfloat16, so we cast it into float32 as an intermedate step.
187
+ jax_output = j2t(jax_output.to(torch.float32)).to(dtype)
188
+
189
+ torch.testing.assert_close(output, jax_output)
190
+
191
+
192
+ @pytest.mark.parametrize("model", MODELS)
193
+ @pytest.mark.parametrize("bias", [False, True])
194
+ @pytest.mark.parametrize("mesh", [
195
+ test_utils.get_spmd_mesh(1),
196
+ test_utils.get_spmd_mesh(jax.local_device_count())
197
+ ])
198
+ @pytest.mark.parametrize("enable_sp", [False, True])
199
+ def test_column_parallel_linear(model, bias, mesh, enable_sp):
200
+ dtype = torch.bfloat16
201
+
202
+ engine_args = EngineArgs(
203
+ model=model,
204
+ max_model_len=64,
205
+ max_num_batched_tokens=64,
206
+ max_num_seqs=4,
207
+ )
208
+ vllm_config = engine_args.create_engine_config()
209
+ vllm_config.compilation_config.pass_config.enable_sp = enable_sp
210
+
211
+ input_tensor = torch.rand(10, 4096, dtype=dtype) / 10
212
+ input_tensor = input_tensor.to('cpu')
213
+
214
+ with set_current_vllm_config(vllm_config):
215
+ column_linear = ColumnParallelLinear(
216
+ input_size=4096,
217
+ output_size=8192,
218
+ bias=bias,
219
+ params_dtype=dtype,
220
+ return_bias=False,
221
+ )
222
+
223
+ weight_data = torch.rand_like(column_linear.weight.data) / 10
224
+ if bias:
225
+ bias_data = torch.rand_like(column_linear.bias.data)
226
+
227
+ column_linear.weight.data = weight_data
228
+ if bias:
229
+ column_linear.bias.data = bias_data
230
+ column_linear = column_linear.to('cpu')
231
+ column_linear.quant_method.process_weights_after_loading(column_linear)
232
+ output = column_linear(input_tensor).to(dtype)
233
+
234
+ vllm_config.model_config.dtype = dtype
235
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
236
+ with set_current_vllm_config(vllm_config):
237
+ jax_column_linear = ColumnParallelLinear(
238
+ input_size=4096,
239
+ output_size=8192,
240
+ bias=bias,
241
+ params_dtype=dtype,
242
+ return_bias=False,
243
+ quant_config=quant_config,
244
+ )
245
+
246
+ jax_column_linear.weight.data = weight_data
247
+ if bias:
248
+ jax_column_linear.bias.data = bias_data
249
+
250
+ jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False))
251
+ jax_input_tensor.apply_jax_(jax.device_put,
252
+ NamedSharding(mesh, P(None, None)))
253
+ with torchax.default_env():
254
+ assert isinstance(jax_column_linear.quant_method,
255
+ VllmUnquantizedLinearMethod)
256
+ jax_column_linear.quant_method.process_weights_after_loading(
257
+ jax_column_linear)
258
+ jax_output = jax_column_linear(jax_input_tensor)
259
+ jax_output = j2t(jax_output.to(torch.float32)).to(dtype)
260
+
261
+ torch.testing.assert_close(output, jax_output)
262
+
263
+
264
+ @pytest.mark.parametrize("model", MODELS)
265
+ @pytest.mark.parametrize("bias", [False, True])
266
+ @pytest.mark.parametrize("mesh", [
267
+ test_utils.get_spmd_mesh(1),
268
+ test_utils.get_spmd_mesh(jax.local_device_count())
269
+ ])
270
+ @pytest.mark.parametrize("enable_sp", [False, True])
271
+ @pytest.mark.parametrize("fuse_matmuls", [False, True])
272
+ def test_qkv_parallel_linear(model, bias, mesh, enable_sp, fuse_matmuls):
273
+ dtype = torch.bfloat16
274
+
275
+ engine_args = EngineArgs(
276
+ model=model,
277
+ max_model_len=64,
278
+ max_num_batched_tokens=64,
279
+ max_num_seqs=4,
280
+ )
281
+ vllm_config = engine_args.create_engine_config()
282
+ vllm_config.compilation_config.pass_config.enable_sp = enable_sp
283
+
284
+ input_tensor = torch.rand(10, 4096, dtype=dtype) / 10
285
+ input_tensor = input_tensor.to('cpu')
286
+
287
+ with set_current_vllm_config(vllm_config):
288
+ qkv_linear = QKVParallelLinear(
289
+ hidden_size=4096,
290
+ head_size=128,
291
+ total_num_heads=32,
292
+ total_num_kv_heads=8,
293
+ bias=bias,
294
+ params_dtype=dtype,
295
+ return_bias=False,
296
+ )
297
+
298
+ weight_data = torch.rand_like(qkv_linear.weight.data) / 10
299
+ if bias:
300
+ bias_data = torch.rand_like(qkv_linear.bias.data)
301
+
302
+ qkv_linear.weight.data = weight_data
303
+ if bias:
304
+ qkv_linear.bias.data = bias_data
305
+ qkv_linear = qkv_linear.to('cpu')
306
+ qkv_linear.quant_method.process_weights_after_loading(qkv_linear)
307
+ output = qkv_linear(input_tensor).to(dtype)
308
+
309
+ vllm_config.model_config.dtype = dtype
310
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
311
+ with set_current_vllm_config(vllm_config):
312
+ jax_qkv_linear = QKVParallelLinear(
313
+ hidden_size=4096,
314
+ head_size=128,
315
+ total_num_heads=32,
316
+ total_num_kv_heads=8,
317
+ bias=bias,
318
+ params_dtype=dtype,
319
+ return_bias=False,
320
+ quant_config=quant_config,
321
+ )
322
+ jax_qkv_linear.quant_method.fuse_matmuls = fuse_matmuls
323
+
324
+ jax_qkv_linear.weight.data = weight_data
325
+ if bias:
326
+ jax_qkv_linear.bias.data = bias_data
327
+
328
+ jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False))
329
+ jax_input_tensor.apply_jax_(jax.device_put,
330
+ NamedSharding(mesh, P(None, None)))
331
+ with torchax.default_env():
332
+ assert isinstance(jax_qkv_linear.quant_method,
333
+ VllmUnquantizedLinearMethod)
334
+ jax_qkv_linear.quant_method.process_weights_after_loading(
335
+ jax_qkv_linear)
336
+ jax_output = jax_qkv_linear(jax_input_tensor)
337
+ jax_output = j2t(jax_output.to(torch.float32)).to(dtype)
338
+
339
+ torch.testing.assert_close(output, jax_output)
340
+
341
+
342
+ @pytest.mark.parametrize("model", MODELS)
343
+ @pytest.mark.parametrize("bias", [False, True])
344
+ @pytest.mark.parametrize("mesh", [
345
+ test_utils.get_spmd_mesh(1),
346
+ test_utils.get_spmd_mesh(jax.local_device_count())
347
+ ])
348
+ @pytest.mark.parametrize("fuse_matmuls", [False, True])
349
+ @pytest.mark.parametrize("enable_sp", [False, True])
350
+ def test_merged_column_parallel_linear(model, bias, mesh, fuse_matmuls,
351
+ enable_sp):
352
+ dtype = torch.bfloat16
353
+
354
+ engine_args = EngineArgs(
355
+ model=model,
356
+ max_model_len=64,
357
+ max_num_batched_tokens=64,
358
+ max_num_seqs=4,
359
+ )
360
+ vllm_config = engine_args.create_engine_config()
361
+ vllm_config.compilation_config.pass_config.enable_sp = enable_sp
362
+
363
+ input_tensor = torch.rand(10, 4096, dtype=dtype) / 10
364
+ input_tensor = input_tensor.to('cpu')
365
+
366
+ # Call vLLM code
367
+ with set_current_vllm_config(vllm_config):
368
+ merged_column_linear = MergedColumnParallelLinear(
369
+ input_size=4096,
370
+ output_sizes=[14336] * 2,
371
+ bias=bias,
372
+ params_dtype=dtype,
373
+ return_bias=False,
374
+ )
375
+
376
+ weight_data = torch.rand_like(merged_column_linear.weight.data) / 10
377
+ if bias:
378
+ bias_data = torch.rand_like(merged_column_linear.bias.data)
379
+
380
+ merged_column_linear.weight.data = weight_data
381
+ if bias:
382
+ merged_column_linear.bias.data = bias_data
383
+ merged_column_linear = merged_column_linear.to('cpu')
384
+ merged_column_linear.quant_method.process_weights_after_loading(
385
+ merged_column_linear)
386
+ output = merged_column_linear(input_tensor).to(dtype)
387
+
388
+ # Call tpu_inference code
389
+ vllm_config.model_config.dtype = dtype
390
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
391
+ with set_current_vllm_config(vllm_config):
392
+ jax_merged_column_linear = MergedColumnParallelLinear(
393
+ input_size=4096,
394
+ output_sizes=[14336] * 2,
395
+ bias=bias,
396
+ params_dtype=dtype,
397
+ return_bias=False,
398
+ quant_config=quant_config,
399
+ )
400
+ jax_merged_column_linear.quant_method.fuse_matmuls = fuse_matmuls
401
+
402
+ jax_merged_column_linear.weight.data = weight_data
403
+ if bias:
404
+ jax_merged_column_linear.bias.data = bias_data
405
+
406
+ jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False))
407
+ jax_input_tensor.apply_jax_(jax.device_put,
408
+ NamedSharding(mesh, P(None, None)))
409
+ with torchax.default_env():
410
+ assert isinstance(jax_merged_column_linear.quant_method,
411
+ VllmUnquantizedLinearMethod)
412
+ jax_merged_column_linear.quant_method.process_weights_after_loading(
413
+ jax_merged_column_linear)
414
+ jax_output = jax_merged_column_linear(jax_input_tensor)
415
+ jax_output = j2t(jax_output.to(torch.float32)).to(dtype)
416
+
417
+ torch.testing.assert_close(output, jax_output)
418
+
419
+
420
+ @pytest.mark.parametrize("use_ep", [True, False])
421
+ @pytest.mark.parametrize("mesh", [
422
+ test_utils.get_spmd_mesh(1),
423
+ test_utils.get_spmd_mesh(jax.local_device_count())
424
+ ])
425
+ @pytest.mark.parametrize("num_tokens", [8])
426
+ @pytest.mark.parametrize("intermediate_size", [1024, 2048])
427
+ @pytest.mark.parametrize("hidden_size", [128, 512])
428
+ @pytest.mark.parametrize("num_experts", [8])
429
+ @pytest.mark.parametrize("topk", [2])
430
+ @pytest.mark.parametrize("has_bias", [False, True])
431
+ @pytest.mark.parametrize("activation", ["silu", "swigluoai"])
432
+ def test_fused_moe(use_ep, mesh, num_tokens, intermediate_size, hidden_size,
433
+ num_experts, topk, has_bias, activation):
434
+
435
+ torch.manual_seed(42)
436
+ dtype = torch.bfloat16
437
+
438
+ a = torch.randn((num_tokens, hidden_size), dtype=dtype) / 10
439
+ w1 = torch.randn(
440
+ (num_experts, 2 * intermediate_size, hidden_size), dtype=dtype) / 10
441
+ w2 = torch.randn(
442
+ (num_experts, hidden_size, intermediate_size), dtype=dtype) / 10
443
+ score = torch.randn((num_tokens, num_experts), dtype=dtype)
444
+
445
+ w1_bias = w2_bias = None
446
+ if has_bias:
447
+ w1_bias = torch.randn(
448
+ (num_experts, 2 * intermediate_size), dtype=dtype) / 10
449
+ w2_bias = torch.randn((num_experts, hidden_size), dtype=dtype) / 10
450
+
451
+ engine_args = EngineArgs(
452
+ model="Qwen/Qwen2-1.5B-Instruct",
453
+ max_model_len=64,
454
+ max_num_batched_tokens=64,
455
+ max_num_seqs=4,
456
+ )
457
+ vllm_config = engine_args.create_engine_config()
458
+ vllm_config.model_config.dtype = dtype
459
+
460
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
461
+ with set_current_vllm_config(vllm_config):
462
+ vllm_fused_moe = FusedMoE(
463
+ num_experts=num_experts,
464
+ top_k=topk,
465
+ hidden_size=hidden_size,
466
+ intermediate_size=intermediate_size,
467
+ reduce_results=False,
468
+ renormalize=False,
469
+ tp_size=1,
470
+ dp_size=1,
471
+ quant_config=quant_config,
472
+ has_bias=has_bias,
473
+ activation=activation,
474
+ )
475
+ vllm_fused_moe.moe_parallel_config.use_ep = use_ep
476
+ vllm_fused_moe.w13_weight.data = w1
477
+ vllm_fused_moe.w2_weight.data = w2
478
+ if has_bias:
479
+ vllm_fused_moe.w13_bias.data = w1_bias
480
+ vllm_fused_moe.w2_bias.data = w2_bias
481
+
482
+ expected = test_utils.ref_moe(a, score, w1, w2, w1_bias, w2_bias,
483
+ vllm_fused_moe.top_k,
484
+ vllm_fused_moe.renormalize,
485
+ vllm_fused_moe.activation)
486
+
487
+ with torchax.default_env(), set_forward_context(None, vllm_config):
488
+ assert isinstance(vllm_fused_moe.quant_method,
489
+ VllmUnquantizedFusedMoEMethod)
490
+
491
+ jax_a = a.to('jax')
492
+ score = score.to('jax')
493
+
494
+ vllm_fused_moe.quant_method.process_weights_after_loading(
495
+ vllm_fused_moe)
496
+ actual = vllm_fused_moe(jax_a, score)
497
+
498
+ torch.testing.assert_close(expected,
499
+ actual,
500
+ check_device=False,
501
+ atol=1e-1,
502
+ rtol=1e-1)
503
+
504
+
505
+ @pytest.mark.parametrize("mesh",
506
+ [test_utils.get_spmd_mesh(jax.local_device_count())])
507
+ @pytest.mark.parametrize("num_tokens", [128, 512])
508
+ @pytest.mark.parametrize("intermediate_size", [512])
509
+ @pytest.mark.parametrize("hidden_size", [512])
510
+ @pytest.mark.parametrize("num_experts", [32])
511
+ @pytest.mark.parametrize("topk", [8])
512
+ @pytest.mark.parametrize("has_bias", [False, True])
513
+ def test_fused_moe_use_kernel(mesh, num_tokens, intermediate_size, hidden_size,
514
+ num_experts, topk, has_bias):
515
+
516
+ # TODO(Qiliang Cui): Remove when issue is resolved.
517
+ if not jtu.is_device_tpu_at_least(version=7):
518
+ pytest.skip(allow_module_level=True, reason="Expected TPUv7+")
519
+
520
+ torch.manual_seed(42)
521
+ dtype = torch.bfloat16
522
+
523
+ a = torch.randn((num_tokens, hidden_size), dtype=dtype) / 10
524
+ w1 = torch.randn(
525
+ (num_experts, 2 * intermediate_size, hidden_size), dtype=dtype) / 10
526
+ w2 = torch.randn(
527
+ (num_experts, hidden_size, intermediate_size), dtype=dtype) / 10
528
+
529
+ w1_bias = w2_bias = None
530
+ if has_bias:
531
+ w1_bias = torch.randn(
532
+ (num_experts, 2 * intermediate_size), dtype=dtype) / 10
533
+ w2_bias = torch.randn((num_experts, hidden_size), dtype=dtype) / 10
534
+
535
+ # Use deterministic gating_output generation (same logic as fused_moe_v1_test.py)
536
+ # Generate base gating scores with deterministic pattern
537
+ score = (
538
+ torch.randn((num_tokens, num_experts), dtype=torch.float32) +
539
+ torch.arange(num_tokens * num_experts, dtype=torch.float32).reshape(
540
+ num_tokens, num_experts) / 100)
541
+
542
+ # Generate unique top-k indices
543
+ generator = torch.Generator()
544
+ generator.manual_seed(42)
545
+ top_k_indices = torch.randint(0,
546
+ num_experts - 1, (num_tokens, topk),
547
+ dtype=torch.int32,
548
+ generator=generator)
549
+
550
+ # Add one-hot encoding weighted by 10 to ensure selected experts have highest scores
551
+ one_hot = torch.nn.functional.one_hot(top_k_indices.long(),
552
+ num_classes=num_experts).float()
553
+ one_hot = one_hot.sum(dim=1) * 10
554
+ score = (score + one_hot).to(dtype)
555
+
556
+ engine_args = EngineArgs(
557
+ model="Qwen/Qwen2-1.5B-Instruct",
558
+ max_model_len=64,
559
+ max_num_batched_tokens=64,
560
+ max_num_seqs=4,
561
+ )
562
+ vllm_config = engine_args.create_engine_config()
563
+ vllm_config.model_config.dtype = dtype
564
+ vllm_config.parallel_config = ParallelConfig(
565
+ tensor_parallel_size=mesh.devices.size)
566
+
567
+ quant_config = get_tpu_quantization_config(vllm_config, mesh)
568
+ with set_current_vllm_config(vllm_config):
569
+ vllm_fused_moe = FusedMoE(
570
+ num_experts=num_experts,
571
+ top_k=topk,
572
+ hidden_size=hidden_size,
573
+ intermediate_size=intermediate_size,
574
+ reduce_results=True,
575
+ renormalize=False,
576
+ tp_size=mesh.devices.size,
577
+ dp_size=1,
578
+ quant_config=quant_config,
579
+ has_bias=has_bias,
580
+ )
581
+ vllm_fused_moe.moe_parallel_config.use_ep = True
582
+ vllm_fused_moe.quant_method.use_kernel = True
583
+
584
+ vllm_fused_moe.w13_weight.data = w1
585
+ vllm_fused_moe.w2_weight.data = w2
586
+ if has_bias:
587
+ vllm_fused_moe.w13_bias.data = w1_bias
588
+ vllm_fused_moe.w2_bias.data = w2_bias
589
+
590
+ expected = test_utils.ref_moe(a, score, w1, w2, w1_bias, w2_bias,
591
+ vllm_fused_moe.top_k,
592
+ vllm_fused_moe.renormalize,
593
+ vllm_fused_moe.activation)
594
+
595
+ with torchax.default_env(), set_forward_context(None, vllm_config):
596
+ assert isinstance(vllm_fused_moe.quant_method,
597
+ VllmUnquantizedFusedMoEMethod)
598
+ jax_a = a.to('jax')
599
+ score = score.to('jax')
600
+
601
+ vllm_fused_moe.quant_method.process_weights_after_loading(
602
+ vllm_fused_moe)
603
+ vllm_fused_moe.quant_method.block_size = {
604
+ "bt": 32,
605
+ "bf": 512,
606
+ "bd1": 512,
607
+ "bd2": 512,
608
+ "btc": 32,
609
+ "bfc": 256,
610
+ "bd1c": 256,
611
+ "bd2c": 256,
612
+ }
613
+ actual = vllm_fused_moe(jax_a, score)
614
+
615
+ torch.testing.assert_close(
616
+ expected,
617
+ actual,
618
+ check_device=False,
619
+ atol=1e-2,
620
+ rtol=1e-2,
621
+ )
@@ -0,0 +1,72 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import jax
16
+ import torch
17
+ import torch.nn.functional as F
18
+
19
+
20
+ def get_spmd_mesh(num_devices: int = 1):
21
+ axis_names = ("data", "model")
22
+ devices = sorted(jax.devices(), key=lambda d: d.id)[0:num_devices]
23
+ mesh_shape = (1, len(devices))
24
+ return jax.make_mesh(mesh_shape, axis_names, devices=devices)
25
+
26
+
27
+ def find_all_layer_type(module: torch.nn.Module, layer_type: torch.nn.Module):
28
+ ret = []
29
+ for name, child in module.named_children():
30
+ if isinstance(child, layer_type):
31
+ ret.append(child)
32
+ else:
33
+ ret.extend(find_all_layer_type(child, layer_type))
34
+ return ret
35
+
36
+
37
+ # TODO(kyuyeunk): Consolidate all reference implementation used for unit tests
38
+ # into a single file.
39
+ def ref_moe(x, router_logits, w1, w2, w1_bias, w2_bias, top_k, renormalize,
40
+ activation):
41
+
42
+ expert_weights = F.softmax(router_logits, dim=-1)
43
+ expert_weights, expert_indices = torch.topk(expert_weights, top_k, dim=-1)
44
+ if renormalize:
45
+ expert_weights /= expert_weights.sum(dim=-1, keepdim=True)
46
+
47
+ x = torch.einsum("ti,eoi->teo", x, w1)
48
+ if w1_bias is not None:
49
+ x += w1_bias.unsqueeze(0)
50
+
51
+ match activation:
52
+ case "silu":
53
+ x1, x3 = x.chunk(chunks=2, dim=-1)
54
+ x = F.silu(x1) * x3
55
+ case "swigluoai":
56
+ x1, x3 = x[..., ::2], x[..., 1::2]
57
+ x1 = x1.clamp(min=None, max=7.0)
58
+ x3 = x3.clamp(min=-7.0, max=7.0)
59
+ gated_activation = x1 * torch.sigmoid(x1 * 1.702)
60
+ x = gated_activation * (x3 + 1)
61
+ case _:
62
+ raise NotImplementedError(
63
+ f"No reference implementation for {activation} activation")
64
+
65
+ x = torch.einsum("teo,eio->tei", x, w2)
66
+ if w2_bias is not None:
67
+ x += w2_bias.unsqueeze(0)
68
+
69
+ seq_indexes = torch.arange(x.shape[0]).unsqueeze(1)
70
+ x = x[seq_indexes, expert_indices]
71
+
72
+ return torch.einsum("tai,ta->ti", x, expert_weights)
tests/lora/__init__.py ADDED
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.