sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +168 -22
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +49 -0
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +35 -0
  8. sglang/srt/custom_op.py +7 -1
  9. sglang/srt/disaggregation/base/conn.py +2 -0
  10. sglang/srt/disaggregation/decode.py +22 -6
  11. sglang/srt/disaggregation/mooncake/conn.py +289 -48
  12. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  13. sglang/srt/disaggregation/nixl/conn.py +100 -52
  14. sglang/srt/disaggregation/prefill.py +5 -4
  15. sglang/srt/disaggregation/utils.py +13 -12
  16. sglang/srt/distributed/parallel_state.py +44 -17
  17. sglang/srt/entrypoints/EngineBase.py +8 -0
  18. sglang/srt/entrypoints/engine.py +45 -9
  19. sglang/srt/entrypoints/http_server.py +111 -24
  20. sglang/srt/entrypoints/openai/protocol.py +51 -6
  21. sglang/srt/entrypoints/openai/serving_chat.py +52 -76
  22. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  23. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  24. sglang/srt/eplb/__init__.py +0 -0
  25. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  26. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  27. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  28. sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
  29. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  30. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  31. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  32. sglang/srt/hf_transformers_utils.py +2 -1
  33. sglang/srt/layers/activation.py +7 -0
  34. sglang/srt/layers/amx_utils.py +86 -0
  35. sglang/srt/layers/attention/ascend_backend.py +219 -0
  36. sglang/srt/layers/attention/flashattention_backend.py +56 -23
  37. sglang/srt/layers/attention/tbo_backend.py +37 -9
  38. sglang/srt/layers/communicator.py +18 -2
  39. sglang/srt/layers/dp_attention.py +9 -3
  40. sglang/srt/layers/elementwise.py +76 -12
  41. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  42. sglang/srt/layers/layernorm.py +41 -0
  43. sglang/srt/layers/linear.py +99 -12
  44. sglang/srt/layers/logits_processor.py +15 -6
  45. sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
  46. sglang/srt/layers/moe/ep_moe/layer.py +115 -25
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  49. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
  50. sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
  51. sglang/srt/layers/moe/router.py +60 -22
  52. sglang/srt/layers/moe/topk.py +36 -28
  53. sglang/srt/layers/parameter.py +67 -7
  54. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  55. sglang/srt/layers/quantization/fp8.py +44 -0
  56. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  57. sglang/srt/layers/quantization/fp8_utils.py +6 -6
  58. sglang/srt/layers/quantization/gptq.py +5 -1
  59. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  60. sglang/srt/layers/quantization/quant_utils.py +166 -0
  61. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  62. sglang/srt/layers/rotary_embedding.py +105 -13
  63. sglang/srt/layers/vocab_parallel_embedding.py +19 -2
  64. sglang/srt/lora/lora.py +4 -5
  65. sglang/srt/lora/lora_manager.py +73 -20
  66. sglang/srt/managers/configure_logging.py +1 -1
  67. sglang/srt/managers/io_struct.py +60 -15
  68. sglang/srt/managers/mm_utils.py +73 -59
  69. sglang/srt/managers/multimodal_processor.py +2 -6
  70. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  71. sglang/srt/managers/schedule_batch.py +80 -79
  72. sglang/srt/managers/scheduler.py +153 -63
  73. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  74. sglang/srt/managers/session_controller.py +12 -3
  75. sglang/srt/managers/tokenizer_manager.py +314 -103
  76. sglang/srt/managers/tp_worker.py +13 -1
  77. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  78. sglang/srt/mem_cache/allocator.py +290 -0
  79. sglang/srt/mem_cache/chunk_cache.py +34 -2
  80. sglang/srt/mem_cache/memory_pool.py +289 -3
  81. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  82. sglang/srt/model_executor/cuda_graph_runner.py +3 -2
  83. sglang/srt/model_executor/forward_batch_info.py +17 -4
  84. sglang/srt/model_executor/model_runner.py +302 -58
  85. sglang/srt/model_loader/loader.py +86 -10
  86. sglang/srt/model_loader/weight_utils.py +160 -3
  87. sglang/srt/models/deepseek_nextn.py +5 -4
  88. sglang/srt/models/deepseek_v2.py +305 -26
  89. sglang/srt/models/deepseek_vl2.py +3 -5
  90. sglang/srt/models/gemma3_causal.py +1 -2
  91. sglang/srt/models/gemma3n_audio.py +949 -0
  92. sglang/srt/models/gemma3n_causal.py +1010 -0
  93. sglang/srt/models/gemma3n_mm.py +495 -0
  94. sglang/srt/models/hunyuan.py +771 -0
  95. sglang/srt/models/kimi_vl.py +1 -2
  96. sglang/srt/models/llama.py +10 -4
  97. sglang/srt/models/llama4.py +32 -45
  98. sglang/srt/models/llama_eagle3.py +61 -11
  99. sglang/srt/models/llava.py +5 -5
  100. sglang/srt/models/minicpmo.py +2 -2
  101. sglang/srt/models/mistral.py +1 -1
  102. sglang/srt/models/mllama4.py +43 -11
  103. sglang/srt/models/phi4mm.py +1 -3
  104. sglang/srt/models/pixtral.py +3 -7
  105. sglang/srt/models/qwen2.py +31 -3
  106. sglang/srt/models/qwen2_5_vl.py +1 -3
  107. sglang/srt/models/qwen2_audio.py +200 -0
  108. sglang/srt/models/qwen2_moe.py +32 -6
  109. sglang/srt/models/qwen2_vl.py +1 -4
  110. sglang/srt/models/qwen3.py +94 -25
  111. sglang/srt/models/qwen3_moe.py +68 -21
  112. sglang/srt/models/vila.py +3 -8
  113. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
  114. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  115. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  116. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  117. sglang/srt/multimodal/processors/gemma3n.py +82 -0
  118. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  119. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  129. sglang/srt/operations_strategy.py +6 -2
  130. sglang/srt/reasoning_parser.py +26 -0
  131. sglang/srt/sampling/sampling_batch_info.py +39 -1
  132. sglang/srt/server_args.py +85 -24
  133. sglang/srt/speculative/build_eagle_tree.py +57 -18
  134. sglang/srt/speculative/eagle_worker.py +6 -4
  135. sglang/srt/two_batch_overlap.py +204 -28
  136. sglang/srt/utils.py +369 -138
  137. sglang/srt/warmup.py +12 -3
  138. sglang/test/runners.py +10 -1
  139. sglang/test/test_utils.py +15 -3
  140. sglang/version.py +1 -1
  141. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
  142. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
  143. sglang/math_utils.py +0 -8
  144. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  145. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  146. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  147. /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
  148. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
  149. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,7 @@
