sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
- sglang/bench_serving.py +73 -14
- sglang/compile_deep_gemm.py +13 -7
- sglang/launch_server.py +2 -0
- sglang/srt/batch_invariant_ops/__init__.py +2 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
- sglang/srt/checkpoint_engine/__init__.py +9 -0
- sglang/srt/checkpoint_engine/update.py +317 -0
- sglang/srt/compilation/backend.py +1 -1
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/deepseek_ocr.py +542 -10
- sglang/srt/configs/deepseekvl2.py +95 -194
- sglang/srt/configs/kimi_linear.py +160 -0
- sglang/srt/configs/mamba_utils.py +66 -0
- sglang/srt/configs/model_config.py +30 -7
- sglang/srt/constants.py +7 -0
- sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
- sglang/srt/disaggregation/decode.py +34 -6
- sglang/srt/disaggregation/nixl/conn.py +2 -2
- sglang/srt/disaggregation/prefill.py +25 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
- sglang/srt/distributed/parallel_state.py +9 -12
- sglang/srt/entrypoints/engine.py +31 -20
- sglang/srt/entrypoints/grpc_server.py +0 -1
- sglang/srt/entrypoints/http_server.py +94 -94
- sglang/srt/entrypoints/openai/protocol.py +7 -1
- sglang/srt/entrypoints/openai/serving_chat.py +42 -0
- sglang/srt/entrypoints/openai/serving_completions.py +10 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/environ.py +23 -2
- sglang/srt/eplb/expert_distribution.py +64 -1
- sglang/srt/eplb/expert_location.py +106 -36
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/minimax_m2.py +367 -0
- sglang/srt/grpc/compile_proto.py +3 -0
- sglang/srt/layers/activation.py +6 -0
- sglang/srt/layers/attention/ascend_backend.py +233 -5
- sglang/srt/layers/attention/attention_registry.py +3 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
- sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
- sglang/srt/layers/attention/fla/kda.py +1359 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
- sglang/srt/layers/attention/flashattention_backend.py +19 -8
- sglang/srt/layers/attention/flashinfer_backend.py +10 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
- sglang/srt/layers/attention/mamba/mamba.py +20 -11
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
- sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
- sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
- sglang/srt/layers/attention/nsa/transform_index.py +1 -1
- sglang/srt/layers/attention/nsa_backend.py +157 -23
- sglang/srt/layers/attention/triton_backend.py +4 -1
- sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
- sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
- sglang/srt/layers/attention/utils.py +78 -0
- sglang/srt/layers/communicator.py +24 -1
- sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/layernorm.py +35 -6
- sglang/srt/layers/logits_processor.py +9 -20
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
- sglang/srt/layers/moe/ep_moe/layer.py +78 -289
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
- sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
- sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
- sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +35 -10
- sglang/srt/layers/moe/utils.py +3 -4
- sglang/srt/layers/pooler.py +21 -2
- sglang/srt/layers/quantization/__init__.py +13 -84
- sglang/srt/layers/quantization/auto_round.py +394 -0
- sglang/srt/layers/quantization/awq.py +0 -3
- sglang/srt/layers/quantization/base_config.py +7 -0
- sglang/srt/layers/quantization/fp8.py +68 -63
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gguf.py +566 -0
- sglang/srt/layers/quantization/modelopt_quant.py +168 -11
- sglang/srt/layers/quantization/mxfp4.py +30 -38
- sglang/srt/layers/quantization/unquant.py +23 -45
- sglang/srt/layers/quantization/w4afp8.py +38 -2
- sglang/srt/layers/radix_attention.py +5 -2
- sglang/srt/layers/rotary_embedding.py +130 -46
- sglang/srt/layers/sampler.py +12 -1
- sglang/srt/lora/lora_registry.py +9 -0
- sglang/srt/managers/async_mm_data_processor.py +122 -0
- sglang/srt/managers/data_parallel_controller.py +30 -3
- sglang/srt/managers/detokenizer_manager.py +3 -0
- sglang/srt/managers/io_struct.py +29 -4
- sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
- sglang/srt/managers/schedule_batch.py +74 -15
- sglang/srt/managers/scheduler.py +185 -144
- sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
- sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
- sglang/srt/managers/scheduler_pp_mixin.py +7 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
- sglang/srt/managers/session_controller.py +6 -5
- sglang/srt/managers/tokenizer_manager.py +165 -78
- sglang/srt/managers/tp_worker.py +24 -1
- sglang/srt/mem_cache/base_prefix_cache.py +23 -4
- sglang/srt/mem_cache/common.py +1 -0
- sglang/srt/mem_cache/hicache_storage.py +7 -1
- sglang/srt/mem_cache/memory_pool.py +253 -57
- sglang/srt/mem_cache/memory_pool_host.py +12 -5
- sglang/srt/mem_cache/radix_cache.py +4 -0
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
- sglang/srt/metrics/collector.py +46 -3
- sglang/srt/model_executor/cuda_graph_runner.py +15 -3
- sglang/srt/model_executor/forward_batch_info.py +55 -14
- sglang/srt/model_executor/model_runner.py +77 -170
- sglang/srt/model_executor/npu_graph_runner.py +7 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/bailing_moe.py +9 -2
- sglang/srt/models/deepseek_nextn.py +11 -2
- sglang/srt/models/deepseek_v2.py +296 -78
- sglang/srt/models/glm4.py +391 -77
- sglang/srt/models/glm4_moe.py +322 -354
- sglang/srt/models/glm4_moe_nextn.py +4 -14
- sglang/srt/models/glm4v.py +196 -55
- sglang/srt/models/glm4v_moe.py +29 -197
- sglang/srt/models/gpt_oss.py +1 -10
- sglang/srt/models/kimi_linear.py +678 -0
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/llama_eagle3.py +11 -1
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/minimax_m2.py +922 -0
- sglang/srt/models/nvila.py +355 -0
- sglang/srt/models/nvila_lite.py +184 -0
- sglang/srt/models/qwen2.py +23 -2
- sglang/srt/models/qwen2_moe.py +30 -15
- sglang/srt/models/qwen3.py +35 -5
- sglang/srt/models/qwen3_moe.py +18 -12
- sglang/srt/models/qwen3_next.py +7 -0
- sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
- sglang/srt/multimodal/processors/base_processor.py +1 -0
- sglang/srt/multimodal/processors/glm4v.py +1 -1
- sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
- sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
- sglang/srt/multiplex/multiplexing_mixin.py +209 -0
- sglang/srt/multiplex/pdmux_context.py +164 -0
- sglang/srt/parser/conversation.py +7 -1
- sglang/srt/parser/reasoning_parser.py +28 -1
- sglang/srt/sampling/custom_logit_processor.py +67 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
- sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
- sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
- sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
- sglang/srt/server_args.py +459 -199
- sglang/srt/single_batch_overlap.py +2 -4
- sglang/srt/speculative/draft_utils.py +16 -0
- sglang/srt/speculative/eagle_info.py +42 -36
- sglang/srt/speculative/eagle_info_v2.py +68 -25
- sglang/srt/speculative/eagle_utils.py +261 -16
- sglang/srt/speculative/eagle_worker.py +11 -3
- sglang/srt/speculative/eagle_worker_v2.py +15 -9
- sglang/srt/speculative/spec_info.py +305 -31
- sglang/srt/speculative/spec_utils.py +44 -8
- sglang/srt/tracing/trace.py +121 -12
- sglang/srt/utils/common.py +142 -74
- sglang/srt/utils/hf_transformers_utils.py +38 -12
- sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
- sglang/test/kits/radix_cache_server_kit.py +50 -0
- sglang/test/runners.py +31 -7
- sglang/test/simple_eval_common.py +5 -3
- sglang/test/simple_eval_humaneval.py +1 -0
- sglang/test/simple_eval_math.py +1 -0
- sglang/test/simple_eval_mmlu.py +1 -0
- sglang/test/simple_eval_mmmu_vlm.py +1 -0
- sglang/test/test_deterministic.py +235 -12
- sglang/test/test_deterministic_utils.py +2 -1
- sglang/test/test_utils.py +7 -1
- sglang/version.py +1 -1
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
- sglang/srt/models/vila.py +0 -306
- /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
|
@@ -1,46 +1,320 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import threading
|
|
1
4
|
from abc import ABC, abstractmethod
|
|
5
|
+
from collections import defaultdict
|
|
2
6
|
from enum import IntEnum, auto
|
|
3
|
-
from
|
|
4
|
-
|
|
7
|
+
from typing import (
|
|
8
|
+
Any,
|
|
9
|
+
Callable,
|
|
10
|
+
DefaultDict,
|
|
11
|
+
Dict,
|
|
12
|
+
Iterable,
|
|
13
|
+
Iterator,
|
|
14
|
+
List,
|
|
15
|
+
Optional,
|
|
16
|
+
Sequence,
|
|
17
|
+
Set,
|
|
18
|
+
Tuple,
|
|
19
|
+
Union,
|
|
20
|
+
)
|
|
5
21
|
|
|
6
22
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
|
7
23
|
|
|
24
|
+
DraftWorkerClass = Callable[..., Any]
|
|
25
|
+
DraftWorkerFactory = Callable[..., Any]
|
|
8
26
|
|
|
9
|
-
class SpeculativeAlgorithm(IntEnum):
|
|
10
|
-
NONE = auto()
|
|
11
|
-
EAGLE = auto()
|
|
12
|
-
EAGLE3 = auto()
|
|
13
|
-
STANDALONE = auto()
|
|
14
|
-
NGRAM = auto()
|
|
15
27
|
|
|
16
|
-
|
|
17
|
-
|
|
28
|
+
class _SpeculativeAlgorithmMeta(type):
|
|
29
|
+
def __iter__(cls) -> Iterator["SpeculativeAlgorithm"]:
|
|
30
|
+
return iter(cls._registration_order)
|
|
18
31
|
|
|
19
|
-
def is_eagle(self):
|
|
20
|
-
return self == SpeculativeAlgorithm.EAGLE or self == SpeculativeAlgorithm.EAGLE3
|
|
21
32
|
|
|
22
|
-
|
|
23
|
-
|
|
33
|
+
class SpeculativeAlgorithm(metaclass=_SpeculativeAlgorithmMeta):
|
|
34
|
+
"""Registry-backed representation of speculative decoding algorithms."""
|
|
24
35
|
|
|
25
|
-
|
|
26
|
-
return self == SpeculativeAlgorithm.STANDALONE
|
|
36
|
+
__slots__ = ("name", "value", "_draft_worker_factory")
|
|
27
37
|
|
|
28
|
-
|
|
29
|
-
|
|
38
|
+
_registry_by_name: Dict[str, "SpeculativeAlgorithm"] = {}
|
|
39
|
+
_registry_by_value: Dict[int, "SpeculativeAlgorithm"] = {}
|
|
40
|
+
_registration_order: List["SpeculativeAlgorithm"] = []
|
|
41
|
+
_flags: DefaultDict[str, Set[int]] = defaultdict(set)
|
|
42
|
+
_next_value: int = 0
|
|
30
43
|
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
name: str,
|
|
47
|
+
value: int,
|
|
48
|
+
draft_worker_factory: Optional[DraftWorkerFactory] = None,
|
|
49
|
+
):
|
|
50
|
+
self.name = name
|
|
51
|
+
self.value = value
|
|
52
|
+
self._draft_worker_factory = draft_worker_factory
|
|
53
|
+
|
|
54
|
+
def __repr__(self) -> str: # pragma: no cover - trivial
|
|
55
|
+
return f"SpeculativeAlgorithm.{self.name}"
|
|
56
|
+
|
|
57
|
+
def __str__(self) -> str: # pragma: no cover - trivial
|
|
58
|
+
return self.name
|
|
59
|
+
|
|
60
|
+
def __hash__(self) -> int:
|
|
61
|
+
return hash(self.value)
|
|
62
|
+
|
|
63
|
+
def __eq__(self, other: object) -> bool:
|
|
64
|
+
if isinstance(other, SpeculativeAlgorithm):
|
|
65
|
+
return self.value == other.value
|
|
66
|
+
return NotImplemented
|
|
67
|
+
|
|
68
|
+
def __int__(self) -> int:
|
|
69
|
+
return self.value
|
|
70
|
+
|
|
71
|
+
@classmethod
|
|
72
|
+
def register(
|
|
73
|
+
cls,
|
|
74
|
+
name: str,
|
|
75
|
+
*,
|
|
76
|
+
aliases: Optional[Sequence[str]] = None,
|
|
77
|
+
value: Optional[int] = None,
|
|
78
|
+
draft_worker_factory: Optional[DraftWorkerFactory] = None,
|
|
79
|
+
) -> SpeculativeAlgorithm:
|
|
80
|
+
normalized_name = name.upper()
|
|
81
|
+
if normalized_name in cls._registry_by_name:
|
|
82
|
+
raise ValueError(
|
|
83
|
+
f"SpeculativeAlgorithm '{normalized_name}' already registered"
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
if value is None:
|
|
87
|
+
value = cls._next_value
|
|
88
|
+
cls._next_value = max(cls._next_value, value + 1)
|
|
89
|
+
|
|
90
|
+
algorithm = cls(
|
|
91
|
+
normalized_name,
|
|
92
|
+
value,
|
|
93
|
+
draft_worker_factory=draft_worker_factory,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
cls._registry_by_name[normalized_name] = algorithm
|
|
97
|
+
cls._registry_by_value[value] = algorithm
|
|
98
|
+
cls._registration_order.append(algorithm)
|
|
99
|
+
setattr(cls, normalized_name, algorithm)
|
|
100
|
+
|
|
101
|
+
if aliases:
|
|
102
|
+
cls.register_aliases(algorithm, *aliases)
|
|
103
|
+
|
|
104
|
+
return algorithm
|
|
105
|
+
|
|
106
|
+
@classmethod
|
|
107
|
+
def register_aliases(cls, algorithm: SpeculativeAlgorithm, *aliases: str) -> None:
|
|
108
|
+
for alias in aliases:
|
|
109
|
+
cls._registry_by_name[alias.upper()] = algorithm
|
|
110
|
+
|
|
111
|
+
@classmethod
|
|
112
|
+
def register_draft_worker(
|
|
113
|
+
cls,
|
|
114
|
+
algorithm: SpeculativeAlgorithm | str,
|
|
115
|
+
factory: DraftWorkerFactory,
|
|
116
|
+
) -> None:
|
|
117
|
+
algo = cls._ensure_algorithm(algorithm)
|
|
118
|
+
algo._draft_worker_factory = factory
|
|
119
|
+
|
|
120
|
+
@classmethod
|
|
121
|
+
def _ensure_algorithm(
|
|
122
|
+
cls, algorithm: SpeculativeAlgorithm | str
|
|
123
|
+
) -> SpeculativeAlgorithm:
|
|
124
|
+
if isinstance(algorithm, SpeculativeAlgorithm):
|
|
125
|
+
return algorithm
|
|
126
|
+
if isinstance(algorithm, str):
|
|
127
|
+
return cls.from_string(algorithm)
|
|
128
|
+
raise TypeError(f"Unsupported algorithm identifier: {algorithm!r}")
|
|
129
|
+
|
|
130
|
+
@classmethod
|
|
131
|
+
def _add_flag(
|
|
132
|
+
cls, flag: str | Sequence[str], algorithm: SpeculativeAlgorithm | str
|
|
133
|
+
) -> None:
|
|
134
|
+
algo = cls._ensure_algorithm(algorithm)
|
|
135
|
+
if isinstance(flag, str):
|
|
136
|
+
flag_iter = (flag,)
|
|
137
|
+
else:
|
|
138
|
+
flag_iter = flag
|
|
139
|
+
for flag_name in flag_iter:
|
|
140
|
+
cls._flags[flag_name.upper()].add(algo.value)
|
|
141
|
+
|
|
142
|
+
@classmethod
|
|
143
|
+
def from_string(cls, name: Optional[str]) -> SpeculativeAlgorithm:
|
|
144
|
+
if name is None:
|
|
145
|
+
return cls.NONE
|
|
146
|
+
try:
|
|
147
|
+
return cls._registry_by_name[name.upper()]
|
|
148
|
+
except KeyError as exc:
|
|
149
|
+
raise ValueError(f"Unknown speculative algorithm '{name}'") from exc
|
|
150
|
+
|
|
151
|
+
@classmethod
|
|
152
|
+
def from_value(cls, value: int) -> SpeculativeAlgorithm:
|
|
153
|
+
try:
|
|
154
|
+
return cls._registry_by_value[value]
|
|
155
|
+
except KeyError as exc:
|
|
156
|
+
raise ValueError(f"Unknown speculative algorithm id {value}") from exc
|
|
157
|
+
|
|
158
|
+
def _has_flag(self, flag: str) -> bool:
|
|
159
|
+
return self.value in type(self)._flags.get(flag.upper(), set())
|
|
160
|
+
|
|
161
|
+
def is_none(self) -> bool:
|
|
162
|
+
return self is SpeculativeAlgorithm.NONE
|
|
163
|
+
|
|
164
|
+
def is_eagle(self) -> bool:
|
|
165
|
+
return self._has_flag("EAGLE")
|
|
166
|
+
|
|
167
|
+
def is_eagle3(self) -> bool:
|
|
168
|
+
return self._has_flag("EAGLE3")
|
|
169
|
+
|
|
170
|
+
def is_standalone(self) -> bool:
|
|
171
|
+
return self._has_flag("STANDALONE")
|
|
172
|
+
|
|
173
|
+
def is_ngram(self) -> bool:
|
|
174
|
+
return self._has_flag("NGRAM")
|
|
175
|
+
|
|
176
|
+
def create_draft_worker(self, **factory_kwargs: Any) -> Any:
|
|
177
|
+
if self._draft_worker_factory is None:
|
|
178
|
+
return None
|
|
179
|
+
return self._draft_worker_factory(self, **factory_kwargs)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
# Registry helpers backed by `SpeculativeAlgorithm`.
|
|
183
|
+
_LOCK = threading.RLock()
|
|
184
|
+
_REGISTERED_WORKERS: Dict[SpeculativeAlgorithm, DraftWorkerClass] = {}
|
|
185
|
+
_FLAG_MARKERS: Dict[str, Callable[[Union[SpeculativeAlgorithm, str]], None]] = {
|
|
186
|
+
"EAGLE": lambda algorithm: SpeculativeAlgorithm._add_flag("EAGLE", algorithm),
|
|
187
|
+
"EAGLE3": lambda algorithm: SpeculativeAlgorithm._add_flag("EAGLE3", algorithm),
|
|
188
|
+
"STANDALONE": lambda algorithm: SpeculativeAlgorithm._add_flag(
|
|
189
|
+
"STANDALONE", algorithm
|
|
190
|
+
),
|
|
191
|
+
"NGRAM": lambda algorithm: SpeculativeAlgorithm._add_flag("NGRAM", algorithm),
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def _wrap_worker_class(worker_cls: DraftWorkerClass) -> DraftWorkerFactory:
|
|
196
|
+
def _factory(_: SpeculativeAlgorithm, **kwargs: Any) -> Any:
|
|
197
|
+
return worker_cls(**kwargs)
|
|
198
|
+
|
|
199
|
+
return _factory
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def register_speculative_algorithm(
|
|
203
|
+
name: str,
|
|
204
|
+
worker_cls: DraftWorkerClass,
|
|
205
|
+
*,
|
|
206
|
+
aliases: Optional[Sequence[str]] = None,
|
|
207
|
+
flags: Optional[Iterable[str]] = None,
|
|
208
|
+
value: Optional[int] = None,
|
|
209
|
+
override_worker: bool = False,
|
|
210
|
+
) -> SpeculativeAlgorithm:
|
|
211
|
+
"""Register a speculative algorithm and the associated draft worker class.
|
|
212
|
+
|
|
213
|
+
Example:
|
|
214
|
+
>>> from sglang.srt.speculative.spec_info import register_speculative_algorithm
|
|
215
|
+
>>> register_speculative_algorithm("MY_ALGO", MyDraftWorker, flags=("EAGLE",))
|
|
216
|
+
"""
|
|
217
|
+
|
|
218
|
+
name_upper = name.upper()
|
|
219
|
+
with _LOCK:
|
|
220
|
+
try:
|
|
221
|
+
algorithm = SpeculativeAlgorithm.from_string(name_upper)
|
|
222
|
+
exists = True
|
|
223
|
+
except ValueError:
|
|
224
|
+
algorithm = SpeculativeAlgorithm.register(
|
|
225
|
+
name_upper,
|
|
226
|
+
aliases=aliases,
|
|
227
|
+
value=value,
|
|
228
|
+
)
|
|
229
|
+
SpeculativeAlgorithm.register_draft_worker(
|
|
230
|
+
algorithm, _wrap_worker_class(worker_cls)
|
|
231
|
+
)
|
|
232
|
+
exists = False
|
|
233
|
+
|
|
234
|
+
if exists:
|
|
235
|
+
if aliases:
|
|
236
|
+
SpeculativeAlgorithm.register_aliases(algorithm, *aliases)
|
|
237
|
+
if not override_worker and algorithm in _REGISTERED_WORKERS:
|
|
238
|
+
raise ValueError(
|
|
239
|
+
f"Worker already registered for {algorithm!r}. "
|
|
240
|
+
"Pass override_worker=True to replace it."
|
|
241
|
+
)
|
|
242
|
+
SpeculativeAlgorithm.register_draft_worker(
|
|
243
|
+
algorithm, _wrap_worker_class(worker_cls)
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
_REGISTERED_WORKERS[algorithm] = worker_cls
|
|
247
|
+
|
|
248
|
+
if flags:
|
|
249
|
+
for flag in flags:
|
|
250
|
+
marker = _FLAG_MARKERS.get(flag.upper())
|
|
251
|
+
if marker is None:
|
|
252
|
+
raise ValueError(f"Unsupported flag '{flag}'")
|
|
253
|
+
marker(algorithm)
|
|
254
|
+
|
|
255
|
+
return algorithm
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def list_registered_workers() -> Dict[str, DraftWorkerClass]:
|
|
259
|
+
"""Return a snapshot of registered speculative worker classes keyed by algorithm name."""
|
|
260
|
+
with _LOCK:
|
|
261
|
+
return {algo.name: cls for algo, cls in _REGISTERED_WORKERS.items()}
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def _create_eagle_worker(**kwargs: Any) -> Any:
|
|
265
|
+
enable_overlap = kwargs.pop("enable_overlap", False)
|
|
266
|
+
if enable_overlap:
|
|
267
|
+
from sglang.srt.speculative.eagle_worker_v2 import EAGLEWorkerV2
|
|
268
|
+
|
|
269
|
+
return EAGLEWorkerV2(**kwargs)
|
|
270
|
+
|
|
271
|
+
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
|
272
|
+
|
|
273
|
+
return EAGLEWorker(**kwargs)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def _create_standalone_worker(**kwargs: Any) -> Any:
|
|
277
|
+
from sglang.srt.speculative.standalone_worker import StandaloneWorker
|
|
278
|
+
|
|
279
|
+
return StandaloneWorker(**kwargs)
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def _create_ngram_worker(**kwargs: Any) -> Any:
|
|
283
|
+
from sglang.srt.speculative.ngram_worker import NGRAMWorker
|
|
284
|
+
|
|
285
|
+
return NGRAMWorker(**kwargs)
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
# Register built-in algorithms.
|
|
289
|
+
# Third-party integrations should import `SpeculativeAlgorithm` and either
|
|
290
|
+
# call `register_speculative_algorithm` or use the helpers below to attach
|
|
291
|
+
# additional draft workers.
|
|
292
|
+
SpeculativeAlgorithm.register("NONE")
|
|
293
|
+
|
|
294
|
+
register_speculative_algorithm(
|
|
295
|
+
"EAGLE",
|
|
296
|
+
aliases=("NEXTN",),
|
|
297
|
+
worker_cls=_create_eagle_worker,
|
|
298
|
+
flags=("EAGLE",),
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
register_speculative_algorithm(
|
|
302
|
+
"EAGLE3",
|
|
303
|
+
worker_cls=_create_eagle_worker,
|
|
304
|
+
flags=("EAGLE", "EAGLE3"),
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
register_speculative_algorithm(
|
|
308
|
+
"STANDALONE",
|
|
309
|
+
worker_cls=_create_standalone_worker,
|
|
310
|
+
flags=("STANDALONE",),
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
register_speculative_algorithm(
|
|
314
|
+
"NGRAM",
|
|
315
|
+
worker_cls=_create_ngram_worker,
|
|
316
|
+
flags=("NGRAM",),
|
|
317
|
+
)
|
|
44
318
|
|
|
45
319
|
|
|
46
320
|
class SpecInputType(IntEnum):
|
|
@@ -19,16 +19,22 @@ from sglang.srt.distributed.parallel_state import (
|
|
|
19
19
|
from sglang.srt.environ import envs
|
|
20
20
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
|
21
21
|
from sglang.srt.managers.schedule_batch import Req
|
|
22
|
-
from sglang.srt.utils import is_cuda, is_hip
|
|
22
|
+
from sglang.srt.utils import is_cuda, is_hip, is_npu, next_power_of_2
|
|
23
|
+
|
|
24
|
+
_is_cuda = is_cuda()
|
|
25
|
+
_is_hip = is_hip()
|
|
26
|
+
_is_npu = is_npu()
|
|
23
27
|
|
|
24
28
|
if TYPE_CHECKING:
|
|
25
29
|
from sglang.srt.speculative.eagle_info import EagleVerifyInput
|
|
26
30
|
|
|
27
31
|
|
|
28
|
-
if
|
|
32
|
+
if _is_cuda:
|
|
29
33
|
from sgl_kernel import fast_topk
|
|
30
|
-
elif
|
|
34
|
+
elif _is_hip:
|
|
31
35
|
from sgl_kernel import fast_topk
|
|
36
|
+
else:
|
|
37
|
+
from sglang.srt.utils.common import fast_topk
|
|
32
38
|
|
|
33
39
|
|
|
34
40
|
logger = logging.getLogger(__name__)
|
|
@@ -39,7 +45,7 @@ SIMULATE_ACC_LEN = envs.SGLANG_SIMULATE_ACC_LEN.get() # turn off if < 0
|
|
|
39
45
|
SIMULATE_ACC_METHOD = envs.SGLANG_SIMULATE_ACC_METHOD.get()
|
|
40
46
|
|
|
41
47
|
TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly
|
|
42
|
-
TREE_SPEC_KERNEL_AVAILABLE =
|
|
48
|
+
TREE_SPEC_KERNEL_AVAILABLE = _is_cuda # This kernel is only available for CUDA now
|
|
43
49
|
|
|
44
50
|
|
|
45
51
|
@triton.jit
|
|
@@ -103,6 +109,36 @@ def assign_req_to_token_pool(
|
|
|
103
109
|
load_offset += BLOCK_SIZE
|
|
104
110
|
|
|
105
111
|
|
|
112
|
+
def assign_req_to_token_pool_func(
|
|
113
|
+
req_pool_indices: torch.Tensor,
|
|
114
|
+
req_to_token: torch.Tensor,
|
|
115
|
+
start_offset: torch.Tensor,
|
|
116
|
+
end_offset: torch.Tensor,
|
|
117
|
+
out_cache_loc: torch.Tensor,
|
|
118
|
+
batch_size: int,
|
|
119
|
+
):
|
|
120
|
+
if _is_cuda or _is_hip:
|
|
121
|
+
assign_req_to_token_pool[(batch_size,)](
|
|
122
|
+
req_pool_indices,
|
|
123
|
+
req_to_token,
|
|
124
|
+
start_offset,
|
|
125
|
+
end_offset,
|
|
126
|
+
out_cache_loc,
|
|
127
|
+
req_to_token.shape[1],
|
|
128
|
+
next_power_of_2(batch_size),
|
|
129
|
+
)
|
|
130
|
+
elif _is_npu:
|
|
131
|
+
import sgl_kernel_npu # noqa: F401
|
|
132
|
+
|
|
133
|
+
torch.ops.npu.cache_loc_assign(
|
|
134
|
+
req_pool_indices,
|
|
135
|
+
req_to_token,
|
|
136
|
+
start_offset,
|
|
137
|
+
end_offset,
|
|
138
|
+
out_cache_loc,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
|
|
106
142
|
@triton.jit
|
|
107
143
|
def assign_draft_cache_locs(
|
|
108
144
|
req_pool_indices,
|
|
@@ -331,7 +367,7 @@ def get_target_cache_loc(
|
|
|
331
367
|
)
|
|
332
368
|
|
|
333
369
|
|
|
334
|
-
@torch.compile(dynamic=True)
|
|
370
|
+
@torch.compile(dynamic=True, disable=_is_npu)
|
|
335
371
|
def get_src_tgt_cache_loc(
|
|
336
372
|
seq_lens: torch.Tensor,
|
|
337
373
|
out_cache_loc: torch.Tensor,
|
|
@@ -381,7 +417,7 @@ def filter_finished_cache_loc_kernel(
|
|
|
381
417
|
)
|
|
382
418
|
|
|
383
419
|
|
|
384
|
-
@torch.compile(dynamic=True)
|
|
420
|
+
@torch.compile(dynamic=True, disable=_is_npu)
|
|
385
421
|
def create_accept_length_filter(
|
|
386
422
|
accept_length: torch.Tensor,
|
|
387
423
|
unfinished_index_device: torch.Tensor,
|
|
@@ -395,7 +431,7 @@ def create_accept_length_filter(
|
|
|
395
431
|
return accept_length_filter
|
|
396
432
|
|
|
397
433
|
|
|
398
|
-
@torch.compile(dynamic=True)
|
|
434
|
+
@torch.compile(dynamic=True, disable=_is_npu)
|
|
399
435
|
def select_top_k_tokens(
|
|
400
436
|
i: int,
|
|
401
437
|
topk_p: torch.Tensor,
|
|
@@ -413,7 +449,7 @@ def select_top_k_tokens(
|
|
|
413
449
|
tree_info = (
|
|
414
450
|
topk_p.unsqueeze(1), # shape: (b, 1, topk)
|
|
415
451
|
topk_index, # shape: (b, topk)
|
|
416
|
-
torch.arange(-1, topk, dtype=torch.long, device=
|
|
452
|
+
torch.arange(-1, topk, dtype=torch.long, device=hidden_states.device)
|
|
417
453
|
.unsqueeze(0)
|
|
418
454
|
.repeat(topk_p.shape[0], 1), # shape: (b, topk + 1)
|
|
419
455
|
)
|