sglang 0.3.6__py3-none-any.whl → 0.3.6.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +2 -2
  3. sglang/bench_one_batch.py +2 -4
  4. sglang/bench_serving.py +75 -26
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/runtime_endpoint.py +2 -2
  7. sglang/srt/configs/model_config.py +13 -14
  8. sglang/srt/constrained/__init__.py +13 -14
  9. sglang/srt/constrained/base_grammar_backend.py +13 -15
  10. sglang/srt/constrained/outlines_backend.py +13 -15
  11. sglang/srt/constrained/outlines_jump_forward.py +13 -15
  12. sglang/srt/constrained/xgrammar_backend.py +38 -57
  13. sglang/srt/conversation.py +13 -15
  14. sglang/srt/hf_transformers_utils.py +13 -15
  15. sglang/srt/layers/activation.py +13 -13
  16. sglang/srt/layers/attention/flashinfer_backend.py +13 -6
  17. sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
  18. sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
  19. sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
  20. sglang/srt/layers/custom_op_util.py +13 -14
  21. sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
  22. sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
  23. sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
  24. sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
  25. sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
  26. sglang/srt/layers/fused_moe_triton/layer.py +633 -0
  27. sglang/srt/layers/layernorm.py +13 -15
  28. sglang/srt/layers/logits_processor.py +13 -15
  29. sglang/srt/layers/quantization/__init__.py +77 -17
  30. sglang/srt/layers/radix_attention.py +13 -15
  31. sglang/srt/layers/rotary_embedding.py +13 -13
  32. sglang/srt/lora/lora.py +13 -14
  33. sglang/srt/lora/lora_config.py +13 -14
  34. sglang/srt/lora/lora_manager.py +22 -24
  35. sglang/srt/managers/data_parallel_controller.py +25 -19
  36. sglang/srt/managers/detokenizer_manager.py +13 -16
  37. sglang/srt/managers/io_struct.py +43 -28
  38. sglang/srt/managers/schedule_batch.py +55 -26
  39. sglang/srt/managers/schedule_policy.py +13 -15
  40. sglang/srt/managers/scheduler.py +89 -70
  41. sglang/srt/managers/session_controller.py +14 -15
  42. sglang/srt/managers/tokenizer_manager.py +29 -22
  43. sglang/srt/managers/tp_worker.py +13 -15
  44. sglang/srt/managers/tp_worker_overlap_thread.py +13 -15
  45. sglang/srt/metrics/collector.py +13 -15
  46. sglang/srt/metrics/func_timer.py +13 -15
  47. sglang/srt/mm_utils.py +13 -14
  48. sglang/srt/model_executor/cuda_graph_runner.py +20 -19
  49. sglang/srt/model_executor/forward_batch_info.py +19 -17
  50. sglang/srt/model_executor/model_runner.py +42 -30
  51. sglang/srt/models/chatglm.py +15 -16
  52. sglang/srt/models/commandr.py +15 -16
  53. sglang/srt/models/dbrx.py +15 -16
  54. sglang/srt/models/deepseek.py +15 -15
  55. sglang/srt/models/deepseek_v2.py +15 -15
  56. sglang/srt/models/exaone.py +14 -15
  57. sglang/srt/models/gemma.py +14 -14
  58. sglang/srt/models/gemma2.py +24 -19
  59. sglang/srt/models/gemma2_reward.py +13 -14
  60. sglang/srt/models/gpt_bigcode.py +14 -14
  61. sglang/srt/models/grok.py +15 -15
  62. sglang/srt/models/internlm2.py +13 -15
  63. sglang/srt/models/internlm2_reward.py +13 -14
  64. sglang/srt/models/llama.py +21 -21
  65. sglang/srt/models/llama_classification.py +13 -14
  66. sglang/srt/models/llama_reward.py +13 -14
  67. sglang/srt/models/llava.py +13 -15
  68. sglang/srt/models/llavavid.py +13 -15
  69. sglang/srt/models/minicpm.py +13 -15
  70. sglang/srt/models/minicpm3.py +13 -15
  71. sglang/srt/models/mistral.py +13 -15
  72. sglang/srt/models/mixtral.py +15 -15
  73. sglang/srt/models/mixtral_quant.py +14 -14
  74. sglang/srt/models/olmo.py +21 -19
  75. sglang/srt/models/olmoe.py +23 -20
  76. sglang/srt/models/qwen.py +14 -14
  77. sglang/srt/models/qwen2.py +22 -19
  78. sglang/srt/models/qwen2_moe.py +17 -18
  79. sglang/srt/models/stablelm.py +18 -16
  80. sglang/srt/models/torch_native_llama.py +15 -17
  81. sglang/srt/models/xverse.py +13 -14
  82. sglang/srt/models/xverse_moe.py +15 -16
  83. sglang/srt/models/yivl.py +13 -15
  84. sglang/srt/openai_api/adapter.py +13 -15
  85. sglang/srt/openai_api/protocol.py +13 -15
  86. sglang/srt/sampling/sampling_batch_info.py +4 -1
  87. sglang/srt/sampling/sampling_params.py +13 -15
  88. sglang/srt/server.py +59 -34
  89. sglang/srt/server_args.py +22 -22
  90. sglang/srt/utils.py +196 -17
  91. sglang/test/few_shot_gsm8k.py +8 -4
  92. sglang/test/runners.py +13 -14
  93. sglang/test/test_utils.py +1 -1
  94. sglang/version.py +1 -1
  95. {sglang-0.3.6.dist-info → sglang-0.3.6.post1.dist-info}/LICENSE +1 -1
  96. {sglang-0.3.6.dist-info → sglang-0.3.6.post1.dist-info}/METADATA +24 -15
  97. sglang-0.3.6.post1.dist-info/RECORD +164 -0
  98. sglang/srt/layers/fused_moe/__init__.py +0 -1
  99. sglang-0.3.6.dist-info/RECORD +0 -161
  100. /sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +0 -0
  101. {sglang-0.3.6.dist-info → sglang-0.3.6.post1.dist-info}/WHEEL +0 -0
  102. {sglang-0.3.6.dist-info → sglang-0.3.6.post1.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,19 @@
1
1
  # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
2
2
 
3
- from typing import Dict, Type
3
+ from typing import Callable, Dict, Optional, Type
4
4
 
5
+ import torch
5
6
  from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
6
7
  from vllm.model_executor.layers.quantization.awq import AWQConfig
7
8
  from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
8
9
  from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
9
- from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
10
+ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
10
11
  CompressedTensorsConfig,
11
12
  )
