sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. sglang/bench_one_batch.py +1 -11
  2. sglang/bench_serving.py +149 -1
  3. sglang/lang/chat_template.py +44 -0
  4. sglang/srt/configs/deepseekvl2.py +3 -0
  5. sglang/srt/configs/device_config.py +1 -1
  6. sglang/srt/configs/internvl.py +696 -0
  7. sglang/srt/configs/janus_pro.py +3 -0
  8. sglang/srt/configs/model_config.py +17 -0
  9. sglang/srt/constrained/xgrammar_backend.py +11 -19
  10. sglang/srt/conversation.py +30 -3
  11. sglang/srt/disaggregation/decode.py +4 -1
  12. sglang/srt/disaggregation/mini_lb.py +74 -23
  13. sglang/srt/disaggregation/mooncake/conn.py +9 -18
  14. sglang/srt/disaggregation/nixl/conn.py +241 -71
  15. sglang/srt/disaggregation/utils.py +44 -1
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  17. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  18. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  19. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  20. sglang/srt/distributed/parallel_state.py +22 -1
  21. sglang/srt/entrypoints/engine.py +14 -2
  22. sglang/srt/entrypoints/http_server.py +28 -1
  23. sglang/srt/entrypoints/verl_engine.py +3 -2
  24. sglang/srt/hf_transformers_utils.py +20 -1
  25. sglang/srt/layers/attention/flashattention_backend.py +146 -50
  26. sglang/srt/layers/attention/flashinfer_backend.py +23 -13
  27. sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
  28. sglang/srt/layers/attention/merge_state.py +46 -0
  29. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  30. sglang/srt/layers/attention/vision.py +290 -163
  31. sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
  32. sglang/srt/layers/moe/ep_moe/layer.py +120 -1
  33. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +4 -1
  37. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  38. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  39. sglang/srt/layers/quantization/deep_gemm.py +5 -0
  40. sglang/srt/layers/quantization/fp8.py +108 -95
  41. sglang/srt/layers/quantization/fp8_kernel.py +79 -60
  42. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  43. sglang/srt/layers/quantization/kv_cache.py +3 -10
  44. sglang/srt/layers/quantization/utils.py +0 -5
  45. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  46. sglang/srt/lora/lora_manager.py +10 -13
  47. sglang/srt/managers/cache_controller.py +115 -119
  48. sglang/srt/managers/io_struct.py +10 -0
  49. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  50. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  51. sglang/srt/managers/schedule_batch.py +19 -1
  52. sglang/srt/managers/schedule_policy.py +11 -5
  53. sglang/srt/managers/scheduler.py +28 -13
  54. sglang/srt/managers/tokenizer_manager.py +24 -13
  55. sglang/srt/managers/tp_worker.py +9 -12
  56. sglang/srt/mem_cache/chunk_cache.py +2 -0
  57. sglang/srt/mem_cache/memory_pool.py +2 -2
  58. sglang/srt/model_executor/model_runner.py +44 -33
  59. sglang/srt/model_loader/loader.py +18 -11
  60. sglang/srt/models/clip.py +4 -4
  61. sglang/srt/models/deepseek_janus_pro.py +1 -1
  62. sglang/srt/models/deepseek_nextn.py +1 -20
  63. sglang/srt/models/deepseek_v2.py +55 -20
  64. sglang/srt/models/gemma3_mm.py +1 -1
  65. sglang/srt/models/internlm2.py +3 -0
  66. sglang/srt/models/internvl.py +670 -0
  67. sglang/srt/models/llama.py +1 -1
  68. sglang/srt/models/llama4.py +53 -7
  69. sglang/srt/models/minicpmv.py +1 -1
  70. sglang/srt/models/mllama.py +1 -1
  71. sglang/srt/models/phi3_small.py +16 -2
  72. sglang/srt/models/qwen2_5_vl.py +8 -4
  73. sglang/srt/models/qwen2_vl.py +4 -4
  74. sglang/srt/models/xiaomi_mimo.py +171 -0
  75. sglang/srt/openai_api/adapter.py +24 -40
  76. sglang/srt/openai_api/protocol.py +28 -16
  77. sglang/srt/reasoning_parser.py +2 -2
  78. sglang/srt/sampling/sampling_batch_info.py +54 -2
  79. sglang/srt/sampling/sampling_params.py +2 -0
  80. sglang/srt/server_args.py +30 -6
  81. sglang/srt/utils.py +35 -1
  82. sglang/test/test_block_fp8.py +2 -2
  83. sglang/test/test_deepep_utils.py +219 -0
  84. sglang/test/test_utils.py +3 -1
  85. sglang/version.py +1 -1
  86. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +14 -6
  87. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +90 -80
  88. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
  89. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
  90. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -290,6 +290,9 @@ class DictOutput(object):
