sglang 0.4.1.post1__py3-none-any.whl → 0.4.1.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. sglang/bench_offline_throughput.py +1 -0
  2. sglang/srt/configs/model_config.py +11 -2
  3. sglang/srt/layers/attention/__init__.py +0 -1
  4. sglang/srt/layers/attention/flashinfer_backend.py +54 -41
  5. sglang/srt/layers/logits_processor.py +30 -2
  6. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -26
  7. sglang/srt/layers/quantization/fp8.py +42 -2
  8. sglang/srt/layers/quantization/fp8_kernel.py +77 -18
  9. sglang/srt/layers/quantization/fp8_utils.py +8 -2
  10. sglang/srt/managers/io_struct.py +29 -8
  11. sglang/srt/managers/schedule_batch.py +22 -15
  12. sglang/srt/managers/scheduler.py +60 -20
  13. sglang/srt/managers/session_controller.py +102 -27
  14. sglang/srt/managers/tokenizer_manager.py +41 -10
  15. sglang/srt/managers/tp_worker.py +7 -0
  16. sglang/srt/managers/tp_worker_overlap_thread.py +5 -0
  17. sglang/srt/model_executor/forward_batch_info.py +42 -3
  18. sglang/srt/model_executor/model_runner.py +4 -0
  19. sglang/srt/models/llama.py +11 -0
  20. sglang/srt/models/llama_eagle.py +132 -0
  21. sglang/srt/openai_api/adapter.py +60 -2
  22. sglang/srt/openai_api/protocol.py +48 -0
  23. sglang/srt/server.py +26 -3
  24. sglang/srt/server_args.py +17 -30
  25. sglang/srt/speculative/spec_info.py +19 -0
  26. sglang/srt/utils.py +62 -0
  27. sglang/version.py +1 -1
  28. {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post2.dist-info}/METADATA +3 -3
  29. {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post2.dist-info}/RECORD +32 -30
  30. {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post2.dist-info}/LICENSE +0 -0
  31. {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post2.dist-info}/WHEEL +0 -0
  32. {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post2.dist-info}/top_level.txt +0 -0
@@ -12,12 +12,23 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
 
15
- from typing import List, Tuple
15
+ import functools
16
+ import json
17
+ import logging
18
+ import os
19
+ from typing import Any, Dict, List, Optional, Tuple
16
20
 
17
21
  import torch
18
22
  import triton
19
23
  import triton.language as tl
20
24
 
25
+ from sglang.srt.utils import get_device_name, is_hip
26
+
27
+ is_hip_ = is_hip()
28
+ fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
29
+
30
+ logger = logging.getLogger(__name__)
31
+
21
32
 
22
33
  @triton.jit
