sglang 0.4.7__py3-none-any.whl → 0.4.7.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/conversation.py +6 -0
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +196 -51
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +18 -13
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +128 -43
- sglang/srt/disaggregation/utils.py +127 -123
- sglang/srt/entrypoints/engine.py +15 -1
- sglang/srt/entrypoints/http_server.py +13 -2
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/layers/activation.py +19 -0
- sglang/srt/layers/attention/aiter_backend.py +15 -2
- sglang/srt/layers/attention/cutlass_mla_backend.py +38 -15
- sglang/srt/layers/attention/flashattention_backend.py +53 -64
- sglang/srt/layers/attention/flashinfer_backend.py +1 -2
- sglang/srt/layers/attention/flashinfer_mla_backend.py +22 -24
- sglang/srt/layers/attention/flashmla_backend.py +2 -10
- sglang/srt/layers/attention/triton_backend.py +119 -119
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +23 -5
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +0 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +6 -5
- sglang/srt/layers/moe/ep_moe/layer.py +42 -32
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -4
- sglang/srt/layers/moe/topk.py +16 -8
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/lora/lora_manager.py +79 -34
- sglang/srt/lora/mem_pool.py +4 -5
- sglang/srt/managers/cache_controller.py +2 -1
- sglang/srt/managers/io_struct.py +28 -4
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +39 -6
- sglang/srt/managers/scheduler.py +73 -17
- sglang/srt/managers/tokenizer_manager.py +29 -2
- sglang/srt/mem_cache/chunk_cache.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +4 -2
- sglang/srt/mem_cache/memory_pool.py +111 -407
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +36 -12
- sglang/srt/model_executor/cuda_graph_runner.py +122 -55
- sglang/srt/model_executor/forward_batch_info.py +14 -5
- sglang/srt/model_executor/model_runner.py +6 -6
- sglang/srt/model_loader/loader.py +8 -1
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_v2.py +113 -155
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/openai_api/adapter.py +162 -4
- sglang/srt/openai_api/protocol.py +37 -1
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +318 -233
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -3
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +5 -2
- sglang/srt/speculative/eagle_utils.py +389 -109
- sglang/srt/speculative/eagle_worker.py +134 -43
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +58 -0
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +3 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/METADATA +5 -5
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/RECORD +99 -88
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/top_level.txt +0 -0
sglang/srt/layers/pooler.py
CHANGED
@@ -3,10 +3,13 @@
|
|
3
3
|
|
4
4
|
from dataclasses import dataclass
|
5
5
|
from enum import IntEnum
|
6
|
+
from typing import Optional
|
6
7
|
|
7
8
|
import torch
|
8
9
|
import torch.nn as nn
|
10
|
+
from transformers import PretrainedConfig
|
9
11
|
|
12
|
+
from sglang.srt.layers.activation import get_cross_encoder_activation_function
|
10
13
|
from sglang.srt.model_executor.model_runner import ForwardBatch
|
11
14
|
|
12
15
|
|
@@ -54,3 +57,56 @@ class Pooler(nn.Module):
|
|
54
57
|
pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
|
55
58
|
|
56
59
|
return EmbeddingPoolerOutput(embeddings=pooled_data)
|
60
|
+
|
61
|
+
|
62
|
+
class CrossEncodingPooler(nn.Module):
|
63
|
+
"""A layer that pools specific information from hidden states.
|
64
|
+
|
65
|
+
This layer does the following:
|
66
|
+
1. Extracts specific tokens or aggregates data based on pooling method.
|
67
|
+
2. Normalizes output if specified.
|
68
|
+
3. Returns structured results as `EmbeddingPoolerOutput`.
|
69
|
+
"""
|
70
|
+
|
71
|
+
def __init__(
|
72
|
+
self,
|
73
|
+
config: PretrainedConfig,
|
74
|
+
classifier: nn.Module,
|
75
|
+
pooler: Optional[nn.Module] = None,
|
76
|
+
):
|
77
|
+
super().__init__()
|
78
|
+
self.classifier = classifier
|
79
|
+
self.pooler = pooler
|
80
|
+
self.default_activation_function = get_cross_encoder_activation_function(config)
|
81
|
+
|
82
|
+
def forward(
|
83
|
+
self,
|
84
|
+
hidden_states: torch.Tensor,
|
85
|
+
forward_batch: ForwardBatch,
|
86
|
+
) -> EmbeddingPoolerOutput:
|
87
|
+
"""Pools sentence pair scores from the hidden_states."""
|
88
|
+
|
89
|
+
prompt_lens = forward_batch.extend_seq_lens
|
90
|
+
|
91
|
+
offset = 0
|
92
|
+
pooled_data_lst = []
|
93
|
+
for prompt_len in prompt_lens:
|
94
|
+
pooled_data_i = hidden_states[offset : offset + prompt_len]
|
95
|
+
|
96
|
+
if self.pooler is not None:
|
97
|
+
final_shape_tensor = self.pooler(pooled_data_i, forward_batch)
|
98
|
+
else:
|
99
|
+
final_shape_tensor = self.classifier(pooled_data_i)
|
100
|
+
|
101
|
+
pooled_data_lst.append(final_shape_tensor)
|
102
|
+
offset += prompt_len
|
103
|
+
|
104
|
+
pooled_output = torch.stack(pooled_data_lst)
|
105
|
+
|
106
|
+
if self.pooler is not None:
|
107
|
+
# apply classifier once on the full batch if possible
|
108
|
+
pooled_output = self.classifier(pooled_output)
|
109
|
+
|
110
|
+
scores = self.default_activation_function(pooled_output).squeeze(-1)
|
111
|
+
|
112
|
+
return EmbeddingPoolerOutput(embeddings=scores)
|
@@ -0,0 +1 @@
|
|
1
|
+
from .entrypoint import *
|
@@ -5,34 +5,23 @@ from dataclasses import dataclass
|
|
5
5
|
from enum import IntEnum, auto
|
6
6
|
from typing import Callable, Dict, List, Optional, Tuple
|
7
7
|
|
8
|
-
import torch
|
9
8
|
from tqdm.contrib.concurrent import thread_map
|
10
9
|
|
10
|
+
from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
|
11
|
+
DEEPGEMM_BLACKWELL,
|
12
|
+
ENABLE_JIT_DEEPGEMM,
|
13
|
+
)
|
11
14
|
from sglang.srt.server_args import ServerArgs
|
12
|
-
from sglang.srt.utils import get_bool_env_var,
|
15
|
+
from sglang.srt.utils import get_bool_env_var, get_int_env_var
|
13
16
|
|
14
17
|
logger = logging.getLogger(__name__)
|
15
|
-
_ENABLE_JIT_DEEPGEMM = False
|
16
18
|
|
17
|
-
|
18
|
-
import deep_gemm
|
19
|
+
if ENABLE_JIT_DEEPGEMM and not DEEPGEMM_BLACKWELL:
|
19
20
|
from deep_gemm import get_num_sms
|
20
21
|
from deep_gemm.jit import build
|
21
|
-
from deep_gemm.jit.compiler import get_nvcc_compiler
|
22
22
|
from deep_gemm.jit_kernels.gemm import get_best_configs
|
23
23
|
from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType
|
24
24
|
|
25
|
-
sm_version = get_device_sm()
|
26
|
-
if sm_version == 90:
|
27
|
-
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
|
28
|
-
_ENABLE_JIT_DEEPGEMM = True
|
29
|
-
except ImportError:
|
30
|
-
logger.warning("Failed to import deepgemm, disable _ENABLE_JIT_DEEPGEMM.")
|
31
|
-
|
32
|
-
|
33
|
-
def get_enable_jit_deepgemm():
|
34
|
-
return _ENABLE_JIT_DEEPGEMM
|
35
|
-
|
36
25
|
|
37
26
|
_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
|
38
27
|
_ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var(
|
@@ -52,8 +41,10 @@ os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
|
|
52
41
|
# NVRTC may have performance loss with some cases.
|
53
42
|
# And NVCC JIT speed is also 9x faster in the ref commit
|
54
43
|
_USE_NVRTC_DEFAULT = "0"
|
55
|
-
if
|
44
|
+
if ENABLE_JIT_DEEPGEMM:
|
56
45
|
try:
|
46
|
+
from deep_gemm.jit.compiler import get_nvcc_compiler
|
47
|
+
|
57
48
|
get_nvcc_compiler()
|
58
49
|
except:
|
59
50
|
logger.warning(
|
@@ -114,11 +105,12 @@ class DeepGemmKernelHelper:
|
|
114
105
|
_INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dict()
|
115
106
|
|
116
107
|
|
108
|
+
# TODO improve naming
|
117
109
|
def _compile_warning_1():
|
118
110
|
if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
|
119
111
|
logger.warning(
|
120
112
|
"Entering DeepGEMM JIT Pre-Compile session. "
|
121
|
-
"
|
113
|
+
"It may takes a long time (typically 10-20 mins) "
|
122
114
|
"if you have not run `sglang.compile_deep_gemm`. "
|
123
115
|
"It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
|
124
116
|
" for pre-compilation to reduce the overhead if you have not run it before. "
|
@@ -127,6 +119,7 @@ def _compile_warning_1():
|
|
127
119
|
)
|
128
120
|
|
129
121
|
|
122
|
+
# TODO improve naming
|
130
123
|
def _compile_warning_2():
|
131
124
|
logger.warning(
|
132
125
|
"Entering DeepGEMM JIT Single Kernel Compile session. "
|
@@ -238,6 +231,7 @@ def _compile_gemm_nt_f8f8bf16_one(
|
|
238
231
|
_ = build("gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
|
239
232
|
|
240
233
|
|
234
|
+
# TODO further refactor warmup-related
|
241
235
|
_KERNEL_HELPER_DICT: Dict[DeepGemmKernelType, DeepGemmKernelHelper] = {
|
242
236
|
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: DeepGemmKernelHelper(
|
243
237
|
name="m_grouped_gemm_fp8_fp8_bf16_nt_masked",
|
@@ -270,7 +264,6 @@ def _maybe_compile_deep_gemm_one_type_all(
|
|
270
264
|
num_groups: int,
|
271
265
|
m_list: Optional[List[int]] = None,
|
272
266
|
) -> None:
|
273
|
-
|
274
267
|
global _INITIALIZATION_DICT
|
275
268
|
global _BUILTIN_M_LIST
|
276
269
|
|
@@ -304,56 +297,6 @@ def _maybe_compile_deep_gemm_one_type_all(
|
|
304
297
|
thread_map(compile_func, collected_configs, max_workers=_COMPILE_WORKERS)
|
305
298
|
|
306
299
|
|
307
|
-
def grouped_gemm_nt_f8f8bf16_masked(
|
308
|
-
lhs: Tuple[torch.Tensor, torch.Tensor],
|
309
|
-
rhs: Tuple[torch.Tensor, torch.Tensor],
|
310
|
-
out: torch.Tensor,
|
311
|
-
masked_m: torch.Tensor,
|
312
|
-
expected_m: int,
|
313
|
-
):
|
314
|
-
num_groups, _, k = lhs[0].shape
|
315
|
-
_, n, _ = rhs[0].shape
|
316
|
-
|
317
|
-
kernel_type = DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED
|
318
|
-
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
|
319
|
-
|
320
|
-
with _log_jit_build(expected_m, n, k, kernel_type):
|
321
|
-
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
322
|
-
lhs, rhs, out, masked_m, expected_m
|
323
|
-
)
|
324
|
-
|
325
|
-
|
326
|
-
def grouped_gemm_nt_f8f8bf16_contig(
|
327
|
-
lhs: Tuple[torch.Tensor, torch.Tensor],
|
328
|
-
rhs: Tuple[torch.Tensor, torch.Tensor],
|
329
|
-
out: torch.Tensor,
|
330
|
-
m_indices: torch.Tensor,
|
331
|
-
):
|
332
|
-
m, k = lhs[0].shape
|
333
|
-
num_groups, n, _ = rhs[0].shape
|
334
|
-
|
335
|
-
kernel_type = DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
|
336
|
-
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
|
337
|
-
|
338
|
-
with _log_jit_build(m, n, k, kernel_type):
|
339
|
-
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs, rhs, out, m_indices)
|
340
|
-
|
341
|
-
|
342
|
-
def gemm_nt_f8f8bf16(
|
343
|
-
lhs: Tuple[torch.Tensor, torch.Tensor],
|
344
|
-
rhs: Tuple[torch.Tensor, torch.Tensor],
|
345
|
-
out: torch.Tensor,
|
346
|
-
):
|
347
|
-
m, k = lhs[0].shape
|
348
|
-
n, _ = rhs[0].shape
|
349
|
-
|
350
|
-
kernel_type = DeepGemmKernelType.GEMM_NT_F8F8BF16
|
351
|
-
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, 1)
|
352
|
-
|
353
|
-
with _log_jit_build(m, n, k, kernel_type):
|
354
|
-
deep_gemm.gemm_fp8_fp8_bf16_nt(lhs, rhs, out)
|
355
|
-
|
356
|
-
|
357
300
|
@contextmanager
|
358
301
|
def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
|
359
302
|
if _IN_PRECOMPILE_STAGE:
|
@@ -368,7 +311,8 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
|
|
368
311
|
ret = origin_func(self, *args, **kwargs)
|
369
312
|
if ret is None:
|
370
313
|
kernel_helper = _KERNEL_HELPER_DICT[kernel_type]
|
371
|
-
|
314
|
+
if not DEEPGEMM_BLACKWELL:
|
315
|
+
_compile_warning_2()
|
372
316
|
logger.warning(
|
373
317
|
f"DeepGEMM JIT Compiling for <{kernel_helper.name}> M={M}, N={N}, K={K}. Please wait."
|
374
318
|
)
|
@@ -380,13 +324,12 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
|
|
380
324
|
|
381
325
|
|
382
326
|
@contextmanager
|
383
|
-
def
|
384
|
-
|
327
|
+
def deep_gemm_execution_hook(
|
328
|
+
m: int, n: int, k: int, num_groups: int, kernel_type: DeepGemmKernelType
|
329
|
+
):
|
330
|
+
# not supported yet
|
331
|
+
if not DEEPGEMM_BLACKWELL:
|
332
|
+
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
|
333
|
+
|
334
|
+
with _log_jit_build(m, n, k, kernel_type):
|
385
335
|
yield
|
386
|
-
else:
|
387
|
-
original_num_sms = deep_gemm.get_num_sms()
|
388
|
-
deep_gemm.set_num_sms(num_sms)
|
389
|
-
try:
|
390
|
-
yield
|
391
|
-
finally:
|
392
|
-
deep_gemm.set_num_sms(original_num_sms)
|
@@ -0,0 +1,32 @@
|
|
1
|
+
import logging
|
2
|
+
|
3
|
+
from sglang.srt.utils import get_bool_env_var, get_device_sm
|
4
|
+
|
5
|
+
logger = logging.getLogger(__name__)
|
6
|
+
|
7
|
+
|
8
|
+
def _compute_enable_deep_gemm():
|
9
|
+
sm_version = get_device_sm()
|
10
|
+
if sm_version < 90:
|
11
|
+
return False
|
12
|
+
|
13
|
+
try:
|
14
|
+
import deep_gemm
|
15
|
+
except ImportError:
|
16
|
+
logger.warning("Failed to import deep_gemm, disable ENABLE_JIT_DEEPGEMM.")
|
17
|
+
return False
|
18
|
+
|
19
|
+
return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true")
|
20
|
+
|
21
|
+
|
22
|
+
ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
|
23
|
+
|
24
|
+
try:
|
25
|
+
from deep_gemm import fp8_gemm_nt
|
26
|
+
|
27
|
+
# They have not given a name to this breaking change
|
28
|
+
DEEPGEMM_BLACKWELL = True
|
29
|
+
except ImportError:
|
30
|
+
DEEPGEMM_BLACKWELL = False
|
31
|
+
|
32
|
+
DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_BLACKWELL
|
@@ -0,0 +1,110 @@
|
|
1
|
+
import logging
|
2
|
+
from contextlib import contextmanager
|
3
|
+
from typing import Tuple
|
4
|
+
|
5
|
+
import torch
|
6
|
+
|
7
|
+
from sglang.srt.layers.quantization.deep_gemm_wrapper import compile_utils
|
8
|
+
from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
|
9
|
+
DEEPGEMM_BLACKWELL,
|
10
|
+
DEEPGEMM_SCALE_UE8M0,
|
11
|
+
ENABLE_JIT_DEEPGEMM,
|
12
|
+
)
|
13
|
+
from sglang.srt.server_args import ServerArgs
|
14
|
+
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
if ENABLE_JIT_DEEPGEMM:
|
18
|
+
import deep_gemm
|
19
|
+
|
20
|
+
if DEEPGEMM_BLACKWELL:
|
21
|
+
from deep_gemm import fp8_gemm_nt as _gemm_nt_f8f8bf16_raw
|
22
|
+
from deep_gemm import (
|
23
|
+
fp8_m_grouped_gemm_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw,
|
24
|
+
)
|
25
|
+
from deep_gemm import (
|
26
|
+
m_grouped_fp8_gemm_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw,
|
27
|
+
)
|
28
|
+
else:
|
29
|
+
from deep_gemm import gemm_fp8_fp8_bf16_nt as _gemm_nt_f8f8bf16_raw
|
30
|
+
from deep_gemm import get_col_major_tma_aligned_tensor
|
31
|
+
from deep_gemm import (
|
32
|
+
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw,
|
33
|
+
)
|
34
|
+
from deep_gemm import (
|
35
|
+
m_grouped_gemm_fp8_fp8_bf16_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw,
|
36
|
+
)
|
37
|
+
|
38
|
+
|
39
|
+
def grouped_gemm_nt_f8f8bf16_masked(
|
40
|
+
lhs: Tuple[torch.Tensor, torch.Tensor],
|
41
|
+
rhs: Tuple[torch.Tensor, torch.Tensor],
|
42
|
+
out: torch.Tensor,
|
43
|
+
masked_m: torch.Tensor,
|
44
|
+
expected_m: int,
|
45
|
+
recipe=None,
|
46
|
+
):
|
47
|
+
num_groups, _, k = lhs[0].shape
|
48
|
+
_, n, _ = rhs[0].shape
|
49
|
+
kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED
|
50
|
+
|
51
|
+
with compile_utils.deep_gemm_execution_hook(
|
52
|
+
expected_m, n, k, num_groups, kernel_type
|
53
|
+
):
|
54
|
+
_grouped_gemm_nt_f8f8bf16_masked_raw(
|
55
|
+
lhs,
|
56
|
+
rhs,
|
57
|
+
out,
|
58
|
+
masked_m,
|
59
|
+
expected_m,
|
60
|
+
**({"recipe": recipe} if DEEPGEMM_BLACKWELL else {})
|
61
|
+
)
|
62
|
+
|
63
|
+
|
64
|
+
def grouped_gemm_nt_f8f8bf16_contig(
|
65
|
+
lhs: Tuple[torch.Tensor, torch.Tensor],
|
66
|
+
rhs: Tuple[torch.Tensor, torch.Tensor],
|
67
|
+
out: torch.Tensor,
|
68
|
+
m_indices: torch.Tensor,
|
69
|
+
):
|
70
|
+
m, k = lhs[0].shape
|
71
|
+
num_groups, n, _ = rhs[0].shape
|
72
|
+
kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
|
73
|
+
|
74
|
+
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
|
75
|
+
_grouped_gemm_nt_f8f8bf16_contig_raw(lhs, rhs, out, m_indices)
|
76
|
+
|
77
|
+
|
78
|
+
def gemm_nt_f8f8bf16(
|
79
|
+
lhs: Tuple[torch.Tensor, torch.Tensor],
|
80
|
+
rhs: Tuple[torch.Tensor, torch.Tensor],
|
81
|
+
out: torch.Tensor,
|
82
|
+
):
|
83
|
+
m, k = lhs[0].shape
|
84
|
+
n, _ = rhs[0].shape
|
85
|
+
num_groups = 1
|
86
|
+
kernel_type = compile_utils.DeepGemmKernelType.GEMM_NT_F8F8BF16
|
87
|
+
|
88
|
+
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
|
89
|
+
_gemm_nt_f8f8bf16_raw(
|
90
|
+
lhs,
|
91
|
+
rhs,
|
92
|
+
out,
|
93
|
+
)
|
94
|
+
|
95
|
+
|
96
|
+
def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
|
97
|
+
compile_utils.update_deep_gemm_config(gpu_id, server_args)
|
98
|
+
|
99
|
+
|
100
|
+
@contextmanager
|
101
|
+
def configure_deep_gemm_num_sms(num_sms):
|
102
|
+
if num_sms is None:
|
103
|
+
yield
|
104
|
+
else:
|
105
|
+
original_num_sms = deep_gemm.get_num_sms()
|
106
|
+
deep_gemm.set_num_sms(num_sms)
|
107
|
+
try:
|
108
|
+
yield
|
109
|
+
finally:
|
110
|
+
deep_gemm.set_num_sms(original_num_sms)
|
@@ -23,7 +23,8 @@ import torch
|
|
23
23
|
import triton
|
24
24
|
import triton.language as tl
|
25
25
|
|
26
|
-
from sglang.
|
26
|
+
from sglang.math_utils import align
|
27
|
+
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
27
28
|
from sglang.srt.utils import (
|
28
29
|
direct_register_custom_op,
|
29
30
|
get_device_core_count,
|
@@ -44,10 +45,6 @@ if _is_cuda:
|
|
44
45
|
sgl_per_token_quant_fp8,
|
45
46
|
)
|
46
47
|
|
47
|
-
from sglang.srt.layers.quantization.deep_gemm import (
|
48
|
-
gemm_nt_f8f8bf16 as deep_gemm_gemm_nt_f8f8bf16,
|
49
|
-
)
|
50
|
-
|
51
48
|
logger = logging.getLogger(__name__)
|
52
49
|
|
53
50
|
|
@@ -67,7 +64,6 @@ else:
|
|
67
64
|
fp8_max = torch.finfo(fp8_dtype).max
|
68
65
|
fp8_min = -fp8_max
|
69
66
|
|
70
|
-
|
71
67
|
if supports_custom_op():
|
72
68
|
|
73
69
|
def deep_gemm_fp8_fp8_bf16_nt(
|
@@ -77,7 +73,7 @@ if supports_custom_op():
|
|
77
73
|
Bs: torch.Tensor,
|
78
74
|
C: torch.Tensor,
|
79
75
|
) -> None:
|
80
|
-
|
76
|
+
deep_gemm_wrapper.gemm_nt_f8f8bf16((A, As), (B, Bs), C)
|
81
77
|
|
82
78
|
def deep_gemm_fp8_fp8_bf16_nt_fake(
|
83
79
|
A: torch.Tensor,
|
@@ -280,6 +276,7 @@ def sglang_per_token_group_quant_fp8(
|
|
280
276
|
eps: float = 1e-10,
|
281
277
|
column_major_scales: bool = False,
|
282
278
|
scale_tma_aligned: bool = False,
|
279
|
+
scale_ue8m0: bool = False,
|
283
280
|
):
|
284
281
|
assert (
|
285
282
|
x.shape[-1] % group_size == 0
|
@@ -287,8 +284,21 @@ def sglang_per_token_group_quant_fp8(
|
|
287
284
|
assert x.is_contiguous(), "`x` is not contiguous"
|
288
285
|
|
289
286
|
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
|
290
|
-
if
|
287
|
+
if scale_ue8m0:
|
288
|
+
assert column_major_scales and scale_tma_aligned
|
289
|
+
x_q_mn, x_q_k = x.shape
|
290
|
+
x_s_mn, x_s_k = x_q_mn, x_q_k // 128
|
291
|
+
aligned_mn = align(x_s_mn, 4)
|
292
|
+
aligned_k = align(x_s_k, 4)
|
293
|
+
# TODO(FIXME): Fix cuda kernel and recover here to empty.
|
294
|
+
x_s = torch.zeros(
|
295
|
+
(aligned_k // 4, aligned_mn),
|
296
|
+
device=x.device,
|
297
|
+
dtype=torch.int,
|
298
|
+
).transpose(0, 1)[:x_s_mn, :]
|
299
|
+
elif column_major_scales:
|
291
300
|
if scale_tma_aligned:
|
301
|
+
# TODO extract "align" function
|
292
302
|
# aligned to 4 * sizeof(float)
|
293
303
|
aligned_size = (x.shape[-2] + 3) // 4 * 4
|
294
304
|
x_s = torch.empty(
|
@@ -309,7 +319,9 @@ def sglang_per_token_group_quant_fp8(
|
|
309
319
|
dtype=torch.float32,
|
310
320
|
)
|
311
321
|
if x.shape[0] > 0:
|
312
|
-
sgl_per_token_group_quant_fp8(
|
322
|
+
sgl_per_token_group_quant_fp8(
|
323
|
+
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
|
324
|
+
)
|
313
325
|
|
314
326
|
return x_q, x_s
|
315
327
|
|
@@ -754,7 +766,15 @@ def prepare_block_fp8_matmul_inputs(
|
|
754
766
|
assert A.shape[-1] == B.shape[-1]
|
755
767
|
assert A.shape[:-1] == As.shape[:-1]
|
756
768
|
assert A.is_contiguous()
|
757
|
-
|
769
|
+
|
770
|
+
if As.dtype == torch.float:
|
771
|
+
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
|
772
|
+
elif As.dtype == torch.int:
|
773
|
+
assert (
|
774
|
+
triton.cdiv(triton.cdiv(A.shape[-1], block_k), 4) == As.shape[-1]
|
775
|
+
), f"{A.shape=} {As.shape=} {block_size=}"
|
776
|
+
else:
|
777
|
+
raise NotImplementedError
|
758
778
|
|
759
779
|
M = A.numel() // A.shape[-1]
|
760
780
|
|
@@ -762,8 +782,17 @@ def prepare_block_fp8_matmul_inputs(
|
|
762
782
|
assert B.is_contiguous()
|
763
783
|
assert Bs.ndim == 2
|
764
784
|
N, K = B.shape
|
765
|
-
|
766
|
-
|
785
|
+
|
786
|
+
if Bs.dtype == torch.float:
|
787
|
+
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
788
|
+
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
789
|
+
elif Bs.dtype == torch.int:
|
790
|
+
assert N == Bs.shape[0], f"{B.shape=} {Bs.shape=} {block_size=}"
|
791
|
+
assert (
|
792
|
+
triton.cdiv(triton.cdiv(K, block_k), 4) == Bs.shape[1]
|
793
|
+
), f"{B.shape=} {Bs.shape=} {block_size=}"
|
794
|
+
else:
|
795
|
+
raise NotImplementedError
|
767
796
|
|
768
797
|
C_shape = A.shape[:-1] + (N,)
|
769
798
|
C = A.new_empty(C_shape, dtype=output_dtype)
|
@@ -782,12 +811,12 @@ def w8a8_block_fp8_matmul_deepgemm(
|
|
782
811
|
M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size, output_dtype)
|
783
812
|
|
784
813
|
# Deepgemm only supports output tensor type as bfloat16
|
785
|
-
assert C.dtype == torch.bfloat16 and
|
814
|
+
assert C.dtype == torch.bfloat16 and deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
786
815
|
|
787
816
|
if supports_custom_op():
|
788
817
|
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
|
789
818
|
else:
|
790
|
-
|
819
|
+
deep_gemm_wrapper.gemm_nt_f8f8bf16((A, As), (B, Bs), C)
|
791
820
|
|
792
821
|
return C
|
793
822
|
|
@@ -881,7 +910,7 @@ def w8a8_block_fp8_matmul(
|
|
881
910
|
block_size: List[int],
|
882
911
|
output_dtype: torch.dtype = torch.float16,
|
883
912
|
) -> torch.Tensor:
|
884
|
-
if output_dtype == torch.bfloat16 and
|
913
|
+
if output_dtype == torch.bfloat16 and deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
885
914
|
return w8a8_block_fp8_matmul_deepgemm(
|
886
915
|
A, B, As, Bs, block_size, output_dtype=output_dtype
|
887
916
|
)
|
@@ -1,9 +1,10 @@
|
|
1
|
-
import os
|
2
|
-
from curses import flash
|
3
1
|
from typing import Callable, List, Optional, Tuple
|
4
2
|
|
3
|
+
import einops
|
5
4
|
import torch
|
6
5
|
|
6
|
+
from sglang.math_utils import align
|
7
|
+
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
7
8
|
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
|
8
9
|
from sglang.srt.layers.utils import is_sm100_supported
|
9
10
|
|
@@ -14,7 +15,6 @@ try:
|
|
14
15
|
except ImportError:
|
15
16
|
VLLM_AVAILABLE = False
|
16
17
|
|
17
|
-
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
18
18
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
19
19
|
fp8_dtype,
|
20
20
|
fp8_max,
|
@@ -137,7 +137,7 @@ def dispatch_w8a8_block_fp8_linear() -> Callable:
|
|
137
137
|
return cutlass_w8a8_block_fp8_linear_with_fallback
|
138
138
|
elif _use_aiter:
|
139
139
|
return aiter_w8a8_block_fp8_linear
|
140
|
-
elif
|
140
|
+
elif deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
141
141
|
return deepgemm_w8a8_block_fp8_linear_with_fallback
|
142
142
|
else:
|
143
143
|
return triton_w8a8_block_fp8_linear
|
@@ -238,7 +238,14 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
|
|
238
238
|
block_size[1],
|
239
239
|
column_major_scales=True,
|
240
240
|
scale_tma_aligned=True,
|
241
|
+
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
241
242
|
)
|
243
|
+
|
244
|
+
# NOTE(alcanderian): Useless when scale is packed to int32
|
245
|
+
# if get_bool_env_var("SGLANG_W8A8_DEEPGEMM_SANITY_CHECK_UE8M0"):
|
246
|
+
# _check_ue8m0("x_scale", x_scale)
|
247
|
+
# _check_ue8m0("weight_scale", ws)
|
248
|
+
|
242
249
|
output = w8a8_block_fp8_matmul_deepgemm(
|
243
250
|
q_input, weight, x_scale, weight_scale, block_size, output_dtype=output_dtype
|
244
251
|
)
|
@@ -247,6 +254,11 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
|
|
247
254
|
return output.to(dtype=output_dtype).view(*output_shape)
|
248
255
|
|
249
256
|
|
257
|
+
def _check_ue8m0(name, x):
|
258
|
+
x_ceil = ceil_to_ue8m0(x)
|
259
|
+
assert torch.all(x == x_ceil), f"{name=} {x=} {x_ceil=}"
|
260
|
+
|
261
|
+
|
250
262
|
def aiter_w8a8_block_fp8_linear(
|
251
263
|
input: torch.Tensor,
|
252
264
|
weight: torch.Tensor,
|
@@ -369,27 +381,80 @@ def block_quant_dequant(
|
|
369
381
|
The output is an unquantized tensor with dtype.
|
370
382
|
"""
|
371
383
|
block_n, block_k = block_size[0], block_size[1]
|
372
|
-
n, k = x_q_block.shape
|
373
|
-
n_tiles = (n + block_n - 1) // block_n
|
374
|
-
k_tiles = (k + block_k - 1) // block_k
|
375
|
-
assert n_tiles == x_s.shape[0]
|
376
|
-
assert k_tiles == x_s.shape[1]
|
384
|
+
*_, n, k = x_q_block.shape
|
377
385
|
|
378
|
-
|
386
|
+
# ... n_scale k_scale -> ... (n_scale block_n) (k_scale block_k)
|
387
|
+
x_scale_repeat = x_s.repeat_interleave(block_n, dim=-2).repeat_interleave(
|
388
|
+
block_k, dim=-1
|
389
|
+
)
|
390
|
+
x_scale_repeat = x_scale_repeat[..., :n, :k]
|
391
|
+
|
392
|
+
return (x_q_block.to(torch.float32) * x_scale_repeat).to(dtype)
|
393
|
+
|
394
|
+
|
395
|
+
def requant_weight_ue8m0_inplace(weight, weight_scale_inv, weight_block_size):
|
396
|
+
assert isinstance(weight, torch.nn.Parameter)
|
397
|
+
assert isinstance(weight_scale_inv, torch.nn.Parameter)
|
398
|
+
weight.data, weight_scale_inv.data = _requant_weight_ue8m0(
|
399
|
+
weight, weight_scale_inv, weight_block_size
|
400
|
+
)
|
401
|
+
|
402
|
+
|
403
|
+
def _requant_weight_ue8m0(
|
404
|
+
weight: torch.Tensor,
|
405
|
+
weight_scale_inv: torch.Tensor,
|
406
|
+
weight_block_size: List[int],
|
407
|
+
):
|
408
|
+
assert weight_block_size == [128, 128]
|
409
|
+
|
410
|
+
*_, n, k = weight.shape
|
411
|
+
|
412
|
+
weight_dequant = block_quant_dequant(
|
413
|
+
weight,
|
414
|
+
weight_scale_inv,
|
415
|
+
weight_block_size,
|
416
|
+
torch.bfloat16,
|
417
|
+
)
|
418
|
+
|
419
|
+
weight_dequant_flat = weight_dequant.view((-1, k))
|
420
|
+
out_w_flat, out_s_flat = per_block_cast_to_fp8(weight_dequant_flat)
|
421
|
+
|
422
|
+
out_w = out_w_flat.view(weight.shape)
|
423
|
+
out_s = out_s_flat.view(weight_scale_inv.shape)
|
424
|
+
|
425
|
+
# NOTE copy and modified from DeepGEMM
|
426
|
+
def _transform_scale(sf, mn: int):
|
427
|
+
import deep_gemm.utils.layout
|
428
|
+
|
429
|
+
sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128)
|
430
|
+
sf = deep_gemm.utils.layout.get_col_major_tma_aligned_packed_tensor(sf)
|
431
|
+
return sf
|
432
|
+
|
433
|
+
out_s = _transform_scale(out_s, mn=out_w.shape[-2])
|
434
|
+
|
435
|
+
return out_w, out_s
|
436
|
+
|
437
|
+
|
438
|
+
# COPIED FROM DeepGEMM
|
439
|
+
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
440
|
+
assert x.dim() == 2
|
441
|
+
m, n = x.shape
|
442
|
+
x_padded = torch.zeros(
|
443
|
+
(align(m, 128), align(n, 128)), dtype=x.dtype, device=x.device
|
444
|
+
)
|
445
|
+
x_padded[:m, :n] = x
|
446
|
+
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
|
447
|
+
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
448
|
+
sf = ceil_to_ue8m0(x_amax / 448.0)
|
449
|
+
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
|
450
|
+
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
|
451
|
+
x_view.size(0), x_view.size(2)
|
452
|
+
)
|
379
453
|
|
380
|
-
for j in range(n_tiles):
|
381
|
-
for i in range(k_tiles):
|
382
|
-
x_q_block_tile = x_q_block[
|
383
|
-
j * block_n : min((j + 1) * block_n, n),
|
384
|
-
i * block_k : min((i + 1) * block_k, k),
|
385
|
-
]
|
386
|
-
x_dq_block_tile = x_dq_block[
|
387
|
-
j * block_n : min((j + 1) * block_n, n),
|
388
|
-
i * block_k : min((i + 1) * block_k, k),
|
389
|
-
]
|
390
|
-
x_dq_block_tile[:, :] = x_q_block_tile.to(torch.float32) * x_s[j][i]
|
391
454
|
|
392
|
-
|
455
|
+
# COPIED FROM DeepGEMM
|
456
|
+
def ceil_to_ue8m0(x: torch.Tensor):
|
457
|
+
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
|
393
458
|
|
394
459
|
|
395
460
|
def channel_quant_to_tensor_quant(
|