sglang 0.4.1.post5__py3-none-any.whl → 0.4.1.post6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. sglang/srt/configs/model_config.py +15 -6
  2. sglang/srt/layers/attention/flashinfer_backend.py +17 -3
  3. sglang/srt/layers/linear.py +36 -98
  4. sglang/srt/layers/moe/fused_moe_triton/layer.py +37 -9
  5. sglang/srt/layers/moe/topk.py +4 -2
  6. sglang/srt/layers/parameter.py +24 -16
  7. sglang/srt/layers/quantization/__init__.py +2 -0
  8. sglang/srt/layers/quantization/fp8.py +106 -52
  9. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  10. sglang/srt/layers/quantization/int8_kernel.py +54 -0
  11. sglang/srt/layers/quantization/modelopt_quant.py +1 -1
  12. sglang/srt/layers/quantization/w8a8_int8.py +117 -0
  13. sglang/srt/layers/radix_attention.py +2 -0
  14. sglang/srt/layers/vocab_parallel_embedding.py +15 -2
  15. sglang/srt/managers/configure_logging.py +43 -0
  16. sglang/srt/managers/detokenizer_manager.py +0 -2
  17. sglang/srt/managers/io_struct.py +29 -13
  18. sglang/srt/managers/scheduler.py +48 -9
  19. sglang/srt/managers/tokenizer_manager.py +109 -49
  20. sglang/srt/mem_cache/memory_pool.py +107 -52
  21. sglang/srt/metrics/collector.py +10 -5
  22. sglang/srt/model_executor/model_runner.py +43 -6
  23. sglang/srt/models/llama.py +37 -2
  24. sglang/srt/models/qwen2.py +11 -0
  25. sglang/srt/models/qwen2_eagle.py +131 -0
  26. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
  27. sglang/srt/sampling/sampling_batch_info.py +14 -5
  28. sglang/srt/sampling/sampling_params.py +1 -1
  29. sglang/srt/server.py +114 -61
  30. sglang/srt/server_args.py +27 -18
  31. sglang/srt/speculative/eagle_worker.py +1 -0
  32. sglang/srt/torch_memory_saver_adapter.py +59 -0
  33. sglang/srt/utils.py +29 -0
  34. sglang/version.py +1 -1
  35. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post6.dist-info}/METADATA +12 -10
  36. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post6.dist-info}/RECORD +39 -34
  37. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post6.dist-info}/LICENSE +0 -0
  38. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post6.dist-info}/WHEEL +0 -0
  39. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post6.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,6 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
2
2
 
3
3
  import logging
4
- import os
5
4
  from typing import Any, Callable, Dict, List, Optional
6
5
 
7
6
  import torch
@@ -40,12 +39,15 @@ from sglang.srt.layers.quantization.fp8_utils import (
40
39
  from sglang.srt.utils import (
41
40
  get_bool_env_var,
42
41
  is_hip,
42
+ permute_weight,
43
43
  print_warning_once,
44
44
  set_weight_attrs,
45
45
  )
46
46
 
47
47
  ACTIVATION_SCHEMES = ["static", "dynamic"]
48
48
 
49
+ is_hip_ = is_hip()
50
+
49
51
  logger = logging.getLogger(__name__)
50
52
 
51
53
 
@@ -161,7 +163,7 @@ class Fp8LinearMethod(LinearMethodBase):
161
163
  # kernel for fast weight-only FP8 quantization
162
164
  self.use_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN")
163
165
  # Disable marlin for ROCm
164
- if is_hip():
166
+ if is_hip_:
165
167
  self.use_marlin = False
166
168
 
167
169
  self.block_quant = self.quant_config.weight_block_size is not None
@@ -273,7 +275,7 @@ class Fp8LinearMethod(LinearMethodBase):
273
275
  # Block quant doesn't need to process weights after loading
274
276
  if self.block_quant:
275
277
  # If ROCm, normalize the weights and scales to e4m3fnuz
276
- if is_hip():
278
+ if is_hip_:
277
279
  # activation_scheme: dynamic
278
280
  weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
279
281
  weight=layer.weight,
@@ -330,7 +332,7 @@ class Fp8LinearMethod(LinearMethodBase):
330
332
  weight_scale = layer.weight_scale
331
333
 
332
334
  # If ROCm, normalize the weights and scales to e4m3fnuz
333
- if is_hip():
335
+ if is_hip_:
334
336
  weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
335
337
  weight=weight,
336
338
  weight_scale=weight_scale,
@@ -567,7 +569,7 @@ class Fp8MoEMethod:
567
569
  # Block quant doesn't need to process weights after loading
568
570
  if self.block_quant:
569
571
  # If ROCm, normalize the weights and scales to e4m3fnuz
570
- if is_hip():
572
+ if is_hip_:
571
573
  # activation_scheme: dynamic
572
574
  w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
573
575
  weight=layer.w13_weight,
@@ -594,7 +596,7 @@ class Fp8MoEMethod:
594
596
  # If checkpoint is fp16 or bfloat16, quantize in place.
595
597
  if not self.quant_config.is_checkpoint_fp8_serialized:
596
598
  # If ROCm, use float8_e4m3fnuz instead (MI300x HW)
597
- fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
599
+ fp8_dtype = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
598
600
  w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
599
601
  w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
600
602
 
@@ -616,18 +618,30 @@ class Fp8MoEMethod:
616
618
  layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
617
619
  layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
618
620
 
619
- # If ROCm, apply weight padding (min. Mem channel contention) only if set
620
- if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
621
- layer.w13_weight = torch.nn.Parameter(
622
- F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
623
- requires_grad=False,
624
- )
625
- torch.cuda.empty_cache()
626
- layer.w2_weight = torch.nn.Parameter(
627
- F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
628
- requires_grad=False,
629
- )
630
- torch.cuda.empty_cache()
621
+ if is_hip_:
622
+ if get_bool_env_var("CK_MOE"):
623
+ layer.w13_weight = torch.nn.Parameter(
624
+ permute_weight(layer.w13_weight.data),
625
+ requires_grad=False,
626
+ )
627
+ torch.cuda.empty_cache()
628
+ layer.w2_weight = torch.nn.Parameter(
629
+ permute_weight(layer.w2_weight.data),
630
+ requires_grad=False,
631
+ )
632
+ torch.cuda.empty_cache()
633
+ elif get_bool_env_var("MOE_PADDING"):
634
+ # If ROCm, apply weight padding (min. Mem channel contention) only if set
635
+ layer.w13_weight = torch.nn.Parameter(
636
+ F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
637
+ requires_grad=False,
638
+ )
639
+ torch.cuda.empty_cache()
640
+ layer.w2_weight = torch.nn.Parameter(
641
+ F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
642
+ requires_grad=False,
643
+ )
644
+ torch.cuda.empty_cache()
631
645
  return
632
646
 
633
647
  # If checkpoint is fp8, we need to handle that the
@@ -658,7 +672,7 @@ class Fp8MoEMethod:
658
672
  )
659
673
 
660
674
  # If ROCm, normalize the weights and scales to e4m3fnuz
661
- if is_hip():
675
+ if is_hip_:
662
676
  # Normalize the weights and scales