23
34
  def _per_token_group_quant_fp8(
@@ -65,7 +76,7 @@ def per_token_group_quant_fp8(
65
76
  x: torch.Tensor,
66
77
  group_size: int,
67
78
  eps: float = 1e-10,
68
- dtype: torch.dtype = torch.float8_e4m3fn,
79
+ dtype: torch.dtype = fp8_type_,
69
80
  ) -> Tuple[torch.Tensor, torch.Tensor]:
70
81
  """Function to perform per-token-group quantization on an input tensor `x`.
71
82
 
@@ -87,9 +98,13 @@ def per_token_group_quant_fp8(
87
98
  assert x.is_contiguous(), "`x` is not contiguous"
88
99
 
89
100
  finfo = torch.finfo(dtype)
90
- fp8_min = finfo.min
91
101
  fp8_max = finfo.max
92
102
 
103
+ if is_hip_:
104
+ fp8_max = 224.0
105
+
106
+ fp8_min = -fp8_max
107
+
93
108
  x_q = torch.empty_like(x, device=x.device, dtype=dtype)
94
109
  M = x.numel() // group_size
95
110
  N = group_size
@@ -205,6 +220,48 @@ def _w8a8_block_fp8_matmul(
205
220
  tl.store(c_ptrs, c, mask=c_mask)
206
221
 
207
222
 
223
+ @functools.lru_cache
224
+ def get_w8a8_block_fp8_configs(
225
+ N: int, K: int, block_n: int, block_k: int
226
+ ) -> Optional[Dict[int, Any]]:
227
+ """
228
+ Return optimized configurations for the w8a8 block fp8 kernel.
229
+
230
+ The return value will be a dictionary that maps an irregular grid of
231
+ batch sizes to configurations of the w8a8 block fp8 kernel. To evaluate the
232
+ kernel on a given batch size bs, the closest batch size in the grid should
233
+ be picked and the associated configuration chosen to invoke the kernel.
234
+ """
235
+
236
+ # First look up if an optimized configuration is available in the configs
237
+ # directory
238
+ device_name = get_device_name().replace(" ", "_")
239
+ json_file_name = f"N={N},K={K},device_name={device_name},dtype=fp8_w8a8,block_shape=[{block_n}, {block_k}].json"
240
+
241
+ config_file_path = os.path.join(
242
+ os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
243
+ )
244
+ if os.path.exists(config_file_path):
245
+ with open(config_file_path) as f:
246
+ logger.info(
247
+ "Using configuration from %s for W8A8 Block FP8 kernel.",
248
+ config_file_path,
249
+ )
250
+ # If a configuration has been found, return it
251
+ return {int(key): val for key, val in json.load(f).items()}
252
+
253
+ # If no optimized configuration is available, we will use the default
254
+ # configuration
255
+ logger.warning(
256
+ (
257
+ "Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! "
258
+ "Config file not found at %s"
259
+ ),
260
+ config_file_path,
261
+ )
262
+ return None
263
+
264
+
208
265
  def w8a8_block_fp8_matmul(
209
266
  A: torch.Tensor,
210
267
  B: torch.Tensor,
@@ -245,17 +302,22 @@ def w8a8_block_fp8_matmul(
245
302
  C_shape = A.shape[:-1] + (N,)
246
303
  C = A.new_empty(C_shape, dtype=output_dtype)
247
304
 
248
- # TODO(HandH1998):
249
- # BLOCK_SIZE_M, BLOCK_SIZE_K, BLOCK_SIZE_N can be optimized.
250
- # BLOCK_SIZE_K must be divisable by block_k
251
- # BLOCK_SIZE_N and BLOCK_SIZE_M has no requirements
252
- BLOCK_SIZE_M = 128
253
- if M < BLOCK_SIZE_M:
254
- BLOCK_SIZE_M = triton.next_power_of_2(M)
255
- BLOCK_SIZE_M = max(BLOCK_SIZE_M, 16)
256
- BLOCK_SIZE_K = block_k
257
- assert block_k % BLOCK_SIZE_K == 0
258
- BLOCK_SIZE_N = block_n
305
+ configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
306
+ if configs:
307
+ # If an optimal configuration map has been found, look up the
308
+ # optimal config
309
+ config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
310
+ else:
311
+ # Default config
312
+ # Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
313
+ config = {
314
+ "BLOCK_SIZE_M": 64,
315
+ "BLOCK_SIZE_N": block_size[0],
316
+ "BLOCK_SIZE_K": block_size[1],
317
+ "GROUP_SIZE_M": 32,
318
+ "num_warps": 4,
319
+ "num_stages": 3,
320
+ }
259
321
 
260
322
  def grid(META):
261
323
  return (
@@ -283,10 +345,7 @@ def w8a8_block_fp8_matmul(
283
345
  As.stride(-1),
284
346
  Bs.stride(1),
285
347
  Bs.stride(0),
286
- BLOCK_SIZE_M=BLOCK_SIZE_M,
287
- BLOCK_SIZE_N=BLOCK_SIZE_N,
288
- BLOCK_SIZE_K=BLOCK_SIZE_K,
289
- GROUP_SIZE_M=8,
348
+ **config,
290
349
  )
291
350
 
292
351
  return C
@@ -7,6 +7,9 @@ from sglang.srt.layers.quantization.fp8_kernel import (
7
7
  per_token_group_quant_fp8,
8
8
  w8a8_block_fp8_matmul,
9
9
  )
10
+ from sglang.srt.utils import is_hip
11
+
12
+ is_hip_ = is_hip()
10
13
 
11
14
 
12
15
  def normalize_e4m3fn_to_e4m3fnuz(
@@ -63,8 +66,11 @@ def input_to_float8(
63
66
  finfo = torch.finfo(dtype)
64
67
  min_val, max_val = x.aminmax()
65
68
  amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
66
- scale = finfo.max / amax
67
- x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
69
+ fp8_max = finfo.max
70
+ if is_hip_:
71
+ fp8_max = 224.0
72
+ scale = fp8_max / amax
73
+ x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max)
68
74
  return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
69
75
 
70
76
 
@@ -21,10 +21,20 @@ from dataclasses import dataclass
21
21
  from enum import Enum
22
22
  from typing import Dict, List, Optional, Tuple, Union
23
23
 
24
+ import torch
25
+
24
26
  from sglang.srt.managers.schedule_batch import BaseFinishReason
25
27
  from sglang.srt.sampling.sampling_params import SamplingParams
26
28
 
27
29
 
30
+ @dataclass
31
+ class SessionParams:
32
+ id: Optional[str] = None
33
+ rid: Optional[str] = None
34
+ offset: Optional[int] = None
35
+ replace: Optional[bool] = None
36
+
37
+
28
38
  @dataclass
29
39
  class GenerateReqInput:
30
40
  # The input prompt. It can be a single prompt or a batch of prompts.
@@ -56,10 +66,8 @@ class GenerateReqInput:
56
66
  # LoRA related
57
67
  lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
58
68
 
59
- # Session id info for continual prompting
60
- session: Optional[
61
- Union[List[Tuple[str, Optional[str]]], Tuple[str, Optional[str]]]
62
- ] = None
69
+ # Session info for continual prompting
70
+ session_params: Optional[Union[List[Dict], Dict]] = None
63
71
 
64
72
  def normalize_batch_and_arguments(self):
65
73
  if (
@@ -221,9 +229,8 @@ class TokenizedGenerateReqInput:
221
229
  # The input embeds
222
230
  input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
223
231
 
224
- # Session id info for continual prompting
225
- session_id: Optional[str] = None
226
- session_rid: Optional[str] = None
232
+ # Session info for continual prompting
233
+ session_params: Optional[SessionParams] = None
227
234
 
228
235
 
229
236
  @dataclass
@@ -407,6 +414,18 @@ class UpdateWeightsFromDistributedReqOutput:
407
414
  message: str
408
415
 
409
416
 
417
+ @dataclass
418
+ class UpdateWeightsFromTensorReqInput:
419
+ name: str
420
+ tensor: torch.Tensor
421
+
422
+
423
+ @dataclass
424
+ class UpdateWeightsFromTensorReqOutput:
425
+ success: bool
426
+ message: str
427
+
428
+
410
429
  @dataclass
411
430
  class InitWeightsUpdateGroupReqInput:
412
431
  # The master address
@@ -454,6 +473,7 @@ class ProfileReq(Enum):
454
473
  @dataclass
455
474
  class OpenSessionReqInput:
456
475
  capacity_of_str_len: int
476
+ session_id: Optional[str] = None
457
477
 
458
478
 
459
479
  @dataclass
@@ -463,4 +483,5 @@ class CloseSessionReqInput:
463
483
 
464
484
  @dataclass
465
485
  class OpenSessionReqOutput:
466
- session_id: str
486
+ session_id: Optional[str]
487
+ success: bool
@@ -29,7 +29,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
29
29
 
30
30
  import dataclasses
31
31
  import logging
32
- from typing import List, Optional, Tuple, Union
32
+ from typing import List, Optional, Set, Tuple, Union
33
33
 
34
34
  import numpy as np
35
35
  import torch
@@ -209,6 +209,7 @@ class Req:
209
209
  lora_path: Optional[str] = None,
210
210
  input_embeds: Optional[List[List[float]]] = None,
211
211
  session_id: Optional[str] = None,
212
+ eos_token_ids: Optional[Set[int]] = None,
212
213
  ):
213
214
  # Input and output info
214
215
  self.rid = rid
@@ -236,6 +237,7 @@ class Req:
236
237
  self.finished_reason = None
237
238
  self.to_abort = False
238
239
  self.stream = stream
240
+ self.eos_token_ids = eos_token_ids
239
241
 
240
242
  # For incremental decoding
241
243
  # ----- | --------- read_ids -------|
@@ -395,18 +397,23 @@ class Req:
395
397
 
396
398
  last_token_id = self.output_ids[-1]
397
399
 
398
- matched_eos = False
399
-
400
- # Check stop token ids
401
- if self.sampling_params.stop_token_ids:
402
- matched_eos = last_token_id in self.sampling_params.stop_token_ids
403
- if self.tokenizer is not None:
404
- matched_eos |= last_token_id == self.tokenizer.eos_token_id
405
- if self.tokenizer.additional_stop_token_ids:
406
- matched_eos |= last_token_id in self.tokenizer.additional_stop_token_ids
407
- if matched_eos and not self.sampling_params.ignore_eos:
408
- self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
409
- return
400
+ if not self.sampling_params.ignore_eos:
401
+ matched_eos = False
402
+
403
+ # Check stop token ids
404
+ if self.sampling_params.stop_token_ids:
405
+ matched_eos = last_token_id in self.sampling_params.stop_token_ids
406
+ if self.eos_token_ids:
407
+ matched_eos |= last_token_id in self.eos_token_ids
408
+ if self.tokenizer is not None:
409
+ matched_eos |= last_token_id == self.tokenizer.eos_token_id
410
+ if self.tokenizer.additional_stop_token_ids:
411
+ matched_eos |= (
412
+ last_token_id in self.tokenizer.additional_stop_token_ids
413
+ )
414
+ if matched_eos:
415
+ self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
416
+ return
410
417
 
411
418
  # Check stop strings
412
419
  if len(self.sampling_params.stop_strs) > 0:
@@ -836,8 +843,8 @@ class ScheduleBatch:
836
843
  # TODO (lianmin): Revisit this. It should be seq_len - 1
837
844
  self.extend_logprob_start_lens.extend([0] * running_bs)
838
845
 
839
- def check_decode_mem(self):
840
- bs = len(self.reqs)
846
+ def check_decode_mem(self, buf_multiplier=1):
847
+ bs = len(self.reqs) * buf_multiplier
841
848
  if self.token_to_kv_pool.available_size() >= bs:
842
849
  return True
843
850
 
@@ -22,7 +22,7 @@ import warnings
22
22
  from collections import deque
23
23
  from concurrent import futures
24
24
  from types import SimpleNamespace
25
- from typing import Callable, Dict, List, Optional, Tuple
25
+ from typing import Dict, List, Optional, Tuple
26
26
 
27
27
  import psutil
28
28
  import setproctitle
@@ -52,6 +52,8 @@ from sglang.srt.managers.io_struct import (
52
52
  UpdateWeightFromDiskReqOutput,
53
53
  UpdateWeightsFromDistributedReqInput,
54
54
  UpdateWeightsFromDistributedReqOutput,
55
+ UpdateWeightsFromTensorReqInput,
56
+ UpdateWeightsFromTensorReqOutput,
55
57
  )
56
58
  from sglang.srt.managers.schedule_batch import (
57
59
  FINISH_ABORT,
@@ -88,7 +90,7 @@ from sglang.utils import get_exception_traceback
88
90
 
89
91
  logger = logging.getLogger(__name__)
90
92
 
91
- # Test retract decode
93
+ # Test retract decode for debugging purposes
92
94
  test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
93
95
 
94
96
 
@@ -127,12 +129,12 @@ class Scheduler:
127
129
  )
128
130
 
129
131
  if server_args.skip_tokenizer_init:
130
- # Directly send to the tokenizer/api
132
+ # Directly send to the TokenizerManager
131
133
  self.send_to_detokenizer = get_zmq_socket(
132
134
  context, zmq.PUSH, port_args.tokenizer_ipc_name
133
135
  )
134
136
  else:
135
- # Send to the detokenizer
137
+ # Send to the DetokenizerManager
136
138
  self.send_to_detokenizer = get_zmq_socket(
137
139
  context, zmq.PUSH, port_args.detokenizer_ipc_name
138
140
  )
@@ -383,7 +385,8 @@ class Scheduler:
383
385
  self.process_input_requests(recv_reqs)
384
386
 
385
387
  batch = self.get_next_batch_to_run()
386
- if self.server_args.enable_dp_attention:
388
+
389
+ if self.server_args.enable_dp_attention: # TODO: simplify this
387
390
  batch = self.prepare_dp_attn_batch(batch)
388
391
 
389
392
  self.cur_batch = batch
@@ -392,7 +395,7 @@ class Scheduler:
392
395
  result = self.run_batch(batch)
393
396
  self.process_batch_result(batch, result)
394
397
  else:
395
- # Self-check and re-init some states when the server is idle
398
+ # When the server is idle, so self-check and re-init some states
396
399
  self.check_memory()
397
400
  self.new_token_ratio = self.init_new_token_ratio
398
401
 
@@ -409,12 +412,13 @@ class Scheduler:
409
412
 
410
413
  batch = self.get_next_batch_to_run()
411
414
  self.cur_batch = batch
415
+
412
416
  if batch:
413
417
  result = self.run_batch(batch)
414
418
  result_queue.append((batch.copy(), result))
415
419
 
416
420
  if self.last_batch is None:
417
- # A dummy first batch to start the pipeline for overlap scheduler.
421
+ # Create a dummy first batch to start the pipeline for overlap scheduler.
418
422
  # It is now used for triggering the sampling_info_done event.
419
423
  tmp_batch = ScheduleBatch(
420
424
  reqs=None,
@@ -424,19 +428,21 @@ class Scheduler:
424
428
  self.process_batch_result(tmp_batch, None)
425
429
 
426
430
  if self.last_batch:
431
+ # Process the results of the last batch
427
432
  tmp_batch, tmp_result = result_queue.popleft()
428
433
  tmp_batch.next_batch_sampling_info = (
429
434
  self.tp_worker.cur_sampling_info if batch else None
430
435
  )
431
436
  self.process_batch_result(tmp_batch, tmp_result)
432
437
  elif batch is None:
433
- # Self-check and re-init some states when the server is idle
438
+ # When the server is idle, so self-check and re-init some states
434
439
  self.check_memory()
435
440
  self.new_token_ratio = self.init_new_token_ratio
436
441
 
437
442
  self.last_batch = batch
438
443
 
439
- def recv_requests(self):
444
+ def recv_requests(self) -> List[Req]:
445
+ """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
440
446
  if self.tp_rank == 0 or self.server_args.enable_dp_attention:
441
447
  recv_reqs = []
442
448
 
@@ -478,6 +484,11 @@ class Scheduler:
478
484
  self.send_to_tokenizer.send_pyobj(
479
485
  UpdateWeightsFromDistributedReqOutput(success, message)
480
486
  )
487
+ elif isinstance(recv_req, UpdateWeightsFromTensorReqInput):
488
+ success, message = self.update_weights_from_tensor(recv_req)
489
+ self.send_to_tokenizer.send_pyobj(
490
+ UpdateWeightsFromTensorReqOutput(success, message)
491
+ )
481
492
  elif isinstance(recv_req, GetWeightsByNameReqInput):
482
493
  parameter = self.get_weights_by_name(recv_req)
483
494
  self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
@@ -487,8 +498,10 @@ class Scheduler:
487
498
  else:
488
499
  self.stop_profile()
489
500
  elif isinstance(recv_req, OpenSessionReqInput):
490
- session_id = self.open_session(recv_req)
491
- self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id))
501
+ session_id, success = self.open_session(recv_req)
502
+ self.send_to_tokenizer.send_pyobj(
503
+ OpenSessionReqOutput(session_id=session_id, success=success)
504
+ )
492
505
  elif isinstance(recv_req, CloseSessionReqInput):
493
506
  self.close_session(recv_req)
494
507
  else:
@@ -499,7 +512,11 @@ class Scheduler:
499
512
  recv_req: TokenizedGenerateReqInput,
500
513
  ):
501
514
  # Create a new request
502
- if recv_req.session_id is None or recv_req.session_id not in self.sessions:
515
+ if (
516
+ recv_req.session_params is None
517
+ or recv_req.session_params.id is None
518
+ or recv_req.session_params.id not in self.sessions
519
+ ):
503
520
 
504
521
  if recv_req.input_embeds is not None:
505
522
  # Generate fake input_ids based on the length of input_embeds
@@ -517,18 +534,22 @@ class Scheduler:
517
534
  stream=recv_req.stream,
518
535
  lora_path=recv_req.lora_path,
519
536
  input_embeds=recv_req.input_embeds,
537
+ eos_token_ids=self.model_config.hf_eos_token_id,
520
538
  )
521
539
  req.tokenizer = self.tokenizer
522
540
 
523
- if recv_req.session_id is not None:
541
+ if (
542
+ recv_req.session_params is not None
543
+ and recv_req.session_params.id is not None
544
+ ):
524
545
  req.finished_reason = FINISH_ABORT(
525
- f"Invalid request: session id {recv_req.session_id} does not exist"
546
+ f"Invalid request: session id {recv_req.session_params.id} does not exist"
526
547
  )
