sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +6 -0
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +7 -7
  6. sglang/srt/disaggregation/decode.py +8 -3
  7. sglang/srt/disaggregation/mooncake/conn.py +43 -25
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  9. sglang/srt/distributed/parallel_state.py +4 -2
  10. sglang/srt/entrypoints/context.py +3 -20
  11. sglang/srt/entrypoints/engine.py +13 -8
  12. sglang/srt/entrypoints/harmony_utils.py +2 -0
  13. sglang/srt/entrypoints/http_server.py +4 -5
  14. sglang/srt/entrypoints/openai/protocol.py +0 -9
  15. sglang/srt/entrypoints/openai/serving_chat.py +59 -265
  16. sglang/srt/entrypoints/openai/tool_server.py +4 -3
  17. sglang/srt/function_call/ebnf_composer.py +1 -0
  18. sglang/srt/function_call/function_call_parser.py +2 -0
  19. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  20. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  21. sglang/srt/function_call/kimik2_detector.py +3 -3
  22. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  23. sglang/srt/jinja_template_utils.py +6 -0
  24. sglang/srt/layers/attention/aiter_backend.py +370 -107
  25. sglang/srt/layers/attention/ascend_backend.py +3 -0
  26. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  27. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  28. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  29. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  30. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  31. sglang/srt/layers/attention/vision.py +9 -1
  32. sglang/srt/layers/attention/wave_backend.py +627 -0
  33. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  34. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  35. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  36. sglang/srt/layers/communicator.py +8 -10
  37. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  38. sglang/srt/layers/linear.py +1 -0
  39. sglang/srt/layers/moe/cutlass_moe.py +11 -16
  40. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  41. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  42. sglang/srt/layers/moe/ep_moe/layer.py +60 -2
  43. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
  46. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  47. sglang/srt/layers/moe/topk.py +4 -1
  48. sglang/srt/layers/quantization/__init__.py +5 -3
  49. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  50. sglang/srt/layers/quantization/fp8_utils.py +22 -10
  51. sglang/srt/layers/quantization/modelopt_quant.py +6 -11
  52. sglang/srt/layers/quantization/mxfp4.py +4 -1
  53. sglang/srt/layers/quantization/w4afp8.py +20 -11
  54. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  55. sglang/srt/layers/rotary_embedding.py +281 -2
  56. sglang/srt/lora/backend/base_backend.py +3 -23
  57. sglang/srt/lora/layers.py +60 -114
  58. sglang/srt/lora/lora.py +17 -62
  59. sglang/srt/lora/lora_manager.py +12 -48
  60. sglang/srt/lora/lora_registry.py +20 -9
  61. sglang/srt/lora/mem_pool.py +20 -63
  62. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  63. sglang/srt/lora/utils.py +25 -58
  64. sglang/srt/managers/cache_controller.py +21 -29
  65. sglang/srt/managers/detokenizer_manager.py +1 -1
  66. sglang/srt/managers/io_struct.py +6 -6
  67. sglang/srt/managers/mm_utils.py +1 -2
  68. sglang/srt/managers/multimodal_processor.py +1 -1
  69. sglang/srt/managers/schedule_batch.py +35 -20
  70. sglang/srt/managers/schedule_policy.py +6 -6
  71. sglang/srt/managers/scheduler.py +15 -7
  72. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  73. sglang/srt/managers/tokenizer_manager.py +25 -26
  74. sglang/srt/mem_cache/allocator.py +61 -87
  75. sglang/srt/mem_cache/hicache_storage.py +1 -1
  76. sglang/srt/mem_cache/hiradix_cache.py +34 -24
  77. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  78. sglang/srt/mem_cache/memory_pool_host.py +33 -35
  79. sglang/srt/mem_cache/radix_cache.py +2 -5
  80. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  81. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  82. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  83. sglang/srt/model_executor/cuda_graph_runner.py +22 -3
  84. sglang/srt/model_executor/forward_batch_info.py +26 -5
  85. sglang/srt/model_executor/model_runner.py +129 -35
  86. sglang/srt/model_loader/loader.py +18 -6
  87. sglang/srt/models/deepseek_v2.py +74 -35
  88. sglang/srt/models/gemma2.py +0 -34
  89. sglang/srt/models/gemma3n_mm.py +8 -9
  90. sglang/srt/models/glm4.py +6 -0
  91. sglang/srt/models/glm4_moe.py +9 -9
  92. sglang/srt/models/glm4v.py +589 -0
  93. sglang/srt/models/glm4v_moe.py +400 -0
  94. sglang/srt/models/gpt_oss.py +136 -19
  95. sglang/srt/models/granite.py +0 -25
  96. sglang/srt/models/llama.py +0 -25
  97. sglang/srt/models/llama4.py +1 -1
  98. sglang/srt/models/qwen2_5_vl.py +7 -3
  99. sglang/srt/models/qwen2_audio.py +10 -9
  100. sglang/srt/models/qwen3.py +0 -24
  101. sglang/srt/models/registry.py +1 -1
  102. sglang/srt/models/torch_native_llama.py +0 -24
  103. sglang/srt/multimodal/processors/base_processor.py +23 -13
  104. sglang/srt/multimodal/processors/glm4v.py +132 -0
  105. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  106. sglang/srt/reasoning_parser.py +316 -0
  107. sglang/srt/server_args.py +115 -139
  108. sglang/srt/speculative/eagle_worker.py +16 -0
  109. sglang/srt/two_batch_overlap.py +12 -4
  110. sglang/srt/utils.py +3 -3
  111. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  112. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  113. sglang/test/doc_patch.py +59 -0
  114. sglang/test/few_shot_gsm8k.py +1 -1
  115. sglang/test/few_shot_gsm8k_engine.py +1 -1
  116. sglang/test/run_eval.py +4 -1
  117. sglang/test/simple_eval_common.py +6 -0
  118. sglang/test/simple_eval_gpqa.py +2 -0
  119. sglang/test/test_fp4_moe.py +118 -36
  120. sglang/utils.py +1 -1
  121. sglang/version.py +1 -1
  122. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +26 -30
  123. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +127 -115
  124. sglang/lang/backend/__init__.py +0 -0
  125. sglang/srt/function_call/harmony_tool_parser.py +0 -130
  126. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  127. /sglang/{api.py → lang/api.py} +0 -0
  128. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  129. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  130. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py CHANGED
