sglang 0.1.12__py3-none-any.whl → 0.1.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/api.py +14 -0
  3. sglang/backend/anthropic.py +18 -12
  4. sglang/backend/base_backend.py +6 -0
  5. sglang/backend/openai.py +41 -12
  6. sglang/backend/runtime_endpoint.py +57 -6
  7. sglang/lang/chat_template.py +47 -26
  8. sglang/lang/interpreter.py +15 -2
  9. sglang/lang/ir.py +1 -1
  10. sglang/srt/constrained/__init__.py +23 -1
  11. sglang/srt/constrained/fsm_cache.py +14 -3
  12. sglang/srt/layers/context_flashattention_nopad.py +1 -1
  13. sglang/srt/layers/extend_attention.py +7 -6
  14. sglang/srt/layers/radix_attention.py +2 -10
  15. sglang/srt/layers/token_attention.py +12 -4
  16. sglang/srt/managers/io_struct.py +3 -1
  17. sglang/srt/managers/router/infer_batch.py +6 -2
  18. sglang/srt/managers/router/model_rpc.py +45 -32
  19. sglang/srt/managers/router/model_runner.py +40 -25
  20. sglang/srt/managers/tokenizer_manager.py +2 -0
  21. sglang/srt/model_config.py +12 -5
  22. sglang/srt/models/gemma.py +340 -0
  23. sglang/srt/models/llama2.py +5 -5
  24. sglang/srt/models/llava.py +2 -4
  25. sglang/srt/models/mixtral.py +5 -5
  26. sglang/srt/models/qwen.py +4 -4
  27. sglang/srt/models/qwen2.py +5 -5
  28. sglang/srt/models/stablelm.py +293 -0
  29. sglang/srt/server.py +111 -47
  30. sglang/srt/server_args.py +44 -9
  31. sglang/srt/utils.py +1 -0
  32. sglang/test/test_utils.py +1 -1
  33. sglang/utils.py +15 -12
  34. {sglang-0.1.12.dist-info → sglang-0.1.14.dist-info}/METADATA +16 -6
  35. sglang-0.1.14.dist-info/RECORD +64 -0
  36. {sglang-0.1.12.dist-info → sglang-0.1.14.dist-info}/WHEEL +1 -1
  37. sglang/srt/models/gpt_neox.py +0 -274
  38. sglang-0.1.12.dist-info/RECORD +0 -63
  39. {sglang-0.1.12.dist-info → sglang-0.1.14.dist-info}/LICENSE +0 -0
  40. {sglang-0.1.12.dist-info → sglang-0.1.14.dist-info}/top_level.txt +0 -0
@@ -4,8 +4,16 @@
4
4
  import torch
5
5
  import triton
6
6
  import triton.language as tl
7
+ from sglang.srt.managers.router.model_runner import global_server_args_dict
7
8
  from sglang.srt.utils import wrap_kernel_launcher
8
9
 
10
+ if global_server_args_dict.get("attention_reduce_in_fp32", False):
11
+ REDUCE_TRITON_TYPE = tl.float32
12
+ REDUCE_TORCH_TYPE = torch.float32
13
+ else:
14
+ REDUCE_TRITON_TYPE = tl.float16
15
+ REDUCE_TORCH_TYPE = torch.float16
16
+
9
17
 
10
18
  @triton.jit
11
19
  def _fwd_kernel_stage1(
@@ -49,7 +57,7 @@ def _fwd_kernel_stage1(
49
57
  block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)
50
58
 
51
59
  for start_mark in range(0, block_mask, 1):
52
- q = tl.load(Q + off_q + start_mark)
60
+ q = tl.load(Q + off_q + start_mark).to(REDUCE_TRITON_TYPE)
53
61
  offs_n_new = cur_batch_start_index + offs_n
54
62
  k_loc = tl.load(
55
63
  Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
@@ -65,7 +73,7 @@ def _fwd_kernel_stage1(
65
73
  K_Buffer + offs_buf_k,
66
74
  mask=offs_n_new[:, None] < cur_batch_end_index,
67
75
  other=0.0,
68
- )
76
+ ).to(REDUCE_TRITON_TYPE)
69
77
  att_value = tl.sum(q[None, :] * k, 1)
70
78
  att_value *= sm_scale
71
79
  off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n)
