sglang 0.2.11__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. sglang/bench_latency.py +6 -4
  2. sglang/bench_serving.py +46 -22
  3. sglang/lang/compiler.py +2 -2
  4. sglang/lang/ir.py +3 -3
  5. sglang/srt/constrained/base_tool_cache.py +1 -1
  6. sglang/srt/constrained/fsm_cache.py +12 -2
  7. sglang/srt/layers/activation.py +33 -0
  8. sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
  9. sglang/srt/layers/extend_attention.py +6 -1
  10. sglang/srt/layers/layernorm.py +65 -0
  11. sglang/srt/layers/logits_processor.py +5 -0
  12. sglang/srt/layers/pooler.py +50 -0
  13. sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
  14. sglang/srt/layers/radix_attention.py +2 -2
  15. sglang/srt/managers/detokenizer_manager.py +31 -9
  16. sglang/srt/managers/io_struct.py +63 -0
  17. sglang/srt/managers/policy_scheduler.py +173 -25
  18. sglang/srt/managers/schedule_batch.py +110 -87
  19. sglang/srt/managers/tokenizer_manager.py +193 -111
  20. sglang/srt/managers/tp_worker.py +289 -352
  21. sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
  22. sglang/srt/mem_cache/chunk_cache.py +43 -20
  23. sglang/srt/mem_cache/memory_pool.py +2 -2
  24. sglang/srt/mem_cache/radix_cache.py +74 -40
  25. sglang/srt/model_executor/cuda_graph_runner.py +24 -9
  26. sglang/srt/model_executor/forward_batch_info.py +168 -105
  27. sglang/srt/model_executor/model_runner.py +24 -37
  28. sglang/srt/models/gemma2.py +0 -1
  29. sglang/srt/models/internlm2.py +2 -7
  30. sglang/srt/models/llama2.py +4 -4
  31. sglang/srt/models/llama_embedding.py +88 -0
  32. sglang/srt/models/qwen2_moe.py +0 -11
  33. sglang/srt/openai_api/adapter.py +155 -27
  34. sglang/srt/openai_api/protocol.py +37 -1
  35. sglang/srt/sampling/penaltylib/__init__.py +13 -0
  36. sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
  37. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
  38. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
  39. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
  40. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
  41. sglang/srt/sampling_params.py +31 -4
  42. sglang/srt/server.py +69 -15
  43. sglang/srt/server_args.py +26 -19
  44. sglang/srt/utils.py +31 -13
  45. sglang/test/run_eval.py +10 -1
  46. sglang/test/runners.py +63 -63
  47. sglang/test/simple_eval_humaneval.py +2 -8
  48. sglang/test/simple_eval_mgsm.py +203 -0
  49. sglang/test/srt/sampling/penaltylib/utils.py +337 -0
  50. sglang/test/test_layernorm.py +60 -0
  51. sglang/test/test_programs.py +4 -2
  52. sglang/test/test_utils.py +20 -2
  53. sglang/utils.py +0 -1
  54. sglang/version.py +1 -1
  55. {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/METADATA +23 -14
  56. sglang-0.2.12.dist-info/RECORD +112 -0
  57. sglang/srt/layers/linear.py +0 -884
  58. sglang/srt/layers/quantization/__init__.py +0 -64
  59. sglang/srt/layers/quantization/fp8.py +0 -677
  60. sglang-0.2.11.dist-info/RECORD +0 -102
  61. {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/LICENSE +0 -0
  62. {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/WHEEL +0 -0
  63. {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/top_level.txt +0 -0
sglang/bench_latency.py CHANGED
@@ -152,7 +152,7 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
152
152
  req = Req(rid=i, origin_input_text=prompts[i], origin_input_ids=tmp_input_ids)
153
153
  req.prefix_indices = []
154
154
  req.sampling_params = sampling_params
155
- req.input_ids = req.origin_input_ids
155
+ req.fill_ids = req.origin_input_ids
156
156
  reqs.append(req)
157
157
 
158
158
  return input_ids, reqs
@@ -163,7 +163,7 @@ def prepare_extend_inputs_for_correctness_test(
163
163
  ):
164
164
  for i in range(len(reqs)):
165
165
  req = reqs[i]
166
- req.input_ids += input_ids[i][bench_args.cut_len :]
166
+ req.fill_ids += input_ids[i][bench_args.cut_len :]
167
167
  req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
168
168
  i, : bench_args.cut_len
169
169
  ]
@@ -182,7 +182,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
182
182
  req = Req(rid=i, origin_input_text="", origin_input_ids=list(input_ids[i]))
183
183
  req.prefix_indices = []
184
184
  req.sampling_params = sampling_params
185
- req.input_ids = req.origin_input_ids
185
+ req.fill_ids = req.origin_input_ids
186
186
  reqs.append(req)
187
187
 
188
188
  return reqs
@@ -238,7 +238,7 @@ def correctness_test(
238
238
 
239
239
  # Decode
240
240
  output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
241
- for _ in range(bench_args.output_len):
241
+ for _ in range(bench_args.output_len[0]):
242
242
  next_token_ids, _ = decode(next_token_ids, batch, model_runner)
243
243
  for i in range(len(reqs)):
244
244
  output_ids[i].append(next_token_ids[i])
@@ -332,6 +332,7 @@ def latency_test(
332
332
  )
333
333
 
334
334
  # Warm up
335
+ rank_print("Warmup ...")
335
336
  latency_test_run_once(
336
337
  bench_args.run_name,
337
338
  model_runner,
@@ -341,6 +342,7 @@ def latency_test(
341
342
  bench_args.input_len[0],
342
343
  4, # shorter decoding to speed up the warmup
343
344
  )
345
+ rank_print("Benchmark ...")
344
346
 
345
347
  # Run the sweep
346
348
  result_list = []
sglang/bench_serving.py CHANGED
@@ -24,7 +24,7 @@ import warnings
24
24
  from argparse import ArgumentParser
25
25
  from dataclasses import dataclass, field
26
26
  from datetime import datetime
27
- from typing import AsyncGenerator, List, Optional, Tuple, Union
27
+ from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
28
28
 
29
29
  import aiohttp
30
30
  import numpy as np
@@ -39,6 +39,8 @@ from transformers import (
39
39
 
40
40
  AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
41
41
 
42
+ global args
43
+
42
44
 
43
45
  @dataclass
44
46
  class RequestFuncInput:
@@ -47,6 +49,7 @@ class RequestFuncInput:
47
49
  prompt_len: int
48
50
  output_len: int
49
51
  model: str
52
+ extra_request_body: Dict[str, Any]
50
53
 
51
54
 
52
55
  @dataclass
@@ -84,6 +87,7 @@ async def async_request_trt_llm(
84
87
  "stream": True,
85
88
  "min_length": request_func_input.output_len,
86
89
  "end_id": 1048576,
90
+ **request_func_input.extra_request_body,
87
91
  }
88
92
  if args.disable_ignore_eos:
89
93
  del payload["min_length"]
@@ -154,6 +158,7 @@ async def async_request_openai_completions(
154
158
  "max_tokens": request_func_input.output_len,
155
159
  "stream": not args.disable_stream,
156
160
  "ignore_eos": not args.disable_ignore_eos,
161
+ **request_func_input.extra_request_body,
157
162
  }
158
163
  headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
159
164
 
@@ -192,7 +197,8 @@ async def async_request_openai_completions(
192
197
  output.ttft = ttft
193
198
 
194
199
  # Decoding phase
195
- output.itl.append(timestamp - most_recent_timestamp)
200
+ else:
201
+ output.itl.append(timestamp - most_recent_timestamp)
196
202
 
197
203
  most_recent_timestamp = timestamp
198
204
  generated_text += data["choices"][0]["text"]
@@ -542,6 +548,7 @@ async def benchmark(
542
548
  request_rate: float,
543
549
  disable_tqdm: bool,
544
550
  enable_multi: bool,
551
+ extra_request_body: Dict[str, Any],
545
552
  ):
546
553
  if backend in ASYNC_REQUEST_FUNCS:
547
554
  request_func = ASYNC_REQUEST_FUNCS[backend]
@@ -556,6 +563,7 @@ async def benchmark(
556
563
  api_url=api_url,
557
564
  prompt_len=test_prompt_len,
558
565
  output_len=test_output_len,
566
+ extra_request_body=extra_request_body,
559
567
  )
560
568
  test_output = await request_func(request_func_input=test_input)
561
569
  if not test_output.success:
@@ -578,6 +586,7 @@ async def benchmark(
578
586
  api_url=api_url,
579
587
  prompt_len=prompt_len,
580
588
  output_len=output_len,
589
+ extra_request_body=extra_request_body,
581
590
  )
582
591
  tasks.append(
583
592
  asyncio.create_task(
@@ -660,19 +669,20 @@ async def benchmark(
660
669
  "backend": args.backend,
661
670
  "dataset_name": args.dataset_name,
662
671
  "request_rate": request_rate,
663
- "total_input": metrics.total_input,
664
- "total_output": metrics.total_output,
665
- "total_output_retokenized": metrics.total_output_retokenized,
666
- "mean_e2e_latency": metrics.mean_e2e_latency_ms,
667
- "median_e2e_latency": metrics.median_e2e_latency_ms,
668
- "median_ttft": metrics.median_ttft_ms,
669
- "median_itl": metrics.median_itl_ms,
670
- "output_token_throughput": metrics.output_throughput,
672
+ "total_input_tokens": metrics.total_input,
673
+ "total_output_tokens": metrics.total_output,
674
+ "total_output_tokens_retokenized": metrics.total_output_retokenized,
675
+ "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
676
+ "median_e2e_latency_ms": metrics.median_e2e_latency_ms,
677
+ "median_ttft_ms": metrics.median_ttft_ms,
678
+ "median_itl_ms": metrics.median_itl_ms,
679
+ "output_throughput": metrics.output_throughput,
671
680
  "sharegpt_output_len": args.sharegpt_output_len,
672
681
  "random_input_len": args.random_input_len,
673
682
  "random_output_len": args.random_output_len,
674
683
  "random_range_ratio": args.random_range_ratio,
675
- "benchmark_duration": benchmark_duration,
684
+ "duration": benchmark_duration,
685
+ "completed": metrics.completed,
676
686
  }
677
687
  else:
678
688
  print(f"Error running benchmark for request rate: {request_rate}")
@@ -742,10 +752,18 @@ def check_chat_template(model_path):
742
752
  return False
743
753
 
744
754
 
745
- def fire(args: argparse.Namespace):
755
+ def run_benchmark(args_: argparse.Namespace):
756
+ global args
757
+ args = args_
758
+
759
+ set_ulimit()
746
760
  random.seed(args.seed)
747
761
  np.random.seed(args.seed)
748
762
 
763
+ extra_request_body = {}
764
+ if args.extra_request_body:
765
+ extra_request_body = json.loads(args.extra_request_body)
766
+
749
767
  if args.port is None:
750
768
  args.port = {
751
769
  "sglang": 30000,
@@ -838,10 +856,11 @@ def fire(args: argparse.Namespace):
838
856
  request_rate=rate,
839
857
  disable_tqdm=args.disable_tqdm,
840
858
  enable_multi=args.multi,
859
+ extra_request_body=extra_request_body,
841
860
  )
842
861
  )
843
862
  else:
844
- asyncio.run(
863
+ return asyncio.run(
845
864
  benchmark(
846
865
  backend=backend,
847
866
  api_url=api_url,
@@ -851,6 +870,7 @@ def fire(args: argparse.Namespace):
851
870
  request_rate=args.request_rate,
852
871
  disable_tqdm=args.disable_tqdm,
853
872
  enable_multi=args.multi,
873
+ extra_request_body=extra_request_body,
854
874
  )
855
875
  )
856
876
 
@@ -949,11 +969,6 @@ if __name__ == "__main__":
949
969
  "Otherwise, we use Poisson process to synthesize the request arrival times. Default is 128.0.",
950
970
  )
951
971
  parser.add_argument("--seed", type=int, default=0, help="Default is 0.")
952
- parser.add_argument(
953
- "--disable-tqdm",
954
- action="store_true",
955
- help="Specify to disable tqdm progress bar.",
956
- )
957
972
  parser.add_argument(
958
973
  "--multi",
959
974
  action="store_true",
@@ -966,6 +981,11 @@ if __name__ == "__main__":
966
981
  help="Range of request rates in the format start,stop,step. Default is 2,34,2. It also supports a list of request rates, requiring the parameters to not equal three.",
967
982
  )
968
983
  parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
984
+ parser.add_argument(
985
+ "--disable-tqdm",
986
+ action="store_true",
987
+ help="Specify to disable tqdm progress bar.",
988
+ )
969
989
  parser.add_argument(
970
990
  "--disable-stream",
971
991
  action="store_true",
@@ -976,8 +996,12 @@ if __name__ == "__main__":
976
996
  action="store_true",
977
997
  help="Disable ignoring EOS.",
978
998
  )
979
-
980
- set_ulimit()
981
-
999
+ parser.add_argument(
1000
+ "--extra-request-body",
1001
+ metavar='{"key1": "value1", "key2": "value2"}',
1002
+ type=str,
1003
+ help="Append given JSON object to the request payload. You can use this to specify"
1004
+ "additional generate params like sampling params.",
1005
+ )
982
1006
  args = parser.parse_args()
983
- fire(args)
1007
+ run_benchmark(args)
sglang/lang/compiler.py CHANGED
@@ -125,7 +125,7 @@ class CompiledFunction:
125
125
  def run(
126
126
  self,
127
127
  *,
128
- max_new_tokens: int = 16,
128
+ max_new_tokens: int = 128,
129
129
  stop: Union[str, List[str]] = (),
130
130
  temperature: float = 1.0,
131
131
  top_p: float = 1.0,
@@ -155,7 +155,7 @@ class CompiledFunction:
155
155
  self,
156
156
  batch_kwargs,
157
157
  *,
158
- max_new_tokens: int = 16,
158
+ max_new_tokens: int = 128,
159
159
  stop: Union[str, List[str]] = (),
160
160
  temperature: float = 1.0,
161
161
  top_p: float = 1.0,
sglang/lang/ir.py CHANGED
@@ -16,7 +16,7 @@ REGEX_STRING = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
16
16
 
17
17
  @dataclasses.dataclass
18
18
  class SglSamplingParams:
19
- max_new_tokens: int = 16
19
+ max_new_tokens: int = 128
20
20
  stop: Union[str, List[str]] = ()
21
21
  temperature: float = 1.0
22
22
  top_p: float = 1.0
@@ -140,7 +140,7 @@ class SglFunction:
140
140
  def run(
141
141
  self,
142
142
  *args,
143
- max_new_tokens: int = 16,
143
+ max_new_tokens: int = 128,
144
144
  stop: Union[str, List[str]] = (),
145
145
  temperature: float = 1.0,
146
146
  top_p: float = 1.0,
@@ -179,7 +179,7 @@ class SglFunction:
179
179
  self,
180
180
  batch_kwargs,
181
181
  *,
182
- max_new_tokens: int = 16,
182
+ max_new_tokens: int = 128,
183
183
  stop: Union[str, List[str]] = (),
184
184
  temperature: float = 1.0,
185
185
  top_p: float = 1.0,
@@ -54,7 +54,7 @@ class BaseToolCache:
54
54
  return val
55
55
 
56
56
  def init_value(self, key):
57
- raise NotImplementedError
57
+ raise NotImplementedError()
58
58
 
59
59
  def get_cache_hit_rate(self):
60
60
  if self.metrics["total"] == 0:
@@ -20,10 +20,20 @@ from sglang.srt.constrained.base_tool_cache import BaseToolCache
20
20
 
21
21
 
22
22
  class FSMCache(BaseToolCache):
23
- def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
23
+ def __init__(
24
+ self,
25
+ tokenizer_path,
26
+ tokenizer_args_dict,
27
+ enable=True,
28
+ skip_tokenizer_init=False,
29
+ ):
24
30
  super().__init__(enable=enable)
25
31
 
26
- if tokenizer_path.endswith(".json") or tokenizer_path.endswith(".model"):
32
+ if (
33
+ skip_tokenizer_init
34
+ or tokenizer_path.endswith(".json")
35
+ or tokenizer_path.endswith(".model")
36
+ ):
27
37
  # Do not support TiktokenTokenizer or SentencePieceTokenizer
28
38
  return
29
39
 
@@ -0,0 +1,33 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+ http://www.apache.org/licenses/LICENSE-2.0
7
+ Unless required by applicable law or agreed to in writing, software
8
+ distributed under the License is distributed on an "AS IS" BASIS,
9
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ See the License for the specific language governing permissions and
11
+ limitations under the License.
12
+ """
13
+
14
+ """Fused operators for activation layers."""
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from flashinfer.activation import silu_and_mul
20
+ from vllm.model_executor.custom_op import CustomOp
21
+
22
+
23
+ class SiluAndMul(CustomOp):
24
+ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
25
+ d = x.shape[-1] // 2
26
+ return F.silu(x[..., :d]) * x[..., d:]
27
+
28
+ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
29
+ d = x.shape[-1] // 2
30
+ output_shape = x.shape[:-1] + (d,)
31
+ out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
32
+ silu_and_mul(x, out)
33
+ return out
@@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
13
13
  limitations under the License.
14
14
  """
15
15
 
16
+ """
17
+ Memory-efficient attention for decoding.
18
+ """
19
+
16
20
  # Adapted from
17
21
  # https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
18
22
  # https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py
@@ -194,7 +198,7 @@ def _fwd_kernel_stage2(
194
198
  tl.store(out_ptrs, acc)
195
199
 
196
200
 
197
- def _token_att_m_fwd(
201
+ def _decode_att_m_fwd(
198
202
  q,
199
203
  k_buffer,
200
204
  att_out,
@@ -254,7 +258,7 @@ def _token_att_m_fwd(
254
258
  )
255
259
 
256
260
 
257
- def _token_softmax_reducev_fwd(
261
+ def _decode_softmax_reducev_fwd(
258
262
  logics,
259
263
  v_buffer,
260
264
  o,
@@ -292,7 +296,7 @@ def _token_softmax_reducev_fwd(
292
296
  )
293
297
 
294
298
 
295
- def token_attention_fwd(
299
+ def decode_attention_fwd(
296
300
  q,
297
301
  k_buffer,
298
302
  v_buffer,
@@ -312,7 +316,7 @@ def token_attention_fwd(
312
316
  (q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda"
313
317
  )
314
318
 
315
- _token_att_m_fwd(
319
+ _decode_att_m_fwd(
316
320
  q,
317
321
  k_buffer,
318
322
  att_m,
@@ -324,7 +328,7 @@ def token_attention_fwd(
324
328
  sm_scale,
325
329
  logit_cap,
326
330
  )
327
- _token_softmax_reducev_fwd(
331
+ _decode_softmax_reducev_fwd(
328
332
  att_m,
329
333
  v_buffer,
330
334
  o,
@@ -13,11 +13,16 @@ See the License for the specific language governing permissions and
13
13
  limitations under the License.
14
14
  """
15
15
 
16
+ """
17
+ Memory-efficient attention for prefill.
18
+ It supporst page size = 1 and prefill with KV cache (i.e. extend).
19
+ """
20
+
16
21
  import torch
17
22
  import triton
18
23
  import triton.language as tl
19
24
 
20
- from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
25
+ from sglang.srt.layers.prefill_attention import context_attention_fwd
21
26
 
22
27
  CUDA_CAPABILITY = torch.cuda.get_device_capability()
23
28
 
@@ -0,0 +1,65 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ """Fused operators for normalization layers."""
17
+
18
+ from typing import Optional, Tuple, Union
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ from flashinfer.norm import fused_add_rmsnorm, rmsnorm
23
+ from vllm.model_executor.custom_op import CustomOp
24
+
25
+
26
+ class RMSNorm(CustomOp):
27
+ def __init__(
28
+ self,
29
+ hidden_size: int,
30
+ eps: float = 1e-6,
31
+ ) -> None:
32
+ super().__init__()
33
+ self.weight = nn.Parameter(torch.ones(hidden_size))
34
+ self.variance_epsilon = eps
35
+
36
+ def forward_cuda(
37
+ self,
38
+ x: torch.Tensor,
39
+ residual: Optional[torch.Tensor] = None,
40
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
41
+
42
+ if residual is not None:
43
+ fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
44
+ return x, residual
45
+ out = rmsnorm(x, self.weight.data, self.variance_epsilon)
46
+ return out
47
+
48
+ def forward_native(
49
+ self,
50
+ x: torch.Tensor,
51
+ residual: Optional[torch.Tensor] = None,
52
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
53
+ orig_dtype = x.dtype
54
+ x = x.to(torch.float32)
55
+ if residual is not None:
56
+ x = x + residual.to(torch.float32)
57
+ residual = x.to(orig_dtype)
58
+
59
+ variance = x.pow(2).mean(dim=-1, keepdim=True)
60
+ x = x * torch.rsqrt(variance + self.variance_epsilon)
61
+ x = x.to(orig_dtype) * self.weight
62
+ if residual is None:
63
+ return x
64
+ else:
65
+ return x, residual
@@ -208,6 +208,11 @@ class LogitsProcessor(nn.Module):
208
208
  all_logits = tensor_model_parallel_all_gather(all_logits)
209
209
  all_logits = all_logits[:, : self.config.vocab_size].float()
210
210
 
211
+ if hasattr(self.config, "final_logit_softcapping"):
212
+ all_logits /= self.config.final_logit_softcapping
213
+ all_logits = torch.tanh(all_logits)
214
+ all_logits *= self.config.final_logit_softcapping
215
+
211
216
  all_logprobs = all_logits
212
217
  del all_logits, hidden_states
213
218
  all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
@@ -0,0 +1,50 @@
1
+ # adapted from
2
+ # https://github.com/vllm-project/vllm/blob/82a1b1a82b1fbb454c82a9ef95730b929c9b270c/vllm/model_executor/layers/pooler.py
3
+
4
+ from dataclasses import dataclass
5
+ from enum import IntEnum
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from sglang.srt.model_executor.model_runner import InputMetadata
11
+
12
+
13
+ class PoolingType(IntEnum):
14
+ LAST = 0
15
+
16
+
17
+ @dataclass
18
+ class EmbeddingPoolerOutput:
19
+ embeddings: torch.Tensor
20
+
21
+
22
+ class Pooler(nn.Module):
23
+ """A layer that pools specific information from hidden states.
24
+ This layer does the following:
25
+ 1. Extracts specific tokens or aggregates data based on pooling method.
26
+ 2. Normalizes output if specified.
27
+ 3. Returns structured results as `PoolerOutput`.
28
+ Attributes:
29
+ pooling_type: The type of pooling to use (LAST, AVERAGE, MAX).
30
+ normalize: Whether to normalize the pooled data.
31
+ """
32
+
33
+ def __init__(self, pooling_type: PoolingType, normalize: bool):
34
+ super().__init__()
35
+ self.pooling_type = pooling_type
36
+ self.normalize = normalize
37
+
38
+ def forward(
39
+ self, hidden_states: torch.Tensor, input_metadata: InputMetadata
40
+ ) -> EmbeddingPoolerOutput:
41
+ if self.pooling_type == PoolingType.LAST:
42
+ last_token_indices = torch.cumsum(input_metadata.extend_seq_lens, dim=0) - 1
43
+ pooled_data = hidden_states[last_token_indices]
44
+ else:
45
+ raise ValueError(f"Invalid pooling type: {self.pooling_type}")
46
+
47
+ if self.normalize:
48
+ pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
49
+
50
+ return EmbeddingPoolerOutput(embeddings=pooled_data)
@@ -13,6 +13,11 @@ See the License for the specific language governing permissions and
13
13
  limitations under the License.
14
14
  """
15
15
 
16
+ """
17
+ Memory-efficient attention for prefill.
18
+ It supporst page size = 1.
19
+ """
20
+
16
21
  # Adapted from
17
22
  # https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
18
23
  import torch
@@ -20,8 +20,8 @@ from flashinfer.cascade import merge_state
20
20
  from torch import nn
21
21
 
22
22
  from sglang.global_config import global_config
23
+ from sglang.srt.layers.decode_attention import decode_attention_fwd
23
24
  from sglang.srt.layers.extend_attention import extend_attention_fwd
24
- from sglang.srt.layers.token_attention import token_attention_fwd
25
25
  from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
26
26
  from sglang.srt.model_executor.model_runner import global_server_args_dict
27
27
 
@@ -95,7 +95,7 @@ class RadixAttention(nn.Module):
95
95
  o = torch.empty_like(q)
96
96
  self.store_kv_cache(k, v, input_metadata)
97
97
 
98
- token_attention_fwd(
98
+ decode_attention_fwd(
99
99
  q.view(-1, self.tp_q_head_num, self.qk_head_dim),
100
100
  input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
101
101
  input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
@@ -25,10 +25,14 @@ import zmq
25
25
  import zmq.asyncio
26
26
 
27
27
  from sglang.srt.hf_transformers_utils import get_tokenizer
28
- from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
28
+ from sglang.srt.managers.io_struct import (
29
+ BatchEmbeddingOut,
30
+ BatchStrOut,
31
+ BatchTokenIDOut,
32
+ )
29
33
  from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
30
34
  from sglang.srt.server_args import PortArgs, ServerArgs
31
- from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry
35
+ from sglang.utils import find_printable_text, get_exception_traceback
32
36
 
33
37
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
34
38
 
@@ -55,20 +59,40 @@ class DetokenizerManager:
55
59
  self.send_to_tokenizer = context.socket(zmq.PUSH)
56
60
  self.send_to_tokenizer.connect(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
57
61
 
58
- self.tokenizer = get_tokenizer(
59
- server_args.tokenizer_path,
60
- tokenizer_mode=server_args.tokenizer_mode,
61
- trust_remote_code=server_args.trust_remote_code,
62
- )
62
+ if server_args.skip_tokenizer_init:
63
+ self.tokenizer = None
64
+ else:
65
+ self.tokenizer = get_tokenizer(
66
+ server_args.tokenizer_path,
67
+ tokenizer_mode=server_args.tokenizer_mode,
68
+ trust_remote_code=server_args.trust_remote_code,
69
+ )
63
70
 
64
71
  self.decode_status = {}
65
72
 
66
73
  async def handle_loop(self):
67
74
  while True:
68
75
  recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
76
+
77
+ if isinstance(recv_obj, BatchEmbeddingOut):
78
+ self.send_to_tokenizer.send_pyobj(
79
+ BatchEmbeddingOut(
80
+ rids=recv_obj.rids,
81
+ embeddings=recv_obj.embeddings,
82
+ meta_info=recv_obj.meta_info,
83
+ finished_reason=recv_obj.finished_reason,
84
+ )
85
+ )
86
+ continue
87
+
69
88
  assert isinstance(recv_obj, BatchTokenIDOut)
70
89
  bs = len(recv_obj.rids)
71
90
 
91
+ if self.tokenizer is None:
92
+ # Send BatchTokenIDOut if no tokenizer init'ed.
93
+ self.send_to_tokenizer.send_pyobj(recv_obj)
94
+ continue
95
+
72
96
  # Initialize decode status
73
97
  read_ids, surr_ids = [], []
74
98
  for i in range(bs):
@@ -140,8 +164,6 @@ def start_detokenizer_process(
140
164
  port_args: PortArgs,
141
165
  pipe_writer,
142
166
  ):
143
- graceful_registry(inspect.currentframe().f_code.co_name)
144
-
145
167
  try:
146
168
  manager = DetokenizerManager(server_args, port_args)
147
169
  except Exception: