sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +133 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +32 -21
  49. sglang/srt/layers/layernorm.py +24 -2
  50. sglang/srt/layers/linear.py +17 -5
  51. sglang/srt/layers/logits_processor.py +25 -7
  52. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  53. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  54. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  55. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  61. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  62. sglang/srt/layers/moe/topk.py +31 -18
  63. sglang/srt/layers/parameter.py +1 -1
  64. sglang/srt/layers/quantization/__init__.py +184 -126
  65. sglang/srt/layers/quantization/base_config.py +5 -0
  66. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  67. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  69. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  70. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  71. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  73. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  74. sglang/srt/layers/quantization/fp8.py +76 -34
  75. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  76. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  77. sglang/srt/layers/quantization/gptq.py +36 -9
  78. sglang/srt/layers/quantization/kv_cache.py +98 -0
  79. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  80. sglang/srt/layers/quantization/utils.py +153 -0
  81. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  82. sglang/srt/layers/rotary_embedding.py +66 -87
  83. sglang/srt/layers/sampler.py +1 -1
  84. sglang/srt/lora/layers.py +68 -0
  85. sglang/srt/lora/lora.py +2 -22
  86. sglang/srt/lora/lora_manager.py +47 -23
  87. sglang/srt/lora/mem_pool.py +110 -51
  88. sglang/srt/lora/utils.py +12 -1
  89. sglang/srt/managers/cache_controller.py +2 -5
  90. sglang/srt/managers/data_parallel_controller.py +30 -8
  91. sglang/srt/managers/expert_distribution.py +81 -0
  92. sglang/srt/managers/io_struct.py +39 -3
  93. sglang/srt/managers/mm_utils.py +373 -0
  94. sglang/srt/managers/multimodal_processor.py +68 -0
  95. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  96. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  97. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  98. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  99. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  100. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  101. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  102. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  103. sglang/srt/managers/schedule_batch.py +133 -30
  104. sglang/srt/managers/scheduler.py +273 -20
  105. sglang/srt/managers/session_controller.py +1 -1
  106. sglang/srt/managers/tokenizer_manager.py +59 -23
  107. sglang/srt/managers/tp_worker.py +1 -1
  108. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  109. sglang/srt/managers/utils.py +6 -1
  110. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  111. sglang/srt/mem_cache/memory_pool.py +255 -98
  112. sglang/srt/mem_cache/paged_allocator.py +2 -2
  113. sglang/srt/mem_cache/radix_cache.py +4 -4
  114. sglang/srt/model_executor/cuda_graph_runner.py +27 -13
  115. sglang/srt/model_executor/forward_batch_info.py +68 -11
  116. sglang/srt/model_executor/model_runner.py +70 -6
  117. sglang/srt/model_loader/loader.py +160 -2
  118. sglang/srt/model_loader/weight_utils.py +45 -0
  119. sglang/srt/models/deepseek_janus_pro.py +29 -86
  120. sglang/srt/models/deepseek_nextn.py +22 -10
  121. sglang/srt/models/deepseek_v2.py +208 -77
  122. sglang/srt/models/deepseek_vl2.py +358 -0
  123. sglang/srt/models/gemma3_causal.py +684 -0
  124. sglang/srt/models/gemma3_mm.py +462 -0
  125. sglang/srt/models/llama.py +47 -7
  126. sglang/srt/models/llama_eagle.py +1 -0
  127. sglang/srt/models/llama_eagle3.py +196 -0
  128. sglang/srt/models/llava.py +3 -3
  129. sglang/srt/models/llavavid.py +3 -3
  130. sglang/srt/models/minicpmo.py +1995 -0
  131. sglang/srt/models/minicpmv.py +62 -137
  132. sglang/srt/models/mllama.py +4 -4
  133. sglang/srt/models/phi3_small.py +1 -1
  134. sglang/srt/models/qwen2.py +3 -0
  135. sglang/srt/models/qwen2_5_vl.py +68 -146
  136. sglang/srt/models/qwen2_classification.py +75 -0
  137. sglang/srt/models/qwen2_moe.py +9 -1
  138. sglang/srt/models/qwen2_vl.py +25 -63
  139. sglang/srt/openai_api/adapter.py +124 -28
  140. sglang/srt/openai_api/protocol.py +23 -2
  141. sglang/srt/sampling/sampling_batch_info.py +1 -1
  142. sglang/srt/sampling/sampling_params.py +6 -6
  143. sglang/srt/server_args.py +99 -9
  144. sglang/srt/speculative/build_eagle_tree.py +7 -347
  145. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  146. sglang/srt/speculative/eagle_utils.py +208 -252
  147. sglang/srt/speculative/eagle_worker.py +139 -53
  148. sglang/srt/speculative/spec_info.py +6 -1
  149. sglang/srt/torch_memory_saver_adapter.py +22 -0
  150. sglang/srt/utils.py +182 -21
  151. sglang/test/__init__.py +0 -0
  152. sglang/test/attention/__init__.py +0 -0
  153. sglang/test/attention/test_flashattn_backend.py +312 -0
  154. sglang/test/runners.py +2 -0
  155. sglang/test/test_activation.py +2 -1
  156. sglang/test/test_block_fp8.py +5 -4
  157. sglang/test/test_block_fp8_ep.py +2 -1
  158. sglang/test/test_dynamic_grad_mode.py +58 -0
  159. sglang/test/test_layernorm.py +3 -2
  160. sglang/test/test_utils.py +55 -4
  161. sglang/utils.py +31 -0
  162. sglang/version.py +1 -1
  163. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  164. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +167 -123
  165. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  166. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  167. sglang/srt/managers/image_processor.py +0 -55
  168. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  169. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  170. sglang/srt/managers/multi_modality_padding.py +0 -134
  171. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  172. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import bisect