@@ -37,7 +37,6 @@ from sglang.srt.utils import (
37
37
  is_hip,
38
38
  is_port_available,
39
39
  is_remote_url,
40
- is_triton_kernels_available,
41
40
  is_valid_ipv6_address,
42
41
  nullable_str,
43
42
  )
@@ -109,7 +108,7 @@ class ServerArgs:
109
108
  log_level: str = "info"
110
109
  log_level_http: Optional[str] = None
111
110
  log_requests: bool = False
112
- log_requests_level: int = 0
111
+ log_requests_level: int = 2
113
112
  crash_dump_folder: Optional[str] = None
114
113
  show_time_cost: bool = False
115
114
  enable_metrics: bool = False
@@ -131,6 +130,7 @@ class ServerArgs:
131
130
  enable_cache_report: bool = False
132
131
  reasoning_parser: Optional[str] = None
133
132
  tool_call_parser: Optional[str] = None
133
+ tool_server: Optional[str] = None
134
134
 
135
135
  # Data parallelism
136
136
  dp_size: int = 1
@@ -278,15 +278,11 @@ class ServerArgs:
278
278
  enable_pdmux: bool = False
279
279
  sm_group_num: int = 3
280
280
 
281
- # For tool server
282
- tool_server: Optional[str] = None
283
-
284
281
  # Deprecated arguments
285
282
  enable_ep_moe: bool = False
286
283
  enable_deepep_moe: bool = False
287
284
 
288
285
  def __post_init__(self):
289
-
290
286
  # Check deprecated arguments
291
287
  def print_deprecated_warning(message: str):
292
288
  logger.warning(f"\033[33m{message}\033[0m")
@@ -392,6 +388,9 @@ class ServerArgs:
392
388
  self.attention_backend = "torch_native"
393
389
  self.sampling_backend = "pytorch"
394
390
 
391
+ # Model-specific adjustments
392
+ self.model_specific_adjustments()
393
+
395
394
  # Set kernel backends
396
395
  if self.device == "cpu":
397
396
  if self.attention_backend is None:
@@ -433,7 +432,10 @@ class ServerArgs:
433
432
  )