12
13
  from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
13
14
  from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
14
15
  from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
15
- from vllm.model_executor.layers.quantization.fp8 import Fp8Config
16
+ from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
16
17
  from vllm.model_executor.layers.quantization.gguf import GGUFConfig
17
18
  from vllm.model_executor.layers.quantization.gptq import GPTQConfig
18
19
  from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
@@ -30,8 +31,6 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
30
31
  "tpu_int8": Int8TpuConfig,
31
32
  "fp8": Fp8Config,
32
33
  "fbgemm_fp8": FBGEMMFp8Config,
33
- # The order of gptq methods is important for config.py iteration over
34
- # override_quantization_method(..)
35
34
  "marlin": MarlinConfig,
36
35
  "gguf": GGUFConfig,
37
36
  "gptq_marlin_24": GPTQMarlin24Config,
@@ -47,20 +46,68 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
47
46
 
48
47
  def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
49
48
  if quantization not in QUANTIZATION_METHODS:
50
- raise ValueError(f"Invalid quantization method: {quantization}")
49
+ raise ValueError(
50
+ f"Invalid quantization method: {quantization}. "
51
+ f"Available methods: {list(QUANTIZATION_METHODS.keys())}"
52
+ )
51
53
  return QUANTIZATION_METHODS[quantization]
52
54
 
53
55
 
