sglang 0.4.6.post1__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. sglang/bench_one_batch.py +3 -11
  2. sglang/bench_serving.py +149 -1
  3. sglang/check_env.py +3 -3
  4. sglang/lang/chat_template.py +44 -0
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/deepseekvl2.py +3 -0
  7. sglang/srt/configs/device_config.py +1 -1
  8. sglang/srt/configs/internvl.py +696 -0
  9. sglang/srt/configs/janus_pro.py +3 -0
  10. sglang/srt/configs/kimi_vl.py +38 -0
  11. sglang/srt/configs/kimi_vl_moonvit.py +32 -0
  12. sglang/srt/configs/model_config.py +32 -0
  13. sglang/srt/constrained/xgrammar_backend.py +11 -19
  14. sglang/srt/conversation.py +151 -3
  15. sglang/srt/disaggregation/decode.py +4 -1
  16. sglang/srt/disaggregation/mini_lb.py +74 -23
  17. sglang/srt/disaggregation/mooncake/conn.py +9 -18
  18. sglang/srt/disaggregation/nixl/conn.py +241 -71
  19. sglang/srt/disaggregation/utils.py +44 -1
  20. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  21. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  22. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  23. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  24. sglang/srt/distributed/parallel_state.py +22 -1
  25. sglang/srt/entrypoints/engine.py +58 -24
  26. sglang/srt/entrypoints/http_server.py +28 -1
  27. sglang/srt/entrypoints/verl_engine.py +3 -2
  28. sglang/srt/function_call_parser.py +97 -0
  29. sglang/srt/hf_transformers_utils.py +22 -1
  30. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -1
  31. sglang/srt/layers/attention/flashattention_backend.py +146 -50
  32. sglang/srt/layers/attention/flashinfer_backend.py +129 -94
  33. sglang/srt/layers/attention/flashinfer_mla_backend.py +88 -30
  34. sglang/srt/layers/attention/flashmla_backend.py +3 -0
  35. sglang/srt/layers/attention/merge_state.py +46 -0
  36. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  37. sglang/srt/layers/attention/vision.py +290 -163
  38. sglang/srt/layers/dp_attention.py +5 -2
  39. sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
  40. sglang/srt/layers/moe/ep_moe/layer.py +120 -1
  41. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +98 -57
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -5
  49. sglang/srt/layers/quantization/__init__.py +2 -2
  50. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  52. sglang/srt/layers/quantization/deep_gemm.py +6 -1
  53. sglang/srt/layers/quantization/fp8.py +108 -95
  54. sglang/srt/layers/quantization/fp8_kernel.py +79 -60
  55. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  56. sglang/srt/layers/quantization/kv_cache.py +3 -10
  57. sglang/srt/layers/quantization/utils.py +0 -5
  58. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  59. sglang/srt/layers/utils.py +35 -0
  60. sglang/srt/lora/layers.py +35 -9
  61. sglang/srt/lora/lora_manager.py +81 -35
  62. sglang/srt/managers/cache_controller.py +115 -119
  63. sglang/srt/managers/data_parallel_controller.py +52 -34
  64. sglang/srt/managers/io_struct.py +10 -0
  65. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  66. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  67. sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
  68. sglang/srt/managers/schedule_batch.py +44 -16
  69. sglang/srt/managers/schedule_policy.py +11 -5
  70. sglang/srt/managers/scheduler.py +291 -72
  71. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -1
  72. sglang/srt/managers/tokenizer_manager.py +24 -13
  73. sglang/srt/managers/tp_worker.py +60 -28
  74. sglang/srt/managers/tp_worker_overlap_thread.py +9 -3
  75. sglang/srt/mem_cache/chunk_cache.py +2 -0
  76. sglang/srt/mem_cache/memory_pool.py +70 -36
  77. sglang/srt/model_executor/cuda_graph_runner.py +82 -19
  78. sglang/srt/model_executor/forward_batch_info.py +31 -1
  79. sglang/srt/model_executor/model_runner.py +159 -90
  80. sglang/srt/model_loader/loader.py +18 -11
  81. sglang/srt/models/clip.py +4 -4
  82. sglang/srt/models/deepseek_janus_pro.py +1 -1
  83. sglang/srt/models/deepseek_nextn.py +2 -277
  84. sglang/srt/models/deepseek_v2.py +132 -37
  85. sglang/srt/models/gemma3_mm.py +1 -1
  86. sglang/srt/models/internlm2.py +3 -0
  87. sglang/srt/models/internvl.py +670 -0
  88. sglang/srt/models/kimi_vl.py +308 -0
  89. sglang/srt/models/kimi_vl_moonvit.py +639 -0
  90. sglang/srt/models/llama.py +93 -31
  91. sglang/srt/models/llama4.py +54 -7
  92. sglang/srt/models/llama_eagle.py +4 -1
  93. sglang/srt/models/llama_eagle3.py +4 -1
  94. sglang/srt/models/minicpmv.py +1 -1
  95. sglang/srt/models/mllama.py +1 -1
  96. sglang/srt/models/phi3_small.py +16 -2
  97. sglang/srt/models/qwen2_5_vl.py +8 -4
  98. sglang/srt/models/qwen2_moe.py +8 -3
  99. sglang/srt/models/qwen2_vl.py +4 -16
  100. sglang/srt/models/qwen3_moe.py +8 -3
  101. sglang/srt/models/xiaomi_mimo.py +171 -0
  102. sglang/srt/openai_api/adapter.py +58 -62
  103. sglang/srt/openai_api/protocol.py +38 -16
  104. sglang/srt/reasoning_parser.py +2 -2
  105. sglang/srt/sampling/sampling_batch_info.py +54 -2
  106. sglang/srt/sampling/sampling_params.py +2 -0
  107. sglang/srt/server_args.py +93 -24
  108. sglang/srt/speculative/eagle_worker.py +3 -2
  109. sglang/srt/utils.py +123 -10
  110. sglang/test/runners.py +4 -0
  111. sglang/test/test_block_fp8.py +2 -2
  112. sglang/test/test_deepep_utils.py +219 -0
  113. sglang/test/test_utils.py +32 -1
  114. sglang/version.py +1 -1
  115. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +18 -9
  116. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +119 -99
  117. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
  118. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
  119. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -24,34 +24,15 @@ from sglang.srt.distributed import get_tensor_model_parallel_world_size
24
24
  from sglang.srt.layers.layernorm import RMSNorm
25
25
  from sglang.srt.layers.linear import ReplicatedLinear
26
26
  from sglang.srt.layers.logits_processor import LogitsProcessor
27
- from sglang.srt.layers.moe.ep_moe.layer import EPMoE
28
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
29
27
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
30
- from sglang.srt.layers.quantization.fp8_utils import (
31
- block_quant_to_tensor_quant,
32
- normalize_e4m3fn_to_e4m3fnuz,
33
- )
34
- from sglang.srt.layers.quantization.int8_utils import (
35
- block_dequant as int8_block_dequant,
36
- )
37
28
  from sglang.srt.layers.vocab_parallel_embedding import (
38
29
  ParallelLMHead,
39
30
  VocabParallelEmbedding,
40
31
  )
41
32
  from sglang.srt.managers.schedule_batch import global_server_args_dict
42
33
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
43
- from sglang.srt.model_loader.weight_utils import default_weight_loader
44
34
  from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
45
- from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda, is_hip
46
-
47
- _is_hip = is_hip()
48
- _is_cuda = is_cuda()
49
-
50
- if _is_cuda:
51
- from sgl_kernel import awq_dequantize
52
- else:
53
- from vllm._custom_ops import awq_dequantize
54
-
35
+ from sglang.srt.utils import BumpAllocator, add_prefix
55
36
 
56
37
  logger = logging.getLogger(__name__)
57
38
 
@@ -177,263 +158,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
177
158
  )
178
159
 