434
433
  self.page_size = 128
435
434
 
436
- if self.attention_backend == "trtllm_mla":
435
+ if (
436
+ self.attention_backend == "trtllm_mla"
437
+ or self.decode_attention_backend == "trtllm_mla"
438
+ ):
437
439
  if not is_sm100_supported():
438
440
  raise ValueError(
439
441
  "TRTLLM MLA backend is only supported on Blackwell GPUs (SM100). Please use a different backend."
@@ -444,11 +446,17 @@ class ServerArgs:
444
446
  f"TensorRT-LLM MLA only supports page_size of 32 or 64, changing page_size from {self.page_size} to 64."
445
447
  )
446
448
  self.page_size = 64
449
+
447
450
  if self.speculative_algorithm is not None:
448
451
  raise ValueError(
449
452
  "trtllm_mla backend does not support speculative decoding yet."
450
453
  )
451
454
 
455
+ if self.kv_cache_dtype not in ["fp8_e4m3", "auto"]:
456
+ raise ValueError(
457
+ "TensorRT-LLM MLA backend only supports kv-cache-dtype of fp8_e4m3 or auto."
458
+ )
459
+
452
460
  if (
453
461
  self.attention_backend == "trtllm_mha"
454
462
  or self.decode_attention_backend == "trtllm_mha"
@@ -470,55 +478,9 @@ class ServerArgs:
470
478
  "trtllm_mha backend does not support speculative decoding yet."
471
479
  )
472
480
 
473
- model_arch = self.get_hf_config().architectures[0]
474
- if model_arch in ["GptOssForCausalLM"]:
475
- if self.attention_backend is None:
476
- # default is triton, but we could have trtllm_mha as an option
477
- self.attention_backend = "triton"
478
- assert (
479
- self.attention_backend == "trtllm_mha"
480
- or self.attention_backend == "triton"
481
- )
482
- quantization_config = getattr(
483
- self.get_hf_config(), "quantization_config", None
484
- )
485
- is_mxfp4_quant_format = (
486
- quantization_config is not None
487
- and quantization_config.get("quant_method") == "mxfp4"
488
- )
489
-
490
- if is_sm100_supported() and is_mxfp4_quant_format:
491
- self.enable_flashinfer_mxfp4_moe = True
492
- self.enable_triton_kernel_moe = False
493
- logger.info(
494
- "Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
495
- )
496
- else:
497
- if self.enable_triton_kernel_moe:
498
- assert (
499
- self.ep_size == 1
500
- ), "Triton kernel MoE is only supported when ep_size == 1"
501
- if not self.enable_triton_kernel_moe and self.ep_size == 1:
502
- self.enable_triton_kernel_moe = True
503
- logger.info(
504
- "Detected GPT-OSS model, enabling triton_kernels MOE kernel."
505
- )
506
-
507
- self.disable_hybrid_swa_memory = True
508
-
509
- if is_mxfp4_quant_format:
510
- # use bf16 for mxfp4 triton kernels
511
- self.dtype = "bfloat16"
512
-
513
481
  if self.attention_backend == "dual_chunk_flash_attn":
514
482
  logger.warning(
515
- "Mixed chunk is disabled because of using dual chunk flash attention backend"
516
- )
517
- logger.warning(
518
- "Radix cache is disabled because of using dual chunk flash attention backend"
519
- )
520
- logger.warning(
521
- "Cuda graph is disabled because of using dual chunk flash attention backend"
483
+ "Mixed chunk, radix cache, and cuda graphs are disabled because of using dual chunk flash attention backend"
522
484
  )
523
485
  self.enable_mixed_chunk = False
524
486
  self.disable_cuda_graph = True
@@ -583,7 +545,7 @@ class ServerArgs:
583
545
 
584
546
  if self.enable_eplb and (self.expert_distribution_recorder_mode is None):
585
547
  self.expert_distribution_recorder_mode = "stat"
586
- logger.info(
548
+ logger.warning(
587
549
  "EPLB is enabled. The expert_distribution_recorder_mode is automatically set."
588
550
  )
589
551
 
@@ -591,9 +553,6 @@ class ServerArgs:
591
553
  self.ep_dispatch_algorithm is None
592
554
  ):
