sglang 0.3.3__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. sglang/bench_latency.py +31 -13
  2. sglang/bench_server_latency.py +21 -10
  3. sglang/bench_serving.py +101 -7
  4. sglang/global_config.py +0 -1
  5. sglang/srt/conversation.py +11 -2
  6. sglang/srt/layers/attention/__init__.py +27 -5
  7. sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
  8. sglang/srt/layers/attention/flashinfer_backend.py +352 -83
  9. sglang/srt/layers/attention/triton_backend.py +6 -4
  10. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
  11. sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
  12. sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
  13. sglang/srt/layers/sampler.py +6 -2
  14. sglang/srt/managers/data_parallel_controller.py +177 -0
  15. sglang/srt/managers/detokenizer_manager.py +31 -10
  16. sglang/srt/managers/io_struct.py +11 -2
  17. sglang/srt/managers/schedule_batch.py +126 -43
  18. sglang/srt/managers/schedule_policy.py +2 -1
  19. sglang/srt/managers/scheduler.py +245 -142
  20. sglang/srt/managers/tokenizer_manager.py +14 -1
  21. sglang/srt/managers/tp_worker.py +111 -1
  22. sglang/srt/mem_cache/chunk_cache.py +8 -4
  23. sglang/srt/mem_cache/memory_pool.py +77 -4
  24. sglang/srt/mem_cache/radix_cache.py +15 -7
  25. sglang/srt/model_executor/cuda_graph_runner.py +4 -4
  26. sglang/srt/model_executor/forward_batch_info.py +16 -21
  27. sglang/srt/model_executor/model_runner.py +100 -36
  28. sglang/srt/models/baichuan.py +2 -3
  29. sglang/srt/models/chatglm.py +5 -6
  30. sglang/srt/models/commandr.py +1 -2
  31. sglang/srt/models/dbrx.py +1 -2
  32. sglang/srt/models/deepseek.py +4 -5
  33. sglang/srt/models/deepseek_v2.py +5 -6
  34. sglang/srt/models/exaone.py +1 -2
  35. sglang/srt/models/gemma.py +2 -2
  36. sglang/srt/models/gemma2.py +5 -5
  37. sglang/srt/models/gpt_bigcode.py +5 -5
  38. sglang/srt/models/grok.py +1 -2
  39. sglang/srt/models/internlm2.py +1 -2
  40. sglang/srt/models/llama.py +1 -2
  41. sglang/srt/models/llama_classification.py +1 -2
  42. sglang/srt/models/llama_reward.py +2 -3
  43. sglang/srt/models/llava.py +4 -8
  44. sglang/srt/models/llavavid.py +1 -2
  45. sglang/srt/models/minicpm.py +1 -2
  46. sglang/srt/models/minicpm3.py +5 -6
  47. sglang/srt/models/mixtral.py +1 -2
  48. sglang/srt/models/mixtral_quant.py +1 -2
  49. sglang/srt/models/olmo.py +352 -0
  50. sglang/srt/models/olmoe.py +1 -2
  51. sglang/srt/models/qwen.py +1 -2
  52. sglang/srt/models/qwen2.py +1 -2
  53. sglang/srt/models/qwen2_moe.py +4 -5
  54. sglang/srt/models/stablelm.py +1 -2
  55. sglang/srt/models/torch_native_llama.py +1 -2
  56. sglang/srt/models/xverse.py +1 -2
  57. sglang/srt/models/xverse_moe.py +4 -5
  58. sglang/srt/models/yivl.py +1 -2
  59. sglang/srt/openai_api/adapter.py +97 -52
  60. sglang/srt/openai_api/protocol.py +10 -2
  61. sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
  62. sglang/srt/sampling/sampling_batch_info.py +105 -59
  63. sglang/srt/sampling/sampling_params.py +2 -0
  64. sglang/srt/server.py +171 -37
  65. sglang/srt/server_args.py +127 -48
  66. sglang/srt/utils.py +37 -14
  67. sglang/test/few_shot_gsm8k.py +4 -1
  68. sglang/test/few_shot_gsm8k_engine.py +144 -0
  69. sglang/test/srt/sampling/penaltylib/utils.py +16 -12
  70. sglang/version.py +1 -1
  71. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/METADATA +82 -32
  72. sglang-0.3.4.dist-info/RECORD +143 -0
  73. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
  74. sglang/srt/layers/attention/flashinfer_utils.py +0 -237
  75. sglang-0.3.3.dist-info/RECORD +0 -139
  76. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
  77. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py CHANGED
