sglang 0.4.1.post7__py3-none-any.whl → 0.4.2.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. sglang/bench_offline_throughput.py +17 -11
  2. sglang/bench_one_batch.py +14 -6
  3. sglang/bench_serving.py +47 -44
  4. sglang/lang/chat_template.py +31 -0
  5. sglang/srt/configs/load_config.py +1 -0
  6. sglang/srt/distributed/device_communicators/custom_all_reduce.py +5 -2
  7. sglang/srt/entrypoints/engine.py +5 -2
  8. sglang/srt/entrypoints/http_server.py +24 -0
  9. sglang/srt/function_call_parser.py +494 -0
  10. sglang/srt/layers/activation.py +5 -5
  11. sglang/srt/layers/attention/triton_ops/prefill_attention.py +6 -0
  12. sglang/srt/layers/attention/vision.py +243 -40
  13. sglang/srt/layers/dp_attention.py +3 -1
  14. sglang/srt/layers/layernorm.py +5 -5
  15. sglang/srt/layers/linear.py +24 -9
  16. sglang/srt/layers/logits_processor.py +1 -1
  17. sglang/srt/layers/moe/ep_moe/layer.py +20 -12
  18. sglang/srt/layers/moe/fused_moe_native.py +17 -3
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  20. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -1
  21. sglang/srt/layers/moe/fused_moe_triton/layer.py +9 -0
  22. sglang/srt/layers/parameter.py +16 -7
  23. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  24. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  25. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  26. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  27. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  28. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  29. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  30. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  31. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  32. sglang/srt/layers/quantization/fp8.py +11 -1
  33. sglang/srt/layers/rotary_embedding.py +34 -13
  34. sglang/srt/layers/sampler.py +33 -10
  35. sglang/srt/layers/torchao_utils.py +12 -6
  36. sglang/srt/managers/detokenizer_manager.py +1 -0
  37. sglang/srt/managers/image_processor.py +77 -38
  38. sglang/srt/managers/io_struct.py +36 -5
  39. sglang/srt/managers/schedule_batch.py +31 -25
  40. sglang/srt/managers/scheduler.py +78 -38
  41. sglang/srt/managers/tokenizer_manager.py +4 -0
  42. sglang/srt/mem_cache/base_prefix_cache.py +4 -0
  43. sglang/srt/mem_cache/chunk_cache.py +3 -0
  44. sglang/srt/mem_cache/radix_cache.py +30 -1
  45. sglang/srt/model_executor/cuda_graph_runner.py +23 -25
  46. sglang/srt/model_executor/forward_batch_info.py +5 -7
  47. sglang/srt/model_executor/model_runner.py +7 -4
  48. sglang/srt/model_loader/loader.py +75 -0
  49. sglang/srt/model_loader/weight_utils.py +91 -5
  50. sglang/srt/models/commandr.py +14 -2
  51. sglang/srt/models/dbrx.py +9 -1
  52. sglang/srt/models/deepseek_v2.py +3 -3
  53. sglang/srt/models/gemma2.py +9 -1
  54. sglang/srt/models/grok.py +1 -0
  55. sglang/srt/models/minicpm3.py +3 -3
  56. sglang/srt/models/minicpmv.py +129 -76
  57. sglang/srt/models/mllama.py +16 -56
  58. sglang/srt/models/qwen2.py +4 -1
  59. sglang/srt/models/qwen2_vl.py +18 -8
  60. sglang/srt/models/torch_native_llama.py +17 -4
  61. sglang/srt/openai_api/adapter.py +139 -37
  62. sglang/srt/openai_api/protocol.py +5 -4
  63. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
  64. sglang/srt/sampling/sampling_batch_info.py +4 -14
  65. sglang/srt/server.py +2 -2
  66. sglang/srt/server_args.py +26 -1
  67. sglang/srt/speculative/eagle_utils.py +37 -15
  68. sglang/srt/speculative/eagle_worker.py +11 -13
  69. sglang/srt/utils.py +62 -67
  70. sglang/test/test_programs.py +1 -0
  71. sglang/test/test_utils.py +81 -22
  72. sglang/utils.py +42 -0
  73. sglang/version.py +1 -1
  74. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/METADATA +8 -8
  75. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/RECORD +78 -67
  76. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/LICENSE +0 -0
  77. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/WHEEL +0 -0
  78. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py CHANGED
@@ -75,6 +75,7 @@ class ServerArgs:
75
75
  # Other runtime options
76
76
  tp_size: int = 1
77
77
  stream_interval: int = 1
78
+ stream_output: bool = False
78
79
  random_seed: Optional[int] = None
79
80
  constrained_json_whitespace_pattern: Optional[str] = None
80
81
  watchdog_timeout: float = 300
@@ -161,6 +162,8 @@ class ServerArgs:
161
162
 
162
163
  # Custom logit processor
163
164
  enable_custom_logit_processor: bool = False
165
+ tool_call_parser: str = None
166
+ enable_hierarchical_cache: bool = False
164
167
 
165
168
  def __post_init__(self):
166
169
  # Set missing default values
@@ -317,6 +320,7 @@ class ServerArgs:
317
320
  "dummy",
318
321
  "gguf",
319
322
  "bitsandbytes",
323
+ "layered",
320
324
  ],
321
325
  help="The format of the model weights to load. "
322
326
  '"auto" will try to load the weights in the safetensors format '
@@ -330,7 +334,10 @@ class ServerArgs:
330
334
  "which is mainly for profiling."
331
335
  '"gguf" will load the weights in the gguf format. '
332
336
  '"bitsandbytes" will load the weights using bitsandbytes '
333
- "quantization.",
337
+ "quantization."
338
+ '"layered" loads weights layer by layer so that one can quantize a '
339
+ "layer before loading another to make the peak memory envelope "
340
+ "smaller.",
334
341
  )
335
342
  parser.add_argument(
336
343
  "--trust-remote-code",
@@ -495,6 +502,11 @@ class ServerArgs:
495
502
  default=ServerArgs.stream_interval,
496
503
  help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
497
504
  )
505
+ parser.add_argument(
506
+ "--stream-output",
507
+ action="store_true",
508
+ help="Whether to output as a sequence of disjoint segments.",
509
+ )
498
510
  parser.add_argument(
499
511
  "--random-seed",
500
512
  type=int,
@@ -873,6 +885,19 @@ class ServerArgs:
873
885
  action="store_true",
874
886
  help="Enable users to pass custom logit processors to the server (disabled by default for security)",
875
887
  )
888
+ # Function Calling
889
+ parser.add_argument(
890
+ "--tool-call-parser",
891
+ type=str,
892
+ choices=["qwen25", "mistral", "llama3"],
893
+ default=ServerArgs.tool_call_parser,
894
+ help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', and 'llama3'.",
895
+ )
896
+ parser.add_argument(
897
+ "--enable-hierarchical-cache",
898
+ action="store_true",
899
+ help="Enable hierarchical cache",
900
+ )
876
901
 
877
902
  @classmethod
878
903
  def from_cli_args(cls, args: argparse.Namespace):
@@ -180,7 +180,6 @@ def generate_draft_decode_kv_indices(
180
180
  class EAGLEDraftInput(SpecInfo):
181
181
  def __init__(self):
182
182
  self.prev_mode = ForwardMode.DECODE
183
- self.sample_output = None
184
183
 
185
184
  self.scores: torch.Tensor = None
186
185
  self.score_list: List[torch.Tensor] = []
@@ -190,12 +189,16 @@ class EAGLEDraftInput(SpecInfo):
190
189
  self.cache_list: List[torch.Tenor] = []
191
190
  self.iter = 0
192
191
 
192
+ # shape: (b, hidden_size)
193
193
  self.hidden_states: torch.Tensor = None
194
+ # shape: (b,)
194
195
  self.verified_id: torch.Tensor = None
196
+ # shape: (b, vocab_size)
197
+ self.sample_output: torch.Tensor = None
198
+
195
199
  self.positions: torch.Tensor = None
196
200
  self.accept_length: torch.Tensor = None
197
- self.has_finished: bool = False
198
- self.unfinished_index: List[int] = None
201
+ self.accept_length_cpu: List[int] = None
199
202
 
200
203
  def load_server_args(self, server_args: ServerArgs):
201
204
  self.topk: int = server_args.speculative_eagle_topk
@@ -218,7 +221,7 @@ class EAGLEDraftInput(SpecInfo):
218
221
  :pre_len
219
222
  ] = req.prefix_indices