593
555
  self.ep_dispatch_algorithm = "static"
594
- logger.info(
595
- "EPLB is enabled or init_expert_location is provided. ep_dispatch_algorithm is configured."
596
- )
597
556
 
598
557
  if self.enable_eplb:
599
558
  assert self.ep_size > 1 or self.moe_a2a_backend is not None
@@ -616,6 +575,11 @@ class ServerArgs:
616
575
  "Pipeline parallelism is incompatible with overlap schedule."
617
576
  )
618
577
 
578
+ if self.hicache_storage_backend == "mooncake":
579
+ # to use mooncake storage backend, the following conditions must be met:
580
+ self.hicache_io_backend = "kernel"
581
+ self.hicache_mem_layout = "page_first"
582
+
619
583
  # Speculative Decoding
620
584
  if self.speculative_algorithm == "NEXTN":
621
585
  # NEXTN shares the same implementation of EAGLE
@@ -1112,7 +1076,7 @@ class ServerArgs:
1112
1076
  parser.add_argument(
1113
1077
  "--log-requests-level",
1114
1078
  type=int,
1115
- default=0,
1079
+ default=ServerArgs.log_requests_level,
1116
1080
  help="0: Log metadata (no sampling parameters). 1: Log metadata and sampling parameters. 2: Log metadata, sampling parameters and partial input/output. 3: Log every input/output.",
1117
1081
  choices=[0, 1, 2, 3],
1118
1082
  )