@@ -35,11 +35,12 @@ class ServerArgs:
35
35
  tokenizer_mode: str = "auto"
36
36
  skip_tokenizer_init: bool = False
37
37
  load_format: str = "auto"
38
+ trust_remote_code: bool = True
38
39
  dtype: str = "auto"
39
40
  kv_cache_dtype: str = "auto"
40
- trust_remote_code: bool = True
41
- context_length: Optional[int] = None
42
41
  quantization: Optional[str] = None
42
+ context_length: Optional[int] = None
43
+ device: str = "cuda"
43
44
  served_model_name: Optional[str] = None
44
45
  chat_template: Optional[str] = None
45
46
  is_embedding: bool = False
@@ -72,6 +73,7 @@ class ServerArgs:
72
73
  # Other
73
74
  api_key: Optional[str] = None
74
75
  file_storage_pth: str = "SGLang_storage"
76
+ enable_cache_report: bool = False
75
77
 
76
78
  # Data parallelism
77
79
  dp_size: int = 1
@@ -85,10 +87,23 @@ class ServerArgs:
85
87
  # Model override args in JSON
86
88
  json_model_override_args: str = "{}"
87
89
 
88
- # Optimization/debug options
90
+ # Double Sparsity
91
+ enable_double_sparsity: bool = False
92
+ ds_channel_config_path: str = None
93
+ ds_heavy_channel_num: int = 32
94
+ ds_heavy_token_num: int = 256
95
+ ds_heavy_channel_type: str = "qk"
96
+ ds_sparse_decode_threshold: int = 4096
97
+
98
+ # LoRA
99
+ lora_paths: Optional[List[str]] = None
100
+ max_loras_per_batch: int = 8
101
+
102
+ # Kernel backend
89
103
  attention_backend: Optional[str] = None
90
104
  sampling_backend: Optional[str] = None
91
105
 
106
+ # Optimization/debug options
92
107
  disable_flashinfer: bool = False
93
108
  disable_flashinfer_sampling: bool = False
94
109
  disable_radix_cache: bool = False
@@ -98,16 +113,16 @@ class ServerArgs:
98
113
  disable_disk_cache: bool = False
99
114
  disable_custom_all_reduce: bool = False
100
115
  disable_mla: bool = False
116
+ disable_penalizer: bool = False
117
+ disable_nan_detection: bool = False
118
+ enable_overlap_schedule: bool = False
101
119
  enable_mixed_chunk: bool = False
102
120
  enable_torch_compile: bool = False
103
121
  max_torch_compile_bs: int = 32
104
122
  torchao_config: str = ""
105
123
  enable_p2p_check: bool = False
106
124
  triton_attention_reduce_in_fp32: bool = False
107
-
108
- # LoRA
109
- lora_paths: Optional[List[str]] = None
110
- max_loras_per_batch: int = 8
125
+ num_continuous_decode_steps: int = 1
111
126
 
112
127
  def __post_init__(self):
113
128
  # Set missing default values
@@ -223,6 +238,11 @@ class ServerArgs:
223
238
  '"dummy" will initialize the weights with random values, '
224
239
  "which is mainly for profiling.",
225
240
  )
241
+ parser.add_argument(
242
+ "--trust-remote-code",
243
+ action="store_true",
244
+ help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
245
+ )
226
246
  parser.add_argument(
227
247
  "--dtype",
228
248
  type=str,
@@ -244,17 +264,6 @@ class ServerArgs:
244
264
  choices=["auto", "fp8_e5m2"],
245
265
  help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
246
266
  )