220
223
 
221
- batch.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
224
+ batch.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
222
225
  out_cache_loc[pt : pt + req.extend_input_len]
223
226
  )
224
227
 
@@ -228,6 +231,14 @@ class EAGLEDraftInput(SpecInfo):
228
231
  assert len(batch.extend_lens) == 1
229
232
  batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id))
230
233
 
234
+ def filter_batch(
235
+ self,
236
+ new_indices: torch.Tensor,
237
+ ):
238
+ self.sample_output = self.sample_output[: len(new_indices)]
239
+ self.hidden_states = self.hidden_states[: len(new_indices)]
240
+ self.verified_id = self.verified_id[: len(new_indices)]
241
+
231
242
  def prepare_for_decode(self, batch: ScheduleBatch):
232
243
  prob = self.sample_output # shape: (b * top_k, vocab) or (b, vocab)
233
244
  top = torch.topk(prob, self.topk, dim=-1)
@@ -287,7 +298,9 @@ class EAGLEDraftInput(SpecInfo):
287
298
  self.cache_list.append(batch.out_cache_loc)
288
299
  self.positions = (
289
300
  batch.seq_lens[:, None]
290
- + torch.ones([1, self.topk], device="cuda", dtype=torch.long) * self.iter
301
+ + torch.full(
302
+ [1, self.topk], fill_value=self.iter, device="cuda", dtype=torch.long
303
+ )
291
304
  ).flatten()
292
305
 
293
306
  bs = len(batch.seq_lens)
@@ -304,24 +317,25 @@ class EAGLEDraftInput(SpecInfo):
304
317
 
305
318
  def prepare_extend_after_decode(self, batch: ScheduleBatch):
306
319
  batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel())
307
- batch.extend_lens = (self.accept_length + 1).tolist()
320
+ accept_length_cpu = batch.spec_info.accept_length_cpu
321
+ batch.extend_lens = [x + 1 for x in accept_length_cpu]
322
+ batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
323
+ seq_lens_cpu = batch.seq_lens.tolist()
308
324
 
309
325
  pt = 0
310
- seq_lens = batch.seq_lens.tolist()
311
-
312
326
  i = 0
313
-
314
327
  for req in batch.reqs:
315
328
  if req.finished():
316
329
  continue
317
330
  # assert seq_len - pre_len == req.extend_input_len
318
- input_len = self.accept_length[i] + 1
319
- seq_len = seq_lens[i]
331
+ input_len = batch.extend_lens[i]
332
+ seq_len = seq_lens_cpu[i]
320
333
  batch.req_to_token_pool.req_to_token[req.req_pool_idx][
321
334
  seq_len - input_len : seq_len
322
335
  ] = batch.out_cache_loc[pt : pt + input_len]
323
336
  pt += input_len
324
337
  i += 1
338
+ assert pt == batch.out_cache_loc.shape[0]
325
339
 
326
340
  self.positions = torch.empty_like(self.verified_id)
327
341
  new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long)
@@ -337,7 +351,7 @@ class EAGLEDraftInput(SpecInfo):
337
351
  triton.next_power_of_2(self.spec_steps + 1),
338
352
  )
339
353
 
340
- batch.seq_lens_sum = sum(batch.seq_lens)
354
+ batch.seq_lens_sum = sum(seq_lens_cpu)
341
355
  batch.input_ids = self.verified_id
342
356
  self.verified_id = new_verified_id
343
357
 
@@ -565,6 +579,8 @@ class EagleVerifyInput(SpecInfo):
565
579
  finished_extend_len = {} # {rid:accept_length + 1}
566
580
  accept_index_cpu = accept_index.tolist()
567
581
  predict_cpu = predict.tolist()
582
+ has_finished = False
583
+
568
584
  # iterate every accepted token and check if req has finished after append the token
569
585
  # should be checked BEFORE free kv cache slots
570
586
  for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
@@ -578,7 +594,7 @@ class EagleVerifyInput(SpecInfo):
578
594
  finished_extend_len[req.rid] = j + 1
579
595
  req.check_finished()
580
596
  if req.finished():
581
- draft_input.has_finished = True
597
+ has_finished = True
582
598
  # set all tokens after finished token to -1 and break
583
599
  accept_index[i, j + 1 :] = -1
584
600
  break
@@ -587,12 +603,12 @@ class EagleVerifyInput(SpecInfo):
587
603
  if not req.finished():
588
604
  new_accept_index.extend(new_accept_index_)
589
605
  unfinished_index.append(i)
606
+ req.spec_verify_ct += 1
590
607
  accept_length = (accept_index != -1).sum(dim=1) - 1
591
608
 
592
609
  accept_index = accept_index[accept_index != -1]
593
610
  accept_length_cpu = accept_length.tolist()
594
611
  verified_id = predict[accept_index]
595
- verified_id_cpu = verified_id.tolist()
596
612
 
597
613
  evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
598
614
  evict_mask[accept_index] = False
@@ -614,7 +630,13 @@ class EagleVerifyInput(SpecInfo):
614
630
  draft_input.verified_id = predict[new_accept_index]
615
631
  draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index]
616
632
  draft_input.accept_length = accept_length[unfinished_index]
617
- draft_input.unfinished_index = unfinished_index
633
+ draft_input.accept_length_cpu = [
634
+ accept_length_cpu[i] for i in unfinished_index
635
+ ]
636
+ if has_finished:
637
+ draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index]
638
+ else:
639
+ draft_input.seq_lens_for_draft_extend = batch.seq_lens
618
640
 
619
641
  logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
