sglang 0.4.1.post5__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +21 -23
- sglang/api.py +2 -7
- sglang/bench_offline_throughput.py +24 -16
- sglang/bench_one_batch.py +51 -3
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +37 -28
- sglang/lang/backend/runtime_endpoint.py +183 -4
- sglang/lang/chat_template.py +15 -4
- sglang/launch_server.py +1 -1
- sglang/srt/_custom_ops.py +80 -42
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/model_config.py +16 -6
- sglang/srt/constrained/base_grammar_backend.py +21 -0
- sglang/srt/constrained/xgrammar_backend.py +8 -4
- sglang/srt/conversation.py +14 -1
- sglang/srt/distributed/__init__.py +3 -3
- sglang/srt/distributed/communication_op.py +2 -1
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
- sglang/srt/distributed/device_communicators/pynccl.py +80 -1
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
- sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
- sglang/srt/distributed/parallel_state.py +1 -1
- sglang/srt/distributed/utils.py +2 -1
- sglang/srt/entrypoints/engine.py +449 -0
- sglang/srt/entrypoints/http_server.py +579 -0
- sglang/srt/layers/activation.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +27 -12
- sglang/srt/layers/attention/triton_backend.py +4 -6
- sglang/srt/layers/attention/vision.py +204 -0
- sglang/srt/layers/dp_attention.py +69 -0
- sglang/srt/layers/linear.py +76 -102
- sglang/srt/layers/logits_processor.py +48 -63
- sglang/srt/layers/moe/ep_moe/layer.py +4 -4
- sglang/srt/layers/moe/fused_moe_native.py +69 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -14
- sglang/srt/layers/moe/topk.py +4 -2
- sglang/srt/layers/parameter.py +26 -17
- sglang/srt/layers/quantization/__init__.py +22 -23
- sglang/srt/layers/quantization/fp8.py +112 -55
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/int8_kernel.py +54 -0
- sglang/srt/layers/quantization/modelopt_quant.py +2 -3
- sglang/srt/layers/quantization/w8a8_int8.py +117 -0
- sglang/srt/layers/radix_attention.py +2 -0
- sglang/srt/layers/rotary_embedding.py +1179 -31
- sglang/srt/layers/sampler.py +39 -1
- sglang/srt/layers/vocab_parallel_embedding.py +17 -4
- sglang/srt/lora/lora.py +1 -9
- sglang/srt/managers/configure_logging.py +46 -0
- sglang/srt/managers/data_parallel_controller.py +79 -72
- sglang/srt/managers/detokenizer_manager.py +23 -8
- sglang/srt/managers/image_processor.py +158 -2
- sglang/srt/managers/io_struct.py +54 -15
- sglang/srt/managers/schedule_batch.py +49 -22
- sglang/srt/managers/schedule_policy.py +26 -12
- sglang/srt/managers/scheduler.py +319 -181
- sglang/srt/managers/session_controller.py +1 -0
- sglang/srt/managers/tokenizer_manager.py +303 -158
- sglang/srt/managers/tp_worker.py +6 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
- sglang/srt/managers/utils.py +44 -0
- sglang/srt/mem_cache/memory_pool.py +110 -77
- sglang/srt/metrics/collector.py +25 -11
- sglang/srt/model_executor/cuda_graph_runner.py +4 -6
- sglang/srt/model_executor/model_runner.py +80 -21
- sglang/srt/model_loader/loader.py +8 -6
- sglang/srt/model_loader/weight_utils.py +55 -2
- sglang/srt/models/baichuan.py +6 -6
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +3 -3
- sglang/srt/models/dbrx.py +4 -4
- sglang/srt/models/deepseek.py +3 -3
- sglang/srt/models/deepseek_v2.py +8 -8
- sglang/srt/models/exaone.py +2 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +6 -24
- sglang/srt/models/gpt2.py +3 -5
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/granite.py +2 -2
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/internlm2.py +2 -2
- sglang/srt/models/llama.py +41 -4
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/minicpm3.py +6 -6
- sglang/srt/models/minicpmv.py +1238 -0
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mixtral_quant.py +3 -3
- sglang/srt/models/mllama.py +2 -2
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/olmo2.py +4 -4
- sglang/srt/models/olmoe.py +7 -13
- sglang/srt/models/phi3_small.py +2 -2
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +52 -4
- sglang/srt/models/qwen2_eagle.py +131 -0
- sglang/srt/models/qwen2_moe.py +3 -3
- sglang/srt/models/qwen2_vl.py +22 -122
- sglang/srt/models/stablelm.py +2 -2
- sglang/srt/models/torch_native_llama.py +3 -3
- sglang/srt/models/xverse.py +6 -6
- sglang/srt/models/xverse_moe.py +6 -6
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/custom_logit_processor.py +38 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
- sglang/srt/sampling/sampling_batch_info.py +153 -9
- sglang/srt/sampling/sampling_params.py +4 -2
- sglang/srt/server.py +4 -1037
- sglang/srt/server_args.py +84 -32
- sglang/srt/speculative/eagle_worker.py +1 -0
- sglang/srt/torch_memory_saver_adapter.py +59 -0
- sglang/srt/utils.py +130 -63
- sglang/test/runners.py +8 -13
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +3 -1
- sglang/utils.py +12 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +26 -13
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +126 -117
- sglang/launch_server_llavavid.py +0 -25
- sglang/srt/constrained/__init__.py +0 -16
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
@@ -21,20 +21,26 @@ from typing import List, Optional, Tuple
|
|
21
21
|
|
22
22
|
import torch
|
23
23
|
import torch.distributed as dist
|
24
|
-
|
24
|
+
|
25
|
+
from sglang.srt.configs.device_config import DeviceConfig
|
26
|
+
from sglang.srt.configs.load_config import LoadConfig
|
27
|
+
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
28
|
+
from sglang.srt.distributed import (
|
25
29
|
get_tp_group,
|
26
30
|
init_distributed_environment,
|
27
31
|
initialize_model_parallel,
|
28
32
|
set_custom_all_reduce,
|
29
33
|
)
|
30
|
-
|
31
|
-
from sglang.srt.configs.device_config import DeviceConfig
|
32
|
-
from sglang.srt.configs.load_config import LoadConfig
|
33
|
-
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
34
|
+
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
|
34
35
|
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
|
35
36
|
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
|
36
37
|
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
37
38
|
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
39
|
+
from sglang.srt.layers.dp_attention import (
|
40
|
+
get_attention_tp_group,
|
41
|
+
get_attention_tp_size,
|
42
|
+
initialize_dp_attention,
|
43
|
+
)
|
38
44
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
39
45
|
from sglang.srt.layers.sampler import Sampler
|
40
46
|
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
@@ -50,13 +56,15 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
50
56
|
from sglang.srt.model_loader import get_model
|
51
57
|
from sglang.srt.server_args import ServerArgs
|
52
58
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
59
|
+
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
53
60
|
from sglang.srt.utils import (
|
54
61
|
enable_show_time_cost,
|
55
62
|
get_available_gpu_memory,
|
56
63
|
init_custom_process_group,
|
64
|
+
is_cuda,
|
57
65
|
is_hip,
|
66
|
+
monkey_patch_p2p_access_check,
|
58
67
|
monkey_patch_vllm_gguf_config,
|
59
|
-
monkey_patch_vllm_p2p_access_check,
|
60
68
|
set_cpu_offload_max_bytes,
|
61
69
|
)
|
62
70
|
|
@@ -99,8 +107,10 @@ class ModelRunner:
|
|
99
107
|
self.model_config.attention_arch == AttentionArch.MLA
|
100
108
|
and not self.server_args.disable_mla
|
101
109
|
):
|
102
|
-
|
103
|
-
self.server_args.
|
110
|
+
# TODO: add MLA optimization on CPU
|
111
|
+
if self.server_args.device != "cpu":
|
112
|
+
logger.info("MLA optimization is turned on. Use triton backend.")
|
113
|
+
self.server_args.attention_backend = "triton"
|
104
114
|
|
105
115
|
if self.server_args.enable_double_sparsity:
|
106
116
|
logger.info(
|
@@ -157,6 +167,7 @@ class ModelRunner:
|
|
157
167
|
"enable_nan_detection": server_args.enable_nan_detection,
|
158
168
|
"enable_dp_attention": server_args.enable_dp_attention,
|
159
169
|
"enable_ep_moe": server_args.enable_ep_moe,
|
170
|
+
"device": server_args.device,
|
160
171
|
}
|
161
172
|
)
|
162
173
|
|
@@ -165,6 +176,10 @@ class ModelRunner:
|
|
165
176
|
# Get memory before model loading
|
166
177
|
min_per_gpu_memory = self.init_torch_distributed()
|
167
178
|
|
179
|
+
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
180
|
+
enable=self.server_args.enable_memory_saver
|
181
|
+
)
|
182
|
+
|
168
183
|
# Load the model
|
169
184
|
self.sampler = Sampler()
|
170
185
|
self.load_model()
|
@@ -210,9 +225,12 @@ class ModelRunner:
|
|
210
225
|
backend = "gloo"
|
211
226
|
elif self.device == "hpu":
|
212
227
|
backend = "hccl"
|
228
|
+
elif self.device == "cpu":
|
229
|
+
backend = "gloo"
|
213
230
|
|
214
231
|
if not self.server_args.enable_p2p_check:
|
215
|
-
|
232
|
+
monkey_patch_p2p_access_check()
|
233
|
+
|
216
234
|
if self.server_args.dist_init_addr:
|
217
235
|
dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
|
218
236
|
else:
|
@@ -220,7 +238,7 @@ class ModelRunner:
|
|
220
238
|
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
|
221
239
|
|
222
240
|
if not self.is_draft_worker:
|
223
|
-
# Only
|
241
|
+
# Only initialize the distributed environment on the target model worker.
|
224
242
|
init_distributed_environment(
|
225
243
|
backend=backend,
|
226
244
|
world_size=self.tp_size,
|
@@ -229,11 +247,18 @@ class ModelRunner:
|
|
229
247
|
distributed_init_method=dist_init_method,
|
230
248
|
)
|
231
249
|
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
250
|
+
initialize_dp_attention(
|
251
|
+
enable_dp_attention=self.server_args.enable_dp_attention,
|
252
|
+
tp_rank=self.tp_rank,
|
253
|
+
tp_size=self.tp_size,
|
254
|
+
dp_size=self.server_args.dp_size,
|
255
|
+
)
|
232
256
|
|
233
257
|
min_per_gpu_memory = get_available_gpu_memory(
|
234
258
|
self.device, self.gpu_id, distributed=self.tp_size > 1
|
235
259
|
)
|
236
260
|
self.tp_group = get_tp_group()
|
261
|
+
self.attention_tp_group = get_attention_tp_group()
|
237
262
|
|
238
263
|
# Check memory for tensor parallelism
|
239
264
|
if self.tp_size > 1:
|
@@ -251,7 +276,8 @@ class ModelRunner:
|
|
251
276
|
)
|
252
277
|
|
253
278
|
# This can reduce thread conflicts and speed up weight loading.
|
254
|
-
|
279
|
+
if self.device != "cpu":
|
280
|
+
torch.set_num_threads(1)
|
255
281
|
if self.device == "cuda":
|
256
282
|
if torch.cuda.get_device_capability()[0] < 8:
|
257
283
|
logger.info(
|
@@ -271,11 +297,38 @@ class ModelRunner:
|
|
271
297
|
monkey_patch_vllm_gguf_config()
|
272
298
|
|
273
299
|
# Load the model
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
300
|
+
# Remove monkey_patch when linear.py quant remove dependencies with vllm
|
301
|
+
monkey_patch_vllm_parallel_state()
|
302
|
+
with self.memory_saver_adapter.region():
|
303
|
+
self.model = get_model(
|
304
|
+
model_config=self.model_config,
|
305
|
+
load_config=self.load_config,
|
306
|
+
device_config=DeviceConfig(self.device),
|
307
|
+
)
|
308
|
+
monkey_patch_vllm_parallel_state(reverse=True)
|
309
|
+
|
310
|
+
if self.server_args.kv_cache_dtype == "fp8_e4m3":
|
311
|
+
if self.server_args.quantization_param_path is not None:
|
312
|
+
if callable(getattr(self.model, "load_kv_cache_scales", None)):
|
313
|
+
self.model.load_kv_cache_scales(
|
314
|
+
self.server_args.quantization_param_path
|
315
|
+
)
|
316
|
+
logger.info(
|
317
|
+
"Loaded KV cache scaling factors from %s",
|
318
|
+
self.server_args.quantization_param_path,
|
319
|
+
)
|
320
|
+
else:
|
321
|
+
raise RuntimeError(
|
322
|
+
"Using FP8 KV cache and scaling factors provided but "
|
323
|
+
"model %s does not support loading scaling factors.",
|
324
|
+
self.model.__class__,
|
325
|
+
)
|
326
|
+
else:
|
327
|
+
logger.warning(
|
328
|
+
"Using FP8 KV cache but no scaling factors "
|
329
|
+
"provided. Defaulting to scaling factors of 1.0. "
|
330
|
+
"This may lead to less accurate results!"
|
331
|
+
)
|
279
332
|
|
280
333
|
# Parse other args
|
281
334
|
self.sliding_window_size = (
|
@@ -393,7 +446,7 @@ class ModelRunner:
|
|
393
446
|
|
394
447
|
logger.info(
|
395
448
|
f"init custom process group: master_address={master_address}, master_port={master_port}, "
|
396
|
-
f"rank_offset={rank_offset}, world_size={world_size}, group_name={group_name}, backend={backend}"
|
449
|
+
f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
|
397
450
|
)
|
398
451
|
|
399
452
|
try:
|
@@ -491,7 +544,7 @@ class ModelRunner:
|
|
491
544
|
)
|
492
545
|
else:
|
493
546
|
cell_size = (
|
494
|
-
self.model_config.get_num_kv_heads(
|
547
|
+
self.model_config.get_num_kv_heads(get_attention_tp_size())
|
495
548
|
* self.model_config.head_dim
|
496
549
|
* self.model_config.num_hidden_layers
|
497
550
|
* 2
|
@@ -516,6 +569,9 @@ class ModelRunner:
|
|
516
569
|
self.kv_cache_dtype = torch.float8_e5m2fnuz
|
517
570
|
else:
|
518
571
|
self.kv_cache_dtype = torch.float8_e5m2
|
572
|
+
elif self.server_args.kv_cache_dtype == "fp8_e4m3":
|
573
|
+
if is_cuda():
|
574
|
+
self.kv_cache_dtype = torch.float8_e4m3fn
|
519
575
|
else:
|
520
576
|
raise ValueError(
|
521
577
|
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
|
@@ -562,7 +618,7 @@ class ModelRunner:
|
|
562
618
|
size=max_num_reqs + 1,
|
563
619
|
max_context_len=self.model_config.context_len + 4,
|
564
620
|
device=self.device,
|
565
|
-
|
621
|
+
enable_memory_saver=self.server_args.enable_memory_saver,
|
566
622
|
)
|
567
623
|
if (
|
568
624
|
self.model_config.attention_arch == AttentionArch.MLA
|
@@ -575,25 +631,28 @@ class ModelRunner:
|
|
575
631
|
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
576
632
|
layer_num=self.model_config.num_hidden_layers,
|
577
633
|
device=self.device,
|
634
|
+
enable_memory_saver=self.server_args.enable_memory_saver,
|
578
635
|
)
|
579
636
|
elif self.server_args.enable_double_sparsity:
|
580
637
|
self.token_to_kv_pool = DoubleSparseTokenToKVPool(
|
581
638
|
self.max_total_num_tokens,
|
582
639
|
dtype=self.kv_cache_dtype,
|
583
|
-
head_num=self.model_config.get_num_kv_heads(
|
640
|
+
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
|
584
641
|
head_dim=self.model_config.head_dim,
|
585
642
|
layer_num=self.model_config.num_hidden_layers,
|
586
643
|
device=self.device,
|
587
644
|
heavy_channel_num=self.server_args.ds_heavy_channel_num,
|
645
|
+
enable_memory_saver=self.server_args.enable_memory_saver,
|
588
646
|
)
|
589
647
|
else:
|
590
648
|
self.token_to_kv_pool = MHATokenToKVPool(
|
591
649
|
self.max_total_num_tokens,
|
592
650
|
dtype=self.kv_cache_dtype,
|
593
|
-
head_num=self.model_config.get_num_kv_heads(
|
651
|
+
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
|
594
652
|
head_dim=self.model_config.head_dim,
|
595
653
|
layer_num=self.model_config.num_hidden_layers,
|
596
654
|
device=self.device,
|
655
|
+
enable_memory_saver=self.server_args.enable_memory_saver,
|
597
656
|
)
|
598
657
|
logger.info(
|
599
658
|
f"Memory pool end. "
|
@@ -21,14 +21,14 @@ from huggingface_hub import HfApi, hf_hub_download
|
|
21
21
|
from torch import nn
|
22
22
|
from transformers import AutoModelForCausalLM, PretrainedConfig
|
23
23
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
24
|
-
from vllm.distributed import (
|
25
|
-
get_tensor_model_parallel_rank,
|
26
|
-
get_tensor_model_parallel_world_size,
|
27
|
-
)
|
28
24
|
|
29
25
|
from sglang.srt.configs.device_config import DeviceConfig
|
30
26
|
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
|
31
27
|
from sglang.srt.configs.model_config import ModelConfig
|
28
|
+
from sglang.srt.distributed import (
|
29
|
+
get_tensor_model_parallel_rank,
|
30
|
+
get_tensor_model_parallel_world_size,
|
31
|
+
)
|
32
32
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
33
33
|
from sglang.srt.model_loader.utils import (
|
34
34
|
get_model_architecture,
|
@@ -496,7 +496,8 @@ class ShardedStateLoader(BaseModelLoader):
|
|
496
496
|
device_config: DeviceConfig,
|
497
497
|
) -> nn.Module:
|
498
498
|
from safetensors.torch import safe_open
|
499
|
-
|
499
|
+
|
500
|
+
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
500
501
|
|
501
502
|
local_model_path = self._prepare_weights(
|
502
503
|
model_config.model_path, model_config.revision
|
@@ -556,7 +557,8 @@ class ShardedStateLoader(BaseModelLoader):
|
|
556
557
|
max_size: Optional[int] = None,
|
557
558
|
) -> None:
|
558
559
|
from safetensors.torch import save_file
|
559
|
-
|
560
|
+
|
561
|
+
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
560
562
|
|
561
563
|
if pattern is None:
|
562
564
|
pattern = ShardedStateLoader.DEFAULT_PATTERN
|
@@ -9,7 +9,17 @@ import logging
|
|
9
9
|
import os
|
10
10
|
import tempfile
|
11
11
|
from collections import defaultdict
|
12
|
-
from typing import
|
12
|
+
from typing import (
|
13
|
+
Any,
|
14
|
+
Callable,
|
15
|
+
Dict,
|
16
|
+
Generator,
|
17
|
+
Iterable,
|
18
|
+
List,
|
19
|
+
Optional,
|
20
|
+
Tuple,
|
21
|
+
Union,
|
22
|
+
)
|
13
23
|
|
14
24
|
import filelock
|
15
25
|
import gguf
|
@@ -19,10 +29,10 @@ import torch
|
|
19
29
|
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
|
20
30
|
from safetensors.torch import load_file, safe_open, save_file
|
21
31
|
from tqdm.auto import tqdm
|
22
|
-
from vllm.distributed import get_tensor_model_parallel_rank
|
23
32
|
|
24
33
|
from sglang.srt.configs.load_config import LoadConfig
|
25
34
|
from sglang.srt.configs.model_config import ModelConfig
|
35
|
+
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
26
36
|
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
|
27
37
|
from sglang.srt.utils import print_warning_once
|
28
38
|
|
@@ -638,3 +648,46 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
|
|
638
648
|
|
639
649
|
# If there were no matches, return the untouched param name
|
640
650
|
return name
|
651
|
+
|
652
|
+
|
653
|
+
def kv_cache_scales_loader(
|
654
|
+
filename: str,
|
655
|
+
tp_rank: int,
|
656
|
+
tp_size: int,
|
657
|
+
num_hidden_layers: int,
|
658
|
+
model_type: Optional[str],
|
659
|
+
) -> Iterable[Tuple[int, float]]:
|
660
|
+
"""
|
661
|
+
A simple utility to read in KV cache scaling factors that have been
|
662
|
+
previously serialized to disk. Used by the model to populate the appropriate
|
663
|
+
KV cache scaling factors. The serialization should represent a dictionary
|
664
|
+
whose keys are the TP ranks and values are another dictionary mapping layers
|
665
|
+
to their KV cache scaling factors.
|
666
|
+
"""
|
667
|
+
try:
|
668
|
+
with open(filename) as f:
|
669
|
+
context = {
|
670
|
+
"model_type": model_type,
|
671
|
+
"num_hidden_layers": num_hidden_layers,
|
672
|
+
"tp_rank": tp_rank,
|
673
|
+
"tp_size": tp_size,
|
674
|
+
}
|
675
|
+
schema_dct = json.load(f)
|
676
|
+
schema = QuantParamSchema.model_validate(schema_dct, context=context)
|
677
|
+
layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
|
678
|
+
return layer_scales_map.items()
|
679
|
+
except FileNotFoundError:
|
680
|
+
logger.error("File or directory '%s' not found.", filename)
|
681
|
+
except json.JSONDecodeError:
|
682
|
+
logger.error("Error decoding JSON in file '%s'.", filename)
|
683
|
+
except Exception:
|
684
|
+
logger.exception("An error occurred while reading '%s'.", filename)
|
685
|
+
# This section is reached if and only if any of the excepts are hit
|
686
|
+
# Return an empty iterable (list) => no KV cache scales are loaded
|
687
|
+
# which ultimately defaults to 1.0 scales
|
688
|
+
logger.warning(
|
689
|
+
"Defaulting to KV cache scaling factors = 1.0 for all "
|
690
|
+
"layers in TP rank %d as an error occurred during loading.",
|
691
|
+
tp_rank,
|
692
|
+
)
|
693
|
+
return []
|
sglang/srt/models/baichuan.py
CHANGED
@@ -24,22 +24,22 @@ from typing import Iterable, Optional, Tuple
|
|
24
24
|
import torch
|
25
25
|
from torch import nn
|
26
26
|
from transformers import PretrainedConfig
|
27
|
-
|
27
|
+
|
28
|
+
from sglang.srt.distributed import (
|
28
29
|
get_tensor_model_parallel_rank,
|
29
30
|
get_tensor_model_parallel_world_size,
|
30
31
|
)
|
31
|
-
from
|
32
|
+
from sglang.srt.layers.activation import SiluAndMul
|
33
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
34
|
+
from sglang.srt.layers.linear import (
|
32
35
|
MergedColumnParallelLinear,
|
33
36
|
QKVParallelLinear,
|
34
37
|
RowParallelLinear,
|
35
38
|
)
|
36
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
37
|
-
|
38
|
-
from sglang.srt.layers.activation import SiluAndMul
|
39
|
-
from sglang.srt.layers.layernorm import RMSNorm
|
40
39
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
41
40
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
42
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
42
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
43
43
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
44
44
|
ParallelLMHead,
|
45
45
|
VocabParallelEmbedding,
|
sglang/srt/models/chatglm.py
CHANGED
@@ -21,10 +21,9 @@ from typing import Iterable, Optional, Tuple
|
|
21
21
|
import torch
|
22
22
|
from torch import nn
|
23
23
|
from torch.nn import LayerNorm
|
24
|
-
from vllm.distributed import get_tensor_model_parallel_world_size
|
25
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
26
24
|
|
27
25
|
from sglang.srt.configs import ChatGLMConfig
|
26
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
28
27
|
from sglang.srt.layers.activation import SiluAndMul
|
29
28
|
from sglang.srt.layers.layernorm import RMSNorm
|
30
29
|
from sglang.srt.layers.linear import (
|
@@ -35,6 +34,7 @@ from sglang.srt.layers.linear import (
|
|
35
34
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
36
35
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
37
36
|
from sglang.srt.layers.radix_attention import RadixAttention
|
37
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
38
38
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
39
39
|
ParallelLMHead,
|
40
40
|
VocabParallelEmbedding,
|
sglang/srt/models/commandr.py
CHANGED
@@ -44,12 +44,11 @@ import torch.utils.checkpoint
|
|
44
44
|
from torch import nn
|
45
45
|
from torch.nn.parameter import Parameter
|
46
46
|
from transformers import PretrainedConfig
|
47
|
-
|
47
|
+
|
48
|
+
from sglang.srt.distributed import (
|
48
49
|
get_tensor_model_parallel_rank,
|
49
50
|
get_tensor_model_parallel_world_size,
|
50
51
|
)
|
51
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
52
|
-
|
53
52
|
from sglang.srt.layers.activation import SiluAndMul
|
54
53
|
from sglang.srt.layers.linear import (
|
55
54
|
MergedColumnParallelLinear,
|
@@ -59,6 +58,7 @@ from sglang.srt.layers.linear import (
|
|
59
58
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
60
59
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
61
60
|
from sglang.srt.layers.radix_attention import RadixAttention
|
61
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
62
62
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
63
63
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
64
64
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
sglang/srt/models/dbrx.py
CHANGED
@@ -19,14 +19,13 @@ from typing import Iterable, Optional, Tuple
|
|
19
19
|
|
20
20
|
import torch
|
21
21
|
import torch.nn as nn
|
22
|
-
|
22
|
+
|
23
|
+
from sglang.srt.configs import DbrxConfig
|
24
|
+
from sglang.srt.distributed import (
|
23
25
|
get_tensor_model_parallel_rank,
|
24
26
|
get_tensor_model_parallel_world_size,
|
25
27
|
tensor_model_parallel_all_reduce,
|
26
28
|
)
|
27
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
28
|
-
|
29
|
-
from sglang.srt.configs import DbrxConfig
|
30
29
|
from sglang.srt.layers.linear import (
|
31
30
|
QKVParallelLinear,
|
32
31
|
ReplicatedLinear,
|
@@ -36,6 +35,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
36
35
|
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
|
37
36
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
38
37
|
from sglang.srt.layers.radix_attention import RadixAttention
|
38
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
39
39
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
40
40
|
DEFAULT_VOCAB_PADDING_SIZE,
|
41
41
|
ParallelLMHead,
|
sglang/srt/models/deepseek.py
CHANGED
@@ -21,13 +21,12 @@ from typing import Any, Dict, Iterable, Optional, Tuple
|
|
21
21
|
import torch
|
22
22
|
from torch import nn
|
23
23
|
from transformers import PretrainedConfig
|
24
|
-
|
24
|
+
|
25
|
+
from sglang.srt.distributed import (
|
25
26
|
get_tensor_model_parallel_rank,
|
26
27
|
get_tensor_model_parallel_world_size,
|
27
28
|
tensor_model_parallel_all_reduce,
|
28
29
|
)
|
29
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
30
|
-
|
31
30
|
from sglang.srt.layers.activation import SiluAndMul
|
32
31
|
from sglang.srt.layers.layernorm import RMSNorm
|
33
32
|
from sglang.srt.layers.linear import (
|
@@ -40,6 +39,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
40
39
|
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
|
41
40
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
42
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
42
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
43
43
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
44
44
|
ParallelLMHead,
|
45
45
|
VocabParallelEmbedding,
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -23,14 +23,13 @@ import torch.nn.functional as F
|
|
23
23
|
from torch import nn
|
24
24
|
from transformers import PretrainedConfig
|
25
25
|
from vllm import _custom_ops as ops
|
26
|
-
|
26
|
+
|
27
|
+
from sglang.srt.distributed import (
|
27
28
|
get_tensor_model_parallel_rank,
|
28
29
|
get_tensor_model_parallel_world_size,
|
29
30
|
get_tp_group,
|
30
31
|
tensor_model_parallel_all_reduce,
|
31
32
|
)
|
32
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
33
|
-
|
34
33
|
from sglang.srt.layers.activation import SiluAndMul
|
35
34
|
from sglang.srt.layers.layernorm import RMSNorm
|
36
35
|
from sglang.srt.layers.linear import (
|
@@ -49,6 +48,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|
49
48
|
normalize_e4m3fn_to_e4m3fnuz,
|
50
49
|
)
|
51
50
|
from sglang.srt.layers.radix_attention import RadixAttention
|
51
|
+
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
|
52
52
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
53
53
|
ParallelLMHead,
|
54
54
|
VocabParallelEmbedding,
|
@@ -271,13 +271,14 @@ class DeepseekV2Attention(nn.Module):
|
|
271
271
|
quant_config=quant_config,
|
272
272
|
)
|
273
273
|
rope_scaling["rope_type"] = "deepseek_yarn"
|
274
|
-
self.rotary_emb =
|
274
|
+
self.rotary_emb = get_rope_wrapper(
|
275
275
|
qk_rope_head_dim,
|
276
276
|
rotary_dim=qk_rope_head_dim,
|
277
277
|
max_position=max_position_embeddings,
|
278
278
|
base=rope_theta,
|
279
279
|
rope_scaling=rope_scaling,
|
280
280
|
is_neox_style=False,
|
281
|
+
device=global_server_args_dict["device"],
|
281
282
|
)
|
282
283
|
|
283
284
|
if rope_scaling:
|
@@ -855,10 +856,9 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
855
856
|
forward_batch: ForwardBatch,
|
856
857
|
) -> torch.Tensor:
|
857
858
|
hidden_states = self.model(input_ids, positions, forward_batch)
|
858
|
-
|
859
|
-
|
860
|
-
|
861
|
-
)
|
859
|
+
return self.logits_processor(
|
860
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
861
|
+
)
|
862
862
|
|
863
863
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
864
864
|
stacked_params_mapping = [
|
sglang/srt/models/exaone.py
CHANGED
@@ -20,9 +20,8 @@ from typing import Any, Dict, Iterable, Optional, Tuple
|
|
20
20
|
|
21
21
|
import torch
|
22
22
|
from torch import nn
|
23
|
-
from vllm.distributed import get_tensor_model_parallel_world_size
|
24
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
25
23
|
|
24
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
26
25
|
from sglang.srt.layers.activation import SiluAndMul
|
27
26
|
from sglang.srt.layers.layernorm import RMSNorm
|
28
27
|
from sglang.srt.layers.linear import (
|
@@ -33,6 +32,7 @@ from sglang.srt.layers.linear import (
|
|
33
32
|
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
34
33
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
35
34
|
from sglang.srt.layers.radix_attention import RadixAttention
|
35
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
36
36
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
37
37
|
ParallelLMHead,
|
38
38
|
VocabParallelEmbedding,
|
sglang/srt/models/gemma.py
CHANGED
@@ -21,9 +21,8 @@ from typing import Iterable, Optional, Tuple
|
|
21
21
|
import torch
|
22
22
|
from torch import nn
|
23
23
|
from transformers import PretrainedConfig
|
24
|
-
from vllm.distributed import get_tensor_model_parallel_world_size
|
25
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
26
24
|
|
25
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
27
26
|
from sglang.srt.layers.activation import GeluAndMul
|
28
27
|
from sglang.srt.layers.layernorm import RMSNorm
|
29
28
|
from sglang.srt.layers.linear import (
|
@@ -34,6 +33,7 @@ from sglang.srt.layers.linear import (
|
|
34
33
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
35
34
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
36
35
|
from sglang.srt.layers.radix_attention import RadixAttention
|
36
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
37
37
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
38
38
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
39
39
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
sglang/srt/models/gemma2.py
CHANGED
@@ -15,13 +15,13 @@
|
|
15
15
|
# Adapted from:
|
16
16
|
# https://github.com/vllm-project/vllm/blob/56b325e977435af744f8b3dca7af0ca209663558/vllm/model_executor/models/gemma2.py
|
17
17
|
|
18
|
-
from typing import Iterable, Optional, Set, Tuple
|
18
|
+
from typing import Iterable, Optional, Set, Tuple
|
19
19
|
|
20
20
|
import torch
|
21
21
|
from torch import nn
|
22
22
|
from transformers import PretrainedConfig
|
23
|
-
from vllm.distributed import get_tensor_model_parallel_world_size
|
24
23
|
|
24
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
25
25
|
from sglang.srt.layers.activation import GeluAndMul
|
26
26
|
from sglang.srt.layers.layernorm import GemmaRMSNorm
|
27
27
|
from sglang.srt.layers.linear import (
|
@@ -32,6 +32,7 @@ from sglang.srt.layers.linear import (
|
|
32
32
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
33
33
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
34
34
|
from sglang.srt.layers.radix_attention import RadixAttention
|
35
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
35
36
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
36
37
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
37
38
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
@@ -44,23 +45,6 @@ def get_attention_sliding_window_size(config):
|
|
44
45
|
return config.sliding_window - 1
|
45
46
|
|
46
47
|
|
47
|
-
# FIXME: temporary solution, remove after next vllm release
|
48
|
-
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
49
|
-
|
50
|
-
|
51
|
-
class GemmaRotaryEmbedding(RotaryEmbedding):
|
52
|
-
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
53
|
-
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
|
54
|
-
inv_freq = 1.0 / (
|
55
|
-
base
|
56
|
-
** (
|
57
|
-
torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float()
|
58
|
-
/ self.rotary_dim
|
59
|
-
)
|
60
|
-
)
|
61
|
-
return inv_freq
|
62
|
-
|
63
|
-
|
64
48
|
class Gemma2MLP(nn.Module):
|
65
49
|
def __init__(
|
66
50
|
self,
|
@@ -143,14 +127,12 @@ class Gemma2Attention(nn.Module):
|
|
143
127
|
bias=config.attention_bias,
|
144
128
|
quant_config=quant_config,
|
145
129
|
)
|
146
|
-
|
147
|
-
self.rotary_emb = GemmaRotaryEmbedding(
|
148
|
-
self.head_dim,
|
130
|
+
self.rotary_emb = get_rope(
|
149
131
|
self.head_dim,
|
150
|
-
|
132
|
+
rotary_dim=self.head_dim,
|
133
|
+
max_position=max_position_embeddings,
|
151
134
|
base=self.rope_theta,
|
152
135
|
is_neox_style=True,
|
153
|
-
dtype=torch.get_default_dtype(),
|
154
136
|
)
|
155
137
|
|
156
138
|
use_sliding_window = layer_id % 2 == 0 and hasattr(config, "sliding_window")
|
sglang/srt/models/gpt2.py
CHANGED
@@ -17,16 +17,14 @@
|
|
17
17
|
# See the License for the specific language governing permissions and
|
18
18
|
# limitations under the License.
|
19
19
|
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
|
20
|
-
from typing import Iterable,
|
20
|
+
from typing import Iterable, Optional, Tuple
|
21
21
|
|
22
22
|
import torch
|
23
23
|
from torch import nn
|
24
24
|
from transformers import GPT2Config
|
25
|
-
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
|
26
|
-
from vllm.model_executor.layers.activation import get_act_fn
|
27
|
-
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
28
25
|
|
29
|
-
|
26
|
+
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_world_size
|
27
|
+
from sglang.srt.layers.activation import get_act_fn
|
30
28
|
from sglang.srt.layers.linear import (
|
31
29
|
ColumnParallelLinear,
|
32
30
|
QKVParallelLinear,
|
sglang/srt/models/gpt_bigcode.py
CHANGED
@@ -21,8 +21,8 @@ from typing import Iterable, Optional, Tuple
|
|
21
21
|
import torch
|
22
22
|
from torch import nn
|
23
23
|
from transformers import GPTBigCodeConfig
|
24
|
-
from vllm.distributed import get_tensor_model_parallel_world_size
|
25
24
|
|
25
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
26
26
|
from sglang.srt.layers.activation import get_act_fn
|
27
27
|
from sglang.srt.layers.linear import (
|
28
28
|
ColumnParallelLinear,
|