sglang 0.4.1.post4__py3-none-any.whl → 0.4.1.post6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sglang/bench_serving.py +18 -1
  2. sglang/lang/interpreter.py +71 -1
  3. sglang/lang/ir.py +2 -0
  4. sglang/srt/configs/__init__.py +4 -0
  5. sglang/srt/configs/chatglm.py +78 -0
  6. sglang/srt/configs/dbrx.py +279 -0
  7. sglang/srt/configs/model_config.py +16 -7
  8. sglang/srt/hf_transformers_utils.py +9 -14
  9. sglang/srt/layers/attention/__init__.py +8 -1
  10. sglang/srt/layers/attention/flashinfer_backend.py +21 -5
  11. sglang/srt/layers/linear.py +89 -47
  12. sglang/srt/layers/logits_processor.py +6 -6
  13. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +16 -5
  14. sglang/srt/layers/moe/fused_moe_triton/layer.py +39 -12
  15. sglang/srt/layers/moe/topk.py +4 -2
  16. sglang/srt/layers/parameter.py +439 -0
  17. sglang/srt/layers/quantization/__init__.py +5 -2
  18. sglang/srt/layers/quantization/fp8.py +107 -53
  19. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  20. sglang/srt/layers/quantization/int8_kernel.py +54 -0
  21. sglang/srt/layers/quantization/modelopt_quant.py +174 -0
  22. sglang/srt/layers/quantization/w8a8_int8.py +117 -0
  23. sglang/srt/layers/radix_attention.py +2 -0
  24. sglang/srt/layers/vocab_parallel_embedding.py +16 -3
  25. sglang/srt/managers/cache_controller.py +307 -0
  26. sglang/srt/managers/configure_logging.py +43 -0
  27. sglang/srt/managers/data_parallel_controller.py +2 -0
  28. sglang/srt/managers/detokenizer_manager.py +0 -2
  29. sglang/srt/managers/io_struct.py +29 -13
  30. sglang/srt/managers/schedule_batch.py +7 -1
  31. sglang/srt/managers/scheduler.py +58 -15
  32. sglang/srt/managers/session_controller.py +1 -1
  33. sglang/srt/managers/tokenizer_manager.py +109 -45
  34. sglang/srt/mem_cache/memory_pool.py +313 -53
  35. sglang/srt/metrics/collector.py +32 -35
  36. sglang/srt/model_executor/cuda_graph_runner.py +14 -7
  37. sglang/srt/model_executor/forward_batch_info.py +20 -15
  38. sglang/srt/model_executor/model_runner.py +53 -10
  39. sglang/srt/models/chatglm.py +1 -1
  40. sglang/srt/models/dbrx.py +1 -1
  41. sglang/srt/models/grok.py +25 -16
  42. sglang/srt/models/llama.py +46 -4
  43. sglang/srt/models/qwen2.py +11 -0
  44. sglang/srt/models/qwen2_eagle.py +131 -0
  45. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
  46. sglang/srt/sampling/sampling_batch_info.py +15 -5
  47. sglang/srt/sampling/sampling_params.py +1 -1
  48. sglang/srt/server.py +125 -69
  49. sglang/srt/server_args.py +39 -19
  50. sglang/srt/speculative/eagle_utils.py +93 -85
  51. sglang/srt/speculative/eagle_worker.py +48 -33
  52. sglang/srt/torch_memory_saver_adapter.py +59 -0
  53. sglang/srt/utils.py +61 -5
  54. sglang/test/test_programs.py +23 -1
  55. sglang/test/test_utils.py +36 -7
  56. sglang/version.py +1 -1
  57. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/METADATA +16 -15
  58. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/RECORD +61 -51
  59. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/WHEEL +1 -1
  60. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/LICENSE +0 -0
  61. {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/top_level.txt +0 -0
@@ -50,10 +50,12 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
50
50
  from sglang.srt.model_loader import get_model
51
51
  from sglang.srt.server_args import ServerArgs
52
52
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
53
+ from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
53
54
  from sglang.srt.utils import (
54
55
  enable_show_time_cost,
55
56
  get_available_gpu_memory,
56
57
  init_custom_process_group,
58
+ is_cuda,
57
59
  is_hip,
58
60
  monkey_patch_vllm_gguf_config,
59
61
  monkey_patch_vllm_p2p_access_check,
@@ -89,6 +91,7 @@ class ModelRunner:
89
91
  self.is_draft_worker = is_draft_worker
90
92
  self.is_generation = model_config.is_generation
91
93
  self.is_multimodal = model_config.is_multimodal
94
+ self.should_log = tp_rank == 0
92
95
  self.spec_algorithm = SpeculativeAlgorithm.from_string(
93
96
  server_args.speculative_algorithm
94
97
  )
@@ -117,15 +120,21 @@ class ModelRunner:
117
120
 
118
121
  if self.is_multimodal:
119
122
  self.mem_fraction_static *= 0.95
123
+ logger.info(
124
+ f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
125
+ f"because this is a multimodal model."
126
+ )
127
+
120
128
  if self.model_config.hf_config.architectures == [
121
129
  "MllamaForConditionalGeneration"
122
130
  ]:
123
131
  logger.info("Automatically turn off --chunked-prefill-size for mllama.")
124
132
  server_args.chunked_prefill_size = -1
125
- # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
133
+
126
134
  if self.model_config.hf_config.architectures == [
127
135
  "Qwen2VLForConditionalGeneration"
128
136
  ]:
137
+ # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
129
138
  logger.info(
130
139
  "Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
131
140
  )
@@ -158,6 +167,10 @@ class ModelRunner:
158
167
  # Get memory before model loading
159
168
  min_per_gpu_memory = self.init_torch_distributed()
160
169
 
170
+ self.memory_saver_adapter = TorchMemorySaverAdapter.create(
171
+ enable=self.server_args.enable_memory_saver
172
+ )
173
+
161
174
  # Load the model
162
175
  self.sampler = Sampler()
163
176
  self.load_model()
@@ -198,7 +211,7 @@ class ModelRunner:
198
211
  if self.device == "cuda":
199
212
  backend = "nccl"
200
213
  elif self.device == "xpu":
201
- # TODO(liangan1):Just use gloo to bypass the initilization fail
214
+ # TODO(liangan1): Just use gloo to bypass the initilization fail
202
215
  # Need to use xccl for xpu backend in the future
203
216
  backend = "gloo"
204
217
  elif self.device == "hpu":
@@ -264,11 +277,35 @@ class ModelRunner:
264
277
  monkey_patch_vllm_gguf_config()
265
278
 
266
279
  # Load the model
267
- self.model = get_model(
268
- model_config=self.model_config,
269
- load_config=self.load_config,
270
- device_config=DeviceConfig(self.device),
271
- )
280
+ with self.memory_saver_adapter.region():
281
+ self.model = get_model(
282
+ model_config=self.model_config,
283
+ load_config=self.load_config,
284
+ device_config=DeviceConfig(self.device),
285
+ )
286
+
287
+ if self.server_args.kv_cache_dtype == "fp8_e4m3":
288
+ if self.server_args.quantization_param_path is not None:
289
+ if callable(getattr(self.model, "load_kv_cache_scales", None)):
290
+ self.model.load_kv_cache_scales(
291
+ self.server_args.quantization_param_path
292
+ )
293
+ logger.info(
294
+ "Loaded KV cache scaling factors from %s",
295
+ self.server_args.quantization_param_path,
296
+ )
297
+ else:
298
+ raise RuntimeError(
299
+ "Using FP8 KV cache and scaling factors provided but "
300
+ "model %s does not support loading scaling factors.",
301
+ self.model.__class__,
302
+ )
303
+ else:
304
+ logger.warning(
305
+ "Using FP8 KV cache but no scaling factors "
306
+ "provided. Defaulting to scaling factors of 1.0. "
307
+ "This may lead to less accurate results!"
308
+ )
272
309
 
273
310
  # Parse other args
274
311
  self.sliding_window_size = (
@@ -386,7 +423,7 @@ class ModelRunner:
386
423
 
387
424
  logger.info(
388
425
  f"init custom process group: master_address={master_address}, master_port={master_port}, "
389
- f"rank_offset={rank_offset}, world_size={world_size}, group_name={group_name}, backend={backend}"
426
+ f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
390
427
  )
391
428
 
392
429
  try:
@@ -509,6 +546,9 @@ class ModelRunner:
509
546
  self.kv_cache_dtype = torch.float8_e5m2fnuz
510
547
  else:
511
548
  self.kv_cache_dtype = torch.float8_e5m2
549
+ elif self.server_args.kv_cache_dtype == "fp8_e4m3":
550
+ if is_cuda():
551
+ self.kv_cache_dtype = torch.float8_e4m3fn
512
552
  else:
513
553
  raise ValueError(
514
554
  f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
@@ -556,6 +596,7 @@ class ModelRunner:
556
596
  max_context_len=self.model_config.context_len + 4,
557
597
  device=self.device,
558
598
  use_records=False,
599
+ enable_memory_saver=self.server_args.enable_memory_saver,
559
600
  )
560
601
  if (
561
602
  self.model_config.attention_arch == AttentionArch.MLA
@@ -568,6 +609,7 @@ class ModelRunner:
568
609
  qk_rope_head_dim=self.model_config.qk_rope_head_dim,
569
610
  layer_num=self.model_config.num_hidden_layers,
570
611
  device=self.device,
612
+ enable_memory_saver=self.server_args.enable_memory_saver,
571
613
  )
572
614
  elif self.server_args.enable_double_sparsity:
573
615
  self.token_to_kv_pool = DoubleSparseTokenToKVPool(
@@ -578,6 +620,7 @@ class ModelRunner:
578
620
  layer_num=self.model_config.num_hidden_layers,
579
621
  device=self.device,
580
622
  heavy_channel_num=self.server_args.ds_heavy_channel_num,
623
+ enable_memory_saver=self.server_args.enable_memory_saver,
581
624
  )
582
625
  else:
583
626
  self.token_to_kv_pool = MHATokenToKVPool(
@@ -587,6 +630,7 @@ class ModelRunner:
587
630
  head_dim=self.model_config.head_dim,
588
631
  layer_num=self.model_config.num_hidden_layers,
589
632
  device=self.device,
633
+ enable_memory_saver=self.server_args.enable_memory_saver,
590
634
  )
591
635
  logger.info(
592
636
  f"Memory pool end. "
@@ -627,7 +671,6 @@ class ModelRunner:
627
671
  )
628
672
 
629
673
  def init_double_sparsity_channel_config(self, selected_channel):
630
-
631
674
  selected_channel = "." + selected_channel + "_proj"
632
675
  self.sorted_channels = []
633
676
  # load channel config
@@ -718,7 +761,7 @@ class ModelRunner:
718
761
  elif forward_batch.forward_mode.is_idle():
719
762
  return self.forward_idle(forward_batch)
720
763
  else:
721
- raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
764
+ raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
722
765
 
723
766
  def sample(
724
767
  self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
@@ -23,8 +23,8 @@ from torch import nn
23
23
  from torch.nn import LayerNorm
24
24
  from vllm.distributed import get_tensor_model_parallel_world_size
25
25
  from vllm.model_executor.layers.rotary_embedding import get_rope
26
- from vllm.transformers_utils.configs import ChatGLMConfig
27
26
 
27
+ from sglang.srt.configs import ChatGLMConfig
28
28
  from sglang.srt.layers.activation import SiluAndMul
29
29
  from sglang.srt.layers.layernorm import RMSNorm
30
30
  from sglang.srt.layers.linear import (
sglang/srt/models/dbrx.py CHANGED
@@ -25,8 +25,8 @@ from vllm.distributed import (
25
25
  tensor_model_parallel_all_reduce,
26
26
  )
27
27
  from vllm.model_executor.layers.rotary_embedding import get_rope
28
- from vllm.transformers_utils.configs.dbrx import DbrxConfig
29
28
 
29
+ from sglang.srt.configs import DbrxConfig
30
30
  from sglang.srt.layers.linear import (
31
31
  QKVParallelLinear,
32
32
  ReplicatedLinear,
sglang/srt/models/grok.py CHANGED
@@ -57,6 +57,7 @@ class Grok1MLP(nn.Module):
57
57
  quant_config: Optional[QuantizationConfig] = None,
58
58
  prefix: str = "",
59
59
  reduce_results=True,
60
+ use_presharded_weights: bool = False,
60
61
  ) -> None:
61
62
  super().__init__()
62
63
  self.gate_up_proj = MergedColumnParallelLinear(
@@ -65,6 +66,7 @@ class Grok1MLP(nn.Module):
65
66
  bias=False,
66
67
  quant_config=quant_config,
67
68
  prefix=f"{prefix}.gate_up_proj",
69
+ use_presharded_weights=use_presharded_weights,
68
70
  )
69
71
  self.down_proj = RowParallelLinear(
70
72
  intermediate_size,
@@ -73,6 +75,7 @@ class Grok1MLP(nn.Module):
73
75
  quant_config=quant_config,
74
76
  prefix=f"{prefix}.down_proj",
75
77
  reduce_results=reduce_results,
78
+ use_presharded_weights=use_presharded_weights,
76
79
  )
77
80
  self.act_fn = GeluAndMul(approximate="tanh")
78
81
 
@@ -103,6 +106,7 @@ class Grok1MoE(nn.Module):
103
106
  quant_config: Optional[QuantizationConfig] = None,
104
107
  tp_size: Optional[int] = None,
105
108
  reduce_results=True,
109
+ use_presharded_weights: bool = False,
106
110
  ):
107
111
  super().__init__()
108
112
  self.hidden_size = hidden_size
@@ -129,6 +133,7 @@ class Grok1MoE(nn.Module):
129
133
  renormalize=False,
130
134
  quant_config=quant_config,
131
135
  tp_size=tp_size,
136
+ use_presharded_weights=use_presharded_weights,
132
137
  )
133
138
 
134
139
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -156,6 +161,7 @@ class Grok1Attention(nn.Module):
156
161
  max_position: int = 4096 * 32,
157
162
  rope_theta: float = 10000,
158
163
  quant_config: Optional[QuantizationConfig] = None,
164
+ reduce_results: bool = True,
159
165
  ) -> None:
160
166
  super().__init__()
161
167
  self.config = config
@@ -194,6 +200,7 @@ class Grok1Attention(nn.Module):
194
200
  hidden_size,
195
201
  bias=False,
196
202
  quant_config=quant_config,
203
+ reduce_results=reduce_results,
197
204
  )
198
205
  self.rotary_emb = get_rope(
199
206
  self.head_dim,
@@ -234,10 +241,12 @@ class Grok1DecoderLayer(nn.Module):
234
241
  config: PretrainedConfig,
235
242
  layer_id: int = 0,
236
243
  quant_config: Optional[QuantizationConfig] = None,
244
+ use_presharded_weights: bool = False,
237
245
  ) -> None:
238
246
  super().__init__()
239
247
  self.num_experts = config.num_local_experts
240
248
  self.hidden_size = config.hidden_size
249
+ self.layer_id = layer_id
241
250
 
242
251
  rope_theta = getattr(config, "rope_theta", 10000)
243
252
  self.self_attn = Grok1Attention(
@@ -262,6 +271,7 @@ class Grok1DecoderLayer(nn.Module):
262
271
  ),
263
272
  quant_config=quant_config,
264
273
  reduce_results=True,
274
+ use_presharded_weights=use_presharded_weights,
265
275
  )