54
- __all__ = [
55
- "QuantizationConfig",
56
- "get_quantization_config",
57
- "QUANTIZATION_METHODS",
58
- ]
56
+ def fp8_moe_apply(
57
+ self,
58
+ layer: torch.nn.Module,
59
+ x: torch.Tensor,
60
+ router_logits: torch.Tensor,
61
+ top_k: int,
62
+ renormalize: bool,
63
+ use_grouped_topk: bool,
64
+ topk_group: Optional[int] = None,
65
+ num_expert_group: Optional[int] = None,
66
+ custom_routing_function: Optional[Callable] = None,
67
+ ) -> torch.Tensor:
68
+ """Enhanced apply method for FP8 MoE."""
69
+ from sglang.srt.layers.fused_moe_triton import FusedMoE
70
+ from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
71
+
72
+ # Expert selection
73
+ topk_weights, topk_ids = FusedMoE.select_experts(
74
+ hidden_states=x,
75
+ router_logits=router_logits,
76
+ use_grouped_topk=use_grouped_topk,
77
+ top_k=top_k,
78
+ renormalize=renormalize,
79
+ topk_group=topk_group,
80
+ num_expert_group=num_expert_group,
81
+ custom_routing_function=custom_routing_function,
82
+ )
83
+
84
+ # Expert fusion with FP8 quantization
85
+ return fused_experts(
86
+ x,
87
+ layer.w13_weight,
88
+ layer.w2_weight,
89
+ topk_weights=topk_weights,
90
+ topk_ids=topk_ids,
91
+ inplace=True,
92
+ use_fp8_w8a8=True,
93
+ w1_scale=layer.w13_weight_scale,
94
+ w2_scale=layer.w2_weight_scale,
95
+ a1_scale=layer.w13_input_scale,
96
+ a2_scale=layer.w2_input_scale,
97
+ )
98
+
99
+
100
+ def fp8_get_quant_method(self, layer, prefix):
101
+ """Enhanced get_quant_method for FP8 config."""
102
+ from vllm.model_executor.layers.linear import LinearBase
103
+ from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
104
+ from vllm.model_executor.layers.quantization.utils.quant_utils import (
105
+ is_layer_skipped,
106
+ )
107
+
108
+ from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
109
+ from sglang.srt.layers.linear import UnquantizedLinearMethod
59
110
 
60
- """
61
- def fp8_get_quant_method(
62
- self, layer: torch.nn.Module, prefix: str
63
- ) -> Optional["QuantizeMethodBase"]:
64
111
  if isinstance(layer, LinearBase):
65
112
  if is_layer_skipped(prefix, self.ignored_layers):
66
113
  return UnquantizedLinearMethod()
@@ -70,5 +117,18 @@ def fp8_get_quant_method(
70
117
  return None
71
118
 
72
119
 
73
- setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
74
- """
120
+ def apply_monkey_patches():
121
+ """Apply all monkey patches in one place."""
122
+ setattr(Fp8MoEMethod, "apply", fp8_moe_apply)
123
+ setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
124
+
125
+
126
+ # Apply patches when module is imported
127
+ apply_monkey_patches()
128
+
129
+
130
+ __all__ = [
131
+ "QuantizationConfig",
132
+ "get_quantization_config",
133
+ "QUANTIZATION_METHODS",
134
+ ]
@@ -1,18 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """Radix attention."""
17
15
 
18
16
  from torch import nn
