sglang 0.4.1.post5__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (129) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +24 -16
  4. sglang/bench_one_batch.py +51 -3
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +37 -28
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +15 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/model_config.py +16 -6
  13. sglang/srt/constrained/base_grammar_backend.py +21 -0
  14. sglang/srt/constrained/xgrammar_backend.py +8 -4
  15. sglang/srt/conversation.py +14 -1
  16. sglang/srt/distributed/__init__.py +3 -3
  17. sglang/srt/distributed/communication_op.py +2 -1
  18. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
  20. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  21. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  22. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  23. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  24. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  25. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  26. sglang/srt/distributed/parallel_state.py +1 -1
  27. sglang/srt/distributed/utils.py +2 -1
  28. sglang/srt/entrypoints/engine.py +449 -0
  29. sglang/srt/entrypoints/http_server.py +579 -0
  30. sglang/srt/layers/activation.py +3 -3
  31. sglang/srt/layers/attention/flashinfer_backend.py +27 -12
  32. sglang/srt/layers/attention/triton_backend.py +4 -6
  33. sglang/srt/layers/attention/vision.py +204 -0
  34. sglang/srt/layers/dp_attention.py +69 -0
  35. sglang/srt/layers/linear.py +76 -102
  36. sglang/srt/layers/logits_processor.py +48 -63
  37. sglang/srt/layers/moe/ep_moe/layer.py +4 -4
  38. sglang/srt/layers/moe/fused_moe_native.py +69 -0
  39. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
  40. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -14
  41. sglang/srt/layers/moe/topk.py +4 -2
  42. sglang/srt/layers/parameter.py +26 -17
  43. sglang/srt/layers/quantization/__init__.py +22 -23
  44. sglang/srt/layers/quantization/fp8.py +112 -55
  45. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  46. sglang/srt/layers/quantization/int8_kernel.py +54 -0
  47. sglang/srt/layers/quantization/modelopt_quant.py +2 -3
  48. sglang/srt/layers/quantization/w8a8_int8.py +117 -0
  49. sglang/srt/layers/radix_attention.py +2 -0
  50. sglang/srt/layers/rotary_embedding.py +1179 -31
  51. sglang/srt/layers/sampler.py +39 -1
  52. sglang/srt/layers/vocab_parallel_embedding.py +17 -4
  53. sglang/srt/lora/lora.py +1 -9
  54. sglang/srt/managers/configure_logging.py +46 -0
  55. sglang/srt/managers/data_parallel_controller.py +79 -72
  56. sglang/srt/managers/detokenizer_manager.py +23 -8
  57. sglang/srt/managers/image_processor.py +158 -2
  58. sglang/srt/managers/io_struct.py +54 -15
  59. sglang/srt/managers/schedule_batch.py +49 -22
  60. sglang/srt/managers/schedule_policy.py +26 -12
  61. sglang/srt/managers/scheduler.py +319 -181
  62. sglang/srt/managers/session_controller.py +1 -0
  63. sglang/srt/managers/tokenizer_manager.py +303 -158
  64. sglang/srt/managers/tp_worker.py +6 -4
  65. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  66. sglang/srt/managers/utils.py +44 -0
  67. sglang/srt/mem_cache/memory_pool.py +110 -77
  68. sglang/srt/metrics/collector.py +25 -11
  69. sglang/srt/model_executor/cuda_graph_runner.py +4 -6
  70. sglang/srt/model_executor/model_runner.py +80 -21
  71. sglang/srt/model_loader/loader.py +8 -6
  72. sglang/srt/model_loader/weight_utils.py +55 -2
  73. sglang/srt/models/baichuan.py +6 -6
  74. sglang/srt/models/chatglm.py +2 -2
  75. sglang/srt/models/commandr.py +3 -3
  76. sglang/srt/models/dbrx.py +4 -4
  77. sglang/srt/models/deepseek.py +3 -3
  78. sglang/srt/models/deepseek_v2.py +8 -8
  79. sglang/srt/models/exaone.py +2 -2
  80. sglang/srt/models/gemma.py +2 -2
  81. sglang/srt/models/gemma2.py +6 -24
  82. sglang/srt/models/gpt2.py +3 -5
  83. sglang/srt/models/gpt_bigcode.py +1 -1
  84. sglang/srt/models/granite.py +2 -2
  85. sglang/srt/models/grok.py +3 -3
  86. sglang/srt/models/internlm2.py +2 -2
  87. sglang/srt/models/llama.py +41 -4
  88. sglang/srt/models/minicpm.py +2 -2
  89. sglang/srt/models/minicpm3.py +6 -6
  90. sglang/srt/models/minicpmv.py +1238 -0
  91. sglang/srt/models/mixtral.py +3 -3
  92. sglang/srt/models/mixtral_quant.py +3 -3
  93. sglang/srt/models/mllama.py +2 -2
  94. sglang/srt/models/olmo.py +3 -3
  95. sglang/srt/models/olmo2.py +4 -4
  96. sglang/srt/models/olmoe.py +7 -13
  97. sglang/srt/models/phi3_small.py +2 -2
  98. sglang/srt/models/qwen.py +2 -2
  99. sglang/srt/models/qwen2.py +52 -4
  100. sglang/srt/models/qwen2_eagle.py +131 -0
  101. sglang/srt/models/qwen2_moe.py +3 -3
  102. sglang/srt/models/qwen2_vl.py +22 -122
  103. sglang/srt/models/stablelm.py +2 -2
  104. sglang/srt/models/torch_native_llama.py +3 -3
  105. sglang/srt/models/xverse.py +6 -6
  106. sglang/srt/models/xverse_moe.py +6 -6
  107. sglang/srt/openai_api/protocol.py +2 -0
  108. sglang/srt/sampling/custom_logit_processor.py +38 -0
  109. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
  110. sglang/srt/sampling/sampling_batch_info.py +153 -9
  111. sglang/srt/sampling/sampling_params.py +4 -2
  112. sglang/srt/server.py +4 -1037
  113. sglang/srt/server_args.py +84 -32
  114. sglang/srt/speculative/eagle_worker.py +1 -0
  115. sglang/srt/torch_memory_saver_adapter.py +59 -0
  116. sglang/srt/utils.py +130 -63
  117. sglang/test/runners.py +8 -13
  118. sglang/test/test_programs.py +1 -1
  119. sglang/test/test_utils.py +3 -1
  120. sglang/utils.py +12 -2
  121. sglang/version.py +1 -1
  122. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +26 -13
  123. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +126 -117
  124. sglang/launch_server_llavavid.py +0 -25
  125. sglang/srt/constrained/__init__.py +0 -16
  126. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  127. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
  128. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
  129. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,117 @@