266
276
  self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
267
277
  self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -299,6 +309,7 @@ class Grok1Model(nn.Module):
299
309
  self,
300
310
  config: PretrainedConfig,
301
311
  quant_config: Optional[QuantizationConfig] = None,
312
+ use_presharded_weights: bool = False,
302
313
  ) -> None:
303
314
  super().__init__()
304
315
  self.config = config
@@ -311,7 +322,12 @@ class Grok1Model(nn.Module):
311
322
  )
312
323
  self.layers = nn.ModuleList(
313
324
  [
314
- Grok1DecoderLayer(config, i, quant_config=quant_config)
325
+ Grok1DecoderLayer(
326
+ config,
327
+ i,
328
+ quant_config=quant_config,
329
+ use_presharded_weights=use_presharded_weights,
330
+ )
315
331
  for i in range(config.num_hidden_layers)
316
332
  ]
317
333
  )
@@ -347,11 +363,7 @@ class Grok1ForCausalLM(nn.Module):
347
363
  super().__init__()
348
364
  self.config = config
349
365
  self.quant_config = quant_config
350
- self.model = Grok1Model(config, quant_config=quant_config)
351
- self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
352
- self.logits_processor = LogitsProcessor(config)
353
366
 
354
- # Monkey patch _prepare_weights to load pre-sharded weights
355
367
  if (
356
368
  self.config.num_local_experts > 0
357
369
  and get_tensor_model_parallel_world_size() > 1
@@ -361,6 +373,14 @@ class Grok1ForCausalLM(nn.Module):
361
373
  else:
362
374
  self.use_presharded_weights = False
363
375
 
376
+ self.model = Grok1Model(
377
+ config,
378
+ quant_config=quant_config,
379
+ use_presharded_weights=self.use_presharded_weights,
380
+ )
381
+ self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
382
+ self.logits_processor = LogitsProcessor(config)
383
+
364
384
  def forward(
365
385
  self,
366
386
  input_ids: torch.Tensor,
@@ -376,10 +396,7 @@ class Grok1ForCausalLM(nn.Module):
376
396
  def load_weights(
377
397
  self,
378
398
  weights: Iterable[Tuple[str, torch.Tensor]],
379
- use_presharded_weights: bool | None = None,
380
399
  ):
381
- if use_presharded_weights is None:
382
- use_presharded_weights = self.use_presharded_weights
383
400
  num_experts = self.config.num_local_experts
384
401
 
385
402
  stacked_params_mapping = [
@@ -435,20 +452,12 @@ class Grok1ForCausalLM(nn.Module):
435
452
  continue
436
453
  name = name.replace(weight_name, param_name)
437
454
 
438
- if use_presharded_weights:
439
- extra_kwargs = {
440
- "use_presharded_weights": use_presharded_weights
441
- }
442
- else:
443
- extra_kwargs = {}
444
-
445
455
  load_weight_wrapper(
446
456
  name,
447
457
  loaded_weight,
448
458
  name,
449
459
  shard_id=shard_id,
450
460
  expert_id=expert_id,
451
- **extra_kwargs,
452
461
  )
453
462
  break
454
463
  else:
@@ -22,8 +22,12 @@ from typing import Any, Dict, Iterable, Optional, Tuple
22
22
  import torch
23
23
  from torch import nn
24
24
  from transformers import LlamaConfig
25
- from vllm.distributed import get_tensor_model_parallel_world_size
25
+ from vllm.distributed import (
26
+ get_tensor_model_parallel_rank,
27
+ get_tensor_model_parallel_world_size,
28
+ )
26
29
  from vllm.model_executor.layers.rotary_embedding import get_rope
30
+ from vllm.model_executor.model_loader.weight_utils import kv_cache_scales_loader
27
31
 
28
32
  from sglang.srt.layers.activation import SiluAndMul
29
33
  from sglang.srt.layers.layernorm import RMSNorm
@@ -100,6 +104,7 @@ class LlamaAttention(nn.Module):
100
104
  max_position_embeddings: int = 8192,
101
105
  quant_config: Optional[QuantizationConfig] = None,
102
106
  prefix: str = "",
107
+ bias: bool = False,
103
108
  ) -> None:
104
109
  super().__init__()
105
110
  self.hidden_size = hidden_size
@@ -132,14 +137,14 @@ class LlamaAttention(nn.Module):
132
137
  self.head_dim,
133
138
  self.total_num_heads,
134
139
  self.total_num_kv_heads,
135
- bias=False,
140
+ bias=bias,
136
141
  quant_config=quant_config,
137
142
  prefix=f"{prefix}.qkv_proj",
138
143
  )
139
144
  self.o_proj = RowParallelLinear(
140
145
  self.total_num_heads * self.head_dim,
141
146
  hidden_size,
142
- bias=False,
147
+ bias=bias,
143
148
  quant_config=quant_config,
144
149
  prefix=f"{prefix}.o_proj",
145
150
  )
@@ -194,6 +199,11 @@ class LlamaDecoderLayer(nn.Module):
194
199
  )
