sglang 0.4.0.post2__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. sglang/bench_offline_throughput.py +0 -12
  2. sglang/bench_one_batch.py +0 -12
  3. sglang/bench_serving.py +1 -0
  4. sglang/srt/aio_rwlock.py +100 -0
  5. sglang/srt/configs/model_config.py +8 -1
  6. sglang/srt/layers/attention/flashinfer_backend.py +49 -5
  7. sglang/srt/layers/linear.py +20 -2
  8. sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +14 -39
  9. sglang/srt/layers/moe/fused_moe_native.py +46 -0
  10. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
  11. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +110 -98
  12. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +16 -48
  13. sglang/srt/layers/moe/topk.py +191 -0
  14. sglang/srt/layers/quantization/__init__.py +3 -3
  15. sglang/srt/layers/quantization/fp8.py +169 -32
  16. sglang/srt/layers/quantization/fp8_kernel.py +278 -0
  17. sglang/srt/layers/quantization/fp8_utils.py +90 -1
  18. sglang/srt/layers/torchao_utils.py +11 -15
  19. sglang/srt/managers/schedule_batch.py +16 -10
  20. sglang/srt/managers/scheduler.py +2 -2
  21. sglang/srt/managers/tokenizer_manager.py +86 -76
  22. sglang/srt/mem_cache/memory_pool.py +15 -8
  23. sglang/srt/model_executor/cuda_graph_runner.py +1 -1
  24. sglang/srt/model_executor/model_runner.py +6 -0
  25. sglang/srt/models/dbrx.py +1 -1
  26. sglang/srt/models/deepseek.py +1 -1
  27. sglang/srt/models/deepseek_v2.py +67 -18
  28. sglang/srt/models/grok.py +1 -1
  29. sglang/srt/models/mixtral.py +2 -2
  30. sglang/srt/models/olmoe.py +1 -1
  31. sglang/srt/models/qwen2_moe.py +1 -1
  32. sglang/srt/models/xverse_moe.py +1 -1
  33. sglang/srt/openai_api/adapter.py +4 -0
  34. sglang/srt/server.py +1 -0
  35. sglang/srt/utils.py +33 -44
  36. sglang/test/test_block_fp8.py +341 -0
  37. sglang/version.py +1 -1
  38. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.dist-info}/METADATA +3 -3
  39. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.dist-info}/RECORD +44 -40
  40. sglang/srt/layers/fused_moe_patch.py +0 -133
  41. /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
  42. /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
  43. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.dist-info}/LICENSE +0 -0
  44. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.dist-info}/WHEEL +0 -0
  45. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ import signal
22
22
  import sys
23
23
  import time
24
24
  import uuid
25
- from typing import Any, Dict, List, Optional, Union
25
+ from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
26
26
 
27
27
  import fastapi
28
28
  import uvloop
@@ -30,6 +30,7 @@ import zmq
30
30
  import zmq.asyncio
31
31
  from fastapi import BackgroundTasks
32
32
 
33
+ from sglang.srt.aio_rwlock import RWLock
33
34
  from sglang.srt.configs.model_config import ModelConfig
34
35
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
35
36
  from sglang.srt.managers.image_processor import (
@@ -62,7 +63,11 @@ from sglang.srt.managers.io_struct import (
62
63
  from sglang.srt.metrics.collector import TokenizerMetricsCollector
63
64
  from sglang.srt.sampling.sampling_params import SamplingParams
64
65
  from sglang.srt.server_args import PortArgs, ServerArgs
65
- from sglang.srt.utils import get_zmq_socket, kill_process_tree
66
+ from sglang.srt.utils import (
67
+ dataclass_to_string_truncated,
68
+ get_zmq_socket,
69
+ kill_process_tree,
70
+ )
66
71
 
67
72
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
68
73
 
@@ -82,6 +87,9 @@ class ReqState:
82
87
  created_time: float
83
88
  first_token_time: Optional[float] = None
84
89
 
90
+ # For streaming output
91
+ last_output_offset: int = 0
92
+
85
93
 
86
94
  class TokenizerManager:
87
95
  """TokenizerManager is a process that tokenizes the text."""
@@ -120,6 +128,7 @@ class TokenizerManager:
120
128
 
121
129
  self.is_generation = self.model_config.is_generation
122
130
  self.context_len = self.model_config.context_len
131
+ self.image_token_id = self.model_config.image_token_id
123
132
 
124
133
  # Create image processor placeholder
125
134
  self.image_processor = get_dummy_image_processor()
@@ -152,9 +161,12 @@ class TokenizerManager:
152
161
  self.to_create_loop = True
153
162
  self.rid_to_state: Dict[str, ReqState] = {}
154
163
 
155
- # For update model weights
156
- self.model_update_lock = asyncio.Lock()
157
- self.model_update_result = None
164
+ # The event to notify the weight sync is finished.
165
+ self.model_update_lock = RWLock()
166
+ self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
167
+ None
168
+ )
169
+ self.asyncio_tasks = set()
158
170
 
159
171
  # For session info
160
172
  self.session_futures = {} # session_id -> asyncio event
@@ -181,9 +193,6 @@ class TokenizerManager:
181
193
  if self.to_create_loop:
182
194
  self.create_handle_loop()
183
195
 
184
- while self.model_update_lock.locked():
185
- await asyncio.sleep(0.001)
186
-
187
196
  if isinstance(obj, EmbeddingReqInput) and self.is_generation:
188
197
  raise ValueError(
189
198
  "This model does not appear to be an embedding model by default. "
@@ -191,17 +200,24 @@ class TokenizerManager:
191
200
  )
192
201
 
193
202
  obj.normalize_batch_and_arguments()
194
- is_single = obj.is_single
195
- if is_single:
196
- tokenized_obj = await self._tokenize_one_request(obj)
197
- self.send_to_scheduler.send_pyobj(tokenized_obj)
198
- async for response in self._wait_one_response(obj, request, created_time):
199
- yield response
200
- else:
201
- async for response in self._handle_batch_request(
202
- obj, request, created_time
203
- ):
204
- yield response
203
+
204
+ if self.server_args.log_requests:
205
+ logger.info(f"Receive: obj={dataclass_to_string_truncated(obj)}")
206
+
207
+ async with self.model_update_lock.reader_lock:
208
+ is_single = obj.is_single
209
+ if is_single:
210
+ tokenized_obj = await self._tokenize_one_request(obj)
211
+ self.send_to_scheduler.send_pyobj(tokenized_obj)
212
+ async for response in self._wait_one_response(
213
+ obj, request, created_time
214
+ ):
215
+ yield response
216
+ else:
217
+ async for response in self._handle_batch_request(
218
+ obj, request, created_time
219
+ ):
220
+ yield response
205
221
 
206
222
  async def _tokenize_one_request(
207
223
  self,
@@ -215,7 +231,7 @@ class TokenizerManager:
215
231
  if not self.server_args.disable_radix_cache:
216
232
  raise ValueError(
217
233
  "input_embeds is provided while disable_radix_cache is False. "
218
- "Please add `--disable-radix-cach` when you launch the server "
234
+ "Please add `--disable-radix-cache` when you launch the server "
219
235
  "if you want to use input_embeds as inputs."
220
236
  )
221
237
  input_embeds = obj.input_embeds
@@ -301,8 +317,8 @@ class TokenizerManager:
301
317
  state.out_list = []
302
318
  if state.finished:
303
319
  if self.server_args.log_requests:
304
- # Log requests
305
- logger.info(f"in={obj}, out={out}")
320
+ msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}"
321
+ logger.info(msg)
306
322
  del self.rid_to_state[obj.rid]
307
323
  yield out
308
324
  break
@@ -423,55 +439,52 @@ class TokenizerManager:
423
439
  self,
424
440
  obj: UpdateWeightFromDiskReqInput,
425
441
  request: Optional[fastapi.Request] = None,
426
- ):
442
+ ) -> Tuple[bool, str]:
427
443
  if self.to_create_loop:
