sglang 0.3.5.post1__py3-none-any.whl → 0.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +1 -553
- sglang/bench_offline_throughput.py +337 -0
- sglang/bench_one_batch.py +474 -0
- sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
- sglang/bench_serving.py +115 -31
- sglang/check_env.py +3 -6
- sglang/srt/constrained/base_grammar_backend.py +4 -3
- sglang/srt/constrained/outlines_backend.py +39 -26
- sglang/srt/constrained/xgrammar_backend.py +58 -14
- sglang/srt/layers/activation.py +3 -0
- sglang/srt/layers/attention/flashinfer_backend.py +93 -48
- sglang/srt/layers/attention/triton_backend.py +9 -7
- sglang/srt/layers/custom_op_util.py +26 -0
- sglang/srt/layers/fused_moe/fused_moe.py +11 -4
- sglang/srt/layers/fused_moe/patch.py +4 -2
- sglang/srt/layers/layernorm.py +4 -0
- sglang/srt/layers/logits_processor.py +10 -10
- sglang/srt/layers/sampler.py +4 -8
- sglang/srt/layers/torchao_utils.py +2 -0
- sglang/srt/managers/data_parallel_controller.py +74 -9
- sglang/srt/managers/detokenizer_manager.py +1 -14
- sglang/srt/managers/io_struct.py +27 -0
- sglang/srt/managers/schedule_batch.py +104 -38
- sglang/srt/managers/schedule_policy.py +5 -1
- sglang/srt/managers/scheduler.py +210 -56
- sglang/srt/managers/session_controller.py +62 -0
- sglang/srt/managers/tokenizer_manager.py +38 -0
- sglang/srt/managers/tp_worker.py +12 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +49 -52
- sglang/srt/model_executor/cuda_graph_runner.py +43 -6
- sglang/srt/model_executor/forward_batch_info.py +109 -15
- sglang/srt/model_executor/model_runner.py +102 -43
- sglang/srt/model_parallel.py +98 -0
- sglang/srt/models/deepseek_v2.py +147 -44
- sglang/srt/models/gemma2.py +9 -8
- sglang/srt/models/llava.py +1 -1
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/phi3_small.py +447 -0
- sglang/srt/models/qwen2_vl.py +13 -6
- sglang/srt/models/torch_native_llama.py +94 -78
- sglang/srt/openai_api/adapter.py +11 -4
- sglang/srt/openai_api/protocol.py +30 -27
- sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
- sglang/srt/sampling/sampling_batch_info.py +58 -57
- sglang/srt/sampling/sampling_params.py +3 -3
- sglang/srt/server.py +29 -2
- sglang/srt/server_args.py +97 -60
- sglang/srt/utils.py +103 -51
- sglang/test/runners.py +25 -6
- sglang/test/srt/sampling/penaltylib/utils.py +23 -21
- sglang/test/test_utils.py +33 -22
- sglang/version.py +1 -1
- {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/METADATA +43 -43
- {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/RECORD +62 -56
- {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/WHEEL +1 -1
- {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/LICENSE +0 -0
- {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,8 @@
|
|
1
|
-
import
|
1
|
+
from typing import List
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
|
-
from
|
5
|
+
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
6
6
|
|
7
7
|
|
8
8
|
class BatchedFrequencyPenalizer(_BatchedPenalizer):
|
@@ -44,9 +44,6 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
|
|
44
44
|
)
|
45
45
|
|
46
46
|
def _teardown(self):
|
47
|
-
del self.frequency_penalties
|
48
|
-
del self.cumulated_frequency_penalties
|
49
|
-
|
50
47
|
self.frequency_penalties = None
|
51
48
|
self.cumulated_frequency_penalties = None
|
52
49
|
|
@@ -62,9 +59,7 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
|
|
62
59
|
logits -= self.cumulated_frequency_penalties
|
63
60
|
return logits
|
64
61
|
|
65
|
-
def _filter(
|
66
|
-
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
67
|
-
):
|
62
|
+
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
68
63
|
self.frequency_penalties = self.frequency_penalties[indices_tensor_to_keep]
|
69
64
|
self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[
|
70
65
|
indices_tensor_to_keep
|
@@ -1,8 +1,8 @@
|
|
1
|
-
import
|
1
|
+
from typing import List
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
|
-
from
|
5
|
+
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
6
6
|
|
7
7
|
|
8
8
|
class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
@@ -70,10 +70,6 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
|
70
70
|
)
|
71
71
|
|
72
72
|
def _teardown(self):
|
73
|
-
del self.min_new_tokens
|
74
|
-
del self.stop_token_penalties
|
75
|
-
del self.len_output_tokens
|
76
|
-
|
77
73
|
self.min_new_tokens = None
|
78
74
|
self.stop_token_penalties = None
|
79
75
|
self.len_output_tokens = None
|
@@ -89,9 +85,7 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
|
89
85
|
logits[mask] += self.stop_token_penalties[mask]
|
90
86
|
return logits
|
91
87
|
|
92
|
-
def _filter(
|
93
|
-
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
94
|
-
):
|
88
|
+
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
95
89
|
self.min_new_tokens = self.min_new_tokens[indices_tensor_to_keep]
|
96
90
|
self.stop_token_penalties = self.stop_token_penalties[indices_tensor_to_keep]
|
97
91
|
self.len_output_tokens = self.len_output_tokens[indices_tensor_to_keep]
|
@@ -1,8 +1,8 @@
|
|
1
|
-
import
|
1
|
+
from typing import List
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
|
-
from
|
5
|
+
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
6
6
|
|
7
7
|
|
8
8
|
class BatchedPresencePenalizer(_BatchedPenalizer):
|
@@ -44,9 +44,6 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
|
|
44
44
|
)
|
45
45
|
|
46
46
|
def _teardown(self):
|
47
|
-
del self.presence_penalties
|
48
|
-
del self.cumulated_presence_penalties
|
49
|
-
|
50
47
|
self.presence_penalties = None
|
51
48
|
self.cumulated_presence_penalties = None
|
52
49
|
|
@@ -61,9 +58,7 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
|
|
61
58
|
logits -= self.cumulated_presence_penalties
|
62
59
|
return logits
|
63
60
|
|
64
|
-
def _filter(
|
65
|
-
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
66
|
-
):
|
61
|
+
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
67
62
|
self.presence_penalties = self.presence_penalties[indices_tensor_to_keep]
|
68
63
|
self.cumulated_presence_penalties = self.cumulated_presence_penalties[
|
69
64
|
indices_tensor_to_keep
|
@@ -1,8 +1,8 @@
|
|
1
|
-
import
|
1
|
+
from typing import List
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
|
-
from
|
5
|
+
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
6
6
|
|
7
7
|
|
8
8
|
class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
@@ -44,9 +44,6 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
|
44
44
|
)
|
45
45
|
|
46
46
|
def _teardown(self):
|
47
|
-
del self.repetition_penalties
|
48
|
-
del self.cumulated_repetition_penalties
|
49
|
-
|
50
47
|
self.repetition_penalties = None
|
51
48
|
self.cumulated_repetition_penalties = None
|
52
49
|
|
@@ -65,9 +62,7 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
|
65
62
|
logits * self.cumulated_repetition_penalties,
|
66
63
|
)
|
67
64
|
|
68
|
-
def _filter(
|
69
|
-
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
|
70
|
-
):
|
65
|
+
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
71
66
|
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
|
72
67
|
self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[
|
73
68
|
indices_tensor_to_keep
|
@@ -1,12 +1,17 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import dataclasses
|
4
|
-
|
4
|
+
import logging
|
5
|
+
import threading
|
6
|
+
from typing import TYPE_CHECKING, Callable, List, Optional
|
5
7
|
|
6
8
|
import torch
|
7
9
|
|
8
10
|
import sglang.srt.sampling.penaltylib as penaltylib
|
9
11
|
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
|
10
15
|
if TYPE_CHECKING:
|
11
16
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
12
17
|
|
@@ -27,10 +32,11 @@ class SamplingBatchInfo:
|
|
27
32
|
|
28
33
|
# Bias Tensors
|
29
34
|
vocab_size: int
|
35
|
+
grammars: Optional[List] = None
|
36
|
+
sampling_info_done: Optional[threading.Event] = None
|
30
37
|
logit_bias: torch.Tensor = None
|
31
38
|
vocab_mask: Optional[torch.Tensor] = None
|
32
|
-
|
33
|
-
grammars: Optional[List] = None
|
39
|
+
apply_mask: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
|
34
40
|
|
35
41
|
# Penalizer
|
36
42
|
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
|
@@ -42,10 +48,7 @@ class SamplingBatchInfo:
|
|
42
48
|
|
43
49
|
@classmethod
|
44
50
|
def from_schedule_batch(
|
45
|
-
cls,
|
46
|
-
batch: ScheduleBatch,
|
47
|
-
vocab_size: int,
|
48
|
-
disable_penalizer: bool,
|
51
|
+
cls, batch: ScheduleBatch, vocab_size: int, enable_overlap_schedule: bool
|
49
52
|
):
|
50
53
|
reqs = batch.reqs
|
51
54
|
device = batch.device
|
@@ -73,12 +76,39 @@ class SamplingBatchInfo:
|
|
73
76
|
top_ks=top_ks,
|
74
77
|
min_ps=min_ps,
|
75
78
|
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
76
|
-
is_all_greedy=
|
79
|
+
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
|
77
80
|
vocab_size=vocab_size,
|
78
81
|
device=device,
|
79
82
|
)
|
80
83
|
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
81
84
|
|
85
|
+
if enable_overlap_schedule:
|
86
|
+
# TODO (lianmin): Some penalizers such as frequency and presence depend on model outputs,
|
87
|
+
# so it is kind of tricky to make it work with overlap scheduler.
|
88
|
+
# It requires correcly updating the penalty logits before the sampling and syncing the events.
|
89
|
+
# We will support them later.
|
90
|
+
penalizers = {
|
91
|
+
penaltylib.BatchedMinNewTokensPenalizer,
|
92
|
+
}
|
93
|
+
if (
|
94
|
+
any(req.sampling_params.frequency_penalty != 0.0 for req in reqs)
|
95
|
+
or any(req.sampling_params.presence_penalty != 0.0 for req in reqs)
|
96
|
+
or any(req.sampling_params.repetition_penalty != 1.0 for req in reqs)
|
97
|
+
):
|
98
|
+
logger.warning(
|
99
|
+
"frequency_penalty, presence_penalty, and repetition_penalty are not supported "
|
100
|
+
"when using the default overlap scheduler. They will be ignored. "
|
101
|
+
"Please add `--disable-overlap` when launching the server if you need these features. "
|
102
|
+
"The speed will be slower in that case."
|
103
|
+
)
|
104
|
+
else:
|
105
|
+
penalizers = {
|
106
|
+
penaltylib.BatchedFrequencyPenalizer,
|
107
|
+
penaltylib.BatchedMinNewTokensPenalizer,
|
108
|
+
penaltylib.BatchedPresencePenalizer,
|
109
|
+
penaltylib.BatchedRepetitionPenalizer,
|
110
|
+
}
|
111
|
+
|
82
112
|
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
83
113
|
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
84
114
|
# should not add hefty computation overhead other than simple checks.
|
@@ -86,20 +116,12 @@ class SamplingBatchInfo:
|
|
86
116
|
# While we choose not to even create the class instances if they are not required, this
|
87
117
|
# could add additional complexity to the {ScheduleBatch} class, especially we need to
|
88
118
|
# handle {filter_batch()} and {merge_batch()} cases as well.
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
device=batch.device,
|
96
|
-
Penalizers={
|
97
|
-
penaltylib.BatchedFrequencyPenalizer,
|
98
|
-
penaltylib.BatchedMinNewTokensPenalizer,
|
99
|
-
penaltylib.BatchedPresencePenalizer,
|
100
|
-
penaltylib.BatchedRepetitionPenalizer,
|
101
|
-
},
|
102
|
-
)
|
119
|
+
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
120
|
+
vocab_size=vocab_size,
|
121
|
+
batch=batch,
|
122
|
+
device=batch.device,
|
123
|
+
Penalizers=penalizers,
|
124
|
+
)
|
103
125
|
|
104
126
|
# Handle logit bias but only allocate when needed
|
105
127
|
ret.logit_bias = None
|
@@ -110,9 +132,6 @@ class SamplingBatchInfo:
|
|
110
132
|
return len(self.temperatures)
|
111
133
|
|
112
134
|
def update_penalties(self):
|
113
|
-
if not self.penalizer_orchestrator:
|
114
|
-
return
|
115
|
-
|
116
135
|
self.scaling_penalties = None
|
117
136
|
self.linear_penalties = None
|
118
137
|
|
@@ -133,23 +152,28 @@ class SamplingBatchInfo:
|
|
133
152
|
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
134
153
|
|
135
154
|
def update_regex_vocab_mask(self):
|
136
|
-
if not self.grammars
|
155
|
+
if not self.grammars:
|
137
156
|
self.vocab_mask = None
|
157
|
+
self.apply_mask = None
|
138
158
|
return
|
139
159
|
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
160
|
+
# find a grammar from the list
|
161
|
+
grammar = next(grammar for grammar in self.grammars if grammar)
|
162
|
+
|
163
|
+
# maybe we can reuse the existing mask?
|
164
|
+
self.vocab_mask = grammar.allocate_vocab_mask(
|
165
|
+
vocab_size=self.vocab_size,
|
166
|
+
batch_size=len(self.temperatures),
|
144
167
|
device=self.device,
|
145
168
|
)
|
169
|
+
self.apply_mask = type(grammar).apply_vocab_mask # force to use static method
|
170
|
+
|
146
171
|
for i, grammar in enumerate(self.grammars):
|
147
172
|
if grammar is not None:
|
148
|
-
grammar.fill_vocab_mask(self.vocab_mask
|
173
|
+
grammar.fill_vocab_mask(self.vocab_mask, i)
|
149
174
|
|
150
175
|
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
151
|
-
|
152
|
-
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
176
|
+
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
153
177
|
|
154
178
|
for item in [
|
155
179
|
"temperatures",
|
@@ -188,8 +212,7 @@ class SamplingBatchInfo:
|
|
188
212
|
return None
|
189
213
|
|
190
214
|
def merge_batch(self, other: "SamplingBatchInfo"):
|
191
|
-
|
192
|
-
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
215
|
+
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
193
216
|
|
194
217
|
for item in [
|
195
218
|
"temperatures",
|
@@ -205,25 +228,3 @@ class SamplingBatchInfo:
|
|
205
228
|
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
206
229
|
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
207
230
|
)
|
208
|
-
|
209
|
-
def copy(self):
|
210
|
-
return SamplingBatchInfo(
|
211
|
-
temperatures=self.temperatures,
|
212
|
-
top_ps=self.top_ps,
|
213
|
-
top_ks=self.top_ks,
|
214
|
-
min_ps=self.min_ps,
|
215
|
-
is_all_greedy=self.is_all_greedy,
|
216
|
-
need_min_p_sampling=self.need_min_p_sampling,
|
217
|
-
vocab_size=self.vocab_size,
|
218
|
-
device=self.device,
|
219
|
-
)
|
220
|
-
|
221
|
-
def to(self, device: str):
|
222
|
-
for item in [
|
223
|
-
"temperatures",
|
224
|
-
"top_ps",
|
225
|
-
"top_ks",
|
226
|
-
"min_ps",
|
227
|
-
]:
|
228
|
-
value = getattr(self, item)
|
229
|
-
setattr(self, item, value.to(device, non_blocking=True))
|
@@ -24,7 +24,6 @@ class SamplingParams:
|
|
24
24
|
def __init__(
|
25
25
|
self,
|
26
26
|
max_new_tokens: int = 128,
|
27
|
-
min_new_tokens: int = 0,
|
28
27
|
stop: Optional[Union[str, List[str]]] = None,
|
29
28
|
stop_token_ids: Optional[List[int]] = None,
|
30
29
|
temperature: float = 1.0,
|
@@ -34,13 +33,14 @@ class SamplingParams:
|
|
34
33
|
frequency_penalty: float = 0.0,
|
35
34
|
presence_penalty: float = 0.0,
|
36
35
|
repetition_penalty: float = 1.0,
|
37
|
-
|
38
|
-
skip_special_tokens: bool = True,
|
36
|
+
min_new_tokens: int = 0,
|
39
37
|
spaces_between_special_tokens: bool = True,
|
40
38
|
regex: Optional[str] = None,
|
41
39
|
n: int = 1,
|
42
40
|
json_schema: Optional[str] = None,
|
43
41
|
no_stop_trim: bool = False,
|
42
|
+
ignore_eos: bool = False,
|
43
|
+
skip_special_tokens: bool = True,
|
44
44
|
) -> None:
|
45
45
|
self.temperature = temperature
|
46
46
|
self.top_p = top_p
|
sglang/srt/server.py
CHANGED
@@ -50,8 +50,10 @@ from sglang.srt.managers.data_parallel_controller import (
|
|
50
50
|
)
|
51
51
|
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
|
52
52
|
from sglang.srt.managers.io_struct import (
|
53
|
+
CloseSessionReqInput,
|
53
54
|
EmbeddingReqInput,
|
54
55
|
GenerateReqInput,
|
56
|
+
OpenSessionReqInput,
|
55
57
|
UpdateWeightReqInput,
|
56
58
|
)
|
57
59
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
@@ -139,6 +141,7 @@ async def get_model_info():
|
|
139
141
|
"""Get the model information."""
|
140
142
|
result = {
|
141
143
|
"model_path": tokenizer_manager.model_path,
|
144
|
+
"tokenizer_path": tokenizer_manager.server_args.tokenizer_path,
|
142
145
|
"is_generation": tokenizer_manager.is_generation,
|
143
146
|
}
|
144
147
|
return result
|
@@ -214,6 +217,30 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request):
|
|
214
217
|
)
|
215
218
|
|
216
219
|
|
220
|
+
@app.api_route("/open_session", methods=["GET", "POST"])
|
221
|
+
async def open_session(obj: OpenSessionReqInput, request: Request):
|
222
|
+
"""Open a session, and return its unique session id."""
|
223
|
+
try:
|
224
|
+
session_id = await tokenizer_manager.open_session(obj, request)
|
225
|
+
return session_id
|
226
|
+
except Exception as e:
|
227
|
+
return ORJSONResponse(
|
228
|
+
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
229
|
+
)
|
230
|
+
|
231
|
+
|
232
|
+
@app.api_route("/close_session", methods=["GET", "POST"])
|
233
|
+
async def close_session(obj: CloseSessionReqInput, request: Request):
|
234
|
+
"""Close the session"""
|
235
|
+
try:
|
236
|
+
await tokenizer_manager.close_session(obj, request)
|
237
|
+
return Response(status_code=200)
|
238
|
+
except Exception as e:
|
239
|
+
return ORJSONResponse(
|
240
|
+
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
241
|
+
)
|
242
|
+
|
243
|
+
|
217
244
|
@time_func_latency
|
218
245
|
async def generate_request(obj: GenerateReqInput, request: Request):
|
219
246
|
"""Handle a generate request."""
|
@@ -391,7 +418,7 @@ def launch_engine(
|
|
391
418
|
)
|
392
419
|
for tp_rank in tp_rank_range:
|
393
420
|
reader, writer = mp.Pipe(duplex=False)
|
394
|
-
gpu_id = tp_rank % tp_size_per_node
|
421
|
+
gpu_id = server_args.base_gpu_id + tp_rank % tp_size_per_node
|
395
422
|
proc = mp.Process(
|
396
423
|
target=run_scheduler_process,
|
397
424
|
args=(server_args, port_args, gpu_id, tp_rank, None, writer),
|
@@ -768,7 +795,7 @@ class Engine:
|
|
768
795
|
self,
|
769
796
|
# The input prompt. It can be a single prompt or a batch of prompts.
|
770
797
|
prompt: Optional[Union[List[str], str]] = None,
|
771
|
-
sampling_params: Optional[Dict] = None,
|
798
|
+
sampling_params: Optional[Union[List[Dict], Dict]] = None,
|
772
799
|
# The token ids for text; one can either specify text or input_ids.
|
773
800
|
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
774
801
|
return_logprob: Optional[Union[List[bool], bool]] = False,
|