@@ -1,16 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
- http://www.apache.org/licenses/LICENSE-2.0
7
- Unless required by applicable law or agreed to in writing, software
8
- distributed under the License is distributed on an "AS IS" BASIS,
9
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
- See the License for the specific language governing permissions and
11
- limitations under the License.
12
- """
13
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
14
  """MRotaryEmbedding"""
15
15
  from typing import Any, Dict, List, Optional, Tuple, Union
16
16
 
sglang/srt/lora/lora.py CHANGED
@@ -1,17 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
15
14
 
16
15
  # Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters"
17
16
  # and "Punica: Multi-Tenant LoRA Serving"
@@ -1,17 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
15
14
 
16
15
  import json
17
16
  import os
@@ -1,22 +1,20 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
15
14
 
16
15
  # Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters"
17
16
  # and "Punica: Multi-Tenant LoRA Serving"
18
17
 
19
-
20
18
  import logging
21
19
  import re
22
20
 
@@ -146,9 +144,9 @@ class LoRAManager:
146
144
  }
147
145
  else:
148
146
  logger.warning(
149
- f"WARNING: get_module_name() is not defined, "
150
- f"which is used to map config module name to model implementation module name."
151
- f"Use the default one, but please check if it is correct for your model."
147
+ "WARNING: get_module_name() is not defined, "
148
+ "which is used to map config module name to model implementation module name."
149
+ "Use the default one, but please check if it is correct for your model."
152
150
  )
153
151
  self.target_modules = {
154
152
  get_module_name(module) for module in self.origin_target_modules
@@ -194,9 +192,9 @@ class LoRAManager:
194
192
  hidden_dim_A, _ = self.base_model.get_hidden_dim(module_A)
195
193
  else:
196
194
  logger.warning(
197
- f"WARNING: get_hidden_dim() is not defined, "
198
- f"which is used to get the hidden dim for different lora modules"
199
- f"Use the default one, but please check if it is correct for your model."
195
+ "WARNING: get_hidden_dim() is not defined, "
196
+ "which is used to get the hidden dim for different lora modules"
197
+ "Use the default one, but please check if it is correct for your model."
200
198
  )
201
199
  hidden_dim_A, _ = get_hidden_dim(module_A, self.base_hf_config)
202
200
  c = self.loras[-1].get_stacked_multiply(module_A)
@@ -218,9 +216,9 @@ class LoRAManager:
218
216
  _, hidden_dim_B = self.base_model.get_hidden_dim(module_B)
219
217
  else:
220
218
  logger.warning(
221
- f"WARNING: get_hidden_dim() is not defined, "
222
- f"which is used to get the hidden dim for different lora modules"
223
- f"Use the default one, but please check if it is correct for your model."
219
+ "WARNING: get_hidden_dim() is not defined, "
220
+ "which is used to get the hidden dim for different lora modules"
221
+ "Use the default one, but please check if it is correct for your model."
224
222
  )
225
223
  _, hidden_dim_B = get_hidden_dim(module_B, self.base_hf_config)
226
224
  c = self.loras[-1].get_stacked_multiply(module_B)
@@ -1,18 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """A controller that dispatches requests to multiple data parallel workers."""
17
15
 
18
16
  import logging
@@ -169,9 +167,12 @@ class DataParallelController:
169
167
  self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
170
168
  )
171
169
 
172
- # Wait for model to finish loading
170
+ # Wait for model to finish loading and get max token nums
171
+ scheduler_info = []
173
172
  for i in range(len(scheduler_pipe_readers)):
174
- scheduler_pipe_readers[i].recv()
173
+ scheduler_info.append(scheduler_pipe_readers[i].recv())
174
+
175
+ self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
175
176
 
176
177
  return send_to
177
178
 
@@ -193,7 +194,10 @@ class DataParallelController:
193
194
  send_to = get_zmq_socket(
194
195
  self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
195
196
  )
196
- reader.recv()
197
+
198
+ scheduler_info = reader.recv()
199
+ self.max_total_num_tokens = scheduler_info["max_total_num_tokens"]
200
+
197
201
  return send_to
198
202
 
199
203
  def round_robin_scheduler(self, req):
@@ -235,7 +239,9 @@ def run_data_parallel_controller_process(
235
239
 
236
240
  try:
237
241
  controller = DataParallelController(server_args, port_args)
238
- pipe_writer.send("ready")
242
+ pipe_writer.send(
243
+ {"status": "ready", "max_total_num_tokens": controller.max_total_num_tokens}
244
+ )
239
245
  controller.event_loop()
240
246
  except Exception:
241
247
  msg = get_exception_traceback()
@@ -1,18 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """DetokenizerManager is a process that detokenizes the token ids."""
17
15
 
18
16
  import dataclasses
@@ -175,7 +173,6 @@ class DetokenizerManager:
175
173
  output_strs=output_strs,
176
174
  meta_info=recv_obj.meta_info,
177
175
  finished_reason=recv_obj.finished_reason,
178
- session_ids=recv_obj.session_ids,
179
176
  )
180
177
  )
181
178
 
@@ -1,18 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """
17
15
  The definition of objects transfered between different
18
16
  processes (TokenizerManager, DetokenizerManager, Controller).
@@ -21,7 +19,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
21
19
  import uuid
22
20
  from dataclasses import dataclass
23
21
  from enum import Enum
24
- from typing import Dict, List, Optional, Union
22
+ from typing import Dict, List, Optional, Tuple, Union
25
23
 
26
24
  from sglang.srt.managers.schedule_batch import BaseFinishReason
27
25
  from sglang.srt.sampling.sampling_params import SamplingParams