179
160
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
180
- if hasattr(self.config, "num_nextn_predict_layers"):
181
- num_nextn_layers = self.config.num_nextn_predict_layers
182
- assert num_nextn_layers == 1, "Only 1 nextn layer is supportted"
183
- assert num_nextn_layers == self.config.num_hidden_layers
184
- else:
185
- raise ValueError("num_nextn_predict_layers is not in the config")
186
-
187
- stacked_params_mapping = [
188
- # (param_name, shard_name, shard_id)
189
- ("gate_up_proj", "gate_proj", 0),
190
- ("gate_up_proj", "up_proj", 1),
191
- ]
192
- if self.n_share_experts_fusion > 0:
193
- logger.info(
194
- f"Cloning {self.n_share_experts_fusion} "
195
- "replicas of the shared expert into MoE for DeepseekV3ForCausalLMNextN"
196
- )
197
- weights_list = list(weights)
198
- weights_dict = dict(weights_list)
199
- if self.quant_config is None or self.quant_config.get_name() == "w8a8_int8":
200
- suffix_list = [
201
- "down_proj.weight",
202
- "down_proj.weight_scale",
203
- "gate_proj.weight",
204
- "gate_proj.weight_scale",
205
- "up_proj.weight",
206
- "up_proj.weight_scale",
207
- ]
208
- else:
209
- suffix_list = [
210
- "down_proj.weight",
211
- "down_proj.weight_scale_inv",
212
- "gate_proj.weight",
213
- "gate_proj.weight_scale_inv",
214
- "up_proj.weight",
215
- "up_proj.weight_scale_inv",
216
- ]
217
- names_to_remove = []
218
- for suffix in suffix_list:
219
- shared_expert_weight_name = (
220
- f"model.layers.0.mlp.shared_experts.{suffix}"
221
- )
222
- for num_repeat in range(self.n_share_experts_fusion):
223
- weights_list.append(
224
- (
225
- f"model.layers.0."
226
- f"mlp.experts."
227
- f"{self.config.n_routed_experts + num_repeat}"
228
- f".{suffix}",
229
- weights_dict[shared_expert_weight_name],
230
- )
231
- )
232
- names_to_remove += [shared_expert_weight_name]
233
- weights = [w for w in weights_list if w[0] not in names_to_remove]
234
-
235
- # Params for weights, fp8 weight scales, fp8 activation scales
236
- # (param_name, weight_name, expert_id, shard_id)
237
- MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
238
- expert_params_mapping = MoEImpl.make_expert_params_mapping(
239
- ckpt_gate_proj_name="gate_proj",
240
- ckpt_down_proj_name="down_proj",
241
- ckpt_up_proj_name="up_proj",
242
- num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
243
- )
244
-
245
- # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
246
- fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
247
- self.config.q_lora_rank is not None
248
- )
249
- cached_a_proj = {} if fuse_qkv_a_proj else None
250
-
251
- nextn_layer_prefix = "model.layers.0"
252
- nextn_spec_weight_names = [
253
- "shared_head.norm",
254
- "eh_proj",
255
- "enorm",
256
- "hnorm",
257
- ]
258
-
259
- params_dict = dict(self.named_parameters())
260
- for name, loaded_weight in weights:
261
- if not name.startswith(nextn_layer_prefix):
262
- continue
263
-
264
- # Use shared head and embed weights from target model
265
- if "shared_head.head" in name or "embed_tokens" in name:
266
- continue
267
-
268
- is_decoder = True
269
- # For nextn specific weights
270
- for weight_name in nextn_spec_weight_names:
271
- if weight_name in name:
272
- name = name.replace(nextn_layer_prefix, "model")
273
- is_decoder = False
274
- break
275
- # For decoder layer weights
276
- if is_decoder:
277
- name = name.replace(nextn_layer_prefix, "model.decoder")
278
-
279
- if "rotary_emb.inv_freq" in name:
280
- continue
281
- for param_name, weight_name, shard_id in stacked_params_mapping:
282
- # Skip non-stacked layers and experts (experts handled below).
283
- if weight_name not in name:
284
- continue
285
- # We have mlp.experts[0].gate_proj in the checkpoint.
286
- # Since we handle the experts below in expert_params_mapping,
287
- # we need to skip here BEFORE we update the name, otherwise
288
- # name will be updated to mlp.experts[0].gate_up_proj, which
289
- # will then be updated below in expert_params_mapping
290
- # for mlp.experts[0].gate_gate_up_proj, which breaks load.
291
- if ("mlp.experts." in name) and name not in params_dict:
292
- continue
293
- name = name.replace(weight_name, param_name)
294
- # Skip loading extra bias for GPTQ models.
295
- if name.endswith(".bias") and name not in params_dict:
296
- continue
297
- param = params_dict[name]
298
- weight_loader = param.weight_loader
299
- weight_loader(param, loaded_weight, shard_id)
300
- break
301
- else:
302
- for mapping in expert_params_mapping:
303
- param_name, weight_name, expert_id, shard_id = mapping
304
- if weight_name not in name:
305
- continue
306
- name = name.replace(weight_name, param_name)
307
- param = params_dict[name]
308
- weight_loader = param.weight_loader
309
- weight_loader(
310
- param,
311
- loaded_weight,
312
- name,
313
- shard_id=shard_id,
314
- expert_id=expert_id,
315
- )
316
- break
317
- else:
318
- # Skip loading extra bias for GPTQ models.
319
- if name.endswith(".bias") and name not in params_dict:
320
- continue
321
-
322
- # Handle fused_qkv_a_proj
323
- if fuse_qkv_a_proj and (
324
- "q_a_proj" in name or "kv_a_proj_with_mqa" in name
325
- ):
326
- cached_a_proj[name] = loaded_weight
327
- q_a_proj_name = (
328
- name
329
- if "q_a_proj" in name
330
- else name.replace("kv_a_proj_with_mqa", "q_a_proj")
331
- )
332
- kv_a_proj_name = (
333
- name
334
- if "kv_a_proj_with_mqa" in name
335
- else name.replace("q_a_proj", "kv_a_proj_with_mqa")
336
- )
337
-
338
- # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
339
- if (
340
- q_a_proj_name in cached_a_proj
341
- and kv_a_proj_name in cached_a_proj
342
- ):
343
-
344
- q_a_proj_weight = cached_a_proj[q_a_proj_name]
345
- kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
346
- fused_weight = torch.cat(
347
- [q_a_proj_weight, kv_a_proj_weight], dim=0
348
- )
349
-
350
- param_name = name.replace(
351
- "q_a_proj", "fused_qkv_a_proj_with_mqa"
352
- )
353
- param = params_dict[param_name]
354
-
355
- weight_loader = getattr(
356
- param, "weight_loader", default_weight_loader
357
- )
358
- weight_loader(param, fused_weight)
359
- cached_a_proj.pop(q_a_proj_name)
360
- cached_a_proj.pop(kv_a_proj_name)
361
- else:
362
- param = params_dict[name]
363
- weight_loader = getattr(
364
- param, "weight_loader", default_weight_loader
365
- )
366
- weight_loader(param, loaded_weight)
367
-
368
- self_attn = self.model.decoder.self_attn
369
- if hasattr(self_attn.kv_b_proj, "qweight"):
370
- # AWQ compatible
371
- if _is_cuda:
372
- w = awq_dequantize(
373
- self_attn.kv_b_proj.qweight,
374
- self_attn.kv_b_proj.scales,
375
- self_attn.kv_b_proj.qzeros,
376
- ).T
377
- else:
378
- w = awq_dequantize(
379
- self_attn.kv_b_proj.qweight,
380
- self_attn.kv_b_proj.scales,
381
- self_attn.kv_b_proj.qzeros,
382
- 0,
383
- 0,
384
- 0,
385
- ).T
386
- else:
387
- w = self_attn.kv_b_proj.weight
388
- # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
389
- # This may affect the accuracy of fp8 model.
390
- if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
391
- torch.float8_e4m3fn,
392
- torch.float8_e4m3fnuz,
393
- ):
394
- weight_block_size = self.quant_config.weight_block_size
395
- if weight_block_size is not None:
396
- assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
397
- if _is_hip:
398
- weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
399
- weight=w,
400
- weight_scale=self_attn.kv_b_proj.weight_scale_inv,
401
- input_scale=None,
402
- )
403
- else:
404
- weight = w
405
- weight_scale = self_attn.kv_b_proj.weight_scale_inv
406
-
407
- w, scale = block_quant_to_tensor_quant(
408
- weight, weight_scale, weight_block_size
409
- )
410
- self_attn.w_scale = scale
411
- if w.dtype == torch.int8:
412
- if hasattr(self.quant_config, "weight_block_size"):
413
- # block-wise int8 need it
414
- weight_block_size = self.quant_config.weight_block_size
415
- if weight_block_size is not None:
416
- assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
417
- weight = w
418
- weight_scale = self_attn.kv_b_proj.weight_scale_inv
419
- w = int8_block_dequant(weight, weight_scale, weight_block_size).to(
420
- torch.bfloat16
421
- )
422
- else:
423
- # channel-wise int8 need it
424
- assert hasattr(self_attn.kv_b_proj, "weight_scale")
425
- w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
426
- torch.bfloat16
427
- )
428
- w_kc, w_vc = w.unflatten(
429
- 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
430
- ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
431
- self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
432
- self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
433
- if hasattr(self_attn.kv_b_proj, "weight_scale") and self_attn.w_scale is None:
434
- self_attn.w_scale = self_attn.kv_b_proj.weight_scale
435
- if _is_hip:
436
- self_attn.w_scale *= 2.0
161
+ super().load_weights(weights, is_nextn=True)
437
162
 
438
163
 
439
164
  EntryClass = [DeepseekV3ForCausalLMNextN]
@@ -59,10 +59,11 @@ from sglang.srt.layers.moe.topk import select_experts
59
59
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
60
60
  from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
61
61
  from sglang.srt.layers.quantization.fp8_kernel import (
62
- per_tensor_quant_mla_deep_gemm_masked_fp8,
63
62
  per_tensor_quant_mla_fp8,
63
+ per_token_group_quant_mla_deep_gemm_masked_fp8,
64
64
  )
65
65
  from sglang.srt.layers.quantization.fp8_utils import (
66
+ block_quant_dequant,
66
67
  block_quant_to_tensor_quant,
67
68
  channel_quant_to_tensor_quant,
68
69
  normalize_e4m3fn_to_e4m3fnuz,
@@ -88,6 +89,7 @@ from sglang.srt.utils import (
88
89
  get_int_env_var,
89
90
  is_cuda,
90
91
  is_hip,
92
+ log_info_on_rank0,
91
93
  )
92
94
 
93
95
  _is_hip = is_hip()
@@ -356,6 +358,7 @@ class DeepseekV2MoE(nn.Module):
356
358
  topk_idx,
357
359
  topk_weights,
358
360
  reorder_topk_ids,
361
+ num_recv_tokens_per_expert,
359
362
  seg_indptr,
360
363
  masked_m,
361
364
  expected_m,
@@ -367,10 +370,13 @@ class DeepseekV2MoE(nn.Module):
367
370
  )
368
371
  final_hidden_states = self.experts(
369
372
  hidden_states=hidden_states,
373
+ topk_idx=topk_idx,
374
+ topk_weights=topk_weights,
370
375
  reorder_topk_ids=reorder_topk_ids,
371
376
  seg_indptr=seg_indptr,
372
377
  masked_m=masked_m,
373
378
  expected_m=expected_m,
379
+ num_recv_tokens_per_expert=num_recv_tokens_per_expert,
374
380
  forward_mode=forward_mode,
375
381
  )
376
382
  if self.ep_size > 1:
@@ -421,6 +427,7 @@ class DeepseekV2AttentionMLA(nn.Module):
421
427
  reduce_results: bool = True,
422
428
  layer_id: int = None,
423
429
  prefix: str = "",
430
+ alt_stream: Optional[torch.cuda.Stream] = None,
424
431
  ) -> None:
425
432
  super().__init__()
426
433
  self.layer_id = layer_id
@@ -543,6 +550,8 @@ class DeepseekV2AttentionMLA(nn.Module):
543
550
  prefix=add_prefix("attn_mha", prefix),
544
551
  )
545
552
 
553
+ self.alt_stream = alt_stream
554
+
546
555
  self.w_kc = None
547
556
  self.w_vc = None
548
557
  self.w_scale = None
@@ -706,20 +715,36 @@ class DeepseekV2AttentionMLA(nn.Module):
706
715
  q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
707
716
  [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
708
717
  )
709
- q = self.q_a_layernorm(q)
718
+ k_nope = latent_cache[..., : self.kv_lora_rank]
719
+
720
+ # overlap qk norm
721
+ if self.alt_stream is not None and torch.cuda.is_current_stream_capturing():
722
+ current_stream = torch.cuda.current_stream()
723
+ self.alt_stream.wait_stream(current_stream)
724
+ q = self.q_a_layernorm(q)
725
+ with torch.cuda.stream(self.alt_stream):
726
+ k_nope = self.kv_a_layernorm(k_nope)
727
+ current_stream.wait_stream(self.alt_stream)
728
+ else:
729
+ q = self.q_a_layernorm(q)
730
+ k_nope = self.kv_a_layernorm(k_nope)
731
+
732
+ k_nope = k_nope.unsqueeze(1)
710
733
  q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