@@ -161,7 +169,7 @@ def _token_att_m_fwd(
161
169
  # shape constraints
162
170
  Lq, Lk = q.shape[-1], k_buffer.shape[-1]
163
171
  assert Lq == Lk
164
- assert Lk in {16, 32, 64, 128}
172
+ assert Lk in {16, 32, 64, 128, 256}
165
173
  sm_scale = 1.0 / (Lk**0.5)
166
174
 
167
175
  batch, head_num = B_req_idx.shape[0], q.shape[1]
@@ -299,7 +307,7 @@ def token_attention_fwd(
299
307
  ):
300
308
  if att_m is None:
301
309
  att_m = torch.empty(
302
- (q.shape[-2], total_num_tokens), dtype=q.dtype, device="cuda"
310
+ (q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda"
303
311
  )
304
312
 
305
313
  _token_att_m_fwd(
@@ -15,10 +15,12 @@ class GenerateReqInput:
15
15
  sampling_params: Union[List[Dict], Dict] = None
16
16
  # The request id
17
17
  rid: Optional[Union[List[str], str]] = None
18
- # Whether return logprobs of the prompts
18
+ # Whether to return logprobs
19
19
  return_logprob: Optional[Union[List[bool], bool]] = None
20
20
  # The start location of the prompt for return_logprob
21
21
  logprob_start_len: Optional[Union[List[int], int]] = None
22
+ # Whether to detokenize tokens in logprobs
23
+ return_text_in_logprobs: bool = False
22
24
  # Whether to stream output
23
25
  stream: bool = False
24
26
 
@@ -27,8 +27,12 @@ class Req:
27
27
  self.input_ids = input_ids
28
28
  self.output_ids = []
29
29
 
30
- # for accumulated prompt tokens from jump forward
31
- self.orig_prompt_tokens = len(input_ids)
30
+ # Since jump forward may retokenize the prompt with partial outputs,
31
+ # we maintain the original prompt length to report the correct usage.
32
+ self.prompt_tokens = len(input_ids)
33
+ # The number of decoded tokens for token usage report. Note that
34
+ # this does not include the jump forward tokens.
35
+ self.completion_tokens_wo_jump_forward = 0
32
36
 
33
37
  # For vision input
34
38
  self.pixel_values = None
@@ -46,7 +46,6 @@ class ModelRpcServer(rpyc.Service):
46
46
  server_args, port_args = [obtain(x) for x in [server_args, port_args]]
47
47
 
48
48
  # Copy arguments
49
- self.model_mode = server_args.model_mode
50
49
  self.tp_rank = tp_rank
51
50
  self.tp_size = server_args.tp_size
52
51
  self.schedule_heuristic = server_args.schedule_heuristic
@@ -57,17 +56,26 @@ class ModelRpcServer(rpyc.Service):
57
56
 
58
57
  # Init model and tokenizer
59
58
  self.model_config = ModelConfig(
60
- server_args.model_path, server_args.trust_remote_code
59
+ server_args.model_path,
60
+ server_args.trust_remote_code,
61
+ context_length=server_args.context_length,
61
62
  )
63
+
64
+ # for model end global settings
65
+ server_args_dict = {
66
+ "enable_flashinfer": server_args.enable_flashinfer,
67
+ "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
68
+ }
69
+
62
70
  self.model_runner = ModelRunner(
63
- self.model_config,
64
- server_args.mem_fraction_static,
65
- tp_rank,
66
- server_args.tp_size,
67
- port_args.nccl_port,
68
- server_args.load_format,
69
- server_args.trust_remote_code,
70
- server_args.model_mode,
71
+ model_config=self.model_config,
72
+ mem_fraction_static=server_args.mem_fraction_static,
73
+ tp_rank=tp_rank,
74
+ tp_size=server_args.tp_size,
75
+ nccl_port=port_args.nccl_port,
76
+ load_format=server_args.load_format,
77
+ trust_remote_code=server_args.trust_remote_code,
78
+ server_args_dict=server_args_dict,
71
79
  )
72
80
  if is_multimodal_model(server_args.model_path):
73
81
  self.processor = get_processor(
@@ -102,11 +110,11 @@ class ModelRpcServer(rpyc.Service):
102
110
  f"max_total_num_token={self.max_total_num_token}, "
103
111
  f"max_prefill_num_token={self.max_prefill_num_token}, "
104
112
  f"context_len={self.model_config.context_len}, "
105
- f"model_mode={self.model_mode}"
106
113
  )
114
+ logger.info(server_args.get_optional_modes_logging())
107
115
 
108
116
  # Init cache
109
- self.tree_cache = RadixCache(disable="no-cache" in self.model_mode)
117
+ self.tree_cache = RadixCache(server_args.disable_radix_cache)
110
118
  self.tree_cache_metrics = {"total": 0, "hit": 0}
111
119
  self.scheduler = Scheduler(
112
120
  self.schedule_heuristic,
@@ -208,6 +216,19 @@ class ModelRpcServer(rpyc.Service):
208
216
 
209
217
  if self.out_pyobjs and self.running_batch.reqs[0].stream:
210
218
  break
219
+
220
+ if self.running_batch is not None and self.tp_rank == 0:
221
+ if self.decode_forward_ct % 40 == 0:
222
+ num_used = self.max_total_num_token - (
223
+ self.token_to_kv_pool.available_size()
224
+ + self.tree_cache.evictable_size()
225
+ )
226
+ logger.info(
227
+ f"#running-req: {len(self.running_batch.reqs)}, "
228
+ f"#token: {num_used}, "
229
+ f"token usage: {num_used / self.max_total_num_token:.2f}, "
230
+ f"#queue-req: {len(self.forward_queue)}"
231
+ )
211
232
  else:
212
233
  # check the available size
213
234
  available_size = (
@@ -221,19 +242,6 @@ class ModelRpcServer(rpyc.Service):
221
242
  "KV cache pool leak detected!"
222
243
  )
223
244
 
224
- if self.running_batch is not None and self.tp_rank == 0:
225
- if self.decode_forward_ct % 20 == 0:
226
- num_used = self.max_total_num_token - (
227
- self.token_to_kv_pool.available_size()
228
- + self.tree_cache.evictable_size()
229
- )
230
- logger.info(
231
- f"#running-req: {len(self.running_batch.reqs)}, "
232
- f"#token: {num_used}, "
233
- f"token usage: {num_used / self.max_total_num_token:.2f}, "
234
- f"#queue-req: {len(self.forward_queue)}"
235
- )
236
-
237
245
  def handle_generate_request(
238
246
  self,
239
247
  recv_req: TokenizedGenerateReqInput,
@@ -424,6 +432,7 @@ class ModelRpcServer(rpyc.Service):
424
432
  # Check finish condition
425
433
  pt = 0
426
434
  for i, req in enumerate(reqs):
435
+ req.completion_tokens_wo_jump_forward += 1
427
436
  req.output_ids = [next_token_ids[i]]
428
437
  req.check_finished()
429
438
 
@@ -431,9 +440,14 @@ class ModelRpcServer(rpyc.Service):
431
440
  req.logprob = logprobs[pt : pt + req.extend_input_len - 1]
432
441
  req.normalized_logprob = normalized_logprobs[i]
433
442
 
434
- token_ids = req.input_ids + [next_token_ids[i]]
435
- token_logprobs = [None] + req.logprob + [last_logprobs[i]]
443
+ # If logprob_start_len > 0, then first logprob_start_len prompt tokens
444
+ # will be ignored.
445
+ prompt_token_len = len(req.logprob)
446
+ token_ids = req.input_ids[-prompt_token_len:] + [next_token_ids[i]]
447
+ token_logprobs = req.logprob + [last_logprobs[i]]
436
448
  req.token_logprob = list(zip(token_ids, token_logprobs))
449
+ if req.logprob_start_len == 0:
450
+ req.token_logprob = [(req.input_ids[0], None)] + req.token_logprob
437
451
  pt += req.extend_input_len
438
452
 
439
453
  self.handle_finished_requests(batch)
@@ -500,6 +514,7 @@ class ModelRpcServer(rpyc.Service):
500
514
 
501
515
  # Check finish condition
502
516
  for i, (req, next_tok_id) in enumerate(zip(reqs, next_token_ids)):
517
+ req.completion_tokens_wo_jump_forward += 1
503
518
  req.output_ids.append(next_tok_id)
504
519
  req.check_finished()
505
520
 
@@ -541,15 +556,13 @@ class ModelRpcServer(rpyc.Service):
541
556
  req.sampling_params.skip_special_tokens
542
557
  )
543
558
 
544
- # For the length of input_ids, which will be accumulated during jump-forward.
545
- # Use the original length of input_ids to calculate the token usage info.
546
559
  meta_info = {
547
- "prompt_tokens": req.orig_prompt_tokens,
560
+ "prompt_tokens": req.prompt_tokens,
548
561
  "completion_tokens": len(req.input_ids)
549
562
  + len(req.output_ids)
550
- - req.orig_prompt_tokens,
563
+ - req.prompt_tokens,
564
+ "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
551
565
  }
552
-
553
566
  if req.return_logprob:
554
567
  meta_info["prompt_logprob"] = req.logprob
555
568
  meta_info["token_logprob"] = req.token_logprob
@@ -1,9 +1,10 @@
1
1
  import importlib
2
2
  import logging
3
+ import inspect
3
4
  from dataclasses import dataclass
4
5
  from functools import lru_cache
5
6
  from pathlib import Path
6
- from typing import List
7
+ import importlib.resources
7
8
 
8
9
  import numpy as np
9
10
  import torch
@@ -13,27 +14,34 @@ from sglang.srt.utils import is_multimodal_model
13
14
  from sglang.utils import get_available_gpu_memory
14
15
  from vllm.model_executor.layers.quantization.awq import AWQConfig
15
16
  from vllm.model_executor.layers.quantization.gptq import GPTQConfig
17
+ from vllm.model_executor.layers.quantization.marlin import MarlinConfig
16
18
  from vllm.model_executor.model_loader import _set_default_torch_dtype
17
19
  from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
18
20
 
21
+ import importlib
22
+ import pkgutil
23
+
19
24
  import sglang
20
25
 
21
- QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig}
26
+ QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig, "marlin": MarlinConfig}
22
27
 
23
28
  logger = logging.getLogger("model_runner")
24
29
 
25
30
 
26
- # for model_mode
27
- global_model_mode: List[str] = []
31
+ # for server args in model endpoints
32
+ global_server_args_dict: dict = None
28
33
 
29
34
 
30
35
  @lru_cache()
31
36
  def import_model_classes():
32
37
  model_arch_name_to_cls = {}
33
- for module_path in (Path(sglang.__file__).parent / "srt" / "models").glob("*.py"):
34
- module = importlib.import_module(f"sglang.srt.models.{module_path.stem}")
35
- if hasattr(module, "EntryClass"):
36
- model_arch_name_to_cls[module.EntryClass.__name__] = module.EntryClass
38
+ package_name = "sglang.srt.models"
39
+ package = importlib.import_module(package_name)
40
+ for finder, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + '.'):
41
+ if not ispkg:
42
+ module = importlib.import_module(name)
43
+ if hasattr(module, "EntryClass"):
44
+ model_arch_name_to_cls[module.EntryClass.__name__] = module.EntryClass
37
45
  return model_arch_name_to_cls
38
46
 
39
47
 
@@ -81,7 +89,6 @@ class InputMetadata:
81
89
  return_logprob: bool = False
82
90
 
83
91
  # for flashinfer
84
- use_flashinfer: bool = False
85
92
  qo_indptr: torch.Tensor = None
86
93
  kv_indptr: torch.Tensor = None
87
94
  kv_indices: torch.Tensor = None
@@ -126,14 +133,21 @@ class InputMetadata:
126
133
  self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
127
134
  workspace_buffer, "NHD"
128
135
  )