527
548
  self.waiting_queue.append(req)
528
549
  return
529
550
  else:
530
- # Create a new request from a previsou session
531
- session = self.sessions[recv_req.session_id]
551
+ # Create a new request from a previous session
552
+ session = self.sessions[recv_req.session_params.id]
532
553
  req = session.create_req(recv_req, self.tokenizer)
533
554
  if isinstance(req.finished_reason, FINISH_ABORT):
534
555
  self.waiting_queue.append(req)
@@ -804,6 +825,8 @@ class Scheduler:
804
825
  if res == AddReqResult.NO_TOKEN:
805
826
  self.batch_is_full = True
806
827
  break
828
+ if self.server_args.prefill_only_one_req:
829
+ break
807
830
 
808
831
  # Update waiting queue
809
832
  can_run_list = adder.can_run_list
@@ -1457,6 +1480,17 @@ class Scheduler:
1457
1480
  logger.error(message)
1458
1481
  return success, message
1459
1482
 
1483
+ def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
1484
+ """Update the online model parameter from tensors."""
1485
+ success, message = self.tp_worker.update_weights_from_tensor(recv_req)
1486
+ # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
1487
+ if success:
1488
+ flash_cache_success = self.flush_cache()
1489
+ assert flash_cache_success, "Cache flush failed after updating weights"
1490
+ else:
1491
+ logger.error(message)
1492
+ return success, message
1493
+
1460
1494
  def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
