tpu-inference 0.11.1.dev202511150811__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (179) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +105 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +654 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +96 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +182 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +236 -0
  27. tpu_inference/__init__.py +34 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +51 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +728 -0
  37. tpu_inference/distributed/utils.py +59 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +107 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +362 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +390 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +507 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +105 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +311 -0
  120. tpu_inference/mock/__init__.py +0 -0
  121. tpu_inference/mock/vllm_config_utils.py +28 -0
  122. tpu_inference/mock/vllm_envs.py +1219 -0
  123. tpu_inference/mock/vllm_logger.py +212 -0
  124. tpu_inference/mock/vllm_logging_utils.py +15 -0
  125. tpu_inference/models/__init__.py +0 -0
  126. tpu_inference/models/common/__init__.py +0 -0
  127. tpu_inference/models/common/model_loader.py +444 -0
  128. tpu_inference/models/jax/__init__.py +0 -0
  129. tpu_inference/models/jax/deepseek_v3.py +868 -0
  130. tpu_inference/models/jax/gpt_oss.py +492 -0
  131. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  132. tpu_inference/models/jax/llama3.py +375 -0
  133. tpu_inference/models/jax/llama4.py +629 -0
  134. tpu_inference/models/jax/llama_eagle3.py +333 -0
  135. tpu_inference/models/jax/phi3.py +376 -0
  136. tpu_inference/models/jax/qwen2.py +375 -0
  137. tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
  138. tpu_inference/models/jax/qwen3.py +302 -0
  139. tpu_inference/models/jax/utils/__init__.py +0 -0
  140. tpu_inference/models/jax/utils/file_utils.py +96 -0
  141. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  142. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  143. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  144. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  145. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  146. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  147. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  148. tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
  149. tpu_inference/models/jax/utils/weight_utils.py +529 -0
  150. tpu_inference/models/vllm/__init__.py +0 -0
  151. tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
  152. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  153. tpu_inference/platforms/__init__.py +2 -0
  154. tpu_inference/platforms/tpu_platform.py +269 -0
  155. tpu_inference/runner/__init__.py +0 -0
  156. tpu_inference/runner/block_table.py +122 -0
  157. tpu_inference/runner/compilation_manager.py +780 -0
  158. tpu_inference/runner/input_batch.py +435 -0
  159. tpu_inference/runner/kv_cache.py +132 -0
  160. tpu_inference/runner/kv_cache_manager.py +479 -0
  161. tpu_inference/runner/lora_utils.py +92 -0
  162. tpu_inference/runner/multimodal_manager.py +217 -0
  163. tpu_inference/runner/persistent_batch_manager.py +244 -0
  164. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  165. tpu_inference/runner/structured_decoding_manager.py +88 -0
  166. tpu_inference/runner/tpu_runner.py +1620 -0
  167. tpu_inference/runner/utils.py +426 -0
  168. tpu_inference/spec_decode/__init__.py +0 -0
  169. tpu_inference/spec_decode/jax/__init__.py +0 -0
  170. tpu_inference/spec_decode/jax/eagle3.py +367 -0
  171. tpu_inference/tpu_info.py +77 -0
  172. tpu_inference/utils.py +317 -0
  173. tpu_inference/worker/__init__.py +0 -0
  174. tpu_inference/worker/tpu_worker.py +321 -0
  175. tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
  176. tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
  177. tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
  178. tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
  179. tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
