sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_one_batch.py +8 -6
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +19 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +211 -72
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +30 -29
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +144 -55
- sglang/srt/disaggregation/utils.py +155 -123
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +37 -29
- sglang/srt/entrypoints/http_server.py +153 -72
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +40 -3
- sglang/srt/layers/attention/aiter_backend.py +20 -4
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
- sglang/srt/layers/attention/flashattention_backend.py +71 -72
- sglang/srt/layers/attention/flashinfer_backend.py +10 -8
- sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -12
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +138 -130
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +28 -10
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +2 -14
- sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
- sglang/srt/layers/moe/ep_moe/layer.py +249 -33
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +107 -12
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +249 -105
- sglang/srt/lora/mem_pool.py +53 -50
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -14
- sglang/srt/managers/io_struct.py +31 -10
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +79 -37
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +220 -79
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +40 -10
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -15
- sglang/srt/mem_cache/hiradix_cache.py +38 -25
- sglang/srt/mem_cache/memory_pool.py +213 -505
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +56 -28
- sglang/srt/model_executor/cuda_graph_runner.py +198 -100
- sglang/srt/model_executor/forward_batch_info.py +32 -10
- sglang/srt/model_executor/model_runner.py +28 -12
- sglang/srt/model_loader/loader.py +16 -2
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +213 -173
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +351 -238
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
- sglang/srt/speculative/eagle_utils.py +468 -116
- sglang/srt/speculative/eagle_worker.py +258 -84
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +235 -11
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +2 -0
- sglang/test/test_utils.py +4 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -1990
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -26,9 +26,11 @@ from typing import List, Optional, Tuple, Union
|
|
26
26
|
import torch
|
27
27
|
import torch.distributed as dist
|
28
28
|
|
29
|
+
from sglang.srt import debug_utils
|
29
30
|
from sglang.srt.configs.device_config import DeviceConfig
|
30
31
|
from sglang.srt.configs.load_config import LoadConfig
|
31
32
|
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
33
|
+
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
|
32
34
|
from sglang.srt.distributed import (
|
33
35
|
get_tp_group,
|
34
36
|
get_world_group,
|
@@ -45,10 +47,9 @@ from sglang.srt.layers.dp_attention import (
|
|
45
47
|
initialize_dp_attention,
|
46
48
|
)
|
47
49
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
48
|
-
from sglang.srt.layers.quantization import
|
49
|
-
|
50
|
-
|
51
|
-
update_deep_gemm_config,
|
50
|
+
from sglang.srt.layers.quantization import (
|
51
|
+
deep_gemm_wrapper,
|
52
|
+
monkey_patch_isinstance_for_vllm_base_layer,
|
52
53
|
)
|
53
54
|
from sglang.srt.layers.sampler import Sampler
|
54
55
|
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
@@ -70,14 +71,17 @@ from sglang.srt.managers.schedule_batch import (
|
|
70
71
|
GLOBAL_SERVER_ARGS_KEYS,
|
71
72
|
global_server_args_dict,
|
72
73
|
)
|
74
|
+
from sglang.srt.mem_cache.allocator import (
|
75
|
+
BaseTokenToKVPoolAllocator,
|
76
|
+
PagedTokenToKVPoolAllocator,
|
77
|
+
TokenToKVPoolAllocator,
|
78
|
+
)
|
73
79
|
from sglang.srt.mem_cache.memory_pool import (
|
74
80
|
DoubleSparseTokenToKVPool,
|
75
81
|
MHATokenToKVPool,
|
76
82
|
MLATokenToKVPool,
|
77
83
|
ReqToTokenPool,
|
78
|
-
TokenToKVPoolAllocator,
|
79
84
|
)
|
80
|
-
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
|
81
85
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
82
86
|
from sglang.srt.model_executor.expert_location_updater import ExpertLocationUpdater
|
83
87
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
@@ -93,6 +97,7 @@ from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
|
93
97
|
from sglang.srt.utils import (
|
94
98
|
MultiprocessingSerializer,
|
95
99
|
cpu_has_amx_support,
|
100
|
+
dynamic_import,
|
96
101
|
enable_show_time_cost,
|
97
102
|
get_available_gpu_memory,
|
98
103
|
get_bool_env_var,
|
@@ -110,6 +115,7 @@ from sglang.srt.utils import (
|
|
110
115
|
)
|
111
116
|
|
112
117
|
_is_hip = is_hip()
|
118
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
113
119
|
|
114
120
|
# Use a small KV cache pool size for tests in CI
|
115
121
|
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
|
@@ -149,7 +155,7 @@ class ModelRunner:
|
|
149
155
|
server_args: ServerArgs,
|
150
156
|
is_draft_worker: bool = False,
|
151
157
|
req_to_token_pool: Optional[ReqToTokenPool] = None,
|
152
|
-
token_to_kv_pool_allocator: Optional[
|
158
|
+
token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
|
153
159
|
):
|
154
160
|
# Parse args
|
155
161
|
self.model_config = model_config
|
@@ -162,6 +168,7 @@ class ModelRunner:
|
|
162
168
|
logger.addFilter(RankZeroFilter(tp_rank == 0))
|
163
169
|
self.tp_rank = tp_rank
|
164
170
|
self.tp_size = tp_size
|
171
|
+
self.dp_size = server_args.dp_size
|
165
172
|
self.pp_rank = pp_rank
|
166
173
|
self.pp_size = pp_size
|
167
174
|
self.dist_port = nccl_port
|
@@ -195,6 +202,7 @@ class ModelRunner:
|
|
195
202
|
| {
|
196
203
|
# TODO it is indeed not a "server args"
|
197
204
|
"use_mla_backend": self.use_mla_backend,
|
205
|
+
"speculative_algorithm": self.spec_algorithm,
|
198
206
|
}
|
199
207
|
)
|
200
208
|
|
@@ -205,8 +213,8 @@ class ModelRunner:
|
|
205
213
|
min_per_gpu_memory = self.init_torch_distributed()
|
206
214
|
|
207
215
|
# Update deep gemm configure
|
208
|
-
if
|
209
|
-
update_deep_gemm_config(gpu_id, server_args)
|
216
|
+
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
217
|
+
deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
|
210
218
|
|
211
219
|
# If it is a draft model, tp_group can be different
|
212
220
|
self.initialize(min_per_gpu_memory)
|
@@ -218,6 +226,7 @@ class ModelRunner:
|
|
218
226
|
|
219
227
|
def initialize(self, min_per_gpu_memory: float):
|
220
228
|
server_args = self.server_args
|
229
|
+
|
221
230
|
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
222
231
|
enable=self.server_args.enable_memory_saver
|
223
232
|
)
|
@@ -272,6 +281,10 @@ class ModelRunner:
|
|
272
281
|
self.apply_torch_tp()
|
273
282
|
|
274
283
|
# Init lora
|
284
|
+
# TODO (lifuhuang): when we support dynamic LoRA loading / unloading, we should add
|
285
|
+
# a new server arg `enable_lora` to control whether to init LoRA manager to be more
|
286
|
+
# explicit, as it is perfectly valid to start a server with an empty lora_paths and
|
287
|
+
# load LoRA adapters dynamically later.
|
275
288
|
if server_args.lora_paths is not None:
|
276
289
|
self.init_lora_manager()
|
277
290
|
|
@@ -299,7 +312,7 @@ class ModelRunner:
|
|
299
312
|
if (
|
300
313
|
server_args.attention_backend == "intel_amx"
|
301
314
|
and server_args.device == "cpu"
|
302
|
-
and not
|
315
|
+
and not _is_cpu_amx_available
|
303
316
|
):
|
304
317
|
logger.info(
|
305
318
|
"The current platform does not support Intel AMX, will fallback to torch_native backend."
|
@@ -543,7 +556,7 @@ class ModelRunner:
|
|
543
556
|
monkey_patch_vllm_parallel_state()
|
544
557
|
monkey_patch_isinstance_for_vllm_base_layer()
|
545
558
|
|
546
|
-
with self.memory_saver_adapter.region():
|
559
|
+
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_WEIGHTS):
|
547
560
|
self.model = get_model(
|
548
561
|
model_config=self.model_config,
|
549
562
|
load_config=self.load_config,
|
@@ -761,6 +774,9 @@ class ModelRunner:
|
|
761
774
|
]
|
762
775
|
if load_format == "direct":
|
763
776
|
_model_load_weights_direct(self.model, named_tensors)
|
777
|
+
elif load_format in self.server_args.custom_weight_loader:
|
778
|
+
custom_loader = dynamic_import(load_format)
|
779
|
+
custom_loader(self.model, named_tensors)
|
764
780
|
elif load_format is None:
|
765
781
|
self.model.load_weights(named_tensors)
|
766
782
|
else:
|
@@ -787,7 +803,6 @@ class ModelRunner:
|
|
787
803
|
def init_lora_manager(self):
|
788
804
|
self.lora_manager = LoRAManager(
|
789
805
|
base_model=self.model,
|
790
|
-
lora_paths=self.server_args.lora_paths,
|
791
806
|
base_hf_config=self.model_config.hf_config,
|
792
807
|
max_loras_per_batch=self.server_args.max_loras_per_batch,
|
793
808
|
load_config=self.load_config,
|
@@ -796,6 +811,7 @@ class ModelRunner:
|
|
796
811
|
tp_size=self.tp_size,
|
797
812
|
tp_rank=self.tp_rank,
|
798
813
|
)
|
814
|
+
self.lora_manager.load_lora_adapters(self.server_args.lora_paths)
|
799
815
|
logger.info("LoRA manager ready.")
|
800
816
|
|
801
817
|
def profile_max_num_token(self, total_gpu_memory: int):
|
@@ -337,7 +337,14 @@ class DefaultModelLoader(BaseModelLoader):
|
|
337
337
|
hf_weights_files,
|
338
338
|
)
|
339
339
|
elif use_safetensors:
|
340
|
-
|
340
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
341
|
+
|
342
|
+
weight_loader_disable_mmap = global_server_args_dict.get(
|
343
|
+
"weight_loader_disable_mmap"
|
344
|
+
)
|
345
|
+
weights_iterator = safetensors_weights_iterator(
|
346
|
+
hf_weights_files, disable_mmap=weight_loader_disable_mmap
|
347
|
+
)
|
341
348
|
else:
|
342
349
|
weights_iterator = pt_weights_iterator(hf_weights_files)
|
343
350
|
|
@@ -1259,12 +1266,19 @@ class GGUFModelLoader(BaseModelLoader):
|
|
1259
1266
|
):
|
1260
1267
|
model_config.hf_config.update({"tie_word_embeddings": True})
|
1261
1268
|
|
1269
|
+
target_device = torch.device(device_config.device)
|
1262
1270
|
with set_default_torch_dtype(model_config.dtype):
|
1263
|
-
with
|
1271
|
+
with target_device:
|
1264
1272
|
model = _initialize_model(model_config, self.load_config)
|
1265
1273
|
model.load_weights(
|
1266
1274
|
self._get_weights_iterator(local_model_path, gguf_weights_map)
|
1267
1275
|
)
|
1276
|
+
|
1277
|
+
for _, module in model.named_modules():
|
1278
|
+
quant_method = getattr(module, "quant_method", None)
|
1279
|
+
if quant_method is not None:
|
1280
|
+
with device_loading_context(module, target_device):
|
1281
|
+
quant_method.process_weights_after_loading(module)
|
1268
1282
|
return model
|
1269
1283
|
|
1270
1284
|
|
@@ -34,6 +34,7 @@ from sglang.srt.configs.load_config import LoadConfig
|
|
34
34
|
from sglang.srt.configs.model_config import ModelConfig
|
35
35
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
36
36
|
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
|
37
|
+
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config
|
37
38
|
from sglang.srt.utils import print_warning_once
|
38
39
|
|
39
40
|
logger = logging.getLogger(__name__)
|
@@ -206,7 +207,10 @@ def get_quant_config(
|
|
206
207
|
config["adapter_name_or_path"] = model_name_or_path
|
207
208
|
elif model_config.quantization == "modelopt":
|
208
209
|
if config["producer"]["name"] == "modelopt":
|
209
|
-
|
210
|
+
if "FP4" in config["quantization"]["quant_algo"]:
|
211
|
+
return ModelOptFp4Config.from_config(config)
|
212
|
+
else:
|
213
|
+
return quant_cls.from_config(config)
|
210
214
|
else:
|
211
215
|
raise ValueError(
|
212
216
|
f"Unsupported quantization config"
|
@@ -418,6 +422,7 @@ def safetensors_weights_iterator(
|
|
418
422
|
hf_weights_files: List[str],
|
419
423
|
is_all_weights_sharded: bool = False,
|
420
424
|
decryption_key: Optional[str] = None,
|
425
|
+
disable_mmap: bool = False,
|
421
426
|
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
422
427
|
"""Iterate over the weights in the model safetensor files.
|
423
428
|
|
@@ -439,7 +444,11 @@ def safetensors_weights_iterator(
|
|
439
444
|
disable=not enable_tqdm,
|
440
445
|
bar_format=_BAR_FORMAT,
|
441
446
|
):
|
442
|
-
|
447
|
+
if disable_mmap:
|
448
|
+
with open(st_file, "rb") as f:
|
449
|
+
result = safetensors.torch.load(f.read())
|
450
|
+
else:
|
451
|
+
result = safetensors.torch.load_file(st_file, device="cpu")
|
443
452
|
for name, param in result.items():
|
444
453
|
yield name, param
|
445
454
|
|
sglang/srt/models/bert.py
CHANGED
@@ -11,12 +11,13 @@ from sglang.srt.layers.linear import (
|
|
11
11
|
QKVParallelLinear,
|
12
12
|
RowParallelLinear,
|
13
13
|
)
|
14
|
-
from sglang.srt.layers.pooler import
|
14
|
+
from sglang.srt.layers.pooler import CrossEncodingPooler, Pooler, PoolingType
|
15
15
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
16
16
|
from sglang.srt.layers.radix_attention import AttentionType, RadixAttention
|
17
17
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
18
18
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
19
19
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
20
|
+
from sglang.srt.utils import add_prefix
|
20
21
|
|
21
22
|
BertConfig = None
|
22
23
|
|
@@ -50,7 +51,8 @@ class BertEmbedding(nn.Module):
|
|
50
51
|
def forward(
|
51
52
|
self,
|
52
53
|
input_ids: torch.Tensor,
|
53
|
-
|
54
|
+
positions: torch.Tensor,
|
55
|
+
forward_batch: ForwardBatch,
|
54
56
|
) -> torch.Tensor:
|
55
57
|
input_shape = input_ids.size()
|
56
58
|
|
@@ -58,11 +60,14 @@ class BertEmbedding(nn.Module):
|
|
58
60
|
inputs_embeds = self.word_embeddings(input_ids)
|
59
61
|
|
60
62
|
# Position embeddings.
|
61
|
-
position_embeddings = self.position_embeddings(
|
63
|
+
position_embeddings = self.position_embeddings(positions)
|
62
64
|
|
63
|
-
token_type_ids =
|
64
|
-
|
65
|
-
|
65
|
+
token_type_ids = forward_batch.token_type_ids
|
66
|
+
|
67
|
+
if token_type_ids is None:
|
68
|
+
token_type_ids = torch.zeros(
|
69
|
+
input_shape, dtype=torch.long, device=inputs_embeds.device
|
70
|
+
)
|
66
71
|
|
67
72
|
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
68
73
|
|
@@ -71,6 +76,25 @@ class BertEmbedding(nn.Module):
|
|
71
76
|
return embeddings
|
72
77
|
|
73
78
|
|
79
|
+
class BertPooler(nn.Module):
|
80
|
+
|
81
|
+
def __init__(self, config: BertConfig):
|
82
|
+
super().__init__()
|
83
|
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
84
|
+
self.activation = nn.Tanh()
|
85
|
+
|
86
|
+
def forward(
|
87
|
+
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
88
|
+
) -> torch.Tensor:
|
89
|
+
# simply taking the hidden state corresponding
|
90
|
+
first_token_tensor = hidden_states[0, :]
|
91
|
+
|
92
|
+
pooled_output = self.dense(first_token_tensor)
|
93
|
+
pooled_output = self.activation(pooled_output)
|
94
|
+
|
95
|
+
return pooled_output
|
96
|
+
|
97
|
+
|
74
98
|
class BertEncoder(nn.Module):
|
75
99
|
|
76
100
|
def __init__(
|
@@ -113,6 +137,8 @@ class BertLayer(nn.Module):
|
|
113
137
|
):
|
114
138
|
super().__init__()
|
115
139
|
|
140
|
+
self.layer_id = layer_id
|
141
|
+
|
116
142
|
self.attention = BertAttention(
|
117
143
|
hidden_size=config.hidden_size,
|
118
144
|
num_attention_heads=config.num_attention_heads,
|
@@ -142,6 +168,7 @@ class BertLayer(nn.Module):
|
|
142
168
|
attn_output = self.attention(hidden_states, forward_batch)
|
143
169
|
intermediate_output = self.intermediate(attn_output)
|
144
170
|
output = self.output(intermediate_output, attn_output)
|
171
|
+
|
145
172
|
return output
|
146
173
|
|
147
174
|
|
@@ -326,16 +353,23 @@ class BertModel(nn.Module):
|
|
326
353
|
*,
|
327
354
|
config: BertConfig,
|
328
355
|
quant_config: Optional[QuantizationConfig] = None,
|
356
|
+
use_bert_pooler: bool = False,
|
329
357
|
prefix: str = "",
|
330
358
|
):
|
331
359
|
super().__init__()
|
360
|
+
self.use_bert_pooler = use_bert_pooler
|
332
361
|
self.config = config
|
333
362
|
self.embeddings = BertEmbedding(config)
|
334
363
|
self.encoder = BertEncoder(
|
335
|
-
config=config,
|
364
|
+
config=config,
|
365
|
+
quant_config=quant_config,
|
366
|
+
prefix=add_prefix("encoder", prefix),
|
367
|
+
)
|
368
|
+
self.pooler = (
|
369
|
+
BertPooler(config)
|
370
|
+
if self.use_bert_pooler
|
371
|
+
else Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
336
372
|
)
|
337
|
-
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
338
|
-
# self.pooler = BertPooler(config)
|
339
373
|
|
340
374
|
@torch.no_grad()
|
341
375
|
def forward(
|
@@ -351,11 +385,16 @@ class BertModel(nn.Module):
|
|
351
385
|
|
352
386
|
hidden_states = self.embeddings(
|
353
387
|
input_ids=input_ids,
|
354
|
-
|
388
|
+
positions=positions,
|
389
|
+
forward_batch=forward_batch,
|
355
390
|
)
|
356
391
|
|
357
392
|
hidden_states = self.encoder(hidden_states, forward_batch=forward_batch)
|
358
|
-
|
393
|
+
|
394
|
+
if not self.use_bert_pooler:
|
395
|
+
hidden_states = self.pooler(hidden_states, forward_batch)
|
396
|
+
|
397
|
+
return hidden_states
|
359
398
|
|
360
399
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
|
361
400
|
stacked_params_mapping = [
|
@@ -368,7 +407,7 @@ class BertModel(nn.Module):
|
|
368
407
|
params_dict = dict(self.named_parameters())
|
369
408
|
for name, loaded_weight in weights:
|
370
409
|
name = name.replace("self", "self_attn")
|
371
|
-
if "pooler" in name:
|
410
|
+
if not self.use_bert_pooler and "pooler" in name:
|
372
411
|
continue
|
373
412
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
374
413
|
|
@@ -395,4 +434,65 @@ class Contriever(BertModel):
|
|
395
434
|
pass
|
396
435
|
|
397
436
|
|
398
|
-
|
437
|
+
class BertForSequenceClassification(nn.Module):
|
438
|
+
|
439
|
+
def __init__(
|
440
|
+
self,
|
441
|
+
*,
|
442
|
+
config: BertConfig,
|
443
|
+
quant_config: Optional[QuantizationConfig] = None,
|
444
|
+
prefix: str = "",
|
445
|
+
):
|
446
|
+
super().__init__()
|
447
|
+
|
448
|
+
self.num_labels = config.num_labels
|
449
|
+
self.bert = BertModel(
|
450
|
+
config=config,
|
451
|
+
quant_config=quant_config,
|
452
|
+
use_bert_pooler=True,
|
453
|
+
prefix=add_prefix("bert", prefix),
|
454
|
+
)
|
455
|
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
456
|
+
self.pooler = CrossEncodingPooler(config, self.classifier, self.bert.pooler)
|
457
|
+
|
458
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
459
|
+
self_weights = []
|
460
|
+
|
461
|
+
def weight_filter():
|
462
|
+
for name, weight in weights:
|
463
|
+
if name.startswith("bert."):
|
464
|
+
yield (name[len("bert.") :], weight)
|
465
|
+
else:
|
466
|
+
self_weights.append((name, weight))
|
467
|
+
|
468
|
+
self.bert.load_weights(weight_filter())
|
469
|
+
|
470
|
+
params_dict = dict(self.named_parameters())
|
471
|
+
|
472
|
+
for name, loaded_weight in self_weights:
|
473
|
+
if name.startswith("classifier"):
|
474
|
+
param = params_dict[name]
|
475
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
476
|
+
weight_loader(param, loaded_weight)
|
477
|
+
|
478
|
+
def forward(
|
479
|
+
self,
|
480
|
+
input_ids: torch.Tensor,
|
481
|
+
positions: torch.Tensor,
|
482
|
+
forward_batch: ForwardBatch,
|
483
|
+
input_embeds: torch.Tensor = None,
|
484
|
+
get_embedding: bool = False,
|
485
|
+
) -> torch.Tensor:
|
486
|
+
assert get_embedding == True
|
487
|
+
|
488
|
+
hidden_states = self.bert(
|
489
|
+
input_ids=input_ids,
|
490
|
+
positions=positions,
|
491
|
+
forward_batch=forward_batch,
|
492
|
+
input_embeds=input_embeds,
|
493
|
+
get_embedding=get_embedding,
|
494
|
+
)
|
495
|
+
return self.pooler(hidden_states, forward_batch)
|
496
|
+
|
497
|
+
|
498
|
+
EntryClass = [BertModel, Contriever, BertForSequenceClassification]
|
@@ -22,7 +22,6 @@ from transformers import PretrainedConfig
|
|
22
22
|
|
23
23
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
24
24
|
from sglang.srt.layers.layernorm import RMSNorm
|
25
|
-
from sglang.srt.layers.linear import ReplicatedLinear
|
26
25
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
27
26
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
28
27
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
@@ -45,6 +44,12 @@ class DeepseekModelNextN(nn.Module):
|
|
45
44
|
prefix: str = "",
|
46
45
|
) -> None:
|
47
46
|
super().__init__()
|
47
|
+
if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
|
48
|
+
logger.warning(
|
49
|
+
"Overriding DeepseekV3ForCausalLMNextN quant config for modelopt_fp4 Deepseek model."
|
50
|
+
)
|
51
|
+
quant_config = None
|
52
|
+
|
48
53
|
self.vocab_size = config.vocab_size
|
49
54
|
|
50
55
|
self.embed_tokens = VocabParallelEmbedding(
|
@@ -77,6 +82,7 @@ class DeepseekModelNextN(nn.Module):
|
|
77
82
|
forward_batch: ForwardBatch,
|
78
83
|
input_embeds: torch.Tensor = None,
|
79
84
|
) -> torch.Tensor:
|
85
|
+
|
80
86
|
zero_allocator = BumpAllocator(
|
81
87
|
buffer_size=2,
|
82
88
|
dtype=torch.float32,
|
@@ -90,15 +96,16 @@ class DeepseekModelNextN(nn.Module):
|
|
90
96
|
else:
|
91
97
|
hidden_states = input_embeds
|
92
98
|
|
93
|
-
hidden_states
|
94
|
-
|
95
|
-
(
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
99
|
+
if hidden_states.shape[0] > 0:
|
100
|
+
hidden_states = self.eh_proj(
|
101
|
+
torch.cat(
|
102
|
+
(
|
103
|
+
self.enorm(hidden_states),
|
104
|
+
self.hnorm(forward_batch.spec_info.hidden_states),
|
105
|
+
),
|
106
|
+
dim=-1,
|
107
|
+
)
|
100
108
|
)
|
101
|
-
)
|
102
109
|
|
103
110
|
residual = None
|
104
111
|
hidden_states, residual = self.decoder(
|
@@ -106,7 +113,11 @@ class DeepseekModelNextN(nn.Module):
|
|
106
113
|
)
|
107
114
|
|
108
115
|
if not forward_batch.forward_mode.is_idle():
|
109
|
-
|
116
|
+
if residual is not None:
|
117
|
+
hidden_states, _ = self.shared_head.norm(hidden_states, residual)
|
118
|
+
else:
|
119
|
+
hidden_states = self.shared_head.norm(hidden_states)
|
120
|
+
|
110
121
|
return hidden_states
|
111
122
|
|
112
123
|
|
@@ -127,23 +138,14 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|
127
138
|
self.model = DeepseekModelNextN(
|
128
139
|
config, quant_config, prefix=add_prefix("model", prefix)
|
129
140
|
)
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
|
139
|
-
else:
|
140
|
-
self.lm_head = ParallelLMHead(
|
141
|
-
config.vocab_size,
|
142
|
-
config.hidden_size,
|
143
|
-
quant_config=quant_config,
|
144
|
-
prefix=add_prefix("model.shared_head.head", prefix),
|
145
|
-
)
|
146
|
-
self.logits_processor = LogitsProcessor(config)
|
141
|
+
self.lm_head = ParallelLMHead(
|
142
|
+
config.vocab_size,
|
143
|
+
config.hidden_size,
|
144
|
+
quant_config=quant_config,
|
145
|
+
prefix=add_prefix("model.shared_head.head", prefix),
|
146
|
+
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
147
|
+
)
|
148
|
+
self.logits_processor = LogitsProcessor(config)
|
147
149
|
|
148
150
|
@torch.no_grad()
|
149
151
|
def forward(
|