sglang 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. sglang/__init__.py +8 -0
  2. sglang/api.py +10 -2
  3. sglang/bench_latency.py +151 -40
  4. sglang/bench_serving.py +46 -22
  5. sglang/check_env.py +24 -2
  6. sglang/global_config.py +0 -1
  7. sglang/lang/backend/base_backend.py +3 -1
  8. sglang/lang/backend/openai.py +8 -3
  9. sglang/lang/backend/runtime_endpoint.py +46 -29
  10. sglang/lang/choices.py +164 -0
  11. sglang/lang/compiler.py +2 -2
  12. sglang/lang/interpreter.py +6 -13
  13. sglang/lang/ir.py +14 -5
  14. sglang/srt/constrained/base_tool_cache.py +1 -1
  15. sglang/srt/constrained/fsm_cache.py +12 -2
  16. sglang/srt/layers/activation.py +33 -0
  17. sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
  18. sglang/srt/layers/extend_attention.py +6 -1
  19. sglang/srt/layers/layernorm.py +65 -0
  20. sglang/srt/layers/logits_processor.py +6 -1
  21. sglang/srt/layers/pooler.py +50 -0
  22. sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
  23. sglang/srt/layers/radix_attention.py +4 -7
  24. sglang/srt/managers/detokenizer_manager.py +31 -9
  25. sglang/srt/managers/io_struct.py +63 -0
  26. sglang/srt/managers/policy_scheduler.py +173 -25
  27. sglang/srt/managers/schedule_batch.py +174 -380
  28. sglang/srt/managers/tokenizer_manager.py +197 -112
  29. sglang/srt/managers/tp_worker.py +299 -364
  30. sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
  31. sglang/srt/mem_cache/chunk_cache.py +43 -20
  32. sglang/srt/mem_cache/memory_pool.py +10 -15
  33. sglang/srt/mem_cache/radix_cache.py +74 -40
  34. sglang/srt/model_executor/cuda_graph_runner.py +27 -12
  35. sglang/srt/model_executor/forward_batch_info.py +319 -0
  36. sglang/srt/model_executor/model_runner.py +30 -47
  37. sglang/srt/models/chatglm.py +1 -1
  38. sglang/srt/models/commandr.py +1 -1
  39. sglang/srt/models/dbrx.py +1 -1
  40. sglang/srt/models/deepseek.py +1 -1
  41. sglang/srt/models/deepseek_v2.py +1 -1
  42. sglang/srt/models/gemma.py +1 -1
  43. sglang/srt/models/gemma2.py +1 -2
  44. sglang/srt/models/gpt_bigcode.py +1 -1
  45. sglang/srt/models/grok.py +1 -1
  46. sglang/srt/models/internlm2.py +3 -8
  47. sglang/srt/models/llama2.py +5 -5
  48. sglang/srt/models/llama_classification.py +1 -1
  49. sglang/srt/models/llama_embedding.py +88 -0
  50. sglang/srt/models/llava.py +1 -2
  51. sglang/srt/models/llavavid.py +1 -2
  52. sglang/srt/models/minicpm.py +1 -1
  53. sglang/srt/models/mixtral.py +1 -1
  54. sglang/srt/models/mixtral_quant.py +1 -1
  55. sglang/srt/models/qwen.py +1 -1
  56. sglang/srt/models/qwen2.py +1 -1
  57. sglang/srt/models/qwen2_moe.py +1 -12
  58. sglang/srt/models/stablelm.py +1 -1
  59. sglang/srt/openai_api/adapter.py +189 -39
  60. sglang/srt/openai_api/protocol.py +43 -1
  61. sglang/srt/sampling/penaltylib/__init__.py +13 -0
  62. sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
  63. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
  64. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
  65. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
  66. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
  67. sglang/srt/sampling_params.py +31 -4
  68. sglang/srt/server.py +93 -21
  69. sglang/srt/server_args.py +30 -19
  70. sglang/srt/utils.py +31 -13
  71. sglang/test/run_eval.py +10 -1
  72. sglang/test/runners.py +63 -63
  73. sglang/test/simple_eval_humaneval.py +2 -8
  74. sglang/test/simple_eval_mgsm.py +203 -0
  75. sglang/test/srt/sampling/penaltylib/utils.py +337 -0
  76. sglang/test/test_layernorm.py +60 -0
  77. sglang/test/test_programs.py +4 -2
  78. sglang/test/test_utils.py +21 -3
  79. sglang/utils.py +0 -1
  80. sglang/version.py +1 -1
  81. {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/METADATA +50 -31
  82. sglang-0.2.12.dist-info/RECORD +112 -0
  83. sglang/srt/layers/linear.py +0 -884
  84. sglang/srt/layers/quantization/__init__.py +0 -64
  85. sglang/srt/layers/quantization/fp8.py +0 -677
  86. sglang-0.2.10.dist-info/RECORD +0 -100
  87. {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/LICENSE +0 -0
  88. {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/WHEEL +0 -0
  89. {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,21 @@
1
1
  import json
2
2
  from typing import List, Optional
3
3
 
4
- import numpy as np
5
-
6
4
  from sglang.global_config import global_config
7
5
  from sglang.lang.backend.base_backend import BaseBackend
8
6
  from sglang.lang.chat_template import get_chat_template_by_model_path
7
+ from sglang.lang.choices import (
8
+ ChoicesDecision,
9
+ ChoicesSamplingMethod,
10
+ token_length_normalized,
11
+ )
9
12
  from sglang.lang.interpreter import StreamExecutor
10
13
  from sglang.lang.ir import SglSamplingParams
11
14
  from sglang.utils import http_request
12
15
 
13
16
 
14
17
  class RuntimeEndpoint(BaseBackend):
18
+
15
19
  def __init__(
16
20
  self,
17
21
  base_url: str,
@@ -43,7 +47,7 @@ class RuntimeEndpoint(BaseBackend):
43
47
  def flush_cache(self):
44
48
  res = http_request(
45
49
  self.base_url + "/flush_cache",
46
- auth_token=self.auth_token,
50
+ api_key=self.api_key,
47
51
  verify=self.verify,
48
52
  )
49
53
  self._assert_success(res)
@@ -51,7 +55,7 @@ class RuntimeEndpoint(BaseBackend):
51
55
  def get_server_args(self):
52
56
  res = http_request(
53
57
  self.base_url + "/get_server_args",
54
- auth_token=self.auth_token,
58
+ api_key=self.api_key,
55
59
  verify=self.verify,
56
60
  )
57
61
  self._assert_success(res)
@@ -208,20 +212,14 @@ class RuntimeEndpoint(BaseBackend):
208
212
  s: StreamExecutor,
209
213
  choices: List[str],
210
214
  temperature: float,
211
- ):
215
+ choices_method: ChoicesSamplingMethod,
216
+ ) -> ChoicesDecision:
212
217
  assert temperature <= 1e-5
213
218
 
214
219
  # Cache common prefix
215
220
  data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
216
- self._add_images(s, data)
217
- res = http_request(
218
- self.base_url + "/generate",
219
- json=data,
220
- api_key=self.api_key,
221
- verify=self.verify,
222
- )
223
- self._assert_success(res)
224
- prompt_len = res.json()["meta_info"]["prompt_tokens"]
221
+ obj = self._generate_http_request(s, data)
222
+ prompt_len = obj["meta_info"]["prompt_tokens"]
225
223
 
226
224
  # Compute logprob
227
225
  data = {
@@ -230,27 +228,35 @@ class RuntimeEndpoint(BaseBackend):
230
228
  "return_logprob": True,
231
229
  "logprob_start_len": max(prompt_len - 2, 0),
232
230
  }
233
- self._add_images(s, data)
234
- res = http_request(
235
- self.base_url + "/generate",
236
- json=data,
237
- api_key=self.api_key,
238
- verify=self.verify,
239
- )
240
- self._assert_success(res)
241
- obj = res.json()
231
+ obj = self._generate_http_request(s, data)
232
+
242
233
  normalized_prompt_logprobs = [
243
234
  r["meta_info"]["normalized_prompt_logprob"] for r in obj
244
235
  ]
245
- decision = choices[np.argmax(normalized_prompt_logprobs)]
246
236
  input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
247
237
  output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
248
238
 
249
- return (
250
- decision,
251
- normalized_prompt_logprobs,
252
- input_token_logprobs,
253
- output_token_logprobs,
239
+ # Compute unconditional logprobs if required
240
+ if choices_method.requires_unconditional_logprobs:
241
+ input_ids = [[el[1] for el in subl] for subl in input_token_logprobs]
242
+ data = {
243
+ "input_ids": input_ids,
244
+ "sampling_params": {"max_new_tokens": 0},
245
+ "return_logprob": True,
246
+ }
247
+ obj = self._generate_http_request(s, data)
248
+ unconditional_token_logprobs = [
249
+ r["meta_info"]["input_token_logprobs"] for r in obj
250
+ ]
251
+ else:
252
+ unconditional_token_logprobs = None
253
+
254
+ return choices_method(
255
+ choices=choices,
256
+ normalized_prompt_logprobs=normalized_prompt_logprobs,
257
+ input_token_logprobs=input_token_logprobs,
258
+ output_token_logprobs=output_token_logprobs,
259
+ unconditional_token_logprobs=unconditional_token_logprobs,
254
260
  )
255
261
 
256
262
  def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
@@ -262,6 +268,17 @@ class RuntimeEndpoint(BaseBackend):
262
268
  )
263
269
  self._assert_success(res)
264
270
 
271
+ def _generate_http_request(self, s: StreamExecutor, data):
272
+ self._add_images(s, data)
273
+ res = http_request(
274
+ self.base_url + "/generate",
275
+ json=data,
276
+ api_key=self.api_key,
277
+ verify=self.verify,
278
+ )
279
+ self._assert_success(res)
280
+ return res.json()
281
+
265
282
  def _add_images(self, s: StreamExecutor, data):
266
283
  if s.images_:
267
284
  assert len(s.images_) == 1, "Only support one image."
sglang/lang/choices.py ADDED
@@ -0,0 +1,164 @@
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ import numpy as np
6
+
7
+
8
+ @dataclass
9
+ class ChoicesDecision:
10
+ decision: str
11
+ meta_info: Optional[Dict[str, Any]] = None
12
+
13
+
14
+ class ChoicesSamplingMethod(ABC):
15
+
16
+ @property
17
+ def requires_unconditional_logprobs(self) -> bool:
18
+ return False
19
+
20
+ @abstractmethod
21
+ def __call__(
22
+ self,
23
+ *,
24
+ choices: List[str],
25
+ normalized_prompt_logprobs: List[float],
26
+ input_token_logprobs: List[List[Any]],
27
+ output_token_logprobs: List[List[Any]],
28
+ unconditional_token_logprobs: Optional[List[List[Any]]] = None,
29
+ ) -> ChoicesDecision: ...
30
+
31
+
32
+ class TokenLengthNormalized(ChoicesSamplingMethod):
33
+
34
+ def __call__(
35
+ self,
36
+ *,
37
+ choices: List[str],
38
+ normalized_prompt_logprobs: List[float],
39
+ input_token_logprobs: List[List[Any]],
40
+ output_token_logprobs: List[List[Any]],
41
+ unconditional_token_logprobs: Optional[List[List[Any]]] = None,
42
+ ) -> ChoicesDecision:
43
+ """Select the option with the highest token length normalized prompt logprob."""
44
+ best_choice = choices[np.argmax(normalized_prompt_logprobs)]
45
+ meta_info = {
46
+ "normalized_prompt_logprobs": normalized_prompt_logprobs,
47
+ "input_token_logprobs": input_token_logprobs,
48
+ "output_token_logprobs": output_token_logprobs,
49
+ }
50
+ return ChoicesDecision(decision=best_choice, meta_info=meta_info)
51
+
52
+
53
+ token_length_normalized = TokenLengthNormalized()
54
+
55
+
56
+ class GreedyTokenSelection(ChoicesSamplingMethod):
57
+
58
+ def __call__(
59
+ self,
60
+ *,
61
+ choices: List[str],
62
+ normalized_prompt_logprobs: List[float],
63
+ input_token_logprobs: List[List[Any]],
64
+ output_token_logprobs: List[List[Any]],
65
+ unconditional_token_logprobs: Optional[List[List[Any]]] = None,
66
+ ) -> ChoicesDecision:
67
+ """Select the option based on greedy logprob selection. For overlapping options
68
+ where one option is a subset of a longer option, extend the shorter option using
69
+ its average logprob for comparison against the longer option."""
70
+
71
+ num_options = len(choices)
72
+ max_tokens = max(len(option) for option in input_token_logprobs)
73
+ logprob_matrix = self._build_logprob_matrix(
74
+ input_token_logprobs, max_tokens, num_options
75
+ )
76
+ remaining = self._greedy_selection(logprob_matrix, num_options, max_tokens)
77
+
78
+ best_choice = choices[remaining[0]]
79
+ meta_info = {
80
+ "normalized_prompt_logprobs": normalized_prompt_logprobs,
81
+ "input_token_logprobs": input_token_logprobs,
82
+ "output_token_logprobs": output_token_logprobs,
83
+ "greedy_logprob_matrix": logprob_matrix.tolist(),
84
+ }
85
+ return ChoicesDecision(decision=best_choice, meta_info=meta_info)
86
+
87
+ def _build_logprob_matrix(self, input_token_logprobs, max_tokens, num_options):
88
+ logprob_matrix = np.zeros((num_options, max_tokens))
89
+ for i, option in enumerate(input_token_logprobs):
90
+ actual_logprobs = [token[0] for token in option]
91
+ avg_logprob = np.mean(actual_logprobs)
92
+ logprob_matrix[i, : len(option)] = actual_logprobs
93
+ if len(option) < max_tokens:
94
+ logprob_matrix[i, len(option) :] = avg_logprob
95
+ return logprob_matrix
96
+
97
+ def _greedy_selection(self, logprob_matrix, num_options, max_tokens):
98
+ remaining = np.arange(num_options)
99
+ for j in range(max_tokens):
100
+ max_logprob = np.max(logprob_matrix[remaining, j])
101
+ remaining = remaining[logprob_matrix[remaining, j] == max_logprob]
102
+ if len(remaining) == 1:
103
+ break
104
+ return remaining
105
+
106
+
107
+ greedy_token_selection = GreedyTokenSelection()
108
+
109
+
110
+ class UnconditionalLikelihoodNormalized(ChoicesSamplingMethod):
111
+
112
+ @property
113
+ def requires_unconditional_logprobs(self) -> bool:
114
+ return True
115
+
116
+ def __call__(
117
+ self,
118
+ *,
119
+ choices: List[str],
120
+ normalized_prompt_logprobs: List[float],
121
+ input_token_logprobs: List[List[Any]],
122
+ output_token_logprobs: List[List[Any]],
123
+ unconditional_token_logprobs: Optional[List[List[Any]]] = None,
124
+ ) -> ChoicesDecision:
125
+ """Select the option with the highest average token logprob once normalized by
126
+ the unconditional token logprobs.
127
+
128
+ The first unconditional token logprob is assumed to be None. If so, it is
129
+ replaced with 0 for the purposes of normalization."""
130
+
131
+ if unconditional_token_logprobs is None:
132
+ raise ValueError(
133
+ "Unconditional token logprobs are required for this method."
134
+ )
135
+
136
+ normalized_unconditional_prompt_logprobs = self._normalize_logprobs(
137
+ input_token_logprobs, unconditional_token_logprobs
138
+ )
139
+
140
+ best_choice = choices[np.argmax(normalized_unconditional_prompt_logprobs)]
141
+ meta_info = {
142
+ "normalized_prompt_logprobs": normalized_prompt_logprobs,
143
+ "input_token_logprobs": input_token_logprobs,
144
+ "output_token_logprobs": output_token_logprobs,
145
+ "unconditional_token_logprobs": unconditional_token_logprobs,
146
+ "normalized_unconditional_prompt_logprobs": normalized_unconditional_prompt_logprobs,
147
+ }
148
+ return ChoicesDecision(decision=best_choice, meta_info=meta_info)
149
+
150
+ def _normalize_logprobs(self, input_token_logprobs, unconditional_token_logprobs):
151
+ normalized_unconditional_prompt_logprobs = []
152
+ for inputs, unconditionals in zip(
153
+ input_token_logprobs, unconditional_token_logprobs
154
+ ):
155
+ inputs_logprobs = np.array([token[0] for token in inputs])
156
+ unconditionals_logprobs = np.array([token[0] for token in unconditionals])
157
+ unconditionals_logprobs[0] = unconditionals_logprobs[0] or 0
158
+ normalized_unconditional_prompt_logprobs.append(
159
+ float(np.mean(inputs_logprobs - unconditionals_logprobs))
160
+ )
161
+ return normalized_unconditional_prompt_logprobs
162
+
163
+
164
+ unconditional_likelihood_normalized = UnconditionalLikelihoodNormalized()
sglang/lang/compiler.py CHANGED
@@ -125,7 +125,7 @@ class CompiledFunction:
125
125
  def run(
126
126
  self,
127
127
  *,
128
- max_new_tokens: int = 16,
128
+ max_new_tokens: int = 128,
129
129
  stop: Union[str, List[str]] = (),
130
130
  temperature: float = 1.0,
131
131
  top_p: float = 1.0,
@@ -155,7 +155,7 @@ class CompiledFunction:
155
155
  self,
156
156
  batch_kwargs,
157
157
  *,
158
- max_new_tokens: int = 16,
158
+ max_new_tokens: int = 128,
159
159
  stop: Union[str, List[str]] = (),
160
160
  temperature: float = 1.0,
161
161
  top_p: float = 1.0,
@@ -538,24 +538,17 @@ class StreamExecutor:
538
538
  self.stream_var_event[name].set()
539
539
 
540
540
  def _execute_select(self, expr: SglSelect):
541
- (
542
- decision,
543
- normalized_prompt_logprobs,
544
- input_token_logprobs,
545
- output_token_logprobs,
546
- ) = self.backend.select(self, expr.choices, expr.temperature)
541
+ choices_decision = self.backend.select(
542
+ self, expr.choices, expr.temperature, expr.choices_method
543
+ )
547
544
  if expr.name is not None:
548
545
  name = expr.name
549
- self.variables[name] = decision
550
- self.meta_info[name] = {
551
- "normalized_prompt_logprobs": normalized_prompt_logprobs,
552
- "input_token_logprobs": input_token_logprobs,
553
- "output_token_logprobs": output_token_logprobs,
554
- }
546
+ self.variables[name] = choices_decision.decision
547
+ self.meta_info[name] = choices_decision.meta_info
555
548
  self.variable_event[name].set()
556
549
  if self.stream_var_event:
557
550
  self.stream_var_event[name].set()
558
- self.text_ += decision
551
+ self.text_ += choices_decision.decision
559
552
 
560
553
  def _execute_variable(self, expr: SglVariable):
561
554
  src_executor = expr.source_stream_executor
sglang/lang/ir.py CHANGED
@@ -6,6 +6,7 @@ import warnings
6
6
  from typing import List, Optional, Union
7
7
 
8
8
  from sglang.global_config import global_config
9
+ from sglang.lang.choices import ChoicesSamplingMethod
9
10
 
10
11
  REGEX_INT = r"[-+]?[0-9]+"
11
12
  REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+"
@@ -15,7 +16,7 @@ REGEX_STRING = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
15
16
 
16
17
  @dataclasses.dataclass
17
18
  class SglSamplingParams:
18
- max_new_tokens: int = 16
19
+ max_new_tokens: int = 128
19
20
  stop: Union[str, List[str]] = ()
20
21
  temperature: float = 1.0
21
22
  top_p: float = 1.0
@@ -139,7 +140,7 @@ class SglFunction:
139
140
  def run(
140
141
  self,
141
142
  *args,
142
- max_new_tokens: int = 16,
143
+ max_new_tokens: int = 128,
143
144
  stop: Union[str, List[str]] = (),
144
145
  temperature: float = 1.0,
145
146
  top_p: float = 1.0,
@@ -178,7 +179,7 @@ class SglFunction:
178
179
  self,
179
180
  batch_kwargs,
180
181
  *,
181
- max_new_tokens: int = 16,
182
+ max_new_tokens: int = 128,
182
183
  stop: Union[str, List[str]] = (),
183
184
  temperature: float = 1.0,
184
185
  top_p: float = 1.0,
@@ -461,14 +462,22 @@ class SglRoleEnd(SglExpr):
461
462
 
462
463
 
463
464
  class SglSelect(SglExpr):
464
- def __init__(self, name: str, choices: List[str], temperature: float):
465
+
466
+ def __init__(
467
+ self,
468
+ name: str,
469
+ choices: List[str],
470
+ temperature: float,
471
+ choices_method: ChoicesSamplingMethod,
472
+ ):
465
473
  super().__init__()
466
474
  self.name = name
467
475
  self.choices = choices
468
476
  self.temperature = temperature
477
+ self.choices_method = choices_method
469
478
 
470
479
  def __repr__(self):
471
- return f"Select({self.name}, choices={self.choices})"
480
+ return f"Select({self.name}, choices={self.choices}, choices_method={self.choices_method})"
472
481
 
473
482
 
474
483
  class SglFork(SglExpr):
@@ -54,7 +54,7 @@ class BaseToolCache:
54
54
  return val
55
55
 
56
56
  def init_value(self, key):
57
- raise NotImplementedError
57
+ raise NotImplementedError()
58
58
 
59
59
  def get_cache_hit_rate(self):
60
60
  if self.metrics["total"] == 0:
@@ -20,10 +20,20 @@ from sglang.srt.constrained.base_tool_cache import BaseToolCache
20
20
 
21
21
 
22
22
  class FSMCache(BaseToolCache):
23
- def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
23
+ def __init__(
24
+ self,
25
+ tokenizer_path,
26
+ tokenizer_args_dict,
27
+ enable=True,
28
+ skip_tokenizer_init=False,
29
+ ):
24
30
  super().__init__(enable=enable)
25
31
 
26
- if tokenizer_path.endswith(".json") or tokenizer_path.endswith(".model"):
32
+ if (
33
+ skip_tokenizer_init
34
+ or tokenizer_path.endswith(".json")
35
+ or tokenizer_path.endswith(".model")
36
+ ):
27
37
  # Do not support TiktokenTokenizer or SentencePieceTokenizer
28
38
  return
29
39
 
@@ -0,0 +1,33 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+ http://www.apache.org/licenses/LICENSE-2.0
7
+ Unless required by applicable law or agreed to in writing, software
8
+ distributed under the License is distributed on an "AS IS" BASIS,
9
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ See the License for the specific language governing permissions and
11
+ limitations under the License.
12
+ """
13
+
14
+ """Fused operators for activation layers."""
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from flashinfer.activation import silu_and_mul
20
+ from vllm.model_executor.custom_op import CustomOp
21
+
22
+
23
+ class SiluAndMul(CustomOp):
24
+ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
25
+ d = x.shape[-1] // 2
26
+ return F.silu(x[..., :d]) * x[..., d:]
27
+
28
+ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
29
+ d = x.shape[-1] // 2
30
+ output_shape = x.shape[:-1] + (d,)
31
+ out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
32
+ silu_and_mul(x, out)
33
+ return out
@@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
13
13
  limitations under the License.
14
14
  """
15
15
 
16
+ """
17
+ Memory-efficient attention for decoding.
18
+ """
19
+
16
20
  # Adapted from
17
21
  # https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
18
22
  # https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py
@@ -194,7 +198,7 @@ def _fwd_kernel_stage2(
194
198
  tl.store(out_ptrs, acc)
195
199
 
196
200
 
197
- def _token_att_m_fwd(
201
+ def _decode_att_m_fwd(
198
202
  q,
199
203
  k_buffer,
200
204
  att_out,
@@ -254,7 +258,7 @@ def _token_att_m_fwd(
254
258
  )
255
259
 
256
260
 
257
- def _token_softmax_reducev_fwd(
261
+ def _decode_softmax_reducev_fwd(
258
262
  logics,
259
263
  v_buffer,
260
264
  o,
@@ -292,7 +296,7 @@ def _token_softmax_reducev_fwd(
292
296
  )
293
297
 
294
298
 
295
- def token_attention_fwd(
299
+ def decode_attention_fwd(
296
300
  q,
297
301
  k_buffer,
298
302
  v_buffer,
@@ -312,7 +316,7 @@ def token_attention_fwd(
312
316
  (q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda"
313
317
  )
314
318
 
315
- _token_att_m_fwd(
319
+ _decode_att_m_fwd(
316
320
  q,
317
321
  k_buffer,
318
322
  att_m,
@@ -324,7 +328,7 @@ def token_attention_fwd(
324
328
  sm_scale,
325
329
  logit_cap,
326
330
  )
327
- _token_softmax_reducev_fwd(
331
+ _decode_softmax_reducev_fwd(
328
332
  att_m,
329
333
  v_buffer,
330
334
  o,
@@ -13,11 +13,16 @@ See the License for the specific language governing permissions and
13
13
  limitations under the License.
14
14
  """
15
15
 
16
+ """
17
+ Memory-efficient attention for prefill.
18
+ It supporst page size = 1 and prefill with KV cache (i.e. extend).
19
+ """
20
+
16
21
  import torch
17
22
  import triton
18
23
  import triton.language as tl
19
24
 
20
- from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
25
+ from sglang.srt.layers.prefill_attention import context_attention_fwd
21
26
 
22
27
  CUDA_CAPABILITY = torch.cuda.get_device_capability()
23
28
 
@@ -0,0 +1,65 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ """Fused operators for normalization layers."""
17
+
18
+ from typing import Optional, Tuple, Union
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ from flashinfer.norm import fused_add_rmsnorm, rmsnorm
23
+ from vllm.model_executor.custom_op import CustomOp
24
+
25
+
26
+ class RMSNorm(CustomOp):
27
+ def __init__(
28
+ self,
29
+ hidden_size: int,
30
+ eps: float = 1e-6,
31
+ ) -> None:
32
+ super().__init__()
33
+ self.weight = nn.Parameter(torch.ones(hidden_size))
34
+ self.variance_epsilon = eps
35
+
36
+ def forward_cuda(
37
+ self,
38
+ x: torch.Tensor,
39
+ residual: Optional[torch.Tensor] = None,
40
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
41
+
42
+ if residual is not None:
43
+ fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
44
+ return x, residual
45
+ out = rmsnorm(x, self.weight.data, self.variance_epsilon)
46
+ return out
47
+
48
+ def forward_native(
49
+ self,
50
+ x: torch.Tensor,
51
+ residual: Optional[torch.Tensor] = None,
52
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
53
+ orig_dtype = x.dtype
54
+ x = x.to(torch.float32)
55
+ if residual is not None:
56
+ x = x + residual.to(torch.float32)
57
+ residual = x.to(orig_dtype)
58
+
59
+ variance = x.pow(2).mean(dim=-1, keepdim=True)
60
+ x = x * torch.rsqrt(variance + self.variance_epsilon)
61
+ x = x.to(orig_dtype) * self.weight
62
+ if residual is None:
63
+ return x
64
+ else:
65
+ return x, residual
@@ -25,7 +25,7 @@ from vllm.distributed import (
25
25
  tensor_model_parallel_all_gather,
26
26
  )
27
27
 
28
- from sglang.srt.model_executor.model_runner import ForwardMode, InputMetadata
28
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
29
29
 
30
30
 
31
31
  @dataclasses.dataclass
@@ -208,6 +208,11 @@ class LogitsProcessor(nn.Module):
208
208
  all_logits = tensor_model_parallel_all_gather(all_logits)
209
209
  all_logits = all_logits[:, : self.config.vocab_size].float()
210
210
 
211
+ if hasattr(self.config, "final_logit_softcapping"):
212
+ all_logits /= self.config.final_logit_softcapping
213
+ all_logits = torch.tanh(all_logits)
214
+ all_logits *= self.config.final_logit_softcapping
215
+
211
216
  all_logprobs = all_logits
212
217
  del all_logits, hidden_states
213
218
  all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
@@ -0,0 +1,50 @@
1
+ # adapted from
2
+ # https://github.com/vllm-project/vllm/blob/82a1b1a82b1fbb454c82a9ef95730b929c9b270c/vllm/model_executor/layers/pooler.py
3
+
4
+ from dataclasses import dataclass
5
+ from enum import IntEnum
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from sglang.srt.model_executor.model_runner import InputMetadata
11
+
12
+
13
+ class PoolingType(IntEnum):
14
+ LAST = 0
15
+
16
+
17
+ @dataclass
18
+ class EmbeddingPoolerOutput:
19
+ embeddings: torch.Tensor
20
+
21
+
22
+ class Pooler(nn.Module):
23
+ """A layer that pools specific information from hidden states.
24
+ This layer does the following:
25
+ 1. Extracts specific tokens or aggregates data based on pooling method.
26
+ 2. Normalizes output if specified.
27
+ 3. Returns structured results as `PoolerOutput`.
28
+ Attributes:
29
+ pooling_type: The type of pooling to use (LAST, AVERAGE, MAX).
30
+ normalize: Whether to normalize the pooled data.
31
+ """
32
+
33
+ def __init__(self, pooling_type: PoolingType, normalize: bool):
34
+ super().__init__()
35
+ self.pooling_type = pooling_type
36
+ self.normalize = normalize
37
+
38
+ def forward(
39
+ self, hidden_states: torch.Tensor, input_metadata: InputMetadata
40
+ ) -> EmbeddingPoolerOutput:
41
+ if self.pooling_type == PoolingType.LAST:
42
+ last_token_indices = torch.cumsum(input_metadata.extend_seq_lens, dim=0) - 1
43
+ pooled_data = hidden_states[last_token_indices]
44
+ else:
45
+ raise ValueError(f"Invalid pooling type: {self.pooling_type}")
46
+
47
+ if self.normalize:
48
+ pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
49
+
50
+ return EmbeddingPoolerOutput(embeddings=pooled_data)