sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +6 -6
- sglang/bench_one_batch.py +5 -4
- sglang/bench_one_batch_server.py +23 -15
- sglang/bench_serving.py +133 -57
- sglang/compile_deep_gemm.py +4 -4
- sglang/srt/configs/model_config.py +39 -28
- sglang/srt/conversation.py +1 -1
- sglang/srt/disaggregation/decode.py +122 -133
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +11 -2
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +9 -19
- sglang/srt/disaggregation/prefill.py +126 -44
- sglang/srt/disaggregation/utils.py +116 -5
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +28 -8
- sglang/srt/entrypoints/http_server.py +6 -4
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +63 -17
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/utils.py +2 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +0 -10
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
- sglang/srt/layers/moe/ep_moe/layer.py +104 -50
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +66 -9
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +7 -2
- sglang/srt/layers/quantization/deep_gemm.py +5 -3
- sglang/srt/layers/quantization/fp8.py +90 -0
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +18 -5
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +16 -3
- sglang/srt/managers/mm_utils.py +293 -139
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +3 -3
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +49 -21
- sglang/srt/managers/schedule_policy.py +4 -5
- sglang/srt/managers/scheduler.py +92 -50
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +99 -24
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +74 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +2 -2
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +20 -9
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +4 -0
- sglang/srt/model_executor/model_runner.py +144 -54
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_v2.py +297 -343
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama4.py +10 -2
- sglang/srt/models/llava.py +26 -18
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/openai_api/adapter.py +28 -16
- sglang/srt/openai_api/protocol.py +6 -0
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/server_args.py +134 -24
- sglang/srt/speculative/eagle_utils.py +131 -0
- sglang/srt/speculative/eagle_worker.py +47 -2
- sglang/srt/utils.py +68 -12
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_utils.py +2 -36
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,6 @@
|
|
13
13
|
# ==============================================================================
|
14
14
|
"""ModelRunner runs the forward passes of the models."""
|
15
15
|
|
16
|
-
import collections
|
17
16
|
import datetime
|
18
17
|
import gc
|
19
18
|
import inspect
|
@@ -52,6 +51,18 @@ from sglang.srt.layers.quantization.deep_gemm import (
|
|
52
51
|
from sglang.srt.layers.sampler import Sampler
|
53
52
|
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
54
53
|
from sglang.srt.lora.lora_manager import LoRAManager
|
54
|
+
from sglang.srt.managers.eplb_manager import EPLBManager
|
55
|
+
from sglang.srt.managers.expert_distribution import (
|
56
|
+
ExpertDistributionRecorder,
|
57
|
+
get_global_expert_distribution_recorder,
|
58
|
+
set_global_expert_distribution_recorder,
|
59
|
+
)
|
60
|
+
from sglang.srt.managers.expert_location import (
|
61
|
+
ExpertLocationMetadata,
|
62
|
+
compute_initial_expert_location_metadata,
|
63
|
+
get_global_expert_location_metadata,
|
64
|
+
set_global_expert_location_metadata,
|
65
|
+
)
|
55
66
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
56
67
|
from sglang.srt.mem_cache.memory_pool import (
|
57
68
|
DoubleSparseTokenToKVPool,
|
@@ -61,6 +72,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
|
61
72
|
TokenToKVPoolAllocator,
|
62
73
|
)
|
63
74
|
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
|
75
|
+
from sglang.srt.model_executor import expert_location_updater
|
64
76
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
65
77
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
66
78
|
from sglang.srt.model_loader import get_model
|
@@ -94,6 +106,8 @@ from sglang.srt.utils import (
|
|
94
106
|
set_cuda_arch,
|
95
107
|
)
|
96
108
|
|
109
|
+
_is_hip = is_hip()
|
110
|
+
|
97
111
|
# Use a small KV cache pool size for tests in CI
|
98
112
|
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
|
99
113
|
|
@@ -103,6 +117,19 @@ UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
|
|
103
117
|
logger = logging.getLogger(__name__)
|
104
118
|
|
105
119
|
|
120
|
+
class RankZeroFilter(logging.Filter):
|
121
|
+
"""Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""
|
122
|
+
|
123
|
+
def __init__(self, is_rank_zero):
|
124
|
+
super().__init__()
|
125
|
+
self.is_rank_zero = is_rank_zero
|
126
|
+
|
127
|
+
def filter(self, record):
|
128
|
+
if record.levelno == logging.INFO:
|
129
|
+
return self.is_rank_zero
|
130
|
+
return True
|
131
|
+
|
132
|
+
|
106
133
|
class ModelRunner:
|
107
134
|
"""ModelRunner runs the forward passes of the models."""
|
108
135
|
|
@@ -126,6 +153,10 @@ class ModelRunner:
|
|
126
153
|
self.mem_fraction_static = mem_fraction_static
|
127
154
|
self.device = server_args.device
|
128
155
|
self.gpu_id = gpu_id
|
156
|
+
|
157
|
+
# Apply the rank zero filter to logger
|
158
|
+
if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
|
159
|
+
logger.addFilter(RankZeroFilter(tp_rank == 0))
|
129
160
|
self.tp_rank = tp_rank
|
130
161
|
self.tp_size = tp_size
|
131
162
|
self.pp_rank = pp_rank
|
@@ -135,7 +166,9 @@ class ModelRunner:
|
|
135
166
|
self.is_draft_worker = is_draft_worker
|
136
167
|
self.is_generation = model_config.is_generation
|
137
168
|
self.is_multimodal = model_config.is_multimodal
|
138
|
-
self.
|
169
|
+
self.is_multimodal_chunked_prefill_supported = (
|
170
|
+
model_config.is_multimodal_chunked_prefill_supported
|
171
|
+
)
|
139
172
|
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
140
173
|
server_args.speculative_algorithm
|
141
174
|
)
|
@@ -145,6 +178,8 @@ class ModelRunner:
|
|
145
178
|
self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
|
146
179
|
self.attention_chunk_size = model_config.attention_chunk_size
|
147
180
|
|
181
|
+
self.forward_pass_id = 0
|
182
|
+
|
148
183
|
# Model-specific adjustment
|
149
184
|
self.model_specific_adjustment()
|
150
185
|
|
@@ -163,10 +198,13 @@ class ModelRunner:
|
|
163
198
|
"disable_radix_cache": server_args.disable_radix_cache,
|
164
199
|
"enable_nan_detection": server_args.enable_nan_detection,
|
165
200
|
"enable_dp_attention": server_args.enable_dp_attention,
|
201
|
+
"enable_dp_lm_head": server_args.enable_dp_lm_head,
|
166
202
|
"enable_ep_moe": server_args.enable_ep_moe,
|
167
203
|
"enable_deepep_moe": server_args.enable_deepep_moe,
|
204
|
+
"deepep_config": server_args.deepep_config,
|
168
205
|
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
169
206
|
"moe_dense_tp_size": server_args.moe_dense_tp_size,
|
207
|
+
"ep_dispatch_algorithm": server_args.ep_dispatch_algorithm,
|
170
208
|
"n_share_experts_fusion": server_args.n_share_experts_fusion,
|
171
209
|
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
172
210
|
"torchao_config": server_args.torchao_config,
|
@@ -175,6 +213,7 @@ class ModelRunner:
|
|
175
213
|
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
176
214
|
"use_mla_backend": self.use_mla_backend,
|
177
215
|
"mm_attention_backend": server_args.mm_attention_backend,
|
216
|
+
"ep_num_redundant_experts": server_args.ep_num_redundant_experts,
|
178
217
|
}
|
179
218
|
)
|
180
219
|
|
@@ -202,6 +241,31 @@ class ModelRunner:
|
|
202
241
|
enable=self.server_args.enable_memory_saver
|
203
242
|
)
|
204
243
|
|
244
|
+
if not self.is_draft_worker:
|
245
|
+
set_global_expert_location_metadata(
|
246
|
+
compute_initial_expert_location_metadata(server_args, self.model_config)
|
247
|
+
)
|
248
|
+
if self.tp_rank == 0 and get_bool_env_var(
|
249
|
+
"SGLANG_LOG_EXPERT_LOCATION_METADATA"
|
250
|
+
):
|
251
|
+
logger.info(
|
252
|
+
f"Initial expert_location_metadata: {get_global_expert_location_metadata().debug_str()}"
|
253
|
+
)
|
254
|
+
|
255
|
+
set_global_expert_distribution_recorder(
|
256
|
+
ExpertDistributionRecorder.init_new(
|
257
|
+
server_args,
|
258
|
+
get_global_expert_location_metadata(),
|
259
|
+
rank=self.tp_rank,
|
260
|
+
)
|
261
|
+
)
|
262
|
+
|
263
|
+
self.eplb_manager = (
|
264
|
+
EPLBManager(self)
|
265
|
+
if self.server_args.enable_eplb and (not self.is_draft_worker)
|
266
|
+
else None
|
267
|
+
)
|
268
|
+
|
205
269
|
# Load the model
|
206
270
|
self.sampler = Sampler()
|
207
271
|
self.load_model()
|
@@ -270,6 +334,8 @@ class ModelRunner:
|
|
270
334
|
and is_fa3_default_architecture(self.model_config.hf_config)
|
271
335
|
):
|
272
336
|
server_args.attention_backend = "fa3"
|
337
|
+
elif _is_hip:
|
338
|
+
server_args.attention_backend = "aiter"
|
273
339
|
else:
|
274
340
|
server_args.attention_backend = (
|
275
341
|
"flashinfer" if is_flashinfer_available() else "triton"
|
@@ -280,10 +346,9 @@ class ModelRunner:
|
|
280
346
|
server_args.attention_backend = "fa3"
|
281
347
|
else:
|
282
348
|
server_args.attention_backend = "triton"
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
)
|
349
|
+
logger.info(
|
350
|
+
f"Attention backend not set. Use {server_args.attention_backend} backend by default."
|
351
|
+
)
|
287
352
|
elif self.use_mla_backend:
|
288
353
|
if server_args.device != "cpu":
|
289
354
|
if server_args.attention_backend in [
|
@@ -293,10 +358,9 @@ class ModelRunner:
|
|
293
358
|
"flashmla",
|
294
359
|
"cutlass_mla",
|
295
360
|
]:
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
)
|
361
|
+
logger.info(
|
362
|
+
f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
|
363
|
+
)
|
300
364
|
else:
|
301
365
|
raise ValueError(
|
302
366
|
f"Invalid attention backend for MLA: {server_args.attention_backend}"
|
@@ -315,10 +379,9 @@ class ModelRunner:
|
|
315
379
|
server_args.attention_backend = "triton"
|
316
380
|
|
317
381
|
if server_args.enable_double_sparsity:
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
)
|
382
|
+
logger.info(
|
383
|
+
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
|
384
|
+
)
|
322
385
|
server_args.attention_backend = "triton"
|
323
386
|
server_args.disable_cuda_graph = True
|
324
387
|
if server_args.ds_heavy_channel_type is None:
|
@@ -329,26 +392,25 @@ class ModelRunner:
|
|
329
392
|
|
330
393
|
if self.is_multimodal:
|
331
394
|
self.mem_fraction_static *= 0.90
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
395
|
+
logger.info(
|
396
|
+
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
|
397
|
+
f"because this is a multimodal model."
|
398
|
+
)
|
399
|
+
if not self.is_multimodal_chunked_prefill_supported:
|
400
|
+
server_args.chunked_prefill_size = -1
|
337
401
|
logger.info(
|
338
|
-
"Automatically turn
|
402
|
+
f"Automatically turn of --chunked-prefill-size as it is not supported for "
|
403
|
+
f"{self.model_config.hf_config.model_type}"
|
339
404
|
)
|
340
|
-
server_args.chunked_prefill_size = -1
|
341
405
|
|
342
406
|
if not self.use_mla_backend:
|
343
407
|
server_args.disable_chunked_prefix_cache = True
|
344
408
|
elif self.page_size > 1:
|
345
|
-
|
346
|
-
logger.info("Disable chunked prefix cache when page size > 1.")
|
409
|
+
logger.info("Disable chunked prefix cache when page size > 1.")
|
347
410
|
server_args.disable_chunked_prefix_cache = True
|
348
411
|
|
349
412
|
if not server_args.disable_chunked_prefix_cache:
|
350
|
-
|
351
|
-
logger.info("Chunked prefix cache is turned on.")
|
413
|
+
logger.info("Chunked prefix cache is turned on.")
|
352
414
|
|
353
415
|
def init_torch_distributed(self):
|
354
416
|
logger.info("Init torch distributed begin.")
|
@@ -445,10 +507,9 @@ class ModelRunner:
|
|
445
507
|
torch.set_num_threads(1)
|
446
508
|
if self.device == "cuda":
|
447
509
|
if torch.cuda.get_device_capability()[0] < 8:
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
)
|
510
|
+
logger.info(
|
511
|
+
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
|
512
|
+
)
|
452
513
|
self.server_args.dtype = "float16"
|
453
514
|
self.model_config.dtype = torch.float16
|
454
515
|
if torch.cuda.get_device_capability()[1] < 5:
|
@@ -484,11 +545,10 @@ class ModelRunner:
|
|
484
545
|
self.model.load_kv_cache_scales(
|
485
546
|
self.server_args.quantization_param_path
|
486
547
|
)
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
)
|
548
|
+
logger.info(
|
549
|
+
"Loaded KV cache scaling factors from %s",
|
550
|
+
self.server_args.quantization_param_path,
|
551
|
+
)
|
492
552
|
else:
|
493
553
|
raise RuntimeError(
|
494
554
|
"Using FP8 KV cache and scaling factors provided but "
|
@@ -531,6 +591,16 @@ class ModelRunner:
|
|
531
591
|
f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
|
532
592
|
) from None
|
533
593
|
|
594
|
+
def update_expert_location(
|
595
|
+
self, new_expert_location_metadata: ExpertLocationMetadata
|
596
|
+
):
|
597
|
+
expert_location_updater.update_expert_location(
|
598
|
+
self.model.routed_experts_weights_of_layer,
|
599
|
+
new_expert_location_metadata,
|
600
|
+
nnodes=self.server_args.nnodes,
|
601
|
+
rank=self.tp_rank,
|
602
|
+
)
|
603
|
+
|
534
604
|
def update_weights_from_disk(
|
535
605
|
self, model_path: str, load_format: str
|
536
606
|
) -> tuple[bool, str]:
|
@@ -552,13 +622,7 @@ class ModelRunner:
|
|
552
622
|
|
553
623
|
def get_weight_iter(config):
|
554
624
|
iter = loader._get_weights_iterator(
|
555
|
-
DefaultModelLoader.Source(
|
556
|
-
config.model_path,
|
557
|
-
revision=config.revision,
|
558
|
-
fall_back_to_pt=getattr(
|
559
|
-
self.model, "fall_back_to_pt_during_load", True
|
560
|
-
),
|
561
|
-
)
|
625
|
+
DefaultModelLoader.Source.init_new(config, self.model)
|
562
626
|
)
|
563
627
|
return iter
|
564
628
|
|
@@ -631,7 +695,6 @@ class ModelRunner:
|
|
631
695
|
rank=rank,
|
632
696
|
group_name=group_name,
|
633
697
|
)
|
634
|
-
dist.barrier(group=self._model_update_group, device_ids=[rank])
|
635
698
|
return True, "Succeeded to initialize custom process group."
|
636
699
|
except Exception as e:
|
637
700
|
message = f"Failed to initialize custom process group: {e}."
|
@@ -726,12 +789,15 @@ class ModelRunner:
|
|
726
789
|
distributed=get_world_group().world_size > 1,
|
727
790
|
cpu_group=get_world_group().cpu_group,
|
728
791
|
)
|
729
|
-
if self.
|
730
|
-
num_layers = (
|
731
|
-
self.model_config.
|
732
|
-
|
733
|
-
|
792
|
+
if self.is_draft_worker:
|
793
|
+
num_layers = getattr(
|
794
|
+
self.model_config.hf_config,
|
795
|
+
"num_nextn_predict_layers",
|
796
|
+
self.num_effective_layers,
|
734
797
|
)
|
798
|
+
else:
|
799
|
+
num_layers = self.num_effective_layers
|
800
|
+
if self.use_mla_backend:
|
735
801
|
# FIXME: pipeline parallelism is not compatible with mla backend
|
736
802
|
assert self.pp_size == 1
|
737
803
|
cell_size = (
|
@@ -743,7 +809,7 @@ class ModelRunner:
|
|
743
809
|
cell_size = (
|
744
810
|
self.model_config.get_num_kv_heads(get_attention_tp_size())
|
745
811
|
* self.model_config.head_dim
|
746
|
-
*
|
812
|
+
* num_layers
|
747
813
|
* 2
|
748
814
|
* torch._utils._element_size(self.kv_cache_dtype)
|
749
815
|
)
|
@@ -762,7 +828,7 @@ class ModelRunner:
|
|
762
828
|
if self.server_args.kv_cache_dtype == "auto":
|
763
829
|
self.kv_cache_dtype = self.dtype
|
764
830
|
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
|
765
|
-
if
|
831
|
+
if _is_hip: # Using natively supported format
|
766
832
|
self.kv_cache_dtype = torch.float8_e5m2fnuz
|
767
833
|
else:
|
768
834
|
self.kv_cache_dtype = torch.float8_e5m2
|
@@ -940,6 +1006,10 @@ class ModelRunner:
|
|
940
1006
|
)
|
941
1007
|
|
942
1008
|
self.attn_backend = FlashInferMLAAttnBackend(self)
|
1009
|
+
elif self.server_args.attention_backend == "aiter":
|
1010
|
+
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
|
1011
|
+
|
1012
|
+
self.attn_backend = AiterAttnBackend(self)
|
943
1013
|
elif self.server_args.attention_backend == "triton":
|
944
1014
|
assert self.sliding_window_size is None, (
|
945
1015
|
"Window attention is not supported in the triton attention backend. "
|
@@ -1020,7 +1090,7 @@ class ModelRunner:
|
|
1020
1090
|
if self.server_args.disable_cuda_graph:
|
1021
1091
|
return
|
1022
1092
|
|
1023
|
-
tic = time.
|
1093
|
+
tic = time.perf_counter()
|
1024
1094
|
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
1025
1095
|
logger.info(
|
1026
1096
|
f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
@@ -1028,13 +1098,12 @@ class ModelRunner:
|
|
1028
1098
|
self.cuda_graph_runner = CudaGraphRunner(self)
|
1029
1099
|
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
1030
1100
|
logger.info(
|
1031
|
-
f"Capture cuda graph end. Time elapsed: {time.
|
1101
|
+
f"Capture cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
|
1032
1102
|
f"mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
|
1033
1103
|
)
|
1034
1104
|
|
1035
1105
|
def apply_torch_tp(self):
|
1036
|
-
|
1037
|
-
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
|
1106
|
+
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
|
1038
1107
|
from sglang.srt.model_parallel import tensor_parallel
|
1039
1108
|
|
1040
1109
|
device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
|
@@ -1093,6 +1162,27 @@ class ModelRunner:
|
|
1093
1162
|
forward_batch: ForwardBatch,
|
1094
1163
|
skip_attn_backend_init: bool = False,
|
1095
1164
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
1165
|
+
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
|
1166
|
+
self.forward_pass_id += 1
|
1167
|
+
|
1168
|
+
with get_global_expert_distribution_recorder().with_forward_pass(
|
1169
|
+
self.forward_pass_id,
|
1170
|
+
forward_batch,
|
1171
|
+
):
|
1172
|
+
output = self._forward_raw(
|
1173
|
+
forward_batch, skip_attn_backend_init, pp_proxy_tensors
|
1174
|
+
)
|
1175
|
+
|
1176
|
+
if self.eplb_manager is not None:
|
1177
|
+
self.eplb_manager.on_forward_pass_end(self.forward_pass_id)
|
1178
|
+
|
1179
|
+
return output
|
1180
|
+
|
1181
|
+
def _forward_raw(
|
1182
|
+
self,
|
1183
|
+
forward_batch: ForwardBatch,
|
1184
|
+
skip_attn_backend_init: bool,
|
1185
|
+
pp_proxy_tensors: Optional[PPProxyTensors],
|
1096
1186
|
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
|
1097
1187
|
can_run_cuda_graph = bool(
|
1098
1188
|
forward_batch.forward_mode.is_cuda_graph()
|
@@ -1171,7 +1261,7 @@ class ModelRunner:
|
|
1171
1261
|
def model_is_mrope(self) -> bool:
|
1172
1262
|
"""Detect if the model has "mrope" rope_scaling type.
|
1173
1263
|
mrope requires keep "rope_deltas" between prompt and decoding phases."""
|
1174
|
-
rope_scaling = getattr(self.model_config.
|
1264
|
+
rope_scaling = getattr(self.model_config.hf_text_config, "rope_scaling", {})
|
1175
1265
|
if rope_scaling is None:
|
1176
1266
|
return False
|
1177
1267
|
is_mrope_enabled = "mrope_section" in rope_scaling
|
@@ -197,6 +197,15 @@ class DefaultModelLoader(BaseModelLoader):
|
|
197
197
|
fall_back_to_pt: bool = True
|
198
198
|
"""Whether .pt weights can be used."""
|
199
199
|
|
200
|
+
@classmethod
|
201
|
+
def init_new(cls, model_config: ModelConfig, model):
|
202
|
+
return cls(
|
203
|
+
model_config.model_path,
|
204
|
+
model_config.revision,
|
205
|
+
prefix="",
|
206
|
+
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True),
|
207
|
+
)
|
208
|
+
|
200
209
|
def __init__(self, load_config: LoadConfig):
|
201
210
|
super().__init__(load_config)
|
202
211
|
if load_config.model_loader_extra_config:
|
@@ -341,12 +350,7 @@ class DefaultModelLoader(BaseModelLoader):
|
|
341
350
|
model: nn.Module,
|
342
351
|
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
343
352
|
|
344
|
-
primary_weights = DefaultModelLoader.Source(
|
345
|
-
model_config.model_path,
|
346
|
-
model_config.revision,
|
347
|
-
prefix="",
|
348
|
-
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True),
|
349
|
-
)
|
353
|
+
primary_weights = DefaultModelLoader.Source.init_new(model_config, model)
|
350
354
|
yield from self._get_weights_iterator(primary_weights)
|
351
355
|
|
352
356
|
secondary_weights = cast(
|
sglang/srt/models/clip.py
CHANGED
@@ -168,7 +168,7 @@ class CLIPEncoderLayer(nn.Module):
|
|
168
168
|
softmax_in_single_precision=softmax_in_single_precision,
|
169
169
|
flatten_batch=True,
|
170
170
|
quant_config=quant_config,
|
171
|
-
prefix=add_prefix("
|
171
|
+
prefix=add_prefix("self_attn", prefix),
|
172
172
|
)
|
173
173
|
self.mlp = CLIPMLP(
|
174
174
|
config,
|
@@ -395,6 +395,10 @@ class CLIPVisionModel(nn.Module):
|
|
395
395
|
config, quant_config, prefix=add_prefix("vision_model", prefix)
|
396
396
|
)
|
397
397
|
|
398
|
+
@property
|
399
|
+
def device(self) -> torch.device:
|
400
|
+
return self.vision_model.device
|
401
|
+
|
398
402
|
def forward(self, pixel_values: torch.Tensor):
|
399
403
|
return self.vision_model(pixel_values)
|
400
404
|
|