sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +6 -1
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +8 -7
- sglang/srt/disaggregation/decode.py +8 -4
- sglang/srt/disaggregation/mooncake/conn.py +43 -25
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/distributed/parallel_state.py +4 -2
- sglang/srt/entrypoints/context.py +3 -20
- sglang/srt/entrypoints/engine.py +13 -8
- sglang/srt/entrypoints/harmony_utils.py +2 -0
- sglang/srt/entrypoints/http_server.py +68 -5
- sglang/srt/entrypoints/openai/protocol.py +2 -9
- sglang/srt/entrypoints/openai/serving_chat.py +60 -265
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/tool_server.py +4 -3
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/jinja_template_utils.py +6 -0
- sglang/srt/layers/attention/aiter_backend.py +370 -107
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +55 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +24 -27
- sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
- sglang/srt/layers/attention/vision.py +9 -1
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +11 -13
- sglang/srt/layers/dp_attention.py +118 -27
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +12 -18
- sglang/srt/layers/moe/cutlass_moe.py +11 -16
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +60 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +4 -1
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +10 -35
- sglang/srt/layers/quantization/awq.py +15 -16
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +22 -10
- sglang/srt/layers/quantization/gptq.py +12 -17
- sglang/srt/layers/quantization/marlin_utils.py +15 -5
- sglang/srt/layers/quantization/modelopt_quant.py +58 -41
- sglang/srt/layers/quantization/mxfp4.py +20 -3
- sglang/srt/layers/quantization/utils.py +52 -2
- sglang/srt/layers/quantization/w4afp8.py +20 -11
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +281 -2
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +66 -116
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +12 -48
- sglang/srt/lora/lora_registry.py +20 -9
- sglang/srt/lora/mem_pool.py +20 -63
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +24 -29
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -6
- sglang/srt/managers/mm_utils.py +1 -2
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +43 -49
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +18 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/tokenizer_manager.py +53 -44
- sglang/srt/mem_cache/allocator.py +39 -214
- sglang/srt/mem_cache/allocator_ascend.py +158 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +34 -24
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +33 -35
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -23
- sglang/srt/model_executor/forward_batch_info.py +33 -14
- sglang/srt/model_executor/model_runner.py +179 -81
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/models/deepseek_nextn.py +2 -1
- sglang/srt/models/deepseek_v2.py +79 -38
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +8 -9
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +11 -11
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +142 -20
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +10 -27
- sglang/srt/models/llama4.py +19 -6
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +20 -5
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_classification.py +78 -0
- sglang/srt/models/qwen3_moe.py +18 -5
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +6 -2
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/operations.py +17 -2
- sglang/srt/reasoning_parser.py +316 -0
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +142 -140
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +16 -12
- sglang/srt/utils.py +3 -3
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
- sglang/lang/backend/__init__.py +0 -0
- sglang/srt/function_call/harmony_tool_parser.py +0 -130
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
sglang/lang/backend/__init__.py
DELETED
File without changes
|
@@ -1,130 +0,0 @@
|
|
1
|
-
# Copyright 2023-2024 SGLang Team
|
2
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
-
# you may not use this file except in compliance with the License.
|
4
|
-
# You may obtain a copy of the License at
|
5
|
-
#
|
6
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
-
#
|
8
|
-
# Unless required by applicable law or agreed to in writing, software
|
9
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
-
# See the License for the specific language governing permissions and
|
12
|
-
# limitations under the License.
|
13
|
-
# ==============================================================================
|
14
|
-
"""Harmony tool call parser for processing tool calls in harmony models."""
|
15
|
-
|
16
|
-
import uuid
|
17
|
-
from typing import List, Optional, Tuple
|
18
|
-
|
19
|
-
from sglang.srt.entrypoints.openai.protocol import (
|
20
|
-
ChatMessage,
|
21
|
-
FunctionResponse,
|
22
|
-
ToolCall,
|
23
|
-
)
|
24
|
-
|
25
|
-
|
26
|
-
class HarmonyToolCallParser:
|
27
|
-
"""Parser for extracting tool calls from harmony model outputs."""
|
28
|
-
|
29
|
-
def extract_tool_calls_from_message(self, msg) -> Optional[ToolCall]:
|
30
|
-
"""
|
31
|
-
Extract tool call from a single message if it's a tool call.
|
32
|
-
|
33
|
-
Args:
|
34
|
-
msg: The harmony message
|
35
|
-
|
36
|
-
Returns:
|
37
|
-
ToolCall if the message is a tool call, None otherwise
|
38
|
-
"""
|
39
|
-
if (
|
40
|
-
msg.channel == "commentary"
|
41
|
-
and msg.recipient
|
42
|
-
and msg.recipient.startswith("functions.")
|
43
|
-
):
|
44
|
-
function_name = msg.recipient.split(".")[-1]
|
45
|
-
arguments = msg.content[0].text if msg.content else "{}"
|
46
|
-
|
47
|
-
return ToolCall(
|
48
|
-
id=f"call_{uuid.uuid4().hex[:24]}",
|
49
|
-
function=FunctionResponse(
|
50
|
-
name=function_name,
|
51
|
-
arguments=arguments,
|
52
|
-
),
|
53
|
-
)
|
54
|
-
return None
|
55
|
-
|
56
|
-
def process_streaming_chunk(
|
57
|
-
self,
|
58
|
-
harmony_parser,
|
59
|
-
index: int,
|
60
|
-
tool_call_trackers: dict,
|
61
|
-
stream_buffers: dict,
|
62
|
-
) -> Tuple[Optional[dict], bool, Optional[str]]:
|
63
|
-
"""
|
64
|
-
Process a streaming chunk for tool calls.
|
65
|
-
|
66
|
-
Args:
|
67
|
-
harmony_parser: The harmony parser instance
|
68
|
-
index: The choice index
|
69
|
-
tool_call_trackers: Dict tracking tool calls per choice
|
70
|
-
stream_buffers: Dict for buffering content
|
71
|
-
|
72
|
-
Returns:
|
73
|
-
Tuple of (tool_call_data, is_tool_call, delta)
|
74
|
-
"""
|
75
|
-
# Check if we're in a tool call
|
76
|
-
is_tool_call = (
|
77
|
-
harmony_parser.current_channel == "commentary"
|
78
|
-
and harmony_parser.current_recipient
|
79
|
-
and harmony_parser.current_recipient.startswith("functions.")
|
80
|
-
)
|
81
|
-
|
82
|
-
delta = harmony_parser.last_content_delta or ""
|
83
|
-
tool_call_data = None
|
84
|
-
|
85
|
-
if is_tool_call:
|
86
|
-
# Handle tool call streaming
|
87
|
-
function_name = harmony_parser.current_recipient.split(".")[-1]
|
88
|
-
|
89
|
-
# Track tool call indices per choice
|
90
|
-
if index not in tool_call_trackers:
|
91
|
-
tool_call_trackers[index] = {"count": 0, "current_function": None}
|
92
|
-
|
93
|
-
# Check if we just started a new tool call
|
94
|
-
tool_call_tracker = tool_call_trackers[index]
|
95
|
-
if tool_call_tracker["current_function"] != function_name:
|
96
|
-
# New tool call started
|
97
|
-
tool_call_tracker["current_function"] = function_name
|
98
|
-
tool_call_index = tool_call_tracker["count"]
|
99
|
-
tool_call_tracker["count"] += 1
|
100
|
-
|
101
|
-
# Store the tool call index for this function
|
102
|
-
tool_call_key = f"{index}_{function_name}"
|
103
|
-
stream_buffers[tool_call_key] = {
|
104
|
-
"index": tool_call_index,
|
105
|
-
"content": "",
|
106
|
-
}
|
107
|
-
|
108
|
-
tool_call_data = {
|
109
|
-
"id": f"call_{uuid.uuid4().hex[:24]}",
|
110
|
-
"index": tool_call_index,
|
111
|
-
"function_name": function_name,
|
112
|
-
"arguments": delta,
|
113
|
-
"is_first_chunk": True,
|
114
|
-
}
|
115
|
-
else:
|
116
|
-
# Subsequent chunks for the same tool call
|
117
|
-
tool_call_key = f"{index}_{function_name}"
|
118
|
-
tool_call_index = stream_buffers[tool_call_key]["index"]
|
119
|
-
|
120
|
-
tool_call_data = {
|
121
|
-
"id": None,
|
122
|
-
"index": tool_call_index,
|
123
|
-
"function_name": None,
|
124
|
-
"arguments": delta,
|
125
|
-
"is_first_chunk": False,
|
126
|
-
}
|
127
|
-
|
128
|
-
stream_buffers[tool_call_key]["content"] += delta
|
129
|
-
|
130
|
-
return tool_call_data, is_tool_call, delta
|
@@ -1,352 +0,0 @@
|
|
1
|
-
# SPDX-License-Identifier: Apache-2.0
|
2
|
-
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
3
|
-
|
4
|
-
import functools
|
5
|
-
import struct
|
6
|
-
from dataclasses import dataclass
|
7
|
-
from enum import Enum
|
8
|
-
from typing import Optional, Union
|
9
|
-
|
10
|
-
_SCALAR_TYPES_ID_MAP = {}
|
11
|
-
|
12
|
-
|
13
|
-
# Mirrors enum in `core/scalar_type.hpp`
|
14
|
-
class NanRepr(Enum):
|
15
|
-
NONE = 0 # nans are not supported
|
16
|
-
IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
|
17
|
-
EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
|
18
|
-
|
19
|
-
|
20
|
-
# This ScalarType class is a parallel implementation of the C++ ScalarType
|
21
|
-
# class found in csrc/core/scalar_type.hpp. These two classes should be kept
|
22
|
-
# in sync until the inductor fully supports custom C++ classes.
|
23
|
-
@dataclass(frozen=True)
|
24
|
-
class ScalarType:
|
25
|
-
"""
|
26
|
-
ScalarType can represent a wide range of floating point and integer
|
27
|
-
types, in particular it can be used to represent sub-byte data types
|
28
|
-
(something that torch.dtype currently does not support). It is also
|
29
|
-
capable of representing types with a bias, i.e.:
|
30
|
-
`stored_value = value + bias`,
|
31
|
-
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
|
32
|
-
of 8). The implementation for this class can be found in
|
33
|
-
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
|
34
|
-
with that file.
|
35
|
-
"""
|
36
|
-
|
37
|
-
exponent: int
|
38
|
-
"""
|
39
|
-
Number of bits in the exponent if this is a floating point type
|
40
|
-
(zero if this an integer type)
|
41
|
-
"""
|
42
|
-
|
43
|
-
mantissa: int
|
44
|
-
"""
|
45
|
-
Number of bits in the mantissa if this is a floating point type,
|
46
|
-
or the number bits representing an integer excluding the sign bit if
|
47
|
-
this an integer type.
|
48
|
-
"""
|
49
|
-
|
50
|
-
signed: bool
|
51
|
-
"If the type is signed (i.e. has a sign bit)"
|
52
|
-
|
53
|
-
bias: int
|
54
|
-
"""
|
55
|
-
bias used to encode the values in this scalar type
|
56
|
-
(value = stored_value - bias, default 0) for example if we store the
|
57
|
-
type as an unsigned integer with a bias of 128 then the value 0 will be
|
58
|
-
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
|
59
|
-
"""
|
60
|
-
|
61
|
-
_finite_values_only: bool = False
|
62
|
-
"""
|
63
|
-
Private: if infs are supported, used `has_infs()` instead.
|
64
|
-
"""
|
65
|
-
|
66
|
-
nan_repr: NanRepr = NanRepr.IEEE_754
|
67
|
-
"""
|
68
|
-
How NaNs are represent in this scalar type, returns NanRepr value.
|
69
|
-
(not applicable for integer types)
|
70
|
-
"""
|
71
|
-
|
72
|
-
def _floating_point_max_int(self) -> int:
|
73
|
-
assert (
|
74
|
-
self.mantissa <= 52 and self.exponent <= 11
|
75
|
-
), f"Cannot represent max/min as a double for type {self.__str__()}"
|
76
|
-
|
77
|
-
max_mantissa = (1 << self.mantissa) - 1
|
78
|
-
if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
|
79
|
-
max_mantissa = max_mantissa - 1
|
80
|
-
|
81
|
-
max_exponent = (1 << self.exponent) - 2
|
82
|
-
if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN or self.nan_repr == NanRepr.NONE:
|
83
|
-
assert (
|
84
|
-
self.exponent < 11
|
85
|
-
), f"Cannot represent max/min as a double for type {self.__str__()}"
|
86
|
-
max_exponent = max_exponent + 1
|
87
|
-
|
88
|
-
# adjust the exponent to match that of a double
|
89
|
-
# for now we assume the exponent bias is the standard 2^(e-1) -1, (where
|
90
|
-
# e is the exponent bits), there is some precedent for non-standard
|
91
|
-
# biases, example `float8_e4m3b11fnuz` here:
|
92
|
-
# https://github.com/jax-ml/ml_dtypes but to avoid premature over
|
93
|
-
# complication we are just assuming the standard exponent bias until
|
94
|
-
# there is a need to support non-standard biases
|
95
|
-
exponent_bias = (1 << (self.exponent - 1)) - 1
|
96
|
-
exponent_bias_double = (1 << 10) - 1 # double e = 11
|
97
|
-
|
98
|
-
max_exponent_double = max_exponent - exponent_bias + exponent_bias_double
|
99
|
-
|
100
|
-
# shift the mantissa and exponent into the proper positions for an
|
101
|
-
# IEEE double and bitwise-or them together.
|
102
|
-
return (max_mantissa << (52 - self.mantissa)) | (max_exponent_double << 52)
|
103
|
-
|
104
|
-
def _floating_point_max(self) -> float:
|
105
|
-
double_raw = self._floating_point_max_int()
|
106
|
-
return struct.unpack("!d", struct.pack("!Q", double_raw))[0]
|
107
|
-
|
108
|
-
def _raw_max(self) -> Union[int, float]:
|
109
|
-
if self.is_floating_point():
|
110
|
-
return self._floating_point_max()
|
111
|
-
else:
|
112
|
-
assert (
|
113
|
-
self.size_bits < 64 or self.size_bits == 64 and self.is_signed()
|
114
|
-
), "Cannot represent max as an int"
|
115
|
-
return (1 << self.mantissa) - 1
|
116
|
-
|
117
|
-
def _raw_min(self) -> Union[int, float]:
|
118
|
-
if self.is_floating_point():
|
119
|
-
assert (
|
120
|
-
self.is_signed()
|
121
|
-
), "We currently assume all floating point types are signed"
|
122
|
-
sign_bit_double = 1 << 63
|
123
|
-
|
124
|
-
max_raw = self._floating_point_max_int()
|
125
|
-
min_raw = max_raw | sign_bit_double
|
126
|
-
return struct.unpack("!d", struct.pack("!Q", min_raw))[0]
|
127
|
-
else:
|
128
|
-
assert (
|
129
|
-
not self.is_signed() or self.size_bits <= 64
|
130
|
-
), "Cannot represent min as a int64_t"
|
131
|
-
|
132
|
-
if self.is_signed():
|
133
|
-
return -(1 << (self.size_bits - 1))
|
134
|
-
else:
|
135
|
-
return 0
|
136
|
-
|
137
|
-
@functools.cached_property
|
138
|
-
def id(self) -> int:
|
139
|
-
"""
|
140
|
-
Convert the ScalarType to an int which can be passed to pytorch custom
|
141
|
-
ops. This layout of the int must be kept in sync with the C++
|
142
|
-
ScalarType's from_id method.
|
143
|
-
"""
|
144
|
-
val = 0
|
145
|
-
offset = 0
|
146
|
-
|
147
|
-
def or_and_advance(member, bit_width):
|
148
|
-
nonlocal val
|
149
|
-
nonlocal offset
|
150
|
-
bit_mask = (1 << bit_width) - 1
|
151
|
-
val = val | (int(member) & bit_mask) << offset
|
152
|
-
offset = offset + bit_width
|
153
|
-
|
154
|
-
or_and_advance(self.exponent, 8)
|
155
|
-
or_and_advance(self.mantissa, 8)
|
156
|
-
or_and_advance(self.signed, 1)
|
157
|
-
or_and_advance(self.bias, 32)
|
158
|
-
or_and_advance(self._finite_values_only, 1)
|
159
|
-
or_and_advance(self.nan_repr.value, 8)
|
160
|
-
|
161
|
-
assert offset <= 64, f"ScalarType fields too big {offset} to fit into an int64"
|
162
|
-
|
163
|
-
_SCALAR_TYPES_ID_MAP[val] = self
|
164
|
-
|
165
|
-
return val
|
166
|
-
|
167
|
-
@property
|
168
|
-
def size_bits(self) -> int:
|
169
|
-
return self.exponent + self.mantissa + int(self.signed)
|
170
|
-
|
171
|
-
def min(self) -> Union[int, float]:
|
172
|
-
"""
|
173
|
-
Min representable value for this scalar type.
|
174
|
-
(accounting for bias if there is one)
|
175
|
-
"""
|
176
|
-
return self._raw_min() - self.bias
|
177
|
-
|
178
|
-
def max(self) -> Union[int, float]:
|
179
|
-
"""
|
180
|
-
Max representable value for this scalar type.
|
181
|
-
(accounting for bias if there is one)
|
182
|
-
"""
|
183
|
-
return self._raw_max() - self.bias
|
184
|
-
|
185
|
-
def is_signed(self) -> bool:
|
186
|
-
"""
|
187
|
-
If the type is signed (i.e. has a sign bit), same as `signed`
|
188
|
-
added for consistency with:
|
189
|
-
https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
|
190
|
-
"""
|
191
|
-
return self.signed
|
192
|
-
|
193
|
-
def is_floating_point(self) -> bool:
|
194
|
-
"If the type is a floating point type"
|
195
|
-
return self.exponent != 0
|
196
|
-
|
197
|
-
def is_integer(self) -> bool:
|
198
|
-
"If the type is an integer type"
|
199
|
-
return self.exponent == 0
|
200
|
-
|
201
|
-
def has_bias(self) -> bool:
|
202
|
-
"If the type has a non-zero bias"
|
203
|
-
return self.bias != 0
|
204
|
-
|
205
|
-
def has_infs(self) -> bool:
|
206
|
-
"If the type is floating point and supports infinity"
|
207
|
-
return not self._finite_values_only
|
208
|
-
|
209
|
-
def has_nans(self) -> bool:
|
210
|
-
return self.nan_repr != NanRepr.NONE.value
|
211
|
-
|
212
|
-
def is_ieee_754(self) -> bool:
|
213
|
-
"""
|
214
|
-
If the type is a floating point type that follows IEEE 754
|
215
|
-
conventions
|
216
|
-
"""
|
217
|
-
return self.nan_repr == NanRepr.IEEE_754.value and not self._finite_values_only
|
218
|
-
|
219
|
-
def __str__(self) -> str:
|
220
|
-
"""
|
221
|
-
naming generally follows: https://github.com/jax-ml/ml_dtypes
|
222
|
-
for floating point types (leading f) the scheme is:
|
223
|
-
`float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
224
|
-
flags:
|
225
|
-
- no-flags: means it follows IEEE 754 conventions
|
226
|
-
- f: means finite values only (no infinities)
|
227
|
-
- n: means nans are supported (non-standard encoding)
|
228
|
-
for integer types the scheme is:
|
229
|
-
`[u]int<size_bits>[b<bias>]`
|
230
|
-
- if bias is not present it means its zero
|
231
|
-
"""
|
232
|
-
if self.is_floating_point():
|
233
|
-
ret = (
|
234
|
-
"float"
|
235
|
-
+ str(self.size_bits)
|
236
|
-
+ "_e"
|
237
|
-
+ str(self.exponent)
|
238
|
-
+ "m"
|
239
|
-
+ str(self.mantissa)
|
240
|
-
)
|
241
|
-
|
242
|
-
if not self.is_ieee_754():
|
243
|
-
if self._finite_values_only:
|
244
|
-
ret = ret + "f"
|
245
|
-
if self.nan_repr != NanRepr.NONE:
|
246
|
-
ret = ret + "n"
|
247
|
-
|
248
|
-
return ret
|
249
|
-
else:
|
250
|
-
ret = ("int" if self.is_signed() else "uint") + str(self.size_bits)
|
251
|
-
if self.has_bias():
|
252
|
-
ret = ret + "b" + str(self.bias)
|
253
|
-
return ret
|
254
|
-
|
255
|
-
def __repr__(self) -> str:
|
256
|
-
return "ScalarType." + self.__str__()
|
257
|
-
|
258
|
-
# __len__ needs to be defined (and has to throw TypeError) for pytorch's
|
259
|
-
# opcheck to work.
|
260
|
-
def __len__(self) -> int:
|
261
|
-
raise TypeError
|
262
|
-
|
263
|
-
#
|
264
|
-
# Convenience Constructors
|
265
|
-
#
|
266
|
-
|
267
|
-
@classmethod
|
268
|
-
def int_(cls, size_bits: int, bias: Optional[int]) -> "ScalarType":
|
269
|
-
"Create a signed integer scalar type (size_bits includes sign-bit)."
|
270
|
-
ret = cls(0, size_bits - 1, True, bias if bias else 0)
|
271
|
-
ret.id # noqa B018: make sure the id is cached
|
272
|
-
return ret
|
273
|
-
|
274
|
-
@classmethod
|
275
|
-
def uint(cls, size_bits: int, bias: Optional[int]) -> "ScalarType":
|
276
|
-
"""Create a unsigned integer scalar type."""
|
277
|
-
ret = cls(0, size_bits, False, bias if bias else 0)
|
278
|
-
ret.id # noqa B018: make sure the id is cached
|
279
|
-
return ret
|
280
|
-
|
281
|
-
@classmethod
|
282
|
-
def float_IEEE754(cls, exponent: int, mantissa: int) -> "ScalarType":
|
283
|
-
"""
|
284
|
-
Create a standard floating point type
|
285
|
-
(i.e. follows IEEE 754 conventions).
|
286
|
-
"""
|
287
|
-
assert mantissa > 0 and exponent > 0
|
288
|
-
ret = cls(exponent, mantissa, True, 0)
|
289
|
-
ret.id # noqa B018: make sure the id is cached
|
290
|
-
return ret
|
291
|
-
|
292
|
-
@classmethod
|
293
|
-
def float_(
|
294
|
-
cls, exponent: int, mantissa: int, finite_values_only: bool, nan_repr: NanRepr
|
295
|
-
) -> "ScalarType":
|
296
|
-
"""
|
297
|
-
Create a non-standard floating point type
|
298
|
-
(i.e. does not follow IEEE 754 conventions).
|
299
|
-
"""
|
300
|
-
assert mantissa > 0 and exponent > 0
|
301
|
-
assert nan_repr != NanRepr.IEEE_754, (
|
302
|
-
"use `float_IEEE754` constructor for floating point types that "
|
303
|
-
"follow IEEE 754 conventions"
|
304
|
-
)
|
305
|
-
ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
|
306
|
-
ret.id # noqa B018: make sure the id is cached
|
307
|
-
return ret
|
308
|
-
|
309
|
-
@classmethod
|
310
|
-
def from_id(cls, scalar_type_id: int):
|
311
|
-
if scalar_type_id not in _SCALAR_TYPES_ID_MAP:
|
312
|
-
raise ValueError(f"scalar_type_id {scalar_type_id} doesn't exists.")
|
313
|
-
return _SCALAR_TYPES_ID_MAP[scalar_type_id]
|
314
|
-
|
315
|
-
|
316
|
-
# naming generally follows: https://github.com/jax-ml/ml_dtypes
|
317
|
-
# for floating point types (leading f) the scheme is:
|
318
|
-
# `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
319
|
-
# flags:
|
320
|
-
# - no-flags: means it follows IEEE 754 conventions
|
321
|
-
# - f: means finite values only (no infinities)
|
322
|
-
# - n: means nans are supported (non-standard encoding)
|
323
|
-
# for integer types the scheme is:
|
324
|
-
# `[u]int<size_bits>[b<bias>]`
|
325
|
-
# - if bias is not present it means its zero
|
326
|
-
|
327
|
-
|
328
|
-
class scalar_types:
|
329
|
-
int4 = ScalarType.int_(4, None)
|
330
|
-
uint4 = ScalarType.uint(4, None)
|
331
|
-
int8 = ScalarType.int_(8, None)
|
332
|
-
uint8 = ScalarType.uint(8, None)
|
333
|
-
float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
|
334
|
-
float8_e5m2 = ScalarType.float_IEEE754(5, 2)
|
335
|
-
float16_e8m7 = ScalarType.float_IEEE754(8, 7)
|
336
|
-
float16_e5m10 = ScalarType.float_IEEE754(5, 10)
|
337
|
-
|
338
|
-
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
|
339
|
-
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
|
340
|
-
|
341
|
-
# fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
|
342
|
-
float4_e2m1f = ScalarType.float_(2, 1, True, NanRepr.NONE)
|
343
|
-
|
344
|
-
# "gptq" types
|
345
|
-
uint2b2 = ScalarType.uint(2, 2)
|
346
|
-
uint3b4 = ScalarType.uint(3, 4)
|
347
|
-
uint4b8 = ScalarType.uint(4, 8)
|
348
|
-
uint8b128 = ScalarType.uint(8, 128)
|
349
|
-
|
350
|
-
# colloquial names
|
351
|
-
bfloat16 = float16_e8m7
|
352
|
-
float16 = float16_e5m10
|
@@ -1,131 +0,0 @@
|
|
1
|
-
from typing import Tuple
|
2
|
-
|
3
|
-
import torch
|
4
|
-
|
5
|
-
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
6
|
-
from sglang.srt.lora.utils import LoRABatchInfo
|
7
|
-
from sglang.srt.utils import is_flashinfer_available
|
8
|
-
|
9
|
-
if is_flashinfer_available():
|
10
|
-
from flashinfer import SegmentGEMMWrapper
|
11
|
-
|
12
|
-
|
13
|
-
class FlashInferLoRABackend(BaseLoRABackend):
|
14
|
-
|
15
|
-
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
|
16
|
-
super().__init__(name, batch_info)
|
17
|
-
|
18
|
-
# Set up SGemm Wrapper from flashinfer
|
19
|
-
# FIXME wait for flashinfer segment gemm update
|
20
|
-
workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda")
|
21
|
-
self.segment_gemm = SegmentGEMMWrapper(workspace_buffer)
|
22
|
-
|
23
|
-
def run_lora_a_sgemm(
|
24
|
-
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
25
|
-
) -> torch.Tensor:
|
26
|
-
|
27
|
-
return self.segment_gemm.run(
|
28
|
-
x=x,
|
29
|
-
weights=weights,
|
30
|
-
batch_size=self.batch_info.bs,
|
31
|
-
weight_column_major=True,
|
32
|
-
seg_indptr=self.batch_info.seg_indptr,
|
33
|
-
weight_indices=self.batch_info.weight_indices,
|
34
|
-
)
|
35
|
-
|
36
|
-
def run_lora_b_sgemm(
|
37
|
-
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
|
38
|
-
) -> torch.Tensor:
|
39
|
-
|
40
|
-
return (
|
41
|
-
self.segment_gemm.run(
|
42
|
-
x=x,
|
43
|
-
weights=weights,
|
44
|
-
batch_size=self.batch_info.bs,
|
45
|
-
weight_column_major=True,
|
46
|
-
seg_indptr=self.batch_info.seg_indptr,
|
47
|
-
weight_indices=self.batch_info.weight_indices,
|
48
|
-
)
|
49
|
-
* self.batch_info.scalings[0]
|
50
|
-
)
|
51
|
-
|
52
|
-
def run_qkv_lora(
|
53
|
-
self,
|
54
|
-
x: torch.Tensor,
|
55
|
-
qkv_lora_a: torch.Tensor,
|
56
|
-
qkv_lora_b: Tuple[torch.Tensor],
|
57
|
-
*args,
|
58
|
-
**kwargs,
|
59
|
-
) -> torch.Tensor:
|
60
|
-
|
61
|
-
assert isinstance(qkv_lora_b, tuple) and len(qkv_lora_b) == 2
|
62
|
-
|
63
|
-
# Shape of lora_a_output: (s, 3 * r)
|
64
|
-
lora_a_output = self.run_lora_a_sgemm(x=x, weights=qkv_lora_a)
|
65
|
-
|
66
|
-
q_lora_b, kv_lora_b = qkv_lora_b
|
67
|
-
lora_rank = kv_lora_b.shape[-1]
|
68
|
-
output_dim_q = q_lora_b.shape[-2]
|
69
|
-
output_dim_kv = kv_lora_b.shape[-2]
|
70
|
-
lora_output = torch.empty(
|
71
|
-
(x.shape[0], output_dim_q + 2 * output_dim_kv),
|
72
|
-
device=x.device,
|
73
|
-
dtype=x.dtype,
|
74
|
-
)
|
75
|
-
|
76
|
-
# q
|
77
|
-
lora_output[:, :output_dim_q] = self.run_lora_b_sgemm(
|
78
|
-
x=lora_a_output[:, :lora_rank].contiguous(), weights=q_lora_b[0]
|
79
|
-
)
|
80
|
-
|
81
|
-
# kv
|
82
|
-
lora_output[:, output_dim_q : output_dim_q + output_dim_kv] = (
|
83
|
-
self.run_lora_b_sgemm(
|
84
|
-
x=lora_a_output[:, lora_rank : 2 * lora_rank].contiguous(),
|
85
|
-
weights=kv_lora_b[0],
|
86
|
-
)
|
87
|
-
)
|
88
|
-
|
89
|
-
lora_output[
|
90
|
-
:, output_dim_q + output_dim_kv : output_dim_q + 2 * output_dim_kv
|
91
|
-
] = self.run_lora_b_sgemm(
|
92
|
-
x=lora_a_output[:, 2 * lora_rank : 3 * lora_rank].contiguous(),
|
93
|
-
weights=kv_lora_b[1],
|
94
|
-
)
|
95
|
-
|
96
|
-
return lora_output * self.batch_info.scalings[0]
|
97
|
-
|
98
|
-
def run_gate_up_lora(
|
99
|
-
self,
|
100
|
-
x: torch.Tensor,
|
101
|
-
gate_up_lora_a: torch.Tensor,
|
102
|
-
gate_up_lora_b: Tuple[torch.Tensor],
|
103
|
-
*args,
|
104
|
-
**kwargs,
|
105
|
-
) -> torch.Tensor:
|
106
|
-
|
107
|
-
assert isinstance(gate_up_lora_b, tuple) and len(gate_up_lora_b) == 2
|
108
|
-
lora_rank = gate_up_lora_b[0].shape[-1]
|
109
|
-
output_dim = gate_up_lora_b[0].shape[-2]
|
110
|
-
|
111
|
-
# Shape of lora_a_output: (s, 2 * r)
|
112
|
-
lora_a_output = self.run_lora_a_sgemm(x=x, weights=gate_up_lora_a)
|
113
|
-
|
114
|
-
lora_output = torch.empty(
|
115
|
-
(x.shape[0], 2 * output_dim),
|
116
|
-
device=x.device,
|
117
|
-
dtype=x.dtype,
|
118
|
-
)
|
119
|
-
|
120
|
-
# Compute lora for gate and up proj respectively
|
121
|
-
lora_output[:, :output_dim] = self.run_lora_b_sgemm(
|
122
|
-
x=lora_a_output[:, :lora_rank].contiguous(),
|
123
|
-
weights=gate_up_lora_b[0],
|
124
|
-
)
|
125
|
-
|
126
|
-
lora_output[:, output_dim:] = self.run_lora_b_sgemm(
|
127
|
-
x=lora_a_output[:, lora_rank:].contiguous(),
|
128
|
-
weights=gate_up_lora_b[1],
|
129
|
-
)
|
130
|
-
|
131
|
-
return lora_output * self.batch_info.scalings[0]
|
/sglang/{api.py → lang/api.py}
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|