tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,416 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Optional, Union
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ import torch
20
+ from jax.experimental.layout import Format, Layout
21
+ from jax.sharding import Mesh, NamedSharding, PartitionSpec
22
+ from torch.nn.parameter import Parameter
23
+ from torchax.interop import jax_view, torch_view
24
+ from torchax.ops.mappings import t2j
25
+ from vllm.attention.layer import Attention
26
+ from vllm.logger import init_logger
27
+ from vllm.model_executor.layers.fused_moe.layer import (
28
+ FusedMoE, FusedMoEConfig, UnquantizedFusedMoEMethod)
29
+ from vllm.model_executor.layers.fused_moe.modular_kernel import (
30
+ FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize)
31
+ from vllm.model_executor.layers.linear import (LinearBase,
32
+ UnquantizedLinearMethod)
33
+ from vllm.model_executor.layers.quantization import \
34
+ register_quantization_config
35
+ from vllm.model_executor.layers.quantization.base_config import (
36
+ QuantizationConfig, QuantizeMethodBase)
37
+
38
+ from tpu_inference import envs
39
+ from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
40
+ from tpu_inference.layers.common.quant_methods import (UNQUANTIZED,
41
+ get_tpu_quant_method)
42
+ from tpu_inference.layers.vllm.fused_moe import fused_moe_func
43
+ from tpu_inference.layers.vllm.linear_common import (
44
+ reorder_concatenated_tensor_for_sharding,
45
+ slice_sharded_tensor_for_concatenation, torch_to_jax_param)
46
+ from tpu_inference.layers.vllm.quantization.common import (
47
+ JaxCommonConfig, JaxCommonLinearConfig)
48
+
49
+ P = PartitionSpec
50
+ logger = init_logger(__name__)
51
+
52
+
53
+ def align_to(a, b):
54
+ return (a + b - 1) // b * b
55
+
56
+
57
+ @register_quantization_config(get_tpu_quant_method(UNQUANTIZED))
58
+ class VllmUnquantizedConfig(QuantizationConfig, JaxCommonConfig):
59
+
60
+ @classmethod
61
+ def get_name(cls) -> str:
62
+ return UNQUANTIZED
63
+
64
+ @classmethod
65
+ def get_supported_act_dtypes(cls) -> list[torch.dtype]:
66
+ return [torch.float32, torch.float16, torch.bfloat16]
67
+
68
+ @classmethod
69
+ def get_min_capability(cls) -> int:
70
+ return 0 # Always supported
71
+
72
+ @classmethod
73
+ def get_config_filenames(cls) -> list[str]:
74
+ return [] # No extra configs required.
75
+
76
+ @classmethod
77
+ def from_config(cls, _: dict[str, Any]) -> "VllmUnquantizedConfig":
78
+ return cls()
79
+
80
+ def get_quant_method(self, layer: torch.nn.Module,
81
+ prefix: str) -> Optional[QuantizeMethodBase]:
82
+ if isinstance(layer, LinearBase):
83
+ linear_config = self.get_linear_config(layer)
84
+ return VllmUnquantizedLinearMethod(linear_config)
85
+ if isinstance(layer, FusedMoE):
86
+ moe_config = self.get_moe_config(layer)
87
+ return VllmUnquantizedFusedMoEMethod(moe_config, self.mesh)
88
+ if isinstance(layer, Attention):
89
+ return None
90
+ return None
91
+
92
+
93
+ class VllmUnquantizedLinearMethod(UnquantizedLinearMethod):
94
+
95
+ def __init__(self, jax_config: JaxCommonLinearConfig):
96
+ self.jax_config = jax_config
97
+
98
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
99
+ weight = torch_to_jax_param(
100
+ layer.weight,
101
+ NamedSharding(self.jax_config.mesh,
102
+ self.jax_config.weight_sharding),
103
+ self.jax_config.output_sizes,
104
+ self.jax_config.n_shards,
105
+ self.jax_config.fuse_matmuls,
106
+ )
107
+ delattr(layer, "weight")
108
+ layer.weight = weight
109
+
110
+ if layer.bias is not None and not layer.skip_bias_add:
111
+ if layer.return_bias:
112
+ logger.warning_once("Bias might return incorrect value.")
113
+
114
+ bias = torch_to_jax_param(
115
+ layer.bias,
116
+ NamedSharding(self.jax_config.mesh,
117
+ self.jax_config.bias_sharding),
118
+ self.jax_config.output_sizes,
119
+ self.jax_config.n_shards,
120
+ self.jax_config.fuse_matmuls,
121
+ )
122
+ delattr(layer, "bias")
123
+ layer.bias = bias
124
+
125
+ def apply(self,
126
+ layer: torch.nn.Module,
127
+ x: torch.Tensor,
128
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
129
+ assert isinstance(layer, LinearBase)
130
+
131
+ with jax.named_scope(layer._get_name()):
132
+ if in_sharding := self.jax_config.get_input_sharding(x):
133
+ x.shard_(NamedSharding(self.jax_config.mesh, in_sharding))
134
+
135
+ if self.jax_config.fuse_matmuls:
136
+ out = self._apply_fused(layer, x, bias)
137
+ else:
138
+ out = self._apply_split(layer, x, bias)
139
+
140
+ if out_sharding := self.jax_config.get_output_sharding(out):
141
+ out.shard_(NamedSharding(self.jax_config.mesh, out_sharding))
142
+
143
+ return out
144
+
145
+ def _apply_fused(self,
146
+ layer: torch.nn.Module,
147
+ x: torch.Tensor,
148
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
149
+ x_jax = jax_view(x)
150
+ weight_jax = jax_view(layer.weight)
151
+
152
+ outs = jnp.einsum("mn,pn->mp", x_jax, weight_jax)
153
+ if bias is not None and not layer.skip_bias_add:
154
+ outs += bias.jax()
155
+
156
+ outs = slice_sharded_tensor_for_concatenation(
157
+ outs, self.jax_config.output_sizes, self.jax_config.n_shards)
158
+ out = jnp.concatenate(outs, axis=-1)
159
+ return torch_view(out)
160
+
161
+ def _apply_split(self,
162
+ layer: torch.nn.Module,
163
+ x: torch.Tensor,
164
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
165
+ assert isinstance(layer.weight, torch.nn.ParameterList)
166
+
167
+ x_jax = x.jax()
168
+ outs = []
169
+ for i, weight in enumerate(layer.weight):
170
+ weight_jax = jax_view(weight)
171
+
172
+ out = jnp.einsum("mn,pn->mp", x_jax, weight_jax)
173
+ if bias is not None and not layer.skip_bias_add:
174
+ out += jax_view(bias[i])
175
+
176
+ outs.append(out)
177
+ out = jnp.concatenate(outs, axis=-1)
178
+ return torch_view(out)
179
+
180
+
181
+ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
182
+
183
+ def __init__(self,
184
+ moe: FusedMoEConfig,
185
+ mesh: Mesh,
186
+ ep_axis_name: str = 'model'):
187
+ super().__init__(moe)
188
+ self.mesh = mesh
189
+ self.use_kernel = envs.USE_MOE_EP_KERNEL and moe.use_ep
190
+ self.ep_axis_name = ep_axis_name
191
+ # TODO: Use autotune table once we have it.
192
+ self.block_size = {
193
+ "bt": 64,
194
+ "bf": 1024,
195
+ "bd1": 1536,
196
+ "bd2": 1536,
197
+ "btc": 64,
198
+ "bfc": 1024,
199
+ "bd1c": 1536,
200
+ "bd2c": 1536,
201
+ }
202
+
203
+ def select_gemm_impl(
204
+ self,
205
+ prepare_finalize: FusedMoEPrepareAndFinalize,
206
+ moe: FusedMoEConfig,
207
+ layer: torch.nn.Module,
208
+ ) -> FusedMoEPermuteExpertsUnpermute:
209
+ raise NotImplementedError(
210
+ "Selecting gemm implementation is currently not supported.")
211
+
212
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
213
+ assert isinstance(layer, FusedMoE)
214
+ w13_weight = t2j(layer.w13_weight, use_dlpack=False)
215
+ w2_weight = t2j(layer.w2_weight, use_dlpack=False)
216
+
217
+ num_experts, hidden_size, intermediate_size = w2_weight.shape
218
+
219
+ if self.moe.has_bias:
220
+ w13_bias = t2j(layer.w13_bias, use_dlpack=False)
221
+ w2_bias = t2j(layer.w2_bias, use_dlpack=False)
222
+
223
+ if layer.activation == "swigluoai":
224
+ # When using swigluoai, vLLM splits gmm output in a interleaved way.
225
+ # However, interleaved split is not performant on TPU. Therefore,
226
+ # we preprocess the weight so that splitting gmm output by middle
227
+ # can still get the same result.
228
+ w1_weight = w13_weight[:, ::2, :]
229
+ w3_weight = w13_weight[:, 1::2, :]
230
+ w13_weight = jnp.concat([w1_weight, w3_weight], axis=1)
231
+
232
+ if self.moe.has_bias:
233
+ w1_bias = w13_bias[:, ::2]
234
+ w3_bias = w13_bias[:, 1::2]
235
+ w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
236
+
237
+ if self.use_kernel:
238
+ # Kernel expects:
239
+ # w13: (num_experts, 2, hidden_size, intermediate_size)
240
+ # w2: (num_experts, intermediate_size, hidden_size)
241
+ # Current format:
242
+ # w13_weight: (num_experts, 2*intermediate_size, hidden_size)
243
+ # w2_weight: (num_experts, hidden_size, intermediate_size)
244
+ num_experts = w13_weight.shape[0]
245
+ intermediate_size = w13_weight.shape[1] // 2
246
+ hidden_size = w13_weight.shape[2]
247
+
248
+ padded_intermediate_size = align_to(intermediate_size, 256)
249
+ padded_hidden_size = align_to(hidden_size, 256)
250
+
251
+ # Transpose w2_weight to (num_experts, intermediate_size, hidden_size)
252
+ w13_weight = w13_weight.reshape(num_experts, 2, intermediate_size,
253
+ hidden_size)
254
+ w13_weight = jnp.swapaxes(w13_weight, 3, 2)
255
+
256
+ w2_weight = jnp.swapaxes(w2_weight, 2, 1)
257
+
258
+ w13_weight = jnp.pad(
259
+ w13_weight,
260
+ ((0, 0), (0, 0), (0, padded_hidden_size - hidden_size),
261
+ (0, padded_intermediate_size - intermediate_size)),
262
+ constant_values=0)
263
+
264
+ w2_weight = jnp.pad(
265
+ w2_weight,
266
+ ((0, 0), (0, padded_intermediate_size - intermediate_size),
267
+ (0, padded_hidden_size - hidden_size)),
268
+ constant_values=0)
269
+
270
+ # Apply EP sharding
271
+ ep_sharding = NamedSharding(self.mesh, P("model"))
272
+
273
+ w13_weight = jax.device_put(
274
+ w13_weight,
275
+ Format(Layout((0, 1, 2, 3)),
276
+ NamedSharding(self.mesh, P("model", None, None, None))))
277
+ w2_weight = jax.device_put(
278
+ w2_weight,
279
+ Format(Layout((0, 1, 2)),
280
+ NamedSharding(self.mesh, P("model", None, None))))
281
+
282
+ if self.moe.has_bias:
283
+ w13_bias = w13_bias.astype(jnp.float32).reshape(
284
+ num_experts, 2, 1, intermediate_size)
285
+ w2_bias = w2_bias.astype(jnp.float32).reshape(
286
+ num_experts, 1, hidden_size)
287
+
288
+ w13_bias = jnp.pad(
289
+ w13_bias,
290
+ ((0, 0), (0, 0), (0, 0),
291
+ (0, padded_intermediate_size - intermediate_size)),
292
+ constant_values=0)
293
+
294
+ w2_bias = jnp.pad(w2_bias,
295
+ ((0, 0), (0, 0),
296
+ (0, padded_hidden_size - hidden_size)),
297
+ constant_values=0)
298
+
299
+ # Apply EP sharding
300
+ w13_bias = jax.device_put(
301
+ w13_bias, Format(Layout((0, 1, 2, 3)), ep_sharding))
302
+ w2_bias = jax.device_put(
303
+ w2_bias, Format(Layout((0, 1, 2)), ep_sharding))
304
+ else:
305
+ if self.moe.has_bias:
306
+ w13_bias = jnp.expand_dims(w13_bias, 1)
307
+ w2_bias = jnp.expand_dims(w2_bias, 1)
308
+
309
+ if layer.use_ep:
310
+ ep_sharding = NamedSharding(self.mesh, P("model"))
311
+ w13_weight = jax.device_put(
312
+ w13_weight, Format(Layout((0, 1, 2)), ep_sharding))
313
+ w2_weight = jax.device_put(
314
+ w2_weight, Format(Layout((0, 1, 2)), ep_sharding))
315
+
316
+ if self.moe.has_bias:
317
+ w13_bias = jax.device_put(
318
+ w13_bias, Format(Layout((0, 1, 2)), ep_sharding))
319
+ w2_bias = jax.device_put(
320
+ w2_bias, Format(Layout((0, 1, 2)), ep_sharding))
321
+
322
+ else:
323
+ output_sizes = [intermediate_size, intermediate_size]
324
+ n_shards = self.mesh.shape["model"]
325
+ assert intermediate_size % n_shards == 0
326
+
327
+ w13_weight = reorder_concatenated_tensor_for_sharding(
328
+ w13_weight, output_sizes, n_shards, dim=1)
329
+ w13_weight = jax.device_put(
330
+ w13_weight,
331
+ Format(Layout((0, 1, 2)),
332
+ NamedSharding(self.mesh, P(None, "model", None))))
333
+ w2_weight = jax.device_put(
334
+ w2_weight,
335
+ Format(Layout((0, 1, 2)),
336
+ NamedSharding(self.mesh, P(None, None, "model"))))
337
+
338
+ if self.moe.has_bias:
339
+ w13_bias = reorder_concatenated_tensor_for_sharding(
340
+ w13_bias, output_sizes, n_shards, dim=2)
341
+
342
+ w13_bias = jax.device_put(
343
+ w13_bias,
344
+ Format(
345
+ Layout((0, 1, 2)),
346
+ NamedSharding(self.mesh, P(None, None, "model"))))
347
+ w2_bias = jax.device_put(
348
+ w2_bias,
349
+ Format(Layout((0, 1, 2)),
350
+ NamedSharding(self.mesh, P(None, None, None))))
351
+
352
+ layer.w13_weight = Parameter(torch_view(w13_weight),
353
+ requires_grad=False)
354
+ layer.w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
355
+
356
+ if self.moe.has_bias:
357
+ layer.w13_bias = Parameter(torch_view(w13_bias),
358
+ requires_grad=False)
359
+ layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False)
360
+
361
+ def apply(
362
+ self,
363
+ layer: torch.nn.Module,
364
+ x: torch.Tensor,
365
+ router_logits: torch.Tensor,
366
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
367
+ assert isinstance(layer, FusedMoE)
368
+ if layer.scoring_func != "softmax":
369
+ raise NotImplementedError(
370
+ "Only softmax is supported for scoring_func")
371
+
372
+ x = jax_view(x)
373
+ w13_weight = jax_view(layer.w13_weight)
374
+ w2_weight = jax_view(layer.w2_weight)
375
+ w13_bias = w2_bias = None
376
+ if self.moe.has_bias:
377
+ w13_bias = jax_view(layer.w13_bias)
378
+ w2_bias = jax_view(layer.w2_bias)
379
+ gating_output = jax_view(router_logits)
380
+
381
+ if self.use_kernel:
382
+ actual_hidden_size = x.shape[-1]
383
+ padding_size = w13_weight.shape[-2] - actual_hidden_size
384
+ x = jnp.pad(x, ((0, 0), (0, padding_size)))
385
+ output = fused_ep_moe(
386
+ mesh=self.mesh,
387
+ tokens=x,
388
+ w1=w13_weight,
389
+ w2=w2_weight,
390
+ b1=w13_bias,
391
+ b2=w2_bias,
392
+ gating_output=gating_output,
393
+ top_k=layer.top_k,
394
+ ep_axis_name=self.ep_axis_name,
395
+ renormalize_topk_logits=layer.renormalize,
396
+ act_fn=layer.activation,
397
+ **self.block_size,
398
+ )[:, :actual_hidden_size]
399
+ else:
400
+ output = fused_moe_func(
401
+ hidden_states=x,
402
+ w1=w13_weight,
403
+ w2=w2_weight,
404
+ w1_scale=None,
405
+ w2_scale=None,
406
+ w1_bias=w13_bias,
407
+ w2_bias=w2_bias,
408
+ gating_output=gating_output,
409
+ topk=layer.top_k,
410
+ renormalize=layer.renormalize,
411
+ mesh=self.mesh,
412
+ use_ep=layer.use_ep,
413
+ activation=layer.activation,
414
+ )
415
+
416
+ return torch_view(output)
@@ -0,0 +1,244 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ import torch
20
+ import torchax
21
+ from jax.sharding import Mesh, NamedSharding, PartitionSpec
22
+ from torch.nn import Parameter
23
+ from torch.utils import _pytree as pytree
24
+ from torchax.interop import jax_view, torch_view
25
+ from torchax.ops.mappings import t2j
26
+ from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
27
+ MergedColumnParallelLinearWithLoRA,
28
+ MergedQKVParallelLinearWithLoRA,
29
+ QKVParallelLinearWithLoRA,
30
+ ReplicatedLinearWithLoRA,
31
+ RowParallelLinearWithLoRA)
32
+ from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA
33
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
34
+ ParallelLMHead, VocabParallelEmbedding)
35
+
36
+ from tpu_inference import envs
37
+ from tpu_inference.logger import init_logger
38
+
39
+ P = PartitionSpec
40
+
41
+ logger = init_logger(__name__)
42
+
43
+ TORCH_TO_JAX_DTYPE_MAP = {
44
+ torch.float32: jnp.float32,
45
+ torch.float16: jnp.float16,
46
+ torch.bfloat16: jnp.bfloat16,
47
+ }
48
+
49
+
50
+ def shard_model_to_tpu(model: torch.nn.Module,
51
+ mesh: Mesh) -> dict[str, torchax.torch.Tensor]:
52
+ """
53
+ Shard the model weights and move them to TPU.
54
+ At the same time, also turn the weight tensors into torchax tensors so that
55
+ jax code can interop with it and the overall program can be traced and
56
+ compiled in XLA.
57
+ Args:
58
+ model: A PyTorch model whose weights are on CPU main memory.
59
+ mesh: JAX mesh object for sharding.
60
+ Returns:
61
+ Dictionary of parameters and buffers that will be used as arguments of
62
+ torch.func.functional_call
63
+ """
64
+
65
+ with jax.default_device(jax.devices("cpu")[0]):
66
+ _shard_module_to_tpu(model, mesh)
67
+
68
+ params, buffers = _extract_all_params_buffers(model)
69
+
70
+ # For other weight tensors, repliate them on all the TPU chips.
71
+ params, buffers = pytree.tree_map_only(
72
+ _tensor_is_in_cpu,
73
+ lambda tensor: _shard_tensor_to_tpu_replicated(tensor, mesh),
74
+ (params, buffers))
75
+
76
+ return {**params, **buffers}
77
+
78
+
79
+ def update_lora(model: torch.nn.Module,
80
+ initial_params_buffers) -> dict[str, torchax.torch.Tensor]:
81
+ params, buffers = _extract_all_params_buffers(model)
82
+ params_buffers = {**params, **buffers}
83
+ for k, v in params_buffers.items():
84
+ if 'lora_a_stacked' in k or 'lora_b_stacked' in k:
85
+ assert k in initial_params_buffers, f"{k} not in initial_params_buffers"
86
+ initial_params_buffers[k] = v
87
+
88
+ return initial_params_buffers
89
+
90
+
91
+ def _extract_all_params_buffers(model: torch.nn.Module):
92
+ return dict(model.named_parameters()), dict(model.named_buffers())
93
+
94
+
95
+ def _tensor_is_in_cpu(tensor: torch.tensor) -> bool:
96
+ # Check if a tensor haven't been converted to torchax tensor.
97
+ if not isinstance(tensor, torchax.tensor.Tensor):
98
+ return True
99
+ # Check if torchax tensor is still in CPU.
100
+ return tensor.jax_device == jax.devices('cpu')[0]
101
+
102
+
103
+ def _convert_to_torchax_and_shard(tensor: torch.Tensor,
104
+ sharding: NamedSharding) -> torch.Tensor:
105
+ if os.getenv("VLLM_TPU_USING_PATHWAYS", False) and isinstance(
106
+ tensor, torch.Tensor):
107
+ np_tensor = tensor.detach().cpu().to(torch.float32).numpy()
108
+ dtype = TORCH_TO_JAX_DTYPE_MAP.get(tensor.dtype, jnp.float32)
109
+ return torch_view(jax.device_put(np_tensor, sharding).astype(dtype))
110
+ else:
111
+ if isinstance(tensor, torchax.tensor.Tensor):
112
+ tensor = jax_view(tensor)
113
+ else:
114
+ tensor = t2j(tensor)
115
+ return torch_view(_sharded_device_put(tensor, sharding))
116
+
117
+
118
+ def _shard_tensor_to_tpu_replicated(tensor: torch.Tensor,
119
+ mesh: Mesh) -> torchax.tensor.Tensor:
120
+ return _convert_to_torchax_and_shard(tensor, NamedSharding(mesh, P()))
121
+
122
+
123
+ def _shard_vocab_parallel_embedding(layer: VocabParallelEmbedding,
124
+ mesh: Mesh) -> None:
125
+ weight = _convert_to_torchax_and_shard(
126
+ layer.weight, NamedSharding(mesh, P('model', None)))
127
+ layer.weight = Parameter(weight, requires_grad=False)
128
+
129
+
130
+ def _shard_lm_head(layer: ParallelLMHead, mesh: Mesh):
131
+ # TODO(qihqi): currently this is not handling case of tie_word_weights=True.
132
+ # if that config is set, then we should not create new weights but reuse the
133
+ # weight from VocabParallelEmbedding
134
+ weight = _convert_to_torchax_and_shard(
135
+ layer.weight, NamedSharding(mesh, P('model', None)))
136
+ layer.weight = Parameter(weight, requires_grad=False)
137
+ if layer.bias is not None:
138
+ bias = _convert_to_torchax_and_shard(layer.bias,
139
+ NamedSharding(mesh, P('model')))
140
+ layer.bias = Parameter(bias, requires_grad=False)
141
+
142
+
143
+ def _shard_base_linear_lora_replicated(layer: BaseLinearLayerWithLoRA,
144
+ mesh: Mesh) -> None:
145
+ # NOTE: lora_a_stacked[i] has shape [max_loras, 1, num_out, num_in]
146
+ sharded_lora_a_tpu = torch.nn.ParameterList()
147
+ sharded_lora_b_tpu = torch.nn.ParameterList()
148
+
149
+ for i in range(layer.n_slices):
150
+ sharded_lora_a_tpu.append(
151
+ _shard_tensor_to_tpu_replicated(layer.lora_a_stacked[i], mesh))
152
+ sharded_lora_b_tpu.append(
153
+ _shard_tensor_to_tpu_replicated(layer.lora_b_stacked[i], mesh))
154
+
155
+ layer.lora_a_stacked = sharded_lora_a_tpu
156
+ layer.lora_b_stacked = sharded_lora_b_tpu
157
+
158
+
159
+ def _shard_column_linear_lora(layer: ColumnParallelLinearWithLoRA,
160
+ mesh: Mesh) -> None:
161
+ assert layer.n_slices > 0, "layer.n_slices should be greater than 0"
162
+ # lora_a_stacked[i] has shape [max_loras, 1, max_lora_rank, in_features]
163
+ sharded_lora_a_tpu = torch.nn.ParameterList()
164
+ sharded_lora_b_tpu = torch.nn.ParameterList()
165
+
166
+ # lora_b_stacked[i] has shape [max_loras, 1, out_features, max_lora_rank]
167
+ lora_b_partition_spec = P(None, None, 'model', None)
168
+ lora_b_sharding = NamedSharding(mesh, lora_b_partition_spec)
169
+ for i in range(layer.n_slices):
170
+ sharded_lora_a_tpu.append(
171
+ _shard_tensor_to_tpu_replicated(layer.lora_a_stacked[i], mesh))
172
+
173
+ sharded_lora_b_tpu.append(
174
+ _convert_to_torchax_and_shard(layer.lora_b_stacked[i],
175
+ lora_b_sharding))
176
+
177
+ layer.lora_a_stacked = sharded_lora_a_tpu
178
+ layer.lora_b_stacked = sharded_lora_b_tpu
179
+
180
+
181
+ def _shard_qkv_linear_lora(layer: ColumnParallelLinearWithLoRA,
182
+ mesh: Mesh) -> None:
183
+ _shard_column_linear_lora(layer, mesh)
184
+
185
+
186
+ def _shard_merged_column_parallel_linear_lora(
187
+ layer: MergedColumnParallelLinearWithLoRA, mesh: Mesh) -> None:
188
+ _shard_column_linear_lora(layer, mesh)
189
+
190
+
191
+ def _shard_merged_qkv_parallel_linear_lora(
192
+ layer: MergedQKVParallelLinearWithLoRA, mesh: Mesh) -> None:
193
+ _shard_column_linear_lora(layer, mesh)
194
+
195
+
196
+ def _shard_row_parallel_linear_lora(layer: RowParallelLinearWithLoRA,
197
+ mesh: Mesh) -> None:
198
+ _shard_base_linear_lora_replicated(layer, mesh)
199
+
200
+
201
+ # NOTE: Ordering is important as it calls first matched type of a given module
202
+ MODULE_TYPE_TO_SHARDING_FUNC = [
203
+ # Shard embedding layers
204
+ (ParallelLMHead, _shard_lm_head),
205
+ (VocabParallelEmbedding, _shard_vocab_parallel_embedding),
206
+ # Shard LoRA layers
207
+ (ColumnParallelLinearWithLoRA, _shard_column_linear_lora),
208
+ (QKVParallelLinearWithLoRA, _shard_qkv_linear_lora),
209
+ (MergedColumnParallelLinearWithLoRA,
210
+ _shard_merged_column_parallel_linear_lora),
211
+ (MergedQKVParallelLinearWithLoRA, _shard_merged_qkv_parallel_linear_lora),
212
+ (RowParallelLinearWithLoRA, _shard_row_parallel_linear_lora),
213
+ (ReplicatedLinearWithLoRA, _shard_base_linear_lora_replicated),
214
+ ]
215
+
216
+
217
+ def _shard_module_to_tpu(model: torch.nn.Module, mesh: Mesh) -> None:
218
+ for path, module in model.named_modules():
219
+ for module_type, sharding_func in MODULE_TYPE_TO_SHARDING_FUNC:
220
+ if type(module) is module_type:
221
+ logger.debug("shard %s with %s", path, sharding_func)
222
+ sharding_func(module, mesh)
223
+ break
224
+
225
+
226
+ def _sharded_device_put(tensor: jax.Array, sharding) -> jax.Array:
227
+ if isinstance(tensor, tuple):
228
+ return tuple(_sharded_device_put(t, sharding) for t in tensor)
229
+ multihost_backend = envs.TPU_MULTIHOST_BACKEND
230
+ if multihost_backend != "ray":
231
+ return jax.device_put(tensor, sharding)
232
+
233
+ # NOTE: at here, num_global_devices != num_local_devices
234
+ # meaning we are in multi-host setup. Each host will run the same process
235
+ # and each process only need to handle the devices accessible to this host.
236
+ shape = tensor.shape
237
+ x_split = [
238
+ jax.device_put(tensor[i], device) for device, i in
239
+ sharding.addressable_devices_indices_map(shape).items()
240
+ ]
241
+ return jax.make_array_from_single_device_arrays(shape,
242
+ sharding,
243
+ x_split,
244
+ dtype=tensor.dtype)
@@ -0,0 +1,10 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from vllm.logger import _VllmLogger
4
+ from vllm.logger import init_logger as init_vllm_logger
5
+
6
+
7
+ def init_logger(name: str) -> _VllmLogger:
8
+ # Prepend the root "vllm" to the module path to use vllm's configured logger.
9
+ patched_name = "vllm." + name
10
+ return init_vllm_logger(patched_name)
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.