sglang 0.3.5.post2__py3-none-any.whl → 0.3.6.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +2 -2
  3. sglang/bench_latency.py +1 -553
  4. sglang/bench_offline_throughput.py +48 -20
  5. sglang/bench_one_batch.py +472 -0
  6. sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
  7. sglang/bench_serving.py +125 -6
  8. sglang/check_env.py +3 -6
  9. sglang/lang/backend/base_backend.py +1 -1
  10. sglang/lang/backend/runtime_endpoint.py +2 -2
  11. sglang/srt/configs/model_config.py +13 -14
  12. sglang/srt/constrained/__init__.py +13 -14
  13. sglang/srt/constrained/base_grammar_backend.py +13 -15
  14. sglang/srt/constrained/outlines_backend.py +28 -17
  15. sglang/srt/constrained/outlines_jump_forward.py +13 -15
  16. sglang/srt/constrained/xgrammar_backend.py +47 -58
  17. sglang/srt/conversation.py +13 -15
  18. sglang/srt/hf_transformers_utils.py +13 -15
  19. sglang/srt/layers/activation.py +16 -13
  20. sglang/srt/layers/attention/flashinfer_backend.py +106 -54
  21. sglang/srt/layers/attention/triton_backend.py +9 -7
  22. sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
  23. sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
  24. sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
  25. sglang/srt/layers/custom_op_util.py +25 -0
  26. sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
  27. sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +11 -4
  28. sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
  29. sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
  30. sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
  31. sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
  32. sglang/srt/layers/fused_moe_triton/layer.py +633 -0
  33. sglang/srt/layers/layernorm.py +17 -15
  34. sglang/srt/layers/logits_processor.py +23 -25
  35. sglang/srt/layers/quantization/__init__.py +77 -17
  36. sglang/srt/layers/radix_attention.py +13 -15
  37. sglang/srt/layers/rotary_embedding.py +13 -13
  38. sglang/srt/layers/sampler.py +4 -8
  39. sglang/srt/layers/torchao_utils.py +2 -0
  40. sglang/srt/lora/lora.py +13 -14
  41. sglang/srt/lora/lora_config.py +13 -14
  42. sglang/srt/lora/lora_manager.py +22 -24
  43. sglang/srt/managers/data_parallel_controller.py +98 -27
  44. sglang/srt/managers/detokenizer_manager.py +13 -15
  45. sglang/srt/managers/io_struct.py +63 -21
  46. sglang/srt/managers/schedule_batch.py +154 -59
  47. sglang/srt/managers/schedule_policy.py +18 -16
  48. sglang/srt/managers/scheduler.py +278 -109
  49. sglang/srt/managers/session_controller.py +61 -0
  50. sglang/srt/managers/tokenizer_manager.py +63 -18
  51. sglang/srt/managers/tp_worker.py +25 -16
  52. sglang/srt/managers/tp_worker_overlap_thread.py +62 -67
  53. sglang/srt/metrics/collector.py +13 -15
  54. sglang/srt/metrics/func_timer.py +13 -15
  55. sglang/srt/mm_utils.py +13 -14
  56. sglang/srt/model_executor/cuda_graph_runner.py +63 -25
  57. sglang/srt/model_executor/forward_batch_info.py +128 -32
  58. sglang/srt/model_executor/model_runner.py +132 -64
  59. sglang/srt/model_parallel.py +98 -0
  60. sglang/srt/models/chatglm.py +15 -16
  61. sglang/srt/models/commandr.py +15 -16
  62. sglang/srt/models/dbrx.py +15 -16
  63. sglang/srt/models/deepseek.py +15 -15
  64. sglang/srt/models/deepseek_v2.py +162 -59
  65. sglang/srt/models/exaone.py +14 -15
  66. sglang/srt/models/gemma.py +14 -14
  67. sglang/srt/models/gemma2.py +31 -25
  68. sglang/srt/models/gemma2_reward.py +13 -14
  69. sglang/srt/models/gpt_bigcode.py +14 -14
  70. sglang/srt/models/grok.py +15 -15
  71. sglang/srt/models/internlm2.py +13 -15
  72. sglang/srt/models/internlm2_reward.py +13 -14
  73. sglang/srt/models/llama.py +21 -21
  74. sglang/srt/models/llama_classification.py +13 -14
  75. sglang/srt/models/llama_reward.py +13 -14
  76. sglang/srt/models/llava.py +14 -16
  77. sglang/srt/models/llavavid.py +14 -16
  78. sglang/srt/models/minicpm.py +13 -15
  79. sglang/srt/models/minicpm3.py +13 -15
  80. sglang/srt/models/mistral.py +13 -15
  81. sglang/srt/models/mixtral.py +15 -15
  82. sglang/srt/models/mixtral_quant.py +14 -14
  83. sglang/srt/models/olmo.py +22 -20
  84. sglang/srt/models/olmoe.py +23 -20
  85. sglang/srt/models/phi3_small.py +447 -0
  86. sglang/srt/models/qwen.py +14 -14
  87. sglang/srt/models/qwen2.py +22 -19
  88. sglang/srt/models/qwen2_moe.py +17 -18
  89. sglang/srt/models/qwen2_vl.py +13 -6
  90. sglang/srt/models/stablelm.py +18 -16
  91. sglang/srt/models/torch_native_llama.py +107 -93
  92. sglang/srt/models/xverse.py +13 -14
  93. sglang/srt/models/xverse_moe.py +15 -16
  94. sglang/srt/models/yivl.py +13 -15
  95. sglang/srt/openai_api/adapter.py +19 -17
  96. sglang/srt/openai_api/protocol.py +14 -16
  97. sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
  98. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
  99. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
  100. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
  101. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
  102. sglang/srt/sampling/sampling_batch_info.py +61 -57
  103. sglang/srt/sampling/sampling_params.py +14 -16
  104. sglang/srt/server.py +86 -35
  105. sglang/srt/server_args.py +96 -80
  106. sglang/srt/utils.py +266 -68
  107. sglang/test/few_shot_gsm8k.py +8 -4
  108. sglang/test/runners.py +38 -20
  109. sglang/test/srt/sampling/penaltylib/utils.py +23 -21
  110. sglang/test/test_utils.py +31 -20
  111. sglang/version.py +1 -1
  112. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/LICENSE +1 -1
  113. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/METADATA +66 -57
  114. sglang-0.3.6.post1.dist-info/RECORD +164 -0
  115. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/WHEEL +1 -1
  116. sglang/srt/layers/fused_moe/__init__.py +0 -1
  117. sglang-0.3.5.post2.dist-info/RECORD +0 -156
  118. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/top_level.txt +0 -0
