sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +164 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +62 -23
  49. sglang/srt/layers/elementwise.py +411 -0
  50. sglang/srt/layers/layernorm.py +24 -2
  51. sglang/srt/layers/linear.py +17 -5
  52. sglang/srt/layers/logits_processor.py +26 -7
  53. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  54. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  55. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  56. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  63. sglang/srt/layers/moe/router.py +342 -0
  64. sglang/srt/layers/moe/topk.py +31 -18
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +184 -126
  67. sglang/srt/layers/quantization/base_config.py +5 -0
  68. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  69. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  70. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  75. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  76. sglang/srt/layers/quantization/fp8.py +76 -34
  77. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  78. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  79. sglang/srt/layers/quantization/gptq.py +36 -9
  80. sglang/srt/layers/quantization/kv_cache.py +98 -0
  81. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  82. sglang/srt/layers/quantization/utils.py +153 -0
  83. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  84. sglang/srt/layers/rotary_embedding.py +66 -87
  85. sglang/srt/layers/sampler.py +1 -1
  86. sglang/srt/lora/layers.py +68 -0
  87. sglang/srt/lora/lora.py +2 -22
  88. sglang/srt/lora/lora_manager.py +47 -23
  89. sglang/srt/lora/mem_pool.py +110 -51
  90. sglang/srt/lora/utils.py +12 -1
  91. sglang/srt/managers/cache_controller.py +4 -5
  92. sglang/srt/managers/data_parallel_controller.py +31 -9
  93. sglang/srt/managers/expert_distribution.py +81 -0
  94. sglang/srt/managers/io_struct.py +39 -3
  95. sglang/srt/managers/mm_utils.py +373 -0
  96. sglang/srt/managers/multimodal_processor.py +68 -0
  97. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  98. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  99. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  100. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  101. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  102. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  103. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  104. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  105. sglang/srt/managers/schedule_batch.py +134 -31
  106. sglang/srt/managers/scheduler.py +325 -38
  107. sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
  108. sglang/srt/managers/session_controller.py +1 -1
  109. sglang/srt/managers/tokenizer_manager.py +59 -23
  110. sglang/srt/managers/tp_worker.py +1 -1
  111. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  112. sglang/srt/managers/utils.py +6 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +27 -8
  114. sglang/srt/mem_cache/memory_pool.py +258 -98
  115. sglang/srt/mem_cache/paged_allocator.py +2 -2
  116. sglang/srt/mem_cache/radix_cache.py +4 -4
  117. sglang/srt/model_executor/cuda_graph_runner.py +85 -28
  118. sglang/srt/model_executor/forward_batch_info.py +81 -15
  119. sglang/srt/model_executor/model_runner.py +70 -6
  120. sglang/srt/model_loader/loader.py +160 -2
  121. sglang/srt/model_loader/weight_utils.py +45 -0
  122. sglang/srt/models/deepseek_janus_pro.py +29 -86
  123. sglang/srt/models/deepseek_nextn.py +22 -10
  124. sglang/srt/models/deepseek_v2.py +326 -192
  125. sglang/srt/models/deepseek_vl2.py +358 -0
  126. sglang/srt/models/gemma3_causal.py +684 -0
  127. sglang/srt/models/gemma3_mm.py +462 -0
  128. sglang/srt/models/grok.py +374 -119
  129. sglang/srt/models/llama.py +47 -7
  130. sglang/srt/models/llama_eagle.py +1 -0
  131. sglang/srt/models/llama_eagle3.py +196 -0
  132. sglang/srt/models/llava.py +3 -3
  133. sglang/srt/models/llavavid.py +3 -3
  134. sglang/srt/models/minicpmo.py +1995 -0
  135. sglang/srt/models/minicpmv.py +62 -137
  136. sglang/srt/models/mllama.py +4 -4
  137. sglang/srt/models/phi3_small.py +1 -1
  138. sglang/srt/models/qwen2.py +3 -0
  139. sglang/srt/models/qwen2_5_vl.py +68 -146
  140. sglang/srt/models/qwen2_classification.py +75 -0
  141. sglang/srt/models/qwen2_moe.py +9 -1
  142. sglang/srt/models/qwen2_vl.py +25 -63
  143. sglang/srt/openai_api/adapter.py +145 -47
  144. sglang/srt/openai_api/protocol.py +23 -2
  145. sglang/srt/sampling/sampling_batch_info.py +1 -1
  146. sglang/srt/sampling/sampling_params.py +6 -6
  147. sglang/srt/server_args.py +104 -14
  148. sglang/srt/speculative/build_eagle_tree.py +7 -347
  149. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  150. sglang/srt/speculative/eagle_utils.py +208 -252
  151. sglang/srt/speculative/eagle_worker.py +139 -53
  152. sglang/srt/speculative/spec_info.py +6 -1
  153. sglang/srt/torch_memory_saver_adapter.py +22 -0
  154. sglang/srt/utils.py +182 -21
  155. sglang/test/__init__.py +0 -0
  156. sglang/test/attention/__init__.py +0 -0
  157. sglang/test/attention/test_flashattn_backend.py +312 -0
  158. sglang/test/runners.py +2 -0
  159. sglang/test/test_activation.py +2 -1
  160. sglang/test/test_block_fp8.py +5 -4
  161. sglang/test/test_block_fp8_ep.py +2 -1
  162. sglang/test/test_dynamic_grad_mode.py +58 -0
  163. sglang/test/test_layernorm.py +3 -2
  164. sglang/test/test_utils.py +55 -4
  165. sglang/utils.py +31 -0
  166. sglang/version.py +1 -1
  167. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  168. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
  169. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  170. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  171. sglang/srt/managers/image_processor.py +0 -55
  172. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  173. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  174. sglang/srt/managers/multi_modality_padding.py +0 -134
  175. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  176. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,14 @@
1
1
  import logging
2
2
  import os
3
3
  import time
4
+ from contextlib import contextmanager
4
5
  from typing import List, Optional, Tuple
5
6
 
6
7
  import torch
7
8
  from huggingface_hub import snapshot_download
8
9
 
10
+ from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
11
+ from sglang.srt.layers.dp_attention import disable_dp_size
9
12
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
10
13
  from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
11
14
  from sglang.srt.managers.schedule_batch import ScheduleBatch
@@ -27,11 +30,23 @@ from sglang.srt.speculative.eagle_utils import (
27
30
  fast_topk,
28
31
  select_top_k_tokens,
29
32
  )
30
- from sglang.srt.utils import get_available_gpu_memory
33
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
34
+ from sglang.srt.utils import empty_context, get_available_gpu_memory, is_cuda_available
35
+
36
+ if is_cuda_available():
37
+ from sgl_kernel import segment_packbits
31
38
 
32
39
  logger = logging.getLogger(__name__)
33
40
 
34
41
 
42
+ @contextmanager
43
+ def draft_tp_context(tp_group: GroupCoordinator):
44
+ # Draft model doesn't use dp and has its own tp group.
45
+ # We disable mscclpp now because it doesn't support 2 comm groups.
46
+ with disable_dp_size(), patch_tensor_parallel_group(tp_group):
47
+ yield
48
+
49
+
35
50
  class EAGLEWorker(TpModelWorker):
36
51
 
37
52
  def __init__(
@@ -52,6 +67,9 @@ class EAGLEWorker(TpModelWorker):
52
67
  self.gpu_id = gpu_id
53
68
  self.device = server_args.device
54
69
  self.target_worker = target_worker
70
+ self.speculative_algorithm = SpeculativeAlgorithm.from_string(
71
+ server_args.speculative_algorithm
72
+ )
55
73
 
56
74
  # Override context length with target model's context length
57
75
  server_args.context_length = target_worker.model_runner.model_config.context_len
@@ -67,7 +85,13 @@ class EAGLEWorker(TpModelWorker):
67
85
  )
68
86
 
69
87
  # Load hot token ids
70
- if server_args.speculative_token_map is not None:
88
+ if self.speculative_algorithm.is_eagle3():
89
+ if server_args.speculative_token_map is not None:
90
+ logger.warning(
91
+ "Speculative token map specified, but EAGLE3 models already have this. Ignoring the specified token map."
92
+ )
93
+ self.hot_token_id = None
94
+ elif server_args.speculative_token_map is not None:
71
95
  self.hot_token_id = load_token_map(server_args.speculative_token_map)
72
96
  server_args.json_model_override_args = (
73
97
  f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
@@ -76,30 +100,47 @@ class EAGLEWorker(TpModelWorker):
76
100
  self.hot_token_id = None
77
101
 
78
102
  # Init draft worker
79
- super().__init__(
80
- gpu_id=gpu_id,
81
- tp_rank=tp_rank,
82
- server_args=server_args,
83
- nccl_port=nccl_port,
84
- dp_rank=dp_rank,
85
- is_draft_worker=True,
86
- req_to_token_pool=self.req_to_token_pool,
87
- token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
88
- )
103
+ with empty_context():
104
+ super().__init__(
105
+ gpu_id=gpu_id,
106
+ tp_rank=tp_rank,
107
+ server_args=server_args,
108
+ nccl_port=nccl_port,
109
+ dp_rank=dp_rank,
110
+ is_draft_worker=True,
111
+ req_to_token_pool=self.req_to_token_pool,
112
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
113
+ )
89
114
 
90
- # Share the embedding and lm_head
91
115
  embed, head = self.target_worker.model_runner.model.get_embed_and_head()
92
- if self.hot_token_id is not None:
93
- head = head.clone()
94
- self.hot_token_id = self.hot_token_id.to(head.device)
95
- head.data = head.data[self.hot_token_id]
96
- self.draft_model_runner.model.set_embed_and_head(embed, head)
116
+
117
+ if self.speculative_algorithm.is_eagle3():
118
+ # EAGLE3 models don't share lm_head
119
+ self.draft_model_runner.model.set_embed(embed)
120
+
121
+ # grab hot token ids
122
+ self.hot_token_id = self.draft_model_runner.model.get_hot_token_id().to(
123
+ embed.device
124
+ )
125
+ else:
126
+ if self.hot_token_id is not None:
127
+ head = head.clone()
128
+ self.hot_token_id = self.hot_token_id.to(head.device)
129
+ head.data = head.data[self.hot_token_id]
130
+
131
+ # Share the embedding and lm_head
132
+ self.draft_model_runner.model.set_embed_and_head(embed, head)
133
+
134
+ # Init attention backend and cuda graphs
97
135
  self.draft_model_runner.server_args.disable_cuda_graph = (
98
136
  backup_disable_cuda_graph
99
137
  )
100
-
101
- self.init_attention_backend()
102
- self.init_cuda_graphs()
138
+ self.draft_tp_context = (
139
+ draft_tp_context if server_args.enable_dp_attention else empty_context
140
+ )
141
+ with self.draft_tp_context(self.draft_model_runner.tp_group):
142
+ self.init_attention_backend()
143
+ self.init_cuda_graphs()
103
144
 
104
145
  def init_attention_backend(self):
105
146
  # Create multi-step attn backends and cuda graph runners
@@ -109,52 +150,70 @@ class EAGLEWorker(TpModelWorker):
109
150
  )
110
151
 
111
152
  self.draft_attn_backend = FlashInferMultiStepDraftBackend(
112
- self.model_runner,
153
+ self.draft_model_runner,
113
154
  self.topk,
114
155
  self.speculative_num_steps,
115
156
  )
157
+ self.draft_extend_attn_backend = None
158
+ self.padded_static_len = self.speculative_num_steps + 1
159
+ self.has_prefill_wrapper_verify = True
116
160
  elif self.server_args.attention_backend == "triton":
117
161
  from sglang.srt.layers.attention.triton_backend import (
118
162
  TritonMultiStepDraftBackend,
119
163
  )
120
164
 
121
165
  self.draft_attn_backend = TritonMultiStepDraftBackend(
122
- self.model_runner,
166
+ self.draft_model_runner,
123
167
  self.topk,
124
168
  self.speculative_num_steps,
125
169
  )
170
+ self.draft_extend_attn_backend = None
171
+ self.padded_static_len = self.speculative_num_steps + 1
172
+ self.has_prefill_wrapper_verify = False
126
173
  elif self.server_args.attention_backend == "flashinfer_mla":
127
174
  from sglang.srt.layers.attention.flashinfer_mla_backend import (
128
175
  FlashInferMLAMultiStepDraftBackend,
129
176
  )
130
177
 
131
178
  self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
132
- self.model_runner,
179
+ self.draft_model_runner,
133
180
  self.topk,
134
181
  self.speculative_num_steps,
135
182
  )
183
+ self.draft_extend_attn_backend = None
184
+ self.padded_static_len = self.speculative_num_steps + 1
185
+ self.has_prefill_wrapper_verify = True
136
186
  else:
137
187
  raise ValueError(
138
188
  f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
139
189
  )
190
+
140
191
  self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
141
192
 
142
193
  def init_cuda_graphs(self):
143
194
  """Capture cuda graphs."""
144
195
  self.cuda_graph_runner = None
196
+ self.cuda_graph_runner_for_draft_extend = None
145
197
 
146
198
  if self.server_args.disable_cuda_graph:
147
199
  return
148
200
 
201
+ # Capture draft
149
202
  tic = time.time()
203
+ before_mem = get_available_gpu_memory(self.device, self.gpu_id)
150
204
  logger.info(
151
- f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
205
+ f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
152
206
  )
153
207
  self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
208
+ after_mem = get_available_gpu_memory(self.device, self.gpu_id)
154
209
  logger.info(
155
- f"Capture draft cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
210
+ f"Capture draft cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
156
211
  )
157
212
 
213
+ # Capture extend
214
+ if self.draft_extend_attn_backend:
215
+ raise NotImplementedError()
216
+
158
217
  @property
159
218
  def draft_model_runner(self):
160
219
  return self.model_runner
@@ -164,8 +223,8 @@ class EAGLEWorker(TpModelWorker):
164
223
  ) -> Tuple[LogitsProcessorOutput, List[int], int, int]:
165
224
  """Run speculative decoding forward.
166
225
 
167
- NOTE: Many states of batch is modified as you go through. It is not guaranteed
168
- the final output batch doesn't have the same state as the input.
226
+ NOTE: Many states of batch is modified as you go through. It is not guaranteed that
227
+ the final output batch have the same state as the input.
169
228
 
170
229
  Args:
171
230
  batch: The batch to run forward. The state of the batch is modified as it runs.
@@ -173,30 +232,42 @@ class EAGLEWorker(TpModelWorker):
173
232
  A tuple of the final logit output of the target model, next tokens accepeted,
174
233
  the batch id (used for overlap schedule), and number of accepeted tokens.
175
234
  """
176
- assert not batch.spec_algorithm.is_none()
177
235
  if batch.forward_mode.is_decode():
178
- spec_info, to_free_cache_loc = self.draft(batch)
236
+ with self.draft_tp_context(self.draft_model_runner.tp_group):
237
+ spec_info, to_free_cache_loc = self.draft(batch)
179
238
  logits_output, verify_output, model_worker_batch = self.verify(
180
239
  batch, spec_info
181
240
  )
241
+
182
242
  # Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.)
183
243
  self.token_to_kv_pool_allocator.free(to_free_cache_loc)
184
- # if it is None, means all requests are finished
185
- if batch.spec_info.verified_id is not None:
186
- self.forward_draft_extend_after_decode(batch)
187
244
 
245
+ # If it is None, it means all requests are finished
246
+ if batch.spec_info.verified_id is not None:
247
+ with self.draft_tp_context(self.draft_model_runner.tp_group):
248
+ self.forward_draft_extend_after_decode(batch)
188
249
  return (
189
250
  logits_output,
190
251
  verify_output.verified_id,
191
252
  model_worker_batch.bid,
192
253
  sum(verify_output.accept_length_per_req_cpu),
193
254
  )
194
-
255
+ elif batch.forward_mode.is_idle():
256
+ model_worker_batch = batch.get_model_worker_batch()
257
+ logits_output, next_token_ids, _ = (
258
+ self.target_worker.forward_batch_generation(
259
+ ForwardBatch.init_new(
260
+ model_worker_batch, self.target_worker.model_runner
261
+ )
262
+ )
263
+ )
264
+ return logits_output, next_token_ids, model_worker_batch.bid, 0, False
195
265
  else:
196
266
  logits_output, next_token_ids, bid = self.forward_target_extend(batch)
197
- self.forward_draft_extend(
198
- batch, logits_output.hidden_states, next_token_ids
199
- )
267
+ with self.draft_tp_context(self.draft_model_runner.tp_group):
268
+ self.forward_draft_extend(
269
+ batch, logits_output.hidden_states, next_token_ids
270
+ )
200
271
  return logits_output, next_token_ids, bid, 0
201
272
 
202
273
  def forward_target_extend(
@@ -226,6 +297,13 @@ class EAGLEWorker(TpModelWorker):
226
297
  num_seqs = batch.batch_size()
227
298
  spec_info = batch.spec_info
228
299
 
300
+ # Accumulate penalty
301
+ if batch.sampling_info.penalizer_orchestrator.is_required:
302
+ # This is a relaxed version of penalties for speculative decoding.
303
+ batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
304
+ spec_info.verified_id.to(torch.int64)
305
+ )
306
+
229
307
  # Allocate cache locations
230
308
  out_cache_loc = batch.alloc_token_slots(
231
309
  num_seqs * self.topk * self.speculative_num_steps
@@ -275,9 +353,7 @@ class EAGLEWorker(TpModelWorker):
275
353
  self.topk,
276
354
  self.speculative_num_steps,
277
355
  self.server_args.speculative_num_draft_tokens,
278
- batch.sampling_info.is_all_greedy,
279
356
  )
280
-
281
357
  return ret, out_cache_loc
282
358
 
283
359
  def draft_forward(self, forward_batch: ForwardBatch):
@@ -307,7 +383,7 @@ class EAGLEWorker(TpModelWorker):
307
383
  token_list.append(tree_info[1])
308
384
  parents_list.append(tree_info[2])
309
385
 
310
- # we don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here
386
+ # We don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here
311
387
  if i == self.speculative_num_steps - 1:
312
388
  break
313
389
 
@@ -322,7 +398,7 @@ class EAGLEWorker(TpModelWorker):
322
398
  spec_info.hidden_states = hidden_states
323
399
 
324
400
  # Run forward
325
- logits_output = self.model_runner.model.forward(
401
+ logits_output = self.draft_model_runner.model.forward(
326
402
  forward_batch.input_ids, forward_batch.positions, forward_batch
327
403
  )
328
404
  self._detect_nan_if_needed(logits_output)
@@ -351,11 +427,10 @@ class EAGLEWorker(TpModelWorker):
351
427
  # Post process based on verified outputs.
352
428
  # Pick indices that we care (accepeted)
353
429
  logits_output.next_token_logits = logits_output.next_token_logits[
354
- res.accepeted_indices_cpu
355
- ]
356
- logits_output.hidden_states = logits_output.hidden_states[
357
- res.accepeted_indices_cpu
430
+ res.accepeted_indices
358
431
  ]
432
+ logits_output.hidden_states = logits_output.hidden_states[res.accepeted_indices]
433
+
359
434
  # Prepare the batch for the next draft forwards.
360
435
  batch.forward_mode = ForwardMode.DECODE
361
436
  batch.spec_info = res.draft_input
@@ -407,7 +482,7 @@ class EAGLEWorker(TpModelWorker):
407
482
  batch_next_token_ids,
408
483
  ]
409
484
 
410
- # Add output logprobs to the request.
485
+ # Add output logprobs to the request
411
486
  pt = 0
412
487
  next_token_logprobs = logits_output.next_token_logprobs.tolist()
413
488
  verified_ids = batch_next_token_ids.tolist()
@@ -456,27 +531,38 @@ class EAGLEWorker(TpModelWorker):
456
531
  self.capture_for_decode(logits_output, forward_batch.spec_info)
457
532
 
458
533
  def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
459
- seq_lens_backup = batch.seq_lens
534
+ # Backup fileds that will be modified in-place
535
+ seq_lens_backup = batch.seq_lens.clone()
536
+ req_pool_indices_backup = batch.req_pool_indices
537
+ accept_length_backup = batch.spec_info.accept_length
538
+ return_logprob_backup = batch.return_logprob
539
+
540
+ # Prepare metadata
460
541
  batch.forward_mode = ForwardMode.DRAFT_EXTEND
461
- batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps)
542
+ batch.spec_info.prepare_extend_after_decode(
543
+ batch,
544
+ self.speculative_num_steps,
545
+ )
462
546
  batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
463
- # We don't need logprob for this extend.
464
- original_return_logprob = batch.return_logprob
465
547
  batch.return_logprob = False
466
548
  model_worker_batch = batch.get_model_worker_batch()
467
549
  forward_batch = ForwardBatch.init_new(
468
550
  model_worker_batch, self.draft_model_runner
469
551
  )
552
+
553
+ # Run
470
554
  logits_output = self.draft_model_runner.forward(forward_batch)
555
+
471
556
  self._detect_nan_if_needed(logits_output)
472
- assert forward_batch.spec_info is batch.spec_info
473
557
  self.capture_for_decode(logits_output, forward_batch.spec_info)
474
558
 
475
559
  # Restore backup.
476
560
  # This is because `seq_lens` can be modified in `prepare_extend_after_decode`
477
- batch.return_logprob = original_return_logprob
478
561
  batch.forward_mode = ForwardMode.DECODE
479
562
  batch.seq_lens = seq_lens_backup
563
+ batch.req_pool_indices = req_pool_indices_backup
564
+ batch.spec_info.accept_length = accept_length_backup
565
+ batch.return_logprob = return_logprob_backup
480
566
 
481
567
  def capture_for_decode(
482
568
  self, logits_output: LogitsProcessorOutput, draft_input: EagleDraftInput
@@ -489,7 +575,7 @@ class EAGLEWorker(TpModelWorker):
489
575
  if self.enable_nan_detection:
490
576
  logits = logits_output.next_token_logits
491
577
  if torch.any(torch.isnan(logits)):
492
- logger.warning("Detected errors during sampling! NaN in the logits.")
578
+ logger.error("Detected errors during sampling! NaN in the logits.")
493
579
  raise ValueError("Detected errors during sampling! NaN in the logits.")
494
580
 
495
581
 
@@ -4,17 +4,22 @@ from enum import IntEnum, auto
4
4
  class SpeculativeAlgorithm(IntEnum):
5
5
  NONE = auto()
6
6
  EAGLE = auto()
7
+ EAGLE3 = auto()
7
8
 
8
9
  def is_none(self):
9
10
  return self == SpeculativeAlgorithm.NONE
10
11
 
11
12
  def is_eagle(self):
12
- return self == SpeculativeAlgorithm.EAGLE
13
+ return self == SpeculativeAlgorithm.EAGLE or self == SpeculativeAlgorithm.EAGLE3
14
+
15
+ def is_eagle3(self):
16
+ return self == SpeculativeAlgorithm.EAGLE3
13
17
 
14
18
  @staticmethod
15
19
  def from_string(name: str):
16
20
  name_map = {
17
21
  "EAGLE": SpeculativeAlgorithm.EAGLE,
22
+ "EAGLE3": SpeculativeAlgorithm.EAGLE3,
18
23
  None: SpeculativeAlgorithm.NONE,
19
24
  }
20
25
  if name is not None:
@@ -1,3 +1,4 @@
1
+ import logging
1
2
  from abc import ABC
2
3
  from contextlib import contextmanager
3
4
 
@@ -8,6 +9,8 @@ try:
8
9
  except ImportError:
9
10
  pass
10
11
 
12
+ logger = logging.getLogger(__name__)
13
+
11
14
 
12
15
  class TorchMemorySaverAdapter(ABC):
13
16
  @staticmethod
@@ -16,6 +19,13 @@ class TorchMemorySaverAdapter(ABC):
16
19
  _TorchMemorySaverAdapterReal() if enable else _TorchMemorySaverAdapterNoop()
17
20
  )
18
21
 
22
+ def check_validity(self, caller_name):
23
+ if not self.enabled:
24
+ logger.warning(
25
+ f"`{caller_name}` will not save memory because torch_memory_saver is not enabled. "
26
+ f"Potential causes: `enable_memory_saver` is false, or torch_memory_saver has installation issues."
27
+ )
28
+
19
29
  def configure_subprocess(self):
20
30
  raise NotImplementedError
21
31
 
@@ -28,6 +38,10 @@ class TorchMemorySaverAdapter(ABC):
28
38
  def resume(self):
29
39
  raise NotImplementedError
30
40
 
41
+ @property
42
+ def enabled(self):
43
+ raise NotImplementedError
44
+
31
45
 
32
46
  class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
33
47
  def configure_subprocess(self):
@@ -42,6 +56,10 @@ class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
42
56
  def resume(self):
43
57
  return _primary_memory_saver.resume()
44
58
 
59
+ @property
60
+ def enabled(self):
61
+ return _primary_memory_saver.enabled
62
+
45
63
 
46
64
  class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
47
65
  @contextmanager
@@ -57,3 +75,7 @@ class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
57
75
 
58
76
  def resume(self):
59
77
  pass
78
+
79
+ @property
80
+ def enabled(self):
81
+ return False