sglang 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. sglang/api.py +7 -1
  2. sglang/bench_latency.py +9 -6
  3. sglang/bench_serving.py +46 -22
  4. sglang/global_config.py +1 -1
  5. sglang/lang/backend/runtime_endpoint.py +60 -49
  6. sglang/lang/compiler.py +2 -2
  7. sglang/lang/interpreter.py +4 -2
  8. sglang/lang/ir.py +16 -7
  9. sglang/srt/constrained/base_tool_cache.py +1 -1
  10. sglang/srt/constrained/fsm_cache.py +12 -2
  11. sglang/srt/constrained/jump_forward.py +13 -2
  12. sglang/srt/layers/activation.py +32 -0
  13. sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
  14. sglang/srt/layers/extend_attention.py +9 -2
  15. sglang/srt/layers/fused_moe/__init__.py +1 -0
  16. sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
  17. sglang/srt/layers/fused_moe/layer.py +587 -0
  18. sglang/srt/layers/layernorm.py +65 -0
  19. sglang/srt/layers/logits_processor.py +7 -2
  20. sglang/srt/layers/pooler.py +50 -0
  21. sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
  22. sglang/srt/layers/radix_attention.py +40 -16
  23. sglang/srt/managers/detokenizer_manager.py +31 -9
  24. sglang/srt/managers/io_struct.py +63 -0
  25. sglang/srt/managers/policy_scheduler.py +173 -25
  26. sglang/srt/managers/schedule_batch.py +115 -97
  27. sglang/srt/managers/tokenizer_manager.py +194 -112
  28. sglang/srt/managers/tp_worker.py +290 -359
  29. sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
  30. sglang/srt/mem_cache/chunk_cache.py +43 -20
  31. sglang/srt/mem_cache/memory_pool.py +2 -2
  32. sglang/srt/mem_cache/radix_cache.py +74 -40
  33. sglang/srt/model_executor/cuda_graph_runner.py +71 -25
  34. sglang/srt/model_executor/forward_batch_info.py +293 -156
  35. sglang/srt/model_executor/model_runner.py +77 -57
  36. sglang/srt/models/chatglm.py +2 -2
  37. sglang/srt/models/commandr.py +1 -1
  38. sglang/srt/models/deepseek.py +2 -2
  39. sglang/srt/models/deepseek_v2.py +7 -6
  40. sglang/srt/models/gemma.py +1 -1
  41. sglang/srt/models/gemma2.py +11 -6
  42. sglang/srt/models/grok.py +50 -396
  43. sglang/srt/models/internlm2.py +2 -7
  44. sglang/srt/models/llama2.py +4 -4
  45. sglang/srt/models/llama_embedding.py +88 -0
  46. sglang/srt/models/minicpm.py +2 -2
  47. sglang/srt/models/mixtral.py +56 -254
  48. sglang/srt/models/mixtral_quant.py +1 -4
  49. sglang/srt/models/qwen.py +2 -2
  50. sglang/srt/models/qwen2.py +2 -2
  51. sglang/srt/models/qwen2_moe.py +2 -13
  52. sglang/srt/models/stablelm.py +1 -1
  53. sglang/srt/openai_api/adapter.py +187 -48
  54. sglang/srt/openai_api/protocol.py +37 -1
  55. sglang/srt/sampling/penaltylib/__init__.py +13 -0
  56. sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
  57. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
  58. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
  59. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
  60. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
  61. sglang/srt/sampling_params.py +31 -8
  62. sglang/srt/server.py +91 -29
  63. sglang/srt/server_args.py +32 -19
  64. sglang/srt/utils.py +32 -15
  65. sglang/test/run_eval.py +10 -1
  66. sglang/test/runners.py +81 -73
  67. sglang/test/simple_eval_humaneval.py +2 -8
  68. sglang/test/simple_eval_mgsm.py +203 -0
  69. sglang/test/srt/sampling/penaltylib/utils.py +337 -0
  70. sglang/test/test_layernorm.py +60 -0
  71. sglang/test/test_programs.py +36 -7
  72. sglang/test/test_utils.py +24 -2
  73. sglang/utils.py +0 -1
  74. sglang/version.py +1 -1
  75. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/METADATA +33 -16
  76. sglang-0.2.13.dist-info/RECORD +112 -0
  77. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/WHEEL +1 -1
  78. sglang/srt/layers/linear.py +0 -884
  79. sglang/srt/layers/quantization/__init__.py +0 -64
  80. sglang/srt/layers/quantization/fp8.py +0 -677
  81. sglang/srt/model_loader/model_loader.py +0 -292
  82. sglang/srt/model_loader/utils.py +0 -275
  83. sglang-0.2.11.dist-info/RECORD +0 -102
  84. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/LICENSE +0 -0
  85. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/top_level.txt +0 -0
