sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +26 -4
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +676 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +49 -8
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  34. sglang/srt/distributed/parallel_state.py +42 -8
  35. sglang/srt/entrypoints/engine.py +55 -5
  36. sglang/srt/entrypoints/http_server.py +78 -13
  37. sglang/srt/entrypoints/verl_engine.py +2 -0
  38. sglang/srt/function_call_parser.py +133 -55
  39. sglang/srt/hf_transformers_utils.py +28 -3
  40. sglang/srt/layers/activation.py +4 -2
  41. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  42. sglang/srt/layers/attention/flashattention_backend.py +434 -0
  43. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  44. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  45. sglang/srt/layers/attention/triton_backend.py +171 -38
  46. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  47. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  48. sglang/srt/layers/attention/utils.py +53 -0
  49. sglang/srt/layers/attention/vision.py +9 -28
  50. sglang/srt/layers/dp_attention.py +41 -19
  51. sglang/srt/layers/layernorm.py +24 -2
  52. sglang/srt/layers/linear.py +17 -5
  53. sglang/srt/layers/logits_processor.py +25 -7
  54. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  55. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  56. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  57. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  64. sglang/srt/layers/moe/topk.py +60 -20
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +80 -53
  67. sglang/srt/layers/quantization/awq.py +200 -0
  68. sglang/srt/layers/quantization/base_config.py +5 -0
  69. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  70. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  72. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  75. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  76. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  77. sglang/srt/layers/quantization/fp8.py +76 -34
  78. sglang/srt/layers/quantization/fp8_kernel.py +25 -8
  79. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  80. sglang/srt/layers/quantization/gptq.py +36 -19
  81. sglang/srt/layers/quantization/kv_cache.py +98 -0
  82. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  83. sglang/srt/layers/quantization/utils.py +153 -0
  84. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  85. sglang/srt/layers/rotary_embedding.py +78 -87
  86. sglang/srt/layers/sampler.py +1 -1
  87. sglang/srt/lora/backend/base_backend.py +4 -4
  88. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  89. sglang/srt/lora/backend/triton_backend.py +5 -8
  90. sglang/srt/lora/layers.py +87 -33
  91. sglang/srt/lora/lora.py +2 -22
  92. sglang/srt/lora/lora_manager.py +67 -30
  93. sglang/srt/lora/mem_pool.py +117 -52
  94. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  95. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  96. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  97. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  98. sglang/srt/lora/utils.py +18 -1
  99. sglang/srt/managers/cache_controller.py +2 -5
  100. sglang/srt/managers/data_parallel_controller.py +30 -8
  101. sglang/srt/managers/expert_distribution.py +81 -0
  102. sglang/srt/managers/io_struct.py +43 -5
  103. sglang/srt/managers/mm_utils.py +373 -0
  104. sglang/srt/managers/multimodal_processor.py +68 -0
  105. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  106. sglang/srt/managers/multimodal_processors/clip.py +63 -0
  107. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  108. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  109. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  110. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  111. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  112. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  113. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  114. sglang/srt/managers/schedule_batch.py +134 -30
  115. sglang/srt/managers/scheduler.py +290 -31
  116. sglang/srt/managers/session_controller.py +1 -1
  117. sglang/srt/managers/tokenizer_manager.py +59 -24
  118. sglang/srt/managers/tp_worker.py +4 -1
  119. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  120. sglang/srt/managers/utils.py +6 -1
  121. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  122. sglang/srt/mem_cache/memory_pool.py +255 -98
  123. sglang/srt/mem_cache/paged_allocator.py +2 -2
  124. sglang/srt/mem_cache/radix_cache.py +4 -4
  125. sglang/srt/model_executor/cuda_graph_runner.py +36 -21
  126. sglang/srt/model_executor/forward_batch_info.py +68 -11
  127. sglang/srt/model_executor/model_runner.py +75 -8
  128. sglang/srt/model_loader/loader.py +171 -3
  129. sglang/srt/model_loader/weight_utils.py +51 -3
  130. sglang/srt/models/clip.py +563 -0
  131. sglang/srt/models/deepseek_janus_pro.py +31 -88
  132. sglang/srt/models/deepseek_nextn.py +22 -10
  133. sglang/srt/models/deepseek_v2.py +329 -73
  134. sglang/srt/models/deepseek_vl2.py +358 -0
  135. sglang/srt/models/gemma3_causal.py +694 -0
  136. sglang/srt/models/gemma3_mm.py +468 -0
  137. sglang/srt/models/llama.py +47 -7
  138. sglang/srt/models/llama_eagle.py +1 -0
  139. sglang/srt/models/llama_eagle3.py +196 -0
  140. sglang/srt/models/llava.py +3 -3
  141. sglang/srt/models/llavavid.py +3 -3
  142. sglang/srt/models/minicpmo.py +1995 -0
  143. sglang/srt/models/minicpmv.py +62 -137
  144. sglang/srt/models/mllama.py +4 -4
  145. sglang/srt/models/phi3_small.py +1 -1
  146. sglang/srt/models/qwen2.py +3 -0
  147. sglang/srt/models/qwen2_5_vl.py +68 -146
  148. sglang/srt/models/qwen2_classification.py +75 -0
  149. sglang/srt/models/qwen2_moe.py +9 -1
  150. sglang/srt/models/qwen2_vl.py +25 -63
  151. sglang/srt/openai_api/adapter.py +201 -104
  152. sglang/srt/openai_api/protocol.py +33 -7
  153. sglang/srt/patch_torch.py +71 -0
  154. sglang/srt/sampling/sampling_batch_info.py +1 -1
  155. sglang/srt/sampling/sampling_params.py +6 -6
  156. sglang/srt/server_args.py +114 -14
  157. sglang/srt/speculative/build_eagle_tree.py +7 -347
  158. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  159. sglang/srt/speculative/eagle_utils.py +208 -252
  160. sglang/srt/speculative/eagle_worker.py +140 -54
  161. sglang/srt/speculative/spec_info.py +6 -1
  162. sglang/srt/torch_memory_saver_adapter.py +22 -0
  163. sglang/srt/utils.py +215 -21
  164. sglang/test/__init__.py +0 -0
  165. sglang/test/attention/__init__.py +0 -0
  166. sglang/test/attention/test_flashattn_backend.py +312 -0
  167. sglang/test/runners.py +29 -2
  168. sglang/test/test_activation.py +2 -1
  169. sglang/test/test_block_fp8.py +5 -4
  170. sglang/test/test_block_fp8_ep.py +2 -1
  171. sglang/test/test_dynamic_grad_mode.py +58 -0
  172. sglang/test/test_layernorm.py +3 -2
  173. sglang/test/test_utils.py +56 -5
  174. sglang/utils.py +31 -0
  175. sglang/version.py +1 -1
  176. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
  177. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
  178. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
  179. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  180. sglang/srt/managers/image_processor.py +0 -55
  181. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  182. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  183. sglang/srt/managers/multi_modality_padding.py +0 -134
  184. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
  185. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import bisect
19
+ import os
19
20
  from contextlib import contextmanager
20
21
  from typing import TYPE_CHECKING, Callable
21
22
 
@@ -81,7 +82,9 @@ def patch_model(
81
82
  # tp_group.ca_comm = None
82
83
  yield torch.compile(
83
84
  torch.no_grad()(model.forward),
84
- mode="max-autotune-no-cudagraphs",
85
+ mode=os.environ.get(
86
+ "SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs"
87
+ ),
85
88
  dynamic=False,
86
89
  )
87
90
  else:
@@ -117,24 +120,21 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
117
120
  else:
118
121
  capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
119
122
  else:
120
- capture_bs = list(range(1, 33))
123
+ # Since speculative decoding requires more cuda graph memory, we
124
+ # capture less.
125
+ capture_bs = list(range(1, 9)) + list(range(9, 33, 2)) + [64, 96, 128, 160]
121
126
 
122
- if _is_hip:
123
- capture_bs += [i * 8 for i in range(21, 33)]
127
+ if _is_hip:
128
+ capture_bs += [i * 8 for i in range(21, 33)]
124
129
 
125
130
  if max(capture_bs) > model_runner.req_to_token_pool.size:
126
131
  # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
127
132
  # is very small. We add more values here to make sure we capture the maximum bs.
128
- capture_bs = list(
129
- sorted(
130
- set(
131
- capture_bs
132
- + [model_runner.req_to_token_pool.size - 1]
133
- + [model_runner.req_to_token_pool.size]
134
- )
135
- )
136
- )
133
+ capture_bs += [model_runner.req_to_token_pool.size - 1] + [
134
+ model_runner.req_to_token_pool.size
135
+ ]
137
136
 
137
+ capture_bs = list(sorted(set(capture_bs)))
138
138
  capture_bs = [
139
139
  bs
140
140
  for bs in capture_bs
@@ -174,6 +174,7 @@ class CudaGraphRunner:
174
174
  self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
175
175
  self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
176
176
  self.enable_dp_attention = model_runner.server_args.enable_dp_attention
177
+ self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
177
178
  self.speculative_algorithm = model_runner.server_args.speculative_algorithm
178
179
  self.tp_size = model_runner.server_args.tp_size
179
180
  self.dp_size = model_runner.server_args.dp_size
@@ -220,7 +221,19 @@ class CudaGraphRunner:
220
221
  self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
221
222
 
222
223
  # Speculative_inference
223
- if model_runner.spec_algorithm.is_eagle():
224
+ if (
225
+ model_runner.spec_algorithm.is_eagle3()
226
+ and not model_runner.is_draft_worker
227
+ ):
228
+ self.hidden_states = torch.zeros(
229
+ (
230
+ self.max_num_token,
231
+ 3 * self.model_runner.model_config.hidden_size,
232
+ ),
233
+ dtype=self.model_runner.dtype,
234
+ )
235
+ self.model_runner.model.set_eagle3_layers_to_capture()
236
+ elif model_runner.spec_algorithm.is_eagle():
224
237
  self.hidden_states = torch.zeros(
225
238
  (self.max_num_token, self.model_runner.model_config.hidden_size),
226
239
  dtype=self.model_runner.dtype,
@@ -233,8 +246,8 @@ class CudaGraphRunner:
233
246
  )
234
247
  else:
235
248
  self.encoder_lens = None
236
-
237
- if self.enable_dp_attention:
249
+ if self.enable_dp_attention or self.enable_sp_layernorm:
250
+ # TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
238
251
  self.gathered_buffer = torch.zeros(
239
252
  (
240
253
  self.max_bs * self.dp_size * self.num_tokens_per_bs,
@@ -276,7 +289,7 @@ class CudaGraphRunner:
276
289
  self.model_runner.token_to_kv_pool.capture_mode = False
277
290
 
278
291
  def can_run(self, forward_batch: ForwardBatch):
279
- if self.enable_dp_attention:
292
+ if self.enable_dp_attention or self.enable_sp_layernorm:
280
293
  total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
281
294
 
282
295
  is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
@@ -357,7 +370,7 @@ class CudaGraphRunner:
357
370
  encoder_lens = None
358
371
  mrope_positions = self.mrope_positions[:, :bs]
359
372
 
360
- if self.enable_dp_attention:
373
+ if self.enable_dp_attention or self.enable_sp_layernorm:
361
374
  self.global_num_tokens_gpu.copy_(
362
375
  torch.tensor(
363
376
  [
@@ -459,7 +472,7 @@ class CudaGraphRunner:
459
472
  raw_num_token = raw_bs * self.num_tokens_per_bs
460
473
 
461
474
  # Pad
462
- if self.enable_dp_attention:
475
+ if self.enable_dp_attention or self.enable_sp_layernorm:
463
476
  index = bisect.bisect_left(
464
477
  self.capture_bs, sum(forward_batch.global_num_tokens_cpu)
465
478
  )
@@ -485,7 +498,7 @@ class CudaGraphRunner:
485
498
  self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
486
499
  if forward_batch.mrope_positions is not None:
487
500
  self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
488
- if self.enable_dp_attention:
501
+ if self.enable_dp_attention or self.enable_sp_layernorm:
489
502
  self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
490
503
 
491
504
  if hasattr(forward_batch.spec_info, "hidden_states"):
@@ -508,7 +521,9 @@ class CudaGraphRunner:
508
521
  self.raw_num_token = raw_num_token
509
522
  self.bs = bs
510
523
 
511
- def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False):
524
+ def replay(
525
+ self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
526
+ ) -> LogitsProcessorOutput:
512
527
  if not skip_attn_backend_init:
513
528
  self.replay_prepare(forward_batch)
514
529
  else:
@@ -33,6 +33,7 @@ from dataclasses import dataclass
33
33
  from enum import IntEnum, auto
34
34
  from typing import TYPE_CHECKING, List, Optional, Union
35
35
 
36
+ import numpy as np
36
37
  import torch
37
38
  import triton
38
39
  import triton.language as tl
@@ -42,7 +43,7 @@ from sglang.srt.utils import get_compiler_backend
42
43
 
43
44
  if TYPE_CHECKING:
44
45
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
45
- from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
46
+ from sglang.srt.managers.schedule_batch import ModelWorkerBatch, MultimodalInputs
46
47
  from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
47
48
  from sglang.srt.model_executor.model_runner import ModelRunner
48
49
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
@@ -175,7 +176,7 @@ class ForwardBatch:
175
176
  extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
176
177
 
177
178
  # For multimodal
178
- image_inputs: Optional[List[ImageInputs]] = None
179
+ mm_inputs: Optional[List[MultimodalInputs]] = None
179
180
 
180
181
  # Encoder-decoder
181
182
  encoder_cached: Optional[List[bool]] = None
@@ -241,7 +242,7 @@ class ForwardBatch:
241
242
  req_pool_indices=batch.req_pool_indices,
242
243
  seq_lens=batch.seq_lens,
243
244
  out_cache_loc=batch.out_cache_loc,
244
- image_inputs=batch.image_inputs,
245
+ mm_inputs=batch.multimodal_inputs,
245
246
  encoder_cached=batch.encoder_cached,
246
247
  encoder_lens=batch.encoder_lens,
247
248
  encoder_lens_cpu=batch.encoder_lens_cpu,
@@ -331,6 +332,53 @@ class ForwardBatch:
331
332
 
332
333
  return ret
333
334
 
335
+ def merge_mm_inputs(self) -> Optional[MultimodalInputs]:
336
+ """
337
+ Merge all image inputs in the batch into a single MultiModalInputs object.
338
+
339
+ Returns:
340
+ if none, current batch contains no image input
341
+
342
+ """
343
+ if not self.mm_inputs or all(x is None for x in self.mm_inputs):
344
+ return None
345
+
346
+ # Filter out None values
347
+ valid_inputs = [x for x in self.mm_inputs if x is not None]
348
+
349
+ # Start with the first valid image input
350
+ merged = valid_inputs[0]
351
+
352
+ # Merge remaining inputs
353
+ for mm_input in valid_inputs[1:]:
354
+ merged.merge(mm_input)
355
+
356
+ if isinstance(merged.pixel_values, np.ndarray):
357
+ merged.pixel_values = torch.from_numpy(merged.pixel_values)
358
+ if isinstance(merged.audio_features, np.ndarray):
359
+ merged.audio_features = torch.from_numpy(merged.audio_features)
360
+
361
+ return merged
362
+
363
+ def contains_image_inputs(self) -> bool:
364
+ if self.mm_inputs is None:
365
+ return False
366
+ return any(
367
+ mm_input is not None and mm_input.contains_image_inputs()
368
+ for mm_input in self.mm_inputs
369
+ )
370
+
371
+ def contains_audio_inputs(self) -> bool:
372
+ if self.mm_inputs is None:
373
+ return False
374
+ return any(
375
+ mm_input is not None and mm_input.contains_audio_inputs()
376
+ for mm_input in self.mm_inputs
377
+ )
378
+
379
+ def contains_mm_inputs(self) -> bool:
380
+ return self.contains_audio_inputs() or self.contains_image_inputs()
381
+
334
382
  def _compute_mrope_positions(
335
383
  self, model_runner: ModelRunner, batch: ModelWorkerBatch
336
384
  ):
@@ -341,8 +389,8 @@ class ForwardBatch:
341
389
  for i, _ in enumerate(mrope_positions_list):
342
390
  mrope_position_delta = (
343
391
  0
344
- if batch.image_inputs[i] is None
345
- else batch.image_inputs[i].mrope_position_delta
392
+ if batch.multimodal_inputs[i] is None
393
+ else batch.multimodal_inputs[i].mrope_position_delta
346
394
  )
347
395
  mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
348
396
  mrope_position_delta,
@@ -351,13 +399,13 @@ class ForwardBatch:
351
399
  )
352
400
  elif self.forward_mode.is_extend():
353
401
  extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
354
- for i, image_inputs in enumerate(batch.image_inputs):
402
+ for i, multimodal_inputs in enumerate(batch.multimodal_inputs):
355
403
  extend_start_loc, extend_seq_len, extend_prefix_len = (
356
404
  extend_start_loc_cpu[i],
357
405
  batch.extend_seq_lens[i],
358
406
  batch.extend_prefix_lens[i],
359
407
  )
360
- if image_inputs is None:
408
+ if multimodal_inputs is None:
361
409
  # text only
362
410
  mrope_positions = [
363
411
  [
@@ -374,16 +422,25 @@ class ForwardBatch:
374
422
  input_tokens=self.input_ids[
375
423
  extend_start_loc : extend_start_loc + extend_seq_len
376
424
  ],
377
- image_grid_thw=image_inputs.image_grid_thws,
425
+ image_grid_thw=multimodal_inputs.image_grid_thws,
426
+ video_grid_thw=multimodal_inputs.video_grid_thws,
427
+ image_token_id=multimodal_inputs.im_token_id,
428
+ video_token_id=multimodal_inputs.video_token_id,
378
429
  vision_start_token_id=hf_config.vision_start_token_id,
430
+ vision_end_token_id=hf_config.vision_end_token_id,
379
431
  spatial_merge_size=hf_config.vision_config.spatial_merge_size,
380
432
  context_len=0,
433
+ seq_len=len(self.input_ids),
434
+ second_per_grid_ts=multimodal_inputs.second_per_grid_ts,
435
+ tokens_per_second=hf_config.vision_config.tokens_per_second,
381
436
  )
382
437
  )
383
- batch.image_inputs[i].mrope_position_delta = mrope_position_delta
438
+ batch.multimodal_inputs[i].mrope_position_delta = (
439
+ mrope_position_delta
440
+ )
384
441
  mrope_positions_list[i] = mrope_positions
385
442
 
386
- self.mrope_positions = torch.concat(
443
+ self.mrope_positions = torch.cat(
387
444
  [torch.tensor(pos, device=device) for pos in mrope_positions_list],
388
445
  axis=1,
389
446
  )
@@ -449,7 +506,7 @@ def compute_position_kernel(
449
506
  def compute_position_torch(
450
507
  extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor
451
508
  ):
452
- positions = torch.concat(
509
+ positions = torch.cat(
453
510
  [
454
511
  torch.arange(
455
512
  prefix_len, prefix_len + extend_len, device=extend_prefix_lens.device
@@ -64,6 +64,7 @@ from sglang.srt.model_loader.loader import (
64
64
  )
65
65
  from sglang.srt.model_loader.utils import set_default_torch_dtype
66
66
  from sglang.srt.model_loader.weight_utils import default_weight_loader
67
+ from sglang.srt.patch_torch import monkey_patch_torch_reductions
67
68
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
68
69
  from sglang.srt.server_args import ServerArgs
69
70
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
@@ -145,10 +146,12 @@ class ModelRunner:
145
146
  "enable_nan_detection": server_args.enable_nan_detection,
146
147
  "enable_dp_attention": server_args.enable_dp_attention,
147
148
  "enable_ep_moe": server_args.enable_ep_moe,
149
+ "enable_deepep_moe": server_args.enable_deepep_moe,
148
150
  "device": server_args.device,
149
151
  "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
150
152
  "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
151
153
  "enable_flashinfer_mla": server_args.enable_flashinfer_mla,
154
+ "enable_flashmla": server_args.enable_flashmla,
152
155
  "disable_radix_cache": server_args.disable_radix_cache,
153
156
  "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
154
157
  "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
@@ -187,9 +190,6 @@ class ModelRunner:
187
190
  supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
188
191
  if self.tp_size > 1 and supports_torch_tp:
189
192
  self.apply_torch_tp()
190
- self.torch_tp_applied = True
191
- else:
192
- self.torch_tp_applied = False
193
193
 
194
194
  # Init lora
195
195
  if server_args.lora_paths is not None:
@@ -209,6 +209,10 @@ class ModelRunner:
209
209
  self.cuda_graph_runner = None
210
210
  self.init_attention_backend()
211
211
 
212
+ # auxiliary hidden capture mode. TODO: expose this to server args?
213
+ if self.spec_algorithm.is_eagle3() and not self.is_draft_worker:
214
+ self.model.set_eagle3_layers_to_capture()
215
+
212
216
  def model_specific_adjustment(self):
213
217
  server_args = self.server_args
214
218
 
@@ -223,6 +227,13 @@ class ModelRunner:
223
227
  "MLA optimization is turned on. Use flashinfer mla backend."
224
228
  )
225
229
  server_args.attention_backend = "flashinfer_mla"
230
+ elif server_args.enable_flashmla:
231
+ logger.info("MLA optimization is turned on. Use flashmla decode.")
232
+ server_args.attention_backend = "flashmla"
233
+ elif server_args.attention_backend == "fa3":
234
+ logger.info(
235
+ f"MLA optimization is turned on. Use flash attention 3 backend."
236
+ )
226
237
  else:
227
238
  logger.info("MLA optimization is turned on. Use triton backend.")
228
239
  server_args.attention_backend = "triton"
@@ -254,18 +265,38 @@ class ModelRunner:
254
265
 
255
266
  if self.model_config.hf_config.architectures == [
256
267
  "Qwen2VLForConditionalGeneration"
268
+ ] or self.model_config.hf_config.architectures == [
269
+ "Qwen2_5_VLForConditionalGeneration"
257
270
  ]:
258
- # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
271
+ # TODO: qwen2-vl series does not support radix cache now, set disable_radix_cache=True automatically
259
272
  logger.info(
260
- "Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
273
+ "Automatically turn off --chunked-prefill-size and disable radix cache for qwen-vl series."
261
274
  )
262
275
  server_args.chunked_prefill_size = -1
263
276
  server_args.disable_radix_cache = True
264
277
 
278
+ if self.model_config.hf_config.architectures == ["DeepseekVL2ForCausalLM"]:
279
+ # TODO: deepseek-vl2 does not support radix cache now, set disable_radix_cache=True automatically
280
+ logger.info(
281
+ "Automatically turn off --chunked-prefill-size and disable radix cache for deepseek-vl2."
282
+ )
283
+ server_args.chunked_prefill_size = -1
284
+ server_args.disable_radix_cache = True
285
+
286
+ if server_args.enable_deepep_moe:
287
+ logger.info("DeepEP is turned on.")
288
+
265
289
  def init_torch_distributed(self):
266
290
  logger.info("Init torch distributed begin.")
267
291
 
268
- torch.get_device_module(self.device).set_device(self.gpu_id)
292
+ try:
293
+ torch.get_device_module(self.device).set_device(self.gpu_id)
294
+ except Exception:
295
+ logger.warning(
296
+ f"Context: {self.device=} {self.gpu_id=} {os.environ.get('CUDA_VISIBLE_DEVICES')=} {self.tp_rank=} {self.tp_size=}"
297
+ )
298
+ raise
299
+
269
300
  if self.device == "cuda":
270
301
  backend = "nccl"
271
302
  elif self.device == "xpu":
@@ -606,6 +637,8 @@ class ModelRunner:
606
637
  load_config=self.load_config,
607
638
  dtype=self.dtype,
608
639
  lora_backend=self.server_args.lora_backend,
640
+ tp_size=self.tp_size,
641
+ tp_rank=self.tp_rank,
609
642
  )
610
643
  logger.info("LoRA manager ready.")
611
644
 
@@ -840,6 +873,23 @@ class ModelRunner:
840
873
  )
841
874
 
842
875
  self.attn_backend = FlashInferMLAAttnBackend(self)
876
+ elif self.server_args.attention_backend == "flashmla":
877
+ from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
878
+
879
+ self.attn_backend = FlashMLABackend(self)
880
+ elif self.server_args.attention_backend == "fa3":
881
+ assert torch.cuda.get_device_capability()[0] >= 9, (
882
+ "FlashAttention v3 Backend requires SM>=90. "
883
+ "Please use `--attention-backend flashinfer`."
884
+ )
885
+ logger.warning(
886
+ "FlashAttention v3 Backend is in Beta. Multimodal, FP8, and Speculative Decoding are not supported."
887
+ )
888
+ from sglang.srt.layers.attention.flashattention_backend import (
889
+ FlashAttentionBackend,
890
+ )
891
+
892
+ self.attn_backend = FlashAttentionBackend(self)
843
893
  else:
844
894
  raise ValueError(
845
895
  f"Invalid attention backend: {self.server_args.attention_backend}"
@@ -1009,6 +1059,22 @@ class ModelRunner:
1009
1059
  return False
1010
1060
  return rope_scaling.get("type", None) == "mrope"
1011
1061
 
1062
+ def save_remote_model(self, url: str):
1063
+ from sglang.srt.model_loader.loader import RemoteModelLoader
1064
+
1065
+ logger.info(f"Saving model to {url}")
1066
+ RemoteModelLoader.save_model(self.model, self.model_config.model_path, url)
1067
+
1068
+ def save_sharded_model(
1069
+ self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None
1070
+ ):
1071
+ from sglang.srt.model_loader.loader import ShardedStateLoader
1072
+
1073
+ logger.info(
1074
+ f"Save sharded model to {path} with pattern {pattern} and max_size {max_size}"
1075
+ )
1076
+ ShardedStateLoader.save_model(self.model, path, pattern, max_size)
1077
+
1012
1078
 
1013
1079
  def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
1014
1080
  params_dict = dict(model.named_parameters())
@@ -1018,8 +1084,9 @@ def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tenso
1018
1084
 
1019
1085
  def _unwrap_tensor(tensor, tp_rank):
1020
1086
  if isinstance(tensor, LocalSerializedTensor):
1021
- return tensor.get(tp_rank)
1022
- return tensor
1087
+ monkey_patch_torch_reductions()
1088
+ tensor = tensor.get(tp_rank)
1089
+ return tensor.to(torch.cuda.current_device())
1023
1090
 
1024
1091
 
1025
1092
  @dataclass
@@ -9,11 +9,11 @@ import json
9
9
  import logging
10
10
  import math
11
11
  import os
12
+ import time
12
13
  from abc import ABC, abstractmethod
13
14
  from contextlib import contextmanager
14
15
  from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
15
16
 
16
- import gguf
17
17
  import huggingface_hub
18
18
  import numpy as np
19
19
  import torch
@@ -25,6 +25,12 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
25
25
  from sglang.srt.configs.device_config import DeviceConfig
26
26
  from sglang.srt.configs.load_config import LoadConfig, LoadFormat
27
27
  from sglang.srt.configs.model_config import ModelConfig
28
+ from sglang.srt.connector import (
29
+ ConnectorType,
30
+ create_remote_connector,
31
+ get_connector_type,
32
+ )
33
+ from sglang.srt.connector.utils import parse_model_name
28
34
  from sglang.srt.distributed import (
29
35
  get_tensor_model_parallel_rank,
30
36
  get_tensor_model_parallel_world_size,
@@ -46,6 +52,7 @@ from sglang.srt.model_loader.weight_utils import (
46
52
  np_cache_weights_iterator,
47
53
  pt_weights_iterator,
48
54
  safetensors_weights_iterator,
55
+ set_runai_streamer_env,
49
56
  )
50
57
  from sglang.srt.utils import (
51
58
  get_bool_env_var,
@@ -194,7 +201,7 @@ class DefaultModelLoader(BaseModelLoader):
194
201
  def _maybe_download_from_modelscope(
195
202
  self, model: str, revision: Optional[str]
196
203
  ) -> Optional[str]:
197
- """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
204
+ """Download model from ModelScope hub if SGLANG_USE_MODELSCOPE is True.
198
205
 
199
206
  Returns the path to the downloaded model, or None if the model is not
200
207
  downloaded from ModelScope."""
@@ -490,7 +497,7 @@ class ShardedStateLoader(BaseModelLoader):
490
497
  Model loader that directly loads each worker's model state dict, which
491
498
  enables a fast load path for large tensor-parallel models where each worker
492
499
  only needs to read its own shard rather than the entire checkpoint. See
493
- `examples/save_sharded_state.py` for creating a sharded checkpoint.
500
+ `examples/runtime/engine/save_sharded_state.py` for creating a sharded checkpoint.
494
501
  """
495
502
 
496
503
  DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
@@ -1147,6 +1154,17 @@ class GGUFModelLoader(BaseModelLoader):
1147
1154
  See "Standardized tensor names" in
1148
1155
  https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
1149
1156
  """
1157
+
1158
+ # only load the gguf module when needed
1159
+ try:
1160
+ import gguf
1161
+
1162
+ # FIXME: add version check for gguf
1163
+ except ImportError as err:
1164
+ raise ImportError(
1165
+ "Please install gguf via `pip install gguf` to use gguf quantizer."
1166
+ ) from err
1167
+
1150
1168
  config = model_config.hf_config
1151
1169
  model_type = config.model_type
1152
1170
  # hack: ggufs have a different name than transformers
@@ -1204,6 +1222,153 @@ class GGUFModelLoader(BaseModelLoader):
1204
1222
  return model
1205
1223
 
1206
1224
 
1225
+ class RemoteModelLoader(BaseModelLoader):
1226
+ """Model loader that can load Tensors from remote database."""
1227
+
1228
+ def __init__(self, load_config: LoadConfig):
1229
+ super().__init__(load_config)
1230
+ # TODO @DellCurry: move to s3 connector only
1231
+ set_runai_streamer_env(load_config)
1232
+
1233
+ def _get_weights_iterator_kv(
1234
+ self,
1235
+ client,
1236
+ ) -> Generator[Tuple[str, torch.Tensor], None, None]:
1237
+ """Get an iterator for the model weights from remote storage."""
1238
+ assert get_connector_type(client) == ConnectorType.KV
1239
+ rank = get_tensor_model_parallel_rank()
1240
+ return client.weight_iterator(rank)
1241
+
1242
+ def _get_weights_iterator_fs(
1243
+ self,
1244
+ client,
1245
+ ) -> Generator[Tuple[str, torch.Tensor], None, None]:
1246
+ """Get an iterator for the model weights from remote storage."""
1247
+ assert get_connector_type(client) == ConnectorType.FS
1248
+ return client.weight_iterator()
1249
+
1250
+ def download_model(self, model_config: ModelConfig) -> None:
1251
+ pass
1252
+
1253
+ @staticmethod
1254
+ def save_model(
1255
+ model: torch.nn.Module,
1256
+ model_path: str,
1257
+ url: str,
1258
+ ) -> None:
1259
+ with create_remote_connector(url) as client:
1260
+ assert get_connector_type(client) == ConnectorType.KV
1261
+ model_name = parse_model_name(url)
1262
+ rank = get_tensor_model_parallel_rank()
1263
+ state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
1264
+ for key, tensor in state_dict.items():
1265
+ r_key = f"{model_name}/keys/rank_{rank}/{key}"
1266
+ client.set(r_key, tensor)
1267
+
1268
+ for root, _, files in os.walk(model_path):
1269
+ for file_name in files:
1270
+ # ignore hidden files
1271
+ if file_name.startswith("."):
1272
+ continue
1273
+ if os.path.splitext(file_name)[1] not in (
1274
+ ".bin",
1275
+ ".pt",
1276
+ ".safetensors",
1277
+ ):
1278
+ file_path = os.path.join(root, file_name)
1279
+ with open(file_path, encoding="utf-8") as file:
1280
+ file_content = file.read()
1281
+ f_key = f"{model_name}/files/{file_name}"
1282
+ client.setstr(f_key, file_content)
1283
+
1284
+ def _load_model_from_remote_kv(self, model: nn.Module, client):
1285
+ for _, module in model.named_modules():
1286
+ quant_method = getattr(module, "quant_method", None)
1287
+ if quant_method is not None:
1288
+ quant_method.process_weights_after_loading(module)
1289
+ weights_iterator = self._get_weights_iterator_kv(client)
1290
+ state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
1291
+ for key, tensor in weights_iterator:
1292
+ # If loading with LoRA enabled, additional padding may
1293
+ # be added to certain parameters. We only load into a
1294
+ # narrowed view of the parameter data.
1295
+ param_data = state_dict[key].data
1296
+ param_shape = state_dict[key].shape
1297
+ for dim, size in enumerate(tensor.shape):
1298
+ if size < param_shape[dim]:
1299
+ param_data = param_data.narrow(dim, 0, size)
1300
+ if tensor.shape != param_shape:
1301
+ logger.warning(
1302
+ "loading tensor of shape %s into " "parameter '%s' of shape %s",
1303
+ tensor.shape,
1304
+ key,
1305
+ param_shape,
1306
+ )
1307
+ param_data.copy_(tensor)
1308
+ state_dict.pop(key)
1309
+ if state_dict:
1310
+ raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
1311
+
1312
+ def _load_model_from_remote_fs(
1313
+ self, model, client, model_config: ModelConfig, device_config: DeviceConfig
1314
+ ) -> nn.Module:
1315
+
1316
+ target_device = torch.device(device_config.device)
1317
+ with set_default_torch_dtype(model_config.dtype):
1318
+ model.load_weights(self._get_weights_iterator_fs(client))
1319
+
1320
+ for _, module in model.named_modules():
1321
+ quant_method = getattr(module, "quant_method", None)
1322
+ if quant_method is not None:
1323
+ # When quant methods need to process weights after loading
1324
+ # (for repacking, quantizing, etc), they expect parameters
1325
+ # to be on the global target device. This scope is for the
1326
+ # case where cpu offloading is used, where we will move the
1327
+ # parameters onto device for processing and back off after.
1328
+ with device_loading_context(module, target_device):
1329
+ quant_method.process_weights_after_loading(module)
1330
+
1331
+ def load_model(
1332
+ self,
1333
+ *,
1334
+ model_config: ModelConfig,
1335
+ device_config: DeviceConfig,
1336
+ ) -> nn.Module:
1337
+ logger.info("Loading weights from remote storage ...")
1338
+ start = time.perf_counter()
1339
+ load_config = self.load_config
1340
+
1341
+ assert load_config.load_format == LoadFormat.REMOTE, (
1342
+ f"Model loader {self.load_config.load_format} is not supported for "
1343
+ f"load format {load_config.load_format}"
1344
+ )
1345
+
1346
+ model_weights = model_config.model_path
1347
+ if hasattr(model_config, "model_weights"):
1348
+ model_weights = model_config.model_weights
1349
+
1350
+ with set_default_torch_dtype(model_config.dtype):
1351
+ with torch.device(device_config.device):
1352
+ model = _initialize_model(model_config, self.load_config)
1353
+ for _, module in model.named_modules():
1354
+ quant_method = getattr(module, "quant_method", None)
1355
+ if quant_method is not None:
1356
+ quant_method.process_weights_after_loading(module)
1357
+
1358
+ with create_remote_connector(model_weights, device_config.device) as client:
1359
+ connector_type = get_connector_type(client)
1360
+ if connector_type == ConnectorType.KV:
1361
+ self._load_model_from_remote_kv(model, client)
1362
+ elif connector_type == ConnectorType.FS:
1363
+ self._load_model_from_remote_fs(
1364
+ model, client, model_config, device_config
1365
+ )
1366
+
1367
+ end = time.perf_counter()
1368
+ logger.info("Loaded weights from remote storage in %.2f seconds.", end - start)
1369
+ return model.eval()
1370
+
1371
+
1207
1372
  def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
1208
1373
  """Get a model loader based on the load format."""
1209
1374
 
@@ -1225,4 +1390,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
1225
1390
  if load_config.load_format == LoadFormat.LAYERED:
1226
1391
  return LayeredModelLoader(load_config)
1227
1392
 
1393
+ if load_config.load_format == LoadFormat.REMOTE:
1394
+ return RemoteModelLoader(load_config)
1395
+
1228
1396
  return DefaultModelLoader(load_config)