sglang 0.4.1.post5__py3-none-any.whl → 0.4.1.post6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/configs/model_config.py +15 -6
- sglang/srt/layers/attention/flashinfer_backend.py +17 -3
- sglang/srt/layers/linear.py +36 -98
- sglang/srt/layers/moe/fused_moe_triton/layer.py +37 -9
- sglang/srt/layers/moe/topk.py +4 -2
- sglang/srt/layers/parameter.py +24 -16
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/fp8.py +106 -52
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/int8_kernel.py +54 -0
- sglang/srt/layers/quantization/modelopt_quant.py +1 -1
- sglang/srt/layers/quantization/w8a8_int8.py +117 -0
- sglang/srt/layers/radix_attention.py +2 -0
- sglang/srt/layers/vocab_parallel_embedding.py +15 -2
- sglang/srt/managers/configure_logging.py +43 -0
- sglang/srt/managers/detokenizer_manager.py +0 -2
- sglang/srt/managers/io_struct.py +29 -13
- sglang/srt/managers/scheduler.py +48 -9
- sglang/srt/managers/tokenizer_manager.py +109 -49
- sglang/srt/mem_cache/memory_pool.py +107 -52
- sglang/srt/metrics/collector.py +10 -5
- sglang/srt/model_executor/model_runner.py +43 -6
- sglang/srt/models/llama.py +37 -2
- sglang/srt/models/qwen2.py +11 -0
- sglang/srt/models/qwen2_eagle.py +131 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
- sglang/srt/sampling/sampling_batch_info.py +14 -5
- sglang/srt/sampling/sampling_params.py +1 -1
- sglang/srt/server.py +114 -61
- sglang/srt/server_args.py +27 -18
- sglang/srt/speculative/eagle_worker.py +1 -0
- sglang/srt/torch_memory_saver_adapter.py +59 -0
- sglang/srt/utils.py +29 -0
- sglang/version.py +1 -1
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post6.dist-info}/METADATA +12 -10
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post6.dist-info}/RECORD +39 -34
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post6.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post6.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post6.dist-info}/top_level.txt +0 -0
@@ -50,10 +50,12 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
50
50
|
from sglang.srt.model_loader import get_model
|
51
51
|
from sglang.srt.server_args import ServerArgs
|
52
52
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
53
|
+
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
53
54
|
from sglang.srt.utils import (
|
54
55
|
enable_show_time_cost,
|
55
56
|
get_available_gpu_memory,
|
56
57
|
init_custom_process_group,
|
58
|
+
is_cuda,
|
57
59
|
is_hip,
|
58
60
|
monkey_patch_vllm_gguf_config,
|
59
61
|
monkey_patch_vllm_p2p_access_check,
|
@@ -165,6 +167,10 @@ class ModelRunner:
|
|
165
167
|
# Get memory before model loading
|
166
168
|
min_per_gpu_memory = self.init_torch_distributed()
|
167
169
|
|
170
|
+
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
171
|
+
enable=self.server_args.enable_memory_saver
|
172
|
+
)
|
173
|
+
|
168
174
|
# Load the model
|
169
175
|
self.sampler = Sampler()
|
170
176
|
self.load_model()
|
@@ -271,11 +277,35 @@ class ModelRunner:
|
|
271
277
|
monkey_patch_vllm_gguf_config()
|
272
278
|
|
273
279
|
# Load the model
|
274
|
-
self.
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
280
|
+
with self.memory_saver_adapter.region():
|
281
|
+
self.model = get_model(
|
282
|
+
model_config=self.model_config,
|
283
|
+
load_config=self.load_config,
|
284
|
+
device_config=DeviceConfig(self.device),
|
285
|
+
)
|
286
|
+
|
287
|
+
if self.server_args.kv_cache_dtype == "fp8_e4m3":
|
288
|
+
if self.server_args.quantization_param_path is not None:
|
289
|
+
if callable(getattr(self.model, "load_kv_cache_scales", None)):
|
290
|
+
self.model.load_kv_cache_scales(
|
291
|
+
self.server_args.quantization_param_path
|
292
|
+
)
|
293
|
+
logger.info(
|
294
|
+
"Loaded KV cache scaling factors from %s",
|
295
|
+
self.server_args.quantization_param_path,
|
296
|
+
)
|
297
|
+
else:
|
298
|
+
raise RuntimeError(
|
299
|
+
"Using FP8 KV cache and scaling factors provided but "
|
300
|
+
"model %s does not support loading scaling factors.",
|
301
|
+
self.model.__class__,
|
302
|
+
)
|
303
|
+
else:
|
304
|
+
logger.warning(
|
305
|
+
"Using FP8 KV cache but no scaling factors "
|
306
|
+
"provided. Defaulting to scaling factors of 1.0. "
|
307
|
+
"This may lead to less accurate results!"
|
308
|
+
)
|
279
309
|
|
280
310
|
# Parse other args
|
281
311
|
self.sliding_window_size = (
|
@@ -393,7 +423,7 @@ class ModelRunner:
|
|
393
423
|
|
394
424
|
logger.info(
|
395
425
|
f"init custom process group: master_address={master_address}, master_port={master_port}, "
|
396
|
-
f"rank_offset={rank_offset}, world_size={world_size}, group_name={group_name}, backend={backend}"
|
426
|
+
f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
|
397
427
|
)
|
398
428
|
|
399
429
|
try:
|
@@ -516,6 +546,9 @@ class ModelRunner:
|
|
516
546
|
self.kv_cache_dtype = torch.float8_e5m2fnuz
|
517
547
|
else:
|
518
548
|
self.kv_cache_dtype = torch.float8_e5m2
|
549
|
+
elif self.server_args.kv_cache_dtype == "fp8_e4m3":
|
550
|
+
if is_cuda():
|
551
|
+
self.kv_cache_dtype = torch.float8_e4m3fn
|
519
552
|
else:
|
520
553
|
raise ValueError(
|
521
554
|
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
|
@@ -563,6 +596,7 @@ class ModelRunner:
|
|
563
596
|
max_context_len=self.model_config.context_len + 4,
|
564
597
|
device=self.device,
|
565
598
|
use_records=False,
|
599
|
+
enable_memory_saver=self.server_args.enable_memory_saver,
|
566
600
|
)
|
567
601
|
if (
|
568
602
|
self.model_config.attention_arch == AttentionArch.MLA
|
@@ -575,6 +609,7 @@ class ModelRunner:
|
|
575
609
|
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
576
610
|
layer_num=self.model_config.num_hidden_layers,
|
577
611
|
device=self.device,
|
612
|
+
enable_memory_saver=self.server_args.enable_memory_saver,
|
578
613
|
)
|
579
614
|
elif self.server_args.enable_double_sparsity:
|
580
615
|
self.token_to_kv_pool = DoubleSparseTokenToKVPool(
|
@@ -585,6 +620,7 @@ class ModelRunner:
|
|
585
620
|
layer_num=self.model_config.num_hidden_layers,
|
586
621
|
device=self.device,
|
587
622
|
heavy_channel_num=self.server_args.ds_heavy_channel_num,
|
623
|
+
enable_memory_saver=self.server_args.enable_memory_saver,
|
588
624
|
)
|
589
625
|
else:
|
590
626
|
self.token_to_kv_pool = MHATokenToKVPool(
|
@@ -594,6 +630,7 @@ class ModelRunner:
|
|
594
630
|
head_dim=self.model_config.head_dim,
|
595
631
|
layer_num=self.model_config.num_hidden_layers,
|
596
632
|
device=self.device,
|
633
|
+
enable_memory_saver=self.server_args.enable_memory_saver,
|
597
634
|
)
|
598
635
|
logger.info(
|
599
636
|
f"Memory pool end. "
|
sglang/srt/models/llama.py
CHANGED
@@ -22,8 +22,12 @@ from typing import Any, Dict, Iterable, Optional, Tuple
|
|
22
22
|
import torch
|
23
23
|
from torch import nn
|
24
24
|
from transformers import LlamaConfig
|
25
|
-
from vllm.distributed import
|
25
|
+
from vllm.distributed import (
|
26
|
+
get_tensor_model_parallel_rank,
|
27
|
+
get_tensor_model_parallel_world_size,
|
28
|
+
)
|
26
29
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
30
|
+
from vllm.model_executor.model_loader.weight_utils import kv_cache_scales_loader
|
27
31
|
|
28
32
|
from sglang.srt.layers.activation import SiluAndMul
|
29
33
|
from sglang.srt.layers.layernorm import RMSNorm
|
@@ -299,6 +303,30 @@ class LlamaModel(nn.Module):
|
|
299
303
|
hidden_states, _ = self.norm(hidden_states, residual)
|
300
304
|
return hidden_states
|
301
305
|
|
306
|
+
# If this function is called, it should always initialize KV cache scale
|
307
|
+
# factors (or else raise an exception). Thus, handled exceptions should
|
308
|
+
# make sure to leave KV cache scale factors in a known good (dummy) state
|
309
|
+
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
310
|
+
tp_size = get_tensor_model_parallel_world_size()
|
311
|
+
tp_rank = get_tensor_model_parallel_rank()
|
312
|
+
for layer_idx, scaling_factor in kv_cache_scales_loader(
|
313
|
+
quantization_param_path,
|
314
|
+
tp_rank,
|
315
|
+
tp_size,
|
316
|
+
self.config.num_hidden_layers,
|
317
|
+
self.config.__class__.model_type,
|
318
|
+
):
|
319
|
+
if not isinstance(self.layers[layer_idx], nn.Identity):
|
320
|
+
layer_self_attn = self.layers[layer_idx].self_attn
|
321
|
+
|
322
|
+
if hasattr(layer_self_attn.attn, "k_scale"):
|
323
|
+
layer_self_attn.attn.k_scale = scaling_factor
|
324
|
+
layer_self_attn.attn.v_scale = scaling_factor
|
325
|
+
else:
|
326
|
+
raise RuntimeError(
|
327
|
+
"Self attention has no KV cache scaling " "factor attribute!"
|
328
|
+
)
|
329
|
+
|
302
330
|
|
303
331
|
class LlamaForCausalLM(nn.Module):
|
304
332
|
|
@@ -534,9 +562,16 @@ class LlamaForCausalLM(nn.Module):
|
|
534
562
|
torch.cuda.empty_cache()
|
535
563
|
torch.cuda.synchronize()
|
536
564
|
|
565
|
+
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
566
|
+
self.model.load_kv_cache_scales(quantization_param_path)
|
567
|
+
|
537
568
|
|
538
569
|
class Phi3ForCausalLM(LlamaForCausalLM):
|
539
570
|
pass
|
540
571
|
|
541
572
|
|
542
|
-
|
573
|
+
class InternLM3ForCausalLM(LlamaForCausalLM):
|
574
|
+
pass
|
575
|
+
|
576
|
+
|
577
|
+
EntryClass = [LlamaForCausalLM, Phi3ForCausalLM, InternLM3ForCausalLM]
|
sglang/srt/models/qwen2.py
CHANGED
@@ -362,5 +362,16 @@ class Qwen2ForCausalLM(nn.Module):
|
|
362
362
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
363
363
|
weight_loader(param, loaded_weight)
|
364
364
|
|
365
|
+
def get_embed_and_head(self):
|
366
|
+
return self.model.embed_tokens.weight, self.lm_head.weight
|
367
|
+
|
368
|
+
def set_embed_and_head(self, embed, head):
|
369
|
+
del self.model.embed_tokens.weight
|
370
|
+
del self.lm_head.weight
|
371
|
+
self.model.embed_tokens.weight = embed
|
372
|
+
self.lm_head.weight = head
|
373
|
+
torch.cuda.empty_cache()
|
374
|
+
torch.cuda.synchronize()
|
375
|
+
|
365
376
|
|
366
377
|
EntryClass = Qwen2ForCausalLM
|
@@ -0,0 +1,131 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
# Adapted from
|
17
|
+
# https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py
|
18
|
+
"""Inference-only LLaMA-EAGLE model compatible with HuggingFace weights."""
|
19
|
+
|
20
|
+
from typing import Iterable, Optional, Tuple
|
21
|
+
|
22
|
+
import torch
|
23
|
+
from torch import nn
|
24
|
+
|
25
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
26
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
27
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
28
|
+
ParallelLMHead,
|
29
|
+
VocabParallelEmbedding,
|
30
|
+
)
|
31
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
32
|
+
from sglang.srt.models.qwen2 import Qwen2DecoderLayer, Qwen2ForCausalLM
|
33
|
+
|
34
|
+
Qwen2Config = None
|
35
|
+
|
36
|
+
|
37
|
+
class Qwen2DecoderLayer(Qwen2DecoderLayer):
|
38
|
+
def __init__(
|
39
|
+
self,
|
40
|
+
config: Qwen2Config,
|
41
|
+
layer_id: int = 0,
|
42
|
+
quant_config: Optional[QuantizationConfig] = None,
|
43
|
+
prefix: str = "",
|
44
|
+
) -> None:
|
45
|
+
super().__init__(config, layer_id, quant_config)
|
46
|
+
|
47
|
+
# Skip the input_layernorm
|
48
|
+
# https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
|
49
|
+
if layer_id == 0:
|
50
|
+
del self.input_layernorm
|
51
|
+
setattr(self, "input_layernorm", lambda x: x)
|
52
|
+
|
53
|
+
|
54
|
+
class Qwen2Model(nn.Module):
|
55
|
+
def __init__(
|
56
|
+
self,
|
57
|
+
config: Qwen2Config,
|
58
|
+
quant_config: Optional[QuantizationConfig] = None,
|
59
|
+
) -> None:
|
60
|
+
super().__init__()
|
61
|
+
self.config = config
|
62
|
+
self.vocab_size = config.vocab_size
|
63
|
+
self.embed_tokens = VocabParallelEmbedding(
|
64
|
+
config.vocab_size,
|
65
|
+
config.hidden_size,
|
66
|
+
)
|
67
|
+
self.layers = nn.ModuleList(
|
68
|
+
[
|
69
|
+
Qwen2DecoderLayer(
|
70
|
+
config, i, quant_config=quant_config, prefix=f"model.layers.{i}"
|
71
|
+
)
|
72
|
+
for i in range(config.num_hidden_layers)
|
73
|
+
]
|
74
|
+
)
|
75
|
+
self.fc = torch.nn.Linear(config.hidden_size * 2, config.hidden_size)
|
76
|
+
|
77
|
+
def forward(
|
78
|
+
self,
|
79
|
+
input_ids: torch.Tensor,
|
80
|
+
positions: torch.Tensor,
|
81
|
+
forward_batch: ForwardBatch,
|
82
|
+
input_embeds: torch.Tensor = None,
|
83
|
+
) -> torch.Tensor:
|
84
|
+
if input_embeds is None:
|
85
|
+
hidden_states = self.embed_tokens(input_ids)
|
86
|
+
else:
|
87
|
+
hidden_states = input_embeds
|
88
|
+
|
89
|
+
hidden_states = self.fc(
|
90
|
+
torch.cat((hidden_states, forward_batch.spec_info.hidden_states), dim=-1)
|
91
|
+
)
|
92
|
+
|
93
|
+
residual = None
|
94
|
+
for i in range(len(self.layers)):
|
95
|
+
layer = self.layers[i]
|
96
|
+
hidden_states, residual = layer(
|
97
|
+
positions,
|
98
|
+
hidden_states,
|
99
|
+
forward_batch,
|
100
|
+
residual,
|
101
|
+
)
|
102
|
+
return hidden_states + residual
|
103
|
+
|
104
|
+
|
105
|
+
class Qwen2ForCausalLMEagle(Qwen2ForCausalLM):
|
106
|
+
def __init__(
|
107
|
+
self,
|
108
|
+
config: Qwen2Config,
|
109
|
+
quant_config: Optional[QuantizationConfig] = None,
|
110
|
+
cache_config=None,
|
111
|
+
) -> None:
|
112
|
+
nn.Module.__init__(self)
|
113
|
+
self.config = config
|
114
|
+
self.quant_config = quant_config
|
115
|
+
self.model = Qwen2Model(config, quant_config=quant_config)
|
116
|
+
if self.config.tie_word_embeddings:
|
117
|
+
self.lm_head = self.model.embed_tokens
|
118
|
+
else:
|
119
|
+
self.lm_head = ParallelLMHead(
|
120
|
+
config.vocab_size, config.hidden_size, quant_config=quant_config
|
121
|
+
)
|
122
|
+
self.logits_processor = LogitsProcessor(config)
|
123
|
+
|
124
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
125
|
+
for name, loaded_weight in weights:
|
126
|
+
if "lm_head" not in name:
|
127
|
+
name = "model." + name
|
128
|
+
super().load_weights([(name, loaded_weight)])
|
129
|
+
|
130
|
+
|
131
|
+
EntryClass = [Qwen2ForCausalLMEagle]
|
@@ -3,6 +3,11 @@ from typing import List
|
|
3
3
|
import torch
|
4
4
|
|
5
5
|
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
6
|
+
from sglang.srt.utils import is_cuda_available
|
7
|
+
|
8
|
+
is_cuda = is_cuda_available()
|
9
|
+
if is_cuda:
|
10
|
+
from sgl_kernel import sampling_scaling_penalties
|
6
11
|
|
7
12
|
|
8
13
|
class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
@@ -56,11 +61,16 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
|
56
61
|
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
|
57
62
|
|
58
63
|
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
+
if is_cuda:
|
65
|
+
return sampling_scaling_penalties(
|
66
|
+
logits, self.cumulated_repetition_penalties
|
67
|
+
)
|
68
|
+
else:
|
69
|
+
return torch.where(
|
70
|
+
logits > 0,
|
71
|
+
logits / self.cumulated_repetition_penalties,
|
72
|
+
logits * self.cumulated_repetition_penalties,
|
73
|
+
)
|
64
74
|
|
65
75
|
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
66
76
|
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
|
@@ -7,6 +7,12 @@ from typing import TYPE_CHECKING, Callable, List, Optional
|
|
7
7
|
|
8
8
|
import torch
|
9
9
|
|
10
|
+
from sglang.srt.utils import is_cuda_available
|
11
|
+
|
12
|
+
is_cuda = is_cuda_available()
|
13
|
+
if is_cuda:
|
14
|
+
from sgl_kernel import sampling_scaling_penalties
|
15
|
+
|
10
16
|
import sglang.srt.sampling.penaltylib as penaltylib
|
11
17
|
|
12
18
|
logger = logging.getLogger(__name__)
|
@@ -245,11 +251,14 @@ class SamplingBatchInfo:
|
|
245
251
|
|
246
252
|
# repetition
|
247
253
|
if self.scaling_penalties is not None:
|
248
|
-
|
249
|
-
logits
|
250
|
-
|
251
|
-
logits
|
252
|
-
|
254
|
+
if is_cuda:
|
255
|
+
logits[:] = sampling_scaling_penalties(logits, self.scaling_penalties)
|
256
|
+
else:
|
257
|
+
logits[:] = torch.where(
|
258
|
+
logits > 0,
|
259
|
+
logits / self.scaling_penalties,
|
260
|
+
logits * self.scaling_penalties,
|
261
|
+
)
|
253
262
|
|
254
263
|
# Apply regex vocab_mask
|
255
264
|
if self.vocab_mask is not None:
|
sglang/srt/server.py
CHANGED
@@ -31,6 +31,8 @@ from typing import AsyncIterator, Dict, List, Optional, Tuple, Union
|
|
31
31
|
|
32
32
|
import torch
|
33
33
|
|
34
|
+
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
35
|
+
|
34
36
|
# Fix a bug of Python threading
|
35
37
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
36
38
|
|
@@ -52,11 +54,14 @@ from sglang.srt.managers.data_parallel_controller import (
|
|
52
54
|
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
|
53
55
|
from sglang.srt.managers.io_struct import (
|
54
56
|
CloseSessionReqInput,
|
57
|
+
ConfigureLoggingReq,
|
55
58
|
EmbeddingReqInput,
|
56
59
|
GenerateReqInput,
|
57
60
|
GetWeightsByNameReqInput,
|
58
61
|
InitWeightsUpdateGroupReqInput,
|
59
62
|
OpenSessionReqInput,
|
63
|
+
ReleaseMemoryOccupationReqInput,
|
64
|
+
ResumeMemoryOccupationReqInput,
|
60
65
|
UpdateWeightFromDiskReqInput,
|
61
66
|
UpdateWeightsFromDistributedReqInput,
|
62
67
|
UpdateWeightsFromTensorReqInput,
|
@@ -157,12 +162,68 @@ async def get_model_info():
|
|
157
162
|
@app.get("/get_server_info")
|
158
163
|
async def get_server_info():
|
159
164
|
return {
|
160
|
-
**dataclasses.asdict(tokenizer_manager.server_args),
|
165
|
+
**dataclasses.asdict(tokenizer_manager.server_args),
|
161
166
|
**scheduler_info,
|
162
167
|
"version": __version__,
|
163
168
|
}
|
164
169
|
|
165
170
|
|
171
|
+
# fastapi implicitly converts json in the request to obj (dataclass)
|
172
|
+
@app.api_route("/generate", methods=["POST", "PUT"])
|
173
|
+
@time_func_latency
|
174
|
+
async def generate_request(obj: GenerateReqInput, request: Request):
|
175
|
+
"""Handle a generate request."""
|
176
|
+
if obj.stream:
|
177
|
+
|
178
|
+
async def stream_results() -> AsyncIterator[bytes]:
|
179
|
+
try:
|
180
|
+
async for out in tokenizer_manager.generate_request(obj, request):
|
181
|
+
yield b"data: " + orjson.dumps(
|
182
|
+
out, option=orjson.OPT_NON_STR_KEYS
|
183
|
+
) + b"\n\n"
|
184
|
+
except ValueError as e:
|
185
|
+
out = {"error": {"message": str(e)}}
|
186
|
+
yield b"data: " + orjson.dumps(
|
187
|
+
out, option=orjson.OPT_NON_STR_KEYS
|
188
|
+
) + b"\n\n"
|
189
|
+
yield b"data: [DONE]\n\n"
|
190
|
+
|
191
|
+
return StreamingResponse(
|
192
|
+
stream_results(),
|
193
|
+
media_type="text/event-stream",
|
194
|
+
background=tokenizer_manager.create_abort_task(obj),
|
195
|
+
)
|
196
|
+
else:
|
197
|
+
try:
|
198
|
+
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
199
|
+
return ret
|
200
|
+
except ValueError as e:
|
201
|
+
logger.error(f"Error: {e}")
|
202
|
+
return _create_error_response(e)
|
203
|
+
|
204
|
+
|
205
|
+
@app.api_route("/encode", methods=["POST", "PUT"])
|
206
|
+
@time_func_latency
|
207
|
+
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
208
|
+
"""Handle an embedding request."""
|
209
|
+
try:
|
210
|
+
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
211
|
+
return ret
|
212
|
+
except ValueError as e:
|
213
|
+
return _create_error_response(e)
|
214
|
+
|
215
|
+
|
216
|
+
@app.api_route("/classify", methods=["POST", "PUT"])
|
217
|
+
@time_func_latency
|
218
|
+
async def classify_request(obj: EmbeddingReqInput, request: Request):
|
219
|
+
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
|
220
|
+
try:
|
221
|
+
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
222
|
+
return ret
|
223
|
+
except ValueError as e:
|
224
|
+
return _create_error_response(e)
|
225
|
+
|
226
|
+
|
166
227
|
@app.post("/flush_cache")
|
167
228
|
async def flush_cache():
|
168
229
|
"""Flush the radix cache."""
|
@@ -174,8 +235,7 @@ async def flush_cache():
|
|
174
235
|
)
|
175
236
|
|
176
237
|
|
177
|
-
@app.
|
178
|
-
@app.post("/start_profile")
|
238
|
+
@app.api_route("/start_profile", methods=["GET", "POST"])
|
179
239
|
async def start_profile_async():
|
180
240
|
"""Start profiling."""
|
181
241
|
tokenizer_manager.start_profile()
|
@@ -185,8 +245,7 @@ async def start_profile_async():
|
|
185
245
|
)
|
186
246
|
|
187
247
|
|
188
|
-
@app.
|
189
|
-
@app.post("/stop_profile")
|
248
|
+
@app.api_route("/stop_profile", methods=["GET", "POST"])
|
190
249
|
async def stop_profile_async():
|
191
250
|
"""Stop profiling."""
|
192
251
|
tokenizer_manager.stop_profile()
|
@@ -255,6 +314,28 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
|
|
255
314
|
return _create_error_response(e)
|
256
315
|
|
257
316
|
|
317
|
+
@app.api_route("/release_memory_occupation", methods=["GET", "POST"])
|
318
|
+
async def release_memory_occupation(
|
319
|
+
obj: ReleaseMemoryOccupationReqInput, request: Request
|
320
|
+
):
|
321
|
+
"""Release GPU occupation temporarily"""
|
322
|
+
try:
|
323
|
+
await tokenizer_manager.release_memory_occupation(obj, request)
|
324
|
+
except Exception as e:
|
325
|
+
return _create_error_response(e)
|
326
|
+
|
327
|
+
|
328
|
+
@app.api_route("/resume_memory_occupation", methods=["GET", "POST"])
|
329
|
+
async def resume_memory_occupation(
|
330
|
+
obj: ResumeMemoryOccupationReqInput, request: Request
|
331
|
+
):
|
332
|
+
"""Resume GPU occupation"""
|
333
|
+
try:
|
334
|
+
await tokenizer_manager.resume_memory_occupation(obj, request)
|
335
|
+
except Exception as e:
|
336
|
+
return _create_error_response(e)
|
337
|
+
|
338
|
+
|
258
339
|
@app.api_route("/open_session", methods=["GET", "POST"])
|
259
340
|
async def open_session(obj: OpenSessionReqInput, request: Request):
|
260
341
|
"""Open a session, and return its unique session id."""
|
@@ -279,60 +360,11 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
|
|
279
360
|
return _create_error_response(e)
|
280
361
|
|
281
362
|
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
if obj.stream:
|
288
|
-
|
289
|
-
async def stream_results() -> AsyncIterator[bytes]:
|
290
|
-
try:
|
291
|
-
async for out in tokenizer_manager.generate_request(obj, request):
|
292
|
-
yield b"data: " + orjson.dumps(
|
293
|
-
out, option=orjson.OPT_NON_STR_KEYS
|
294
|
-
) + b"\n\n"
|
295
|
-
except ValueError as e:
|
296
|
-
out = {"error": {"message": str(e)}}
|
297
|
-
yield b"data: " + orjson.dumps(
|
298
|
-
out, option=orjson.OPT_NON_STR_KEYS
|
299
|
-
) + b"\n\n"
|
300
|
-
yield b"data: [DONE]\n\n"
|
301
|
-
|
302
|
-
return StreamingResponse(
|
303
|
-
stream_results(),
|
304
|
-
media_type="text/event-stream",
|
305
|
-
background=tokenizer_manager.create_abort_task(obj),
|
306
|
-
)
|
307
|
-
else:
|
308
|
-
try:
|
309
|
-
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
310
|
-
return ret
|
311
|
-
except ValueError as e:
|
312
|
-
logger.error(f"Error: {e}")
|
313
|
-
return _create_error_response(e)
|
314
|
-
|
315
|
-
|
316
|
-
@app.api_route("/encode", methods=["POST", "PUT"])
|
317
|
-
@time_func_latency
|
318
|
-
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
319
|
-
"""Handle an embedding request."""
|
320
|
-
try:
|
321
|
-
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
322
|
-
return ret
|
323
|
-
except ValueError as e:
|
324
|
-
return _create_error_response(e)
|
325
|
-
|
326
|
-
|
327
|
-
@app.api_route("/classify", methods=["POST", "PUT"])
|
328
|
-
@time_func_latency
|
329
|
-
async def classify_request(obj: EmbeddingReqInput, request: Request):
|
330
|
-
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
|
331
|
-
try:
|
332
|
-
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
333
|
-
return ret
|
334
|
-
except ValueError as e:
|
335
|
-
return _create_error_response(e)
|
363
|
+
@app.api_route("/configure_logging", methods=["GET", "POST"])
|
364
|
+
async def configure_logging(obj: ConfigureLoggingReq, request: Request):
|
365
|
+
"""Close the session"""
|
366
|
+
tokenizer_manager.configure_logging(obj)
|
367
|
+
return Response(status_code=200)
|
336
368
|
|
337
369
|
|
338
370
|
##### OpenAI-compatible API endpoints #####
|
@@ -438,6 +470,10 @@ def launch_engine(
|
|
438
470
|
server_args.model_path, server_args.tokenizer_path
|
439
471
|
)
|
440
472
|
|
473
|
+
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
474
|
+
enable=server_args.enable_memory_saver
|
475
|
+
)
|
476
|
+
|
441
477
|
if server_args.dp_size == 1:
|
442
478
|
# Launch tensor parallel scheduler processes
|
443
479
|
scheduler_procs = []
|
@@ -454,7 +490,8 @@ def launch_engine(
|
|
454
490
|
target=run_scheduler_process,
|
455
491
|
args=(server_args, port_args, gpu_id, tp_rank, None, writer),
|
456
492
|
)
|
457
|
-
|
493
|
+
with memory_saver_adapter.configure_subprocess():
|
494
|
+
proc.start()
|
458
495
|
scheduler_procs.append(proc)
|
459
496
|
scheduler_pipe_readers.append(reader)
|
460
497
|
|
@@ -471,7 +508,8 @@ def launch_engine(
|
|
471
508
|
target=run_data_parallel_controller_process,
|
472
509
|
args=(server_args, port_args, writer),
|
473
510
|
)
|
474
|
-
|
511
|
+
with memory_saver_adapter.configure_subprocess():
|
512
|
+
proc.start()
|
475
513
|
|
476
514
|
# Launch detokenizer process
|
477
515
|
detoken_proc = mp.Process(
|
@@ -611,6 +649,9 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
611
649
|
# The child processes will send SIGQUIT to this process when any error happens
|
612
650
|
# This process then clean up the whole process tree
|
613
651
|
def sigquit_handler(signum, frame):
|
652
|
+
logger.error(
|
653
|
+
"Received sigquit from a child proces. It usually means the child failed."
|
654
|
+
)
|
614
655
|
kill_process_tree(os.getpid())
|
615
656
|
|
616
657
|
signal.signal(signal.SIGQUIT, sigquit_handler)
|
@@ -894,6 +935,18 @@ class Engine:
|
|
894
935
|
loop = asyncio.get_event_loop()
|
895
936
|
return loop.run_until_complete(tokenizer_manager.get_weights_by_name(obj, None))
|
896
937
|
|
938
|
+
def release_memory_occupation(self):
|
939
|
+
"""Release GPU occupation temporarily"""
|
940
|
+
obj = ReleaseMemoryOccupationReqInput()
|
941
|
+
loop = asyncio.get_event_loop()
|
942
|
+
loop.run_until_complete(tokenizer_manager.release_memory_occupation(obj, None))
|
943
|
+
|
944
|
+
def resume_memory_occupation(self):
|
945
|
+
"""Resume GPU occupation"""
|
946
|
+
obj = ResumeMemoryOccupationReqInput()
|
947
|
+
loop = asyncio.get_event_loop()
|
948
|
+
loop.run_until_complete(tokenizer_manager.resume_memory_occupation(obj, None))
|
949
|
+
|
897
950
|
|
898
951
|
class Runtime:
|
899
952
|
"""
|