sglang 0.4.7__py3-none-any.whl → 0.4.7.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_serving.py +1 -1
  4. sglang/lang/interpreter.py +40 -1
  5. sglang/lang/ir.py +27 -0
  6. sglang/math_utils.py +8 -0
  7. sglang/srt/configs/model_config.py +6 -0
  8. sglang/srt/conversation.py +6 -0
  9. sglang/srt/disaggregation/base/__init__.py +1 -1
  10. sglang/srt/disaggregation/base/conn.py +25 -11
  11. sglang/srt/disaggregation/common/__init__.py +5 -1
  12. sglang/srt/disaggregation/common/utils.py +42 -0
  13. sglang/srt/disaggregation/decode.py +196 -51
  14. sglang/srt/disaggregation/fake/__init__.py +1 -1
  15. sglang/srt/disaggregation/fake/conn.py +15 -9
  16. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  17. sglang/srt/disaggregation/mooncake/conn.py +18 -13
  18. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  19. sglang/srt/disaggregation/nixl/conn.py +17 -12
  20. sglang/srt/disaggregation/prefill.py +128 -43
  21. sglang/srt/disaggregation/utils.py +127 -123
  22. sglang/srt/entrypoints/engine.py +15 -1
  23. sglang/srt/entrypoints/http_server.py +13 -2
  24. sglang/srt/eplb_simulator/__init__.py +1 -0
  25. sglang/srt/eplb_simulator/reader.py +51 -0
  26. sglang/srt/layers/activation.py +19 -0
  27. sglang/srt/layers/attention/aiter_backend.py +15 -2
  28. sglang/srt/layers/attention/cutlass_mla_backend.py +38 -15
  29. sglang/srt/layers/attention/flashattention_backend.py +53 -64
  30. sglang/srt/layers/attention/flashinfer_backend.py +1 -2
  31. sglang/srt/layers/attention/flashinfer_mla_backend.py +22 -24
  32. sglang/srt/layers/attention/flashmla_backend.py +2 -10
  33. sglang/srt/layers/attention/triton_backend.py +119 -119
  34. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  35. sglang/srt/layers/attention/vision.py +51 -24
  36. sglang/srt/layers/communicator.py +23 -5
  37. sglang/srt/layers/linear.py +0 -4
  38. sglang/srt/layers/logits_processor.py +0 -12
  39. sglang/srt/layers/moe/ep_moe/kernels.py +6 -5
  40. sglang/srt/layers/moe/ep_moe/layer.py +42 -32
  41. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  42. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -4
  43. sglang/srt/layers/moe/topk.py +16 -8
  44. sglang/srt/layers/pooler.py +56 -0
  45. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  46. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  47. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  48. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  49. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  50. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  51. sglang/srt/layers/radix_attention.py +2 -3
  52. sglang/srt/lora/lora_manager.py +79 -34
  53. sglang/srt/lora/mem_pool.py +4 -5
  54. sglang/srt/managers/cache_controller.py +2 -1
  55. sglang/srt/managers/io_struct.py +28 -4
  56. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  57. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  58. sglang/srt/managers/schedule_batch.py +39 -6
  59. sglang/srt/managers/scheduler.py +73 -17
  60. sglang/srt/managers/tokenizer_manager.py +29 -2
  61. sglang/srt/mem_cache/chunk_cache.py +1 -0
  62. sglang/srt/mem_cache/hiradix_cache.py +4 -2
  63. sglang/srt/mem_cache/memory_pool.py +111 -407
  64. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  65. sglang/srt/mem_cache/radix_cache.py +36 -12
  66. sglang/srt/model_executor/cuda_graph_runner.py +122 -55
  67. sglang/srt/model_executor/forward_batch_info.py +14 -5
  68. sglang/srt/model_executor/model_runner.py +6 -6
  69. sglang/srt/model_loader/loader.py +8 -1
  70. sglang/srt/models/bert.py +113 -13
  71. sglang/srt/models/deepseek_v2.py +113 -155
  72. sglang/srt/models/internvl.py +46 -102
  73. sglang/srt/models/roberta.py +117 -9
  74. sglang/srt/models/vila.py +305 -0
  75. sglang/srt/openai_api/adapter.py +162 -4
  76. sglang/srt/openai_api/protocol.py +37 -1
  77. sglang/srt/sampling/sampling_batch_info.py +24 -0
  78. sglang/srt/sampling/sampling_params.py +2 -0
  79. sglang/srt/server_args.py +318 -233
  80. sglang/srt/speculative/build_eagle_tree.py +1 -1
  81. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -3
  82. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +5 -2
  83. sglang/srt/speculative/eagle_utils.py +389 -109
  84. sglang/srt/speculative/eagle_worker.py +134 -43
  85. sglang/srt/two_batch_overlap.py +4 -2
  86. sglang/srt/utils.py +58 -0
  87. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  88. sglang/test/runners.py +38 -3
  89. sglang/test/test_block_fp8.py +1 -0
  90. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  91. sglang/test/test_block_fp8_ep.py +1 -0
  92. sglang/test/test_utils.py +3 -1
  93. sglang/utils.py +9 -0
  94. sglang/version.py +1 -1
  95. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/METADATA +5 -5
  96. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/RECORD +99 -88
  97. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/top_level.txt +0 -0
@@ -31,6 +31,7 @@ from __future__ import annotations
31
31
 
32
32
  from dataclasses import dataclass
33
33
  from enum import IntEnum, auto
34
+ from functools import total_ordering
34
35
  from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
35
36
 
36
37
  import torch
@@ -117,13 +118,14 @@ class ForwardMode(IntEnum):
117
118
  return self == ForwardMode.DECODE or self == ForwardMode.IDLE
118
119
 
119
120
 
121
+ @total_ordering
120
122
  class CaptureHiddenMode(IntEnum):
121
123
  # Do not capture anything.
122
- NULL = auto()
123
- # Capture hidden states of all tokens.
124
- FULL = auto()
124
+ NULL = 0
125
125
  # Capture a hidden state of the last token.
126
- LAST = auto()
126
+ LAST = 1
127
+ # Capture hidden states of all tokens.
128
+ FULL = 2
127
129
 
128
130
  def need_capture(self):
129
131
  return self != CaptureHiddenMode.NULL
@@ -134,6 +136,9 @@ class CaptureHiddenMode(IntEnum):
134
136
  def is_last(self):
135
137
  return self == CaptureHiddenMode.LAST
136
138
 
139
+ def __lt__(self, other):
140
+ return self.value < other.value
141
+
137
142
 
138
143
  @dataclass
139
144
  class ForwardBatch:
@@ -219,6 +224,9 @@ class ForwardBatch:
219
224
  # For input embeddings
220
225
  input_embeds: Optional[torch.tensor] = None
221
226
 
227
+ # For cross-encoder model
228
+ token_type_ids: Optional[torch.Tensor] = None
229
+
222
230
  # Sampling info
223
231
  sampling_info: SamplingBatchInfo = None
224
232
 
@@ -295,6 +303,7 @@ class ForwardBatch:
295
303
  spec_info=batch.spec_info,
296
304
  capture_hidden_mode=batch.capture_hidden_mode,
297
305
  input_embeds=batch.input_embeds,
306
+ token_type_ids=batch.token_type_ids,
298
307
  tbo_split_seq_index=batch.tbo_split_seq_index,
299
308
  )
300
309
  device = model_runner.device
@@ -351,8 +360,8 @@ class ForwardBatch:
351
360
  ret.extend_prefix_lens = torch.tensor(
352
361
  batch.extend_prefix_lens, dtype=torch.int32
353
362
  ).to(device, non_blocking=True)
363
+ ret.extend_num_tokens = batch.extend_num_tokens
354
364
  if support_triton(model_runner.server_args.attention_backend):
355
- ret.extend_num_tokens = batch.extend_num_tokens
356
365
  positions, ret.extend_start_loc = compute_position_triton(
357
366
  ret.extend_prefix_lens,
358
367
  ret.extend_seq_lens,
@@ -26,6 +26,7 @@ from typing import List, Optional, Tuple, Union
26
26
  import torch
27
27
  import torch.distributed as dist
28
28
 
29
+ from sglang.srt import debug_utils
29
30
  from sglang.srt.configs.device_config import DeviceConfig
30
31
  from sglang.srt.configs.load_config import LoadConfig
31
32
  from sglang.srt.configs.model_config import AttentionArch, ModelConfig
@@ -45,10 +46,9 @@ from sglang.srt.layers.dp_attention import (
45
46
  initialize_dp_attention,
46
47
  )
47
48
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
48
- from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
49
- from sglang.srt.layers.quantization.deep_gemm import (
50
- _ENABLE_JIT_DEEPGEMM,
51
- update_deep_gemm_config,
49
+ from sglang.srt.layers.quantization import (
50
+ deep_gemm_wrapper,
51
+ monkey_patch_isinstance_for_vllm_base_layer,
52
52
  )
53
53
  from sglang.srt.layers.sampler import Sampler
54
54
  from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
@@ -205,8 +205,8 @@ class ModelRunner:
205
205
  min_per_gpu_memory = self.init_torch_distributed()
206
206
 
207
207
  # Update deep gemm configure
208
- if _ENABLE_JIT_DEEPGEMM:
209
- update_deep_gemm_config(gpu_id, server_args)
208
+ if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
209
+ deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
210
210
 
211
211
  # If it is a draft model, tp_group can be different
212
212
  self.initialize(min_per_gpu_memory)
@@ -1259,12 +1259,19 @@ class GGUFModelLoader(BaseModelLoader):
1259
1259
  ):
1260
1260
  model_config.hf_config.update({"tie_word_embeddings": True})
1261
1261
 
1262
+ target_device = torch.device(device_config.device)
1262
1263
  with set_default_torch_dtype(model_config.dtype):
1263
- with torch.device(device_config.device):
1264
+ with target_device:
1264
1265
  model = _initialize_model(model_config, self.load_config)
1265
1266
  model.load_weights(
1266
1267
  self._get_weights_iterator(local_model_path, gguf_weights_map)
1267
1268
  )
1269
+
1270
+ for _, module in model.named_modules():
1271
+ quant_method = getattr(module, "quant_method", None)
1272
+ if quant_method is not None:
1273
+ with device_loading_context(module, target_device):
1274
+ quant_method.process_weights_after_loading(module)
1268
1275
  return model
1269
1276
 
1270
1277
 
sglang/srt/models/bert.py CHANGED
@@ -11,12 +11,13 @@ from sglang.srt.layers.linear import (
11
11
  QKVParallelLinear,
12
12
  RowParallelLinear,
13
13
  )
14
- from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
14
+ from sglang.srt.layers.pooler import CrossEncodingPooler, Pooler, PoolingType
15
15
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
16
16
  from sglang.srt.layers.radix_attention import AttentionType, RadixAttention
17
17
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
18
18
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
19
19
  from sglang.srt.model_loader.weight_utils import default_weight_loader
20
+ from sglang.srt.utils import add_prefix
20
21
 
21
22
  BertConfig = None
22
23
 
@@ -50,7 +51,8 @@ class BertEmbedding(nn.Module):
50
51
  def forward(
51
52
  self,
52
53
  input_ids: torch.Tensor,
53
- position_ids: torch.Tensor,
54
+ positions: torch.Tensor,
55
+ forward_batch: ForwardBatch,
54
56
  ) -> torch.Tensor:
55
57
  input_shape = input_ids.size()
56
58
 
@@ -58,11 +60,14 @@ class BertEmbedding(nn.Module):
58
60
  inputs_embeds = self.word_embeddings(input_ids)
59
61
 
60
62
  # Position embeddings.
61
- position_embeddings = self.position_embeddings(position_ids)
63
+ position_embeddings = self.position_embeddings(positions)
62
64
 
63
- token_type_ids = torch.zeros(
64
- input_shape, dtype=torch.long, device=inputs_embeds.device
65
- )
65
+ token_type_ids = forward_batch.token_type_ids
66
+
67
+ if token_type_ids is None:
68
+ token_type_ids = torch.zeros(
69
+ input_shape, dtype=torch.long, device=inputs_embeds.device
70
+ )
66
71
 
67
72
  token_type_embeddings = self.token_type_embeddings(token_type_ids)
68
73
 
@@ -71,6 +76,25 @@ class BertEmbedding(nn.Module):
71
76
  return embeddings
72
77
 
73
78
 
79
+ class BertPooler(nn.Module):
80
+
81
+ def __init__(self, config: BertConfig):
82
+ super().__init__()
83
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
84
+ self.activation = nn.Tanh()
85
+
86
+ def forward(
87
+ self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
88
+ ) -> torch.Tensor:
89
+ # simply taking the hidden state corresponding
90
+ first_token_tensor = hidden_states[0, :]
91
+
92
+ pooled_output = self.dense(first_token_tensor)
93
+ pooled_output = self.activation(pooled_output)
94
+
95
+ return pooled_output
96
+
97
+
74
98
  class BertEncoder(nn.Module):
75
99
 
76
100
  def __init__(
@@ -113,6 +137,8 @@ class BertLayer(nn.Module):
113
137
  ):
114
138
  super().__init__()
115
139
 
140
+ self.layer_id = layer_id
141
+
116
142
  self.attention = BertAttention(
117
143
  hidden_size=config.hidden_size,
118
144
  num_attention_heads=config.num_attention_heads,
@@ -142,6 +168,7 @@ class BertLayer(nn.Module):
142
168
  attn_output = self.attention(hidden_states, forward_batch)
143
169
  intermediate_output = self.intermediate(attn_output)
144
170
  output = self.output(intermediate_output, attn_output)
171
+
145
172
  return output
146
173
 
147
174
 
@@ -326,16 +353,23 @@ class BertModel(nn.Module):
326
353
  *,
327
354
  config: BertConfig,
328
355
  quant_config: Optional[QuantizationConfig] = None,
356
+ use_bert_pooler: bool = False,
329
357
  prefix: str = "",
330
358
  ):
331
359
  super().__init__()
360
+ self.use_bert_pooler = use_bert_pooler
332
361
  self.config = config
333
362
  self.embeddings = BertEmbedding(config)
334
363
  self.encoder = BertEncoder(
335
- config=config, quant_config=quant_config, prefix=f"encoder"
364
+ config=config,
365
+ quant_config=quant_config,
366
+ prefix=add_prefix("encoder", prefix),
367
+ )
368
+ self.pooler = (
369
+ BertPooler(config)
370
+ if self.use_bert_pooler
371
+ else Pooler(pooling_type=PoolingType.LAST, normalize=True)
336
372
  )
337
- self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
338
- # self.pooler = BertPooler(config)
339
373
 
340
374
  @torch.no_grad()
341
375
  def forward(
@@ -351,11 +385,16 @@ class BertModel(nn.Module):
351
385
 
352
386
  hidden_states = self.embeddings(
353
387
  input_ids=input_ids,
354
- position_ids=positions,
388
+ positions=positions,
389
+ forward_batch=forward_batch,
355
390
  )
356
391
 
357
392
  hidden_states = self.encoder(hidden_states, forward_batch=forward_batch)
358
- return self.pooler(hidden_states, forward_batch)
393
+
394
+ if not self.use_bert_pooler:
395
+ hidden_states = self.pooler(hidden_states, forward_batch)
396
+
397
+ return hidden_states
359
398
 
360
399
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
361
400
  stacked_params_mapping = [
@@ -368,7 +407,7 @@ class BertModel(nn.Module):
368
407
  params_dict = dict(self.named_parameters())
369
408
  for name, loaded_weight in weights:
370
409
  name = name.replace("self", "self_attn")
371
- if "pooler" in name:
410
+ if not self.use_bert_pooler and "pooler" in name:
372
411
  continue
373
412
  for param_name, weight_name, shard_id in stacked_params_mapping:
374
413
 
@@ -395,4 +434,65 @@ class Contriever(BertModel):
395
434
  pass
396
435
 
397
436
 
398
- EntryClass = [BertModel, Contriever]
437
+ class BertForSequenceClassification(nn.Module):
438
+
439
+ def __init__(
440
+ self,
441
+ *,
442
+ config: BertConfig,
443
+ quant_config: Optional[QuantizationConfig] = None,
444
+ prefix: str = "",
445
+ ):
446
+ super().__init__()
447
+
448
+ self.num_labels = config.num_labels
449
+ self.bert = BertModel(
450
+ config=config,
451
+ quant_config=quant_config,
452
+ use_bert_pooler=True,
453
+ prefix=add_prefix("bert", prefix),
454
+ )
455
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
456
+ self.pooler = CrossEncodingPooler(config, self.classifier, self.bert.pooler)
457
+
458
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
459
+ self_weights = []
460
+
461
+ def weight_filter():
462
+ for name, weight in weights:
463
+ if name.startswith("bert."):
464
+ yield (name[len("bert.") :], weight)
465
+ else:
466
+ self_weights.append((name, weight))
467
+
468
+ self.bert.load_weights(weight_filter())
469
+
470
+ params_dict = dict(self.named_parameters())
471
+
472
+ for name, loaded_weight in self_weights:
473
+ if name.startswith("classifier"):
474
+ param = params_dict[name]
475
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
476
+ weight_loader(param, loaded_weight)
477
+
478
+ def forward(
479
+ self,
480
+ input_ids: torch.Tensor,
481
+ positions: torch.Tensor,
482
+ forward_batch: ForwardBatch,
483
+ input_embeds: torch.Tensor = None,
484
+ get_embedding: bool = False,
485
+ ) -> torch.Tensor:
486
+ assert get_embedding == True
487
+
488
+ hidden_states = self.bert(
489
+ input_ids=input_ids,
490
+ positions=positions,
491
+ forward_batch=forward_batch,
492
+ input_embeds=input_embeds,
493
+ get_embedding=get_embedding,
494
+ )
495
+ return self.pooler(hidden_states, forward_batch)
496
+
497
+
498
+ EntryClass = [BertModel, Contriever, BertForSequenceClassification]
@@ -51,11 +51,11 @@ from sglang.srt.layers.linear import (
51
51
  RowParallelLinear,
52
52
  )
53
53
  from sglang.srt.layers.logits_processor import LogitsProcessor
54
- from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
54
+ from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
55
55
  from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
56
56
  from sglang.srt.layers.moe.topk import select_experts
57
+ from sglang.srt.layers.quantization import deep_gemm_wrapper
57
58
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
58
- from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
59
59
  from sglang.srt.layers.quantization.fp8_kernel import (
60
60
  is_fp8_fnuz,
61
61
  per_tensor_quant_mla_fp8,
@@ -66,6 +66,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
66
66
  block_quant_to_tensor_quant,
67
67
  channel_quant_to_tensor_quant,
68
68
  normalize_e4m3fn_to_e4m3fnuz,
69
+ requant_weight_ue8m0_inplace,
69
70
  )
70
71
  from sglang.srt.layers.quantization.int8_utils import (
71
72
  block_dequant as int8_block_dequant,
@@ -109,10 +110,6 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
109
110
 
110
111
  if _is_cuda:
111
112
  from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
112
-
113
- from sglang.srt.layers.quantization.deep_gemm import (
114
- grouped_gemm_nt_f8f8bf16_masked as deep_gemm_grouped_gemm_nt_f8f8bf16_masked,
115
- )
116
113
  else:
117
114
  from vllm._custom_ops import awq_dequantize
118
115
 
@@ -980,7 +977,7 @@ class DeepseekV2AttentionMLA(nn.Module):
980
977
  q_nope_out = q_nope.new_empty(
981
978
  (self.num_local_heads, aligned_m, self.kv_lora_rank)
982
979
  )
983
- deep_gemm_grouped_gemm_nt_f8f8bf16_masked(
980
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
984
981
  (q_nope_val, q_nope_scale),
985
982
  (self.w_kc, self.w_scale_k),
986
983
  q_nope_out,
@@ -1013,7 +1010,11 @@ class DeepseekV2AttentionMLA(nn.Module):
1013
1010
  def forward_absorb_core(
1014
1011
  self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
1015
1012
  ):
1016
- if self.attention_backend == "fa3" or self.attention_backend == "flashinfer":
1013
+ if (
1014
+ self.attention_backend == "fa3"
1015
+ or self.attention_backend == "flashinfer"
1016
+ or self.attention_backend == "cutlass_mla"
1017
+ ):
1017
1018
  attn_output = self.attn_mqa(
1018
1019
  q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
1019
1020
  )
@@ -1032,7 +1033,7 @@ class DeepseekV2AttentionMLA(nn.Module):
1032
1033
  attn_bmm_output = attn_output.new_empty(
1033
1034
  (self.num_local_heads, aligned_m, self.v_head_dim)
1034
1035
  )
1035
- deep_gemm_grouped_gemm_nt_f8f8bf16_masked(
1036
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
1036
1037
  (attn_output_val, attn_output_scale),
1037
1038
  (self.w_vc, self.w_scale_v),
1038
1039
  attn_bmm_output,
@@ -1708,53 +1709,35 @@ class DeepseekV2ForCausalLM(nn.Module):
1708
1709
  def determine_num_fused_shared_experts(
1709
1710
  self, architecture: str = "DeepseekV3ForCausalLM"
1710
1711
  ):
1711
- self.num_fused_shared_experts = (
1712
- 0
1713
- if global_server_args_dict["disable_shared_experts_fusion"]
1714
- else self.config.n_shared_experts
1715
- )
1716
- if self.num_fused_shared_experts > 0:
1717
- # Only Deepseek V3/R1 can use shared experts fusion optimization now.
1718
- if (
1719
- not _is_cuda
1720
- or self.config.architectures[0] != architecture
1721
- or self.config.n_routed_experts != 256
1722
- ):
1723
- self.num_fused_shared_experts = 0
1724
- global_server_args_dict["disable_shared_experts_fusion"] = True
1725
- log_info_on_rank0(
1726
- logger,
1727
- "Only Deepseek V3/R1 on NV-platform can use shared experts fusion optimization. Shared experts fusion optimization is disabled.",
1728
- )
1729
- elif (
1730
- global_server_args_dict["enable_deepep_moe"]
1731
- or global_server_args_dict["enable_ep_moe"]
1732
- ):
1733
- self.num_fused_shared_experts = 0
1734
- global_server_args_dict["disable_shared_experts_fusion"] = True
1735
- log_info_on_rank0(
1736
- logger,
1737
- "Deepseek V3/R1 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode. Shared experts fusion optimization is disabled.",
1738
- )
1739
- elif self.num_fused_shared_experts == 0:
1740
- if (
1741
- _is_cuda
1742
- and torch.cuda.get_device_capability("cuda") >= (9, 0)
1743
- and self.config.architectures[0] == architecture
1744
- and self.config.n_routed_experts == 256
1745
- and (
1746
- not (
1747
- global_server_args_dict["enable_deepep_moe"]
1748
- or global_server_args_dict["enable_ep_moe"]
1749
- )
1750
- )
1751
- ):
1752
- self.num_fused_shared_experts = self.config.n_shared_experts
1753
- global_server_args_dict["disable_shared_experts_fusion"] = False
1754
- log_info_on_rank0(
1755
- logger,
1756
- "Deepseek V3/R1 with fp8/fp4 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
1757
- )
1712
+ self.num_fused_shared_experts = 0
1713
+ if global_server_args_dict["disable_shared_experts_fusion"]:
1714
+ return
1715
+
1716
+ # Only Deepseek V3/R1 can use shared experts fusion optimization now.
1717
+ disable_reason = None
1718
+ if (
1719
+ not _is_cuda
1720
+ or torch.cuda.get_device_capability("cuda") < (9, 0)
1721
+ or self.config.architectures[0] != architecture
1722
+ or self.config.n_routed_experts != 256
1723
+ or self.config.n_shared_experts != 1
1724
+ ):
1725
+ disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 90 can use shared experts fusion optimization."
1726
+ elif (
1727
+ global_server_args_dict["enable_deepep_moe"]
1728
+ or global_server_args_dict["enable_ep_moe"]
1729
+ ):
1730
+ disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
1731
+
1732
+ if disable_reason is not None:
1733
+ global_server_args_dict["disable_shared_experts_fusion"] = True
1734
+ log_info_on_rank0(
1735
+ logger,
1736
+ f"{disable_reason} Shared experts fusion optimization is disabled.",
1737
+ )
1738
+ return
1739
+
1740
+ self.num_fused_shared_experts = self.config.n_shared_experts
1758
1741
 
1759
1742
  def get_input_embeddings(self) -> nn.Embedding:
1760
1743
  return self.model.embed_tokens
@@ -1786,8 +1769,7 @@ class DeepseekV2ForCausalLM(nn.Module):
1786
1769
  for name in weight_names:
1787
1770
  if "kv_b_proj" in name:
1788
1771
  layer_id = int(name.split(".")[2])
1789
- # filter the nextn layer.
1790
- if layer_id != self.config.num_hidden_layers:
1772
+ if layer_id < self.config.num_hidden_layers:
1791
1773
  layer_ids.add(layer_id)
1792
1774
 
1793
1775
  for layer_id in layer_ids:
@@ -1847,8 +1829,10 @@ class DeepseekV2ForCausalLM(nn.Module):
1847
1829
  and weight_block_size[1] == 128
1848
1830
  and model_dtype == torch.bfloat16
1849
1831
  ):
1850
- if _ENABLE_JIT_DEEPGEMM and get_bool_env_var(
1851
- "SGL_USE_DEEPGEMM_BMM", "false"
1832
+ if (
1833
+ deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
1834
+ and not deep_gemm_wrapper.DEEPGEMM_BLACKWELL
1835
+ and get_bool_env_var("SGL_USE_DEEPGEMM_BMM", "false")
1852
1836
  ):
1853
1837
  block_scale = weight_scale
1854
1838
  use_deep_gemm_bmm = True
@@ -1932,6 +1916,65 @@ class DeepseekV2ForCausalLM(nn.Module):
1932
1916
  self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous())
1933
1917
  self_attn.use_deep_gemm_bmm = True
1934
1918
 
1919
+ if (
1920
+ deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
1921
+ and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
1922
+ ):
1923
+ self._weight_requant_ue8m0()
1924
+
1925
+ def _weight_requant_ue8m0(self):
1926
+ weight_block_size = self.quant_config.weight_block_size
1927
+
1928
+ moe_layers = list(
1929
+ range(
1930
+ self.config.first_k_dense_replace,
1931
+ self.config.num_hidden_layers,
1932
+ self.config.moe_layer_freq,
1933
+ )
1934
+ )
1935
+
1936
+ for layer_id in range(self.config.num_hidden_layers):
1937
+ layer = self.model.layers[layer_id]
1938
+
1939
+ for module in [
1940
+ layer.self_attn.fused_qkv_a_proj_with_mqa,
1941
+ layer.self_attn.q_b_proj,
1942
+ layer.self_attn.kv_b_proj,
1943
+ layer.self_attn.o_proj,
1944
+ ]:
1945
+ requant_weight_ue8m0_inplace(
1946
+ module.weight, module.weight_scale_inv, weight_block_size
1947
+ )
1948
+
1949
+ if layer_id in moe_layers:
1950
+ shared_experts = getattr(layer.mlp, "shared_experts", None)
1951
+ if shared_experts is not None:
1952
+ for module in [
1953
+ shared_experts.gate_up_proj,
1954
+ shared_experts.down_proj,
1955
+ ]:
1956
+ requant_weight_ue8m0_inplace(
1957
+ module.weight, module.weight_scale_inv, weight_block_size
1958
+ )
1959
+
1960
+ experts = layer.mlp.experts
1961
+ if isinstance(experts, DeepEPMoE):
1962
+ for w in [
1963
+ experts.w13_weight_fp8,
1964
+ experts.w2_weight_fp8,
1965
+ ]:
1966
+ requant_weight_ue8m0_inplace(w[0], w[1], weight_block_size)
1967
+ else:
1968
+ mlp = layer.mlp
1969
+ assert isinstance(mlp, DeepseekV2MLP)
1970
+ for module in [
1971
+ mlp.gate_up_proj,
1972
+ mlp.down_proj,
1973
+ ]:
1974
+ requant_weight_ue8m0_inplace(
1975
+ module.weight, module.weight_scale_inv, weight_block_size
1976
+ )
1977
+
1935
1978
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
1936
1979
 
1937
1980
  if is_nextn:
@@ -1952,101 +1995,6 @@ class DeepseekV2ForCausalLM(nn.Module):
1952
1995
  ("gate_up_proj", "gate_proj", 0),
1953
1996
  ("gate_up_proj", "up_proj", 1),
1954
1997
  ]
1955
- if self.num_fused_shared_experts > 0:
1956
- assert self.num_fused_shared_experts == 1
1957
- weights_list = list(weights)
1958
- weights_dict = dict(weights_list)
1959
- if self.quant_config is not None:
1960
- if self.quant_config.get_name() == "w8a8_int8":
1961
- suffix_list = [
1962
- "down_proj.weight",
1963
- "down_proj.weight_scale",
1964
- "gate_proj.weight",
1965
- "gate_proj.weight_scale",
1966
- "up_proj.weight",
1967
- "up_proj.weight_scale",
1968
- ]
1969
- elif (
1970
- self.quant_config.get_name() == "fp8"
1971
- or self.quant_config.get_name() == "blockwise_int8"
1972
- ):
1973
- suffix_list = [
1974
- "down_proj.weight",
1975
- "down_proj.weight_scale_inv",
1976
- "gate_proj.weight",
1977
- "gate_proj.weight_scale_inv",
1978
- "up_proj.weight",
1979
- "up_proj.weight_scale_inv",
1980
- ]
1981
- elif self.quant_config.get_name() == "awq":
1982
- suffix_list = [
1983
- "down_proj.qweight",
1984
- "down_proj.qzeros",
1985
- "down_proj.scales",
1986
- "gate_proj.qweight",
1987
- "gate_proj.qzeros",
1988
- "gate_proj.scales",
1989
- "up_proj.qweight",
1990
- "up_proj.qzeros",
1991
- "up_proj.scales",
1992
- ]
1993
- elif self.quant_config.get_name() == "modelopt_fp4":
1994
- suffix_list = [
1995
- "down_proj.weight",
1996
- "down_proj.weight_scale",
1997
- "down_proj.weight_scale_2",
1998
- "down_proj.input_scale",
1999
- "gate_proj.weight",
2000
- "gate_proj.weight_scale",
2001
- "gate_proj.weight_scale_2",
2002
- "gate_proj.input_scale",
2003
- "up_proj.weight",
2004
- "up_proj.weight_scale",
2005
- "up_proj.weight_scale_2",
2006
- "up_proj.input_scale",
2007
- ]
2008
- else:
2009
- raise ValueError(
2010
- f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
2011
- )
2012
- else:
2013
- suffix_list = [
2014
- "down_proj.weight",
2015
- "gate_proj.weight",
2016
- "up_proj.weight",
2017
- ]
2018
- names_to_remove = []
2019
-
2020
- moe_layers = (
2021
- range(
2022
- self.config.first_k_dense_replace,
2023
- self.config.num_hidden_layers,
2024
- self.config.moe_layer_freq,
2025
- )
2026
- if not is_nextn
2027
- else [nextn_layer_id]
2028
- )
2029
-
2030
- for moe_layer in tqdm(
2031
- moe_layers,
2032
- desc=f"Cloning {self.num_fused_shared_experts} "
2033
- "shared expert into MoE",
2034
- ):
2035
- for suffix in suffix_list:
2036
- shared_expert_weight_name = (
2037
- f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
2038
- )
2039
- weights_list.append(
2040
- (
2041
- f"model.layers.{moe_layer}."
2042
- f"mlp.experts."
2043
- f"{self.config.n_routed_experts + 0}"
2044
- f".{suffix}",
2045
- weights_dict[shared_expert_weight_name],
2046
- )
2047
- )
2048
- names_to_remove += [shared_expert_weight_name]
2049
- weights = [w for w in weights_list if w[0] not in names_to_remove]
2050
1998
 
2051
1999
  # Params for weights, fp8 weight scales, fp8 activation scales
2052
2000
  # (param_name, weight_name, expert_id, shard_id)
@@ -2072,9 +2020,19 @@ class DeepseekV2ForCausalLM(nn.Module):
2072
2020
  "hnorm",
2073
2021
  ]
2074
2022
 
2023
+ if self.num_fused_shared_experts > 0:
2024
+ assert self.num_fused_shared_experts == 1
2025
+ logger.info("Shared experts fusion optimization enabled.")
2026
+
2075
2027
  params_dict = dict(self.named_parameters())
2076
2028
  weight_names = []
2077
2029
  for name, loaded_weight in weights:
2030
+ if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name:
2031
+ name = name.replace(
2032
+ "mlp.shared_experts",
2033
+ f"mlp.experts.{self.config.n_routed_experts}",
2034
+ )
2035
+
2078
2036
  weight_names.append(name)
2079
2037
 
2080
2038
  if not is_nextn: