sglang 0.4.0.post2__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +0 -12
- sglang/bench_one_batch.py +0 -12
- sglang/bench_serving.py +1 -0
- sglang/srt/aio_rwlock.py +100 -0
- sglang/srt/configs/model_config.py +8 -1
- sglang/srt/layers/attention/flashinfer_backend.py +49 -5
- sglang/srt/layers/linear.py +20 -2
- sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +14 -39
- sglang/srt/layers/moe/fused_moe_native.py +46 -0
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +110 -98
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +16 -48
- sglang/srt/layers/moe/topk.py +191 -0
- sglang/srt/layers/quantization/__init__.py +3 -3
- sglang/srt/layers/quantization/fp8.py +169 -32
- sglang/srt/layers/quantization/fp8_kernel.py +278 -0
- sglang/srt/layers/quantization/fp8_utils.py +90 -1
- sglang/srt/layers/torchao_utils.py +11 -15
- sglang/srt/managers/schedule_batch.py +16 -10
- sglang/srt/managers/scheduler.py +2 -2
- sglang/srt/managers/tokenizer_manager.py +86 -76
- sglang/srt/mem_cache/memory_pool.py +15 -8
- sglang/srt/model_executor/cuda_graph_runner.py +1 -1
- sglang/srt/model_executor/model_runner.py +6 -0
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +67 -18
- sglang/srt/models/grok.py +1 -1
- sglang/srt/models/mixtral.py +2 -2
- sglang/srt/models/olmoe.py +1 -1
- sglang/srt/models/qwen2_moe.py +1 -1
- sglang/srt/models/xverse_moe.py +1 -1
- sglang/srt/openai_api/adapter.py +4 -0
- sglang/srt/server.py +1 -0
- sglang/srt/utils.py +33 -44
- sglang/test/test_block_fp8.py +341 -0
- sglang/version.py +1 -1
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.dist-info}/METADATA +3 -3
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.dist-info}/RECORD +44 -40
- sglang/srt/layers/fused_moe_patch.py +0 -133
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.dist-info}/LICENSE +0 -0
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.dist-info}/WHEEL +0 -0
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ import signal
|
|
22
22
|
import sys
|
23
23
|
import time
|
24
24
|
import uuid
|
25
|
-
from typing import Any, Dict, List, Optional, Union
|
25
|
+
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
|
26
26
|
|
27
27
|
import fastapi
|
28
28
|
import uvloop
|
@@ -30,6 +30,7 @@ import zmq
|
|
30
30
|
import zmq.asyncio
|
31
31
|
from fastapi import BackgroundTasks
|
32
32
|
|
33
|
+
from sglang.srt.aio_rwlock import RWLock
|
33
34
|
from sglang.srt.configs.model_config import ModelConfig
|
34
35
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
35
36
|
from sglang.srt.managers.image_processor import (
|
@@ -62,7 +63,11 @@ from sglang.srt.managers.io_struct import (
|
|
62
63
|
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
63
64
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
64
65
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
65
|
-
from sglang.srt.utils import
|
66
|
+
from sglang.srt.utils import (
|
67
|
+
dataclass_to_string_truncated,
|
68
|
+
get_zmq_socket,
|
69
|
+
kill_process_tree,
|
70
|
+
)
|
66
71
|
|
67
72
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
68
73
|
|
@@ -82,6 +87,9 @@ class ReqState:
|
|
82
87
|
created_time: float
|
83
88
|
first_token_time: Optional[float] = None
|
84
89
|
|
90
|
+
# For streaming output
|
91
|
+
last_output_offset: int = 0
|
92
|
+
|
85
93
|
|
86
94
|
class TokenizerManager:
|
87
95
|
"""TokenizerManager is a process that tokenizes the text."""
|
@@ -120,6 +128,7 @@ class TokenizerManager:
|
|
120
128
|
|
121
129
|
self.is_generation = self.model_config.is_generation
|
122
130
|
self.context_len = self.model_config.context_len
|
131
|
+
self.image_token_id = self.model_config.image_token_id
|
123
132
|
|
124
133
|
# Create image processor placeholder
|
125
134
|
self.image_processor = get_dummy_image_processor()
|
@@ -152,9 +161,12 @@ class TokenizerManager:
|
|
152
161
|
self.to_create_loop = True
|
153
162
|
self.rid_to_state: Dict[str, ReqState] = {}
|
154
163
|
|
155
|
-
#
|
156
|
-
self.model_update_lock =
|
157
|
-
self.model_update_result =
|
164
|
+
# The event to notify the weight sync is finished.
|
165
|
+
self.model_update_lock = RWLock()
|
166
|
+
self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
|
167
|
+
None
|
168
|
+
)
|
169
|
+
self.asyncio_tasks = set()
|
158
170
|
|
159
171
|
# For session info
|
160
172
|
self.session_futures = {} # session_id -> asyncio event
|
@@ -181,9 +193,6 @@ class TokenizerManager:
|
|
181
193
|
if self.to_create_loop:
|
182
194
|
self.create_handle_loop()
|
183
195
|
|
184
|
-
while self.model_update_lock.locked():
|
185
|
-
await asyncio.sleep(0.001)
|
186
|
-
|
187
196
|
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
|
188
197
|
raise ValueError(
|
189
198
|
"This model does not appear to be an embedding model by default. "
|
@@ -191,17 +200,24 @@ class TokenizerManager:
|
|
191
200
|
)
|
192
201
|
|
193
202
|
obj.normalize_batch_and_arguments()
|
194
|
-
|
195
|
-
if
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
203
|
+
|
204
|
+
if self.server_args.log_requests:
|
205
|
+
logger.info(f"Receive: obj={dataclass_to_string_truncated(obj)}")
|
206
|
+
|
207
|
+
async with self.model_update_lock.reader_lock:
|
208
|
+
is_single = obj.is_single
|
209
|
+
if is_single:
|
210
|
+
tokenized_obj = await self._tokenize_one_request(obj)
|
211
|
+
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
212
|
+
async for response in self._wait_one_response(
|
213
|
+
obj, request, created_time
|
214
|
+
):
|
215
|
+
yield response
|
216
|
+
else:
|
217
|
+
async for response in self._handle_batch_request(
|
218
|
+
obj, request, created_time
|
219
|
+
):
|
220
|
+
yield response
|
205
221
|
|
206
222
|
async def _tokenize_one_request(
|
207
223
|
self,
|
@@ -215,7 +231,7 @@ class TokenizerManager:
|
|
215
231
|
if not self.server_args.disable_radix_cache:
|
216
232
|
raise ValueError(
|
217
233
|
"input_embeds is provided while disable_radix_cache is False. "
|
218
|
-
"Please add `--disable-radix-
|
234
|
+
"Please add `--disable-radix-cache` when you launch the server "
|
219
235
|
"if you want to use input_embeds as inputs."
|
220
236
|
)
|
221
237
|
input_embeds = obj.input_embeds
|
@@ -301,8 +317,8 @@ class TokenizerManager:
|
|
301
317
|
state.out_list = []
|
302
318
|
if state.finished:
|
303
319
|
if self.server_args.log_requests:
|
304
|
-
|
305
|
-
logger.info(
|
320
|
+
msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}"
|
321
|
+
logger.info(msg)
|
306
322
|
del self.rid_to_state[obj.rid]
|
307
323
|
yield out
|
308
324
|
break
|
@@ -423,55 +439,52 @@ class TokenizerManager:
|
|
423
439
|
self,
|
424
440
|
obj: UpdateWeightFromDiskReqInput,
|
425
441
|
request: Optional[fastapi.Request] = None,
|
426
|
-
):
|
442
|
+
) -> Tuple[bool, str]:
|
427
443
|
if self.to_create_loop:
|
428
444
|
self.create_handle_loop()
|
429
445
|
|
430
446
|
# default the load format to the server_args
|
431
447
|
if obj.load_format is None:
|
432
448
|
obj.load_format = self.server_args.load_format
|
449
|
+
logger.info("Start update_weights. Load format=%s", obj.load_format)
|
433
450
|
|
434
|
-
if
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
while len(self.rid_to_state) > 0:
|
440
|
-
await asyncio.sleep(0.001)
|
441
|
-
# FIXME: We add some sleep here to avoid some race conditions.
|
442
|
-
# We can use a read-write lock as a better fix.
|
443
|
-
await asyncio.sleep(0.01)
|
444
|
-
self.send_to_scheduler.send_pyobj(obj)
|
445
|
-
self.model_update_result = asyncio.Future()
|
451
|
+
if True:
|
452
|
+
# Hold the lock if it is not async. This means that weight sync
|
453
|
+
# cannot run while requests are in progress.
|
454
|
+
async with self.model_update_lock.writer_lock:
|
455
|
+
return await self._wait_for_model_update_from_disk(obj)
|
446
456
|
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
457
|
+
async def _wait_for_model_update_from_disk(
|
458
|
+
self, obj: UpdateWeightFromDiskReqInput
|
459
|
+
) -> Tuple[bool, str, int]:
|
460
|
+
self.send_to_scheduler.send_pyobj(obj)
|
461
|
+
self.model_update_result = asyncio.Future()
|
462
|
+
if self.server_args.dp_size == 1:
|
463
|
+
result = await self.model_update_result
|
464
|
+
if result.success:
|
465
|
+
self.served_model_name = obj.model_path
|
466
|
+
self.server_args.model_path = obj.model_path
|
467
|
+
self.server_args.load_format = obj.load_format
|
468
|
+
self.model_path = obj.model_path
|
469
|
+
return result.success, result.message
|
470
|
+
else: # self.server_args.dp_size > 1
|
471
|
+
self.model_update_tmp = []
|
472
|
+
result = await self.model_update_result
|
473
|
+
|
474
|
+
all_success = all([r.success for r in result])
|
475
|
+
if all_success is True:
|
476
|
+
self.server_args.model_path = obj.model_path
|
477
|
+
self.server_args.load_format = obj.load_format
|
478
|
+
self.model_path = obj.model_path
|
479
|
+
all_message = [r.message for r in result]
|
480
|
+
all_message = " | ".join(all_message)
|
481
|
+
return all_success, all_message
|
469
482
|
|
470
483
|
async def init_weights_update_group(
|
471
484
|
self,
|
472
485
|
obj: InitWeightsUpdateGroupReqInput,
|
473
486
|
request: Optional[fastapi.Request] = None,
|
474
|
-
) -> bool:
|
487
|
+
) -> Tuple[bool, str]:
|
475
488
|
if self.to_create_loop:
|
476
489
|
self.create_handle_loop()
|
477
490
|
self.send_to_scheduler.send_pyobj(obj)
|
@@ -487,25 +500,22 @@ class TokenizerManager:
|
|
487
500
|
self,
|
488
501
|
obj: UpdateWeightsFromDistributedReqInput,
|
489
502
|
request: Optional[fastapi.Request] = None,
|
490
|
-
):
|
503
|
+
) -> Tuple[bool, str]:
|
491
504
|
if self.to_create_loop:
|
492
505
|
self.create_handle_loop()
|
493
506
|
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
return
|
506
|
-
False,
|
507
|
-
"Another parameter update is in progress. Please try again later.",
|
508
|
-
)
|
507
|
+
# This means that weight sync
|
508
|
+
# cannot run while requests are in progress.
|
509
|
+
async with self.model_update_lock.writer_lock:
|
510
|
+
self.send_to_scheduler.send_pyobj(obj)
|
511
|
+
self.parameter_update_result: Awaitable[
|
512
|
+
UpdateWeightsFromDistributedReqOutput
|
513
|
+
] = asyncio.Future()
|
514
|
+
assert (
|
515
|
+
self.server_args.dp_size == 1
|
516
|
+
), "dp_size must be for update weights from distributed"
|
517
|
+
result = await self.parameter_update_result
|
518
|
+
return result.success, result.message
|
509
519
|
|
510
520
|
async def get_weights_by_name(
|
511
521
|
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
@@ -564,11 +574,11 @@ class TokenizerManager:
|
|
564
574
|
|
565
575
|
self.to_create_loop = False
|
566
576
|
loop = asyncio.get_event_loop()
|
567
|
-
loop.create_task(self.handle_loop())
|
577
|
+
self.asyncio_tasks.add(loop.create_task(self.handle_loop()))
|
568
578
|
|
569
579
|
signal_handler = SignalHandler(self)
|
570
580
|
loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
|
571
|
-
loop.create_task(self.sigterm_watchdog())
|
581
|
+
self.asyncio_tasks.add(loop.create_task(self.sigterm_watchdog()))
|
572
582
|
|
573
583
|
async def sigterm_watchdog(self):
|
574
584
|
while not self.gracefully_exit:
|
@@ -184,26 +184,35 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
|
184
184
|
device: str,
|
185
185
|
):
|
186
186
|
super().__init__(size, dtype, device)
|
187
|
+
self.head_num = head_num
|
188
|
+
self.head_dim = head_dim
|
189
|
+
self.layer_num = layer_num
|
190
|
+
self._create_buffers()
|
187
191
|
|
192
|
+
def _create_buffers(self):
|
188
193
|
# [size, head_num, head_dim] for each layer
|
189
194
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
190
195
|
self.k_buffer = [
|
191
196
|
torch.empty(
|
192
|
-
(size + 1, head_num, head_dim),
|
197
|
+
(self.size + 1, self.head_num, self.head_dim),
|
193
198
|
dtype=self.store_dtype,
|
194
|
-
device=device,
|
199
|
+
device=self.device,
|
195
200
|
)
|
196
|
-
for _ in range(layer_num)
|
201
|
+
for _ in range(self.layer_num)
|
197
202
|
]
|
198
203
|
self.v_buffer = [
|
199
204
|
torch.empty(
|
200
|
-
(size + 1, head_num, head_dim),
|
205
|
+
(self.size + 1, self.head_num, self.head_dim),
|
201
206
|
dtype=self.store_dtype,
|
202
|
-
device=device,
|
207
|
+
device=self.device,
|
203
208
|
)
|
204
|
-
for _ in range(layer_num)
|
209
|
+
for _ in range(self.layer_num)
|
205
210
|
]
|
206
211
|
|
212
|
+
def _clear_buffers(self):
|
213
|
+
del self.k_buffer
|
214
|
+
del self.v_buffer
|
215
|
+
|
207
216
|
def get_key_buffer(self, layer_id: int):
|
208
217
|
if self.store_dtype != self.dtype:
|
209
218
|
return self.k_buffer[layer_id].view(self.dtype)
|
@@ -245,7 +254,6 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
|
|
245
254
|
|
246
255
|
|
247
256
|
class MLATokenToKVPool(BaseTokenToKVPool):
|
248
|
-
|
249
257
|
def __init__(
|
250
258
|
self,
|
251
259
|
size: int,
|
@@ -298,7 +306,6 @@ class MLATokenToKVPool(BaseTokenToKVPool):
|
|
298
306
|
|
299
307
|
|
300
308
|
class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
|
301
|
-
|
302
309
|
def __init__(
|
303
310
|
self,
|
304
311
|
size: int,
|
@@ -25,12 +25,12 @@ from vllm.distributed import get_tensor_model_parallel_rank
|
|
25
25
|
from vllm.distributed.parallel_state import graph_capture
|
26
26
|
from vllm.model_executor.custom_op import CustomOp
|
27
27
|
|
28
|
-
from sglang.srt.layers.fused_moe_patch import fused_moe_forward_native
|
29
28
|
from sglang.srt.layers.logits_processor import (
|
30
29
|
LogitsMetadata,
|
31
30
|
LogitsProcessor,
|
32
31
|
LogitsProcessorOutput,
|
33
32
|
)
|
33
|
+
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
|
34
34
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
35
35
|
from sglang.srt.utils import maybe_torch_compile, monkey_patch_vllm_all_gather
|
36
36
|
|
@@ -95,6 +95,12 @@ class ModelRunner:
|
|
95
95
|
):
|
96
96
|
logger.info("MLA optimization is turned on. Use triton backend.")
|
97
97
|
self.server_args.attention_backend = "triton"
|
98
|
+
# FIXME(HandH1998)
|
99
|
+
if (
|
100
|
+
"DeepseekV3ForCausalLM" in self.model_config.hf_config.architectures
|
101
|
+
and not self.server_args.disable_cuda_graph
|
102
|
+
):
|
103
|
+
self.server_args.disable_cuda_graph = True
|
98
104
|
|
99
105
|
if self.server_args.enable_double_sparsity:
|
100
106
|
logger.info(
|
sglang/srt/models/dbrx.py
CHANGED
@@ -27,13 +27,13 @@ from vllm.distributed import (
|
|
27
27
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
28
28
|
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
29
29
|
|
30
|
-
from sglang.srt.layers.fused_moe_triton import fused_moe
|
31
30
|
from sglang.srt.layers.linear import (
|
32
31
|
QKVParallelLinear,
|
33
32
|
ReplicatedLinear,
|
34
33
|
RowParallelLinear,
|
35
34
|
)
|
36
35
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
36
|
+
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
|
37
37
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
38
38
|
from sglang.srt.layers.radix_attention import RadixAttention
|
39
39
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
sglang/srt/models/deepseek.py
CHANGED
@@ -29,7 +29,6 @@ from vllm.distributed import (
|
|
29
29
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
30
30
|
|
31
31
|
from sglang.srt.layers.activation import SiluAndMul
|
32
|
-
from sglang.srt.layers.fused_moe_triton import fused_moe
|
33
32
|
from sglang.srt.layers.layernorm import RMSNorm
|
34
33
|
from sglang.srt.layers.linear import (
|
35
34
|
MergedColumnParallelLinear,
|
@@ -38,6 +37,7 @@ from sglang.srt.layers.linear import (
|
|
38
37
|
RowParallelLinear,
|
39
38
|
)
|
40
39
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
40
|
+
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
|
41
41
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
42
42
|
from sglang.srt.layers.radix_attention import RadixAttention
|
43
43
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -19,6 +19,7 @@
|
|
19
19
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
20
20
|
|
21
21
|
import torch
|
22
|
+
import torch.nn.functional as F
|
22
23
|
from torch import nn
|
23
24
|
from transformers import PretrainedConfig
|
24
25
|
from vllm import _custom_ops as ops
|
@@ -31,8 +32,6 @@ from vllm.distributed import (
|
|
31
32
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
32
33
|
|
33
34
|
from sglang.srt.layers.activation import SiluAndMul
|
34
|
-
from sglang.srt.layers.ep_moe.layer import EPMoE
|
35
|
-
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
36
35
|
from sglang.srt.layers.layernorm import RMSNorm
|
37
36
|
from sglang.srt.layers.linear import (
|
38
37
|
ColumnParallelLinear,
|
@@ -41,7 +40,13 @@ from sglang.srt.layers.linear import (
|
|
41
40
|
RowParallelLinear,
|
42
41
|
)
|
43
42
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
43
|
+
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
44
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
44
45
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
46
|
+
from sglang.srt.layers.quantization.fp8_utils import (
|
47
|
+
block_quant_to_tensor_quant,
|
48
|
+
input_to_float8,
|
49
|
+
)
|
45
50
|
from sglang.srt.layers.radix_attention import RadixAttention
|
46
51
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
47
52
|
ParallelLMHead,
|
@@ -90,6 +95,24 @@ class DeepseekV2MLP(nn.Module):
|
|
90
95
|
return x
|
91
96
|
|
92
97
|
|
98
|
+
class MoEGate(nn.Module):
|
99
|
+
def __init__(self, config):
|
100
|
+
super().__init__()
|
101
|
+
self.weight = nn.Parameter(
|
102
|
+
torch.empty((config.n_routed_experts, config.hidden_size))
|
103
|
+
)
|
104
|
+
if config.topk_method == "noaux_tc":
|
105
|
+
self.e_score_correction_bias = nn.Parameter(
|
106
|
+
torch.empty((config.n_routed_experts))
|
107
|
+
)
|
108
|
+
else:
|
109
|
+
self.e_score_correction_bias = None
|
110
|
+
|
111
|
+
def forward(self, hidden_states):
|
112
|
+
logits = F.linear(hidden_states, self.weight, None)
|
113
|
+
return logits
|
114
|
+
|
115
|
+
|
93
116
|
class DeepseekV2MoE(nn.Module):
|
94
117
|
|
95
118
|
def __init__(
|
@@ -114,6 +137,8 @@ class DeepseekV2MoE(nn.Module):
|
|
114
137
|
"Only silu is supported for now."
|
115
138
|
)
|
116
139
|
|
140
|
+
self.gate = MoEGate(config=config)
|
141
|
+
|
117
142
|
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
118
143
|
self.experts = MoEImpl(
|
119
144
|
num_experts=config.n_routed_experts,
|
@@ -125,11 +150,9 @@ class DeepseekV2MoE(nn.Module):
|
|
125
150
|
use_grouped_topk=True,
|
126
151
|
num_expert_group=config.n_group,
|
127
152
|
topk_group=config.topk_group,
|
153
|
+
correction_bias=self.gate.e_score_correction_bias,
|
128
154
|
)
|
129
155
|
|
130
|
-
self.gate = ReplicatedLinear(
|
131
|
-
config.hidden_size, config.n_routed_experts, bias=False, quant_config=None
|
132
|
-
)
|
133
156
|
if config.n_shared_experts is not None:
|
134
157
|
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
135
158
|
self.shared_experts = DeepseekV2MLP(
|
@@ -146,7 +169,7 @@ class DeepseekV2MoE(nn.Module):
|
|
146
169
|
if self.n_shared_experts is not None:
|
147
170
|
shared_output = self.shared_experts(hidden_states)
|
148
171
|
# router_logits: (num_tokens, n_experts)
|
149
|
-
router_logits
|
172
|
+
router_logits = self.gate(hidden_states)
|
150
173
|
final_hidden_states = (
|
151
174
|
self.experts(hidden_states=hidden_states, router_logits=router_logits)
|
152
175
|
* self.routed_scaling_factor
|
@@ -167,15 +190,6 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
|
167
190
|
return 0.1 * mscale * math.log(scale) + 1.0
|
168
191
|
|
169
192
|
|
170
|
-
def input_to_float8(x, dtype=torch.float8_e4m3fn):
|
171
|
-
finfo = torch.finfo(dtype)
|
172
|
-
min_val, max_val = x.aminmax()
|
173
|
-
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
174
|
-
scale = finfo.max / amax
|
175
|
-
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
176
|
-
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
177
|
-
|
178
|
-
|
179
193
|
class DeepseekV2Attention(nn.Module):
|
180
194
|
|
181
195
|
def __init__(
|
@@ -439,7 +453,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
439
453
|
quant_config=quant_config,
|
440
454
|
)
|
441
455
|
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
442
|
-
|
456
|
+
|
457
|
+
if rope_scaling:
|
458
|
+
rope_scaling["rope_type"] = "deepseek_yarn"
|
459
|
+
|
443
460
|
self.rotary_emb = get_rope(
|
444
461
|
qk_rope_head_dim,
|
445
462
|
rotary_dim=qk_rope_head_dim,
|
@@ -454,6 +471,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
454
471
|
scaling_factor = rope_scaling["factor"]
|
455
472
|
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
456
473
|
self.scaling = self.scaling * mscale * mscale
|
474
|
+
else:
|
475
|
+
self.rotary_emb.forward = self.rotary_emb.forward_native
|
457
476
|
|
458
477
|
self.attn_mqa = RadixAttention(
|
459
478
|
self.num_local_heads,
|
@@ -845,6 +864,16 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
845
864
|
|
846
865
|
params_dict = dict(self.named_parameters())
|
847
866
|
for name, loaded_weight in weights:
|
867
|
+
# TODO(HandH1998): Modify it when nextn is supported.
|
868
|
+
if hasattr(self.config, "num_nextn_predict_layers"):
|
869
|
+
num_nextn_layers = self.config.num_nextn_predict_layers
|
870
|
+
if num_nextn_layers > 0 and name.startswith("model.layers"):
|
871
|
+
name_list = name.split(".")
|
872
|
+
if (
|
873
|
+
len(name_list) >= 3
|
874
|
+
and int(name_list[2]) >= self.config.num_hidden_layers
|
875
|
+
):
|
876
|
+
continue
|
848
877
|
if "rotary_emb.inv_freq" in name:
|
849
878
|
continue
|
850
879
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
@@ -909,13 +938,33 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
909
938
|
).T
|
910
939
|
else:
|
911
940
|
w = self_attn.kv_b_proj.weight
|
941
|
+
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
942
|
+
# This may affect the accuracy of fp8 model.
|
943
|
+
if (
|
944
|
+
hasattr(self.quant_config, "weight_block_size")
|
945
|
+
and w.dtype == torch.float8_e4m3fn
|
946
|
+
):
|
947
|
+
weight_block_size = self.quant_config.weight_block_size
|
948
|
+
if weight_block_size is not None:
|
949
|
+
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
950
|
+
w, scale = block_quant_to_tensor_quant(
|
951
|
+
w, self_attn.kv_b_proj.weight_scale_inv, weight_block_size
|
952
|
+
)
|
953
|
+
self_attn.w_scale = scale
|
912
954
|
w_kc, w_vc = w.unflatten(
|
913
955
|
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
914
956
|
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
915
957
|
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
916
958
|
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
917
|
-
if
|
959
|
+
if (
|
960
|
+
hasattr(self_attn.kv_b_proj, "weight_scale")
|
961
|
+
and self_attn.w_scale is None
|
962
|
+
):
|
918
963
|
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
919
964
|
|
920
965
|
|
921
|
-
|
966
|
+
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
967
|
+
pass
|
968
|
+
|
969
|
+
|
970
|
+
EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM]
|
sglang/srt/models/grok.py
CHANGED
@@ -26,7 +26,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size
|
|
26
26
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
27
27
|
|
28
28
|
from sglang.srt.layers.activation import GeluAndMul
|
29
|
-
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
30
29
|
from sglang.srt.layers.layernorm import RMSNorm
|
31
30
|
from sglang.srt.layers.linear import (
|
32
31
|
MergedColumnParallelLinear,
|
@@ -35,6 +34,7 @@ from sglang.srt.layers.linear import (
|
|
35
34
|
RowParallelLinear,
|
36
35
|
)
|
37
36
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
37
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
38
38
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
39
39
|
from sglang.srt.layers.radix_attention import RadixAttention
|
40
40
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
sglang/srt/models/mixtral.py
CHANGED
@@ -27,8 +27,6 @@ from vllm.distributed import (
|
|
27
27
|
)
|
28
28
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
29
29
|
|
30
|
-
from sglang.srt.layers.ep_moe.layer import EPMoE
|
31
|
-
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
32
30
|
from sglang.srt.layers.layernorm import RMSNorm
|
33
31
|
from sglang.srt.layers.linear import (
|
34
32
|
QKVParallelLinear,
|
@@ -36,6 +34,8 @@ from sglang.srt.layers.linear import (
|
|
36
34
|
RowParallelLinear,
|
37
35
|
)
|
38
36
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
37
|
+
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
38
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
39
39
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
40
40
|
from sglang.srt.layers.radix_attention import RadixAttention
|
41
41
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
sglang/srt/models/olmoe.py
CHANGED
@@ -36,9 +36,9 @@ from vllm.model_executor.layers.linear import (
|
|
36
36
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
37
37
|
|
38
38
|
from sglang.srt.layers.activation import SiluAndMul
|
39
|
-
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
40
39
|
from sglang.srt.layers.layernorm import RMSNorm
|
41
40
|
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
41
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
42
42
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
43
43
|
from sglang.srt.layers.radix_attention import RadixAttention
|
44
44
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
sglang/srt/models/qwen2_moe.py
CHANGED
@@ -29,7 +29,6 @@ from vllm.distributed import (
|
|
29
29
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
30
30
|
|
31
31
|
from sglang.srt.layers.activation import SiluAndMul
|
32
|
-
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
33
32
|
from sglang.srt.layers.layernorm import RMSNorm
|
34
33
|
from sglang.srt.layers.linear import (
|
35
34
|
MergedColumnParallelLinear,
|
@@ -38,6 +37,7 @@ from sglang.srt.layers.linear import (
|
|
38
37
|
RowParallelLinear,
|
39
38
|
)
|
40
39
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
40
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
41
41
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
42
42
|
from sglang.srt.layers.radix_attention import RadixAttention
|
43
43
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
sglang/srt/models/xverse_moe.py
CHANGED
@@ -33,8 +33,8 @@ from vllm.model_executor.layers.linear import (
|
|
33
33
|
)
|
34
34
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
35
35
|
|
36
|
-
from sglang.srt.layers.fused_moe_triton import fused_moe
|
37
36
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
37
|
+
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
|
38
38
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
39
39
|
from sglang.srt.layers.radix_attention import RadixAttention
|
40
40
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
sglang/srt/openai_api/adapter.py
CHANGED
@@ -858,6 +858,7 @@ def v1_chat_generate_request(
|
|
858
858
|
logprob_start_lens = []
|
859
859
|
top_logprobs_nums = []
|
860
860
|
modalities_list = []
|
861
|
+
lora_paths = []
|
861
862
|
|
862
863
|
# NOTE: with openai API, the prompt's logprobs are always not computed
|
863
864
|
|
@@ -920,6 +921,7 @@ def v1_chat_generate_request(
|
|
920
921
|
return_logprobs.append(request.logprobs)
|
921
922
|
logprob_start_lens.append(-1)
|
922
923
|
top_logprobs_nums.append(request.top_logprobs or 0)
|
924
|
+
lora_paths.append(request.lora_path)
|
923
925
|
|
924
926
|
sampling_params = {
|
925
927
|
"temperature": request.temperature,
|
@@ -958,6 +960,7 @@ def v1_chat_generate_request(
|
|
958
960
|
logprob_start_lens = logprob_start_lens[0]
|
959
961
|
top_logprobs_nums = top_logprobs_nums[0]
|
960
962
|
modalities_list = modalities_list[0]
|
963
|
+
lora_paths = lora_paths[0]
|
961
964
|
else:
|
962
965
|
if isinstance(input_ids[0], str):
|
963
966
|
prompt_kwargs = {"text": input_ids}
|
@@ -975,6 +978,7 @@ def v1_chat_generate_request(
|
|
975
978
|
return_text_in_logprobs=True,
|
976
979
|
rid=request_ids,
|
977
980
|
modalities=modalities_list,
|
981
|
+
lora_path=lora_paths,
|
978
982
|
)
|
979
983
|
|
980
984
|
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
|