sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +18 -3
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  5. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
  6. sglang/srt/checkpoint_engine/__init__.py +9 -0
  7. sglang/srt/checkpoint_engine/update.py +317 -0
  8. sglang/srt/configs/__init__.py +2 -0
  9. sglang/srt/configs/deepseek_ocr.py +542 -10
  10. sglang/srt/configs/deepseekvl2.py +95 -194
  11. sglang/srt/configs/kimi_linear.py +160 -0
  12. sglang/srt/configs/mamba_utils.py +66 -0
  13. sglang/srt/configs/model_config.py +25 -2
  14. sglang/srt/constants.py +7 -0
  15. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  16. sglang/srt/disaggregation/decode.py +34 -6
  17. sglang/srt/disaggregation/nixl/conn.py +2 -2
  18. sglang/srt/disaggregation/prefill.py +25 -3
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  20. sglang/srt/distributed/parallel_state.py +9 -5
  21. sglang/srt/entrypoints/engine.py +13 -5
  22. sglang/srt/entrypoints/http_server.py +22 -3
  23. sglang/srt/entrypoints/openai/protocol.py +7 -1
  24. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  27. sglang/srt/environ.py +7 -0
  28. sglang/srt/eplb/expert_distribution.py +34 -1
  29. sglang/srt/eplb/expert_location.py +106 -36
  30. sglang/srt/grpc/compile_proto.py +3 -0
  31. sglang/srt/layers/attention/ascend_backend.py +233 -5
  32. sglang/srt/layers/attention/attention_registry.py +3 -0
  33. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  34. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  35. sglang/srt/layers/attention/fla/kda.py +1359 -0
  36. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  37. sglang/srt/layers/attention/flashattention_backend.py +7 -6
  38. sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
  39. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  40. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  41. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  42. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  43. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  44. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  45. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  46. sglang/srt/layers/attention/nsa_backend.py +157 -23
  47. sglang/srt/layers/attention/triton_backend.py +4 -1
  48. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  49. sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
  50. sglang/srt/layers/communicator.py +23 -1
  51. sglang/srt/layers/layernorm.py +16 -2
  52. sglang/srt/layers/logits_processor.py +4 -20
  53. sglang/srt/layers/moe/ep_moe/layer.py +0 -18
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  57. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  59. sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
  60. sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
  61. sglang/srt/layers/moe/topk.py +31 -6
  62. sglang/srt/layers/pooler.py +21 -2
  63. sglang/srt/layers/quantization/__init__.py +9 -78
  64. sglang/srt/layers/quantization/auto_round.py +394 -0
  65. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  66. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  67. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  68. sglang/srt/layers/rotary_embedding.py +117 -45
  69. sglang/srt/lora/lora_registry.py +9 -0
  70. sglang/srt/managers/async_mm_data_processor.py +122 -0
  71. sglang/srt/managers/data_parallel_controller.py +30 -3
  72. sglang/srt/managers/detokenizer_manager.py +3 -0
  73. sglang/srt/managers/io_struct.py +26 -4
  74. sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
  75. sglang/srt/managers/schedule_batch.py +74 -15
  76. sglang/srt/managers/scheduler.py +164 -129
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  78. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  79. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  80. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  81. sglang/srt/managers/session_controller.py +6 -5
  82. sglang/srt/managers/tokenizer_manager.py +154 -59
  83. sglang/srt/managers/tp_worker.py +24 -1
  84. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  85. sglang/srt/mem_cache/common.py +1 -0
  86. sglang/srt/mem_cache/memory_pool.py +171 -57
  87. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  88. sglang/srt/mem_cache/radix_cache.py +4 -0
  89. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  90. sglang/srt/metrics/collector.py +46 -3
  91. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  92. sglang/srt/model_executor/forward_batch_info.py +11 -11
  93. sglang/srt/model_executor/model_runner.py +76 -21
  94. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  95. sglang/srt/model_loader/weight_utils.py +1 -1
  96. sglang/srt/models/bailing_moe.py +9 -2
  97. sglang/srt/models/deepseek_nextn.py +11 -2
  98. sglang/srt/models/deepseek_v2.py +149 -34
  99. sglang/srt/models/glm4.py +391 -77
  100. sglang/srt/models/glm4v.py +196 -55
  101. sglang/srt/models/glm4v_moe.py +0 -1
  102. sglang/srt/models/gpt_oss.py +1 -10
  103. sglang/srt/models/kimi_linear.py +678 -0
  104. sglang/srt/models/llama4.py +1 -1
  105. sglang/srt/models/llama_eagle3.py +11 -1
  106. sglang/srt/models/longcat_flash.py +2 -2
  107. sglang/srt/models/minimax_m2.py +1 -1
  108. sglang/srt/models/qwen2.py +1 -1
  109. sglang/srt/models/qwen2_moe.py +30 -15
  110. sglang/srt/models/qwen3.py +1 -1
  111. sglang/srt/models/qwen3_moe.py +16 -8
  112. sglang/srt/models/qwen3_next.py +7 -0
  113. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  114. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  115. sglang/srt/multiplex/pdmux_context.py +164 -0
  116. sglang/srt/parser/conversation.py +7 -1
  117. sglang/srt/sampling/custom_logit_processor.py +67 -1
  118. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  119. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  120. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  121. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  122. sglang/srt/server_args.py +103 -22
  123. sglang/srt/single_batch_overlap.py +4 -1
  124. sglang/srt/speculative/draft_utils.py +16 -0
  125. sglang/srt/speculative/eagle_info.py +42 -36
  126. sglang/srt/speculative/eagle_info_v2.py +68 -25
  127. sglang/srt/speculative/eagle_utils.py +261 -16
  128. sglang/srt/speculative/eagle_worker.py +11 -3
  129. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  130. sglang/srt/speculative/spec_info.py +305 -31
  131. sglang/srt/speculative/spec_utils.py +44 -8
  132. sglang/srt/tracing/trace.py +121 -12
  133. sglang/srt/utils/common.py +55 -32
  134. sglang/srt/utils/hf_transformers_utils.py +38 -16
  135. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  136. sglang/test/kits/radix_cache_server_kit.py +50 -0
  137. sglang/test/runners.py +31 -7
  138. sglang/test/simple_eval_common.py +5 -3
  139. sglang/test/simple_eval_humaneval.py +1 -0
  140. sglang/test/simple_eval_math.py +1 -0
  141. sglang/test/simple_eval_mmlu.py +1 -0
  142. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  143. sglang/test/test_utils.py +7 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
  146. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
  147. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  148. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  149. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,7 @@ from typing import List, Optional, Tuple
