sglang 0.5.4__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +56 -12
- sglang/launch_server.py +2 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +101 -4
- sglang/srt/compilation/backend.py +1 -1
- sglang/srt/configs/model_config.py +5 -5
- sglang/srt/distributed/parallel_state.py +0 -7
- sglang/srt/entrypoints/engine.py +18 -15
- sglang/srt/entrypoints/grpc_server.py +0 -1
- sglang/srt/entrypoints/http_server.py +75 -94
- sglang/srt/environ.py +16 -2
- sglang/srt/eplb/expert_distribution.py +30 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/minimax_m2.py +367 -0
- sglang/srt/layers/activation.py +6 -0
- sglang/srt/layers/attention/flashattention_backend.py +12 -2
- sglang/srt/layers/attention/flashinfer_backend.py +10 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +18 -10
- sglang/srt/layers/attention/trtllm_mla_backend.py +1 -13
- sglang/srt/layers/attention/utils.py +78 -0
- sglang/srt/layers/communicator.py +1 -0
- sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/layernorm.py +19 -4
- sglang/srt/layers/logits_processor.py +5 -0
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
- sglang/srt/layers/moe/ep_moe/layer.py +79 -272
- sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +287 -22
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
- sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
- sglang/srt/layers/moe/token_dispatcher/deepep.py +18 -14
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +4 -4
- sglang/srt/layers/moe/utils.py +3 -4
- sglang/srt/layers/quantization/__init__.py +3 -5
- sglang/srt/layers/quantization/awq.py +0 -3
- sglang/srt/layers/quantization/base_config.py +7 -0
- sglang/srt/layers/quantization/fp8.py +68 -63
- sglang/srt/layers/quantization/gguf.py +566 -0
- sglang/srt/layers/quantization/mxfp4.py +30 -38
- sglang/srt/layers/quantization/unquant.py +23 -45
- sglang/srt/layers/quantization/w4afp8.py +38 -2
- sglang/srt/layers/radix_attention.py +5 -2
- sglang/srt/layers/rotary_embedding.py +13 -1
- sglang/srt/layers/sampler.py +12 -1
- sglang/srt/managers/io_struct.py +3 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -1
- sglang/srt/managers/scheduler.py +21 -15
- sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
- sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
- sglang/srt/managers/tokenizer_manager.py +11 -19
- sglang/srt/mem_cache/hicache_storage.py +7 -1
- sglang/srt/mem_cache/memory_pool.py +82 -0
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
- sglang/srt/model_executor/forward_batch_info.py +44 -3
- sglang/srt/model_executor/model_runner.py +1 -149
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
- sglang/srt/models/deepseek_v2.py +147 -44
- sglang/srt/models/glm4_moe.py +322 -354
- sglang/srt/models/glm4_moe_nextn.py +4 -14
- sglang/srt/models/glm4v_moe.py +29 -196
- sglang/srt/models/minimax_m2.py +922 -0
- sglang/srt/models/nvila.py +355 -0
- sglang/srt/models/nvila_lite.py +184 -0
- sglang/srt/models/qwen2.py +22 -1
- sglang/srt/models/qwen3.py +34 -4
- sglang/srt/models/qwen3_moe.py +2 -4
- sglang/srt/multimodal/processors/base_processor.py +1 -0
- sglang/srt/multimodal/processors/glm4v.py +1 -1
- sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
- sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
- sglang/srt/parser/reasoning_parser.py +28 -1
- sglang/srt/server_args.py +365 -186
- sglang/srt/single_batch_overlap.py +2 -7
- sglang/srt/utils/common.py +87 -42
- sglang/srt/utils/hf_transformers_utils.py +7 -3
- sglang/test/test_deterministic.py +235 -12
- sglang/test/test_deterministic_utils.py +2 -1
- sglang/version.py +1 -1
- {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +7 -6
- {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +87 -82
- sglang/srt/models/vila.py +0 -306
- {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
sglang/bench_serving.py
CHANGED
|
@@ -88,6 +88,7 @@ class RequestFuncOutput:
|
|
|
88
88
|
latency: float = 0.0
|
|
89
89
|
ttft: float = 0.0 # Time to first token
|
|
90
90
|
itl: List[float] = field(default_factory=list) # List of inter-token latencies
|
|
91
|
+
text_chunks: List[str] = field(default_factory=list)
|
|
91
92
|
prompt_len: int = 0
|
|
92
93
|
error: str = ""
|
|
93
94
|
output_len: int = 0
|
|
@@ -258,6 +259,9 @@ async def async_request_openai_completions(
|
|
|
258
259
|
|
|
259
260
|
# Decoding phase
|
|
260
261
|
else:
|
|
262
|
+
output.text_chunks.append(
|
|
263
|
+
data["choices"][0]["text"]
|
|
264
|
+
)
|
|
261
265
|
output.itl.append(timestamp - most_recent_timestamp)
|
|
262
266
|
|
|
263
267
|
most_recent_timestamp = timestamp
|
|
@@ -574,9 +578,8 @@ async def async_request_sglang_generate(
|
|
|
574
578
|
num_new_tokens = output_len - last_output_len
|
|
575
579
|
if num_new_tokens == 0:
|
|
576
580
|
continue
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
) / num_new_tokens
|
|
581
|
+
chunk_gap = timestamp - most_recent_timestamp
|
|
582
|
+
adjust_itl = chunk_gap / num_new_tokens
|
|
580
583
|
output.itl.extend([adjust_itl] * num_new_tokens)
|
|
581
584
|
|
|
582
585
|
most_recent_timestamp = timestamp
|
|
@@ -764,6 +767,7 @@ def get_dataset(args, tokenizer, model_id=None):
|
|
|
764
767
|
image_content=args.image_content,
|
|
765
768
|
image_format=args.image_format,
|
|
766
769
|
image_resolution=args.image_resolution,
|
|
770
|
+
backend=args.backend,
|
|
767
771
|
)
|
|
768
772
|
elif args.dataset_name == "generated-shared-prefix":
|
|
769
773
|
assert not tokenize_prompt
|
|
@@ -781,6 +785,7 @@ def get_dataset(args, tokenizer, model_id=None):
|
|
|
781
785
|
input_requests = sample_mmmu_requests(
|
|
782
786
|
num_requests=args.num_prompts,
|
|
783
787
|
processor=processor,
|
|
788
|
+
backend=args.backend,
|
|
784
789
|
fixed_output_len=args.random_output_len,
|
|
785
790
|
random_sample=True,
|
|
786
791
|
)
|
|
@@ -1009,6 +1014,7 @@ async def get_mooncake_request_over_time(
|
|
|
1009
1014
|
def sample_mmmu_requests(
|
|
1010
1015
|
num_requests: int,
|
|
1011
1016
|
processor: AutoProcessor | AutoTokenizer,
|
|
1017
|
+
backend: str,
|
|
1012
1018
|
fixed_output_len: Optional[int] = None,
|
|
1013
1019
|
random_sample: bool = True,
|
|
1014
1020
|
) -> List[DatasetRow]:
|
|
@@ -1081,7 +1087,7 @@ def sample_mmmu_requests(
|
|
|
1081
1087
|
text_prompt = f"Question: {question}\n\nAnswer: "
|
|
1082
1088
|
output_len = fixed_output_len if fixed_output_len is not None else 256
|
|
1083
1089
|
data_row = create_mm_data_row(
|
|
1084
|
-
text_prompt, [image], [image_data], output_len, processor
|
|
1090
|
+
text_prompt, [image], [image_data], output_len, processor, backend
|
|
1085
1091
|
)
|
|
1086
1092
|
filtered_dataset.append(data_row)
|
|
1087
1093
|
|
|
@@ -1316,13 +1322,19 @@ def parse_image_resolution(image_resolution: str) -> Tuple[int, int]:
|
|
|
1316
1322
|
)
|
|
1317
1323
|
|
|
1318
1324
|
|
|
1319
|
-
def create_mm_data_row(
|
|
1325
|
+
def create_mm_data_row(
|
|
1326
|
+
text_prompt, images: list, images_base64, output_len, processor, backend
|
|
1327
|
+
):
|
|
1320
1328
|
try:
|
|
1321
|
-
|
|
1322
|
-
|
|
1323
|
-
|
|
1324
|
-
|
|
1325
|
-
|
|
1329
|
+
if type(processor).__name__ == "Phi4MMProcessor":
|
|
1330
|
+
# <|endoftext10|> is the image token used in the phi-4-multimodal model.
|
|
1331
|
+
content_items = text_prompt.replace("image 1", "|endoftext10|")
|
|
1332
|
+
else:
|
|
1333
|
+
content_items = [
|
|
1334
|
+
{"type": "image", "image": {"url": image_base64}}
|
|
1335
|
+
for image_base64 in images_base64
|
|
1336
|
+
]
|
|
1337
|
+
content_items.append({"type": "text", "text": text_prompt})
|
|
1326
1338
|
prompt_str = processor.apply_chat_template(
|
|
1327
1339
|
[{"role": "user", "content": content_items}],
|
|
1328
1340
|
add_generation_prompt=True,
|
|
@@ -1362,8 +1374,16 @@ def create_mm_data_row(text_prompt, images: list, images_base64, output_len, pro
|
|
|
1362
1374
|
# Vision tokens = total tokens - text tokens
|
|
1363
1375
|
vision_prompt_len = prompt_len - text_prompt_len
|
|
1364
1376
|
|
|
1377
|
+
use_raw_prompt = backend in [
|
|
1378
|
+
"sglang-oai",
|
|
1379
|
+
"sglang-oai-chat",
|
|
1380
|
+
"vllm",
|
|
1381
|
+
"vllm-chat",
|
|
1382
|
+
"lmdeploy",
|
|
1383
|
+
"lmdeploy-chat",
|
|
1384
|
+
]
|
|
1365
1385
|
return DatasetRow(
|
|
1366
|
-
prompt=text_prompt,
|
|
1386
|
+
prompt=text_prompt if use_raw_prompt else prompt_str,
|
|
1367
1387
|
prompt_len=prompt_len,
|
|
1368
1388
|
output_len=output_len,
|
|
1369
1389
|
text_prompt_len=text_prompt_len,
|
|
@@ -1382,6 +1402,7 @@ def sample_image_requests(
|
|
|
1382
1402
|
image_content: str,
|
|
1383
1403
|
image_format: str,
|
|
1384
1404
|
image_resolution: str,
|
|
1405
|
+
backend: str,
|
|
1385
1406
|
) -> List[DatasetRow]:
|
|
1386
1407
|
"""Generate requests with images.
|
|
1387
1408
|
|
|
@@ -1447,6 +1468,7 @@ def sample_image_requests(
|
|
|
1447
1468
|
list(images_base64),
|
|
1448
1469
|
int(output_lens[i]),
|
|
1449
1470
|
processor,
|
|
1471
|
+
backend,
|
|
1450
1472
|
)
|
|
1451
1473
|
|
|
1452
1474
|
dataset.append(data_row)
|
|
@@ -1607,6 +1629,7 @@ def calculate_metrics(
|
|
|
1607
1629
|
dur_s: float,
|
|
1608
1630
|
tokenizer: PreTrainedTokenizerBase,
|
|
1609
1631
|
backend: str,
|
|
1632
|
+
accept_length: Optional[float] = None,
|
|
1610
1633
|
) -> Tuple[BenchmarkMetrics, List[int]]:
|
|
1611
1634
|
output_lens: List[int] = []
|
|
1612
1635
|
retokenized_output_lens: List[int] = []
|
|
@@ -1618,6 +1641,14 @@ def calculate_metrics(
|
|
|
1618
1641
|
tpots: List[float] = []
|
|
1619
1642
|
ttfts: List[float] = []
|
|
1620
1643
|
e2e_latencies: List[float] = []
|
|
1644
|
+
retokenized_itls: List[float] = []
|
|
1645
|
+
|
|
1646
|
+
use_retokenized_itl = (
|
|
1647
|
+
accept_length is not None
|
|
1648
|
+
and accept_length > 0
|
|
1649
|
+
and backend in ("sglang-oai", "sglang-oai-chat")
|
|
1650
|
+
)
|
|
1651
|
+
|
|
1621
1652
|
for i in range(len(outputs)):
|
|
1622
1653
|
if outputs[i].success:
|
|
1623
1654
|
output_len = outputs[i].output_len
|
|
@@ -1631,7 +1662,17 @@ def calculate_metrics(
|
|
|
1631
1662
|
total_input_vision += input_requests[i].vision_prompt_len
|
|
1632
1663
|
if output_len > 1:
|
|
1633
1664
|
tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1))
|
|
1634
|
-
|
|
1665
|
+
if use_retokenized_itl:
|
|
1666
|
+
for k, itl in enumerate(outputs[i].itl):
|
|
1667
|
+
num_tokens = len(
|
|
1668
|
+
tokenizer.encode(
|
|
1669
|
+
outputs[i].text_chunks[k], add_special_tokens=False
|
|
1670
|
+
)
|
|
1671
|
+
)
|
|
1672
|
+
adjusted_itl = itl / num_tokens
|
|
1673
|
+
retokenized_itls.extend([adjusted_itl] * num_tokens)
|
|
1674
|
+
else:
|
|
1675
|
+
itls += outputs[i].itl
|
|
1635
1676
|
ttfts.append(outputs[i].ttft)
|
|
1636
1677
|
|
|
1637
1678
|
e2e_latencies.append(outputs[i].latency)
|
|
@@ -1647,6 +1688,8 @@ def calculate_metrics(
|
|
|
1647
1688
|
"on the benchmark arguments.",
|
|
1648
1689
|
stacklevel=2,
|
|
1649
1690
|
)
|
|
1691
|
+
|
|
1692
|
+
itls = retokenized_itls if use_retokenized_itl else itls
|
|
1650
1693
|
metrics = BenchmarkMetrics(
|
|
1651
1694
|
completed=completed,
|
|
1652
1695
|
total_input=total_input,
|
|
@@ -1910,6 +1953,7 @@ async def benchmark(
|
|
|
1910
1953
|
dur_s=benchmark_duration,
|
|
1911
1954
|
tokenizer=tokenizer,
|
|
1912
1955
|
backend=backend,
|
|
1956
|
+
accept_length=accept_length,
|
|
1913
1957
|
)
|
|
1914
1958
|
|
|
1915
1959
|
print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
|
sglang/launch_server.py
CHANGED
|
@@ -12,10 +12,12 @@ if __name__ == "__main__":
|
|
|
12
12
|
|
|
13
13
|
try:
|
|
14
14
|
if server_args.grpc_mode:
|
|
15
|
+
# Handle gRPC server
|
|
15
16
|
from sglang.srt.entrypoints.grpc_server import serve_grpc
|
|
16
17
|
|
|
17
18
|
asyncio.run(serve_grpc(server_args))
|
|
18
19
|
else:
|
|
20
|
+
# Handle HTTP server
|
|
19
21
|
from sglang.srt.entrypoints.http_server import launch_server
|
|
20
22
|
|
|
21
23
|
launch_server(server_args)
|
|
@@ -9,6 +9,22 @@ import torch
|
|
|
9
9
|
import triton
|
|
10
10
|
import triton.language as tl
|
|
11
11
|
|
|
12
|
+
from sglang.srt.layers.deep_gemm_wrapper.configurer import ENABLE_JIT_DEEPGEMM
|
|
13
|
+
from sglang.srt.utils.common import calc_diff, get_bool_env_var
|
|
14
|
+
|
|
15
|
+
if ENABLE_JIT_DEEPGEMM:
|
|
16
|
+
import deep_gemm
|
|
17
|
+
|
|
18
|
+
_ENABLE_MM_DEEPGEMM = get_bool_env_var(
|
|
19
|
+
"SGLANG_BATCH_INVARIANT_OPS_ENABLE_MM_DEEPGEMM", "1"
|
|
20
|
+
)
|
|
21
|
+
_ENABLE_MM_COMPARISON_TEST = get_bool_env_var(
|
|
22
|
+
"SGLANG_BATCH_INVARIANT_OPS_ENABLE_MM_COMPARISON_TEST"
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
if not _ENABLE_MM_DEEPGEMM:
|
|
26
|
+
print("Disable DeepGEMM in batch invariant ops. Performance may be suboptimal.")
|
|
27
|
+
|
|
12
28
|
__all__ = [
|
|
13
29
|
"set_batch_invariant_mode",
|
|
14
30
|
"is_batch_invariant_mode_enabled",
|
|
@@ -140,7 +156,7 @@ def matmul_kernel_persistent(
|
|
|
140
156
|
tl.store(c_ptrs, c, mask=c_mask)
|
|
141
157
|
|
|
142
158
|
|
|
143
|
-
def
|
|
159
|
+
def _matmul_persistent_triton(
|
|
144
160
|
a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None
|
|
145
161
|
):
|
|
146
162
|
# Check constraints.
|
|
@@ -217,6 +233,54 @@ def matmul_persistent(
|
|
|
217
233
|
return c
|
|
218
234
|
|
|
219
235
|
|
|
236
|
+
def _matmul_persistent_deepgemm(
|
|
237
|
+
a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None
|
|
238
|
+
):
|
|
239
|
+
M, K = a.shape
|
|
240
|
+
K, N = b.shape
|
|
241
|
+
dtype = a.dtype
|
|
242
|
+
out = torch.empty((M, N), device=a.device, dtype=dtype)
|
|
243
|
+
|
|
244
|
+
deep_gemm.bf16_gemm_nn(a, b, out)
|
|
245
|
+
|
|
246
|
+
# TODO can this be put in DeepGEMM's `c`?
|
|
247
|
+
if bias is not None:
|
|
248
|
+
out += bias
|
|
249
|
+
|
|
250
|
+
return out
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def matmul_persistent(
|
|
254
|
+
a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None
|
|
255
|
+
):
|
|
256
|
+
if (
|
|
257
|
+
_ENABLE_MM_DEEPGEMM
|
|
258
|
+
and ENABLE_JIT_DEEPGEMM
|
|
259
|
+
and (a.dtype == torch.bfloat16)
|
|
260
|
+
and (b.dtype == torch.bfloat16)
|
|
261
|
+
and a.is_contiguous()
|
|
262
|
+
and b.transpose(0, 1).is_contiguous()
|
|
263
|
+
):
|
|
264
|
+
if _ENABLE_MM_COMPARISON_TEST:
|
|
265
|
+
out_triton = _matmul_persistent_triton(a=a, b=b, bias=bias)
|
|
266
|
+
out_deepgemm = _matmul_persistent_deepgemm(a=a, b=b, bias=bias)
|
|
267
|
+
diff = calc_diff(out_triton, out_deepgemm)
|
|
268
|
+
assert diff < 0.0001, f"{diff=} {out_triton=} {out_deepgemm=}"
|
|
269
|
+
# can be enabled for debugging
|
|
270
|
+
# print(
|
|
271
|
+
# f"{diff=} "
|
|
272
|
+
# f"{(out_triton - out_deepgemm).abs().mean()=} "
|
|
273
|
+
# f"{(out_triton - out_deepgemm).abs().sum()=} "
|
|
274
|
+
# f"{torch.sum(out_triton != out_deepgemm)=} "
|
|
275
|
+
# )
|
|
276
|
+
# print(f"{a=} {b=} {bias=} {out_triton=} {out_deepgemm=}")
|
|
277
|
+
return out_deepgemm
|
|
278
|
+
|
|
279
|
+
return _matmul_persistent_deepgemm(a=a, b=b, bias=bias)
|
|
280
|
+
|
|
281
|
+
return _matmul_persistent_triton(a=a, b=b, bias=bias)
|
|
282
|
+
|
|
283
|
+
|
|
220
284
|
@triton.jit
|
|
221
285
|
def _log_softmax_kernel(
|
|
222
286
|
input_ptr,
|
|
@@ -495,16 +559,39 @@ def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None =
|
|
|
495
559
|
return torch.sum(input, dim=dim, keepdim=keepdim, dtype=torch.float32) / n_elems
|
|
496
560
|
|
|
497
561
|
|
|
562
|
+
def bmm_batch_invariant(a, b, *, out=None):
|
|
563
|
+
# Batched matrix multiply: (B, M, K) x (B, K, N) -> (B, M, N)
|
|
564
|
+
# Process each batch separately with our persistent kernel
|
|
565
|
+
if a.ndim == 3 and b.ndim == 3:
|
|
566
|
+
results = []
|
|
567
|
+
for i in range(a.shape[0]):
|
|
568
|
+
results.append(matmul_persistent(a[i], b[i]))
|
|
569
|
+
result = torch.stack(results, dim=0)
|
|
570
|
+
|
|
571
|
+
if out is not None:
|
|
572
|
+
out.copy_(result)
|
|
573
|
+
return out
|
|
574
|
+
return result
|
|
575
|
+
else:
|
|
576
|
+
raise ValueError(
|
|
577
|
+
f"bmm_batch_invariant expects 3D tensors, "
|
|
578
|
+
f"got shapes {a.shape} and {b.shape}"
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
|
|
498
582
|
_batch_invariant_MODE = False
|
|
499
583
|
_batch_invariant_LIB = None
|
|
584
|
+
_original_torch_bmm = None
|
|
500
585
|
|
|
501
586
|
|
|
502
587
|
def is_batch_invariant_mode_enabled():
|
|
503
588
|
return _batch_invariant_MODE
|
|
504
589
|
|
|
505
590
|
|
|
506
|
-
def enable_batch_invariant_mode(
|
|
507
|
-
|
|
591
|
+
def enable_batch_invariant_mode(
|
|
592
|
+
enable_bmm: bool = True,
|
|
593
|
+
):
|
|
594
|
+
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
|
|
508
595
|
if _batch_invariant_MODE:
|
|
509
596
|
return
|
|
510
597
|
|
|
@@ -517,11 +604,21 @@ def enable_batch_invariant_mode():
|
|
|
517
604
|
)
|
|
518
605
|
_batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA")
|
|
519
606
|
|
|
607
|
+
if enable_bmm:
|
|
608
|
+
_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA")
|
|
609
|
+
|
|
610
|
+
# Also monkeypatch torch.bmm directly as a fallback
|
|
611
|
+
_original_torch_bmm = torch.bmm
|
|
612
|
+
torch.bmm = bmm_batch_invariant
|
|
613
|
+
|
|
520
614
|
|
|
521
615
|
def disable_batch_invariant_mode():
|
|
522
|
-
global _batch_invariant_MODE, _batch_invariant_LIB
|
|
616
|
+
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
|
|
523
617
|
if _batch_invariant_LIB is not None:
|
|
524
618
|
_batch_invariant_LIB._destroy()
|
|
619
|
+
if _original_torch_bmm is not None:
|
|
620
|
+
torch.bmm = _original_torch_bmm
|
|
621
|
+
_original_torch_bmm = None
|
|
525
622
|
_batch_invariant_MODE = False
|
|
526
623
|
_batch_invariant_LIB = None
|
|
527
624
|
|
|
@@ -392,7 +392,7 @@ class SGLangBackend:
|
|
|
392
392
|
self.configure_post_pass()
|
|
393
393
|
|
|
394
394
|
self.split_gm, self.piecewise_graphs = split_graph(
|
|
395
|
-
graph, ["sglang.unified_attention_with_output"]
|
|
395
|
+
graph, ["sglang.unified_attention_with_output", "sglang.inplace_all_reduce"]
|
|
396
396
|
)
|
|
397
397
|
|
|
398
398
|
from torch._dynamo.utils import lazy_format_graph_code
|
|
@@ -535,7 +535,7 @@ class ModelConfig:
|
|
|
535
535
|
quant_cfg = self._parse_modelopt_quant_config(quant_config_dict)
|
|
536
536
|
return quant_cfg
|
|
537
537
|
|
|
538
|
-
def _parse_modelopt_quant_config(self, quant_config_dict: dict) -> dict:
|
|
538
|
+
def _parse_modelopt_quant_config(self, quant_config_dict: dict) -> Optional[dict]:
|
|
539
539
|
"""Parse ModelOpt quantization config and return the appropriate quant_method."""
|
|
540
540
|
json_quant_configs = quant_config_dict["quantization"]
|
|
541
541
|
quant_algo = json_quant_configs.get("quant_algo", None)
|
|
@@ -547,8 +547,7 @@ class ModelConfig:
|
|
|
547
547
|
elif quant_algo and "FP8" in quant_algo:
|
|
548
548
|
return {"quant_method": "modelopt_fp8"}
|
|
549
549
|
else:
|
|
550
|
-
|
|
551
|
-
return {"quant_method": "modelopt_fp8"}
|
|
550
|
+
return None
|
|
552
551
|
|
|
553
552
|
def _is_already_quantized(self) -> bool:
|
|
554
553
|
"""Check if the model is already quantized based on config files."""
|
|
@@ -806,7 +805,7 @@ def _get_and_verify_dtype(
|
|
|
806
805
|
) -> torch.dtype:
|
|
807
806
|
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
|
|
808
807
|
# because config.torch_dtype can be None.
|
|
809
|
-
config_dtype = getattr(config, "
|
|
808
|
+
config_dtype = getattr(config, "dtype", None)
|
|
810
809
|
if isinstance(config_dtype, str):
|
|
811
810
|
config_dtype = _STR_DTYPE_TO_TORCH_DTYPE.get(config_dtype, None)
|
|
812
811
|
if config_dtype is None:
|
|
@@ -915,12 +914,13 @@ multimodal_model_archs = [
|
|
|
915
914
|
"InternVLChatModel",
|
|
916
915
|
"InternS1ForConditionalGeneration",
|
|
917
916
|
"Phi4MMForCausalLM",
|
|
918
|
-
"VILAForConditionalGeneration",
|
|
919
917
|
"Step3VLForConditionalGeneration",
|
|
920
918
|
"POINTSV15ChatModel",
|
|
921
919
|
"DotsVLMForCausalLM",
|
|
922
920
|
"DotsOCRForCausalLM",
|
|
923
921
|
"Sarashina2VisionForCausalLM",
|
|
922
|
+
"NVILAForConditionalGeneration",
|
|
923
|
+
"NVILALiteForConditionalGeneration",
|
|
924
924
|
"DeepseekOCRForCausalLM",
|
|
925
925
|
]
|
|
926
926
|
|
|
@@ -340,17 +340,10 @@ class GroupCoordinator:
|
|
|
340
340
|
self.qr_comm: Optional[QuickAllReduce] = None
|
|
341
341
|
if use_custom_allreduce and self.world_size > 1:
|
|
342
342
|
# Initialize a custom fast all-reduce implementation.
|
|
343
|
-
if torch_compile is not None and torch_compile:
|
|
344
|
-
# For piecewise CUDA graph, the requirement for custom allreduce is larger to
|
|
345
|
-
# avoid illegal cuda memory access.
|
|
346
|
-
ca_max_size = 256 * 1024 * 1024
|
|
347
|
-
else:
|
|
348
|
-
ca_max_size = 8 * 1024 * 1024
|
|
349
343
|
try:
|
|
350
344
|
self.ca_comm = CustomAllreduce(
|
|
351
345
|
group=self.cpu_group,
|
|
352
346
|
device=self.device,
|
|
353
|
-
max_size=ca_max_size,
|
|
354
347
|
)
|
|
355
348
|
except Exception as e:
|
|
356
349
|
logger.warning(
|
sglang/srt/entrypoints/engine.py
CHANGED
|
@@ -101,7 +101,7 @@ class Engine(EngineBase):
|
|
|
101
101
|
|
|
102
102
|
Note:
|
|
103
103
|
1. The HTTP server, Engine, and TokenizerManager all run in the main process.
|
|
104
|
-
2. Inter-process communication
|
|
104
|
+
2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.
|
|
105
105
|
"""
|
|
106
106
|
|
|
107
107
|
def __init__(self, **kwargs):
|
|
@@ -109,6 +109,8 @@ class Engine(EngineBase):
|
|
|
109
109
|
The arguments of this function is the same as `sglang/srt/server_args.py::ServerArgs`.
|
|
110
110
|
Please refer to `ServerArgs` for the documentation.
|
|
111
111
|
"""
|
|
112
|
+
|
|
113
|
+
# Parse server_args
|
|
112
114
|
if "server_args" in kwargs:
|
|
113
115
|
# Directly load server_args
|
|
114
116
|
server_args = kwargs["server_args"]
|
|
@@ -118,29 +120,28 @@ class Engine(EngineBase):
|
|
|
118
120
|
# Do not print logs by default
|
|
119
121
|
kwargs["log_level"] = "error"
|
|
120
122
|
server_args = ServerArgs(**kwargs)
|
|
123
|
+
self.server_args = server_args
|
|
124
|
+
logger.info(f"{server_args=}")
|
|
121
125
|
|
|
122
126
|
# Shutdown the subprocesses automatically when the program exits
|
|
123
127
|
atexit.register(self.shutdown)
|
|
124
128
|
|
|
125
|
-
# Allocate ports for inter-process communications
|
|
126
|
-
self.port_args = PortArgs.init_new(server_args)
|
|
127
|
-
logger.info(f"{server_args=}")
|
|
128
|
-
|
|
129
129
|
# Launch subprocesses
|
|
130
|
-
tokenizer_manager, template_manager, scheduler_info =
|
|
131
|
-
server_args=server_args
|
|
132
|
-
port_args=self.port_args,
|
|
130
|
+
tokenizer_manager, template_manager, scheduler_info, port_args = (
|
|
131
|
+
_launch_subprocesses(server_args=server_args)
|
|
133
132
|
)
|
|
134
|
-
self.server_args = server_args
|
|
135
133
|
self.tokenizer_manager = tokenizer_manager
|
|
136
134
|
self.template_manager = template_manager
|
|
137
135
|
self.scheduler_info = scheduler_info
|
|
136
|
+
self.port_args = port_args
|
|
138
137
|
|
|
138
|
+
# Initialize ZMQ sockets
|
|
139
139
|
context = zmq.Context(2)
|
|
140
140
|
self.send_to_rpc = get_zmq_socket(
|
|
141
141
|
context, zmq.DEALER, self.port_args.rpc_ipc_name, True
|
|
142
142
|
)
|
|
143
143
|
|
|
144
|
+
# Enable tracing
|
|
144
145
|
if server_args.enable_trace:
|
|
145
146
|
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
|
|
146
147
|
if server_args.disaggregation_mode == "null":
|
|
@@ -672,15 +673,17 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
|
672
673
|
os.environ["NCCL_CUMEM_ENABLE"] = str(int(server_args.enable_symm_mem))
|
|
673
674
|
if not server_args.enable_symm_mem:
|
|
674
675
|
os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
|
|
675
|
-
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "
|
|
676
|
+
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8"
|
|
676
677
|
os.environ["CUDA_MODULE_LOADING"] = "AUTO"
|
|
677
|
-
|
|
678
|
+
|
|
678
679
|
if os.environ.get("TRTLLM_ENABLE_PDL", "1") != "0":
|
|
680
|
+
# flashinfer uses this environment variable for various kernels from MoE to quant kernels
|
|
679
681
|
os.environ["TRTLLM_ENABLE_PDL"] = "1"
|
|
680
682
|
|
|
681
683
|
if os.environ.get("CUTE_DSL_LOG_LEVEL") is None:
|
|
682
684
|
# Default to warning level, to avoid too many logs
|
|
683
685
|
os.environ["CUTE_DSL_LOG_LEVEL"] = "30"
|
|
686
|
+
|
|
684
687
|
if os.environ.get("CUTE_DSL_LOG_TO_CONSOLE") is None:
|
|
685
688
|
# Need to set log to console, otherwise the log level won't take effect
|
|
686
689
|
os.environ["CUTE_DSL_LOG_TO_CONSOLE"] = "1"
|
|
@@ -709,7 +712,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
|
709
712
|
if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
|
|
710
713
|
assert_pkg_version(
|
|
711
714
|
"sgl-kernel",
|
|
712
|
-
"0.3.16.
|
|
715
|
+
"0.3.16.post4",
|
|
713
716
|
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
|
714
717
|
)
|
|
715
718
|
|
|
@@ -840,7 +843,7 @@ def _launch_subprocesses(
|
|
|
840
843
|
|
|
841
844
|
if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
|
|
842
845
|
# When using `Engine` as a Python API, we don't want to block here.
|
|
843
|
-
return None, None, None
|
|
846
|
+
return None, None, None, port_args
|
|
844
847
|
|
|
845
848
|
launch_dummy_health_check_server(
|
|
846
849
|
server_args.host, server_args.port, server_args.enable_metrics
|
|
@@ -851,7 +854,7 @@ def _launch_subprocesses(
|
|
|
851
854
|
logger.error(
|
|
852
855
|
f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
|
|
853
856
|
)
|
|
854
|
-
return None, None, None
|
|
857
|
+
return None, None, None, port_args
|
|
855
858
|
|
|
856
859
|
# Launch detokenizer process
|
|
857
860
|
detoken_proc = mp.Process(
|
|
@@ -897,4 +900,4 @@ def _launch_subprocesses(
|
|
|
897
900
|
|
|
898
901
|
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
|
|
899
902
|
|
|
900
|
-
return tokenizer_manager, template_manager, scheduler_info
|
|
903
|
+
return tokenizer_manager, template_manager, scheduler_info, port_args
|
|
@@ -999,7 +999,6 @@ def _wait_and_warmup_grpc(
|
|
|
999
999
|
# Mark health service as SERVING after warmup completes
|
|
1000
1000
|
if health_servicer:
|
|
1001
1001
|
health_servicer.set_serving()
|
|
1002
|
-
logger.info("Health service marked as SERVING")
|
|
1003
1002
|
|
|
1004
1003
|
logger.info("The server is fired up and ready to roll!")
|
|
1005
1004
|
|