247
- parser.add_argument(
248
- "--trust-remote-code",
249
- action="store_true",
250
- help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
251
- )
252
- parser.add_argument(
253
- "--context-length",
254
- type=int,
255
- default=ServerArgs.context_length,
256
- help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
257
- )
258
267
  parser.add_argument(
259
268
  "--quantization",
260
269
  type=str,
@@ -270,6 +279,19 @@ class ServerArgs:
270
279
  ],
271
280
  help="The quantization method.",
272
281
  )
282
+ parser.add_argument(
283
+ "--context-length",
284
+ type=int,
285
+ default=ServerArgs.context_length,
286
+ help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
287
+ )
288
+ parser.add_argument(
289
+ "--device",
290
+ type=str,
291
+ default="cuda",
292
+ choices=["cuda", "xpu"],
293
+ help="The device type.",
294
+ )
273
295
  parser.add_argument(
274
296
  "--served-model-name",
275
297
  type=str,
@@ -390,6 +412,11 @@ class ServerArgs:
390
412
  default=ServerArgs.file_storage_pth,
391
413
  help="The path of the file storage in backend.",
392
414
  )
415
+ parser.add_argument(
416
+ "--enable-cache-report",
417
+ action="store_true",
418
+ help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
419
+ )
393
420
 
394
421
  # Data parallelism
395
422
  parser.add_argument(
@@ -432,7 +459,60 @@ class ServerArgs:
432
459
  default=ServerArgs.json_model_override_args,
433
460
  )
434
461
 
435
- # Optimization/debug options
462
+ # Double Sparsity
463
+ parser.add_argument(
464
+ "--enable-double-sparsity",
465
+ action="store_true",
466
+ help="Enable double sparsity attention",
467
+ )
468
+ parser.add_argument(
469
+ "--ds-channel-config-path",
470
+ type=str,
471
+ default=ServerArgs.ds_channel_config_path,
472
+ help="The path of the double sparsity channel config",
473
+ )
474
+ parser.add_argument(
475
+ "--ds-heavy-channel-num",
476
+ type=int,
477
+ default=ServerArgs.ds_heavy_channel_num,
478
+ help="The number of heavy channels in double sparsity attention",
479
+ )
480
+ parser.add_argument(
481
+ "--ds-heavy-token-num",
482
+ type=int,
483
+ default=ServerArgs.ds_heavy_token_num,
484
+ help="The number of heavy tokens in double sparsity attention",
485
+ )
486
+ parser.add_argument(
487
+ "--ds-heavy-channel-type",
488
+ type=str,
489
+ default=ServerArgs.ds_heavy_channel_type,
490
+ help="The type of heavy channels in double sparsity attention",
491
+ )
492
+ parser.add_argument(
493
+ "--ds-sparse-decode-threshold",
494
+ type=int,
495
+ default=ServerArgs.ds_sparse_decode_threshold,
496
+ help="The type of heavy channels in double sparsity attention",
497
+ )
498
+
499
+ # LoRA
500
+ parser.add_argument(
501
+ "--lora-paths",
502
+ type=str,
503
+ nargs="*",
504
+ default=None,
505
+ action=LoRAPathAction,
506
+ help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}",
507
+ )
508
+ parser.add_argument(
509
+ "--max-loras-per-batch",
510
+ type=int,
511
+ default=8,
512
+ help="Maximum number of adapters for a running batch, include base-only request",
513
+ )
514
+
515
+ # Kernel backend
436
516
  parser.add_argument(
437
517
  "--attention-backend",
438
518
  type=str,
@@ -447,6 +527,8 @@ class ServerArgs:
447
527
  default=ServerArgs.sampling_backend,
448
528
  help="Choose the kernels for sampling layers.",
449
529
  )
530
+
531
+ # Optimization/debug options
450
532
  parser.add_argument(
451
533
  "--disable-flashinfer",
452
534
  action="store_true",
@@ -493,6 +575,21 @@ class ServerArgs:
493
575
  action="store_true",
494
576
  help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
495
577
  )
