sglang 0.3.5.post2__py3-none-any.whl → 0.3.6.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/api.py +2 -2
- sglang/bench_latency.py +1 -553
- sglang/bench_offline_throughput.py +48 -20
- sglang/bench_one_batch.py +472 -0
- sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
- sglang/bench_serving.py +125 -6
- sglang/check_env.py +3 -6
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +2 -2
- sglang/srt/configs/model_config.py +13 -14
- sglang/srt/constrained/__init__.py +13 -14
- sglang/srt/constrained/base_grammar_backend.py +13 -15
- sglang/srt/constrained/outlines_backend.py +28 -17
- sglang/srt/constrained/outlines_jump_forward.py +13 -15
- sglang/srt/constrained/xgrammar_backend.py +47 -58
- sglang/srt/conversation.py +13 -15
- sglang/srt/hf_transformers_utils.py +13 -15
- sglang/srt/layers/activation.py +16 -13
- sglang/srt/layers/attention/flashinfer_backend.py +106 -54
- sglang/srt/layers/attention/triton_backend.py +9 -7
- sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
- sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
- sglang/srt/layers/custom_op_util.py +25 -0
- sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
- sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +11 -4
- sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
- sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
- sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
- sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
- sglang/srt/layers/fused_moe_triton/layer.py +633 -0
- sglang/srt/layers/layernorm.py +17 -15
- sglang/srt/layers/logits_processor.py +23 -25
- sglang/srt/layers/quantization/__init__.py +77 -17
- sglang/srt/layers/radix_attention.py +13 -15
- sglang/srt/layers/rotary_embedding.py +13 -13
- sglang/srt/layers/sampler.py +4 -8
- sglang/srt/layers/torchao_utils.py +2 -0
- sglang/srt/lora/lora.py +13 -14
- sglang/srt/lora/lora_config.py +13 -14
- sglang/srt/lora/lora_manager.py +22 -24
- sglang/srt/managers/data_parallel_controller.py +98 -27
- sglang/srt/managers/detokenizer_manager.py +13 -15
- sglang/srt/managers/io_struct.py +63 -21
- sglang/srt/managers/schedule_batch.py +154 -59
- sglang/srt/managers/schedule_policy.py +18 -16
- sglang/srt/managers/scheduler.py +278 -109
- sglang/srt/managers/session_controller.py +61 -0
- sglang/srt/managers/tokenizer_manager.py +63 -18
- sglang/srt/managers/tp_worker.py +25 -16
- sglang/srt/managers/tp_worker_overlap_thread.py +62 -67
- sglang/srt/metrics/collector.py +13 -15
- sglang/srt/metrics/func_timer.py +13 -15
- sglang/srt/mm_utils.py +13 -14
- sglang/srt/model_executor/cuda_graph_runner.py +63 -25
- sglang/srt/model_executor/forward_batch_info.py +128 -32
- sglang/srt/model_executor/model_runner.py +132 -64
- sglang/srt/model_parallel.py +98 -0
- sglang/srt/models/chatglm.py +15 -16
- sglang/srt/models/commandr.py +15 -16
- sglang/srt/models/dbrx.py +15 -16
- sglang/srt/models/deepseek.py +15 -15
- sglang/srt/models/deepseek_v2.py +162 -59
- sglang/srt/models/exaone.py +14 -15
- sglang/srt/models/gemma.py +14 -14
- sglang/srt/models/gemma2.py +31 -25
- sglang/srt/models/gemma2_reward.py +13 -14
- sglang/srt/models/gpt_bigcode.py +14 -14
- sglang/srt/models/grok.py +15 -15
- sglang/srt/models/internlm2.py +13 -15
- sglang/srt/models/internlm2_reward.py +13 -14
- sglang/srt/models/llama.py +21 -21
- sglang/srt/models/llama_classification.py +13 -14
- sglang/srt/models/llama_reward.py +13 -14
- sglang/srt/models/llava.py +14 -16
- sglang/srt/models/llavavid.py +14 -16
- sglang/srt/models/minicpm.py +13 -15
- sglang/srt/models/minicpm3.py +13 -15
- sglang/srt/models/mistral.py +13 -15
- sglang/srt/models/mixtral.py +15 -15
- sglang/srt/models/mixtral_quant.py +14 -14
- sglang/srt/models/olmo.py +22 -20
- sglang/srt/models/olmoe.py +23 -20
- sglang/srt/models/phi3_small.py +447 -0
- sglang/srt/models/qwen.py +14 -14
- sglang/srt/models/qwen2.py +22 -19
- sglang/srt/models/qwen2_moe.py +17 -18
- sglang/srt/models/qwen2_vl.py +13 -6
- sglang/srt/models/stablelm.py +18 -16
- sglang/srt/models/torch_native_llama.py +107 -93
- sglang/srt/models/xverse.py +13 -14
- sglang/srt/models/xverse_moe.py +15 -16
- sglang/srt/models/yivl.py +13 -15
- sglang/srt/openai_api/adapter.py +19 -17
- sglang/srt/openai_api/protocol.py +14 -16
- sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
- sglang/srt/sampling/sampling_batch_info.py +61 -57
- sglang/srt/sampling/sampling_params.py +14 -16
- sglang/srt/server.py +86 -35
- sglang/srt/server_args.py +96 -80
- sglang/srt/utils.py +266 -68
- sglang/test/few_shot_gsm8k.py +8 -4
- sglang/test/runners.py +38 -20
- sglang/test/srt/sampling/penaltylib/utils.py +23 -21
- sglang/test/test_utils.py +31 -20
- sglang/version.py +1 -1
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/LICENSE +1 -1
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/METADATA +66 -57
- sglang-0.3.6.post1.dist-info/RECORD +164 -0
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/WHEEL +1 -1
- sglang/srt/layers/fused_moe/__init__.py +0 -1
- sglang-0.3.5.post2.dist-info/RECORD +0 -156
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/top_level.txt +0 -0
@@ -1,23 +1,22 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""ModelRunner runs the forward passes of the models."""
|
17
15
|
|
18
16
|
import gc
|
19
17
|
import importlib
|
20
18
|
import importlib.resources
|
19
|
+
import inspect
|
21
20
|
import json
|
22
21
|
import logging
|
23
22
|
import pkgutil
|
@@ -56,10 +55,13 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
56
55
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
57
56
|
from sglang.srt.server_args import ServerArgs
|
58
57
|
from sglang.srt.utils import (
|
58
|
+
crash_on_warnings,
|
59
59
|
enable_show_time_cost,
|
60
60
|
get_available_gpu_memory,
|
61
|
-
|
61
|
+
is_hip,
|
62
|
+
monkey_patch_vllm_model_config,
|
62
63
|
monkey_patch_vllm_p2p_access_check,
|
64
|
+
set_cpu_offload_max_bytes,
|
63
65
|
)
|
64
66
|
|
65
67
|
logger = logging.getLogger(__name__)
|
@@ -113,7 +115,7 @@ class ModelRunner:
|
|
113
115
|
)
|
114
116
|
|
115
117
|
if self.is_multimodal:
|
116
|
-
logger.
|
118
|
+
logger.info(
|
117
119
|
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
|
118
120
|
)
|
119
121
|
server_args.chunked_prefill_size = None
|
@@ -139,15 +141,26 @@ class ModelRunner:
|
|
139
141
|
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
140
142
|
"disable_mla": server_args.disable_mla,
|
141
143
|
"torchao_config": server_args.torchao_config,
|
142
|
-
"
|
143
|
-
"
|
144
|
+
"enable_nan_detection": server_args.enable_nan_detection,
|
145
|
+
"enable_dp_attention": server_args.enable_dp_attention,
|
144
146
|
}
|
145
147
|
)
|
146
148
|
|
147
|
-
|
149
|
+
set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))
|
150
|
+
|
151
|
+
# Init components
|
148
152
|
min_per_gpu_memory = self.init_torch_distributed()
|
149
153
|
self.sampler = Sampler()
|
150
154
|
self.load_model()
|
155
|
+
|
156
|
+
# Apply torch TP if model supports it
|
157
|
+
supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
|
158
|
+
if self.tp_size > 1 and supports_torch_tp:
|
159
|
+
self.apply_torch_tp()
|
160
|
+
self.torch_tp_applied = True
|
161
|
+
else:
|
162
|
+
self.torch_tp_applied = False
|
163
|
+
|
151
164
|
if server_args.lora_paths is not None:
|
152
165
|
self.init_lora_manager()
|
153
166
|
self.init_memory_pool(
|
@@ -166,14 +179,15 @@ class ModelRunner:
|
|
166
179
|
def init_torch_distributed(self):
|
167
180
|
logger.info("Init torch distributed begin.")
|
168
181
|
# Init torch distributed
|
182
|
+
torch.get_device_module(self.device).set_device(self.gpu_id)
|
169
183
|
if self.device == "cuda":
|
170
|
-
torch.cuda.set_device(self.gpu_id)
|
171
184
|
backend = "nccl"
|
172
185
|
# ToDO(liangan1):Just use gloo to bypass the initilization fail
|
173
186
|
# Need to use xccl for xpu backend in the future
|
174
187
|
elif self.device == "xpu":
|
175
|
-
torch.xpu.set_device(self.gpu_id)
|
176
188
|
backend = "gloo"
|
189
|
+
elif self.device == "hpu":
|
190
|
+
backend = "hccl"
|
177
191
|
|
178
192
|
if not self.server_args.enable_p2p_check:
|
179
193
|
monkey_patch_vllm_p2p_access_check(self.gpu_id)
|
@@ -215,6 +229,49 @@ class ModelRunner:
|
|
215
229
|
|
216
230
|
return min_per_gpu_memory
|
217
231
|
|
232
|
+
def setup_model(self):
|
233
|
+
try:
|
234
|
+
from vllm.config import VllmConfig
|
235
|
+
|
236
|
+
vllm_config = VllmConfig()
|
237
|
+
vllm_config.model_config = self.vllm_model_config
|
238
|
+
vllm_config.load_config = self.load_config
|
239
|
+
vllm_config.device_config = DeviceConfig(self.device)
|
240
|
+
vllm_config.quant_config = VllmConfig._get_quantization_config(
|
241
|
+
vllm_config.model_config, vllm_config.load_config
|
242
|
+
)
|
243
|
+
return get_model(vllm_config=vllm_config)
|
244
|
+
except ImportError:
|
245
|
+
pass
|
246
|
+
|
247
|
+
return get_model(
|
248
|
+
model_config=self.vllm_model_config,
|
249
|
+
load_config=self.load_config,
|
250
|
+
device_config=DeviceConfig(self.device),
|
251
|
+
parallel_config=None,
|
252
|
+
scheduler_config=None,
|
253
|
+
lora_config=None,
|
254
|
+
cache_config=None,
|
255
|
+
)
|
256
|
+
|
257
|
+
def get_model_config_params(self):
|
258
|
+
sig = inspect.signature(VllmModelConfig.__init__)
|
259
|
+
params = {
|
260
|
+
"model": self.server_args.model_path,
|
261
|
+
"quantization": self.server_args.quantization,
|
262
|
+
"tokenizer": None,
|
263
|
+
"tokenizer_mode": None,
|
264
|
+
"trust_remote_code": self.server_args.trust_remote_code,
|
265
|
+
"dtype": self.server_args.dtype,
|
266
|
+
"seed": self.server_args.random_seed,
|
267
|
+
"skip_tokenizer_init": True,
|
268
|
+
}
|
269
|
+
|
270
|
+
if "task" in sig.parameters:
|
271
|
+
params["task"] = ""
|
272
|
+
|
273
|
+
return params
|
274
|
+
|
218
275
|
def load_model(self):
|
219
276
|
logger.info(
|
220
277
|
f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
@@ -232,42 +289,25 @@ class ModelRunner:
|
|
232
289
|
raise RuntimeError("SGLang only supports sm75 and above.")
|
233
290
|
|
234
291
|
# Prepare the vllm model config
|
235
|
-
monkey_patch_vllm_dummy_weight_loader()
|
236
292
|
self.load_config = LoadConfig(
|
237
293
|
load_format=self.server_args.load_format,
|
238
294
|
download_dir=self.server_args.download_dir,
|
239
295
|
)
|
240
|
-
|
241
|
-
|
242
|
-
quantization=self.server_args.quantization,
|
243
|
-
tokenizer=None,
|
244
|
-
tokenizer_mode=None,
|
245
|
-
trust_remote_code=self.server_args.trust_remote_code,
|
246
|
-
dtype=self.server_args.dtype,
|
247
|
-
seed=self.server_args.random_seed,
|
248
|
-
skip_tokenizer_init=True,
|
249
|
-
)
|
296
|
+
monkey_patch_vllm_model_config()
|
297
|
+
self.vllm_model_config = VllmModelConfig(**self.get_model_config_params())
|
250
298
|
if self.model_config.model_override_args is not None:
|
251
299
|
self.vllm_model_config.hf_config.update(
|
252
300
|
self.model_config.model_override_args
|
253
301
|
)
|
254
|
-
self.dtype = self.vllm_model_config.dtype
|
255
302
|
|
256
|
-
|
257
|
-
|
258
|
-
model_config=self.vllm_model_config,
|
259
|
-
load_config=self.load_config,
|
260
|
-
device_config=DeviceConfig(self.device),
|
261
|
-
parallel_config=None,
|
262
|
-
scheduler_config=None,
|
263
|
-
lora_config=None,
|
264
|
-
cache_config=None,
|
265
|
-
)
|
303
|
+
self.model = self.setup_model()
|
304
|
+
|
266
305
|
self.sliding_window_size = (
|
267
306
|
self.model.get_attention_sliding_window_size()
|
268
307
|
if hasattr(self.model, "get_attention_sliding_window_size")
|
269
308
|
else None
|
270
309
|
)
|
310
|
+
self.dtype = self.vllm_model_config.dtype
|
271
311
|
|
272
312
|
logger.info(
|
273
313
|
f"Load weight end. "
|
@@ -293,17 +333,9 @@ class ModelRunner:
|
|
293
333
|
target_device = torch.device(self.device)
|
294
334
|
|
295
335
|
try:
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
quantization=self.server_args.quantization,
|
300
|
-
tokenizer=None,
|
301
|
-
tokenizer_mode=None,
|
302
|
-
trust_remote_code=self.server_args.trust_remote_code,
|
303
|
-
dtype=self.server_args.dtype,
|
304
|
-
seed=self.server_args.random_seed,
|
305
|
-
skip_tokenizer_init=True,
|
306
|
-
)
|
336
|
+
model_config_params = self.get_model_config_params()
|
337
|
+
model_config_params["model"] = model_path
|
338
|
+
vllm_model_config = VllmModelConfig(**model_config_params)
|
307
339
|
except Exception as e:
|
308
340
|
message = f"Failed to load model config: {e}."
|
309
341
|
return False, message
|
@@ -412,7 +444,10 @@ class ModelRunner:
|
|
412
444
|
if self.server_args.kv_cache_dtype == "auto":
|
413
445
|
self.kv_cache_dtype = self.dtype
|
414
446
|
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
|
415
|
-
|
447
|
+
if is_hip(): # Using natively supported format
|
448
|
+
self.kv_cache_dtype = torch.float8_e5m2fnuz
|
449
|
+
else:
|
450
|
+
self.kv_cache_dtype = torch.float8_e5m2
|
416
451
|
else:
|
417
452
|
raise ValueError(
|
418
453
|
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
|
@@ -551,6 +586,13 @@ class ModelRunner:
|
|
551
586
|
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
552
587
|
self.cuda_graph_runner = CudaGraphRunner(self)
|
553
588
|
|
589
|
+
def apply_torch_tp(self):
|
590
|
+
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
|
591
|
+
from sglang.srt.model_parallel import tensor_parallel
|
592
|
+
|
593
|
+
device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
|
594
|
+
tensor_parallel(self.model, device_mesh)
|
595
|
+
|
554
596
|
def forward_decode(self, forward_batch: ForwardBatch):
|
555
597
|
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
|
556
598
|
return self.cuda_graph_runner.replay(forward_batch)
|
@@ -564,9 +606,17 @@ class ModelRunner:
|
|
564
606
|
def forward_extend(self, forward_batch: ForwardBatch):
|
565
607
|
self.attn_backend.init_forward_metadata(forward_batch)
|
566
608
|
if self.is_generation:
|
567
|
-
|
568
|
-
|
569
|
-
|
609
|
+
if forward_batch.input_embeds is None:
|
610
|
+
return self.model.forward(
|
611
|
+
forward_batch.input_ids, forward_batch.positions, forward_batch
|
612
|
+
)
|
613
|
+
else:
|
614
|
+
return self.model.forward(
|
615
|
+
forward_batch.input_ids,
|
616
|
+
forward_batch.positions,
|
617
|
+
forward_batch,
|
618
|
+
input_embeds=forward_batch.input_embeds.bfloat16(),
|
619
|
+
)
|
570
620
|
else:
|
571
621
|
# Only embedding models have get_embedding parameter
|
572
622
|
return self.model.forward(
|
@@ -576,21 +626,37 @@ class ModelRunner:
|
|
576
626
|
get_embedding=True,
|
577
627
|
)
|
578
628
|
|
629
|
+
def forward_idle(self, forward_batch: ForwardBatch):
|
630
|
+
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
|
631
|
+
return self.cuda_graph_runner.replay(forward_batch)
|
632
|
+
|
633
|
+
return self.model.forward(
|
634
|
+
forward_batch.input_ids, forward_batch.positions, forward_batch
|
635
|
+
)
|
636
|
+
|
579
637
|
def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
|
580
638
|
if forward_batch.forward_mode.is_decode():
|
581
639
|
return self.forward_decode(forward_batch)
|
582
640
|
elif forward_batch.forward_mode.is_extend():
|
583
641
|
return self.forward_extend(forward_batch)
|
642
|
+
elif forward_batch.forward_mode.is_idle():
|
643
|
+
return self.forward_idle(forward_batch)
|
584
644
|
else:
|
585
645
|
raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
|
586
646
|
|
587
647
|
def sample(
|
588
648
|
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
589
649
|
) -> torch.Tensor:
|
590
|
-
# Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
591
650
|
sampling_info = forward_batch.sampling_info
|
592
|
-
sampling_info.
|
593
|
-
|
651
|
+
if sampling_info.sampling_info_done:
|
652
|
+
# Overlap mode: the function update_regex_vocab_mask was executed
|
653
|
+
# in process_batch_result of the last batch.
|
654
|
+
if sampling_info.grammars:
|
655
|
+
sampling_info.sampling_info_done.wait()
|
656
|
+
else:
|
657
|
+
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
658
|
+
sampling_info.update_regex_vocab_mask()
|
659
|
+
sampling_info.update_penalties()
|
594
660
|
logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)
|
595
661
|
|
596
662
|
# Sample the next tokens.
|
@@ -616,7 +682,7 @@ class ModelRunner:
|
|
616
682
|
|
617
683
|
# Apply regex vocab_mask
|
618
684
|
if sampling_info.vocab_mask is not None:
|
619
|
-
logits =
|
685
|
+
sampling_info.apply_mask(logits=logits, vocab_mask=sampling_info.vocab_mask)
|
620
686
|
|
621
687
|
return logits
|
622
688
|
|
@@ -640,7 +706,9 @@ def import_model_classes():
|
|
640
706
|
try:
|
641
707
|
module = importlib.import_module(name)
|
642
708
|
except Exception as e:
|
643
|
-
logger.warning(f"Ignore import error when loading {name}.
|
709
|
+
logger.warning(f"Ignore import error when loading {name}. {e}")
|
710
|
+
if crash_on_warnings():
|
711
|
+
raise ValueError(f"Ignore import error when loading {name}. {e}")
|
644
712
|
continue
|
645
713
|
if hasattr(module, "EntryClass"):
|
646
714
|
entry = module.EntryClass
|
@@ -0,0 +1,98 @@
|
|
1
|
+
"""
|
2
|
+
Common utilities for torch model parallelism.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from typing import Optional
|
6
|
+
|
7
|
+
import torch
|
8
|
+
from torch.distributed.device_mesh import DeviceMesh
|
9
|
+
|
10
|
+
try:
|
11
|
+
from torch.distributed.tensor import DTensor, Shard
|
12
|
+
except ImportError:
|
13
|
+
# torch 2.4 or older
|
14
|
+
from torch.distributed._tensor import DTensor, Shard
|
15
|
+
|
16
|
+
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
17
|
+
from torch.distributed.tensor.parallel import (
|
18
|
+
ColwiseParallel,
|
19
|
+
RowwiseParallel,
|
20
|
+
parallelize_module,
|
21
|
+
)
|
22
|
+
|
23
|
+
|
24
|
+
class ColwiseParallelSharded(ColwiseParallel):
|
25
|
+
"""
|
26
|
+
A version of ColwiseParallel where the local weight has been already
|
27
|
+
sharded. This is used for the fused wqkv case, where during loading, we
|
28
|
+
already sharded wq, wk, wv before fusing them.
|
29
|
+
"""
|
30
|
+
|
31
|
+
# Override the _partition_linear_fn in ColwiseParallel
|
32
|
+
def _partition_linear_fn(self, name, module, device_mesh):
|
33
|
+
# colwise shard weight/bias to Shard(0), weight be Shard(0)
|
34
|
+
# means Colwise as Linear is input * weight^T + bias, where
|
35
|
+
# weight would become Shard(1)
|
36
|
+
for name, param in module.named_parameters():
|
37
|
+
dtensor = DTensor.from_local(param, device_mesh, [Shard(0)])
|
38
|
+
dist_param = torch.nn.Parameter(dtensor, requires_grad=False)
|
39
|
+
module.register_parameter(name, dist_param)
|
40
|
+
|
41
|
+
|
42
|
+
class RowwiseParallelMaybeWait(RowwiseParallel):
|
43
|
+
"""
|
44
|
+
A version of RowwiseParallel that waits for the output (establish dependency
|
45
|
+
between comm stream and compute stream in CUDA sense) before going into the
|
46
|
+
next op. This is needed to workaround the current interaction between
|
47
|
+
AsyncCollectiveTensor and custom ops, such as `class RMSNorm(CustomOp)`.
|
48
|
+
"""
|
49
|
+
|
50
|
+
@staticmethod
|
51
|
+
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
|
52
|
+
outputs = super(
|
53
|
+
RowwiseParallelMaybeWait, RowwiseParallelMaybeWait
|
54
|
+
)._prepare_output_fn(
|
55
|
+
output_layouts, use_local_output, mod, outputs, device_mesh
|
56
|
+
)
|
57
|
+
# wait for the output to be ready
|
58
|
+
if isinstance(outputs, AsyncCollectiveTensor):
|
59
|
+
return outputs.wait()
|
60
|
+
else:
|
61
|
+
return outputs
|
62
|
+
|
63
|
+
|
64
|
+
def tensor_parallel(
|
65
|
+
module: torch.nn.Module,
|
66
|
+
device_mesh: Optional[DeviceMesh] = None,
|
67
|
+
):
|
68
|
+
"""
|
69
|
+
Tensor parallelize the model across the given device mesh.
|
70
|
+
Args:
|
71
|
+
module (`torch.nn.Module`):
|
72
|
+
The module to tensor parallelize.
|
73
|
+
device_mesh (`torch.distributed.DeviceMesh`):
|
74
|
+
The device mesh to use for tensor parallelism.
|
75
|
+
"""
|
76
|
+
|
77
|
+
# Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module.
|
78
|
+
# No op if `_tp_plan` attribute does not exist under the module.
|
79
|
+
# This is a helper function to be used with `model.apply` to recursively
|
80
|
+
# parallelize a model.
|
81
|
+
def tplize(mod: torch.nn.Module) -> None:
|
82
|
+
tp_plan = getattr(mod, "_tp_plan", None)
|
83
|
+
if tp_plan is None:
|
84
|
+
return
|
85
|
+
for child_name, tp_style in tp_plan.items():
|
86
|
+
submod = mod.get_submodule(child_name)
|
87
|
+
if tp_style == "Colwise":
|
88
|
+
parallelize_module(submod, device_mesh, ColwiseParallel())
|
89
|
+
elif tp_style == "Rowwise":
|
90
|
+
parallelize_module(submod, device_mesh, RowwiseParallelMaybeWait())
|
91
|
+
elif tp_style == "Colwise_Sharded":
|
92
|
+
parallelize_module(submod, device_mesh, ColwiseParallelSharded())
|
93
|
+
else:
|
94
|
+
raise ValueError(f"Unknown TP style {tp_style}")
|
95
|
+
|
96
|
+
# `apply` is a native method of `nn.Module` that recursively applies a
|
97
|
+
# function to every submodule.
|
98
|
+
module.apply(tplize)
|
sglang/srt/models/chatglm.py
CHANGED
@@ -1,22 +1,21 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
# coding=utf-8
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
|
17
15
|
# Adapted from
|
18
16
|
# https://github.com/THUDM/ChatGLM2-6B
|
19
17
|
"""Inference-only ChatGLM model compatible with THUDM weights."""
|
18
|
+
|
20
19
|
from typing import Iterable, Optional, Tuple
|
21
20
|
|
22
21
|
import torch
|
sglang/srt/models/commandr.py
CHANGED
@@ -1,19 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
16
|
-
# coding=utf-8
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
17
14
|
# Copyright 2024 Cohere and the HuggingFace Inc. team. All rights reserved.
|
18
15
|
#
|
19
16
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
@@ -32,12 +29,14 @@ limitations under the License.
|
|
32
29
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
33
30
|
# See the License for the specific language governing permissions and
|
34
31
|
# limitations under the License.
|
32
|
+
# ==============================================================================
|
35
33
|
|
36
34
|
# Adapted from
|
37
35
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/commandr.py#L1
|
38
36
|
|
39
37
|
# This file is based on the LLama model definition file in transformers
|
40
38
|
"""PyTorch Cohere model."""
|
39
|
+
|
41
40
|
from typing import Iterable, Optional, Tuple
|
42
41
|
|
43
42
|
import torch
|
sglang/srt/models/dbrx.py
CHANGED
@@ -1,21 +1,20 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
15
14
|
|
16
15
|
# Adapted from:
|
17
16
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/dbrx.py#L1
|
18
|
-
|
17
|
+
|
19
18
|
from typing import Iterable, Optional, Tuple
|
20
19
|
|
21
20
|
import torch
|
@@ -25,11 +24,11 @@ from vllm.distributed import (
|
|
25
24
|
get_tensor_model_parallel_world_size,
|
26
25
|
tensor_model_parallel_all_reduce,
|
27
26
|
)
|
28
|
-
from vllm.model_executor.layers.fused_moe import fused_moe
|
29
27
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
30
28
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
31
29
|
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
32
30
|
|
31
|
+
from sglang.srt.layers.fused_moe_triton import fused_moe
|
33
32
|
from sglang.srt.layers.linear import (
|
34
33
|
QKVParallelLinear,
|
35
34
|
ReplicatedLinear,
|
sglang/srt/models/deepseek.py
CHANGED
@@ -1,21 +1,21 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
15
14
|
|
16
15
|
# Adapted from:
|
17
16
|
# https://github.com/vllm-project/vllm/blob/14f91fe67c2342f2fe859dc6a5c40810df0e1c61/vllm/model_executor/models/deepseek.py
|
18
17
|
"""Inference-only Deepseek model."""
|
18
|
+
|
19
19
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
20
20
|
|
21
21
|
import torch
|
@@ -26,11 +26,11 @@ from vllm.distributed import (
|
|
26
26
|
get_tensor_model_parallel_world_size,
|
27
27
|
tensor_model_parallel_all_reduce,
|
28
28
|
)
|
29
|
-
from vllm.model_executor.layers.fused_moe import fused_moe
|
30
29
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
31
30
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
32
31
|
|
33
32
|
from sglang.srt.layers.activation import SiluAndMul
|
33
|
+
from sglang.srt.layers.fused_moe_triton import fused_moe
|
34
34
|
from sglang.srt.layers.layernorm import RMSNorm
|
35
35
|
from sglang.srt.layers.linear import (
|
36
36
|
MergedColumnParallelLinear,
|