sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +133 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +32 -21
  49. sglang/srt/layers/layernorm.py +24 -2
  50. sglang/srt/layers/linear.py +17 -5
  51. sglang/srt/layers/logits_processor.py +25 -7
  52. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  53. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  54. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  55. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  61. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  62. sglang/srt/layers/moe/topk.py +31 -18
  63. sglang/srt/layers/parameter.py +1 -1
  64. sglang/srt/layers/quantization/__init__.py +184 -126
  65. sglang/srt/layers/quantization/base_config.py +5 -0
  66. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  67. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  69. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  70. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  71. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  73. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  74. sglang/srt/layers/quantization/fp8.py +76 -34
  75. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  76. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  77. sglang/srt/layers/quantization/gptq.py +36 -9
  78. sglang/srt/layers/quantization/kv_cache.py +98 -0
  79. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  80. sglang/srt/layers/quantization/utils.py +153 -0
  81. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  82. sglang/srt/layers/rotary_embedding.py +66 -87
  83. sglang/srt/layers/sampler.py +1 -1
  84. sglang/srt/lora/layers.py +68 -0
  85. sglang/srt/lora/lora.py +2 -22
  86. sglang/srt/lora/lora_manager.py +47 -23
  87. sglang/srt/lora/mem_pool.py +110 -51
  88. sglang/srt/lora/utils.py +12 -1
  89. sglang/srt/managers/cache_controller.py +2 -5
  90. sglang/srt/managers/data_parallel_controller.py +30 -8
  91. sglang/srt/managers/expert_distribution.py +81 -0
  92. sglang/srt/managers/io_struct.py +39 -3
  93. sglang/srt/managers/mm_utils.py +373 -0
  94. sglang/srt/managers/multimodal_processor.py +68 -0
  95. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  96. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  97. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  98. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  99. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  100. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  101. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  102. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  103. sglang/srt/managers/schedule_batch.py +133 -30
  104. sglang/srt/managers/scheduler.py +273 -20
  105. sglang/srt/managers/session_controller.py +1 -1
  106. sglang/srt/managers/tokenizer_manager.py +59 -23
  107. sglang/srt/managers/tp_worker.py +1 -1
  108. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  109. sglang/srt/managers/utils.py +6 -1
  110. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  111. sglang/srt/mem_cache/memory_pool.py +255 -98
  112. sglang/srt/mem_cache/paged_allocator.py +2 -2
  113. sglang/srt/mem_cache/radix_cache.py +4 -4
  114. sglang/srt/model_executor/cuda_graph_runner.py +27 -13
  115. sglang/srt/model_executor/forward_batch_info.py +68 -11
  116. sglang/srt/model_executor/model_runner.py +70 -6
  117. sglang/srt/model_loader/loader.py +160 -2
  118. sglang/srt/model_loader/weight_utils.py +45 -0
  119. sglang/srt/models/deepseek_janus_pro.py +29 -86
  120. sglang/srt/models/deepseek_nextn.py +22 -10
  121. sglang/srt/models/deepseek_v2.py +208 -77
  122. sglang/srt/models/deepseek_vl2.py +358 -0
  123. sglang/srt/models/gemma3_causal.py +684 -0
  124. sglang/srt/models/gemma3_mm.py +462 -0
  125. sglang/srt/models/llama.py +47 -7
  126. sglang/srt/models/llama_eagle.py +1 -0
  127. sglang/srt/models/llama_eagle3.py +196 -0
  128. sglang/srt/models/llava.py +3 -3
  129. sglang/srt/models/llavavid.py +3 -3
  130. sglang/srt/models/minicpmo.py +1995 -0
  131. sglang/srt/models/minicpmv.py +62 -137
  132. sglang/srt/models/mllama.py +4 -4
  133. sglang/srt/models/phi3_small.py +1 -1
  134. sglang/srt/models/qwen2.py +3 -0
  135. sglang/srt/models/qwen2_5_vl.py +68 -146
  136. sglang/srt/models/qwen2_classification.py +75 -0
  137. sglang/srt/models/qwen2_moe.py +9 -1
  138. sglang/srt/models/qwen2_vl.py +25 -63
  139. sglang/srt/openai_api/adapter.py +124 -28
  140. sglang/srt/openai_api/protocol.py +23 -2
  141. sglang/srt/sampling/sampling_batch_info.py +1 -1
  142. sglang/srt/sampling/sampling_params.py +6 -6
  143. sglang/srt/server_args.py +99 -9
  144. sglang/srt/speculative/build_eagle_tree.py +7 -347
  145. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  146. sglang/srt/speculative/eagle_utils.py +208 -252
  147. sglang/srt/speculative/eagle_worker.py +139 -53
  148. sglang/srt/speculative/spec_info.py +6 -1
  149. sglang/srt/torch_memory_saver_adapter.py +22 -0
  150. sglang/srt/utils.py +182 -21
  151. sglang/test/__init__.py +0 -0
  152. sglang/test/attention/__init__.py +0 -0
  153. sglang/test/attention/test_flashattn_backend.py +312 -0
  154. sglang/test/runners.py +2 -0
  155. sglang/test/test_activation.py +2 -1
  156. sglang/test/test_block_fp8.py +5 -4
  157. sglang/test/test_block_fp8_ep.py +2 -1
  158. sglang/test/test_dynamic_grad_mode.py +58 -0
  159. sglang/test/test_layernorm.py +3 -2
  160. sglang/test/test_utils.py +55 -4
  161. sglang/utils.py +31 -0
  162. sglang/version.py +1 -1
  163. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  164. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +167 -123
  165. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  166. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  167. sglang/srt/managers/image_processor.py +0 -55
  168. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  169. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  170. sglang/srt/managers/multi_modality_padding.py +0 -134
  171. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  172. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -26,7 +26,6 @@ import logging