290
290
  def __getitem__(self, item):
291
291
  return self.__dict__[item]
292
292
 
293
+ def __contains__(self, key):
294
+ return key in self.__dict__
295
+
293
296
  def __setitem__(self, key, value):
294
297
  self.__dict__[key] = value
295
298
 
@@ -24,6 +24,7 @@ from transformers import PretrainedConfig
24
24
 
25
25
  from sglang.srt.hf_transformers_utils import get_config, get_context_length
26
26
  from sglang.srt.layers.quantization import QUANTIZATION_METHODS
27
+ from sglang.srt.server_args import ServerArgs
27
28
  from sglang.srt.utils import get_bool_env_var, is_hip
28
29
 
29
30
  logger = logging.getLogger(__name__)
@@ -210,6 +211,21 @@ class ModelConfig:
210
211
  self.hf_eos_token_id = self.get_hf_eos_token_id()
211
212
  self.image_token_id = getattr(self.hf_config, "image_token_id", None)
212
213
 
214
+ @staticmethod
215
+ def from_server_args(server_args: ServerArgs, model_path: str = None, **kwargs):
216
+ return ModelConfig(
217
+ model_path=model_path or server_args.model_path,
218
+ trust_remote_code=server_args.trust_remote_code,
219
+ revision=server_args.revision,
220
+ context_length=server_args.context_length,
221
+ model_override_args=server_args.json_model_override_args,
222
+ is_embedding=server_args.is_embedding,
223
+ enable_multimodal=server_args.enable_multimodal,
224
+ dtype=server_args.dtype,
225
+ quantization=server_args.quantization,
226
+ **kwargs,
227
+ )
228
+
213
229
  # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
214
230
  def get_total_num_kv_heads(self) -> int:
215
231
  """Returns the total number of KV heads."""
@@ -538,6 +554,7 @@ multimodal_model_archs = [
538
554
  "Qwen2_5_VLForConditionalGeneration",
539
555
  "CLIPModel",
540
556
  "KimiVLForConditionalGeneration",
557
+ "InternVLChatModel",
541
558
  ]
542
559
 
543
560
 
@@ -18,6 +18,7 @@ import logging
18
18
  from typing import List, Optional, Tuple, Union
19
19
 
20
20
  import torch