5
5
  import torch
6
6
 
7
7
  from sglang.srt.distributed import get_tp_group
8
+ from sglang.srt.layers.dp_attention import get_attention_tp_group
8
9
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
9
10
  from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
10
11
  from sglang.srt.managers.schedule_batch import ScheduleBatch
@@ -52,9 +53,12 @@ from sglang.srt.utils import (
52
53
  get_available_gpu_memory,
53
54
  get_bool_env_var,
54
55
  is_cuda,
56
+ is_npu,
55
57
  next_power_of_2,
56
58
  )
57
59
 
60
+ _is_npu = is_npu()
61
+
58
62
  if is_cuda():
59
63
  from sgl_kernel import segment_packbits # noqa: F401
60
64
 
@@ -117,7 +121,11 @@ class EAGLEWorker(TpModelWorker):
117
121
  self.hot_token_id = None
118
122
 
119
123
  # Init draft worker
120
- with empty_context():
124
+ if server_args.enable_dp_attention and self.speculative_algorithm.is_eagle3():
125
+ ctx = draft_tp_context(get_attention_tp_group())
126
+ else:
127
+ ctx = empty_context()
128
+ with ctx:
121
129
  super().__init__(
122
130
  server_args=server_args,
123
131
  gpu_id=gpu_id,
@@ -200,7 +208,7 @@ class EAGLEWorker(TpModelWorker):
200
208
  self.cuda_graph_runner = None
201
209
  self.cuda_graph_runner_for_draft_extend = None
202
210
 
203
- if self.server_args.disable_cuda_graph:
211
+ if self.server_args.disable_cuda_graph or _is_npu:
204
212
  return
205
213
 
206
214
  # Capture draft
@@ -940,7 +948,7 @@ class EAGLEWorker(TpModelWorker):
940
948
  draft_input.hidden_states = logits_output.hidden_states
941
949
 
942
950
 
943
- @torch.compile(dynamic=True)
951
+ @torch.compile(dynamic=True, disable=_is_npu)
944
952
  def get_last_loc_large_page_size_top_k_1(
945
953
  req_to_token: torch.Tensor,
946
954
  req_pool_indices: torch.Tensor,
@@ -4,7 +4,6 @@ import time
4
4
  from typing import List, Optional, Tuple
5
5
 
6
6
  import torch
7
- from torch.cuda import Stream as CudaStream
8
7
 
9
8
  from sglang.srt.environ import envs
10
9
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch
@@ -38,18 +37,21 @@ from sglang.srt.utils.common import (
38
37
  empty_context,
39
38
  fast_topk,
40
39
  get_available_gpu_memory,
40
+ is_npu,
41
41
  next_power_of_2,
42
42
  )
43
43
 
44
+ _is_npu = is_npu()
45
+
44
46
  logger = logging.getLogger(__name__)
45
47
 
46
48
 
47
49
  def _get_plan_stream(
48
50
  device: str,
49
- ) -> Tuple[Optional[CudaStream], contextlib.AbstractContextManager]:
51
+ ) -> Tuple[any, contextlib.AbstractContextManager]:
50
52
  if envs.SGLANG_ENABLE_OVERLAP_PLAN_STREAM.get():
51
- plan_stream: CudaStream = torch.get_device_module(device).Stream()
52
- plan_stream_ctx = torch.cuda.stream(plan_stream)
53
+ plan_stream = torch.get_device_module(device).Stream()
54
+ plan_stream_ctx = torch.get_device_module(device).stream(plan_stream)
53
55
  return plan_stream, plan_stream_ctx
54
56
  else:
55
57
  return None, contextlib.nullcontext()
@@ -206,7 +208,7 @@ class EagleDraftWorker(BaseDraftWorker):
206
208
  self.cuda_graph_runner = None
207
209
  self.cuda_graph_runner_for_draft_extend = None
208
210
 
209
- if self.server_args.disable_cuda_graph:
211
+ if self.server_args.disable_cuda_graph or _is_npu:
210
212
  return
211
213
 
212
214
  # Capture draft
@@ -456,7 +458,9 @@ class EagleDraftWorker(BaseDraftWorker):
456
458
  )
