tpu-inference 0.11.1rc1__py3-none-any.whl → 0.11.1rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +254 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/attention_interface.py +356 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/binary_search.py +295 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +172 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +95 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
- tpu_inference/layers/jax/sharding.py +406 -0
- tpu_inference/layers/jax/transformer_block.py +76 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +184 -0
- tpu_inference/layers/vllm/fused_moe.py +399 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +34 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
- tpu_inference/layers/vllm/sharding.py +151 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +433 -0
- {tpu_inference-0.11.1rc1.dist-info → tpu_inference-0.11.1rc3.dist-info}/METADATA +6 -6
- {tpu_inference-0.11.1rc1.dist-info → tpu_inference-0.11.1rc3.dist-info}/RECORD +50 -5
- {tpu_inference-0.11.1rc1.dist-info → tpu_inference-0.11.1rc3.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1rc1.dist-info → tpu_inference-0.11.1rc3.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1rc1.dist-info → tpu_inference-0.11.1rc3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,406 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from jax.sharding import Mesh
|
|
6
|
+
from vllm.config import VllmConfig
|
|
7
|
+
|
|
8
|
+
BATCH_AXIS_NAME = 'data'
|
|
9
|
+
SEQUENCE_AXIS_NAME = 'data'
|
|
10
|
+
DATA_AXIS_NAME = 'data'
|
|
11
|
+
ATTN_HEAD_AXIS_NAME = 'model'
|
|
12
|
+
ATTN_TENSOR_AXIS_NAME = None
|
|
13
|
+
MLP_TENSOR_AXIS_NAME = ('model', 'expert')
|
|
14
|
+
MOE_TENSOR_AXIS_NAME = 'model'
|
|
15
|
+
EXPERT_AXIS_NAME = 'expert'
|
|
16
|
+
VOCAB_AXIS_NAME = ('data', 'expert', 'model')
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class ShardingStrategy:
|
|
21
|
+
"""Defines the high-level parallelism strategy.
|
|
22
|
+
|
|
23
|
+
This class specifies how many ways each type of parallelism (tensor, expert,
|
|
24
|
+
sequence, data) should be distributed across the available devices.
|
|
25
|
+
|
|
26
|
+
Attributes:
|
|
27
|
+
tensor_parallelism: The degree of tensor parallelism (e.g., splitting
|
|
28
|
+
weights of a single layer).
|
|
29
|
+
expert_parallelism: The degree of expert parallelism for MoE models.
|
|
30
|
+
sequence_parallelism: The degree of sequence parallelism (splitting
|
|
31
|
+
activations along the sequence length dimension).
|
|
32
|
+
data_parallelism: The degree of data parallelism (splitting the batch
|
|
33
|
+
across devices).
|
|
34
|
+
"""
|
|
35
|
+
tensor_parallelism: int = 1
|
|
36
|
+
expert_parallelism: int = 1
|
|
37
|
+
sequence_parallelism: int = 1
|
|
38
|
+
data_parallelism: int = 1
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
#TODO split this into block unique sharding config, i.e. attentionShardingConfig, MoEShardingConfig
|
|
42
|
+
@dataclass
|
|
43
|
+
class ShardingRulesConfig:
|
|
44
|
+
"""Holds detailed sharding configurations for individual tensors, namely logical rules.
|
|
45
|
+
|
|
46
|
+
Each attribute in this class corresponds to a specific weight or activation
|
|
47
|
+
tensor within a transformer model. The value of each attribute is a
|
|
48
|
+
tuple of logical mesh axis names (e.g., 'dp', 'sp', 'tp'), which defines
|
|
49
|
+
how the corresponding tensor's dimensions are partitioned across the device mesh.
|
|
50
|
+
The dimension order in the attribute name (e.g., `btd` for batch, sequence,
|
|
51
|
+
d_model) maps directly to the sharding tuple.
|
|
52
|
+
|
|
53
|
+
TODO: update the mesh axis names to be clear and reduce confusion between prefill & generate
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
# Activation for attn input: (Batch * Sequence, Dim)
|
|
57
|
+
activation_attention_td: tuple = (None, None)
|
|
58
|
+
# Activation for attn out: (Batch * Sequence, Dim)
|
|
59
|
+
activation_attention_out_td: tuple = (None, None)
|
|
60
|
+
# Activation for q projection input: (Batch * Sequence, Dim)
|
|
61
|
+
activation_q_td: tuple = (None, None)
|
|
62
|
+
# Attention Out activation after projection: (Batch * Sequence, NumHeads, HeadDim)
|
|
63
|
+
attn_o_tnh: tuple = (None, None, None)
|
|
64
|
+
# Q vector: (Batch * Sequence, NumHeads, HeadDim)
|
|
65
|
+
query_tnh: tuple = (None, None, None)
|
|
66
|
+
# K/V vector: (Batch * Sequence, NumKVHeads, HeadDim)
|
|
67
|
+
keyvalue_skh: tuple = (None, None, None)
|
|
68
|
+
|
|
69
|
+
# Attention Q weight: (Dim, NumHeads, HeadDim)
|
|
70
|
+
attn_q_weight_dnh: tuple = (None, None, None)
|
|
71
|
+
# Attention K weight: (Dim, NumKVHeads, HeadDim)
|
|
72
|
+
attn_k_weight_dkh: tuple = (None, None, None)
|
|
73
|
+
# Attention V weight: (Dim, NumKVHeads, HeadDim)
|
|
74
|
+
attn_v_weight_dkh: tuple = (None, None, None)
|
|
75
|
+
# Attention Out weight: (NumHeads, HeadDim, Dim)
|
|
76
|
+
attn_o_weight_nhd: tuple = (None, None, None)
|
|
77
|
+
|
|
78
|
+
# Activation for ffw input: (Batch * Sequence, Dim)
|
|
79
|
+
activation_ffw_td: tuple = (None, None)
|
|
80
|
+
|
|
81
|
+
# Activation for ffw input: (Batch * Sequence, Expert, Dim)
|
|
82
|
+
activation_ffw_ted: tuple = (None, None, None)
|
|
83
|
+
|
|
84
|
+
# FFW hidden activation: (Batch * Sequence, FfwDim)
|
|
85
|
+
ffw_hidden_tf: tuple = (None, None)
|
|
86
|
+
|
|
87
|
+
# FFW up/gate weight: (Dim, FfwDim)
|
|
88
|
+
ffw_weight_df: tuple = (None, None)
|
|
89
|
+
# FFW down weight: (FfwDim, Dim)
|
|
90
|
+
ffw_weight_fd: tuple = (None, None)
|
|
91
|
+
# MoE gate/up weights: (NumExperts, Dim, FfwDim)
|
|
92
|
+
moe_weights_edf: tuple = (None, None, None)
|
|
93
|
+
# MoE down weights: (NumExperts, FfwDim, Dim)
|
|
94
|
+
moe_weights_efd: tuple = (None, None, None)
|
|
95
|
+
# MoE router weights: (Dim, NumExperts)
|
|
96
|
+
moe_router_de: tuple = (None, None)
|
|
97
|
+
# MoE router bias weights: (NumExperts,)
|
|
98
|
+
moe_router_bias_e: tuple = (None, )
|
|
99
|
+
|
|
100
|
+
# Embedding weight: (VocabSize, Dim)
|
|
101
|
+
emb_weight_vd: tuple = (None, None)
|
|
102
|
+
# Activation between layers: (Batch * Sequence, Dim)
|
|
103
|
+
activation_td: tuple = (None, None)
|
|
104
|
+
# Final activation before logits: (Batch * Sequence, Dim)
|
|
105
|
+
prelogit_td: tuple = (None, None)
|
|
106
|
+
# Logit activation: (Batch * Sequence, VocabSize)
|
|
107
|
+
logits_tv: tuple = (None, None)
|
|
108
|
+
# RMS norm scale weight: (Dim,)
|
|
109
|
+
norm_scale: tuple = (None)
|
|
110
|
+
# Vocab projection weight (tied embeddings): (Dim, VocabSize)
|
|
111
|
+
vocab_vd: tuple = (None, None)
|
|
112
|
+
vocab_dv: tuple = (None, None)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class ShardingConfig:
|
|
116
|
+
"""Container for operation-specific sharding configurations.
|
|
117
|
+
|
|
118
|
+
This class holds two separate `ShardingRulesConfig` objects, one for the
|
|
119
|
+
'prefill' phase and one for the 'generate' (or decode) phase of model
|
|
120
|
+
execution. This allows tailoring sharding strategies to the different
|
|
121
|
+
computational patterns of each phase.
|
|
122
|
+
|
|
123
|
+
Example Sharding Strategy and Configuration:
|
|
124
|
+
|
|
125
|
+
Sharding Strategy defines the high-level parallelism dimensions.
|
|
126
|
+
For a device mesh like `Mesh((2, 4, 4, 4), ('data', 'seq', 'expert', 'tensor'))` on 128 devices:
|
|
127
|
+
- data: Data Parallelism (2-way)
|
|
128
|
+
- seq: Sequence Parallelism (4-way)
|
|
129
|
+
- expert: Expert Parallelism (4-way)
|
|
130
|
+
- tensor: Tensor Parallelism (4-way)
|
|
131
|
+
|
|
132
|
+
ShardingConfig then maps tensor dimensions to these logical mesh axes.
|
|
133
|
+
For example, a tensor with shape (Batch, Sequence, Dimension) could be sharded
|
|
134
|
+
differently for prefill and decode/generate operations:
|
|
135
|
+
|
|
136
|
+
- Prefill (long sequences, small batch):
|
|
137
|
+
Sharding sequence dim on the 'sp' axis is often efficient.
|
|
138
|
+
`prefill_rules.activation_attention_btd = (None, 'seq', 'tensor')`
|
|
139
|
+
|
|
140
|
+
- Generate (short sequences, large batch):
|
|
141
|
+
Sharding batch dim on the 'dp' axis is often efficient.
|
|
142
|
+
`generate_rules.activation_attention_btd = ('data', None, 'tensor')`
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
def __init__(self,
|
|
146
|
+
prefill_rules=None,
|
|
147
|
+
generate_rules=None,
|
|
148
|
+
default_rules_cls=ShardingRulesConfig):
|
|
149
|
+
"""Initializes the ShardingConfig.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
prefill_rules: An `ShardingRulesConfig` for the prefill phase.
|
|
153
|
+
If None, a default config is created.
|
|
154
|
+
generate_rules: An `ShardingRulesConfig` for the generate phase.
|
|
155
|
+
If None, a default config is created.
|
|
156
|
+
default_rules_cls: The default sharding rules (class) to use.
|
|
157
|
+
"""
|
|
158
|
+
# Use a factory pattern to avoid mutable default arguments
|
|
159
|
+
self.default_rules_cls = default_rules_cls
|
|
160
|
+
self.prefill_rules = prefill_rules if prefill_rules is not None else default_rules_cls(
|
|
161
|
+
)
|
|
162
|
+
self.generate_rules = generate_rules if generate_rules is not None else default_rules_cls(
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def build_mesh(devices, strategy: dict[str, int]) -> Mesh:
|
|
167
|
+
"""Constructs a JAX device mesh from a sharding strategy.
|
|
168
|
+
|
|
169
|
+
This method creates a logical grid of devices based on the parallelism
|
|
170
|
+
degrees defined in the strategy. The logical axis names ('dp', 'ep',
|
|
171
|
+
'sp', 'tp') are used to map tensor dimensions to the physical device grid.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
strategy: A dictionary from upper level config.
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
A JAX `Mesh` object.
|
|
178
|
+
"""
|
|
179
|
+
|
|
180
|
+
axis_order = {
|
|
181
|
+
"data": strategy.get("data_parallelism", 1),
|
|
182
|
+
"expert": strategy.get("expert_parallelism", 1),
|
|
183
|
+
"seq": strategy.get("sequence_parallelism", 1),
|
|
184
|
+
"model": strategy.get("tensor_parallelism", 1),
|
|
185
|
+
}
|
|
186
|
+
# TODO: add logic to infer axis when the degree is -1
|
|
187
|
+
mesh_axis_names = []
|
|
188
|
+
mesh_shape = []
|
|
189
|
+
for axis, dim in axis_order.items():
|
|
190
|
+
mesh_axis_names.append(axis)
|
|
191
|
+
mesh_shape.append(dim)
|
|
192
|
+
|
|
193
|
+
if not mesh_shape:
|
|
194
|
+
mesh_shape = [1]
|
|
195
|
+
mesh_axis_names = [
|
|
196
|
+
'data'
|
|
197
|
+
] # default to data parallelism if no other strategy is specified
|
|
198
|
+
|
|
199
|
+
devices = np.asarray(devices).reshape(mesh_shape)
|
|
200
|
+
return Mesh(devices, axis_names=tuple(mesh_axis_names))
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class Sharding:
|
|
204
|
+
"""Generates and manages sharding configurations based on a high-level strategy.
|
|
205
|
+
|
|
206
|
+
This class populates a `ShardingConfig` with detailed tensor sharding
|
|
207
|
+
rules for both prefill and generation phases. It also allows for runtime
|
|
208
|
+
overrides of these rules.
|
|
209
|
+
|
|
210
|
+
Attributes:
|
|
211
|
+
sharding_cfg: The generated `ShardingConfig` with detailed rules.
|
|
212
|
+
"""
|
|
213
|
+
|
|
214
|
+
def __init__(self,
|
|
215
|
+
prefill_rules: dict | None = None,
|
|
216
|
+
generate_rules: dict | None = None,
|
|
217
|
+
default_rules_cls=ShardingRulesConfig,
|
|
218
|
+
vllm_config: VllmConfig = None):
|
|
219
|
+
"""Initializes the Sharding manager.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
prefill_rules: A dictionary of overrides for the prefill
|
|
223
|
+
sharding config. Keys are attribute names in `ShardingRulesConfig`,
|
|
224
|
+
and values are the new sharding tuples.
|
|
225
|
+
generate_rules: A dictionary of overrides for the generate
|
|
226
|
+
sharding config.
|
|
227
|
+
"""
|
|
228
|
+
self.vllm_config = vllm_config
|
|
229
|
+
self.default_rules_cls = default_rules_cls
|
|
230
|
+
self.sharding_cfg = self.make_sharding_config(
|
|
231
|
+
default_rules_cls=default_rules_cls,
|
|
232
|
+
prefill_overrides=prefill_rules,
|
|
233
|
+
generate_overrides=generate_rules)
|
|
234
|
+
|
|
235
|
+
def _get_overrides(self, sharding_phase: str):
|
|
236
|
+
"""Return the overrides from the vLLM config for the given sharding phase."""
|
|
237
|
+
overrides = {}
|
|
238
|
+
try:
|
|
239
|
+
overrides = self.vllm_config.additional_config["sharding"][
|
|
240
|
+
"logical_rules"]["all"]
|
|
241
|
+
except KeyError:
|
|
242
|
+
pass
|
|
243
|
+
|
|
244
|
+
try:
|
|
245
|
+
additional_overrides = self.vllm_config.additional_config[
|
|
246
|
+
"sharding"]["logical_rules"][f"{sharding_phase}"]
|
|
247
|
+
overrides.update(additional_overrides)
|
|
248
|
+
except KeyError:
|
|
249
|
+
pass
|
|
250
|
+
return overrides
|
|
251
|
+
|
|
252
|
+
def __str__(self):
|
|
253
|
+
"""Succinct representation of relevant Sharding settings and overrides."""
|
|
254
|
+
output_str = f" Using {self.default_rules_cls.__name__} logical rules.\n"
|
|
255
|
+
output_str += f" {self.__class__.__name__:} overrides:\n"
|
|
256
|
+
output_str += f" prefill logical_rule overrides:\n {json.dumps(self._get_overrides('prefill'), indent=4, default=str)}\n\n"
|
|
257
|
+
output_str += f" generate logical_rule overrides:\n {json.dumps(self._get_overrides('generate'), indent=4, default=str)}\n\n"
|
|
258
|
+
return output_str
|
|
259
|
+
|
|
260
|
+
def validate_sharding_strategy(self, ):
|
|
261
|
+
"""Validates if the sharding strategy is compatible with the environment.
|
|
262
|
+
|
|
263
|
+
This method is a placeholder now, and will check if the product of parallelism degrees
|
|
264
|
+
matches the number of available devices.
|
|
265
|
+
"""
|
|
266
|
+
#TODO: check num_devices % parallelism == 0
|
|
267
|
+
#TODO: check num_devices == multiply(parallelism(with inferred))
|
|
268
|
+
return
|
|
269
|
+
|
|
270
|
+
def get_sharding_cfg(self) -> ShardingConfig:
|
|
271
|
+
"""Returns the generated sharding configuration."""
|
|
272
|
+
return self.sharding_cfg
|
|
273
|
+
|
|
274
|
+
def _apply_overrides(self, config_obj: ShardingRulesConfig,
|
|
275
|
+
overrides: dict | None):
|
|
276
|
+
"""Applies runtime overrides to a sharding configuration object.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
config_obj: The sharding configuration object (e.g., prefill_rules)
|
|
280
|
+
to be updated.
|
|
281
|
+
overrides: A dictionary where keys are attribute names of the config
|
|
282
|
+
object and values are the new sharding tuples.
|
|
283
|
+
|
|
284
|
+
Raises:
|
|
285
|
+
AttributeError: If a key in the overrides dictionary is not a valid
|
|
286
|
+
attribute of the configuration object.
|
|
287
|
+
"""
|
|
288
|
+
for key, value in overrides.items():
|
|
289
|
+
if hasattr(config_obj, key):
|
|
290
|
+
setattr(config_obj, key, value)
|
|
291
|
+
else:
|
|
292
|
+
# Raise an error for invalid keys to prevent silent failures
|
|
293
|
+
raise AttributeError(
|
|
294
|
+
f"'{key}' is not a valid attribute of {type(config_obj).__name__}"
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
def _make_default_sharding_config(self, prefill_rules, generate_rules):
|
|
298
|
+
|
|
299
|
+
# Populate Prefill Config
|
|
300
|
+
# During prefill, sequence length is long, so we shard along the sequence axis.
|
|
301
|
+
prefill_rules.activation_attention_td = (DATA_AXIS_NAME,
|
|
302
|
+
ATTN_TENSOR_AXIS_NAME)
|
|
303
|
+
prefill_rules.activation_attention_out_td = (DATA_AXIS_NAME,
|
|
304
|
+
ATTN_TENSOR_AXIS_NAME)
|
|
305
|
+
prefill_rules.activation_q_td = (DATA_AXIS_NAME, ATTN_TENSOR_AXIS_NAME)
|
|
306
|
+
#TODO: the default qkv and kvcache is sharded on head dim
|
|
307
|
+
# We may change it after we finalize the KVCache design
|
|
308
|
+
prefill_rules.attn_o_tnh = (DATA_AXIS_NAME, ATTN_HEAD_AXIS_NAME, None)
|
|
309
|
+
prefill_rules.query_tnh = (DATA_AXIS_NAME, ATTN_HEAD_AXIS_NAME, None)
|
|
310
|
+
prefill_rules.keyvalue_skh = (DATA_AXIS_NAME, ATTN_HEAD_AXIS_NAME,
|
|
311
|
+
None)
|
|
312
|
+
|
|
313
|
+
# Populate Generate (Decode) Config
|
|
314
|
+
# During decode, batch size is the large dimension, so we shard along the batch axis.
|
|
315
|
+
generate_rules.activation_attention_td = (DATA_AXIS_NAME,
|
|
316
|
+
ATTN_TENSOR_AXIS_NAME)
|
|
317
|
+
generate_rules.activation_attention_out_td = (DATA_AXIS_NAME,
|
|
318
|
+
ATTN_TENSOR_AXIS_NAME)
|
|
319
|
+
generate_rules.activation_q_td = (DATA_AXIS_NAME,
|
|
320
|
+
ATTN_TENSOR_AXIS_NAME)
|
|
321
|
+
#TODO: the default qkv and kvcache is sharded on head dim
|
|
322
|
+
# We may change it after we finalize the KVCache design
|
|
323
|
+
generate_rules.attn_o_tnh = (DATA_AXIS_NAME, ATTN_HEAD_AXIS_NAME, None)
|
|
324
|
+
generate_rules.query_tnh = (DATA_AXIS_NAME, ATTN_HEAD_AXIS_NAME, None)
|
|
325
|
+
generate_rules.keyvalue_skh = (DATA_AXIS_NAME, ATTN_HEAD_AXIS_NAME,
|
|
326
|
+
None)
|
|
327
|
+
generate_rules.attn_q_weight_dnh = (None, ATTN_HEAD_AXIS_NAME,
|
|
328
|
+
ATTN_TENSOR_AXIS_NAME)
|
|
329
|
+
generate_rules.attn_k_weight_dkh = (None, ATTN_HEAD_AXIS_NAME,
|
|
330
|
+
ATTN_TENSOR_AXIS_NAME)
|
|
331
|
+
generate_rules.attn_v_weight_dkh = (None, ATTN_HEAD_AXIS_NAME,
|
|
332
|
+
ATTN_TENSOR_AXIS_NAME)
|
|
333
|
+
generate_rules.attn_o_weight_nhd = (ATTN_HEAD_AXIS_NAME, None,
|
|
334
|
+
ATTN_TENSOR_AXIS_NAME)
|
|
335
|
+
generate_rules.activation_ffw_td = (DATA_AXIS_NAME, None)
|
|
336
|
+
generate_rules.activation_ffw_ted = (DATA_AXIS_NAME, EXPERT_AXIS_NAME,
|
|
337
|
+
None)
|
|
338
|
+
generate_rules.ffw_hidden_tf = (DATA_AXIS_NAME, MLP_TENSOR_AXIS_NAME)
|
|
339
|
+
# FFW weights are typically sharded along the hidden dimension (F).
|
|
340
|
+
generate_rules.ffw_weight_df = (None, MLP_TENSOR_AXIS_NAME)
|
|
341
|
+
generate_rules.ffw_weight_fd = (MLP_TENSOR_AXIS_NAME, None)
|
|
342
|
+
# MoE weights are sharded along the expert axis and the hidden dimension.
|
|
343
|
+
generate_rules.moe_weights_edf = (EXPERT_AXIS_NAME, None,
|
|
344
|
+
MOE_TENSOR_AXIS_NAME)
|
|
345
|
+
generate_rules.moe_weights_efd = (EXPERT_AXIS_NAME,
|
|
346
|
+
MOE_TENSOR_AXIS_NAME, None)
|
|
347
|
+
generate_rules.moe_router_de = (None, EXPERT_AXIS_NAME)
|
|
348
|
+
|
|
349
|
+
# Embedding weight: (VocabSize, Dim)
|
|
350
|
+
generate_rules.emb_weight_vd = (MLP_TENSOR_AXIS_NAME, None)
|
|
351
|
+
generate_rules.activation_td = (DATA_AXIS_NAME, ATTN_TENSOR_AXIS_NAME)
|
|
352
|
+
generate_rules.prelogit_td = (DATA_AXIS_NAME, ATTN_TENSOR_AXIS_NAME)
|
|
353
|
+
generate_rules.logits_tv = (DATA_AXIS_NAME, MLP_TENSOR_AXIS_NAME)
|
|
354
|
+
generate_rules.vocab_vd = (VOCAB_AXIS_NAME, None)
|
|
355
|
+
generate_rules.vocab_dv = (None, VOCAB_AXIS_NAME)
|
|
356
|
+
|
|
357
|
+
def make_sharding_config(
|
|
358
|
+
self,
|
|
359
|
+
default_rules_cls: ShardingRulesConfig,
|
|
360
|
+
prefill_overrides: dict | None = None,
|
|
361
|
+
generate_overrides: dict | None = None) -> ShardingConfig:
|
|
362
|
+
"""Creates the detailed `ShardingConfig` with specific partitioning rules
|
|
363
|
+
and applies any runtime overrides.
|
|
364
|
+
|
|
365
|
+
This method populates the `prefill_rules` and
|
|
366
|
+
`generate_rules` with hardcoded sharding rules that are generally
|
|
367
|
+
effective for transformer models, and then updates them with any provided
|
|
368
|
+
overrides.
|
|
369
|
+
|
|
370
|
+
Args:
|
|
371
|
+
prefill_overrides: A dictionary with attribute names and their new values
|
|
372
|
+
for the prefill sharding configuration.
|
|
373
|
+
generate_overrides: A dictionary with attribute names and their new values
|
|
374
|
+
for the generate sharding configuration.
|
|
375
|
+
|
|
376
|
+
Returns:
|
|
377
|
+
The populated and overridden `ShardingConfig` object.
|
|
378
|
+
"""
|
|
379
|
+
#TODO: organize into update_prefill() and update_decode for each axis
|
|
380
|
+
#TODO: verify the sharding axes
|
|
381
|
+
sharding_cfg = ShardingConfig(default_rules_cls=default_rules_cls)
|
|
382
|
+
prefill_rules = sharding_cfg.prefill_rules
|
|
383
|
+
generate_rules = sharding_cfg.generate_rules
|
|
384
|
+
|
|
385
|
+
# Extract the overrides from the vllm_config if they are not provided programatically.
|
|
386
|
+
if prefill_overrides is None:
|
|
387
|
+
prefill_overrides = self._get_overrides("prefill")
|
|
388
|
+
if generate_overrides is None:
|
|
389
|
+
generate_overrides = self._get_overrides("generate")
|
|
390
|
+
|
|
391
|
+
# Apply default sharding configs
|
|
392
|
+
self._make_default_sharding_config(prefill_rules, generate_rules)
|
|
393
|
+
|
|
394
|
+
# Apply overriding the runtime sharding rules
|
|
395
|
+
self._apply_overrides(prefill_rules, prefill_overrides)
|
|
396
|
+
self._apply_overrides(generate_rules, generate_overrides)
|
|
397
|
+
|
|
398
|
+
return sharding_cfg
|
|
399
|
+
|
|
400
|
+
#TODO: Add __repr__
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
class ShardingInfo:
|
|
404
|
+
#TODO a sharding info class for visualizing & debugging the sharding performance
|
|
405
|
+
# Will implement it for the next version
|
|
406
|
+
pass
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Any, Tuple
|
|
3
|
+
|
|
4
|
+
# Flax and JAX sharding imports
|
|
5
|
+
import jax
|
|
6
|
+
from flax import nnx
|
|
7
|
+
|
|
8
|
+
from tpu_inference.layers.jax.attention.attention import (AttentionMetadata,
|
|
9
|
+
KVCache)
|
|
10
|
+
from tpu_inference.layers.jax.layers import DenseFFW
|
|
11
|
+
from tpu_inference.layers.jax.moe.moe import MoE
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(kw_only=True)
|
|
15
|
+
class TransformerBlock(nnx.Module):
|
|
16
|
+
"""
|
|
17
|
+
A heavy weight module which serves as the stateful live blocks in serving
|
|
18
|
+
|
|
19
|
+
custom_module can be either a dense module (i.e., DenseFFW) or MoE.
|
|
20
|
+
"""
|
|
21
|
+
pre_attention_norm: nnx.Module
|
|
22
|
+
pre_mlp_norm: nnx.Module
|
|
23
|
+
custom_module: nnx.Module
|
|
24
|
+
attn: nnx.Module
|
|
25
|
+
use_attention_rope: bool = True
|
|
26
|
+
quant: Any | None = None
|
|
27
|
+
|
|
28
|
+
def __call__(
|
|
29
|
+
self, x_TD: jax.Array, is_prefill: bool, kv_cache: KVCache,
|
|
30
|
+
attention_metadata: AttentionMetadata
|
|
31
|
+
) -> Tuple[KVCache, jax.Array]:
|
|
32
|
+
# Attn Block
|
|
33
|
+
attn_residual_TD = x_TD
|
|
34
|
+
x_TD = self.pre_attention_norm(x_TD)
|
|
35
|
+
new_cache, attn_output_TD = self.attn(x_TD, is_prefill, kv_cache,
|
|
36
|
+
attention_metadata,
|
|
37
|
+
self.use_attention_rope)
|
|
38
|
+
attn_output_TD += attn_residual_TD
|
|
39
|
+
|
|
40
|
+
# FFW Block
|
|
41
|
+
ffw_residual_TD = attn_output_TD
|
|
42
|
+
normed_ffw_input_TD = self.pre_mlp_norm(attn_output_TD)
|
|
43
|
+
logits_TD = self.custom_module(normed_ffw_input_TD)
|
|
44
|
+
logits_TD += ffw_residual_TD
|
|
45
|
+
return new_cache, logits_TD
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass(kw_only=True)
|
|
49
|
+
class SharedExpertsTransformerBlock(TransformerBlock):
|
|
50
|
+
"""Create a modified TransformerBlock that sums MoE layer output with shared expert output."""
|
|
51
|
+
shared_experts: nnx.Module
|
|
52
|
+
|
|
53
|
+
def __call__(self, x_TD, is_prefill, kv_cache, attention_metadata):
|
|
54
|
+
# Attn Block
|
|
55
|
+
attn_residual_TD = x_TD
|
|
56
|
+
x_TD = self.pre_attention_norm(x_TD)
|
|
57
|
+
new_cache, attn_output_TD = self.attn(x_TD, is_prefill, kv_cache,
|
|
58
|
+
attention_metadata,
|
|
59
|
+
self.use_attention_rope)
|
|
60
|
+
attn_output_TD += attn_residual_TD
|
|
61
|
+
|
|
62
|
+
# FFW Block
|
|
63
|
+
ffw_residual_TD = attn_output_TD
|
|
64
|
+
normed_ffw_input_TD = self.pre_mlp_norm(attn_output_TD)
|
|
65
|
+
if isinstance(self.custom_module, MoE):
|
|
66
|
+
logits_TD = self.custom_module(normed_ffw_input_TD)
|
|
67
|
+
# Add the shared expert outputs to the MoE outputs.
|
|
68
|
+
shared_expert_output_TD = self.shared_experts(normed_ffw_input_TD)
|
|
69
|
+
logits_TD += shared_expert_output_TD
|
|
70
|
+
elif isinstance(self.custom_module, DenseFFW):
|
|
71
|
+
logits_TD = self.custom_module(normed_ffw_input_TD)
|
|
72
|
+
else:
|
|
73
|
+
raise ValueError(
|
|
74
|
+
f"Invalid custom moduel type: {type(self.custom_module)}")
|
|
75
|
+
logits_TD += ffw_residual_TD
|
|
76
|
+
return new_cache, logits_TD
|
|
File without changes
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Optional, Tuple
|
|
5
|
+
|
|
6
|
+
import jax
|
|
7
|
+
import torch
|
|
8
|
+
from jax.sharding import Mesh
|
|
9
|
+
from torchax.interop import jax_view, torch_view
|
|
10
|
+
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|
11
|
+
AttentionLayer, AttentionType)
|
|
12
|
+
|
|
13
|
+
from tpu_inference import utils
|
|
14
|
+
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
15
|
+
from tpu_inference.layers.jax.attention_interface import attention
|
|
16
|
+
from tpu_inference.logger import init_logger
|
|
17
|
+
from tpu_inference.models.vllm.vllm_model_wrapper_context import \
|
|
18
|
+
get_vllm_model_wrapper_context
|
|
19
|
+
|
|
20
|
+
logger = init_logger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class PallasAttentionBackend(AttentionBackend):
|
|
24
|
+
|
|
25
|
+
@staticmethod
|
|
26
|
+
def get_name() -> str:
|
|
27
|
+
return "PALLAS"
|
|
28
|
+
|
|
29
|
+
@staticmethod
|
|
30
|
+
def get_impl_cls() -> type["PallasAttentionBackendImpl"]:
|
|
31
|
+
return PallasAttentionBackendImpl
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class PallasAttentionBackendImpl(AttentionImpl):
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
num_heads: int,
|
|
39
|
+
head_size: int,
|
|
40
|
+
scale: float,
|
|
41
|
+
num_kv_heads: int,
|
|
42
|
+
alibi_slopes: Optional[list[float]],
|
|
43
|
+
sliding_window: Optional[int],
|
|
44
|
+
kv_cache_dtype: str,
|
|
45
|
+
logits_soft_cap: Optional[float] = None,
|
|
46
|
+
attn_type: str = AttentionType.DECODER,
|
|
47
|
+
kv_sharing_target_layer_name: Optional[int] = None,
|
|
48
|
+
use_irope: bool = False,
|
|
49
|
+
) -> None:
|
|
50
|
+
if use_irope:
|
|
51
|
+
logger.warning_once(
|
|
52
|
+
"Using irope in Pallas is not supported yet, it will fall back "
|
|
53
|
+
"to global attention for long context.")
|
|
54
|
+
self.num_heads = num_heads
|
|
55
|
+
self.head_size = head_size
|
|
56
|
+
self.scale = float(scale)
|
|
57
|
+
self.num_kv_heads = num_kv_heads
|
|
58
|
+
self.sliding_window = sliding_window
|
|
59
|
+
self.logits_soft_cap = logits_soft_cap
|
|
60
|
+
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
|
61
|
+
|
|
62
|
+
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
|
63
|
+
if alibi_slopes is not None:
|
|
64
|
+
raise NotImplementedError("Alibi slopes is not supported.")
|
|
65
|
+
self.kv_cache_quantized_dtype = None
|
|
66
|
+
if kv_cache_dtype != "auto":
|
|
67
|
+
self.kv_cache_quantized_dtype = utils.get_jax_dtype_from_str_dtype(
|
|
68
|
+
kv_cache_dtype)
|
|
69
|
+
|
|
70
|
+
if attn_type != AttentionType.DECODER:
|
|
71
|
+
raise NotImplementedError("Encoder self-attention and "
|
|
72
|
+
"encoder/decoder cross-attention "
|
|
73
|
+
"are not implemented for "
|
|
74
|
+
"PallasAttentionBackendImpl")
|
|
75
|
+
|
|
76
|
+
def forward(
|
|
77
|
+
self,
|
|
78
|
+
layer: AttentionLayer,
|
|
79
|
+
query: torch.Tensor,
|
|
80
|
+
key: torch.Tensor,
|
|
81
|
+
value: torch.Tensor,
|
|
82
|
+
kv_cache: torch.Tensor,
|
|
83
|
+
attn_metadata: AttentionMetadata,
|
|
84
|
+
output: Optional[torch.Tensor] = None,
|
|
85
|
+
output_scale: Optional[torch.Tensor] = None,
|
|
86
|
+
) -> torch.Tensor:
|
|
87
|
+
if output_scale is not None:
|
|
88
|
+
raise NotImplementedError(
|
|
89
|
+
"fused output quantization is not yet supported for "
|
|
90
|
+
"PallasAttentionBackendImpl")
|
|
91
|
+
|
|
92
|
+
if kv_cache.numel():
|
|
93
|
+
raise RuntimeError(
|
|
94
|
+
"KV cache from vLLM Attention layer should be empty but has "
|
|
95
|
+
"the size of %s.", kv_cache.numel())
|
|
96
|
+
|
|
97
|
+
del kv_cache # Use kv_cache from vllm wrapper context values instead.
|
|
98
|
+
|
|
99
|
+
vllm_model_wrapper_context = get_vllm_model_wrapper_context()
|
|
100
|
+
kv_cache_index = vllm_model_wrapper_context.layer_name_to_kvcache_index[
|
|
101
|
+
layer.layer_name]
|
|
102
|
+
kv_cache = vllm_model_wrapper_context.kv_caches[kv_cache_index]
|
|
103
|
+
|
|
104
|
+
mesh = vllm_model_wrapper_context.mesh
|
|
105
|
+
|
|
106
|
+
query, key, value = jax_view(query), jax_view(key), jax_view(value)
|
|
107
|
+
q_scale = k_scale = v_scale = None
|
|
108
|
+
if self.kv_cache_quantized_dtype:
|
|
109
|
+
key, value = utils.quantize_kv(key, value,
|
|
110
|
+
self.kv_cache_quantized_dtype,
|
|
111
|
+
layer._k_scale_float,
|
|
112
|
+
layer._v_scale_float)
|
|
113
|
+
# TODO(kyuyeunk): Enable w8a8 when VREG spill issue is resolved.
|
|
114
|
+
# q_scale = layer._q_scale_float
|
|
115
|
+
k_scale = layer._k_scale_float
|
|
116
|
+
v_scale = layer._v_scale_float
|
|
117
|
+
|
|
118
|
+
new_kv_cache, outputs = _jax_attn_func(kv_cache, query, key, value,
|
|
119
|
+
attn_metadata, mesh, self.scale,
|
|
120
|
+
self.head_size, self.num_heads,
|
|
121
|
+
self.num_kv_heads, q_scale,
|
|
122
|
+
k_scale, v_scale)
|
|
123
|
+
vllm_model_wrapper_context.kv_caches[kv_cache_index] = new_kv_cache
|
|
124
|
+
|
|
125
|
+
return torch_view(outputs)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@functools.partial(
|
|
129
|
+
jax.jit,
|
|
130
|
+
static_argnums=(
|
|
131
|
+
5, 6, 7, 8, 9, 10, 11, 12
|
|
132
|
+
), # mesh, scale, head_size, num_heads, num_kv_heads, q_scale, k_scale, v_scale
|
|
133
|
+
donate_argnums=(0, ), # donate kv_cache
|
|
134
|
+
)
|
|
135
|
+
def _jax_attn_func(
|
|
136
|
+
kv_cache: jax.Array,
|
|
137
|
+
q: jax.Array,
|
|
138
|
+
k: jax.Array,
|
|
139
|
+
v: jax.Array,
|
|
140
|
+
attention_metadata: AttentionMetadata,
|
|
141
|
+
mesh: Mesh,
|
|
142
|
+
scale: float,
|
|
143
|
+
head_size: int,
|
|
144
|
+
num_heads: int,
|
|
145
|
+
num_kv_heads: int,
|
|
146
|
+
q_scale: Optional[float] = None,
|
|
147
|
+
k_scale: Optional[float] = None,
|
|
148
|
+
v_scale: Optional[float] = None,
|
|
149
|
+
) -> Tuple[jax.Array, jax.Array]:
|
|
150
|
+
del scale # Unused for now, as the attention function applies a default scale.
|
|
151
|
+
|
|
152
|
+
# Get shapes from vllm
|
|
153
|
+
q_len, q_compute_dim = q.shape
|
|
154
|
+
k_len, k_compute_dim = k.shape
|
|
155
|
+
assert k.shape == v.shape
|
|
156
|
+
assert q_compute_dim == head_size * num_heads
|
|
157
|
+
assert k_compute_dim == head_size * num_kv_heads
|
|
158
|
+
|
|
159
|
+
# Convert the shapes from vLLM's convetion to what the attention function expects
|
|
160
|
+
# bs, num_heads, q_len, head_size
|
|
161
|
+
q = q.reshape(q_len, num_heads, head_size)
|
|
162
|
+
# bs, num_kv_heads, k_len, head_size
|
|
163
|
+
k = k.reshape(k_len, num_kv_heads, head_size)
|
|
164
|
+
v = v.reshape(k_len, num_kv_heads, head_size)
|
|
165
|
+
|
|
166
|
+
new_kv_cache, outputs = attention(
|
|
167
|
+
kv_cache,
|
|
168
|
+
q,
|
|
169
|
+
k,
|
|
170
|
+
v,
|
|
171
|
+
attention_metadata,
|
|
172
|
+
mesh,
|
|
173
|
+
q_scale=q_scale,
|
|
174
|
+
k_scale=k_scale,
|
|
175
|
+
v_scale=v_scale,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
# Convert the shape back to vLLM's convention
|
|
179
|
+
assert outputs.shape[0] == q_len
|
|
180
|
+
assert outputs.shape[1] == num_heads
|
|
181
|
+
assert outputs.shape[2] == head_size
|
|
182
|
+
outputs = outputs.reshape(q_len, q_compute_dim)
|
|
183
|
+
|
|
184
|
+
return new_kv_cache, outputs
|