129
- self.prefill_wrapper.begin_forward(
136
+ args = [
130
137
  self.qo_indptr,
131
138
  self.kv_indptr,
132
139
  self.kv_indices,
133
140
  self.kv_last_page_len,
134
141
  self.model_runner.model_config.num_attention_heads // tp_size,
135
142
  self.model_runner.model_config.num_key_value_heads // tp_size,
136
- )
143
+ ]
144
+
145
+ # flashinfer >= 0.0.3
146
+ # FIXME: Drop this when flashinfer updates to 0.0.4
147
+ if len(inspect.signature(self.prefill_wrapper.begin_forward).parameters) == 7:
148
+ args.append(self.model_runner.model_config.head_dim)
149
+
150
+ self.prefill_wrapper.begin_forward(*args)
137
151
  else:
138
152
  self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
139
153
  workspace_buffer, "NHD"
@@ -224,8 +238,7 @@ class InputMetadata:
224
238
  if forward_mode == ForwardMode.EXTEND:
225
239
  ret.init_extend_args()
226
240
 
227
- ret.use_flashinfer = "flashinfer" in model_runner.model_mode
228
- if ret.use_flashinfer:
241
+ if global_server_args_dict.get("enable_flashinfer", False):
229
242
  ret.init_flashinfer_args(tp_size)
230
243
 
231
244
  return ret
@@ -241,7 +254,7 @@ class ModelRunner:
241
254
  nccl_port,
242
255
  load_format="auto",
243
256
  trust_remote_code=True,
244
- model_mode: List[str] = (),
257
+ server_args_dict: dict = {},
245
258
  ):
246
259
  self.model_config = model_config
247
260
  self.mem_fraction_static = mem_fraction_static
@@ -250,10 +263,9 @@ class ModelRunner:
250
263
  self.nccl_port = nccl_port
251
264
  self.load_format = load_format
252
265
  self.trust_remote_code = trust_remote_code
253
- self.model_mode = model_mode
254
266
 
255
- global global_model_mode
256
- global_model_mode = model_mode
267
+ global global_server_args_dict
268
+ global_server_args_dict = server_args_dict
257
269
 
258
270
  # Init torch distributed
259
271
  torch.cuda.set_device(self.tp_rank)
@@ -292,9 +304,15 @@ class ModelRunner:
292
304
  self.model_config.hf_config, "quantization_config", None
293
305
  )
294
306
  if hf_quant_config is not None:
295
- quant_config_class = QUANTIONCONFIG_MAPPING.get(
296
- hf_quant_config["quant_method"]
297
- )
307
+ hf_quant_method = hf_quant_config["quant_method"]
308
+
309
+ # compat: autogptq uses is_marlin_format within quant config
310
+ if (hf_quant_method == "gptq"
311
+ and "is_marlin_format" in hf_quant_config
312
+ and hf_quant_config["is_marlin_format"]):
313
+ hf_quant_method = "marlin"
314
+ quant_config_class = QUANTIONCONFIG_MAPPING.get(hf_quant_method)
315
+
298
316
  if quant_config_class is None:
299
317
  raise ValueError(
300
318
  f"Unsupported quantization method: {hf_quant_config['quant_method']}"
@@ -319,9 +337,7 @@ class ModelRunner:
319
337
  available_gpu_memory = get_available_gpu_memory(
320
338
  self.tp_rank, distributed=self.tp_size > 1
321
339
  ) * (1 << 30)
322
- head_dim = (
323
- self.model_config.hidden_size // self.model_config.num_attention_heads
324
- )
340
+ head_dim = self.model_config.head_dim
325
341
  head_num = self.model_config.num_key_value_heads // self.tp_size
326
342
  cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
327
343
  rest_memory = available_gpu_memory - total_gpu_memory * (
@@ -346,8 +362,7 @@ class ModelRunner:
346
362
  self.max_total_num_token,
347
363
  dtype=torch.float16,
348
364
  head_num=self.model_config.num_key_value_heads // self.tp_size,
349
- head_dim=self.model_config.hidden_size
350
- // self.model_config.num_attention_heads,
365
+ head_dim=self.model_config.head_dim,
351
366
  layer_num=self.model_config.num_hidden_layers,
352
367
  )
353
368
 
@@ -82,6 +82,8 @@ class TokenizerManager:
82
82
  server_args: ServerArgs,
83
83
  port_args: PortArgs,
84
84
  ):
85
+ self.server_args = server_args
86
+
85
87
  context = zmq.asyncio.Context(2)
86
88
  self.recv_from_detokenizer = context.socket(zmq.PULL)
87
89
  self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
@@ -1,7 +1,5 @@
1
- import os
2
- from typing import Optional, Union
1
+ from typing import Optional
3
2
 
4
- import torch
5
3
  from sglang.srt.hf_transformers_utils import get_config, get_context_length
6
4
 
7
5
 
@@ -11,15 +9,24 @@ class ModelConfig:
11
9
  path: str,
12
10
  trust_remote_code: bool = True,
13
11
  revision: Optional[str] = None,
12
+ context_length: Optional[int] = None,
14
13
  ) -> None:
15
14
  self.path = path
16
15
  self.trust_remote_code = trust_remote_code
17
16
  self.revision = revision
18
17
  self.hf_config = get_config(self.path, trust_remote_code, revision)
19
18
 
19
+ if context_length is not None:
20
+ self.context_len = context_length
21
+ else:
22
+ self.context_len = get_context_length(self.hf_config)
23
+
20
24
  # Unify the config keys for hf_config
21
- self.context_len = get_context_length(self.hf_config)
22
- self.head_dim = self.hf_config.hidden_size // self.hf_config.num_attention_heads
25
+ self.head_dim = getattr(
26
+ self.hf_config,
27
+ "head_dim",
28
+ self.hf_config.hidden_size // self.hf_config.num_attention_heads,
29
+ )
23
30
  self.num_attention_heads = self.hf_config.num_attention_heads
24
31
  self.num_key_value_heads = getattr(self.hf_config, "num_key_value_heads", None)
25
32
  if self.num_key_value_heads is None: