sglang 0.2.13__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. sglang/api.py +6 -0
  2. sglang/bench_latency.py +7 -3
  3. sglang/bench_serving.py +50 -26
  4. sglang/check_env.py +15 -0
  5. sglang/lang/chat_template.py +10 -5
  6. sglang/lang/compiler.py +4 -0
  7. sglang/lang/interpreter.py +1 -0
  8. sglang/lang/ir.py +9 -0
  9. sglang/launch_server.py +8 -1
  10. sglang/srt/conversation.py +50 -1
  11. sglang/srt/hf_transformers_utils.py +22 -23
  12. sglang/srt/layers/activation.py +24 -1
  13. sglang/srt/layers/decode_attention.py +338 -50
  14. sglang/srt/layers/fused_moe/layer.py +2 -2
  15. sglang/srt/layers/layernorm.py +3 -0
  16. sglang/srt/layers/logits_processor.py +60 -23
  17. sglang/srt/layers/radix_attention.py +3 -4
  18. sglang/srt/layers/sampler.py +154 -0
  19. sglang/srt/managers/controller_multi.py +2 -8
  20. sglang/srt/managers/controller_single.py +7 -10
  21. sglang/srt/managers/detokenizer_manager.py +20 -9
  22. sglang/srt/managers/io_struct.py +44 -11
  23. sglang/srt/managers/policy_scheduler.py +5 -2
  24. sglang/srt/managers/schedule_batch.py +52 -167
  25. sglang/srt/managers/tokenizer_manager.py +192 -83
  26. sglang/srt/managers/tp_worker.py +130 -43
  27. sglang/srt/mem_cache/memory_pool.py +82 -8
  28. sglang/srt/mm_utils.py +79 -7
  29. sglang/srt/model_executor/cuda_graph_runner.py +49 -11
  30. sglang/srt/model_executor/forward_batch_info.py +59 -27
  31. sglang/srt/model_executor/model_runner.py +210 -61
  32. sglang/srt/models/chatglm.py +4 -12
  33. sglang/srt/models/commandr.py +5 -1
  34. sglang/srt/models/dbrx.py +5 -1
  35. sglang/srt/models/deepseek.py +5 -1
  36. sglang/srt/models/deepseek_v2.py +5 -1
  37. sglang/srt/models/gemma.py +5 -1
  38. sglang/srt/models/gemma2.py +15 -7
  39. sglang/srt/models/gpt_bigcode.py +5 -1
  40. sglang/srt/models/grok.py +16 -2
  41. sglang/srt/models/internlm2.py +5 -1
  42. sglang/srt/models/llama2.py +7 -3
  43. sglang/srt/models/llama_classification.py +2 -2
  44. sglang/srt/models/llama_embedding.py +4 -0
  45. sglang/srt/models/llava.py +176 -59
  46. sglang/srt/models/minicpm.py +5 -1
  47. sglang/srt/models/mixtral.py +5 -1
  48. sglang/srt/models/mixtral_quant.py +5 -1
  49. sglang/srt/models/qwen.py +5 -2
  50. sglang/srt/models/qwen2.py +13 -3
  51. sglang/srt/models/qwen2_moe.py +5 -14
  52. sglang/srt/models/stablelm.py +5 -1
  53. sglang/srt/openai_api/adapter.py +117 -37
  54. sglang/srt/sampling/sampling_batch_info.py +209 -0
  55. sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -0
  56. sglang/srt/server.py +84 -56
  57. sglang/srt/server_args.py +43 -15
  58. sglang/srt/utils.py +26 -16
  59. sglang/test/runners.py +23 -31
  60. sglang/test/simple_eval_common.py +9 -10
  61. sglang/test/simple_eval_gpqa.py +2 -1
  62. sglang/test/simple_eval_humaneval.py +2 -2
  63. sglang/test/simple_eval_math.py +2 -1
  64. sglang/test/simple_eval_mmlu.py +2 -1
  65. sglang/test/test_activation.py +55 -0
  66. sglang/test/test_utils.py +36 -53
  67. sglang/version.py +1 -1
  68. {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/METADATA +92 -25
  69. sglang-0.2.14.dist-info/RECORD +114 -0
  70. {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
  71. sglang/launch_server_llavavid.py +0 -29
  72. sglang-0.2.13.dist-info/RECORD +0 -112
  73. {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
  74. {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,209 @@
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ from typing import TYPE_CHECKING, List
5
+
6
+ import torch
7
+
8
+ import sglang.srt.sampling.penaltylib as penaltylib
9
+
10
+ if TYPE_CHECKING:
11
+ from sglang.srt.managers.schedule_batch import ScheduleBatch
12
+
13
+
14
+ @dataclasses.dataclass
15
+ class SamplingBatchInfo:
16
+ # Basic Info
17
+ vocab_size: int
18
+
19
+ # Batched sampling params
20
+ temperatures: torch.Tensor = None
21
+ top_ps: torch.Tensor = None
22
+ top_ks: torch.Tensor = None
23
+ min_ps: torch.Tensor = None
24
+
25
+ # Dispatch in CUDA graph
26
+ need_min_p_sampling: bool = False
27
+
28
+ # Bias Tensors
29
+ logit_bias: torch.Tensor = None
30
+ vocab_mask: torch.Tensor = None
31
+
32
+ # Penalizer
33
+ penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
34
+ linear_penalties: torch.Tensor = None
35
+ scaling_penalties: torch.Tensor = None
36
+
37
+ def has_bias(self):
38
+ return (
39
+ self.logit_bias is not None
40
+ or self.vocab_mask is not None
41
+ or self.linear_penalties is not None
42
+ or self.scaling_penalties is not None
43
+ )
44
+
45
+ @classmethod
46
+ def dummy_one(cls, max_bs: int, vocab_size: int):
47
+ ret = cls(vocab_size=vocab_size)
48
+ ret.temperatures = torch.ones((max_bs, 1), dtype=torch.float, device="cuda")
49
+ ret.top_ps = torch.ones((max_bs,), dtype=torch.float, device="cuda")
50
+ ret.top_ks = torch.ones((max_bs,), dtype=torch.int, device="cuda")
51
+ ret.min_ps = torch.zeros((max_bs,), dtype=torch.float, device="cuda")
52
+ return ret
53
+
54
+ def __getitem__(self, key):
55
+ if isinstance(key, slice):
56
+ # NOTE: We do not use cuda graph when there is bias tensors
57
+ assert not self.has_bias()
58
+ return SamplingBatchInfo(
59
+ vocab_size=self.vocab_size,
60
+ temperatures=self.temperatures[key],
61
+ top_ps=self.top_ps[key],
62
+ top_ks=self.top_ks[key],
63
+ min_ps=self.min_ps[key],
64
+ need_min_p_sampling=self.need_min_p_sampling,
65
+ )
66
+ else:
67
+ raise NotImplementedError
68
+
69
+ def inplace_assign(self, bs: int, other: SamplingBatchInfo):
70
+ # NOTE: We do not use cuda graph when there is bias tensors
71
+ assert not self.has_bias()
72
+
73
+ self.vocab_size = other.vocab_size
74
+ self.need_min_p_sampling = other.need_min_p_sampling
75
+
76
+ self.temperatures[:bs] = other.temperatures
77
+ self.top_ps[:bs] = other.top_ps
78
+ self.top_ks[:bs] = other.top_ks
79
+ self.min_ps[:bs] = other.min_ps
80
+
81
+ @classmethod
82
+ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
83
+ device = "cuda"
84
+ reqs = batch.reqs
85
+ ret = cls(vocab_size=vocab_size)
86
+
87
+ ret.temperatures = torch.tensor(
88
+ [r.sampling_params.temperature for r in reqs],
89
+ dtype=torch.float,
90
+ device=device,
91
+ ).view(-1, 1)
92
+ ret.top_ps = torch.tensor(
93
+ [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
94
+ )
95
+ ret.top_ks = torch.tensor(
96
+ [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
97
+ )
98
+ ret.min_ps = torch.tensor(
99
+ [r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
100
+ )
101
+ ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs)
102
+
103
+ # Each penalizers will do nothing if they evaluate themselves as not required by looking at
104
+ # the sampling_params of the requests (See {_is_required()} of each penalizers). So this
105
+ # should not add hefty computation overhead other than simple checks.
106
+ #
107
+ # While we choose not to even create the class instances if they are not required, this
108
+ # could add additional complexity to the {ScheduleBatch} class, especially we need to
109
+ # handle {filter_batch()} and {merge()} cases as well.
110
+ ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
111
+ vocab_size=vocab_size,
112
+ batch=batch,
113
+ device=device,
114
+ Penalizers={
115
+ penaltylib.BatchedFrequencyPenalizer,
116
+ penaltylib.BatchedMinNewTokensPenalizer,
117
+ penaltylib.BatchedPresencePenalizer,
118
+ penaltylib.BatchedRepetitionPenalizer,
119
+ },
120
+ )
121
+
122
+ # Handle logit bias but only allocate when needed
123
+ ret.logit_bias = None
124
+
125
+ ret.update_regex_vocab_mask(batch)
126
+
127
+ return ret
128
+
129
+ def prepare_penalties(self):
130
+ self.scaling_penalties = None
131
+ self.linear_penalties = None
132
+
133
+ for penalizer in self.penalizer_orchestrator.penalizers.values():
134
+ if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer):
135
+ if penalizer.is_prepared():
136
+ self.scaling_penalties = penalizer.cumulated_repetition_penalties
137
+ else:
138
+ if penalizer.is_prepared():
139
+ if self.linear_penalties is None:
140
+ bs = self.penalizer_orchestrator.batch.batch_size()
141
+ self.linear_penalties = torch.zeros(
142
+ (bs, self.vocab_size),
143
+ dtype=torch.float32,
144
+ device="cuda",
145
+ )
146
+ self.linear_penalties = penalizer.apply(self.linear_penalties)
147
+
148
+ def update_regex_vocab_mask(self, batch: ScheduleBatch):
149
+ bs, reqs = batch.batch_size(), batch.reqs
150
+ device = "cuda"
151
+ has_regex = any(req.regex_fsm is not None for req in reqs)
152
+
153
+ # Reset the vocab mask
154
+ self.vocab_mask = None
155
+
156
+ if has_regex:
157
+ for i, req in enumerate(reqs):
158
+ if req.regex_fsm is not None:
159
+ if self.vocab_mask is None:
160
+ self.vocab_mask = torch.zeros(
161
+ bs, self.vocab_size, dtype=torch.bool, device=device
162
+ )
163
+ self.vocab_mask[i][
164
+ req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
165
+ ] = 1
166
+
167
+ def filter(self, unfinished_indices: List[int], new_indices: torch.Tensor):
168
+ self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
169
+
170
+ for item in [
171
+ "temperatures",
172
+ "top_ps",
173
+ "top_ks",
174
+ "min_ps",
175
+ "logit_bias",
176
+ ]:
177
+ self_val = getattr(self, item, None)
178
+ if self_val is not None: # logit_bias can be None
179
+ setattr(self, item, self_val[new_indices])
180
+
181
+ def merge(self, other: "SamplingBatchInfo"):
182
+ self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
183
+
184
+ for item in [
185
+ "temperatures",
186
+ "top_ps",
187
+ "top_ks",
188
+ "min_ps",
189
+ ]:
190
+ self_val = getattr(self, item, None)
191
+ other_val = getattr(other, item, None)
192
+ setattr(self, item, torch.concat([self_val, other_val]))
193
+
194
+ # logit_bias can be None
195
+ if self.logit_bias is not None or other.logit_bias is not None:
196
+ vocab_size = (
197
+ self.logit_bias.shape[1]
198
+ if self.logit_bias is not None
199
+ else other.logit_bias.shape[1]
200
+ )
201
+ if self.logit_bias is None:
202
+ self.logit_bias = torch.zeros(
203
+ (len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
204
+ )
205
+ if other.logit_bias is None:
206
+ other.logit_bias = torch.zeros(
207
+ (len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
208
+ )
209
+ self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
@@ -30,6 +30,7 @@ class SamplingParams:
30
30
  temperature: float = 1.0,
31
31
  top_p: float = 1.0,
32
32
  top_k: int = -1,
33
+ min_p: float = 0.0,
33
34
  frequency_penalty: float = 0.0,
34
35
  presence_penalty: float = 0.0,
35
36
  repetition_penalty: float = 1.0,
@@ -42,6 +43,7 @@ class SamplingParams:
42
43
  self.temperature = temperature
43
44
  self.top_p = top_p
44
45
  self.top_k = top_k
46
+ self.min_p = min_p
45
47
  self.frequency_penalty = frequency_penalty
46
48
  self.presence_penalty = presence_penalty
47
49
  self.repetition_penalty = repetition_penalty
@@ -69,6 +71,8 @@ class SamplingParams:
69
71
  )
70
72
  if not 0.0 < self.top_p <= 1.0:
71
73
  raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
74
+ if not 0.0 <= self.min_p <= 1.0:
75
+ raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.")
72
76
  if self.top_k < -1 or self.top_k == 0:
73
77
  raise ValueError(
74
78
  f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}."
@@ -123,3 +127,17 @@ class SamplingParams:
123
127
  else:
124
128
  stop_str_max_len = max(stop_str_max_len, len(stop_str))
125
129
  self.stop_str_max_len = stop_str_max_len
130
+
131
+ def to_srt_kwargs(self):
132
+ return {
133
+ "max_new_tokens": self.max_new_tokens,
134
+ "stop": self.stop_strs,
135
+ "stop_token_ids": list(self.stop_token_ids),
136
+ "temperature": self.temperature,
137
+ "top_p": self.top_p,
138
+ "top_k": self.top_k,
139
+ "frequency_penalty": self.frequency_penalty,
140
+ "presence_penalty": self.presence_penalty,
141
+ "ignore_eos": self.ignore_eos,
142
+ "regex": self.regex,
143
+ }
sglang/srt/server.py CHANGED
@@ -24,7 +24,6 @@ import json
24
24
  import logging
25
25
  import multiprocessing as mp
26
26
  import os
27
- import sys
28
27
  import threading
29
28
  import time
30
29
  from http import HTTPStatus
@@ -34,7 +33,6 @@ from typing import Dict, List, Optional, Union
34
33
  setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
35
34
 
36
35
  import aiohttp
37
- import psutil
38
36
  import requests
39
37
  import uvicorn
40
38
  import uvloop
@@ -52,7 +50,11 @@ from sglang.srt.managers.controller_single import (
52
50
  start_controller_process as start_controller_process_single,
53
51
  )
54
52
  from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
55
- from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
53
+ from sglang.srt.managers.io_struct import (
54
+ EmbeddingReqInput,
55
+ GenerateReqInput,
56
+ UpdateWeightReqInput,
57
+ )
56
58
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
57
59
  from sglang.srt.openai_api.adapter import (
58
60
  load_chat_template_for_openai_api,
@@ -72,6 +74,7 @@ from sglang.srt.utils import (
72
74
  add_api_key_middleware,
73
75
  allocate_init_ports,
74
76
  assert_pkg_version,
77
+ configure_logger,
75
78
  enable_show_time_cost,
76
79
  kill_child_process,
77
80
  maybe_set_triton_cache_manager,
@@ -92,10 +95,25 @@ tokenizer_manager = None
92
95
 
93
96
  @app.get("/health")
94
97
  async def health() -> Response:
95
- """Health check."""
98
+ """Check the health of the http server."""
96
99
  return Response(status_code=200)
97
100
 
98
101
 
102
+ @app.get("/health_generate")
103
+ async def health_generate(request: Request) -> Response:
104
+ """Check the health of the inference server by generating one token."""
105
+ gri = GenerateReqInput(
106
+ text="s", sampling_params={"max_new_tokens": 1, "temperature": 0.7}
107
+ )
108
+ try:
109
+ async for _ in tokenizer_manager.generate_request(gri, request):
110
+ break
111
+ return Response(status_code=200)
112
+ except Exception as e:
113
+ logger.exception(e)
114
+ return Response(status_code=503)
115
+
116
+
99
117
  @app.get("/get_model_info")
100
118
  async def get_model_info():
101
119
  result = {
@@ -120,6 +138,23 @@ async def flush_cache():
120
138
  )
121
139
 
122
140
 
141
+ @app.post("/update_weights")
142
+ async def update_weights(obj: UpdateWeightReqInput, request: Request):
143
+
144
+ success, message = await tokenizer_manager.update_weights(obj, request)
145
+ content = {"message": message, "success": str(success)}
146
+ if success:
147
+ return JSONResponse(
148
+ content,
149
+ status_code=HTTPStatus.OK,
150
+ )
151
+ else:
152
+ return JSONResponse(
153
+ content,
154
+ status_code=HTTPStatus.BAD_REQUEST,
155
+ )
156
+
157
+
123
158
  async def generate_request(obj: GenerateReqInput, request: Request):
124
159
  """Handle a generate request."""
125
160
  if obj.stream:
@@ -236,15 +271,12 @@ def launch_server(
236
271
  """Launch an HTTP server."""
237
272
  global tokenizer_manager
238
273
 
239
- logging.basicConfig(
240
- level=getattr(logging, server_args.log_level.upper()),
241
- format="%(message)s",
242
- )
274
+ configure_logger(server_args)
243
275
 
244
276
  server_args.check_server_args()
245
277
  _set_envs_and_config(server_args)
246
278
 
247
- # Allocate ports
279
+ # Allocate ports for inter-process communications
248
280
  server_args.port, server_args.additional_ports = allocate_init_ports(
249
281
  server_args.port,
250
282
  server_args.additional_ports,
@@ -264,27 +296,29 @@ def launch_server(
264
296
  server_args.tokenizer_path = prepare_tokenizer(server_args.tokenizer_path)
265
297
 
266
298
  # Launch processes for multi-node tensor parallelism
267
- if server_args.nnodes > 1:
268
- if server_args.node_rank != 0:
269
- tp_size_local = server_args.tp_size // server_args.nnodes
270
- gpu_ids = [
271
- i for _ in range(server_args.nnodes) for i in range(tp_size_local)
272
- ]
273
- tp_rank_range = list(
274
- range(
275
- server_args.node_rank * tp_size_local,
276
- (server_args.node_rank + 1) * tp_size_local,
277
- )
278
- )
279
- procs = launch_tp_servers(
280
- gpu_ids,
281
- tp_rank_range,
282
- server_args,
283
- ports[3],
284
- model_overide_args,
299
+ if server_args.nnodes > 1 and server_args.node_rank != 0:
300
+ tp_size_local = server_args.tp_size // server_args.nnodes
301
+ gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
302
+ tp_rank_range = list(
303
+ range(
304
+ server_args.node_rank * tp_size_local,
305
+ (server_args.node_rank + 1) * tp_size_local,
285
306
  )
286
- while True:
287
- pass
307
+ )
308
+ procs = launch_tp_servers(
309
+ gpu_ids,
310
+ tp_rank_range,
311
+ server_args,
312
+ ports[3],
313
+ model_overide_args,
314
+ )
315
+
316
+ try:
317
+ for p in procs:
318
+ p.join()
319
+ finally:
320
+ kill_child_process(os.getpid(), including_parent=False)
321
+ return
288
322
 
289
323
  # Launch processes
290
324
  tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
@@ -297,11 +331,13 @@ def launch_server(
297
331
  start_process = start_controller_process_single
298
332
  else:
299
333
  start_process = start_controller_process_multi
334
+
300
335
  proc_controller = mp.Process(
301
336
  target=start_process,
302
337
  args=(server_args, port_args, pipe_controller_writer, model_overide_args),
303
338
  )
304
339
  proc_controller.start()
340
+
305
341
  proc_detoken = mp.Process(
306
342
  target=start_detokenizer_process,
307
343
  args=(
@@ -319,15 +355,11 @@ def launch_server(
319
355
  if controller_init_state != "init ok" or detoken_init_state != "init ok":
320
356
  proc_controller.kill()
321
357
  proc_detoken.kill()
322
- print(
323
- f"Initialization failed. controller_init_state: {controller_init_state}",
324
- flush=True,
325
- )
326
- print(
327
- f"Initialization failed. detoken_init_state: {detoken_init_state}",
328
- flush=True,
358
+ raise RuntimeError(
359
+ "Initialization failed. "
360
+ f"controller_init_state: {controller_init_state}, "
361
+ f"detoken_init_state: {detoken_init_state}"
329
362
  )
330
- sys.exit(1)
331
363
  assert proc_controller.is_alive() and proc_detoken.is_alive()
332
364
 
333
365
  # Add api key authorization
@@ -336,12 +368,12 @@ def launch_server(
336
368
 
337
369
  # Send a warmup request
338
370
  t = threading.Thread(
339
- target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
371
+ target=_wait_and_warmup, args=(server_args, pipe_finish_writer, os.getpid())
340
372
  )
341
373
  t.start()
342
374
 
343
- # Listen for requests
344
375
  try:
376
+ # Listen for requests
345
377
  uvicorn.run(
346
378
  app,
347
379
  host=server_args.host,
@@ -389,7 +421,7 @@ def _set_envs_and_config(server_args: ServerArgs):
389
421
  )
390
422
 
391
423
 
392
- def _wait_and_warmup(server_args, pipe_finish_writer):
424
+ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
393
425
  headers = {}
394
426
  url = server_args.url()
395
427
  if server_args.api_key:
@@ -412,8 +444,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
412
444
  if not success:
413
445
  if pipe_finish_writer is not None:
414
446
  pipe_finish_writer.send(last_traceback)
415
- print(f"Initialization failed. warmup error: {last_traceback}", flush=True)
416
- sys.exit(1)
447
+ logger.error(f"Initialization failed. warmup error: {last_traceback}")
448
+ kill_child_process(pid, including_parent=False)
449
+ return
417
450
 
418
451
  # Send a warmup request
419
452
  request_name = "/generate" if model_info["is_generation"] else "/encode"
@@ -438,21 +471,13 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
438
471
  timeout=600,
439
472
  )
440
473
  assert res.status_code == 200, f"{res}"
441
- except Exception as e:
474
+ except Exception:
442
475
  last_traceback = get_exception_traceback()
443
476
  if pipe_finish_writer is not None:
444
477
  pipe_finish_writer.send(last_traceback)
445
- print(f"Initialization failed. warmup error: {last_traceback}", flush=True)
446
- sys.exit(1)
447
-
448
- # Print warnings here
449
- if server_args.disable_radix_cache and server_args.chunked_prefill_size is not None:
450
- logger.warning(
451
- "You set both `--disable-radix-cache` and `--chunked-prefill-size`. "
452
- "This combination is an experimental feature and we noticed it can lead to "
453
- "wrong generation results. If you want to use chunked prefill, it is recommended "
454
- "not using `--disable-radix-cache`."
455
- )
478
+ logger.error(f"Initialization failed. warmup error: {last_traceback}")
479
+ kill_child_process(pid, including_parent=False)
480
+ return
456
481
 
457
482
  logger.info("The server is fired up and ready to roll!")
458
483
  if pipe_finish_writer is not None:
@@ -490,6 +515,7 @@ class Runtime:
490
515
 
491
516
  self.pid = None
492
517
  pipe_reader, pipe_writer = mp.Pipe(duplex=False)
518
+
493
519
  proc = mp.Process(
494
520
  target=launch_server,
495
521
  args=(self.server_args, model_overide_args, pipe_writer),
@@ -566,15 +592,17 @@ class Runtime:
566
592
 
567
593
  def generate(
568
594
  self,
569
- prompt: str,
595
+ prompt: Union[str, List[str]],
570
596
  sampling_params: Optional[Dict] = None,
571
597
  return_logprob: Optional[Union[List[bool], bool]] = False,
598
+ logprob_start_len: Optional[Union[List[int], int]] = None,
572
599
  top_logprobs_num: Optional[Union[List[int], int]] = None,
573
600
  ):
574
601
  json_data = {
575
602
  "text": prompt,
576
603
  "sampling_params": sampling_params,
577
604
  "return_logprob": return_logprob,
605
+ "logprob_start_len": logprob_start_len,
578
606
  "top_logprobs_num": top_logprobs_num,
579
607
  }
580
608
  response = requests.post(
@@ -585,7 +613,7 @@ class Runtime:
585
613
 
586
614
  def encode(
587
615
  self,
588
- prompt: str,
616
+ prompt: Union[str, List[str]],
589
617
  ):
590
618
  json_data = {
591
619
  "text": prompt,
sglang/srt/server_args.py CHANGED
@@ -33,11 +33,13 @@ class ServerArgs:
33
33
  skip_tokenizer_init: bool = False
34
34
  load_format: str = "auto"
35
35
  dtype: str = "auto"
36
+ kv_cache_dtype: str = "auto"
36
37
  trust_remote_code: bool = True
37
38
  context_length: Optional[int] = None
38
39
  quantization: Optional[str] = None
39
40
  served_model_name: Optional[str] = None
40
41
  chat_template: Optional[str] = None
42
+ is_embedding: bool = False
41
43
 
42
44
  # Port
43
45
  host: str = "127.0.0.1"
@@ -79,12 +81,14 @@ class ServerArgs:
79
81
  disable_radix_cache: bool = False
80
82
  disable_regex_jump_forward: bool = False
81
83
  disable_cuda_graph: bool = False
84
+ disable_cuda_graph_padding: bool = False
82
85
  disable_disk_cache: bool = False
86
+ disable_custom_all_reduce: bool = False
87
+ enable_mixed_chunk: bool = False
83
88
  enable_torch_compile: bool = False
84
89
  enable_p2p_check: bool = False
85
90
  enable_mla: bool = False
86
- attention_reduce_in_fp32: bool = False
87
- efficient_weight_load: bool = False
91
+ triton_attention_reduce_in_fp32: bool = False
88
92
 
89
93
  # Distributed args
90
94
  nccl_init_addr: Optional[str] = None
@@ -193,11 +197,23 @@ class ServerArgs:
193
197
  '* "float" is shorthand for FP32 precision.\n'
194
198
  '* "float32" for FP32 precision.',
195
199
  )
200
+ parser.add_argument(
201
+ "--kv-cache-dtype",
202
+ type=str,
203
+ default=ServerArgs.kv_cache_dtype,
204
+ choices=["auto", "fp8_e5m2"],
205
+ help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
206
+ )
196
207
  parser.add_argument(
197
208
  "--trust-remote-code",
198
209
  action="store_true",
199
210
  help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
200
211
  )
212
+ parser.add_argument(
213
+ "--is-embedding",
214
+ action="store_true",
215
+ help="Whether to use a CausalLM as an embedding model.",
216
+ )
201
217
  parser.add_argument(
202
218
  "--context-length",
203
219
  type=int,
@@ -391,11 +407,27 @@ class ServerArgs:
391
407
  action="store_true",
392
408
  help="Disable cuda graph.",
393
409
  )
410
+ parser.add_argument(
411
+ "--disable-cuda-graph-padding",
412
+ action="store_true",
413
+ help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
414
+ )
394
415
  parser.add_argument(
395
416
  "--disable-disk-cache",
396
417
  action="store_true",
397
418
  help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
398
419
  )
420
+ parser.add_argument(
421
+ "--disable-custom-all-reduce",
422
+ action="store_true",
423
+ default=False,
424
+ help="Disable the custom all-reduce kernel and fall back to NCCL.",
425
+ )
426
+ parser.add_argument(
427
+ "--enable-mixed-chunk",
428
+ action="store_true",
429
+ help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
430
+ )
399
431
  parser.add_argument(
400
432
  "--enable-torch-compile",
401
433
  action="store_true",
@@ -409,13 +441,13 @@ class ServerArgs:
409
441
  parser.add_argument(
410
442
  "--enable-mla",
411
443
  action="store_true",
412
- help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2",
444
+ help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
413
445
  )
414
446
  parser.add_argument(
415
- "--attention-reduce-in-fp32",
447
+ "--triton-attention-reduce-in-fp32",
416
448
  action="store_true",
417
449
  help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
418
- "This only affects Triton attention kernels",
450
+ "This only affects Triton attention kernels.",
419
451
  )
420
452
  parser.add_argument(
421
453
  "--efficient-weight-load",
@@ -433,15 +465,6 @@ class ServerArgs:
433
465
  def url(self):
434
466
  return f"http://{self.host}:{self.port}"
435
467
 
436
- def print_mode_args(self):
437
- return (
438
- f"disable_flashinfer={self.disable_flashinfer}, "
439
- f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, "
440
- f"disable_radix_cache={self.disable_radix_cache}, "
441
- f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
442
- f"disable_disk_cache={self.disable_disk_cache}, "
443
- )
444
-
445
468
  def check_server_args(self):
446
469
  assert (
447
470
  self.tp_size % self.nnodes == 0
@@ -449,8 +472,13 @@ class ServerArgs:
449
472
  assert not (
450
473
  self.dp_size > 1 and self.node_rank is not None
451
474
  ), "multi-node data parallel is not supported"
475
+ if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
476
+ logger.info(
477
+ "Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
478
+ )
479
+ self.trust_remote_code = False
452
480
  if "gemma-2" in self.model_path.lower():
453
- logger.info(f"When using sliding window in gemma-2, turn on flashinfer.")
481
+ logger.info("When using sliding window in gemma-2, turn on flashinfer.")
454
482
  self.disable_flashinfer = False
455
483
 
456
484