457
459
 
458
460
  if self.plan_stream:
459
- torch.cuda.current_stream().wait_stream(self.plan_stream)
461
+ torch.get_device_module(self.device).current_stream().wait_stream(
462
+ self.plan_stream
463
+ )
460
464
 
461
465
  # Run draft extend batch in the main compute stream
462
466
  draft_logits_output = self.draft_runner.model.forward(
@@ -577,7 +581,9 @@ class EAGLEWorkerV2(BaseSpecWorker):
577
581
  # Since batch.seq_lens is allocated in another stream, we need
578
582
  # record_stream() to prevent pytorch gc and reuse the gpu memory
579
583
  # while forward_stream is still running.
580
- batch.seq_lens.record_stream(torch.cuda.current_stream())
584
+ batch.seq_lens.record_stream(
585
+ torch.get_device_module(self.device).current_stream()
586
+ )
581
587
 
582
588
  # Parse args
583
589
  verify_input: EagleVerifyInput = batch.spec_info
@@ -596,7 +602,7 @@ class EAGLEWorkerV2(BaseSpecWorker):
596
602
 
597
603
  # Correct some buffers due to the overlap plan
598
604
  if self.plan_stream:
599
- torch.cuda.current_stream().wait_stream(self.plan_stream)
605
+ torch.get_device_module().current_stream().wait_stream(self.plan_stream)
600
606
 
601
607
  # Some values such as custom_mask and position depend on the output of draft,
602
608
  # so the previous plan step used the wrong values. Here, we need to run the related
@@ -628,7 +634,7 @@ class EAGLEWorkerV2(BaseSpecWorker):
628
634
  accept_index,
629
635
  ) = verify_input.sample(batch, logits_output)
630
636
  new_seq_lens = batch.seq_lens + accept_length
631
- verify_done = torch.cuda.Event()
637
+ verify_done = torch.get_device_module(self.device).Event()
632
638
  verify_done.record()
633
639
 
634
640
  all_verified_id = predict[accept_index]
@@ -1,46 +1,320 @@
1
+ from __future__ import annotations
2
+
3
+ import threading
1
4
  from abc import ABC, abstractmethod
5
+ from collections import defaultdict
2
6
  from enum import IntEnum, auto
3
- from functools import lru_cache
4
- from typing import List, Tuple
7
+ from typing import (
8
+ Any,
9
+ Callable,
10
+ DefaultDict,
11
+ Dict,
12
+ Iterable,
13
+ Iterator,
14
+ List,
15
+ Optional,
16
+ Sequence,
17
+ Set,
18
+ Tuple,
19
+ Union,
20
+ )
5
21
 
6
22
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch
7
23
 
24
+ DraftWorkerClass = Callable[..., Any]
25
+ DraftWorkerFactory = Callable[..., Any]
8
26
 
9
- class SpeculativeAlgorithm(IntEnum):
10
- NONE = auto()
11
- EAGLE = auto()
12
- EAGLE3 = auto()
13
- STANDALONE = auto()
14
- NGRAM = auto()
15
27
 
16
- def is_none(self):
17
- return self == SpeculativeAlgorithm.NONE
28
+ class _SpeculativeAlgorithmMeta(type):
29
+ def __iter__(cls) -> Iterator["SpeculativeAlgorithm"]:
30
+ return iter(cls._registration_order)
18
31
 
19
- def is_eagle(self):
20
- return self == SpeculativeAlgorithm.EAGLE or self == SpeculativeAlgorithm.EAGLE3
21
32
 
22
- def is_eagle3(self):
23
- return self == SpeculativeAlgorithm.EAGLE3
33
+ class SpeculativeAlgorithm(metaclass=_SpeculativeAlgorithmMeta):
34
+ """Registry-backed representation of speculative decoding algorithms."""
24
35
 
25
- def is_standalone(self):
26
- return self == SpeculativeAlgorithm.STANDALONE
36
+ __slots__ = ("name", "value", "_draft_worker_factory")
27
37
 
28
- def is_ngram(self):
29
- return self == SpeculativeAlgorithm.NGRAM
38
+ _registry_by_name: Dict[str, "SpeculativeAlgorithm"] = {}
39
+ _registry_by_value: Dict[int, "SpeculativeAlgorithm"] = {}
40
+ _registration_order: List["SpeculativeAlgorithm"] = []
41
+ _flags: DefaultDict[str, Set[int]] = defaultdict(set)
42
+ _next_value: int = 0
30
43
 
31
- @lru_cache(maxsize=None)
32
- @staticmethod
33
- def from_string(name: str):
34
- name_map = {
35
- "EAGLE": SpeculativeAlgorithm.EAGLE,
36
- "EAGLE3": SpeculativeAlgorithm.EAGLE3,
37
- "STANDALONE": SpeculativeAlgorithm.STANDALONE,
38
- "NGRAM": SpeculativeAlgorithm.NGRAM,
39
- None: SpeculativeAlgorithm.NONE,
40
- }
41
- if name is not None:
42
- name = name.upper()
43
- return name_map[name]
44
+ def __init__(
45
+ self,
46
+ name: str,
47
+ value: int,
48
+ draft_worker_factory: Optional[DraftWorkerFactory] = None,
49
+ ):
50
+ self.name = name
51
+ self.value = value
52
+ self._draft_worker_factory = draft_worker_factory
53
+
54
+ def __repr__(self) -> str: # pragma: no cover - trivial
55
+ return f"SpeculativeAlgorithm.{self.name}"
56
+
57
+ def __str__(self) -> str: # pragma: no cover - trivial
58
+ return self.name
59
+
60
+ def __hash__(self) -> int:
61
+ return hash(self.value)
62
+
63
+ def __eq__(self, other: object) -> bool:
64
+ if isinstance(other, SpeculativeAlgorithm):
65
+ return self.value == other.value
66
+ return NotImplemented
67
+
68
+ def __int__(self) -> int:
69
+ return self.value
70
+
71
+ @classmethod
72
+ def register(
73
+ cls,
74
+ name: str,
75
+ *,
76
+ aliases: Optional[Sequence[str]] = None,
77
+ value: Optional[int] = None,
78
+ draft_worker_factory: Optional[DraftWorkerFactory] = None,
79
+ ) -> SpeculativeAlgorithm:
80
+ normalized_name = name.upper()
81
+ if normalized_name in cls._registry_by_name:
82
+ raise ValueError(
83
+ f"SpeculativeAlgorithm '{normalized_name}' already registered"
84
+ )
85
+
86
+ if value is None:
87
+ value = cls._next_value
88
+ cls._next_value = max(cls._next_value, value + 1)
89
+
90
+ algorithm = cls(
91
+ normalized_name,
92
+ value,
93
+ draft_worker_factory=draft_worker_factory,
94
+ )
95
+
96
+ cls._registry_by_name[normalized_name] = algorithm
97
+ cls._registry_by_value[value] = algorithm
98
+ cls._registration_order.append(algorithm)
99
+ setattr(cls, normalized_name, algorithm)
100
+
101
+ if aliases:
102
+ cls.register_aliases(algorithm, *aliases)
103
+
104
+ return algorithm
105
+
106
+ @classmethod
107
+ def register_aliases(cls, algorithm: SpeculativeAlgorithm, *aliases: str) -> None:
108
+ for alias in aliases:
109
+ cls._registry_by_name[alias.upper()] = algorithm
110
+
111
+ @classmethod
112
+ def register_draft_worker(
113
+ cls,
114
+ algorithm: SpeculativeAlgorithm | str,
115
+ factory: DraftWorkerFactory,
116
+ ) -> None:
117
+ algo = cls._ensure_algorithm(algorithm)
118
+ algo._draft_worker_factory = factory
119
+
120
+ @classmethod
121
+ def _ensure_algorithm(
122
+ cls, algorithm: SpeculativeAlgorithm | str
123
+ ) -> SpeculativeAlgorithm:
124
+ if isinstance(algorithm, SpeculativeAlgorithm):
125
+ return algorithm
126
+ if isinstance(algorithm, str):
127
+ return cls.from_string(algorithm)
128
+ raise TypeError(f"Unsupported algorithm identifier: {algorithm!r}")
129
+
130
+ @classmethod
131
+ def _add_flag(
132
+ cls, flag: str | Sequence[str], algorithm: SpeculativeAlgorithm | str
133
+ ) -> None:
134
+ algo = cls._ensure_algorithm(algorithm)
135
+ if isinstance(flag, str):
136
+ flag_iter = (flag,)
137
+ else:
138
+ flag_iter = flag
139
+ for flag_name in flag_iter:
140
+ cls._flags[flag_name.upper()].add(algo.value)
141
+
142
+ @classmethod
143
+ def from_string(cls, name: Optional[str]) -> SpeculativeAlgorithm:
144
+ if name is None:
145
+ return cls.NONE
146
+ try:
147
+ return cls._registry_by_name[name.upper()]
148
+ except KeyError as exc:
149
+ raise ValueError(f"Unknown speculative algorithm '{name}'") from exc
150
+
151
+ @classmethod
152
+ def from_value(cls, value: int) -> SpeculativeAlgorithm:
153
+ try:
154
+ return cls._registry_by_value[value]
155
+ except KeyError as exc:
156
+ raise ValueError(f"Unknown speculative algorithm id {value}") from exc
157
+
158
+ def _has_flag(self, flag: str) -> bool:
159
+ return self.value in type(self)._flags.get(flag.upper(), set())
160
+
161
+ def is_none(self) -> bool:
162
+ return self is SpeculativeAlgorithm.NONE
163
+
164
+ def is_eagle(self) -> bool:
165
+ return self._has_flag("EAGLE")
166
+
167
+ def is_eagle3(self) -> bool:
168
+ return self._has_flag("EAGLE3")
169
+
170
+ def is_standalone(self) -> bool:
171
+ return self._has_flag("STANDALONE")
172
+
173
+ def is_ngram(self) -> bool:
174
+ return self._has_flag("NGRAM")
175
+
176
+ def create_draft_worker(self, **factory_kwargs: Any) -> Any:
177
+ if self._draft_worker_factory is None:
178
+ return None
179
+ return self._draft_worker_factory(self, **factory_kwargs)
180
+
181
+
182
+ # Registry helpers backed by `SpeculativeAlgorithm`.
183
+ _LOCK = threading.RLock()
184
+ _REGISTERED_WORKERS: Dict[SpeculativeAlgorithm, DraftWorkerClass] = {}
185
+ _FLAG_MARKERS: Dict[str, Callable[[Union[SpeculativeAlgorithm, str]], None]] = {
186
+ "EAGLE": lambda algorithm: SpeculativeAlgorithm._add_flag("EAGLE", algorithm),
187
+ "EAGLE3": lambda algorithm: SpeculativeAlgorithm._add_flag("EAGLE3", algorithm),
188
+ "STANDALONE": lambda algorithm: SpeculativeAlgorithm._add_flag(
189
+ "STANDALONE", algorithm
190
+ ),
191
+ "NGRAM": lambda algorithm: SpeculativeAlgorithm._add_flag("NGRAM", algorithm),
192
+ }
193
+
194
+
195
+ def _wrap_worker_class(worker_cls: DraftWorkerClass) -> DraftWorkerFactory:
196
+ def _factory(_: SpeculativeAlgorithm, **kwargs: Any) -> Any:
197
+ return worker_cls(**kwargs)
198
+
199
+ return _factory
200
+
201
+
202
+ def register_speculative_algorithm(
203
+ name: str,
204
+ worker_cls: DraftWorkerClass,
205
+ *,
206
+ aliases: Optional[Sequence[str]] = None,
207
+ flags: Optional[Iterable[str]] = None,
208
+ value: Optional[int] = None,
209
+ override_worker: bool = False,
210
+ ) -> SpeculativeAlgorithm:
211
+ """Register a speculative algorithm and the associated draft worker class.
212
+
213
+ Example:
214
+ >>> from sglang.srt.speculative.spec_info import register_speculative_algorithm
215
+ >>> register_speculative_algorithm("MY_ALGO", MyDraftWorker, flags=("EAGLE",))
216
+ """
217
+
218
+ name_upper = name.upper()
219
+ with _LOCK:
220
+ try:
221
+ algorithm = SpeculativeAlgorithm.from_string(name_upper)
222
+ exists = True
223
+ except ValueError:
224
+ algorithm = SpeculativeAlgorithm.register(
225
+ name_upper,
226
+ aliases=aliases,
227
+ value=value,
228
+ )
229
+ SpeculativeAlgorithm.register_draft_worker(
230
+ algorithm, _wrap_worker_class(worker_cls)
231
+ )
232
+ exists = False
233
+
234
+ if exists:
235
+ if aliases:
236
+ SpeculativeAlgorithm.register_aliases(algorithm, *aliases)
237
+ if not override_worker and algorithm in _REGISTERED_WORKERS:
238
+ raise ValueError(
239
+ f"Worker already registered for {algorithm!r}. "
240
+ "Pass override_worker=True to replace it."
241
+ )
242
+ SpeculativeAlgorithm.register_draft_worker(
243
+ algorithm, _wrap_worker_class(worker_cls)
244
+ )
245
+
246
+ _REGISTERED_WORKERS[algorithm] = worker_cls
247
+
248
+ if flags:
249
+ for flag in flags:
250
+ marker = _FLAG_MARKERS.get(flag.upper())
251
+ if marker is None:
252
+ raise ValueError(f"Unsupported flag '{flag}'")
253
+ marker(algorithm)
254
+
255
+ return algorithm
256
+
257
+
258
+ def list_registered_workers() -> Dict[str, DraftWorkerClass]:
259
+ """Return a snapshot of registered speculative worker classes keyed by algorithm name."""
260
+ with _LOCK:
261
+ return {algo.name: cls for algo, cls in _REGISTERED_WORKERS.items()}
262
+
263
+
264
+ def _create_eagle_worker(**kwargs: Any) -> Any:
265
+ enable_overlap = kwargs.pop("enable_overlap", False)
266
+ if enable_overlap:
267
+ from sglang.srt.speculative.eagle_worker_v2 import EAGLEWorkerV2
268
+
269
+ return EAGLEWorkerV2(**kwargs)
270
+
271
+ from sglang.srt.speculative.eagle_worker import EAGLEWorker
272
+
273
+ return EAGLEWorker(**kwargs)
274
+
275
+
276
+ def _create_standalone_worker(**kwargs: Any) -> Any:
277
+ from sglang.srt.speculative.standalone_worker import StandaloneWorker
278
+
279
+ return StandaloneWorker(**kwargs)
280
+
281
+
282
+ def _create_ngram_worker(**kwargs: Any) -> Any:
283
+ from sglang.srt.speculative.ngram_worker import NGRAMWorker
284
+
285
+ return NGRAMWorker(**kwargs)
286
+
287
+
288
+ # Register built-in algorithms.
289
+ # Third-party integrations should import `SpeculativeAlgorithm` and either
290
+ # call `register_speculative_algorithm` or use the helpers below to attach
291
+ # additional draft workers.
292
+ SpeculativeAlgorithm.register("NONE")
293
+
294
+ register_speculative_algorithm(
295
+ "EAGLE",
296
+ aliases=("NEXTN",),
297
+ worker_cls=_create_eagle_worker,
298
+ flags=("EAGLE",),
299
+ )
300
+
301
+ register_speculative_algorithm(
302
+ "EAGLE3",
303
+ worker_cls=_create_eagle_worker,
304
+ flags=("EAGLE", "EAGLE3"),
305
+ )
306
+
307
+ register_speculative_algorithm(
308
+ "STANDALONE",
309
+ worker_cls=_create_standalone_worker,
310
+ flags=("STANDALONE",),
311
+ )
312
+
313
+ register_speculative_algorithm(
314
+ "NGRAM",
315
+ worker_cls=_create_ngram_worker,
316
+ flags=("NGRAM",),
317
+ )
44
318
 
45
319
 
46
320
  class SpecInputType(IntEnum):
@@ -19,16 +19,22 @@ from sglang.srt.distributed.parallel_state import (
19
19
  from sglang.srt.environ import envs
20
20
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
21
21
  from sglang.srt.managers.schedule_batch import Req
22
- from sglang.srt.utils import is_cuda, is_hip
22
+ from sglang.srt.utils import is_cuda, is_hip, is_npu, next_power_of_2
23
+
24
+ _is_cuda = is_cuda()
25
+ _is_hip = is_hip()
26
+ _is_npu = is_npu()
23
27
 
24
28
  if TYPE_CHECKING:
25
29
  from sglang.srt.speculative.eagle_info import EagleVerifyInput
26
30
 
27
31
 
28
- if is_cuda():
32
+ if _is_cuda:
29
33
  from sgl_kernel import fast_topk
30
- elif is_hip():
34
+ elif _is_hip:
31
35
  from sgl_kernel import fast_topk
36
+ else:
37
+ from sglang.srt.utils.common import fast_topk
32
38
 
33
39
 
34
40
  logger = logging.getLogger(__name__)
@@ -39,7 +45,7 @@ SIMULATE_ACC_LEN = envs.SGLANG_SIMULATE_ACC_LEN.get() # turn off if < 0
39
45
  SIMULATE_ACC_METHOD = envs.SGLANG_SIMULATE_ACC_METHOD.get()
40
46
 
41
47
  TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly
42
- TREE_SPEC_KERNEL_AVAILABLE = is_cuda() # This kernel is only available for CUDA now
48
+ TREE_SPEC_KERNEL_AVAILABLE = _is_cuda # This kernel is only available for CUDA now
43
49
 
44
50
 
45
51
  @triton.jit
@@ -103,6 +109,36 @@ def assign_req_to_token_pool(
103
109
  load_offset += BLOCK_SIZE
104
110
 
105
111
 
112
+ def assign_req_to_token_pool_func(
113
+ req_pool_indices: torch.Tensor,
114
+ req_to_token: torch.Tensor,
115
+ start_offset: torch.Tensor,
116
+ end_offset: torch.Tensor,
117
+ out_cache_loc: torch.Tensor,
118
+ batch_size: int,
119
+ ):
120
+ if _is_cuda or _is_hip:
121
+ assign_req_to_token_pool[(batch_size,)](
122
+ req_pool_indices,
123
+ req_to_token,
124
+ start_offset,
125
+ end_offset,
126
+ out_cache_loc,
127
+ req_to_token.shape[1],
128
+ next_power_of_2(batch_size),
129
+ )
130
+ elif _is_npu:
131
+ import sgl_kernel_npu # noqa: F401
132
+
133
+ torch.ops.npu.cache_loc_assign(
134
+ req_pool_indices,
135
+ req_to_token,
136
+ start_offset,
137
+ end_offset,
138
+ out_cache_loc,
139
+ )
140
+
141
+
106
142
  @triton.jit
107
143
  def assign_draft_cache_locs(
108
144
  req_pool_indices,
@@ -331,7 +367,7 @@ def get_target_cache_loc(
331
367
  )
332
368
 
333
369
 
334
- @torch.compile(dynamic=True)
370
+ @torch.compile(dynamic=True, disable=_is_npu)
335
371
  def get_src_tgt_cache_loc(
336
372
  seq_lens: torch.Tensor,
337
373
  out_cache_loc: torch.Tensor,
@@ -381,7 +417,7 @@ def filter_finished_cache_loc_kernel(
381
417
  )
382
418
 
383
419
 
384
- @torch.compile(dynamic=True)
420
+ @torch.compile(dynamic=True, disable=_is_npu)
385
421
  def create_accept_length_filter(
386
422
  accept_length: torch.Tensor,
387
423
  unfinished_index_device: torch.Tensor,
@@ -395,7 +431,7 @@ def create_accept_length_filter(
395
431
  return accept_length_filter
396
432
 
397
433
 
398
- @torch.compile(dynamic=True)
434
+ @torch.compile(dynamic=True, disable=_is_npu)
399
435
  def select_top_k_tokens(
400
436
  i: int,
401
437
  topk_p: torch.Tensor,
@@ -413,7 +449,7 @@ def select_top_k_tokens(
413
449
  tree_info = (
414
450
  topk_p.unsqueeze(1), # shape: (b, 1, topk)
415
451
  topk_index, # shape: (b, topk)
416
- torch.arange(-1, topk, dtype=torch.long, device="cuda")
452
+ torch.arange(-1, topk, dtype=torch.long, device=hidden_states.device)
417
453
  .unsqueeze(0)
418
454
  .repeat(topk_p.shape[0], 1), # shape: (b, topk + 1)
419
455
  )