711
734
  else:
712
735
  q = self.q_proj(hidden_states)[0].view(
713
736
  -1, self.num_local_heads, self.qk_head_dim
714
737
  )
715
738
  latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
739
+ k_nope = latent_cache[..., : self.kv_lora_rank]
740
+ k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)
741
+
716
742
  q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
743
+ k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
717
744
 
718
745
  if self.use_deep_gemm_bmm:
719
746
  q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
720
- per_tensor_quant_mla_deep_gemm_masked_fp8(
721
- q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
722
- )
747
+ per_token_group_quant_mla_deep_gemm_masked_fp8(q_nope.transpose(0, 1))
723
748
  )
724
749
  q_nope_out = q_nope.new_empty(
725
750
  (self.num_local_heads, aligned_m, self.kv_lora_rank)
@@ -750,14 +775,9 @@ class DeepseekV2AttentionMLA(nn.Module):
750
775
  q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
751
776
 
752
777
  q_nope_out = q_nope_out.transpose(0, 1)
753
-
754
- k_nope = latent_cache[..., : self.kv_lora_rank]
755
- k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)
756
- k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
757
-
758
778
  q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
759
779
 
760
- if self.attention_backend == "fa3":
780
+ if self.attention_backend == "fa3" or self.attention_backend == "flashinfer":
761
781
  attn_output = self.attn_mqa(
762
782
  q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
763
783
  )