@@ -1231,7 +1195,7 @@ class ServerArgs:
1231
1195
  parser.add_argument(
1232
1196
  "--tool-call-parser",
1233
1197
  type=str,
1234
- choices=[
1198
+ choices=[ # TODO: use FunctionCallParser.DetectorMap.keys()
1235
1199
  "qwen25",
1236
1200
  "mistral",
1237
1201
  "llama3",
@@ -1241,10 +1205,17 @@ class ServerArgs:
1241
1205
  "qwen3_coder",
1242
1206
  "glm45",
1243
1207
  "step3",
1208
+ "gpt-oss",
1244
1209
  ],
1245
1210
  default=ServerArgs.tool_call_parser,
1246
1211
  help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', 'kimi_k2', 'qwen3_coder', 'glm45', and 'step3'.",
1247
1212
  )
1213
+ parser.add_argument(
1214
+ "--tool-server",
1215
+ type=str,
1216
+ default=None,
1217
+ help="Either 'demo' or a comma-separated list of tool server urls to use for the model. If not specified, no tool server will be used.",
1218
+ )
1248
1219
 
1249
1220
  # Data parallelism
1250
1221
  parser.add_argument(
@@ -1344,55 +1315,42 @@ class ServerArgs:
1344
1315
  )
1345
1316
 
1346
1317
  # Kernel backend
1318
+ ATTN_BACKENDS = [
1319
+ "aiter",
1320
+ "cutlass_mla",
1321
+ "fa3",
1322
+ "flashinfer",
1323
+ "flashmla",
1324
+ "intel_amx",
1325
+ "torch_native",
1326
+ "ascend",
1327
+ "triton",
1328
+ "trtllm_mla",
1329
+ "trtllm_mha",
1330
+ "dual_chunk_flash_attn",
1331
+ "wave",
1332
+ ]
1347
1333
  parser.add_argument(
1348
1334
  "--attention-backend",
1349
1335
  type=str,
1350
- choices=[
1351
- "aiter",
1352
- "cutlass_mla",
1353
- "fa3",
1354
- "flashinfer",
1355
- "flashmla",
1356
- "intel_amx",
1357
- "torch_native",
1358
- "ascend",
1359
- "triton",
1360
- "trtllm_mla",
1361
- "trtllm_mha",
1362
- "dual_chunk_flash_attn",
1363
- ],
1336
+ choices=ATTN_BACKENDS,
1364
1337
  default=ServerArgs.attention_backend,
1365
1338
  help="Choose the kernels for attention layers.",
1366
1339
  )
1367
- parser.add_argument(
1368
- "--decode-attention-backend",
1369
- type=str,
1370
- choices=[
1371
- "flashinfer",
1372
- "triton",
1373
- "torch_native",
1374
- "fa3",
1375
- "flashmla",
1376
- "cutlass_mla",
1377
- ],
1378
- default=ServerArgs.decode_attention_backend,
1379
- help="Choose the kernels for decode attention layers (have priority over --attention-backend).",
1380
- )
1381
-
1382
1340
  parser.add_argument(
1383
1341
  "--prefill-attention-backend",
1384
1342
  type=str,
1385
- choices=[
1386
- "flashinfer",
1387
- "triton",
1388
- "torch_native",
1389
- "fa3",
1390
- "flashmla",
1391
- "cutlass_mla",
1392
- ],
1343
+ choices=ATTN_BACKENDS,
1393
1344
  default=ServerArgs.prefill_attention_backend,
1394
1345
  help="Choose the kernels for prefill attention layers (have priority over --attention-backend).",
1395
1346
  )
1347
+ parser.add_argument(
1348
+ "--decode-attention-backend",
1349
+ type=str,
1350
+ choices=ATTN_BACKENDS,
1351
+ default=ServerArgs.decode_attention_backend,
1352
+ help="Choose the kernels for decode attention layers (have priority over --attention-backend).",
1353
+ )
1396
1354
  parser.add_argument(
1397
1355
  "--sampling-backend",
1398
1356
  type=str,
@@ -1493,7 +1451,7 @@ class ServerArgs:
1493
1451
  parser.add_argument(
1494
1452
  "--enable-flashinfer-allreduce-fusion",
1495
1453
  action="store_true",
1496
- help="Enable FlashInfer allreduce fusion for Add_RMSNorm.",
1454
+ help="Enable FlashInfer allreduce fusion with Residual RMSNorm.",
1497
1455
  )
1498
1456
  parser.add_argument(
1499
1457
  "--deepep-mode",
@@ -1612,7 +1570,6 @@ class ServerArgs:
1612
1570
  default=ServerArgs.hicache_mem_layout,
1613
1571
  help="The layout of host memory pool for hierarchical cache.",
1614
1572
  )
1615
-
1616
1573
  parser.add_argument(
1617
1574
  "--hicache-storage-backend",
1618
1575
  type=str,
@@ -1985,14 +1942,6 @@ class ServerArgs:
1985
1942
  help="Disable mmap while loading weight using safetensors.",
1986
1943
  )
1987
1944
 
1988
- # For tool server
1989
- parser.add_argument(
1990
- "--tool-server",
1991
- type=str,
1992
- default=None,
1993
- help="Either 'demo' or a comma-separated list of tool server urls to use for the model. If not specified, no tool server will be used.",
1994
- )
1995
-
1996
1945
  # Deprecated arguments
1997
1946
  parser.add_argument(
1998
1947
  "--enable-ep-moe",
@@ -2056,25 +2005,6 @@ class ServerArgs:
2056
2005
  None,
2057
2006
  }, "moe_dense_tp_size only support 1 and None currently"
2058
2007
 
2059
- # Check model architecture
2060
- model_arch = self.get_hf_config().architectures[0]
2061
- if "Llama4" in model_arch:
2062
- assert self.attention_backend == "fa3", "fa3 is required for Llama4 model"
2063
-
2064
- if model_arch in [
2065
- "Gemma2ForCausalLM",
2066
- "Gemma3ForCausalLM",
2067
- "Gemma3ForConditionalGeneration",
2068
- "Gemma3nForCausalLM",
2069
- "Gemma3nForConditionalGeneration",
2070
- ]:
2071
- # FIXME: https://github.com/sgl-project/sglang/pull/7367 is not compatible with gemma2 model.
2072
- # It failed at this test: https://github.com/sgl-project/sglang/actions/runs/16255155597/job/45890331952#step:4:736
2073
- logger.warning(
2074
- f"Disable hybrid SWA memory for {model_arch} as it is not yet supported."
2075
- )
2076
- self.disable_hybrid_swa_memory = True
2077
-
2078
2008
  # Check LoRA
2079
2009
  self.check_lora_server_args()
2080
2010
 
@@ -2085,22 +2015,20 @@ class ServerArgs:
2085
2015
  ), "enable_mixed_chunk is required for speculative decoding"
2086
2016
 
2087
2017
  # Check chunked prefill
2088
- assert (
2089
- self.chunked_prefill_size % self.page_size == 0
2090
- ), "chunked_prefill_size must be divisible by page_size"
2018
+ # Skip validation if chunked prefill is disabled (i.e., size <= 0).
2019
+ if self.chunked_prefill_size > 0:
2020
+ assert (
2021
+ self.chunked_prefill_size % self.page_size == 0
2022
+ ), "chunked_prefill_size must be divisible by page_size"
2091
2023
 
2092
2024
  def check_lora_server_args(self):
2093
- assert (
2094
- self.max_loras_per_batch > 0
2095
- # FIXME
2096
- and (self.lora_paths is None or self.disable_radix_cache)
2097
- ), "compatibility of lora and radix attention is in progress"
2025
+ assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive"
2098
2026
 
2099
2027
  # Enable LoRA if any LoRA paths are provided for backward compatibility.
2100
2028
  if self.lora_paths:
2101
2029
  if self.enable_lora is None:
2102
2030
  self.enable_lora = True
2103
- logger.info(
2031
+ logger.warning(
2104
2032
  "--enable-lora is set to True because --lora-paths is provided."
2105
2033
  )
2106
2034
  elif self.enable_lora is False:
@@ -2172,6 +2100,58 @@ class ServerArgs:
2172
2100
  f"decode_tp={decode_tp}, prefill_tp={prefill_tp}"
2173
2101
  )
2174
2102
 
2103
+ def model_specific_adjustments(self):
2104
+ hf_config = self.get_hf_config()
2105
+ model_arch = hf_config.architectures[0]
2106
+ if model_arch in ["GptOssForCausalLM"]:
2107
+ if self.attention_backend is None:
2108
+ self.attention_backend = "triton"
2109
+ supported_backends = ["triton", "trtllm_mha", "fa3"]
2110
+ assert (
2111
+ self.attention_backend in supported_backends
2112
+ ), f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got '{self.attention_backend}'"
2113
+ quantization_config = getattr(hf_config, "quantization_config", None)
2114
+ is_mxfp4_quant_format = (
2115
+ quantization_config is not None
2116
+ and quantization_config.get("quant_method") == "mxfp4"
2117
+ )
2118
+
2119
+ if is_sm100_supported() and is_mxfp4_quant_format:
2120
+ self.enable_flashinfer_mxfp4_moe = True
2121
+ self.enable_triton_kernel_moe = False
2122
+ logger.warning(
2123
+ "Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
2124
+ )
2125
+ else:
2126
+ if self.enable_triton_kernel_moe:
2127
+ assert (
2128
+ self.ep_size == 1
2129
+ ), "Triton kernel MoE is only supported when ep_size == 1"
2130
+ if not self.enable_triton_kernel_moe and self.ep_size == 1:
2131
+ self.enable_triton_kernel_moe = True
2132
+ logger.warning(
2133
+ "Detected GPT-OSS model, enabling triton_kernels MOE kernel."
2134
+ )
2135
+ self.disable_hybrid_swa_memory = True
2136
+ if is_mxfp4_quant_format:
2137
+ # use bf16 for mxfp4 triton kernels
2138
+ self.dtype = "bfloat16"
2139
+ elif "Llama4" in model_arch:
2140
+ assert self.attention_backend == "fa3", "fa3 is required for Llama4 model"
2141
+ elif model_arch in [
2142
+ "Gemma2ForCausalLM",
2143
+ "Gemma3ForCausalLM",
2144
+ "Gemma3ForConditionalGeneration",
2145
+ "Gemma3nForCausalLM",
2146
+ "Gemma3nForConditionalGeneration",
2147
+ ]:
2148
+ # FIXME: https://github.com/sgl-project/sglang/pull/7367 is not compatible with gemma2 model.
2149
+ # It failed at this test: https://github.com/sgl-project/sglang/actions/runs/16255155597/job/45890331952#step:4:736
2150
+ logger.warning(
2151
+ f"Disable hybrid SWA memory for {model_arch} as it is not yet supported."
2152
+ )
2153
+ self.disable_hybrid_swa_memory = True
2154
+
2175
2155
  def adjust_mem_fraction_for_vlm(self, model_config):
2176
2156
  vision_config = getattr(model_config.hf_config, "vision_config", None)
2177
2157
  if vision_config is None:
@@ -2209,10 +2189,6 @@ class ServerArgs:
2209
2189
  self.mem_fraction_static = (
2210
2190
  original_server_arg_mem_fraction * final_overall_factor
2211
2191
  )
2212
- logger.warning(
2213
- f"Multimodal model: Dynamically adjusted --mem-fraction-static "
2214
- f"from: {original_server_arg_mem_fraction:.3f} to: {self.mem_fraction_static:.3f}."
2215
- )
2216
2192
 
2217
2193
 
2218
2194
  def prepare_server_args(argv: List[str]) -> ServerArgs:
@@ -226,6 +226,22 @@ class EAGLEWorker(TpModelWorker):
226
226
  self.draft_model_runner,
227
227
  skip_prefill=False,
228
228
  )