@@ -1,35 +1,37 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
15
14
 
16
15
  # Adapted from:
17
16
  # https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
18
17
  """Inference-only DeepseekV2 model."""
18
+
19
19
  from typing import Any, Dict, Iterable, Optional, Tuple
20
20
 
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import PretrainedConfig
24
24
  from vllm.distributed import (
25
+ get_tensor_model_parallel_rank,
25
26
  get_tensor_model_parallel_world_size,
27
+ get_tp_group,
26
28
  tensor_model_parallel_all_reduce,
27
29
  )
28
- from vllm.model_executor.layers.fused_moe import FusedMoE
29
30
  from vllm.model_executor.layers.rotary_embedding import get_rope
30
31
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
31
32
 
32
33
  from sglang.srt.layers.activation import SiluAndMul
34
+ from sglang.srt.layers.fused_moe_triton import FusedMoE
33
35
  from sglang.srt.layers.layernorm import RMSNorm
34
36
  from sglang.srt.layers.linear import (
35
37
  ColumnParallelLinear,
@@ -338,6 +340,7 @@ class DeepseekV2AttentionMLA(nn.Module):
338
340
  cache_config=None,
339
341
  quant_config: Optional[QuantizationConfig] = None,
340
342
  layer_id=None,
343
+ use_dp=False,
341
344
  ) -> None:
342
345
  super().__init__()
343
346
  self.layer_id = layer_id
@@ -351,29 +354,80 @@ class DeepseekV2AttentionMLA(nn.Module):
351
354
  self.num_heads = num_heads
352
355
  tp_size = get_tensor_model_parallel_world_size()
353
356
  assert num_heads % tp_size == 0
354
- self.num_local_heads = num_heads // tp_size
357
+ self.num_local_heads = num_heads if use_dp else num_heads // tp_size
355
358
  self.scaling = self.qk_head_dim**-0.5
356
359
  self.rope_theta = rope_theta
357
360
  self.max_position_embeddings = max_position_embeddings