@@ -23,17 +23,19 @@ _SAMPLING_EPS = 1e-6
23
23
  class SamplingParams:
24
24
  def __init__(
25
25
  self,
26
- max_new_tokens: int = 16,
26
+ max_new_tokens: int = 128,
27
+ min_new_tokens: int = 0,
27
28
  stop: Optional[Union[str, List[str]]] = None,
29
+ stop_token_ids: Optional[List[int]] = [],
28
30
  temperature: float = 1.0,
29
31
  top_p: float = 1.0,
30
32
  top_k: int = -1,
31
33
  frequency_penalty: float = 0.0,
32
34
  presence_penalty: float = 0.0,
35
+ repetition_penalty: float = 1.0,
33
36
  ignore_eos: bool = False,
34
37
  skip_special_tokens: bool = True,
35
38
  spaces_between_special_tokens: bool = True,
36
- dtype: Optional[str] = None,
37
39
  regex: Optional[str] = None,
38
40
  n: int = 1,
39
41
  ) -> None:
@@ -42,12 +44,14 @@ class SamplingParams:
42
44
  self.top_k = top_k
43
45
  self.frequency_penalty = frequency_penalty
44
46
  self.presence_penalty = presence_penalty
47
+ self.repetition_penalty = repetition_penalty
45
48
  self.stop_strs = stop
49
+ self.stop_token_ids = {*stop_token_ids}
46
50
  self.max_new_tokens = max_new_tokens
51
+ self.min_new_tokens = min_new_tokens
47
52
  self.ignore_eos = ignore_eos
48
53
  self.skip_special_tokens = skip_special_tokens
49
54
  self.spaces_between_special_tokens = spaces_between_special_tokens
50
- self.dtype = dtype
51
55
  self.regex = regex
52
56
  self.n = n
53
57
 
@@ -57,8 +61,6 @@ class SamplingParams:
57
61
  self.top_k = 1
58
62
  if self.top_k == -1:
59
63
  self.top_k = 1 << 30 # whole vocabulary
60
- if self.dtype == "int":
61
- self.stop_strs = [" ", "\n"]
62
64
 
63
65
  def verify(self):
64
66
  if self.temperature < 0.0:
@@ -80,23 +82,44 @@ class SamplingParams:
80
82
  raise ValueError(
81
83
  "presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}."
82
84
  )
85
+ if not 0.0 <= self.repetition_penalty <= 2.0:
86
+ raise ValueError(
87
+ "repetition_penalty must be in (0, 2], got "
88
+ f"{self.repetition_penalty}."
89
+ )
90
+ if not 0 <= self.min_new_tokens:
91
+ raise ValueError(
92
+ f"min_new_tokens must be in (0, max_new_tokens], got "
93
+ f"{self.min_new_tokens}."
94
+ )
83
95
  if self.max_new_tokens is not None:
84
96
  if self.max_new_tokens < 0:
85
97
  raise ValueError(
86
98
  f"max_new_tokens must be at least 0, got {self.max_new_tokens}."
87
99
  )
100
+ if not self.min_new_tokens <= self.max_new_tokens:
101
+ raise ValueError(
102
+ f"min_new_tokens must be in (0, max_new_tokens({self.max_new_tokens})], got "
103
+ f"{self.min_new_tokens}."
104
+ )
88
105
 
89
106
  def normalize(self, tokenizer):
90
107
  # Process stop strings
91
108
  if self.stop_strs is None:
92
109
  self.stop_strs = []
