tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/kernels/mla_v1_test.py +129 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
- tests/lora/test_layers.py +4 -7
- tests/lora/test_lora_perf.py +53 -0
- tests/lora/utils.py +0 -8
- tests/test_envs.py +110 -12
- tests/test_quantization.py +3 -0
- tests/test_utils.py +1 -2
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +3 -4
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +93 -9
- tpu_inference/executors/ray_distributed_executor.py +9 -2
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
- tpu_inference/kernels/mla/v1/kernel.py +98 -120
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +140 -67
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +204 -120
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +11 -7
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +170 -208
- tpu_inference/layers/vllm/linear_common.py +43 -21
- tpu_inference/layers/vllm/quantization/common.py +11 -6
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
- tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
- tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +84 -28
- tpu_inference/models/jax/deepseek_v3.py +185 -64
- tpu_inference/models/jax/gpt_oss.py +3 -3
- tpu_inference/models/jax/llama3.py +2 -1
- tpu_inference/models/jax/llama_eagle3.py +8 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +2 -1
- tpu_inference/models/jax/qwen2_5_vl.py +163 -48
- tpu_inference/models/jax/qwen3.py +2 -1
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
- tpu_inference/models/jax/utils/weight_utils.py +205 -144
- tpu_inference/models/vllm/vllm_model_wrapper.py +14 -8
- tpu_inference/platforms/tpu_platform.py +34 -50
- tpu_inference/runner/compilation_manager.py +144 -60
- tpu_inference/runner/kv_cache.py +40 -20
- tpu_inference/runner/kv_cache_manager.py +48 -33
- tpu_inference/runner/persistent_batch_manager.py +40 -2
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +280 -149
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +71 -21
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +46 -18
- tpu_inference/worker/tpu_worker.py +197 -63
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +9 -10
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +70 -74
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import os
|
|
1
2
|
import re
|
|
2
3
|
from dataclasses import dataclass
|
|
3
4
|
from typing import List, Optional, Tuple
|
|
@@ -13,6 +14,7 @@ from torchax.ops.mappings import j2t_dtype
|
|
|
13
14
|
from vllm.config import VllmConfig
|
|
14
15
|
|
|
15
16
|
from tpu_inference import utils
|
|
17
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
16
18
|
from tpu_inference.layers.jax.attention.attention import AttentionMetadata
|
|
17
19
|
from tpu_inference.layers.jax.attention.deepseek_v3_attention import MLA
|
|
18
20
|
from tpu_inference.layers.jax.constants import KVCacheType
|
|
@@ -69,6 +71,7 @@ class DeepSeekV3(nnx.Module):
|
|
|
69
71
|
hidden_act: str = "silu"
|
|
70
72
|
rms_norm_eps: float = 1e-06
|
|
71
73
|
first_k_dense_replace: int = 3 # replace the first few MOE layers to dense layer.
|
|
74
|
+
self.use_mla_kernel: bool = self.vllm_config.model_config.use_mla
|
|
72
75
|
|
|
73
76
|
num_shared_experts = 1
|
|
74
77
|
rope_theta = 10000
|
|
@@ -114,19 +117,30 @@ class DeepSeekV3(nnx.Module):
|
|
|
114
117
|
qk_rope_head_dim=qk_rope_head_dim,
|
|
115
118
|
v_head_dim=v_head_dim,
|
|
116
119
|
num_local_experts=num_local_experts,
|
|
117
|
-
model_dtype=dtype
|
|
120
|
+
model_dtype=dtype,
|
|
121
|
+
use_mla_kernel=self.use_mla_kernel)
|
|
118
122
|
|
|
119
123
|
self.embedder = Embedder(vocab_size=vocab_size,
|
|
120
124
|
hidden_size=hidden_size,
|
|
121
125
|
dtype=dtype,
|
|
122
126
|
rngs=self.rng,
|
|
123
|
-
vd_sharding=(
|
|
127
|
+
vd_sharding=(ShardingAxisName.MLP_TENSOR,
|
|
124
128
|
None),
|
|
125
129
|
random_init=self.random_init)
|
|
126
130
|
|
|
127
131
|
self.layers = []
|
|
128
132
|
|
|
129
133
|
def _create_mla() -> MLA:
|
|
134
|
+
if self.use_mla_kernel:
|
|
135
|
+
query_tnh_spec = P(ShardingAxisName.MLP_TENSOR, None, None)
|
|
136
|
+
keyvalue_skh_spec = P(ShardingAxisName.MLP_TENSOR, None)
|
|
137
|
+
attn_o_tnh_spec = P(ShardingAxisName.MLP_TENSOR, None, None)
|
|
138
|
+
|
|
139
|
+
else:
|
|
140
|
+
query_tnh_spec = P(None, ShardingAxisName.MLP_TENSOR, None)
|
|
141
|
+
keyvalue_skh_spec = P(None, ShardingAxisName.MLP_TENSOR, None)
|
|
142
|
+
attn_o_tnh_spec = P(None, ShardingAxisName.MLP_TENSOR, None)
|
|
143
|
+
|
|
130
144
|
return MLA(
|
|
131
145
|
rope_theta=rope_theta,
|
|
132
146
|
rope_scaling=rope_scaling,
|
|
@@ -137,10 +151,12 @@ class DeepSeekV3(nnx.Module):
|
|
|
137
151
|
rms_norm_eps=rms_norm_eps,
|
|
138
152
|
v_head_dim=v_head_dim,
|
|
139
153
|
mesh=self.mesh,
|
|
154
|
+
use_mla_kernel=self.use_mla_kernel,
|
|
140
155
|
random_init=self.random_init,
|
|
141
156
|
hidden_size=hidden_size,
|
|
142
157
|
num_attention_heads=num_attention_heads,
|
|
143
|
-
num_key_value_heads=
|
|
158
|
+
num_key_value_heads=1
|
|
159
|
+
if self.use_mla_kernel else num_key_value_heads,
|
|
144
160
|
head_dim=v_head_dim, # MLA uses v_head_dim as head_dim
|
|
145
161
|
dtype=dtype,
|
|
146
162
|
# TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
|
|
@@ -148,14 +164,14 @@ class DeepSeekV3(nnx.Module):
|
|
|
148
164
|
rngs=self.rng,
|
|
149
165
|
activation_attention_td=(None, None),
|
|
150
166
|
activation_q_td=(None, None),
|
|
151
|
-
query_tnh=
|
|
152
|
-
keyvalue_skh=
|
|
167
|
+
query_tnh=query_tnh_spec,
|
|
168
|
+
keyvalue_skh=keyvalue_skh_spec,
|
|
153
169
|
activation_attention_out_td=(None, None),
|
|
154
|
-
attn_o_tnh=
|
|
155
|
-
q_da_sharding=(None,
|
|
156
|
-
anh_sharding=(None,
|
|
157
|
-
kv_da_sharding=(None,
|
|
158
|
-
nhd_sharding=(
|
|
170
|
+
attn_o_tnh=attn_o_tnh_spec,
|
|
171
|
+
q_da_sharding=(None, ShardingAxisName.VOCAB),
|
|
172
|
+
anh_sharding=(None, ShardingAxisName.MLP_TENSOR, None),
|
|
173
|
+
kv_da_sharding=(None, ShardingAxisName.VOCAB),
|
|
174
|
+
nhd_sharding=(ShardingAxisName.MLP_TENSOR, None, None))
|
|
159
175
|
|
|
160
176
|
for i in range(first_k_dense_replace):
|
|
161
177
|
block = TransformerBlock(
|
|
@@ -176,14 +192,15 @@ class DeepSeekV3(nnx.Module):
|
|
|
176
192
|
rngs=self.rng,
|
|
177
193
|
),
|
|
178
194
|
attn=_create_mla(),
|
|
179
|
-
custom_module=DenseFFW(
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
195
|
+
custom_module=DenseFFW(
|
|
196
|
+
dtype=dtype,
|
|
197
|
+
hidden_act=hidden_act,
|
|
198
|
+
hidden_size=hidden_size,
|
|
199
|
+
intermediate_size=ffw_intermediate_size,
|
|
200
|
+
rngs=self.rng,
|
|
201
|
+
df_sharding=(None, ShardingAxisName.MLP_TENSOR),
|
|
202
|
+
fd_sharding=(ShardingAxisName.MLP_TENSOR, None),
|
|
203
|
+
random_init=self.random_init))
|
|
187
204
|
|
|
188
205
|
self.layers.append(block)
|
|
189
206
|
|
|
@@ -200,9 +217,9 @@ class DeepSeekV3(nnx.Module):
|
|
|
200
217
|
rngs=self.rng,
|
|
201
218
|
routed_scaling_factor=2.5,
|
|
202
219
|
dtype=dtype,
|
|
203
|
-
activation_ffw_td=(
|
|
204
|
-
ed_sharding=(
|
|
205
|
-
e_sharding=(
|
|
220
|
+
activation_ffw_td=(ShardingAxisName.MLP_DATA, None),
|
|
221
|
+
ed_sharding=(ShardingAxisName.MLP_TENSOR, None),
|
|
222
|
+
e_sharding=(ShardingAxisName.MLP_TENSOR, ))
|
|
206
223
|
if self.sparse_matmul:
|
|
207
224
|
# TODO: orginize the SparseMoE and DenseMoE better given they share most interfaces
|
|
208
225
|
custom_module = SparseMoE(
|
|
@@ -216,10 +233,10 @@ class DeepSeekV3(nnx.Module):
|
|
|
216
233
|
hidden_act=hidden_act,
|
|
217
234
|
rngs=self.rng,
|
|
218
235
|
random_init=self.random_init,
|
|
219
|
-
activation_ffw_td=(
|
|
220
|
-
activation_ffw_ted=(
|
|
221
|
-
edf_sharding=(
|
|
222
|
-
efd_sharding=(
|
|
236
|
+
activation_ffw_td=(ShardingAxisName.MLP_TENSOR, None),
|
|
237
|
+
activation_ffw_ted=(ShardingAxisName.MLP_DATA, None, None),
|
|
238
|
+
edf_sharding=(ShardingAxisName.MLP_TENSOR, None, None),
|
|
239
|
+
efd_sharding=(ShardingAxisName.MLP_TENSOR, None, None),
|
|
223
240
|
quantized_dtype=self.weight_loader.quant_dtype
|
|
224
241
|
if self.weight_loader.is_model_quantized else None,
|
|
225
242
|
router=router) if is_moe_layer else DenseFFW(
|
|
@@ -229,8 +246,8 @@ class DeepSeekV3(nnx.Module):
|
|
|
229
246
|
intermediate_size=ffw_intermediate_size,
|
|
230
247
|
rngs=self.rng,
|
|
231
248
|
random_init=self.random_init,
|
|
232
|
-
df_sharding=(None,
|
|
233
|
-
fd_sharding=(
|
|
249
|
+
df_sharding=(None, ShardingAxisName.MLP_TENSOR),
|
|
250
|
+
fd_sharding=(ShardingAxisName.MLP_TENSOR, None))
|
|
234
251
|
else:
|
|
235
252
|
custom_module = MoE(
|
|
236
253
|
dtype=dtype,
|
|
@@ -241,10 +258,10 @@ class DeepSeekV3(nnx.Module):
|
|
|
241
258
|
hidden_act=hidden_act,
|
|
242
259
|
rngs=self.rng,
|
|
243
260
|
random_init=self.random_init,
|
|
244
|
-
activation_ffw_td=(
|
|
245
|
-
activation_ffw_ted=(
|
|
246
|
-
edf_sharding=(
|
|
247
|
-
efd_sharding=(
|
|
261
|
+
activation_ffw_td=(ShardingAxisName.MLP_DATA, None),
|
|
262
|
+
activation_ffw_ted=(ShardingAxisName.MLP_DATA, None, None),
|
|
263
|
+
edf_sharding=(ShardingAxisName.MLP_TENSOR, None, None),
|
|
264
|
+
efd_sharding=(ShardingAxisName.MLP_TENSOR, None, None),
|
|
248
265
|
router=router) if is_moe_layer else DenseFFW(
|
|
249
266
|
dtype=dtype,
|
|
250
267
|
hidden_act=hidden_act,
|
|
@@ -252,18 +269,18 @@ class DeepSeekV3(nnx.Module):
|
|
|
252
269
|
intermediate_size=ffw_intermediate_size,
|
|
253
270
|
rngs=self.rng,
|
|
254
271
|
random_init=self.random_init,
|
|
255
|
-
df_sharding=(None,
|
|
256
|
-
fd_sharding=(
|
|
257
|
-
|
|
258
|
-
shared_experts = DenseFFW(
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
272
|
+
df_sharding=(None, ShardingAxisName.MLP_TENSOR),
|
|
273
|
+
fd_sharding=(ShardingAxisName.MLP_TENSOR, None))
|
|
274
|
+
|
|
275
|
+
shared_experts = DenseFFW(
|
|
276
|
+
dtype=dtype,
|
|
277
|
+
hidden_act=hidden_act,
|
|
278
|
+
hidden_size=hidden_size,
|
|
279
|
+
intermediate_size=num_shared_experts * moe_intermediate_size,
|
|
280
|
+
rngs=self.rng,
|
|
281
|
+
random_init=self.random_init,
|
|
282
|
+
df_sharding=(None, ShardingAxisName.MLP_TENSOR),
|
|
283
|
+
fd_sharding=(ShardingAxisName.MLP_TENSOR, None))
|
|
267
284
|
|
|
268
285
|
pre_attention_norm = RMSNorm(
|
|
269
286
|
dims=hidden_size,
|
|
@@ -304,10 +321,28 @@ class DeepSeekV3(nnx.Module):
|
|
|
304
321
|
hidden_size=hidden_size,
|
|
305
322
|
dtype=dtype,
|
|
306
323
|
rngs=self.rng,
|
|
307
|
-
vd_sharding=(
|
|
308
|
-
dv_sharding=(None,
|
|
324
|
+
vd_sharding=(ShardingAxisName.MLP_TENSOR, None),
|
|
325
|
+
dv_sharding=(None, ShardingAxisName.MLP_TENSOR),
|
|
309
326
|
random_init=self.random_init)
|
|
310
327
|
|
|
328
|
+
if os.environ.get("VLLM_LOGGING_LEVEL", "").upper() == "DEBUG":
|
|
329
|
+
self._print_model_architecture()
|
|
330
|
+
|
|
331
|
+
def _print_model_architecture(self):
|
|
332
|
+
num_display_layers = 5
|
|
333
|
+
|
|
334
|
+
logger.debug("### Embedding ###")
|
|
335
|
+
nnx.display(self.embedder)
|
|
336
|
+
|
|
337
|
+
logger.debug(f"\n### First {num_display_layers} Layers ###")
|
|
338
|
+
# Loop through the slice and display each layer
|
|
339
|
+
for i, layer in enumerate(self.layers[:num_display_layers]):
|
|
340
|
+
logger.debug(f"\n--- Layer {i} ---")
|
|
341
|
+
nnx.display(layer)
|
|
342
|
+
|
|
343
|
+
logger.debug("\n### LM Head ###")
|
|
344
|
+
nnx.display(self.lm_head)
|
|
345
|
+
|
|
311
346
|
# For compatibility with flax.
|
|
312
347
|
def apply(self, variables, *args, **kwargs):
|
|
313
348
|
return self.__call__(*args, **kwargs)
|
|
@@ -352,10 +387,19 @@ class DeepSeekV3(nnx.Module):
|
|
|
352
387
|
@dataclass
|
|
353
388
|
class DeepSeekV3WeightLoader:
|
|
354
389
|
|
|
355
|
-
def __init__(self,
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
390
|
+
def __init__(self,
|
|
391
|
+
vllm_config: VllmConfig,
|
|
392
|
+
num_layers,
|
|
393
|
+
hidden_size,
|
|
394
|
+
q_lora_rank,
|
|
395
|
+
kv_lora_rank,
|
|
396
|
+
attn_heads,
|
|
397
|
+
qk_nope_head_dim,
|
|
398
|
+
qk_rope_head_dim,
|
|
399
|
+
v_head_dim,
|
|
400
|
+
num_local_experts,
|
|
401
|
+
model_dtype,
|
|
402
|
+
use_mla_kernel=False):
|
|
359
403
|
self.num_layers = num_layers
|
|
360
404
|
self.names_and_weights_generator = model_weights_generator(
|
|
361
405
|
model_name_or_path=vllm_config.model_config.model,
|
|
@@ -364,7 +408,12 @@ class DeepSeekV3WeightLoader:
|
|
|
364
408
|
self.is_verbose = vllm_config.additional_config.get(
|
|
365
409
|
"is_verbose", None) is not None
|
|
366
410
|
self.num_routed_experts = num_local_experts
|
|
411
|
+
self.attn_heads = attn_heads
|
|
412
|
+
self.qk_nope_head_dim = qk_nope_head_dim
|
|
413
|
+
self.v_head_dim = v_head_dim
|
|
414
|
+
self.kv_lora_rank = kv_lora_rank
|
|
367
415
|
self.model_dtype = model_dtype
|
|
416
|
+
self.use_mla_kernel = use_mla_kernel
|
|
368
417
|
|
|
369
418
|
self._transpose_map = {
|
|
370
419
|
# dense mlp
|
|
@@ -376,6 +425,8 @@ class DeepSeekV3WeightLoader:
|
|
|
376
425
|
r"q_b_proj": (2, 0, 1),
|
|
377
426
|
r"kv_a_proj_with_mqa": (1, 0),
|
|
378
427
|
r"kv_b_proj": (2, 0, 1),
|
|
428
|
+
r"k_b_proj": (2, 0, 1), # used for MLA kernel
|
|
429
|
+
r"v_b_proj": (2, 0, 1), # used for MLA kernel
|
|
379
430
|
r"o_proj": (1, 2, 0),
|
|
380
431
|
# moe
|
|
381
432
|
r"mlp\.gate\.weight": (1, 0),
|
|
@@ -393,6 +444,8 @@ class DeepSeekV3WeightLoader:
|
|
|
393
444
|
(attn_heads, qk_nope_head_dim + qk_rope_head_dim, q_lora_rank),
|
|
394
445
|
"kv_b_proj":
|
|
395
446
|
(attn_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank),
|
|
447
|
+
"k_b_proj": (attn_heads, qk_nope_head_dim, kv_lora_rank),
|
|
448
|
+
"v_b_proj": (attn_heads, v_head_dim, kv_lora_rank),
|
|
396
449
|
"o_proj": (hidden_size, attn_heads, v_head_dim)
|
|
397
450
|
}
|
|
398
451
|
|
|
@@ -452,6 +505,13 @@ class DeepSeekV3WeightLoader:
|
|
|
452
505
|
"model.layers.*.mlp.shared_experts.up_proj.weight":
|
|
453
506
|
"layers.*.shared_experts.kernel_up_proj_DF",
|
|
454
507
|
}
|
|
508
|
+
if self.use_mla_kernel:
|
|
509
|
+
self._loaded_to_standardized_keys.update({
|
|
510
|
+
"model.layers.*.self_attn.k_b_proj.weight":
|
|
511
|
+
"layers.*.attn.kernel_k_up_proj_ANH",
|
|
512
|
+
"model.layers.*.self_attn.v_b_proj.weight":
|
|
513
|
+
"layers.*.attn.kernel_v_up_proj_ANH",
|
|
514
|
+
})
|
|
455
515
|
|
|
456
516
|
# TODO (jacobplatin): we shouldn't hard-code this, but the logic to obtain the true quantized dtype
|
|
457
517
|
# is non-trivial and the default checkpoints all use this dtype
|
|
@@ -487,6 +547,15 @@ class DeepSeekV3WeightLoader:
|
|
|
487
547
|
"kv_b_proj": (attn_heads, (qk_nope_head_dim + v_head_dim) //
|
|
488
548
|
self.quantization_block_size_n,
|
|
489
549
|
kv_lora_rank // self.quantization_block_size_n),
|
|
550
|
+
# used for MLA kernel
|
|
551
|
+
"k_b_proj":
|
|
552
|
+
(attn_heads,
|
|
553
|
+
qk_nope_head_dim // self.quantization_block_size_n,
|
|
554
|
+
kv_lora_rank // self.quantization_block_size_n),
|
|
555
|
+
# used for MLA kernel
|
|
556
|
+
"v_b_proj":
|
|
557
|
+
(attn_heads, v_head_dim // self.quantization_block_size_n,
|
|
558
|
+
kv_lora_rank // self.quantization_block_size_n),
|
|
490
559
|
"o_proj":
|
|
491
560
|
(hidden_size // self.quantization_block_size_n, attn_heads,
|
|
492
561
|
v_head_dim // self.quantization_block_size_n),
|
|
@@ -802,21 +871,73 @@ class DeepSeekV3WeightLoader:
|
|
|
802
871
|
f"Cumulative local memory: {cumulative_local_memory} GB"
|
|
803
872
|
)
|
|
804
873
|
else:
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
874
|
+
if self.use_mla_kernel and "kv_b_proj" in loaded_name:
|
|
875
|
+
# loaded_weight shape: (num_heads * (d_k + d_v), kv_lora_rank)
|
|
876
|
+
# scale shape: (num_heads * (d_k + d_v) / block_n, kv_lora_rank / block_k)
|
|
877
|
+
# Reshape to (num_heads, (d_k + d_v), kv_lora_rank) and split
|
|
878
|
+
weight_reshaped = loaded_weight.view(
|
|
879
|
+
self.attn_heads,
|
|
880
|
+
self.qk_nope_head_dim + self.v_head_dim,
|
|
881
|
+
self.kv_lora_rank)
|
|
882
|
+
k_weight = weight_reshaped[:, :self.
|
|
883
|
+
qk_nope_head_dim, :].reshape(
|
|
884
|
+
-1, self.kv_lora_rank)
|
|
885
|
+
v_weight = weight_reshaped[:, self.
|
|
886
|
+
qk_nope_head_dim:, :].reshape(
|
|
887
|
+
-1, self.kv_lora_rank)
|
|
888
|
+
|
|
889
|
+
loaded_weights_list = [k_weight, v_weight]
|
|
890
|
+
loaded_names = [
|
|
891
|
+
loaded_name.replace("kv_b_proj", "k_b_proj"),
|
|
892
|
+
loaded_name.replace("kv_b_proj", "v_b_proj")
|
|
893
|
+
]
|
|
894
|
+
|
|
895
|
+
scales_list = [None, None]
|
|
896
|
+
if scale is not None:
|
|
897
|
+
bn = self.quantization_block_size_n
|
|
898
|
+
bk = self.quantization_block_size_k
|
|
899
|
+
scale_reshaped = scale.view(
|
|
900
|
+
self.attn_heads,
|
|
901
|
+
(self.qk_nope_head_dim + self.v_head_dim) //
|
|
902
|
+
bn, self.kv_lora_rank // bk)
|
|
903
|
+
|
|
904
|
+
k_scale = scale_reshaped[:, :self.
|
|
905
|
+
qk_nope_head_dim //
|
|
906
|
+
bn, :].reshape(
|
|
907
|
+
-1,
|
|
908
|
+
self.kv_lora_rank //
|
|
909
|
+
bk)
|
|
910
|
+
v_scale = scale_reshaped[:,
|
|
911
|
+
self.qk_nope_head_dim //
|
|
912
|
+
bn:, :].reshape(
|
|
913
|
+
-1,
|
|
914
|
+
self.kv_lora_rank //
|
|
915
|
+
bk)
|
|
916
|
+
scales_list = [k_scale, v_scale]
|
|
917
|
+
|
|
918
|
+
else:
|
|
919
|
+
loaded_weights_list = [loaded_weight]
|
|
920
|
+
loaded_names = [loaded_name]
|
|
921
|
+
scales_list = [scale]
|
|
922
|
+
|
|
923
|
+
for loaded_name, loaded_weight, scale in zip(
|
|
924
|
+
loaded_names, loaded_weights_list, scales_list):
|
|
925
|
+
|
|
926
|
+
weight_bytes, weight_shards = self._load_individual_weight(
|
|
927
|
+
loaded_name,
|
|
928
|
+
loaded_weight,
|
|
929
|
+
model_params,
|
|
930
|
+
model_for_loading.mesh,
|
|
931
|
+
scale=scale)
|
|
932
|
+
if self.is_verbose:
|
|
933
|
+
cumulative_global_memory += weight_bytes
|
|
934
|
+
cumulative_local_memory += weight_shards
|
|
935
|
+
logger.info(
|
|
936
|
+
f"Cumulative global memory: {cumulative_global_memory} GB"
|
|
937
|
+
)
|
|
938
|
+
logger.info(
|
|
939
|
+
f"Cumulative local memory: {cumulative_local_memory} GB"
|
|
940
|
+
)
|
|
820
941
|
|
|
821
942
|
del mlp_experts_gate_proj_weights
|
|
822
943
|
del mlp_experts_up_proj_weights
|
|
@@ -102,9 +102,9 @@ class GptOss(nnx.Module):
|
|
|
102
102
|
rope_ntk_beta=rope_ntk_beta,
|
|
103
103
|
rngs=self.rng,
|
|
104
104
|
random_init=self.random_init,
|
|
105
|
-
query_tnh=P(
|
|
106
|
-
keyvalue_skh=P(
|
|
107
|
-
attn_o_tnh=P(
|
|
105
|
+
query_tnh=P("data", 'model', None),
|
|
106
|
+
keyvalue_skh=P("data", 'model', None),
|
|
107
|
+
attn_o_tnh=P("data", 'model', None),
|
|
108
108
|
dnh_sharding=P(None, 'model', None),
|
|
109
109
|
dkh_sharding=P(None, 'model', None),
|
|
110
110
|
nhd_sharding=P('model', None, None),
|
|
@@ -368,7 +368,8 @@ class LlamaForCausalLM(nnx.Module):
|
|
|
368
368
|
"lm_head": "model.lm_head",
|
|
369
369
|
})
|
|
370
370
|
|
|
371
|
-
metadata_map = get_default_maps(self.vllm_config
|
|
371
|
+
metadata_map = get_default_maps(self.vllm_config.model_config,
|
|
372
|
+
self.mesh, mappings)
|
|
372
373
|
load_hf_weights(vllm_config=self.vllm_config,
|
|
373
374
|
model=self,
|
|
374
375
|
metadata_map=metadata_map,
|
|
@@ -194,13 +194,12 @@ class Eagle3LlamaModel(nnx.Module):
|
|
|
194
194
|
|
|
195
195
|
def update_reshape_map_for_eagle3(vllm_config: VllmConfig,
|
|
196
196
|
metadata_map: MetadataMap):
|
|
197
|
-
model_config = vllm_config.
|
|
197
|
+
model_config = vllm_config.speculative_config.draft_model_config
|
|
198
198
|
hf_config = model_config.hf_config
|
|
199
199
|
|
|
200
200
|
num_heads = hf_config.num_attention_heads
|
|
201
201
|
num_kv_heads = hf_config.num_key_value_heads
|
|
202
|
-
hidden_size =
|
|
203
|
-
|
|
202
|
+
hidden_size = hf_config.hidden_size
|
|
204
203
|
head_dim_original = model_config.get_head_size()
|
|
205
204
|
|
|
206
205
|
metadata_map.reshape_map.update({
|
|
@@ -305,6 +304,8 @@ class EagleLlama3ForCausalLM(nnx.Module):
|
|
|
305
304
|
"fc": "model.fc.kernel",
|
|
306
305
|
"lm_head": "lm_head.kernel",
|
|
307
306
|
"d2t": "draft_id_to_target_id",
|
|
307
|
+
"embed_tokens":
|
|
308
|
+
"model.embed_tokens.embedding", # Some checkpoints need this
|
|
308
309
|
}
|
|
309
310
|
|
|
310
311
|
# Define keys to keep in original dtype (e.g., float32 for stability)
|
|
@@ -312,7 +313,9 @@ class EagleLlama3ForCausalLM(nnx.Module):
|
|
|
312
313
|
r".*d2t.*",
|
|
313
314
|
]
|
|
314
315
|
|
|
315
|
-
metadata_map = get_default_maps(
|
|
316
|
+
metadata_map = get_default_maps(
|
|
317
|
+
self.vllm_config.speculative_config.draft_model_config, self.mesh,
|
|
318
|
+
mappings)
|
|
316
319
|
|
|
317
320
|
update_reshape_map_for_eagle3(self.vllm_config, metadata_map)
|
|
318
321
|
|
|
@@ -324,7 +327,7 @@ class EagleLlama3ForCausalLM(nnx.Module):
|
|
|
324
327
|
is_draft_model=True,
|
|
325
328
|
keep_original_dtype_keys_regex=keep_original_dtype_keys_regex)
|
|
326
329
|
|
|
327
|
-
# If the embedding is not initialized, initialize it with a
|
|
330
|
+
# If the embedding is not initialized, initialize it with a dummy array here to pass jit compilation. The real weights will be shared from the target model in eagle3 class.
|
|
328
331
|
if isinstance(self.model.embed_tokens.embedding.value,
|
|
329
332
|
jax.ShapeDtypeStruct):
|
|
330
333
|
self.model.embed_tokens.embedding.value = jnp.zeros(
|