sglang 0.3.6__py3-none-any.whl → 0.3.6.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/api.py +2 -2
- sglang/bench_one_batch.py +2 -4
- sglang/bench_serving.py +75 -26
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +2 -2
- sglang/srt/configs/model_config.py +13 -14
- sglang/srt/constrained/__init__.py +13 -14
- sglang/srt/constrained/base_grammar_backend.py +13 -15
- sglang/srt/constrained/outlines_backend.py +13 -15
- sglang/srt/constrained/outlines_jump_forward.py +13 -15
- sglang/srt/constrained/xgrammar_backend.py +38 -57
- sglang/srt/conversation.py +13 -15
- sglang/srt/hf_transformers_utils.py +13 -15
- sglang/srt/layers/activation.py +13 -13
- sglang/srt/layers/attention/flashinfer_backend.py +13 -6
- sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
- sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
- sglang/srt/layers/custom_op_util.py +13 -14
- sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
- sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
- sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
- sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
- sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
- sglang/srt/layers/fused_moe_triton/layer.py +633 -0
- sglang/srt/layers/layernorm.py +13 -15
- sglang/srt/layers/logits_processor.py +13 -15
- sglang/srt/layers/quantization/__init__.py +77 -17
- sglang/srt/layers/radix_attention.py +13 -15
- sglang/srt/layers/rotary_embedding.py +13 -13
- sglang/srt/lora/lora.py +13 -14
- sglang/srt/lora/lora_config.py +13 -14
- sglang/srt/lora/lora_manager.py +22 -24
- sglang/srt/managers/data_parallel_controller.py +25 -19
- sglang/srt/managers/detokenizer_manager.py +13 -16
- sglang/srt/managers/io_struct.py +43 -28
- sglang/srt/managers/schedule_batch.py +55 -26
- sglang/srt/managers/schedule_policy.py +13 -15
- sglang/srt/managers/scheduler.py +89 -70
- sglang/srt/managers/session_controller.py +14 -15
- sglang/srt/managers/tokenizer_manager.py +29 -22
- sglang/srt/managers/tp_worker.py +13 -15
- sglang/srt/managers/tp_worker_overlap_thread.py +13 -15
- sglang/srt/metrics/collector.py +13 -15
- sglang/srt/metrics/func_timer.py +13 -15
- sglang/srt/mm_utils.py +13 -14
- sglang/srt/model_executor/cuda_graph_runner.py +20 -19
- sglang/srt/model_executor/forward_batch_info.py +19 -17
- sglang/srt/model_executor/model_runner.py +42 -30
- sglang/srt/models/chatglm.py +15 -16
- sglang/srt/models/commandr.py +15 -16
- sglang/srt/models/dbrx.py +15 -16
- sglang/srt/models/deepseek.py +15 -15
- sglang/srt/models/deepseek_v2.py +15 -15
- sglang/srt/models/exaone.py +14 -15
- sglang/srt/models/gemma.py +14 -14
- sglang/srt/models/gemma2.py +24 -19
- sglang/srt/models/gemma2_reward.py +13 -14
- sglang/srt/models/gpt_bigcode.py +14 -14
- sglang/srt/models/grok.py +15 -15
- sglang/srt/models/internlm2.py +13 -15
- sglang/srt/models/internlm2_reward.py +13 -14
- sglang/srt/models/llama.py +21 -21
- sglang/srt/models/llama_classification.py +13 -14
- sglang/srt/models/llama_reward.py +13 -14
- sglang/srt/models/llava.py +13 -15
- sglang/srt/models/llavavid.py +13 -15
- sglang/srt/models/minicpm.py +13 -15
- sglang/srt/models/minicpm3.py +13 -15
- sglang/srt/models/mistral.py +13 -15
- sglang/srt/models/mixtral.py +15 -15
- sglang/srt/models/mixtral_quant.py +14 -14
- sglang/srt/models/olmo.py +21 -19
- sglang/srt/models/olmoe.py +23 -20
- sglang/srt/models/qwen.py +14 -14
- sglang/srt/models/qwen2.py +22 -19
- sglang/srt/models/qwen2_moe.py +17 -18
- sglang/srt/models/stablelm.py +18 -16
- sglang/srt/models/torch_native_llama.py +15 -17
- sglang/srt/models/xverse.py +13 -14
- sglang/srt/models/xverse_moe.py +15 -16
- sglang/srt/models/yivl.py +13 -15
- sglang/srt/openai_api/adapter.py +13 -15
- sglang/srt/openai_api/protocol.py +13 -15
- sglang/srt/sampling/sampling_batch_info.py +4 -1
- sglang/srt/sampling/sampling_params.py +13 -15
- sglang/srt/server.py +59 -34
- sglang/srt/server_args.py +22 -22
- sglang/srt/utils.py +196 -17
- sglang/test/few_shot_gsm8k.py +8 -4
- sglang/test/runners.py +13 -14
- sglang/test/test_utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.3.6.dist-info → sglang-0.3.6.post1.dist-info}/LICENSE +1 -1
- {sglang-0.3.6.dist-info → sglang-0.3.6.post1.dist-info}/METADATA +24 -15
- sglang-0.3.6.post1.dist-info/RECORD +164 -0
- sglang/srt/layers/fused_moe/__init__.py +0 -1
- sglang-0.3.6.dist-info/RECORD +0 -161
- /sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +0 -0
- {sglang-0.3.6.dist-info → sglang-0.3.6.post1.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.dist-info → sglang-0.3.6.post1.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""TokenizerManager is a process that tokenizes the text."""
|
17
15
|
|
18
16
|
import asyncio
|
@@ -203,8 +201,18 @@ class TokenizerManager:
|
|
203
201
|
):
|
204
202
|
"""Tokenize one request."""
|
205
203
|
# Tokenize
|
204
|
+
input_embeds = None
|
206
205
|
input_text = obj.text
|
207
|
-
if obj.
|
206
|
+
if obj.input_embeds is not None:
|
207
|
+
if not self.server_args.disable_radix_cache:
|
208
|
+
raise ValueError(
|
209
|
+
"input_embeds is provided while disable_radix_cache is False. "
|
210
|
+
"Please add `--disable-radix-cach` when you launch the server "
|
211
|
+
"if you want to use input_embeds as inputs."
|
212
|
+
)
|
213
|
+
input_embeds = obj.input_embeds
|
214
|
+
input_ids = obj.input_ids
|
215
|
+
elif obj.input_ids is None:
|
208
216
|
input_ids = self.tokenizer.encode(input_text)
|
209
217
|
else:
|
210
218
|
input_ids = obj.input_ids
|
@@ -218,10 +226,10 @@ class TokenizerManager:
|
|
218
226
|
return_logprob = obj.return_logprob
|
219
227
|
logprob_start_len = obj.logprob_start_len
|
220
228
|
top_logprobs_num = obj.top_logprobs_num
|
221
|
-
session_id = obj.
|
222
|
-
session_rid = obj.
|
229
|
+
session_id = obj.session[0] if obj.session else None
|
230
|
+
session_rid = obj.session[1] if obj.session else None
|
223
231
|
|
224
|
-
if len(input_ids) >= self.context_len:
|
232
|
+
if obj.input_ids is not None and len(input_ids) >= self.context_len:
|
225
233
|
raise ValueError(
|
226
234
|
f"The input ({len(input_ids)} tokens) is longer than the "
|
227
235
|
f"model's context length ({self.context_len} tokens)."
|
@@ -244,7 +252,8 @@ class TokenizerManager:
|
|
244
252
|
logprob_start_len,
|
245
253
|
top_logprobs_num,
|
246
254
|
obj.stream,
|
247
|
-
obj.lora_path,
|
255
|
+
lora_path=obj.lora_path,
|
256
|
+
input_embeds=input_embeds,
|
248
257
|
session_id=session_id,
|
249
258
|
session_rid=session_rid,
|
250
259
|
)
|
@@ -572,13 +581,11 @@ class TokenizerManager:
|
|
572
581
|
out_dict = {
|
573
582
|
"text": recv_obj.output_strs[i],
|
574
583
|
"meta_info": recv_obj.meta_info[i],
|
575
|
-
"session_id": recv_obj.session_ids[i],
|
576
584
|
}
|
577
585
|
elif isinstance(recv_obj, BatchTokenIDOut):
|
578
586
|
out_dict = {
|
579
587
|
"token_ids": recv_obj.output_ids[i],
|
580
588
|
"meta_info": recv_obj.meta_info[i],
|
581
|
-
"session_id": recv_obj.session_ids[i],
|
582
589
|
}
|
583
590
|
else:
|
584
591
|
assert isinstance(recv_obj, BatchEmbeddingOut)
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""A tensor parallel worker."""
|
17
15
|
|
18
16
|
import logging
|
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""A tensor parallel worker."""
|
17
15
|
|
18
16
|
import dataclasses
|
sglang/srt/metrics/collector.py
CHANGED
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""Utilities for Prometheus Metrics Collection."""
|
17
15
|
|
18
16
|
from dataclasses import dataclass
|
sglang/srt/metrics/func_timer.py
CHANGED
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""
|
17
15
|
Records the latency of some functions
|
18
16
|
"""
|
sglang/srt/mm_utils.py
CHANGED
@@ -1,17 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
15
14
|
|
16
15
|
# Source: https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/llava/mm_utils.py
|
17
16
|
"""
|
@@ -1,22 +1,20 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
See the License for the specific language governing permissions and
|
15
|
-
limitations under the License.
|
16
|
-
"""
|
17
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
18
14
|
"""Run the model with cuda graph and torch.compile."""
|
19
15
|
|
16
|
+
from __future__ import annotations
|
17
|
+
|
20
18
|
import bisect
|
21
19
|
from contextlib import contextmanager
|
22
20
|
from typing import TYPE_CHECKING, Callable
|
@@ -25,7 +23,7 @@ import torch
|
|
25
23
|
from vllm.distributed.parallel_state import graph_capture
|
26
24
|
from vllm.model_executor.custom_op import CustomOp
|
27
25
|
|
28
|
-
from sglang.srt.layers.
|
26
|
+
from sglang.srt.layers.fused_moe_patch import fused_moe_forward_native
|
29
27
|
from sglang.srt.layers.logits_processor import (
|
30
28
|
LogitsMetadata,
|
31
29
|
LogitsProcessor,
|
@@ -67,7 +65,10 @@ def patch_model(
|
|
67
65
|
_to_torch(model)
|
68
66
|
monkey_patch_vllm_all_gather()
|
69
67
|
backup_ca_comm = tp_group.ca_comm
|
70
|
-
|
68
|
+
# Use custom-allreduce here.
|
69
|
+
# We found the custom allreduce is much faster than the built-in allreduce in torch,
|
70
|
+
# even with ENABLE_INTRA_NODE_COMM=1.
|
71
|
+
# tp_group.ca_comm = None
|
71
72
|
yield torch.compile(
|
72
73
|
torch.no_grad()(model.forward), mode="max-autotune-no-cudagraphs"
|
73
74
|
)
|
@@ -1,20 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
See the License for the specific language governing permissions and
|
15
|
-
limitations under the License.
|
16
|
-
"""
|
17
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
18
14
|
"""
|
19
15
|
Store information about a forward batch.
|
20
16
|
|
@@ -31,6 +27,8 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
|
31
27
|
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
32
28
|
"""
|
33
29
|
|
30
|
+
from __future__ import annotations
|
31
|
+
|
34
32
|
from dataclasses import dataclass
|
35
33
|
from enum import IntEnum, auto
|
36
34
|
from typing import TYPE_CHECKING, List, Optional
|
@@ -132,6 +130,9 @@ class ForwardBatch:
|
|
132
130
|
# For LoRA
|
133
131
|
lora_paths: Optional[List[str]] = None
|
134
132
|
|
133
|
+
# For input embeddings
|
134
|
+
input_embeds: Optional[torch.tensor] = None
|
135
|
+
|
135
136
|
# Sampling info
|
136
137
|
sampling_info: SamplingBatchInfo = None
|
137
138
|
|
@@ -233,6 +234,7 @@ class ForwardBatch:
|
|
233
234
|
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
234
235
|
lora_paths=batch.lora_paths,
|
235
236
|
sampling_info=batch.sampling_info,
|
237
|
+
input_embeds=batch.input_embeds,
|
236
238
|
)
|
237
239
|
|
238
240
|
if ret.global_num_tokens is not None:
|
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""ModelRunner runs the forward passes of the models."""
|
17
15
|
|
18
16
|
import gc
|
@@ -63,6 +61,7 @@ from sglang.srt.utils import (
|
|
63
61
|
is_hip,
|
64
62
|
monkey_patch_vllm_model_config,
|
65
63
|
monkey_patch_vllm_p2p_access_check,
|
64
|
+
set_cpu_offload_max_bytes,
|
66
65
|
)
|
67
66
|
|
68
67
|
logger = logging.getLogger(__name__)
|
@@ -147,7 +146,9 @@ class ModelRunner:
|
|
147
146
|
}
|
148
147
|
)
|
149
148
|
|
150
|
-
|
149
|
+
set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))
|
150
|
+
|
151
|
+
# Init components
|
151
152
|
min_per_gpu_memory = self.init_torch_distributed()
|
152
153
|
self.sampler = Sampler()
|
153
154
|
self.load_model()
|
@@ -178,14 +179,15 @@ class ModelRunner:
|
|
178
179
|
def init_torch_distributed(self):
|
179
180
|
logger.info("Init torch distributed begin.")
|
180
181
|
# Init torch distributed
|
182
|
+
torch.get_device_module(self.device).set_device(self.gpu_id)
|
181
183
|
if self.device == "cuda":
|
182
|
-
torch.cuda.set_device(self.gpu_id)
|
183
184
|
backend = "nccl"
|
184
185
|
# ToDO(liangan1):Just use gloo to bypass the initilization fail
|
185
186
|
# Need to use xccl for xpu backend in the future
|
186
187
|
elif self.device == "xpu":
|
187
|
-
torch.xpu.set_device(self.gpu_id)
|
188
188
|
backend = "gloo"
|
189
|
+
elif self.device == "hpu":
|
190
|
+
backend = "hccl"
|
189
191
|
|
190
192
|
if not self.server_args.enable_p2p_check:
|
191
193
|
monkey_patch_vllm_p2p_access_check(self.gpu_id)
|
@@ -240,15 +242,17 @@ class ModelRunner:
|
|
240
242
|
)
|
241
243
|
return get_model(vllm_config=vllm_config)
|
242
244
|
except ImportError:
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
245
|
+
pass
|
246
|
+
|
247
|
+
return get_model(
|
248
|
+
model_config=self.vllm_model_config,
|
249
|
+
load_config=self.load_config,
|
250
|
+
device_config=DeviceConfig(self.device),
|
251
|
+
parallel_config=None,
|
252
|
+
scheduler_config=None,
|
253
|
+
lora_config=None,
|
254
|
+
cache_config=None,
|
255
|
+
)
|
252
256
|
|
253
257
|
def get_model_config_params(self):
|
254
258
|
sig = inspect.signature(VllmModelConfig.__init__)
|
@@ -602,9 +606,17 @@ class ModelRunner:
|
|
602
606
|
def forward_extend(self, forward_batch: ForwardBatch):
|
603
607
|
self.attn_backend.init_forward_metadata(forward_batch)
|
604
608
|
if self.is_generation:
|
605
|
-
|
606
|
-
|
607
|
-
|
609
|
+
if forward_batch.input_embeds is None:
|
610
|
+
return self.model.forward(
|
611
|
+
forward_batch.input_ids, forward_batch.positions, forward_batch
|
612
|
+
)
|
613
|
+
else:
|
614
|
+
return self.model.forward(
|
615
|
+
forward_batch.input_ids,
|
616
|
+
forward_batch.positions,
|
617
|
+
forward_batch,
|
618
|
+
input_embeds=forward_batch.input_embeds.bfloat16(),
|
619
|
+
)
|
608
620
|
else:
|
609
621
|
# Only embedding models have get_embedding parameter
|
610
622
|
return self.model.forward(
|
sglang/srt/models/chatglm.py
CHANGED
@@ -1,22 +1,21 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
# coding=utf-8
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
|
17
15
|
# Adapted from
|
18
16
|
# https://github.com/THUDM/ChatGLM2-6B
|
19
17
|
"""Inference-only ChatGLM model compatible with THUDM weights."""
|
18
|
+
|
20
19
|
from typing import Iterable, Optional, Tuple
|
21
20
|
|
22
21
|
import torch
|
sglang/srt/models/commandr.py
CHANGED
@@ -1,19 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
16
|
-
# coding=utf-8
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
17
14
|
# Copyright 2024 Cohere and the HuggingFace Inc. team. All rights reserved.
|
18
15
|
#
|
19
16
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
@@ -32,12 +29,14 @@ limitations under the License.
|
|
32
29
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
33
30
|
# See the License for the specific language governing permissions and
|
34
31
|
# limitations under the License.
|
32
|
+
# ==============================================================================
|
35
33
|
|
36
34
|
# Adapted from
|
37
35
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/commandr.py#L1
|
38
36
|
|
39
37
|
# This file is based on the LLama model definition file in transformers
|
40
38
|
"""PyTorch Cohere model."""
|
39
|
+
|
41
40
|
from typing import Iterable, Optional, Tuple
|
42
41
|
|
43
42
|
import torch
|
sglang/srt/models/dbrx.py
CHANGED
@@ -1,21 +1,20 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
15
14
|
|
16
15
|
# Adapted from:
|
17
16
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/dbrx.py#L1
|
18
|
-
|
17
|
+
|
19
18
|
from typing import Iterable, Optional, Tuple
|
20
19
|
|
21
20
|
import torch
|
@@ -25,11 +24,11 @@ from vllm.distributed import (
|
|
25
24
|
get_tensor_model_parallel_world_size,
|
26
25
|
tensor_model_parallel_all_reduce,
|
27
26
|
)
|
28
|
-
from vllm.model_executor.layers.fused_moe import fused_moe
|
29
27
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
30
28
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
31
29
|
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
32
30
|
|
31
|
+
from sglang.srt.layers.fused_moe_triton import fused_moe
|
33
32
|
from sglang.srt.layers.linear import (
|
34
33
|
QKVParallelLinear,
|
35
34
|
ReplicatedLinear,
|