sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. sglang/bench_one_batch.py +1 -11
  2. sglang/bench_serving.py +149 -1
  3. sglang/lang/chat_template.py +44 -0
  4. sglang/srt/configs/deepseekvl2.py +3 -0
  5. sglang/srt/configs/device_config.py +1 -1
  6. sglang/srt/configs/internvl.py +696 -0
  7. sglang/srt/configs/janus_pro.py +3 -0
  8. sglang/srt/configs/model_config.py +17 -0
  9. sglang/srt/constrained/xgrammar_backend.py +11 -19
  10. sglang/srt/conversation.py +30 -3
  11. sglang/srt/disaggregation/decode.py +4 -1
  12. sglang/srt/disaggregation/mini_lb.py +74 -23
  13. sglang/srt/disaggregation/mooncake/conn.py +9 -18
  14. sglang/srt/disaggregation/nixl/conn.py +241 -71
  15. sglang/srt/disaggregation/utils.py +44 -1
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  17. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  18. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  19. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  20. sglang/srt/distributed/parallel_state.py +22 -1
  21. sglang/srt/entrypoints/engine.py +14 -2
  22. sglang/srt/entrypoints/http_server.py +28 -1
  23. sglang/srt/entrypoints/verl_engine.py +3 -2
  24. sglang/srt/hf_transformers_utils.py +20 -1
  25. sglang/srt/layers/attention/flashattention_backend.py +146 -50
  26. sglang/srt/layers/attention/flashinfer_backend.py +23 -13
  27. sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
  28. sglang/srt/layers/attention/merge_state.py +46 -0
  29. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  30. sglang/srt/layers/attention/vision.py +290 -163
  31. sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
  32. sglang/srt/layers/moe/ep_moe/layer.py +120 -1
  33. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +4 -1
  37. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  38. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  39. sglang/srt/layers/quantization/deep_gemm.py +5 -0
  40. sglang/srt/layers/quantization/fp8.py +108 -95
  41. sglang/srt/layers/quantization/fp8_kernel.py +79 -60
  42. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  43. sglang/srt/layers/quantization/kv_cache.py +3 -10
  44. sglang/srt/layers/quantization/utils.py +0 -5
  45. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  46. sglang/srt/lora/lora_manager.py +10 -13
  47. sglang/srt/managers/cache_controller.py +115 -119
  48. sglang/srt/managers/io_struct.py +10 -0
  49. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  50. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  51. sglang/srt/managers/schedule_batch.py +19 -1
  52. sglang/srt/managers/schedule_policy.py +11 -5
  53. sglang/srt/managers/scheduler.py +28 -13
  54. sglang/srt/managers/tokenizer_manager.py +24 -13
  55. sglang/srt/managers/tp_worker.py +9 -12
  56. sglang/srt/mem_cache/chunk_cache.py +2 -0
  57. sglang/srt/mem_cache/memory_pool.py +2 -2
  58. sglang/srt/model_executor/model_runner.py +44 -33
  59. sglang/srt/model_loader/loader.py +18 -11
  60. sglang/srt/models/clip.py +4 -4
  61. sglang/srt/models/deepseek_janus_pro.py +1 -1
  62. sglang/srt/models/deepseek_nextn.py +1 -20
  63. sglang/srt/models/deepseek_v2.py +55 -20
  64. sglang/srt/models/gemma3_mm.py +1 -1
  65. sglang/srt/models/internlm2.py +3 -0
  66. sglang/srt/models/internvl.py +670 -0
  67. sglang/srt/models/llama.py +1 -1
  68. sglang/srt/models/llama4.py +53 -7
  69. sglang/srt/models/minicpmv.py +1 -1
  70. sglang/srt/models/mllama.py +1 -1
  71. sglang/srt/models/phi3_small.py +16 -2
  72. sglang/srt/models/qwen2_5_vl.py +8 -4
  73. sglang/srt/models/qwen2_vl.py +4 -4
  74. sglang/srt/models/xiaomi_mimo.py +171 -0
  75. sglang/srt/openai_api/adapter.py +24 -40
  76. sglang/srt/openai_api/protocol.py +28 -16
  77. sglang/srt/reasoning_parser.py +2 -2
  78. sglang/srt/sampling/sampling_batch_info.py +54 -2
  79. sglang/srt/sampling/sampling_params.py +2 -0
  80. sglang/srt/server_args.py +30 -6
  81. sglang/srt/utils.py +35 -1
  82. sglang/test/test_block_fp8.py +2 -2
  83. sglang/test/test_deepep_utils.py +219 -0
  84. sglang/test/test_utils.py +3 -1
  85. sglang/version.py +1 -1
  86. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +14 -6
  87. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +90 -80
  88. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
  89. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
  90. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,232 @@
