sglang 0.3.1__py3-none-any.whl → 0.3.1.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. sglang/bench_latency.py +10 -3
  2. sglang/bench_server_latency.py +187 -0
  3. sglang/bench_serving.py +1 -1
  4. sglang/global_config.py +5 -13
  5. sglang/lang/interpreter.py +0 -3
  6. sglang/srt/constrained/fsm_cache.py +5 -1
  7. sglang/srt/layers/activation.py +16 -1
  8. sglang/srt/layers/attention_backend.py +12 -12
  9. sglang/srt/layers/fused_moe/layer.py +27 -7
  10. sglang/srt/layers/layernorm.py +21 -6
  11. sglang/srt/layers/sampler.py +40 -98
  12. sglang/srt/lora/lora_manager.py +11 -8
  13. sglang/srt/managers/io_struct.py +3 -0
  14. sglang/srt/managers/policy_scheduler.py +49 -93
  15. sglang/srt/managers/schedule_batch.py +2 -1
  16. sglang/srt/managers/tp_worker.py +19 -13
  17. sglang/srt/model_executor/cuda_graph_runner.py +25 -13
  18. sglang/srt/model_executor/model_runner.py +37 -46
  19. sglang/srt/models/deepseek_v2.py +8 -3
  20. sglang/srt/models/llama.py +1 -3
  21. sglang/srt/models/llama_classification.py +2 -3
  22. sglang/srt/models/minicpm3.py +7 -3
  23. sglang/srt/models/olmoe.py +415 -0
  24. sglang/srt/models/xverse.py +1 -3
  25. sglang/srt/models/xverse_moe.py +1 -4
  26. sglang/srt/sampling/sampling_batch_info.py +3 -50
  27. sglang/srt/server.py +6 -1
  28. sglang/srt/server_args.py +39 -10
  29. sglang/srt/utils.py +7 -51
  30. sglang/test/few_shot_gsm8k.py +8 -2
  31. sglang/test/test_utils.py +1 -1
  32. sglang/version.py +1 -1
  33. {sglang-0.3.1.dist-info → sglang-0.3.1.post2.dist-info}/METADATA +4 -5
  34. {sglang-0.3.1.dist-info → sglang-0.3.1.post2.dist-info}/RECORD +37 -35
  35. {sglang-0.3.1.dist-info → sglang-0.3.1.post2.dist-info}/WHEEL +1 -1
  36. {sglang-0.3.1.dist-info → sglang-0.3.1.post2.dist-info}/LICENSE +0 -0
  37. {sglang-0.3.1.dist-info → sglang-0.3.1.post2.dist-info}/top_level.txt +0 -0