620
642
  return (
@@ -13,6 +13,7 @@ from sglang.srt.model_executor.forward_batch_info import (
13
13
  from sglang.srt.model_executor.model_runner import ModelRunner
14
14
  from sglang.srt.server_args import ServerArgs
15
15
  from sglang.srt.speculative.eagle_utils import EAGLEDraftInput
16
+ from sglang.srt.utils import rank0_print
16
17
 
17
18
 
18
19
  class EAGLEWorker(TpModelWorker):
@@ -50,18 +51,18 @@ class EAGLEWorker(TpModelWorker):
50
51
 
51
52
  def forward_draft_decode(self, batch: ScheduleBatch):
52
53
  batch.spec_info.prepare_for_decode(batch)
54
+ batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
53
55
  model_worker_batch = batch.get_model_worker_batch()
54
56
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
55
- forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
56
57
  logits_output = self.model_runner.forward(forward_batch)
57
58
  self.capture_for_decode(logits_output, forward_batch)
58
59
 
59
60
  def forward_draft_extend(self, batch: ScheduleBatch):
60
61
  self._set_mem_pool(batch, self.model_runner)
61
62
  batch.spec_info.prepare_for_extend(batch)
63
+ batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
62
64
  model_worker_batch = batch.get_model_worker_batch()
63
65
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
64
- forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
65
66
  logits_output = self.model_runner.forward(forward_batch)
66
67
  self.capture_for_decode(logits_output, forward_batch)
67
68
  self._set_mem_pool(batch, self.target_worker.model_runner)
@@ -134,26 +135,23 @@ class EAGLEWorker(TpModelWorker):
134
135
  batch.req_to_token_pool = runner.req_to_token_pool
135
136
 
136
137
  def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
138
+ seq_lens_backup = batch.seq_lens
139
+
137
140
  self._set_mem_pool(batch, self.model_runner)
138
141
  batch.forward_mode = ForwardMode.DRAFT_EXTEND
139
- if batch.spec_info.has_finished:
140
- index = batch.spec_info.unfinished_index
141
- seq_lens = batch.seq_lens
142
- batch.seq_lens = batch.seq_lens[index]
143
-
144
142
  batch.spec_info.prepare_extend_after_decode(batch)
143
+ batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
145
144
  model_worker_batch = batch.get_model_worker_batch()
146
145
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
147
- forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
148
146
  logits_output = self.model_runner.forward(forward_batch)
149
-
150
- batch.spec_info.hidden_states = logits_output.hidden_states
151
147
  self.capture_for_decode(logits_output, forward_batch)
152
- batch.forward_mode = ForwardMode.DECODE
153
- if batch.spec_info.has_finished:
154
- batch.seq_lens = seq_lens
155
148
  self._set_mem_pool(batch, self.target_worker.model_runner)
156
149
 
150
+ # Restore backup.
151
+ # This is because `seq_lens` can be modified in `prepare_extend_after_decode`
152
+ batch.forward_mode = ForwardMode.DECODE
153
+ batch.seq_lens = seq_lens_backup
154
+
157
155
  def capture_for_decode(
158
156
  self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
159
157
  ):
sglang/srt/utils.py CHANGED
@@ -14,6 +14,7 @@
14
14
  """Common utilities."""
15
15
 
16
16
  import base64
17
+ import ctypes
17
18
  import dataclasses
18
19
  import io
19
20
  import ipaddress
@@ -29,6 +30,7 @@ import shutil
29
30
  import signal
30
31
  import socket
31
32
  import subprocess
33
+ import sys
32
34
  import tempfile
33
35
  import time
34
36
  import warnings
@@ -59,7 +61,6 @@ from triton.runtime.cache import (
59
61
  default_dump_dir,
60
62
  default_override_dir,
61
63
  )
62
- from uvicorn.config import LOGGING_CONFIG
63
64
 
64
65
  logger = logging.getLogger(__name__)
65
66
 
@@ -73,7 +74,7 @@ def is_hip() -> bool:
73
74
 
74
75
 
75
76
  def is_cuda():
76
- return hasattr(torch, "cuda") and torch.cuda.is_available()
77
+ return hasattr(torch, "cuda") and torch.version.cuda is not None
77
78
 
78
79
 
79
80
  def is_cuda_alike():
@@ -443,8 +444,6 @@ def load_image(image_file: Union[str, bytes]):
443
444
  else:
444
445
  raise ValueError(f"Invalid image: {image}")
445
446
 
446
- # if image_size is None:
447
- # image_size = image.size
448
447
  return image, image_size
449
448
 
450
449
 
@@ -773,7 +772,7 @@ def get_zmq_socket(
773
772
 
774
773
 
775
774
  def dump_to_file(dirpath, name, value):
776
- from vllm.distributed import get_tensor_model_parallel_rank
775
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
777
776
 
778
777
  if get_tensor_model_parallel_rank() != 0:
779
778
  return
@@ -1242,68 +1241,6 @@ def dataclass_to_string_truncated(data, max_length=2048):
1242
1241
  return str(data)
1243
1242
 
1244
1243
 
1245
- TOOLS_TAG_LIST = ["<|plugin|>", "<function=", "<tool_call>", "<|python_tag|>"]
1246
-
1247
-
1248
- def parse_tool_response(text, tools, **kwargs):
1249
- """Parse model response containing tool information.
1250
-
1251
- Args:
1252
- text(str): model response in string format
1253
- tools(List): tools from user request
1254
- """
1255
- if "<|plugin|>" in text: # internlm2
1256
- text, action = text.split("<|action_start|><|plugin|>")
1257
- action = action.split("<|action_end|>".strip())[0]
1258
- action = action[action.find("{") :]
1259
- action = json.loads(action)
1260
- name, parameters = action["name"], json.dumps(
1261
- action.get("parameters", action.get("arguments", {})), ensure_ascii=False
1262
- )
1263
- call_info_list = [(name, parameters)]
1264
- elif "<function=" in text: # llama3.1
1265
- action, _ = text.split("</function>")
1266
- parameters = action[action.find("{") :]
1267
- name = action.split("<function=")[1].split(">{")[0]
1268
- call_info_list = [(name, parameters)]
1269
- elif "<tool_call>" in text and "</tool_call>" in text: # qwen2.5
1270
- # get tool_call in text
1271
- pattern = r"<tool_call>(.*?)</tool_call>"
1272
- match_result_list = re.findall(pattern, text, re.DOTALL)
1273
- call_info_list = []
1274
- for match_result in match_result_list:
1275
- action = json.loads(match_result)
1276
- call_info_list.append(
1277
- (action["name"], json.dumps(action["arguments"], ensure_ascii=False))
1278
- )
1279
- # get text outside of tags
1280
- if not text.startswith("<tool_call>"):
1281
- text = text[: text.find("<tool_call>")]
1282
- elif not text.endswith("</tool_call>"):
1283
- text = text[text.rfind("</tool_call>") + len("</tool_call>") :]
1284
- else:
1285
- text = ""
1286
- elif "<|python_tag|>" in text: # llama3.2
1287
- _, action = text.split("<|python_tag|>")
1288
- action = json.loads(action)
1289
- name, parameters = action["name"], json.dumps(
1290
- action.get("parameters", action.get("arguments", {})), ensure_ascii=False
1291
- )
1292
- call_info_list = [(name, parameters)]
1293
- else:
1294
- raise RuntimeError(f"Unexpected model response: {text}")
1295
-
1296
- call_info_list = [
1297
- (
1298
- [tool.function.name for tool in tools].index(call_info[0]),
1299
- call_info[0],
1300
- call_info[1],
1301
- )
1302
- for call_info in call_info_list
1303
- ]
1304
- return text, call_info_list
1305
-
1306
-
1307
1244
  def permute_weight(x: torch.Tensor) -> torch.Tensor:
1308
1245
  b_ = x.shape[0]
1309
1246
  n_ = x.shape[1]
@@ -1366,7 +1303,33 @@ def nullable_str(val: str):
1366
1303
  return val
1367
1304
 
1368
1305
 
1306
+ def pyspy_dump_schedulers():
1307
+ """py-spy dump on all scheduler in a local node."""
1308
+ try:
1309
+ pid = psutil.Process().pid
1310
+ # Command to run py-spy with the PID
1311
+ cmd = f"py-spy dump --pid {pid}"
1312
+ result = subprocess.run(
1313
+ cmd, shell=True, capture_output=True, text=True, check=True
1314
+ )
1315
+ logger.info(f"Profile for PID {pid}:\n{result.stdout}")
1316
+ except subprocess.CalledProcessError as e:
1317
+ logger.info(f"Failed to profile PID {pid}. Error: {e.stderr}")
1318
+
1319
+
1320
+ def kill_itself_when_parent_died():
1321
+ if sys.platform == "linux":
1322
+ # sigkill this process when parent worker manager dies
1323
+ PR_SET_PDEATHSIG = 1
1324
+ libc = ctypes.CDLL("libc.so.6")
1325
+ libc.prctl(PR_SET_PDEATHSIG, signal.SIGKILL)
1326
+ else:
1327
+ logger.warninig("kill_itself_when_parent_died is only supported in linux.")
1328
+
1329
+
1369
1330
  def set_uvicorn_logging_configs():
1331
+ from uvicorn.config import LOGGING_CONFIG
1332
+
1370
1333
  LOGGING_CONFIG["formatters"]["default"][
1371
1334
  "fmt"
1372
1335
  ] = "[%(asctime)s] %(levelprefix)s %(message)s"
@@ -1442,3 +1405,35 @@ def is_valid_ipv6_address(address: str) -> bool:
1442
1405
  return True
1443
1406
  except ValueError:
1444
1407
  return False
1408
+
1409
+
1410
+ def rank0_print(msg: str):
1411
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
1412
+
1413
+ if get_tensor_model_parallel_rank() == 0:
1414
+ print(msg, flush=True)
1415
+
1416
+
1417
+ def launch_dummy_health_check_server(host, port):
1418
+ import uvicorn
1419
+ from fastapi import FastAPI, Response
1420
+
1421
+ app = FastAPI()
1422
+
1423
+ @app.get("/health")
1424
+ async def health():
1425
+ """Check the health of the http server."""
1426
+ return Response(status_code=200)
1427
+
1428
+ @app.get("/health_generate")
1429
+ async def health_generate():
1430
+ """Check the health of the http server."""
1431
+ return Response(status_code=200)
1432
+
1433
+ uvicorn.run(
1434
+ app,
1435
+ host=host,
1436
+ port=port,
1437
+ timeout_keep_alive=5,
1438
+ loop="uvloop",
1439
+ )
@@ -535,6 +535,7 @@ def test_hellaswag_select():
535
535
 
536
536
  # Compute accuracy
537
537
  accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels))
538
+ print(f"{accuracy=}, {accuracy_gen=}")
538
539
  assert np.abs(accuracy_gen - accuracy) < 0.05
539
540
  assert np.abs(latency_gen - latency) < 1
540
541
 
sglang/test/test_utils.py CHANGED
@@ -34,7 +34,7 @@ DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST = "Qwen/Qwen1.5-MoE-A2.7B"
34
34
  DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST = "Alibaba-NLP/gte-Qwen2-1.5B-instruct"
35
35
  DEFAULT_MLA_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
36
36
  DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
37
- DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 600
37
+ DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 1000
38
38
  DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP1 = "meta-llama/Llama-3.1-8B-Instruct,mistralai/Mistral-7B-Instruct-v0.3,deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct,google/gemma-2-27b-it"
39
39
  DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2 = "meta-llama/Llama-3.1-70B-Instruct,mistralai/Mixtral-8x7B-Instruct-v0.1,Qwen/Qwen2-57B-A14B-Instruct"
40
40
  DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8,neuralmagic/Mistral-7B-Instruct-v0.3-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8,neuralmagic/gemma-2-2b-it-FP8"
@@ -42,6 +42,9 @@ DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2 = "neuralmagic/Meta-Llama-3.1-70B-In
42
42
  DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4,hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4"
43
43
  DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN = "Qwen/Qwen2.5-1.5B-Instruct"
44
44
 
45
+ DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf"
46
+ DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmzheng/sglang-EAGLE-llama2-chat-7B"
47
+
45
48
 
46
49
  def is_in_ci():
47
50
  """Return whether it is in CI runner."""
@@ -132,10 +135,6 @@ def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None):
132
135
  return pred