578
+ parser.add_argument(
579
+ "--disable-penalizer",
580
+ action="store_true",
581
+ help="Disable the logit penalizers (e.g., frequency and repetition penalty) for better performance if they are not used in any requests.",
582
+ )
583
+ parser.add_argument(
584
+ "--disable-nan-detection",
585
+ action="store_true",
586
+ help="Disable the NaN detection for better performance.",
587
+ )
588
+ parser.add_argument(
589
+ "--enable-overlap-schedule",
590
+ action="store_true",
591
+ help="Overlap the CPU scheduler with GPU model worker. Experimental feature.",
592
+ )
496
593
  parser.add_argument(
497
594
  "--enable-mixed-chunk",
498
595
  action="store_true",
@@ -527,25 +624,12 @@ class ServerArgs:
527
624
  "This only affects Triton attention kernels.",
528
625
  )
529
626
  parser.add_argument(
530
- "--efficient-weight-load",
531
- action="store_true",
532
- help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
533
- )
534
-
535
- # LoRA options
536
- parser.add_argument(
537
- "--lora-paths",
538
- type=str,
539
- nargs="*",
540
- default=None,
541
- action=LoRAPathAction,
542
- help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}",
543
- )
544
- parser.add_argument(
545
- "--max-loras-per-batch",
627
+ "--num-continuous-decode-steps",
546
628
  type=int,
547
- default=8,
548
- help="Maximum number of adapters for a running batch, include base-only request",
629
+ default=ServerArgs.num_continuous_decode_steps,
630
+ help="Run multiple continuous decoding steps to reduce scheduling overhead. "
631
+ "This can potentially increase throughput but may also increase time-to-first-token latency. "
632
+ "The default value is 1, meaning only run one decoding step at a time.",
549
633
  )
550
634
 
551
635
  @classmethod
@@ -566,7 +650,7 @@ class ServerArgs:
566
650
  self.tp_size % self.nnodes == 0
567
651
  ), "tp_size must be divisible by number of nodes"
568
652
  assert not (
569
- self.dp_size > 1 and self.node_rank is not None
653
+ self.dp_size > 1 and self.nnodes != 1
570
654
  ), "multi-node data parallel is not supported"
571
655
  assert (
572
656
  self.max_loras_per_batch > 0
@@ -575,11 +659,6 @@ class ServerArgs:
575
659
  and (self.lora_paths is None or self.disable_radix_cache)
576
660
  ), "compatibility of lora and cuda graph and radix attention is in progress"
577
661
 
578
- assert self.dp_size == 1, (
579
- "The support for data parallelism is temporarily disabled during refactor. "
580
- "Please use sglang<=0.3.2 or wait for later updates."
581
- )
582
-
583
662
  if isinstance(self.lora_paths, list):
584
663
  lora_paths = self.lora_paths
585
664
  self.lora_paths = {}
@@ -618,11 +697,11 @@ class PortArgs:
618
697
  # The ipc filename for detokenizer to receive inputs from scheduler (zmq)
619
698
  detokenizer_ipc_name: str
620
699
 
621
- # The port for nccl initialization for multiple TP groups (torch.dist)
622
- nccl_ports: List[int]
700
+ # The port for nccl initialization (torch.dist)
701
+ nccl_port: int
623
702
 
624
- @classmethod
625
- def init_new(self, server_args):
703
+ @staticmethod
704
+ def init_new(server_args) -> "PortArgs":
626
705
  port = server_args.port + 1
627
706
  while True:
628
707
  if is_port_available(port):
@@ -633,7 +712,7 @@ class PortArgs:
633
712
  tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
634
713
  scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
635
714
  detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
636
- nccl_ports=[port],
715
+ nccl_port=port,
637
716
  )
638
717
 
639
718
 
sglang/srt/utils.py CHANGED
@@ -35,7 +35,7 @@ import psutil
35
35
  import requests
36
36
  import torch
37
37
  import torch.distributed as dist
38
- from fastapi.responses import JSONResponse
38
+ from fastapi.responses import ORJSONResponse
39
39
  from packaging import version as pkg_version