sglang/bench_latency.py CHANGED
@@ -1,5 +1,7 @@
1
1
  """
2
- Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py.
2
+ Benchmark the latency of running a single static batch.
3
+ This script does not launch a server and uses the low-level APIs.
4
+ It accepts arguments similar to those of launch_server.py.
3
5
 
4
6
  # Usage (latency test)
5
7
  ## with dummy weights:
@@ -63,7 +65,7 @@ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
63
65
  from sglang.srt.model_executor.model_runner import ModelRunner
64
66
  from sglang.srt.sampling.sampling_params import SamplingParams
65
67
  from sglang.srt.server_args import ServerArgs
66
- from sglang.srt.utils import suppress_other_loggers
68
+ from sglang.srt.utils import kill_child_process, suppress_other_loggers
67
69
 
68
70
 
69
71
  @dataclasses.dataclass
@@ -502,4 +504,9 @@ if __name__ == "__main__":
502
504
  format="%(message)s",
503
505
  )
504
506
 
505
- main(server_args, bench_args)
507
+ try:
508
+ main(server_args, bench_args)
509
+ except Exception as e:
510
+ raise e
511
+ finally:
512
+ kill_child_process(os.getpid(), including_parent=False)
@@ -0,0 +1,187 @@
1
+ """
2
+ Benchmark the latency of serving a single batch with a real server.
3
+ This script launches a server and uses the HTTP interface.
4
+ It accepts arguments similar to those of launch_server.py.
5
+
6
+ Usage:
7
+
8
+ python3 -m sglang.bench_server_latency --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
9
+ """
10
+
11
+ import argparse
12
+ import dataclasses
13
+ import itertools
14
+ import json
15
+ import multiprocessing
16
+ import os
17
+ import time
18
+ from typing import Tuple
19
+
20
+ import numpy as np
21
+ import requests
22
+
23
+ from sglang.srt.server import launch_server
24
+ from sglang.srt.server_args import ServerArgs
25
+ from sglang.srt.utils import kill_child_process
26
+
27
+
28
+ @dataclasses.dataclass
29
+ class BenchArgs:
30
+ run_name: str = "default"
31
+ batch_size: Tuple[int] = (1,)
32
+ input_len: Tuple[int] = (1024,)
33
+ output_len: Tuple[int] = (16,)
34
+ result_filename: str = "result.jsonl"
35
+
36
+ @staticmethod
37
+ def add_cli_args(parser: argparse.ArgumentParser):
38
+ parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
39
+ parser.add_argument(
40
+ "--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
41
+ )
42
+ parser.add_argument(
43
+ "--input-len", type=int, nargs="+", default=BenchArgs.input_len
44
+ )
45
+ parser.add_argument(
46
+ "--output-len", type=int, nargs="+", default=BenchArgs.output_len
47
+ )
48
+ parser.add_argument(
49
+ "--result-filename", type=str, default=BenchArgs.result_filename
50
+ )
51
+
52
+ @classmethod
53
+ def from_cli_args(cls, args: argparse.Namespace):
54
+ # use the default value's type to case the args into correct types.
55
+ attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
56
+ return cls(
57
+ **{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
58
+ )
59
+
60
+
61
+ def launch_server_internal(server_args):
62
+ try:
63
+ launch_server(server_args)
64
+ except Exception as e:
65
+ raise e
66
+ finally:
67
+ kill_child_process(os.getpid(), including_parent=False)
68
+
69
+
70
+ def launch_server_process(server_args: ServerArgs):
71
+ proc = multiprocessing.Process(target=launch_server_internal, args=(server_args,))
72
+ proc.start()
73
+ base_url = f"http://{server_args.host}:{server_args.port}"
74
+ timeout = 600
75
+
76
+ start_time = time.time()
77
+ while time.time() - start_time < timeout:
78
+ try:
79
+ headers = {
80
+ "Content-Type": "application/json; charset=utf-8",
81
+ }
82
+ response = requests.get(f"{base_url}/v1/models", headers=headers)
83
+ if response.status_code == 200:
84
+ return proc, base_url
85
+ except requests.RequestException:
86
+ pass
87
+ time.sleep(10)
88
+ raise TimeoutError("Server failed to start within the timeout period.")
89
+
90
+
91
+ def run_one_case(
92
+ url: str,
93
+ batch_size: int,
94
+ input_len: int,
95
+ output_len: int,
96
+ run_name: str,
97
+ result_filename: str,
98
+ ):
99
+ input_ids = [
100
+ [int(x) for x in np.random.randint(0, high=16384, size=(input_len,))]
101
+ for _ in range(batch_size)
102
+ ]
103
+
104
+ tic = time.time()
105
+ response = requests.post(
106
+ url + "/generate",
107
+ json={
108
+ "input_ids": input_ids,
109
+ "sampling_params": {
110
+ "temperature": 0,
111
+ "max_new_tokens": output_len,
112
+ "ignore_eos": True,
113
+ },
114
+ },
115
+ )
116
+ latency = time.time() - tic
117
+
118
+ _ = response.json()
119
+ output_throughput = batch_size * output_len / latency
120
+ overall_throughput = batch_size * (input_len + output_len) / latency
121
+
122
+ print(f"batch size: {batch_size}")
123
+ print(f"latency: {latency:.2f} s")
124
+ print(f"output throughput: {output_throughput:.2f} token/s")
125
+ print(f"(input + output) throughput: {overall_throughput:.2f} token/s")
126
+
127
+ if result_filename:
128
+ with open(result_filename, "a") as fout:
129
+ res = {
130
+ "run_name": run_name,
131
+ "batch_size": batch_size,
132
+ "input_len": input_len,
133
+ "output_len": output_len,
134
+ "latency": round(latency, 4),
135
+ "output_throughput": round(output_throughput, 2),
136
+ "overall_throughput": round(overall_throughput, 2),
137
+ }
138
+ fout.write(json.dumps(res) + "\n")
139
+
140
+
141
+ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
142
+ proc, base_url = launch_server_process(server_args)
143
+
144
+ # warmup
145
+ run_one_case(
146
+ base_url,
147
+ batch_size=16,
148
+ input_len=1024,
149
+ output_len=16,
150
+ run_name="",
151
+ result_filename="",
152
+ )
153
+
154
+ # benchmark
155
+ try:
156
+ for bs, il, ol in itertools.product(
157
+ bench_args.batch_size, bench_args.input_len, bench_args.output_len
158
+ ):
159
+ run_one_case(
160
+ base_url,
161
+ bs,
162
+ il,
163
+ ol,
164
+ bench_args.run_name,
165
+ bench_args.result_filename,
166
+ )
167
+ finally:
168
+ kill_child_process(proc.pid)
169
+
170
+ print(f"\nResults are saved to {bench_args.result_filename}")
171
+
172
+
173
+ if __name__ == "__main__":
174
+ parser = argparse.ArgumentParser()
175
+ ServerArgs.add_cli_args(parser)
176
+ BenchArgs.add_cli_args(parser)
177
+ # For this script, model-path is not required
178
+ assert (
179
+ parser._actions[1].option_strings[0] == "--model-path"
180
+ ), "options changed, this code need to be updated"
181
+ parser._actions[1].required = False
182
+ args = parser.parse_args()
183
+
184
+ server_args = ServerArgs.from_cli_args(args)
185
+ bench_args = BenchArgs.from_cli_args(args)
186
+
187
+ run_benchmark(server_args, bench_args)
sglang/bench_serving.py CHANGED
@@ -2,7 +2,7 @@
2
2
  # Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/benchmark_serving.py
3
3
 
4
4
  """