663
677
  w13_weight, w13_weight_scale, w13_input_scale = (
664
678
  normalize_e4m3fn_to_e4m3fnuz(
@@ -708,18 +722,30 @@ class Fp8MoEMethod:
708
722
  max_w13_scales, requires_grad=False
709
723
  )
710
724
 
711
- # If ROCm, apply weight padding (min. Mem channel contention) only if set
712
- if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
713
- layer.w13_weight = torch.nn.Parameter(
714
- F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
715
- requires_grad=False,
716
- )
717
- torch.cuda.empty_cache()
718
- layer.w2_weight = torch.nn.Parameter(
719
- F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
720
- requires_grad=False,
721
- )
722
- torch.cuda.empty_cache()
725
+ if is_hip_:
726
+ if get_bool_env_var("CK_MOE"):
727
+ layer.w13_weight = torch.nn.Parameter(
728
+ permute_weight(layer.w13_weight.data),
729
+ requires_grad=False,
730
+ )
731
+ torch.cuda.empty_cache()
732
+ layer.w2_weight = torch.nn.Parameter(
733
+ permute_weight(layer.w2_weight.data),
734
+ requires_grad=False,
735
+ )
736
+ torch.cuda.empty_cache()
737
+ elif get_bool_env_var("MOE_PADDING"):
738
+ # If ROCm, apply weight padding (min. Mem channel contention) only if set
739
+ layer.w13_weight = torch.nn.Parameter(
740
+ F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
741
+ requires_grad=False,
742
+ )
743
+ torch.cuda.empty_cache()
744
+ layer.w2_weight = torch.nn.Parameter(
745
+ F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
746
+ requires_grad=False,
747
+ )
748
+ torch.cuda.empty_cache()
723
749
  return
724
750
 
725
751
  def apply(
@@ -752,27 +778,55 @@ class Fp8MoEMethod:
752
778
  correction_bias=correction_bias,
753
779
  )
754
780
 
755
- # Expert fusion with FP8 quantization
756
- return fused_experts(
757
- x,
758
- layer.w13_weight,
759
- layer.w2_weight,
760
- topk_weights=topk_weights,
761
- topk_ids=topk_ids,
762
- inplace=True,
763
- use_fp8_w8a8=True,
764
- w1_scale=(
765
- layer.w13_weight_scale_inv
766
- if self.block_quant
767
- else layer.w13_weight_scale
768
- ),
769
- w2_scale=(
770
- layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
771
- ),
772
- a1_scale=layer.w13_input_scale,
773
- a2_scale=layer.w2_input_scale,
774
- block_shape=self.quant_config.weight_block_size,
775
- )
781
+ if is_hip_ and get_bool_env_var("CK_MOE"):
782
+ import ater
783
+ from ater.fused_moe import fused_experts_ck
784
+
785
+ return fused_experts_ck(
786
+ x,
787
+ layer.w13_weight,
788
+ layer.w2_weight,
789
+ topk_weights=topk_weights,
790
+ topk_ids=topk_ids,
791
+ use_fp8_w8a8=True,
792
+ w1_scale=(
793
+ layer.w13_weight_scale_inv
794
+ if self.block_quant
795
+ else layer.w13_weight_scale
796
+ ),
797
+ w2_scale=(
798
+ layer.w2_weight_scale_inv
799
+ if self.block_quant
800
+ else layer.w2_weight_scale
801
+ ),
802
+ a1_scale=layer.w13_input_scale,
803
+ a2_scale=layer.w2_input_scale,
804
+ )
805
+
806
+ else:
807
+ # Expert fusion with FP8 quantization
808
+ return fused_experts(
809
+ x,
810
+ layer.w13_weight,
811
+ layer.w2_weight,
812
+ topk_weights=topk_weights,
813
+ topk_ids=topk_ids,
814
+ inplace=True,
815
+ use_fp8_w8a8=True,
816
+ w1_scale=(
817
+ layer.w13_weight_scale_inv
818
+ if self.block_quant
819
+ else layer.w13_weight_scale
820
+ ),
821
+ w2_scale=(
822
+ layer.w2_weight_scale_inv
823
+ if self.block_quant
824
+ else layer.w2_weight_scale
825
+ ),
826
+ a1_scale=layer.w13_input_scale,
827
+ a2_scale=layer.w2_input_scale,
828
+ block_shape=self.quant_config.weight_block_size,
829
+ )
776
830
 
777
831
 
778
832
  class Fp8KVCacheMethod(BaseKVCacheMethod):
@@ -1,8 +1,8 @@
1
1
  from typing import List, Optional, Tuple
2
2
 
3
3
  import torch
4
- from vllm.model_executor.parameter import RowvLLMParameter, _ColumnvLLMParameter
5
4
 
5
+ from sglang.srt.layers.parameter import RowvLLMParameter, _ColumnvLLMParameter
6
6
  from sglang.srt.layers.quantization.fp8_kernel import (
7
7
  per_token_group_quant_fp8,
8
8
  w8a8_block_fp8_matmul,
@@ -0,0 +1,54 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ @triton.jit
7
+ def _per_token_quant_int8(
8
+ x_ptr,
9
+ xq_ptr,
10
+ scale_ptr,
11
+ stride_x,
12
+ stride_xq,
13
+ N,
14
+ BLOCK: tl.constexpr,
15
+ ):
16
+ # Adapted from https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
17
+ row_id = tl.program_id(0)
18
+
19
+ cols = tl.arange(0, BLOCK)
20
+ mask = cols < N
21
+
22
+ x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask, other=0.0).to(tl.float32)
23
+ absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
24
+ scale_x = absmax / 127
25
+ x_q = x * (127 / absmax)
26
+ x_q = tl.extra.cuda.libdevice.round(x_q).to(tl.int8)
27
+
28
+ tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
29
+ tl.store(scale_ptr + row_id, scale_x)
30
+
31
+
32
+ def per_token_quant_int8(x):
33
+ M = x.numel() // x.shape[-1]
34
+ N = x.shape[-1]
35
+ x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
36
+ scales = torch.empty(x.shape[:-1] + (1,), device=x.device, dtype=torch.float32)
37
+ BLOCK = triton.next_power_of_2(N)
38
+ # heuristics for number of warps
39
+ num_warps = min(max(BLOCK // 256, 1), 8)
40
+
41
+ assert x.is_contiguous()
42
+ _per_token_quant_int8[(M,)](
43
+ x,
44
+ x_q,
45
+ scales,
46
+ stride_x=x.stride(-2),
47
+ stride_xq=x_q.stride(-2),
48
+ N=N,
49
+ BLOCK=BLOCK,
50
+ num_warps=num_warps,
51
+ num_stages=1,
52
+ )
53
+
54
+ return x_q, scales
@@ -11,9 +11,9 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
11
11
  cutlass_fp8_supported,
12
12
  requantize_with_max_scale,
13
13
  )
14
- from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
15
14
 
16
15
  from sglang.srt.layers.linear import LinearMethodBase
16
+ from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
17
17
  from sglang.srt.layers.quantization.base_config import (
18
18
  QuantizationConfig,
19
19
  QuantizeMethodBase,
@@ -0,0 +1,117 @@
1
+ from typing import Any, Dict, List, Optional
2
+
3
+ import torch
4
+
5
+ from sglang.srt.utils import is_cuda_available
6
+
7
+ is_cuda = is_cuda_available()
8
+ if is_cuda:
9
+ from sgl_kernel import int8_scaled_mm
10
+
11
+ from torch.nn.parameter import Parameter
12
+
13
+ from sglang.srt.layers.linear import LinearMethodBase
14
+ from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
15
+ from sglang.srt.layers.quantization.base_config import (
16
+ QuantizationConfig,
17
+ QuantizeMethodBase,
18
+ )
19
+ from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
20
+
21
+
22
+ class W8A8Int8Config(QuantizationConfig):
23
+ """Config class for W8A8 Int8 Quantization.
24
+
25
+ - Weight: static, per-channel, symmetric
26
+ - Activation: dynamic, per-token, symmetric
27
+ """
28
+
29
+ def __init__(self):
30
+ pass
31
+
32
+ @classmethod
33
+ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
34
+ return [torch.float16, torch.bfloat16]
35
+
36
+ @classmethod
37
+ def get_min_capability(cls) -> int:
38
+ return 75
39
+
40
+ @classmethod
41
+ def get_name(self) -> str:
42
+ return "w8a8_int8"
43
+
44
+ @classmethod
45
+ def get_config_filenames(cls) -> List[str]:
46
+ return []
47
+
48
+ @classmethod
49
+ def from_config(cls, config: Dict[str, Any]) -> "W8A8Int8Config":
50
+ return cls()
51
+
52
+ def get_quant_method(
53
+ self,
54
+ layer: torch.nn.Module,
55
+ prefix: str,
56
+ ) -> Optional["QuantizeMethodBase"]:
57
+ from vllm.model_executor.layers.linear import LinearBase
58
+
59
+ if isinstance(layer, LinearBase):
60
+ return W8A8Int8LinearMethod(self)
61
+ return None
62
+
63
+ def get_scaled_act_names(self) -> List[str]:
64
+ return []
65
+
66
+
67
+ class W8A8Int8LinearMethod(LinearMethodBase):
68
+
69
+ def __init__(self, quantization_config: W8A8Int8Config):
70
+ self.quantization_config = quantization_config
71
+
72
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
73
+ layer.weight = Parameter(layer.weight.t(), requires_grad=False)
74
+ layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
75
+
76
+ def create_weights(
77
+ self,
78
+ layer: torch.nn.Module,
79
+ input_size_per_partition: int,
80
+ output_partition_sizes: List[int],
81
+ input_size: int,
82
+ output_size: int,
83
+ params_dtype: torch.dtype,
84
+ **extra_weight_attrs
85
+ ):
86
+
87
+ weight_loader = extra_weight_attrs.get("weight_loader")
88
+ self.logical_widths = output_partition_sizes
89
+
90
+ weight = ModelWeightParameter(
91
+ data=torch.empty(
92
+ sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8
93
+ ),
94
+ input_dim=1,
95
+ output_dim=0,
96
+ weight_loader=weight_loader,
97
+ )
98
+ layer.register_parameter("weight", weight)
99
+
100
+ weight_scale = ChannelQuantScaleParameter(
101
+ data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
102
+ output_dim=0,
103
+ weight_loader=weight_loader,
104
+ )
105
+ layer.register_parameter("weight_scale", weight_scale)
106
+
107
+ def apply(
108
+ self,
109
+ layer: torch.nn.Module,
110
+ x: torch.Tensor,
111
+ bias: Optional[torch.Tensor] = None,
112
+ ):
113
+ x_q, x_scale = per_token_quant_int8(x)
114
+
115
+ return int8_scaled_mm(
116
+ x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
117
+ )
@@ -47,6 +47,8 @@ class RadixAttention(nn.Module):
47
47
  self.logit_cap = logit_cap
48
48
  self.sliding_window_size = sliding_window_size or -1
49
49
  self.is_cross_attention = is_cross_attention
50
+ self.k_scale = 1.0
51
+ self.v_scale = 1.0
50
52
 
51
53
  def forward(
52
54
  self,
@@ -220,6 +220,7 @@ class VocabParallelEmbedding(torch.nn.Module):
220
220
  quant_config: Optional[QuantizationConfig] = None,
221
221
  prefix: str = "",
222
222
  enable_tp: bool = True,
223
+ use_presharded_weights: bool = False,
223
224
  ):
224
225
  super().__init__()
225
226
  self.quant_config = quant_config
@@ -236,6 +237,12 @@ class VocabParallelEmbedding(torch.nn.Module):
236
237
  self.padding_size = padding_size
237
238
  self.org_vocab_size = org_num_embeddings or num_embeddings
238
239
  num_added_embeddings = num_embeddings - self.org_vocab_size
240
+ self.use_presharded_weights = use_presharded_weights
241
+ if use_presharded_weights:
242
+ assert (
243
+ num_added_embeddings == 0
244
+ ), "Lora is not supported with presharded weights."
245
+
239
246
  self.org_vocab_size_padded = pad_vocab_size(
240
247
  self.org_vocab_size, self.padding_size
241
248
  )
@@ -447,10 +454,14 @@ class VocabParallelEmbedding(torch.nn.Module):
447
454
  start_idx = start_idx // packed_factor
448
455
  shard_size = shard_size // packed_factor
449
456
  else:
450
- assert loaded_weight.shape[output_dim] == self.org_vocab_size
457
+ assert loaded_weight.shape[output_dim] == (
458
+ self.org_vocab_size
459
+ // (self.tp_size if self.use_presharded_weights else 1)
460
+ )
451
461
 
452
462
  # Copy the data.
453
- loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
463
+ if not self.use_presharded_weights:
464
+ loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
454
465
  param[: loaded_weight.shape[0]].data.copy_(loaded_weight)
455
466
  param[loaded_weight.shape[0] :].data.fill_(0)
456
467
 
@@ -514,6 +525,7 @@ class ParallelLMHead(VocabParallelEmbedding):
514
525
  padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
515
526
  quant_config: Optional[QuantizationConfig] = None,
516
527
  prefix: str = "",
528
+ use_presharded_weights: bool = False,
517
529
  ):
518
530
  super().__init__(
519
531
  num_embeddings,
@@ -523,6 +535,7 @@ class ParallelLMHead(VocabParallelEmbedding):
523
535
  padding_size,
524
536
  quant_config,
525
537
  prefix,
538
+ use_presharded_weights=use_presharded_weights,
526
539
  )
527
540
  self.quant_config = quant_config
528
541
  if bias:
@@ -0,0 +1,43 @@
1
+ """
2
+ Copyright 2023-2025 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ """
17
+ Configure the logging settings of a server.
18
+
19
+ Usage:
20
+ python3 -m sglang.srt.managers.configure_logging --url http://localhost:30000
21
+ """
22
+
23
+ import argparse
24
+
25
+ import requests
26
+
27
+ if __name__ == "__main__":
28
+ parser = argparse.ArgumentParser()
29
+ parser.add_argument("--url", type=str, default="http://localhost:30000")
30
+ parser.add_argument(
31
+ "--dump-requests-folder", type=str, default="/tmp/sglang_request_dump"
32
+ )
33
+ parser.add_argument("--dump-requests-threshold", type=int, default=1000)
34
+ args = parser.parse_args()
35
+
36
+ response = requests.post(
37
+ args.url + "/configure_logging",
38
+ json={
39
+ "dump_requests_folder": args.dump_requests_folder,
40
+ "dump_requests_threshold": args.dump_requests_threshold,
41
+ },
42
+ )
43
+ assert response.status_code == 200
@@ -181,8 +181,6 @@ class DetokenizerManager:
181
181
  finished_reasons=recv_obj.finished_reasons,
182
182
  output_strs=output_strs,
183
183
  prompt_tokens=recv_obj.prompt_tokens,
184
- origin_input_ids=recv_obj.origin_input_ids,
185
- output_ids=recv_obj.output_ids,
186
184
  completion_tokens=recv_obj.completion_tokens,
187
185
  cached_tokens=recv_obj.cached_tokens,
188
186
  input_token_logprobs_val=recv_obj.input_token_logprobs_val,
@@ -19,9 +19,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
19
19
  import uuid
20
20
  from dataclasses import dataclass
21
21
  from enum import Enum
22
- from typing import Dict, List, Optional, Tuple, Union
23
-
24
- import torch
22
+ from typing import Dict, List, Optional, Union
25
23
 
26
24
  from sglang.srt.managers.schedule_batch import BaseFinishReason
27
25
  from sglang.srt.sampling.sampling_params import SamplingParams
@@ -323,9 +321,7 @@ class BatchTokenIDOut:
323
321
  decoded_texts: List[str]
324
322
  decode_ids: List[int]
325
323
  read_offsets: List[int]
326
- # Only used when --return-token-ids` is set
327
- origin_input_ids: Optional[List[int]]
328
- # Only used when `--skip-tokenizer-init` or `--return-token-ids` is set
324
+ # Only used when `--skip-tokenizer-init` is on
329
325
  output_ids: Optional[List[int]]
330
326
  # Detokenization configs
331
327
  skip_special_tokens: List[bool]
@@ -356,14 +352,7 @@ class BatchStrOut:
356
352
  # The output decoded strings
357
353
  output_strs: List[str]
358
354
 
359
- # The token ids
360
- origin_input_ids: Optional[List[int]]
361
- output_ids: Optional[List[int]]
362
-
363
355
  # Token counts
364
- # real input and output tokens can be get from
365
- # origin_input_ids and output_ids by enabling --return_token_ids
366
- # TODO (Shuai): Rename this to clarify the meaning.
367
356
  prompt_tokens: List[int]
368
357
  completion_tokens: List[int]
369
358
  cached_tokens: List[int]
@@ -468,6 +457,26 @@ class GetWeightsByNameReqOutput:
468
457
  parameter: list
469
458
 
470
459
 
460
+ @dataclass
461
+ class ReleaseMemoryOccupationReqInput:
462
+ pass
463
+
464
+
465
+ @dataclass
466
+ class ReleaseMemoryOccupationReqOutput:
467
+ pass
468
+
469
+
470
+ @dataclass
471
+ class ResumeMemoryOccupationReqInput:
472
+ pass
473
+
474
+
475
+ @dataclass
476
+ class ResumeMemoryOccupationReqOutput:
477
+ pass
478
+
479
+
471
480
  @dataclass
472
481
  class AbortReq:
473
482
  # The request id
@@ -479,6 +488,13 @@ class ProfileReq(Enum):
479
488
  STOP_PROFILE = 2
480
489
 
481
490
 
491
+ @dataclass
492
+ class ConfigureLoggingReq:
493
+ log_requests: Optional[bool] = None
494
+ dump_requests_folder: Optional[str] = None
495
+ dump_requests_threshold: Optional[int] = None
496
+
497
+
482
498
  @dataclass
483
499
  class OpenSessionReqInput:
484
500
  capacity_of_str_len: int