40
40
  from torch import nn
41
41
  from torch.profiler import ProfilerActivity, profile, record_function
@@ -140,26 +140,41 @@ def calculate_time(show=False, min_cost_ms=0.0):
140
140
  return wrapper
141
141
 
142
142
 
143
- def get_available_gpu_memory(gpu_id, distributed=False):
143
+ def get_available_gpu_memory(device, gpu_id, distributed=False):
144
144
  """
145
145
  Get available memory for cuda:gpu_id device.
146
146
  When distributed is True, the available memory is the minimum available memory of all GPUs.
147
147
  """
148
- num_gpus = torch.cuda.device_count()
149
- assert gpu_id < num_gpus
148
+ if device == "cuda":
149
+ num_gpus = torch.cuda.device_count()
150
+ assert gpu_id < num_gpus
151
+
152
+ if torch.cuda.current_device() != gpu_id:
153
+ print(
154
+ f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
155
+ "which may cause useless memory allocation for torch CUDA context.",
156
+ )
150
157
 
151
- if torch.cuda.current_device() != gpu_id:
152
- print(
153
- f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
154
- "which may cause useless memory allocation for torch CUDA context.",
155
- )
158
+ torch.cuda.empty_cache()
159
+ free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
156
160
 
157
- torch.cuda.empty_cache()
158
- free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
161
+ elif device == "xpu":
162
+ num_gpus = torch.xpu.device_count()
163
+ assert gpu_id < num_gpus
164
+
165
+ if torch.xpu.current_device() != gpu_id:
166
+ print(
167
+ f"WARNING: current device is not {gpu_id}, but {torch.xpu.current_device()}, ",
168
+ "which may cause useless memory allocation for torch XPU context.",
169
+ )
170
+ torch.xpu.empty_cache()
171
+ used_memory = torch.xpu.memory_allocated()
172
+ total_gpu_memory = torch.xpu.get_device_properties(gpu_id).total_memory
173
+ free_gpu_memory = total_gpu_memory - used_memory
159
174
 
160
175
  if distributed:
161
176
  tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
162
- torch.device("cuda", gpu_id)
177
+ torch.device(device, gpu_id)
163
178
  )
164
179
  torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
165
180
  free_gpu_memory = tensor.item()
@@ -551,7 +566,7 @@ def add_api_key_middleware(app, api_key: str):
551
566
  if request.url.path.startswith("/health"):
552
567
  return await call_next(request)
553
568
  if request.headers.get("Authorization") != "Bearer " + api_key:
554
- return JSONResponse(content={"error": "Unauthorized"}, status_code=401)
569
+ return ORJSONResponse(content={"error": "Unauthorized"}, status_code=401)
555
570
  return await call_next(request)
556
571
 
557
572
 
@@ -569,10 +584,11 @@ def prepare_model_and_tokenizer(model_path: str, tokenizer_path: str):
569
584
 
570
585
  def configure_logger(server_args, prefix: str = ""):
571
586
  format = f"[%(asctime)s{prefix}] %(message)s"
587
+ # format = f"[%(asctime)s.%(msecs)03d{prefix}] %(message)s"
572
588
  logging.basicConfig(
573
589
  level=getattr(logging, server_args.log_level.upper()),
574
590
  format=format,
575
- datefmt="%H:%M:%S",
591
+ datefmt="%Y-%m-%d %H:%M:%S",
576
592
  force=True,
577
593
  )
578
594
 
@@ -675,3 +691,10 @@ def pytorch_profile(name, func, *args, data_size=-1):
675
691
  prof.export_chrome_trace(f"trace/{name}_{step_counter}.json")
676
692
  step_counter += 1
677
693
  return result
694
+
695
+
696
+ def first_rank_print(*args, **kwargs):
697
+ if torch.cuda.current_device() == 0:
698
+ print(*args, **kwargs)
699
+ else:
700
+ pass
@@ -76,7 +76,9 @@ def run_eval(args):
76
76
  def few_shot_gsm8k(s, question):