5
- Benchmark online serving.
5
+ Benchmark online serving with dynamic requests.
6
6
 
7
7
  Usage:
8
8
  python3 -m sglang.bench_serving --backend sglang --num-prompt 10
sglang/global_config.py CHANGED
@@ -1,5 +1,7 @@
1
1
  """Global configurations"""
2
2
 
3
+ import os
4
+
3
5
 
4
6
  class GlobalConfig:
5
7
  def __init__(self):
@@ -16,30 +18,20 @@ class GlobalConfig:
16
18
  self.base_min_new_token_ratio = 0.1
17
19
  self.new_token_ratio_decay = 0.001
18
20
 
19
- # Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync.
20
- # This can improve the speed for large batch sizes during prefill.
21
- self.layer_sync_threshold = 8192
22
-
23
21
  # Runtime constants: others
24
22
  self.num_continue_decode_steps = 10
25
23
  self.retract_decode_steps = 20
26
- self.flashinfer_workspace_size = 384 * 1024 * 1024
24
+ self.flashinfer_workspace_size = os.environ.get(
25
+ "FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024
26
+ )
27
27
 
28
28
  # Output tokenization configs
29
29
  self.skip_special_tokens_in_output = True
30
30
  self.spaces_between_special_tokens_in_out = True
31
31
 
32
32
  # Interpreter optimization configs
33
- self.eager_fill_image = False
34
33
  self.enable_precache_with_tracing = True
35
34
  self.enable_parallel_encoding = True
36
- self.enable_parallel_decoding = True
37
-
38
- # Deprecated
39
- # Choices: ["no_adjust", "adjust_cache"]
40
- # no_adjust: Do not adjust the position embedding of KV cache.
41
- # adjust_cache: Adjust the position embedding of KV cache.
42
- self.concate_and_append_mode = "no_adjust"
43
35
 
44
36
 
45
37
  global_config = GlobalConfig()
@@ -434,9 +434,6 @@ class StreamExecutor:
434
434
  self.cur_images.append((path, base64_data))
435
435
  self.text_ += self.chat_template.image_token
436
436
 
437
- # if global_config.eager_fill_image:
438
- # self.backend.fill_image(self)
439
-
440
437
  def _spec_gen(self, sampling_params):
441
438
  stop = sampling_params.stop
442
439
  max_new_tokens = sampling_params.max_new_tokens
@@ -29,6 +29,7 @@ class FSMCache(BaseToolCache):
29
29
  tokenizer_args_dict,
30
30
  enable=True,
31
31
  skip_tokenizer_init=False,
32
+ constrained_json_whitespace_pattern=None,
32
33
  ):
33
34
  super().__init__(enable=enable)
34
35
 
@@ -63,11 +64,14 @@ class FSMCache(BaseToolCache):
63
64
  self.outlines_tokenizer.vocabulary = (
64
65
  self.outlines_tokenizer.tokenizer.get_vocab()
65
66
  )
67
+ self.constrained_json_whitespace_pattern = constrained_json_whitespace_pattern
66
68
 
67
69
  def init_value(self, key):
68
70
  key_type, key_string = key
69
71
  if key_type == "json":
70
- regex = build_regex_from_schema(key_string, whitespace_pattern=r"[\n\t ]*")
72
+ regex = build_regex_from_schema(
73
+ key_string, whitespace_pattern=self.constrained_json_whitespace_pattern
74
+ )
71
75
  elif key_type == "regex":
72
76
  regex = key_string
73
77
  else:
@@ -13,12 +13,18 @@ limitations under the License.
13
13
 
14
14
  """Fused operators for activation layers."""
15
15
 
16
+ import logging
16
17
  from typing import Optional
17
18
 
18
19
  import torch
19
20
  import torch.nn as nn
20
21
  import torch.nn.functional as F
21
- from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
22
+
23
+ from sglang.srt.utils import is_hip
24
+
25
+ if not is_hip():
26
+ from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
27
+
22
28
  from vllm.distributed import (
23
29
  divide,
24
30
  get_tensor_model_parallel_rank,
@@ -28,6 +34,8 @@ from vllm.model_executor.custom_op import CustomOp
28
34
  from vllm.model_executor.layers.quantization import QuantizationConfig
29
35
  from vllm.model_executor.utils import set_weight_attrs
30
36
 
37
+ logger = logging.getLogger(__name__)
38
+
31
39
 
32
40
  class SiluAndMul(CustomOp):
33
41
  def forward_native(self, x: torch.Tensor) -> torch.Tensor:
@@ -135,3 +143,10 @@ def get_act_fn(
135
143
  act_fn, intermediate_size, input_is_parallel, params_dtype
136
144
  )
137
145
  return act_fn
146
+
147
+
148
+ if is_hip():
149
+ logger.info(
150
+ "FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries."
151
+ )
152
+ from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
@@ -12,22 +12,26 @@ from typing import TYPE_CHECKING
12
12
 
13
13
  import torch
14
14
  import torch.nn as nn
15
- from flashinfer import (
16
- BatchDecodeWithPagedKVCacheWrapper,
17
- BatchPrefillWithPagedKVCacheWrapper,
18
- BatchPrefillWithRaggedKVCacheWrapper,
19
- )
20
- from flashinfer.cascade import merge_state
21
- from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
22
15
 
23
16
  from sglang.global_config import global_config
24
17
  from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
25
18
  from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
26
19
  from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
20
+ from sglang.srt.utils import is_hip
27
21
 
28
22
  if TYPE_CHECKING:
29
23
  from sglang.srt.model_executor.model_runner import ModelRunner
30
24
 
25
+ # ROCm: flashinfer available later
26
+ if not is_hip():
27
+ from flashinfer import (
28
+ BatchDecodeWithPagedKVCacheWrapper,
29
+ BatchPrefillWithPagedKVCacheWrapper,
30
+ BatchPrefillWithRaggedKVCacheWrapper,
31
+ )
32
+ from flashinfer.cascade import merge_state
33
+ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
34
+
31
35
 
32
36
  class AttentionBackend(ABC):
33
37
  """The base class of attention backends"""
@@ -150,7 +154,7 @@ class FlashInferAttnBackend(AttentionBackend):
150
154
  # Some heuristics to check whether to use ragged forward
151
155
  use_ragged = False
152
156
  if (
153
- int(torch.sum(input_metadata.seq_lens)) > 4096
157
+ torch.sum(input_metadata.seq_lens).item() >= 4096
154
158
  and self.model_runner.sliding_window_size is None
155
159
  ):
156
160
  use_ragged = True
@@ -301,10 +305,6 @@ class FlashInferAttnBackend(AttentionBackend):
301
305
  layer.layer_id, input_metadata.out_cache_loc, k, v
302
306
  )
