sglang 0.3.5.post2__py3-none-any.whl → 0.3.6.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +2 -2
  3. sglang/bench_latency.py +1 -553
  4. sglang/bench_offline_throughput.py +48 -20
  5. sglang/bench_one_batch.py +472 -0
  6. sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
  7. sglang/bench_serving.py +125 -6
  8. sglang/check_env.py +3 -6
  9. sglang/lang/backend/base_backend.py +1 -1
  10. sglang/lang/backend/runtime_endpoint.py +2 -2
  11. sglang/srt/configs/model_config.py +13 -14
  12. sglang/srt/constrained/__init__.py +13 -14
  13. sglang/srt/constrained/base_grammar_backend.py +13 -15
  14. sglang/srt/constrained/outlines_backend.py +28 -17
  15. sglang/srt/constrained/outlines_jump_forward.py +13 -15
  16. sglang/srt/constrained/xgrammar_backend.py +47 -58
  17. sglang/srt/conversation.py +13 -15
  18. sglang/srt/hf_transformers_utils.py +13 -15
  19. sglang/srt/layers/activation.py +16 -13
  20. sglang/srt/layers/attention/flashinfer_backend.py +106 -54
  21. sglang/srt/layers/attention/triton_backend.py +9 -7
  22. sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
  23. sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
  24. sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
  25. sglang/srt/layers/custom_op_util.py +25 -0
  26. sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
  27. sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +11 -4
  28. sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
  29. sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
  30. sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
  31. sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
  32. sglang/srt/layers/fused_moe_triton/layer.py +633 -0
  33. sglang/srt/layers/layernorm.py +17 -15
  34. sglang/srt/layers/logits_processor.py +23 -25
  35. sglang/srt/layers/quantization/__init__.py +77 -17
  36. sglang/srt/layers/radix_attention.py +13 -15
  37. sglang/srt/layers/rotary_embedding.py +13 -13
  38. sglang/srt/layers/sampler.py +4 -8
  39. sglang/srt/layers/torchao_utils.py +2 -0
  40. sglang/srt/lora/lora.py +13 -14
  41. sglang/srt/lora/lora_config.py +13 -14
  42. sglang/srt/lora/lora_manager.py +22 -24
  43. sglang/srt/managers/data_parallel_controller.py +98 -27
  44. sglang/srt/managers/detokenizer_manager.py +13 -15
  45. sglang/srt/managers/io_struct.py +63 -21
  46. sglang/srt/managers/schedule_batch.py +154 -59
  47. sglang/srt/managers/schedule_policy.py +18 -16
  48. sglang/srt/managers/scheduler.py +278 -109
  49. sglang/srt/managers/session_controller.py +61 -0
  50. sglang/srt/managers/tokenizer_manager.py +63 -18
  51. sglang/srt/managers/tp_worker.py +25 -16
  52. sglang/srt/managers/tp_worker_overlap_thread.py +62 -67
  53. sglang/srt/metrics/collector.py +13 -15
  54. sglang/srt/metrics/func_timer.py +13 -15
  55. sglang/srt/mm_utils.py +13 -14
  56. sglang/srt/model_executor/cuda_graph_runner.py +63 -25
  57. sglang/srt/model_executor/forward_batch_info.py +128 -32
  58. sglang/srt/model_executor/model_runner.py +132 -64
  59. sglang/srt/model_parallel.py +98 -0
  60. sglang/srt/models/chatglm.py +15 -16
  61. sglang/srt/models/commandr.py +15 -16
  62. sglang/srt/models/dbrx.py +15 -16
  63. sglang/srt/models/deepseek.py +15 -15
  64. sglang/srt/models/deepseek_v2.py +162 -59
  65. sglang/srt/models/exaone.py +14 -15
  66. sglang/srt/models/gemma.py +14 -14
  67. sglang/srt/models/gemma2.py +31 -25
  68. sglang/srt/models/gemma2_reward.py +13 -14
  69. sglang/srt/models/gpt_bigcode.py +14 -14
  70. sglang/srt/models/grok.py +15 -15
  71. sglang/srt/models/internlm2.py +13 -15
  72. sglang/srt/models/internlm2_reward.py +13 -14
  73. sglang/srt/models/llama.py +21 -21
  74. sglang/srt/models/llama_classification.py +13 -14
  75. sglang/srt/models/llama_reward.py +13 -14
  76. sglang/srt/models/llava.py +14 -16
  77. sglang/srt/models/llavavid.py +14 -16
  78. sglang/srt/models/minicpm.py +13 -15
  79. sglang/srt/models/minicpm3.py +13 -15
  80. sglang/srt/models/mistral.py +13 -15
  81. sglang/srt/models/mixtral.py +15 -15
  82. sglang/srt/models/mixtral_quant.py +14 -14
  83. sglang/srt/models/olmo.py +22 -20
  84. sglang/srt/models/olmoe.py +23 -20
  85. sglang/srt/models/phi3_small.py +447 -0
  86. sglang/srt/models/qwen.py +14 -14
  87. sglang/srt/models/qwen2.py +22 -19
  88. sglang/srt/models/qwen2_moe.py +17 -18
  89. sglang/srt/models/qwen2_vl.py +13 -6
  90. sglang/srt/models/stablelm.py +18 -16
  91. sglang/srt/models/torch_native_llama.py +107 -93
  92. sglang/srt/models/xverse.py +13 -14
  93. sglang/srt/models/xverse_moe.py +15 -16
  94. sglang/srt/models/yivl.py +13 -15
  95. sglang/srt/openai_api/adapter.py +19 -17
  96. sglang/srt/openai_api/protocol.py +14 -16
  97. sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
  98. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
  99. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
  100. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
  101. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
  102. sglang/srt/sampling/sampling_batch_info.py +61 -57
  103. sglang/srt/sampling/sampling_params.py +14 -16
  104. sglang/srt/server.py +86 -35
  105. sglang/srt/server_args.py +96 -80
  106. sglang/srt/utils.py +266 -68
  107. sglang/test/few_shot_gsm8k.py +8 -4
  108. sglang/test/runners.py +38 -20
  109. sglang/test/srt/sampling/penaltylib/utils.py +23 -21
  110. sglang/test/test_utils.py +31 -20
  111. sglang/version.py +1 -1
  112. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/LICENSE +1 -1
  113. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/METADATA +66 -57
  114. sglang-0.3.6.post1.dist-info/RECORD +164 -0
  115. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/WHEEL +1 -1
  116. sglang/srt/layers/fused_moe/__init__.py +0 -1
  117. sglang-0.3.5.post2.dist-info/RECORD +0 -156
  118. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """Logits processing."""
17
15
 
18
16
  import dataclasses
@@ -62,21 +60,21 @@ class LogitsMetadata:
62
60
 
63
61
  @classmethod
64
62
  def from_forward_batch(cls, forward_batch: ForwardBatch):
63
+ extend_logprob_pruned_lens_cpu = None
64
+
65
65
  if forward_batch.return_logprob:
66
66
  return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
67
+ if forward_batch.forward_mode.is_extend():
68
+ extend_logprob_pruned_lens_cpu = [
69
+ extend_len - start_len
70
+ for extend_len, start_len in zip(
71
+ forward_batch.extend_seq_lens_cpu,
72
+ forward_batch.extend_logprob_start_lens_cpu,
73
+ )
74
+ ]
67
75
  else:
68
76
  return_top_logprob = False
69
77
 
70
- if forward_batch.forward_mode.is_extend():
71
- extend_logprob_pruned_lens_cpu = [
72
- extend_len - start_len
73
- for extend_len, start_len in zip(
74
- forward_batch.extend_seq_lens,
75
- forward_batch.extend_logprob_start_lens_cpu,
76
- )
77
- ]
78
- else:
79
- extend_logprob_pruned_lens_cpu = None
80
78
  return cls(
81
79
  forward_mode=forward_batch.forward_mode,
82
80
  top_logprobs_nums=forward_batch.top_logprobs_nums,
@@ -1,18 +1,19 @@
1
1
  # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
2
2
 
3
- from typing import Dict, Type
3
+ from typing import Callable, Dict, Optional, Type
4
4
 
5
+ import torch
5
6
  from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
6
7
  from vllm.model_executor.layers.quantization.awq import AWQConfig
7
8
  from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
8
9
  from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
9
- from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
10
+ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
10
11
  CompressedTensorsConfig,
11
12
  )
12
13
  from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
13
14
  from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
14
15
  from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
15
- from vllm.model_executor.layers.quantization.fp8 import Fp8Config
16
+ from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
16
17
  from vllm.model_executor.layers.quantization.gguf import GGUFConfig
17
18
  from vllm.model_executor.layers.quantization.gptq import GPTQConfig
18
19
  from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
@@ -30,8 +31,6 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
30
31
  "tpu_int8": Int8TpuConfig,
31
32
  "fp8": Fp8Config,
32
33
  "fbgemm_fp8": FBGEMMFp8Config,
33
- # The order of gptq methods is important for config.py iteration over
34
- # override_quantization_method(..)
35
34
  "marlin": MarlinConfig,
36
35
  "gguf": GGUFConfig,
37
36
  "gptq_marlin_24": GPTQMarlin24Config,
@@ -47,20 +46,68 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
47
46
 
48
47
  def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
49
48
  if quantization not in QUANTIZATION_METHODS:
50
- raise ValueError(f"Invalid quantization method: {quantization}")
49
+ raise ValueError(
50
+ f"Invalid quantization method: {quantization}. "
51
+ f"Available methods: {list(QUANTIZATION_METHODS.keys())}"
52
+ )
51
53
  return QUANTIZATION_METHODS[quantization]
52
54
 
53
55
 
54
- __all__ = [
55
- "QuantizationConfig",
56
- "get_quantization_config",
57
- "QUANTIZATION_METHODS",
58
- ]
56
+ def fp8_moe_apply(
57
+ self,
58
+ layer: torch.nn.Module,
59
+ x: torch.Tensor,
60
+ router_logits: torch.Tensor,
61
+ top_k: int,
62
+ renormalize: bool,
63
+ use_grouped_topk: bool,
64
+ topk_group: Optional[int] = None,
65
+ num_expert_group: Optional[int] = None,
66
+ custom_routing_function: Optional[Callable] = None,
67
+ ) -> torch.Tensor:
68
+ """Enhanced apply method for FP8 MoE."""
69
+ from sglang.srt.layers.fused_moe_triton import FusedMoE
70
+ from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
71
+
72
+ # Expert selection
73
+ topk_weights, topk_ids = FusedMoE.select_experts(
74
+ hidden_states=x,
75
+ router_logits=router_logits,
76
+ use_grouped_topk=use_grouped_topk,
77
+ top_k=top_k,
78
+ renormalize=renormalize,
79
+ topk_group=topk_group,
80
+ num_expert_group=num_expert_group,
81
+ custom_routing_function=custom_routing_function,
82
+ )
83
+
84
+ # Expert fusion with FP8 quantization
85
+ return fused_experts(
86
+ x,
87
+ layer.w13_weight,
88
+ layer.w2_weight,
89
+ topk_weights=topk_weights,
90
+ topk_ids=topk_ids,
91
+ inplace=True,
92
+ use_fp8_w8a8=True,
93
+ w1_scale=layer.w13_weight_scale,
94
+ w2_scale=layer.w2_weight_scale,
95
+ a1_scale=layer.w13_input_scale,
96
+ a2_scale=layer.w2_input_scale,
97
+ )
98
+
99
+
100
+ def fp8_get_quant_method(self, layer, prefix):
101
+ """Enhanced get_quant_method for FP8 config."""
102
+ from vllm.model_executor.layers.linear import LinearBase
103
+ from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
104
+ from vllm.model_executor.layers.quantization.utils.quant_utils import (
105
+ is_layer_skipped,
106
+ )
107
+
108
+ from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
109
+ from sglang.srt.layers.linear import UnquantizedLinearMethod
59
110
 
60
- """
61
- def fp8_get_quant_method(
62
- self, layer: torch.nn.Module, prefix: str
63
- ) -> Optional["QuantizeMethodBase"]:
64
111
  if isinstance(layer, LinearBase):
65
112
  if is_layer_skipped(prefix, self.ignored_layers):
66
113
  return UnquantizedLinearMethod()
@@ -70,5 +117,18 @@ def fp8_get_quant_method(
70
117
  return None
71
118
 
72
119
 
73
- setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
74
- """
120
+ def apply_monkey_patches():
121
+ """Apply all monkey patches in one place."""
122
+ setattr(Fp8MoEMethod, "apply", fp8_moe_apply)
123
+ setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
124
+
125
+
126
+ # Apply patches when module is imported
127
+ apply_monkey_patches()
128
+
129
+
130
+ __all__ = [
131
+ "QuantizationConfig",
132
+ "get_quantization_config",
133
+ "QUANTIZATION_METHODS",
134
+ ]
@@ -1,18 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """Radix attention."""
17
15
 
18
16
  from torch import nn
@@ -1,16 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
- http://www.apache.org/licenses/LICENSE-2.0
7
- Unless required by applicable law or agreed to in writing, software
8
- distributed under the License is distributed on an "AS IS" BASIS,
9
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
- See the License for the specific language governing permissions and
11
- limitations under the License.
12
- """
13
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
14
  """MRotaryEmbedding"""
15
15
  from typing import Any, Dict, List, Optional, Tuple, Union
16
16
 
@@ -1,5 +1,4 @@
1
1
  import logging
2
- import os
3
2
  from typing import Union
4
3
 
5
4
  import torch
@@ -8,7 +7,7 @@ from torch import nn
8
7
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
9
8
  from sglang.srt.managers.schedule_batch import global_server_args_dict
10
9
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
11
- from sglang.srt.utils import is_flashinfer_available
10
+ from sglang.srt.utils import crash_on_warnings, is_flashinfer_available
12
11
 
13
12
  if is_flashinfer_available():
14
13
  from flashinfer.sampling import (
@@ -19,17 +18,13 @@ if is_flashinfer_available():
19
18
  )
20
19
 
21
20
 
22
- # Crash on warning if we are running CI tests
23
- crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
24
-
25
-
26
21
  logger = logging.getLogger(__name__)
27
22
 
28
23
 
29
24
  class Sampler(nn.Module):
30
25
  def __init__(self):
31
26
  super().__init__()
32
- self.use_nan_detectioin = not global_server_args_dict["disable_nan_detection"]
27
+ self.use_nan_detectioin = global_server_args_dict["enable_nan_detection"]
33
28
 
34
29
  def forward(
35
30
  self,
@@ -46,7 +41,8 @@ class Sampler(nn.Module):
46
41
  logits = torch.where(
47
42
  torch.isnan(logits), torch.full_like(logits, -1e5), logits
48
43
  )
49
- exit(1) if crash_on_warning else None
44
+ if crash_on_warnings():
45
+ raise ValueError("Detected errors during sampling! NaN in the logits.")
50
46
 
51
47
  if sampling_info.is_all_greedy:
52
48
  # Use torch.argmax if all requests use greedy sampling
@@ -62,6 +62,8 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
62
62
  granularity=GRANULARITY_MAP[granularity]
63
63
  ),
64
64
  )
65
+ else:
66
+ raise ValueError(f"Unexpected config: {torchao_config}")
65
67
 
66
68
  return dummy_linear.weight
67
69
 
sglang/srt/lora/lora.py CHANGED
@@ -1,17 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
15
14
 
16
15
  # Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters"
17
16
  # and "Punica: Multi-Tenant LoRA Serving"
@@ -1,17 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
15
14
 
16
15
  import json
17
16
  import os
@@ -1,22 +1,20 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
15
14
 
16
15
  # Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters"
17
16
  # and "Punica: Multi-Tenant LoRA Serving"
18
17
 
19
-
20
18
  import logging
21
19
  import re
22
20
 
@@ -146,9 +144,9 @@ class LoRAManager:
146
144
  }
147
145
  else:
148
146
  logger.warning(
149
- f"WARNING: get_module_name() is not defined, "
150
- f"which is used to map config module name to model implementation module name."
151
- f"Use the default one, but please check if it is correct for your model."
147
+ "WARNING: get_module_name() is not defined, "
148
+ "which is used to map config module name to model implementation module name."
149
+ "Use the default one, but please check if it is correct for your model."
152
150
  )
153
151
  self.target_modules = {
154
152
  get_module_name(module) for module in self.origin_target_modules
@@ -194,9 +192,9 @@ class LoRAManager:
194
192
  hidden_dim_A, _ = self.base_model.get_hidden_dim(module_A)
195
193
  else:
196
194
  logger.warning(
197
- f"WARNING: get_hidden_dim() is not defined, "
198
- f"which is used to get the hidden dim for different lora modules"
199
- f"Use the default one, but please check if it is correct for your model."
195
+ "WARNING: get_hidden_dim() is not defined, "
196
+ "which is used to get the hidden dim for different lora modules"
197
+ "Use the default one, but please check if it is correct for your model."
200
198
  )
201
199
  hidden_dim_A, _ = get_hidden_dim(module_A, self.base_hf_config)
202
200
  c = self.loras[-1].get_stacked_multiply(module_A)
@@ -218,9 +216,9 @@ class LoRAManager:
218
216
  _, hidden_dim_B = self.base_model.get_hidden_dim(module_B)
219
217
  else:
220
218
  logger.warning(
221
- f"WARNING: get_hidden_dim() is not defined, "
222
- f"which is used to get the hidden dim for different lora modules"
223
- f"Use the default one, but please check if it is correct for your model."
219
+ "WARNING: get_hidden_dim() is not defined, "
220
+ "which is used to get the hidden dim for different lora modules"
221
+ "Use the default one, but please check if it is correct for your model."
224
222
  )
225
223
  _, hidden_dim_B = get_hidden_dim(module_B, self.base_hf_config)
226
224
  c = self.loras[-1].get_stacked_multiply(module_B)
@@ -1,22 +1,21 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """A controller that dispatches requests to multiple data parallel workers."""
17
15
 
18
16
  import logging
19
17
  import multiprocessing as mp
18
+ import threading
20
19
  from enum import Enum, auto
21
20
 
22
21
  import zmq
@@ -28,6 +27,7 @@ from sglang.srt.managers.io_struct import (
28
27
  from sglang.srt.managers.scheduler import run_scheduler_process
29
28
  from sglang.srt.server_args import PortArgs, ServerArgs
30
29
  from sglang.srt.utils import (
30
+ bind_port,
31
31
  configure_logger,
32
32
  get_zmq_socket,
33
33
  kill_parent_process,
@@ -80,20 +80,62 @@ class DataParallelController:
80
80
 
81
81
  # Start data parallel workers
82
82
  base_gpu_id = 0
83
- self.workers = []
83
+ self.workers = [None] * server_args.dp_size
84
+
85
+ threads = []
86
+ sockets = []
84
87
  for dp_rank in range(server_args.dp_size):
85
88
  tmp_port_args = PortArgs.init_new(server_args)
89
+ tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
86
90
  tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name
87
91
 
88
- send_to = self.launch_tensor_parallel_group(
89
- server_args,
90
- tmp_port_args,
91
- base_gpu_id,
92
- dp_rank,
92
+ if server_args.enable_dp_attention:
93
+ # Data parallelism resues the tensor parallelism group,
94
+ # so all dp ranks should use the same nccl port.
95
+ tmp_port_args.nccl_port = port_args.nccl_port
96
+ else:
97
+ # This port is checked free in PortArgs.init_new.
98
+ # We hold it first so that the next dp worker gets a different port
99
+ sockets.append(bind_port(tmp_port_args.nccl_port))
100
+
101
+ # Create a thread for each worker
102
+ thread = threading.Thread(
103
+ target=self.launch_worker_func,
104
+ args=(server_args, tmp_port_args, base_gpu_id, dp_rank),
93
105
  )
106
+ threads.append(thread)
107
+ base_gpu_id += 1 if server_args.enable_dp_attention else server_args.tp_size
108
+
109
+ # Free all sockets before starting the threads to launch TP workers
110
+ for sock in sockets:
111
+ sock.close()
112
+
113
+ # Start all threads
114
+ for thread in threads:
115
+ thread.start()
116
+ for thread in threads:
117
+ thread.join()
118
+
119
+ def launch_worker_func(
120
+ self,
121
+ server_args: ServerArgs,
122
+ port_args: PortArgs,
123
+ base_gpu_id: int,
124
+ dp_rank: int,
125
+ ):
126
+ logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
94
127
 
95
- self.workers.append(send_to)
96
- base_gpu_id += server_args.tp_size
128
+ launch_func_ = (
129
+ self.launch_tensor_parallel_process
130
+ if server_args.enable_dp_attention
131
+ else self.launch_tensor_parallel_group
132
+ )
133
+ self.workers[dp_rank] = launch_func_(
134
+ server_args,
135
+ port_args,
136
+ base_gpu_id,
137
+ dp_rank,
138
+ )
97
139
 
98
140
  def launch_tensor_parallel_group(
99
141
  self,
@@ -112,7 +154,7 @@ class DataParallelController:
112
154
  )
113
155
  for tp_rank in tp_rank_range:
114
156
  reader, writer = mp.Pipe(duplex=False)
115
- gpu_id = base_gpu_id + tp_rank % tp_size_per_node
157
+ gpu_id = server_args.base_gpu_id + base_gpu_id + tp_rank % tp_size_per_node
116
158
  proc = mp.Process(
117
159
  target=run_scheduler_process,
118
160
  args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
@@ -125,9 +167,36 @@ class DataParallelController:
125
167
  self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
126
168
  )
127
169
 
128
- # Wait for model to finish loading
170
+ # Wait for model to finish loading and get max token nums
171
+ scheduler_info = []
129
172
  for i in range(len(scheduler_pipe_readers)):
130
- scheduler_pipe_readers[i].recv()
173
+ scheduler_info.append(scheduler_pipe_readers[i].recv())
174
+
175
+ self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
176
+
177
+ return send_to
178
+
179
+ def launch_tensor_parallel_process(
180
+ self,
181
+ server_args: ServerArgs,
182
+ port_args: PortArgs,
183
+ base_gpu_id: int,
184
+ dp_rank: int,
185
+ ):
186
+ reader, writer = mp.Pipe(duplex=False)
187
+ gpu_id = base_gpu_id
188
+ tp_rank = dp_rank
189
+ proc = mp.Process(
190
+ target=run_scheduler_process,
191
+ args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
192
+ )
193
+ proc.start()
194
+ send_to = get_zmq_socket(
195
+ self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
196
+ )
197
+
198
+ scheduler_info = reader.recv()
199
+ self.max_total_num_tokens = scheduler_info["max_total_num_tokens"]
131
200
 
132
201
  return send_to
133
202
 
@@ -170,7 +239,9 @@ def run_data_parallel_controller_process(
170
239
 
171
240
  try:
172
241
  controller = DataParallelController(server_args, port_args)
173
- pipe_writer.send("ready")
242
+ pipe_writer.send(
243
+ {"status": "ready", "max_total_num_tokens": controller.max_total_num_tokens}
244
+ )
174
245
  controller.event_loop()
175
246
  except Exception:
176
247
  msg = get_exception_traceback()