sglang 0.4.1.post5__py3-none-any.whl → 0.4.1.post6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. sglang/srt/configs/model_config.py +15 -6
  2. sglang/srt/layers/attention/flashinfer_backend.py +17 -3
  3. sglang/srt/layers/linear.py +36 -98
  4. sglang/srt/layers/moe/fused_moe_triton/layer.py +37 -9
  5. sglang/srt/layers/moe/topk.py +4 -2
  6. sglang/srt/layers/parameter.py +24 -16
  7. sglang/srt/layers/quantization/__init__.py +2 -0
  8. sglang/srt/layers/quantization/fp8.py +106 -52
  9. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  10. sglang/srt/layers/quantization/int8_kernel.py +54 -0
  11. sglang/srt/layers/quantization/modelopt_quant.py +1 -1
  12. sglang/srt/layers/quantization/w8a8_int8.py +117 -0
  13. sglang/srt/layers/radix_attention.py +2 -0
  14. sglang/srt/layers/vocab_parallel_embedding.py +15 -2
  15. sglang/srt/managers/configure_logging.py +43 -0
  16. sglang/srt/managers/detokenizer_manager.py +0 -2
  17. sglang/srt/managers/io_struct.py +29 -13
  18. sglang/srt/managers/scheduler.py +48 -9
  19. sglang/srt/managers/tokenizer_manager.py +109 -49
  20. sglang/srt/mem_cache/memory_pool.py +107 -52
  21. sglang/srt/metrics/collector.py +10 -5
  22. sglang/srt/model_executor/model_runner.py +43 -6
  23. sglang/srt/models/llama.py +37 -2
  24. sglang/srt/models/qwen2.py +11 -0
  25. sglang/srt/models/qwen2_eagle.py +131 -0
  26. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
  27. sglang/srt/sampling/sampling_batch_info.py +14 -5
  28. sglang/srt/sampling/sampling_params.py +1 -1
  29. sglang/srt/server.py +114 -61
  30. sglang/srt/server_args.py +27 -18
  31. sglang/srt/speculative/eagle_worker.py +1 -0
  32. sglang/srt/torch_memory_saver_adapter.py +59 -0
  33. sglang/srt/utils.py +29 -0
  34. sglang/version.py +1 -1
  35. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post6.dist-info}/METADATA +12 -10
  36. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post6.dist-info}/RECORD +39 -34
  37. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post6.dist-info}/LICENSE +0 -0
  38. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post6.dist-info}/WHEEL +0 -0
  39. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post6.dist-info}/top_level.txt +0 -0
@@ -50,10 +50,12 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
50
50
  from sglang.srt.model_loader import get_model
51
51
  from sglang.srt.server_args import ServerArgs
52
52
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
53
+ from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
53
54
  from sglang.srt.utils import (
54
55
  enable_show_time_cost,
55
56
  get_available_gpu_memory,
56
57
  init_custom_process_group,
58
+ is_cuda,
57
59
  is_hip,
58
60
  monkey_patch_vllm_gguf_config,
59
61
  monkey_patch_vllm_p2p_access_check,
@@ -165,6 +167,10 @@ class ModelRunner:
165
167
  # Get memory before model loading
166
168
  min_per_gpu_memory = self.init_torch_distributed()
167
169
 
170
+ self.memory_saver_adapter = TorchMemorySaverAdapter.create(
171
+ enable=self.server_args.enable_memory_saver
172
+ )
173
+
168
174
  # Load the model
169
175
  self.sampler = Sampler()
170
176
  self.load_model()
@@ -271,11 +277,35 @@ class ModelRunner:
271
277
  monkey_patch_vllm_gguf_config()
272
278
 
273
279
  # Load the model
274
- self.model = get_model(
275
- model_config=self.model_config,
276
- load_config=self.load_config,
277
- device_config=DeviceConfig(self.device),
278
- )
280
+ with self.memory_saver_adapter.region():
281
+ self.model = get_model(
282
+ model_config=self.model_config,
283
+ load_config=self.load_config,
284
+ device_config=DeviceConfig(self.device),
285
+ )
286
+
287
+ if self.server_args.kv_cache_dtype == "fp8_e4m3":
288
+ if self.server_args.quantization_param_path is not None:
289
+ if callable(getattr(self.model, "load_kv_cache_scales", None)):
290
+ self.model.load_kv_cache_scales(
291
+ self.server_args.quantization_param_path
292
+ )
293
+ logger.info(
294
+ "Loaded KV cache scaling factors from %s",
295
+ self.server_args.quantization_param_path,
296
+ )
297
+ else:
298
+ raise RuntimeError(
299
+ "Using FP8 KV cache and scaling factors provided but "
300
+ "model %s does not support loading scaling factors.",
301
+ self.model.__class__,
302
+ )
303
+ else:
304
+ logger.warning(
305
+ "Using FP8 KV cache but no scaling factors "
306
+ "provided. Defaulting to scaling factors of 1.0. "
307
+ "This may lead to less accurate results!"
308
+ )
279
309
 
280
310
  # Parse other args
281
311
  self.sliding_window_size = (
@@ -393,7 +423,7 @@ class ModelRunner:
393
423
 
394
424
  logger.info(
395
425
  f"init custom process group: master_address={master_address}, master_port={master_port}, "
396
- f"rank_offset={rank_offset}, world_size={world_size}, group_name={group_name}, backend={backend}"
426
+ f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
397
427
  )
398
428
 
399
429
  try:
@@ -516,6 +546,9 @@ class ModelRunner:
516
546
  self.kv_cache_dtype = torch.float8_e5m2fnuz
517
547
  else:
518
548
  self.kv_cache_dtype = torch.float8_e5m2
549
+ elif self.server_args.kv_cache_dtype == "fp8_e4m3":
550
+ if is_cuda():
551
+ self.kv_cache_dtype = torch.float8_e4m3fn
519
552
  else:
520
553
  raise ValueError(
521
554
  f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
@@ -563,6 +596,7 @@ class ModelRunner:
563
596
  max_context_len=self.model_config.context_len + 4,
564
597
  device=self.device,
565
598
  use_records=False,
599
+ enable_memory_saver=self.server_args.enable_memory_saver,
566
600
  )
567
601
  if (
568
602
  self.model_config.attention_arch == AttentionArch.MLA
@@ -575,6 +609,7 @@ class ModelRunner:
575
609
  qk_rope_head_dim=self.model_config.qk_rope_head_dim,
576
610
  layer_num=self.model_config.num_hidden_layers,
577
611
  device=self.device,
612
+ enable_memory_saver=self.server_args.enable_memory_saver,
578
613
  )
579
614
  elif self.server_args.enable_double_sparsity:
580
615
  self.token_to_kv_pool = DoubleSparseTokenToKVPool(
@@ -585,6 +620,7 @@ class ModelRunner:
585
620
  layer_num=self.model_config.num_hidden_layers,
586
621
  device=self.device,
587
622
  heavy_channel_num=self.server_args.ds_heavy_channel_num,
623
+ enable_memory_saver=self.server_args.enable_memory_saver,
588
624
  )
589
625
  else:
590
626
  self.token_to_kv_pool = MHATokenToKVPool(
@@ -594,6 +630,7 @@ class ModelRunner:
594
630
  head_dim=self.model_config.head_dim,
595
631
  layer_num=self.model_config.num_hidden_layers,
596
632
  device=self.device,
633
+ enable_memory_saver=self.server_args.enable_memory_saver,
597
634
  )
598
635
  logger.info(
599
636
  f"Memory pool end. "
@@ -22,8 +22,12 @@ from typing import Any, Dict, Iterable, Optional, Tuple
22
22
  import torch
23
23
  from torch import nn
24
24
  from transformers import LlamaConfig
25
- from vllm.distributed import get_tensor_model_parallel_world_size
25
+ from vllm.distributed import (
26
+ get_tensor_model_parallel_rank,
27
+ get_tensor_model_parallel_world_size,
28
+ )
26
29
  from vllm.model_executor.layers.rotary_embedding import get_rope
30
+ from vllm.model_executor.model_loader.weight_utils import kv_cache_scales_loader
27
31
 
28
32
  from sglang.srt.layers.activation import SiluAndMul
29
33
  from sglang.srt.layers.layernorm import RMSNorm
@@ -299,6 +303,30 @@ class LlamaModel(nn.Module):
299
303
  hidden_states, _ = self.norm(hidden_states, residual)
300
304
  return hidden_states
301
305
 
306
+ # If this function is called, it should always initialize KV cache scale
307
+ # factors (or else raise an exception). Thus, handled exceptions should
308
+ # make sure to leave KV cache scale factors in a known good (dummy) state
309
+ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
310
+ tp_size = get_tensor_model_parallel_world_size()
311
+ tp_rank = get_tensor_model_parallel_rank()
312
+ for layer_idx, scaling_factor in kv_cache_scales_loader(
313
+ quantization_param_path,
314
+ tp_rank,
315
+ tp_size,
316
+ self.config.num_hidden_layers,
317
+ self.config.__class__.model_type,
318
+ ):
319
+ if not isinstance(self.layers[layer_idx], nn.Identity):
320
+ layer_self_attn = self.layers[layer_idx].self_attn
321
+
322
+ if hasattr(layer_self_attn.attn, "k_scale"):
323
+ layer_self_attn.attn.k_scale = scaling_factor
324
+ layer_self_attn.attn.v_scale = scaling_factor
325
+ else:
326
+ raise RuntimeError(
327
+ "Self attention has no KV cache scaling " "factor attribute!"
328
+ )
329
+
302
330
 
303
331
  class LlamaForCausalLM(nn.Module):
304
332
 
@@ -534,9 +562,16 @@ class LlamaForCausalLM(nn.Module):
534
562
  torch.cuda.empty_cache()
535
563
  torch.cuda.synchronize()
536
564
 
565
+ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
566
+ self.model.load_kv_cache_scales(quantization_param_path)
567
+
537
568
 
538
569
  class Phi3ForCausalLM(LlamaForCausalLM):
539
570
  pass
540
571
 
541
572
 
542
- EntryClass = [LlamaForCausalLM, Phi3ForCausalLM]
573
+ class InternLM3ForCausalLM(LlamaForCausalLM):
574
+ pass
575
+
576
+
577
+ EntryClass = [LlamaForCausalLM, Phi3ForCausalLM, InternLM3ForCausalLM]
@@ -362,5 +362,16 @@ class Qwen2ForCausalLM(nn.Module):
362
362
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
363
363
  weight_loader(param, loaded_weight)
364
364
 
365
+ def get_embed_and_head(self):
366
+ return self.model.embed_tokens.weight, self.lm_head.weight
367
+
368
+ def set_embed_and_head(self, embed, head):
369
+ del self.model.embed_tokens.weight
370
+ del self.lm_head.weight
371
+ self.model.embed_tokens.weight = embed
372
+ self.lm_head.weight = head
373
+ torch.cuda.empty_cache()
374
+ torch.cuda.synchronize()
375
+
365
376
 
366
377
  EntryClass = Qwen2ForCausalLM
@@ -0,0 +1,131 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ # Adapted from
17
+ # https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py
18
+ """Inference-only LLaMA-EAGLE model compatible with HuggingFace weights."""
19
+
20
+ from typing import Iterable, Optional, Tuple
21
+
22
+ import torch
23
+ from torch import nn
24
+
25
+ from sglang.srt.layers.logits_processor import LogitsProcessor
26
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
27
+ from sglang.srt.layers.vocab_parallel_embedding import (
28
+ ParallelLMHead,
29
+ VocabParallelEmbedding,
30
+ )
31
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
32
+ from sglang.srt.models.qwen2 import Qwen2DecoderLayer, Qwen2ForCausalLM
33
+
34
+ Qwen2Config = None
35
+
36
+
37
+ class Qwen2DecoderLayer(Qwen2DecoderLayer):
38
+ def __init__(
39
+ self,
40
+ config: Qwen2Config,
41
+ layer_id: int = 0,
42
+ quant_config: Optional[QuantizationConfig] = None,
43
+ prefix: str = "",
44
+ ) -> None:
45
+ super().__init__(config, layer_id, quant_config)
46
+
47
+ # Skip the input_layernorm
48
+ # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
49
+ if layer_id == 0:
50
+ del self.input_layernorm
51
+ setattr(self, "input_layernorm", lambda x: x)
52
+
53
+
54
+ class Qwen2Model(nn.Module):
55
+ def __init__(
56
+ self,
57
+ config: Qwen2Config,
58
+ quant_config: Optional[QuantizationConfig] = None,
59
+ ) -> None:
60
+ super().__init__()
61
+ self.config = config
62
+ self.vocab_size = config.vocab_size
63
+ self.embed_tokens = VocabParallelEmbedding(
64
+ config.vocab_size,
65
+ config.hidden_size,
66
+ )
67
+ self.layers = nn.ModuleList(
68
+ [
69
+ Qwen2DecoderLayer(
70
+ config, i, quant_config=quant_config, prefix=f"model.layers.{i}"
71
+ )
72
+ for i in range(config.num_hidden_layers)
73
+ ]
74
+ )
75
+ self.fc = torch.nn.Linear(config.hidden_size * 2, config.hidden_size)
76
+
77
+ def forward(
78
+ self,
79
+ input_ids: torch.Tensor,
80
+ positions: torch.Tensor,
81
+ forward_batch: ForwardBatch,
82
+ input_embeds: torch.Tensor = None,
83
+ ) -> torch.Tensor:
84
+ if input_embeds is None:
85
+ hidden_states = self.embed_tokens(input_ids)
86
+ else:
87
+ hidden_states = input_embeds
88
+
89
+ hidden_states = self.fc(
90
+ torch.cat((hidden_states, forward_batch.spec_info.hidden_states), dim=-1)
91
+ )
92
+
93
+ residual = None
94
+ for i in range(len(self.layers)):
95
+ layer = self.layers[i]
96
+ hidden_states, residual = layer(
97
+ positions,
98
+ hidden_states,
99
+ forward_batch,
100
+ residual,
101
+ )
102
+ return hidden_states + residual
103
+
104
+
105
+ class Qwen2ForCausalLMEagle(Qwen2ForCausalLM):
106
+ def __init__(
107
+ self,
108
+ config: Qwen2Config,
109
+ quant_config: Optional[QuantizationConfig] = None,
110
+ cache_config=None,
111
+ ) -> None:
112
+ nn.Module.__init__(self)
113
+ self.config = config
114
+ self.quant_config = quant_config
115
+ self.model = Qwen2Model(config, quant_config=quant_config)
116
+ if self.config.tie_word_embeddings:
117
+ self.lm_head = self.model.embed_tokens
118
+ else:
119
+ self.lm_head = ParallelLMHead(
120
+ config.vocab_size, config.hidden_size, quant_config=quant_config
121
+ )
122
+ self.logits_processor = LogitsProcessor(config)
123
+
124
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
125
+ for name, loaded_weight in weights:
126
+ if "lm_head" not in name:
127
+ name = "model." + name
128
+ super().load_weights([(name, loaded_weight)])
129
+
130
+
131
+ EntryClass = [Qwen2ForCausalLMEagle]
@@ -3,6 +3,11 @@ from typing import List
3
3
  import torch
4
4
 
5
5
  from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
6
+ from sglang.srt.utils import is_cuda_available
7
+
8
+ is_cuda = is_cuda_available()
9
+ if is_cuda:
10
+ from sgl_kernel import sampling_scaling_penalties
6
11
 
7
12
 
8
13
  class BatchedRepetitionPenalizer(_BatchedPenalizer):
@@ -56,11 +61,16 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
56
61
  self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
57
62
 
58
63
  def _apply(self, logits: torch.Tensor) -> torch.Tensor:
59
- return torch.where(
60
- logits > 0,
61
- logits / self.cumulated_repetition_penalties,
62
- logits * self.cumulated_repetition_penalties,
63
- )
64
+ if is_cuda:
65
+ return sampling_scaling_penalties(
66
+ logits, self.cumulated_repetition_penalties
67
+ )
68
+ else:
69
+ return torch.where(
70
+ logits > 0,
71
+ logits / self.cumulated_repetition_penalties,
72
+ logits * self.cumulated_repetition_penalties,
73
+ )
64
74
 
65
75
  def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
66
76
  self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
@@ -7,6 +7,12 @@ from typing import TYPE_CHECKING, Callable, List, Optional
7
7
 
8
8
  import torch
9
9
 
10
+ from sglang.srt.utils import is_cuda_available
11
+
12
+ is_cuda = is_cuda_available()
13
+ if is_cuda:
14
+ from sgl_kernel import sampling_scaling_penalties
15
+
10
16
  import sglang.srt.sampling.penaltylib as penaltylib
11
17
 
12
18
  logger = logging.getLogger(__name__)
@@ -245,11 +251,14 @@ class SamplingBatchInfo:
245
251
 
246
252
  # repetition
247
253
  if self.scaling_penalties is not None:
248
- logits[:] = torch.where(
249
- logits > 0,
250
- logits / self.scaling_penalties,
251
- logits * self.scaling_penalties,
252
- )
254
+ if is_cuda:
255
+ logits[:] = sampling_scaling_penalties(logits, self.scaling_penalties)
256
+ else:
257
+ logits[:] = torch.where(
258
+ logits > 0,
259
+ logits / self.scaling_penalties,
260
+ logits * self.scaling_penalties,
261
+ )
253
262
 
254
263
  # Apply regex vocab_mask
255
264
  if self.vocab_mask is not None:
@@ -23,7 +23,7 @@ class SamplingParams:
23
23
  The sampling parameters.
24
24
 
25
25
  See docs/references/sampling_params.md or
26
- https://sgl-project.github.io/references/sampling_params.html
26
+ https://docs.sglang.ai/references/sampling_params.html
27
27
  for the documentation.
28
28
  """
29
29
 
sglang/srt/server.py CHANGED
@@ -31,6 +31,8 @@ from typing import AsyncIterator, Dict, List, Optional, Tuple, Union
31
31
 
32
32
  import torch
33
33
 
34
+ from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
35
+
34
36
  # Fix a bug of Python threading
35
37
  setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
36
38
 
@@ -52,11 +54,14 @@ from sglang.srt.managers.data_parallel_controller import (
52
54
  from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
53
55
  from sglang.srt.managers.io_struct import (
54
56
  CloseSessionReqInput,
57
+ ConfigureLoggingReq,
55
58
  EmbeddingReqInput,
56
59
  GenerateReqInput,
57
60
  GetWeightsByNameReqInput,
58
61
  InitWeightsUpdateGroupReqInput,
59
62
  OpenSessionReqInput,
63
+ ReleaseMemoryOccupationReqInput,
64
+ ResumeMemoryOccupationReqInput,
60
65
  UpdateWeightFromDiskReqInput,
61
66
  UpdateWeightsFromDistributedReqInput,
62
67
  UpdateWeightsFromTensorReqInput,
@@ -157,12 +162,68 @@ async def get_model_info():
157
162
  @app.get("/get_server_info")
158
163
  async def get_server_info():
159
164
  return {
160
- **dataclasses.asdict(tokenizer_manager.server_args), # server args
165
+ **dataclasses.asdict(tokenizer_manager.server_args),
161
166
  **scheduler_info,
162
167
  "version": __version__,
163
168
  }
164
169
 
165
170
 
171
+ # fastapi implicitly converts json in the request to obj (dataclass)
172
+ @app.api_route("/generate", methods=["POST", "PUT"])
173
+ @time_func_latency
174
+ async def generate_request(obj: GenerateReqInput, request: Request):
175
+ """Handle a generate request."""
176
+ if obj.stream:
177
+
178
+ async def stream_results() -> AsyncIterator[bytes]:
179
+ try:
180
+ async for out in tokenizer_manager.generate_request(obj, request):
181
+ yield b"data: " + orjson.dumps(
182
+ out, option=orjson.OPT_NON_STR_KEYS
183
+ ) + b"\n\n"
184
+ except ValueError as e:
185
+ out = {"error": {"message": str(e)}}
186
+ yield b"data: " + orjson.dumps(
187
+ out, option=orjson.OPT_NON_STR_KEYS
188
+ ) + b"\n\n"
189
+ yield b"data: [DONE]\n\n"
190
+
191
+ return StreamingResponse(
192
+ stream_results(),
193
+ media_type="text/event-stream",
194
+ background=tokenizer_manager.create_abort_task(obj),
195
+ )
196
+ else:
197
+ try:
198
+ ret = await tokenizer_manager.generate_request(obj, request).__anext__()
199
+ return ret
200
+ except ValueError as e:
201
+ logger.error(f"Error: {e}")
202
+ return _create_error_response(e)
203
+
204
+
205
+ @app.api_route("/encode", methods=["POST", "PUT"])
206
+ @time_func_latency
207
+ async def encode_request(obj: EmbeddingReqInput, request: Request):
208
+ """Handle an embedding request."""
209
+ try:
210
+ ret = await tokenizer_manager.generate_request(obj, request).__anext__()
211
+ return ret
212
+ except ValueError as e:
213
+ return _create_error_response(e)
214
+
215
+
216
+ @app.api_route("/classify", methods=["POST", "PUT"])
217
+ @time_func_latency
218
+ async def classify_request(obj: EmbeddingReqInput, request: Request):
219
+ """Handle a reward model request. Now the arguments and return values are the same as embedding models."""
220
+ try:
221
+ ret = await tokenizer_manager.generate_request(obj, request).__anext__()
222
+ return ret
223
+ except ValueError as e:
224
+ return _create_error_response(e)
225
+
226
+
166
227
  @app.post("/flush_cache")
167
228
  async def flush_cache():
168
229
  """Flush the radix cache."""
@@ -174,8 +235,7 @@ async def flush_cache():
174
235
  )
175
236
 
176
237
 
177
- @app.get("/start_profile")
178
- @app.post("/start_profile")
238
+ @app.api_route("/start_profile", methods=["GET", "POST"])
179
239
  async def start_profile_async():
180
240
  """Start profiling."""
181
241
  tokenizer_manager.start_profile()
@@ -185,8 +245,7 @@ async def start_profile_async():
185
245
  )
186
246
 
187
247
 
188
- @app.get("/stop_profile")
189
- @app.post("/stop_profile")
248
+ @app.api_route("/stop_profile", methods=["GET", "POST"])
190
249
  async def stop_profile_async():
191
250
  """Stop profiling."""
192
251
  tokenizer_manager.stop_profile()
@@ -255,6 +314,28 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
255
314
  return _create_error_response(e)
256
315
 
257
316
 
317
+ @app.api_route("/release_memory_occupation", methods=["GET", "POST"])
318
+ async def release_memory_occupation(
319
+ obj: ReleaseMemoryOccupationReqInput, request: Request
320
+ ):
321
+ """Release GPU occupation temporarily"""
322
+ try:
323
+ await tokenizer_manager.release_memory_occupation(obj, request)
324
+ except Exception as e:
325
+ return _create_error_response(e)
326
+
327
+
328
+ @app.api_route("/resume_memory_occupation", methods=["GET", "POST"])
329
+ async def resume_memory_occupation(
330
+ obj: ResumeMemoryOccupationReqInput, request: Request
331
+ ):
332
+ """Resume GPU occupation"""
333
+ try:
334
+ await tokenizer_manager.resume_memory_occupation(obj, request)
335
+ except Exception as e:
336
+ return _create_error_response(e)
337
+
338
+
258
339
  @app.api_route("/open_session", methods=["GET", "POST"])
259
340
  async def open_session(obj: OpenSessionReqInput, request: Request):
260
341
  """Open a session, and return its unique session id."""
@@ -279,60 +360,11 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
279
360
  return _create_error_response(e)
280
361
 
281
362
 
282
- # fastapi implicitly converts json in the request to obj (dataclass)
283
- @app.api_route("/generate", methods=["POST", "PUT"])
284
- @time_func_latency
285
- async def generate_request(obj: GenerateReqInput, request: Request):
286
- """Handle a generate request."""
287
- if obj.stream:
288
-
289
- async def stream_results() -> AsyncIterator[bytes]:
290
- try:
291
- async for out in tokenizer_manager.generate_request(obj, request):
292
- yield b"data: " + orjson.dumps(
293
- out, option=orjson.OPT_NON_STR_KEYS
294
- ) + b"\n\n"
295
- except ValueError as e:
296
- out = {"error": {"message": str(e)}}
297
- yield b"data: " + orjson.dumps(
298
- out, option=orjson.OPT_NON_STR_KEYS
299
- ) + b"\n\n"
300
- yield b"data: [DONE]\n\n"
301
-
302
- return StreamingResponse(
303
- stream_results(),
304
- media_type="text/event-stream",
305
- background=tokenizer_manager.create_abort_task(obj),
306
- )
307
- else:
308
- try:
309
- ret = await tokenizer_manager.generate_request(obj, request).__anext__()
310
- return ret
311
- except ValueError as e:
312
- logger.error(f"Error: {e}")
313
- return _create_error_response(e)
314
-
315
-
316
- @app.api_route("/encode", methods=["POST", "PUT"])
317
- @time_func_latency
318
- async def encode_request(obj: EmbeddingReqInput, request: Request):
319
- """Handle an embedding request."""
320
- try:
321
- ret = await tokenizer_manager.generate_request(obj, request).__anext__()
322
- return ret
323
- except ValueError as e:
324
- return _create_error_response(e)
325
-
326
-
327
- @app.api_route("/classify", methods=["POST", "PUT"])
328
- @time_func_latency
329
- async def classify_request(obj: EmbeddingReqInput, request: Request):
330
- """Handle a reward model request. Now the arguments and return values are the same as embedding models."""
331
- try:
332
- ret = await tokenizer_manager.generate_request(obj, request).__anext__()
333
- return ret
334
- except ValueError as e:
335
- return _create_error_response(e)
363
+ @app.api_route("/configure_logging", methods=["GET", "POST"])
364
+ async def configure_logging(obj: ConfigureLoggingReq, request: Request):
365
+ """Close the session"""
366
+ tokenizer_manager.configure_logging(obj)
367
+ return Response(status_code=200)
336
368
 
337
369
 
338
370
  ##### OpenAI-compatible API endpoints #####
@@ -438,6 +470,10 @@ def launch_engine(
438
470
  server_args.model_path, server_args.tokenizer_path
439
471
  )
440
472
 
473
+ memory_saver_adapter = TorchMemorySaverAdapter.create(
474
+ enable=server_args.enable_memory_saver
475
+ )
476
+
441
477
  if server_args.dp_size == 1:
442
478
  # Launch tensor parallel scheduler processes
443
479
  scheduler_procs = []
@@ -454,7 +490,8 @@ def launch_engine(
454
490
  target=run_scheduler_process,
455
491
  args=(server_args, port_args, gpu_id, tp_rank, None, writer),
456
492
  )
457
- proc.start()
493
+ with memory_saver_adapter.configure_subprocess():
494
+ proc.start()
458
495
  scheduler_procs.append(proc)
459
496
  scheduler_pipe_readers.append(reader)
460
497
 
@@ -471,7 +508,8 @@ def launch_engine(
471
508
  target=run_data_parallel_controller_process,
472
509
  args=(server_args, port_args, writer),
473
510
  )
474
- proc.start()
511
+ with memory_saver_adapter.configure_subprocess():
512
+ proc.start()
475
513
 
476
514
  # Launch detokenizer process
477
515
  detoken_proc = mp.Process(
@@ -611,6 +649,9 @@ def _set_envs_and_config(server_args: ServerArgs):
611
649
  # The child processes will send SIGQUIT to this process when any error happens
612
650
  # This process then clean up the whole process tree
613
651
  def sigquit_handler(signum, frame):
652
+ logger.error(
653
+ "Received sigquit from a child proces. It usually means the child failed."
654
+ )
614
655
  kill_process_tree(os.getpid())
615
656
 
616
657
  signal.signal(signal.SIGQUIT, sigquit_handler)
@@ -894,6 +935,18 @@ class Engine:
894
935
  loop = asyncio.get_event_loop()
895
936
  return loop.run_until_complete(tokenizer_manager.get_weights_by_name(obj, None))
896
937
 
938
+ def release_memory_occupation(self):
939
+ """Release GPU occupation temporarily"""
940
+ obj = ReleaseMemoryOccupationReqInput()
941
+ loop = asyncio.get_event_loop()
942
+ loop.run_until_complete(tokenizer_manager.release_memory_occupation(obj, None))
943
+
944
+ def resume_memory_occupation(self):
945
+ """Resume GPU occupation"""
946
+ obj = ResumeMemoryOccupationReqInput()
947
+ loop = asyncio.get_event_loop()
948
+ loop.run_until_complete(tokenizer_manager.resume_memory_occupation(obj, None))
949
+
897
950
 
898
951
  class Runtime:
899
952
  """