sglang 0.3.4__py3-none-any.whl → 0.3.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. sglang/bench_latency.py +2 -1
  2. sglang/lang/chat_template.py +17 -0
  3. sglang/launch_server_llavavid.py +1 -1
  4. sglang/srt/configs/__init__.py +3 -0
  5. sglang/srt/configs/model_config.py +27 -2
  6. sglang/srt/configs/qwen2vl.py +133 -0
  7. sglang/srt/constrained/fsm_cache.py +10 -3
  8. sglang/srt/conversation.py +27 -0
  9. sglang/srt/hf_transformers_utils.py +16 -1
  10. sglang/srt/layers/attention/__init__.py +16 -5
  11. sglang/srt/layers/attention/double_sparsity_backend.py +22 -6
  12. sglang/srt/layers/attention/flashinfer_backend.py +174 -54
  13. sglang/srt/layers/attention/triton_backend.py +22 -6
  14. sglang/srt/layers/attention/triton_ops/prefill_attention.py +26 -4
  15. sglang/srt/layers/linear.py +89 -63
  16. sglang/srt/layers/logits_processor.py +5 -5
  17. sglang/srt/layers/rotary_embedding.py +112 -0
  18. sglang/srt/layers/sampler.py +51 -39
  19. sglang/srt/lora/lora.py +3 -1
  20. sglang/srt/managers/data_parallel_controller.py +1 -1
  21. sglang/srt/managers/detokenizer_manager.py +4 -0
  22. sglang/srt/managers/image_processor.py +186 -13
  23. sglang/srt/managers/io_struct.py +10 -0
  24. sglang/srt/managers/schedule_batch.py +238 -68
  25. sglang/srt/managers/scheduler.py +69 -50
  26. sglang/srt/managers/tokenizer_manager.py +24 -4
  27. sglang/srt/managers/tp_worker.py +26 -111
  28. sglang/srt/managers/tp_worker_overlap_thread.py +209 -0
  29. sglang/srt/mem_cache/memory_pool.py +56 -10
  30. sglang/srt/mem_cache/radix_cache.py +4 -3
  31. sglang/srt/model_executor/cuda_graph_runner.py +87 -28
  32. sglang/srt/model_executor/forward_batch_info.py +83 -3
  33. sglang/srt/model_executor/model_runner.py +32 -11
  34. sglang/srt/models/chatglm.py +3 -3
  35. sglang/srt/models/deepseek_v2.py +2 -2
  36. sglang/srt/models/mllama.py +1004 -0
  37. sglang/srt/models/qwen2_vl.py +724 -0
  38. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
  39. sglang/srt/sampling/sampling_batch_info.py +13 -3
  40. sglang/srt/sampling/sampling_params.py +5 -7
  41. sglang/srt/server.py +12 -0
  42. sglang/srt/server_args.py +10 -0
  43. sglang/srt/utils.py +22 -0
  44. sglang/test/run_eval.py +2 -0
  45. sglang/test/runners.py +20 -1
  46. sglang/test/srt/sampling/penaltylib/utils.py +1 -0
  47. sglang/test/test_utils.py +100 -3
  48. sglang/version.py +1 -1
  49. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/METADATA +17 -18
  50. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/RECORD +53 -48
  51. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/LICENSE +0 -0
  52. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/WHEEL +0 -0
  53. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/top_level.txt +0 -0
@@ -25,6 +25,8 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
25
25
  - ScheduleBatch is managed by `scheduler.py::Scheduler`.
26
26
  It contains high-level scheduling data. Most of the data is on the CPU.
27
27
  - ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
28
+ It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
29
+ It will be transformed from CPU scheduler to GPU model runner.
28
30
  - ForwardBatch is managed by `model_runner.py::ModelRunner`.
29
31
  It contains low-level tensor data. Most of the data consists of GPU tensors.
