tpu-inference 0.11.1.dev202511150811__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_dp_scheduler.py +899 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/fused_moe_v1_test.py +105 -0
- tests/kernels/mla_v1_test.py +396 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/conftest.py +32 -0
- tests/lora/test_bgmv.py +43 -0
- tests/lora/test_layers.py +654 -0
- tests/lora/test_lora.py +133 -0
- tests/lora/utils.py +96 -0
- tests/test_base.py +201 -0
- tests/test_envs.py +182 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +236 -0
- tpu_inference/__init__.py +34 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/core/sched/__init__.py +0 -0
- tpu_inference/core/sched/dp_scheduler.py +523 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/jax_parallel_state.py +67 -0
- tpu_inference/distributed/tpu_connector.py +728 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +107 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +362 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
- tpu_inference/kernels/mla/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/kernel.py +1349 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_interface.py +390 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/common/sharding.py +582 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +255 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +280 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +96 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
- tpu_inference/layers/jax/transformer_block.py +107 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +507 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +39 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
- tpu_inference/layers/vllm/sharding.py +230 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +311 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +444 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/gpt_oss.py +492 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
- tpu_inference/models/jax/llama3.py +375 -0
- tpu_inference/models/jax/llama4.py +629 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
- tpu_inference/models/jax/utils/weight_utils.py +529 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_platform.py +269 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +780 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +132 -0
- tpu_inference/runner/kv_cache_manager.py +479 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +217 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +248 -0
- tpu_inference/runner/structured_decoding_manager.py +88 -0
- tpu_inference/runner/tpu_runner.py +1620 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +367 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +317 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/tpu_worker.py +321 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,582 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import math
|
|
3
|
+
import os
|
|
4
|
+
from dataclasses import asdict, dataclass
|
|
5
|
+
from typing import TYPE_CHECKING, List, Optional
|
|
6
|
+
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
import numpy as np
|
|
9
|
+
from jax.sharding import Mesh
|
|
10
|
+
|
|
11
|
+
from tpu_inference import utils
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from vllm.v1.configs.vllm_config import VllmConfig
|
|
15
|
+
|
|
16
|
+
MESH_AXIS_NAMES = ("data", "attn_dp", "expert", "model")
|
|
17
|
+
MESH_AXIS_NAMES_2D = ('data', 'model')
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ShardingAxisNameBase:
|
|
21
|
+
"""Base class for sharding axis names."""
|
|
22
|
+
SEQUENCE = ('data', 'attn_dp')
|
|
23
|
+
ATTN_DATA = ('data', 'attn_dp')
|
|
24
|
+
MLP_DATA = 'data'
|
|
25
|
+
ATTN_HEAD = 'model'
|
|
26
|
+
ATTN_TENSOR = None
|
|
27
|
+
MLP_TENSOR = ('attn_dp', 'model', 'expert')
|
|
28
|
+
MOE_TENSOR = ('attn_dp', 'model')
|
|
29
|
+
EXPERT = ('attn_dp', 'expert', 'model')
|
|
30
|
+
VOCAB = ('expert', 'model')
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class ShardingAxisName2D:
|
|
34
|
+
"""Sharding axis names for 2D data parallelism scenarios.
|
|
35
|
+
NOTE(wenxindongwork): The new MoE kernel expects a 2D mesh for now.
|
|
36
|
+
We should use ShardingAxisNameBase once the new MoE kernel supports
|
|
37
|
+
more general mesh shapes. For now, this is the default sharding axes.
|
|
38
|
+
"""
|
|
39
|
+
SEQUENCE = 'data'
|
|
40
|
+
ATTN_DATA = 'data'
|
|
41
|
+
MLP_DATA = 'data'
|
|
42
|
+
ATTN_HEAD = 'model'
|
|
43
|
+
ATTN_TENSOR = None
|
|
44
|
+
MLP_TENSOR = 'model'
|
|
45
|
+
MOE_TENSOR = 'model'
|
|
46
|
+
EXPERT = 'model'
|
|
47
|
+
VOCAB = ('data', 'model')
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
try:
|
|
51
|
+
_use_base_sharding = os.getenv("NEW_MODEL_DESIGN", False)
|
|
52
|
+
if _use_base_sharding:
|
|
53
|
+
ShardingAxisName = ShardingAxisNameBase
|
|
54
|
+
else:
|
|
55
|
+
ShardingAxisName = ShardingAxisName2D
|
|
56
|
+
except Exception:
|
|
57
|
+
ShardingAxisName = ShardingAxisName2D
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@dataclass
|
|
61
|
+
class ShardingStrategy:
|
|
62
|
+
"""Defines the high-level parallelism strategy.
|
|
63
|
+
|
|
64
|
+
This class specifies how many ways each type of parallelism (tensor, expert,
|
|
65
|
+
sequence, data) should be distributed across the available devices.
|
|
66
|
+
|
|
67
|
+
Attributes:
|
|
68
|
+
tensor_parallelism: The degree of tensor parallelism (e.g., splitting
|
|
69
|
+
weights of a single layer).
|
|
70
|
+
expert_parallelism: The degree of expert parallelism for MoE models.
|
|
71
|
+
sequence_parallelism: The degree of sequence parallelism (splitting
|
|
72
|
+
activations along the sequence length dimension).
|
|
73
|
+
data_parallelism: The degree of data parallelism (splitting the batch
|
|
74
|
+
across devices).
|
|
75
|
+
"""
|
|
76
|
+
tensor_parallelism: int = 1
|
|
77
|
+
expert_parallelism: int = 1
|
|
78
|
+
sequence_parallelism: int = 1
|
|
79
|
+
data_parallelism: int = 1
|
|
80
|
+
attention_data_parallelism: int = 1
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class ShardingConfigManager:
|
|
84
|
+
"""Manages sharding configuration parsing and access from vLLM config.
|
|
85
|
+
|
|
86
|
+
Usage:
|
|
87
|
+
sharding_config = ShardingConfigManager.from_vllm_config(vllm_config)
|
|
88
|
+
tp_size = sharding_config.tp_size
|
|
89
|
+
|
|
90
|
+
During initialization, we set `vllm_config.sharding_config` to
|
|
91
|
+
`ShardingConfigManager.from_vllm_config(vllm_config)`, so you can access
|
|
92
|
+
`vllm_config.sharding_config.tp_size` directly.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
def __init__(self,
|
|
96
|
+
sharding_strategy: ShardingStrategy,
|
|
97
|
+
device_indexes: Optional[List] = None):
|
|
98
|
+
|
|
99
|
+
self.sharding_strategy: ShardingStrategy = sharding_strategy
|
|
100
|
+
self.device_indexes: Optional[List[int]] = device_indexes
|
|
101
|
+
self._total_devices: int = int(
|
|
102
|
+
math.prod(asdict(sharding_strategy).values()))
|
|
103
|
+
if device_indexes:
|
|
104
|
+
assert self._total_devices == len(device_indexes)
|
|
105
|
+
|
|
106
|
+
@classmethod
|
|
107
|
+
def from_vllm_config(cls,
|
|
108
|
+
vllm_config: 'VllmConfig') -> 'ShardingConfigManager':
|
|
109
|
+
|
|
110
|
+
sharding_strategy = vllm_config.additional_config.get(
|
|
111
|
+
"sharding", {}).get("sharding_strategy", {})
|
|
112
|
+
parallel_config = vllm_config.parallel_config
|
|
113
|
+
tensor_parallelism = parallel_config.tensor_parallel_size
|
|
114
|
+
data_parallelism = parallel_config.data_parallel_size
|
|
115
|
+
expert_parallelism = sharding_strategy.get("expert_parallelism", 1)
|
|
116
|
+
sequence_parallelism = sharding_strategy.get("sequence_parallelism", 1)
|
|
117
|
+
device_indexes = sharding_strategy.get("device_indexes", None)
|
|
118
|
+
|
|
119
|
+
enable_dp_attention = sharding_strategy.get("enable_dp_attention",
|
|
120
|
+
False)
|
|
121
|
+
if enable_dp_attention:
|
|
122
|
+
# Replicate attention layer when num_kv_heads < TP
|
|
123
|
+
num_kv_heads = vllm_config.model_config.get_total_num_kv_heads()
|
|
124
|
+
kv_dtype = utils.get_jax_dtype_from_str_dtype(
|
|
125
|
+
vllm_config.cache_config.cache_dtype) or jnp.bfloat16
|
|
126
|
+
packing = 4 // jnp.dtype(kv_dtype).itemsize
|
|
127
|
+
# When num_kv_heads * 2 / packing < TP, tensor parallelism would
|
|
128
|
+
# duplicate KV heads across devices, wasting kv cache memory.
|
|
129
|
+
# Use attention DP instead to reduce per-device num_kv_heads and
|
|
130
|
+
# eliminate this waste.
|
|
131
|
+
num_kv_heads_per_device_in_kv_cache = (num_kv_heads * 2) / packing
|
|
132
|
+
attn_dp = max(
|
|
133
|
+
int(tensor_parallelism // num_kv_heads_per_device_in_kv_cache),
|
|
134
|
+
1)
|
|
135
|
+
tensor_parallelism = tensor_parallelism // attn_dp
|
|
136
|
+
else:
|
|
137
|
+
attn_dp = 1
|
|
138
|
+
|
|
139
|
+
sharding_strategy = ShardingStrategy(
|
|
140
|
+
tensor_parallelism=tensor_parallelism,
|
|
141
|
+
data_parallelism=data_parallelism,
|
|
142
|
+
expert_parallelism=expert_parallelism,
|
|
143
|
+
sequence_parallelism=sequence_parallelism,
|
|
144
|
+
attention_data_parallelism=attn_dp)
|
|
145
|
+
|
|
146
|
+
# Must override here to avoid vLLM spinning up multiple DP engines.
|
|
147
|
+
if vllm_config.parallel_config.data_parallel_size > 1:
|
|
148
|
+
vllm_config.parallel_config.data_parallel_size = 1
|
|
149
|
+
vllm_config.parallel_config.data_parallel_rank = 0
|
|
150
|
+
vllm_config.parallel_config.data_parallel_size_local = 1
|
|
151
|
+
|
|
152
|
+
cls.validate(vllm_config, sharding_strategy)
|
|
153
|
+
return cls(sharding_strategy, device_indexes)
|
|
154
|
+
|
|
155
|
+
@classmethod
|
|
156
|
+
def validate(cls, vllm_config, sharding_strategy):
|
|
157
|
+
total_dp_size = sharding_strategy.data_parallelism * sharding_strategy.attention_data_parallelism
|
|
158
|
+
if total_dp_size > 1:
|
|
159
|
+
if vllm_config.speculative_config is not None:
|
|
160
|
+
raise ValueError(
|
|
161
|
+
f"Speculative decoding is not supported with data parallelism "
|
|
162
|
+
f"(DP size: {total_dp_size}). Please disable speculative decoding or "
|
|
163
|
+
f"set data parallelism to 1.")
|
|
164
|
+
if vllm_config.lora_config is not None:
|
|
165
|
+
raise ValueError(
|
|
166
|
+
f"LoRA is not supported with data parallelism "
|
|
167
|
+
f"(DP size: {total_dp_size}). Please disable LoRA or "
|
|
168
|
+
f"set data parallelism to 1.")
|
|
169
|
+
if not os.environ.get("NEW_MODEL_DESIGN", False):
|
|
170
|
+
raise ValueError(
|
|
171
|
+
"Must run DP with NEW_MODEL_DESIGN enabled. Please set the "
|
|
172
|
+
"NEW_MODEL_DESIGN=True.")
|
|
173
|
+
|
|
174
|
+
@property
|
|
175
|
+
def total_dp_size(self) -> int:
|
|
176
|
+
return self.sharding_strategy.data_parallelism * self.sharding_strategy.attention_data_parallelism
|
|
177
|
+
|
|
178
|
+
@property
|
|
179
|
+
def model_dp_size(self) -> int:
|
|
180
|
+
return self.sharding_strategy.data_parallelism
|
|
181
|
+
|
|
182
|
+
@property
|
|
183
|
+
def attn_dp_size(self) -> int:
|
|
184
|
+
return self.sharding_strategy.attention_data_parallelism
|
|
185
|
+
|
|
186
|
+
@property
|
|
187
|
+
def tp_size(self) -> int:
|
|
188
|
+
return self.sharding_strategy.tensor_parallelism
|
|
189
|
+
|
|
190
|
+
@property
|
|
191
|
+
def expert_size(self) -> int:
|
|
192
|
+
return self.sharding_strategy.expert_parallelism
|
|
193
|
+
|
|
194
|
+
@property
|
|
195
|
+
def sequence_size(self) -> int:
|
|
196
|
+
return self.sharding_strategy.sequence_parallelism
|
|
197
|
+
|
|
198
|
+
@property
|
|
199
|
+
def total_devices(self) -> int:
|
|
200
|
+
return self._total_devices
|
|
201
|
+
|
|
202
|
+
def __str__(self):
|
|
203
|
+
return (f"ShardingConfigManager(total_devices={self.total_devices}, "
|
|
204
|
+
f"sharding_strategy={self.sharding_strategy}, "
|
|
205
|
+
f"device_indexes={self.device_indexes})")
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
#TODO split this into block unique sharding config, i.e. attentionShardingConfig, MoEShardingConfig
|
|
209
|
+
@dataclass
|
|
210
|
+
class ShardingRulesConfig:
|
|
211
|
+
"""Holds detailed sharding configurations for individual tensors, namely logical rules.
|
|
212
|
+
|
|
213
|
+
Each attribute in this class corresponds to a specific weight or activation
|
|
214
|
+
tensor within a transformer model. The value of each attribute is a
|
|
215
|
+
tuple of logical mesh axis names (e.g., 'dp', 'sp', 'tp'), which defines
|
|
216
|
+
how the corresponding tensor's dimensions are partitioned across the device mesh.
|
|
217
|
+
The dimension order in the attribute name (e.g., `btd` for batch, sequence,
|
|
218
|
+
d_model) maps directly to the sharding tuple.
|
|
219
|
+
|
|
220
|
+
TODO: update the mesh axis names to be clear and reduce confusion between prefill & generate
|
|
221
|
+
"""
|
|
222
|
+
|
|
223
|
+
# Activation for attn input: (Batch * Sequence, Dim)
|
|
224
|
+
activation_attention_td: tuple = (None, None)
|
|
225
|
+
# Activation for attn out: (Batch * Sequence, Dim)
|
|
226
|
+
activation_attention_out_td: tuple = (None, None)
|
|
227
|
+
# Activation for q projection input: (Batch * Sequence, Dim)
|
|
228
|
+
activation_q_td: tuple = (None, None)
|
|
229
|
+
# Attention Out activation after projection: (Batch * Sequence, NumHeads, HeadDim)
|
|
230
|
+
attn_o_tnh: tuple = (None, None, None)
|
|
231
|
+
# Q vector: (Batch * Sequence, NumHeads, HeadDim)
|
|
232
|
+
query_tnh: tuple = (None, None, None)
|
|
233
|
+
# K/V vector: (Batch * Sequence, NumKVHeads, HeadDim)
|
|
234
|
+
keyvalue_skh: tuple = (None, None, None)
|
|
235
|
+
|
|
236
|
+
# Attention Q weight: (Dim, NumHeads, HeadDim)
|
|
237
|
+
attn_q_weight_dnh: tuple = (None, None, None)
|
|
238
|
+
# Attention K weight: (Dim, NumKVHeads, HeadDim)
|
|
239
|
+
attn_k_weight_dkh: tuple = (None, None, None)
|
|
240
|
+
# Attention V weight: (Dim, NumKVHeads, HeadDim)
|
|
241
|
+
attn_v_weight_dkh: tuple = (None, None, None)
|
|
242
|
+
# Attention Out weight: (NumHeads, HeadDim, Dim)
|
|
243
|
+
attn_o_weight_nhd: tuple = (None, None, None)
|
|
244
|
+
|
|
245
|
+
# Activation for ffw input: (Batch * Sequence, Dim)
|
|
246
|
+
activation_ffw_td: tuple = (None, None)
|
|
247
|
+
|
|
248
|
+
# Activation for ffw input: (Batch * Sequence, Expert, Dim)
|
|
249
|
+
activation_ffw_ted: tuple = (None, None, None)
|
|
250
|
+
|
|
251
|
+
# FFW hidden activation: (Batch * Sequence, FfwDim)
|
|
252
|
+
ffw_hidden_tf: tuple = (None, None)
|
|
253
|
+
|
|
254
|
+
# FFW up/gate weight: (Dim, FfwDim)
|
|
255
|
+
ffw_weight_df: tuple = (None, None)
|
|
256
|
+
# FFW down weight: (FfwDim, Dim)
|
|
257
|
+
ffw_weight_fd: tuple = (None, None)
|
|
258
|
+
# MoE gate/up weights: (NumExperts, Dim, FfwDim)
|
|
259
|
+
moe_weights_edf: tuple = (None, None, None)
|
|
260
|
+
# MoE down weights: (NumExperts, FfwDim, Dim)
|
|
261
|
+
moe_weights_efd: tuple = (None, None, None)
|
|
262
|
+
# MoE router weights: (Dim, NumExperts)
|
|
263
|
+
moe_router_de: tuple = (None, None)
|
|
264
|
+
# MoE router bias weights: (NumExperts,)
|
|
265
|
+
moe_router_bias_e: tuple = (None, )
|
|
266
|
+
|
|
267
|
+
# Embedding weight: (VocabSize, Dim)
|
|
268
|
+
emb_weight_vd: tuple = (None, None)
|
|
269
|
+
# Activation between layers: (Batch * Sequence, Dim)
|
|
270
|
+
activation_td: tuple = (None, None)
|
|
271
|
+
# Final activation before logits: (Batch * Sequence, Dim)
|
|
272
|
+
prelogit_td: tuple = (None, None)
|
|
273
|
+
# Logit activation: (Batch * Sequence, VocabSize)
|
|
274
|
+
logits_tv: tuple = (None, None)
|
|
275
|
+
# RMS norm scale weight: (Dim,)
|
|
276
|
+
norm_scale: tuple = (None)
|
|
277
|
+
# Vocab projection weight (tied embeddings): (Dim, VocabSize)
|
|
278
|
+
vocab_vd: tuple = (None, None)
|
|
279
|
+
vocab_dv: tuple = (None, None)
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
class ShardingConfig:
|
|
283
|
+
"""Container for operation-specific sharding configurations.
|
|
284
|
+
|
|
285
|
+
This class holds two separate `ShardingRulesConfig` objects, one for the
|
|
286
|
+
'prefill' phase and one for the 'generate' (or decode) phase of model
|
|
287
|
+
execution. This allows tailoring sharding strategies to the different
|
|
288
|
+
computational patterns of each phase.
|
|
289
|
+
|
|
290
|
+
Example Sharding Strategy and Configuration:
|
|
291
|
+
|
|
292
|
+
Sharding Strategy defines the high-level parallelism dimensions.
|
|
293
|
+
For a device mesh like `Mesh((2, 4, 4, 4), ('data', 'seq', 'expert', 'tensor'))` on 128 devices:
|
|
294
|
+
- data: Data Parallelism (2-way)
|
|
295
|
+
- seq: Sequence Parallelism (4-way)
|
|
296
|
+
- expert: Expert Parallelism (4-way)
|
|
297
|
+
- tensor: Tensor Parallelism (4-way)
|
|
298
|
+
|
|
299
|
+
ShardingConfig then maps tensor dimensions to these logical mesh axes.
|
|
300
|
+
For example, a tensor with shape (Batch, Sequence, Dimension) could be sharded
|
|
301
|
+
differently for prefill and decode/generate operations:
|
|
302
|
+
|
|
303
|
+
- Prefill (long sequences, small batch):
|
|
304
|
+
Sharding sequence dim on the 'sp' axis is often efficient.
|
|
305
|
+
`prefill_rules.activation_attention_btd = (None, 'seq', 'tensor')`
|
|
306
|
+
|
|
307
|
+
- Generate (short sequences, large batch):
|
|
308
|
+
Sharding batch dim on the 'dp' axis is often efficient.
|
|
309
|
+
`generate_rules.activation_attention_btd = ('data', None, 'tensor')`
|
|
310
|
+
"""
|
|
311
|
+
|
|
312
|
+
def __init__(self,
|
|
313
|
+
prefill_rules=None,
|
|
314
|
+
generate_rules=None,
|
|
315
|
+
default_rules_cls=ShardingRulesConfig):
|
|
316
|
+
"""Initializes the ShardingConfig.
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
prefill_rules: An `ShardingRulesConfig` for the prefill phase.
|
|
320
|
+
If None, a default config is created.
|
|
321
|
+
generate_rules: An `ShardingRulesConfig` for the generate phase.
|
|
322
|
+
If None, a default config is created.
|
|
323
|
+
default_rules_cls: The default sharding rules (class) to use.
|
|
324
|
+
"""
|
|
325
|
+
# Use a factory pattern to avoid mutable default arguments
|
|
326
|
+
self.default_rules_cls = default_rules_cls
|
|
327
|
+
self.prefill_rules = prefill_rules if prefill_rules is not None else default_rules_cls(
|
|
328
|
+
)
|
|
329
|
+
self.generate_rules = generate_rules if generate_rules is not None else default_rules_cls(
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def build_mesh(devices, strategy: dict[str, int]) -> Mesh:
|
|
334
|
+
"""Constructs a JAX device mesh from a sharding strategy.
|
|
335
|
+
|
|
336
|
+
This method creates a logical grid of devices based on the parallelism
|
|
337
|
+
degrees defined in the strategy. The logical axis names ('dp', 'ep',
|
|
338
|
+
'sp', 'tp') are used to map tensor dimensions to the physical device grid.
|
|
339
|
+
|
|
340
|
+
Args:
|
|
341
|
+
strategy: A dictionary from upper level config.
|
|
342
|
+
|
|
343
|
+
Returns:
|
|
344
|
+
A JAX `Mesh` object.
|
|
345
|
+
"""
|
|
346
|
+
|
|
347
|
+
axis_order = {
|
|
348
|
+
"data": strategy.get("data_parallelism", 1),
|
|
349
|
+
"expert": strategy.get("expert_parallelism", 1),
|
|
350
|
+
"seq": strategy.get("sequence_parallelism", 1),
|
|
351
|
+
"model": strategy.get("tensor_parallelism", 1),
|
|
352
|
+
}
|
|
353
|
+
# TODO: add logic to infer axis when the degree is -1
|
|
354
|
+
mesh_axis_names = []
|
|
355
|
+
mesh_shape = []
|
|
356
|
+
for axis, dim in axis_order.items():
|
|
357
|
+
mesh_axis_names.append(axis)
|
|
358
|
+
mesh_shape.append(dim)
|
|
359
|
+
|
|
360
|
+
if not mesh_shape:
|
|
361
|
+
mesh_shape = [1]
|
|
362
|
+
mesh_axis_names = [
|
|
363
|
+
'data'
|
|
364
|
+
] # default to data parallelism if no other strategy is specified
|
|
365
|
+
|
|
366
|
+
devices = np.asarray(devices).reshape(mesh_shape)
|
|
367
|
+
return Mesh(devices, axis_names=tuple(mesh_axis_names))
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
class Sharding:
|
|
371
|
+
"""Generates and manages sharding configurations based on a high-level strategy.
|
|
372
|
+
|
|
373
|
+
This class populates a `ShardingConfig` with detailed tensor sharding
|
|
374
|
+
rules for both prefill and generation phases. It also allows for runtime
|
|
375
|
+
overrides of these rules.
|
|
376
|
+
|
|
377
|
+
Attributes:
|
|
378
|
+
sharding_cfg: The generated `ShardingConfig` with detailed rules.
|
|
379
|
+
"""
|
|
380
|
+
|
|
381
|
+
def __init__(self,
|
|
382
|
+
prefill_rules: dict | None = None,
|
|
383
|
+
generate_rules: dict | None = None,
|
|
384
|
+
default_rules_cls=ShardingRulesConfig,
|
|
385
|
+
vllm_config: 'VllmConfig' = None):
|
|
386
|
+
"""Initializes the Sharding manager.
|
|
387
|
+
|
|
388
|
+
Args:
|
|
389
|
+
prefill_rules: A dictionary of overrides for the prefill
|
|
390
|
+
sharding config. Keys are attribute names in `ShardingRulesConfig`,
|
|
391
|
+
and values are the new sharding tuples.
|
|
392
|
+
generate_rules: A dictionary of overrides for the generate
|
|
393
|
+
sharding config.
|
|
394
|
+
"""
|
|
395
|
+
self.vllm_config = vllm_config
|
|
396
|
+
self.default_rules_cls = default_rules_cls
|
|
397
|
+
self.sharding_cfg = self.make_sharding_config(
|
|
398
|
+
default_rules_cls=default_rules_cls,
|
|
399
|
+
prefill_overrides=prefill_rules,
|
|
400
|
+
generate_overrides=generate_rules)
|
|
401
|
+
|
|
402
|
+
def _get_overrides(self, sharding_phase: str):
|
|
403
|
+
"""Return the overrides from the vLLM config for the given sharding phase."""
|
|
404
|
+
overrides = {}
|
|
405
|
+
try:
|
|
406
|
+
overrides = self.vllm_config.additional_config["sharding"][
|
|
407
|
+
"logical_rules"]["all"]
|
|
408
|
+
except KeyError:
|
|
409
|
+
pass
|
|
410
|
+
|
|
411
|
+
try:
|
|
412
|
+
additional_overrides = self.vllm_config.additional_config[
|
|
413
|
+
"sharding"]["logical_rules"][f"{sharding_phase}"]
|
|
414
|
+
overrides.update(additional_overrides)
|
|
415
|
+
except KeyError:
|
|
416
|
+
pass
|
|
417
|
+
return overrides
|
|
418
|
+
|
|
419
|
+
def __str__(self):
|
|
420
|
+
"""Succinct representation of relevant Sharding settings and overrides."""
|
|
421
|
+
output_str = f" Using {self.default_rules_cls.__name__} logical rules.\n"
|
|
422
|
+
output_str += f" {self.__class__.__name__:} overrides:\n"
|
|
423
|
+
output_str += f" prefill logical_rule overrides:\n {json.dumps(self._get_overrides('prefill'), indent=4, default=str)}\n\n"
|
|
424
|
+
output_str += f" generate logical_rule overrides:\n {json.dumps(self._get_overrides('generate'), indent=4, default=str)}\n\n"
|
|
425
|
+
return output_str
|
|
426
|
+
|
|
427
|
+
def validate_sharding_strategy(self, ):
|
|
428
|
+
"""Validates if the sharding strategy is compatible with the environment.
|
|
429
|
+
|
|
430
|
+
This method is a placeholder now, and will check if the product of parallelism degrees
|
|
431
|
+
matches the number of available devices.
|
|
432
|
+
"""
|
|
433
|
+
#TODO: check num_devices % parallelism == 0
|
|
434
|
+
#TODO: check num_devices == multiply(parallelism(with inferred))
|
|
435
|
+
return
|
|
436
|
+
|
|
437
|
+
def get_sharding_cfg(self) -> ShardingConfig:
|
|
438
|
+
"""Returns the generated sharding configuration."""
|
|
439
|
+
return self.sharding_cfg
|
|
440
|
+
|
|
441
|
+
def _apply_overrides(self, config_obj: ShardingRulesConfig,
|
|
442
|
+
overrides: dict | None):
|
|
443
|
+
"""Applies runtime overrides to a sharding configuration object.
|
|
444
|
+
|
|
445
|
+
Args:
|
|
446
|
+
config_obj: The sharding configuration object (e.g., prefill_rules)
|
|
447
|
+
to be updated.
|
|
448
|
+
overrides: A dictionary where keys are attribute names of the config
|
|
449
|
+
object and values are the new sharding tuples.
|
|
450
|
+
|
|
451
|
+
Raises:
|
|
452
|
+
AttributeError: If a key in the overrides dictionary is not a valid
|
|
453
|
+
attribute of the configuration object.
|
|
454
|
+
"""
|
|
455
|
+
for key, value in overrides.items():
|
|
456
|
+
if hasattr(config_obj, key):
|
|
457
|
+
setattr(config_obj, key, value)
|
|
458
|
+
else:
|
|
459
|
+
# Raise an error for invalid keys to prevent silent failures
|
|
460
|
+
raise AttributeError(
|
|
461
|
+
f"'{key}' is not a valid attribute of {type(config_obj).__name__}"
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
def _make_default_sharding_config(self, prefill_rules, generate_rules):
|
|
465
|
+
|
|
466
|
+
# Populate Prefill Config
|
|
467
|
+
# During prefill, sequence length is long, so we shard along the sequence axis.
|
|
468
|
+
prefill_rules.activation_attention_td = (ShardingAxisName.ATTN_DATA,
|
|
469
|
+
ShardingAxisName.ATTN_TENSOR)
|
|
470
|
+
prefill_rules.activation_attention_out_td = (
|
|
471
|
+
ShardingAxisName.ATTN_DATA, ShardingAxisName.ATTN_TENSOR)
|
|
472
|
+
prefill_rules.activation_q_td = (ShardingAxisName.ATTN_DATA,
|
|
473
|
+
ShardingAxisName.ATTN_TENSOR)
|
|
474
|
+
#TODO: the default qkv and kvcache is sharded on head dim
|
|
475
|
+
# We may change it after we finalize the KVCache design
|
|
476
|
+
prefill_rules.attn_o_tnh = (ShardingAxisName.ATTN_DATA,
|
|
477
|
+
ShardingAxisName.ATTN_HEAD, None)
|
|
478
|
+
prefill_rules.query_tnh = (ShardingAxisName.ATTN_DATA,
|
|
479
|
+
ShardingAxisName.ATTN_HEAD, None)
|
|
480
|
+
prefill_rules.keyvalue_skh = (ShardingAxisName.ATTN_DATA,
|
|
481
|
+
ShardingAxisName.ATTN_HEAD, None)
|
|
482
|
+
|
|
483
|
+
# Populate Generate (Decode) Config
|
|
484
|
+
# During decode, batch size is the large dimension, so we shard along the batch axis.
|
|
485
|
+
generate_rules.activation_attention_td = (ShardingAxisName.ATTN_DATA,
|
|
486
|
+
ShardingAxisName.ATTN_TENSOR)
|
|
487
|
+
generate_rules.activation_attention_out_td = (
|
|
488
|
+
ShardingAxisName.MLP_DATA, ShardingAxisName.ATTN_TENSOR)
|
|
489
|
+
generate_rules.activation_q_td = (ShardingAxisName.ATTN_DATA,
|
|
490
|
+
ShardingAxisName.ATTN_TENSOR)
|
|
491
|
+
#TODO: the default qkv and kvcache is sharded on head dim
|
|
492
|
+
# We may change it after we finalize the KVCache design
|
|
493
|
+
generate_rules.attn_o_tnh = (ShardingAxisName.ATTN_DATA,
|
|
494
|
+
ShardingAxisName.ATTN_HEAD, None)
|
|
495
|
+
generate_rules.query_tnh = (ShardingAxisName.ATTN_DATA,
|
|
496
|
+
ShardingAxisName.ATTN_HEAD, None)
|
|
497
|
+
generate_rules.keyvalue_skh = (ShardingAxisName.ATTN_DATA,
|
|
498
|
+
ShardingAxisName.ATTN_HEAD, None)
|
|
499
|
+
generate_rules.attn_q_weight_dnh = (None, ShardingAxisName.ATTN_HEAD,
|
|
500
|
+
ShardingAxisName.ATTN_TENSOR)
|
|
501
|
+
generate_rules.attn_k_weight_dkh = (None, ShardingAxisName.ATTN_HEAD,
|
|
502
|
+
ShardingAxisName.ATTN_TENSOR)
|
|
503
|
+
generate_rules.attn_v_weight_dkh = (None, ShardingAxisName.ATTN_HEAD,
|
|
504
|
+
ShardingAxisName.ATTN_TENSOR)
|
|
505
|
+
generate_rules.attn_o_weight_nhd = (ShardingAxisName.ATTN_HEAD, None,
|
|
506
|
+
ShardingAxisName.ATTN_TENSOR)
|
|
507
|
+
generate_rules.activation_ffw_td = (ShardingAxisName.MLP_DATA, None)
|
|
508
|
+
generate_rules.activation_ffw_ted = (ShardingAxisName.MLP_DATA,
|
|
509
|
+
ShardingAxisName.EXPERT, None)
|
|
510
|
+
generate_rules.ffw_hidden_tf = (ShardingAxisName.MLP_DATA,
|
|
511
|
+
ShardingAxisName.MLP_TENSOR)
|
|
512
|
+
# FFW weights are typically sharded along the hidden dimension (F).
|
|
513
|
+
generate_rules.ffw_weight_df = (None, ShardingAxisName.MLP_TENSOR)
|
|
514
|
+
generate_rules.ffw_weight_fd = (ShardingAxisName.MLP_TENSOR, None)
|
|
515
|
+
# MoE weights are sharded along the expert axis and the hidden dimension.
|
|
516
|
+
generate_rules.moe_weights_edf = (ShardingAxisName.EXPERT, None,
|
|
517
|
+
ShardingAxisName.MOE_TENSOR)
|
|
518
|
+
generate_rules.moe_weights_efd = (ShardingAxisName.EXPERT,
|
|
519
|
+
ShardingAxisName.MOE_TENSOR, None)
|
|
520
|
+
generate_rules.moe_router_de = (None, ShardingAxisName.EXPERT)
|
|
521
|
+
|
|
522
|
+
# Embedding weight: (VocabSize, Dim)
|
|
523
|
+
generate_rules.emb_weight_vd = (ShardingAxisName.MLP_TENSOR, None)
|
|
524
|
+
generate_rules.activation_td = (ShardingAxisName.MLP_DATA,
|
|
525
|
+
ShardingAxisName.ATTN_TENSOR)
|
|
526
|
+
generate_rules.prelogit_td = (ShardingAxisName.MLP_DATA,
|
|
527
|
+
ShardingAxisName.MLP_TENSOR)
|
|
528
|
+
generate_rules.logits_tv = (ShardingAxisName.MLP_DATA,
|
|
529
|
+
ShardingAxisName.MLP_TENSOR)
|
|
530
|
+
generate_rules.vocab_vd = (ShardingAxisName.VOCAB, None)
|
|
531
|
+
generate_rules.vocab_dv = (None, ShardingAxisName.VOCAB)
|
|
532
|
+
|
|
533
|
+
def make_sharding_config(
|
|
534
|
+
self,
|
|
535
|
+
default_rules_cls: ShardingRulesConfig,
|
|
536
|
+
prefill_overrides: dict | None = None,
|
|
537
|
+
generate_overrides: dict | None = None) -> ShardingConfig:
|
|
538
|
+
"""Creates the detailed `ShardingConfig` with specific partitioning rules
|
|
539
|
+
and applies any runtime overrides.
|
|
540
|
+
|
|
541
|
+
This method populates the `prefill_rules` and
|
|
542
|
+
`generate_rules` with hardcoded sharding rules that are generally
|
|
543
|
+
effective for transformer models, and then updates them with any provided
|
|
544
|
+
overrides.
|
|
545
|
+
|
|
546
|
+
Args:
|
|
547
|
+
prefill_overrides: A dictionary with attribute names and their new values
|
|
548
|
+
for the prefill sharding configuration.
|
|
549
|
+
generate_overrides: A dictionary with attribute names and their new values
|
|
550
|
+
for the generate sharding configuration.
|
|
551
|
+
|
|
552
|
+
Returns:
|
|
553
|
+
The populated and overridden `ShardingConfig` object.
|
|
554
|
+
"""
|
|
555
|
+
#TODO: organize into update_prefill() and update_decode for each axis
|
|
556
|
+
#TODO: verify the sharding axes
|
|
557
|
+
sharding_cfg = ShardingConfig(default_rules_cls=default_rules_cls)
|
|
558
|
+
prefill_rules = sharding_cfg.prefill_rules
|
|
559
|
+
generate_rules = sharding_cfg.generate_rules
|
|
560
|
+
|
|
561
|
+
# Extract the overrides from the vllm_config if they are not provided programatically.
|
|
562
|
+
if prefill_overrides is None:
|
|
563
|
+
prefill_overrides = self._get_overrides("prefill")
|
|
564
|
+
if generate_overrides is None:
|
|
565
|
+
generate_overrides = self._get_overrides("generate")
|
|
566
|
+
|
|
567
|
+
# Apply default sharding configs
|
|
568
|
+
self._make_default_sharding_config(prefill_rules, generate_rules)
|
|
569
|
+
|
|
570
|
+
# Apply overriding the runtime sharding rules
|
|
571
|
+
self._apply_overrides(prefill_rules, prefill_overrides)
|
|
572
|
+
self._apply_overrides(generate_rules, generate_overrides)
|
|
573
|
+
|
|
574
|
+
return sharding_cfg
|
|
575
|
+
|
|
576
|
+
#TODO: Add __repr__
|
|
577
|
+
|
|
578
|
+
|
|
579
|
+
class ShardingInfo:
|
|
580
|
+
#TODO a sharding info class for visualizing & debugging the sharding performance
|
|
581
|
+
# Will implement it for the next version
|
|
582
|
+
pass
|
|
File without changes
|
|
File without changes
|