1
+ # Adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
2
+
3
+ import numpy as np
4
+ import torch
5
+ from decord import VideoReader, cpu
6
+ from numpy.distutils.cpuinfo import cpu
7
+ from PIL import Image
8
+
9
+ from sglang.srt.managers.multimodal_processors.base_processor import (
10
+ BaseMultimodalProcessor,
11
+ MultimodalSpecialTokens,
12
+ )
13
+ from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
14
+ from sglang.srt.models.internvl import InternVLChatModel
15
+
16
+
17
+ class InternVLImageProcessor(BaseMultimodalProcessor):
18
+ models = [InternVLChatModel]
19
+
20
+ def __init__(self, hf_config, server_args, _image_processor):
21
+ super().__init__(hf_config, server_args, _image_processor)
22
+ image_size = hf_config.force_image_size or hf_config.vision_config.image_size
23
+ patch_size = hf_config.vision_config.patch_size
24
+
25
+ self.IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>"
26
+ self.IMG_START_TOKEN = "<img>"
27
+ self.IMG_END_TOKEN = "</img>"
28
+ self.IMG_TOKEN = "<image>"
29
+ self.num_image_token = int(
30
+ (image_size // patch_size) ** 2 * (hf_config.downsample_ratio**2)
31
+ )
32
+
33
+ tokenizer = self._processor
34
+ self.img_start_token_id = tokenizer.convert_tokens_to_ids(self.IMG_START_TOKEN)
35
+ self.img_end_token_id = tokenizer.convert_tokens_to_ids(self.IMG_END_TOKEN)
36
+ self.img_context_token_id = tokenizer.convert_tokens_to_ids(
37
+ self.IMG_CONTEXT_TOKEN
38
+ )
39
+
40
+ @staticmethod
41
+ def build_transform(input_size):
42
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
43
+ IMAGENET_STD = (0.229, 0.224, 0.225)
44
+
45
+ def resize_image(img, size):
46
+ return img.resize((size, size), Image.Resampling.BICUBIC)
47
+
48
+ def to_tensor(img):
49
+ # Convert PIL Image to numpy array
50
+ img_array = np.array(img).astype(np.float32) / 255.0
51
+ # Convert HWC to CHW format
52
+ img_array = img_array.transpose(2, 0, 1)
53
+ return torch.from_numpy(img_array)
54
+
55
+ def normalize(tensor, mean, std):
56
+ mean = torch.tensor(mean).view(-1, 1, 1)
57
+ std = torch.tensor(std).view(-1, 1, 1)
58
+ return (tensor - mean) / std
59
+
60
+ def transform(img):
61
+ img = img.convert("RGB") if img.mode != "RGB" else img
62
+ img = resize_image(img, input_size)
63
+ tensor = to_tensor(img)
64
+ tensor = normalize(tensor, IMAGENET_MEAN, IMAGENET_STD)
65
+ return tensor
66
+
67
+ return transform
68
+
69
+ @staticmethod
70
+ def dynamic_preprocess(
71
+ image, min_num=1, max_num=12, image_size=448, use_thumbnail=False
72
+ ):
73
+
74
+ def find_closest_aspect_ratio(
75
+ aspect_ratio, target_ratios, width, height, image_size
76
+ ):
77
+ best_ratio_diff = float("inf")
78
+ best_ratio = (1, 1)
79
+ area = width * height
80
+ for ratio in target_ratios:
81
+ target_aspect_ratio = ratio[0] / ratio[1]
82
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
83
+ if ratio_diff < best_ratio_diff:
84
+ best_ratio_diff = ratio_diff
85
+ best_ratio = ratio
86
+ elif ratio_diff == best_ratio_diff:
87
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
88
+ best_ratio = ratio
89
+ return best_ratio
90
+
91
+ orig_width, orig_height = image.size
92
+ aspect_ratio = orig_width / orig_height
93
+
94
+ # calculate the existing image aspect ratio
95
+ target_ratios = set(
96
+ (i, j)
97
+ for n in range(min_num, max_num + 1)
98
+ for i in range(1, n + 1)
99
+ for j in range(1, n + 1)
100
+ if i * j <= max_num and i * j >= min_num
101
+ )
102
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
103
+
104
+ # find the closest aspect ratio to the target
105
+ target_aspect_ratio = find_closest_aspect_ratio(
106
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size
107
+ )
108
+
109
+ # calculate the target width and height
110
+ target_width = image_size * target_aspect_ratio[0]
111
+ target_height = image_size * target_aspect_ratio[1]
112
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
113
+
114
+ # resize the image
115
+ resized_img = image.resize((target_width, target_height))
116
+ processed_images = []
117
+ for i in range(blocks):
118
+ box = (
119
+ (i % (target_width // image_size)) * image_size,
120
+ (i // (target_width // image_size)) * image_size,
121
+ ((i % (target_width // image_size)) + 1) * image_size,
122
+ ((i // (target_width // image_size)) + 1) * image_size,
123
+ )
124
+ # split the image
125
+ split_img = resized_img.crop(box)
126
+ processed_images.append(split_img)
127
+ assert len(processed_images) == blocks
128
+ if use_thumbnail and len(processed_images) != 1:
129
+ thumbnail_img = image.resize((image_size, image_size))
130
+ processed_images.append(thumbnail_img)
131
+ return processed_images
132
+
133
+ @staticmethod
134
+ def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
135
+ if bound:
136
+ start, end = bound[0], bound[1]
137
+ else:
138
+ start, end = -100000, 100000
139
+ start_idx = max(first_idx, round(start * fps))
140
+ end_idx = min(round(end * fps), max_frame)
141
+ seg_size = float(end_idx - start_idx) / num_segments
142
+ frame_indices = np.array(
143
+ [
144
+ int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
145
+ for idx in range(num_segments)
146
+ ]
147
+ )
148
+ return frame_indices
149
+
150
+ @staticmethod
151
+ def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
152
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
153
+ max_frame = len(vr) - 1
154
+ fps = float(vr.get_avg_fps())
155
+
156
+ pixel_values_list, num_patches_list = [], []
157
+ transform = InternVLImageProcessor.build_transform(input_size=input_size)
158
+ frame_indices = InternVLImageProcessor.get_index(
159
+ bound, fps, max_frame, first_idx=0, num_segments=num_segments
160
+ )
161
+ for frame_index in frame_indices:
162
+ img = Image.fromarray(vr[frame_index].asnumpy()).convert("RGB")
163
+ img = InternVLImageProcessor.dynamic_preprocess(
164
+ img, image_size=input_size, use_thumbnail=True, max_num=max_num
165
+ )
166
+ pixel_values = [transform(tile) for tile in img]
167
+ pixel_values = torch.stack(pixel_values)
168
+ num_patches_list.append(pixel_values.shape[0])
169
+ pixel_values_list.append(pixel_values)
170
+ pixel_values = torch.cat(pixel_values_list)
171
+ return pixel_values, num_patches_list
172
+
173
+ async def process_mm_data_async(
174
+ self, image_data, input_text, request_obj, max_req_input_len, **kwargs
175
+ ):
176
+ if not image_data:
177
+ return None
178
+
179
+ base_output = self.load_mm_data(
180
+ prompt=input_text,
181
+ image_data=image_data,
182
+ multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMG_TOKEN),
183
+ max_req_input_len=max_req_input_len,
184
+ discard_alpha_channel=True,
185
+ )
186
+
187
+ def process_image_internvl(image, input_size=448, max_num=12):
188
+ transform = InternVLImageProcessor.build_transform(input_size=input_size)
189
+ images = InternVLImageProcessor.dynamic_preprocess(
190
+ image, image_size=input_size, use_thumbnail=True, max_num=max_num
191
+ )
192
+ pixel_values = [transform(image) for image in images]
193
+ pixel_values = torch.stack(pixel_values)
194
+ return pixel_values
195
+
196
+ num_patches_list = []
197
+ pixel_values = []
198
+ # Process each input with allocated frames
199
+ for image_index, (image) in enumerate(base_output.images):
200
+ try:
201
+ # TODO: video input
202
+ raw_image = process_image_internvl(image)
203
+ pixel_value = [raw_image.to(torch.bfloat16).cuda()]
204
+ pixel_values += pixel_value
205
+ num_patches = raw_image.shape[0]
206
+ num_patches_list += [num_patches]
207
+
208
+ except FileNotFoundError as e:
209
+ print(e)
210
+ return None
211
+
212
+ pixel_values = torch.cat(pixel_values, dim=0)
213
+ items = [MultimodalDataItem(pixel_values=pixel_values, modality=Modality.IMAGE)]
214
+
215
+ for idx, num_patches in enumerate(num_patches_list):
216
+ image_tokens = (
217
+ self.IMG_START_TOKEN
218
+ + self.IMG_CONTEXT_TOKEN * self.num_image_token * num_patches
219
+ + self.IMG_END_TOKEN
220
+ )
221
+ input_text = input_text.replace("<image>", image_tokens, 1)
222
+
223
+ tokenizer = self._processor
224
+ return {
225
+ "input_ids": tokenizer(input_text, return_tensors="pt")["input_ids"]
226
+ .flatten()
227
+ .tolist(),
228
+ "mm_items": items,
229
+ "im_start_id": self.img_start_token_id,
230
+ "im_end_id": self.img_end_token_id,
231
+ "im_token_id": self.img_context_token_id,
232
+ }
@@ -745,6 +745,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
745
745
  out_cache_loc: torch.Tensor = None # shape: [b], int64
746
746
  output_ids: torch.Tensor = None # shape: [b], int64
747
747
 
748
+ # For multimodal inputs
749
+ multimodal_inputs: Optional[List] = None
750
+
748
751
  # The sum of all sequence lengths
749
752
  seq_lens_sum: int = None
750
753
 
@@ -1050,6 +1053,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1050
1053
  # Copy prefix and do some basic check
1051
1054
  input_embeds = []
1052
1055
  extend_input_logprob_token_ids = []
1056
+ multimodal_inputs = []
1053
1057
 
1054
1058
  for i, (req, seq_len, pre_len) in enumerate(zip(reqs, seq_lens, prefix_lens)):
1055
1059
  req.req_pool_idx = req_pool_indices[i]
@@ -1065,6 +1069,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1065
1069
  # If req.input_embeds is already a list, append its content directly
1066
1070
  input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
1067
1071
 
1072
+ multimodal_inputs.append(req.multimodal_inputs)
1073
+
1068
1074
  req.cached_tokens += pre_len - req.already_computed
1069
1075
  req.already_computed = seq_len
1070
1076
  req.is_retracted = False
@@ -1147,6 +1153,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1147
1153
  if input_embeds
1148
1154
  else None
1149
1155
  )
1156
+ for mm_input in multimodal_inputs:
1157
+ if mm_input is None:
1158
+ continue
1159
+ for mm_item in mm_input.mm_items:
1160
+ pixel_values = getattr(mm_item, "pixel_values", None)
1161
+ if isinstance(pixel_values, torch.Tensor):
1162
+ mm_item.pixel_values = pixel_values.to(
1163
+ self.device, non_blocking=True
1164
+ )
1165
+ self.multimodal_inputs = multimodal_inputs
1150
1166
  self.seq_lens_sum = sum(seq_lens)
1151
1167
 
1152
1168
  if self.return_logprob:
@@ -1452,6 +1468,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1452
1468
  self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
1453
1469
 
1454
1470
  self.reqs = [self.reqs[i] for i in keep_indices]
1471
+ self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
1455
1472
  self.req_pool_indices = self.req_pool_indices[keep_indices_device]
1456
1473
  self.seq_lens = self.seq_lens[keep_indices_device]
1457
1474
  self.out_cache_loc = None
@@ -1500,6 +1517,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1500
1517
  self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1501
1518
  self.token_ids_logprobs = [None] * len(self.reqs) + other.token_ids_logprobs
1502
1519
  self.reqs.extend(other.reqs)
1520
+ self.multimodal_inputs.extend(other.multimodal_inputs)
1503
1521
 
1504
1522
  self.return_logprob |= other.return_logprob
1505
1523
  self.has_stream |= other.has_stream
@@ -1558,7 +1576,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1558
1576
  extend_seq_lens=extend_seq_lens,
1559
1577
  extend_prefix_lens=extend_prefix_lens,
1560
1578
  extend_logprob_start_lens=extend_logprob_start_lens,
1561
- multimodal_inputs=[r.multimodal_inputs for r in self.reqs],
1579
+ multimodal_inputs=self.multimodal_inputs,
1562
1580
  encoder_cached=self.encoder_cached,
1563
1581
  encoder_lens=self.encoder_lens,
1564
1582
  encoder_lens_cpu=self.encoder_lens_cpu,
@@ -455,7 +455,10 @@ class PrefillAdder:
455
455
  total_tokens = req.extend_input_len + min(
456
456
  req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
457
457
  )
458
- input_tokens = req.extend_input_len
458
+ input_tokens = (
459
+ -(-req.extend_input_len // self.tree_cache.page_size)
460
+ * self.tree_cache.page_size
461
+ )
459
462
  prefix_len = len(req.prefix_indices)
460
463
 
461
464
  if total_tokens >= self.rem_total_tokens:
@@ -477,7 +480,10 @@ class PrefillAdder:
477
480
  req.last_node_global, req.prefix_indices
478
481
  )
479
482
  req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
480
- input_tokens = req.extend_input_len
483
+ input_tokens = (
484
+ -(-req.extend_input_len // self.tree_cache.page_size)
485
+ * self.tree_cache.page_size
486
+ )
481
487
  prefix_len = len(req.prefix_indices)
482
488
 
483
489
  if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
@@ -493,12 +499,12 @@ class PrefillAdder:
493
499
  ),
494
500
  )
495
501
  else:
496
- if self.rem_chunk_tokens == 0:
502
+ # Make sure at least one page is available
503
+ trunc_len = self.rem_chunk_tokens - self.tree_cache.page_size + 1
504
+ if trunc_len <= 0:
497
505
  return AddReqResult.OTHER
498
506
 
499
507
  # Chunked prefill
500
- trunc_len = self.rem_chunk_tokens
501
-
502
508
  req.extend_input_len = trunc_len
503
509
  req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
504
510
 
@@ -52,7 +52,11 @@ from sglang.srt.disaggregation.utils import (
52
52
  TransferBackend,
53
53
  )
54
54
  from sglang.srt.distributed import get_pp_group, get_world_group
55
- from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
55
+ from sglang.srt.hf_transformers_utils import (
56
+ get_processor,
57
+ get_tokenizer,
58
+ get_tokenizer_from_processor,
59
+ )
56
60
  from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
57
61
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
58
62
  from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
@@ -83,6 +87,8 @@ from sglang.srt.managers.io_struct import (
83
87
  RpcReqOutput,
84
88
  SetInternalStateReq,
85
89
  SetInternalStateReqOutput,
90
+ SlowDownReqInput,
91
+ SlowDownReqOutput,
86
92
  TokenizedEmbeddingReqInput,
87
93
  TokenizedGenerateReqInput,
88
94
  UpdateWeightFromDiskReqInput,
@@ -413,6 +419,8 @@ class Scheduler(
413
419
  self.profiler_id: Optional[str] = None
414
420
  self.profiler_target_forward_ct: Optional[int] = None
415
421
 
422
+ self.forward_sleep_time = None
423
+
416
424
  # Init metrics stats
417
425
  self.init_metrics()
418
426
 
@@ -435,6 +443,7 @@ class Scheduler(
435
443
  (GetWeightsByNameReqInput, self.get_weights_by_name),
436
444
  (ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
437
445
  (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
446
+ (SlowDownReqInput, self.slow_down),
438
447
  (ProfileReq, self.profile),
439
448
  (GetInternalStateReq, self.get_internal_state),
440
449
  (SetInternalStateReq, self.set_internal_state),
@@ -451,17 +460,7 @@ class Scheduler(
451
460
  def init_tokenizer(self):
452
461
  server_args = self.server_args
453
462
 
454
- self.model_config = ModelConfig(
455
- server_args.model_path,
456
- trust_remote_code=server_args.trust_remote_code,
457
- revision=server_args.revision,
458
- context_length=server_args.context_length,
459
- model_override_args=server_args.json_model_override_args,
460
- is_embedding=server_args.is_embedding,
461
- enable_multimodal=server_args.enable_multimodal,
462
- dtype=server_args.dtype,
463
- quantization=server_args.quantization,
464
- )
463
+ self.model_config = ModelConfig.from_server_args(server_args)
465
464
  self.is_generation = self.model_config.is_generation
466
465
 
467
466
  if server_args.skip_tokenizer_init:
@@ -475,7 +474,7 @@ class Scheduler(
475
474
  revision=server_args.revision,
476
475
  use_fast=not server_args.disable_fast_image_processor,
477
476
  )
478
- self.tokenizer = self.processor.tokenizer
477
+ self.tokenizer = get_tokenizer_from_processor(self.processor)
479
478
  else:
480
479
  self.tokenizer = get_tokenizer(
481
480
  server_args.tokenizer_path,
@@ -498,6 +497,7 @@ class Scheduler(
498
497
  self.tree_cache = ChunkCache(
499
498
  req_to_token_pool=self.req_to_token_pool,
500
499
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
500
+ page_size=self.page_size,
501
501
  )
502
502
  else:
503
503
  if self.enable_hierarchical_cache:
@@ -920,6 +920,10 @@ class Scheduler(
920
920
  )
921
921
  custom_logit_processor = None
922
922
 
923
+ if recv_req.bootstrap_port is None:
924
+ # Use default bootstrap port
925
+ recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port
926
+
923
927
  req = Req(
924
928
  recv_req.rid,
925
929
  recv_req.input_text,
@@ -1527,6 +1531,10 @@ class Scheduler(
1527
1531
  ):
1528
1532
  self.stop_profile()
1529
1533
 
1534
+ if self.forward_sleep_time is not None:
1535
+ logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
1536
+ time.sleep(self.forward_sleep_time)
1537
+
1530
1538
  # Run forward
1531
1539
  if self.is_generation:
1532
1540
  if self.spec_algorithm.is_none():
@@ -2002,6 +2010,13 @@ class Scheduler(
2002
2010
  del self.stashed_model_static_state
2003
2011
  return ResumeMemoryOccupationReqOutput()
2004
2012
 
2013
+ def slow_down(self, recv_req: SlowDownReqInput):
2014
+ t = recv_req.forward_sleep_time
2015
+ if t is not None and t <= 0:
2016
+ t = None
2017
+ self.forward_sleep_time = t
2018
+ return SlowDownReqOutput()
2019
+
2005
2020
  def profile(self, recv_req: ProfileReq):
2006
2021
  if recv_req.type == ProfileReqType.START_PROFILE:
2007
2022
  return self.start_profile(
@@ -54,7 +54,11 @@ from sglang.srt.disaggregation.utils import (
54
54
  TransferBackend,
55
55
  get_kv_class,
56
56
  )
57
- from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
57
+ from sglang.srt.hf_transformers_utils import (
58
+ get_processor,
59
+ get_tokenizer,
60
+ get_tokenizer_from_processor,
61
+ )
58
62
  from sglang.srt.managers.io_struct import (
59
63
  AbortReq,
60
64
  BatchEmbeddingOut,
@@ -86,6 +90,8 @@ from sglang.srt.managers.io_struct import (
86
90
  ResumeMemoryOccupationReqInput,
87
91
  ResumeMemoryOccupationReqOutput,
88
92
  SessionParams,
93
+ SlowDownReqInput,
94
+ SlowDownReqOutput,
89
95
  TokenizedEmbeddingReqInput,
90
96
  TokenizedGenerateReqInput,
91
97
  UpdateWeightFromDiskReqInput,
@@ -161,17 +167,7 @@ class TokenizerManager:
161
167
  # Read model args
162
168
  self.model_path = server_args.model_path
163
169
  self.served_model_name = server_args.served_model_name
164
- self.model_config = ModelConfig(
165
- server_args.model_path,
166
- trust_remote_code=server_args.trust_remote_code,
167
- revision=server_args.revision,
168
- context_length=server_args.context_length,
169
- model_override_args=server_args.json_model_override_args,
170
- is_embedding=server_args.is_embedding,
171
- enable_multimodal=server_args.enable_multimodal,
172
- dtype=server_args.dtype,
173
- quantization=server_args.quantization,
174
- )
170
+ self.model_config = ModelConfig.from_server_args(server_args)
175
171
 
176
172
  self.is_generation = self.model_config.is_generation
177
173
  self.is_image_gen = self.model_config.is_image_gen
@@ -199,7 +195,7 @@ class TokenizerManager:
199
195
  self.tokenizer = self.processor = None
200
196
  else:
201
197
  self.processor = _processor
202
- self.tokenizer = self.processor.tokenizer
198
+ self.tokenizer = get_tokenizer_from_processor(self.processor)
203
199
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
204
200
  else:
205
201
  self.mm_processor = get_dummy_processor()
@@ -265,6 +261,9 @@ class TokenizerManager:
265
261
  self.resume_memory_occupation_communicator = _Communicator(
266
262
  self.send_to_scheduler, server_args.dp_size
267
263
  )
264
+ self.slow_down_communicator = _Communicator(
265
+ self.send_to_scheduler, server_args.dp_size
266
+ )
268
267
  self.flush_cache_communicator = _Communicator(
269
268
  self.send_to_scheduler, server_args.dp_size
270
269
  )
@@ -318,6 +317,10 @@ class TokenizerManager:
318
317
  ResumeMemoryOccupationReqOutput,
319
318
  self.resume_memory_occupation_communicator.handle_recv,
320
319
  ),
320
+ (
321
+ SlowDownReqOutput,
322
+ self.slow_down_communicator.handle_recv,
323
+ ),
321
324
  (
322
325
  FlushCacheReqOutput,
323
326
  self.flush_cache_communicator.handle_recv,
@@ -876,6 +879,14 @@ class TokenizerManager:
876
879
  self.auto_create_handle_loop()
877
880
  await self.resume_memory_occupation_communicator(obj)
878
881
 
882
+ async def slow_down(
883
+ self,
884
+ obj: SlowDownReqInput,
885
+ request: Optional[fastapi.Request] = None,
886
+ ):
887
+ self.auto_create_handle_loop()
888
+ await self.slow_down_communicator(obj)
889
+
879
890
  async def open_session(
880
891
  self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
881
892
  ):
@@ -21,7 +21,11 @@ import torch
21
21
 
22
22
  from sglang.srt.configs.model_config import ModelConfig
23
23
  from sglang.srt.distributed import get_pp_group, get_tp_group, get_world_group
24
- from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
24
+ from sglang.srt.hf_transformers_utils import (
25
+ get_processor,
26
+ get_tokenizer,
27
+ get_tokenizer_from_processor,
28
+ )
25
29
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
26
30
  from sglang.srt.managers.io_struct import (
27
31
  GetWeightsByNameReqInput,
@@ -61,20 +65,13 @@ class TpModelWorker:
61
65
  self.pp_rank = pp_rank
62
66
 
63
67
  # Init model and tokenizer
64
- self.model_config = ModelConfig(
65
- (
68
+ self.model_config = ModelConfig.from_server_args(
69
+ server_args,
70
+ model_path=(
66
71
  server_args.model_path
67
72
  if not is_draft_worker
68
73
  else server_args.speculative_draft_model_path
69
74
  ),
70
- trust_remote_code=server_args.trust_remote_code,
71
- revision=server_args.revision,
72
- context_length=server_args.context_length,
73
- model_override_args=server_args.json_model_override_args,
74
- is_embedding=server_args.is_embedding,
75
- enable_multimodal=server_args.enable_multimodal,
76
- dtype=server_args.dtype,
77
- quantization=server_args.quantization,
78
75
  is_draft_model=is_draft_worker,
79
76
  )
80
77
 
@@ -102,7 +99,7 @@ class TpModelWorker:
102
99
  trust_remote_code=server_args.trust_remote_code,
103
100
  revision=server_args.revision,
104
101
  )
105
- self.tokenizer = self.processor.tokenizer
102
+ self.tokenizer = get_tokenizer_from_processor(self.processor)
106
103
  else:
107
104
  self.tokenizer = get_tokenizer(
108
105
  server_args.tokenizer_path,
@@ -24,9 +24,11 @@ class ChunkCache(BasePrefixCache):
24
24
  self,
25
25
  req_to_token_pool: ReqToTokenPool,
26
26
  token_to_kv_pool_allocator: TokenToKVPoolAllocator,
27
+ page_size: int,
27
28
  ):
28
29
  self.req_to_token_pool = req_to_token_pool
29
30
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
31
+ self.page_size = page_size
30
32
 
31
33
  def reset(self):
32
34
  pass
@@ -374,9 +374,9 @@ class MHATokenToKVPool(KVCache):
374
374
  # Overlap the copy of K and V cache for small batch size
375
375
  current_stream = self.device_module.current_stream()
376
376
  self.alt_stream.wait_stream(current_stream)
377
+ self.k_buffer[layer_id - self.start_layer][loc] = cache_k
377
378
  with self.device_module.stream(self.alt_stream):
378
- self.k_buffer[layer_id - self.start_layer][loc] = cache_k
379
- self.v_buffer[layer_id - self.start_layer][loc] = cache_v
379
+ self.v_buffer[layer_id - self.start_layer][loc] = cache_v
380
380
  current_stream.wait_stream(self.alt_stream)
381
381
  else:
382
382
  self.k_buffer[layer_id - self.start_layer][loc] = cache_k