sglang 0.3.4__py3-none-any.whl → 0.3.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +2 -1
- sglang/lang/chat_template.py +17 -0
- sglang/launch_server_llavavid.py +1 -1
- sglang/srt/configs/__init__.py +3 -0
- sglang/srt/configs/model_config.py +27 -2
- sglang/srt/configs/qwen2vl.py +133 -0
- sglang/srt/constrained/fsm_cache.py +10 -3
- sglang/srt/conversation.py +27 -0
- sglang/srt/hf_transformers_utils.py +16 -1
- sglang/srt/layers/attention/__init__.py +16 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +22 -6
- sglang/srt/layers/attention/flashinfer_backend.py +174 -54
- sglang/srt/layers/attention/triton_backend.py +22 -6
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +26 -4
- sglang/srt/layers/linear.py +89 -63
- sglang/srt/layers/logits_processor.py +5 -5
- sglang/srt/layers/rotary_embedding.py +112 -0
- sglang/srt/layers/sampler.py +51 -39
- sglang/srt/lora/lora.py +3 -1
- sglang/srt/managers/data_parallel_controller.py +1 -1
- sglang/srt/managers/detokenizer_manager.py +4 -0
- sglang/srt/managers/image_processor.py +186 -13
- sglang/srt/managers/io_struct.py +10 -0
- sglang/srt/managers/schedule_batch.py +238 -68
- sglang/srt/managers/scheduler.py +69 -50
- sglang/srt/managers/tokenizer_manager.py +24 -4
- sglang/srt/managers/tp_worker.py +26 -111
- sglang/srt/managers/tp_worker_overlap_thread.py +209 -0
- sglang/srt/mem_cache/memory_pool.py +56 -10
- sglang/srt/mem_cache/radix_cache.py +4 -3
- sglang/srt/model_executor/cuda_graph_runner.py +87 -28
- sglang/srt/model_executor/forward_batch_info.py +83 -3
- sglang/srt/model_executor/model_runner.py +32 -11
- sglang/srt/models/chatglm.py +3 -3
- sglang/srt/models/deepseek_v2.py +2 -2
- sglang/srt/models/mllama.py +1004 -0
- sglang/srt/models/qwen2_vl.py +724 -0
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
- sglang/srt/sampling/sampling_batch_info.py +13 -3
- sglang/srt/sampling/sampling_params.py +5 -7
- sglang/srt/server.py +12 -0
- sglang/srt/server_args.py +10 -0
- sglang/srt/utils.py +22 -0
- sglang/test/run_eval.py +2 -0
- sglang/test/runners.py +20 -1
- sglang/test/srt/sampling/penaltylib/utils.py +1 -0
- sglang/test/test_utils.py +100 -3
- sglang/version.py +1 -1
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/METADATA +17 -18
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/RECORD +53 -48
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/LICENSE +0 -0
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/WHEEL +0 -0
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/top_level.txt +0 -0
sglang/srt/layers/linear.py
CHANGED
@@ -20,8 +20,10 @@ from vllm.distributed import (
|
|
20
20
|
from vllm.model_executor.layers.linear import LinearBase
|
21
21
|
from vllm.model_executor.parameter import (
|
22
22
|
BasevLLMParameter,
|
23
|
+
PackedColumnParameter,
|
23
24
|
PackedvLLMParameter,
|
24
25
|
PerTensorScaleParameter,
|
26
|
+
RowvLLMParameter,
|
25
27
|
)
|
26
28
|
|
27
29
|
from sglang.srt.layers.quantization.base_config import (
|
@@ -39,6 +41,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
|
39
41
|
"GPTQMarlinLinearMethod",
|
40
42
|
"Fp8LinearMethod",
|
41
43
|
"MarlinLinearMethod",
|
44
|
+
"GPTQLinearMethod",
|
42
45
|
]
|
43
46
|
|
44
47
|
|
@@ -50,7 +53,7 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
|
|
50
53
|
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
|
51
54
|
|
52
55
|
|
53
|
-
def
|
56
|
+
def adjust_bitsandbytes_4bit_shard(
|
54
57
|
param: Parameter, qkv_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str
|
55
58
|
) -> Tuple[int, int]:
|
56
59
|
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
|
@@ -207,7 +210,6 @@ class ReplicatedLinear(LinearBase):
|
|
207
210
|
self.output_size,
|
208
211
|
self.params_dtype,
|
209
212
|
weight_loader=self.weight_loader,
|
210
|
-
prefix=prefix,
|
211
213
|
)
|
212
214
|
|
213
215
|
if bias:
|
@@ -315,7 +317,6 @@ class ColumnParallelLinear(LinearBase):
|
|
315
317
|
if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
|
316
318
|
else self.weight_loader
|
317
319
|
),
|
318
|
-
prefix=prefix,
|
319
320
|
)
|
320
321
|
if bias:
|
321
322
|
self.bias = Parameter(
|
@@ -345,8 +346,12 @@ class ColumnParallelLinear(LinearBase):
|
|
345
346
|
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
346
347
|
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
|
347
348
|
|
349
|
+
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
350
|
+
|
348
351
|
param_data = param.data
|
349
|
-
|
352
|
+
# bitsandbytes loads the weights of the specific portion
|
353
|
+
# no need to narrow here
|
354
|
+
if output_dim is not None and not use_bitsandbytes_4bit:
|
350
355
|
shard_size = param_data.shape[output_dim]
|
351
356
|
start_idx = tp_rank * shard_size
|
352
357
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
@@ -454,17 +459,22 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
454
459
|
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
|
455
460
|
return
|
456
461
|
|
457
|
-
if is_gguf_weight
|
458
|
-
|
462
|
+
if is_gguf_weight:
|
463
|
+
tp_size = get_tensor_model_parallel_world_size()
|
464
|
+
tp_rank = get_tensor_model_parallel_rank()
|
465
|
+
|
466
|
+
output_dim = getattr(param, "output_dim", None)
|
467
|
+
shard_size = loaded_weight.size(output_dim) // tp_size
|
468
|
+
start_idx = tp_rank * shard_size
|
459
469
|
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
470
|
+
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
471
|
+
|
472
|
+
param.shard_id.append(loaded_shard_id)
|
473
|
+
param.shard_id_map[loaded_shard_id] = len(param.data_container)
|
474
|
+
param.data_container.append(loaded_weight)
|
475
|
+
if len(param.data_container) == 2:
|
476
|
+
self.qweight = param.materialize_nested()
|
477
|
+
return
|
468
478
|
|
469
479
|
param_data = param.data
|
470
480
|
output_dim = getattr(param, "output_dim", None)
|
@@ -526,26 +536,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
526
536
|
param, shard_size, shard_offset
|
527
537
|
)
|
528
538
|
|
529
|
-
|
530
|
-
if
|
539
|
+
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
540
|
+
if use_bitsandbytes_4bit:
|
531
541
|
shard_size = loaded_weight.shape[output_dim]
|
532
542
|
shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
|
533
543
|
|
534
|
-
if is_gguf_weight:
|
535
|
-
tp_size = get_tensor_model_parallel_world_size()
|
536
|
-
output_dim = getattr(param, "output_dim", None)
|
537
|
-
shard_shape = list(loaded_weight.shape)
|
538
|
-
shard_shape[output_dim] = shard_shape[output_dim] // tp_size
|
539
|
-
param.shard_id.append(loaded_shard_id)
|
540
|
-
param.shard_size[loaded_shard_id] = shard_shape
|
541
|
-
|
542
|
-
input_dim = getattr(param, "input_dim", None)
|
543
|
-
input_size = loaded_weight.shape[input_dim]
|
544
|
-
param_data = param_data.narrow(input_dim, 0, input_size)
|
545
|
-
|
546
544
|
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
|
547
545
|
start_idx = tp_rank * shard_size
|
548
|
-
|
546
|
+
# bitsandbytes loads the weights of the specific portion
|
547
|
+
# no need to narrow here
|
548
|
+
if not use_bitsandbytes_4bit:
|
549
|
+
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
549
550
|
# Special case for AQLM codebooks.
|
550
551
|
elif is_metadata:
|
551
552
|
# metadata indicates fixed size concatenated along dim 0
|
@@ -595,7 +596,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
595
596
|
# If quantized, we need to adjust the offset and size to account
|
596
597
|
# for the packing.
|
597
598
|
if (
|
598
|
-
isinstance(param, PackedvLLMParameter)
|
599
|
+
isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
|
599
600
|
and param.packed_dim == param.output_dim
|
600
601
|
):
|
601
602
|
shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
|
@@ -617,7 +618,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
617
618
|
if isinstance(param, PerTensorScaleParameter):
|
618
619
|
param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
|
619
620
|
return
|
620
|
-
elif type(param)
|
621
|
+
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
|
621
622
|
param.load_merged_column_weight(loaded_weight=loaded_weight)
|
622
623
|
return
|
623
624
|
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
@@ -760,7 +761,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
760
761
|
# If quantized, we need to adjust the offset and size to account
|
761
762
|
# for the packing.
|
762
763
|
if (
|
763
|
-
isinstance(param, PackedvLLMParameter)
|
764
|
+
isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
|
764
765
|
and param.packed_dim == param.output_dim
|
765
766
|
):
|
766
767
|
shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
|
@@ -780,10 +781,10 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
780
781
|
):
|
781
782
|
if loaded_shard_id is None: # special case for certain models
|
782
783
|
if isinstance(param, PerTensorScaleParameter):
|
783
|
-
param.
|
784
|
+
param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0)
|
784
785
|
return
|
785
|
-
elif type(param)
|
786
|
-
param.
|
786
|
+
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
|
787
|
+
param.load_qkv_weight(loaded_weight=loaded_weight)
|
787
788
|
return
|
788
789
|
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
789
790
|
return
|
@@ -818,17 +819,22 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
818
819
|
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
|
819
820
|
return
|
820
821
|
|
821
|
-
if is_gguf_weight
|
822
|
-
|
822
|
+
if is_gguf_weight:
|
823
|
+
tp_size = get_tensor_model_parallel_world_size()
|
824
|
+
tp_rank = get_tensor_model_parallel_rank()
|
825
|
+
|
826
|
+
output_dim = getattr(param, "output_dim", None)
|
827
|
+
shard_size = loaded_weight.size(output_dim) // tp_size
|
828
|
+
start_idx = tp_rank * shard_size
|
823
829
|
|
824
|
-
|
825
|
-
|
826
|
-
|
827
|
-
|
828
|
-
|
829
|
-
|
830
|
-
|
831
|
-
|
830
|
+
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
831
|
+
|
832
|
+
param.shard_id.append(loaded_shard_id)
|
833
|
+
param.shard_id_map[loaded_shard_id] = len(param.data_container)
|
834
|
+
param.data_container.append(loaded_weight)
|
835
|
+
if len(param.data_container) == 3:
|
836
|
+
self.qweight = param.materialize_nested()
|
837
|
+
return
|
832
838
|
|
833
839
|
param_data = param.data
|
834
840
|
output_dim = getattr(param, "output_dim", None)
|
@@ -863,6 +869,8 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
863
869
|
self.total_num_kv_heads * self.head_size,
|
864
870
|
),
|
865
871
|
]
|
872
|
+
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
873
|
+
|
866
874
|
packed_dim = getattr(param, "packed_dim", None)
|
867
875
|
for shard_id, shard_offset, shard_size in shard_offsets:
|
868
876
|
# Special case for Quantized Weights.
|
@@ -877,6 +885,29 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
877
885
|
param, shard_size, shard_offset
|
878
886
|
)
|
879
887
|
|
888
|
+
if use_bitsandbytes_4bit:
|
889
|
+
orig_qkv_offsets = {
|
890
|
+
"q": (0, self.total_num_heads * self.head_size),
|
891
|
+
"k": (
|
892
|
+
self.total_num_heads * self.head_size,
|
893
|
+
self.total_num_kv_heads * self.head_size,
|
894
|
+
),
|
895
|
+
"v": (
|
896
|
+
(self.total_num_heads + self.total_num_kv_heads)
|
897
|
+
* self.head_size,
|
898
|
+
self.total_num_kv_heads * self.head_size,
|
899
|
+
),
|
900
|
+
"total": (
|
901
|
+
(self.total_num_heads + 2 * self.total_num_kv_heads)
|
902
|
+
* self.head_size,
|
903
|
+
0,
|
904
|
+
),
|
905
|
+
}
|
906
|
+
|
907
|
+
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
|
908
|
+
param, orig_qkv_offsets, shard_id
|
909
|
+
)
|
910
|
+
|
880
911
|
loaded_weight_shard = loaded_weight.narrow(
|
881
912
|
output_dim, shard_offset, shard_size
|
882
913
|
)
|
@@ -910,8 +941,8 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
910
941
|
param, shard_size, shard_offset
|
911
942
|
)
|
912
943
|
|
913
|
-
|
914
|
-
if
|
944
|
+
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
945
|
+
if use_bitsandbytes_4bit:
|
915
946
|
orig_qkv_offsets = {
|
916
947
|
"q": (0, self.num_heads * self.head_size),
|
917
948
|
"k": (
|
@@ -927,29 +958,22 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
927
958
|
0,
|
928
959
|
),
|
929
960
|
}
|
930
|
-
shard_size, shard_offset =
|
961
|
+
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
|
931
962
|
param, orig_qkv_offsets, loaded_shard_id
|
932
963
|
)
|
933
964
|
|
934
|
-
if is_gguf_weight:
|
935
|
-
tp_size = get_tensor_model_parallel_world_size()
|
936
|
-
output_dim = getattr(param, "output_dim", None)
|
937
|
-
shard_shape = list(loaded_weight.shape)
|
938
|
-
shard_shape[output_dim] = shard_shape[output_dim] // tp_size
|
939
|
-
param.shard_id.append(loaded_shard_id)
|
940
|
-
param.shard_size[loaded_shard_id] = shard_shape
|
941
|
-
|
942
|
-
input_dim = getattr(param, "input_dim", None)
|
943
|
-
input_size = loaded_weight.shape[input_dim]
|
944
|
-
param_data = param_data.narrow(input_dim, 0, input_size)
|
945
|
-
|
946
965
|
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
|
947
966
|
if loaded_shard_id == "q":
|
948
967
|
shard_id = tp_rank
|
949
968
|
else:
|
950
969
|
shard_id = tp_rank // self.num_kv_head_replicas
|
951
970
|
start_idx = shard_id * shard_size
|
952
|
-
|
971
|
+
|
972
|
+
# bitsandbytes loads the weights of the specific portion
|
973
|
+
# no need to narrow here
|
974
|
+
if not use_bitsandbytes_4bit:
|
975
|
+
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
976
|
+
|
953
977
|
# Special case for for AQLM codebooks.
|
954
978
|
elif is_metadata:
|
955
979
|
# metadata indicates fixed size concatenated along dim 0
|
@@ -1037,7 +1061,6 @@ class RowParallelLinear(LinearBase):
|
|
1037
1061
|
if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED
|
1038
1062
|
else self.weight_loader
|
1039
1063
|
),
|
1040
|
-
prefix=prefix,
|
1041
1064
|
)
|
1042
1065
|
if not reduce_results and (bias and not skip_bias_add):
|
1043
1066
|
raise ValueError(
|
@@ -1061,6 +1084,7 @@ class RowParallelLinear(LinearBase):
|
|
1061
1084
|
tp_rank = get_tensor_model_parallel_rank()
|
1062
1085
|
tp_size = get_tensor_model_parallel_world_size()
|
1063
1086
|
input_dim = getattr(param, "input_dim", None)
|
1087
|
+
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
1064
1088
|
|
1065
1089
|
# Special case for GGUF
|
1066
1090
|
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
@@ -1076,7 +1100,9 @@ class RowParallelLinear(LinearBase):
|
|
1076
1100
|
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
|
1077
1101
|
|
1078
1102
|
param_data = param.data
|
1079
|
-
|
1103
|
+
# bitsandbytes loads the weights of the specific portion
|
1104
|
+
# no need to narrow here
|
1105
|
+
if input_dim is not None and not use_bitsandbytes_4bit:
|
1080
1106
|
shard_size = param_data.shape[input_dim]
|
1081
1107
|
start_idx = tp_rank * shard_size
|
1082
1108
|
loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
|
@@ -33,17 +33,17 @@ class LogitsProcessorOutput:
|
|
33
33
|
# The logits of the next tokens. shape: [#seq, vocab_size]
|
34
34
|
next_token_logits: torch.Tensor
|
35
35
|
# The logprobs of the next tokens. shape: [#seq, vocab_size]
|
36
|
-
next_token_logprobs: torch.Tensor
|
36
|
+
next_token_logprobs: torch.Tensor = None
|
37
37
|
|
38
38
|
# The normlaized logprobs of prompts. shape: [#seq]
|
39
|
-
normalized_prompt_logprobs: torch.Tensor
|
39
|
+
normalized_prompt_logprobs: torch.Tensor = None
|
40
40
|
# The logprobs of input tokens. shape: [#token, vocab_size]
|
41
|
-
input_token_logprobs: torch.Tensor
|
41
|
+
input_token_logprobs: torch.Tensor = None
|
42
42
|
|
43
43
|
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
44
|
-
input_top_logprobs: List
|
44
|
+
input_top_logprobs: List = None
|
45
45
|
# The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
46
|
-
output_top_logprobs: List
|
46
|
+
output_top_logprobs: List = None
|
47
47
|
|
48
48
|
|
49
49
|
@dataclasses.dataclass
|
@@ -0,0 +1,112 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
Unless required by applicable law or agreed to in writing, software
|
8
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
9
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
10
|
+
See the License for the specific language governing permissions and
|
11
|
+
limitations under the License.
|
12
|
+
"""
|
13
|
+
|
14
|
+
"""MRotaryEmbedding"""
|
15
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
16
|
+
|
17
|
+
import torch
|
18
|
+
|
19
|
+
|
20
|
+
class MRotaryEmbedding:
|
21
|
+
"""Rotary Embedding with Multimodal Sections."""
|
22
|
+
|
23
|
+
@staticmethod
|
24
|
+
def get_input_positions(
|
25
|
+
input_tokens: torch.Tensor,
|
26
|
+
image_grid_thw: Union[List[List[int]], torch.Tensor],
|
27
|
+
vision_start_token_id: int,
|
28
|
+
spatial_merge_size: int,
|
29
|
+
context_len: int = 0,
|
30
|
+
) -> Tuple[List[List[int]], int]:
|
31
|
+
"""Get mrope input positions and delta value."""
|
32
|
+
|
33
|
+
if isinstance(image_grid_thw, torch.Tensor):
|
34
|
+
image_grid_thw = image_grid_thw.tolist()
|
35
|
+
|
36
|
+
vision_start_indices = torch.argwhere(
|
37
|
+
input_tokens == vision_start_token_id
|
38
|
+
).squeeze(1)
|
39
|
+
image_indices = vision_start_indices + 1
|
40
|
+
image_nums = image_indices.shape[0]
|
41
|
+
llm_pos_ids_list: list = []
|
42
|
+
|
43
|
+
st = 0
|
44
|
+
input_tokens_len = input_tokens.shape[0]
|
45
|
+
for image_index in range(image_nums):
|
46
|
+
ed = image_indices[image_index].item()
|
47
|
+
t, h, w = (
|
48
|
+
image_grid_thw[image_index][0],
|
49
|
+
image_grid_thw[image_index][1],
|
50
|
+
image_grid_thw[image_index][2],
|
51
|
+
)
|
52
|
+
llm_grid_t, llm_grid_h, llm_grid_w = (
|
53
|
+
t,
|
54
|
+
h // spatial_merge_size,
|
55
|
+
w // spatial_merge_size,
|
56
|
+
)
|
57
|
+
text_len = ed - st
|
58
|
+
|
59
|
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
60
|
+
llm_pos_ids_list.append(
|
61
|
+
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
62
|
+
)
|
63
|
+
|
64
|
+
t_index = (
|
65
|
+
torch.arange(llm_grid_t)
|
66
|
+
.view(-1, 1)
|
67
|
+
.expand(-1, llm_grid_h * llm_grid_w)
|
68
|
+
.flatten()
|
69
|
+
)
|
70
|
+
h_index = (
|
71
|
+
torch.arange(llm_grid_h)
|
72
|
+
.view(1, -1, 1)
|
73
|
+
.expand(llm_grid_t, -1, llm_grid_w)
|
74
|
+
.flatten()
|
75
|
+
)
|
76
|
+
w_index = (
|
77
|
+
torch.arange(llm_grid_w)
|
78
|
+
.view(1, 1, -1)
|
79
|
+
.expand(llm_grid_t, llm_grid_h, -1)
|
80
|
+
.flatten()
|
81
|
+
)
|
82
|
+
llm_pos_ids_list.append(
|
83
|
+
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
|
84
|
+
)
|
85
|
+
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
86
|
+
|
87
|
+
if st < input_tokens_len:
|
88
|
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
89
|
+
text_len = input_tokens_len - st
|
90
|
+
llm_pos_ids_list.append(
|
91
|
+
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
92
|
+
)
|
93
|
+
|
94
|
+
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
95
|
+
llm_positions = llm_positions[:, context_len:]
|
96
|
+
mrope_position_delta = (llm_positions.max() + 1 - input_tokens_len).item()
|
97
|
+
return llm_positions.tolist(), mrope_position_delta
|
98
|
+
|
99
|
+
@staticmethod
|
100
|
+
def get_next_input_positions(
|
101
|
+
mrope_position_delta: int,
|
102
|
+
context_len: int,
|
103
|
+
seq_len: int,
|
104
|
+
) -> List[List[int]]:
|
105
|
+
return [
|
106
|
+
list(
|
107
|
+
range(
|
108
|
+
context_len + mrope_position_delta, seq_len + mrope_position_delta
|
109
|
+
)
|
110
|
+
)
|
111
|
+
for _ in range(3)
|
112
|
+
]
|
sglang/srt/layers/sampler.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
import logging
|
2
|
+
import os
|
2
3
|
from typing import Union
|
3
4
|
|
4
5
|
import torch
|
@@ -17,6 +18,11 @@ if is_flashinfer_available():
|
|
17
18
|
top_p_renorm_prob,
|
18
19
|
)
|
19
20
|
|
21
|
+
|
22
|
+
# Crash on warning if we are running CI tests
|
23
|
+
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
24
|
+
|
25
|
+
|
20
26
|
logger = logging.getLogger(__name__)
|
21
27
|
|
22
28
|
|
@@ -33,56 +39,62 @@ class Sampler(nn.Module):
|
|
33
39
|
if isinstance(logits, LogitsProcessorOutput):
|
34
40
|
logits = logits.next_token_logits
|
35
41
|
|
36
|
-
# Post process logits
|
37
42
|
logits = logits.contiguous()
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
if self.use_nan_detectioin and torch.any(torch.isnan(probs)):
|
44
|
-
logger.warning("Detected errors during sampling! NaN in the probability.")
|
45
|
-
probs = torch.where(
|
46
|
-
torch.isnan(probs), torch.full_like(probs, 1e-10), probs
|
43
|
+
|
44
|
+
if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
|
45
|
+
logger.warning("Detected errors during sampling! NaN in the logits.")
|
46
|
+
logits = torch.where(
|
47
|
+
torch.isnan(logits), torch.full_like(logits, -1e5), logits
|
47
48
|
)
|
49
|
+
exit(1) if crash_on_warning else None
|
48
50
|
|
49
51
|
if sampling_info.is_all_greedy:
|
50
52
|
# Use torch.argmax if all requests use greedy sampling
|
51
|
-
batch_next_token_ids = torch.argmax(
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
53
|
+
batch_next_token_ids = torch.argmax(logits, -1)
|
54
|
+
else:
|
55
|
+
# Post process logits
|
56
|
+
logits.div_(sampling_info.temperatures)
|
57
|
+
probs = torch.softmax(logits, dim=-1)
|
58
|
+
logits = None
|
59
|
+
del logits
|
60
|
+
|
61
|
+
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
62
|
+
max_top_k_round, batch_size = 32, probs.shape[0]
|
63
|
+
uniform_samples = torch.rand(
|
64
|
+
(max_top_k_round, batch_size), device=probs.device
|
62
65
|
)
|
63
|
-
|
64
|
-
|
66
|
+
if sampling_info.need_min_p_sampling:
|
67
|
+
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
|
68
|
+
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
|
69
|
+
batch_next_token_ids, success = min_p_sampling_from_probs(
|
70
|
+
probs, uniform_samples, sampling_info.min_ps
|
71
|
+
)
|
72
|
+
else:
|
73
|
+
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
74
|
+
probs,
|
75
|
+
uniform_samples,
|
76
|
+
sampling_info.top_ks,
|
77
|
+
sampling_info.top_ps,
|
78
|
+
filter_apply_order="joint",
|
79
|
+
)
|
80
|
+
|
81
|
+
if not torch.all(success):
|
82
|
+
logger.warning("Detected errors during sampling!")
|
83
|
+
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
|
84
|
+
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
85
|
+
# A slower fallback implementation with torch native operations.
|
86
|
+
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
65
87
|
probs,
|
66
|
-
uniform_samples,
|
67
88
|
sampling_info.top_ks,
|
68
89
|
sampling_info.top_ps,
|
69
|
-
|
90
|
+
sampling_info.min_ps,
|
91
|
+
)
|
92
|
+
else:
|
93
|
+
raise ValueError(
|
94
|
+
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
70
95
|
)
|
71
96
|
|
72
|
-
|
73
|
-
logger.warning("Detected errors during sampling!")
|
74
|
-
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
|
75
|
-
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
76
|
-
# Here we provide a slower fallback implementation.
|
77
|
-
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
78
|
-
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
|
79
|
-
)
|
80
|
-
else:
|
81
|
-
raise ValueError(
|
82
|
-
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
83
|
-
)
|
84
|
-
|
85
|
-
return batch_next_token_ids
|
97
|
+
return batch_next_token_ids.to(torch.int32)
|
86
98
|
|
87
99
|
|
88
100
|
def top_k_top_p_min_p_sampling_from_probs_torch(
|
sglang/srt/lora/lora.py
CHANGED
@@ -351,7 +351,9 @@ class LoRAAdapter(nn.Module):
|
|
351
351
|
loader = DefaultModelLoader(self.load_config)
|
352
352
|
revision = getattr(self.config.hf_config, "revision", None)
|
353
353
|
for name, loaded_weight in loader._get_weights_iterator(
|
354
|
-
|
354
|
+
DefaultModelLoader.Source(
|
355
|
+
model_path, revision=revision, fall_back_to_pt=True
|
356
|
+
)
|
355
357
|
):
|
356
358
|
match = re.search(r"layers\.(\d+)\.", name)
|
357
359
|
if match is not None:
|
@@ -27,6 +27,7 @@ from sglang.srt.managers.io_struct import (
|
|
27
27
|
BatchEmbeddingOut,
|
28
28
|
BatchStrOut,
|
29
29
|
BatchTokenIDOut,
|
30
|
+
GetMemPoolSizeReqOutput,
|
30
31
|
UpdateWeightReqOutput,
|
31
32
|
)
|
32
33
|
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
|
@@ -111,6 +112,9 @@ class DetokenizerManager:
|
|
111
112
|
# If it is a weight update request, no detokenization is needed.
|
112
113
|
self.send_to_tokenizer.send_pyobj(recv_obj)
|
113
114
|
continue
|
115
|
+
elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
|
116
|
+
self.send_to_tokenizer.send_pyobj(recv_obj)
|
117
|
+
continue
|
114
118
|
elif self.tokenizer is None:
|
115
119
|
# If the tokenizer is skipped, no detokenization is needed
|
116
120
|
self.send_to_tokenizer.send_pyobj(recv_obj)
|