sglang 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. sglang/__init__.py +8 -0
  2. sglang/api.py +10 -2
  3. sglang/bench_latency.py +151 -40
  4. sglang/bench_serving.py +46 -22
  5. sglang/check_env.py +24 -2
  6. sglang/global_config.py +0 -1
  7. sglang/lang/backend/base_backend.py +3 -1
  8. sglang/lang/backend/openai.py +8 -3
  9. sglang/lang/backend/runtime_endpoint.py +46 -29
  10. sglang/lang/choices.py +164 -0
  11. sglang/lang/compiler.py +2 -2
  12. sglang/lang/interpreter.py +6 -13
  13. sglang/lang/ir.py +14 -5
  14. sglang/srt/constrained/base_tool_cache.py +1 -1
  15. sglang/srt/constrained/fsm_cache.py +12 -2
  16. sglang/srt/layers/activation.py +33 -0
  17. sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
  18. sglang/srt/layers/extend_attention.py +6 -1
  19. sglang/srt/layers/layernorm.py +65 -0
  20. sglang/srt/layers/logits_processor.py +6 -1
  21. sglang/srt/layers/pooler.py +50 -0
  22. sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
  23. sglang/srt/layers/radix_attention.py +4 -7
  24. sglang/srt/managers/detokenizer_manager.py +31 -9
  25. sglang/srt/managers/io_struct.py +63 -0
  26. sglang/srt/managers/policy_scheduler.py +173 -25
  27. sglang/srt/managers/schedule_batch.py +174 -380
  28. sglang/srt/managers/tokenizer_manager.py +197 -112
  29. sglang/srt/managers/tp_worker.py +299 -364
  30. sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
  31. sglang/srt/mem_cache/chunk_cache.py +43 -20
  32. sglang/srt/mem_cache/memory_pool.py +10 -15
  33. sglang/srt/mem_cache/radix_cache.py +74 -40
  34. sglang/srt/model_executor/cuda_graph_runner.py +27 -12
  35. sglang/srt/model_executor/forward_batch_info.py +319 -0
  36. sglang/srt/model_executor/model_runner.py +30 -47
  37. sglang/srt/models/chatglm.py +1 -1
  38. sglang/srt/models/commandr.py +1 -1
  39. sglang/srt/models/dbrx.py +1 -1
  40. sglang/srt/models/deepseek.py +1 -1
  41. sglang/srt/models/deepseek_v2.py +1 -1
  42. sglang/srt/models/gemma.py +1 -1
  43. sglang/srt/models/gemma2.py +1 -2
  44. sglang/srt/models/gpt_bigcode.py +1 -1
  45. sglang/srt/models/grok.py +1 -1
  46. sglang/srt/models/internlm2.py +3 -8
  47. sglang/srt/models/llama2.py +5 -5
  48. sglang/srt/models/llama_classification.py +1 -1
  49. sglang/srt/models/llama_embedding.py +88 -0
  50. sglang/srt/models/llava.py +1 -2
  51. sglang/srt/models/llavavid.py +1 -2
  52. sglang/srt/models/minicpm.py +1 -1
  53. sglang/srt/models/mixtral.py +1 -1
  54. sglang/srt/models/mixtral_quant.py +1 -1
  55. sglang/srt/models/qwen.py +1 -1
  56. sglang/srt/models/qwen2.py +1 -1
  57. sglang/srt/models/qwen2_moe.py +1 -12
  58. sglang/srt/models/stablelm.py +1 -1
  59. sglang/srt/openai_api/adapter.py +189 -39
  60. sglang/srt/openai_api/protocol.py +43 -1
  61. sglang/srt/sampling/penaltylib/__init__.py +13 -0
  62. sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
  63. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
  64. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
  65. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
  66. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
  67. sglang/srt/sampling_params.py +31 -4
  68. sglang/srt/server.py +93 -21
  69. sglang/srt/server_args.py +30 -19
  70. sglang/srt/utils.py +31 -13
  71. sglang/test/run_eval.py +10 -1
  72. sglang/test/runners.py +63 -63
  73. sglang/test/simple_eval_humaneval.py +2 -8
  74. sglang/test/simple_eval_mgsm.py +203 -0
  75. sglang/test/srt/sampling/penaltylib/utils.py +337 -0
  76. sglang/test/test_layernorm.py +60 -0
  77. sglang/test/test_programs.py +4 -2
  78. sglang/test/test_utils.py +21 -3
  79. sglang/utils.py +0 -1
  80. sglang/version.py +1 -1
  81. {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/METADATA +50 -31
  82. sglang-0.2.12.dist-info/RECORD +112 -0
  83. sglang/srt/layers/linear.py +0 -884
  84. sglang/srt/layers/quantization/__init__.py +0 -64
  85. sglang/srt/layers/quantization/fp8.py +0 -677
  86. sglang-0.2.10.dist-info/RECORD +0 -100
  87. {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/LICENSE +0 -0
  88. {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/WHEEL +0 -0
  89. {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,79 @@
1
+ import typing
2
+
3
+ import torch
4
+
5
+ from ..orchestrator import _BatchedPenalizer, _TokenIDs
6
+
7
+
8
+ class BatchedPresencePenalizer(_BatchedPenalizer):
9
+ """
10
+ Presence penalizer penalizes tokens based on their presence in the output.
11
+ """
12
+
13
+ presence_penalties: torch.Tensor = None
14
+ cumulated_presence_penalties: torch.Tensor = None
15
+
16
+ def _is_required(self) -> bool:
17
+ return any(
18
+ req.sampling_params.presence_penalty != 0.0
19
+ for req in self.orchestrator.reqs()
20
+ )
21
+
22
+ def _prepare(self):
23
+ self.cumulated_presence_penalties = (
24
+ torch.tensor(
25
+ data=[0.0 for _ in self.orchestrator.reqs()],
26
+ dtype=torch.float32,
27
+ device=self.orchestrator.device,
28
+ )
29
+ .unsqueeze_(1)
30
+ .repeat(1, self.orchestrator.vocab_size)
31
+ )
32
+
33
+ self.presence_penalties = (
34
+ torch.tensor(
35
+ data=[
36
+ req.sampling_params.presence_penalty
37
+ for req in self.orchestrator.reqs()
38
+ ],
39
+ dtype=torch.float32,
40
+ device=self.orchestrator.device,
41
+ )
42
+ .unsqueeze_(1)
43
+ .expand_as(self.cumulated_presence_penalties)
44
+ )
45
+
46
+ def _teardown(self):
47
+ del self.presence_penalties
48
+ del self.cumulated_presence_penalties
49
+
50
+ self.presence_penalties = None
51
+ self.cumulated_presence_penalties = None
52
+
53
+ def _cumulate_input_tokens(self, input_ids: _TokenIDs):
54
+ pass
55
+
56
+ def _cumulate_output_tokens(self, output_ids: _TokenIDs):
57
+ mask = output_ids.occurrence_count() > 0
58
+ self.cumulated_presence_penalties[mask] = self.presence_penalties[mask]
59
+
60
+ def _apply(self, logits: torch.Tensor) -> torch.Tensor:
61
+ logits -= self.cumulated_presence_penalties
62
+ return logits
63
+
64
+ def _filter(
65
+ self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
66
+ ):
67
+ self.presence_penalties = self.presence_penalties[indices_tensor_to_keep]
68
+ self.cumulated_presence_penalties = self.cumulated_presence_penalties[
69
+ indices_tensor_to_keep
70
+ ]
71
+
72
+ def _merge(self, their: "BatchedPresencePenalizer"):
73
+ self.presence_penalties = torch.cat(
74
+ [self.presence_penalties, their.presence_penalties], dim=0
75
+ )
76
+ self.cumulated_presence_penalties = torch.cat(
77
+ [self.cumulated_presence_penalties, their.cumulated_presence_penalties],
78
+ dim=0,
79
+ )
@@ -0,0 +1,83 @@
1
+ import typing
2
+
3
+ import torch
4
+
5
+ from ..orchestrator import _BatchedPenalizer, _TokenIDs
6
+
7
+
8
+ class BatchedRepetitionPenalizer(_BatchedPenalizer):
9
+ """
10
+ Repetition penalizer penalizes tokens based on their repetition in the input and output.
11
+ """
12
+
13
+ repetition_penalties: torch.Tensor = None
14
+ cumulated_repetition_penalties: torch.Tensor = None
15
+
16
+ def _is_required(self) -> bool:
17
+ return any(
18
+ req.sampling_params.repetition_penalty != 1.0
19
+ for req in self.orchestrator.reqs()
20
+ )
21
+
22
+ def _prepare(self):
23
+ self.cumulated_repetition_penalties = (
24
+ torch.tensor(
25
+ data=[1.0 for _ in self.orchestrator.reqs()],
26
+ dtype=torch.float32,
27
+ device=self.orchestrator.device,
28
+ )
29
+ .unsqueeze_(1)
30
+ .repeat(1, self.orchestrator.vocab_size)
31
+ )
32
+
33
+ self.repetition_penalties = (
34
+ torch.tensor(
35
+ data=[
36
+ req.sampling_params.repetition_penalty
37
+ for req in self.orchestrator.reqs()
38
+ ],
39
+ dtype=torch.float32,
40
+ device=self.orchestrator.device,
41
+ )
42
+ .unsqueeze_(1)
43
+ .expand_as(self.cumulated_repetition_penalties)
44
+ )
45
+
46
+ def _teardown(self):
47
+ del self.repetition_penalties
48
+ del self.cumulated_repetition_penalties
49
+
50
+ self.repetition_penalties = None
51
+ self.cumulated_repetition_penalties = None
52
+
53
+ def _cumulate_input_tokens(self, input_ids: _TokenIDs):
54
+ mask = input_ids.occurrence_count() > 0
55
+ self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
56
+
57
+ def _cumulate_output_tokens(self, output_ids: _TokenIDs):
58
+ mask = output_ids.occurrence_count() > 0
59
+ self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
60
+
61
+ def _apply(self, logits: torch.Tensor) -> torch.Tensor:
62
+ return torch.where(
63
+ logits > 0,
64
+ logits / self.cumulated_repetition_penalties,
65
+ logits * self.cumulated_repetition_penalties,
66
+ )
67
+
68
+ def _filter(
69
+ self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
70
+ ):
71
+ self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
72
+ self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[
73
+ indices_tensor_to_keep
74
+ ]
75
+
76
+ def _merge(self, their: "BatchedRepetitionPenalizer"):
77
+ self.repetition_penalties = torch.cat(
78
+ [self.repetition_penalties, their.repetition_penalties], dim=0
79
+ )
80
+ self.cumulated_repetition_penalties = torch.cat(
81
+ [self.cumulated_repetition_penalties, their.cumulated_repetition_penalties],
82
+ dim=0,
83
+ )
@@ -23,13 +23,16 @@ _SAMPLING_EPS = 1e-6
23
23
  class SamplingParams:
24
24
  def __init__(
25
25
  self,
26
- max_new_tokens: int = 16,
26
+ max_new_tokens: int = 128,
27
+ min_new_tokens: int = 0,
27
28
  stop: Optional[Union[str, List[str]]] = None,
29
+ stop_token_ids: Optional[List[int]] = [],
28
30
  temperature: float = 1.0,
29
31
  top_p: float = 1.0,
30
32
  top_k: int = -1,
31
33
  frequency_penalty: float = 0.0,
32
34
  presence_penalty: float = 0.0,
35
+ repetition_penalty: float = 1.0,
33
36
  ignore_eos: bool = False,
34
37
  skip_special_tokens: bool = True,
35
38
  spaces_between_special_tokens: bool = True,
@@ -42,8 +45,11 @@ class SamplingParams:
42
45
  self.top_k = top_k
43
46
  self.frequency_penalty = frequency_penalty
44
47
  self.presence_penalty = presence_penalty
48
+ self.repetition_penalty = repetition_penalty
45
49
  self.stop_strs = stop
50
+ self.stop_token_ids = {*stop_token_ids}
46
51
  self.max_new_tokens = max_new_tokens
52
+ self.min_new_tokens = min_new_tokens
47
53
  self.ignore_eos = ignore_eos
48
54
  self.skip_special_tokens = skip_special_tokens
49
55
  self.spaces_between_special_tokens = spaces_between_special_tokens
@@ -80,23 +86,44 @@ class SamplingParams:
80
86
  raise ValueError(
81
87
  "presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}."
82
88
  )
89
+ if not 0.0 <= self.repetition_penalty <= 2.0:
90
+ raise ValueError(
91
+ "repetition_penalty must be in (0, 2], got "
92
+ f"{self.repetition_penalty}."
93
+ )
94
+ if not 0 <= self.min_new_tokens:
95
+ raise ValueError(
96
+ f"min_new_tokens must be in (0, max_new_tokens], got "
97
+ f"{self.min_new_tokens}."
98
+ )
83
99
  if self.max_new_tokens is not None:
84
100
  if self.max_new_tokens < 0:
85
101
  raise ValueError(
86
102
  f"max_new_tokens must be at least 0, got {self.max_new_tokens}."
87
103
  )
104
+ if not self.min_new_tokens <= self.max_new_tokens:
105
+ raise ValueError(
106
+ f"min_new_tokens must be in (0, max_new_tokens({self.max_new_tokens})], got "
107
+ f"{self.min_new_tokens}."
108
+ )
88
109
 
89
110
  def normalize(self, tokenizer):
90
111
  # Process stop strings
91
112
  if self.stop_strs is None:
92
113
  self.stop_strs = []
93
- self.stop_str_max_len = 0
114
+ if self.stop_token_ids is None:
115
+ self.stop_str_max_len = 0
116
+ else:
117
+ self.stop_str_max_len = 1
94
118
  else:
95
119
  if isinstance(self.stop_strs, str):
96
120
  self.stop_strs = [self.stop_strs]
97
121
 
98
122
  stop_str_max_len = 0
99
123
  for stop_str in self.stop_strs:
100
- stop_str_ids = tokenizer.encode(stop_str, add_special_tokens=False)
101
- stop_str_max_len = max(stop_str_max_len, len(stop_str_ids))
124
+ if tokenizer is not None:
125
+ stop_str_ids = tokenizer.encode(stop_str, add_special_tokens=False)
126
+ stop_str_max_len = max(stop_str_max_len, len(stop_str_ids))
127
+ else:
128
+ stop_str_max_len = max(stop_str_max_len, len(stop_str))
102
129
  self.stop_str_max_len = stop_str_max_len
sglang/srt/server.py CHANGED
@@ -52,13 +52,15 @@ from sglang.srt.managers.controller_single import (
52
52
  start_controller_process as start_controller_process_single,
53
53
  )
54
54
  from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
55
- from sglang.srt.managers.io_struct import GenerateReqInput
55
+ from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
56
56
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
57
57
  from sglang.srt.openai_api.adapter import (
58
58
  load_chat_template_for_openai_api,
59
59
  v1_batches,
60
60
  v1_chat_completions,
61
61
  v1_completions,
62
+ v1_delete_file,
63
+ v1_embeddings,
62
64
  v1_files_create,
63
65
  v1_retrieve_batch,
64
66
  v1_retrieve_file,
@@ -73,7 +75,8 @@ from sglang.srt.utils import (
73
75
  enable_show_time_cost,
74
76
  kill_child_process,
75
77
  maybe_set_triton_cache_manager,
76
- set_torch_compile_config,
78
+ prepare_model,
79
+ prepare_tokenizer,
77
80
  set_ulimit,
78
81
  )
79
82
  from sglang.utils import get_exception_traceback
@@ -97,6 +100,7 @@ async def health() -> Response:
97
100
  async def get_model_info():
98
101
  result = {
99
102
  "model_path": tokenizer_manager.model_path,
103
+ "is_generation": tokenizer_manager.is_generation,
100
104
  }
101
105
  return result
102
106
 
@@ -148,6 +152,21 @@ app.post("/generate")(generate_request)
148
152
  app.put("/generate")(generate_request)
149
153
 
150
154
 
155
+ async def encode_request(obj: EmbeddingReqInput, request: Request):
156
+ """Handle an embedding request."""
157
+ try:
158
+ ret = await tokenizer_manager.generate_request(obj, request).__anext__()
159
+ return ret
160
+ except ValueError as e:
161
+ return JSONResponse(
162
+ {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
163
+ )
164
+
165
+
166
+ app.post("/encode")(encode_request)
167
+ app.put("/encode")(encode_request)
168
+
169
+
151
170
  @app.post("/v1/completions")
152
171
  async def openai_v1_completions(raw_request: Request):
153
172
  return await v1_completions(tokenizer_manager, raw_request)
@@ -158,6 +177,12 @@ async def openai_v1_chat_completions(raw_request: Request):
158
177
  return await v1_chat_completions(tokenizer_manager, raw_request)
159
178
 
160
179
 
180
+ @app.post("/v1/embeddings")
181
+ async def openai_v1_embeddings(raw_request: Request):
182
+ response = await v1_embeddings(tokenizer_manager, raw_request)
183
+ return response
184
+
185
+
161
186
  @app.get("/v1/models")
162
187
  def available_models():
163
188
  """Show available models."""
@@ -175,6 +200,12 @@ async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("bat
175
200
  )
176
201
 
177
202
 
203
+ @app.delete("/v1/files/{file_id}")
204
+ async def delete_file(file_id: str):
205
+ # https://platform.openai.com/docs/api-reference/files/delete
206
+ return await v1_delete_file(file_id)
207
+
208
+
178
209
  @app.post("/v1/batches")
179
210
  async def openai_v1_batches(raw_request: Request):
180
211
  return await v1_batches(tokenizer_manager, raw_request)
@@ -228,6 +259,10 @@ def launch_server(
228
259
  )
229
260
  logger.info(f"{server_args=}")
230
261
 
262
+ # Use model from www.modelscope.cn, first download the model.
263
+ server_args.model_path = prepare_model(server_args.model_path)
264
+ server_args.tokenizer_path = prepare_tokenizer(server_args.tokenizer_path)
265
+
231
266
  # Launch processes for multi-node tensor parallelism
232
267
  if server_args.nnodes > 1:
233
268
  if server_args.node_rank != 0:
@@ -340,10 +375,6 @@ def _set_envs_and_config(server_args: ServerArgs):
340
375
  # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
341
376
  maybe_set_triton_cache_manager()
342
377
 
343
- # Set torch compile config
344
- if server_args.enable_torch_compile:
345
- set_torch_compile_config()
346
-
347
378
  # Set global chat template
348
379
  if server_args.chat_template:
349
380
  # TODO: replace this with huggingface transformers template
@@ -353,7 +384,7 @@ def _set_envs_and_config(server_args: ServerArgs):
353
384
  if not server_args.disable_flashinfer:
354
385
  assert_pkg_version(
355
386
  "flashinfer",
356
- "0.1.3",
387
+ "0.1.4",
357
388
  "Please uninstall the old version and "
358
389
  "reinstall the latest version by following the instructions "
359
390
  "at https://docs.flashinfer.ai/installation.html.",
@@ -367,35 +398,63 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
367
398
  headers["Authorization"] = f"Bearer {server_args.api_key}"
368
399
 
369
400
  # Wait until the server is launched
401
+ success = False
370
402
  for _ in range(120):
371
403
  time.sleep(1)
372
404
  try:
373
- requests.get(url + "/get_model_info", timeout=5, headers=headers)
405
+ res = requests.get(url + "/get_model_info", timeout=5, headers=headers)
406
+ assert res.status_code == 200, f"{res}"
407
+ success = True
374
408
  break
375
- except requests.exceptions.RequestException:
409
+ except (AssertionError, requests.exceptions.RequestException) as e:
410
+ last_traceback = get_exception_traceback()
376
411
  pass
412
+ model_info = res.json()
413
+
414
+ if not success:
415
+ if pipe_finish_writer is not None:
416
+ pipe_finish_writer.send(last_traceback)
417
+ print(f"Initialization failed. warmup error: {last_traceback}", flush=True)
418
+ sys.exit(1)
377
419
 
378
420
  # Send a warmup request
421
+ request_name = "/generate" if model_info["is_generation"] else "/encode"
422
+ max_new_tokens = 8 if model_info["is_generation"] else 1
423
+ json_data = {
424
+ "sampling_params": {
425
+ "temperature": 0,
426
+ "max_new_tokens": max_new_tokens,
427
+ },
428
+ }
429
+ if server_args.skip_tokenizer_init:
430
+ json_data["input_ids"] = [10, 11, 12]
431
+ else:
432
+ json_data["text"] = "The capital city of France is"
433
+
379
434
  try:
380
435
  for _ in range(server_args.dp_size):
381
436
  res = requests.post(
382
- url + "/generate",
383
- json={
384
- "text": "The capital city of France is",
385
- "sampling_params": {
386
- "temperature": 0,
387
- "max_new_tokens": 8,
388
- },
389
- },
437
+ url + request_name,
438
+ json=json_data,
390
439
  headers=headers,
391
440
  timeout=600,
392
441
  )
393
- assert res.status_code == 200
442
+ assert res.status_code == 200, f"{res}"
394
443
  except Exception as e:
444
+ last_traceback = get_exception_traceback()
395
445
  if pipe_finish_writer is not None:
396
- pipe_finish_writer.send(get_exception_traceback())
397
- print(f"Initialization failed. warmup error: {e}", flush=True)
398
- raise e
446
+ pipe_finish_writer.send(last_traceback)
447
+ print(f"Initialization failed. warmup error: {last_traceback}", flush=True)
448
+ sys.exit(1)
449
+
450
+ # Print warnings here
451
+ if server_args.disable_radix_cache and server_args.chunked_prefill_size is not None:
452
+ logger.warning(
453
+ "You set both `--disable-radix-cache` and `--chunked-prefill-size`. "
454
+ "This combination is an experimental feature and we noticed it can lead to "
455
+ "wrong generation results. If you want to use chunked prefill, it is recommended "
456
+ "not using `--disable-radix-cache`."
457
+ )
399
458
 
400
459
  logger.info("The server is fired up and ready to roll!")
401
460
  if pipe_finish_writer is not None:
@@ -516,5 +575,18 @@ class Runtime:
516
575
  )
517
576
  return json.dumps(response.json())
518
577
 
578
+ def encode(
579
+ self,
580
+ prompt: str,
581
+ ):
582
+ json_data = {
583
+ "text": prompt,
584
+ }
585
+ response = requests.post(
586
+ self.url + "/encode",
587
+ json=json_data,
588
+ )
589
+ return json.dumps(response.json())
590
+
519
591
  def __del__(self):
520
592
  self.shutdown()
sglang/srt/server_args.py CHANGED
@@ -27,6 +27,7 @@ class ServerArgs:
27
27
  model_path: str
28
28
  tokenizer_path: Optional[str] = None
29
29
  tokenizer_mode: str = "auto"
30
+ skip_tokenizer_init: bool = False
30
31
  load_format: str = "auto"
31
32
  dtype: str = "auto"
32
33
  trust_remote_code: bool = True
@@ -42,10 +43,11 @@ class ServerArgs:
42
43
 
43
44
  # Memory and scheduling
44
45
  mem_fraction_static: Optional[float] = None
45
- max_prefill_tokens: Optional[int] = None
46
46
  max_running_requests: Optional[int] = None
47
47
  max_num_reqs: Optional[int] = None
48
48
  max_total_tokens: Optional[int] = None
49
+ chunked_prefill_size: int = -1
50
+ max_prefill_tokens: int = 16384
49
51
  schedule_policy: str = "lpm"
50
52
  schedule_conservativeness: float = 1.0
51
53
 
@@ -62,15 +64,12 @@ class ServerArgs:
62
64
 
63
65
  # Other
64
66
  api_key: Optional[str] = None
65
- file_storage_pth: str = "SGlang_storage"
67
+ file_storage_pth: str = "SGLang_storage"
66
68
 
67
69
  # Data parallelism
68
70
  dp_size: int = 1
69
71
  load_balance_method: str = "round_robin"
70
72
 
71
- # Chunked Prefill
72
- chunked_prefill_size: Optional[int] = None
73
-
74
73
  # Optimization/debug options
75
74
  disable_flashinfer: bool = False
76
75
  disable_flashinfer_sampling: bool = False
@@ -96,6 +95,10 @@ class ServerArgs:
96
95
  if self.served_model_name is None:
97
96
  self.served_model_name = self.model_path
98
97
 
98
+ if self.chunked_prefill_size <= 0:
99
+ # Disable chunked prefill
100
+ self.chunked_prefill_size = None
101
+
99
102
  if self.mem_fraction_static is None:
100
103
  if self.tp_size >= 16:
101
104
  self.mem_fraction_static = 0.79
@@ -107,6 +110,7 @@ class ServerArgs:
107
110
  self.mem_fraction_static = 0.87
108
111
  else:
109
112
  self.mem_fraction_static = 0.88
113
+
110
114
  if isinstance(self.additional_ports, int):
111
115
  self.additional_ports = [self.additional_ports]
112
116
  elif self.additional_ports is None:
@@ -151,6 +155,11 @@ class ServerArgs:
151
155
  "tokenizer if available, and 'slow' will "
152
156
  "always use the slow tokenizer.",
153
157
  )
158
+ parser.add_argument(
159
+ "--skip-tokenizer-init",
160
+ action="store_true",
161
+ help="If set, skip init tokenizer and pass input_ids in generate request",
162
+ )
154
163
  parser.add_argument(
155
164
  "--load-format",
156
165
  type=str,
@@ -226,12 +235,6 @@ class ServerArgs:
226
235
  default=ServerArgs.mem_fraction_static,
227
236
  help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
228
237
  )
229
- parser.add_argument(
230
- "--max-prefill-tokens",
231
- type=int,
232
- default=ServerArgs.max_prefill_tokens,
233
- help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
234
- )
235
238
  parser.add_argument(
236
239
  "--max-running-requests",
237
240
  type=int,
@@ -250,6 +253,18 @@ class ServerArgs:
250
253
  default=ServerArgs.max_total_tokens,
251
254
  help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. This option is typically used for development and debugging purposes.",
252
255
  )
256
+ parser.add_argument(
257
+ "--chunked-prefill-size",
258
+ type=int,
259
+ default=ServerArgs.chunked_prefill_size,
260
+ help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill",
261
+ )
262
+ parser.add_argument(
263
+ "--max-prefill-tokens",
264
+ type=int,
265
+ default=ServerArgs.max_prefill_tokens,
266
+ help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
267
+ )
253
268
  parser.add_argument(
254
269
  "--schedule-policy",
255
270
  type=str,
@@ -264,6 +279,7 @@ class ServerArgs:
264
279
  help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
265
280
  )
266
281
  parser.add_argument(
282
+ "--tensor-parallel-size",
267
283
  "--tp-size",
268
284
  type=int,
269
285
  default=ServerArgs.tp_size,
@@ -318,6 +334,7 @@ class ServerArgs:
318
334
 
319
335
  # Data parallelism
320
336
  parser.add_argument(
337
+ "--data-parallel-size",
321
338
  "--dp-size",
322
339
  type=int,
323
340
  default=ServerArgs.dp_size,
@@ -345,14 +362,6 @@ class ServerArgs:
345
362
  )
346
363
  parser.add_argument("--node-rank", type=int, help="The node rank.")
347
364
 
348
- # Chunked prefill
349
- parser.add_argument(
350
- "--chunked-prefill-size",
351
- type=int,
352
- default=ServerArgs.chunked_prefill_size,
353
- help="The size of the chunked prefill.",
354
- )
355
-
356
365
  # Optimization/debug options
357
366
  parser.add_argument(
358
367
  "--disable-flashinfer",
@@ -413,6 +422,8 @@ class ServerArgs:
413
422
 
414
423
  @classmethod
415
424
  def from_cli_args(cls, args: argparse.Namespace):
425
+ args.tp_size = args.tensor_parallel_size
426
+ args.dp_size = args.data_parallel_size
416
427
  attrs = [attr.name for attr in dataclasses.fields(cls)]
417
428
  return cls(**{attr: getattr(args, attr) for attr in attrs})
418
429
 
sglang/srt/utils.py CHANGED
@@ -197,6 +197,8 @@ def allocate_init_ports(
197
197
  def get_int_token_logit_bias(tokenizer, vocab_size):
198
198
  """Get the logit bias for integer-only tokens."""
199
199
  # a bug when model's vocab size > tokenizer.vocab_size
200
+ if tokenizer == None:
201
+ return [-1e5] * vocab_size
200
202
  vocab_size = tokenizer.vocab_size
201
203
  logit_bias = np.zeros(vocab_size, dtype=np.float32)
202
204
  for t_id in range(vocab_size):
@@ -223,6 +225,15 @@ def is_multimodal_model(model):
223
225
  raise ValueError("unrecognized type")
224
226
 
225
227
 
228
+ def is_generation_model(model_architectures):
229
+ if (
230
+ "LlamaEmbeddingModel" in model_architectures
231
+ or "MistralModel" in model_architectures
232
+ ):
233
+ return False
234
+ return True
235
+
236
+
226
237
  def decode_video_base64(video_base64):
227
238
  from PIL import Image
228
239
 
@@ -622,19 +633,6 @@ def receive_addrs(model_port_args, server_args):
622
633
  dist.destroy_process_group()
623
634
 
624
635
 
625
- def set_torch_compile_config():
626
- # The following configurations are for torch compile optimizations
627
- import torch._dynamo.config
628
- import torch._inductor.config
629
-
630
- torch._inductor.config.coordinate_descent_tuning = True
631
- torch._inductor.config.triton.unique_kernel_names = True
632
- torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
633
-
634
- # FIXME: tmp workaround
635
- torch._dynamo.config.accumulated_cache_size_limit = 256
636
-
637
-
638
636
  def set_ulimit(target_soft_limit=65535):
639
637
  resource_type = resource.RLIMIT_NOFILE
640
638
  current_soft, current_hard = resource.getrlimit(resource_type)
@@ -705,3 +703,23 @@ def add_api_key_middleware(app, api_key):
705
703
  if request.headers.get("Authorization") != "Bearer " + api_key:
706
704
  return JSONResponse(content={"error": "Unauthorized"}, status_code=401)
707
705
  return await call_next(request)
706
+
707
+
708
+ def prepare_model(model_path):
709
+ if "SGLANG_USE_MODELSCOPE" in os.environ:
710
+ if not os.path.exists(model_path):
711
+ from modelscope import snapshot_download
712
+
713
+ return snapshot_download(model_path)
714
+ return model_path
715
+
716
+
717
+ def prepare_tokenizer(tokenizer_path):
718
+ if "SGLANG_USE_MODELSCOPE" in os.environ:
719
+ if not os.path.exists(tokenizer_path):
720
+ from modelscope import snapshot_download
721
+
722
+ return snapshot_download(
723
+ tokenizer_path, ignore_patterns=["*.bin", "*.safetensors"]
724
+ )
725
+ return tokenizer_path
sglang/test/run_eval.py CHANGED
@@ -16,6 +16,8 @@ from sglang.test.simple_eval_common import (
16
16
 
17
17
 
18
18
  def run_eval(args):
19
+ set_ulimit()
20
+
19
21
  if "OPENAI_API_KEY" not in os.environ:
20
22
  os.environ["OPENAI_API_KEY"] = "EMPTY"
21
23
 
@@ -39,6 +41,14 @@ def run_eval(args):
39
41
  eval_obj = MathEval(
40
42
  filename, equality_checker, args.num_examples, args.num_threads
41
43
  )
44
+ elif args.eval_name == "mgsm":
45
+ from sglang.test.simple_eval_mgsm import MGSMEval
46
+
47
+ eval_obj = MGSMEval(args.num_examples, args.num_threads)
48
+ elif args.eval_name == "mgsm_en":
49
+ from sglang.test.simple_eval_mgsm import MGSMEval
50
+
51
+ eval_obj = MGSMEval(args.num_examples, args.num_threads, languages=["en"])
42
52
  elif args.eval_name == "gpqa":
43
53
  from sglang.test.simple_eval_gpqa import GPQAEval
44
54
 
@@ -109,7 +119,6 @@ if __name__ == "__main__":
109
119
  parser.add_argument("--eval-name", type=str, default="mmlu")
110
120
  parser.add_argument("--num-examples", type=int)
111
121
  parser.add_argument("--num-threads", type=int, default=512)
112
- set_ulimit()
113
122
  args = parser.parse_args()
114
123
 
115
124
  run_eval(args)