428
444
  self.create_handle_loop()
429
445
 
430
446
  # default the load format to the server_args
431
447
  if obj.load_format is None:
432
448
  obj.load_format = self.server_args.load_format
449
+ logger.info("Start update_weights. Load format=%s", obj.load_format)
433
450
 
434
- if not self.model_update_lock.locked():
435
-
436
- async with self.model_update_lock:
437
- # wait for the previous generation requests to finish
438
- for i in range(3):
439
- while len(self.rid_to_state) > 0:
440
- await asyncio.sleep(0.001)
441
- # FIXME: We add some sleep here to avoid some race conditions.
442
- # We can use a read-write lock as a better fix.
443
- await asyncio.sleep(0.01)
444
- self.send_to_scheduler.send_pyobj(obj)
445
- self.model_update_result = asyncio.Future()
451
+ if True:
452
+ # Hold the lock if it is not async. This means that weight sync
453
+ # cannot run while requests are in progress.
454
+ async with self.model_update_lock.writer_lock:
455
+ return await self._wait_for_model_update_from_disk(obj)
446
456
 
447
- if self.server_args.dp_size == 1:
448
- result = await self.model_update_result
449
- if result.success:
450
- self.server_args.model_path = obj.model_path
451
- self.server_args.load_format = obj.load_format
452
- self.model_path = obj.model_path
453
- return result.success, result.message
454
- else: # self.server_args.dp_size > 1
455
- self.model_update_tmp = []
456
- result = await self.model_update_result
457
-
458
- all_success = all([r.success for r in result])
459
- if all_success is True:
460
- self.server_args.model_path = obj.model_path
461
- self.server_args.load_format = obj.load_format
462
- self.model_path = obj.model_path
463
- all_message = [r.message for r in result]
464
- all_message = " | ".join(all_message)
465
- return all_success, all_message
466
-
467
- else:
468
- return False, "Another update is in progress. Please try again later."
457
+ async def _wait_for_model_update_from_disk(
458
+ self, obj: UpdateWeightFromDiskReqInput
459
+ ) -> Tuple[bool, str, int]:
460
+ self.send_to_scheduler.send_pyobj(obj)
461
+ self.model_update_result = asyncio.Future()
462
+ if self.server_args.dp_size == 1:
463
+ result = await self.model_update_result
464
+ if result.success:
465
+ self.served_model_name = obj.model_path
466
+ self.server_args.model_path = obj.model_path
467
+ self.server_args.load_format = obj.load_format
468
+ self.model_path = obj.model_path
469
+ return result.success, result.message
470
+ else: # self.server_args.dp_size > 1
471
+ self.model_update_tmp = []
472
+ result = await self.model_update_result
473
+
474
+ all_success = all([r.success for r in result])
475
+ if all_success is True:
476
+ self.server_args.model_path = obj.model_path
477
+ self.server_args.load_format = obj.load_format
478
+ self.model_path = obj.model_path
479
+ all_message = [r.message for r in result]
480
+ all_message = " | ".join(all_message)
481
+ return all_success, all_message
469
482
 
470
483
  async def init_weights_update_group(
471
484
  self,
472
485
  obj: InitWeightsUpdateGroupReqInput,
473
486
  request: Optional[fastapi.Request] = None,
474
- ) -> bool:
487
+ ) -> Tuple[bool, str]:
475
488
  if self.to_create_loop:
476
489
  self.create_handle_loop()
477
490
  self.send_to_scheduler.send_pyobj(obj)
@@ -487,25 +500,22 @@ class TokenizerManager:
487
500
  self,
488
501
  obj: UpdateWeightsFromDistributedReqInput,
489
502
  request: Optional[fastapi.Request] = None,
490
- ):
503
+ ) -> Tuple[bool, str]:
491
504
  if self.to_create_loop:
492
505
  self.create_handle_loop()
493
506
 