@@ -769,8 +789,8 @@ class DeepseekV2AttentionMLA(nn.Module):
769
789
 
770
790
  if self.use_deep_gemm_bmm:
771
791
  attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = (
772
- per_tensor_quant_mla_deep_gemm_masked_fp8(
773
- attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
792
+ per_token_group_quant_mla_deep_gemm_masked_fp8(
793
+ attn_output.transpose(0, 1)
774
794
  )
775
795
  )
776
796
  attn_bmm_output = attn_output.new_empty(
@@ -1104,6 +1124,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1104
1124
  quant_config: Optional[QuantizationConfig] = None,
1105
1125
  is_nextn: bool = False,
1106
1126
  prefix: str = "",
1127
+ alt_stream: Optional[torch.cuda.Stream] = None,
1107
1128
  ) -> None:
1108
1129
  super().__init__()
1109
1130
  self.hidden_size = config.hidden_size
@@ -1133,6 +1154,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1133
1154
  layer_id=layer_id,
1134
1155
  reduce_results=False,
1135
1156
  prefix=add_prefix("self_attn", prefix),
1157
+ alt_stream=alt_stream,
1136
1158
  )
1137
1159
 
1138
1160
  self.info = self._compute_info(config, layer_id=layer_id, is_nextn=is_nextn)
@@ -1376,6 +1398,7 @@ class DeepseekV2Model(nn.Module):
1376
1398
  config.hidden_size,
1377
1399
  enable_tp=not global_server_args_dict["enable_dp_attention"],
1378
1400
  )
