sglang 0.4.1__py3-none-any.whl → 0.4.1.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. sglang/bench_offline_throughput.py +1 -0
  2. sglang/bench_serving.py +11 -3
  3. sglang/lang/backend/openai.py +10 -0
  4. sglang/srt/configs/model_config.py +11 -2
  5. sglang/srt/constrained/xgrammar_backend.py +6 -0
  6. sglang/srt/layers/attention/__init__.py +0 -1
  7. sglang/srt/layers/attention/flashinfer_backend.py +54 -41
  8. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -14
  9. sglang/srt/layers/logits_processor.py +30 -2
  10. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -30
  11. sglang/srt/layers/moe/topk.py +14 -0
  12. sglang/srt/layers/quantization/fp8.py +42 -2
  13. sglang/srt/layers/quantization/fp8_kernel.py +91 -18
  14. sglang/srt/layers/quantization/fp8_utils.py +8 -2
  15. sglang/srt/managers/io_struct.py +29 -8
  16. sglang/srt/managers/schedule_batch.py +22 -15
  17. sglang/srt/managers/schedule_policy.py +1 -1
  18. sglang/srt/managers/scheduler.py +71 -34
  19. sglang/srt/managers/session_controller.py +102 -27
  20. sglang/srt/managers/tokenizer_manager.py +95 -55
  21. sglang/srt/managers/tp_worker.py +7 -0
  22. sglang/srt/managers/tp_worker_overlap_thread.py +5 -0
  23. sglang/srt/model_executor/forward_batch_info.py +42 -3
  24. sglang/srt/model_executor/model_runner.py +4 -6
  25. sglang/srt/model_loader/loader.py +22 -11
  26. sglang/srt/models/gemma2.py +19 -0
  27. sglang/srt/models/llama.py +13 -2
  28. sglang/srt/models/llama_eagle.py +132 -0
  29. sglang/srt/openai_api/adapter.py +79 -2
  30. sglang/srt/openai_api/protocol.py +50 -0
  31. sglang/srt/sampling/sampling_params.py +9 -2
  32. sglang/srt/server.py +45 -39
  33. sglang/srt/server_args.py +17 -30
  34. sglang/srt/speculative/spec_info.py +19 -0
  35. sglang/srt/utils.py +62 -0
  36. sglang/version.py +1 -1
  37. {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/METADATA +5 -5
  38. {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/RECORD +41 -39
  39. {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/LICENSE +0 -0
  40. {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/WHEEL +0 -0
  41. {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,132 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ # Adapted from
17
+ # https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py
18
+ """Inference-only LLaMA-EAGLE model compatible with HuggingFace weights."""
19
+
20
+ from typing import Iterable, Optional, Tuple
21
+
22
+ import torch
23
+ from torch import nn
24
+ from transformers import LlamaConfig
25
+
26
+ from sglang.srt.layers.logits_processor import LogitsProcessor
27
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
28
+ from sglang.srt.layers.vocab_parallel_embedding import (
29
+ ParallelLMHead,
30
+ VocabParallelEmbedding,
31
+ )
32
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
33
+ from sglang.srt.models.llama import LlamaDecoderLayer, LlamaForCausalLM
34
+
35
+
36
+ class LlamaDecoderLayer(LlamaDecoderLayer):
37
+ def __init__(
38
+ self,
39
+ config: LlamaConfig,
40
+ layer_id: int = 0,
41
+ quant_config: Optional[QuantizationConfig] = None,
42
+ prefix: str = "",
43
+ ) -> None:
44
+ super().__init__(config, layer_id, quant_config, prefix)
45
+
46
+ # Skip the input_layernorm
47
+ # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
48
+ if layer_id == 0:
49
+ del self.input_layernorm
50
+ setattr(self, "input_layernorm", lambda x: x)
51
+
52
+
53
+ class LlamaModel(nn.Module):
54
+ def __init__(
55
+ self,
56
+ config: LlamaConfig,
57
+ quant_config: Optional[QuantizationConfig] = None,
58
+ ) -> None:
59
+ super().__init__()
60
+ self.config = config
61
+ self.vocab_size = config.vocab_size
62
+ self.embed_tokens = VocabParallelEmbedding(
63
+ config.vocab_size,
64
+ config.hidden_size,
65
+ )
66
+ self.layers = nn.ModuleList(
67
+ [
68
+ LlamaDecoderLayer(
69
+ config, i, quant_config=quant_config, prefix=f"model.layers.{i}"
70
+ )
71
+ for i in range(config.num_hidden_layers)
72
+ ]
73
+ )
74
+ self.fc = torch.nn.Linear(config.hidden_size * 2, config.hidden_size)
75
+
76
+ def forward(
77
+ self,
78
+ input_ids: torch.Tensor,
79
+ positions: torch.Tensor,
80
+ forward_batch: ForwardBatch,
81
+ input_embeds: torch.Tensor = None,
82
+ ) -> torch.Tensor:
83
+ if input_embeds is None:
84
+ hidden_states = self.embed_tokens(input_ids)
85
+ else:
86
+ hidden_states = input_embeds
87
+
88
+ hidden_states = self.fc(
89
+ torch.cat((hidden_states, forward_batch.spec_info.hidden_states), dim=-1)
90
+ )
91
+
92
+ residual = None
93
+ for i in range(len(self.layers)):
94
+ layer = self.layers[i]
95
+ hidden_states, residual = layer(
96
+ positions,
97
+ hidden_states,
98
+ forward_batch,
99
+ residual,
100
+ )
101
+ return hidden_states + residual
102
+
103
+
104
+ class LlamaForCausalLMEagle(LlamaForCausalLM):
105
+ def __init__(
106
+ self,
107
+ config: LlamaConfig,
108
+ quant_config: Optional[QuantizationConfig] = None,
109
+ cache_config=None,
110
+ ) -> None:
111
+ nn.Module.__init__(self)
112
+ self.config = config
113
+ self.quant_config = quant_config
114
+ self.model = LlamaModel(config, quant_config=quant_config)
115
+ # Llama 3.2 1B Instruct set tie_word_embeddings to True
116
+ # Llama 3.1 8B Instruct set tie_word_embeddings to False
117
+ if self.config.tie_word_embeddings:
118
+ self.lm_head = self.model.embed_tokens
119
+ else:
120
+ self.lm_head = ParallelLMHead(
121
+ config.vocab_size, config.hidden_size, quant_config=quant_config
122
+ )
123
+ self.logits_processor = LogitsProcessor(config)
124
+
125
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
126
+ for name, loaded_weight in weights:
127
+ if "lm_head" not in name:
128
+ name = "model." + name
129
+ super().load_weights([(name, loaded_weight)])
130
+
131
+
132
+ EntryClass = [LlamaForCausalLMEagle]
@@ -65,10 +65,13 @@ from sglang.srt.openai_api.protocol import (
65
65
  FileDeleteResponse,
66
66
  FileRequest,
67
67
  FileResponse,
68
+ FunctionResponse,
68
69
  LogProbs,
70
+ ToolCall,
69
71
  TopLogprob,
70
72
  UsageInfo,
71
73
  )
74
+ from sglang.srt.utils import TOOLS_TAG_LIST, parse_tool_response
72
75
  from sglang.utils import get_exception_traceback
73
76
 
74
77
  logger = logging.getLogger(__name__)
@@ -517,6 +520,7 @@ def v1_generate_request(
517
520
  "repetition_penalty": request.repetition_penalty,
518
521
  "regex": request.regex,
519
522
  "json_schema": request.json_schema,
523
+ "ebnf": request.ebnf,
520
524
  "n": request.n,
521
525
  "no_stop_trim": request.no_stop_trim,
522
526
  "ignore_eos": request.ignore_eos,
@@ -692,6 +696,14 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
692
696
 
693
697
  async def v1_completions(tokenizer_manager, raw_request: Request):
694
698
  request_json = await raw_request.json()
699
+ if "extra_body" in request_json:
700
+ extra = request_json["extra_body"]
701
+ if "ebnf" in extra:
702
+ request_json["ebnf"] = extra["ebnf"]
703
+ if "regex" in extra:
704
+ request_json["regex"] = extra["regex"]
705
+ # remove extra_body to avoid pydantic conflict
706
+ del request_json["extra_body"]
695
707
  all_requests = [CompletionRequest(**request_json)]
696
708
  adapted_request, request = v1_generate_request(all_requests)
697
709
 
@@ -870,6 +882,21 @@ def v1_chat_generate_request(
870
882
  # None skips any image processing in GenerateReqInput.
871
883
  if not isinstance(request.messages, str):
872
884
  # Apply chat template and its stop strings.
885
+ tools = None
886
+ if request.tools and request.tool_choice != "none":
887
+ request.skip_special_tokens = False
888
+ if request.stream:
889
+ logger.warning("Streaming is not supported with tools.")
890
+ request.stream = False
891
+ if not isinstance(request.tool_choice, str):
892
+ tools = [
893
+ item.function.model_dump()
894
+ for item in request.tools
895
+ if item.function.name == request.tool_choice.function.name
896
+ ]
897
+ else:
898
+ tools = [item.function.model_dump() for item in request.tools]
899
+
873
900
  if chat_template_name is None:
874
901
  openai_compatible_messages = []
875
902
  for message in request.messages:
@@ -893,6 +920,7 @@ def v1_chat_generate_request(
893
920
  openai_compatible_messages,
894
921
  tokenize=True,
895
922
  add_generation_prompt=True,
923
+ tools=tools,
896
924
  )
897
925
  if assistant_prefix:
898
926
  prompt_ids += tokenizer_manager.tokenizer.encode(assistant_prefix)
@@ -936,6 +964,7 @@ def v1_chat_generate_request(
936
964
  "frequency_penalty": request.frequency_penalty,
937
965
  "repetition_penalty": request.repetition_penalty,
938
966
  "regex": request.regex,
967
+ "ebnf": request.ebnf,
939
968
  "n": request.n,
940
969
  "no_stop_trim": request.no_stop_trim,
941
970
  "ignore_eos": request.ignore_eos,
@@ -1031,11 +1060,46 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
1031
1060
 
1032
1061
  finish_reason = ret_item["meta_info"]["finish_reason"]
1033
1062
 
1063
+ tool_calls = None
1064
+ text = ret_item["text"]
1065
+
1066
+ if isinstance(request, list):
1067
+ tool_choice = request[idx].tool_choice
1068
+ tools = request[idx].tools
1069
+ else:
1070
+ tool_choice = request.tool_choice
1071
+ tools = request.tools
1072
+
1073
+ if tool_choice != "none" and any([i in text for i in TOOLS_TAG_LIST]):
1074
+ if finish_reason == "stop":
1075
+ finish_reason = "tool_calls"
1076
+ try:
1077
+ text, call_info_list = parse_tool_response(text, tools) # noqa
1078
+ tool_calls = [
1079
+ ToolCall(
1080
+ id=str(call_info[0]),
1081
+ function=FunctionResponse(
1082
+ name=call_info[1], arguments=call_info[2]
1083
+ ),
1084
+ )
1085
+ for call_info in call_info_list
1086
+ ]
1087
+ except Exception as e:
1088
+ logger.error(f"Exception: {e}")
1089
+ return create_error_response(
1090
+ HTTPStatus.BAD_REQUEST,
1091
+ "Failed to parse fc related info to json format!",
1092
+ )
1093
+
1034
1094
  if to_file:
1035
1095
  # to make the choice data json serializable
1036
1096
  choice_data = {
1037
1097
  "index": 0,
1038
- "message": {"role": "assistant", "content": ret_item["text"]},
1098
+ "message": {
1099
+ "role": "assistant",
1100
+ "content": ret_item["text"] if tool_calls is None else None,
1101
+ "tool_calls": tool_calls,
1102
+ },
1039
1103
  "logprobs": choice_logprobs,
1040
1104
  "finish_reason": (finish_reason["type"] if finish_reason else ""),
1041
1105
  "matched_stop": (
@@ -1047,7 +1111,11 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
1047
1111
  else:
1048
1112
  choice_data = ChatCompletionResponseChoice(
1049
1113
  index=idx,
1050
- message=ChatMessage(role="assistant", content=ret_item["text"]),
1114
+ message=ChatMessage(
1115
+ role="assistant",
1116
+ content=ret_item["text"] if tool_calls is None else None,
1117
+ tool_calls=tool_calls,
1118
+ ),
1051
1119
  logprobs=choice_logprobs,
1052
1120
  finish_reason=(finish_reason["type"] if finish_reason else ""),
1053
1121
  matched_stop=(
@@ -1108,6 +1176,15 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
1108
1176
 
1109
1177
  async def v1_chat_completions(tokenizer_manager, raw_request: Request):
1110
1178
  request_json = await raw_request.json()
1179
+ if "extra_body" in request_json:
1180
+ extra = request_json["extra_body"]
1181
+ # For example, if 'ebnf' is given:
1182
+ if "ebnf" in extra:
1183
+ request_json["ebnf"] = extra["ebnf"]
1184
+ if "regex" in extra:
1185
+ request_json["regex"] = extra["regex"]
1186
+ # remove extra_body to avoid pydantic conflict
1187
+ del request_json["extra_body"]
1111
1188
  all_requests = [ChatCompletionRequest(**request_json)]
1112
1189
  adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
1113
1190
 
@@ -179,6 +179,7 @@ class CompletionRequest(BaseModel):
179
179
  ignore_eos: bool = False
180
180
  skip_special_tokens: bool = True
181
181
  lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
182
+ ebnf: Optional[str] = None
182
183
 
183
184
 
184
185
  class CompletionResponseChoice(BaseModel):
@@ -256,6 +257,34 @@ class ResponseFormat(BaseModel):
256
257
  json_schema: Optional[JsonSchemaResponseFormat] = None
257
258
 
258
259
 
260
+ class Function(BaseModel):
261
+ """Function descriptions."""
262
+
263
+ description: Optional[str] = Field(default=None, examples=[None])
264
+ name: str
265
+ parameters: Optional[object] = None
266
+
267
+
268
+ class Tool(BaseModel):
269
+ """Function wrapper."""
270
+
271
+ type: str = Field(default="function", examples=["function"])
272
+ function: Function
273
+
274
+
275
+ class ToolChoiceFuncName(BaseModel):
276
+ """The name of tool choice function."""
277
+
278
+ name: str
279
+
280
+
281
+ class ToolChoice(BaseModel):
282
+ """The tool choice definition."""
283
+
284
+ function: ToolChoiceFuncName
285
+ type: Literal["function"] = Field(default="function", examples=["function"])
286
+
287
+
259
288
  class ChatCompletionRequest(BaseModel):
260
289
  # Ordered by official OpenAI API documentation
261
290
  # https://platform.openai.com/docs/api-reference/chat/create
@@ -276,6 +305,10 @@ class ChatCompletionRequest(BaseModel):
276
305
  temperature: float = 0.7
277
306
  top_p: float = 1.0
278
307
  user: Optional[str] = None
308
+ tools: Optional[List[Tool]] = Field(default=None, examples=[None])
309
+ tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field(
310
+ default="auto", examples=["none"]
311
+ ) # noqa
279
312
 
280
313
  # Extra parameters for SRT backend only and will be ignored by OpenAI models.
281
314
  top_k: int = -1
@@ -288,11 +321,28 @@ class ChatCompletionRequest(BaseModel):
288
321
  ignore_eos: bool = False
289
322
  skip_special_tokens: bool = True
290
323
  lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
324
+ ebnf: Optional[str] = None
325
+
326
+
327
+ class FunctionResponse(BaseModel):
328
+ """Function response."""
329
+
330
+ name: str
331
+ arguments: str
332
+
333
+
334
+ class ToolCall(BaseModel):
335
+ """Tool call response."""
336
+
337
+ id: str
338
+ type: Literal["function"] = "function"
339
+ function: FunctionResponse
291
340
 
292
341
 
293
342
  class ChatMessage(BaseModel):
294
343
  role: Optional[str] = None
295
344
  content: Optional[str] = None
345
+ tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
296
346
 
297
347
 
298
348
  class ChatCompletionResponseChoice(BaseModel):
@@ -36,6 +36,7 @@ class SamplingParams:
36
36
  regex: Optional[str] = None,
37
37
  n: int = 1,
38
38
  json_schema: Optional[str] = None,
39
+ ebnf: Optional[str] = None,
39
40
  no_stop_trim: bool = False,
40
41
  ignore_eos: bool = False,
41
42
  skip_special_tokens: bool = True,
@@ -60,6 +61,7 @@ class SamplingParams:
60
61
  self.regex = regex
61
62
  self.n = n
62
63
  self.json_schema = json_schema
64
+ self.ebnf = ebnf
63
65
  self.no_stop_trim = no_stop_trim
64
66
 
65
67
  # Process some special cases
@@ -111,8 +113,13 @@ class SamplingParams:
111
113
  f"min_new_tokens must be in (0, max_new_tokens({self.max_new_tokens})], got "
112
114
  f"{self.min_new_tokens}."
113
115
  )
114
- if self.regex is not None and self.json_schema is not None:
115
- raise ValueError("regex and json_schema cannot be both set.")
116
+ grammars = [
117
+ self.json_schema,
118
+ self.regex,
119
+ self.ebnf,
120
+ ] # since mutually exclusive, only one can be set
121
+ if sum(x is not None for x in grammars) > 1:
122
+ raise ValueError("Only one of regex, json_schema, or ebnf can be set.")
116
123
 
117
124
  def normalize(self, tokenizer):
118
125
  # Process stop strings
sglang/srt/server.py CHANGED
@@ -57,6 +57,7 @@ from sglang.srt.managers.io_struct import (
57
57
  OpenSessionReqInput,
58
58
  UpdateWeightFromDiskReqInput,
59
59
  UpdateWeightsFromDistributedReqInput,
60
+ UpdateWeightsFromTensorReqInput,
60
61
  )
61
62
  from sglang.srt.managers.scheduler import run_scheduler_process
62
63
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
@@ -109,6 +110,7 @@ app.add_middleware(
109
110
  tokenizer_manager: TokenizerManager = None
110
111
  scheduler_info: Dict = None
111
112
 
113
+
112
114
  ##### Native API endpoints #####
113
115
 
114
116
 
@@ -245,16 +247,11 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
245
247
  try:
246
248
  ret = await tokenizer_manager.get_weights_by_name(obj, request)
247
249
  if ret is None:
248
- return ORJSONResponse(
249
- {"error": {"message": "Get parameter by name failed"}},
250
- status_code=HTTPStatus.BAD_REQUEST,
251
- )
250
+ return _create_error_response("Get parameter by name failed")
252
251
  else:
253
252
  return ORJSONResponse(ret, status_code=200)
254
253
  except Exception as e:
255
- return ORJSONResponse(
256
- {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
257
- )
254
+ return _create_error_response(e)
258
255
 
259
256
 
260
257
  @app.api_route("/open_session", methods=["GET", "POST"])
@@ -262,11 +259,13 @@ async def open_session(obj: OpenSessionReqInput, request: Request):
262
259
  """Open a session, and return its unique session id."""
263
260
  try:
264
261
  session_id = await tokenizer_manager.open_session(obj, request)
262
+ if session_id is None:
263
+ raise Exception(
264
+ "Failed to open the session. Check if a session with the same id is still open."
265
+ )
265
266
  return session_id
266
267
  except Exception as e:
267
- return ORJSONResponse(
268
- {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
269
- )
268
+ return _create_error_response(e)
270
269
 
271
270
 
272
271
  @app.api_route("/close_session", methods=["GET", "POST"])
@@ -276,9 +275,7 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
276
275
  await tokenizer_manager.close_session(obj, request)
277
276
  return Response(status_code=200)
278
277
  except Exception as e:
279
- return ORJSONResponse(
280
- {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
281
- )
278
+ return _create_error_response(e)
282
279
 
283
280
 
284
281
  # fastapi implicitly converts json in the request to obj (dataclass)
@@ -312,9 +309,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
312
309
  return ret
313
310
  except ValueError as e:
314
311
  logger.error(f"Error: {e}")
315
- return ORJSONResponse(
316
- {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
317
- )
312
+ return _create_error_response(e)
318
313
 
319
314
 
320
315
  @app.api_route("/encode", methods=["POST", "PUT"])
@@ -325,9 +320,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request):
325
320
  ret = await tokenizer_manager.generate_request(obj, request).__anext__()
326
321
  return ret
327
322
  except ValueError as e:
328
- return ORJSONResponse(
329
- {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
330
- )
323
+ return _create_error_response(e)
331
324
 
332
325
 
333
326
  @app.api_route("/classify", methods=["POST", "PUT"])
@@ -338,9 +331,7 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
338
331
  ret = await tokenizer_manager.generate_request(obj, request).__anext__()
339
332
  return ret
340
333
  except ValueError as e:
341
- return ORJSONResponse(
342
- {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
343
- )
334
+ return _create_error_response(e)
344
335
 
345
336
 
346
337
  ##### OpenAI-compatible API endpoints #####
@@ -416,6 +407,12 @@ async def retrieve_file_content(file_id: str):
416
407
  return await v1_retrieve_file_content(file_id)
417
408
 
418
409
 
410
+ def _create_error_response(e):
411
+ return ORJSONResponse(
412
+ {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
413
+ )
414
+
415
+
419
416
  def launch_engine(
420
417
  server_args: ServerArgs,
421
418
  ):
@@ -493,7 +490,16 @@ def launch_engine(
493
490
  # Wait for model to finish loading
494
491
  scheduler_infos = []
495
492
  for i in range(len(scheduler_pipe_readers)):
496
- data = scheduler_pipe_readers[i].recv()
493
+ try:
494
+ data = scheduler_pipe_readers[i].recv()
495
+ except EOFError as e:
496
+ logger.exception(e)
497
+ logger.error(
498
+ f"Rank {i} scheduler is dead. Please check if there are relevant logs."
499
+ )
500
+ scheduler_procs[i].join()
501
+ logger.error(f"Exit code: {scheduler_procs[i].exitcode}")
502
+ raise
497
503
 
498
504
  if data["status"] != "ready":
499
505
  raise RuntimeError(
@@ -501,7 +507,7 @@ def launch_engine(
501
507
  )
502
508
  scheduler_infos.append(data)
503
509
 
504
- # Assume all schedulers have same max_total_num_tokens
510
+ # Assume all schedulers have same scheduler_info
505
511
  scheduler_info = scheduler_infos[0]
506
512
 
507
513
 
@@ -849,12 +855,10 @@ class Engine:
849
855
  group_name=group_name,
850
856
  backend=backend,
851
857
  )
852
-
853
- async def _init_group():
854
- return await tokenizer_manager.init_weights_update_group(obj, None)
855
-
856
858
  loop = asyncio.get_event_loop()
857
- return loop.run_until_complete(_init_group())
859
+ return loop.run_until_complete(
860
+ tokenizer_manager.init_weights_update_group(obj, None)
861
+ )
858
862
 
859
863
  def update_weights_from_distributed(self, name, dtype, shape):
860
864
  """Update weights from distributed source."""
@@ -863,22 +867,24 @@ class Engine:
863
867
  dtype=dtype,
864
868
  shape=shape,
865
869
  )
870
+ loop = asyncio.get_event_loop()
871
+ return loop.run_until_complete(
872
+ tokenizer_manager.update_weights_from_distributed(obj, None)
873
+ )
866
874
 
867
- async def _update_weights():
868
- return await tokenizer_manager.update_weights_from_distributed(obj, None)
869
-
875
+ def update_weights_from_tensor(self, name, tensor):
876
+ """Update weights from distributed source."""
877
+ obj = UpdateWeightsFromTensorReqInput(name=name, tensor=tensor)
870
878
  loop = asyncio.get_event_loop()
871
- return loop.run_until_complete(_update_weights())
879
+ return loop.run_until_complete(
880
+ tokenizer_manager.update_weights_from_tensor(obj, None)
881
+ )
872
882
 
873
883
  def get_weights_by_name(self, name, truncate_size=100):
874
884
  """Get weights by parameter name."""
875
885
  obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
876
-
877
- async def _get_weights():
878
- return await tokenizer_manager.get_weights_by_name(obj, None)
879
-
880
886
  loop = asyncio.get_event_loop()
881
- return loop.run_until_complete(_get_weights())
887
+ return loop.run_until_complete(tokenizer_manager.get_weights_by_name(obj, None))
882
888
 
883
889
 
884
890
  class Runtime:
@@ -888,7 +894,7 @@ class Runtime:
888
894
  using the commond line interface.
889
895
 
890
896
  It is mainly used for the frontend language.
891
- You should use the Engine class if you want to do normal offline processing.
897
+ You should use the Engine class above if you want to do normal offline processing.
892
898
  """
893
899
 
894
900
  def __init__(
sglang/srt/server_args.py CHANGED
@@ -55,7 +55,7 @@ class ServerArgs:
55
55
  is_embedding: bool = False
56
56
  revision: Optional[str] = None
57
57
 
58
- # Port
58
+ # Port for the HTTP server
59
59
  host: str = "127.0.0.1"
60
60
  port: int = 30000
61
61
 
@@ -68,6 +68,7 @@ class ServerArgs:
68
68
  schedule_policy: str = "lpm"
69
69
  schedule_conservativeness: float = 1.0
70
70
  cpu_offload_gb: int = 0
71
+ prefill_only_one_req: bool = False
71
72
 
72
73
  # Other runtime options
73
74
  tp_size: int = 1
@@ -94,6 +95,7 @@ class ServerArgs:
94
95
  # Data parallelism
95
96
  dp_size: int = 1
96
97
  load_balance_method: str = "round_robin"
98
+
97
99
  # Expert parallelism
98
100
  ep_size: int = 1
99
101
 
@@ -217,6 +219,13 @@ class ServerArgs:
217
219
  )
218
220
  self.disable_cuda_graph = True
219
221
 
222
+ # Expert parallelism
223
+ if self.enable_ep_moe:
224
+ self.ep_size = self.tp_size
225
+ logger.info(
226
+ f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
227
+ )
228
+
220
229
  # Others
221
230
  if self.enable_dp_attention:
222
231
  self.dp_size = self.tp_size
@@ -229,12 +238,6 @@ class ServerArgs:
229
238
  "Data parallel size is adjusted to be the same as tensor parallel size. "
230
239
  "Overlap scheduler is disabled."
231
240
  )
232
- # Expert parallelism
233
- if self.enable_ep_moe:
234
- self.ep_size = self.tp_size
235
- logger.info(
236
- f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
237
- )
238
241
 
239
242
  # GGUF
240
243
  if (
@@ -430,13 +433,18 @@ class ServerArgs:
430
433
  default=ServerArgs.schedule_conservativeness,
431
434
  help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
432
435
  )
433
-
434
436
  parser.add_argument(
435
437
  "--cpu-offload-gb",
436
438
  type=int,
437
439
  default=ServerArgs.cpu_offload_gb,
438
440
  help="How many GBs of RAM to reserve for CPU offloading",
439
441
  )
442
+ parser.add_argument(
443
+ "--prefill-only-one-req",
444
+ type=bool,
445
+ help="If true, we only prefill one request at one prefill batch",
446
+ default=ServerArgs.prefill_only_one_req,
447
+ )
440
448
 
441
449
  # Other runtime options
442
450
  parser.add_argument(
@@ -555,6 +563,7 @@ class ServerArgs:
555
563
  "shortest_queue",
556
564
  ],
557
565
  )
566
+
558
567
  # Expert parallelism
559
568
  parser.add_argument(
560
569
  "--expert-parallel-size",
@@ -777,28 +786,6 @@ class ServerArgs:
777
786
  help="Delete the model checkpoint after loading the model.",
778
787
  )
779
788
 
780
- # Deprecated arguments
781
- parser.add_argument(
782
- "--enable-overlap-schedule",
783
- action=DeprecatedAction,
784
- help="'--enable-overlap-schedule' is deprecated. It is enabled by default now. Please drop this argument.",
785
- )
786
- parser.add_argument(
787
- "--disable-flashinfer",
788
- action=DeprecatedAction,
789
- help="'--disable-flashinfer' is deprecated. Please use '--attention-backend triton' instead.",
790
- )
791
- parser.add_argument(
792
- "--disable-flashinfer-sampling",
793
- action=DeprecatedAction,
794
- help="'--disable-flashinfer-sampling' is deprecated. Please use '--sampling-backend pytroch' instead.",
795
- )
796
- parser.add_argument(
797
- "--disable-disk-cache",
798
- action=DeprecatedAction,
799
- help="'--disable-disk-cache' is deprecated. Please use '--disable-outlines-disk-cache' instead.",
800
- )
801
-
802
789
  @classmethod
803
790
  def from_cli_args(cls, args: argparse.Namespace):
804
791
  args.tp_size = args.tensor_parallel_size