494
- if not self.model_update_lock.locked():
495
- async with self.model_update_lock:
496
- self.send_to_scheduler.send_pyobj(obj)
497
- self.parameter_update_result = asyncio.Future()
498
- assert (
499
- self.server_args.dp_size == 1
500
- ), "dp_size must be for update weights from distributed"
501
- result = await self.parameter_update_result
502
- return result.success, result.message
503
- else:
504
- logger.error("Another parameter update is in progress in tokenizer manager")
505
- return (
506
- False,
507
- "Another parameter update is in progress. Please try again later.",
508
- )
507
+ # This means that weight sync
508
+ # cannot run while requests are in progress.
509
+ async with self.model_update_lock.writer_lock:
510
+ self.send_to_scheduler.send_pyobj(obj)
511
+ self.parameter_update_result: Awaitable[
512
+ UpdateWeightsFromDistributedReqOutput
513
+ ] = asyncio.Future()
514
+ assert (
515
+ self.server_args.dp_size == 1
516
+ ), "dp_size must be for update weights from distributed"
517
+ result = await self.parameter_update_result
518
+ return result.success, result.message
509
519
 
510
520
  async def get_weights_by_name(
511
521
  self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
@@ -564,11 +574,11 @@ class TokenizerManager:
564
574
 
565
575
  self.to_create_loop = False
566
576
  loop = asyncio.get_event_loop()
567
- loop.create_task(self.handle_loop())
577
+ self.asyncio_tasks.add(loop.create_task(self.handle_loop()))
568
578
 
569
579
  signal_handler = SignalHandler(self)
570
580
  loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
571
- loop.create_task(self.sigterm_watchdog())
581
+ self.asyncio_tasks.add(loop.create_task(self.sigterm_watchdog()))
572
582
 
573
583
  async def sigterm_watchdog(self):
574
584
  while not self.gracefully_exit:
@@ -184,26 +184,35 @@ class MHATokenToKVPool(BaseTokenToKVPool):
184
184
  device: str,
185
185
  ):
186
186
  super().__init__(size, dtype, device)
187
+ self.head_num = head_num
188
+ self.head_dim = head_dim
189
+ self.layer_num = layer_num
190
+ self._create_buffers()
187
191
 
192
+ def _create_buffers(self):
188
193
  # [size, head_num, head_dim] for each layer
189
194
  # The padded slot 0 is used for writing dummy outputs from padded tokens.
190
195
  self.k_buffer = [
191
196
  torch.empty(
192
- (size + 1, head_num, head_dim),
197
+ (self.size + 1, self.head_num, self.head_dim),
193
198
  dtype=self.store_dtype,
194
- device=device,
199
+ device=self.device,
195
200
  )
196
- for _ in range(layer_num)
201
+ for _ in range(self.layer_num)
197
202
  ]
198
203
  self.v_buffer = [
199
204
  torch.empty(
200
- (size + 1, head_num, head_dim),
205
+ (self.size + 1, self.head_num, self.head_dim),
201
206
  dtype=self.store_dtype,
202
- device=device,
207
+ device=self.device,
203
208
  )
204
- for _ in range(layer_num)
209
+ for _ in range(self.layer_num)
205
210
  ]
206
211
 
212
+ def _clear_buffers(self):
213
+ del self.k_buffer
214
+ del self.v_buffer
215
+
207
216
  def get_key_buffer(self, layer_id: int):
208
217
  if self.store_dtype != self.dtype:
209
218
  return self.k_buffer[layer_id].view(self.dtype)
@@ -245,7 +254,6 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
245
254
 
246
255
 
247
256
  class MLATokenToKVPool(BaseTokenToKVPool):