19
+ import os
19
20
  from contextlib import contextmanager
20
21
  from typing import TYPE_CHECKING, Callable
21
22
 
@@ -81,7 +82,9 @@ def patch_model(
81
82
  # tp_group.ca_comm = None
82
83
  yield torch.compile(
83
84
  torch.no_grad()(model.forward),
84
- mode="max-autotune-no-cudagraphs",
85
+ mode=os.environ.get(
86
+ "SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs"
87
+ ),
85
88
  dynamic=False,
86
89
  )
87
90
  else:
@@ -117,7 +120,9 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
117
120
  else:
118
121
  capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
119
122
  else:
120
- capture_bs = list(range(1, 33))
123
+ # Since speculative decoding requires more cuda graph memory, we
124
+ # capture less.
125
+ capture_bs = list(range(1, 9)) + list(range(9, 33, 2)) + [64, 96, 128, 160]
121
126
 
122
127
  if _is_hip:
123
128
  capture_bs += [i * 8 for i in range(21, 33)]
@@ -125,16 +130,11 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
125
130
  if max(capture_bs) > model_runner.req_to_token_pool.size:
126
131
  # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
127
132
  # is very small. We add more values here to make sure we capture the maximum bs.
128
- capture_bs = list(
129
- sorted(
130
- set(
131
- capture_bs
132
- + [model_runner.req_to_token_pool.size - 1]
133
- + [model_runner.req_to_token_pool.size]
134
- )
135
- )
136
- )
133
+ capture_bs += [model_runner.req_to_token_pool.size - 1] + [
134
+ model_runner.req_to_token_pool.size
135
+ ]
137
136
 
137
+ capture_bs = list(sorted(set(capture_bs)))
138
138
  capture_bs = [
139
139
  bs
140
140
  for bs in capture_bs
@@ -220,7 +220,19 @@ class CudaGraphRunner:
220
220
  self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
221
221
 
222
222
  # Speculative_inference
223
- if model_runner.spec_algorithm.is_eagle():
223
+ if (
224
+ model_runner.spec_algorithm.is_eagle3()
225
+ and not model_runner.is_draft_worker
226
+ ):
227
+ self.hidden_states = torch.zeros(
228
+ (
229
+ self.max_num_token,
230
+ 3 * self.model_runner.model_config.hidden_size,
231
+ ),
232
+ dtype=self.model_runner.dtype,
233
+ )
234
+ self.model_runner.model.set_eagle3_layers_to_capture()
235
+ elif model_runner.spec_algorithm.is_eagle():
224
236
  self.hidden_states = torch.zeros(
225
237
  (self.max_num_token, self.model_runner.model_config.hidden_size),
226
238
  dtype=self.model_runner.dtype,
@@ -508,7 +520,9 @@ class CudaGraphRunner:
508
520
  self.raw_num_token = raw_num_token
509
521
  self.bs = bs
510
522
 
511
- def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False):
523
+ def replay(
524
+ self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
525
+ ) -> LogitsProcessorOutput:
512
526
  if not skip_attn_backend_init:
513
527
  self.replay_prepare(forward_batch)
514
528
  else:
@@ -33,6 +33,7 @@ from dataclasses import dataclass
33
33
  from enum import IntEnum, auto
34
34
  from typing import TYPE_CHECKING, List, Optional, Union
35
35
 
36
+ import numpy as np
36
37
  import torch
37
38
  import triton
38
39
  import triton.language as tl
@@ -42,7 +43,7 @@ from sglang.srt.utils import get_compiler_backend
42
43
 
43
44
  if TYPE_CHECKING:
44
45
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
45
- from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
46
+ from sglang.srt.managers.schedule_batch import ModelWorkerBatch, MultimodalInputs
46
47
  from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
47
48
  from sglang.srt.model_executor.model_runner import ModelRunner
48
49
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
@@ -175,7 +176,7 @@ class ForwardBatch:
175
176
  extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
176
177
 
177
178
  # For multimodal
178
- image_inputs: Optional[List[ImageInputs]] = None
179
+ mm_inputs: Optional[List[MultimodalInputs]] = None
179
180
 
180
181
  # Encoder-decoder
181
182
  encoder_cached: Optional[List[bool]] = None
@@ -241,7 +242,7 @@ class ForwardBatch:
241
242
  req_pool_indices=batch.req_pool_indices,
242
243
  seq_lens=batch.seq_lens,
243
244
  out_cache_loc=batch.out_cache_loc,
244
- image_inputs=batch.image_inputs,
245
+ mm_inputs=batch.multimodal_inputs,
245
246
  encoder_cached=batch.encoder_cached,
246
247
  encoder_lens=batch.encoder_lens,
247
248
  encoder_lens_cpu=batch.encoder_lens_cpu,
@@ -331,6 +332,53 @@ class ForwardBatch:
331
332
 
332
333
  return ret
333
334
 
335
+ def merge_mm_inputs(self) -> Optional[MultimodalInputs]:
336
+ """
337
+ Merge all image inputs in the batch into a single MultiModalInputs object.
338
+
339
+ Returns:
340
+ if none, current batch contains no image input
341
+
342
+ """
343
+ if not self.mm_inputs or all(x is None for x in self.mm_inputs):
344
+ return None
345
+
346
+ # Filter out None values
347
+ valid_inputs = [x for x in self.mm_inputs if x is not None]
348
+
349
+ # Start with the first valid image input
350
+ merged = valid_inputs[0]
351
+
352
+ # Merge remaining inputs
353
+ for mm_input in valid_inputs[1:]:
354
+ merged.merge(mm_input)
355
+
356
+ if isinstance(merged.pixel_values, np.ndarray):
357
+ merged.pixel_values = torch.from_numpy(merged.pixel_values)
358
+ if isinstance(merged.audio_features, np.ndarray):
359
+ merged.audio_features = torch.from_numpy(merged.audio_features)
360
+
361
+ return merged
362
+
363
+ def contains_image_inputs(self) -> bool:
364
+ if self.mm_inputs is None:
365
+ return False
366
+ return any(
367
+ mm_input is not None and mm_input.contains_image_inputs()
368
+ for mm_input in self.mm_inputs
369
+ )
370
+
371
+ def contains_audio_inputs(self) -> bool:
372
+ if self.mm_inputs is None:
373
+ return False
374
+ return any(
375
+ mm_input is not None and mm_input.contains_audio_inputs()
376
+ for mm_input in self.mm_inputs
377
+ )
378
+
379
+ def contains_mm_inputs(self) -> bool:
380
+ return self.contains_audio_inputs() or self.contains_image_inputs()
381
+
334
382
  def _compute_mrope_positions(
335
383
  self, model_runner: ModelRunner, batch: ModelWorkerBatch
336
384
  ):
@@ -341,8 +389,8 @@ class ForwardBatch:
341
389
  for i, _ in enumerate(mrope_positions_list):
342
390
  mrope_position_delta = (
343
391
  0
344
- if batch.image_inputs[i] is None
345
- else batch.image_inputs[i].mrope_position_delta
392
+ if batch.multimodal_inputs[i] is None
393
+ else batch.multimodal_inputs[i].mrope_position_delta
346
394
  )
347
395
  mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
348
396
  mrope_position_delta,
@@ -351,13 +399,13 @@ class ForwardBatch:
351
399
  )
352
400
  elif self.forward_mode.is_extend():
353
401
  extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
354
- for i, image_inputs in enumerate(batch.image_inputs):
402
+ for i, multimodal_inputs in enumerate(batch.multimodal_inputs):
355
403
  extend_start_loc, extend_seq_len, extend_prefix_len = (
356
404
  extend_start_loc_cpu[i],
357
405
  batch.extend_seq_lens[i],
358
406
  batch.extend_prefix_lens[i],
359
407
  )
360
- if image_inputs is None:
408
+ if multimodal_inputs is None:
361
409
  # text only
362
410
  mrope_positions = [
363
411
  [
@@ -374,16 +422,25 @@ class ForwardBatch:
374
422
  input_tokens=self.input_ids[
375
423
  extend_start_loc : extend_start_loc + extend_seq_len
376
424
  ],
377
- image_grid_thw=image_inputs.image_grid_thws,
425
+ image_grid_thw=multimodal_inputs.image_grid_thws,
426
+ video_grid_thw=multimodal_inputs.video_grid_thws,
427
+ image_token_id=multimodal_inputs.im_token_id,
428
+ video_token_id=multimodal_inputs.video_token_id,
378
429
  vision_start_token_id=hf_config.vision_start_token_id,
430
+ vision_end_token_id=hf_config.vision_end_token_id,
379
431
  spatial_merge_size=hf_config.vision_config.spatial_merge_size,
380
432
  context_len=0,
433
+ seq_len=len(self.input_ids),
434
+ second_per_grid_ts=multimodal_inputs.second_per_grid_ts,
435
+ tokens_per_second=hf_config.vision_config.tokens_per_second,
381
436
  )
382
437
  )
383
- batch.image_inputs[i].mrope_position_delta = mrope_position_delta
438
+ batch.multimodal_inputs[i].mrope_position_delta = (
439
+ mrope_position_delta
440
+ )
384
441
  mrope_positions_list[i] = mrope_positions
385
442
 
386
- self.mrope_positions = torch.concat(
443
+ self.mrope_positions = torch.cat(
387
444
  [torch.tensor(pos, device=device) for pos in mrope_positions_list],
388
445
  axis=1,
389
446
  )
@@ -449,7 +506,7 @@ def compute_position_kernel(
449
506
  def compute_position_torch(
450
507
  extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor
451
508
  ):
452
- positions = torch.concat(
509
+ positions = torch.cat(
453
510
  [
454
511
  torch.arange(
455
512
  prefix_len, prefix_len + extend_len, device=extend_prefix_lens.device
@@ -145,10 +145,12 @@ class ModelRunner:
145
145
  "enable_nan_detection": server_args.enable_nan_detection,
146
146
  "enable_dp_attention": server_args.enable_dp_attention,
147
147
  "enable_ep_moe": server_args.enable_ep_moe,
148
+ "enable_deepep_moe": server_args.enable_deepep_moe,
148
149
  "device": server_args.device,
149
150
  "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
150
151
  "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
151
152
  "enable_flashinfer_mla": server_args.enable_flashinfer_mla,
153
+ "enable_flashmla": server_args.enable_flashmla,
152
154
  "disable_radix_cache": server_args.disable_radix_cache,
153
155
  "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
154
156
  "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
@@ -187,9 +189,6 @@ class ModelRunner:
187
189
  supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
188
190
  if self.tp_size > 1 and supports_torch_tp:
189
191
  self.apply_torch_tp()
190
- self.torch_tp_applied = True
191
- else:
192
- self.torch_tp_applied = False
193
192
 
194
193
  # Init lora
195
194
  if server_args.lora_paths is not None:
@@ -209,6 +208,10 @@ class ModelRunner:
209
208
  self.cuda_graph_runner = None
210
209
  self.init_attention_backend()
211
210
 
211
+ # auxiliary hidden capture mode. TODO: expose this to server args?
212
+ if self.spec_algorithm.is_eagle3() and not self.is_draft_worker:
213
+ self.model.set_eagle3_layers_to_capture()
214
+
212
215
  def model_specific_adjustment(self):
213
216
  server_args = self.server_args
214
217
 
@@ -223,6 +226,9 @@ class ModelRunner:
223
226
  "MLA optimization is turned on. Use flashinfer mla backend."
224
227
  )
225
228
  server_args.attention_backend = "flashinfer_mla"
229
+ elif server_args.enable_flashmla:
230
+ logger.info("MLA optimization is turned on. Use flashmla decode.")
231
+ server_args.attention_backend = "flashmla"
226
232
  else:
227
233
  logger.info("MLA optimization is turned on. Use triton backend.")
228
234
  server_args.attention_backend = "triton"
@@ -254,18 +260,41 @@ class ModelRunner:
254
260
 
255
261
  if self.model_config.hf_config.architectures == [
256
262
  "Qwen2VLForConditionalGeneration"
263
+ ] or self.model_config.hf_config.architectures == [
264
+ "Qwen2_5_VLForConditionalGeneration"
257
265
  ]:
258
- # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
266
+ # TODO: qwen2-vl series does not support radix cache now, set disable_radix_cache=True automatically
267
+ logger.info(
268
+ "Automatically turn off --chunked-prefill-size and disable radix cache for qwen-vl series."
269
+ )
270
+ server_args.chunked_prefill_size = -1
271
+ server_args.disable_radix_cache = True
272
+
273
+ if self.model_config.hf_config.architectures == ["DeepseekVL2ForCausalLM"]:
274
+ # TODO: deepseek-vl2 does not support radix cache now, set disable_radix_cache=True automatically
259
275
  logger.info(
260
- "Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
276
+ "Automatically turn off --chunked-prefill-size and disable radix cache for deepseek-vl2."
261
277
  )
262
278
  server_args.chunked_prefill_size = -1
263
279
  server_args.disable_radix_cache = True
264
280
 
281
+ if server_args.enable_deepep_moe:
282
+ logger.info("DeepEP is turned on.")
283
+ assert (
284
+ server_args.enable_dp_attention == True
285
+ ), "Currently DeepEP is bind to Attention DP. Set '--enable-dp-attention --enable-deepep-moe'"
286
+
265
287
  def init_torch_distributed(self):
266
288
  logger.info("Init torch distributed begin.")
267
289
 
268
- torch.get_device_module(self.device).set_device(self.gpu_id)
290
+ try:
291
+ torch.get_device_module(self.device).set_device(self.gpu_id)
292
+ except Exception:
293
+ logger.warning(
294
+ f"Context: {self.device=} {self.gpu_id=} {os.environ.get('CUDA_VISIBLE_DEVICES')=} {self.tp_rank=} {self.tp_size=}"
295
+ )
296
+ raise
297
+
269
298
  if self.device == "cuda":
270
299
  backend = "nccl"
271
300
  elif self.device == "xpu":
@@ -606,6 +635,8 @@ class ModelRunner:
606
635
  load_config=self.load_config,
607
636
  dtype=self.dtype,
608
637
  lora_backend=self.server_args.lora_backend,
638
+ tp_size=self.tp_size,
639
+ tp_rank=self.tp_rank,
609
640
  )
610
641
  logger.info("LoRA manager ready.")
611
642
 
@@ -840,6 +871,23 @@ class ModelRunner:
840
871
  )
841
872
 
842
873
  self.attn_backend = FlashInferMLAAttnBackend(self)
874
+ elif self.server_args.attention_backend == "flashmla":
875
+ from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
876
+
877
+ self.attn_backend = FlashMLABackend(self)
878
+ elif self.server_args.attention_backend == "fa3":
879
+ assert torch.cuda.get_device_capability()[0] >= 9, (
880
+ "FlashAttention v3 Backend requires SM>=90. "
881
+ "Please use `--attention-backend flashinfer`."
882
+ )
883
+ logger.warning(
884
+ "FlashAttention v3 Backend is in Beta. Multimodal, Page > 1, FP8, MLA and Speculative Decoding are not supported."
885
+ )
886
+ from sglang.srt.layers.attention.flashattention_backend import (
887
+ FlashAttentionBackend,
888
+ )
889
+
890
+ self.attn_backend = FlashAttentionBackend(self)
843
891
  else:
844
892
  raise ValueError(
845
893
  f"Invalid attention backend: {self.server_args.attention_backend}"
@@ -1009,6 +1057,22 @@ class ModelRunner:
1009
1057
  return False
1010
1058
  return rope_scaling.get("type", None) == "mrope"
1011
1059
 
1060
+ def save_remote_model(self, url: str):
1061
+ from sglang.srt.model_loader.loader import RemoteModelLoader
1062
+
1063
+ logger.info(f"Saving model to {url}")
1064
+ RemoteModelLoader.save_model(self.model, self.model_config.model_path, url)
1065
+
1066
+ def save_sharded_model(
1067
+ self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None
1068
+ ):
1069
+ from sglang.srt.model_loader.loader import ShardedStateLoader
1070
+
1071
+ logger.info(
1072
+ f"Save sharded model to {path} with pattern {pattern} and max_size {max_size}"
1073
+ )
1074
+ ShardedStateLoader.save_model(self.model, path, pattern, max_size)
1075
+
1012
1076
 
1013
1077
  def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
1014
1078
  params_dict = dict(model.named_parameters())
@@ -9,6 +9,7 @@ import json
9
9
  import logging
10
10
  import math
11
11
  import os
12
+ import time
12
13
  from abc import ABC, abstractmethod
13
14
  from contextlib import contextmanager
14
15
  from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
@@ -25,6 +26,12 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
25
26
  from sglang.srt.configs.device_config import DeviceConfig
26
27
  from sglang.srt.configs.load_config import LoadConfig, LoadFormat
27
28
  from sglang.srt.configs.model_config import ModelConfig
29
+ from sglang.srt.connector import (
30
+ ConnectorType,
31
+ create_remote_connector,
32
+ get_connector_type,
33
+ )
34
+ from sglang.srt.connector.utils import parse_model_name
28
35
  from sglang.srt.distributed import (
29
36
  get_tensor_model_parallel_rank,
30
37
  get_tensor_model_parallel_world_size,
@@ -46,6 +53,7 @@ from sglang.srt.model_loader.weight_utils import (
46
53
  np_cache_weights_iterator,
47
54
  pt_weights_iterator,
48
55
  safetensors_weights_iterator,
56
+ set_runai_streamer_env,
49
57
  )
50
58
  from sglang.srt.utils import (
51
59
  get_bool_env_var,
@@ -194,7 +202,7 @@ class DefaultModelLoader(BaseModelLoader):
194
202
  def _maybe_download_from_modelscope(
195
203
  self, model: str, revision: Optional[str]
196
204
  ) -> Optional[str]:
197
- """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
205
+ """Download model from ModelScope hub if SGLANG_USE_MODELSCOPE is True.
198
206
 
199
207
  Returns the path to the downloaded model, or None if the model is not
200
208
  downloaded from ModelScope."""
@@ -490,7 +498,7 @@ class ShardedStateLoader(BaseModelLoader):
490
498
  Model loader that directly loads each worker's model state dict, which
491
499
  enables a fast load path for large tensor-parallel models where each worker
492
500
  only needs to read its own shard rather than the entire checkpoint. See
493
- `examples/save_sharded_state.py` for creating a sharded checkpoint.
501
+ `examples/runtime/engine/save_sharded_state.py` for creating a sharded checkpoint.
494
502
  """
495
503
 
496
504
  DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
@@ -1204,6 +1212,153 @@ class GGUFModelLoader(BaseModelLoader):
1204
1212
  return model
1205
1213
 
1206
1214
 
1215
+ class RemoteModelLoader(BaseModelLoader):
1216
+ """Model loader that can load Tensors from remote database."""
1217
+
1218
+ def __init__(self, load_config: LoadConfig):
1219
+ super().__init__(load_config)
1220
+ # TODO @DellCurry: move to s3 connector only
1221
+ set_runai_streamer_env(load_config)
1222
+
1223
+ def _get_weights_iterator_kv(
1224
+ self,
1225
+ client,
1226
+ ) -> Generator[Tuple[str, torch.Tensor], None, None]:
1227
+ """Get an iterator for the model weights from remote storage."""
1228
+ assert get_connector_type(client) == ConnectorType.KV
1229
+ rank = get_tensor_model_parallel_rank()
1230
+ return client.weight_iterator(rank)
1231
+
1232
+ def _get_weights_iterator_fs(
1233
+ self,
1234
+ client,
1235
+ ) -> Generator[Tuple[str, torch.Tensor], None, None]:
1236
+ """Get an iterator for the model weights from remote storage."""
1237
+ assert get_connector_type(client) == ConnectorType.FS
1238
+ return client.weight_iterator()
1239
+
1240
+ def download_model(self, model_config: ModelConfig) -> None:
1241
+ pass
1242
+
1243
+ @staticmethod
1244
+ def save_model(
1245
+ model: torch.nn.Module,
1246
+ model_path: str,
1247
+ url: str,
1248
+ ) -> None:
1249
+ with create_remote_connector(url) as client:
1250
+ assert get_connector_type(client) == ConnectorType.KV
1251
+ model_name = parse_model_name(url)
1252
+ rank = get_tensor_model_parallel_rank()
1253
+ state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
1254
+ for key, tensor in state_dict.items():
1255
+ r_key = f"{model_name}/keys/rank_{rank}/{key}"
1256
+ client.set(r_key, tensor)
1257
+
1258
+ for root, _, files in os.walk(model_path):
1259
+ for file_name in files:
1260
+ # ignore hidden files
1261
+ if file_name.startswith("."):
1262
+ continue
1263
+ if os.path.splitext(file_name)[1] not in (
1264
+ ".bin",
1265
+ ".pt",
1266
+ ".safetensors",
1267
+ ):
1268
+ file_path = os.path.join(root, file_name)
1269
+ with open(file_path, encoding="utf-8") as file:
1270
+ file_content = file.read()
1271
+ f_key = f"{model_name}/files/{file_name}"
1272
+ client.setstr(f_key, file_content)
1273
+
1274
+ def _load_model_from_remote_kv(self, model: nn.Module, client):
1275
+ for _, module in model.named_modules():
1276
+ quant_method = getattr(module, "quant_method", None)
1277
+ if quant_method is not None:
1278
+ quant_method.process_weights_after_loading(module)
1279
+ weights_iterator = self._get_weights_iterator_kv(client)
1280
+ state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
1281
+ for key, tensor in weights_iterator:
1282
+ # If loading with LoRA enabled, additional padding may
1283
+ # be added to certain parameters. We only load into a
1284
+ # narrowed view of the parameter data.
1285
+ param_data = state_dict[key].data
1286
+ param_shape = state_dict[key].shape
1287
+ for dim, size in enumerate(tensor.shape):
1288
+ if size < param_shape[dim]:
1289
+ param_data = param_data.narrow(dim, 0, size)
1290
+ if tensor.shape != param_shape:
1291
+ logger.warning(
1292
+ "loading tensor of shape %s into " "parameter '%s' of shape %s",
1293
+ tensor.shape,
1294
+ key,
1295
+ param_shape,
1296
+ )
1297
+ param_data.copy_(tensor)
1298
+ state_dict.pop(key)
1299
+ if state_dict:
1300
+ raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
1301
+
1302
+ def _load_model_from_remote_fs(
1303
+ self, model, client, model_config: ModelConfig, device_config: DeviceConfig
1304
+ ) -> nn.Module:
1305
+
1306
+ target_device = torch.device(device_config.device)
1307
+ with set_default_torch_dtype(model_config.dtype):
1308
+ model.load_weights(self._get_weights_iterator_fs(client))
1309
+
1310
+ for _, module in model.named_modules():
1311
+ quant_method = getattr(module, "quant_method", None)
1312
+ if quant_method is not None:
1313
+ # When quant methods need to process weights after loading
1314
+ # (for repacking, quantizing, etc), they expect parameters
1315
+ # to be on the global target device. This scope is for the
1316
+ # case where cpu offloading is used, where we will move the
1317
+ # parameters onto device for processing and back off after.
1318
+ with device_loading_context(module, target_device):
1319
+ quant_method.process_weights_after_loading(module)
1320
+
1321
+ def load_model(
1322
+ self,
1323
+ *,
1324
+ model_config: ModelConfig,
1325
+ device_config: DeviceConfig,
1326
+ ) -> nn.Module:
1327
+ logger.info("Loading weights from remote storage ...")
1328
+ start = time.perf_counter()
1329
+ load_config = self.load_config
1330
+
1331
+ assert load_config.load_format == LoadFormat.REMOTE, (
1332
+ f"Model loader {self.load_config.load_format} is not supported for "
1333
+ f"load format {load_config.load_format}"
1334
+ )
1335
+
1336
+ model_weights = model_config.model_path
1337
+ if hasattr(model_config, "model_weights"):
1338
+ model_weights = model_config.model_weights
1339
+
1340
+ with set_default_torch_dtype(model_config.dtype):
1341
+ with torch.device(device_config.device):
1342
+ model = _initialize_model(model_config, self.load_config)
1343
+ for _, module in model.named_modules():
1344
+ quant_method = getattr(module, "quant_method", None)
1345
+ if quant_method is not None:
1346
+ quant_method.process_weights_after_loading(module)
1347
+
1348
+ with create_remote_connector(model_weights, device_config.device) as client:
1349
+ connector_type = get_connector_type(client)
1350
+ if connector_type == ConnectorType.KV:
1351
+ self._load_model_from_remote_kv(model, client)
1352
+ elif connector_type == ConnectorType.FS:
1353
+ self._load_model_from_remote_fs(
1354
+ model, client, model_config, device_config
1355
+ )
1356
+
1357
+ end = time.perf_counter()
1358
+ logger.info("Loaded weights from remote storage in %.2f seconds.", end - start)
1359
+ return model.eval()
1360
+
1361
+
1207
1362
  def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
1208
1363
  """Get a model loader based on the load format."""
1209
1364
 
@@ -1225,4 +1380,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
1225
1380
  if load_config.load_format == LoadFormat.LAYERED:
1226
1381
  return LayeredModelLoader(load_config)
1227
1382
 
1383
+ if load_config.load_format == LoadFormat.REMOTE:
1384
+ return RemoteModelLoader(load_config)
1385
+
1228
1386
  return DefaultModelLoader(load_config)
@@ -585,6 +585,51 @@ def composed_weight_loader(
585
585
  return composed_loader
586
586
 
587
587
 
588
+ def runai_safetensors_weights_iterator(
589
+ hf_weights_files: List[str],
590
+ ) -> Generator[Tuple[str, torch.Tensor], None, None]:
591
+ """Iterate over the weights in the model safetensor files."""
592
+ from runai_model_streamer import SafetensorsStreamer
593
+
594
+ enable_tqdm = (
595
+ not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
596
+ )
597
+
598
+ with SafetensorsStreamer() as streamer:
599
+ for st_file in tqdm(
600
+ hf_weights_files,
601
+ desc="Loading safetensors using Runai Model Streamer",
602
+ disable=not enable_tqdm,
603
+ bar_format=_BAR_FORMAT,
604
+ ):
605
+ streamer.stream_file(st_file)
606
+ yield from streamer.get_tensors()
607
+
608
+
609
+ def set_runai_streamer_env(load_config: LoadConfig):
610
+ if load_config.model_loader_extra_config:
611
+ extra_config = load_config.model_loader_extra_config
612
+
613
+ if "concurrency" in extra_config and isinstance(
614
+ extra_config.get("concurrency"), int
615
+ ):
616
+ os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(
617
+ extra_config.get("concurrency")
618
+ )
619
+
620
+ if "memory_limit" in extra_config and isinstance(
621
+ extra_config.get("memory_limit"), int
622
+ ):
623
+ os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(
624
+ extra_config.get("memory_limit")
625
+ )
626
+
627
+ runai_streamer_s3_endpoint = os.getenv("RUNAI_STREAMER_S3_ENDPOINT")
628
+ aws_endpoint_url = os.getenv("AWS_ENDPOINT_URL")
629
+ if runai_streamer_s3_endpoint is None and aws_endpoint_url is not None:
630
+ os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url
631
+
632
+
588
633
  def initialize_dummy_weights(
589
634
  model: torch.nn.Module,
590
635
  low: float = -1e-3,