1
+ from typing import Any, Dict, List, Optional
2
+
3
+ import torch
4
+
5
+ from sglang.srt.utils import is_cuda_available
6
+
7
+ is_cuda = is_cuda_available()
8
+ if is_cuda:
9
+ from sgl_kernel import int8_scaled_mm
10
+
11
+ from torch.nn.parameter import Parameter
12
+
13
+ from sglang.srt.layers.linear import LinearMethodBase
14
+ from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
15
+ from sglang.srt.layers.quantization.base_config import (
16
+ QuantizationConfig,
17
+ QuantizeMethodBase,
18
+ )
19
+ from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
20
+
21
+
22
+ class W8A8Int8Config(QuantizationConfig):
23
+ """Config class for W8A8 Int8 Quantization.
24
+
25
+ - Weight: static, per-channel, symmetric
26
+ - Activation: dynamic, per-token, symmetric
27
+ """
28
+
29
+ def __init__(self):
30
+ pass
31
+
32
+ @classmethod
33
+ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
34
+ return [torch.float16, torch.bfloat16]
35
+
36
+ @classmethod
37
+ def get_min_capability(cls) -> int:
38
+ return 75
39
+
40
+ @classmethod
41
+ def get_name(self) -> str:
42
+ return "w8a8_int8"
43
+
44
+ @classmethod
45
+ def get_config_filenames(cls) -> List[str]:
46
+ return []
47
+
48
+ @classmethod
49
+ def from_config(cls, config: Dict[str, Any]) -> "W8A8Int8Config":
50
+ return cls()
51
+
52
+ def get_quant_method(
53
+ self,
54
+ layer: torch.nn.Module,
55
+ prefix: str,
56
+ ) -> Optional["QuantizeMethodBase"]:
57
+ from sglang.srt.layers.linear import LinearBase
58
+
59
+ if isinstance(layer, LinearBase):
60
+ return W8A8Int8LinearMethod(self)
61
+ return None
62
+
63
+ def get_scaled_act_names(self) -> List[str]:
64
+ return []
65
+
66
+
67
+ class W8A8Int8LinearMethod(LinearMethodBase):
68
+
69
+ def __init__(self, quantization_config: W8A8Int8Config):
70
+ self.quantization_config = quantization_config
71
+
72
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
73
+ layer.weight = Parameter(layer.weight.t(), requires_grad=False)
74
+ layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
75
+
76
+ def create_weights(
77
+ self,
78
+ layer: torch.nn.Module,
79
+ input_size_per_partition: int,
80
+ output_partition_sizes: List[int],
81
+ input_size: int,
82
+ output_size: int,
83
+ params_dtype: torch.dtype,
84
+ **extra_weight_attrs
85
+ ):
86
+
87
+ weight_loader = extra_weight_attrs.get("weight_loader")
88
+ self.logical_widths = output_partition_sizes
89
+
90
+ weight = ModelWeightParameter(
91
+ data=torch.empty(
92
+ sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8
93
+ ),
94
+ input_dim=1,
95
+ output_dim=0,
96
+ weight_loader=weight_loader,
97
+ )
98
+ layer.register_parameter("weight", weight)
99
+
100
+ weight_scale = ChannelQuantScaleParameter(
101
+ data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
102
+ output_dim=0,
103
+ weight_loader=weight_loader,
104
+ )
105
+ layer.register_parameter("weight_scale", weight_scale)
106
+
107
+ def apply(
108
+ self,
109
+ layer: torch.nn.Module,
110
+ x: torch.Tensor,
111
+ bias: Optional[torch.Tensor] = None,
112
+ ):
113
+ x_q, x_scale = per_token_quant_int8(x)
114
+
115
+ return int8_scaled_mm(
116
+ x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
117
+ )
@@ -47,6 +47,8 @@ class RadixAttention(nn.Module):
47
47
  self.logit_cap = logit_cap
48
48
  self.sliding_window_size = sliding_window_size or -1
49
49
  self.is_cross_attention = is_cross_attention
50
+ self.k_scale = None
51
+ self.v_scale = None
50
52
 
51
53
  def forward(
52
54
  self,