@@ -0,0 +1,582 @@
1
+ import json
2
+ import math
3
+ import os
4
+ from dataclasses import asdict, dataclass
5
+ from typing import TYPE_CHECKING, List, Optional
6
+
7
+ import jax.numpy as jnp
8
+ import numpy as np
9
+ from jax.sharding import Mesh
10
+
11
+ from tpu_inference import utils
12
+
13
+ if TYPE_CHECKING:
14
+ from vllm.v1.configs.vllm_config import VllmConfig
15
+
16
+ MESH_AXIS_NAMES = ("data", "attn_dp", "expert", "model")
17
+ MESH_AXIS_NAMES_2D = ('data', 'model')
18
+
19
+
20
+ class ShardingAxisNameBase:
21
+ """Base class for sharding axis names."""
22
+ SEQUENCE = ('data', 'attn_dp')
23
+ ATTN_DATA = ('data', 'attn_dp')
24
+ MLP_DATA = 'data'
25
+ ATTN_HEAD = 'model'
26
+ ATTN_TENSOR = None
27
+ MLP_TENSOR = ('attn_dp', 'model', 'expert')
28
+ MOE_TENSOR = ('attn_dp', 'model')
29
+ EXPERT = ('attn_dp', 'expert', 'model')
30
+ VOCAB = ('expert', 'model')
31
+
32
+
33
+ class ShardingAxisName2D:
34
+ """Sharding axis names for 2D data parallelism scenarios.
35
+ NOTE(wenxindongwork): The new MoE kernel expects a 2D mesh for now.
36
+ We should use ShardingAxisNameBase once the new MoE kernel supports
37
+ more general mesh shapes. For now, this is the default sharding axes.
38
+ """
39
+ SEQUENCE = 'data'
40
+ ATTN_DATA = 'data'
41
+ MLP_DATA = 'data'
42
+ ATTN_HEAD = 'model'
43
+ ATTN_TENSOR = None
44
+ MLP_TENSOR = 'model'
45
+ MOE_TENSOR = 'model'
46
+ EXPERT = 'model'
47
+ VOCAB = ('data', 'model')
48
+
49
+
50
+ try:
51
+ _use_base_sharding = os.getenv("NEW_MODEL_DESIGN", False)
52
+ if _use_base_sharding:
53
+ ShardingAxisName = ShardingAxisNameBase
54
+ else:
55
+ ShardingAxisName = ShardingAxisName2D
56
+ except Exception:
57
+ ShardingAxisName = ShardingAxisName2D
58
+
59
+
60
+ @dataclass
61
+ class ShardingStrategy:
62
+ """Defines the high-level parallelism strategy.
63
+
64
+ This class specifies how many ways each type of parallelism (tensor, expert,
65
+ sequence, data) should be distributed across the available devices.
66
+
67
+ Attributes:
68
+ tensor_parallelism: The degree of tensor parallelism (e.g., splitting
69
+ weights of a single layer).
70
+ expert_parallelism: The degree of expert parallelism for MoE models.
71
+ sequence_parallelism: The degree of sequence parallelism (splitting
72
+ activations along the sequence length dimension).
73
+ data_parallelism: The degree of data parallelism (splitting the batch
74
+ across devices).
75
+ """
76
+ tensor_parallelism: int = 1
77
+ expert_parallelism: int = 1
78
+ sequence_parallelism: int = 1
79
+ data_parallelism: int = 1
80
+ attention_data_parallelism: int = 1
81
+
82
+
83
+ class ShardingConfigManager:
84
+ """Manages sharding configuration parsing and access from vLLM config.
85
+
86
+ Usage:
87
+ sharding_config = ShardingConfigManager.from_vllm_config(vllm_config)
88
+ tp_size = sharding_config.tp_size
89
+
90
+ During initialization, we set `vllm_config.sharding_config` to
91
+ `ShardingConfigManager.from_vllm_config(vllm_config)`, so you can access
92
+ `vllm_config.sharding_config.tp_size` directly.
93
+ """
94
+
95
+ def __init__(self,
96
+ sharding_strategy: ShardingStrategy,
97
+ device_indexes: Optional[List] = None):
98
+
99
+ self.sharding_strategy: ShardingStrategy = sharding_strategy
100
+ self.device_indexes: Optional[List[int]] = device_indexes
101
+ self._total_devices: int = int(
102
+ math.prod(asdict(sharding_strategy).values()))
103
+ if device_indexes:
104
+ assert self._total_devices == len(device_indexes)
105
+
106
+ @classmethod
107
+ def from_vllm_config(cls,
108
+ vllm_config: 'VllmConfig') -> 'ShardingConfigManager':
109
+
110
+ sharding_strategy = vllm_config.additional_config.get(
111
+ "sharding", {}).get("sharding_strategy", {})
112
+ parallel_config = vllm_config.parallel_config
113
+ tensor_parallelism = parallel_config.tensor_parallel_size
114
+ data_parallelism = parallel_config.data_parallel_size
115
+ expert_parallelism = sharding_strategy.get("expert_parallelism", 1)
116
+ sequence_parallelism = sharding_strategy.get("sequence_parallelism", 1)
117
+ device_indexes = sharding_strategy.get("device_indexes", None)
118
+
119
+ enable_dp_attention = sharding_strategy.get("enable_dp_attention",
120
+ False)
121
+ if enable_dp_attention:
122
+ # Replicate attention layer when num_kv_heads < TP
123
+ num_kv_heads = vllm_config.model_config.get_total_num_kv_heads()
124
+ kv_dtype = utils.get_jax_dtype_from_str_dtype(
125
+ vllm_config.cache_config.cache_dtype) or jnp.bfloat16
126
+ packing = 4 // jnp.dtype(kv_dtype).itemsize
127
+ # When num_kv_heads * 2 / packing < TP, tensor parallelism would
128
+ # duplicate KV heads across devices, wasting kv cache memory.
129
+ # Use attention DP instead to reduce per-device num_kv_heads and
130
+ # eliminate this waste.
131
+ num_kv_heads_per_device_in_kv_cache = (num_kv_heads * 2) / packing
132
+ attn_dp = max(
133
+ int(tensor_parallelism // num_kv_heads_per_device_in_kv_cache),
134
+ 1)
135
+ tensor_parallelism = tensor_parallelism // attn_dp
136
+ else:
137
+ attn_dp = 1
138
+
139
+ sharding_strategy = ShardingStrategy(
140
+ tensor_parallelism=tensor_parallelism,
141
+ data_parallelism=data_parallelism,
142
+ expert_parallelism=expert_parallelism,
143
+ sequence_parallelism=sequence_parallelism,
144
+ attention_data_parallelism=attn_dp)
145
+
146
+ # Must override here to avoid vLLM spinning up multiple DP engines.
147
+ if vllm_config.parallel_config.data_parallel_size > 1:
148
+ vllm_config.parallel_config.data_parallel_size = 1
149
+ vllm_config.parallel_config.data_parallel_rank = 0
150
+ vllm_config.parallel_config.data_parallel_size_local = 1
151
+
152
+ cls.validate(vllm_config, sharding_strategy)
153
+ return cls(sharding_strategy, device_indexes)
154
+
155
+ @classmethod
156
+ def validate(cls, vllm_config, sharding_strategy):
157
+ total_dp_size = sharding_strategy.data_parallelism * sharding_strategy.attention_data_parallelism
158
+ if total_dp_size > 1:
159
+ if vllm_config.speculative_config is not None:
160
+ raise ValueError(
161
+ f"Speculative decoding is not supported with data parallelism "
162
+ f"(DP size: {total_dp_size}). Please disable speculative decoding or "
163
+ f"set data parallelism to 1.")
164
+ if vllm_config.lora_config is not None:
165
+ raise ValueError(
166
+ f"LoRA is not supported with data parallelism "
167
+ f"(DP size: {total_dp_size}). Please disable LoRA or "
168
+ f"set data parallelism to 1.")
169
+ if not os.environ.get("NEW_MODEL_DESIGN", False):
170
+ raise ValueError(
171
+ "Must run DP with NEW_MODEL_DESIGN enabled. Please set the "
172
+ "NEW_MODEL_DESIGN=True.")
173
+
174
+ @property
175
+ def total_dp_size(self) -> int:
176
+ return self.sharding_strategy.data_parallelism * self.sharding_strategy.attention_data_parallelism
177
+
178
+ @property
179
+ def model_dp_size(self) -> int:
180
+ return self.sharding_strategy.data_parallelism
181
+
182
+ @property
183
+ def attn_dp_size(self) -> int:
184
+ return self.sharding_strategy.attention_data_parallelism
185
+
186
+ @property
187
+ def tp_size(self) -> int:
188
+ return self.sharding_strategy.tensor_parallelism
189
+
190
+ @property
191
+ def expert_size(self) -> int:
192
+ return self.sharding_strategy.expert_parallelism
193
+
194
+ @property
195
+ def sequence_size(self) -> int:
196
+ return self.sharding_strategy.sequence_parallelism
197
+
198
+ @property
199
+ def total_devices(self) -> int:
200
+ return self._total_devices
201
+
202
+ def __str__(self):
203
+ return (f"ShardingConfigManager(total_devices={self.total_devices}, "
204
+ f"sharding_strategy={self.sharding_strategy}, "
205
+ f"device_indexes={self.device_indexes})")
206
+
207
+
208
+ #TODO split this into block unique sharding config, i.e. attentionShardingConfig, MoEShardingConfig
209
+ @dataclass
210
+ class ShardingRulesConfig:
211
+ """Holds detailed sharding configurations for individual tensors, namely logical rules.
212
+
213
+ Each attribute in this class corresponds to a specific weight or activation
214
+ tensor within a transformer model. The value of each attribute is a
215
+ tuple of logical mesh axis names (e.g., 'dp', 'sp', 'tp'), which defines
216
+ how the corresponding tensor's dimensions are partitioned across the device mesh.
217
+ The dimension order in the attribute name (e.g., `btd` for batch, sequence,
218
+ d_model) maps directly to the sharding tuple.
219
+
220
+ TODO: update the mesh axis names to be clear and reduce confusion between prefill & generate
221
+ """
222
+
223
+ # Activation for attn input: (Batch * Sequence, Dim)
224
+ activation_attention_td: tuple = (None, None)
225
+ # Activation for attn out: (Batch * Sequence, Dim)
226
+ activation_attention_out_td: tuple = (None, None)
227
+ # Activation for q projection input: (Batch * Sequence, Dim)
228
+ activation_q_td: tuple = (None, None)
229
+ # Attention Out activation after projection: (Batch * Sequence, NumHeads, HeadDim)
230
+ attn_o_tnh: tuple = (None, None, None)
231
+ # Q vector: (Batch * Sequence, NumHeads, HeadDim)
232
+ query_tnh: tuple = (None, None, None)
233
+ # K/V vector: (Batch * Sequence, NumKVHeads, HeadDim)
234
+ keyvalue_skh: tuple = (None, None, None)
235
+
236
+ # Attention Q weight: (Dim, NumHeads, HeadDim)
237
+ attn_q_weight_dnh: tuple = (None, None, None)
238
+ # Attention K weight: (Dim, NumKVHeads, HeadDim)
239
+ attn_k_weight_dkh: tuple = (None, None, None)
240
+ # Attention V weight: (Dim, NumKVHeads, HeadDim)
241
+ attn_v_weight_dkh: tuple = (None, None, None)
242
+ # Attention Out weight: (NumHeads, HeadDim, Dim)
243
+ attn_o_weight_nhd: tuple = (None, None, None)
244
+
245
+ # Activation for ffw input: (Batch * Sequence, Dim)
246
+ activation_ffw_td: tuple = (None, None)
247
+
248
+ # Activation for ffw input: (Batch * Sequence, Expert, Dim)
249
+ activation_ffw_ted: tuple = (None, None, None)
250
+
251
+ # FFW hidden activation: (Batch * Sequence, FfwDim)
252
+ ffw_hidden_tf: tuple = (None, None)
253
+
254
+ # FFW up/gate weight: (Dim, FfwDim)
255
+ ffw_weight_df: tuple = (None, None)
256
+ # FFW down weight: (FfwDim, Dim)
257
+ ffw_weight_fd: tuple = (None, None)
258
+ # MoE gate/up weights: (NumExperts, Dim, FfwDim)
259
+ moe_weights_edf: tuple = (None, None, None)
260
+ # MoE down weights: (NumExperts, FfwDim, Dim)
261
+ moe_weights_efd: tuple = (None, None, None)
262
+ # MoE router weights: (Dim, NumExperts)
263
+ moe_router_de: tuple = (None, None)
264
+ # MoE router bias weights: (NumExperts,)
265
+ moe_router_bias_e: tuple = (None, )
266
+
267
+ # Embedding weight: (VocabSize, Dim)
268
+ emb_weight_vd: tuple = (None, None)
269
+ # Activation between layers: (Batch * Sequence, Dim)
270
+ activation_td: tuple = (None, None)
271
+ # Final activation before logits: (Batch * Sequence, Dim)
272
+ prelogit_td: tuple = (None, None)
273
+ # Logit activation: (Batch * Sequence, VocabSize)
274
+ logits_tv: tuple = (None, None)
275
+ # RMS norm scale weight: (Dim,)
276
+ norm_scale: tuple = (None)
277
+ # Vocab projection weight (tied embeddings): (Dim, VocabSize)
278
+ vocab_vd: tuple = (None, None)
279
+ vocab_dv: tuple = (None, None)
280
+
281
+
282
+ class ShardingConfig:
283
+ """Container for operation-specific sharding configurations.
284
+
285
+ This class holds two separate `ShardingRulesConfig` objects, one for the
286
+ 'prefill' phase and one for the 'generate' (or decode) phase of model
287
+ execution. This allows tailoring sharding strategies to the different
288
+ computational patterns of each phase.
289
+
290
+ Example Sharding Strategy and Configuration:
291
+
292
+ Sharding Strategy defines the high-level parallelism dimensions.
293
+ For a device mesh like `Mesh((2, 4, 4, 4), ('data', 'seq', 'expert', 'tensor'))` on 128 devices:
294
+ - data: Data Parallelism (2-way)
295
+ - seq: Sequence Parallelism (4-way)
296
+ - expert: Expert Parallelism (4-way)
297
+ - tensor: Tensor Parallelism (4-way)
298
+
299
+ ShardingConfig then maps tensor dimensions to these logical mesh axes.
300
+ For example, a tensor with shape (Batch, Sequence, Dimension) could be sharded
301
+ differently for prefill and decode/generate operations:
302
+
303
+ - Prefill (long sequences, small batch):
304
+ Sharding sequence dim on the 'sp' axis is often efficient.
305
+ `prefill_rules.activation_attention_btd = (None, 'seq', 'tensor')`
306
+
307
+ - Generate (short sequences, large batch):
308
+ Sharding batch dim on the 'dp' axis is often efficient.
309
+ `generate_rules.activation_attention_btd = ('data', None, 'tensor')`
310
+ """
311
+
312
+ def __init__(self,
313
+ prefill_rules=None,
314
+ generate_rules=None,
315
+ default_rules_cls=ShardingRulesConfig):
316
+ """Initializes the ShardingConfig.
317
+
318
+ Args:
319
+ prefill_rules: An `ShardingRulesConfig` for the prefill phase.
320
+ If None, a default config is created.
321
+ generate_rules: An `ShardingRulesConfig` for the generate phase.
322
+ If None, a default config is created.
323
+ default_rules_cls: The default sharding rules (class) to use.
324
+ """
325
+ # Use a factory pattern to avoid mutable default arguments
326
+ self.default_rules_cls = default_rules_cls
327
+ self.prefill_rules = prefill_rules if prefill_rules is not None else default_rules_cls(
328
+ )
329
+ self.generate_rules = generate_rules if generate_rules is not None else default_rules_cls(
330
+ )
331
+
332
+
333
+ def build_mesh(devices, strategy: dict[str, int]) -> Mesh:
334
+ """Constructs a JAX device mesh from a sharding strategy.
335
+
336
+ This method creates a logical grid of devices based on the parallelism
337
+ degrees defined in the strategy. The logical axis names ('dp', 'ep',
338
+ 'sp', 'tp') are used to map tensor dimensions to the physical device grid.
339
+
340
+ Args:
341
+ strategy: A dictionary from upper level config.
342
+
343
+ Returns:
344
+ A JAX `Mesh` object.
345
+ """
346
+
347
+ axis_order = {
348
+ "data": strategy.get("data_parallelism", 1),
349
+ "expert": strategy.get("expert_parallelism", 1),
350
+ "seq": strategy.get("sequence_parallelism", 1),
351
+ "model": strategy.get("tensor_parallelism", 1),
352
+ }
353
+ # TODO: add logic to infer axis when the degree is -1
354
+ mesh_axis_names = []
355
+ mesh_shape = []
356
+ for axis, dim in axis_order.items():
357
+ mesh_axis_names.append(axis)
358
+ mesh_shape.append(dim)
359
+
360
+ if not mesh_shape:
361
+ mesh_shape = [1]
362
+ mesh_axis_names = [
363
+ 'data'
364
+ ] # default to data parallelism if no other strategy is specified
365
+
366
+ devices = np.asarray(devices).reshape(mesh_shape)
367
+ return Mesh(devices, axis_names=tuple(mesh_axis_names))
368
+
369
+
370
+ class Sharding:
371
+ """Generates and manages sharding configurations based on a high-level strategy.
372
+
373
+ This class populates a `ShardingConfig` with detailed tensor sharding
374
+ rules for both prefill and generation phases. It also allows for runtime
375
+ overrides of these rules.
376
+
377
+ Attributes:
378
+ sharding_cfg: The generated `ShardingConfig` with detailed rules.
379
+ """
380
+
381
+ def __init__(self,
382
+ prefill_rules: dict | None = None,
383
+ generate_rules: dict | None = None,
384
+ default_rules_cls=ShardingRulesConfig,
385
+ vllm_config: 'VllmConfig' = None):
386
+ """Initializes the Sharding manager.
387
+
388
+ Args:
389
+ prefill_rules: A dictionary of overrides for the prefill
390
+ sharding config. Keys are attribute names in `ShardingRulesConfig`,
391
+ and values are the new sharding tuples.
392
+ generate_rules: A dictionary of overrides for the generate
393
+ sharding config.
394
+ """
395
+ self.vllm_config = vllm_config
396
+ self.default_rules_cls = default_rules_cls
397
+ self.sharding_cfg = self.make_sharding_config(
398
+ default_rules_cls=default_rules_cls,
399
+ prefill_overrides=prefill_rules,
400
+ generate_overrides=generate_rules)
401
+
402
+ def _get_overrides(self, sharding_phase: str):
403
+ """Return the overrides from the vLLM config for the given sharding phase."""
404
+ overrides = {}
405
+ try:
406
+ overrides = self.vllm_config.additional_config["sharding"][
407
+ "logical_rules"]["all"]
408
+ except KeyError:
409
+ pass
410
+
411
+ try:
412
+ additional_overrides = self.vllm_config.additional_config[
413
+ "sharding"]["logical_rules"][f"{sharding_phase}"]
414
+ overrides.update(additional_overrides)
415
+ except KeyError:
416
+ pass
417
+ return overrides
418
+
419
+ def __str__(self):
420
+ """Succinct representation of relevant Sharding settings and overrides."""
421
+ output_str = f" Using {self.default_rules_cls.__name__} logical rules.\n"
422
+ output_str += f" {self.__class__.__name__:} overrides:\n"
423
+ output_str += f" prefill logical_rule overrides:\n {json.dumps(self._get_overrides('prefill'), indent=4, default=str)}\n\n"
424
+ output_str += f" generate logical_rule overrides:\n {json.dumps(self._get_overrides('generate'), indent=4, default=str)}\n\n"
425
+ return output_str
426
+
427
+ def validate_sharding_strategy(self, ):
428
+ """Validates if the sharding strategy is compatible with the environment.
429
+
430
+ This method is a placeholder now, and will check if the product of parallelism degrees
431
+ matches the number of available devices.
432
+ """
433
+ #TODO: check num_devices % parallelism == 0
434
+ #TODO: check num_devices == multiply(parallelism(with inferred))
435
+ return
436
+
437
+ def get_sharding_cfg(self) -> ShardingConfig:
438
+ """Returns the generated sharding configuration."""
439
+ return self.sharding_cfg
440
+
441
+ def _apply_overrides(self, config_obj: ShardingRulesConfig,
442
+ overrides: dict | None):
443
+ """Applies runtime overrides to a sharding configuration object.
444
+
445
+ Args:
446
+ config_obj: The sharding configuration object (e.g., prefill_rules)
447
+ to be updated.
448
+ overrides: A dictionary where keys are attribute names of the config
449
+ object and values are the new sharding tuples.
450
+
451
+ Raises:
452
+ AttributeError: If a key in the overrides dictionary is not a valid
453
+ attribute of the configuration object.
454
+ """
455
+ for key, value in overrides.items():
456
+ if hasattr(config_obj, key):
457
+ setattr(config_obj, key, value)
458
+ else:
459
+ # Raise an error for invalid keys to prevent silent failures
460
+ raise AttributeError(
461
+ f"'{key}' is not a valid attribute of {type(config_obj).__name__}"
462
+ )
463
+
464
+ def _make_default_sharding_config(self, prefill_rules, generate_rules):
465
+
466
+ # Populate Prefill Config
467
+ # During prefill, sequence length is long, so we shard along the sequence axis.
468
+ prefill_rules.activation_attention_td = (ShardingAxisName.ATTN_DATA,
469
+ ShardingAxisName.ATTN_TENSOR)
470
+ prefill_rules.activation_attention_out_td = (
471
+ ShardingAxisName.ATTN_DATA, ShardingAxisName.ATTN_TENSOR)
472
+ prefill_rules.activation_q_td = (ShardingAxisName.ATTN_DATA,
473
+ ShardingAxisName.ATTN_TENSOR)
474
+ #TODO: the default qkv and kvcache is sharded on head dim
475
+ # We may change it after we finalize the KVCache design
476
+ prefill_rules.attn_o_tnh = (ShardingAxisName.ATTN_DATA,
477
+ ShardingAxisName.ATTN_HEAD, None)
478
+ prefill_rules.query_tnh = (ShardingAxisName.ATTN_DATA,
479
+ ShardingAxisName.ATTN_HEAD, None)
480
+ prefill_rules.keyvalue_skh = (ShardingAxisName.ATTN_DATA,
481
+ ShardingAxisName.ATTN_HEAD, None)
482
+
483
+ # Populate Generate (Decode) Config
484
+ # During decode, batch size is the large dimension, so we shard along the batch axis.
485
+ generate_rules.activation_attention_td = (ShardingAxisName.ATTN_DATA,
486
+ ShardingAxisName.ATTN_TENSOR)
487
+ generate_rules.activation_attention_out_td = (
488
+ ShardingAxisName.MLP_DATA, ShardingAxisName.ATTN_TENSOR)
489
+ generate_rules.activation_q_td = (ShardingAxisName.ATTN_DATA,
490
+ ShardingAxisName.ATTN_TENSOR)
491
+ #TODO: the default qkv and kvcache is sharded on head dim
492
+ # We may change it after we finalize the KVCache design
493
+ generate_rules.attn_o_tnh = (ShardingAxisName.ATTN_DATA,
494
+ ShardingAxisName.ATTN_HEAD, None)
495
+ generate_rules.query_tnh = (ShardingAxisName.ATTN_DATA,
496
+ ShardingAxisName.ATTN_HEAD, None)
497
+ generate_rules.keyvalue_skh = (ShardingAxisName.ATTN_DATA,
498
+ ShardingAxisName.ATTN_HEAD, None)
499
+ generate_rules.attn_q_weight_dnh = (None, ShardingAxisName.ATTN_HEAD,
500
+ ShardingAxisName.ATTN_TENSOR)
501
+ generate_rules.attn_k_weight_dkh = (None, ShardingAxisName.ATTN_HEAD,
502
+ ShardingAxisName.ATTN_TENSOR)
503
+ generate_rules.attn_v_weight_dkh = (None, ShardingAxisName.ATTN_HEAD,
504
+ ShardingAxisName.ATTN_TENSOR)
505
+ generate_rules.attn_o_weight_nhd = (ShardingAxisName.ATTN_HEAD, None,
506
+ ShardingAxisName.ATTN_TENSOR)
507
+ generate_rules.activation_ffw_td = (ShardingAxisName.MLP_DATA, None)
508
+ generate_rules.activation_ffw_ted = (ShardingAxisName.MLP_DATA,
509
+ ShardingAxisName.EXPERT, None)
510
+ generate_rules.ffw_hidden_tf = (ShardingAxisName.MLP_DATA,
511
+ ShardingAxisName.MLP_TENSOR)
512
+ # FFW weights are typically sharded along the hidden dimension (F).
513
+ generate_rules.ffw_weight_df = (None, ShardingAxisName.MLP_TENSOR)
514
+ generate_rules.ffw_weight_fd = (ShardingAxisName.MLP_TENSOR, None)
515
+ # MoE weights are sharded along the expert axis and the hidden dimension.
516
+ generate_rules.moe_weights_edf = (ShardingAxisName.EXPERT, None,
517
+ ShardingAxisName.MOE_TENSOR)
518
+ generate_rules.moe_weights_efd = (ShardingAxisName.EXPERT,
519
+ ShardingAxisName.MOE_TENSOR, None)
520
+ generate_rules.moe_router_de = (None, ShardingAxisName.EXPERT)
521
+
522
+ # Embedding weight: (VocabSize, Dim)
523
+ generate_rules.emb_weight_vd = (ShardingAxisName.MLP_TENSOR, None)
524
+ generate_rules.activation_td = (ShardingAxisName.MLP_DATA,
525
+ ShardingAxisName.ATTN_TENSOR)
526
+ generate_rules.prelogit_td = (ShardingAxisName.MLP_DATA,
527
+ ShardingAxisName.MLP_TENSOR)
528
+ generate_rules.logits_tv = (ShardingAxisName.MLP_DATA,
529
+ ShardingAxisName.MLP_TENSOR)
530
+ generate_rules.vocab_vd = (ShardingAxisName.VOCAB, None)
531
+ generate_rules.vocab_dv = (None, ShardingAxisName.VOCAB)
532
+
533
+ def make_sharding_config(
534
+ self,
535
+ default_rules_cls: ShardingRulesConfig,
536
+ prefill_overrides: dict | None = None,
537
+ generate_overrides: dict | None = None) -> ShardingConfig:
538
+ """Creates the detailed `ShardingConfig` with specific partitioning rules
539
+ and applies any runtime overrides.
540
+
541
+ This method populates the `prefill_rules` and
542
+ `generate_rules` with hardcoded sharding rules that are generally
543
+ effective for transformer models, and then updates them with any provided
544
+ overrides.
545
+
546
+ Args:
547
+ prefill_overrides: A dictionary with attribute names and their new values
548
+ for the prefill sharding configuration.
549
+ generate_overrides: A dictionary with attribute names and their new values
550
+ for the generate sharding configuration.
551
+
552
+ Returns:
553
+ The populated and overridden `ShardingConfig` object.
554
+ """
555
+ #TODO: organize into update_prefill() and update_decode for each axis
556
+ #TODO: verify the sharding axes
557
+ sharding_cfg = ShardingConfig(default_rules_cls=default_rules_cls)
558
+ prefill_rules = sharding_cfg.prefill_rules
559
+ generate_rules = sharding_cfg.generate_rules
560
+
561
+ # Extract the overrides from the vllm_config if they are not provided programatically.
562
+ if prefill_overrides is None:
563
+ prefill_overrides = self._get_overrides("prefill")
564
+ if generate_overrides is None:
565
+ generate_overrides = self._get_overrides("generate")
566
+
567
+ # Apply default sharding configs
568
+ self._make_default_sharding_config(prefill_rules, generate_rules)
569
+
570
+ # Apply overriding the runtime sharding rules
571
+ self._apply_overrides(prefill_rules, prefill_overrides)
572
+ self._apply_overrides(generate_rules, generate_overrides)
573
+
574
+ return sharding_cfg
575
+
576
+ #TODO: Add __repr__
577
+
578
+
579
+ class ShardingInfo:
580
+ #TODO a sharding info class for visualizing & debugging the sharding performance
581
+ # Will implement it for the next version
582
+ pass
File without changes
File without changes