tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (251) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +22 -1
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +167 -97
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +31 -9
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +280 -210
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +77 -36
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +91 -31
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -4
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -71
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +158 -63
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +53 -30
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +54 -2
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +105 -57
  232. tpu_inference/runner/utils.py +2 -2
  233. tpu_inference/spec_decode/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/__init__.py +13 -0
  235. tpu_inference/spec_decode/jax/eagle3.py +65 -19
  236. tpu_inference/tpu_info.py +14 -0
  237. tpu_inference/utils.py +72 -44
  238. tpu_inference/worker/__init__.py +13 -0
  239. tpu_inference/worker/tpu_worker.py +65 -52
  240. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
  241. tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
  242. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  244. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  245. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  246. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  247. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  248. tpu_inference-0.11.1.dev202511270815.dist-info/RECORD +0 -174
  249. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
  250. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
  251. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
2
  # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ import copy
3
4
  import functools
4
5
  import os
5
6
  from typing import TYPE_CHECKING, Callable, List
@@ -41,10 +42,36 @@ DEFAULT_DEEPSEEK_FP8_CONFIG = {
41
42
  "scale_dtype":
42
43
  "bfloat16",
43
44
  "rules": [
45
+ # Exclude router from quantization
44
46
  {
45
47
  "module_path": ".*.custom_module.router.*",
46
48
  "weight_qtype": None,
47
49
  },
50
+ # Avoid the combine expert ops
51
+ {
52
+ "module_path": ".*combine_experts.*",
53
+ "weight_qtype": None,
54
+ },
55
+ # Attention layers: keep FP8 for weights and activations
56
+ {
57
+ "module_path": ".*.attn.*",
58
+ "weight_qtype": "float8_e4m3fn",
59
+ "act_qtype": "float8_e4m3fn",
60
+ },
61
+ # MoE experts: use FP4 for expert weights
62
+ {
63
+ "module_path": ".*.custom_module.*",
64
+ "weight_qtype": "float4_e2m1fn",
65
+ "act_qtype": "float8_e4m3fn",
66
+ "tile_size": 256,
67
+ },
68
+ # Shared experts: also FP4
69
+ {
70
+ "module_path": ".*.shared_experts.*",
71
+ "weight_qtype": "float4_e2m1fn",
72
+ "act_qtype": "float8_e4m3fn",
73
+ "tile_size": 256,
74
+ },
48
75
  {
49
76
  "module_path": ".*",
50
77
  "weight_qtype": "float8_e4m3fn",
@@ -154,12 +181,9 @@ def qwix_quantize_nnx_model(model: nnx.Module, qwix_config: List[dict],
154
181
  logger.info(f"Memory usage before applying quantization of params: "
155
182
  f"hbm={utils.hbm_usage_gb(jax.local_devices())}Gb")
156
183
 
157
- # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
158
- kv_cache_jnp_dtype = utils.get_jax_dtype_from_str_dtype(kv_cache_dtype)
159
-
160
- # Handle the case where kv_cache_dtype is "auto"
161
- if kv_cache_jnp_dtype is None:
162
- assert kv_cache_dtype == "auto", "kv_cache_dtype must be 'auto' if kv_cache_jnp_dtype is None"
184
+ if kv_cache_dtype != "auto":
185
+ kv_cache_jnp_dtype = utils.to_jax_dtype(kv_cache_dtype)
186
+ else:
163
187
  kv_cache_jnp_dtype = DEFAULT_KV_CACHE_DTYPE
164
188
 
165
189
  kv_caches = create_kv_caches(
@@ -169,9 +193,11 @@ def qwix_quantize_nnx_model(model: nnx.Module, qwix_config: List[dict],
169
193
  head_size=kv_cache_head_size,
170
194
  mesh=mesh,
171
195
  layer_names=[f"layer.{i}" for i in range(num_hidden_layers)],
172
- cache_dtype=kv_cache_jnp_dtype)
196
+ cache_dtype=kv_cache_jnp_dtype,
197
+ use_mla=model.vllm_config.model_config.use_mla,
198
+ )
173
199
 
174
- dp_size = mesh.shape.get("data", 1) * mesh.shape.get("attn", 1)
200
+ dp_size = model.vllm_config.sharding_config.total_dp_size
175
201
 
176
202
  # NOTE: the inputs don't need to match the actual ones, as long as the consumed weights are the same
177
203
  input_ids = jax.random.randint(rng,
@@ -399,8 +425,7 @@ def apply_qwix_on_abstract_model(vllm_config: "VllmConfig") -> bool:
399
425
 
400
426
 
401
427
  def get_default_qwix_quantization_config(
402
- model_type: str, quant_method: str,
403
- skip_quantization: bool) -> dict | None:
428
+ hf_config: dict, skip_quantization: bool) -> dict | None:
404
429
  """
405
430
  Some models are pre-quantized and in those cases, we want to return a default set of
406
431
  Qwix quantization rules (instead of forcing the user to pass in a quantization config each time).
@@ -418,9 +443,42 @@ def get_default_qwix_quantization_config(
418
443
  """
419
444
  if skip_quantization:
420
445
  return None
421
- # TODO (jacobplatin): remove this so that we can support various quantization types
446
+ model_type = hf_config.model_type.lower() if hasattr(
447
+ hf_config, "model_type") else None
448
+ quant_method = hf_config.quantization_config["quant_method"] if hasattr(
449
+ hf_config, "quantization_config") else None
450
+ # TODO (jacobplatin): remove this so that we can support various quantization types + make
451
+ # more flexible
452
+ # NOTE (jacobplatin): we'll default to mixed FP8 (attention) + FP4 (MoE experts)
453
+ # for DeepSeek
422
454
  if model_type == "deepseek_v3" and quant_method == "fp8":
423
- return DEFAULT_DEEPSEEK_FP8_CONFIG
455
+ config = copy.deepcopy(DEFAULT_DEEPSEEK_FP8_CONFIG)
456
+
457
+ # Dynamically fetch block size from HF config if available
458
+ # Config fmt: 'weight_block_size': [1, 512] -> we want the 2nd dim for tile_size
459
+ # NOTE: if the checkpoint is not 1D subchannel, we will throw an error
460
+ hf_quant_config = hf_config.quantization_config
461
+ assert "weight_block_size" in hf_quant_config, "Expected weight_block_size in quantization_config"
462
+ block_size = hf_quant_config["weight_block_size"]
463
+ if isinstance(block_size, (list, tuple)) and len(block_size) == 2:
464
+ assert block_size[
465
+ 0] == 1, f"Expected first dimension to be 1 (unchanneled), but got {block_size[0]}!"
466
+ tile_size = block_size[1]
467
+ assert tile_size > 1, f"Expected tile_size > 1 for DeepSeek, but got {tile_size}"
468
+ logger.info(
469
+ f"Detected DeepSeek tile_size from config: {tile_size}")
470
+
471
+ # Update tile_size in the rules, since we might not always use a 1D subchannel size of
472
+ # 256
473
+ for rule in config["qwix"]["rules"]:
474
+ if "tile_size" in rule:
475
+ rule["tile_size"] = tile_size
476
+ else:
477
+ raise ValueError(
478
+ f"Invalid weight_block_size config: {block_size}, expected a list/tuple of length 2"
479
+ )
480
+
481
+ return config
424
482
  elif model_type == "llama4" and quant_method == "compressed-tensors":
425
483
  return DEFAULT_LLAMA4_FP8_CONFIG
426
484
  # MXFP4 (GPT-OSS): provide a default configuration to quantize MoE experts via Qwix
@@ -439,14 +497,10 @@ def update_vllm_config_for_qwix_quantization(vllm_config: "VllmConfig"):
439
497
  # Qwix quantization config accordingly
440
498
  # NOTE: if a Qwix config is provided (via the`additional_config`), we'll
441
499
  # use that instead
442
- model_type = vllm_config.model_config.hf_config.model_type.lower(
443
- ) if hasattr(vllm_config.model_config.hf_config, "model_type") else None
444
- quant_method = vllm_config.model_config.hf_config.quantization_config[
445
- "quant_method"] if hasattr(vllm_config.model_config.hf_config,
446
- "quantization_config") else None
500
+ hf_config = vllm_config.model_config.hf_config
447
501
  default_quantization_config = get_default_qwix_quantization_config(
448
- model_type, quant_method,
449
- vllm_config.additional_config.get("skip_quantization", False))
502
+ hf_config, vllm_config.additional_config.get("skip_quantization",
503
+ False))
450
504
 
451
505
  maybe_existing_quantization_config = vllm_config.additional_config.get(
452
506
  "quantization")
@@ -503,7 +557,14 @@ def get_random_sharded_array(key: PRNGKey, mesh: Mesh, param: nnx.Param,
503
557
  maxval = jnp.array(jnp.iinfo(dtype).max, dtype=dtype)
504
558
  weight = jax.random.randint(key, param_shape, minval, maxval, dtype)
505
559
  else:
506
- weight = jax.random.normal(key, param_shape, dtype)
560
+ # NOTE: _uniform() in random.py does not accept float4_e2m1fn
561
+ # Error: "TypeError: uniform only accepts 8-, 16-, 32-, or 64-bit dtypesgot float4_e2m1fn."
562
+ # Workaround: call function with dtype jnp.float8_e4m3fn and cast back to float4_e2m1fn
563
+ if dtype != "float4_e2m1fn":
564
+ weight = jax.random.normal(key, param_shape, dtype)
565
+ else:
566
+ weight = jax.random.normal(key, param_shape,
567
+ jnp.float8_e4m3fn).astype(dtype)
507
568
 
508
569
  def get_slice(index):
509
570
  return weight[index]
@@ -538,18 +599,16 @@ def load_random_weights_into_qwix_abstract_model(rng: PRNGKey,
538
599
  logger.info("Initializing Qwix-quantized model with random weights...")
539
600
  # TODO (jacobplatin): clean up this logic
540
601
  scale_dtype = model.weight_loader.scale_dtype
541
- scale_shape_map = model.weight_loader.scale_shap_map_for_random_weight_loading if hasattr(
602
+ scale_shape_map = model.weight_loader.scale_shape_map_for_random_weight_loading if hasattr(
542
603
  model.weight_loader,
543
- 'scale_shap_map_for_random_weight_loading') else {}
604
+ 'scale_shape_map_for_random_weight_loading') else {}
544
605
  quantization_block_sizes = quantization_config["weight_block_size"]
545
606
  assert len(
546
607
  quantization_block_sizes
547
608
  ) == 2, f"Expected only 2 quantization block sizes but got {quantization_block_sizes}"
548
- quantization_block_size_n, _ = quantization_block_sizes[
549
- 0], quantization_block_sizes[1]
550
609
 
551
610
  # Iterate through all variables and initialize them
552
- prev_param_shape = None
611
+
553
612
  for path, param in nnx.iter_graph(model):
554
613
  if not isinstance(param, nnx.Variable):
555
614
  continue
@@ -559,16 +618,17 @@ def load_random_weights_into_qwix_abstract_model(rng: PRNGKey,
559
618
  is_qwix_scale = (path[-1] == 'scale' and path[-2] == "array")
560
619
  param_dtype = scale_dtype if is_qwix_scale else param.value.dtype
561
620
  param_shape = param.value.shape
562
- # TODO (jacobplatin): clean this up
563
621
  if is_qwix_scale:
564
- param_shape = scale_shape_map.get(
565
- path[3],
566
- tuple(dim // quantization_block_size_n
567
- for dim in prev_param_shape))
622
+ key = f"{path[2]}.{path[3]}"
623
+
624
+ if key in scale_shape_map:
625
+ param_shape = scale_shape_map[key]
626
+ else:
627
+ raise ValueError(
628
+ f"Scale shape for {key} not found in scale_shape_map.")
568
629
  param.value = get_random_sharded_array(
569
630
  rng, mesh, param, param_shape, param_dtype,
570
631
  ".".join([str(x) for x in path]))
571
- prev_param_shape = param_shape
572
632
 
573
633
  # Handles the DeepSeek case, where this needs to be called to make the cache weights
574
634
  # concrete
@@ -1,3 +1,16 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
1
14
  """Utilities for downloading model weights from HuggingFace."""
2
15
 
3
16
  import functools
@@ -67,7 +80,13 @@ def transpose_params(param_key: str, param_tensor: jax.Array, transpose_map):
67
80
  def reshape_params(param_key: str, param_tensor: jax.Array, shape_map):
68
81
  for key, new_shape in shape_map.items():
69
82
  if key in param_key:
70
- return jnp.reshape(param_tensor, new_shape)
83
+ try:
84
+ #TODO:(gpolovets) Add validation on whether reshape preserves data layout.
85
+ return jnp.reshape(param_tensor, new_shape)
86
+ except TypeError:
87
+ raise TypeError(
88
+ f"Cannot reshape for key={key}, new_shape={new_shape}, param_shape={param_tensor.shape}"
89
+ )
71
90
  return param_tensor # Base case / no-op
72
91
 
73
92
 
@@ -275,7 +294,8 @@ def _load_and_shard_weight(vllm_config,
275
294
  hf_key: str,
276
295
  hf_weight: jax.Array,
277
296
  keep_original_dtype_keys_regex: list[str]
278
- | None = None):
297
+ | None = None,
298
+ pp_missing_layers: list[str] | None = None):
279
299
  name_map = metadata_map.name_map
280
300
  reshape_keys = metadata_map.reshape_map
281
301
  bias_reshape_keys = metadata_map.bias_reshape_map
@@ -331,6 +351,10 @@ def _load_and_shard_weight(vllm_config,
331
351
  return
332
352
  model_key = name_map.get(hf_key, hf_key)
333
353
 
354
+ if pp_missing_layers and _is_pp_missing_layer(hf_key, pp_missing_layers):
355
+ logger.warning(
356
+ f"Skip loading {hf_key} as it doesn't belong to this PP stage.")
357
+ return
334
358
  model_weight, model_sharding = get_param_and_sharding(
335
359
  params, shardings, model_key)
336
360
 
@@ -394,6 +418,14 @@ def _load_and_shard_weight(vllm_config,
394
418
  model_weight.value = shard(hf_weight, spec)
395
419
 
396
420
 
421
+ def _is_pp_missing_layer(hf_key: str, pp_missing_layers: list[str]) -> bool:
422
+ has_digit = any(char.isdigit() for char in hf_key)
423
+ # add the suffix after digits to avoid it matches "layers.10" with "layers.1"
424
+ suffix = "." if has_digit else ""
425
+ return any(f'{pp_missing_layer}{suffix}' in hf_key
426
+ for pp_missing_layer in pp_missing_layers)
427
+
428
+
397
429
  def _load_hf_weights_on_thread(
398
430
  vllm_config: VllmConfig,
399
431
  params: nnx.State,
@@ -402,6 +434,7 @@ def _load_hf_weights_on_thread(
402
434
  weights_file: str,
403
435
  filter_regex: Optional[str] = None,
404
436
  keep_original_dtype_keys_regex: Optional[list[str]] = None,
437
+ pp_missing_layers: list[str] | None = None,
405
438
  ):
406
439
  """Loads weights from a single weights file."""
407
440
  try:
@@ -420,6 +453,7 @@ def _load_hf_weights_on_thread(
420
453
  hf_key,
421
454
  hf_weight,
422
455
  keep_original_dtype_keys_regex,
456
+ pp_missing_layers,
423
457
  )
424
458
 
425
459
 
@@ -431,6 +465,7 @@ def load_hf_weights(
431
465
  filter_regex: Optional[str] = None,
432
466
  is_draft_model: bool = False,
433
467
  keep_original_dtype_keys_regex: Optional[list[str]] = None,
468
+ pp_missing_layers: list[str] | None = None,
434
469
  ):
435
470
  """Load weights into a JAX model from either an iterator or files."""
436
471
  params = nnx.state(model)
@@ -461,6 +496,7 @@ def load_hf_weights(
461
496
  hf_key,
462
497
  hf_weight_jax,
463
498
  keep_original_dtype_keys_regex,
499
+ pp_missing_layers=pp_missing_layers,
464
500
  )
465
501
  else:
466
502
  # File-based path (multi-threaded)
@@ -488,6 +524,7 @@ def load_hf_weights(
488
524
  filter_regex=filter_regex,
489
525
  keep_original_dtype_keys_regex=
490
526
  keep_original_dtype_keys_regex,
527
+ pp_missing_layers=pp_missing_layers,
491
528
  ) for weights_file in weights_files
492
529
  ]
493
530
  for future in futures:
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import copy
2
16
  import functools
3
17
  from collections.abc import Sequence
@@ -23,6 +37,7 @@ from vllm.model_executor.models import supports_lora, supports_multimodal
23
37
  from vllm.sequence import IntermediateTensors
24
38
 
25
39
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
40
+ from tpu_inference.layers.common.sharding import ShardingAxisName
26
41
  from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
27
42
  from tpu_inference.layers.vllm.sharding import shard_model_to_tpu
28
43
  from tpu_inference.logger import init_logger
@@ -197,7 +212,7 @@ class VllmModelWrapper:
197
212
  kwargs={
198
213
  "input_ids": torch_view(input_ids),
199
214
  "positions": torch_view(input_positions),
200
- "intermediate_tensors": None,
215
+ "intermediate_tensors": intermediate_tensors,
201
216
  "inputs_embeds": None,
202
217
  },
203
218
  tie_weights=False,
@@ -220,8 +235,10 @@ class VllmModelWrapper:
220
235
 
221
236
  @functools.partial(
222
237
  jax.jit,
223
- out_shardings=(NamedSharding(self.mesh,
224
- PartitionSpec(None, "model"))),
238
+ out_shardings=(NamedSharding(
239
+ self.mesh,
240
+ PartitionSpec(ShardingAxisName.MLP_DATA,
241
+ ShardingAxisName.MLP_TENSOR))),
225
242
  )
226
243
  def compute_logits_func(
227
244
  params_and_buffers: Any,
@@ -263,7 +280,6 @@ def load_lora_model(model: torch.nn.Module, vllm_config: VllmConfig,
263
280
  vllm_config,
264
281
  device,
265
282
  model.embedding_modules,
266
- model.embedding_padding_modules,
267
283
  )
268
284
  return lora_manager, lora_manager.create_lora_manager(model)
269
285
 
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from contextlib import contextmanager
2
16
  from dataclasses import dataclass
3
17
  from typing import Dict, List, Optional
@@ -1,2 +1,16 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  # ruff: noqa
2
16
  from tpu_inference.platforms.tpu_platform import TpuPlatform
@@ -1,39 +1,35 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
2
 
3
- from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast
3
+ from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
4
4
 
5
5
  import jax.numpy as jnp
6
6
  import torch
7
7
  import vllm.envs as vllm_envs
8
- from torchax.ops.mappings import j2t_dtype
9
8
  from tpu_info import device
10
9
  from vllm.inputs import ProcessorInputs, PromptType
11
10
  from vllm.platforms.interface import Platform, PlatformEnum
12
- from vllm.sampling_params import SamplingParams, SamplingType
13
11
 
14
12
  from tpu_inference import envs
15
13
  from tpu_inference.layers.common.sharding import ShardingConfigManager
16
14
  from tpu_inference.logger import init_logger
17
15
 
18
16
  if TYPE_CHECKING:
19
- from vllm.attention.backends.registry import _Backend
17
+ from vllm.attention.backends.registry import AttentionBackendEnum
18
+ from vllm.attention.selector import AttentionSelectorConfig
20
19
  from vllm.config import BlockSize, ModelConfig, VllmConfig
21
20
  from vllm.pooling_params import PoolingParams
21
+ from vllm.sampling_params import SamplingParams, SamplingType
22
22
  else:
23
23
  BlockSize = None
24
24
  ModelConfig = None
25
25
  VllmConfig = None
26
26
  PoolingParams = None
27
- _Backend = None
27
+ AttentionBackendEnum = None
28
+ SamplingParams = None
29
+ SamplingType = None
28
30
 
29
31
  logger = init_logger(__name__)
30
32
 
31
- _DTYPE: dict[str, jnp.dtype] = {
32
- "bfloat16": jnp.bfloat16,
33
- "float": jnp.float32,
34
- "float32": jnp.float32,
35
- }
36
-
37
33
 
38
34
  class TpuPlatform(Platform):
39
35
  _enum = PlatformEnum.TPU
@@ -50,25 +46,21 @@ class TpuPlatform(Platform):
50
46
 
51
47
  additional_env_vars: list[str] = [
52
48
  "PHASED_PROFILING_DIR", "TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS",
53
- "TPU_MULTIHOST_BACKEND", "VLLM_MLA_DISABLE", "TPU_BACKEND_TYPE"
49
+ "TPU_MULTIHOST_BACKEND", "VLLM_MLA_DISABLE", "TPU_BACKEND_TYPE",
50
+ "NEW_MODEL_DESIGN"
54
51
  ]
55
52
 
56
53
  @classmethod
57
- def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
58
- dtype: jnp.dtype, kv_cache_dtype: Optional[str],
59
- block_size: int, use_v1: bool, use_mla: bool,
60
- has_sink: bool, use_sparse: bool,
61
- attn_type: Any) -> str:
62
- from vllm.attention.backends.registry import _Backend
63
- if selected_backend != _Backend.PALLAS:
54
+ def get_attn_backend_cls(cls, selected_backend: "AttentionBackendEnum",
55
+ attn_selector_config: "AttentionSelectorConfig",
56
+ **kwargs) -> str:
57
+ from vllm.attention.backends.registry import AttentionBackendEnum
58
+
59
+ if selected_backend != AttentionBackendEnum.PALLAS:
64
60
  logger.info("Cannot use %s backend on TPU.", selected_backend)
65
61
 
66
- if use_v1:
67
- logger.info("Using Pallas V1 backend.")
68
- return "tpu_inference.layers.vllm.attention.PallasAttentionBackend"
69
- else:
70
- logger.info("Using Pallas backend.")
71
- return "vllm.attention.backends.pallas.PallasAttentionBackend"
62
+ logger.info("Using Pallas V1 backend.")
63
+ return "tpu_inference.layers.vllm.attention.PallasAttentionBackend"
72
64
 
73
65
  @classmethod
74
66
  def get_device_name(cls, device_id: int = 0) -> str:
@@ -152,40 +144,21 @@ class TpuPlatform(Platform):
152
144
  if compilation_config.backend == "":
153
145
  compilation_config.backend = "openxla"
154
146
 
155
- # If we use vLLM's model implementation in PyTorch, we should set it with torch version of the dtype.
156
- impl = envs.MODEL_IMPL_TYPE
157
-
158
- # NOTE(xiang): convert dtype to jnp.dtype
159
- # NOTE(wenlong): skip this logic for mm model preprocessing
160
- # For mm model preprocessors, it may need the output dtype to be torch.
161
- # In order to avoid a PR to vLLM, we postpone the dtype checking during tpu_worker initialization
162
- if not vllm_config.scheduler_config.is_multimodal_model or impl == "vllm":
163
- if not isinstance(vllm_config.model_config.dtype, str):
164
- logger.warning(
165
- "The model dtype is not properly set for JAX backend. "
166
- "Overwriting it to jnp.bfloat16")
167
- vllm_config.model_config.dtype = jnp.bfloat16
168
- else:
169
- vllm_config.model_config.dtype = _DTYPE.get(
170
- vllm_config.model_config.dtype, jnp.bfloat16)
171
-
172
- if impl == "vllm":
173
- vllm_config.model_config.dtype = j2t_dtype(
174
- vllm_config.model_config.dtype.dtype)
175
-
176
147
  # TODO(cuiq): remove this dependency.
177
- from vllm.v1.attention.backends.pallas import PallasAttentionBackend
178
- cache_config.block_size = PallasAttentionBackend.get_page_size(
179
- vllm_config) # type: ignore[assignment]
180
- min_page_size = PallasAttentionBackend.get_min_page_size(vllm_config)
181
- if min_page_size > cache_config.block_size:
182
- logger.warning(
183
- "Increase the page size from %s to %s to make sure there's"
184
- "no SMEM OOM",
185
- cache_config.block_size,
186
- min_page_size,
187
- )
188
- cache_config.block_size = min_page_size # type: ignore[assignment]
148
+ if vllm_config.model_config:
149
+ from vllm.v1.attention.backends.pallas import \
150
+ PallasAttentionBackend
151
+ cache_config.block_size = PallasAttentionBackend.get_page_size(
152
+ vllm_config) # type: ignore[assignment]
153
+ min_page_size = PallasAttentionBackend.get_min_page_size(
154
+ vllm_config)
155
+ if min_page_size > cache_config.block_size:
156
+ logger.warning(
157
+ "Increase the page size from %s to %s to avoid SMEM OOM",
158
+ cache_config.block_size,
159
+ min_page_size,
160
+ )
161
+ cache_config.block_size = min_page_size # type: ignore[assignment]
189
162
 
190
163
  parallel_config = vllm_config.parallel_config
191
164
  scheduler_config = vllm_config.scheduler_config
@@ -195,12 +168,12 @@ class TpuPlatform(Platform):
195
168
  multihost_backend = envs.TPU_MULTIHOST_BACKEND
196
169
  if not multihost_backend: # Single host
197
170
  if parallel_config.pipeline_parallel_size == 1:
198
- logger.info("Force using UniProcExecutor for JAX on \
199
- single host without pipeline parallelism.")
171
+ logger.info("Force using UniProcExecutor for JAX on "
172
+ "single host without pipeline parallelism.")
200
173
  parallel_config.distributed_executor_backend = "uni"
201
174
  else:
202
- logger.info("Force using MultiprocExecutor for JAX on \
203
- single host with pipeline parallelism.")
175
+ logger.info("Force using MultiprocExecutor for JAX on "
176
+ "single host with pipeline parallelism.")
204
177
  parallel_config.distributed_executor_backend = "mp"
205
178
  elif multihost_backend == "ray":
206
179
  from tpu_inference.executors.ray_distributed_executor import \
@@ -216,19 +189,21 @@ class TpuPlatform(Platform):
216
189
 
217
190
  if scheduler_config.is_multimodal_model and not \
218
191
  scheduler_config.disable_chunked_mm_input:
219
- logger.warning("TPU does not support running Multimodal models"\
220
- " without setting `--disable_chunked_mm_input`. " \
221
- "Forcing --disable_chunked_mm_input.")
192
+ logger.warning("TPU does not support running Multimodal models"
193
+ " without setting `--disable_chunked_mm_input`. "
194
+ "Forcing --disable_chunked_mm_input.")
222
195
  scheduler_config.disable_chunked_mm_input = True
223
196
 
224
197
  kv_transfer_config = vllm_config.kv_transfer_config
225
198
  if kv_transfer_config is not None:
226
199
  assert kv_transfer_config.kv_connector == "TPUConnector"
227
- # Late initialization to avoid circular import
228
- from tpu_inference.models.jax.utils.quantization.quantization_utils import \
229
- update_vllm_config_for_qwix_quantization
230
-
231
- update_vllm_config_for_qwix_quantization(vllm_config)
200
+ # Late initialization to avoid circular import.
201
+ # Only perform qwix quantization if it is jax model.
202
+ if vllm_config.model_config is not None:
203
+ from tpu_inference.models.jax.utils.qwix.qwix_utils import \
204
+ update_vllm_config_for_qwix_quantization
205
+ if vllm_config.model_config:
206
+ update_vllm_config_for_qwix_quantization(vllm_config)
232
207
 
233
208
  from tpu_inference.core.sched.dp_scheduler import \
234
209
  update_vllm_config_for_dp_scheduler
@@ -256,10 +231,11 @@ class TpuPlatform(Platform):
256
231
  def validate_request(
257
232
  cls,
258
233
  prompt: PromptType,
259
- params: Union[SamplingParams, PoolingParams],
234
+ params: Union["SamplingParams", PoolingParams],
260
235
  processed_inputs: ProcessorInputs,
261
236
  ) -> None:
262
237
  """Raises if this request is unsupported on this platform"""
238
+ from vllm.sampling_params import SamplingParams, SamplingType
263
239
 
264
240
  if isinstance(params, SamplingParams):
265
241
  if params.sampling_type == SamplingType.RANDOM_SEED:
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.