sglang 0.4.1__py3-none-any.whl → 0.4.1.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +1 -0
- sglang/bench_serving.py +11 -3
- sglang/lang/backend/openai.py +10 -0
- sglang/srt/configs/model_config.py +11 -2
- sglang/srt/constrained/xgrammar_backend.py +6 -0
- sglang/srt/layers/attention/__init__.py +0 -1
- sglang/srt/layers/attention/flashinfer_backend.py +54 -41
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -14
- sglang/srt/layers/logits_processor.py +30 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -30
- sglang/srt/layers/moe/topk.py +14 -0
- sglang/srt/layers/quantization/fp8.py +42 -2
- sglang/srt/layers/quantization/fp8_kernel.py +91 -18
- sglang/srt/layers/quantization/fp8_utils.py +8 -2
- sglang/srt/managers/io_struct.py +29 -8
- sglang/srt/managers/schedule_batch.py +22 -15
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +71 -34
- sglang/srt/managers/session_controller.py +102 -27
- sglang/srt/managers/tokenizer_manager.py +95 -55
- sglang/srt/managers/tp_worker.py +7 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -0
- sglang/srt/model_executor/forward_batch_info.py +42 -3
- sglang/srt/model_executor/model_runner.py +4 -6
- sglang/srt/model_loader/loader.py +22 -11
- sglang/srt/models/gemma2.py +19 -0
- sglang/srt/models/llama.py +13 -2
- sglang/srt/models/llama_eagle.py +132 -0
- sglang/srt/openai_api/adapter.py +79 -2
- sglang/srt/openai_api/protocol.py +50 -0
- sglang/srt/sampling/sampling_params.py +9 -2
- sglang/srt/server.py +45 -39
- sglang/srt/server_args.py +17 -30
- sglang/srt/speculative/spec_info.py +19 -0
- sglang/srt/utils.py +62 -0
- sglang/version.py +1 -1
- {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/METADATA +5 -5
- {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/RECORD +41 -39
- {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/top_level.txt +0 -0
@@ -11,12 +11,17 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
11
11
|
import torch
|
12
12
|
import triton
|
13
13
|
import triton.language as tl
|
14
|
-
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
15
14
|
from vllm import _custom_ops as ops
|
16
15
|
|
17
16
|
from sglang.srt.layers.moe.topk import select_experts
|
18
17
|
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
19
|
-
from sglang.srt.utils import direct_register_custom_op, get_device_name
|
18
|
+
from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip
|
19
|
+
|
20
|
+
not_hip = False
|
21
|
+
if not is_hip():
|
22
|
+
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
23
|
+
|
24
|
+
not_hip = True
|
20
25
|
|
21
26
|
logger = logging.getLogger(__name__)
|
22
27
|
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
@@ -267,8 +272,14 @@ def moe_align_block_size(
|
|
267
272
|
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
268
273
|
)
|
269
274
|
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
270
|
-
|
271
|
-
|
275
|
+
if not_hip and num_experts >= 224:
|
276
|
+
token_cnts_buffer = torch.empty(
|
277
|
+
(num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device
|
278
|
+
)
|
279
|
+
cumsum_buffer = torch.empty(
|
280
|
+
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
281
|
+
)
|
282
|
+
|
272
283
|
sgl_moe_align_block_size(
|
273
284
|
topk_ids,
|
274
285
|
num_experts,
|
@@ -276,6 +287,8 @@ def moe_align_block_size(
|
|
276
287
|
sorted_ids,
|
277
288
|
expert_ids,
|
278
289
|
num_tokens_post_pad,
|
290
|
+
token_cnts_buffer,
|
291
|
+
cumsum_buffer,
|
279
292
|
)
|
280
293
|
else:
|
281
294
|
ops.moe_align_block_size(
|
@@ -379,14 +392,25 @@ def invoke_fused_moe_kernel(
|
|
379
392
|
)
|
380
393
|
|
381
394
|
|
382
|
-
def get_config_file_name(
|
395
|
+
def get_config_file_name(
|
396
|
+
E: int, N: int, dtype: Optional[str], block_shape: Optional[int] = None
|
397
|
+
) -> str:
|
383
398
|
device_name = get_device_name().replace(" ", "_")
|
384
399
|
dtype_selector = "" if not dtype else f",dtype={dtype}"
|
385
|
-
|
400
|
+
block_shape_selector = (
|
401
|
+
"" if not block_shape or not all(block_shape) else f",block_shape={block_shape}"
|
402
|
+
)
|
403
|
+
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json"
|
386
404
|
|
387
405
|
|
388
406
|
@functools.lru_cache
|
389
|
-
def get_moe_configs(
|
407
|
+
def get_moe_configs(
|
408
|
+
E: int,
|
409
|
+
N: int,
|
410
|
+
dtype: Optional[str],
|
411
|
+
block_n: Optional[int] = 0,
|
412
|
+
block_k: Optional[int] = 0,
|
413
|
+
) -> Optional[Dict[int, Any]]:
|
390
414
|
"""
|
391
415
|
Return optimized configurations for the fused MoE kernel.
|
392
416
|
|
@@ -398,7 +422,7 @@ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int,
|
|
398
422
|
|
399
423
|
# First look up if an optimized configuration is available in the configs
|
400
424
|
# directory
|
401
|
-
json_file_name = get_config_file_name(E, N, dtype)
|
425
|
+
json_file_name = get_config_file_name(E, N, dtype, [block_n, block_k])
|
402
426
|
|
403
427
|
config_file_path = os.path.join(
|
404
428
|
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
|
@@ -429,25 +453,37 @@ def get_default_config(
|
|
429
453
|
topk: int,
|
430
454
|
dtype: Optional[str],
|
431
455
|
is_marlin: bool,
|
456
|
+
block_shape: Optional[List[int]] = None,
|
432
457
|
) -> Dict[str, int]:
|
433
458
|
if dtype == "fp8_w8a8":
|
434
|
-
|
435
|
-
"BLOCK_SIZE_M": 128,
|
436
|
-
"BLOCK_SIZE_N": 256,
|
437
|
-
"BLOCK_SIZE_K": 128,
|
438
|
-
"GROUP_SIZE_M": 32,
|
439
|
-
"num_warps": 8,
|
440
|
-
"num_stages": 4,
|
441
|
-
}
|
442
|
-
if M <= E:
|
459
|
+
if block_shape is None:
|
443
460
|
config = {
|
444
|
-
"BLOCK_SIZE_M":
|
445
|
-
"BLOCK_SIZE_N":
|
461
|
+
"BLOCK_SIZE_M": 128,
|
462
|
+
"BLOCK_SIZE_N": 256,
|
446
463
|
"BLOCK_SIZE_K": 128,
|
447
|
-
"GROUP_SIZE_M":
|
448
|
-
"num_warps":
|
464
|
+
"GROUP_SIZE_M": 32,
|
465
|
+
"num_warps": 8,
|
449
466
|
"num_stages": 4,
|
450
467
|
}
|
468
|
+
if M <= E:
|
469
|
+
config = {
|
470
|
+
"BLOCK_SIZE_M": 64,
|
471
|
+
"BLOCK_SIZE_N": 128,
|
472
|
+
"BLOCK_SIZE_K": 128,
|
473
|
+
"GROUP_SIZE_M": 1,
|
474
|
+
"num_warps": 4,
|
475
|
+
"num_stages": 4,
|
476
|
+
}
|
477
|
+
else:
|
478
|
+
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1]
|
479
|
+
config = {
|
480
|
+
"BLOCK_SIZE_M": 64,
|
481
|
+
"BLOCK_SIZE_N": block_shape[0],
|
482
|
+
"BLOCK_SIZE_K": block_shape[1],
|
483
|
+
"GROUP_SIZE_M": 32,
|
484
|
+
"num_warps": 4,
|
485
|
+
"num_stages": 3,
|
486
|
+
}
|
451
487
|
else:
|
452
488
|
config = {
|
453
489
|
"BLOCK_SIZE_M": 64,
|
@@ -483,7 +519,9 @@ def try_get_optimal_moe_config(
|
|
483
519
|
else:
|
484
520
|
# First try to load optimal config from the file
|
485
521
|
E, _, N = w2_shape
|
486
|
-
|
522
|
+
block_n = block_shape[0] if block_shape else 0
|
523
|
+
block_k = block_shape[1] if block_shape else 0
|
524
|
+
configs = get_moe_configs(E, N, dtype, block_n, block_k)
|
487
525
|
|
488
526
|
if configs:
|
489
527
|
# If an optimal configuration map has been found, look up the
|
@@ -491,14 +529,9 @@ def try_get_optimal_moe_config(
|
|
491
529
|
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
492
530
|
else:
|
493
531
|
# Else use the default config
|
494
|
-
config = get_default_config(
|
495
|
-
|
496
|
-
|
497
|
-
# BLOCK_K must be divisable by block_shape[1]
|
498
|
-
# BLOCK_N and BLOCK_M has no requirements
|
499
|
-
if block_shape is not None:
|
500
|
-
config["BLOCK_SIZE_N"] = block_shape[0]
|
501
|
-
config["BLOCK_SIZE_K"] = block_shape[1]
|
532
|
+
config = get_default_config(
|
533
|
+
M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape
|
534
|
+
)
|
502
535
|
return config
|
503
536
|
|
504
537
|
|
sglang/srt/layers/moe/topk.py
CHANGED
@@ -1,3 +1,17 @@
|
|
1
|
+
# Copyright 2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
|
1
15
|
from typing import Callable, Optional
|
2
16
|
|
3
17
|
import torch
|
@@ -28,7 +28,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|
28
28
|
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
|
29
29
|
|
30
30
|
from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
|
31
|
-
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import padding_size
|
32
31
|
from sglang.srt.layers.quantization.base_config import (
|
33
32
|
QuantizationConfig,
|
34
33
|
QuantizeMethodBase,
|
@@ -273,6 +272,19 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
273
272
|
def process_weights_after_loading(self, layer: Module) -> None:
|
274
273
|
# Block quant doesn't need to process weights after loading
|
275
274
|
if self.block_quant:
|
275
|
+
# If ROCm, normalize the weights and scales to e4m3fnuz
|
276
|
+
if is_hip():
|
277
|
+
# activation_scheme: dynamic
|
278
|
+
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
279
|
+
weight=layer.weight,
|
280
|
+
weight_scale=layer.weight_scale_inv,
|
281
|
+
input_scale=None,
|
282
|
+
)
|
283
|
+
layer.weight = torch.nn.Parameter(weight, require_grad=False)
|
284
|
+
layer.weight_scale_inv = torch.nn.Parameter(
|
285
|
+
weight_scale, require_grad=False
|
286
|
+
)
|
287
|
+
layer.input_scale = None
|
276
288
|
return
|
277
289
|
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
278
290
|
# If checkpoint not serialized fp8, quantize the weights.
|
@@ -370,7 +382,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
370
382
|
weight=layer.weight,
|
371
383
|
block_size=self.quant_config.weight_block_size,
|
372
384
|
weight_scale=layer.weight_scale_inv,
|
373
|
-
input_scale=
|
385
|
+
input_scale=None,
|
374
386
|
bias=bias,
|
375
387
|
)
|
376
388
|
|
@@ -548,8 +560,36 @@ class Fp8MoEMethod:
|
|
548
560
|
layer.w2_input_scale = None
|
549
561
|
|
550
562
|
def process_weights_after_loading(self, layer: Module) -> None:
|
563
|
+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
564
|
+
padding_size, # Avoid circular import
|
565
|
+
)
|
566
|
+
|
551
567
|
# Block quant doesn't need to process weights after loading
|
552
568
|
if self.block_quant:
|
569
|
+
# If ROCm, normalize the weights and scales to e4m3fnuz
|
570
|
+
if is_hip():
|
571
|
+
# activation_scheme: dynamic
|
572
|
+
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
573
|
+
weight=layer.w13_weight,
|
574
|
+
weight_scale=layer.w13_weight_scale_inv,
|
575
|
+
input_scale=None,
|
576
|
+
)
|
577
|
+
w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
578
|
+
weight=layer.w2_weight,
|
579
|
+
weight_scale=layer.w2_weight_scale_inv,
|
580
|
+
input_scale=None,
|
581
|
+
)
|
582
|
+
# Reset the parameter
|
583
|
+
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
584
|
+
layer.w13_weight_scale_inv = torch.nn.Parameter(
|
585
|
+
w13_weight_scale, requires_grad=False
|
586
|
+
)
|
587
|
+
layer.w13_input_scale = None
|
588
|
+
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
589
|
+
layer.w2_weight_scale_inv = torch.nn.Parameter(
|
590
|
+
w2_weight_scale, requires_grad=False
|
591
|
+
)
|
592
|
+
layer.w2_input_scale = None
|
553
593
|
return
|
554
594
|
# If checkpoint is fp16 or bfloat16, quantize in place.
|
555
595
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
@@ -1,9 +1,34 @@
|
|
1
|
-
|
1
|
+
# Copyright 2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
|
15
|
+
import functools
|
16
|
+
import json
|
17
|
+
import logging
|
18
|
+
import os
|
19
|
+
from typing import Any, Dict, List, Optional, Tuple
|
2
20
|
|
3
21
|
import torch
|
4
22
|
import triton
|
5
23
|
import triton.language as tl
|
6
24
|
|
25
|
+
from sglang.srt.utils import get_device_name, is_hip
|
26
|
+
|
27
|
+
is_hip_ = is_hip()
|
28
|
+
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
29
|
+
|
30
|
+
logger = logging.getLogger(__name__)
|
31
|
+
|
7
32
|
|
8
33
|
@triton.jit
|
9
34
|
def _per_token_group_quant_fp8(
|
@@ -51,7 +76,7 @@ def per_token_group_quant_fp8(
|
|
51
76
|
x: torch.Tensor,
|
52
77
|
group_size: int,
|
53
78
|
eps: float = 1e-10,
|
54
|
-
dtype: torch.dtype =
|
79
|
+
dtype: torch.dtype = fp8_type_,
|
55
80
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
56
81
|
"""Function to perform per-token-group quantization on an input tensor `x`.
|
57
82
|
|
@@ -73,9 +98,13 @@ def per_token_group_quant_fp8(
|
|
73
98
|
assert x.is_contiguous(), "`x` is not contiguous"
|
74
99
|
|
75
100
|
finfo = torch.finfo(dtype)
|
76
|
-
fp8_min = finfo.min
|
77
101
|
fp8_max = finfo.max
|
78
102
|
|
103
|
+
if is_hip_:
|
104
|
+
fp8_max = 224.0
|
105
|
+
|
106
|
+
fp8_min = -fp8_max
|
107
|
+
|
79
108
|
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
80
109
|
M = x.numel() // group_size
|
81
110
|
N = group_size
|
@@ -191,6 +220,48 @@ def _w8a8_block_fp8_matmul(
|
|
191
220
|
tl.store(c_ptrs, c, mask=c_mask)
|
192
221
|
|
193
222
|
|
223
|
+
@functools.lru_cache
|
224
|
+
def get_w8a8_block_fp8_configs(
|
225
|
+
N: int, K: int, block_n: int, block_k: int
|
226
|
+
) -> Optional[Dict[int, Any]]:
|
227
|
+
"""
|
228
|
+
Return optimized configurations for the w8a8 block fp8 kernel.
|
229
|
+
|
230
|
+
The return value will be a dictionary that maps an irregular grid of
|
231
|
+
batch sizes to configurations of the w8a8 block fp8 kernel. To evaluate the
|
232
|
+
kernel on a given batch size bs, the closest batch size in the grid should
|
233
|
+
be picked and the associated configuration chosen to invoke the kernel.
|
234
|
+
"""
|
235
|
+
|
236
|
+
# First look up if an optimized configuration is available in the configs
|
237
|
+
# directory
|
238
|
+
device_name = get_device_name().replace(" ", "_")
|
239
|
+
json_file_name = f"N={N},K={K},device_name={device_name},dtype=fp8_w8a8,block_shape=[{block_n}, {block_k}].json"
|
240
|
+
|
241
|
+
config_file_path = os.path.join(
|
242
|
+
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
|
243
|
+
)
|
244
|
+
if os.path.exists(config_file_path):
|
245
|
+
with open(config_file_path) as f:
|
246
|
+
logger.info(
|
247
|
+
"Using configuration from %s for W8A8 Block FP8 kernel.",
|
248
|
+
config_file_path,
|
249
|
+
)
|
250
|
+
# If a configuration has been found, return it
|
251
|
+
return {int(key): val for key, val in json.load(f).items()}
|
252
|
+
|
253
|
+
# If no optimized configuration is available, we will use the default
|
254
|
+
# configuration
|
255
|
+
logger.warning(
|
256
|
+
(
|
257
|
+
"Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! "
|
258
|
+
"Config file not found at %s"
|
259
|
+
),
|
260
|
+
config_file_path,
|
261
|
+
)
|
262
|
+
return None
|
263
|
+
|
264
|
+
|
194
265
|
def w8a8_block_fp8_matmul(
|
195
266
|
A: torch.Tensor,
|
196
267
|
B: torch.Tensor,
|
@@ -231,17 +302,22 @@ def w8a8_block_fp8_matmul(
|
|
231
302
|
C_shape = A.shape[:-1] + (N,)
|
232
303
|
C = A.new_empty(C_shape, dtype=output_dtype)
|
233
304
|
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
305
|
+
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
|
306
|
+
if configs:
|
307
|
+
# If an optimal configuration map has been found, look up the
|
308
|
+
# optimal config
|
309
|
+
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
310
|
+
else:
|
311
|
+
# Default config
|
312
|
+
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
|
313
|
+
config = {
|
314
|
+
"BLOCK_SIZE_M": 64,
|
315
|
+
"BLOCK_SIZE_N": block_size[0],
|
316
|
+
"BLOCK_SIZE_K": block_size[1],
|
317
|
+
"GROUP_SIZE_M": 32,
|
318
|
+
"num_warps": 4,
|
319
|
+
"num_stages": 3,
|
320
|
+
}
|
245
321
|
|
246
322
|
def grid(META):
|
247
323
|
return (
|
@@ -269,10 +345,7 @@ def w8a8_block_fp8_matmul(
|
|
269
345
|
As.stride(-1),
|
270
346
|
Bs.stride(1),
|
271
347
|
Bs.stride(0),
|
272
|
-
|
273
|
-
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
274
|
-
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
275
|
-
GROUP_SIZE_M=8,
|
348
|
+
**config,
|
276
349
|
)
|
277
350
|
|
278
351
|
return C
|
@@ -7,6 +7,9 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|
7
7
|
per_token_group_quant_fp8,
|
8
8
|
w8a8_block_fp8_matmul,
|
9
9
|
)
|
10
|
+
from sglang.srt.utils import is_hip
|
11
|
+
|
12
|
+
is_hip_ = is_hip()
|
10
13
|
|
11
14
|
|
12
15
|
def normalize_e4m3fn_to_e4m3fnuz(
|
@@ -63,8 +66,11 @@ def input_to_float8(
|
|
63
66
|
finfo = torch.finfo(dtype)
|
64
67
|
min_val, max_val = x.aminmax()
|
65
68
|
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
66
|
-
|
67
|
-
|
69
|
+
fp8_max = finfo.max
|
70
|
+
if is_hip_:
|
71
|
+
fp8_max = 224.0
|
72
|
+
scale = fp8_max / amax
|
73
|
+
x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max)
|
68
74
|
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
69
75
|
|
70
76
|
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -21,10 +21,20 @@ from dataclasses import dataclass
|
|
21
21
|
from enum import Enum
|
22
22
|
from typing import Dict, List, Optional, Tuple, Union
|
23
23
|
|
24
|
+
import torch
|
25
|
+
|
24
26
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
25
27
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
26
28
|
|
27
29
|
|
30
|
+
@dataclass
|
31
|
+
class SessionParams:
|
32
|
+
id: Optional[str] = None
|
33
|
+
rid: Optional[str] = None
|
34
|
+
offset: Optional[int] = None
|
35
|
+
replace: Optional[bool] = None
|
36
|
+
|
37
|
+
|
28
38
|
@dataclass
|
29
39
|
class GenerateReqInput:
|
30
40
|
# The input prompt. It can be a single prompt or a batch of prompts.
|
@@ -56,10 +66,8 @@ class GenerateReqInput:
|
|
56
66
|
# LoRA related
|
57
67
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
58
68
|
|
59
|
-
# Session
|
60
|
-
|
61
|
-
Union[List[Tuple[str, Optional[str]]], Tuple[str, Optional[str]]]
|
62
|
-
] = None
|
69
|
+
# Session info for continual prompting
|
70
|
+
session_params: Optional[Union[List[Dict], Dict]] = None
|
63
71
|
|
64
72
|
def normalize_batch_and_arguments(self):
|
65
73
|
if (
|
@@ -221,9 +229,8 @@ class TokenizedGenerateReqInput:
|
|
221
229
|
# The input embeds
|
222
230
|
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
223
231
|
|
224
|
-
# Session
|
225
|
-
|
226
|
-
session_rid: Optional[str] = None
|
232
|
+
# Session info for continual prompting
|
233
|
+
session_params: Optional[SessionParams] = None
|
227
234
|
|
228
235
|
|
229
236
|
@dataclass
|
@@ -407,6 +414,18 @@ class UpdateWeightsFromDistributedReqOutput:
|
|
407
414
|
message: str
|
408
415
|
|
409
416
|
|
417
|
+
@dataclass
|
418
|
+
class UpdateWeightsFromTensorReqInput:
|
419
|
+
name: str
|
420
|
+
tensor: torch.Tensor
|
421
|
+
|
422
|
+
|
423
|
+
@dataclass
|
424
|
+
class UpdateWeightsFromTensorReqOutput:
|
425
|
+
success: bool
|
426
|
+
message: str
|
427
|
+
|
428
|
+
|
410
429
|
@dataclass
|
411
430
|
class InitWeightsUpdateGroupReqInput:
|
412
431
|
# The master address
|
@@ -454,6 +473,7 @@ class ProfileReq(Enum):
|
|
454
473
|
@dataclass
|
455
474
|
class OpenSessionReqInput:
|
456
475
|
capacity_of_str_len: int
|
476
|
+
session_id: Optional[str] = None
|
457
477
|
|
458
478
|
|
459
479
|
@dataclass
|
@@ -463,4 +483,5 @@ class CloseSessionReqInput:
|
|
463
483
|
|
464
484
|
@dataclass
|
465
485
|
class OpenSessionReqOutput:
|
466
|
-
session_id: str
|
486
|
+
session_id: Optional[str]
|
487
|
+
success: bool
|
@@ -29,7 +29,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
|
29
29
|
|
30
30
|
import dataclasses
|
31
31
|
import logging
|
32
|
-
from typing import List, Optional, Tuple, Union
|
32
|
+
from typing import List, Optional, Set, Tuple, Union
|
33
33
|
|
34
34
|
import numpy as np
|
35
35
|
import torch
|
@@ -209,6 +209,7 @@ class Req:
|
|
209
209
|
lora_path: Optional[str] = None,
|
210
210
|
input_embeds: Optional[List[List[float]]] = None,
|
211
211
|
session_id: Optional[str] = None,
|
212
|
+
eos_token_ids: Optional[Set[int]] = None,
|
212
213
|
):
|
213
214
|
# Input and output info
|
214
215
|
self.rid = rid
|
@@ -236,6 +237,7 @@ class Req:
|
|
236
237
|
self.finished_reason = None
|
237
238
|
self.to_abort = False
|
238
239
|
self.stream = stream
|
240
|
+
self.eos_token_ids = eos_token_ids
|
239
241
|
|
240
242
|
# For incremental decoding
|
241
243
|
# ----- | --------- read_ids -------|
|
@@ -395,18 +397,23 @@ class Req:
|
|
395
397
|
|
396
398
|
last_token_id = self.output_ids[-1]
|
397
399
|
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
400
|
+
if not self.sampling_params.ignore_eos:
|
401
|
+
matched_eos = False
|
402
|
+
|
403
|
+
# Check stop token ids
|
404
|
+
if self.sampling_params.stop_token_ids:
|
405
|
+
matched_eos = last_token_id in self.sampling_params.stop_token_ids
|
406
|
+
if self.eos_token_ids:
|
407
|
+
matched_eos |= last_token_id in self.eos_token_ids
|
408
|
+
if self.tokenizer is not None:
|
409
|
+
matched_eos |= last_token_id == self.tokenizer.eos_token_id
|
410
|
+
if self.tokenizer.additional_stop_token_ids:
|
411
|
+
matched_eos |= (
|
412
|
+
last_token_id in self.tokenizer.additional_stop_token_ids
|
413
|
+
)
|
414
|
+
if matched_eos:
|
415
|
+
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
|
416
|
+
return
|
410
417
|
|
411
418
|
# Check stop strings
|
412
419
|
if len(self.sampling_params.stop_strs) > 0:
|
@@ -836,8 +843,8 @@ class ScheduleBatch:
|
|
836
843
|
# TODO (lianmin): Revisit this. It should be seq_len - 1
|
837
844
|
self.extend_logprob_start_lens.extend([0] * running_bs)
|
838
845
|
|
839
|
-
def check_decode_mem(self):
|
840
|
-
bs = len(self.reqs)
|
846
|
+
def check_decode_mem(self, buf_multiplier=1):
|
847
|
+
bs = len(self.reqs) * buf_multiplier
|
841
848
|
if self.token_to_kv_pool.available_size() >= bs:
|
842
849
|
return True
|
843
850
|
|