@@ -31,8 +29,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams
31
29
  class GenerateReqInput:
32
30
  # The input prompt. It can be a single prompt or a batch of prompts.
33
31
  text: Optional[Union[List[str], str]] = None
34
- # The token ids for text; one can either specify text or input_ids.
32
+ # The token ids for text; one can specify either text or input_ids
35
33
  input_ids: Optional[Union[List[List[int]], List[int]]] = None
34
+ # The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
35
+ input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
36
36
  # The image input. It can be a file name, a url, or base64 encoded string.
37
37
  # See also python/sglang/srt/utils.py:load_image.
38
38
  image_data: Optional[Union[List[str], str]] = None
@@ -57,14 +57,21 @@ class GenerateReqInput:
57
57
  lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
58
58
 
59
59
  # Session id info for continual prompting
60
- session_id: Optional[Union[List[str], str]] = None
61
- session_rid: Optional[Union[List[str], str]] = None
60
+ session: Optional[
61
+ Union[List[Tuple[str, Optional[str]]], Tuple[str, Optional[str]]]
62
+ ] = None
62
63
 
63
64
  def normalize_batch_and_arguments(self):
64
- if (self.text is None and self.input_ids is None) or (
65
- self.text is not None and self.input_ids is not None
65
+ if (
66
+ self.text is None and self.input_ids is None and self.input_embeds is None
67
+ ) or (
68
+ self.text is not None
69
+ and self.input_ids is not None
70
+ and self.input_embeds is not None
66
71
  ):
67
- raise ValueError("Either text or input_ids should be provided.")
72
+ raise ValueError(
73
+ "Either text, input_ids or input_embeds should be provided."
74
+ )
68
75
 
69
76
  # Derive the batch size
70
77
  if self.text is not None:
@@ -74,13 +81,21 @@ class GenerateReqInput:
74
81
  else:
75
82
  self.is_single = False
76
83
  self.batch_size = len(self.text)
77
- else:
84
+ self.input_embeds = None
85
+ elif self.input_ids is not None:
78
86
  if isinstance(self.input_ids[0], int):
79
87
  self.is_single = True
80
88
  self.batch_size = 1
81
89
  else:
82
90
  self.is_single = False
83
91
  self.batch_size = len(self.input_ids)
92
+ self.input_embeds = None
93
+ else:
94
+ if isinstance(self.input_embeds[0][0], float):
95
+ self.is_single = True
96
+ self.batch_size = 1
97
+ else:
98
+ self.batch_size = len(self.input_embeds)
84
99
 
85
100
  # Handle parallel sampling
86
101
  # When parallel sampling is used, we always treat the input as a batch.
@@ -203,9 +218,11 @@ class TokenizedGenerateReqInput:
203
218
 
204
219
  # LoRA related
205
220
  lora_path: Optional[str] = None # None means just use the base model
221
+ # The input embeds
222
+ input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
206
223
 
207
224
  # Session id info for continual prompting
208
- session_id: Optional[int] = None
225
+ session_id: Optional[str] = None
209
226
  session_rid: Optional[str] = None
210
227
 
211
228
 
@@ -219,6 +236,8 @@ class EmbeddingReqInput:
219
236
  rid: Optional[Union[List[str], str]] = None
220
237
  # Dummy sampling params for compatibility
221
238
  sampling_params: Union[List[Dict], Dict] = None
239
+ # Dummy input embeds for compatibility
240
+ input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
222
241
 
223
242
  def normalize_batch_and_arguments(self):
224
243
  if (self.text is None and self.input_ids is None) or (
@@ -301,8 +320,6 @@ class BatchTokenIDOut:
301
320
  meta_info: List[Dict]
302
321
  finished_reason: List[BaseFinishReason]
303
322
  no_stop_trim: List[bool]
304
- # The updated session unique id
305
- session_ids: List[str]
306
323
 
307
324
 
308
325
  @dataclass
@@ -315,8 +332,6 @@ class BatchStrOut:
315
332
  meta_info: List[Dict]
316
333
  # The finish reason
317
334
  finished_reason: List[BaseFinishReason]
318
- # The update session unique id
319
- session_ids: List[str]
320
335
 
321
336
 
322
337
  @dataclass