2
2
 
3
3
  # ruff: noqa: SIM117
4
4
  import collections
5
+ import concurrent
5
6
  import dataclasses
6
7
  import fnmatch
7
8
  import glob
@@ -11,14 +12,17 @@ import math
11
12
  import os
12
13
  import time
13
14
  from abc import ABC, abstractmethod
15
+ from concurrent.futures import ThreadPoolExecutor
14
16
  from contextlib import contextmanager
15
17
  from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
16
18
 
17
19
  import huggingface_hub
18
20
  import numpy as np
21
+ import safetensors.torch
19
22
  import torch
20
23
  from huggingface_hub import HfApi, hf_hub_download
21
24
  from torch import nn
25
+ from tqdm.auto import tqdm
22
26
  from transformers import AutoModelForCausalLM
23
27
  from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
24
28
 
@@ -41,6 +45,7 @@ from sglang.srt.model_loader.utils import (
41
45
  set_default_torch_dtype,
42
46
  )
43
47
  from sglang.srt.model_loader.weight_utils import (
48
+ _BAR_FORMAT,
44
49
  download_safetensors_index_file_from_hf,
45
50
  download_weights_from_hf,
46
51
  filter_duplicate_safetensors_files,
@@ -49,6 +54,8 @@ from sglang.srt.model_loader.weight_utils import (
49
54
  get_quant_config,
50
55
  gguf_quant_weights_iterator,
51
56
  initialize_dummy_weights,
57
+ multi_thread_pt_weights_iterator,
58
+ multi_thread_safetensors_weights_iterator,
52
59
  np_cache_weights_iterator,
53
60
  pt_weights_iterator,
54
61
  safetensors_weights_iterator,
@@ -117,6 +124,9 @@ def _get_quantization_config(
117
124
  quant_config = get_quant_config(
118
125
  model_config, load_config, packed_modules_mapping
119
126
  )
127
+ # (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
128
+ if quant_config is None:
129
+ return None
120
130
  major, minor = get_device_capability()
121
131
 
122
132
  if major is not None and minor is not None:
@@ -181,6 +191,9 @@ class BaseModelLoader(ABC):
181
191
  class DefaultModelLoader(BaseModelLoader):
182
192
  """Model loader that can load different file types from disk."""
183
193
 
194
+ # default number of thread when enable multithread weight loading
195
+ DEFAULT_NUM_THREADS = 8
196
+
184
197
  @dataclasses.dataclass
185
198
  class Source:
186
199
  """A source for weights."""
@@ -208,10 +221,15 @@ class DefaultModelLoader(BaseModelLoader):
208
221
 
209
222
  def __init__(self, load_config: LoadConfig):
210
223
  super().__init__(load_config)
211
- if load_config.model_loader_extra_config:
224
+ extra_config = load_config.model_loader_extra_config
225
+ allowed_keys = {"enable_multithread_load", "num_threads"}
226
+ unexpected_keys = set(extra_config.keys()) - allowed_keys
227
+
228
+ if unexpected_keys:
212
229
  raise ValueError(
213
- f"Model loader extra config is not supported for "
214
- f"load format {load_config.load_format}"
230
+ f"Unexpected extra config keys for load format "
231
+ f"{load_config.load_format}: "
232
+ f"{unexpected_keys}"
215
233
  )
216
234
 
217
235
  def _maybe_download_from_modelscope(
@@ -324,6 +342,7 @@ class DefaultModelLoader(BaseModelLoader):
324
342
  self, source: "Source"
325
343
  ) -> Generator[Tuple[str, torch.Tensor], None, None]:
326
344
  """Get an iterator for the model weights based on the load format."""
345
+ extra_config = self.load_config.model_loader_extra_config
327
346
  hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
328
347
  source.model_or_path, source.revision, source.fall_back_to_pt
329
348
  )
@@ -342,11 +361,30 @@ class DefaultModelLoader(BaseModelLoader):
342
361
  weight_loader_disable_mmap = global_server_args_dict.get(
343
362
  "weight_loader_disable_mmap"
344
363
  )
345
- weights_iterator = safetensors_weights_iterator(
346
- hf_weights_files, disable_mmap=weight_loader_disable_mmap
347
- )
364
+
365
+ if extra_config.get("enable_multithread_load"):
366
+ weights_iterator = multi_thread_safetensors_weights_iterator(
367
+ hf_weights_files,
368
+ max_workers=extra_config.get(
369
+ "num_threads", self.DEFAULT_NUM_THREADS
370
+ ),
371
+ disable_mmap=weight_loader_disable_mmap,
372
+ )
373
+ else:
374
+ weights_iterator = safetensors_weights_iterator(
375
+ hf_weights_files, disable_mmap=weight_loader_disable_mmap
376
+ )
377
+
348
378
  else:
349
- weights_iterator = pt_weights_iterator(hf_weights_files)
379
+ if extra_config.get("enable_multithread_load"):
380
+ weights_iterator = multi_thread_pt_weights_iterator(
381
+ hf_weights_files,
382
+ max_workers=extra_config.get(
383
+ "num_threads", self.DEFAULT_NUM_THREADS
384
+ ),
385
+ )
386
+ else:
387
+ weights_iterator = pt_weights_iterator(hf_weights_files)
350
388
 
351
389
  # Apply the prefix.
352
390
  return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator)
@@ -385,9 +423,9 @@ class DefaultModelLoader(BaseModelLoader):
385
423
  self.load_config,
386
424
  )