1461
1495
  parameter = self.tp_worker.get_weights_by_name(recv_req)
1462
1496
  return parameter
@@ -1475,16 +1509,20 @@ class Scheduler:
1475
1509
  )
1476
1510
  logger.info("Profiler is done")
1477
1511
 
1478
- def open_session(self, recv_req: OpenSessionReqInput) -> str:
1512
+ def open_session(self, recv_req: OpenSessionReqInput) -> Tuple[Optional[str], bool]:
1479
1513
  # handle error
1480
1514
  session_id = recv_req.session_id
1481
1515
  if session_id in self.sessions:
1482
1516
  logger.warning(f"session id {session_id} already exist, cannot open.")
1517
+ return session_id, False
1518
+ elif session_id is None:
1519
+ logger.warning(f"session id is None, cannot open.")
1520
+ return session_id, False
1483
1521
  else:
1484
1522
  self.sessions[session_id] = Session(
1485
1523
  recv_req.capacity_of_str_len, session_id
1486
1524
  )
1487
- return session_id
1525
+ return session_id, True
1488
1526
 
1489
1527
  def close_session(self, recv_req: CloseSessionReqInput):
1490
1528
  # handle error
@@ -1509,18 +1547,20 @@ def run_scheduler_process(
1509
1547
  if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
1510
1548
  dp_rank = int(os.environ["SGLANG_DP_RANK"])
1511
1549
 
1550
+ # Configue the logger
1512
1551
  if dp_rank is None:
1513
1552
  configure_logger(server_args, prefix=f" TP{tp_rank}")
1514
1553
  else:
1515
1554
  configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
1555
+ suppress_other_loggers()
1516
1556
 
1517
- # set cpu affinity to this gpu process
1557
+ # Set cpu affinity to this gpu process
1518
1558
  if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
1519
1559
  set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
1520
1560
 
1521
- suppress_other_loggers()
1522
1561
  parent_process = psutil.Process().parent()
1523
1562
 
1563
+ # Create a scheduler and run the event loop
1524
1564
  try:
1525
1565
  scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
1526
1566
  pipe_writer.send(