30
32
  """
@@ -33,9 +35,10 @@ from dataclasses import dataclass
33
35
  from enum import IntEnum, auto
34
36
  from typing import TYPE_CHECKING, List, Optional
35
37
 
36
- import numpy as np
37
38
  import torch
38
39
 
40
+ from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
41
+
39
42
  if TYPE_CHECKING:
40
43
  from sglang.srt.layers.attention import AttentionBackend
41
44
  from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
@@ -84,6 +87,9 @@ class ForwardBatch:
84
87
  # The indices of output tokens in the token_to_kv_pool
85
88
  out_cache_loc: torch.Tensor
86
89
 
90
+ # The sum of all sequence lengths
91
+ seq_lens_sum: int
92
+
87
93
  # For logprob
88
94
  return_logprob: bool = False
89
95
  top_logprobs_nums: Optional[List[int]] = None
@@ -92,6 +98,7 @@ class ForwardBatch:
92
98
  positions: torch.Tensor = None
93
99
 
94
100
  # For extend
101
+ extend_num_tokens: Optional[int] = None
95
102
  extend_seq_lens: Optional[torch.Tensor] = None
96
103
  extend_prefix_lens: Optional[torch.Tensor] = None
97
104
  extend_start_loc: Optional[torch.Tensor] = None
@@ -101,6 +108,12 @@ class ForwardBatch:
101
108
  # For multimodal
102
109
  image_inputs: Optional[List[ImageInputs]] = None
103
110
 
111
+ # Encoder-decoder
112
+ encoder_cached: Optional[List[bool]] = None
113
+ encoder_lens: Optional[torch.Tensor] = None
114
+ encoder_lens_cpu: Optional[List[int]] = None
115
+ encoder_out_cache_loc: Optional[torch.Tensor] = None
116
+
104
117
  # For LoRA
105
118
  lora_paths: Optional[List[str]] = None
106
119
 
@@ -112,14 +125,71 @@ class ForwardBatch:
112
125
  token_to_kv_pool: BaseTokenToKVPool = None
113
126
  attn_backend: AttentionBackend = None
114
127
 
128
+ # For Qwen2-VL
129
+ mrope_positions: torch.Tensor = None
130
+
131
+ def compute_mrope_positions(
132
+ self, model_runner: ModelRunner, batch: ModelWorkerBatch
133
+ ):
134
+ device = model_runner.device
135
+ hf_config = model_runner.model_config.hf_config
136
+ mrope_positions_list = [None] * self.seq_lens.shape[0]
137
+ if self.forward_mode.is_decode():
138
+ for i, _ in enumerate(mrope_positions_list):
139
+ mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
140
+ batch.mrope_positions_delta[i][0],
141
+ int(self.seq_lens[i]) - 1,
142
+ int(self.seq_lens[i]),
143
+ )
144
+ elif self.forward_mode.is_extend():
145
+ extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
146
+ for i, image_inputs in enumerate(batch.image_inputs):
147
+ extend_start_loc, extend_seq_len, extend_prefix_len = (
148
+ extend_start_loc_cpu[i],
149
+ batch.extend_seq_lens[i],
150
+ batch.extend_prefix_lens[i],
151
+ )
152
+ if image_inputs is None:
153
+ # text only
154
+ mrope_positions = [
155
+ [
156
+ pos
157
+ for pos in range(
158
+ extend_prefix_len, extend_prefix_len + extend_seq_len
159
+ )
160
+ ]
161
+ ] * 3
162
+ mrope_position_delta = 0
163
+ else:
164
+ # TODO: current qwen2-vl do not support radix cache since mrope position calculation
165
+ mrope_positions, mrope_position_delta = (
166
+ MRotaryEmbedding.get_input_positions(
167
+ input_tokens=self.input_ids[
168
+ extend_start_loc : extend_start_loc + extend_seq_len
169
+ ],
170
+ image_grid_thw=image_inputs.image_grid_thws,
171
+ vision_start_token_id=hf_config.vision_start_token_id,
172
+ spatial_merge_size=hf_config.vision_config.spatial_merge_size,
173
+ context_len=0,
174
+ )
175
+ )
176
+ mrope_positions_list[i] = mrope_positions
177
+ batch.mrope_positions_delta[i].append(mrope_position_delta)
178
+
179
+ self.mrope_positions = torch.concat(
180
+ [torch.tensor(pos, device=device) for pos in mrope_positions_list],
181
+ axis=1,
182
+ )
183
+ self.mrope_positions = self.mrope_positions.to(torch.int64)
184
+
115
185
  @classmethod
116
186
  def init_new(
117
187
  cls,
118
188
  batch: ModelWorkerBatch,
119
189
  model_runner: ModelRunner,
120
190
  ):
121
- device = model_runner.device
122
191
 
192
+ device = model_runner.device
123
193
  ret = cls(
124
194
  forward_mode=batch.forward_mode,
125
195
  batch_size=len(batch.seq_lens),
@@ -127,6 +197,12 @@ class ForwardBatch:
127
197
  req_pool_indices=batch.req_pool_indices,
128
198
  seq_lens=batch.seq_lens,
129
199
  out_cache_loc=batch.out_cache_loc,
200
+ image_inputs=batch.image_inputs,
201
+ encoder_cached=batch.encoder_cached,
202
+ encoder_lens=batch.encoder_lens,
203
+ encoder_lens_cpu=batch.encoder_lens_cpu,
204
+ encoder_out_cache_loc=batch.encoder_out_cache_loc,
205
+ seq_lens_sum=batch.seq_lens_sum,
130
206
  return_logprob=batch.return_logprob,
131
207
  top_logprobs_nums=batch.top_logprobs_nums,
132
208
  lora_paths=batch.lora_paths,
@@ -144,10 +220,11 @@ class ForwardBatch:
144
220
  ],
145
221
  axis=0,
146
222
  )
147
- ret.image_inputs = batch.image_inputs
223
+ ret.extend_num_tokens = batch.extend_num_tokens
148
224
  ret.extend_seq_lens = torch.tensor(
149
225
  batch.extend_seq_lens, dtype=torch.int32
150
226
  ).to(device, non_blocking=True)
227
+
151
228
  ret.extend_prefix_lens = torch.tensor(
152
229
  batch.extend_prefix_lens, dtype=torch.int32
153
230
  ).to(device, non_blocking=True)
@@ -156,6 +233,9 @@ class ForwardBatch:
156
233
  ret.extend_seq_lens_cpu = batch.extend_seq_lens
157
234
  ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
158
235
 
236
+ if model_runner.model_is_mrope:
237
+ ret.compute_mrope_positions(model_runner, batch)
238
+
159
239
  # Init attention information
160
240
  ret.req_to_token_pool = model_runner.req_to_token_pool
161
241
  ret.token_to_kv_pool = model_runner.token_to_kv_pool
@@ -59,8 +59,11 @@ from sglang.srt.server_args import ServerArgs
59
59
  from sglang.srt.utils import (
60
60
  enable_show_time_cost,
61
61
  get_available_gpu_memory,
62
+ is_attention_free_model,
63
+ is_embedding_model,
62
64
  is_generation_model,
63
65
  is_multimodal_model,
66
+ model_has_inner_state,
64
67
  monkey_patch_vllm_dummy_weight_loader,
65
68
  monkey_patch_vllm_p2p_access_check,
66
69
  )
@@ -117,11 +120,16 @@ class ModelRunner:
117
120
  )
118
121
 
119
122
  if self.is_multimodal_model:
120
- logger.info(
123
+ logger.warning(
121
124
  "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
122
125
  )
123
126
  server_args.chunked_prefill_size = None
124
127
  server_args.mem_fraction_static *= 0.95
128
+ # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
129
+ if self.model_config.hf_config.architectures == [
130
+ "Qwen2VLForConditionalGeneration"
131
+ ]:
132
+ server_args.disable_radix_cache = True
125
133
 
126
134
  # Global vars
127
135
  if server_args.show_time_cost:
@@ -262,7 +270,6 @@ class ModelRunner:
262
270
  if hasattr(self.model, "get_attention_sliding_window_size")
263
271
  else None
264
272
  )
265
- self.has_cross_attention = getattr(self.model, "has_cross_attention", False)
266
273
  self.is_generation = is_generation_model(
267
274
  self.model_config.hf_config.architectures, self.server_args.is_embedding
268
275
  )
@@ -316,11 +323,13 @@ class ModelRunner:
316
323
 
317
324
  def get_weight_iter(config):
318
325
  iter = loader._get_weights_iterator(
319
- config.model,
320
- config.revision,
321
- fall_back_to_pt=getattr(
322
- self.model, "fall_back_to_pt_during_load", True
323
- ),
326
+ DefaultModelLoader.Source(
327
+ config.model,
328
+ revision=config.revision,
329
+ fall_back_to_pt=getattr(
330
+ self.model, "fall_back_to_pt_during_load", True
331
+ ),
332
+ )
324
333
  )
325
334
  return iter
326
335
 
@@ -444,6 +453,7 @@ class ModelRunner:
444
453
  size=max_num_reqs + 1,
445
454
  max_context_len=self.model_config.context_len + 4,
446
455
  device=self.device,
456
+ use_records=False,
447
457
  )
448
458
  if (
449
459
  self.model_config.attention_arch == AttentionArch.MLA
@@ -499,7 +509,7 @@ class ModelRunner:
499
509
  "Window attention is not supported in the triton attention backend. "
500
510
  "Please use `--attention-backend flashinfer`."
501
511
  )
502
- assert not self.has_cross_attention, (
512
+ assert not self.model_config.is_encoder_decoder, (
503
513
  "Cross attention is not supported in the triton attention backend. "
504
514
  "Please use `--attention-backend flashinfer`."
505
515
  )
@@ -547,9 +557,7 @@ class ModelRunner:
547
557
  self.cuda_graph_runner = CudaGraphRunner(self)
548
558
 
549
559
  def forward_decode(self, forward_batch: ForwardBatch):
550
- if self.cuda_graph_runner and self.cuda_graph_runner.can_run(
551
- forward_batch.batch_size
552
- ):
560
+ if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
553
561
  return self.cuda_graph_runner.replay(forward_batch)
554
562
 
555
563
  forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64)
@@ -617,6 +625,15 @@ class ModelRunner:
617
625
 
618
626
  return logits
619
627
 
628
+ @property
629
+ def model_is_mrope(self) -> bool:
630
+ """Detect if the model has "mrope" rope_scaling type.
631
+ mrope requires keep "rope_deltas" between prompt and decoding phases."""
632
+ rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
633
+ if rope_scaling is None:
634
+ return False
635
+ return rope_scaling.get("type", None) == "mrope"
636
+
620
637
 
621
638
  @lru_cache()
622
639
  def import_model_classes():
@@ -662,3 +679,7 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
662
679
 
663
680
  # Monkey patch model loader
664
681
  setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
682
+ setattr(ModelRegistry, "is_multimodal_model", is_multimodal_model)
683
+ setattr(ModelRegistry, "is_attention_free_model", is_attention_free_model)
684
+ setattr(ModelRegistry, "model_has_inner_state", model_has_inner_state)
685
+ setattr(ModelRegistry, "is_embedding_model", is_embedding_model)
@@ -303,7 +303,7 @@ class GLMTransformer(nn.Module):
303
303
  return hidden_states
304
304
 
305
305
 
306
- class ChatGLMModel(nn.Module):
306
+ class ChatGLMM(nn.Module):
307
307
  def __init__(
308
308
  self,
309
309
  config,
@@ -366,7 +366,7 @@ class ChatGLMForCausalLM(nn.Module):
366
366
  self.config: ChatGLMConfig = config
367
367
  self.quant_config = quant_config
368
368
  self.max_position_embeddings = getattr(config, "max_sequence_length", 8192)
369
- self.transformer = ChatGLMModel(config, cache_config, quant_config)
369
+ self.transformer = ChatGLMM(config, cache_config, quant_config)
370
370
  self.lm_head = self.transformer.output_layer
371
371
  self.logits_processor = LogitsProcessor(config)
372
372
 
@@ -401,4 +401,4 @@ class ChatGLMModel(ChatGLMForCausalLM):
401
401
  pass
402
402
 
403
403
 
404
- EntryClass = [ChatGLMForCausalLM, ChatGLMModel]
404
+ EntryClass = [ChatGLMModel]
@@ -250,7 +250,7 @@ class DeepseekV2Attention(nn.Module):
250
250
  bias=False,
251
251
  quant_config=quant_config,
252
252
  )
253
- rope_scaling["type"] = "deepseek_yarn"
253
+ rope_scaling["rope_type"] = "deepseek_yarn"
254
254
  self.rotary_emb = get_rope(
255
255
  qk_rope_head_dim,
256
256
  rotary_dim=qk_rope_head_dim,
@@ -398,7 +398,7 @@ class DeepseekV2AttentionMLA(nn.Module):
398
398
  bias=False,
399
399
  quant_config=quant_config,
400
400
  )
401
- rope_scaling["type"] = "deepseek_yarn"
401
+ rope_scaling["rope_type"] = "deepseek_yarn"
402
402
  self.rotary_emb = get_rope(
403
403
  qk_rope_head_dim,
404
404
  rotary_dim=qk_rope_head_dim,