sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +168 -22
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +49 -0
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +35 -0
- sglang/srt/custom_op.py +7 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +22 -6
- sglang/srt/disaggregation/mooncake/conn.py +289 -48
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +100 -52
- sglang/srt/disaggregation/prefill.py +5 -4
- sglang/srt/disaggregation/utils.py +13 -12
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +45 -9
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/openai/protocol.py +51 -6
- sglang/srt/entrypoints/openai/serving_chat.py +52 -76
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +7 -0
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +56 -23
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +18 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +41 -0
- sglang/srt/layers/linear.py +99 -12
- sglang/srt/layers/logits_processor.py +15 -6
- sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
- sglang/srt/layers/moe/ep_moe/layer.py +115 -25
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +36 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +44 -0
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +6 -6
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +105 -13
- sglang/srt/layers/vocab_parallel_embedding.py +19 -2
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +60 -15
- sglang/srt/managers/mm_utils.py +73 -59
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +80 -79
- sglang/srt/managers/scheduler.py +153 -63
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/memory_pool.py +289 -3
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +3 -2
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +302 -58
- sglang/srt/model_loader/loader.py +86 -10
- sglang/srt/model_loader/weight_utils.py +160 -3
- sglang/srt/models/deepseek_nextn.py +5 -4
- sglang/srt/models/deepseek_v2.py +305 -26
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1010 -0
- sglang/srt/models/gemma3n_mm.py +495 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +43 -11
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/multimodal/processors/gemma3n.py +82 -0
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +85 -24
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +204 -28
- sglang/srt/utils.py +369 -138
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_utils.py +15 -3
- sglang/version.py +1 -1
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,7 @@
|
|
2
2
|
|
3
3
|
# ruff: noqa: SIM117
|
4
4
|
import collections
|
5
|
+
import concurrent
|
5
6
|
import dataclasses
|
6
7
|
import fnmatch
|
7
8
|
import glob
|
@@ -11,14 +12,17 @@ import math
|
|
11
12
|
import os
|
12
13
|
import time
|
13
14
|
from abc import ABC, abstractmethod
|
15
|
+
from concurrent.futures import ThreadPoolExecutor
|
14
16
|
from contextlib import contextmanager
|
15
17
|
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
|
16
18
|
|
17
19
|
import huggingface_hub
|
18
20
|
import numpy as np
|
21
|
+
import safetensors.torch
|
19
22
|
import torch
|
20
23
|
from huggingface_hub import HfApi, hf_hub_download
|
21
24
|
from torch import nn
|
25
|
+
from tqdm.auto import tqdm
|
22
26
|
from transformers import AutoModelForCausalLM
|
23
27
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
24
28
|
|
@@ -41,6 +45,7 @@ from sglang.srt.model_loader.utils import (
|
|
41
45
|
set_default_torch_dtype,
|
42
46
|
)
|
43
47
|
from sglang.srt.model_loader.weight_utils import (
|
48
|
+
_BAR_FORMAT,
|
44
49
|
download_safetensors_index_file_from_hf,
|
45
50
|
download_weights_from_hf,
|
46
51
|
filter_duplicate_safetensors_files,
|
@@ -49,6 +54,8 @@ from sglang.srt.model_loader.weight_utils import (
|
|
49
54
|
get_quant_config,
|
50
55
|
gguf_quant_weights_iterator,
|
51
56
|
initialize_dummy_weights,
|
57
|
+
multi_thread_pt_weights_iterator,
|
58
|
+
multi_thread_safetensors_weights_iterator,
|
52
59
|
np_cache_weights_iterator,
|
53
60
|
pt_weights_iterator,
|
54
61
|
safetensors_weights_iterator,
|
@@ -117,6 +124,9 @@ def _get_quantization_config(
|
|
117
124
|
quant_config = get_quant_config(
|
118
125
|
model_config, load_config, packed_modules_mapping
|
119
126
|
)
|
127
|
+
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
|
128
|
+
if quant_config is None:
|
129
|
+
return None
|
120
130
|
major, minor = get_device_capability()
|
121
131
|
|
122
132
|
if major is not None and minor is not None:
|
@@ -181,6 +191,9 @@ class BaseModelLoader(ABC):
|
|
181
191
|
class DefaultModelLoader(BaseModelLoader):
|
182
192
|
"""Model loader that can load different file types from disk."""
|
183
193
|
|
194
|
+
# default number of thread when enable multithread weight loading
|
195
|
+
DEFAULT_NUM_THREADS = 8
|
196
|
+
|
184
197
|
@dataclasses.dataclass
|
185
198
|
class Source:
|
186
199
|
"""A source for weights."""
|
@@ -208,10 +221,15 @@ class DefaultModelLoader(BaseModelLoader):
|
|
208
221
|
|
209
222
|
def __init__(self, load_config: LoadConfig):
|
210
223
|
super().__init__(load_config)
|
211
|
-
|
224
|
+
extra_config = load_config.model_loader_extra_config
|
225
|
+
allowed_keys = {"enable_multithread_load", "num_threads"}
|
226
|
+
unexpected_keys = set(extra_config.keys()) - allowed_keys
|
227
|
+
|
228
|
+
if unexpected_keys:
|
212
229
|
raise ValueError(
|
213
|
-
f"
|
214
|
-
f"
|
230
|
+
f"Unexpected extra config keys for load format "
|
231
|
+
f"{load_config.load_format}: "
|
232
|
+
f"{unexpected_keys}"
|
215
233
|
)
|
216
234
|
|
217
235
|
def _maybe_download_from_modelscope(
|
@@ -324,6 +342,7 @@ class DefaultModelLoader(BaseModelLoader):
|
|
324
342
|
self, source: "Source"
|
325
343
|
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
326
344
|
"""Get an iterator for the model weights based on the load format."""
|
345
|
+
extra_config = self.load_config.model_loader_extra_config
|
327
346
|
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
|
328
347
|
source.model_or_path, source.revision, source.fall_back_to_pt
|
329
348
|
)
|
@@ -342,11 +361,30 @@ class DefaultModelLoader(BaseModelLoader):
|
|
342
361
|
weight_loader_disable_mmap = global_server_args_dict.get(
|
343
362
|
"weight_loader_disable_mmap"
|
344
363
|
)
|
345
|
-
|
346
|
-
|
347
|
-
|
364
|
+
|
365
|
+
if extra_config.get("enable_multithread_load"):
|
366
|
+
weights_iterator = multi_thread_safetensors_weights_iterator(
|
367
|
+
hf_weights_files,
|
368
|
+
max_workers=extra_config.get(
|
369
|
+
"num_threads", self.DEFAULT_NUM_THREADS
|
370
|
+
),
|
371
|
+
disable_mmap=weight_loader_disable_mmap,
|
372
|
+
)
|
373
|
+
else:
|
374
|
+
weights_iterator = safetensors_weights_iterator(
|
375
|
+
hf_weights_files, disable_mmap=weight_loader_disable_mmap
|
376
|
+
)
|
377
|
+
|
348
378
|
else:
|
349
|
-
|
379
|
+
if extra_config.get("enable_multithread_load"):
|
380
|
+
weights_iterator = multi_thread_pt_weights_iterator(
|
381
|
+
hf_weights_files,
|
382
|
+
max_workers=extra_config.get(
|
383
|
+
"num_threads", self.DEFAULT_NUM_THREADS
|
384
|
+
),
|
385
|
+
)
|
386
|
+
else:
|
387
|
+
weights_iterator = pt_weights_iterator(hf_weights_files)
|
350
388
|
|
351
389
|
# Apply the prefix.
|
352
390
|
return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator)
|
@@ -385,9 +423,9 @@ class DefaultModelLoader(BaseModelLoader):
|
|
385
423
|
self.load_config,
|
386
424
|
)
|
387
425
|
|
388
|
-
|
389
|
-
|
390
|
-
|
426
|
+
self.load_weights_and_postprocess(
|
427
|
+
model, self._get_all_weights(model_config, model), target_device
|
428
|
+
)
|
391
429
|
|
392
430
|
return model.eval()
|
393
431
|
|
@@ -499,6 +537,12 @@ class DummyModelLoader(BaseModelLoader):
|
|
499
537
|
model_config: ModelConfig,
|
500
538
|
device_config: DeviceConfig,
|
501
539
|
) -> nn.Module:
|
540
|
+
|
541
|
+
if get_bool_env_var("SGL_CPU_QUANTIZATION"):
|
542
|
+
return load_model_with_cpu_quantization(
|
543
|
+
self, model_config=model_config, device_config=device_config
|
544
|
+
)
|
545
|
+
|
502
546
|
with set_default_torch_dtype(model_config.dtype):
|
503
547
|
with torch.device(device_config.device):
|
504
548
|
model = _initialize_model(
|
@@ -1429,6 +1473,38 @@ class RemoteModelLoader(BaseModelLoader):
|
|
1429
1473
|
return model.eval()
|
1430
1474
|
|
1431
1475
|
|
1476
|
+
def load_model_with_cpu_quantization(
|
1477
|
+
self,
|
1478
|
+
*,
|
1479
|
+
model_config: ModelConfig,
|
1480
|
+
device_config: DeviceConfig,
|
1481
|
+
) -> nn.Module:
|
1482
|
+
target_device = torch.device(device_config.device)
|
1483
|
+
with set_default_torch_dtype(model_config.dtype):
|
1484
|
+
model = _initialize_model(
|
1485
|
+
model_config,
|
1486
|
+
self.load_config,
|
1487
|
+
)
|
1488
|
+
|
1489
|
+
if not isinstance(self, DummyModelLoader):
|
1490
|
+
model.load_weights(self._get_all_weights(model_config, model))
|
1491
|
+
|
1492
|
+
for _, module in model.named_modules():
|
1493
|
+
quant_method = getattr(module, "quant_method", None)
|
1494
|
+
if quant_method is not None:
|
1495
|
+
# When quant methods need to process weights after loading
|
1496
|
+
# (for repacking, quantizing, etc), they expect parameters
|
1497
|
+
# to be on the global target device. This scope is for the
|
1498
|
+
# case where cpu offloading is used, where we will move the
|
1499
|
+
# parameters onto device for processing and back off after.
|
1500
|
+
with device_loading_context(module, target_device):
|
1501
|
+
quant_method.process_weights_after_loading(module)
|
1502
|
+
|
1503
|
+
model.to(target_device)
|
1504
|
+
|
1505
|
+
return model.eval()
|
1506
|
+
|
1507
|
+
|
1432
1508
|
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
1433
1509
|
"""Get a model loader based on the load format."""
|
1434
1510
|
|
@@ -1,12 +1,14 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/weight_utils.py
|
2
2
|
|
3
3
|
"""Utilities for downloading and initializing model weights."""
|
4
|
+
import concurrent.futures
|
4
5
|
import fnmatch
|
5
6
|
import glob
|
6
7
|
import hashlib
|
7
8
|
import json
|
8
9
|
import logging
|
9
10
|
import os
|
11
|
+
import queue
|
10
12
|
import tempfile
|
11
13
|
from collections import defaultdict
|
12
14
|
from typing import (
|
@@ -207,6 +209,17 @@ def get_quant_config(
|
|
207
209
|
config["adapter_name_or_path"] = model_name_or_path
|
208
210
|
elif model_config.quantization == "modelopt":
|
209
211
|
if config["producer"]["name"] == "modelopt":
|
212
|
+
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
|
213
|
+
if config["quantization"]["quant_algo"] is None:
|
214
|
+
if (
|
215
|
+
model_config.hf_config.architectures[0]
|
216
|
+
!= "LlamaForCausalLMEagle3"
|
217
|
+
):
|
218
|
+
raise ValueError(
|
219
|
+
f"Invalid quant_config, quantization method: {model_config.quantization},"
|
220
|
+
f"hf architectures: {model_config.hf_config.architectures[0]}. "
|
221
|
+
)
|
222
|
+
return None
|
210
223
|
if "FP4" in config["quantization"]["quant_algo"]:
|
211
224
|
return ModelOptFp4Config.from_config(config)
|
212
225
|
else:
|
@@ -447,10 +460,67 @@ def safetensors_weights_iterator(
|
|
447
460
|
if disable_mmap:
|
448
461
|
with open(st_file, "rb") as f:
|
449
462
|
result = safetensors.torch.load(f.read())
|
463
|
+
for name, param in result.items():
|
464
|
+
yield name, param
|
450
465
|
else:
|
451
|
-
|
452
|
-
|
453
|
-
|
466
|
+
with safetensors.safe_open(st_file, framework="pt", device="cpu") as f:
|
467
|
+
for name in f.keys():
|
468
|
+
yield name, f.get_tensor(name)
|
469
|
+
|
470
|
+
|
471
|
+
def multi_thread_safetensors_weights_iterator(
|
472
|
+
hf_weights_files: List[str],
|
473
|
+
is_all_weights_sharded: bool = False,
|
474
|
+
decryption_key: Optional[str] = None,
|
475
|
+
max_workers: int = 4,
|
476
|
+
disable_mmap: bool = False,
|
477
|
+
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
478
|
+
"""Multi-Thread iterate over the weights in the model safetensor files.
|
479
|
+
|
480
|
+
If is_all_weights_sharded is True, it uses more optimize read by reading an
|
481
|
+
entire file instead of reading each tensor one by one.
|
482
|
+
"""
|
483
|
+
if decryption_key:
|
484
|
+
logger.warning(
|
485
|
+
"Multi-Thread loading is not working for encrypted safetensor weights."
|
486
|
+
)
|
487
|
+
yield from safetensors_encrypted_weights_iterator(
|
488
|
+
hf_weights_files, is_all_weights_sharded, decryption_key
|
489
|
+
)
|
490
|
+
return
|
491
|
+
|
492
|
+
enable_tqdm = (
|
493
|
+
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
494
|
+
)
|
495
|
+
|
496
|
+
def _load_file(st_file: str):
|
497
|
+
if disable_mmap:
|
498
|
+
with open(st_file, "rb") as f:
|
499
|
+
result = safetensors.torch.load(f.read())
|
500
|
+
else:
|
501
|
+
with safetensors.safe_open(st_file, framework="pt", device="cpu") as f:
|
502
|
+
result = {k: f.get_tensor(k) for k in f.keys()}
|
503
|
+
|
504
|
+
return result
|
505
|
+
|
506
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
507
|
+
futures = [executor.submit(_load_file, st_file) for st_file in hf_weights_files]
|
508
|
+
|
509
|
+
if enable_tqdm:
|
510
|
+
futures_iter = tqdm(
|
511
|
+
concurrent.futures.as_completed(futures),
|
512
|
+
total=len(hf_weights_files),
|
513
|
+
desc="Multi-thread loading shards",
|
514
|
+
disable=not enable_tqdm,
|
515
|
+
bar_format=_BAR_FORMAT,
|
516
|
+
)
|
517
|
+
else:
|
518
|
+
futures_iter = concurrent.futures.as_completed(futures)
|
519
|
+
|
520
|
+
for future in futures_iter:
|
521
|
+
state_dict = future.result()
|
522
|
+
for name, param in state_dict.items():
|
523
|
+
yield name, param
|
454
524
|
|
455
525
|
|
456
526
|
def pt_weights_iterator(
|
@@ -471,6 +541,39 @@ def pt_weights_iterator(
|
|
471
541
|
del state
|
472
542
|
|
473
543
|
|
544
|
+
def multi_thread_pt_weights_iterator(
|
545
|
+
hf_weights_files: List[str],
|
546
|
+
max_workers: int = 4,
|
547
|
+
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
548
|
+
"""Multi-Thread iterate over the weights in the model bin/pt files."""
|
549
|
+
enable_tqdm = (
|
550
|
+
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
551
|
+
)
|
552
|
+
|
553
|
+
def _load_file(bin_file: str):
|
554
|
+
return torch.load(bin_file, map_location="cpu", weights_only=True)
|
555
|
+
|
556
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
557
|
+
futures = [
|
558
|
+
executor.submit(_load_file, bin_file) for bin_file in hf_weights_files
|
559
|
+
]
|
560
|
+
|
561
|
+
if enable_tqdm:
|
562
|
+
futures_iter = tqdm(
|
563
|
+
concurrent.futures.as_completed(futures),
|
564
|
+
total=len(hf_weights_files),
|
565
|
+
desc="Multi-thread loading pt checkpoint shards",
|
566
|
+
disable=not enable_tqdm,
|
567
|
+
bar_format=_BAR_FORMAT,
|
568
|
+
)
|
569
|
+
else:
|
570
|
+
futures_iter = concurrent.futures.as_completed(futures)
|
571
|
+
|
572
|
+
for future in futures_iter:
|
573
|
+
state = future.result()
|
574
|
+
yield from state.items()
|
575
|
+
|
576
|
+
|
474
577
|
def get_gguf_extra_tensor_names(
|
475
578
|
gguf_file: str, gguf_to_hf_name_map: Dict[str, str]
|
476
579
|
) -> List[str]:
|
@@ -858,3 +961,57 @@ def kv_cache_scales_loader(
|
|
858
961
|
tp_rank,
|
859
962
|
)
|
860
963
|
return []
|
964
|
+
|
965
|
+
|
966
|
+
def get_actual_shard_size(shard_size, weight_start, weight_end):
|
967
|
+
if weight_end < weight_start:
|
968
|
+
return 0
|
969
|
+
|
970
|
+
return min(shard_size, weight_end - weight_start)
|
971
|
+
|
972
|
+
|
973
|
+
def reset_param_data_if_needed(param_data, dim, start, length):
|
974
|
+
if length == 0:
|
975
|
+
return
|
976
|
+
|
977
|
+
assert length > 0, f"Length should be positive, but got {length}"
|
978
|
+
|
979
|
+
param_data.narrow(dim, start, length).zero_()
|
980
|
+
return
|
981
|
+
|
982
|
+
|
983
|
+
def narrow_padded_param_and_loaded_weight(
|
984
|
+
param_data,
|
985
|
+
loaded_weight,
|
986
|
+
param_data_start,
|
987
|
+
weight_start,
|
988
|
+
dim,
|
989
|
+
shard_size,
|
990
|
+
narrow_weight=True,
|
991
|
+
):
|
992
|
+
actual_shard_size = get_actual_shard_size(
|
993
|
+
shard_size, weight_start, loaded_weight.size(dim)
|
994
|
+
)
|
995
|
+
|
996
|
+
if narrow_weight:
|
997
|
+
if actual_shard_size > 0:
|
998
|
+
loaded_weight = loaded_weight.narrow(dim, weight_start, actual_shard_size)
|
999
|
+
else:
|
1000
|
+
# No real data to load; create a dummy tensor filled with zeros
|
1001
|
+
loaded_weight = torch.zeros_like(
|
1002
|
+
param_data.narrow(dim, param_data_start, actual_shard_size)
|
1003
|
+
)
|
1004
|
+
|
1005
|
+
# [Note] Reset padded weights to zero.
|
1006
|
+
# If the actual shard size is less than the shard size, we need to reset
|
1007
|
+
# the padded param_data to zero and then copy the loaded_weight into it.
|
1008
|
+
reset_param_data_if_needed(
|
1009
|
+
param_data,
|
1010
|
+
dim,
|
1011
|
+
param_data_start + actual_shard_size,
|
1012
|
+
shard_size - actual_shard_size,
|
1013
|
+
)
|
1014
|
+
|
1015
|
+
param_data = param_data.narrow(dim, param_data_start, actual_shard_size)
|
1016
|
+
|
1017
|
+
return param_data, loaded_weight
|
@@ -21,6 +21,7 @@ from torch import nn
|
|
21
21
|
from transformers import PretrainedConfig
|
22
22
|
|
23
23
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
24
|
+
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
24
25
|
from sglang.srt.layers.layernorm import RMSNorm
|
25
26
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
26
27
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
@@ -82,7 +83,6 @@ class DeepseekModelNextN(nn.Module):
|
|
82
83
|
forward_batch: ForwardBatch,
|
83
84
|
input_embeds: torch.Tensor = None,
|
84
85
|
) -> torch.Tensor:
|
85
|
-
|
86
86
|
zero_allocator = BumpAllocator(
|
87
87
|
buffer_size=2,
|
88
88
|
dtype=torch.float32,
|
@@ -108,9 +108,10 @@ class DeepseekModelNextN(nn.Module):
|
|
108
108
|
)
|
109
109
|
|
110
110
|
residual = None
|
111
|
-
|
112
|
-
|
113
|
-
|
111
|
+
with get_global_expert_distribution_recorder().disable_this_region():
|
112
|
+
hidden_states, residual = self.decoder(
|
113
|
+
positions, hidden_states, forward_batch, residual, zero_allocator
|
114
|
+
)
|
114
115
|
|
115
116
|
if not forward_batch.forward_mode.is_idle():
|
116
117
|
if residual is not None:
|