93
- self.stop_str_max_len = 0
110
+ if self.stop_token_ids is None:
111
+ self.stop_str_max_len = 0
112
+ else:
113
+ self.stop_str_max_len = 1
94
114
  else:
95
115
  if isinstance(self.stop_strs, str):
96
116
  self.stop_strs = [self.stop_strs]
97
117
 
98
118
  stop_str_max_len = 0
99
119
  for stop_str in self.stop_strs:
100
- stop_str_ids = tokenizer.encode(stop_str, add_special_tokens=False)
101
- stop_str_max_len = max(stop_str_max_len, len(stop_str_ids))
120
+ if tokenizer is not None:
121
+ stop_str_ids = tokenizer.encode(stop_str, add_special_tokens=False)
122
+ stop_str_max_len = max(stop_str_max_len, len(stop_str_ids))
123
+ else:
124
+ stop_str_max_len = max(stop_str_max_len, len(stop_str))
102
125
  self.stop_str_max_len = stop_str_max_len
sglang/srt/server.py CHANGED
@@ -52,7 +52,7 @@ from sglang.srt.managers.controller_single import (
52
52
  start_controller_process as start_controller_process_single,
53
53
  )
54
54
  from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
55
- from sglang.srt.managers.io_struct import GenerateReqInput
55
+ from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
56
56
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
57
57
  from sglang.srt.openai_api.adapter import (
58
58
  load_chat_template_for_openai_api,
@@ -60,6 +60,7 @@ from sglang.srt.openai_api.adapter import (
60
60
  v1_chat_completions,
61
61
  v1_completions,
62
62
  v1_delete_file,
63
+ v1_embeddings,
63
64
  v1_files_create,
64
65
  v1_retrieve_batch,
65
66
  v1_retrieve_file,
@@ -74,7 +75,8 @@ from sglang.srt.utils import (
74
75
  enable_show_time_cost,
75
76
  kill_child_process,
76
77
  maybe_set_triton_cache_manager,
77
- set_torch_compile_config,
78
+ prepare_model,
79
+ prepare_tokenizer,
78
80
  set_ulimit,
79
81
  )
80
82
  from sglang.utils import get_exception_traceback
@@ -98,6 +100,7 @@ async def health() -> Response:
98
100
  async def get_model_info():
99
101
  result = {
100
102
  "model_path": tokenizer_manager.model_path,
103
+ "is_generation": tokenizer_manager.is_generation,
101
104
  }
102
105
  return result
103
106
 
@@ -149,6 +152,21 @@ app.post("/generate")(generate_request)
149
152
  app.put("/generate")(generate_request)
150
153
 
151
154
 
155
+ async def encode_request(obj: EmbeddingReqInput, request: Request):
156
+ """Handle an embedding request."""
157
+ try:
158
+ ret = await tokenizer_manager.generate_request(obj, request).__anext__()
159
+ return ret
160
+ except ValueError as e:
161
+ return JSONResponse(
162
+ {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
163
+ )
164
+
165
+
166
+ app.post("/encode")(encode_request)
167
+ app.put("/encode")(encode_request)
168
+
169
+
152
170
  @app.post("/v1/completions")
153
171
  async def openai_v1_completions(raw_request: Request):
154
172
  return await v1_completions(tokenizer_manager, raw_request)
@@ -159,6 +177,12 @@ async def openai_v1_chat_completions(raw_request: Request):
159
177
  return await v1_chat_completions(tokenizer_manager, raw_request)
160
178
 
161
179
 
180
+ @app.post("/v1/embeddings")
181
+ async def openai_v1_embeddings(raw_request: Request):
182
+ response = await v1_embeddings(tokenizer_manager, raw_request)
183
+ return response
184
+
185
+
162
186
  @app.get("/v1/models")
163
187
  def available_models():
164
188
  """Show available models."""
@@ -235,6 +259,10 @@ def launch_server(
235
259
  )
236
260
  logger.info(f"{server_args=}")
237
261
 
262
+ # Use model from www.modelscope.cn, first download the model.
263
+ server_args.model_path = prepare_model(server_args.model_path)
264
+ server_args.tokenizer_path = prepare_tokenizer(server_args.tokenizer_path)
265
+
238
266
  # Launch processes for multi-node tensor parallelism
239
267
  if server_args.nnodes > 1:
240
268
  if server_args.node_rank != 0:
@@ -260,6 +288,8 @@ def launch_server(
260
288
 
261
289
  # Launch processes
262
290
  tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
291
+ if server_args.chat_template:
292
+ load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
263
293
  pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
264
294
  pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
265
295
 
@@ -330,6 +360,7 @@ def _set_envs_and_config(server_args: ServerArgs):
330
360
  os.environ["NCCL_CUMEM_ENABLE"] = "0"
331
361
  os.environ["NCCL_NVLS_ENABLE"] = "0"
332
362
  os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
363
+ os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
333
364
 
334
365
  # Set ulimit
335
366
  set_ulimit()
@@ -347,20 +378,11 @@ def _set_envs_and_config(server_args: ServerArgs):
347
378
  # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
348
379
  maybe_set_triton_cache_manager()
349
380
 
350
- # Set torch compile config
351
- if server_args.enable_torch_compile:
352
- set_torch_compile_config()
353
-
354
- # Set global chat template
355
- if server_args.chat_template:
356
- # TODO: replace this with huggingface transformers template
357
- load_chat_template_for_openai_api(server_args.chat_template)
358
-
359
381
  # Check flashinfer version
360
382
  if not server_args.disable_flashinfer:
361
383
  assert_pkg_version(
362
384
  "flashinfer",
363
- "0.1.3",
385
+ "0.1.5",
364
386
  "Please uninstall the old version and "
365
387
  "reinstall the latest version by following the instructions "
366
388
  "at https://docs.flashinfer.ai/installation.html.",
@@ -385,6 +407,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
385
407
  except (AssertionError, requests.exceptions.RequestException) as e:
386
408
  last_traceback = get_exception_traceback()
387
409
  pass
410
+ model_info = res.json()
388
411
 
389
412
  if not success:
390
413
  if pipe_finish_writer is not None:
@@ -393,17 +416,24 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
393
416
  sys.exit(1)
394
417
 
395
418
  # Send a warmup request
419
+ request_name = "/generate" if model_info["is_generation"] else "/encode"
420
+ max_new_tokens = 8 if model_info["is_generation"] else 1
421
+ json_data = {
422
+ "sampling_params": {
423
+ "temperature": 0,
424
+ "max_new_tokens": max_new_tokens,
425
+ },
426
+ }
427
+ if server_args.skip_tokenizer_init:
428
+ json_data["input_ids"] = [10, 11, 12]
429
+ else:
430
+ json_data["text"] = "The capital city of France is"
431
+
396
432
  try:
397
433
  for _ in range(server_args.dp_size):
398
434
  res = requests.post(
399
- url + "/generate",
400
- json={
401
- "text": "The capital city of France is",
402
- "sampling_params": {
403
- "temperature": 0,
404
- "max_new_tokens": 8,
405
- },
406
- },
435
+ url + request_name,
436
+ json=json_data,
407
437
  headers=headers,
408
438
  timeout=600,
409
439
  )
@@ -415,6 +445,15 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
415
445
  print(f"Initialization failed. warmup error: {last_traceback}", flush=True)
416
446
  sys.exit(1)
417
447
 
448
+ # Print warnings here
449
+ if server_args.disable_radix_cache and server_args.chunked_prefill_size is not None:
450
+ logger.warning(
451
+ "You set both `--disable-radix-cache` and `--chunked-prefill-size`. "
452
+ "This combination is an experimental feature and we noticed it can lead to "
453
+ "wrong generation results. If you want to use chunked prefill, it is recommended "
454
+ "not using `--disable-radix-cache`."
455
+ )
456
+
418
457
  logger.info("The server is fired up and ready to roll!")
419
458
  if pipe_finish_writer is not None:
420
459
  pipe_finish_writer.send("init ok")
@@ -492,11 +531,18 @@ class Runtime:
492
531
  prompt: str,
493
532
  sampling_params: Optional[Dict] = None,
494
533
  ):
495
- json_data = {
496
- "text": prompt,
497
- "sampling_params": sampling_params,
498
- "stream": True,
499
- }
534
+ if self.server_args.skip_tokenizer_init:
535
+ json_data = {
536
+ "input_ids": prompt,
537
+ "sampling_params": sampling_params,
538
+ "stream": True,
539
+ }
540
+ else:
541
+ json_data = {
542
+ "text": prompt,
543
+ "sampling_params": sampling_params,
544
+ "stream": True,
545
+ }
500
546
  pos = 0
501
547
 
502
548
  timeout = aiohttp.ClientTimeout(total=3 * 3600)
@@ -508,10 +554,13 @@ class Runtime:
508
554
  if chunk == "data: [DONE]\n\n":
509
555
  break
510
556
  data = json.loads(chunk[5:].strip("\n"))
511
- cur = data["text"][pos:]
512
- if cur:
513
- yield cur
514
- pos += len(cur)
557
+ if hasattr(data, "text"):
558
+ cur = data["text"][pos:]
559
+ if cur:
560
+ yield cur
561
+ pos += len(cur)
562
+ else:
563
+ yield data
515
564
 
516
565
  add_request = async_generate
517
566
 
@@ -534,5 +583,18 @@ class Runtime:
534
583
  )
535
584
  return json.dumps(response.json())
536
585
 
586
+ def encode(
587
+ self,
588
+ prompt: str,
589
+ ):
590
+ json_data = {
591
+ "text": prompt,
592
+ }
593
+ response = requests.post(
594
+ self.url + "/encode",
595
+ json=json_data,
596
+ )
597
+ return json.dumps(response.json())
598
+
537
599
  def __del__(self):
538
600
  self.shutdown()
sglang/srt/server_args.py CHANGED
@@ -17,9 +17,12 @@ limitations under the License.
17
17
 
18
18
  import argparse
19
19
  import dataclasses
20
+ import logging
20
21
  import random
21
22
  from typing import List, Optional, Union
22
23
 
24
+ logger = logging.getLogger(__name__)
25
+
23
26
 
24
27
  @dataclasses.dataclass
25
28
  class ServerArgs:
@@ -27,6 +30,7 @@ class ServerArgs:
27
30
  model_path: str
28
31
  tokenizer_path: Optional[str] = None
29
32
  tokenizer_mode: str = "auto"
33
+ skip_tokenizer_init: bool = False
30
34
  load_format: str = "auto"
31
35
  dtype: str = "auto"
32
36
  trust_remote_code: bool = True
@@ -42,10 +46,11 @@ class ServerArgs:
42
46
 
43
47
  # Memory and scheduling
44
48
  mem_fraction_static: Optional[float] = None
45
- max_prefill_tokens: Optional[int] = None
46
49
  max_running_requests: Optional[int] = None
47
50
  max_num_reqs: Optional[int] = None
48
51
  max_total_tokens: Optional[int] = None
52
+ chunked_prefill_size: int = 8192
53
+ max_prefill_tokens: int = 16384
49
54
  schedule_policy: str = "lpm"
50
55
  schedule_conservativeness: float = 1.0
51
56
 
@@ -62,15 +67,12 @@ class ServerArgs:
62
67
 
63
68
  # Other
64
69
  api_key: Optional[str] = None
65
- file_storage_pth: str = "SGlang_storage"
70
+ file_storage_pth: str = "SGLang_storage"
66
71
 
67
72
  # Data parallelism
68
73
  dp_size: int = 1
69
74
  load_balance_method: str = "round_robin"
70
75
 
71
- # Chunked Prefill
72
- chunked_prefill_size: Optional[int] = None
73
-
74
76
  # Optimization/debug options
75
77
  disable_flashinfer: bool = False
76
78
  disable_flashinfer_sampling: bool = False
@@ -96,6 +98,10 @@ class ServerArgs:
96
98
  if self.served_model_name is None:
97
99
  self.served_model_name = self.model_path
98
100
 
101
+ if self.chunked_prefill_size <= 0:
102
+ # Disable chunked prefill
103
+ self.chunked_prefill_size = None
104
+
99
105
  if self.mem_fraction_static is None:
100
106
  if self.tp_size >= 16:
101
107
  self.mem_fraction_static = 0.79
@@ -107,6 +113,7 @@ class ServerArgs:
107
113
  self.mem_fraction_static = 0.87
108
114
  else:
109
115
  self.mem_fraction_static = 0.88
116
+
110
117
  if isinstance(self.additional_ports, int):
111
118
  self.additional_ports = [self.additional_ports]
112
119
  elif self.additional_ports is None:
@@ -151,6 +158,11 @@ class ServerArgs:
151
158
  "tokenizer if available, and 'slow' will "
152
159
  "always use the slow tokenizer.",
153
160
  )
161
+ parser.add_argument(
162
+ "--skip-tokenizer-init",
163
+ action="store_true",
164
+ help="If set, skip init tokenizer and pass input_ids in generate request",
165
+ )
154
166
  parser.add_argument(
155
167
  "--load-format",
156
168
  type=str,
@@ -226,12 +238,6 @@ class ServerArgs:
226
238
  default=ServerArgs.mem_fraction_static,
227
239
  help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
228
240
  )
229
- parser.add_argument(
230
- "--max-prefill-tokens",
231
- type=int,
232
- default=ServerArgs.max_prefill_tokens,
233
- help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
234
- )
235
241
  parser.add_argument(
236
242
  "--max-running-requests",
237
243
  type=int,
@@ -250,6 +256,18 @@ class ServerArgs:
250
256
  default=ServerArgs.max_total_tokens,
251
257
  help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. This option is typically used for development and debugging purposes.",
252
258
  )
259
+ parser.add_argument(
260
+ "--chunked-prefill-size",
261
+ type=int,
262
+ default=ServerArgs.chunked_prefill_size,
263
+ help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill",
264
+ )
265
+ parser.add_argument(
266
+ "--max-prefill-tokens",
267
+ type=int,
268
+ default=ServerArgs.max_prefill_tokens,
269
+ help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
270
+ )
253
271
  parser.add_argument(
254
272
  "--schedule-policy",
255
273
  type=str,
@@ -347,14 +365,6 @@ class ServerArgs:
347
365
  )
348
366
  parser.add_argument("--node-rank", type=int, help="The node rank.")
349
367
 
350
- # Chunked prefill
351
- parser.add_argument(
352
- "--chunked-prefill-size",
353
- type=int,
354
- default=ServerArgs.chunked_prefill_size,
355
- help="The size of the chunked prefill.",
356
- )
357
-
358
368
  # Optimization/debug options
359
369
  parser.add_argument(
360
370
  "--disable-flashinfer",
@@ -439,6 +449,9 @@ class ServerArgs:
439
449
  assert not (
440
450
  self.dp_size > 1 and self.node_rank is not None
441
451
  ), "multi-node data parallel is not supported"
452
+ if "gemma-2" in self.model_path.lower():
453
+ logger.info(f"When using sliding window in gemma-2, turn on flashinfer.")
454
+ self.disable_flashinfer = False
442
455
 
443
456
 
444
457
  @dataclasses.dataclass
sglang/srt/utils.py CHANGED
@@ -35,7 +35,6 @@ import torch
35
35
  import torch.distributed as dist
36
36
  from fastapi.responses import JSONResponse
37
37
  from packaging import version as pkg_version
38
- from starlette.middleware.base import BaseHTTPMiddleware
39
38
  from torch.nn.parameter import Parameter
40
39
  from triton.runtime.cache import (
41
40
  FileCacheManager,
@@ -197,6 +196,8 @@ def allocate_init_ports(
197
196
  def get_int_token_logit_bias(tokenizer, vocab_size):
198
197
  """Get the logit bias for integer-only tokens."""
199
198
  # a bug when model's vocab size > tokenizer.vocab_size
199
+ if tokenizer == None:
200
+ return [-1e5] * vocab_size
200
201
  vocab_size = tokenizer.vocab_size
201
202
  logit_bias = np.zeros(vocab_size, dtype=np.float32)
202
203
  for t_id in range(vocab_size):
@@ -223,6 +224,15 @@ def is_multimodal_model(model):
223
224
  raise ValueError("unrecognized type")
224
225
 
225
226
 
227
+ def is_generation_model(model_architectures):
228
+ if (
229
+ "LlamaEmbeddingModel" in model_architectures
230
+ or "MistralModel" in model_architectures
231
+ ):
232
+ return False
233
+ return True
234
+
235
+
226
236
  def decode_video_base64(video_base64):
227
237
  from PIL import Image
228
238
 
@@ -622,19 +632,6 @@ def receive_addrs(model_port_args, server_args):
622
632
  dist.destroy_process_group()
623
633
 
624
634
 
625
- def set_torch_compile_config():
626
- # The following configurations are for torch compile optimizations
627
- import torch._dynamo.config
628
- import torch._inductor.config
629
-
630
- torch._inductor.config.coordinate_descent_tuning = True
631
- torch._inductor.config.triton.unique_kernel_names = True
632
- torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
633
-
634
- # FIXME: tmp workaround
635
- torch._dynamo.config.accumulated_cache_size_limit = 256
636
-
637
-
638
635
  def set_ulimit(target_soft_limit=65535):
639
636
  resource_type = resource.RLIMIT_NOFILE
640
637
  current_soft, current_hard = resource.getrlimit(resource_type)
@@ -646,7 +643,7 @@ def set_ulimit(target_soft_limit=65535):
646
643
  logger.warn(f"Fail to set RLIMIT_NOFILE: {e}")
647
644
 
648
645
 
649
- def is_llama3_405b_fp8(model_config):
646
+ def is_llama3_405b_fp8_head_16(model_config):
650
647
  """Return whether the model is meta-llama/Meta-Llama-3.1-405B-FP8 with 16 kv heads."""
651
648
  if (
652
649
  model_config.hf_config.architectures[0] == "LlamaForCausalLM"
@@ -705,3 +702,23 @@ def add_api_key_middleware(app, api_key):
705
702
  if request.headers.get("Authorization") != "Bearer " + api_key:
706
703
  return JSONResponse(content={"error": "Unauthorized"}, status_code=401)
707
704
  return await call_next(request)
705
+
706
+
707
+ def prepare_model(model_path):
708
+ if "SGLANG_USE_MODELSCOPE" in os.environ:
709
+ if not os.path.exists(model_path):
710
+ from modelscope import snapshot_download
711
+
712
+ return snapshot_download(model_path)
713
+ return model_path
714
+
715
+
716
+ def prepare_tokenizer(tokenizer_path):
717
+ if "SGLANG_USE_MODELSCOPE" in os.environ:
718
+ if not os.path.exists(tokenizer_path):
719
+ from modelscope import snapshot_download
720
+
721
+ return snapshot_download(
722
+ tokenizer_path, ignore_patterns=["*.bin", "*.safetensors"]
723
+ )
724
+ return tokenizer_path
sglang/test/run_eval.py CHANGED
@@ -16,6 +16,8 @@ from sglang.test.simple_eval_common import (
16
16
 
17
17
 
18
18
  def run_eval(args):
19
+ set_ulimit()
20
+
19
21
  if "OPENAI_API_KEY" not in os.environ:
20
22
  os.environ["OPENAI_API_KEY"] = "EMPTY"
21
23
 
@@ -39,6 +41,14 @@ def run_eval(args):
39
41
  eval_obj = MathEval(
40
42
  filename, equality_checker, args.num_examples, args.num_threads
41
43
  )
44
+ elif args.eval_name == "mgsm":
45
+ from sglang.test.simple_eval_mgsm import MGSMEval
46
+
47
+ eval_obj = MGSMEval(args.num_examples, args.num_threads)
48
+ elif args.eval_name == "mgsm_en":
49
+ from sglang.test.simple_eval_mgsm import MGSMEval
50
+
51
+ eval_obj = MGSMEval(args.num_examples, args.num_threads, languages=["en"])
42
52
  elif args.eval_name == "gpqa":
43
53
  from sglang.test.simple_eval_gpqa import GPQAEval
44
54
 
@@ -109,7 +119,6 @@ if __name__ == "__main__":
109
119
  parser.add_argument("--eval-name", type=str, default="mmlu")
110
120
  parser.add_argument("--num-examples", type=int)
111
121
  parser.add_argument("--num-threads", type=int, default=512)
112
- set_ulimit()
113
122
  args = parser.parse_args()
114
123
 
115
124
  run_eval(args)