tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,600 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import math
17
+ from dataclasses import asdict, dataclass
18
+ from typing import TYPE_CHECKING, List, Optional
19
+
20
+ import jax.numpy as jnp
21
+ import numpy as np
22
+ from jax.sharding import Mesh
23
+
24
+ from tpu_inference import envs, utils
25
+
26
+ if TYPE_CHECKING:
27
+ from vllm.v1.configs.vllm_config import VllmConfig
28
+
29
+ MESH_AXIS_NAMES = ("data", "attn_dp", "expert", "model")
30
+ MESH_AXIS_NAMES_2D = ('data', 'model')
31
+
32
+
33
+ class ShardingAxisNameBase:
34
+ """Base class for sharding axis names."""
35
+ SEQUENCE = ('data', 'attn_dp')
36
+ ATTN_DATA = ('data', 'attn_dp')
37
+ MLP_DATA = 'data'
38
+ ATTN_HEAD = 'model'
39
+ ATTN_TENSOR = None
40
+ MLP_TENSOR = ('attn_dp', 'model', 'expert')
41
+ MOE_TENSOR = ('attn_dp', 'model')
42
+ EXPERT = ('attn_dp', 'expert', 'model')
43
+ VOCAB = ('expert', 'model')
44
+
45
+
46
+ class ShardingAxisName2D:
47
+ """Sharding axis names for 2D data parallelism scenarios.
48
+ NOTE(wenxindongwork): The new MoE kernel expects a 2D mesh for now.
49
+ We should use ShardingAxisNameBase once the new MoE kernel supports
50
+ more general mesh shapes. For now, this is the default sharding axes.
51
+ """
52
+ SEQUENCE = 'data'
53
+ ATTN_DATA = 'data'
54
+ MLP_DATA = 'data'
55
+ ATTN_HEAD = 'model'
56
+ ATTN_TENSOR = None
57
+ MLP_TENSOR = 'model'
58
+ MOE_TENSOR = 'model'
59
+ EXPERT = 'model'
60
+ VOCAB = ('data', 'model')
61
+
62
+
63
+ try:
64
+ _use_base_sharding = envs.NEW_MODEL_DESIGN
65
+ if _use_base_sharding:
66
+ ShardingAxisName = ShardingAxisNameBase
67
+ else:
68
+ ShardingAxisName = ShardingAxisName2D
69
+ except Exception:
70
+ ShardingAxisName = ShardingAxisName2D
71
+
72
+
73
+ @dataclass
74
+ class ShardingStrategy:
75
+ """Defines the high-level parallelism strategy.
76
+
77
+ This class specifies how many ways each type of parallelism (tensor, expert,
78
+ sequence, data) should be distributed across the available devices.
79
+
80
+ Attributes:
81
+ tensor_parallelism: The degree of tensor parallelism (e.g., splitting
82
+ weights of a single layer).
83
+ expert_parallelism: The degree of expert parallelism for MoE models.
84
+ sequence_parallelism: The degree of sequence parallelism (splitting
85
+ activations along the sequence length dimension).
86
+ data_parallelism: The degree of data parallelism (splitting the batch
87
+ across devices).
88
+ """
89
+ tensor_parallelism: int = 1
90
+ expert_parallelism: int = 1
91
+ sequence_parallelism: int = 1
92
+ data_parallelism: int = 1
93
+ attention_data_parallelism: int = 1
94
+
95
+
96
+ class ShardingConfigManager:
97
+ """Manages sharding configuration parsing and access from vLLM config.
98
+
99
+ Usage:
100
+ sharding_config = ShardingConfigManager.from_vllm_config(vllm_config)
101
+ tp_size = sharding_config.tp_size
102
+
103
+ During initialization, we set `vllm_config.sharding_config` to
104
+ `ShardingConfigManager.from_vllm_config(vllm_config)`, so you can access
105
+ `vllm_config.sharding_config.tp_size` directly.
106
+ """
107
+
108
+ def __init__(self,
109
+ sharding_strategy: ShardingStrategy,
110
+ device_indexes: Optional[List] = None):
111
+
112
+ self.sharding_strategy: ShardingStrategy = sharding_strategy
113
+ self.device_indexes: Optional[List[int]] = device_indexes
114
+ self._total_devices: int = int(
115
+ math.prod(asdict(sharding_strategy).values()))
116
+ if device_indexes:
117
+ assert self._total_devices == len(device_indexes)
118
+
119
+ @classmethod
120
+ def from_vllm_config(cls,
121
+ vllm_config: 'VllmConfig') -> 'ShardingConfigManager':
122
+
123
+ sharding_strategy = vllm_config.additional_config.get(
124
+ "sharding", {}).get("sharding_strategy", {})
125
+ parallel_config = vllm_config.parallel_config
126
+ tensor_parallelism = parallel_config.tensor_parallel_size
127
+ data_parallelism = parallel_config.data_parallel_size
128
+ expert_parallelism = sharding_strategy.get("expert_parallelism", 1)
129
+ sequence_parallelism = sharding_strategy.get("sequence_parallelism", 1)
130
+ device_indexes = sharding_strategy.get("device_indexes", None)
131
+
132
+ enable_dp_attention = sharding_strategy.get("enable_dp_attention",
133
+ False)
134
+ if enable_dp_attention:
135
+ # Replicate attention layer when num_kv_heads < TP
136
+ num_kv_heads = 1 if vllm_config.model_config.use_mla else vllm_config.model_config.get_total_num_kv_heads(
137
+ )
138
+ cache_dtype = vllm_config.cache_config.cache_dtype
139
+ if cache_dtype == 'auto':
140
+ cache_dtype = vllm_config.model_config.dtype
141
+ kv_dtype = utils.get_jax_dtype_from_str_dtype(
142
+ cache_dtype) or jnp.bfloat16
143
+ packing = 4 // jnp.dtype(kv_dtype).itemsize
144
+ # When num_kv_heads * 2 / packing < TP, tensor parallelism would
145
+ # duplicate KV heads across devices, wasting kv cache memory.
146
+ # Use attention DP instead to reduce per-device num_kv_heads and
147
+ # eliminate this waste.
148
+ num_kv_heads_per_device_in_kv_cache = (num_kv_heads * 2) / packing
149
+ attn_dp = max(
150
+ int(tensor_parallelism // num_kv_heads_per_device_in_kv_cache),
151
+ 1)
152
+ tensor_parallelism = tensor_parallelism // attn_dp
153
+ else:
154
+ attn_dp = 1
155
+
156
+ sharding_strategy = ShardingStrategy(
157
+ tensor_parallelism=tensor_parallelism,
158
+ data_parallelism=data_parallelism,
159
+ expert_parallelism=expert_parallelism,
160
+ sequence_parallelism=sequence_parallelism,
161
+ attention_data_parallelism=attn_dp)
162
+
163
+ # Must override here to avoid vLLM spinning up multiple DP engines.
164
+ if vllm_config.parallel_config.data_parallel_size > 1:
165
+ vllm_config.parallel_config.data_parallel_size = 1
166
+ vllm_config.parallel_config.data_parallel_rank = 0
167
+ vllm_config.parallel_config.data_parallel_size_local = 1
168
+
169
+ cls.validate(vllm_config, sharding_strategy)
170
+ return cls(sharding_strategy, device_indexes)
171
+
172
+ @classmethod
173
+ def validate(cls, vllm_config, sharding_strategy):
174
+ total_dp_size = sharding_strategy.data_parallelism * sharding_strategy.attention_data_parallelism
175
+ if total_dp_size > 1:
176
+ if vllm_config.speculative_config is not None:
177
+ raise ValueError(
178
+ f"Speculative decoding is not supported with data parallelism "
179
+ f"(DP size: {total_dp_size}). Please disable speculative decoding or "
180
+ f"set data parallelism to 1.")
181
+ if vllm_config.lora_config is not None:
182
+ raise ValueError(
183
+ f"LoRA is not supported with data parallelism "
184
+ f"(DP size: {total_dp_size}). Please disable LoRA or "
185
+ f"set data parallelism to 1.")
186
+ if sharding_strategy.attention_data_parallelism > 1:
187
+ if not envs.NEW_MODEL_DESIGN:
188
+ raise ValueError(
189
+ "Must run Attention DP with NEW_MODEL_DESIGN enabled. Please set the "
190
+ "NEW_MODEL_DESIGN=True.")
191
+
192
+ @property
193
+ def total_dp_size(self) -> int:
194
+ return self.sharding_strategy.data_parallelism * self.sharding_strategy.attention_data_parallelism
195
+
196
+ @property
197
+ def model_dp_size(self) -> int:
198
+ return self.sharding_strategy.data_parallelism
199
+
200
+ @property
201
+ def attn_dp_size(self) -> int:
202
+ return self.sharding_strategy.attention_data_parallelism
203
+
204
+ @property
205
+ def tp_size(self) -> int:
206
+ return self.sharding_strategy.tensor_parallelism
207
+
208
+ @property
209
+ def expert_size(self) -> int:
210
+ return self.sharding_strategy.expert_parallelism
211
+
212
+ @property
213
+ def sequence_size(self) -> int:
214
+ return self.sharding_strategy.sequence_parallelism
215
+
216
+ @property
217
+ def total_devices(self) -> int:
218
+ return self._total_devices
219
+
220
+ def __str__(self):
221
+ return (f"ShardingConfigManager(total_devices={self.total_devices}, "
222
+ f"sharding_strategy={self.sharding_strategy}, "
223
+ f"device_indexes={self.device_indexes})")
224
+
225
+
226
+ #TODO split this into block unique sharding config, i.e. attentionShardingConfig, MoEShardingConfig
227
+ @dataclass
228
+ class ShardingRulesConfig:
229
+ """Holds detailed sharding configurations for individual tensors, namely logical rules.
230
+
231
+ Each attribute in this class corresponds to a specific weight or activation
232
+ tensor within a transformer model. The value of each attribute is a
233
+ tuple of logical mesh axis names (e.g., 'dp', 'sp', 'tp'), which defines
234
+ how the corresponding tensor's dimensions are partitioned across the device mesh.
235
+ The dimension order in the attribute name (e.g., `btd` for batch, sequence,
236
+ d_model) maps directly to the sharding tuple.
237
+
238
+ TODO: update the mesh axis names to be clear and reduce confusion between prefill & generate
239
+ """
240
+
241
+ # Activation for attn input: (Batch * Sequence, Dim)
242
+ activation_attention_td: tuple = (None, None)
243
+ # Activation for attn out: (Batch * Sequence, Dim)
244
+ activation_attention_out_td: tuple = (None, None)
245
+ # Activation for q projection input: (Batch * Sequence, Dim)
246
+ activation_q_td: tuple = (None, None)
247
+ # Attention Out activation after projection: (Batch * Sequence, NumHeads, HeadDim)
248
+ attn_o_tnh: tuple = (None, None, None)
249
+ # Q vector: (Batch * Sequence, NumHeads, HeadDim)
250
+ query_tnh: tuple = (None, None, None)
251
+ # K/V vector: (Batch * Sequence, NumKVHeads, HeadDim)
252
+ keyvalue_skh: tuple = (None, None, None)
253
+
254
+ # Attention Q weight: (Dim, NumHeads, HeadDim)
255
+ attn_q_weight_dnh: tuple = (None, None, None)
256
+ # Attention K weight: (Dim, NumKVHeads, HeadDim)
257
+ attn_k_weight_dkh: tuple = (None, None, None)
258
+ # Attention V weight: (Dim, NumKVHeads, HeadDim)
259
+ attn_v_weight_dkh: tuple = (None, None, None)
260
+ # Attention Out weight: (NumHeads, HeadDim, Dim)
261
+ attn_o_weight_nhd: tuple = (None, None, None)
262
+
263
+ # Activation for ffw input: (Batch * Sequence, Dim)
264
+ activation_ffw_td: tuple = (None, None)
265
+
266
+ # Activation for ffw input: (Batch * Sequence, Expert, Dim)
267
+ activation_ffw_ted: tuple = (None, None, None)
268
+
269
+ # FFW hidden activation: (Batch * Sequence, FfwDim)
270
+ ffw_hidden_tf: tuple = (None, None)
271
+
272
+ # FFW up/gate weight: (Dim, FfwDim)
273
+ ffw_weight_df: tuple = (None, None)
274
+ # FFW down weight: (FfwDim, Dim)
275
+ ffw_weight_fd: tuple = (None, None)
276
+ # MoE gate/up weights: (NumExperts, Dim, FfwDim)
277
+ moe_weights_edf: tuple = (None, None, None)
278
+ # MoE down weights: (NumExperts, FfwDim, Dim)
279
+ moe_weights_efd: tuple = (None, None, None)
280
+ # MoE router weights: (Dim, NumExperts)
281
+ moe_router_de: tuple = (None, None)
282
+ # MoE router bias weights: (NumExperts,)
283
+ moe_router_bias_e: tuple = (None, )
284
+
285
+ # Embedding weight: (VocabSize, Dim)
286
+ emb_weight_vd: tuple = (None, None)
287
+ # Activation between layers: (Batch * Sequence, Dim)
288
+ activation_td: tuple = (None, None)
289
+ # Final activation before logits: (Batch * Sequence, Dim)
290
+ prelogit_td: tuple = (None, None)
291
+ # Logit activation: (Batch * Sequence, VocabSize)
292
+ logits_tv: tuple = (None, None)
293
+ # RMS norm scale weight: (Dim,)
294
+ norm_scale: tuple = (None)
295
+ # Vocab projection weight (tied embeddings): (Dim, VocabSize)
296
+ vocab_vd: tuple = (None, None)
297
+ vocab_dv: tuple = (None, None)
298
+
299
+
300
+ class ShardingConfig:
301
+ """Container for operation-specific sharding configurations.
302
+
303
+ This class holds two separate `ShardingRulesConfig` objects, one for the
304
+ 'prefill' phase and one for the 'generate' (or decode) phase of model
305
+ execution. This allows tailoring sharding strategies to the different
306
+ computational patterns of each phase.
307
+
308
+ Example Sharding Strategy and Configuration:
309
+
310
+ Sharding Strategy defines the high-level parallelism dimensions.
311
+ For a device mesh like `Mesh((2, 4, 4, 4), ('data', 'seq', 'expert', 'tensor'))` on 128 devices:
312
+ - data: Data Parallelism (2-way)
313
+ - seq: Sequence Parallelism (4-way)
314
+ - expert: Expert Parallelism (4-way)
315
+ - tensor: Tensor Parallelism (4-way)
316
+
317
+ ShardingConfig then maps tensor dimensions to these logical mesh axes.
318
+ For example, a tensor with shape (Batch, Sequence, Dimension) could be sharded
319
+ differently for prefill and decode/generate operations:
320
+
321
+ - Prefill (long sequences, small batch):
322
+ Sharding sequence dim on the 'sp' axis is often efficient.
323
+ `prefill_rules.activation_attention_btd = (None, 'seq', 'tensor')`
324
+
325
+ - Generate (short sequences, large batch):
326
+ Sharding batch dim on the 'dp' axis is often efficient.
327
+ `generate_rules.activation_attention_btd = ('data', None, 'tensor')`
328
+ """
329
+
330
+ def __init__(self,
331
+ prefill_rules=None,
332
+ generate_rules=None,
333
+ default_rules_cls=ShardingRulesConfig):
334
+ """Initializes the ShardingConfig.
335
+
336
+ Args:
337
+ prefill_rules: An `ShardingRulesConfig` for the prefill phase.
338
+ If None, a default config is created.
339
+ generate_rules: An `ShardingRulesConfig` for the generate phase.
340
+ If None, a default config is created.
341
+ default_rules_cls: The default sharding rules (class) to use.
342
+ """
343
+ # Use a factory pattern to avoid mutable default arguments
344
+ self.default_rules_cls = default_rules_cls
345
+ self.prefill_rules = prefill_rules if prefill_rules is not None else default_rules_cls(
346
+ )
347
+ self.generate_rules = generate_rules if generate_rules is not None else default_rules_cls(
348
+ )
349
+
350
+
351
+ def build_mesh(devices, strategy: dict[str, int]) -> Mesh:
352
+ """Constructs a JAX device mesh from a sharding strategy.
353
+
354
+ This method creates a logical grid of devices based on the parallelism
355
+ degrees defined in the strategy. The logical axis names ('dp', 'ep',
356
+ 'sp', 'tp') are used to map tensor dimensions to the physical device grid.
357
+
358
+ Args:
359
+ strategy: A dictionary from upper level config.
360
+
361
+ Returns:
362
+ A JAX `Mesh` object.
363
+ """
364
+
365
+ axis_order = {
366
+ "data": strategy.get("data_parallelism", 1),
367
+ "expert": strategy.get("expert_parallelism", 1),
368
+ "seq": strategy.get("sequence_parallelism", 1),
369
+ "model": strategy.get("tensor_parallelism", 1),
370
+ }
371
+ # TODO: add logic to infer axis when the degree is -1
372
+ mesh_axis_names = []
373
+ mesh_shape = []
374
+ for axis, dim in axis_order.items():
375
+ mesh_axis_names.append(axis)
376
+ mesh_shape.append(dim)
377
+
378
+ if not mesh_shape:
379
+ mesh_shape = [1]
380
+ mesh_axis_names = [
381
+ 'data'
382
+ ] # default to data parallelism if no other strategy is specified
383
+
384
+ devices = np.asarray(devices).reshape(mesh_shape)
385
+ return Mesh(devices, axis_names=tuple(mesh_axis_names))
386
+
387
+
388
+ class Sharding:
389
+ """Generates and manages sharding configurations based on a high-level strategy.
390
+
391
+ This class populates a `ShardingConfig` with detailed tensor sharding
392
+ rules for both prefill and generation phases. It also allows for runtime
393
+ overrides of these rules.
394
+
395
+ Attributes:
396
+ sharding_cfg: The generated `ShardingConfig` with detailed rules.
397
+ """
398
+
399
+ def __init__(self,
400
+ prefill_rules: dict | None = None,
401
+ generate_rules: dict | None = None,
402
+ default_rules_cls=ShardingRulesConfig,
403
+ vllm_config: 'VllmConfig' = None):
404
+ """Initializes the Sharding manager.
405
+
406
+ Args:
407
+ prefill_rules: A dictionary of overrides for the prefill
408
+ sharding config. Keys are attribute names in `ShardingRulesConfig`,
409
+ and values are the new sharding tuples.
410
+ generate_rules: A dictionary of overrides for the generate
411
+ sharding config.
412
+ """
413
+ self.vllm_config = vllm_config
414
+ self.default_rules_cls = default_rules_cls
415
+ self.sharding_cfg = self.make_sharding_config(
416
+ default_rules_cls=default_rules_cls,
417
+ prefill_overrides=prefill_rules,
418
+ generate_overrides=generate_rules)
419
+
420
+ def _get_overrides(self, sharding_phase: str):
421
+ """Return the overrides from the vLLM config for the given sharding phase."""
422
+ overrides = {}
423
+ try:
424
+ overrides = self.vllm_config.additional_config["sharding"][
425
+ "logical_rules"]["all"]
426
+ except KeyError:
427
+ pass
428
+
429
+ try:
430
+ additional_overrides = self.vllm_config.additional_config[
431
+ "sharding"]["logical_rules"][f"{sharding_phase}"]
432
+ overrides.update(additional_overrides)
433
+ except KeyError:
434
+ pass
435
+ return overrides
436
+
437
+ def __str__(self):
438
+ """Succinct representation of relevant Sharding settings and overrides."""
439
+ output_str = f" Using {self.default_rules_cls.__name__} logical rules.\n"
440
+ output_str += f" {self.__class__.__name__:} overrides:\n"
441
+ output_str += f" prefill logical_rule overrides:\n {json.dumps(self._get_overrides('prefill'), indent=4, default=str)}\n\n"
442
+ output_str += f" generate logical_rule overrides:\n {json.dumps(self._get_overrides('generate'), indent=4, default=str)}\n\n"
443
+ return output_str
444
+
445
+ def validate_sharding_strategy(self, ):
446
+ """Validates if the sharding strategy is compatible with the environment.
447
+
448
+ This method is a placeholder now, and will check if the product of parallelism degrees
449
+ matches the number of available devices.
450
+ """
451
+ #TODO: check num_devices % parallelism == 0
452
+ #TODO: check num_devices == multiply(parallelism(with inferred))
453
+ return
454
+
455
+ def get_sharding_cfg(self) -> ShardingConfig:
456
+ """Returns the generated sharding configuration."""
457
+ return self.sharding_cfg
458
+
459
+ def _apply_overrides(self, config_obj: ShardingRulesConfig,
460
+ overrides: dict | None):
461
+ """Applies runtime overrides to a sharding configuration object.
462
+
463
+ Args:
464
+ config_obj: The sharding configuration object (e.g., prefill_rules)
465
+ to be updated.
466
+ overrides: A dictionary where keys are attribute names of the config
467
+ object and values are the new sharding tuples.
468
+
469
+ Raises:
470
+ AttributeError: If a key in the overrides dictionary is not a valid
471
+ attribute of the configuration object.
472
+ """
473
+ for key, value in overrides.items():
474
+ if hasattr(config_obj, key):
475
+ setattr(config_obj, key, value)
476
+ else:
477
+ # Raise an error for invalid keys to prevent silent failures
478
+ raise AttributeError(
479
+ f"'{key}' is not a valid attribute of {type(config_obj).__name__}"
480
+ )
481
+
482
+ def _make_default_sharding_config(self, prefill_rules, generate_rules):
483
+
484
+ # Populate Prefill Config
485
+ # During prefill, sequence length is long, so we shard along the sequence axis.
486
+ prefill_rules.activation_attention_td = (ShardingAxisName.ATTN_DATA,
487
+ ShardingAxisName.ATTN_TENSOR)
488
+ prefill_rules.activation_attention_out_td = (
489
+ ShardingAxisName.ATTN_DATA, ShardingAxisName.ATTN_TENSOR)
490
+ prefill_rules.activation_q_td = (ShardingAxisName.ATTN_DATA,
491
+ ShardingAxisName.ATTN_TENSOR)
492
+ #TODO: the default qkv and kvcache is sharded on head dim
493
+ # We may change it after we finalize the KVCache design
494
+ prefill_rules.attn_o_tnh = (ShardingAxisName.ATTN_DATA,
495
+ ShardingAxisName.ATTN_HEAD, None)
496
+ prefill_rules.query_tnh = (ShardingAxisName.ATTN_DATA,
497
+ ShardingAxisName.ATTN_HEAD, None)
498
+ prefill_rules.keyvalue_skh = (ShardingAxisName.ATTN_DATA,
499
+ ShardingAxisName.ATTN_HEAD, None)
500
+
501
+ # Populate Generate (Decode) Config
502
+ # During decode, batch size is the large dimension, so we shard along the batch axis.
503
+ generate_rules.activation_attention_td = (ShardingAxisName.ATTN_DATA,
504
+ ShardingAxisName.ATTN_TENSOR)
505
+ generate_rules.activation_attention_out_td = (
506
+ ShardingAxisName.MLP_DATA, ShardingAxisName.ATTN_TENSOR)
507
+ generate_rules.activation_q_td = (ShardingAxisName.ATTN_DATA,
508
+ ShardingAxisName.ATTN_TENSOR)
509
+ #TODO: the default qkv and kvcache is sharded on head dim
510
+ # We may change it after we finalize the KVCache design
511
+ generate_rules.attn_o_tnh = (ShardingAxisName.ATTN_DATA,
512
+ ShardingAxisName.ATTN_HEAD, None)
513
+ generate_rules.query_tnh = (ShardingAxisName.ATTN_DATA,
514
+ ShardingAxisName.ATTN_HEAD, None)
515
+ generate_rules.keyvalue_skh = (ShardingAxisName.ATTN_DATA,
516
+ ShardingAxisName.ATTN_HEAD, None)
517
+ generate_rules.attn_q_weight_dnh = (None, ShardingAxisName.ATTN_HEAD,
518
+ ShardingAxisName.ATTN_TENSOR)
519
+ generate_rules.attn_k_weight_dkh = (None, ShardingAxisName.ATTN_HEAD,
520
+ ShardingAxisName.ATTN_TENSOR)
521
+ generate_rules.attn_v_weight_dkh = (None, ShardingAxisName.ATTN_HEAD,
522
+ ShardingAxisName.ATTN_TENSOR)
523
+ generate_rules.attn_o_weight_nhd = (ShardingAxisName.ATTN_HEAD, None,
524
+ ShardingAxisName.ATTN_TENSOR)
525
+ generate_rules.activation_ffw_td = (ShardingAxisName.MLP_DATA, None)
526
+ generate_rules.activation_ffw_ted = (ShardingAxisName.MLP_DATA,
527
+ ShardingAxisName.EXPERT, None)
528
+ generate_rules.ffw_hidden_tf = (ShardingAxisName.MLP_DATA,
529
+ ShardingAxisName.MLP_TENSOR)
530
+ # FFW weights are typically sharded along the hidden dimension (F).
531
+ generate_rules.ffw_weight_df = (None, ShardingAxisName.MLP_TENSOR)
532
+ generate_rules.ffw_weight_fd = (ShardingAxisName.MLP_TENSOR, None)
533
+ # MoE weights are sharded along the expert axis and the hidden dimension.
534
+ generate_rules.moe_weights_edf = (ShardingAxisName.EXPERT, None,
535
+ ShardingAxisName.MOE_TENSOR)
536
+ generate_rules.moe_weights_efd = (ShardingAxisName.EXPERT,
537
+ ShardingAxisName.MOE_TENSOR, None)
538
+ generate_rules.moe_router_de = (None, ShardingAxisName.EXPERT)
539
+
540
+ # Embedding weight: (VocabSize, Dim)
541
+ generate_rules.emb_weight_vd = (ShardingAxisName.MLP_TENSOR, None)
542
+ generate_rules.activation_td = (ShardingAxisName.MLP_DATA,
543
+ ShardingAxisName.ATTN_TENSOR)
544
+ generate_rules.prelogit_td = (ShardingAxisName.MLP_DATA,
545
+ ShardingAxisName.MLP_TENSOR)
546
+ generate_rules.logits_tv = (ShardingAxisName.MLP_DATA,
547
+ ShardingAxisName.MLP_TENSOR)
548
+ generate_rules.vocab_vd = (ShardingAxisName.VOCAB, None)
549
+ generate_rules.vocab_dv = (None, ShardingAxisName.VOCAB)
550
+
551
+ def make_sharding_config(
552
+ self,
553
+ default_rules_cls: ShardingRulesConfig,
554
+ prefill_overrides: dict | None = None,
555
+ generate_overrides: dict | None = None) -> ShardingConfig:
556
+ """Creates the detailed `ShardingConfig` with specific partitioning rules
557
+ and applies any runtime overrides.
558
+
559
+ This method populates the `prefill_rules` and
560
+ `generate_rules` with hardcoded sharding rules that are generally
561
+ effective for transformer models, and then updates them with any provided
562
+ overrides.
563
+
564
+ Args:
565
+ prefill_overrides: A dictionary with attribute names and their new values
566
+ for the prefill sharding configuration.
567
+ generate_overrides: A dictionary with attribute names and their new values
568
+ for the generate sharding configuration.
569
+
570
+ Returns:
571
+ The populated and overridden `ShardingConfig` object.
572
+ """
573
+ #TODO: organize into update_prefill() and update_decode for each axis
574
+ #TODO: verify the sharding axes
575
+ sharding_cfg = ShardingConfig(default_rules_cls=default_rules_cls)
576
+ prefill_rules = sharding_cfg.prefill_rules
577
+ generate_rules = sharding_cfg.generate_rules
578
+
579
+ # Extract the overrides from the vllm_config if they are not provided programatically.
580
+ if prefill_overrides is None:
581
+ prefill_overrides = self._get_overrides("prefill")
582
+ if generate_overrides is None:
583
+ generate_overrides = self._get_overrides("generate")
584
+
585
+ # Apply default sharding configs
586
+ self._make_default_sharding_config(prefill_rules, generate_rules)
587
+
588
+ # Apply overriding the runtime sharding rules
589
+ self._apply_overrides(prefill_rules, prefill_overrides)
590
+ self._apply_overrides(generate_rules, generate_overrides)
591
+
592
+ return sharding_cfg
593
+
594
+ #TODO: Add __repr__
595
+
596
+
597
+ class ShardingInfo:
598
+ #TODO a sharding info class for visualizing & debugging the sharding performance
599
+ # Will implement it for the next version
600
+ pass
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.