358
361
 
359
- if self.q_lora_rank is not None:
360
- self.q_a_proj = ReplicatedLinear(
361
- self.hidden_size,
362
- self.q_lora_rank,
362
+ if use_dp:
363
+ # For data parallel attention
364
+ if self.q_lora_rank is not None:
365
+ self.q_a_proj = ReplicatedLinear(
366
+ self.hidden_size,
367
+ self.q_lora_rank,
368
+ bias=False,
369
+ quant_config=quant_config,
370
+ )
371
+ self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
372
+ self.q_b_proj = ReplicatedLinear(
373
+ q_lora_rank,
374
+ self.num_heads * self.qk_head_dim,
375
+ bias=False,
376
+ quant_config=quant_config,
377
+ )
378
+ else:
379
+ self.q_proj = ReplicatedLinear(
380
+ self.hidden_size,
381
+ self.num_heads * self.qk_head_dim,
382
+ bias=False,
383
+ quant_config=quant_config,
384
+ )
385
+ self.kv_b_proj = ReplicatedLinear(
386
+ self.kv_lora_rank,
387
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
363
388
  bias=False,
364
389
  quant_config=quant_config,
365
390
  )
366
- self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
367
- self.q_b_proj = ColumnParallelLinear(
368
- q_lora_rank,
369
- self.num_heads * self.qk_head_dim,
391
+ # O projection.
392
+ self.o_proj = ReplicatedLinear(
393
+ self.num_heads * self.v_head_dim,
394
+ self.hidden_size,
370
395
  bias=False,
371
396
  quant_config=quant_config,
372
397
  )
373
398
  else:
374
- self.q_proj = ColumnParallelLinear(
399
+ # For tensor parallel attention
400
+ if self.q_lora_rank is not None:
401
+ self.q_a_proj = ReplicatedLinear(
402
+ self.hidden_size,
403
+ self.q_lora_rank,
404
+ bias=False,
405
+ quant_config=quant_config,
406
+ )
407
+ self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
408
+ self.q_b_proj = ColumnParallelLinear(
409
+ q_lora_rank,
410
+ self.num_heads * self.qk_head_dim,
411
+ bias=False,
412
+ quant_config=quant_config,
413
+ )
414
+ else:
415
+ self.q_proj = ColumnParallelLinear(
416
+ self.hidden_size,
417
+ self.num_heads * self.qk_head_dim,
418
+ bias=False,
419
+ quant_config=quant_config,
420
+ )
421
+ self.kv_b_proj = ColumnParallelLinear(
422
+ self.kv_lora_rank,
423
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
424
+ bias=False,
425
+ quant_config=quant_config,
426
+ )
427
+ # O projection.
428
+ self.o_proj = RowParallelLinear(
429
+ self.num_heads * self.v_head_dim,
375
430
  self.hidden_size,
376
- self.num_heads * self.qk_head_dim,
377
431
  bias=False,
378
432
  quant_config=quant_config,
379
433
  )
@@ -385,19 +439,6 @@ class DeepseekV2AttentionMLA(nn.Module):
385
439
  quant_config=quant_config,
386
440
  )
387
441
  self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
388
- self.kv_b_proj = ColumnParallelLinear(
389
- self.kv_lora_rank,
390
- self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
391
- bias=False,
392
- quant_config=quant_config,
393
- )
394
- # O projection.
395
- self.o_proj = RowParallelLinear(
396
- self.num_heads * self.v_head_dim,
397
- self.hidden_size,
398
- bias=False,
399
- quant_config=quant_config,
400
- )
401
442
  rope_scaling["rope_type"] = "deepseek_yarn"
402
443
  self.rotary_emb = get_rope(
403
444
  qk_rope_head_dim,
@@ -491,6 +532,36 @@ class DeepseekV2AttentionMLA(nn.Module):
491
532
  return output
492
533
 
493
534
 
535
+ def all_gather(
536
+ input_tensor: torch.Tensor, forward_batch: ForwardBatch, rank, world_size, group
537
+ ):
538
+ if world_size == 1:
539
+ return input_tensor
540
+
541
+ all_lens = forward_batch.global_num_tokens
542
+ max_len = max(forward_batch.global_num_tokens)
543
+
544
+ padded_tensor = torch.nn.functional.pad(
545
+ input_tensor, (0, 0, 0, max_len - input_tensor.shape[0])
546
+ )
547
+
548
+ torch.distributed.all_gather_into_tensor(
549
+ forward_batch.gathered_buffer, padded_tensor, group=group
550
+ )
551
+
552
+ gathered_tensors = torch.concat(
553
+ [
554
+ forward_batch.gathered_buffer[i * max_len : i * max_len + all_lens[i]]
555
+ for i in range(world_size)
556
+ ]
557
+ )
558
+
559
+ start_index = 0 if rank == 0 else sum(all_lens[:rank])
560
+ end_index = start_index + all_lens[rank]
561
+
562
+ return gathered_tensors, start_index, end_index
563
+
564
+
494
565
  class DeepseekV2DecoderLayer(nn.Module):
495
566
 
496
567
  def __init__(
@@ -505,6 +576,14 @@ class DeepseekV2DecoderLayer(nn.Module):
505
576
  rope_theta = getattr(config, "rope_theta", 10000)
506
577
  rope_scaling = getattr(config, "rope_scaling", None)
507
578
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
579
+ self.enable_dp_attention = (
580
+ not global_server_args_dict["disable_mla"]
581
+ and global_server_args_dict["enable_dp_attention"]
582
+ )
583
+ if self.enable_dp_attention:
584
+ self.tp_rank = get_tensor_model_parallel_rank()
585
+ self.tp_size = get_tensor_model_parallel_world_size()
586
+ self.tp_group = get_tp_group().device_group
508
587
  if not global_server_args_dict["disable_mla"]:
509
588
  self.self_attn = DeepseekV2AttentionMLA(
510
589
  config=config,
@@ -523,6 +602,7 @@ class DeepseekV2DecoderLayer(nn.Module):
523
602
  cache_config=cache_config,
524
603
  quant_config=quant_config,
525
604
  layer_id=layer_id,
605
+ use_dp=self.enable_dp_attention,
526
606
  )
527
607
  else:
528
608
  self.self_attn = DeepseekV2Attention(
@@ -569,20 +649,32 @@ class DeepseekV2DecoderLayer(nn.Module):
569
649
  residual: Optional[torch.Tensor],
570
650
  ) -> torch.Tensor:
571
651
  # Self Attention
572
- if residual is None:
573
- residual = hidden_states
574
- hidden_states = self.input_layernorm(hidden_states)
575
- else:
576
- hidden_states, residual = self.input_layernorm(hidden_states, residual)
577
- hidden_states = self.self_attn(
578
- positions=positions,
579
- hidden_states=hidden_states,
580
- forward_batch=forward_batch,
581
- )
652
+ if not forward_batch.forward_mode.is_idle():
653
+ if residual is None:
654
+ residual = hidden_states
655
+ hidden_states = self.input_layernorm(hidden_states)
656
+ else:
657
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
658
+
659
+ hidden_states = self.self_attn(
660
+ positions=positions,
661
+ hidden_states=hidden_states,
662
+ forward_batch=forward_batch,
663
+ )
664
+ hidden_states, residual = self.post_attention_layernorm(
665
+ hidden_states, residual
666
+ )
582
667
 
583
668
  # Fully Connected
584
- hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
585
- hidden_states = self.mlp(hidden_states)
669
+ if self.enable_dp_attention:
670
+ hidden_states, start_idx, end_idx = all_gather(
671
+ hidden_states, forward_batch, self.tp_rank, self.tp_size, self.tp_group
672
+ )
673
+ hidden_states = self.mlp(hidden_states)
674
+ hidden_states = hidden_states[start_idx:end_idx]
675
+ else:
676
+ hidden_states = self.mlp(hidden_states)
677
+
586
678
  return hidden_states, residual
587
679
 
588
680
 
@@ -603,6 +695,7 @@ class DeepseekV2Model(nn.Module):
603
695
  self.embed_tokens = VocabParallelEmbedding(
604
696
  config.vocab_size,
605
697
  config.hidden_size,
698
+ enable_tp=not global_server_args_dict["enable_dp_attention"],
606
699
  )
607
700
  self.layers = nn.ModuleList(
608
701
  [
@@ -630,7 +723,8 @@ class DeepseekV2Model(nn.Module):
630
723
  hidden_states, residual = layer(
631
724
  positions, hidden_states, forward_batch, residual
632
725
  )
633
- hidden_states, _ = self.norm(hidden_states, residual)
726
+ if not forward_batch.forward_mode.is_idle():
727
+ hidden_states, _ = self.norm(hidden_states, residual)
634
728
  return hidden_states
635
729
 
636
730
 
@@ -646,10 +740,18 @@ class DeepseekV2ForCausalLM(nn.Module):
646
740
  self.config = config
647
741
  self.quant_config = quant_config
648
742
  self.model = DeepseekV2Model(config, cache_config, quant_config)
649
- self.lm_head = ParallelLMHead(
650
- config.vocab_size, config.hidden_size, quant_config=quant_config
651
- )
652
- self.logits_processor = LogitsProcessor(config)
743
+ if global_server_args_dict["enable_dp_attention"]:
744
+ self.lm_head = ReplicatedLinear(
745
+ config.hidden_size,
746
+ config.vocab_size,
747
+ bias=False,
748
+ )
749
+ self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
750
+ else:
751
+ self.lm_head = ParallelLMHead(
752
+ config.vocab_size, config.hidden_size, quant_config=quant_config
753
+ )
754
+ self.logits_processor = LogitsProcessor(config)
653
755
 
654
756
  @torch.no_grad()
655
757
  def forward(
@@ -659,9 +761,10 @@ class DeepseekV2ForCausalLM(nn.Module):
659
761
  forward_batch: ForwardBatch,
660
762
  ) -> torch.Tensor:
661
763
  hidden_states = self.model(input_ids, positions, forward_batch)
662
- return self.logits_processor(
663
- input_ids, hidden_states, self.lm_head.weight, forward_batch
664
- )
764
+ if not forward_batch.forward_mode.is_idle():
765
+ return self.logits_processor(
766
+ input_ids, hidden_states, self.lm_head.weight, forward_batch
767
+ )
665
768
 
666
769
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
667
770
  stacked_params_mapping = [
@@ -1,18 +1,17 @@
1
- """
2
- Copyright 2024 The LGcns AI Engineering Team
3
- Copyright 2023-2024 SGLang Team
4
- Licensed under the Apache License, Version 2.0 (the "License");
5
- you may not use this file except in compliance with the License.
6
- You may obtain a copy of the License at
7
-
8
- http://www.apache.org/licenses/LICENSE-2.0
9
-
10
- Unless required by applicable law or agreed to in writing, software
11
- distributed under the License is distributed on an "AS IS" BASIS,
12
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- See the License for the specific language governing permissions and
14
- limitations under the License.
15
- """
1
+ # Copyright 2024 The LGcns AI Engineering Team
2
+ # Copyright 2023-2024 SGLang Team
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
16
15
 
17
16
  # Adapted from llama2.py
18
17
  """Inference-only Exaone model compatible with HuggingFace weights."""
@@ -1,21 +1,21 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
15
14
 
16
15
  # Adapted from:
17
16
  # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/gemma.py#L1
18
17
  """Inference-only Gemma model compatible with HuggingFace weights."""
18
+
19
19
  from typing import Iterable, Optional, Tuple
20
20
 
21
21
  import torch
@@ -1,20 +1,20 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
15
14
 
16
15
  # Adapted from:
17
16
  # https://github.com/vllm-project/vllm/blob/56b325e977435af744f8b3dca7af0ca209663558/vllm/model_executor/models/gemma2.py
17
+
18
18
  from typing import Iterable, Optional, Set, Tuple, Union
19
19
 
20
20
  import torch
@@ -38,6 +38,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
38
38
  from sglang.srt.layers.radix_attention import RadixAttention
39
39
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
40
40
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
41
+ from sglang.srt.utils import make_layers
41
42
 
42
43
 
43
44
  # Aligned with HF's implementation, using sliding window inclusive with the last token
@@ -97,7 +98,7 @@ class Gemma2MLP(nn.Module):
97
98
  class Gemma2Attention(nn.Module):
98
99
  def __init__(
99
100
  self,
100
- layer_idx: int,
101
+ layer_id: int,
101
102
  config: PretrainedConfig,
102
103
  hidden_size: int,
103
104
  num_heads: int,
@@ -109,7 +110,7 @@ class Gemma2Attention(nn.Module):
109
110
  quant_config: Optional[QuantizationConfig] = None,
110
111
  ) -> None:
111
112
  super().__init__()
112
- self.layer_idx = layer_idx
113
+ self.layer_id = layer_id
113
114
  self.config = config
114
115
  self.hidden_size = hidden_size
115
116
  tp_size = get_tensor_model_parallel_world_size()
@@ -156,13 +157,13 @@ class Gemma2Attention(nn.Module):
156
157
  dtype=torch.get_default_dtype(),
157
158
  )
158
159
 
159
- use_sliding_window = layer_idx % 2 == 0 and hasattr(config, "sliding_window")
160
+ use_sliding_window = layer_id % 2 == 0 and hasattr(config, "sliding_window")
160
161
  self.attn = RadixAttention(
161
162
  self.num_heads,
162
163
  self.head_dim,
163
164
  self.scaling,
164
165
  num_kv_heads=self.num_kv_heads,
165
- layer_id=layer_idx,
166
+ layer_id=layer_id,
166
167
  logit_cap=self.config.attn_logit_softcapping,
167
168
  sliding_window_size=(
168
169
  get_attention_sliding_window_size(config)
@@ -188,7 +189,7 @@ class Gemma2Attention(nn.Module):
188
189
  class Gemma2DecoderLayer(nn.Module):
189
190
  def __init__(
190
191
  self,
191
- layer_idx: int,
192
+ layer_id: int,
192
193
  config: PretrainedConfig,
193
194
  cache_config=None,
194
195
  quant_config: Optional[QuantizationConfig] = None,
@@ -196,7 +197,7 @@ class Gemma2DecoderLayer(nn.Module):
196
197
  super().__init__()
197
198
  self.hidden_size = config.hidden_size
198
199
  self.self_attn = Gemma2Attention(
199
- layer_idx=layer_idx,
200
+ layer_id=layer_id,
200
201
  config=config,
201
202
  hidden_size=self.hidden_size,
202
203
  num_heads=config.num_attention_heads,
@@ -267,11 +268,15 @@ class Gemma2Model(nn.Module):
267
268
  config.vocab_size,
268
269
  config.hidden_size,
269
270
  )
270
- self.layers = nn.ModuleList(
271
- [
272
- Gemma2DecoderLayer(layer_idx, config, cache_config, quant_config)
273
- for layer_idx in range(config.num_hidden_layers)
274
- ]
271
+ self.layers = make_layers(
272
+ config.num_hidden_layers,
273
+ lambda idx, prefix: Gemma2DecoderLayer(
274
+ layer_id=idx,
275
+ config=config,
276
+ cache_config=cache_config,
277
+ quant_config=quant_config,
278
+ ),
279
+ prefix="",
275
280
  )
276
281
  self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
277
282
 
@@ -332,6 +337,7 @@ class Gemma2ForCausalLM(nn.Module):
332
337
  # Gemma does not apply LoRA to the embedding layer.
333
338
  embedding_modules = {}
334
339
  embedding_padding_modules = []
340
+ supports_lora = True
335
341
 
336
342
  def __init__(
337
343
  self,
@@ -1,17 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
15
14
 
16
15
  from typing import Iterable, Optional, Tuple
17
16
 
@@ -1,21 +1,21 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
15
14
 
16
15
  # Adapted from:
17
16
  # https://github.com/vllm-project/vllm/blob/07eb6f19f3b0ee9f7adf6eb689607028aa40bfd5/vllm/model_executor/models/gpt_bigcode.py
18
17
  """Inference-only GPTBigCode model compatible with HuggingFace weights."""
18
+
19
19
  from typing import Iterable, Optional, Tuple
20
20
 
21
21
  import torch
sglang/srt/models/grok.py CHANGED
@@ -1,21 +1,21 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
15
14
 
16
15
  # Adapted from
17
16
  # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
18
17
  """Inference-only Grok1 model."""
18
+
19
19
  import warnings
20
20
  from typing import Iterable, List, Optional, Tuple
21
21
 
@@ -31,7 +31,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
31
31
  from vllm.model_executor.model_loader.loader import DefaultModelLoader
32
32
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
33
33
 
34
- from sglang.srt.layers.fused_moe import FusedMoE
34
+ from sglang.srt.layers.fused_moe_grok import FusedMoE
35
35
  from sglang.srt.layers.layernorm import RMSNorm
36
36
  from sglang.srt.layers.linear import (
37
37
  QKVParallelLinear,