229
+ elif self.server_args.attention_backend == "aiter":
230
+ from sglang.srt.layers.attention.aiter_backend import (
231
+ AiterAttnBackend,
232
+ AiterMultiStepDraftBackend,
233
+ )
234
+
235
+ self.draft_attn_backend = AiterMultiStepDraftBackend(
236
+ self.draft_model_runner,
237
+ self.topk,
238
+ self.speculative_num_steps,
239
+ )
240
+ self.draft_extend_attn_backend = AiterAttnBackend(
241
+ self.draft_model_runner,
242
+ skip_prefill=False,
243
+ )
244
+ self.has_prefill_wrapper_verify = False
229
245
  elif self.server_args.attention_backend == "fa3":
230
246
  from sglang.srt.layers.attention.flashattention_backend import (
231
247
  FlashAttentionBackend,
@@ -26,11 +26,13 @@ from sglang.srt.model_executor.forward_batch_info import (
26
26
  from sglang.srt.operations import execute_operations, execute_overlapped_operations
27
27
  from sglang.srt.operations_strategy import OperationsStrategy
28
28
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
29
- from sglang.srt.utils import BumpAllocator, get_bool_env_var
29
+ from sglang.srt.utils import BumpAllocator, get_bool_env_var, is_hip
30
30
 
31
31
  if TYPE_CHECKING:
32
32
  from sglang.srt.layers.moe.token_dispatcher import DispatchOutput
33
33
 
34
+ _is_hip = is_hip()
35
+
34
36
  _tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
35
37
 
36
38
  logger = logging.getLogger(__name__)
@@ -822,9 +824,15 @@ def _model_forward_tbo(
822
824
  )
823
825
  del inputs
824
826
 
825
- with deep_gemm_wrapper.configure_deep_gemm_num_sms(
826
- operations_strategy.deep_gemm_num_sms
827
- ):
827
+ context = (
828
+ empty_context()
829
+ if _is_hip
830
+ else deep_gemm_wrapper.configure_deep_gemm_num_sms(
831
+ operations_strategy.deep_gemm_num_sms
832
+ )
833
+ )
834
+
835
+ with context:
828
836
  outputs_arr = execute_overlapped_operations(
829
837
  inputs_arr=inputs_arr,
830
838
  operations_arr=[operations_strategy.operations] * 2,
sglang/srt/utils.py CHANGED
@@ -815,7 +815,7 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
815
815
  vr = VideoReader(tmp_file.name, ctx=ctx)
816
816
  elif video_file.startswith("data:"):
817
817
  _, encoded = video_file.split(",", 1)
818
- video_bytes = base64.b64decode(encoded)
818
+ video_bytes = pybase64.b64decode(encoded)
819
819
  tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
820
820
  tmp_file.write(video_bytes)
821
821
  tmp_file.close()
@@ -823,7 +823,7 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
823
823
  elif os.path.isfile(video_file):
824
824
  vr = VideoReader(video_file, ctx=ctx)
825
825
  else:
826
- video_bytes = base64.b64decode(video_file)
826
+ video_bytes = pybase64.b64decode(video_file)
827
827
  tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
828
828
  tmp_file.write(video_bytes)
829
829
  tmp_file.close()
@@ -2960,7 +2960,7 @@ class ConcurrentCounter:
2960
2960
  This suspends the calling coroutine without blocking the thread, allowing
2961
2961
  other tasks to run while waiting. When the counter becomes zero, the coroutine resumes.
2962
2962
  """
2963
- self.wait_for(lambda count: count == 0)
2963
+ await self.wait_for(lambda count: count == 0)
2964
2964
 
2965
2965
 
2966
2966
  @lru_cache(maxsize=1)
@@ -0,0 +1,106 @@
1
+ from dataclasses import dataclass
2
+ from typing import List, Tuple
3
+
4
+ import torch
5
+
6
+
7
+ @dataclass
8
+ class FlattenedTensorMetadata:
9
+ """Metadata for a tensor in a flattened bucket"""
10
+
11
+ name: str
12
+ shape: torch.Size
13
+ dtype: torch.dtype
14
+ start_idx: int
15
+ end_idx: int
16
+ numel: int
17
+
18
+
19
+ class FlattenedTensorBucket:
20
+ """
21
+ A bucket that flattens multiple tensors into a single tensor for efficient processing
22
+ while preserving all metadata needed for reconstruction.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ named_tensors: List[Tuple[str, torch.Tensor]] = None,
28
+ flattened_tensor: torch.Tensor = None,
29
+ metadata: List[FlattenedTensorMetadata] = None,
30
+ ):
31
+ """
32
+ Initialize a tensor bucket from a list of named tensors OR from pre-flattened data.
33
+ Args:
34
+ named_tensors: List of (name, tensor) tuples (for creating new bucket)
35
+ flattened_tensor: Pre-flattened tensor (for reconstruction)
36
+ metadata: Pre-computed metadata (for reconstruction)
37
+ """
38
+ if named_tensors is not None:
39
+ # Create bucket from named tensors
40
+ self.metadata: List[FlattenedTensorMetadata] = [None] * len(named_tensors)
41
+ self.flattened_tensor: torch.Tensor = None
42
+
43
+ if not named_tensors:
44
+ raise ValueError("Cannot create empty tensor bucket")
45
+
46
+ # Collect metadata and flatten tensors
47
+ current_idx = 0
48
+ flattened_tensors: List[torch.Tensor] = [None] * len(named_tensors)
49
+
50
+ for i, (name, tensor) in enumerate(named_tensors):
51
+ flattened = tensor.flatten()
52
+ flattened_tensors[i] = flattened
53
+
54
+ # Store metadata
55
+
56
+ numel = flattened.numel()
57
+ metadata_obj = FlattenedTensorMetadata(
58
+ name=name,
59
+ shape=tensor.shape,
60
+ dtype=tensor.dtype,
61
+ start_idx=current_idx,
62
+ end_idx=current_idx + numel,
63
+ numel=numel,
64
+ )
65
+ self.metadata[i] = metadata_obj
66
+ current_idx += numel
67
+
68
+ # Concatenate all flattened tensors
69
+ self.flattened_tensor = torch.cat(flattened_tensors, dim=0)
70
+ else:
71
+ # Initialize from pre-flattened data
72
+ if flattened_tensor is None or metadata is None:
73
+ raise ValueError(
74
+ "Must provide either named_tensors or both flattened_tensor and metadata"
75
+ )
76
+ self.flattened_tensor = flattened_tensor
77
+ self.metadata = metadata
78
+
79
+ def get_flattened_tensor(self) -> torch.Tensor:
80
+ """Get the flattened tensor containing all bucket tensors"""
81
+ return self.flattened_tensor
82
+
83
+ def get_metadata(self) -> List[FlattenedTensorMetadata]:
84
+ """Get metadata for all tensors in the bucket"""
85
+ return self.metadata
86
+
87
+ def reconstruct_tensors(self) -> List[Tuple[str, torch.Tensor]]:
88
+ """
89
+ Reconstruct original tensors from flattened tensor with optimized performance.
90
+ Uses memory-efficient operations to minimize allocations and copies.
91
+ """
92
+ # preallocate the result list
93
+ reconstructed = [None] * len(self.metadata)
94
+
95
+ for i, meta in enumerate(self.metadata):
96
+ tensor = self.flattened_tensor[meta.start_idx : meta.end_idx].reshape(
97
+ meta.shape
98
+ )
99
+
100
+ # batch dtype conversion (if needed)
101
+ if tensor.dtype != meta.dtype:
102
+ tensor = tensor.to(meta.dtype)
103
+
104
+ reconstructed[i] = (meta.name, tensor)
105
+
106
+ return reconstructed