21
+ import xgrammar
21
22
  from xgrammar import (
22
23
  CompiledGrammar,
23
24
  GrammarCompiler,
@@ -58,17 +59,11 @@ class XGrammarGrammar(BaseGrammarObject):
58
59
  self.override_stop_tokens = override_stop_tokens
59
60
  self.finished = False
60
61
 
61
- # Fix (from vLLM team): postpone the import of apply_token_bitmask_inplace_kernels to the
62
- # class init site to avoid re-initializing CUDA in forked subprocess.
63
- from xgrammar.kernels import apply_token_bitmask_inplace_kernels
64
-
65
- self.use_token_bitmask_triton = get_bool_env_var(
66
- "SGLANG_TOKEN_BITMASK_TRITON", "false"
67
- )
68
- self.apply_vocab_mask_cuda = apply_token_bitmask_inplace_kernels.get(
69
- "cuda", None
62
+ from xgrammar.kernels.apply_token_bitmask_inplace_cpu import (
63
+ apply_token_bitmask_inplace_cpu,
70
64
  )
71
- self.apply_vocab_mask_cpu = apply_token_bitmask_inplace_kernels.get("cpu", None)
65
+
66
+ self.apply_vocab_mask_cpu = apply_token_bitmask_inplace_cpu
72
67
 
73
68
  def accept_token(self, token: int):
74
69
  assert self.matcher.accept_token(token)
@@ -113,15 +108,12 @@ class XGrammarGrammar(BaseGrammarObject):
113
108
  return vocab_mask.to(device, non_blocking=True)
114
109
 
115
110
  def apply_vocab_mask(self, logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
116
- if (
117
- not self.use_token_bitmask_triton
118
- and logits.device.type == "cuda"
119
- and self.apply_vocab_mask_cuda
120
- ):
121
- return self.apply_vocab_mask_cuda(logits, vocab_mask)
122
- if logits.device.type == "cpu" and self.apply_vocab_mask_cpu:
123
- return self.apply_vocab_mask_cpu(logits, vocab_mask)
124
- apply_token_bitmask_inplace_triton(logits, vocab_mask)
111
+ if logits.device.type == "cuda":
112
+ apply_token_bitmask_inplace_triton(logits, vocab_mask)
113
+ elif logits.device.type == "cpu" and self.apply_vocab_mask_cpu:
114
+ self.apply_vocab_mask_cpu(logits, vocab_mask)
115
+ else:
116
+ raise RuntimeError(f"Unsupported device: {logits.device.type}")
125
117
 
126
118
  def copy(self):
127
119
  matcher = GrammarMatcher(
@@ -48,6 +48,7 @@ class SeparatorStyle(IntEnum):
48
48
  DeepSeekVL2 = auto()
49
49
  QWEN2_VL_EMBED = auto()
50
50
  GEMMA3 = auto()
51
+ MPT = auto()
51
52
 
52
53
 
53
54
  @dataclasses.dataclass
@@ -327,6 +328,16 @@ class Conversation:
327
328
  ret += role
328
329
  return ret
329
330
 
331
+ elif self.sep_style == SeparatorStyle.MPT:
332
+ ret = system_prompt + self.sep
333
+ for role, message in self.messages:
334
+ if message:
335
+ if type(message) is tuple:
336
+ message, _, _ = message
337
+ ret += role + message + self.sep
338
+ else:
339
+ ret += role
340
+ return ret
330
341
  else:
331
342
  raise ValueError(f"Invalid style: {self.sep_style}")
332
343
 
@@ -570,8 +581,11 @@ def generate_chat_conv(
570
581
  real_content += "\n" # for video
571
582
  real_content += content.text
572
583
  elif content.type == "image_url":
573
- # NOTE: Only works for llava
574
- real_content += image_token
584
+ # NOTE: works for llava and intervl2_5
585
+ if conv.name == "internvl-2-5":
586
+ real_content = image_token + real_content
587
+ else:
588
+ real_content += image_token
575
589
  conv.append_image(content.image_url.url)
576
590
  elif content.type == "audio_url":
577
591
  real_content += audio_token
@@ -703,6 +717,19 @@ register_conv_template(
703
717
  )
704
718
  )
705
719
 
720
+ register_conv_template(
721
+ Conversation(
722
+ name="internvl-2-5",
723
+ system_template="<|im_start|>system\n{system_message}",
724
+ system_message="你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。",
725
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
726
+ sep_style=SeparatorStyle.MPT,
727
+ sep="<|im_end|>\n",
728
+ stop_str=["<|im_end|>", "<|action_end|>"],
729
+ image_token="<image>",
730
+ )
731
+ )
732
+
706
733
  # Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
707
734
  register_conv_template(
708
735
  Conversation(
@@ -826,7 +853,7 @@ register_conv_template(
826
853
 
827
854
 
828
855
  @register_conv_template_matching_function
829
- def match_deepseek_janus_pro(model_path: str):
856
+ def match_llama_3_vision(model_path: str):
830
857
  if (
831
858
  "llama" in model_path.lower()
832
859
  and "3.2" in model_path.lower()
@@ -21,6 +21,7 @@ Life cycle of a request in the decode server
21
21
  from __future__ import annotations
22
22
 
23
23
  import logging
24
+ import os
24
25
  from collections import deque
25
26
  from dataclasses import dataclass
26
27
  from typing import TYPE_CHECKING, List, Optional, Tuple
@@ -97,7 +98,9 @@ class DecodePreallocQueue:
97
98
  self.tp_size = tp_size
98
99
  self.bootstrap_port = bootstrap_port
99
100
 
100
- self.num_reserved_decode_tokens = 512
101
+ self.num_reserved_decode_tokens = int(
102
+ os.environ.get("SGLANG_NUM_RESERVED_DECODE_TOKENS", "512")
103
+ )
101
104
 
102
105
  # Queue for requests pending pre-allocation
103
106
  self.queue: List[DecodeRequest] = []
@@ -3,10 +3,12 @@ Minimal HTTP load balancer for prefill and decode servers for testing.
3
3
  """
4
4
 
5
5
  import asyncio
6
+ import dataclasses
7
+ import logging
6
8
  import random
7
9
  import urllib
8
10
  from itertools import chain
9
- from typing import List
11
+ from typing import List, Optional
10
12
 
11
13
  import aiohttp
12
14
  import orjson
@@ -14,11 +16,32 @@ import uvicorn
14
16
  from fastapi import FastAPI, HTTPException
15
17
  from fastapi.responses import ORJSONResponse, Response, StreamingResponse
16
18
 
19
+ from sglang.srt.disaggregation.utils import PDRegistryRequest
17
20
 
21
+
22
+ def setup_logger():
23
+ logger = logging.getLogger("pdlb")
24
+ logger.setLevel(logging.INFO)
25
+
26
+ formatter = logging.Formatter(
27
+ "[PDLB (Python)] %(asctime)s - %(levelname)s - %(message)s",
28
+ datefmt="%Y-%m-%d %H:%M:%S",
29
+ )
30
+
31
+ handler = logging.StreamHandler()
32
+ handler.setFormatter(formatter)
33
+ logger.addHandler(handler)
34
+
35
+ return logger
36
+
37
+
38
+ logger = setup_logger()
39
+
40
+
41
+ @dataclasses.dataclass
18
42
  class PrefillConfig:
19
- def __init__(self, url: str, bootstrap_port: int):
20
- self.url = url
21
- self.bootstrap_port = bootstrap_port
43
+ url: str
44
+ bootstrap_port: Optional[int] = None
22
45
 
23
46
 
24
47
  class MiniLoadBalancer:
@@ -28,6 +51,10 @@ class MiniLoadBalancer:
28
51
  self.decode_servers = decode_servers
29
52
 
30
53
  def select_pair(self):
54
+ # TODO: return some message instead of panic
55
+ assert len(self.prefill_configs) > 0, "No prefill servers available"
56
+ assert len(self.decode_servers) > 0, "No decode servers available"
57
+
31
58
  prefill_config = random.choice(self.prefill_configs)
32
59
  decode_server = random.choice(self.decode_servers)
33
60
  return prefill_config.url, prefill_config.bootstrap_port, decode_server
@@ -47,7 +74,7 @@ class MiniLoadBalancer:
47
74
  session.post(f"{decode_server}/{endpoint}", json=modified_request),
48
75
  ]
49
76
  # Wait for both responses to complete. Prefill should end first.
50
- prefill_response, decode_response = await asyncio.gather(*tasks)
77
+ _, decode_response = await asyncio.gather(*tasks)
51
78
 
52
79
  return ORJSONResponse(
53
80
  content=await decode_response.json(),
@@ -268,6 +295,32 @@ async def get_models():
268
295
  raise HTTPException(status_code=500, detail=str(e))
269
296
 
270
297
 
298
+ @app.post("/register")
299
+ async def register(obj: PDRegistryRequest):
300
+ if obj.mode == "prefill":
301
+ load_balancer.prefill_configs.append(
302
+ PrefillConfig(obj.registry_url, obj.bootstrap_port)
303
+ )
304
+ logger.info(
305
+ f"Registered prefill server: {obj.registry_url} with bootstrap port: {obj.bootstrap_port}"
306
+ )
307
+ elif obj.mode == "decode":
308
+ load_balancer.decode_servers.append(obj.registry_url)
309
+ logger.info(f"Registered decode server: {obj.registry_url}")
310
+ else:
311
+ raise HTTPException(
312
+ status_code=400,
313
+ detail="Invalid mode. Must be either PREFILL or DECODE.",
314
+ )
315
+
316
+ logger.info(
317
+ f"#Prefill servers: {len(load_balancer.prefill_configs)}, "
318
+ f"#Decode servers: {len(load_balancer.decode_servers)}"
319
+ )
320
+
321
+ return Response(status_code=200)
322
+
323
+
271
324
  def run(prefill_configs, decode_addrs, host, port):
272
325
  global load_balancer
273
326
  load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs)
@@ -279,15 +332,16 @@ if __name__ == "__main__":
279
332
 
280
333
  parser = argparse.ArgumentParser(description="Mini Load Balancer Server")
281
334
  parser.add_argument(
282
- "--prefill", required=True, help="Comma-separated URLs for prefill servers"
335
+ "--prefill", type=str, default=[], nargs="+", help="URLs for prefill servers"
283
336
  )
284
337
  parser.add_argument(
285
- "--prefill-bootstrap-ports",
286
- help="Comma-separated bootstrap ports for prefill servers",
287
- default="8998",
338
+ "--decode", type=str, default=[], nargs="+", help="URLs for decode servers"
288
339
  )
289
340
  parser.add_argument(
290
- "--decode", required=True, help="Comma-separated URLs for decode servers"
341
+ "--prefill-bootstrap-ports",
342
+ type=int,
343
+ nargs="+",
344
+ help="Bootstrap ports for prefill servers",
291
345
  )
292
346
  parser.add_argument(
293
347
  "--host", default="0.0.0.0", help="Host to bind the server (default: 0.0.0.0)"
@@ -297,22 +351,19 @@ if __name__ == "__main__":
297
351
  )
298
352
  args = parser.parse_args()
299
353
 
300
- prefill_urls = args.prefill.split(",")
301
- bootstrap_ports = [int(p) for p in args.prefill_bootstrap_ports.split(",")]
302
-
303
- if len(bootstrap_ports) == 1:
304
- bootstrap_ports = bootstrap_ports * len(prefill_urls)
354
+ bootstrap_ports = args.prefill_bootstrap_ports
355
+ if bootstrap_ports is None:
356
+ bootstrap_ports = [None] * len(args.prefill)
357
+ elif len(bootstrap_ports) == 1:
358
+ bootstrap_ports = bootstrap_ports * len(args.prefill)
305
359
  else:
306
- if len(bootstrap_ports) != len(prefill_urls):
360
+ if len(bootstrap_ports) != len(args.prefill):
307
361
  raise ValueError(
308
362
  "Number of prefill URLs must match number of bootstrap ports"
309
363
  )
310
- exit(1)
311
-
312
- prefill_configs = []
313
- for url, port in zip(prefill_urls, bootstrap_ports):
314
- prefill_configs.append(PrefillConfig(url, port))
315
364
 
316
- decode_addrs = args.decode.split(",")
365
+ prefill_configs = [
366
+ PrefillConfig(url, port) for url, port in zip(args.prefill, bootstrap_ports)
367
+ ]
317
368
 
318
- run(prefill_configs, decode_addrs, args.host, args.port)
369
+ run(prefill_configs, args.decode, args.host, args.port)
@@ -37,25 +37,16 @@ logger = logging.getLogger(__name__)
37
37
  def group_concurrent_contiguous(
38
38
  src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
39
39
  ) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
40
- src_groups = []
41
- dst_groups = []
42
- current_src = [src_indices[0]]
43
- current_dst = [dst_indices[0]]
44
-
45
- for i in range(1, len(src_indices)):
46
- src_contiguous = src_indices[i] == src_indices[i - 1] + 1
47
- dst_contiguous = dst_indices[i] == dst_indices[i - 1] + 1
48
- if src_contiguous and dst_contiguous:
49
- current_src.append(src_indices[i])
50
- current_dst.append(dst_indices[i])
51
- else:
52
- src_groups.append(current_src)
53
- dst_groups.append(current_dst)
54
- current_src = [src_indices[i]]
55
- current_dst = [dst_indices[i]]
40
+ """Vectorised NumPy implementation."""
41
+ if src_indices.size == 0:
42
+ return [], []
43
+
44
+ brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
45
+ src_groups = np.split(src_indices, brk)
46
+ dst_groups = np.split(dst_indices, brk)
56
47
 
57
- src_groups.append(current_src)
58
- dst_groups.append(current_dst)
48
+ src_groups = [g.tolist() for g in src_groups]
49
+ dst_groups = [g.tolist() for g in dst_groups]
59
50
 
60
51
  return src_groups, dst_groups
61
52