26
26
  from functools import lru_cache, partial
27
27
  from typing import Iterable, List, Optional, Tuple, Type
28
28
 
29
- import numpy as np
30
29
  import torch
31
30
  import torch.nn as nn
32
31
  import torch.nn.functional as F
@@ -34,8 +33,15 @@ from einops import rearrange
34
33
  from transformers import AutoModel, Qwen2VLConfig
35
34
  from transformers.activations import ACT2FN
36
35
  from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
36
+ from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
37
+ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
38
+ Qwen2_5_VLConfig,
39
+ Qwen2_5_VLVisionConfig,
40
+ )
41
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
42
+ Qwen2_5_VLForConditionalGeneration,
43
+ )
37
44
 
38
- from sglang.srt.configs import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig
39
45
  from sglang.srt.distributed import (
40
46
  get_tensor_model_parallel_rank,
41
47
  get_tensor_model_parallel_world_size,
@@ -47,14 +53,15 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
47
53
  from sglang.srt.layers.pooler import Pooler, PoolingType
48
54
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
49
55
  from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
50
- from sglang.srt.managers.multi_modality_padding import (
56
+ from sglang.srt.managers.mm_utils import (
51
57
  MultiModalityDataPaddingPatternTokenPairs,
58
+ general_mm_embed_routine,
52
59
  )
53
- from sglang.srt.managers.schedule_batch import ImageInputs
60
+ from sglang.srt.managers.schedule_batch import MultimodalInputs
54
61
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
55
62
  from sglang.srt.model_loader.weight_utils import default_weight_loader
56
63
  from sglang.srt.models.qwen2 import Qwen2Model
57
- from sglang.srt.models.qwen2_vl import Qwen2VLImageInputs, Qwen2VLVideoInputs
64
+ from sglang.srt.models.qwen2_vl import Qwen2VLVideoInputs
58
65
  from sglang.srt.utils import add_prefix
59
66
 
60
67
  logger = logging.getLogger(__name__)
@@ -125,12 +132,15 @@ class Qwen2_5_VisionBlock(nn.Module):
125
132
  if attn_implementation == "sdpa":
126
133
  use_context_forward = False
127
134
  softmax_in_single_precision = False
135
+ flatten_batch = True
128
136
  elif attn_implementation == "flash_attention_2":
129
137
  softmax_in_single_precision = False
130
138
  use_context_forward = True
139
+ flatten_batch = True
131
140
  elif attn_implementation == "eager":
132
141
  softmax_in_single_precision = True
133
142
  use_context_forward = False
143
+ flatten_batch = True
134
144
 
135
145
  self.attn = VisionAttention(
136
146
  embed_dim=dim,
@@ -139,7 +149,7 @@ class Qwen2_5_VisionBlock(nn.Module):
139
149
  use_qkv_parallel=False,
140
150
  use_context_forward=use_context_forward,
141
151
  softmax_in_single_precision=softmax_in_single_precision,
142
- flatten_batch=True,
152
+ flatten_batch=flatten_batch,
143
153
  quant_config=quant_config,
144
154
  prefix=add_prefix("attn", prefix),
145
155
  )
@@ -192,9 +202,10 @@ class Qwen2_5_VisionPatchEmbed(nn.Module):
192
202
  )
193
203
 
194
204
  def forward(self, x: torch.Tensor) -> torch.Tensor:
205
+ target_dtype = self.proj.weight.dtype
195
206
  L, C = x.shape
196
207
  x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
197
- x = self.proj(x).view(L, self.embed_dim)
208
+ x = self.proj(x.to(dtype=target_dtype)).view(L, self.embed_dim)
198
209
  return x
199
210
 
200
211
 
@@ -246,35 +257,15 @@ class Qwen2_5_VisionRotaryEmbedding(nn.Module):
246
257
 
247
258
  def __init__(self, dim: int, theta: float = 10000.0) -> None:
248
259
  super().__init__()
249
- self.dim = dim
250
- self.theta = theta
251
260
  inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
252
261
  self.register_buffer("inv_freq", inv_freq, persistent=False)
253
- self._seq_len_cached = 0
254
- self._freqs_cached = None
255
-
256
- def update_freqs_cache(self, seqlen: int) -> None:
257
- if seqlen > self._seq_len_cached:
258
- seqlen *= 2
259
- self._seq_len_cached = seqlen
260
- self.inv_freq = 1.0 / (
261
- self.theta
262
- ** (
263
- torch.arange(
264
- 0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device
265
- )
266
- / self.dim
267
- )
268
- )
269
- seq = torch.arange(
270
- seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
271
- )
272
- freqs = torch.outer(seq, self.inv_freq)
273
- self._freqs_cached = freqs
274
262
 
275
263
  def forward(self, seqlen: int) -> torch.Tensor:
276
- self.update_freqs_cache(seqlen)
277
- return self._freqs_cached[:seqlen]
264
+ seq = torch.arange(
265
+ seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
266
+ )
267
+ freqs = torch.outer(seq, self.inv_freq)
268
+ return freqs
278
269
 
279
270
 
280
271
  class Qwen2_5_VisionTransformer(nn.Module):
@@ -293,7 +284,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
293
284
  spatial_merge_size: int = vision_config.spatial_merge_size
294
285
  self.spatial_merge_size = spatial_merge_size
295
286
  self.spatial_merge_unit: int = spatial_merge_size * spatial_merge_size
296
- in_chans: int = vision_config.in_chans
287
+ in_chans: int = vision_config.in_channels
297
288
  hidden_size: int = vision_config.hidden_size
298
289
  depth: int = vision_config.depth
299
290
  num_heads: int = vision_config.num_heads
@@ -335,13 +326,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
335
326
  )
336
327
 
337
328
  def get_window_index(self, grid_thw):
338
- window_index: list = []
339
329
  cu_window_seqlens: list = [0]
340
330
  window_index_id = 0
341
331
  vit_merger_window_size = (
342
332
  self.window_size // self.spatial_merge_size // self.patch_size
343
333
  )
344
-
334
+ window_index: list = []
345
335
  for grid_t, grid_h, grid_w in grid_thw:
346
336
  llm_grid_h, llm_grid_w = (
347
337
  grid_h // self.spatial_merge_size,
@@ -378,7 +368,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
378
368
  cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
379
369
  window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
380
370
  window_index = torch.cat(window_index, dim=0)
381
-
382
371
  return window_index, cu_window_seqlens
383
372
 
384
373
  @property
@@ -391,29 +380,29 @@ class Qwen2_5_VisionTransformer(nn.Module):
391
380
 
392
381
  def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
393
382
  pos_ids = []
394
- for t, h, w in grid_thw:
383
+ for i in range(grid_thw.size(0)):
384
+ t, h, w = grid_thw[i].tolist()
395
385
  hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
396
- wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
397
- hpos_ids = (
398
- hpos_ids.reshape(
399
- h // self.spatial_merge_size,
400
- self.spatial_merge_size,
401
- w // self.spatial_merge_size,
402
- self.spatial_merge_size,
403
- )
404
- .permute(0, 2, 1, 3)
405
- .flatten()
386
+
387
+ hpos_ids = hpos_ids.reshape(
388
+ h // self.spatial_merge_size,
389
+ self.spatial_merge_size,
390
+ w // self.spatial_merge_size,
391
+ self.spatial_merge_size,
406
392
  )
407
- wpos_ids = (
408
- wpos_ids.reshape(
409
- h // self.spatial_merge_size,
410
- self.spatial_merge_size,
411
- w // self.spatial_merge_size,
412
- self.spatial_merge_size,
413
- )
414
- .permute(0, 2, 1, 3)
415
- .flatten()
393
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
394
+ hpos_ids = hpos_ids.flatten()
395
+
396
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
397
+ wpos_ids = wpos_ids.reshape(
398
+ h // self.spatial_merge_size,
399
+ self.spatial_merge_size,
400
+ w // self.spatial_merge_size,
401
+ self.spatial_merge_size,
416
402
  )
403
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
404
+ wpos_ids = wpos_ids.flatten()
405
+
417
406
  pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
418
407
  pos_ids = torch.cat(pos_ids, dim=0)
419
408
  max_grid_size = grid_thw[:, 1:].max()
@@ -437,7 +426,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
437
426
  cu_window_seqlens = torch.tensor(
438
427
  cu_window_seqlens,
439
428
  device=x.device,
440
- dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
429
+ dtype=torch.int32,
441
430
  )
442
431
  cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
443
432
 
@@ -455,9 +444,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
455
444
  position_embeddings = (emb.cos(), emb.sin())
456
445
 
457
446
  # compute cu_seqlens
458
- cu_seqlens = torch.repeat_interleave(
459
- grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
460
- ).cumsum(dim=0, dtype=torch.int32)
447
+ cu_seqlens = torch.cat(
448
+ [
449
+ torch.tensor([0], device=grid_thw.device),
450
+ (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).cumsum(dim=0),
451
+ ]
452
+ )
461
453
  cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
462
454
 
463
455
  # transformers
@@ -521,19 +513,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
521
513
  self.logits_processor = LogitsProcessor(config)
522
514
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
523
515
 
524
- def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
525
- processor = cached_get_processor(self.config._name_or_path)
526
- grid_t, grid_h, grid_w = image_grid_thw
527
- num_image_tokens = (
528
- grid_t
529
- * grid_h
530
- * grid_w
531
- // processor.image_processor.merge_size
532
- // processor.image_processor.merge_size
533
- )
534
- return num_image_tokens
535
-
536
- def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
516
+ def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
537
517
  # Get all special token IDs
538
518
  im_start_id: int = image_inputs.im_start_id
539
519
  im_end_id: int = image_inputs.im_end_id
@@ -543,9 +523,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
543
523
 
544
524
  return pattern.pad_input_tokens(input_ids, image_inputs)
545
525
 
546
- def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor:
547
- pixel_values = image_input["pixel_values"].type(self.visual.dtype)
548
- image_embeds = self.visual(pixel_values, grid_thw=image_input["image_grid_thw"])
526
+ def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor:
527
+ pixel_values = image_input.pixel_values.type(self.visual.dtype)
528
+ image_embeds = self.visual(pixel_values, grid_thw=image_input.image_grid_thws)
549
529
  return image_embeds
550
530
 
551
531
  def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
@@ -555,6 +535,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
555
535
  )
556
536
  return video_embeds
557
537
 
538
+ def get_input_embeddings(self):
539
+ return self.model.embed_tokens
540
+
558
541
  def forward(
559
542
  self,
560
543
  input_ids: torch.Tensor,
@@ -577,85 +560,25 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
577
560
  if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
578
561
  positions = forward_batch.mrope_positions
579
562
 
580
- image_inputs = None
581
- if forward_batch.image_inputs is not None:
582
- image_inputs = [
583
- img for img in forward_batch.image_inputs if img is not None
584
- ]
585
-
586
- if (
563
+ if not (
587
564
  forward_batch.forward_mode.is_decode()
588
- or image_inputs is None
589
- or len(image_inputs) == 0
565
+ or not forward_batch.contains_image_inputs()
590
566
  ):
591
- inputs_embeds = self.model.embed_tokens(input_ids)
592
- else:
593
567
  if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
594
568
  assert positions.ndim == 2 and positions.size(0) == 3, (
595
569
  "multimodal section rotary embedding requires "
596
570
  f"(3, seq_len) positions, but got {positions.size()}"
597
571
  )
598
572
 
599
- # Clamp input ids. This is because the input_ids for the image tokens are
600
- # filled with the hash values of the image for the prefix matching in the radix attention.
601
- # There values are useless because their embeddings will be replaced by vision embeddings anyway.
602
- input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
603
- # [B, s, hidden_size]
604
- inputs_embeds = self.model.embed_tokens(input_ids)
605
- extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
606
- prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
607
- for i, image in enumerate(forward_batch.image_inputs):
608
- if image is None or image.pixel_values is None:
609
- continue
610
- start_idx = extend_start_loc_cpu[i]
611
- prefix_len = prefix_lens_cpu[i]
612
-
613
- pixel_values = image.pixel_values.clone().detach().requires_grad_(False)
614
- image_grid_thws = torch.tensor(
615
- np.array(image.image_grid_thws), device="cuda"
616
- )
617
- image_offsets = image.image_offsets
618
- image_input = Qwen2VLImageInputs(
619
- pixel_values=pixel_values, image_grid_thw=image_grid_thws
620
- )
621
- image_embeds = self._process_image_input(image_input)
622
-
623
- image_embeds_offset = 0
624
- for idx, image_offset in enumerate(image_offsets):
625
- if image_offset < prefix_len:
626
- continue
627
- num_image_tokens = self.calculate_num_image_tokens(
628
- image_grid_thws[idx]
629
- )
630
-
631
- left_idx = start_idx + (image_offset - prefix_len)
632
- right_idx = left_idx + num_image_tokens
633
-
634
- tp_size = get_tensor_model_parallel_world_size()
635
-
636
- hidden_size = image_embeds.shape[-1]
637
-
638
- if hidden_size % tp_size != 0:
639
- padding_size = tp_size - (hidden_size % tp_size)
640
- image_embeds = F.pad(image_embeds, (0, padding_size))
641
- inputs_embeds = F.pad(inputs_embeds, (0, padding_size))
642
-
643
- hidden_chunk_size = image_embeds.shape[-1] // tp_size
644
- rank = get_tensor_model_parallel_rank()
645
- start_dim = rank * hidden_chunk_size
646
- end_dim = (rank + 1) * hidden_chunk_size
647
- inputs_embeds[left_idx:right_idx, ..., start_dim:end_dim] = (
648
- image_embeds[
649
- image_embeds_offset : image_embeds_offset
650
- + num_image_tokens,
651
- ...,
652
- start_dim:end_dim,
653
- ]
654
- )
655
- image_embeds_offset += num_image_tokens
573
+ inputs_embeds = general_mm_embed_routine(
574
+ input_ids=input_ids,
575
+ forward_batch=forward_batch,
576
+ embed_tokens=self.get_input_embeddings(),
577
+ mm_data_embedding_func=self.get_image_feature,
578
+ )
656
579
 
657
580
  hidden_states = self.model(
658
- input_ids=input_ids,
581
+ input_ids=None,
659
582
  positions=positions,
660
583
  forward_batch=forward_batch,
661
584
  input_embeds=inputs_embeds,
@@ -732,4 +655,3 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
732
655
 
733
656
 
734
657
  EntryClass = [Qwen2_5_VLForConditionalGeneration]
735
- AutoModel.register(Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration)
@@ -0,0 +1,75 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ from typing import Iterable, Optional, Tuple
16
+
17
+ import torch
18
+ from torch import nn
19
+ from transformers import Qwen2Config
20
+
21
+ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
22
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
23
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
24
+ from sglang.srt.models.qwen2 import Qwen2ForCausalLM, Qwen2Model
25
+ from sglang.srt.utils import add_prefix
26
+
27
+
28
+ class Qwen2ForSequenceClassification(nn.Module):
29
+ def __init__(
30
+ self,
31
+ config: Qwen2Config,
32
+ quant_config: Optional[QuantizationConfig] = None,
33
+ prefix: str = "",
34
+ ) -> None:
35
+ super().__init__()
36
+ self.config = config
37
+ self.quant_config = quant_config
38
+ self.model = Qwen2Model(
39
+ config, quant_config=quant_config, prefix=add_prefix("model", prefix)
40
+ )
41
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
42
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)
43
+
44
+ self.eos_token_id = config.eos_token_id
45
+
46
+ @torch.no_grad()
47
+ def forward(
48
+ self,
49
+ input_ids: torch.Tensor,
50
+ positions: torch.Tensor,
51
+ forward_batch: ForwardBatch,
52
+ input_embeds: torch.Tensor = None,
53
+ get_embedding: bool = True,
54
+ ) -> EmbeddingPoolerOutput:
55
+ assert (
56
+ get_embedding
57
+ ), "Qwen2ForSequenceClassification is only used for embedding"
58
+
59
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
60
+ logits = self.score(hidden_states)
61
+ pooled_logits = self.pooler(logits, forward_batch).embeddings
62
+
63
+ return EmbeddingPoolerOutput(pooled_logits)
64
+
65
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
66
+ # Filter out lm_head weights of Qwen2ForCausalLM
67
+ filtered_weights = [
68
+ (name, w) for name, w in weights if not name.startswith("lm_head")
69
+ ]
70
+ return Qwen2ForCausalLM.load_weights(self, filtered_weights)
71
+
72
+
73
+ EntryClass = [
74
+ Qwen2ForSequenceClassification,
75
+ ]
@@ -44,10 +44,13 @@ from sglang.srt.layers.vocab_parallel_embedding import (
44
44
  ParallelLMHead,
45
45
  VocabParallelEmbedding,
46
46
  )
47
+ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
47
48
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
48
49
  from sglang.srt.model_loader.weight_utils import default_weight_loader
49
50
  from sglang.srt.utils import add_prefix
50
51
 
52
+ expert_distribution_recorder = ExpertDistributionRecorder()
53
+
51
54
 
52
55
  class Qwen2MoeMLP(nn.Module):
53
56
  def __init__(
@@ -170,6 +173,7 @@ class Qwen2MoeAttention(nn.Module):
170
173
  rope_theta: float = 10000,
171
174
  rope_scaling: Optional[Dict[str, Any]] = None,
172
175
  max_position_embeddings: int = 8192,
176
+ qkv_bias: int = True,
173
177
  quant_config: Optional[QuantizationConfig] = None,
174
178
  prefix: str = "",
175
179
  ) -> None:
@@ -201,7 +205,7 @@ class Qwen2MoeAttention(nn.Module):
201
205
  self.head_dim,
202
206
  self.total_num_heads,
203
207
  self.total_num_kv_heads,
204
- bias=True,
208
+ bias=qkv_bias,
205
209
  quant_config=quant_config,
206
210
  prefix=add_prefix("qkv_proj", prefix),
207
211
  )
@@ -257,6 +261,8 @@ class Qwen2MoeDecoderLayer(nn.Module):
257
261
  rope_theta = getattr(config, "rope_theta", 10000)
258
262
  rope_scaling = getattr(config, "rope_scaling", None)
259
263
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
264
+ # note: replace config.num_hidden_layers < 80 with True once its available in transformers 4.50.0
265
+ qkv_bias = getattr(config, "qkv_bias", config.num_hidden_layers < 80)
260
266
  self.self_attn = Qwen2MoeAttention(
261
267
  hidden_size=self.hidden_size,
262
268
  num_heads=config.num_attention_heads,
@@ -266,6 +272,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
266
272
  rope_scaling=rope_scaling,
267
273
  max_position_embeddings=max_position_embeddings,
268
274
  quant_config=quant_config,
275
+ qkv_bias=qkv_bias,
269
276
  prefix=add_prefix("self_attn", prefix),
270
277
  )
271
278
 
@@ -362,6 +369,7 @@ class Qwen2MoeModel(nn.Module):
362
369
  hidden_states = input_embeds
363
370
  residual = None
364
371
  for i in range(len(self.layers)):
372
+ expert_distribution_recorder.set_current_layer(i)
365
373
  layer = self.layers[i]
366
374
  hidden_states, residual = layer(
367
375
  positions, hidden_states, forward_batch, residual
@@ -26,7 +26,6 @@ import logging
26
26
  from functools import lru_cache, partial
27
27
  from typing import Iterable, List, Optional, Tuple, Type, TypedDict
28
28
 
29
- import numpy as np
30
29
  import torch
31
30
  import torch.nn as nn
32
31
  import torch.nn.functional as F
@@ -42,10 +41,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
42
41
  from sglang.srt.layers.pooler import Pooler, PoolingType
43
42
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
44
43
  from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
45
- from sglang.srt.managers.multi_modality_padding import (
44
+ from sglang.srt.managers.mm_utils import (
46
45
  MultiModalityDataPaddingPatternTokenPairs,
46
+ general_mm_embed_routine,
47
47
  )
48
- from sglang.srt.managers.schedule_batch import ImageInputs
48
+ from sglang.srt.managers.schedule_batch import MultimodalInputs
49
49
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
50
50
  from sglang.srt.model_loader.weight_utils import default_weight_loader
51
51
  from sglang.srt.models.qwen2 import Qwen2Model
@@ -351,7 +351,7 @@ class Qwen2VisionTransformer(nn.Module):
351
351
 
352
352
  @property
353
353
  def dtype(self) -> torch.dtype:
354
- return self.blocks[0].mlp.fc2.weight.dtype
354
+ return next(self.parameters()).dtype
355
355
 
356
356
  @property
357
357
  def device(self) -> torch.device:
@@ -359,7 +359,8 @@ class Qwen2VisionTransformer(nn.Module):
359
359
 
360
360
  def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
361
361
  pos_ids = []
362
- for t, h, w in grid_thw:
362
+ for i in range(grid_thw.size(0)):
363
+ t, h, w = grid_thw[i].tolist()
363
364
  hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
364
365
  wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
365
366
  hpos_ids = (
@@ -471,18 +472,18 @@ class Qwen2VLForConditionalGeneration(nn.Module):
471
472
 
472
473
  # Use grid_t * grid_w * grid_h to pad tokens for each image
473
474
  # add replaced padding by unique image hash
474
- def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
475
+ def pad_input_ids(self, input_ids: List[int], multi_modal_inputs: MultimodalInputs):
475
476
  # Get all special token IDs
476
- im_start_id: int = image_inputs.im_start_id
477
- im_end_id: int = image_inputs.im_end_id
477
+ im_start_id: int = multi_modal_inputs.im_start_id
478
+ im_end_id: int = multi_modal_inputs.im_end_id
478
479
 
479
480
  media_token_pairs = [(im_start_id, im_end_id)]
480
481
  pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
481
- return pattern.pad_input_tokens(input_ids, image_inputs)
482
+ return pattern.pad_input_tokens(input_ids, multi_modal_inputs)
482
483
 
483
- def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor:
484
- pixel_values = image_input["pixel_values"].type(self.visual.dtype)
485
- image_embeds = self.visual(pixel_values, grid_thw=image_input["image_grid_thw"])
484
+ def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor:
485
+ pixel_values = image_input.pixel_values.type(self.visual.dtype)
486
+ image_embeds = self.visual(pixel_values, grid_thw=image_input.image_grid_thws)
486
487
  return image_embeds
487
488
 
488
489
  def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
@@ -492,6 +493,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
492
493
  )
493
494
  return video_embeds
494
495
 
496
+ def get_input_embeddings(self):
497
+ return self.model.embed_tokens
498
+
495
499
  def forward(
496
500
  self,
497
501
  input_ids: torch.Tensor,
@@ -514,67 +518,25 @@ class Qwen2VLForConditionalGeneration(nn.Module):
514
518
  if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
515
519
  positions = forward_batch.mrope_positions
516
520
 
517
- image_inputs = None
518
- if forward_batch.image_inputs is not None:
519
- image_inputs = [
520
- img for img in forward_batch.image_inputs if img is not None
521
- ]
522
-
523
- if (
521
+ if not (
524
522
  forward_batch.forward_mode.is_decode()
525
- or image_inputs is None
526
- or len(image_inputs) == 0
523
+ or not forward_batch.contains_image_inputs()
527
524
  ):
528
- inputs_embeds = self.model.embed_tokens(input_ids)
529
- else:
530
525
  if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
531
526
  assert positions.ndim == 2 and positions.size(0) == 3, (
532
527
  "multimodal section rotary embedding requires "
533
528
  f"(3, seq_len) positions, but got {positions.size()}"
534
529
  )
535
530
 
536
- # Clamp input ids. This is because the input_ids for the image tokens are
537
- # filled with the hash values of the image for the prefix matching in the radix attention.
538
- # There values are useless because their embeddings will be replaced by vision embeddings anyway.
539
- input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
540
-
541
- inputs_embeds = self.model.embed_tokens(input_ids)
542
- extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
543
- prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
544
- for i, image in enumerate(forward_batch.image_inputs):
545
- if image is None or image.pixel_values is None:
546
- continue
547
- start_idx = extend_start_loc_cpu[i]
548
- prefix_len = prefix_lens_cpu[i]
549
- pixel_values = image.pixel_values.clone()
550
-
551
- image_grid_thws = torch.tensor(
552
- np.array(image.image_grid_thws), device="cuda"
553
- )
554
- image_offsets = image.image_offsets
555
- image_input = Qwen2VLImageInputs(
556
- pixel_values=pixel_values, image_grid_thw=image_grid_thws
557
- )
558
- image_embeds = self._process_image_input(image_input)
559
-
560
- image_embeds_offset = 0
561
- for idx, image_offset in enumerate(image_offsets):
562
- if image_offset < prefix_len:
563
- continue
564
- num_image_tokens = self.calculate_num_image_tokens(
565
- image_grid_thws[idx]
566
- )
567
-
568
- left_idx = start_idx + (image_offset - prefix_len + 1)
569
- right_idx = left_idx + num_image_tokens
570
- inputs_embeds[left_idx:right_idx] = image_embeds[
571
- image_embeds_offset : image_embeds_offset + num_image_tokens
572
- ]
573
- image_embeds_offset += num_image_tokens
574
- input_ids = None
531
+ inputs_embeds = general_mm_embed_routine(
532
+ input_ids=input_ids,
533
+ forward_batch=forward_batch,
534
+ embed_tokens=self.get_input_embeddings(),
535
+ mm_data_embedding_func=self.get_image_feature,
536
+ )
575
537
 
576
538
  hidden_states = self.model(
577
- input_ids=input_ids,
539
+ input_ids=None,
578
540
  positions=positions,
579
541
  forward_batch=forward_batch,
580
542
  input_embeds=inputs_embeds,