sglang 0.4.3.post3__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -122,66 +122,17 @@ class ModelRunner:
122
122
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
123
123
 
124
124
  # Model-specific adjustment
125
- if (
126
- self.model_config.attention_arch == AttentionArch.MLA
127
- and not self.server_args.disable_mla
128
- ):
129
- # TODO: add MLA optimization on CPU
130
- if self.server_args.device != "cpu":
131
- if server_args.enable_flashinfer_mla:
132
- logger.info(
133
- "MLA optimization is turned on. Use flashinfer mla backend."
134
- )
135
- self.server_args.attention_backend = "flashinfer_mla"
136
- else:
137
- logger.info("MLA optimization is turned on. Use triton backend.")
138
- self.server_args.attention_backend = "triton"
125
+ self.model_specific_adjustment()
139
126
 
140
- if self.server_args.enable_double_sparsity:
141
- logger.info(
142
- "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
143
- )
144
- self.server_args.attention_backend = "triton"
145
- self.server_args.disable_cuda_graph = True
146
- if self.server_args.ds_heavy_channel_type is None:
147
- raise ValueError(
148
- "Please specify the heavy channel type for double sparsity optimization."
149
- )
150
- self.init_double_sparsity_channel_config(
151
- self.server_args.ds_heavy_channel_type
152
- )
153
-
154
- if self.is_multimodal:
155
- self.mem_fraction_static *= 0.95
156
- logger.info(
157
- f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
158
- f"because this is a multimodal model."
159
- )
160
-
161
- if self.model_config.hf_config.architectures == [
162
- "MllamaForConditionalGeneration"
163
- ]:
164
- logger.info("Automatically turn off --chunked-prefill-size for mllama.")
165
- server_args.chunked_prefill_size = -1
166
-
167
- if self.model_config.hf_config.architectures == [
168
- "Qwen2VLForConditionalGeneration"
169
- ]:
170
- # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
171
- logger.info(
172
- "Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
173
- )
174
- server_args.chunked_prefill_size = -1
175
- server_args.disable_radix_cache = True
176
-
177
- # Global vars
178
127
  if server_args.show_time_cost:
179
128
  enable_show_time_cost()
129
+
180
130
  if server_args.disable_outlines_disk_cache:
181
131
  from outlines.caching import disable_cache
182
132
 
183
133
  disable_cache()
184
134
 
135
+ # Global vars
185
136
  global_server_args_dict.update(
186
137
  {
187
138
  "attention_backend": server_args.attention_backend,
@@ -203,6 +154,7 @@ class ModelRunner:
203
154
  }
204
155
  )
205
156
 
157
+ # CPU offload
206
158
  set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))
207
159
 
208
160
  # Get memory before model loading
@@ -216,18 +168,6 @@ class ModelRunner:
216
168
  self.sampler = Sampler()
217
169
  self.load_model()
218
170
 
219
- # Handle the case where some of models don't finish loading.
220
- try:
221
- dist.monitored_barrier(
222
- group=get_tp_group().cpu_group,
223
- timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
224
- wait_all_ranks=True,
225
- )
226
- except RuntimeError:
227
- raise ValueError(
228
- f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
229
- ) from None
230
-
231
171
  # Apply torchao quantization
232
172
  torchao_applied = getattr(self.model, "torchao_applied", False)
233
173
  # In layered loading, torchao may have been applied
@@ -244,9 +184,11 @@ class ModelRunner:
244
184
  else:
245
185
  self.torch_tp_applied = False
246
186
 
247
- # Init memory pool and attention backends
187
+ # Init lora
248
188
  if server_args.lora_paths is not None:
249
189
  self.init_lora_manager()
190
+
191
+ # Init memory pool and attention backends
250
192
  self.init_memory_pool(
251
193
  min_per_gpu_memory,
252
194
  server_args.max_running_requests,
@@ -260,10 +202,63 @@ class ModelRunner:
260
202
  self.cuda_graph_runner = None
261
203
  self.init_attention_backend()
262
204
 
205
+ def model_specific_adjustment(self):
206
+ server_args = self.server_args
207
+
208
+ if (
209
+ self.model_config.attention_arch == AttentionArch.MLA
210
+ and not server_args.disable_mla
211
+ ):
212
+ # TODO: add MLA optimization on CPU
213
+ if server_args.device != "cpu":
214
+ if server_args.enable_flashinfer_mla:
215
+ logger.info(
216
+ "MLA optimization is turned on. Use flashinfer mla backend."
217
+ )
218
+ server_args.attention_backend = "flashinfer_mla"
219
+ else:
220
+ logger.info("MLA optimization is turned on. Use triton backend.")
221
+ server_args.attention_backend = "triton"
222
+
223
+ if server_args.enable_double_sparsity:
224
+ logger.info(
225
+ "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
226
+ )
227
+ server_args.attention_backend = "triton"
228
+ server_args.disable_cuda_graph = True
229
+ if server_args.ds_heavy_channel_type is None:
230
+ raise ValueError(
231
+ "Please specify the heavy channel type for double sparsity optimization."
232
+ )
233
+ self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
234
+
235
+ if self.is_multimodal:
236
+ self.mem_fraction_static *= 0.95
237
+ logger.info(
238
+ f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
239
+ f"because this is a multimodal model."
240
+ )
241
+
242
+ if self.model_config.hf_config.architectures == [
243
+ "MllamaForConditionalGeneration"
244
+ ]:
245
+ logger.info("Automatically turn off --chunked-prefill-size for mllama.")
246
+ server_args.chunked_prefill_size = -1
247
+
248
+ if self.model_config.hf_config.architectures == [
249
+ "Qwen2VLForConditionalGeneration"
250
+ ]:
251
+ # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
252
+ logger.info(
253
+ "Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
254
+ )
255
+ server_args.chunked_prefill_size = -1
256
+ server_args.disable_radix_cache = True
257
+
263
258
  def init_torch_distributed(self):
264
259
  logger.info("Init torch distributed begin.")
265
- torch.get_device_module(self.device).set_device(self.gpu_id)
266
260
 
261
+ torch.get_device_module(self.device).set_device(self.gpu_id)
267
262
  if self.device == "cuda":
268
263
  backend = "nccl"
269
264
  elif self.device == "xpu":
@@ -400,6 +395,18 @@ class ModelRunner:
400
395
  f"mem usage={(before_avail_memory - after_avail_memory):.2f} GB."
401
396
  )
402
397
 
398
+ # Handle the case where some ranks do not finish loading.
399
+ try:
400
+ dist.monitored_barrier(
401
+ group=get_tp_group().cpu_group,
402
+ timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
403
+ wait_all_ranks=True,
404
+ )
405
+ except RuntimeError:
406
+ raise ValueError(
407
+ f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
408
+ ) from None
409
+
403
410
  def update_weights_from_disk(
404
411
  self, model_path: str, load_format: str
405
412
  ) -> tuple[bool, str]:
@@ -710,15 +717,6 @@ class ModelRunner:
710
717
  # Draft worker shares req_to_token_pool with the target worker.
711
718
  assert self.is_draft_worker
712
719
 
713
- if self.token_to_kv_pool_allocator is None:
714
- self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
715
- self.max_total_num_tokens,
716
- dtype=self.kv_cache_dtype,
717
- device=self.device,
718
- )
719
- else:
720
- assert self.is_draft_worker
721
-
722
720
  if (
723
721
  self.model_config.attention_arch == AttentionArch.MLA
724
722
  and not self.server_args.disable_mla
@@ -753,6 +751,17 @@ class ModelRunner:
753
751
  device=self.device,
754
752
  enable_memory_saver=self.server_args.enable_memory_saver,
755
753
  )
754
+
755
+ if self.token_to_kv_pool_allocator is None:
756
+ self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
757
+ self.max_total_num_tokens,
758
+ dtype=self.kv_cache_dtype,
759
+ device=self.device,
760
+ kvcache=self.token_to_kv_pool,
761
+ )
762
+ else:
763
+ assert self.is_draft_worker
764
+
756
765
  logger.info(
757
766
  f"Memory pool end. "
758
767
  f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
@@ -770,6 +779,10 @@ class ModelRunner:
770
779
  def init_attention_backend(self):
771
780
  """Init attention kernel backend."""
772
781
  if self.server_args.attention_backend == "flashinfer":
782
+ # Init streams
783
+ if self.server_args.speculative_algorithm == "EAGLE":
784
+ self.plan_stream_for_flashinfer = torch.cuda.Stream()
785
+
773
786
  self.attn_backend = FlashInferAttnBackend(self)
774
787
  elif self.server_args.attention_backend == "triton":
775
788
  assert self.sliding_window_size is None, (
@@ -878,18 +891,24 @@ class ModelRunner:
878
891
  forward_batch.input_ids, forward_batch.positions, forward_batch
879
892
  )
880
893
 
881
- def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
894
+ def forward(
895
+ self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
896
+ ) -> LogitsProcessorOutput:
882
897
  if (
883
898
  forward_batch.forward_mode.is_cuda_graph()
884
899
  and self.cuda_graph_runner
885
900
  and self.cuda_graph_runner.can_run(forward_batch)
886
901
  ):
887
- return self.cuda_graph_runner.replay(forward_batch)
902
+ return self.cuda_graph_runner.replay(
903
+ forward_batch, skip_attn_backend_init=skip_attn_backend_init
904
+ )
888
905
 
889
906
  if forward_batch.forward_mode.is_decode():
890
907
  return self.forward_decode(forward_batch)
891
908
  elif forward_batch.forward_mode.is_extend():
892
- return self.forward_extend(forward_batch)
909
+ return self.forward_extend(
910
+ forward_batch, skip_attn_backend_init=skip_attn_backend_init
911
+ )
893
912
  elif forward_batch.forward_mode.is_idle():
894
913
  return self.forward_idle(forward_batch)
895
914
  else:
sglang/srt/server_args.py CHANGED
@@ -71,7 +71,6 @@ class ServerArgs:
71
71
  schedule_policy: str = "fcfs"
72
72
  schedule_conservativeness: float = 1.0
73
73
  cpu_offload_gb: int = 0
74
- prefill_only_one_req: bool = False
75
74
 
76
75
  # Other runtime options
77
76
  tp_size: int = 1
@@ -277,19 +276,17 @@ class ServerArgs:
277
276
  self.speculative_algorithm = "EAGLE"
278
277
 
279
278
  if self.speculative_algorithm == "EAGLE":
280
- self.disable_overlap_schedule = True
281
- self.prefill_only_one_req = True
282
- self.disable_cuda_graph_padding = True
283
279
  if self.max_running_requests is None:
284
280
  self.max_running_requests = 32
281
+ self.disable_overlap_schedule = True
282
+ self.disable_cuda_graph_padding = True
285
283
  logger.info(
286
284
  "Overlap scheduler are disabled because of using "
287
285
  "eagle speculative decoding."
288
- "Max running request set to 32 because of using eagle speculative decoding."
289
286
  )
290
287
  # The token generated from the verify step is counted.
291
288
  # If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
292
- assert self.speculative_num_steps < self.speculative_num_draft_tokens
289
+ # assert self.speculative_num_steps < self.speculative_num_draft_tokens
293
290
 
294
291
  # GGUF
295
292
  if (
@@ -509,12 +506,6 @@ class ServerArgs:
509
506
  default=ServerArgs.cpu_offload_gb,
510
507
  help="How many GBs of RAM to reserve for CPU offloading",
511
508
  )
512
- parser.add_argument(
513
- "--prefill-only-one-req",
514
- type=bool,
515
- help="If true, we only prefill one request at one prefill batch",
516
- default=ServerArgs.prefill_only_one_req,
517
- )
518
509
 
519
510
  # Other runtime options
520
511
  parser.add_argument(
@@ -26,7 +26,12 @@ def build_tree_kernel_efficient_preprocess(
26
26
 
27
27
  draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
28
28
  draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten()
29
- parent_list = torch.cat(parents_list[:-1], dim=1)
29
+
30
+ if len(parents_list) > 1:
31
+ parent_list = torch.cat(parents_list[:-1], dim=1)
32
+ else:
33
+ batch_size = parents_list[0].shape[0]
34
+ parent_list = torch.empty(batch_size, 0, device=parents_list[0].device)
30
35
 
31
36
  return parent_list, top_scores_index, draft_tokens
32
37
 
@@ -1,7 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import bisect
4
- import time
5
4
  from typing import TYPE_CHECKING, Callable
6
5
 
7
6
  import torch
@@ -162,20 +161,11 @@ class EAGLEDraftCudaGraphRunner:
162
161
 
163
162
  run_once()
164
163
 
165
- torch.cuda.synchronize()
166
- self.model_runner.tp_group.barrier()
167
-
168
- torch.cuda.synchronize()
169
- self.model_runner.tp_group.barrier()
170
-
171
164
  with torch.cuda.graph(
172
165
  graph, pool=get_global_graph_memory_pool(), stream=stream
173
166
  ):
174
167
  out = run_once()
175
168
 
176
- torch.cuda.synchronize()
177
- self.model_runner.tp_group.barrier()
178
-
179
169
  set_global_graph_memory_pool(graph.pool())
180
170
  return graph, out
181
171
 
@@ -204,7 +194,7 @@ class EAGLEDraftCudaGraphRunner:
204
194
 
205
195
  # Attention backend
206
196
  self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
207
- forward_batch
197
+ forward_batch, forward_batch.batch_size
208
198
  )
209
199
 
210
200
  # Replay
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import TYPE_CHECKING, Dict, List
4
+ from typing import TYPE_CHECKING, List
5
5
 
6
6
  import torch
7
7
  import torch.nn.functional as F
@@ -62,6 +62,7 @@ class EagleDraftInput:
62
62
  batch.input_ids[pt : pt + extend_len] = torch.concat(
63
63
  (input_ids[1:], self.verified_id[i].reshape(1))
64
64
  )
65
+ pt += extend_len
65
66
 
66
67
  def prepare_extend_after_decode(self, batch: ScheduleBatch, speculative_num_steps):
67
68
  assert self.verified_id.numel() == batch.out_cache_loc.shape[0]
@@ -1,20 +1,19 @@
1
1
  import logging
2
2
  import os
3
3
  import time
4
- from typing import Dict, List, Optional, Tuple, Union
4
+ from typing import List, Optional, Tuple
5
5
 
6
6
  import torch
7
7
  from huggingface_hub import snapshot_download
8
8
 
9
9
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
10
- from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
10
+ from sglang.srt.managers.schedule_batch import ScheduleBatch
11
11
  from sglang.srt.managers.tp_worker import TpModelWorker
12
12
  from sglang.srt.model_executor.forward_batch_info import (
13
13
  CaptureHiddenMode,
14
14
  ForwardBatch,
15
15
  ForwardMode,
16
16
  )
17
- from sglang.srt.model_executor.model_runner import ModelRunner
18
17
  from sglang.srt.server_args import ServerArgs
19
18
  from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
20
19
  EAGLEDraftCudaGraphRunner,
@@ -27,7 +26,6 @@ from sglang.srt.speculative.eagle_utils import (
27
26
  fast_topk,
28
27
  select_top_k_tokens,
29
28
  )
30
- from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
31
29
  from sglang.srt.utils import get_available_gpu_memory
32
30
 
33
31
  logger = logging.getLogger(__name__)
@@ -44,16 +42,30 @@ class EAGLEWorker(TpModelWorker):
44
42
  nccl_port: int,
45
43
  target_worker: TpModelWorker,
46
44
  ):
45
+ # Parse arguments
46
+ self.server_args = server_args
47
+ self.topk = server_args.speculative_eagle_topk
48
+ self.speculative_num_steps = server_args.speculative_num_steps
49
+ self.padded_static_len = self.speculative_num_steps + 1
50
+ self.enable_nan_detection = server_args.enable_nan_detection
51
+ self.gpu_id = gpu_id
52
+ self.device = server_args.device
53
+ self.target_worker = target_worker
54
+
47
55
  # Override context length with target model's context length
48
56
  server_args.context_length = target_worker.model_runner.model_config.context_len
49
- os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1"
50
57
 
51
58
  # Do not capture cuda graph in `super().__init__()`
52
- # We will capture it later
59
+ # It will be captured later.
53
60
  backup_disable_cuda_graph = server_args.disable_cuda_graph
54
61
  server_args.disable_cuda_graph = True
62
+ # Share the allocator with a target worker.
63
+ # Draft and target worker own their own KV cache pools.
64
+ self.req_to_token_pool, self.token_to_kv_pool_allocator = (
65
+ target_worker.get_memory_pool()
66
+ )
55
67
 
56
- # Lossy optimization by using hot tokens
68
+ # Load hot token ids
57
69
  if server_args.speculative_token_map is not None:
58
70
  self.hot_token_id = load_token_map(server_args.speculative_token_map)
59
71
  server_args.json_model_override_args = (
@@ -62,13 +74,7 @@ class EAGLEWorker(TpModelWorker):
62
74
  else:
63
75
  self.hot_token_id = None
64
76
 
65
- # We share the allocator with a target worker. Draft/target worker
66
- # owns its own KV cache.
67
- self.req_to_token_pool, self.token_to_kv_pool_allocator = (
68
- target_worker.get_memory_pool()
69
- )
70
-
71
- # Init target worker
77
+ # Init draft worker
72
78
  super().__init__(
73
79
  gpu_id=gpu_id,
74
80
  tp_rank=tp_rank,
@@ -79,18 +85,6 @@ class EAGLEWorker(TpModelWorker):
79
85
  req_to_token_pool=self.req_to_token_pool,
80
86
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
81
87
  )
82
- self.target_worker = target_worker
83
-
84
- # Parse arguments
85
- self.topk = server_args.speculative_eagle_topk
86
- self.speculative_num_steps = server_args.speculative_num_steps
87
- self.speculative_algorithm = SpeculativeAlgorithm.from_string(
88
- server_args.speculative_algorithm
89
- )
90
- self.server_args = server_args
91
- self.use_nan_detection = self.server_args.enable_nan_detection
92
- self.device = self.model_runner.device
93
- self.gpu_id = self.model_runner.gpu_id
94
88
 
95
89
  # Share the embedding and lm_head
96
90
  embed, head = self.target_worker.model_runner.model.get_embed_and_head()
@@ -103,8 +97,12 @@ class EAGLEWorker(TpModelWorker):
103
97
  backup_disable_cuda_graph
104
98
  )
105
99
 
100
+ self.init_attention_backend()
101
+ self.init_cuda_graphs()
102
+
103
+ def init_attention_backend(self):
106
104
  # Create multi-step attn backends and cuda graph runners
107
- if server_args.attention_backend == "flashinfer":
105
+ if self.server_args.attention_backend == "flashinfer":
108
106
  from sglang.srt.layers.attention.flashinfer_backend import (
109
107
  FlashInferMultiStepDraftBackend,
110
108
  )
@@ -114,7 +112,7 @@ class EAGLEWorker(TpModelWorker):
114
112
  self.topk,
115
113
  self.speculative_num_steps,
116
114
  )
117
- elif server_args.attention_backend == "triton":
115
+ elif self.server_args.attention_backend == "triton":
118
116
  from sglang.srt.layers.attention.triton_backend import (
119
117
  TritonMultiStepDraftBackend,
120
118
  )
@@ -126,11 +124,9 @@ class EAGLEWorker(TpModelWorker):
126
124
  )
127
125
  else:
128
126
  raise ValueError(
129
- f"EAGLE is not supportted in attention backend {server_args.attention_backend}"
127
+ f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
130
128
  )
131
-
132
129
  self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
133
- self.init_cuda_graphs()
134
130
 
135
131
  def init_cuda_graphs(self):
136
132
  """Capture cuda graphs."""
@@ -356,6 +352,41 @@ class EAGLEWorker(TpModelWorker):
356
352
  batch.forward_mode = ForwardMode.DECODE
357
353
  batch.spec_info = res.draft_input
358
354
 
355
+ if batch.return_logprob:
356
+ # Compute output logprobs using the sampler.
357
+ num_tokens_per_req = [
358
+ accept + 1 for accept in res.accept_length_per_req_cpu
359
+ ]
360
+ self.target_worker.model_runner.update_output_logprobs(
361
+ logits_output,
362
+ batch.sampling_info,
363
+ batch.top_logprobs_nums,
364
+ batch.token_ids_logprobs,
365
+ res.verified_id,
366
+ # +1 for bonus token.
367
+ num_tokens_per_req=num_tokens_per_req,
368
+ )
369
+
370
+ # Add output logprobs to the request.
371
+ pt = 0
372
+ # NOTE: tolist() of these values are skipped when output is processed
373
+ next_token_logprobs = res.logits_output.next_token_logprobs.tolist()
374
+ verified_ids = res.verified_id.tolist()
375
+ for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
376
+ for _ in range(num_tokens):
377
+ if req.return_logprob:
378
+ token_id = verified_ids[pt]
379
+ req.output_token_logprobs_val.append(next_token_logprobs[pt])
380
+ req.output_token_logprobs_idx.append(token_id)
381
+ if req.top_logprobs_num > 0:
382
+ req.output_top_logprobs_val.append(
383
+ res.logits_output.next_token_top_logprobs_val[pt]
384
+ )
385
+ req.output_top_logprobs_idx.append(
386
+ res.logits_output.next_token_top_logprobs_idx[pt]
387
+ )
388
+ pt += 1
389
+
359
390
  return logits_output, res, model_worker_batch
360
391
 
361
392
  def forward_draft_extend(
@@ -381,6 +412,7 @@ class EAGLEWorker(TpModelWorker):
381
412
  forward_batch = ForwardBatch.init_new(
382
413
  model_worker_batch, self.draft_model_runner
383
414
  )
415
+ forward_batch.return_logprob = False
384
416
  logits_output = self.draft_model_runner.forward(forward_batch)
385
417
  self._detect_nan_if_needed(logits_output)
386
418
  assert isinstance(forward_batch.spec_info, EagleDraftInput)
@@ -393,6 +425,8 @@ class EAGLEWorker(TpModelWorker):
393
425
  batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps)
394
426
  batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
395
427
  # We don't need logprob for this extend.
428
+ original_return_logprob = batch.return_logprob
429
+ batch.return_logprob = False
396
430
  model_worker_batch = batch.get_model_worker_batch()
397
431
  forward_batch = ForwardBatch.init_new(
398
432
  model_worker_batch, self.draft_model_runner
@@ -404,6 +438,7 @@ class EAGLEWorker(TpModelWorker):
404
438
 
405
439
  # Restore backup.
406
440
  # This is because `seq_lens` can be modified in `prepare_extend_after_decode`
441
+ batch.return_logprob = original_return_logprob
407
442
  batch.forward_mode = ForwardMode.DECODE
408
443
  batch.seq_lens = seq_lens_backup
409
444
 
@@ -415,7 +450,7 @@ class EAGLEWorker(TpModelWorker):
415
450
  draft_input.hidden_states = logits_output.hidden_states
416
451
 
417
452
  def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
418
- if self.use_nan_detection:
453
+ if self.enable_nan_detection:
419
454
  logits = logits_output.next_token_logits
420
455
  if torch.any(torch.isnan(logits)):
421
456
  logger.warning("Detected errors during sampling! NaN in the logits.")
sglang/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.4.3.post3"
1
+ __version__ = "0.4.3.post4"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: sglang
3
- Version: 0.4.3.post3
3
+ Version: 0.4.3.post4
4
4
  Summary: SGLang is yet another fast serving framework for large language models and vision language models.
5
5
  License: Apache License
6
6
  Version 2.0, January 2004
@@ -239,6 +239,7 @@ Requires-Dist: xgrammar==0.1.14; extra == "runtime-common"
239
239
  Requires-Dist: ninja; extra == "runtime-common"
240
240
  Requires-Dist: transformers==4.48.3; extra == "runtime-common"
241
241
  Requires-Dist: llguidance>=0.6.15; extra == "runtime-common"
242
+ Requires-Dist: datasets; extra == "runtime-common"
242
243
  Provides-Extra: srt
243
244
  Requires-Dist: sglang[runtime_common]; extra == "srt"
244
245
  Requires-Dist: sgl-kernel==0.0.3.post6; extra == "srt"