1401
+ self.alt_stream = torch.cuda.Stream()
1379
1402
  self.layers = nn.ModuleList(
1380
1403
  [
1381
1404
  DeepseekV2DecoderLayer(
@@ -1383,6 +1406,7 @@ class DeepseekV2Model(nn.Module):
1383
1406
  layer_id,
1384
1407
  quant_config=quant_config,
1385
1408
  prefix=add_prefix(f"layers.{layer_id}", prefix),
1409
+ alt_stream=self.alt_stream,
1386
1410
  )
1387
1411
  for layer_id in range(config.num_hidden_layers)
1388
1412
  ]
@@ -1391,6 +1415,9 @@ class DeepseekV2Model(nn.Module):
1391
1415
 
1392
1416
  self.dp_size = get_attention_dp_size()
1393
1417
 
1418
+ def get_input_embeddings(self) -> torch.Tensor:
1419
+ return self.embed_tokens
1420
+
1394
1421
  def forward(
1395
1422
  self,
1396
1423
  input_ids: torch.Tensor,
@@ -1464,8 +1491,9 @@ class DeepseekV2ForCausalLM(nn.Module):
1464
1491
  ):
1465
1492
  self.n_share_experts_fusion = 0
1466
1493
  global_server_args_dict["n_share_experts_fusion"] = 0
1467
- logger.info(
1468
- "Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled."
1494
+ log_info_on_rank0(
1495
+ logger,
1496
+ "Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled.",
1469
1497
  )
1470
1498
  else:
1471
1499
  assert (
@@ -1480,8 +1508,9 @@ class DeepseekV2ForCausalLM(nn.Module):
1480
1508
  ):
1481
1509
  self.n_share_experts_fusion = self.tp_size
1482
1510
  global_server_args_dict["n_share_experts_fusion"] = self.tp_size
1483
- logger.info(
1484
- "Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled."
1511
+ log_info_on_rank0(
1512
+ logger,
1513
+ "Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
1485
1514
  )
1486
1515
 
1487
1516
  def get_input_embeddings(self) -> nn.Embedding:
@@ -1502,11 +1531,20 @@ class DeepseekV2ForCausalLM(nn.Module):
1502
1531
  input_ids, hidden_states, self.lm_head, forward_batch
1503
1532
  )
1504
1533
 
1505
- def post_load_weights(self):
1534
+ def post_load_weights(self, is_nextn=False):
1506
1535
 
1507
1536
  # Perform post-processing after loading weights
1508
- for layer_id in range(self.config.num_hidden_layers):
1509
- self_attn = self.model.layers[layer_id].self_attn
1537
+ layer_ids = (
1538
+ range(self.config.num_hidden_layers)
1539
+ if not is_nextn
1540
+ else [self.config.num_hidden_layers]
1541
+ )
1542
+ for layer_id in layer_ids:
1543
+ self_attn = (
1544
+ self.model.layers[layer_id].self_attn
1545
+ if not is_nextn
1546
+ else self.model.decoder.self_attn
1547
+ )
1510
1548
  if hasattr(self_attn.kv_b_proj, "qweight"):
1511
1549
  # AWQ compatible
1512
1550
  if _is_cuda:
@@ -1552,13 +1590,22 @@ class DeepseekV2ForCausalLM(nn.Module):
1552
1590
 
1553
1591
  if (
1554
1592
  _is_cuda
1555
- and _ENABLE_JIT_DEEPGEMM
1556
1593
  and weight_block_size[0] == 128
1557
1594
  and weight_block_size[1] == 128
1558
1595
  and model_dtype == torch.bfloat16
1559
1596
  ):
1560
- block_scale = weight_scale
1561
- use_deep_gemm_bmm = True
1597
+ if _ENABLE_JIT_DEEPGEMM and get_bool_env_var(
1598
+ "SGL_USE_DEEPGEMM_BMM", "false"
1599
+ ):
1600
+ block_scale = weight_scale
1601
+ use_deep_gemm_bmm = True
1602
+ else:
1603
+ w = block_quant_dequant(
1604
+ weight,
1605
+ weight_scale,
1606
+ weight_block_size,
1607
+ model_dtype,
1608
+ )
1562
1609
  else:
1563
1610
  w, scale = block_quant_to_tensor_quant(
1564
1611
  weight, weight_scale, weight_block_size
@@ -1612,7 +1659,20 @@ class DeepseekV2ForCausalLM(nn.Module):
1612
1659
  self_attn.w_vc = w_vc.contiguous()
1613
1660
  self_attn.use_deep_gemm_bmm = True
1614
1661
 
1615
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
1662
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
1663
+ if is_nextn:
1664
+ if hasattr(self.config, "num_nextn_predict_layers"):
1665
+ num_nextn_layers = self.config.num_nextn_predict_layers
1666
+ assert num_nextn_layers == 1, "Only 1 nextn layer is supportted"
1667
+ # compatible with old design
1668
+ nextn_layer_id = (
1669
+ 0
1670
+ if self.config.num_hidden_layers == 1
1671
+ else self.config.num_hidden_layers
1672
+ )
1673
+ else:
1674
+ raise ValueError("num_nextn_predict_layers is not in the config")
1675
+
1616
1676
  stacked_params_mapping = [
1617
1677
  # (param_name, shard_name, shard_id)
1618
1678
  ("gate_up_proj", "gate_proj", 0),
@@ -1640,12 +1700,19 @@ class DeepseekV2ForCausalLM(nn.Module):
1640
1700
  "up_proj.weight_scale_inv",
1641
1701
  ]
1642
1702
  names_to_remove = []
1643
- for moe_layer in tqdm(
1703
+
1704
+ moe_layers = (
1644
1705
  range(
1645
1706
  self.config.first_k_dense_replace,
1646
1707
  self.config.num_hidden_layers,
1647
1708
  self.config.moe_layer_freq,
1648
- ),
1709
+ )
1710
+ if not is_nextn
1711
+ else [nextn_layer_id]
1712
+ )
1713
+
1714
+ for moe_layer in tqdm(
1715
+ moe_layers,
1649
1716
  desc=f"Cloning {self.n_share_experts_fusion} "
1650
1717
  "replicas of the shared expert into MoE",
1651
1718
  ):
@@ -1686,18 +1753,46 @@ class DeepseekV2ForCausalLM(nn.Module):
1686
1753
  )
1687
1754
  cached_a_proj = {} if fuse_qkv_a_proj else None
1688
1755
 
1756
+ if is_nextn:
1757
+ nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
1758
+ nextn_spec_weight_names = [
1759
+ "shared_head.norm",
1760
+ "eh_proj",
1761
+ "enorm",
1762
+ "hnorm",
1763
+ ]
1764
+
1689
1765
  params_dict = dict(self.named_parameters())
1690
1766
  for name, loaded_weight in weights:
1691
- # TODO(HandH1998): Modify it when nextn is supported.
1692
- if hasattr(self.config, "num_nextn_predict_layers"):
1693
- num_nextn_layers = self.config.num_nextn_predict_layers
1694
- if num_nextn_layers > 0 and name.startswith("model.layers"):
1695
- name_list = name.split(".")
1696
- if (
1697
- len(name_list) >= 3
1698
- and int(name_list[2]) >= self.config.num_hidden_layers
1699
- ):
1700
- continue
1767
+ if not is_nextn:
1768
+ if hasattr(self.config, "num_nextn_predict_layers"):
1769
+ num_nextn_layers = self.config.num_nextn_predict_layers
1770
+ if num_nextn_layers > 0 and name.startswith("model.layers"):
1771
+ name_list = name.split(".")
1772
+ if (
1773
+ len(name_list) >= 3
1774
+ and int(name_list[2]) >= self.config.num_hidden_layers
1775
+ ):
1776
+ continue
1777
+ else:
1778
+ if not name.startswith(nextn_layer_prefix):
1779
+ continue
1780
+
1781
+ # Use shared head and embed weights from target model
1782
+ if "shared_head.head" in name or "embed_tokens" in name:
1783
+ continue
1784
+
1785
+ is_decoder = True
1786
+ # For nextn specific weights
1787
+ for weight_name in nextn_spec_weight_names:
1788
+ if weight_name in name:
1789
+ name = name.replace(nextn_layer_prefix, "model")
1790
+ is_decoder = False
1791
+ break
1792
+ # For decoder layer weights
1793
+ if is_decoder:
1794
+ name = name.replace(nextn_layer_prefix, "model.decoder")
1795
+
1701
1796
  if "rotary_emb.inv_freq" in name:
1702
1797
  continue
1703
1798
  for param_name, weight_name, shard_id in stacked_params_mapping:
@@ -1786,7 +1881,7 @@ class DeepseekV2ForCausalLM(nn.Module):
1786
1881
  )
1787
1882
  weight_loader(param, loaded_weight)
1788
1883
 
1789
- self.post_load_weights()
1884
+ self.post_load_weights(is_nextn=is_nextn)
1790
1885
 
1791
1886
  def get_embed_and_head(self):
1792
1887
  return self.model.embed_tokens.weight, self.lm_head.weight
@@ -281,7 +281,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
281
281
  pixel_values = torch.stack(
282
282
  flatten_nested_list([item.pixel_values for item in items]), dim=0
283
283
  )
284
- pixel_values = pixel_values.to("cuda")
284
+ pixel_values = pixel_values.to(device=self.vision_tower.device)
285
285
  pixel_values = pixel_values.to(dtype=self.language_model.dtype())
286
286
 
287
287
  vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state
@@ -290,6 +290,9 @@ class InternLM2ForCausalLM(nn.Module):
290
290
  )
291
291
  self.logits_processor = LogitsProcessor(config)
292
292
 
293
+ def get_input_embeddings(self) -> nn.Embedding:
294
+ return self.model.tok_embeddings
295
+
293
296
  @torch.no_grad()
294
297
  def forward(
295
298
  self,