sglang 0.4.1.post1__py3-none-any.whl → 0.4.1.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +1 -0
- sglang/srt/configs/model_config.py +11 -2
- sglang/srt/layers/attention/__init__.py +0 -1
- sglang/srt/layers/attention/flashinfer_backend.py +54 -41
- sglang/srt/layers/logits_processor.py +30 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -26
- sglang/srt/layers/quantization/fp8.py +42 -2
- sglang/srt/layers/quantization/fp8_kernel.py +77 -18
- sglang/srt/layers/quantization/fp8_utils.py +8 -2
- sglang/srt/managers/io_struct.py +29 -8
- sglang/srt/managers/schedule_batch.py +22 -15
- sglang/srt/managers/scheduler.py +60 -20
- sglang/srt/managers/session_controller.py +102 -27
- sglang/srt/managers/tokenizer_manager.py +41 -10
- sglang/srt/managers/tp_worker.py +7 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -0
- sglang/srt/model_executor/forward_batch_info.py +42 -3
- sglang/srt/model_executor/model_runner.py +4 -0
- sglang/srt/models/llama.py +11 -0
- sglang/srt/models/llama_eagle.py +132 -0
- sglang/srt/openai_api/adapter.py +60 -2
- sglang/srt/openai_api/protocol.py +48 -0
- sglang/srt/server.py +26 -3
- sglang/srt/server_args.py +17 -30
- sglang/srt/speculative/spec_info.py +19 -0
- sglang/srt/utils.py +62 -0
- sglang/version.py +1 -1
- {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post2.dist-info}/METADATA +3 -3
- {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post2.dist-info}/RECORD +32 -30
- {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post2.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post2.dist-info}/top_level.txt +0 -0
@@ -12,12 +12,23 @@
|
|
12
12
|
# limitations under the License.
|
13
13
|
# ==============================================================================
|
14
14
|
|
15
|
-
|
15
|
+
import functools
|
16
|
+
import json
|
17
|
+
import logging
|
18
|
+
import os
|
19
|
+
from typing import Any, Dict, List, Optional, Tuple
|
16
20
|
|
17
21
|
import torch
|
18
22
|
import triton
|
19
23
|
import triton.language as tl
|
20
24
|
|
25
|
+
from sglang.srt.utils import get_device_name, is_hip
|
26
|
+
|
27
|
+
is_hip_ = is_hip()
|
28
|
+
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
29
|
+
|
30
|
+
logger = logging.getLogger(__name__)
|
31
|
+
|
21
32
|
|
22
33
|
@triton.jit
|
23
34
|
def _per_token_group_quant_fp8(
|
@@ -65,7 +76,7 @@ def per_token_group_quant_fp8(
|
|
65
76
|
x: torch.Tensor,
|
66
77
|
group_size: int,
|
67
78
|
eps: float = 1e-10,
|
68
|
-
dtype: torch.dtype =
|
79
|
+
dtype: torch.dtype = fp8_type_,
|
69
80
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
70
81
|
"""Function to perform per-token-group quantization on an input tensor `x`.
|
71
82
|
|
@@ -87,9 +98,13 @@ def per_token_group_quant_fp8(
|
|
87
98
|
assert x.is_contiguous(), "`x` is not contiguous"
|
88
99
|
|
89
100
|
finfo = torch.finfo(dtype)
|
90
|
-
fp8_min = finfo.min
|
91
101
|
fp8_max = finfo.max
|
92
102
|
|
103
|
+
if is_hip_:
|
104
|
+
fp8_max = 224.0
|
105
|
+
|
106
|
+
fp8_min = -fp8_max
|
107
|
+
|
93
108
|
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
94
109
|
M = x.numel() // group_size
|
95
110
|
N = group_size
|
@@ -205,6 +220,48 @@ def _w8a8_block_fp8_matmul(
|
|
205
220
|
tl.store(c_ptrs, c, mask=c_mask)
|
206
221
|
|
207
222
|
|
223
|
+
@functools.lru_cache
|
224
|
+
def get_w8a8_block_fp8_configs(
|
225
|
+
N: int, K: int, block_n: int, block_k: int
|
226
|
+
) -> Optional[Dict[int, Any]]:
|
227
|
+
"""
|
228
|
+
Return optimized configurations for the w8a8 block fp8 kernel.
|
229
|
+
|
230
|
+
The return value will be a dictionary that maps an irregular grid of
|
231
|
+
batch sizes to configurations of the w8a8 block fp8 kernel. To evaluate the
|
232
|
+
kernel on a given batch size bs, the closest batch size in the grid should
|
233
|
+
be picked and the associated configuration chosen to invoke the kernel.
|
234
|
+
"""
|
235
|
+
|
236
|
+
# First look up if an optimized configuration is available in the configs
|
237
|
+
# directory
|
238
|
+
device_name = get_device_name().replace(" ", "_")
|
239
|
+
json_file_name = f"N={N},K={K},device_name={device_name},dtype=fp8_w8a8,block_shape=[{block_n}, {block_k}].json"
|
240
|
+
|
241
|
+
config_file_path = os.path.join(
|
242
|
+
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
|
243
|
+
)
|
244
|
+
if os.path.exists(config_file_path):
|
245
|
+
with open(config_file_path) as f:
|
246
|
+
logger.info(
|
247
|
+
"Using configuration from %s for W8A8 Block FP8 kernel.",
|
248
|
+
config_file_path,
|
249
|
+
)
|
250
|
+
# If a configuration has been found, return it
|
251
|
+
return {int(key): val for key, val in json.load(f).items()}
|
252
|
+
|
253
|
+
# If no optimized configuration is available, we will use the default
|
254
|
+
# configuration
|
255
|
+
logger.warning(
|
256
|
+
(
|
257
|
+
"Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! "
|
258
|
+
"Config file not found at %s"
|
259
|
+
),
|
260
|
+
config_file_path,
|
261
|
+
)
|
262
|
+
return None
|
263
|
+
|
264
|
+
|
208
265
|
def w8a8_block_fp8_matmul(
|
209
266
|
A: torch.Tensor,
|
210
267
|
B: torch.Tensor,
|
@@ -245,17 +302,22 @@ def w8a8_block_fp8_matmul(
|
|
245
302
|
C_shape = A.shape[:-1] + (N,)
|
246
303
|
C = A.new_empty(C_shape, dtype=output_dtype)
|
247
304
|
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
305
|
+
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
|
306
|
+
if configs:
|
307
|
+
# If an optimal configuration map has been found, look up the
|
308
|
+
# optimal config
|
309
|
+
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
310
|
+
else:
|
311
|
+
# Default config
|
312
|
+
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
|
313
|
+
config = {
|
314
|
+
"BLOCK_SIZE_M": 64,
|
315
|
+
"BLOCK_SIZE_N": block_size[0],
|
316
|
+
"BLOCK_SIZE_K": block_size[1],
|
317
|
+
"GROUP_SIZE_M": 32,
|
318
|
+
"num_warps": 4,
|
319
|
+
"num_stages": 3,
|
320
|
+
}
|
259
321
|
|
260
322
|
def grid(META):
|
261
323
|
return (
|
@@ -283,10 +345,7 @@ def w8a8_block_fp8_matmul(
|
|
283
345
|
As.stride(-1),
|
284
346
|
Bs.stride(1),
|
285
347
|
Bs.stride(0),
|
286
|
-
|
287
|
-
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
288
|
-
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
289
|
-
GROUP_SIZE_M=8,
|
348
|
+
**config,
|
290
349
|
)
|
291
350
|
|
292
351
|
return C
|
@@ -7,6 +7,9 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|
7
7
|
per_token_group_quant_fp8,
|
8
8
|
w8a8_block_fp8_matmul,
|
9
9
|
)
|
10
|
+
from sglang.srt.utils import is_hip
|
11
|
+
|
12
|
+
is_hip_ = is_hip()
|
10
13
|
|
11
14
|
|
12
15
|
def normalize_e4m3fn_to_e4m3fnuz(
|
@@ -63,8 +66,11 @@ def input_to_float8(
|
|
63
66
|
finfo = torch.finfo(dtype)
|
64
67
|
min_val, max_val = x.aminmax()
|
65
68
|
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
66
|
-
|
67
|
-
|
69
|
+
fp8_max = finfo.max
|
70
|
+
if is_hip_:
|
71
|
+
fp8_max = 224.0
|
72
|
+
scale = fp8_max / amax
|
73
|
+
x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max)
|
68
74
|
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
69
75
|
|
70
76
|
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -21,10 +21,20 @@ from dataclasses import dataclass
|
|
21
21
|
from enum import Enum
|
22
22
|
from typing import Dict, List, Optional, Tuple, Union
|
23
23
|
|
24
|
+
import torch
|
25
|
+
|
24
26
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
25
27
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
26
28
|
|
27
29
|
|
30
|
+
@dataclass
|
31
|
+
class SessionParams:
|
32
|
+
id: Optional[str] = None
|
33
|
+
rid: Optional[str] = None
|
34
|
+
offset: Optional[int] = None
|
35
|
+
replace: Optional[bool] = None
|
36
|
+
|
37
|
+
|
28
38
|
@dataclass
|
29
39
|
class GenerateReqInput:
|
30
40
|
# The input prompt. It can be a single prompt or a batch of prompts.
|
@@ -56,10 +66,8 @@ class GenerateReqInput:
|
|
56
66
|
# LoRA related
|
57
67
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
58
68
|
|
59
|
-
# Session
|
60
|
-
|
61
|
-
Union[List[Tuple[str, Optional[str]]], Tuple[str, Optional[str]]]
|
62
|
-
] = None
|
69
|
+
# Session info for continual prompting
|
70
|
+
session_params: Optional[Union[List[Dict], Dict]] = None
|
63
71
|
|
64
72
|
def normalize_batch_and_arguments(self):
|
65
73
|
if (
|
@@ -221,9 +229,8 @@ class TokenizedGenerateReqInput:
|
|
221
229
|
# The input embeds
|
222
230
|
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
223
231
|
|
224
|
-
# Session
|
225
|
-
|
226
|
-
session_rid: Optional[str] = None
|
232
|
+
# Session info for continual prompting
|
233
|
+
session_params: Optional[SessionParams] = None
|
227
234
|
|
228
235
|
|
229
236
|
@dataclass
|
@@ -407,6 +414,18 @@ class UpdateWeightsFromDistributedReqOutput:
|
|
407
414
|
message: str
|
408
415
|
|
409
416
|
|
417
|
+
@dataclass
|
418
|
+
class UpdateWeightsFromTensorReqInput:
|
419
|
+
name: str
|
420
|
+
tensor: torch.Tensor
|
421
|
+
|
422
|
+
|
423
|
+
@dataclass
|
424
|
+
class UpdateWeightsFromTensorReqOutput:
|
425
|
+
success: bool
|
426
|
+
message: str
|
427
|
+
|
428
|
+
|
410
429
|
@dataclass
|
411
430
|
class InitWeightsUpdateGroupReqInput:
|
412
431
|
# The master address
|
@@ -454,6 +473,7 @@ class ProfileReq(Enum):
|
|
454
473
|
@dataclass
|
455
474
|
class OpenSessionReqInput:
|
456
475
|
capacity_of_str_len: int
|
476
|
+
session_id: Optional[str] = None
|
457
477
|
|
458
478
|
|
459
479
|
@dataclass
|
@@ -463,4 +483,5 @@ class CloseSessionReqInput:
|
|
463
483
|
|
464
484
|
@dataclass
|
465
485
|
class OpenSessionReqOutput:
|
466
|
-
session_id: str
|
486
|
+
session_id: Optional[str]
|
487
|
+
success: bool
|
@@ -29,7 +29,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
|
29
29
|
|
30
30
|
import dataclasses
|
31
31
|
import logging
|
32
|
-
from typing import List, Optional, Tuple, Union
|
32
|
+
from typing import List, Optional, Set, Tuple, Union
|
33
33
|
|
34
34
|
import numpy as np
|
35
35
|
import torch
|
@@ -209,6 +209,7 @@ class Req:
|
|
209
209
|
lora_path: Optional[str] = None,
|
210
210
|
input_embeds: Optional[List[List[float]]] = None,
|
211
211
|
session_id: Optional[str] = None,
|
212
|
+
eos_token_ids: Optional[Set[int]] = None,
|
212
213
|
):
|
213
214
|
# Input and output info
|
214
215
|
self.rid = rid
|
@@ -236,6 +237,7 @@ class Req:
|
|
236
237
|
self.finished_reason = None
|
237
238
|
self.to_abort = False
|
238
239
|
self.stream = stream
|
240
|
+
self.eos_token_ids = eos_token_ids
|
239
241
|
|
240
242
|
# For incremental decoding
|
241
243
|
# ----- | --------- read_ids -------|
|
@@ -395,18 +397,23 @@ class Req:
|
|
395
397
|
|
396
398
|
last_token_id = self.output_ids[-1]
|
397
399
|
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
400
|
+
if not self.sampling_params.ignore_eos:
|
401
|
+
matched_eos = False
|
402
|
+
|
403
|
+
# Check stop token ids
|
404
|
+
if self.sampling_params.stop_token_ids:
|
405
|
+
matched_eos = last_token_id in self.sampling_params.stop_token_ids
|
406
|
+
if self.eos_token_ids:
|
407
|
+
matched_eos |= last_token_id in self.eos_token_ids
|
408
|
+
if self.tokenizer is not None:
|
409
|
+
matched_eos |= last_token_id == self.tokenizer.eos_token_id
|
410
|
+
if self.tokenizer.additional_stop_token_ids:
|
411
|
+
matched_eos |= (
|
412
|
+
last_token_id in self.tokenizer.additional_stop_token_ids
|
413
|
+
)
|
414
|
+
if matched_eos:
|
415
|
+
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
|
416
|
+
return
|
410
417
|
|
411
418
|
# Check stop strings
|
412
419
|
if len(self.sampling_params.stop_strs) > 0:
|
@@ -836,8 +843,8 @@ class ScheduleBatch:
|
|
836
843
|
# TODO (lianmin): Revisit this. It should be seq_len - 1
|
837
844
|
self.extend_logprob_start_lens.extend([0] * running_bs)
|
838
845
|
|
839
|
-
def check_decode_mem(self):
|
840
|
-
bs = len(self.reqs)
|
846
|
+
def check_decode_mem(self, buf_multiplier=1):
|
847
|
+
bs = len(self.reqs) * buf_multiplier
|
841
848
|
if self.token_to_kv_pool.available_size() >= bs:
|
842
849
|
return True
|
843
850
|
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -22,7 +22,7 @@ import warnings
|
|
22
22
|
from collections import deque
|
23
23
|
from concurrent import futures
|
24
24
|
from types import SimpleNamespace
|
25
|
-
from typing import
|
25
|
+
from typing import Dict, List, Optional, Tuple
|
26
26
|
|
27
27
|
import psutil
|
28
28
|
import setproctitle
|
@@ -52,6 +52,8 @@ from sglang.srt.managers.io_struct import (
|
|
52
52
|
UpdateWeightFromDiskReqOutput,
|
53
53
|
UpdateWeightsFromDistributedReqInput,
|
54
54
|
UpdateWeightsFromDistributedReqOutput,
|
55
|
+
UpdateWeightsFromTensorReqInput,
|
56
|
+
UpdateWeightsFromTensorReqOutput,
|
55
57
|
)
|
56
58
|
from sglang.srt.managers.schedule_batch import (
|
57
59
|
FINISH_ABORT,
|
@@ -88,7 +90,7 @@ from sglang.utils import get_exception_traceback
|
|
88
90
|
|
89
91
|
logger = logging.getLogger(__name__)
|
90
92
|
|
91
|
-
# Test retract decode
|
93
|
+
# Test retract decode for debugging purposes
|
92
94
|
test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
|
93
95
|
|
94
96
|
|
@@ -127,12 +129,12 @@ class Scheduler:
|
|
127
129
|
)
|
128
130
|
|
129
131
|
if server_args.skip_tokenizer_init:
|
130
|
-
# Directly send to the
|
132
|
+
# Directly send to the TokenizerManager
|
131
133
|
self.send_to_detokenizer = get_zmq_socket(
|
132
134
|
context, zmq.PUSH, port_args.tokenizer_ipc_name
|
133
135
|
)
|
134
136
|
else:
|
135
|
-
# Send to the
|
137
|
+
# Send to the DetokenizerManager
|
136
138
|
self.send_to_detokenizer = get_zmq_socket(
|
137
139
|
context, zmq.PUSH, port_args.detokenizer_ipc_name
|
138
140
|
)
|
@@ -383,7 +385,8 @@ class Scheduler:
|
|
383
385
|
self.process_input_requests(recv_reqs)
|
384
386
|
|
385
387
|
batch = self.get_next_batch_to_run()
|
386
|
-
|
388
|
+
|
389
|
+
if self.server_args.enable_dp_attention: # TODO: simplify this
|
387
390
|
batch = self.prepare_dp_attn_batch(batch)
|
388
391
|
|
389
392
|
self.cur_batch = batch
|
@@ -392,7 +395,7 @@ class Scheduler:
|
|
392
395
|
result = self.run_batch(batch)
|
393
396
|
self.process_batch_result(batch, result)
|
394
397
|
else:
|
395
|
-
#
|
398
|
+
# When the server is idle, so self-check and re-init some states
|
396
399
|
self.check_memory()
|
397
400
|
self.new_token_ratio = self.init_new_token_ratio
|
398
401
|
|
@@ -409,12 +412,13 @@ class Scheduler:
|
|
409
412
|
|
410
413
|
batch = self.get_next_batch_to_run()
|
411
414
|
self.cur_batch = batch
|
415
|
+
|
412
416
|
if batch:
|
413
417
|
result = self.run_batch(batch)
|
414
418
|
result_queue.append((batch.copy(), result))
|
415
419
|
|
416
420
|
if self.last_batch is None:
|
417
|
-
#
|
421
|
+
# Create a dummy first batch to start the pipeline for overlap scheduler.
|
418
422
|
# It is now used for triggering the sampling_info_done event.
|
419
423
|
tmp_batch = ScheduleBatch(
|
420
424
|
reqs=None,
|
@@ -424,19 +428,21 @@ class Scheduler:
|
|
424
428
|
self.process_batch_result(tmp_batch, None)
|
425
429
|
|
426
430
|
if self.last_batch:
|
431
|
+
# Process the results of the last batch
|
427
432
|
tmp_batch, tmp_result = result_queue.popleft()
|
428
433
|
tmp_batch.next_batch_sampling_info = (
|
429
434
|
self.tp_worker.cur_sampling_info if batch else None
|
430
435
|
)
|
431
436
|
self.process_batch_result(tmp_batch, tmp_result)
|
432
437
|
elif batch is None:
|
433
|
-
#
|
438
|
+
# When the server is idle, so self-check and re-init some states
|
434
439
|
self.check_memory()
|
435
440
|
self.new_token_ratio = self.init_new_token_ratio
|
436
441
|
|
437
442
|
self.last_batch = batch
|
438
443
|
|
439
|
-
def recv_requests(self):
|
444
|
+
def recv_requests(self) -> List[Req]:
|
445
|
+
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
440
446
|
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
|
441
447
|
recv_reqs = []
|
442
448
|
|
@@ -478,6 +484,11 @@ class Scheduler:
|
|
478
484
|
self.send_to_tokenizer.send_pyobj(
|
479
485
|
UpdateWeightsFromDistributedReqOutput(success, message)
|
480
486
|
)
|
487
|
+
elif isinstance(recv_req, UpdateWeightsFromTensorReqInput):
|
488
|
+
success, message = self.update_weights_from_tensor(recv_req)
|
489
|
+
self.send_to_tokenizer.send_pyobj(
|
490
|
+
UpdateWeightsFromTensorReqOutput(success, message)
|
491
|
+
)
|
481
492
|
elif isinstance(recv_req, GetWeightsByNameReqInput):
|
482
493
|
parameter = self.get_weights_by_name(recv_req)
|
483
494
|
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
|
@@ -487,8 +498,10 @@ class Scheduler:
|
|
487
498
|
else:
|
488
499
|
self.stop_profile()
|
489
500
|
elif isinstance(recv_req, OpenSessionReqInput):
|
490
|
-
session_id = self.open_session(recv_req)
|
491
|
-
self.send_to_tokenizer.send_pyobj(
|
501
|
+
session_id, success = self.open_session(recv_req)
|
502
|
+
self.send_to_tokenizer.send_pyobj(
|
503
|
+
OpenSessionReqOutput(session_id=session_id, success=success)
|
504
|
+
)
|
492
505
|
elif isinstance(recv_req, CloseSessionReqInput):
|
493
506
|
self.close_session(recv_req)
|
494
507
|
else:
|
@@ -499,7 +512,11 @@ class Scheduler:
|
|
499
512
|
recv_req: TokenizedGenerateReqInput,
|
500
513
|
):
|
501
514
|
# Create a new request
|
502
|
-
if
|
515
|
+
if (
|
516
|
+
recv_req.session_params is None
|
517
|
+
or recv_req.session_params.id is None
|
518
|
+
or recv_req.session_params.id not in self.sessions
|
519
|
+
):
|
503
520
|
|
504
521
|
if recv_req.input_embeds is not None:
|
505
522
|
# Generate fake input_ids based on the length of input_embeds
|
@@ -517,18 +534,22 @@ class Scheduler:
|
|
517
534
|
stream=recv_req.stream,
|
518
535
|
lora_path=recv_req.lora_path,
|
519
536
|
input_embeds=recv_req.input_embeds,
|
537
|
+
eos_token_ids=self.model_config.hf_eos_token_id,
|
520
538
|
)
|
521
539
|
req.tokenizer = self.tokenizer
|
522
540
|
|
523
|
-
if
|
541
|
+
if (
|
542
|
+
recv_req.session_params is not None
|
543
|
+
and recv_req.session_params.id is not None
|
544
|
+
):
|
524
545
|
req.finished_reason = FINISH_ABORT(
|
525
|
-
f"Invalid request: session id {recv_req.
|
546
|
+
f"Invalid request: session id {recv_req.session_params.id} does not exist"
|
526
547
|
)
|
527
548
|
self.waiting_queue.append(req)
|
528
549
|
return
|
529
550
|
else:
|
530
|
-
# Create a new request from a
|
531
|
-
session = self.sessions[recv_req.
|
551
|
+
# Create a new request from a previous session
|
552
|
+
session = self.sessions[recv_req.session_params.id]
|
532
553
|
req = session.create_req(recv_req, self.tokenizer)
|
533
554
|
if isinstance(req.finished_reason, FINISH_ABORT):
|
534
555
|
self.waiting_queue.append(req)
|
@@ -804,6 +825,8 @@ class Scheduler:
|
|
804
825
|
if res == AddReqResult.NO_TOKEN:
|
805
826
|
self.batch_is_full = True
|
806
827
|
break
|
828
|
+
if self.server_args.prefill_only_one_req:
|
829
|
+
break
|
807
830
|
|
808
831
|
# Update waiting queue
|
809
832
|
can_run_list = adder.can_run_list
|
@@ -1457,6 +1480,17 @@ class Scheduler:
|
|
1457
1480
|
logger.error(message)
|
1458
1481
|
return success, message
|
1459
1482
|
|
1483
|
+
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
1484
|
+
"""Update the online model parameter from tensors."""
|
1485
|
+
success, message = self.tp_worker.update_weights_from_tensor(recv_req)
|
1486
|
+
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
|
1487
|
+
if success:
|
1488
|
+
flash_cache_success = self.flush_cache()
|
1489
|
+
assert flash_cache_success, "Cache flush failed after updating weights"
|
1490
|
+
else:
|
1491
|
+
logger.error(message)
|
1492
|
+
return success, message
|
1493
|
+
|
1460
1494
|
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
1461
1495
|
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
1462
1496
|
return parameter
|
@@ -1475,16 +1509,20 @@ class Scheduler:
|
|
1475
1509
|
)
|
1476
1510
|
logger.info("Profiler is done")
|
1477
1511
|
|
1478
|
-
def open_session(self, recv_req: OpenSessionReqInput) -> str:
|
1512
|
+
def open_session(self, recv_req: OpenSessionReqInput) -> Tuple[Optional[str], bool]:
|
1479
1513
|
# handle error
|
1480
1514
|
session_id = recv_req.session_id
|
1481
1515
|
if session_id in self.sessions:
|
1482
1516
|
logger.warning(f"session id {session_id} already exist, cannot open.")
|
1517
|
+
return session_id, False
|
1518
|
+
elif session_id is None:
|
1519
|
+
logger.warning(f"session id is None, cannot open.")
|
1520
|
+
return session_id, False
|
1483
1521
|
else:
|
1484
1522
|
self.sessions[session_id] = Session(
|
1485
1523
|
recv_req.capacity_of_str_len, session_id
|
1486
1524
|
)
|
1487
|
-
|
1525
|
+
return session_id, True
|
1488
1526
|
|
1489
1527
|
def close_session(self, recv_req: CloseSessionReqInput):
|
1490
1528
|
# handle error
|
@@ -1509,18 +1547,20 @@ def run_scheduler_process(
|
|
1509
1547
|
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
|
1510
1548
|
dp_rank = int(os.environ["SGLANG_DP_RANK"])
|
1511
1549
|
|
1550
|
+
# Configue the logger
|
1512
1551
|
if dp_rank is None:
|
1513
1552
|
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
1514
1553
|
else:
|
1515
1554
|
configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
|
1555
|
+
suppress_other_loggers()
|
1516
1556
|
|
1517
|
-
#
|
1557
|
+
# Set cpu affinity to this gpu process
|
1518
1558
|
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
1519
1559
|
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
1520
1560
|
|
1521
|
-
suppress_other_loggers()
|
1522
1561
|
parent_process = psutil.Process().parent()
|
1523
1562
|
|
1563
|
+
# Create a scheduler and run the event loop
|
1524
1564
|
try:
|
1525
1565
|
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
|
1526
1566
|
pipe_writer.send(
|