387
425
 
388
- self.load_weights_and_postprocess(
389
- model, self._get_all_weights(model_config, model), target_device
390
- )
426
+ self.load_weights_and_postprocess(
427
+ model, self._get_all_weights(model_config, model), target_device
428
+ )
391
429
 
392
430
  return model.eval()
393
431
 
@@ -499,6 +537,12 @@ class DummyModelLoader(BaseModelLoader):
499
537
  model_config: ModelConfig,
500
538
  device_config: DeviceConfig,
501
539
  ) -> nn.Module:
540
+
541
+ if get_bool_env_var("SGL_CPU_QUANTIZATION"):
542
+ return load_model_with_cpu_quantization(
543
+ self, model_config=model_config, device_config=device_config
544
+ )
545
+
502
546
  with set_default_torch_dtype(model_config.dtype):
503
547
  with torch.device(device_config.device):
504
548
  model = _initialize_model(
@@ -1429,6 +1473,38 @@ class RemoteModelLoader(BaseModelLoader):
1429
1473
  return model.eval()
1430
1474
 
1431
1475
 
1476
+ def load_model_with_cpu_quantization(
1477
+ self,
1478
+ *,
1479
+ model_config: ModelConfig,
1480
+ device_config: DeviceConfig,
1481
+ ) -> nn.Module:
1482
+ target_device = torch.device(device_config.device)
1483
+ with set_default_torch_dtype(model_config.dtype):
1484
+ model = _initialize_model(
1485
+ model_config,
1486
+ self.load_config,
1487
+ )
1488
+
1489
+ if not isinstance(self, DummyModelLoader):
1490
+ model.load_weights(self._get_all_weights(model_config, model))
1491
+
1492
+ for _, module in model.named_modules():
1493
+ quant_method = getattr(module, "quant_method", None)
1494
+ if quant_method is not None:
1495
+ # When quant methods need to process weights after loading
1496
+ # (for repacking, quantizing, etc), they expect parameters
1497
+ # to be on the global target device. This scope is for the
1498
+ # case where cpu offloading is used, where we will move the
1499
+ # parameters onto device for processing and back off after.
1500
+ with device_loading_context(module, target_device):
1501
+ quant_method.process_weights_after_loading(module)
1502
+
1503
+ model.to(target_device)
1504
+
1505
+ return model.eval()
1506
+
1507
+
1432
1508
  def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
1433
1509
  """Get a model loader based on the load format."""
1434
1510
 
@@ -1,12 +1,14 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/weight_utils.py
2
2
 
3
3
  """Utilities for downloading and initializing model weights."""
4
+ import concurrent.futures
4
5
  import fnmatch
5
6
  import glob
6
7
  import hashlib
7
8
  import json
8
9
  import logging
9
10
  import os
11
+ import queue
10
12
  import tempfile
11
13
  from collections import defaultdict
12
14
  from typing import (
@@ -207,6 +209,17 @@ def get_quant_config(
207
209
  config["adapter_name_or_path"] = model_name_or_path
208
210
  elif model_config.quantization == "modelopt":
209
211
  if config["producer"]["name"] == "modelopt":
212
+ # (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
213
+ if config["quantization"]["quant_algo"] is None:
214
+ if (
215
+ model_config.hf_config.architectures[0]
216
+ != "LlamaForCausalLMEagle3"
217
+ ):
218
+ raise ValueError(
219
+ f"Invalid quant_config, quantization method: {model_config.quantization},"
220
+ f"hf architectures: {model_config.hf_config.architectures[0]}. "
221
+ )
222
+ return None
210
223
  if "FP4" in config["quantization"]["quant_algo"]:
211
224
  return ModelOptFp4Config.from_config(config)
212
225
  else:
@@ -447,10 +460,67 @@ def safetensors_weights_iterator(
447
460
  if disable_mmap:
448
461
  with open(st_file, "rb") as f:
449
462
  result = safetensors.torch.load(f.read())
463
+ for name, param in result.items():
464
+ yield name, param
450
465
  else:
451
- result = safetensors.torch.load_file(st_file, device="cpu")
452
- for name, param in result.items():
453
- yield name, param
466
+ with safetensors.safe_open(st_file, framework="pt", device="cpu") as f:
467
+ for name in f.keys():
468
+ yield name, f.get_tensor(name)
469
+
470
+
471
+ def multi_thread_safetensors_weights_iterator(
472
+ hf_weights_files: List[str],
473
+ is_all_weights_sharded: bool = False,
474
+ decryption_key: Optional[str] = None,
475
+ max_workers: int = 4,
476
+ disable_mmap: bool = False,
477
+ ) -> Generator[Tuple[str, torch.Tensor], None, None]:
478
+ """Multi-Thread iterate over the weights in the model safetensor files.
479
+
480
+ If is_all_weights_sharded is True, it uses more optimize read by reading an
481
+ entire file instead of reading each tensor one by one.
482
+ """
483
+ if decryption_key:
484
+ logger.warning(
485
+ "Multi-Thread loading is not working for encrypted safetensor weights."
486
+ )
487
+ yield from safetensors_encrypted_weights_iterator(
488
+ hf_weights_files, is_all_weights_sharded, decryption_key
489
+ )
490
+ return
491
+
492
+ enable_tqdm = (
493
+ not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
494
+ )
495
+
496
+ def _load_file(st_file: str):
497
+ if disable_mmap:
498
+ with open(st_file, "rb") as f:
499
+ result = safetensors.torch.load(f.read())
500
+ else:
501
+ with safetensors.safe_open(st_file, framework="pt", device="cpu") as f:
502
+ result = {k: f.get_tensor(k) for k in f.keys()}
503
+
504
+ return result
505
+
506
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
507
+ futures = [executor.submit(_load_file, st_file) for st_file in hf_weights_files]
508
+
509
+ if enable_tqdm:
510
+ futures_iter = tqdm(
511
+ concurrent.futures.as_completed(futures),
512
+ total=len(hf_weights_files),
513
+ desc="Multi-thread loading shards",
514
+ disable=not enable_tqdm,
515
+ bar_format=_BAR_FORMAT,
516
+ )
517
+ else:
518
+ futures_iter = concurrent.futures.as_completed(futures)
519
+
520
+ for future in futures_iter:
521
+ state_dict = future.result()
522
+ for name, param in state_dict.items():
523
+ yield name, param
454
524
 
455
525
 
456
526
  def pt_weights_iterator(
@@ -471,6 +541,39 @@ def pt_weights_iterator(
471
541
  del state
472
542
 
473
543
 
544
+ def multi_thread_pt_weights_iterator(
545
+ hf_weights_files: List[str],
546
+ max_workers: int = 4,
547
+ ) -> Generator[Tuple[str, torch.Tensor], None, None]:
548
+ """Multi-Thread iterate over the weights in the model bin/pt files."""
549
+ enable_tqdm = (
550
+ not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
551
+ )
552
+
553
+ def _load_file(bin_file: str):
554
+ return torch.load(bin_file, map_location="cpu", weights_only=True)
555
+
556
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
557
+ futures = [
558
+ executor.submit(_load_file, bin_file) for bin_file in hf_weights_files
559
+ ]
560
+
561
+ if enable_tqdm:
562
+ futures_iter = tqdm(
563
+ concurrent.futures.as_completed(futures),
564
+ total=len(hf_weights_files),
565
+ desc="Multi-thread loading pt checkpoint shards",
566
+ disable=not enable_tqdm,
567
+ bar_format=_BAR_FORMAT,
568
+ )
569
+ else:
570
+ futures_iter = concurrent.futures.as_completed(futures)
571
+
572
+ for future in futures_iter:
573
+ state = future.result()
574
+ yield from state.items()
575
+
576
+
474
577
  def get_gguf_extra_tensor_names(
475
578
  gguf_file: str, gguf_to_hf_name_map: Dict[str, str]
476
579
  ) -> List[str]:
@@ -858,3 +961,57 @@ def kv_cache_scales_loader(
858
961
  tp_rank,
859
962
  )
860
963
  return []
964
+
965
+
966
+ def get_actual_shard_size(shard_size, weight_start, weight_end):
967
+ if weight_end < weight_start:
968
+ return 0
969
+
970
+ return min(shard_size, weight_end - weight_start)
971
+
972
+
973
+ def reset_param_data_if_needed(param_data, dim, start, length):
974
+ if length == 0:
975
+ return
976
+
977
+ assert length > 0, f"Length should be positive, but got {length}"
978
+
979
+ param_data.narrow(dim, start, length).zero_()
980
+ return
981
+
982
+
983
+ def narrow_padded_param_and_loaded_weight(
984
+ param_data,
985
+ loaded_weight,
986
+ param_data_start,
987
+ weight_start,
988
+ dim,
989
+ shard_size,
990
+ narrow_weight=True,
991
+ ):
992
+ actual_shard_size = get_actual_shard_size(
993
+ shard_size, weight_start, loaded_weight.size(dim)
994
+ )
995
+
996
+ if narrow_weight:
997
+ if actual_shard_size > 0:
998
+ loaded_weight = loaded_weight.narrow(dim, weight_start, actual_shard_size)
999
+ else:
1000
+ # No real data to load; create a dummy tensor filled with zeros
1001
+ loaded_weight = torch.zeros_like(
1002
+ param_data.narrow(dim, param_data_start, actual_shard_size)
1003
+ )
1004
+
1005
+ # [Note] Reset padded weights to zero.
1006
+ # If the actual shard size is less than the shard size, we need to reset
1007
+ # the padded param_data to zero and then copy the loaded_weight into it.
1008
+ reset_param_data_if_needed(
1009
+ param_data,
1010
+ dim,
1011
+ param_data_start + actual_shard_size,
1012
+ shard_size - actual_shard_size,
1013
+ )
1014
+
1015
+ param_data = param_data.narrow(dim, param_data_start, actual_shard_size)
1016
+
1017
+ return param_data, loaded_weight
@@ -21,6 +21,7 @@ from torch import nn
21
21
  from transformers import PretrainedConfig
22
22
 
23
23
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
24
+ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
24
25
  from sglang.srt.layers.layernorm import RMSNorm
25
26
  from sglang.srt.layers.logits_processor import LogitsProcessor
26
27
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -82,7 +83,6 @@ class DeepseekModelNextN(nn.Module):
82
83
  forward_batch: ForwardBatch,
83
84
  input_embeds: torch.Tensor = None,
84
85
  ) -> torch.Tensor:
85
-
86
86
  zero_allocator = BumpAllocator(
87
87
  buffer_size=2,
88
88
  dtype=torch.float32,
@@ -108,9 +108,10 @@ class DeepseekModelNextN(nn.Module):
108
108
  )
109
109
 
110
110
  residual = None
111
- hidden_states, residual = self.decoder(
112
- positions, hidden_states, forward_batch, residual, zero_allocator
113
- )
111
+ with get_global_expert_distribution_recorder().disable_this_region():
112
+ hidden_states, residual = self.decoder(
113
+ positions, hidden_states, forward_batch, residual, zero_allocator
114
+ )
114
115
 
115
116
  if not forward_batch.forward_mode.is_idle():
116
117
  if residual is not None: