sglang 0.1.12__py3-none-any.whl → 0.1.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/api.py +14 -0
  3. sglang/backend/anthropic.py +18 -12
  4. sglang/backend/base_backend.py +6 -0
  5. sglang/backend/openai.py +41 -12
  6. sglang/backend/runtime_endpoint.py +57 -6
  7. sglang/lang/chat_template.py +47 -26
  8. sglang/lang/interpreter.py +15 -2
  9. sglang/lang/ir.py +1 -1
  10. sglang/srt/constrained/__init__.py +23 -1
  11. sglang/srt/constrained/fsm_cache.py +14 -3
  12. sglang/srt/layers/context_flashattention_nopad.py +1 -1
  13. sglang/srt/layers/extend_attention.py +7 -6
  14. sglang/srt/layers/radix_attention.py +2 -10
  15. sglang/srt/layers/token_attention.py +12 -4
  16. sglang/srt/managers/io_struct.py +3 -1
  17. sglang/srt/managers/router/infer_batch.py +6 -2
  18. sglang/srt/managers/router/model_rpc.py +45 -32
  19. sglang/srt/managers/router/model_runner.py +40 -25
  20. sglang/srt/managers/tokenizer_manager.py +2 -0
  21. sglang/srt/model_config.py +12 -5
  22. sglang/srt/models/gemma.py +340 -0
  23. sglang/srt/models/llama2.py +5 -5
  24. sglang/srt/models/llava.py +2 -4
  25. sglang/srt/models/mixtral.py +5 -5
  26. sglang/srt/models/qwen.py +4 -4
  27. sglang/srt/models/qwen2.py +5 -5
  28. sglang/srt/models/stablelm.py +293 -0
  29. sglang/srt/server.py +111 -47
  30. sglang/srt/server_args.py +44 -9
  31. sglang/srt/utils.py +1 -0
  32. sglang/test/test_utils.py +1 -1
  33. sglang/utils.py +15 -12
  34. {sglang-0.1.12.dist-info → sglang-0.1.14.dist-info}/METADATA +16 -6
  35. sglang-0.1.14.dist-info/RECORD +64 -0
  36. {sglang-0.1.12.dist-info → sglang-0.1.14.dist-info}/WHEEL +1 -1
  37. sglang/srt/models/gpt_neox.py +0 -274
  38. sglang-0.1.12.dist-info/RECORD +0 -63
  39. {sglang-0.1.12.dist-info → sglang-0.1.14.dist-info}/LICENSE +0 -0
  40. {sglang-0.1.12.dist-info → sglang-0.1.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,293 @@
1
+ # This code is based on:
2
+ # https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/stablelm.py
3
+ """Inference-only StableLM-2 (https://huggingface.co/stabilityai/stablelm-2-1_6b)
4
+ model compatible with HuggingFace weights."""
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ from torch import nn
9
+ from transformers import PretrainedConfig
10
+
11
+ from sglang.srt.layers.logits_processor import LogitsProcessor
12
+ from sglang.srt.layers.radix_attention import RadixAttention
13
+ from sglang.srt.managers.router.model_runner import InputMetadata
14
+ from vllm.model_executor.layers.activation import SiluAndMul
15
+ from vllm.model_executor.layers.linear import (
16
+ LinearMethodBase,
17
+ MergedColumnParallelLinear,
18
+ QKVParallelLinear,
19
+ RowParallelLinear,
20
+ )
21
+ from vllm.model_executor.layers.rotary_embedding import get_rope
22
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
23
+ VocabParallelEmbedding,
24
+ ParallelLMHead,
25
+ )
26
+ from vllm.model_executor.parallel_utils.parallel_state import (
27
+ get_tensor_model_parallel_world_size,
28
+ )
29
+ from vllm.model_executor.weight_utils import (
30
+ default_weight_loader,
31
+ hf_model_weights_iterator,
32
+ )
33
+
34
+
35
+ class StablelmMLP(nn.Module):
36
+ def __init__(
37
+ self, config: PretrainedConfig, linear_method: Optional[LinearMethodBase] = None
38
+ ) -> None:
39
+ super().__init__()
40
+ self.config = config
41
+ self.hidden_size = config.hidden_size
42
+ self.intermediate_size = config.intermediate_size
43
+ self.gate_up_proj = MergedColumnParallelLinear(
44
+ config.hidden_size,
45
+ [config.intermediate_size] * 2,
46
+ bias=False,
47
+ linear_method=linear_method,
48
+ )
49
+ self.down_proj = RowParallelLinear(
50
+ config.intermediate_size, config.hidden_size, bias=False
51
+ )
52
+ self.act_fn = SiluAndMul()
53
+
54
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
55
+ gate_up, _ = self.gate_up_proj(x)
56
+ x = self.act_fn(gate_up)
57
+ x, _ = self.down_proj(x)
58
+ return x
59
+
60
+
61
+ class StablelmAttention(nn.Module):
62
+ def __init__(
63
+ self,
64
+ config: PretrainedConfig,
65
+ layer_id: int = 0,
66
+ linear_method: Optional[LinearMethodBase] = None,
67
+ ) -> None:
68
+ super().__init__()
69
+ self.config = config
70
+ self.hidden_size = config.hidden_size
71
+ tp_size = get_tensor_model_parallel_world_size()
72
+ self.total_num_heads = config.num_attention_heads
73
+ self.num_heads = self.total_num_heads // tp_size
74
+
75
+ self.total_num_key_value_heads = config.num_key_value_heads
76
+ if self.total_num_key_value_heads >= tp_size:
77
+ # Number of KV heads is greater than TP size, so we partition
78
+ # the KV heads across multiple tensor parallel GPUs.
79
+ assert self.total_num_key_value_heads % tp_size == 0
80
+ else:
81
+ # Number of KV heads is less than TP size, so we replicate
82
+ # the KV heads across multiple tensor parallel GPUs.
83
+ assert tp_size % self.total_num_key_value_heads == 0
84
+ self.num_key_value_heads = max(1, self.total_num_key_value_heads // tp_size)
85
+ self.head_dim = self.hidden_size // self.total_num_heads
86
+ self.max_position_embeddings = config.max_position_embeddings
87
+ rope_pct = getattr(
88
+ config, "rope_pct", getattr(config, "partial_rotary_factor", 1)
89
+ )
90
+ self.rotary_ndims = int(self.head_dim * rope_pct)
91
+ self.scaling = self.head_dim**-0.5
92
+ self.q_size = self.num_heads * self.head_dim
93
+ self.kv_size = self.num_key_value_heads * self.head_dim
94
+ self.qkv_bias = getattr(config, "use_qkv_bias", False)
95
+ if (self.head_dim * self.num_heads * tp_size) != self.hidden_size:
96
+ raise ValueError(
97
+ f"hidden_size must be divisible by num_heads "
98
+ f"(got `hidden_size`: {self.hidden_size}"
99
+ f" and `num_heads`: {self.num_heads})."
100
+ )
101
+
102
+ self.qkv_proj = QKVParallelLinear(
103
+ self.hidden_size,
104
+ self.head_dim,
105
+ self.total_num_heads,
106
+ self.total_num_key_value_heads,
107
+ self.qkv_bias,
108
+ linear_method=linear_method,
109
+ )
110
+ self.o_proj = RowParallelLinear(
111
+ self.total_num_heads * self.head_dim,
112
+ self.hidden_size,
113
+ bias=False,
114
+ linear_method=linear_method,
115
+ )
116
+ self.rotary_emb = get_rope(
117
+ self.head_dim,
118
+ rotary_dim=self.rotary_ndims,
119
+ max_position=self.config.max_position_embeddings,
120
+ base=self.config.rope_theta,
121
+ )
122
+ self.attn = RadixAttention(
123
+ self.num_heads,
124
+ self.head_dim,
125
+ self.scaling,
126
+ num_kv_heads=self.num_key_value_heads,
127
+ layer_id=layer_id,
128
+ )
129
+
130
+ def forward(
131
+ self,
132
+ positions: torch.Tensor,
133
+ hidden_states: torch.Tensor,
134
+ input_metadata: InputMetadata,
135
+ ) -> torch.Tensor:
136
+ qkv, _ = self.qkv_proj(hidden_states)
137
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
138
+ q, k = self.rotary_emb(positions, q, k)
139
+ attn_output = self.attn(q, k, v, input_metadata)
140
+ output, _ = self.o_proj(attn_output)
141
+ return output
142
+
143
+
144
+ class StablelmDecoderLayer(nn.Module):
145
+ def __init__(
146
+ self,
147
+ config: PretrainedConfig,
148
+ layer_id: int = 0,
149
+ linear_method: Optional[LinearMethodBase] = None,
150
+ ) -> None:
151
+ super().__init__()
152
+ self.self_attn = StablelmAttention(config, layer_id=layer_id)
153
+ self.mlp = StablelmMLP(config, linear_method)
154
+ norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05))
155
+ self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
156
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
157
+
158
+ def forward(
159
+ self,
160
+ positions: torch.Tensor,
161
+ hidden_states: torch.Tensor,
162
+ input_metadata: InputMetadata,
163
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
164
+ # Self Attention
165
+ residual = hidden_states
166
+ hidden_states = self.input_layernorm(hidden_states)
167
+ hidden_states = self.self_attn(
168
+ positions=positions,
169
+ hidden_states=hidden_states,
170
+ input_metadata=input_metadata,
171
+ )
172
+ hidden_states = residual + hidden_states
173
+
174
+ # Fully Connected
175
+ residual = hidden_states
176
+ hidden_states = self.post_attention_layernorm(hidden_states)
177
+ hidden_states = self.mlp(hidden_states)
178
+ hidden_states = residual + hidden_states
179
+
180
+ return hidden_states, residual
181
+
182
+
183
+ class StableLMEpochModel(nn.Module):
184
+ def __init__(
185
+ self, config: PretrainedConfig, linear_method: Optional[LinearMethodBase] = None
186
+ ) -> None:
187
+ super().__init__()
188
+ self.embed_tokens = VocabParallelEmbedding(
189
+ config.vocab_size,
190
+ config.hidden_size,
191
+ )
192
+ self.layers = nn.ModuleList(
193
+ [
194
+ StablelmDecoderLayer(config, i, linear_method)
195
+ for i in range(config.num_hidden_layers)
196
+ ]
197
+ )
198
+ norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05))
199
+ self.norm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
200
+
201
+ def forward(
202
+ self,
203
+ input_ids: torch.Tensor,
204
+ positions: torch.Tensor,
205
+ input_metadata: InputMetadata,
206
+ input_embeds: torch.Tensor = None,
207
+ ) -> torch.Tensor:
208
+ if input_embeds is None:
209
+ hidden_states = self.embed_tokens(input_ids)
210
+ else:
211
+ hidden_states = input_embeds
212
+ for i in range(len(self.layers)):
213
+ layer = self.layers[i]
214
+ hidden_states, residual = layer(
215
+ positions,
216
+ hidden_states,
217
+ input_metadata,
218
+ )
219
+ hidden_states = self.norm(hidden_states)
220
+ return hidden_states
221
+
222
+
223
+ class StableLmForCausalLM(nn.Module):
224
+ def __init__(
225
+ self,
226
+ config: PretrainedConfig,
227
+ linear_method: Optional[LinearMethodBase] = None,
228
+ ) -> None:
229
+ super().__init__()
230
+ self.config = config
231
+ self.linear_method = linear_method
232
+ self.model = StableLMEpochModel(config, linear_method)
233
+ self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
234
+ self.logits_processor = LogitsProcessor(config)
235
+
236
+ def forward(
237
+ self,
238
+ input_ids: torch.Tensor,
239
+ positions: torch.Tensor,
240
+ input_metadata: InputMetadata,
241
+ input_embeds: torch.Tensor = None,
242
+ ) -> torch.Tensor:
243
+ hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
244
+ return self.logits_processor(
245
+ input_ids, hidden_states, self.lm_head.weight, input_metadata
246
+ )
247
+
248
+ def load_weights(
249
+ self,
250
+ model_name_or_path: str,
251
+ cache_dir: Optional[str] = None,
252
+ load_format: str = "auto",
253
+ revision: Optional[str] = None,
254
+ ):
255
+ stacked_params_mapping = [
256
+ # (param_name, shard_name, shard_id)
257
+ ("qkv_proj", "q_proj", "q"),
258
+ ("qkv_proj", "k_proj", "k"),
259
+ ("qkv_proj", "v_proj", "v"),
260
+ ("gate_up_proj", "gate_proj", 0),
261
+ ("gate_up_proj", "up_proj", 1),
262
+ ]
263
+ params_dict = dict(self.named_parameters())
264
+ for name, loaded_weight in hf_model_weights_iterator(
265
+ model_name_or_path, cache_dir, load_format, revision
266
+ ):
267
+ if "rotary_emb.inv_freq" in name:
268
+ continue
269
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
270
+ # Models trained using ColossalAI may include these tensors in
271
+ # the checkpoint. Skip them.
272
+ continue
273
+ for param_name, weight_name, shard_id in stacked_params_mapping:
274
+ if weight_name not in name:
275
+ continue
276
+ name = name.replace(weight_name, param_name)
277
+ # Skip loading extra bias for GPTQ models.
278
+ if name.endswith(".bias") and name not in params_dict:
279
+ continue
280
+ param = params_dict[name]
281
+ weight_loader = param.weight_loader
282
+ weight_loader(param, loaded_weight, shard_id)
283
+ break
284
+ else:
285
+ # Skip loading extra bias for GPTQ models.
286
+ if name.endswith(".bias") and name not in params_dict:
287
+ continue
288
+ param = params_dict[name]
289
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
290
+ weight_loader(param, loaded_weight)
291
+
292
+
293
+ EntryClass = StableLmForCausalLM
sglang/srt/server.py CHANGED
@@ -1,6 +1,7 @@
1
1
  """SRT: SGLang Runtime"""
2
2
 
3
3
  import asyncio
4
+ import dataclasses
4
5
  import json
5
6
  import multiprocessing as mp
6
7
  import os
@@ -52,10 +53,31 @@ from sglang.srt.managers.openai_protocol import (
52
53
  from sglang.srt.managers.router.manager import start_router_process
53
54
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
54
55
  from sglang.srt.server_args import PortArgs, ServerArgs
55
- from sglang.srt.utils import alloc_usable_network_port, handle_port_init
56
+ from sglang.srt.utils import handle_port_init
57
+ from starlette.middleware.base import BaseHTTPMiddleware
58
+ from starlette.responses import JSONResponse
56
59
 
57
60
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
58
61
 
62
+ API_KEY_HEADER_NAME = "X-API-Key"
63
+
64
+
65
+ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
66
+ def __init__(self, app, api_key: str):
67
+ super().__init__(app)
68
+ self.api_key = api_key
69
+
70
+ async def dispatch(self, request: Request, call_next):
71
+ # extract API key from the request headers
72
+ api_key_header = request.headers.get(API_KEY_HEADER_NAME)
73
+ if not api_key_header or api_key_header != self.api_key:
74
+ return JSONResponse(
75
+ status_code=403,
76
+ content={"detail": "Invalid API Key"},
77
+ )
78
+ response = await call_next(request)
79
+ return response
80
+
59
81
 
60
82
  app = FastAPI()
61
83
  tokenizer_manager = None
@@ -86,6 +108,11 @@ async def get_model_info():
86
108
  return result
87
109
 
88
110
 
111
+ @app.get("/get_server_args")
112
+ async def get_server_args():
113
+ return dataclasses.asdict(tokenizer_manager.server_args)
114
+
115
+
89
116
  @app.get("/flush_cache")
90
117
  async def flush_cache():
91
118
  await tokenizer_manager.flush_cache()
@@ -96,19 +123,25 @@ async def flush_cache():
96
123
  )
97
124
 
98
125
 
99
- async def stream_generator(obj):
126
+ async def detokenize_logprob_tokens(token_logprobs):
127
+ token_ids = [tid for tid, _ in token_logprobs]
128
+ token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids))
129
+ return [(text, logprob) for text, (_, logprob) in zip(token_texts, token_logprobs)]
130
+
131
+
132
+ async def stream_generator(obj: GenerateReqInput):
100
133
  async for out in tokenizer_manager.generate_request(obj):
134
+ if obj.return_logprob and obj.return_text_in_logprobs:
135
+ out["meta_info"]["token_logprob"] = await detokenize_logprob_tokens(
136
+ out["meta_info"]["token_logprob"]
137
+ )
101
138
  yield out
102
139
 
103
140
 
104
141
  async def make_openai_style_logprobs(token_logprobs):
105
142
  ret_logprobs = LogProbs()
106
143
 
107
- # Detokenize
108
- token_ids = [tid for tid, _ in token_logprobs]
109
- token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids))
110
-
111
- for token_text, (_, token_logprob) in zip(token_texts, token_logprobs):
144
+ for token_text, token_logprob in token_logprobs:
112
145
  ret_logprobs.tokens.append(token_text)
113
146
  ret_logprobs.token_logprobs.append(token_logprob)
114
147
 
@@ -132,6 +165,11 @@ async def generate_request(obj: GenerateReqInput):
132
165
  return StreamingResponse(stream_results(), media_type="text/event-stream")
133
166
 
134
167
  ret = await tokenizer_manager.generate_request(obj).__anext__()
168
+ if obj.return_logprob and obj.return_text_in_logprobs:
169
+ ret["meta_info"]["token_logprob"] = await detokenize_logprob_tokens(
170
+ ret["meta_info"]["token_logprob"]
171
+ )
172
+
135
173
  return ret
136
174
 
137
175
 
@@ -155,6 +193,7 @@ async def v1_completions(raw_request: Request):
155
193
  "regex": request.regex,
156
194
  },
157
195
  return_logprob=request.logprobs is not None,
196
+ return_text_in_logprobs=True,
158
197
  stream=request.stream,
159
198
  )
160
199
  adapted_request.post_init()
@@ -211,6 +250,7 @@ async def v1_completions(raw_request: Request):
211
250
 
212
251
  # Non-streaming response.
213
252
  ret = await generate_request(adapted_request)
253
+ ret = ret[0] if isinstance(ret, list) else ret
214
254
 
215
255
  prompt_tokens = ret["meta_info"]["prompt_tokens"]
216
256
  completion_tokens = ret["meta_info"]["completion_tokens"]
@@ -463,8 +503,10 @@ def launch_server(server_args, pipe_finish_writer):
463
503
 
464
504
  assert proc_router.is_alive() and proc_detoken.is_alive()
465
505
 
506
+ if server_args.api_key and server_args.api_key != "":
507
+ app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
508
+
466
509
  def _launch_server():
467
- # Launch api server
468
510
  uvicorn.run(
469
511
  app,
470
512
  host=server_args.host,
@@ -474,49 +516,59 @@ def launch_server(server_args, pipe_finish_writer):
474
516
  loop="uvloop",
475
517
  )
476
518
 
477
- t = threading.Thread(target=_launch_server)
478
- t.start()
519
+ def _wait_and_warmup():
520
+ headers = {}
521
+ url = server_args.url()
522
+ if server_args.api_key and server_args.api_key != "":
523
+ headers[API_KEY_HEADER_NAME] = server_args.api_key
479
524
 
480
- url = server_args.url()
481
- for _ in range(60):
482
- time.sleep(1)
483
- try:
484
- requests.get(url + "/get_model_info", timeout=5)
485
- break
486
- except requests.exceptions.RequestException as e:
487
- pass
488
- else:
489
- if pipe_finish_writer is not None:
490
- pipe_finish_writer.send(str(e))
525
+ for _ in range(120):
526
+ time.sleep(0.5)
527
+ try:
528
+ requests.get(url + "/get_model_info", timeout=5, headers=headers)
529
+ break
530
+ except requests.exceptions.RequestException as e:
531
+ pass
491
532
  else:
492
- print(e, flush=True)
493
- return
533
+ if pipe_finish_writer is not None:
534
+ pipe_finish_writer.send(str(e))
535
+ else:
536
+ print(e, flush=True)
537
+ return
494
538
 
495
- # Warmup
496
- try:
497
- # print("Warmup...", flush=True)
498
- res = requests.post(
499
- url + "/generate",
500
- json={
501
- "text": "Say this is a warmup request.",
502
- "sampling_params": {
503
- "temperature": 0,
504
- "max_new_tokens": 16,
539
+ # Warmup
540
+ try:
541
+ # print("Warmup...", flush=True)
542
+ res = requests.post(
543
+ url + "/generate",
544
+ json={
545
+ "text": "Say this is a warmup request.",
546
+ "sampling_params": {
547
+ "temperature": 0,
548
+ "max_new_tokens": 16,
549
+ },
505
550
  },
506
- },
507
- timeout=60,
508
- )
509
- # print(f"Warmup done. model response: {res.json()['text']}")
510
- # print("=" * 20, "Server is ready", "=" * 20, flush=True)
511
- except requests.exceptions.RequestException as e:
551
+ headers=headers,
552
+ timeout=60,
553
+ )
554
+ # print(f"Warmup done. model response: {res.json()['text']}")
555
+ # print("=" * 20, "Server is ready", "=" * 20, flush=True)
556
+ except requests.exceptions.RequestException as e:
557
+ if pipe_finish_writer is not None:
558
+ pipe_finish_writer.send(str(e))
559
+ else:
560
+ print(e, flush=True)
561
+ return
562
+
512
563
  if pipe_finish_writer is not None:
513
- pipe_finish_writer.send(str(e))
514
- else:
515
- print(e, flush=True)
516
- return
564
+ pipe_finish_writer.send("init ok")
517
565
 
518
- if pipe_finish_writer is not None:
519
- pipe_finish_writer.send("init ok")
566
+ t = threading.Thread(target=_wait_and_warmup)
567
+ t.start()
568
+ try:
569
+ _launch_server()
570
+ finally:
571
+ t.join()
520
572
 
521
573
 
522
574
  class Runtime:
@@ -529,11 +581,17 @@ class Runtime:
529
581
  trust_remote_code: bool = True,
530
582
  mem_fraction_static: float = ServerArgs.mem_fraction_static,
531
583
  max_prefill_num_token: int = ServerArgs.max_prefill_num_token,
584
+ context_length: int = ServerArgs.context_length,
532
585
  tp_size: int = 1,
533
- model_mode: List[str] = (),
534
586
  schedule_heuristic: str = "lpm",
587
+ attention_reduce_in_fp32: bool = False,
535
588
  random_seed: int = 42,
536
589
  log_level: str = "error",
590
+ disable_radix_cache: bool = False,
591
+ enable_flashinfer: bool = False,
592
+ disable_regex_jump_forward: bool = False,
593
+ disable_disk_cache: bool = False,
594
+ api_key: str = "",
537
595
  port: Optional[int] = None,
538
596
  additional_ports: Optional[Union[List[int], int]] = None,
539
597
  ):
@@ -550,11 +608,17 @@ class Runtime:
550
608
  trust_remote_code=trust_remote_code,
551
609
  mem_fraction_static=mem_fraction_static,
552
610
  max_prefill_num_token=max_prefill_num_token,
611
+ context_length=context_length,
553
612
  tp_size=tp_size,
554
- model_mode=model_mode,
555
613
  schedule_heuristic=schedule_heuristic,
614
+ attention_reduce_in_fp32=attention_reduce_in_fp32,
556
615
  random_seed=random_seed,
557
616
  log_level=log_level,
617
+ disable_radix_cache=disable_radix_cache,
618
+ enable_flashinfer=enable_flashinfer,
619
+ disable_regex_jump_forward=disable_regex_jump_forward,
620
+ disable_disk_cache=disable_disk_cache,
621
+ api_key=api_key,
558
622
  )
559
623
 
560
624
  self.url = self.server_args.url()
sglang/srt/server_args.py CHANGED
@@ -16,17 +16,23 @@ class ServerArgs:
16
16
  trust_remote_code: bool = True
17
17
  mem_fraction_static: Optional[float] = None
18
18
  max_prefill_num_token: Optional[int] = None
19
+ context_length: Optional[int] = None
19
20
  tp_size: int = 1
20
- model_mode: List[str] = ()
21
21
  schedule_heuristic: str = "lpm"
22
22
  schedule_conservativeness: float = 1.0
23
+ attention_reduce_in_fp32: bool = False
23
24
  random_seed: int = 42
24
25
  stream_interval: int = 8
25
26
  disable_log_stats: bool = False
26
27
  log_stats_interval: int = 10
27
28
  log_level: str = "info"
29
+
30
+ # optional modes
31
+ disable_radix_cache: bool = False
32
+ enable_flashinfer: bool = False
28
33
  disable_regex_jump_forward: bool = False
29
34
  disable_disk_cache: bool = False
35
+ api_key: str = ""
30
36
 
31
37
  def __post_init__(self):
32
38
  if self.tokenizer_path is None:
@@ -117,20 +123,18 @@ class ServerArgs:
117
123
  default=ServerArgs.max_prefill_num_token,
118
124
  help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
119
125
  )
126
+ parser.add_argument(
127
+ "--context-length",
128
+ type=int,
129
+ default=ServerArgs.context_length,
130
+ help="The model's maximum context length. Use this to reduce the context length to save memory. Defaults to None (will use the value from the model's config.json instead).",
131
+ )
120
132
  parser.add_argument(
121
133
  "--tp-size",
122
134
  type=int,
123
135
  default=ServerArgs.tp_size,
124
136
  help="Tensor parallelism degree.",
125
137
  )
126
- parser.add_argument(
127
- "--model-mode",
128
- type=str,
129
- default=[],
130
- nargs="+",
131
- choices=["flashinfer", "no-cache"],
132
- help="Model mode: [flashinfer, no-cache]",
133
- )
134
138
  parser.add_argument(
135
139
  "--schedule-heuristic",
136
140
  type=str,
@@ -149,6 +153,11 @@ class ServerArgs:
149
153
  default=ServerArgs.random_seed,
150
154
  help="Random seed.",
151
155
  )
156
+ parser.add_argument(
157
+ "--attention-reduce-in-fp32",
158
+ action="store_true",
159
+ help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16.",
160
+ )
152
161
  parser.add_argument(
153
162
  "--stream-interval",
154
163
  type=int,
@@ -172,6 +181,17 @@ class ServerArgs:
172
181
  default=ServerArgs.log_stats_interval,
173
182
  help="Log stats interval in second.",
174
183
  )
184
+ # optional modes
185
+ parser.add_argument(
186
+ "--disable-radix-cache",
187
+ action="store_true",
188
+ help="Disable RadixAttention",
189
+ )
190
+ parser.add_argument(
191
+ "--enable-flashinfer",
192
+ action="store_true",
193
+ help="Enable flashinfer inference kernels",
194
+ )
175
195
  parser.add_argument(
176
196
  "--disable-regex-jump-forward",
177
197
  action="store_true",
@@ -182,6 +202,12 @@ class ServerArgs:
182
202
  action="store_true",
183
203
  help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
184
204
  )
205
+ parser.add_argument(
206
+ "--api-key",
207
+ type=str,
208
+ default=ServerArgs.api_key,
209
+ help="Set API Key",
210
+ )
185
211
 
186
212
  @classmethod
187
213
  def from_cli_args(cls, args: argparse.Namespace):
@@ -191,6 +217,15 @@ class ServerArgs:
191
217
  def url(self):
192
218
  return f"http://{self.host}:{self.port}"
193
219
 
220
+ def get_optional_modes_logging(self):
221
+ return (
222
+ f"disable_radix_cache={self.disable_radix_cache}, "
223
+ f"enable_flashinfer={self.enable_flashinfer}, "
224
+ f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
225
+ f"disable_disk_cache={self.disable_disk_cache}, "
226
+ f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}"
227
+ )
228
+
194
229
 
195
230
  @dataclasses.dataclass
196
231
  class PortArgs:
sglang/srt/utils.py CHANGED
@@ -103,6 +103,7 @@ def alloc_usable_network_port(num, used_list=()):
103
103
  def check_port(port):
104
104
  with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
105
105
  try:
106
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
106
107
  s.bind(("", port))
107
108
  return True
108
109
  except socket.error:
sglang/test/test_utils.py CHANGED
@@ -155,7 +155,7 @@ def select_sglang_backend(args):
155
155
  global_config.enable_parallel_decoding = False
156
156
  global_config.enable_parallel_encoding = False
157
157
  backend = RuntimeEndpoint(f"{args.host}:{args.port}")
158
- elif args.backend.startswith("gpt"):
158
+ elif args.backend.startswith("gpt-"):
159
159
  backend = OpenAI(args.backend)
160
160
  else:
161
161
  raise ValueError(f"Invalid backend: {args.backend}")