sglang 0.3.5__py3-none-any.whl → 0.3.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. sglang/bench_offline_throughput.py +309 -0
  2. sglang/bench_serving.py +148 -24
  3. sglang/srt/configs/model_config.py +5 -2
  4. sglang/srt/constrained/__init__.py +2 -66
  5. sglang/srt/constrained/base_grammar_backend.py +73 -0
  6. sglang/srt/constrained/outlines_backend.py +165 -0
  7. sglang/srt/constrained/outlines_jump_forward.py +182 -0
  8. sglang/srt/constrained/xgrammar_backend.py +150 -0
  9. sglang/srt/layers/attention/triton_ops/decode_attention.py +7 -0
  10. sglang/srt/layers/attention/triton_ops/extend_attention.py +6 -0
  11. sglang/srt/layers/fused_moe/fused_moe.py +23 -7
  12. sglang/srt/layers/fused_moe/patch.py +4 -2
  13. sglang/srt/layers/quantization/base_config.py +4 -6
  14. sglang/srt/layers/vocab_parallel_embedding.py +216 -150
  15. sglang/srt/managers/detokenizer_manager.py +0 -14
  16. sglang/srt/managers/io_struct.py +5 -3
  17. sglang/srt/managers/schedule_batch.py +14 -20
  18. sglang/srt/managers/scheduler.py +159 -96
  19. sglang/srt/managers/tokenizer_manager.py +81 -17
  20. sglang/srt/metrics/collector.py +211 -0
  21. sglang/srt/metrics/func_timer.py +108 -0
  22. sglang/srt/mm_utils.py +1 -1
  23. sglang/srt/model_executor/cuda_graph_runner.py +2 -2
  24. sglang/srt/model_executor/forward_batch_info.py +7 -3
  25. sglang/srt/model_executor/model_runner.py +6 -2
  26. sglang/srt/models/gemma2_reward.py +69 -0
  27. sglang/srt/models/gpt2.py +31 -37
  28. sglang/srt/models/internlm2_reward.py +62 -0
  29. sglang/srt/models/llama.py +11 -6
  30. sglang/srt/models/llama_reward.py +5 -26
  31. sglang/srt/models/qwen2_vl.py +5 -7
  32. sglang/srt/openai_api/adapter.py +11 -4
  33. sglang/srt/openai_api/protocol.py +29 -26
  34. sglang/srt/sampling/sampling_batch_info.py +2 -3
  35. sglang/srt/sampling/sampling_params.py +2 -16
  36. sglang/srt/server.py +60 -17
  37. sglang/srt/server_args.py +66 -25
  38. sglang/srt/utils.py +120 -0
  39. sglang/test/simple_eval_common.py +1 -1
  40. sglang/test/simple_eval_humaneval.py +2 -2
  41. sglang/test/simple_eval_mgsm.py +2 -2
  42. sglang/test/test_utils.py +21 -7
  43. sglang/utils.py +1 -0
  44. sglang/version.py +1 -1
  45. {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/METADATA +12 -8
  46. {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/RECORD +49 -45
  47. {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/WHEEL +1 -1
  48. sglang/srt/constrained/base_tool_cache.py +0 -65
  49. sglang/srt/constrained/bnf_cache.py +0 -61
  50. sglang/srt/constrained/fsm_cache.py +0 -95
  51. sglang/srt/constrained/grammar.py +0 -190
  52. sglang/srt/constrained/jump_forward.py +0 -203
  53. {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/LICENSE +0 -0
  54. {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/top_level.txt +0 -0
sglang/srt/server.py CHANGED
@@ -30,12 +30,11 @@ import time
30
30
  from http import HTTPStatus
31
31
  from typing import AsyncIterator, Dict, List, Optional, Union
32
32
 
33
- import orjson
34
-
35
33
  # Fix a bug of Python threading
36
34
  setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
37
35
 
38
36
  import aiohttp
37
+ import orjson
39
38
  import requests
40
39
  import uvicorn
41
40
  import uvloop
@@ -57,6 +56,7 @@ from sglang.srt.managers.io_struct import (
57
56
  )
58
57
  from sglang.srt.managers.scheduler import run_scheduler_process
59
58
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
59
+ from sglang.srt.metrics.func_timer import enable_func_timer, time_func_latency
60
60
  from sglang.srt.openai_api.adapter import (
61
61
  load_chat_template_for_openai_api,
62
62
  v1_batches,
@@ -74,12 +74,15 @@ from sglang.srt.openai_api.protocol import ModelCard, ModelList
74
74
  from sglang.srt.server_args import PortArgs, ServerArgs
75
75
  from sglang.srt.utils import (
76
76
  add_api_key_middleware,
77
+ add_prometheus_middleware,
77
78
  assert_pkg_version,
78
79
  configure_logger,
80
+ delete_directory,
79
81
  is_port_available,
80
82
  kill_child_process,
81
83
  maybe_set_triton_cache_manager,
82
84
  prepare_model_and_tokenizer,
85
+ set_prometheus_multiproc_dir,
83
86
  set_ulimit,
84
87
  )
85
88
  from sglang.utils import get_exception_traceback
@@ -90,8 +93,6 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
90
93
 
91
94
 
92
95
  app = FastAPI()
93
- tokenizer_manager: TokenizerManager = None
94
-
95
96
  app.add_middleware(
96
97
  CORSMiddleware,
97
98
  allow_origins=["*"],
@@ -100,6 +101,10 @@ app.add_middleware(
100
101
  allow_headers=["*"],
101
102
  )
102
103
 
104
+ tokenizer_manager: TokenizerManager = None
105
+
106
+ ##### Native API endpoints #####
107
+
103
108
 
104
109
  @app.get("/health")
105
110
  async def health() -> Response:
@@ -110,9 +115,16 @@ async def health() -> Response:
110
115
  @app.get("/health_generate")
111
116
  async def health_generate(request: Request) -> Response:
112
117
  """Check the health of the inference server by generating one token."""
113
- gri = GenerateReqInput(
114
- text="s", sampling_params={"max_new_tokens": 1, "temperature": 0.7}
115
- )
118
+
119
+ if tokenizer_manager.is_generation:
120
+ gri = GenerateReqInput(
121
+ input_ids=[0], sampling_params={"max_new_tokens": 1, "temperature": 0.7}
122
+ )
123
+ else:
124
+ gri = EmbeddingReqInput(
125
+ input_ids=[0], sampling_params={"max_new_tokens": 1, "temperature": 0.7}
126
+ )
127
+
116
128
  try:
117
129
  async for _ in tokenizer_manager.generate_request(gri, request):
118
130
  break
@@ -127,6 +139,7 @@ async def get_model_info():
127
139
  """Get the model information."""
128
140
  result = {
129
141
  "model_path": tokenizer_manager.model_path,
142
+ "tokenizer_path": tokenizer_manager.server_args.tokenizer_path,
130
143
  "is_generation": tokenizer_manager.is_generation,
131
144
  }
132
145
  return result
@@ -185,6 +198,7 @@ async def get_memory_pool_size():
185
198
 
186
199
 
187
200
  @app.post("/update_weights")
201
+ @time_func_latency
188
202
  async def update_weights(obj: UpdateWeightReqInput, request: Request):
189
203
  """Update the weights inplace without re-launching the server."""
190
204
  success, message = await tokenizer_manager.update_weights(obj, request)
@@ -201,7 +215,7 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request):
201
215
  )
202
216
 
203
217
 
204
- # fastapi implicitly converts json in the request to obj (dataclass)
218
+ @time_func_latency
205
219
  async def generate_request(obj: GenerateReqInput, request: Request):
206
220
  """Handle a generate request."""
207
221
  if obj.stream:
@@ -234,10 +248,12 @@ async def generate_request(obj: GenerateReqInput, request: Request):
234
248
  )
235
249
 
236
250
 
251
+ # fastapi implicitly converts json in the request to obj (dataclass)
237
252
  app.post("/generate")(generate_request)
238
253
  app.put("/generate")(generate_request)
239
254
 
240
255
 
256
+ @time_func_latency
241
257
  async def encode_request(obj: EmbeddingReqInput, request: Request):
242
258
  """Handle an embedding request."""
243
259
  try:
@@ -253,7 +269,8 @@ app.post("/encode")(encode_request)
253
269
  app.put("/encode")(encode_request)
254
270
 
255
271
 
256
- async def judge_request(obj: EmbeddingReqInput, request: Request):
272
+ @time_func_latency
273
+ async def classify_request(obj: EmbeddingReqInput, request: Request):
257
274
  """Handle a reward model request. Now the arguments and return values are the same as embedding models."""
258
275
  try:
259
276
  ret = await tokenizer_manager.generate_request(obj, request).__anext__()
@@ -264,21 +281,27 @@ async def judge_request(obj: EmbeddingReqInput, request: Request):
264
281
  )
265
282
 
266
283
 
267
- app.post("/judge")(judge_request)
268
- app.put("/judge")(judge_request)
284
+ app.post("/classify")(classify_request)
285
+ app.put("/classify")(classify_request)
286
+
287
+
288
+ ##### OpenAI-compatible API endpoints #####
269
289
 
270
290
 
271
291
  @app.post("/v1/completions")
292
+ @time_func_latency
272
293
  async def openai_v1_completions(raw_request: Request):
273
294
  return await v1_completions(tokenizer_manager, raw_request)
274
295
 
275
296
 
276
297
  @app.post("/v1/chat/completions")
298
+ @time_func_latency
277
299
  async def openai_v1_chat_completions(raw_request: Request):
278
300
  return await v1_chat_completions(tokenizer_manager, raw_request)
279
301
 
280
302
 
281
303
  @app.post("/v1/embeddings", response_class=ORJSONResponse)
304
+ @time_func_latency
282
305
  async def openai_v1_embeddings(raw_request: Request):
283
306
  response = await v1_embeddings(tokenizer_manager, raw_request)
284
307
  return response
@@ -432,13 +455,17 @@ def launch_server(
432
455
  1. The HTTP server and Tokenizer Manager both run in the main process.
433
456
  2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
434
457
  """
435
-
436
458
  launch_engine(server_args=server_args)
437
459
 
438
460
  # Add api key authorization
439
461
  if server_args.api_key:
440
462
  add_api_key_middleware(app, server_args.api_key)
441
463
 
464
+ # add prometheus middleware
465
+ if server_args.enable_metrics:
466
+ add_prometheus_middleware(app)
467
+ enable_func_timer()
468
+
442
469
  # Send a warmup request
443
470
  t = threading.Thread(
444
471
  target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
@@ -475,6 +502,10 @@ def _set_envs_and_config(server_args: ServerArgs):
475
502
  os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
476
503
  os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
477
504
 
505
+ # Set prometheus env vars
506
+ if server_args.enable_metrics:
507
+ set_prometheus_multiproc_dir()
508
+
478
509
  # Set ulimit
479
510
  set_ulimit()
480
511
 
@@ -523,6 +554,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
523
554
  return
524
555
 
525
556
  model_info = res.json()
557
+
526
558
  # Send a warmup request
527
559
  request_name = "/generate" if model_info["is_generation"] else "/encode"
528
560
  max_new_tokens = 8 if model_info["is_generation"] else 1
@@ -560,6 +592,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
560
592
  if pipe_finish_writer is not None:
561
593
  pipe_finish_writer.send("ready")
562
594
 
595
+ if server_args.delete_ckpt_after_loading:
596
+ delete_directory(server_args.model_path)
597
+
563
598
 
564
599
  class Runtime:
565
600
  """
@@ -720,12 +755,12 @@ class Engine:
720
755
 
721
756
  # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
722
757
  atexit.register(self.shutdown)
723
-
758
+
724
759
  # runtime server default log level is log
725
760
  # offline engine works in scripts, so we set it to error
726
761
 
727
- if 'log_level' not in kwargs:
728
- kwargs['log_level'] = 'error'
762
+ if "log_level" not in kwargs:
763
+ kwargs["log_level"] = "error"
729
764
 
730
765
  server_args = ServerArgs(*args, **kwargs)
731
766
  launch_engine(server_args=server_args)
@@ -734,7 +769,7 @@ class Engine:
734
769
  self,
735
770
  # The input prompt. It can be a single prompt or a batch of prompts.
736
771
  prompt: Optional[Union[List[str], str]] = None,
737
- sampling_params: Optional[Dict] = None,
772
+ sampling_params: Optional[Union[List[Dict], Dict]] = None,
738
773
  # The token ids for text; one can either specify text or input_ids.
739
774
  input_ids: Optional[Union[List[List[int]], List[int]]] = None,
740
775
  return_logprob: Optional[Union[List[bool], bool]] = False,
@@ -840,4 +875,12 @@ class Engine:
840
875
  else:
841
876
  return tokenizer_manager.tokenizer
842
877
 
843
- # TODO (ByronHsu): encode
878
+ def encode(
879
+ self,
880
+ prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
881
+ ):
882
+ obj = EmbeddingReqInput(text=prompt)
883
+
884
+ # get the current event loop
885
+ loop = asyncio.get_event_loop()
886
+ return loop.run_until_complete(encode_request(obj, None))
sglang/srt/server_args.py CHANGED
@@ -22,7 +22,12 @@ import random
22
22
  import tempfile
23
23
  from typing import List, Optional
24
24
 
25
- from sglang.srt.utils import is_flashinfer_available, is_ipv6, is_port_available
25
+ from sglang.srt.utils import (
26
+ get_gpu_memory_capacity,
27
+ is_flashinfer_available,
28
+ is_ipv6,
29
+ is_port_available,
30
+ )
26
31
 
27
32
  logger = logging.getLogger(__name__)
28
33
 
@@ -63,25 +68,27 @@ class ServerArgs:
63
68
  stream_interval: int = 1
64
69
  random_seed: Optional[int] = None
65
70
  constrained_json_whitespace_pattern: Optional[str] = None
66
- decode_log_interval: int = 40
71
+ watchdog_timeout: float = 300
72
+ download_dir: Optional[str] = None
67
73
 
68
74
  # Logging
69
75
  log_level: str = "info"
70
76
  log_level_http: Optional[str] = None
71
77
  log_requests: bool = False
72
78
  show_time_cost: bool = False
79
+ enable_metrics: bool = False
80
+ decode_log_interval: int = 40
73
81
 
74
- # Other
82
+ # API related
75
83
  api_key: Optional[str] = None
76
84
  file_storage_pth: str = "SGLang_storage"
77
85
  enable_cache_report: bool = False
78
- watchdog_timeout: float = 600
79
86
 
80
87
  # Data parallelism
81
88
  dp_size: int = 1
82
89
  load_balance_method: str = "round_robin"
83
90
 
84
- # Distributed args
91
+ # Multi-node distributed serving
85
92
  dist_init_addr: Optional[str] = None
86
93
  nnodes: int = 1
87
94
  node_rank: int = 0
@@ -110,7 +117,7 @@ class ServerArgs:
110
117
  disable_flashinfer: bool = False
111
118
  disable_flashinfer_sampling: bool = False
112
119
  disable_radix_cache: bool = False
113
- disable_regex_jump_forward: bool = False
120
+ disable_jump_forward: bool = False
114
121
  disable_cuda_graph: bool = False
115
122
  disable_cuda_graph_padding: bool = False
116
123
  disable_disk_cache: bool = False
@@ -127,6 +134,7 @@ class ServerArgs:
127
134
  enable_p2p_check: bool = False
128
135
  triton_attention_reduce_in_fp32: bool = False
129
136
  num_continuous_decode_steps: int = 1
137
+ delete_ckpt_after_loading: bool = False
130
138
 
131
139
  def __post_init__(self):
132
140
  # Set missing default values
@@ -140,6 +148,9 @@ class ServerArgs:
140
148
  # Disable chunked prefill
141
149
  self.chunked_prefill_size = None
142
150
 
151
+ if self.random_seed is None:
152
+ self.random_seed = random.randint(0, 1 << 30)
153
+
143
154
  # Mem fraction depends on the tensor parallelism size
144
155
  if self.mem_fraction_static is None:
145
156
  if self.tp_size >= 16:
@@ -153,8 +164,14 @@ class ServerArgs:
153
164
  else:
154
165
  self.mem_fraction_static = 0.88
155
166
 
156
- if self.random_seed is None:
157
- self.random_seed = random.randint(0, 1 << 30)
167
+ # Adjust for GPUs with small memory capacities
168
+ gpu_mem = get_gpu_memory_capacity()
169
+ if gpu_mem < 25000:
170
+ logger.warning(
171
+ "Automatically adjust --chunked-prefill-size for small GPUs."
172
+ )
173
+ self.chunked_prefill_size //= 4 # make it 2048
174
+ self.cuda_graph_max_bs = 4
158
175
 
159
176
  # Deprecation warnings
160
177
  if self.disable_flashinfer:
@@ -204,6 +221,7 @@ class ServerArgs:
204
221
 
205
222
  @staticmethod
206
223
  def add_cli_args(parser: argparse.ArgumentParser):
224
+ # Model and port args
207
225
  parser.add_argument(
208
226
  "--model-path",
209
227
  type=str,
@@ -323,6 +341,8 @@ class ServerArgs:
323
341
  action="store_true",
324
342
  help="Whether to use a CausalLM as an embedding model.",
325
343
  )
344
+
345
+ # Memory and scheduling
326
346
  parser.add_argument(
327
347
  "--mem-fraction-static",
328
348
  type=float,
@@ -367,6 +387,8 @@ class ServerArgs:
367
387
  default=ServerArgs.schedule_conservativeness,
368
388
  help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
369
389
  )
390
+
391
+ # Other runtime options
370
392
  parser.add_argument(
371
393
  "--tensor-parallel-size",
372
394
  "--tp-size",
@@ -392,6 +414,20 @@ class ServerArgs:
392
414
  default=ServerArgs.constrained_json_whitespace_pattern,
393
415
  help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
394
416
  )
417
+ parser.add_argument(
418
+ "--watchdog-timeout",
419
+ type=float,
420
+ default=ServerArgs.watchdog_timeout,
421
+ help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
422
+ )
423
+ parser.add_argument(
424
+ "--download-dir",
425
+ type=str,
426
+ default=ServerArgs.download_dir,
427
+ help="Model download directory.",
428
+ )
429
+
430
+ # Logging
395
431
  parser.add_argument(
396
432
  "--log-level",
397
433
  type=str,
@@ -414,6 +450,19 @@ class ServerArgs:
414
450
  action="store_true",
415
451
  help="Show time cost of custom marks.",
416
452
  )
453
+ parser.add_argument(
454
+ "--enable-metrics",
455
+ action="store_true",
456
+ help="Enable log prometheus metrics.",
457
+ )
458
+ parser.add_argument(
459
+ "--decode-log-interval",
460
+ type=int,
461
+ default=ServerArgs.decode_log_interval,
462
+ help="The log interval of decode batch",
463
+ )
464
+
465
+ # API related
417
466
  parser.add_argument(
418
467
  "--api-key",
419
468
  type=str,
@@ -431,18 +480,6 @@ class ServerArgs:
431
480
  action="store_true",
432
481
  help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
433
482
  )
434
- parser.add_argument(
435
- "--watchdog-timeout",
436
- type=float,
437
- default=ServerArgs.watchdog_timeout,
438
- help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
439
- )
440
- parser.add_argument(
441
- "--decode-log-interval",
442
- type=int,
443
- default=ServerArgs.decode_log_interval,
444
- help="The log interval of decode batch"
445
- )
446
483
 
447
484
  # Data parallelism
448
485
  parser.add_argument(
@@ -463,7 +500,7 @@ class ServerArgs:
463
500
  ],
464
501
  )
465
502
 
466
- # Multi-node distributed serving args
503
+ # Multi-node distributed serving
467
504
  parser.add_argument(
468
505
  "--dist-init-addr",
469
506
  "--nccl-init-addr", # For backward compatbility. This will be removed in the future.
@@ -558,7 +595,7 @@ class ServerArgs:
558
595
  type=str,
559
596
  choices=["xgrammar", "outlines"],
560
597
  default=ServerArgs.grammar_backend,
561
- help="Choose the backend for constrained decoding.",
598
+ help="Choose the backend for grammar-guided decoding.",
562
599
  )
563
600
 
564
601
  # Optimization/debug options
@@ -578,9 +615,9 @@ class ServerArgs:
578
615
  help="Disable RadixAttention for prefix caching.",
579
616
  )
580
617
  parser.add_argument(
581
- "--disable-regex-jump-forward",
618
+ "--disable-jump-forward",
582
619
  action="store_true",
583
- help="Disable regex jump-forward.",
620
+ help="Disable jump-forward for grammar-guided decoding.",
584
621
  )
585
622
  parser.add_argument(
586
623
  "--disable-cuda-graph",
@@ -600,7 +637,6 @@ class ServerArgs:
600
637
  parser.add_argument(
601
638
  "--disable-custom-all-reduce",
602
639
  action="store_true",
603
- default=False,
604
640
  help="Disable the custom all-reduce kernel and fall back to NCCL.",
605
641
  )
606
642
  parser.add_argument(
@@ -670,6 +706,11 @@ class ServerArgs:
670
706
  "This can potentially increase throughput but may also increase time-to-first-token latency. "
671
707
  "The default value is 1, meaning only run one decoding step at a time.",
672
708
  )
709
+ parser.add_argument(
710
+ "--delete-ckpt-after-loading",
711
+ action="store_true",
712
+ help="Delete the model checkpoint after loading the model.",
713
+ )
673
714
 
674
715
  @classmethod
675
716
  def from_cli_args(cls, args: argparse.Namespace):
sglang/srt/utils.py CHANGED
@@ -22,8 +22,13 @@ import logging
22
22
  import os
23
23
  import pickle
24
24
  import random
25
+ import re
25
26
  import resource
27
+ import shutil
28
+ import signal
26
29
  import socket
30
+ import subprocess
31
+ import tempfile
27
32
  import time
28
33
  import warnings
29
34
  from importlib.metadata import PackageNotFoundError, version
@@ -35,9 +40,11 @@ import psutil
35
40
  import requests
36
41
  import torch
37
42
  import torch.distributed as dist
43
+ import triton
38
44
  import zmq
39
45
  from fastapi.responses import ORJSONResponse
40
46
  from packaging import version as pkg_version
47
+ from starlette.routing import Mount
41
48
  from torch import nn
42
49
  from torch.profiler import ProfilerActivity, profile, record_function
43
50
  from triton.runtime.cache import (
@@ -379,6 +386,10 @@ def kill_child_process(pid=None, include_self=False, skip_pid=None):
379
386
  if include_self:
380
387
  try:
381
388
  itself.kill()
389
+
390
+ # Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes),
391
+ # so we send an additional signal to kill them.
392
+ itself.send_signal(signal.SIGINT)
382
393
  except psutil.NoSuchProcess:
383
394
  pass
384
395
 
@@ -704,3 +715,112 @@ def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint:
704
715
  raise ValueError(f"Unsupported socket type: {socket_type}")
705
716
 
706
717
  return socket
718
+
719
+
720
+ def dump_to_file(dirpath, name, value):
721
+ from vllm.distributed import get_tensor_model_parallel_rank
722
+
723
+ if get_tensor_model_parallel_rank() != 0:
724
+ return
725
+
726
+ os.makedirs(dirpath, exist_ok=True)
727
+ if value.dtype is torch.bfloat16:
728
+ value = value.float()
729
+ value = value.cpu().numpy()
730
+ output_filename = os.path.join(dirpath, f"pytorch_dump_{name}.npy")
731
+ logger.info(f"Dump a tensor to {output_filename}. Shape = {value.shape}")
732
+ np.save(output_filename, value)
733
+
734
+
735
+ def is_triton_3():
736
+ return triton.__version__.startswith("3.")
737
+
738
+
739
+ def maybe_torch_compile(*args, **kwargs):
740
+ """
741
+ torch.compile does not work for triton 2.2.0, which is needed in xlm1's jax.
742
+ Therefore, we disable it here.
743
+ """
744
+
745
+ def decorator(func):
746
+ if is_triton_3():
747
+ return torch.compile(*args, **kwargs)(func)
748
+ return func
749
+
750
+ return decorator
751
+
752
+
753
+ def delete_directory(dirpath):
754
+ try:
755
+ # This will remove the directory and all its contents
756
+ shutil.rmtree(dirpath)
757
+ except OSError as e:
758
+ print(f"Warning: {dirpath} : {e.strerror}")
759
+
760
+
761
+ # Temporary directory for prometheus multiprocess mode
762
+ # Cleaned up automatically when this object is garbage collected
763
+ prometheus_multiproc_dir: tempfile.TemporaryDirectory
764
+
765
+
766
+ def set_prometheus_multiproc_dir():
767
+ # Set prometheus multiprocess directory
768
+ # sglang uses prometheus multiprocess mode
769
+ # we need to set this before importing prometheus_client
770
+ # https://prometheus.github.io/client_python/multiprocess/
771
+ global prometheus_multiproc_dir
772
+
773
+ if "PROMETHEUS_MULTIPROC_DIR" in os.environ:
774
+ logger.debug("User set PROMETHEUS_MULTIPROC_DIR detected.")
775
+ prometheus_multiproc_dir = tempfile.TemporaryDirectory(
776
+ dir=os.environ["PROMETHEUS_MULTIPROC_DIR"]
777
+ )
778
+ else:
779
+ prometheus_multiproc_dir = tempfile.TemporaryDirectory()
780
+ os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
781
+ logger.debug(f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}")
782
+
783
+
784
+ def add_prometheus_middleware(app):
785
+ # We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
786
+ from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess
787
+
788
+ registry = CollectorRegistry()
789
+ multiprocess.MultiProcessCollector(registry)
790
+ metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
791
+
792
+ # Workaround for 307 Redirect for /metrics
793
+ metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
794
+ app.routes.append(metrics_route)
795
+
796
+
797
+ def get_gpu_memory_capacity():
798
+ try:
799
+ # Run nvidia-smi and capture the output
800
+ result = subprocess.run(
801
+ ["nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits"],
802
+ stdout=subprocess.PIPE,
803
+ stderr=subprocess.PIPE,
804
+ text=True,
805
+ )
806
+
807
+ if result.returncode != 0:
808
+ raise RuntimeError(f"nvidia-smi error: {result.stderr.strip()}")
809
+
810
+ # Parse the output to extract memory values
811
+ memory_values = [
812
+ float(mem)
813
+ for mem in result.stdout.strip().split("\n")
814
+ if re.match(r"^\d+(\.\d+)?$", mem.strip())
815
+ ]
816
+
817
+ if not memory_values:
818
+ raise ValueError("No GPU memory values found.")
819
+
820
+ # Return the minimum memory value
821
+ return min(memory_values)
822
+
823
+ except FileNotFoundError:
824
+ raise RuntimeError(
825
+ "nvidia-smi not found. Ensure NVIDIA drivers are installed and accessible."
826
+ )
@@ -320,7 +320,7 @@ jinja_env = jinja2.Environment(
320
320
  _message_template = """
321
321
  <div class="message {{ role }}">
322
322
  <div class="role">
323
- {{ role }}
323
+ {{ role }}
324
324
  {% if variant %}<span class="variant">({{ variant }})</span>{% endif %}
325
325
  </div>
326
326
  <div class="content">
@@ -2,8 +2,8 @@
2
2
 
3
3
  """
4
4
  HumanEval: Evaluating Large Language Models Trained on Code
5
- Mark Chen and Jerry Tworek and Heewoo Jun and Qiming Yuan and Henrique Ponde de Oliveira Pinto and Jared Kaplan and Harri Edwards and Yuri Burda and Nicholas Joseph and Greg Brockman and Alex Ray and Raul Puri and Gretchen Krueger and Michael Petrov and Heidy Khlaaf and Girish Sastry and Pamela Mishkin and Brooke Chan and Scott Gray and Nick Ryder and Mikhail Pavlov and Alethea Power and Lukasz Kaiser and Mohammad Bavarian and Clemens Winter and Philippe Tillet and Felipe Petroski Such and Dave Cummings and Matthias Plappert and Fotios Chantzis and Elizabeth Barnes and Ariel Herbert-Voss and William Hebgen Guss and Alex Nichol and Alex Paino and Nikolas Tezak and Jie Tang and Igor Babuschkin and Suchir Balaji and Shantanu Jain and William Saunders and Christopher Hesse and Andrew N. Carr and Jan Leike and Josh Achiam and Vedant Misra and Evan Morikawa and Alec Radford and Matthew Knight and Miles Brundage and Mira Murati and Katie Mayer and Peter Welinder and Bob McGrew and Dario Amodei and Sam McCandlish and Ilya Sutskever and Wojciech Zaremba
6
- https://arxiv.org/abs/2107.03374 https://github.com/openai/human-eval/
5
+ Mark Chen and Jerry Tworek and Heewoo Jun and Qiming Yuan and Henrique Ponde de Oliveira Pinto and Jared Kaplan and Harri Edwards and Yuri Burda and Nicholas Joseph and Greg Brockman and Alex Ray and Raul Puri and Gretchen Krueger and Michael Petrov and Heidy Khlaaf and Girish Sastry and Pamela Mishkin and Brooke Chan and Scott Gray and Nick Ryder and Mikhail Pavlov and Alethea Power and Lukasz Kaiser and Mohammad Bavarian and Clemens Winter and Philippe Tillet and Felipe Petroski Such and Dave Cummings and Matthias Plappert and Fotios Chantzis and Elizabeth Barnes and Ariel Herbert-Voss and William Hebgen Guss and Alex Nichol and Alex Paino and Nikolas Tezak and Jie Tang and Igor Babuschkin and Suchir Balaji and Shantanu Jain and William Saunders and Christopher Hesse and Andrew N. Carr and Jan Leike and Josh Achiam and Vedant Misra and Evan Morikawa and Alec Radford and Matthew Knight and Miles Brundage and Mira Murati and Katie Mayer and Peter Welinder and Bob McGrew and Dario Amodei and Sam McCandlish and Ilya Sutskever and Wojciech Zaremba
6
+ https://arxiv.org/abs/2107.03374 https://github.com/openai/human-eval/
7
7
  """
8
8
 
9
9
  import random
@@ -1,10 +1,10 @@
1
1
  # Adapted from https://github.com/openai/simple-evals/
2
2
 
3
3
  """
4
- MGSM: Multilingual Grade School Math Benchmark (MGSM) is a benchmark of grade-school math problems.
4
+ MGSM: Multilingual Grade School Math Benchmark (MGSM) is a benchmark of grade-school math problems.
5
5
  Language Models are Multilingual Chain-of-Thought Reasoners
6
6
  Freda Shi, Mirac Suzgun, Markus Freitag, Xuezhi Wang, Suraj Srivats, Soroush Vosoughi, Hyung Won Chung, Yi Tay, Sebastian Ruder, Denny Zhou, Dipanjan Das, Jason Wei
7
- https://arxiv.org/abs/2210.03057 reference: https://github.com/google-research/url-nlp
7
+ https://arxiv.org/abs/2210.03057 reference: https://github.com/google-research/url-nlp
8
8
  """
9
9
 
10
10
  import re