133
136
 
134
137
 
135
- def call_generate_gserver(prompt, temperature, max_tokens, stop=None, url=None):
136
- raise NotImplementedError()
137
-
138
-
139
138
  def call_generate_guidance(
140
139
  prompt, temperature, max_tokens, stop=None, n=1, regex=None, model=None
141
140
  ):
@@ -527,6 +526,48 @@ def get_similarities(vec1, vec2):
527
526
  return F.cosine_similarity(torch.tensor(vec1), torch.tensor(vec2), dim=0)
528
527
 
529
528
 
529
+ def get_benchmark_args(
530
+ base_url="",
531
+ dataset_name="",
532
+ dataset_path="",
533
+ tokenizer="",
534
+ num_prompts=500,
535
+ random_input_len=4096,
536
+ random_output_len=2048,
537
+ request_rate=float("inf"),
538
+ disable_stream=False,
539
+ disable_ignore_eos=False,
540
+ ):
541
+ return SimpleNamespace(
542
+ backend="sglang",
543
+ base_url=base_url,
544
+ host=None,
545
+ port=None,
546
+ dataset_name=dataset_name,
547
+ dataset_path=dataset_path,
548
+ model=None,
549
+ tokenizer=tokenizer,
550
+ num_prompts=num_prompts,
551
+ sharegpt_output_len=None,
552
+ sharegpt_context_len=None,
553
+ random_input_len=random_input_len,
554
+ random_output_len=random_output_len,
555
+ random_range_ratio=0.0,
556
+ request_rate=request_rate,
557
+ multi=None,
558
+ output_file=None,
559
+ disable_tqdm=False,
560
+ disable_stream=disable_stream,
561
+ return_logprob=False,
562
+ seed=0,
563
+ disable_ignore_eos=disable_ignore_eos,
564
+ extra_request_body=None,
565
+ apply_chat_template=False,
566
+ profile=None,
567
+ lora_name=None,
568
+ )
569
+
570
+
530
571
  def run_bench_serving(
531
572
  model,
532
573
  num_prompts,
@@ -538,6 +579,7 @@ def run_bench_serving(
538
579
  random_input_len=4096,
539
580
  random_output_len=2048,
540
581
  disable_stream=False,
582
+ disable_ignore_eos=False,
541
583
  need_warmup=False,
542
584
  ):
543
585
  # Launch the server
@@ -550,32 +592,17 @@ def run_bench_serving(
550
592
  )
551
593
 
552
594
  # Run benchmark
553
- args = SimpleNamespace(
554
- backend="sglang",
595
+ args = get_benchmark_args(
555
596
  base_url=base_url,
556
- host=None,
557
- port=None,
558
597
  dataset_name=dataset_name,
559
598
  dataset_path=dataset_path,
560
- model=None,
561
599
  tokenizer=tokenizer,
562
600
  num_prompts=num_prompts,
563
- sharegpt_output_len=None,
564
- sharegpt_context_len=None,
565
601
  random_input_len=random_input_len,
566
602
  random_output_len=random_output_len,
567
- random_range_ratio=0.0,
568
603
  request_rate=request_rate,
569
- multi=None,
570
- seed=0,
571
- output_file=None,
572
- disable_tqdm=False,
573
604
  disable_stream=disable_stream,
574
- disable_ignore_eos=False,
575
- return_logprob=False,
576
- lora_name=None,
577
- extra_request_body=None,
578
- profile=None,
605
+ disable_ignore_eos=disable_ignore_eos,
579
606
  )
580
607
 
581
608
  try:
@@ -591,6 +618,38 @@ def run_bench_serving(
591
618
  return res
592
619
 
593
620
 
621
+ def run_bench_serving_multi(
622
+ model,
623
+ base_url,
624
+ other_server_args,
625
+ benchmark_args,
626
+ need_warmup=False,
627
+ ):
628
+ # Launch the server
629
+ process = popen_launch_server(
630
+ model,
631
+ base_url,
632
+ timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
633
+ other_args=other_server_args,
634
+ )
635
+
636
+ # run benchmark for all
637
+ res_l = []
638
+ try:
639
+ for args in benchmark_args:
640
+ if need_warmup:
641
+ warmup_args = copy.deepcopy(args)
642
+ warmup_args.num_prompts = 16
643
+ run_benchmark(warmup_args)
644
+
645
+ res = run_benchmark(args)
646
+ res_l.append((args, res))
647
+ finally:
648
+ kill_process_tree(process.pid)
649
+
650
+ return res_l
651
+
652
+
594
653
  def run_bench_one_batch(model, other_args):
595
654
  command = [
596
655
  "python3",
sglang/utils.py CHANGED
@@ -373,3 +373,45 @@ class TypeBasedDispatcher:
373
373
  if isinstance(obj, ty):
374
374
  return fn(obj)
375
375
  raise ValueError(f"Invalid object: {obj}")
376
+
377
+
378
+ def trim_overlap(existing_text, new_chunk):
379
+ """
380
+ Finds the largest suffix of 'existing_text' that is a prefix of 'new_chunk'
381
+ and removes that overlap from the start of 'new_chunk'.
382
+ """
383
+ max_overlap = 0
384
+ max_possible = min(len(existing_text), len(new_chunk))
385
+ for i in range(max_possible, 0, -1):
386
+ if existing_text.endswith(new_chunk[:i]):
387
+ max_overlap = i
388
+ break
389
+ return new_chunk[max_overlap:]
390
+
391
+
392
+ def stream_and_merge(llm, prompt, sampling_params):
393
+ """
394
+ 1) Streams the text,
395
+ 2) Removes chunk overlaps,
396
+ 3) Returns the merged text.
397
+ """
398
+ final_text = ""
399
+ for chunk in llm.generate(prompt, sampling_params, stream=True):
400
+ chunk_text = chunk["text"]
401
+ cleaned_chunk = trim_overlap(final_text, chunk_text)
402
+ final_text += cleaned_chunk
403
+ return final_text
404
+
405
+
406
+ async def async_stream_and_merge(llm, prompt, sampling_params):
407
+ """
408
+ Streams tokens asynchronously, removes chunk overlaps,
409
+ and yields the cleaned chunk in real time for printing.
410
+ """
411
+ final_text = ""
412
+ generator = await llm.async_generate(prompt, sampling_params, stream=True)
413
+ async for chunk in generator:
414
+ chunk_text = chunk["text"]
415
+ cleaned_chunk = trim_overlap(final_text, chunk_text)
416
+ final_text += cleaned_chunk
417
+ yield cleaned_chunk # yield the non-overlapping portion
sglang/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.4.1.post7"
1
+ __version__ = "0.4.2.post1"