248
-
249
257
  def __init__(
250
258
  self,
251
259
  size: int,
@@ -298,7 +306,6 @@ class MLATokenToKVPool(BaseTokenToKVPool):
298
306
 
299
307
 
300
308
  class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
301
-
302
309
  def __init__(
303
310
  self,
304
311
  size: int,
@@ -25,12 +25,12 @@ from vllm.distributed import get_tensor_model_parallel_rank
25
25
  from vllm.distributed.parallel_state import graph_capture
26
26
  from vllm.model_executor.custom_op import CustomOp
27
27
 
28
- from sglang.srt.layers.fused_moe_patch import fused_moe_forward_native
29
28
  from sglang.srt.layers.logits_processor import (
30
29
  LogitsMetadata,
31
30
  LogitsProcessor,
32
31
  LogitsProcessorOutput,
33
32
  )
33
+ from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
34
34
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
35
35
  from sglang.srt.utils import maybe_torch_compile, monkey_patch_vllm_all_gather
36
36
 
@@ -95,6 +95,12 @@ class ModelRunner:
95
95
  ):
96
96
  logger.info("MLA optimization is turned on. Use triton backend.")
97
97
  self.server_args.attention_backend = "triton"
98
+ # FIXME(HandH1998)
99
+ if (
100
+ "DeepseekV3ForCausalLM" in self.model_config.hf_config.architectures
101
+ and not self.server_args.disable_cuda_graph
102
+ ):
103
+ self.server_args.disable_cuda_graph = True
98
104
 
99
105
  if self.server_args.enable_double_sparsity:
100
106
  logger.info(
sglang/srt/models/dbrx.py CHANGED
@@ -27,13 +27,13 @@ from vllm.distributed import (
27
27
  from vllm.model_executor.layers.rotary_embedding import get_rope
28
28
  from vllm.transformers_utils.configs.dbrx import DbrxConfig
29
29
 
30
- from sglang.srt.layers.fused_moe_triton import fused_moe
31
30
  from sglang.srt.layers.linear import (
32
31
  QKVParallelLinear,
33
32
  ReplicatedLinear,
34
33
  RowParallelLinear,
35
34
  )
36
35
  from sglang.srt.layers.logits_processor import LogitsProcessor
36
+ from sglang.srt.layers.moe.fused_moe_triton import fused_moe
37
37
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
38
38
  from sglang.srt.layers.radix_attention import RadixAttention
39
39
  from sglang.srt.layers.vocab_parallel_embedding import (
@@ -29,7 +29,6 @@ from vllm.distributed import (
29
29
  from vllm.model_executor.layers.rotary_embedding import get_rope
30
30
 
31
31
  from sglang.srt.layers.activation import SiluAndMul
32
- from sglang.srt.layers.fused_moe_triton import fused_moe
33
32
  from sglang.srt.layers.layernorm import RMSNorm
34
33
  from sglang.srt.layers.linear import (
35
34
  MergedColumnParallelLinear,
@@ -38,6 +37,7 @@ from sglang.srt.layers.linear import (
38
37
  RowParallelLinear,
39
38
  )
40
39
  from sglang.srt.layers.logits_processor import LogitsProcessor
40
+ from sglang.srt.layers.moe.fused_moe_triton import fused_moe
41
41
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
43
  from sglang.srt.layers.vocab_parallel_embedding import (
@@ -19,6 +19,7 @@
19
19
  from typing import Any, Dict, Iterable, Optional, Tuple
20
20
 
21
21
  import torch
22
+ import torch.nn.functional as F
22
23
  from torch import nn
23
24
  from transformers import PretrainedConfig
24
25
  from vllm import _custom_ops as ops
@@ -31,8 +32,6 @@ from vllm.distributed import (
31
32
  from vllm.model_executor.layers.rotary_embedding import get_rope
32
33
 
33
34
  from sglang.srt.layers.activation import SiluAndMul
34
- from sglang.srt.layers.ep_moe.layer import EPMoE
35
- from sglang.srt.layers.fused_moe_triton import FusedMoE
36
35
  from sglang.srt.layers.layernorm import RMSNorm
37
36
  from sglang.srt.layers.linear import (
38
37
  ColumnParallelLinear,
@@ -41,7 +40,13 @@ from sglang.srt.layers.linear import (
41
40
  RowParallelLinear,
42
41
  )
43
42
  from sglang.srt.layers.logits_processor import LogitsProcessor
43
+ from sglang.srt.layers.moe.ep_moe.layer import EPMoE
44
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
44
45
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
46
+ from sglang.srt.layers.quantization.fp8_utils import (
47
+ block_quant_to_tensor_quant,
48
+ input_to_float8,
49
+ )
45
50
  from sglang.srt.layers.radix_attention import RadixAttention
46
51
  from sglang.srt.layers.vocab_parallel_embedding import (
47
52
  ParallelLMHead,
@@ -90,6 +95,24 @@ class DeepseekV2MLP(nn.Module):
90
95
  return x
91
96
 
92
97
 
98
+ class MoEGate(nn.Module):
99
+ def __init__(self, config):
100
+ super().__init__()
101
+ self.weight = nn.Parameter(
102
+ torch.empty((config.n_routed_experts, config.hidden_size))
103
+ )
104
+ if config.topk_method == "noaux_tc":
105
+ self.e_score_correction_bias = nn.Parameter(
106
+ torch.empty((config.n_routed_experts))
107
+ )
108
+ else:
109
+ self.e_score_correction_bias = None
110
+
111
+ def forward(self, hidden_states):
112
+ logits = F.linear(hidden_states, self.weight, None)
113
+ return logits
114
+
115
+
93
116
  class DeepseekV2MoE(nn.Module):
94
117
 
95
118
  def __init__(
@@ -114,6 +137,8 @@ class DeepseekV2MoE(nn.Module):
114
137
  "Only silu is supported for now."
115
138
  )
116
139
 
140
+ self.gate = MoEGate(config=config)
141
+
117
142
  MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
118
143
  self.experts = MoEImpl(
119
144
  num_experts=config.n_routed_experts,
@@ -125,11 +150,9 @@ class DeepseekV2MoE(nn.Module):
125
150
  use_grouped_topk=True,
126
151
  num_expert_group=config.n_group,
127
152
  topk_group=config.topk_group,
153
+ correction_bias=self.gate.e_score_correction_bias,
128
154
  )
129
155
 
130
- self.gate = ReplicatedLinear(
131
- config.hidden_size, config.n_routed_experts, bias=False, quant_config=None
132
- )
133
156
  if config.n_shared_experts is not None:
134
157
  intermediate_size = config.moe_intermediate_size * config.n_shared_experts
135
158
  self.shared_experts = DeepseekV2MLP(
@@ -146,7 +169,7 @@ class DeepseekV2MoE(nn.Module):
146
169
  if self.n_shared_experts is not None:
147
170
  shared_output = self.shared_experts(hidden_states)
148
171
  # router_logits: (num_tokens, n_experts)
149
- router_logits, _ = self.gate(hidden_states)
172
+ router_logits = self.gate(hidden_states)
150
173
  final_hidden_states = (
151
174
  self.experts(hidden_states=hidden_states, router_logits=router_logits)
152
175
  * self.routed_scaling_factor
@@ -167,15 +190,6 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
167
190
  return 0.1 * mscale * math.log(scale) + 1.0
168
191
 
169
192
 
170
- def input_to_float8(x, dtype=torch.float8_e4m3fn):
171
- finfo = torch.finfo(dtype)
172
- min_val, max_val = x.aminmax()
173
- amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
174
- scale = finfo.max / amax
175
- x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
176
- return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
177
-
178
-
179
193
  class DeepseekV2Attention(nn.Module):
180
194
 
181
195
  def __init__(
@@ -439,7 +453,10 @@ class DeepseekV2AttentionMLA(nn.Module):
439
453
  quant_config=quant_config,
440
454
  )
441
455
  self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
442
- rope_scaling["rope_type"] = "deepseek_yarn"
456
+
457
+ if rope_scaling:
458
+ rope_scaling["rope_type"] = "deepseek_yarn"
459
+
443
460
  self.rotary_emb = get_rope(
444
461
  qk_rope_head_dim,
445
462
  rotary_dim=qk_rope_head_dim,
@@ -454,6 +471,8 @@ class DeepseekV2AttentionMLA(nn.Module):
454
471
  scaling_factor = rope_scaling["factor"]
455
472
  mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
456
473
  self.scaling = self.scaling * mscale * mscale
474
+ else:
475
+ self.rotary_emb.forward = self.rotary_emb.forward_native
457
476
 
458
477
  self.attn_mqa = RadixAttention(
459
478
  self.num_local_heads,
@@ -845,6 +864,16 @@ class DeepseekV2ForCausalLM(nn.Module):
845
864
 
846
865
  params_dict = dict(self.named_parameters())
847
866
  for name, loaded_weight in weights:
867
+ # TODO(HandH1998): Modify it when nextn is supported.
868
+ if hasattr(self.config, "num_nextn_predict_layers"):
869
+ num_nextn_layers = self.config.num_nextn_predict_layers
870
+ if num_nextn_layers > 0 and name.startswith("model.layers"):
871
+ name_list = name.split(".")
872
+ if (
873
+ len(name_list) >= 3
874
+ and int(name_list[2]) >= self.config.num_hidden_layers
875
+ ):
876
+ continue
848
877
  if "rotary_emb.inv_freq" in name:
849
878
  continue
850
879
  for param_name, weight_name, shard_id in stacked_params_mapping:
@@ -909,13 +938,33 @@ class DeepseekV2ForCausalLM(nn.Module):
909
938
  ).T
910
939
  else:
911
940
  w = self_attn.kv_b_proj.weight
941
+ # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
942
+ # This may affect the accuracy of fp8 model.
943
+ if (
944
+ hasattr(self.quant_config, "weight_block_size")
945
+ and w.dtype == torch.float8_e4m3fn
946
+ ):
947
+ weight_block_size = self.quant_config.weight_block_size
948
+ if weight_block_size is not None:
949
+ assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
950
+ w, scale = block_quant_to_tensor_quant(
951
+ w, self_attn.kv_b_proj.weight_scale_inv, weight_block_size
952
+ )
953
+ self_attn.w_scale = scale
912
954
  w_kc, w_vc = w.unflatten(
913
955
  0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
914
956
  ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
915
957
  self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
916
958
  self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
917
- if hasattr(self_attn.kv_b_proj, "weight_scale"):
959
+ if (
960
+ hasattr(self_attn.kv_b_proj, "weight_scale")
961
+ and self_attn.w_scale is None
962
+ ):
918
963
  self_attn.w_scale = self_attn.kv_b_proj.weight_scale
919
964
 
920
965
 
921
- EntryClass = DeepseekV2ForCausalLM
966
+ class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
967
+ pass
968
+
969
+
970
+ EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM]
sglang/srt/models/grok.py CHANGED
@@ -26,7 +26,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size
26
26
  from vllm.model_executor.layers.rotary_embedding import get_rope
27
27
 
28
28
  from sglang.srt.layers.activation import GeluAndMul
29
- from sglang.srt.layers.fused_moe_triton import FusedMoE
30
29
  from sglang.srt.layers.layernorm import RMSNorm
31
30
  from sglang.srt.layers.linear import (
32
31
  MergedColumnParallelLinear,
@@ -35,6 +34,7 @@ from sglang.srt.layers.linear import (
35
34
  RowParallelLinear,
36
35
  )
37
36
  from sglang.srt.layers.logits_processor import LogitsProcessor
37
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
38
38
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
39
39
  from sglang.srt.layers.radix_attention import RadixAttention
40
40
  from sglang.srt.layers.vocab_parallel_embedding import (
@@ -27,8 +27,6 @@ from vllm.distributed import (
27
27
  )
28
28
  from vllm.model_executor.layers.rotary_embedding import get_rope
29
29
 
30
- from sglang.srt.layers.ep_moe.layer import EPMoE
31
- from sglang.srt.layers.fused_moe_triton import FusedMoE
32
30
  from sglang.srt.layers.layernorm import RMSNorm
33
31
  from sglang.srt.layers.linear import (
34
32
  QKVParallelLinear,
@@ -36,6 +34,8 @@ from sglang.srt.layers.linear import (
36
34
  RowParallelLinear,
37
35
  )
38
36
  from sglang.srt.layers.logits_processor import LogitsProcessor
37
+ from sglang.srt.layers.moe.ep_moe.layer import EPMoE
38
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
39
39
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
40
40
  from sglang.srt.layers.radix_attention import RadixAttention
41
41
  from sglang.srt.layers.vocab_parallel_embedding import (
@@ -36,9 +36,9 @@ from vllm.model_executor.layers.linear import (
36
36
  from vllm.model_executor.layers.rotary_embedding import get_rope
37
37
 
38
38
  from sglang.srt.layers.activation import SiluAndMul
39
- from sglang.srt.layers.fused_moe_triton import FusedMoE
40
39
  from sglang.srt.layers.layernorm import RMSNorm
41
40
  from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
41
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
42
42
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
43
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
44
  from sglang.srt.layers.vocab_parallel_embedding import (
@@ -29,7 +29,6 @@ from vllm.distributed import (
29
29
  from vllm.model_executor.layers.rotary_embedding import get_rope
30
30
 
31
31
  from sglang.srt.layers.activation import SiluAndMul
32
- from sglang.srt.layers.fused_moe_triton import FusedMoE
33
32
  from sglang.srt.layers.layernorm import RMSNorm
34
33
  from sglang.srt.layers.linear import (
35
34
  MergedColumnParallelLinear,
@@ -38,6 +37,7 @@ from sglang.srt.layers.linear import (
38
37
  RowParallelLinear,
39
38
  )
40
39
  from sglang.srt.layers.logits_processor import LogitsProcessor
40
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
41
41
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
43
  from sglang.srt.layers.vocab_parallel_embedding import (
@@ -33,8 +33,8 @@ from vllm.model_executor.layers.linear import (
33
33
  )
34
34
  from vllm.model_executor.layers.rotary_embedding import get_rope
35
35
 
36
- from sglang.srt.layers.fused_moe_triton import fused_moe
37
36
  from sglang.srt.layers.logits_processor import LogitsProcessor
37
+ from sglang.srt.layers.moe.fused_moe_triton import fused_moe
38
38
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
39
39
  from sglang.srt.layers.radix_attention import RadixAttention
40
40
  from sglang.srt.layers.vocab_parallel_embedding import (
@@ -858,6 +858,7 @@ def v1_chat_generate_request(
858
858
  logprob_start_lens = []
859
859
  top_logprobs_nums = []
860
860
  modalities_list = []
861
+ lora_paths = []
861
862
 
862
863
  # NOTE: with openai API, the prompt's logprobs are always not computed
863
864
 
@@ -920,6 +921,7 @@ def v1_chat_generate_request(
920
921
  return_logprobs.append(request.logprobs)
921
922
  logprob_start_lens.append(-1)
922
923
  top_logprobs_nums.append(request.top_logprobs or 0)
924
+ lora_paths.append(request.lora_path)
923
925
 
924
926
  sampling_params = {
925
927
  "temperature": request.temperature,
@@ -958,6 +960,7 @@ def v1_chat_generate_request(
958
960
  logprob_start_lens = logprob_start_lens[0]
959
961
  top_logprobs_nums = top_logprobs_nums[0]
960
962
  modalities_list = modalities_list[0]
963
+ lora_paths = lora_paths[0]
961
964
  else:
962
965
  if isinstance(input_ids[0], str):
963
966
  prompt_kwargs = {"text": input_ids}
@@ -975,6 +978,7 @@ def v1_chat_generate_request(
975
978
  return_text_in_logprobs=True,
976
979
  rid=request_ids,
977
980
  modalities=modalities_list,
981
+ lora_path=lora_paths,
978
982
  )
979
983
 
980
984
  return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]