195
200
  rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
196
201
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
202
+ # Support llamafy/Qwen-Qwen2.5-7B-Instruct-llamafied with attention_bias
203
+ # Support internlm/internlm-7b with bias
204
+ attention_bias = getattr(config, "attention_bias", False) or getattr(
205
+ config, "bias", False
206
+ )
197
207
  self.self_attn = LlamaAttention(
198
208
  config=config,
199
209
  hidden_size=self.hidden_size,
@@ -206,6 +216,7 @@ class LlamaDecoderLayer(nn.Module):
206
216
  max_position_embeddings=max_position_embeddings,
207
217
  quant_config=quant_config,
208
218
  prefix=f"{prefix}.self_attn",
219
+ bias=attention_bias,
209
220
  )
210
221
  self.mlp = LlamaMLP(
211
222
  hidden_size=self.hidden_size,
@@ -292,6 +303,30 @@ class LlamaModel(nn.Module):
292
303
  hidden_states, _ = self.norm(hidden_states, residual)
293
304
  return hidden_states
294
305
 
306
+ # If this function is called, it should always initialize KV cache scale
307
+ # factors (or else raise an exception). Thus, handled exceptions should
308
+ # make sure to leave KV cache scale factors in a known good (dummy) state
309
+ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
310
+ tp_size = get_tensor_model_parallel_world_size()
311
+ tp_rank = get_tensor_model_parallel_rank()
312
+ for layer_idx, scaling_factor in kv_cache_scales_loader(
313
+ quantization_param_path,
314
+ tp_rank,
315
+ tp_size,
316
+ self.config.num_hidden_layers,
317
+ self.config.__class__.model_type,
318
+ ):
319
+ if not isinstance(self.layers[layer_idx], nn.Identity):
320
+ layer_self_attn = self.layers[layer_idx].self_attn
321
+
322
+ if hasattr(layer_self_attn.attn, "k_scale"):
323
+ layer_self_attn.attn.k_scale = scaling_factor
324
+ layer_self_attn.attn.v_scale = scaling_factor
325
+ else:
326
+ raise RuntimeError(
327
+ "Self attention has no KV cache scaling " "factor attribute!"
328
+ )
329
+
295
330
 
296
331
  class LlamaForCausalLM(nn.Module):
297
332
 
@@ -527,9 +562,16 @@ class LlamaForCausalLM(nn.Module):
527
562
  torch.cuda.empty_cache()
528
563
  torch.cuda.synchronize()
529
564
 
565
+ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
566
+ self.model.load_kv_cache_scales(quantization_param_path)
567
+
530
568
 
531
569
  class Phi3ForCausalLM(LlamaForCausalLM):
532
570
  pass
533
571
 
534
572
 
535
- EntryClass = [LlamaForCausalLM, Phi3ForCausalLM]
573
+ class InternLM3ForCausalLM(LlamaForCausalLM):
574
+ pass
575
+
576
+
577
+ EntryClass = [LlamaForCausalLM, Phi3ForCausalLM, InternLM3ForCausalLM]
@@ -362,5 +362,16 @@ class Qwen2ForCausalLM(nn.Module):
362
362
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
363
363
  weight_loader(param, loaded_weight)
364
364
 
365
+ def get_embed_and_head(self):
366
+ return self.model.embed_tokens.weight, self.lm_head.weight
367
+
368
+ def set_embed_and_head(self, embed, head):
369
+ del self.model.embed_tokens.weight
370
+ del self.lm_head.weight
371
+ self.model.embed_tokens.weight = embed
372
+ self.lm_head.weight = head
373
+ torch.cuda.empty_cache()
374
+ torch.cuda.synchronize()
375
+
365
376
 
366
377
  EntryClass = Qwen2ForCausalLM
@@ -0,0 +1,131 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ # Adapted from
17
+ # https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py
18
+ """Inference-only LLaMA-EAGLE model compatible with HuggingFace weights."""
19
+
20
+ from typing import Iterable, Optional, Tuple
21
+
22
+ import torch
23
+ from torch import nn
24
+
25
+ from sglang.srt.layers.logits_processor import LogitsProcessor
26
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
27
+ from sglang.srt.layers.vocab_parallel_embedding import (
28
+ ParallelLMHead,
29
+ VocabParallelEmbedding,
30
+ )
31
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
32
+ from sglang.srt.models.qwen2 import Qwen2DecoderLayer, Qwen2ForCausalLM
33
+
34
+ Qwen2Config = None
35
+
36
+
37
+ class Qwen2DecoderLayer(Qwen2DecoderLayer):
38
+ def __init__(
39
+ self,
40
+ config: Qwen2Config,
41
+ layer_id: int = 0,
42
+ quant_config: Optional[QuantizationConfig] = None,
43
+ prefix: str = "",
44
+ ) -> None:
45
+ super().__init__(config, layer_id, quant_config)
46
+
47
+ # Skip the input_layernorm
48
+ # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
49
+ if layer_id == 0:
50
+ del self.input_layernorm
51
+ setattr(self, "input_layernorm", lambda x: x)
52
+
53
+
54
+ class Qwen2Model(nn.Module):
55
+ def __init__(
56
+ self,
57
+ config: Qwen2Config,
58
+ quant_config: Optional[QuantizationConfig] = None,
59
+ ) -> None:
60
+ super().__init__()
61
+ self.config = config
62
+ self.vocab_size = config.vocab_size
63
+ self.embed_tokens = VocabParallelEmbedding(
64
+ config.vocab_size,
65
+ config.hidden_size,
66
+ )
67
+ self.layers = nn.ModuleList(
68
+ [
69
+ Qwen2DecoderLayer(
70
+ config, i, quant_config=quant_config, prefix=f"model.layers.{i}"
71
+ )
72
+ for i in range(config.num_hidden_layers)
73
+ ]
74
+ )
75
+ self.fc = torch.nn.Linear(config.hidden_size * 2, config.hidden_size)
76
+
77
+ def forward(
78
+ self,
79
+ input_ids: torch.Tensor,
80
+ positions: torch.Tensor,
81
+ forward_batch: ForwardBatch,
82
+ input_embeds: torch.Tensor = None,
83
+ ) -> torch.Tensor:
84
+ if input_embeds is None:
85
+ hidden_states = self.embed_tokens(input_ids)
86
+ else:
87
+ hidden_states = input_embeds
88
+
89
+ hidden_states = self.fc(
90
+ torch.cat((hidden_states, forward_batch.spec_info.hidden_states), dim=-1)
91
+ )
92
+
93
+ residual = None
94
+ for i in range(len(self.layers)):
95
+ layer = self.layers[i]
96
+ hidden_states, residual = layer(
97
+ positions,
98
+ hidden_states,
99
+ forward_batch,
100
+ residual,
101
+ )
102
+ return hidden_states + residual
103
+
104
+
105
+ class Qwen2ForCausalLMEagle(Qwen2ForCausalLM):
106
+ def __init__(
107
+ self,
108
+ config: Qwen2Config,
109
+ quant_config: Optional[QuantizationConfig] = None,
110
+ cache_config=None,
111
+ ) -> None:
112
+ nn.Module.__init__(self)
113
+ self.config = config
114
+ self.quant_config = quant_config
115
+ self.model = Qwen2Model(config, quant_config=quant_config)
116
+ if self.config.tie_word_embeddings:
117
+ self.lm_head = self.model.embed_tokens
118
+ else:
119
+ self.lm_head = ParallelLMHead(
120
+ config.vocab_size, config.hidden_size, quant_config=quant_config
121
+ )
122
+ self.logits_processor = LogitsProcessor(config)
123
+
124
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
125
+ for name, loaded_weight in weights:
126
+ if "lm_head" not in name:
127
+ name = "model." + name
128
+ super().load_weights([(name, loaded_weight)])
129
+
130
+
131
+ EntryClass = [Qwen2ForCausalLMEagle]
@@ -3,6 +3,11 @@ from typing import List
3
3
  import torch
4
4
 
5
5
  from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
6
+ from sglang.srt.utils import is_cuda_available
7
+
8
+ is_cuda = is_cuda_available()
9
+ if is_cuda:
10
+ from sgl_kernel import sampling_scaling_penalties
6
11
 
7
12
 
8
13
  class BatchedRepetitionPenalizer(_BatchedPenalizer):
@@ -56,11 +61,16 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
56
61
  self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
57
62
 
58
63
  def _apply(self, logits: torch.Tensor) -> torch.Tensor:
59
- return torch.where(
60
- logits > 0,
61
- logits / self.cumulated_repetition_penalties,
62
- logits * self.cumulated_repetition_penalties,
63
- )
64
+ if is_cuda:
65
+ return sampling_scaling_penalties(
66
+ logits, self.cumulated_repetition_penalties
67
+ )
68
+ else:
69
+ return torch.where(
70
+ logits > 0,
71
+ logits / self.cumulated_repetition_penalties,
72
+ logits * self.cumulated_repetition_penalties,
73
+ )
64
74
 
65
75
  def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
66
76
  self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
@@ -7,6 +7,12 @@ from typing import TYPE_CHECKING, Callable, List, Optional
7
7
 
8
8
  import torch
9
9
 
10
+ from sglang.srt.utils import is_cuda_available
11
+
12
+ is_cuda = is_cuda_available()
13
+ if is_cuda:
14
+ from sgl_kernel import sampling_scaling_penalties
15
+
10
16
  import sglang.srt.sampling.penaltylib as penaltylib
11
17
 
12
18
  logger = logging.getLogger(__name__)
@@ -232,6 +238,7 @@ class SamplingBatchInfo:
232
238
  self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
233
239
  self.logit_bias, other.logit_bias, len(self), len(other), self.device
234
240
  )
241
+ self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling
235
242
 
236
243
  def apply_logits_bias(self, logits: torch.Tensor):
237
244
  # Apply logit_bias
@@ -244,11 +251,14 @@ class SamplingBatchInfo:
244
251
 
245
252
  # repetition
246
253
  if self.scaling_penalties is not None:
247
- logits[:] = torch.where(
248
- logits > 0,
249
- logits / self.scaling_penalties,
250
- logits * self.scaling_penalties,
251
- )
254
+ if is_cuda:
255
+ logits[:] = sampling_scaling_penalties(logits, self.scaling_penalties)
256
+ else:
257
+ logits[:] = torch.where(
258
+ logits > 0,
259
+ logits / self.scaling_penalties,
260
+ logits * self.scaling_penalties,
261
+ )
252
262
 
253
263
  # Apply regex vocab_mask
254
264
  if self.vocab_mask is not None:
@@ -23,7 +23,7 @@ class SamplingParams:
23
23
  The sampling parameters.
24
24
 
25
25
  See docs/references/sampling_params.md or
26
- https://sgl-project.github.io/references/sampling_params.html
26
+ https://docs.sglang.ai/references/sampling_params.html
27
27
  for the documentation.
28
28
  """
29
29