303
307
 
304
- if total_num_tokens >= global_config.layer_sync_threshold:
305
- # TODO: Revisit this. Why is this synchronize needed?
306
- torch.cuda.synchronize()
307
-
308
308
  return o.view(-1, layer.tp_q_head_num * layer.head_dim)
309
309
 
310
310
  def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
@@ -18,6 +18,8 @@ from vllm.model_executor.layers.quantization.base_config import (
18
18
  from vllm.model_executor.layers.quantization.fp8 import Fp8Config
19
19
  from vllm.model_executor.utils import set_weight_attrs
20
20
 
21
+ from sglang.srt.utils import is_hip
22
+
21
23
  logger = init_logger(__name__)
22
24
 
23
25
 
@@ -381,6 +383,7 @@ from torch.nn import Module
381
383
  from vllm import _custom_ops as ops
382
384
  from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
383
385
  all_close_1d,
386
+ normalize_e4m3fn_to_e4m3fnuz,
384
387
  per_tensor_dequantize,
385
388
  )
386
389
  from vllm.utils import print_warning_once
@@ -479,14 +482,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
479
482
 
480
483
  def process_weights_after_loading(self, layer: Module) -> None:
481
484
 
482
- # If checkpoint is fp16, quantize in place.
485
+ # If checkpoint is fp16 or bfloat16, quantize in place.
483
486
  if not self.quant_config.is_checkpoint_fp8_serialized:
484
- w13_weight = torch.empty_like(
485
- layer.w13_weight.data, dtype=torch.float8_e4m3fn
486
- )
487
- w2_weight = torch.empty_like(
488
- layer.w2_weight.data, dtype=torch.float8_e4m3fn
489
- )
487
+ # If ROCm, use float8_e4m3fnuz instead (MI300x HW)
488
+ fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
489
+ w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
490
+ w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
490
491
 
491
492
  # Re-initialize w13_scale because we directly quantize
492
493
  # merged w13 weights and generate a single scaling factor.
@@ -534,6 +535,25 @@ class Fp8MoEMethod(FusedMoEMethodBase):
534
535
  layer.a2_scale.max(), requires_grad=False
535
536
  )
536
537
 
538
+ # If ROCm, normalize the weights and scales to e4m3fnuz
539
+ if is_hip():
540
+ # Normalize the weights and scales
541
+ w13_weight, w13_scale, a13_scale = normalize_e4m3fn_to_e4m3fnuz(
542
+ layer.w13_weight, layer.w13_scale, layer.a13_scale
543
+ )
544
+ w2_weight, w2_scale, a2_scale = normalize_e4m3fn_to_e4m3fnuz(
545
+ layer.w2_weight, layer.w2_scale, layer.a2_scale
546
+ )
547
+ # Reset the parameters
548
+ layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
549
+ layer.w13_scale = torch.nn.Parameter(w13_scale, requires_grad=False)
550
+ if a13_scale is not None:
551
+ layer.a13_scale = torch.nn.Parameter(a13_scale, requires_grad=False)
552
+ layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
553
+ layer.w2_scale = torch.nn.Parameter(w2_scale, requires_grad=False)
554
+ if a2_scale is not None:
555
+ layer.a2_scale = torch.nn.Parameter(a2_scale, requires_grad=False)
556
+
537
557
  # Fp8 moe kernel needs single weight scale for w13 per expert.
538
558
  # We take the max then dequant and requant each expert.
539
559
  assert layer.w13_scale is not None
@@ -15,18 +15,26 @@ limitations under the License.
15
15
 
16
16
  """Fused operators for normalization layers."""
17
17
 
18
+ import logging
18
19
  from typing import Optional, Tuple, Union
19
20
 
20
21
  import torch
21
22
  import torch.nn as nn
22
- from flashinfer.norm import (
23
- fused_add_rmsnorm,
24
- gemma_fused_add_rmsnorm,
25
- gemma_rmsnorm,
26
- rmsnorm,
27
- )
23
+
24
+ from sglang.srt.utils import is_hip
25
+
26
+ if not is_hip():
27
+ from flashinfer.norm import (
28
+ fused_add_rmsnorm,
29
+ gemma_fused_add_rmsnorm,
30
+ gemma_rmsnorm,
31
+ rmsnorm,
32
+ )
33
+
28
34
  from vllm.model_executor.custom_op import CustomOp
29
35
 
36
+ logger = logging.getLogger(__name__)
37
+
30
38
 
31
39
  class RMSNorm(CustomOp):
32
40
  def __init__(
@@ -109,3 +117,10 @@ class GemmaRMSNorm(CustomOp):
109
117
  return x, residual
110
118
  out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
111
119
  return out
120
+
121
+
122
+ if is_hip():
123
+ logger.info(
124
+ "FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries."
125
+ )
126
+ from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
@@ -1,51 +1,28 @@
1
- import dataclasses
2
1
  import logging
3
- from typing import Tuple, Union
2
+ from typing import Union
4
3
 
5
4
  import torch
6
- from flashinfer.sampling import (
7
- min_p_sampling_from_probs,
8
- top_k_renorm_prob,
9
- top_k_top_p_sampling_from_probs,
10
- top_p_renorm_prob,
11
- )
12
- from torch.library import custom_op as torch_custom_op
13
- from vllm.model_executor.custom_op import CustomOp
5
+ from torch import nn
14
6
 
15
7
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
16
-
17
- # TODO: move this dict to another place
18
8
  from sglang.srt.managers.schedule_batch import global_server_args_dict
19
9
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
10
+ from sglang.srt.utils import is_hip
11
+
12
+ # ROCm: flashinfer available later
13
+ if not is_hip():
14
+ from flashinfer.sampling import (
15
+ min_p_sampling_from_probs,
16
+ top_k_renorm_prob,
17
+ top_k_top_p_sampling_from_probs,
18
+ top_p_renorm_prob,
19
+ )
20
20
 
21
21
  logger = logging.getLogger(__name__)
22
22
 
23
23
 
24
- @dataclasses.dataclass
25
- class SampleOutput:
26
- success: torch.Tensor
27
- probs: torch.Tensor
28
- batch_next_token_ids: torch.Tensor
29
-
30
-
31
- class Sampler(CustomOp):
32
- def __init__(self):
33
- super().__init__()
34
- # FIXME: torch.multinomial has too many bugs
35
- self.forward_native = self.forward_cuda
36
- self.is_torch_compile = False
37
-
38
- def _get_probs(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
39
- # Post process logits
40
- logits = logits.contiguous()
41
- logits.div_(sampling_info.temperatures)
42
- if self.is_torch_compile:
43
- # FIXME: Temporary workaround for unknown bugs in torch.compile
44
- logits.add_(0)
45
-
46
- return torch.softmax(logits, dim=-1)
47
-
48
- def forward_cuda(
24
+ class Sampler(nn.Module):
25
+ def forward(
49
26
  self,
50
27
  logits: Union[torch.Tensor, LogitsProcessorOutput],
51
28
  sampling_info: SamplingBatchInfo,
@@ -53,7 +30,18 @@ class Sampler(CustomOp):
53
30
  if isinstance(logits, LogitsProcessorOutput):
54
31
  logits = logits.next_token_logits
55
32
 
56
- probs = self._get_probs(logits, sampling_info)
33
+ # Post process logits
34
+ logits = logits.contiguous()
35
+ logits.div_(sampling_info.temperatures)
36
+ probs = torch.softmax(logits, dim=-1)
37
+ logits = None
38
+ del logits
39
+
40
+ if torch.any(torch.isnan(probs)):
41
+ logger.warning("Detected errors during sampling! NaN in the probability.")
42
+ probs = torch.where(
43
+ torch.isnan(probs), torch.full_like(probs, 1e-10), probs
44
+ )
57
45
 
58
46
  if global_server_args_dict["sampling_backend"] == "flashinfer":
59
47
  max_top_k_round, batch_size = 32, probs.shape[0]
@@ -67,12 +55,20 @@ class Sampler(CustomOp):
67
55
  probs, uniform_samples, sampling_info.min_ps
68
56
  )
69
57
  else:
70
- batch_next_token_ids, success = flashinfer_top_k_top_p(
71
- probs, uniform_samples, sampling_info.top_ks, sampling_info.top_ps
58
+ batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
59
+ probs,
60
+ uniform_samples,
61
+ sampling_info.top_ks,
62
+ sampling_info.top_ps,
63
+ filter_apply_order="joint",
72
64
  )
65
+
66
+ if not torch.all(success):
67
+ logger.warning("Detected errors during sampling!")
68
+ batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
73
69
  elif global_server_args_dict["sampling_backend"] == "pytorch":
74
70
  # Here we provide a slower fallback implementation.
75
- batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
71
+ batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
76
72
  probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
77
73
  )
78
74
  else:
@@ -80,48 +76,7 @@ class Sampler(CustomOp):
80
76
  f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
81
77
  )
82
78
 
83
- return SampleOutput(success, probs, batch_next_token_ids)
84
-
85
- def forward_native(
86
- self,
87
- logits: Union[torch.Tensor, LogitsProcessorOutput],
88
- sampling_info: SamplingBatchInfo,
89
- ):
90
- if isinstance(logits, LogitsProcessorOutput):
91
- logits = logits.next_token_logits
92
-
93
- probs = self._get_probs(logits, sampling_info)
94
-
95
- batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
96
- probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
97
- )
98
-
99
- return SampleOutput(success, probs, batch_next_token_ids)
100
-
101
-
102
- @torch_custom_op("my_lib::flashinfer_top_k_top_p", mutates_args={})
103
- def flashinfer_top_k_top_p(
104
- probs: torch.Tensor,
105
- uniform_samples: torch.Tensor,
106
- top_ks: torch.Tensor,
107
- top_ps: torch.Tensor,
108
- ) -> Tuple[torch.Tensor, torch.Tensor]:
109
- # NOTE: we do not use min_p neither in CUDA nor in torch.compile
110
- return top_k_top_p_sampling_from_probs(probs, uniform_samples, top_ks, top_ps)
111
-
112
-
113
- @flashinfer_top_k_top_p.register_fake
114
- def _(
115
- probs: torch.Tensor,
116
- uniform_samples: torch.Tensor,
117
- top_ks: torch.Tensor,
118
- top_ps: torch.Tensor,
119
- ) -> Tuple[torch.Tensor, torch.Tensor]:
120
- bs = probs.shape[0]
121
- return (
122
- torch.ones(bs, dtype=torch.bool, device=probs.device),
123
- torch.zeros(bs, dtype=torch.int32, device=probs.device),
124
- )
79
+ return batch_next_token_ids
125
80
 
126
81
 
127
82
  def top_k_top_p_min_p_sampling_from_probs_torch(
@@ -141,19 +96,6 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
141
96
  ] = 0.0
142
97
  probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
143
98
  probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
144
- try:
145
- # FIXME: torch.multiomial does not support num_samples = 1
146
- sampled_index = torch.multinomial(probs_sort, num_samples=2, replacement=True)[
147
- :, :1
148
- ]
149
- except RuntimeError as e:
150
- logger.warning(f"Sampling error: {e}")
151
- batch_next_token_ids = torch.zeros(
152
- (probs_sort.shape[0],), dtype=torch.int32, device=probs.device
153
- )
154
- success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
155
- return batch_next_token_ids, success
156
-
99
+ sampled_index = torch.multinomial(probs_sort, num_samples=1)
157
100
  batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
158
- success = torch.ones(probs.shape[0], dtype=torch.bool, device=probs.device)
159
- return batch_next_token_ids, success
101
+ return batch_next_token_ids