77
77
  s += few_shot_examples + question
78
78
  s += sgl.gen(
79
- "answer", max_tokens=512, stop=["Question", "Assistant:", "<|separator|>"]
79
+ "answer",
80
+ max_tokens=args.max_new_tokens,
81
+ stop=["Question", "Assistant:", "<|separator|>"],
80
82
  )
81
83
 
82
84
  #####################################
@@ -131,6 +133,7 @@ if __name__ == "__main__":
131
133
  parser.add_argument("--num-shots", type=int, default=5)
132
134
  parser.add_argument("--data-path", type=str, default="test.jsonl")
133
135
  parser.add_argument("--num-questions", type=int, default=200)
136
+ parser.add_argument("--max-new-tokens", type=int, default=512)
134
137
  parser.add_argument("--parallel", type=int, default=128)
135
138
  parser.add_argument("--host", type=str, default="http://127.0.0.1")
136
139
  parser.add_argument("--port", type=int, default=30000)
@@ -0,0 +1,144 @@
1
+ import argparse
2
+ import ast
3
+ import asyncio
4
+ import json
5
+ import re
6
+ import time
7
+
8
+ import numpy as np
9
+
10
+ import sglang as sgl
11
+ from sglang.api import set_default_backend
12
+ from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
13
+ from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
14
+
15
+ INVALID = -9999999
16
+
17
+
18
+ def get_one_example(lines, i, include_answer):
19
+ ret = "Question: " + lines[i]["question"] + "\nAnswer:"
20
+ if include_answer:
21
+ ret += " " + lines[i]["answer"]
22
+ return ret
23
+
24
+
25
+ def get_few_shot_examples(lines, k):
26
+ ret = ""
27
+ for i in range(k):
28
+ ret += get_one_example(lines, i, True) + "\n\n"
29
+ return ret
30
+
31
+
32
+ def get_answer_value(answer_str):
33
+ answer_str = answer_str.replace(",", "")
34
+ numbers = re.findall(r"\d+", answer_str)
35
+ if len(numbers) < 1:
36
+ return INVALID
37
+ try:
38
+ return ast.literal_eval(numbers[-1])
39
+ except SyntaxError:
40
+ return INVALID
41
+
42
+
43
+ async def concurrent_generate(engine, prompts, sampling_param):
44
+ tasks = []
45
+ for prompt in prompts:
46
+ tasks.append(asyncio.create_task(engine.async_generate(prompt, sampling_param)))
47
+
48
+ outputs = await asyncio.gather(*tasks)
49
+ return outputs
50
+
51
+
52
+ def run_eval(args):
53
+ # Select backend
54
+ engine = sgl.Engine(model_path=args.model_path, log_level="error")
55
+
56
+ if args.local_data_path is None:
57
+ # Read data
58
+ url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
59
+ filename = download_and_cache_file(url)
60
+ else:
61
+ filename = args.local_data_path
62
+
63
+ lines = list(read_jsonl(filename))
64
+
65
+ # Construct prompts
66
+ num_questions = args.num_questions
67
+ num_shots = args.num_shots
68
+ few_shot_examples = get_few_shot_examples(lines, num_shots)
69
+
70
+ questions = []
71
+ labels = []
72
+ for i in range(len(lines[:num_questions])):
73
+ questions.append(get_one_example(lines, i, False))
74
+ labels.append(get_answer_value(lines[i]["answer"]))
75
+ assert all(l != INVALID for l in labels)
76
+ arguments = [{"question": q} for q in questions]
77
+
78
+ # construct the prompts
79
+ prompts = []
80
+ for i, arg in enumerate(arguments):
81
+ q = arg["question"]
82
+ prompt = few_shot_examples + q
83
+ prompts.append(prompt)
84
+
85
+ sampling_param = {
86
+ "stop": ["Question", "Assistant:", "<|separator|>"],
87
+ "max_new_tokens": 512,
88
+ "temperature": 0,
89
+ }
90
+
91
+ # Run requests
92
+ tic = time.time()
93
+
94
+ loop = asyncio.get_event_loop()
95
+
96
+ outputs = loop.run_until_complete(
97
+ concurrent_generate(engine, prompts, sampling_param)
98
+ )
99
+
100
+ # End requests
101
+ latency = time.time() - tic
102
+
103
+ # Shutdown the engine
104
+ engine.shutdown()
105
+
106
+ # Parse output
107
+ preds = []
108
+
109
+ for output in outputs:
110
+ preds.append(get_answer_value(output["text"]))
111
+
112
+ # Compute accuracy
113
+ acc = np.mean(np.array(preds) == np.array(labels))
114
+ invalid = np.mean(np.array(preds) == INVALID)
115
+
116
+ # Compute speed
117
+ num_output_tokens = sum(
118
+ output["meta_info"]["completion_tokens"] for output in outputs
119
+ )
120
+ output_throughput = num_output_tokens / latency
121
+
122
+ # Print results
123
+ print(f"Accuracy: {acc:.3f}")
124
+ print(f"Invalid: {invalid:.3f}")
125
+ print(f"Latency: {latency:.3f} s")
126
+ print(f"Output throughput: {output_throughput:.3f} token/s")
127
+
128
+ return {
129
+ "accuracy": acc,
130
+ "latency": latency,
131
+ "output_throughput": output_throughput,
132
+ }
133
+
134
+
135
+ if __name__ == "__main__":
136
+ parser = argparse.ArgumentParser()
137
+ parser.add_argument(
138
+ "--model-path", type=str, default="meta-llama/Meta-Llama-3.1-8B-Instruct"
139
+ )
140
+ parser.add_argument("--local-data-path", type=Optional[str], default=None)
141
+ parser.add_argument("--num-shots", type=int, default=5)
142
+ parser.add_argument("--num-questions", type=int, default=200)
143
+ args = parser.parse_args()
144
+ metrics = run_eval(args)
@@ -164,19 +164,20 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
164
164
  msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
165
165
  )
166
166
 
167
- actual = orchestrator.apply(
168
- torch.ones(
169
- size=(len(case.test_subjects), self.vocab_size),
170
- dtype=torch.float32,
171
- device=self.device,
172
- )
167
+ original = torch.ones(
168
+ size=(len(case.test_subjects), self.vocab_size),
169
+ dtype=torch.float32,
170
+ device=self.device,
173
171
  )
172
+ actual = orchestrator.apply(original.clone())
174
173
  expected = torch.cat(
175
174
  tensors=[
176
175
  subject.steps[0].expected_logits
177
176
  for subject in case.test_subjects
178
177
  ],
179
178
  )
179
+ if actual is None:
180
+ actual = original
180
181
  torch.testing.assert_close(
181
182
  actual=actual,
182
183
  expected=expected,
@@ -226,6 +227,8 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
226
227
  device=self.device,
227
228
  )
228
229
  )
230
+ if actual_logits is None:
231
+ continue
229
232
  filtered_expected_logits = torch.cat(
230
233
  tensors=[
231
234
  subject.steps[0].expected_logits
@@ -317,19 +320,20 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
317
320
  msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
318
321
  )
319
322
 
320
- actual_logits = orchestrator.apply(
321
- torch.ones(
322
- size=(len(filtered_subjects), self.vocab_size),
323
- dtype=torch.float32,
324
- device=self.device,
325
- )
323
+ original = torch.ones(
324
+ size=(len(filtered_subjects), self.vocab_size),
325
+ dtype=torch.float32,
326
+ device=self.device,
326
327
  )
328
+ actual_logits = orchestrator.apply(original.clone())
327
329
  filtered_expected_logits = torch.cat(
328
330
  tensors=[
329
331
  subject.steps[i].expected_logits
330
332
  for subject in filtered_subjects
331
333
  ],
332
334
  )
335
+ if actual_logits is None:
336
+ actual_logits = original
333
337
  torch.testing.assert_close(
334
338
  actual=actual_logits,
335
339
  expected=filtered_expected_logits,
sglang/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.3.3"
1
+ __version__ = "0.3.4"