sglang 0.3.4__py3-none-any.whl → 0.3.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. sglang/bench_latency.py +2 -1
  2. sglang/lang/chat_template.py +17 -0
  3. sglang/launch_server_llavavid.py +1 -1
  4. sglang/srt/configs/__init__.py +3 -0
  5. sglang/srt/configs/model_config.py +27 -2
  6. sglang/srt/configs/qwen2vl.py +133 -0
  7. sglang/srt/constrained/fsm_cache.py +10 -3
  8. sglang/srt/conversation.py +27 -0
  9. sglang/srt/hf_transformers_utils.py +16 -1
  10. sglang/srt/layers/attention/__init__.py +16 -5
  11. sglang/srt/layers/attention/double_sparsity_backend.py +22 -6
  12. sglang/srt/layers/attention/flashinfer_backend.py +174 -54
  13. sglang/srt/layers/attention/triton_backend.py +22 -6
  14. sglang/srt/layers/attention/triton_ops/prefill_attention.py +26 -4
  15. sglang/srt/layers/linear.py +89 -63
  16. sglang/srt/layers/logits_processor.py +5 -5
  17. sglang/srt/layers/rotary_embedding.py +112 -0
  18. sglang/srt/layers/sampler.py +51 -39
  19. sglang/srt/lora/lora.py +3 -1
  20. sglang/srt/managers/data_parallel_controller.py +1 -1
  21. sglang/srt/managers/detokenizer_manager.py +4 -0
  22. sglang/srt/managers/image_processor.py +186 -13
  23. sglang/srt/managers/io_struct.py +10 -0
  24. sglang/srt/managers/schedule_batch.py +238 -68
  25. sglang/srt/managers/scheduler.py +69 -50
  26. sglang/srt/managers/tokenizer_manager.py +24 -4
  27. sglang/srt/managers/tp_worker.py +26 -111
  28. sglang/srt/managers/tp_worker_overlap_thread.py +209 -0
  29. sglang/srt/mem_cache/memory_pool.py +56 -10
  30. sglang/srt/mem_cache/radix_cache.py +4 -3
  31. sglang/srt/model_executor/cuda_graph_runner.py +87 -28
  32. sglang/srt/model_executor/forward_batch_info.py +83 -3
  33. sglang/srt/model_executor/model_runner.py +32 -11
  34. sglang/srt/models/chatglm.py +3 -3
  35. sglang/srt/models/deepseek_v2.py +2 -2
  36. sglang/srt/models/mllama.py +1004 -0
  37. sglang/srt/models/qwen2_vl.py +724 -0
  38. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
  39. sglang/srt/sampling/sampling_batch_info.py +13 -3
  40. sglang/srt/sampling/sampling_params.py +5 -7
  41. sglang/srt/server.py +12 -0
  42. sglang/srt/server_args.py +10 -0
  43. sglang/srt/utils.py +22 -0
  44. sglang/test/run_eval.py +2 -0
  45. sglang/test/runners.py +20 -1
  46. sglang/test/srt/sampling/penaltylib/utils.py +1 -0
  47. sglang/test/test_utils.py +100 -3
  48. sglang/version.py +1 -1
  49. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/METADATA +17 -18
  50. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/RECORD +53 -48
  51. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/LICENSE +0 -0
  52. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/WHEEL +0 -0
  53. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/top_level.txt +0 -0
@@ -20,8 +20,10 @@ from vllm.distributed import (
20
20
  from vllm.model_executor.layers.linear import LinearBase
21
21
  from vllm.model_executor.parameter import (
22
22
  BasevLLMParameter,
23
+ PackedColumnParameter,
23
24
  PackedvLLMParameter,
24
25
  PerTensorScaleParameter,
26
+ RowvLLMParameter,
25
27
  )
26
28
 
27
29
  from sglang.srt.layers.quantization.base_config import (
@@ -39,6 +41,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
39
41
  "GPTQMarlinLinearMethod",
40
42
  "Fp8LinearMethod",
41
43
  "MarlinLinearMethod",
44
+ "GPTQLinearMethod",
42
45
  ]
43
46
 
44
47
 
@@ -50,7 +53,7 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
50
53
  return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
51
54
 
52
55
 
53
- def adjust_bitsandbytes_shard(
56
+ def adjust_bitsandbytes_4bit_shard(
54
57
  param: Parameter, qkv_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str
55
58
  ) -> Tuple[int, int]:
56
59
  """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
@@ -207,7 +210,6 @@ class ReplicatedLinear(LinearBase):
207
210
  self.output_size,
208
211
  self.params_dtype,
209
212
  weight_loader=self.weight_loader,
210
- prefix=prefix,
211
213
  )
212
214
 
213
215
  if bias:
@@ -315,7 +317,6 @@ class ColumnParallelLinear(LinearBase):
315
317
  if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
316
318
  else self.weight_loader
317
319
  ),
318
- prefix=prefix,
319
320
  )
320
321
  if bias:
321
322
  self.bias = Parameter(
@@ -345,8 +346,12 @@ class ColumnParallelLinear(LinearBase):
345
346
  if is_gguf_weight and isinstance(param, UninitializedParameter):
346
347
  param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
347
348
 
349
+ use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
350
+
348
351
  param_data = param.data
349
- if output_dim is not None:
352
+ # bitsandbytes loads the weights of the specific portion
353
+ # no need to narrow here
354
+ if output_dim is not None and not use_bitsandbytes_4bit:
350
355
  shard_size = param_data.shape[output_dim]
351
356
  start_idx = tp_rank * shard_size
352
357
  loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
@@ -454,17 +459,22 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
454
459
  param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
455
460
  return
456
461
 
457
- if is_gguf_weight and isinstance(param, UninitializedParameter):
458
- from gguf.constants import GGML_QUANT_SIZES
462
+ if is_gguf_weight:
463
+ tp_size = get_tensor_model_parallel_world_size()
464
+ tp_rank = get_tensor_model_parallel_rank()
465
+
466
+ output_dim = getattr(param, "output_dim", None)
467
+ shard_size = loaded_weight.size(output_dim) // tp_size
468
+ start_idx = tp_rank * shard_size
459
469
 
460
- ori_shape = param.tensor_shape
461
- weight_types = self.qweight_type.shard_weight_type.values()
462
- row_size = []
463
- for weight_type in weight_types:
464
- block_size, type_size = GGML_QUANT_SIZES[weight_type]
465
- row_size.append(ori_shape[1] // block_size * type_size)
466
- q_shape = (ori_shape[0], max(row_size))
467
- param.materialize(q_shape, dtype=loaded_weight.dtype)
470
+ loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
471
+
472
+ param.shard_id.append(loaded_shard_id)
473
+ param.shard_id_map[loaded_shard_id] = len(param.data_container)
474
+ param.data_container.append(loaded_weight)
475
+ if len(param.data_container) == 2:
476
+ self.qweight = param.materialize_nested()
477
+ return
468
478
 
469
479
  param_data = param.data
470
480
  output_dim = getattr(param, "output_dim", None)
@@ -526,26 +536,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
526
536
  param, shard_size, shard_offset
527
537
  )
528
538
 
529
- use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
530
- if use_bitsandbytes:
539
+ use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
540
+ if use_bitsandbytes_4bit:
531
541
  shard_size = loaded_weight.shape[output_dim]
532
542
  shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
533
543
 
534
- if is_gguf_weight:
535
- tp_size = get_tensor_model_parallel_world_size()
536
- output_dim = getattr(param, "output_dim", None)
537
- shard_shape = list(loaded_weight.shape)
538
- shard_shape[output_dim] = shard_shape[output_dim] // tp_size
539
- param.shard_id.append(loaded_shard_id)
540
- param.shard_size[loaded_shard_id] = shard_shape
541
-
542
- input_dim = getattr(param, "input_dim", None)
543
- input_size = loaded_weight.shape[input_dim]
544
- param_data = param_data.narrow(input_dim, 0, input_size)
545
-
546
544
  param_data = param_data.narrow(output_dim, shard_offset, shard_size)
547
545
  start_idx = tp_rank * shard_size
548
- loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
546
+ # bitsandbytes loads the weights of the specific portion
547
+ # no need to narrow here
548
+ if not use_bitsandbytes_4bit:
549
+ loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
549
550
  # Special case for AQLM codebooks.
550
551
  elif is_metadata:
551
552
  # metadata indicates fixed size concatenated along dim 0
@@ -595,7 +596,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
595
596
  # If quantized, we need to adjust the offset and size to account
596
597
  # for the packing.
597
598
  if (
598
- isinstance(param, PackedvLLMParameter)
599
+ isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
599
600
  and param.packed_dim == param.output_dim
600
601
  ):
601
602
  shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
@@ -617,7 +618,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
617
618
  if isinstance(param, PerTensorScaleParameter):
618
619
  param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
619
620
  return
620
- elif type(param) is BasevLLMParameter:
621
+ elif type(param) in (RowvLLMParameter, BasevLLMParameter):
621
622
  param.load_merged_column_weight(loaded_weight=loaded_weight)
622
623
  return
623
624
  self._load_fused_module_from_checkpoint(param, loaded_weight)
@@ -760,7 +761,7 @@ class QKVParallelLinear(ColumnParallelLinear):
760
761
  # If quantized, we need to adjust the offset and size to account
761
762
  # for the packing.
762
763
  if (
763
- isinstance(param, PackedvLLMParameter)
764
+ isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
764
765
  and param.packed_dim == param.output_dim
765
766
  ):
766
767
  shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
@@ -780,10 +781,10 @@ class QKVParallelLinear(ColumnParallelLinear):
780
781
  ):
781
782
  if loaded_shard_id is None: # special case for certain models
782
783
  if isinstance(param, PerTensorScaleParameter):
783
- param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
784
+ param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0)
784
785
  return
785
- elif type(param) is BasevLLMParameter:
786
- param.load_merged_column_weight(loaded_weight=loaded_weight)
786
+ elif type(param) in (RowvLLMParameter, BasevLLMParameter):
787
+ param.load_qkv_weight(loaded_weight=loaded_weight)
787
788
  return
788
789
  self._load_fused_module_from_checkpoint(param, loaded_weight)
789
790
  return
@@ -818,17 +819,22 @@ class QKVParallelLinear(ColumnParallelLinear):
818
819
  param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
819
820
  return
820
821
 
821
- if is_gguf_weight and isinstance(param, UninitializedParameter):
822
- from gguf.constants import GGML_QUANT_SIZES
822
+ if is_gguf_weight:
823
+ tp_size = get_tensor_model_parallel_world_size()
824
+ tp_rank = get_tensor_model_parallel_rank()
825
+
826
+ output_dim = getattr(param, "output_dim", None)
827
+ shard_size = loaded_weight.size(output_dim) // tp_size
828
+ start_idx = tp_rank * shard_size
823
829
 
824
- ori_shape = param.tensor_shape
825
- weight_types = self.qweight_type.shard_weight_type.values()
826
- row_size = []
827
- for weight_type in weight_types:
828
- block_size, type_size = GGML_QUANT_SIZES[weight_type]
829
- row_size.append(ori_shape[1] // block_size * type_size)
830
- q_shape = (ori_shape[0], max(row_size))
831
- param.materialize(q_shape, dtype=loaded_weight.dtype)
830
+ loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
831
+
832
+ param.shard_id.append(loaded_shard_id)
833
+ param.shard_id_map[loaded_shard_id] = len(param.data_container)
834
+ param.data_container.append(loaded_weight)
835
+ if len(param.data_container) == 3:
836
+ self.qweight = param.materialize_nested()
837
+ return
832
838
 
833
839
  param_data = param.data
834
840
  output_dim = getattr(param, "output_dim", None)
@@ -863,6 +869,8 @@ class QKVParallelLinear(ColumnParallelLinear):
863
869
  self.total_num_kv_heads * self.head_size,
864
870
  ),
865
871
  ]
872
+ use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
873
+
866
874
  packed_dim = getattr(param, "packed_dim", None)
867
875
  for shard_id, shard_offset, shard_size in shard_offsets:
868
876
  # Special case for Quantized Weights.
@@ -877,6 +885,29 @@ class QKVParallelLinear(ColumnParallelLinear):
877
885
  param, shard_size, shard_offset
878
886
  )
879
887
 
888
+ if use_bitsandbytes_4bit:
889
+ orig_qkv_offsets = {
890
+ "q": (0, self.total_num_heads * self.head_size),
891
+ "k": (
892
+ self.total_num_heads * self.head_size,
893
+ self.total_num_kv_heads * self.head_size,
894
+ ),
895
+ "v": (
896
+ (self.total_num_heads + self.total_num_kv_heads)
897
+ * self.head_size,
898
+ self.total_num_kv_heads * self.head_size,
899
+ ),
900
+ "total": (
901
+ (self.total_num_heads + 2 * self.total_num_kv_heads)
902
+ * self.head_size,
903
+ 0,
904
+ ),
905
+ }
906
+
907
+ shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
908
+ param, orig_qkv_offsets, shard_id
909
+ )
910
+
880
911
  loaded_weight_shard = loaded_weight.narrow(
881
912
  output_dim, shard_offset, shard_size
882
913
  )
@@ -910,8 +941,8 @@ class QKVParallelLinear(ColumnParallelLinear):
910
941
  param, shard_size, shard_offset
911
942
  )
912
943
 
913
- use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
914
- if use_bitsandbytes:
944
+ use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
945
+ if use_bitsandbytes_4bit:
915
946
  orig_qkv_offsets = {
916
947
  "q": (0, self.num_heads * self.head_size),
917
948
  "k": (
@@ -927,29 +958,22 @@ class QKVParallelLinear(ColumnParallelLinear):
927
958
  0,
928
959
  ),
929
960
  }
930
- shard_size, shard_offset = adjust_bitsandbytes_shard(
961
+ shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
931
962
  param, orig_qkv_offsets, loaded_shard_id
932
963
  )
933
964
 
934
- if is_gguf_weight:
935
- tp_size = get_tensor_model_parallel_world_size()
936
- output_dim = getattr(param, "output_dim", None)
937
- shard_shape = list(loaded_weight.shape)
938
- shard_shape[output_dim] = shard_shape[output_dim] // tp_size
939
- param.shard_id.append(loaded_shard_id)
940
- param.shard_size[loaded_shard_id] = shard_shape
941
-
942
- input_dim = getattr(param, "input_dim", None)
943
- input_size = loaded_weight.shape[input_dim]
944
- param_data = param_data.narrow(input_dim, 0, input_size)
945
-
946
965
  param_data = param_data.narrow(output_dim, shard_offset, shard_size)
947
966
  if loaded_shard_id == "q":
948
967
  shard_id = tp_rank
949
968
  else:
950
969
  shard_id = tp_rank // self.num_kv_head_replicas
951
970
  start_idx = shard_id * shard_size
952
- loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
971
+
972
+ # bitsandbytes loads the weights of the specific portion
973
+ # no need to narrow here
974
+ if not use_bitsandbytes_4bit:
975
+ loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
976
+
953
977
  # Special case for for AQLM codebooks.
954
978
  elif is_metadata:
955
979
  # metadata indicates fixed size concatenated along dim 0
@@ -1037,7 +1061,6 @@ class RowParallelLinear(LinearBase):
1037
1061
  if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
1038
1062
  else self.weight_loader
1039
1063
  ),
1040
- prefix=prefix,
1041
1064
  )
1042
1065
  if not reduce_results and (bias and not skip_bias_add):
1043
1066
  raise ValueError(
@@ -1061,6 +1084,7 @@ class RowParallelLinear(LinearBase):
1061
1084
  tp_rank = get_tensor_model_parallel_rank()
1062
1085
  tp_size = get_tensor_model_parallel_world_size()
1063
1086
  input_dim = getattr(param, "input_dim", None)
1087
+ use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
1064
1088
 
1065
1089
  # Special case for GGUF
1066
1090
  is_gguf_weight = getattr(param, "is_gguf_weight", False)
@@ -1076,7 +1100,9 @@ class RowParallelLinear(LinearBase):
1076
1100
  param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
1077
1101
 
1078
1102
  param_data = param.data
1079
- if input_dim is not None:
1103
+ # bitsandbytes loads the weights of the specific portion
1104
+ # no need to narrow here
1105
+ if input_dim is not None and not use_bitsandbytes_4bit:
1080
1106
  shard_size = param_data.shape[input_dim]
1081
1107
  start_idx = tp_rank * shard_size
1082
1108
  loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
@@ -33,17 +33,17 @@ class LogitsProcessorOutput:
33
33
  # The logits of the next tokens. shape: [#seq, vocab_size]
34
34
  next_token_logits: torch.Tensor
35
35
  # The logprobs of the next tokens. shape: [#seq, vocab_size]
36
- next_token_logprobs: torch.Tensor
36
+ next_token_logprobs: torch.Tensor = None
37
37
 
38
38
  # The normlaized logprobs of prompts. shape: [#seq]
39
- normalized_prompt_logprobs: torch.Tensor
39
+ normalized_prompt_logprobs: torch.Tensor = None
40
40
  # The logprobs of input tokens. shape: [#token, vocab_size]
41
- input_token_logprobs: torch.Tensor
41
+ input_token_logprobs: torch.Tensor = None
42
42
 
43
43
  # The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
44
- input_top_logprobs: List
44
+ input_top_logprobs: List = None
45
45
  # The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
46
- output_top_logprobs: List
46
+ output_top_logprobs: List = None
47
47
 
48
48
 
49
49
  @dataclasses.dataclass
@@ -0,0 +1,112 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+ http://www.apache.org/licenses/LICENSE-2.0
7
+ Unless required by applicable law or agreed to in writing, software
8
+ distributed under the License is distributed on an "AS IS" BASIS,
9
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ See the License for the specific language governing permissions and
11
+ limitations under the License.
12
+ """
13
+
14
+ """MRotaryEmbedding"""
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+
19
+
20
+ class MRotaryEmbedding:
21
+ """Rotary Embedding with Multimodal Sections."""
22
+
23
+ @staticmethod
24
+ def get_input_positions(
25
+ input_tokens: torch.Tensor,
26
+ image_grid_thw: Union[List[List[int]], torch.Tensor],
27
+ vision_start_token_id: int,
28
+ spatial_merge_size: int,
29
+ context_len: int = 0,
30
+ ) -> Tuple[List[List[int]], int]:
31
+ """Get mrope input positions and delta value."""
32
+
33
+ if isinstance(image_grid_thw, torch.Tensor):
34
+ image_grid_thw = image_grid_thw.tolist()
35
+
36
+ vision_start_indices = torch.argwhere(
37
+ input_tokens == vision_start_token_id
38
+ ).squeeze(1)
39
+ image_indices = vision_start_indices + 1
40
+ image_nums = image_indices.shape[0]
41
+ llm_pos_ids_list: list = []
42
+
43
+ st = 0
44
+ input_tokens_len = input_tokens.shape[0]
45
+ for image_index in range(image_nums):
46
+ ed = image_indices[image_index].item()
47
+ t, h, w = (
48
+ image_grid_thw[image_index][0],
49
+ image_grid_thw[image_index][1],
50
+ image_grid_thw[image_index][2],
51
+ )
52
+ llm_grid_t, llm_grid_h, llm_grid_w = (
53
+ t,
54
+ h // spatial_merge_size,
55
+ w // spatial_merge_size,
56
+ )
57
+ text_len = ed - st
58
+
59
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
60
+ llm_pos_ids_list.append(
61
+ torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
62
+ )
63
+
64
+ t_index = (
65
+ torch.arange(llm_grid_t)
66
+ .view(-1, 1)
67
+ .expand(-1, llm_grid_h * llm_grid_w)
68
+ .flatten()
69
+ )
70
+ h_index = (
71
+ torch.arange(llm_grid_h)
72
+ .view(1, -1, 1)
73
+ .expand(llm_grid_t, -1, llm_grid_w)
74
+ .flatten()
75
+ )
76
+ w_index = (
77
+ torch.arange(llm_grid_w)
78
+ .view(1, 1, -1)
79
+ .expand(llm_grid_t, llm_grid_h, -1)
80
+ .flatten()
81
+ )
82
+ llm_pos_ids_list.append(
83
+ torch.stack([t_index, h_index, w_index]) + text_len + st_idx
84
+ )
85
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
86
+
87
+ if st < input_tokens_len:
88
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
89
+ text_len = input_tokens_len - st
90
+ llm_pos_ids_list.append(
91
+ torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
92
+ )
93
+
94
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
95
+ llm_positions = llm_positions[:, context_len:]
96
+ mrope_position_delta = (llm_positions.max() + 1 - input_tokens_len).item()
97
+ return llm_positions.tolist(), mrope_position_delta
98
+
99
+ @staticmethod
100
+ def get_next_input_positions(
101
+ mrope_position_delta: int,
102
+ context_len: int,
103
+ seq_len: int,
104
+ ) -> List[List[int]]:
105
+ return [
106
+ list(
107
+ range(
108
+ context_len + mrope_position_delta, seq_len + mrope_position_delta
109
+ )
110
+ )
111
+ for _ in range(3)
112
+ ]
@@ -1,4 +1,5 @@
1
1
  import logging
2
+ import os
2
3
  from typing import Union
3
4
 
4
5
  import torch
@@ -17,6 +18,11 @@ if is_flashinfer_available():
17
18
  top_p_renorm_prob,
18
19
  )
19
20
 
21
+
22
+ # Crash on warning if we are running CI tests
23
+ crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
24
+
25
+
20
26
  logger = logging.getLogger(__name__)
21
27
 
22
28
 
@@ -33,56 +39,62 @@ class Sampler(nn.Module):
33
39
  if isinstance(logits, LogitsProcessorOutput):
34
40
  logits = logits.next_token_logits
35
41
 
36
- # Post process logits
37
42
  logits = logits.contiguous()
38
- logits.div_(sampling_info.temperatures)
39
- probs = torch.softmax(logits, dim=-1)
40
- logits = None
41
- del logits
42
-
43
- if self.use_nan_detectioin and torch.any(torch.isnan(probs)):
44
- logger.warning("Detected errors during sampling! NaN in the probability.")
45
- probs = torch.where(
46
- torch.isnan(probs), torch.full_like(probs, 1e-10), probs
43
+
44
+ if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
45
+ logger.warning("Detected errors during sampling! NaN in the logits.")
46
+ logits = torch.where(
47
+ torch.isnan(logits), torch.full_like(logits, -1e5), logits
47
48
  )
49
+ exit(1) if crash_on_warning else None
48
50
 
49
51
  if sampling_info.is_all_greedy:
50
52
  # Use torch.argmax if all requests use greedy sampling
51
- batch_next_token_ids = torch.argmax(probs, -1)
52
- elif global_server_args_dict["sampling_backend"] == "flashinfer":
53
- max_top_k_round, batch_size = 32, probs.shape[0]
54
- uniform_samples = torch.rand(
55
- (max_top_k_round, batch_size), device=probs.device
56
- )
57
- if sampling_info.need_min_p_sampling:
58
- probs = top_k_renorm_prob(probs, sampling_info.top_ks)
59
- probs = top_p_renorm_prob(probs, sampling_info.top_ps)
60
- batch_next_token_ids, success = min_p_sampling_from_probs(
61
- probs, uniform_samples, sampling_info.min_ps
53
+ batch_next_token_ids = torch.argmax(logits, -1)
54
+ else:
55
+ # Post process logits
56
+ logits.div_(sampling_info.temperatures)
57
+ probs = torch.softmax(logits, dim=-1)
58
+ logits = None
59
+ del logits
60
+
61
+ if global_server_args_dict["sampling_backend"] == "flashinfer":
62
+ max_top_k_round, batch_size = 32, probs.shape[0]
63
+ uniform_samples = torch.rand(
64
+ (max_top_k_round, batch_size), device=probs.device
62
65
  )
63
- else:
64
- batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
66
+ if sampling_info.need_min_p_sampling:
67
+ probs = top_k_renorm_prob(probs, sampling_info.top_ks)
68
+ probs = top_p_renorm_prob(probs, sampling_info.top_ps)
69
+ batch_next_token_ids, success = min_p_sampling_from_probs(
70
+ probs, uniform_samples, sampling_info.min_ps
71
+ )
72
+ else:
73
+ batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
74
+ probs,
75
+ uniform_samples,
76
+ sampling_info.top_ks,
77
+ sampling_info.top_ps,
78
+ filter_apply_order="joint",
79
+ )
80
+
81
+ if not torch.all(success):
82
+ logger.warning("Detected errors during sampling!")
83
+ batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
84
+ elif global_server_args_dict["sampling_backend"] == "pytorch":
85
+ # A slower fallback implementation with torch native operations.
86
+ batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
65
87
  probs,
66
- uniform_samples,
67
88
  sampling_info.top_ks,
68
89
  sampling_info.top_ps,
69
- filter_apply_order="joint",
90
+ sampling_info.min_ps,
91
+ )
92
+ else:
93
+ raise ValueError(
94
+ f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
70
95
  )
71
96
 
72
- if not torch.all(success):
73
- logger.warning("Detected errors during sampling!")
74
- batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
75
- elif global_server_args_dict["sampling_backend"] == "pytorch":
76
- # Here we provide a slower fallback implementation.
77
- batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
78
- probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
79
- )
80
- else:
81
- raise ValueError(
82
- f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
83
- )
84
-
85
- return batch_next_token_ids
97
+ return batch_next_token_ids.to(torch.int32)
86
98
 
87
99
 
88
100
  def top_k_top_p_min_p_sampling_from_probs_torch(
sglang/srt/lora/lora.py CHANGED
@@ -351,7 +351,9 @@ class LoRAAdapter(nn.Module):
351
351
  loader = DefaultModelLoader(self.load_config)
352
352
  revision = getattr(self.config.hf_config, "revision", None)
353
353
  for name, loaded_weight in loader._get_weights_iterator(
354
- model_path, revision=revision, fall_back_to_pt=True
354
+ DefaultModelLoader.Source(
355
+ model_path, revision=revision, fall_back_to_pt=True
356
+ )
355
357
  ):
356
358
  match = re.search(r"layers\.(\d+)\.", name)
357
359
  if match is not None:
@@ -156,7 +156,7 @@ class DataParallelController:
156
156
  else:
157
157
  # Send other control messages to all workers
158
158
  for worker in self.workers:
159
- worker.queue.put(recv_req)
159
+ worker.send_pyobj(recv_req)
160
160
 
161
161
 
162
162
  def run_data_parallel_controller_process(
@@ -27,6 +27,7 @@ from sglang.srt.managers.io_struct import (
27
27
  BatchEmbeddingOut,
28
28
  BatchStrOut,
29
29
  BatchTokenIDOut,
30
+ GetMemPoolSizeReqOutput,
30
31
  UpdateWeightReqOutput,
31
32
  )
32
33
  from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
@@ -111,6 +112,9 @@ class DetokenizerManager:
111
112
  # If it is a weight update request, no detokenization is needed.
112
113
  self.send_to_tokenizer.send_pyobj(recv_obj)
113
114
  continue
115
+ elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
116
+ self.send_to_tokenizer.send_pyobj(recv_obj)
117
+ continue
114
118
  elif self.tokenizer is None:
115
119
  # If the tokenizer is skipped, no detokenization is needed
116
120
  self.send_to_tokenizer.send_pyobj(recv_obj)