sglang 0.1.15__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. sglang/__init__.py +3 -1
  2. sglang/api.py +5 -0
  3. sglang/global_config.py +4 -1
  4. sglang/lang/chat_template.py +9 -2
  5. sglang/lang/interpreter.py +52 -19
  6. sglang/lang/ir.py +12 -9
  7. sglang/lang/tracer.py +1 -1
  8. sglang/launch_server.py +1 -2
  9. sglang/launch_server_llavavid.py +31 -0
  10. sglang/srt/flush_cache.py +16 -0
  11. sglang/srt/hf_transformers_utils.py +8 -1
  12. sglang/srt/managers/io_struct.py +15 -3
  13. sglang/srt/managers/router/infer_batch.py +31 -19
  14. sglang/srt/managers/router/manager.py +6 -8
  15. sglang/srt/managers/router/model_rpc.py +59 -23
  16. sglang/srt/managers/router/model_runner.py +6 -6
  17. sglang/srt/managers/router/radix_cache.py +47 -17
  18. sglang/srt/managers/router/scheduler.py +17 -28
  19. sglang/srt/managers/tokenizer_manager.py +54 -22
  20. sglang/srt/model_config.py +4 -0
  21. sglang/srt/models/commandr.py +6 -10
  22. sglang/srt/models/dbrx.py +14 -15
  23. sglang/srt/models/gemma.py +7 -10
  24. sglang/srt/models/llama2.py +7 -10
  25. sglang/srt/models/llava.py +2 -6
  26. sglang/srt/models/llavavid.py +307 -0
  27. sglang/srt/models/mixtral.py +7 -13
  28. sglang/srt/models/qwen.py +20 -13
  29. sglang/srt/models/qwen2.py +7 -10
  30. sglang/srt/models/stablelm.py +13 -12
  31. sglang/srt/models/yivl.py +1 -4
  32. sglang/srt/server.py +32 -18
  33. sglang/srt/server_args.py +9 -6
  34. sglang/srt/utils.py +126 -17
  35. sglang/srt/weight_utils.py +66 -51
  36. sglang/utils.py +77 -26
  37. {sglang-0.1.15.dist-info → sglang-0.1.16.dist-info}/METADATA +9 -5
  38. sglang-0.1.16.dist-info/RECORD +72 -0
  39. sglang-0.1.15.dist-info/RECORD +0 -69
  40. {sglang-0.1.15.dist-info → sglang-0.1.16.dist-info}/LICENSE +0 -0
  41. {sglang-0.1.15.dist-info → sglang-0.1.16.dist-info}/WHEEL +0 -0
  42. {sglang-0.1.15.dist-info → sglang-0.1.16.dist-info}/top_level.txt +0 -0
@@ -7,35 +7,31 @@ from typing import Optional, Tuple
7
7
  import torch
8
8
  from torch import nn
9
9
  from transformers import PretrainedConfig
10
+ from vllm.distributed import get_tensor_model_parallel_world_size
10
11
  from vllm.model_executor.layers.activation import SiluAndMul
11
12
  from vllm.model_executor.layers.linear import (
12
13
  MergedColumnParallelLinear,
13
14
  QKVParallelLinear,
14
15
  RowParallelLinear,
15
16
  )
16
- from vllm.model_executor.layers.quantization.base_config import (
17
- QuantizationConfig)
17
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
18
18
  from vllm.model_executor.layers.rotary_embedding import get_rope
19
19
  from vllm.model_executor.layers.vocab_parallel_embedding import (
20
20
  ParallelLMHead,
21
21
  VocabParallelEmbedding,
22
22
  )
23
- from vllm.distributed import (
24
- get_tensor_model_parallel_world_size,
25
- )
26
- from sglang.srt.weight_utils import (
27
- default_weight_loader,
28
- hf_model_weights_iterator,
29
- )
30
23
 
31
24
  from sglang.srt.layers.logits_processor import LogitsProcessor
32
25
  from sglang.srt.layers.radix_attention import RadixAttention
33
26
  from sglang.srt.managers.router.model_runner import InputMetadata
27
+ from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
34
28
 
35
29
 
36
30
  class StablelmMLP(nn.Module):
37
31
  def __init__(
38
- self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,
32
+ self,
33
+ config: PretrainedConfig,
34
+ quant_config: Optional[QuantizationConfig] = None,
39
35
  ) -> None:
40
36
  super().__init__()
41
37
  self.config = config
@@ -48,7 +44,10 @@ class StablelmMLP(nn.Module):
48
44
  quant_config=quant_config,
49
45
  )
50
46
  self.down_proj = RowParallelLinear(
51
- config.intermediate_size, config.hidden_size, bias=False, quant_config=quant_config,
47
+ config.intermediate_size,
48
+ config.hidden_size,
49
+ bias=False,
50
+ quant_config=quant_config,
52
51
  )
53
52
  self.act_fn = SiluAndMul()
54
53
 
@@ -181,7 +180,9 @@ class StablelmDecoderLayer(nn.Module):
181
180
 
182
181
  class StableLMEpochModel(nn.Module):
183
182
  def __init__(
184
- self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,
183
+ self,
184
+ config: PretrainedConfig,
185
+ quant_config: Optional[QuantizationConfig] = None,
185
186
  ) -> None:
186
187
  super().__init__()
187
188
  self.embed_tokens = VocabParallelEmbedding(
sglang/srt/models/yivl.py CHANGED
@@ -6,16 +6,13 @@ from typing import List, Optional
6
6
  import torch
7
7
  import torch.nn as nn
8
8
  from transformers import CLIPVisionModel, LlavaConfig
9
- from sglang.srt.weight_utils import (
10
- default_weight_loader,
11
- hf_model_weights_iterator,
12
- )
13
9
 
14
10
  from sglang.srt.models.llava import (
15
11
  LlavaLlamaForCausalLM,
16
12
  clip_vision_embed_forward,
17
13
  monkey_path_clip_vision_embed_forward,
18
14
  )
15
+ from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
19
16
 
20
17
 
21
18
  class YiVLForCausalLM(LlavaLlamaForCausalLM):
sglang/srt/server.py CHANGED
@@ -20,7 +20,7 @@ import requests
20
20
  import uvicorn
21
21
  import uvloop
22
22
  from fastapi import FastAPI, Request
23
- from fastapi.responses import Response, StreamingResponse
23
+ from fastapi.responses import JSONResponse, Response, StreamingResponse
24
24
 
25
25
  from sglang.backend.runtime_endpoint import RuntimeEndpoint
26
26
  from sglang.srt.constrained import disable_cache
@@ -90,8 +90,11 @@ async def generate_request(obj: GenerateReqInput):
90
90
 
91
91
  return StreamingResponse(stream_results(), media_type="text/event-stream")
92
92
 
93
- ret = await tokenizer_manager.generate_request(obj).__anext__()
94
- return ret
93
+ try:
94
+ ret = await tokenizer_manager.generate_request(obj).__anext__()
95
+ return ret
96
+ except ValueError as e:
97
+ return JSONResponse({"error": str(e)}, status_code=400)
95
98
 
96
99
 
97
100
  @app.post("/v1/completions")
@@ -104,7 +107,7 @@ async def openai_v1_chat_completions(raw_request: Request):
104
107
  return await v1_chat_completions(tokenizer_manager, raw_request)
105
108
 
106
109
 
107
- def launch_server(server_args: ServerArgs, pipe_finish_writer):
110
+ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_args=None):
108
111
  global tokenizer_manager
109
112
 
110
113
  logging.basicConfig(
@@ -137,17 +140,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
137
140
  )
138
141
 
139
142
  # Launch processes
140
- tokenizer_manager = TokenizerManager(server_args, port_args)
143
+ tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
141
144
  pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
142
145
  pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
143
146
 
144
147
  proc_router = mp.Process(
145
148
  target=start_router_process,
146
- args=(
147
- server_args,
148
- port_args,
149
- pipe_router_writer,
150
- ),
149
+ args=(server_args, port_args, pipe_router_writer, model_overide_args),
151
150
  )
152
151
  proc_router.start()
153
152
  proc_detoken = mp.Process(
@@ -167,8 +166,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
167
166
  if router_init_state != "init ok" or detoken_init_state != "init ok":
168
167
  proc_router.kill()
169
168
  proc_detoken.kill()
170
- print(f"Initialization failed. router_init_state: {router_init_state}", flush=True)
171
- print(f"Initialization failed. detoken_init_state: {detoken_init_state}", flush=True)
169
+ print(
170
+ f"Initialization failed. router_init_state: {router_init_state}", flush=True
171
+ )
172
+ print(
173
+ f"Initialization failed. detoken_init_state: {detoken_init_state}",
174
+ flush=True,
175
+ )
172
176
  sys.exit(1)
173
177
  assert proc_router.is_alive() and proc_detoken.is_alive()
174
178
 
@@ -186,6 +190,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
186
190
  time.sleep(0.5)
187
191
  try:
188
192
  requests.get(url + "/get_model_info", timeout=5, headers=headers)
193
+ success = True # Set flag to True if request succeeds
189
194
  break
190
195
  except requests.exceptions.RequestException as e:
191
196
  pass
@@ -202,7 +207,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
202
207
  },
203
208
  },
204
209
  headers=headers,
205
- timeout=60,
210
+ timeout=600,
206
211
  )
207
212
  assert res.status_code == 200
208
213
  except Exception as e:
@@ -232,7 +237,8 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
232
237
  class Runtime:
233
238
  def __init__(
234
239
  self,
235
- log_evel="error",
240
+ log_evel: str = "error",
241
+ model_overide_args: Optional[dict] = None,
236
242
  *args,
237
243
  **kwargs,
238
244
  ):
@@ -241,7 +247,10 @@ class Runtime:
241
247
 
242
248
  # Pre-allocate ports
243
249
  self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
244
- self.server_args.port, self.server_args.additional_ports, self.server_args.tp_size)
250
+ self.server_args.port,
251
+ self.server_args.additional_ports,
252
+ self.server_args.tp_size,
253
+ )
245
254
 
246
255
  self.url = self.server_args.url()
247
256
  self.generate_url = (
@@ -250,7 +259,10 @@ class Runtime:
250
259
 
251
260
  self.pid = None
252
261
  pipe_reader, pipe_writer = mp.Pipe(duplex=False)
253
- proc = mp.Process(target=launch_server, args=(self.server_args, pipe_writer))
262
+ proc = mp.Process(
263
+ target=launch_server,
264
+ args=(self.server_args, pipe_writer, model_overide_args),
265
+ )
254
266
  proc.start()
255
267
  pipe_writer.close()
256
268
  self.pid = proc.pid
@@ -262,7 +274,9 @@ class Runtime:
262
274
 
263
275
  if init_state != "init ok":
264
276
  self.shutdown()
265
- raise RuntimeError("Initialization failed. Please see the error messages above.")
277
+ raise RuntimeError(
278
+ "Initialization failed. Please see the error messages above."
279
+ )
266
280
 
267
281
  self.endpoint = RuntimeEndpoint(self.url)
268
282
 
@@ -314,4 +328,4 @@ class Runtime:
314
328
  pos += len(cur)
315
329
 
316
330
  def __del__(self):
317
- self.shutdown()
331
+ self.shutdown()
sglang/srt/server_args.py CHANGED
@@ -80,10 +80,12 @@ class ServerArgs:
80
80
  default=ServerArgs.tokenizer_path,
81
81
  help="The path of the tokenizer.",
82
82
  )
83
- parser.add_argument("--host", type=str, default=ServerArgs.host,
84
- help="The host of the server.")
85
- parser.add_argument("--port", type=int, default=ServerArgs.port,
86
- help="The port of the server.")
83
+ parser.add_argument(
84
+ "--host", type=str, default=ServerArgs.host, help="The host of the server."
85
+ )
86
+ parser.add_argument(
87
+ "--port", type=int, default=ServerArgs.port, help="The port of the server."
88
+ )
87
89
  parser.add_argument(
88
90
  "--additional-ports",
89
91
  type=int,
@@ -149,7 +151,8 @@ class ServerArgs:
149
151
  "--schedule-heuristic",
150
152
  type=str,
151
153
  default=ServerArgs.schedule_heuristic,
152
- help="Schudule mode: [lpm, weight, random, fcfs]",
154
+ choices=["lpm", "random", "fcfs", "dfs-weight"],
155
+ help="Scheduling Heuristic.",
153
156
  )
154
157
  parser.add_argument(
155
158
  "--schedule-conservativeness",
@@ -260,4 +263,4 @@ class PortArgs:
260
263
  router_port: int
261
264
  detokenizer_port: int
262
265
  nccl_port: int
263
- model_rpc_ports: List[int]
266
+ model_rpc_ports: List[int]
sglang/srt/utils.py CHANGED
@@ -4,9 +4,7 @@ import base64
4
4
  import os
5
5
  import random
6
6
  import socket
7
- import sys
8
7
  import time
9
- import traceback
10
8
  from importlib.metadata import PackageNotFoundError, version
11
9
  from io import BytesIO
12
10
  from typing import List, Optional
@@ -20,6 +18,8 @@ from packaging import version as pkg_version
20
18
  from pydantic import BaseModel
21
19
  from starlette.middleware.base import BaseHTTPMiddleware
22
20
 
21
+ from sglang.utils import get_exception_traceback
22
+
23
23
  show_time_cost = False
24
24
  time_infos = {}
25
25
 
@@ -90,6 +90,32 @@ def calculate_time(show=False, min_cost_ms=0.0):
90
90
  return wrapper
91
91
 
92
92
 
93
+ def get_available_gpu_memory(gpu_id, distributed=True):
94
+ """
95
+ Get available memory for cuda:gpu_id device.
96
+ When distributed is True, the available memory is the minimum available memory of all GPUs.
97
+ """
98
+ num_gpus = torch.cuda.device_count()
99
+ assert gpu_id < num_gpus
100
+
101
+ if torch.cuda.current_device() != gpu_id:
102
+ print(
103
+ f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
104
+ "which may cause useless memory allocation for torch CUDA context.",
105
+ )
106
+
107
+ free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
108
+
109
+ if distributed:
110
+ tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
111
+ torch.device("cuda", gpu_id)
112
+ )
113
+ torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
114
+ free_gpu_memory = tensor.item()
115
+
116
+ return free_gpu_memory / (1 << 30)
117
+
118
+
93
119
  def set_random_seed(seed: int) -> None:
94
120
  random.seed(seed)
95
121
 
@@ -105,11 +131,13 @@ def alloc_usable_network_port(num, used_list=()):
105
131
  continue
106
132
 
107
133
  with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
134
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
108
135
  try:
109
136
  s.bind(("", port))
137
+ s.listen(1) # Attempt to listen on the port
110
138
  port_list.append(port)
111
139
  except socket.error:
112
- pass
140
+ pass # If any error occurs, this port is not usable
113
141
 
114
142
  if len(port_list) == num:
115
143
  return port_list
@@ -158,12 +186,6 @@ def allocate_init_ports(
158
186
  return port, additional_ports
159
187
 
160
188
 
161
- def get_exception_traceback():
162
- etype, value, tb = sys.exc_info()
163
- err_str = "".join(traceback.format_exception(etype, value, tb))
164
- return err_str
165
-
166
-
167
189
  def get_int_token_logit_bias(tokenizer, vocab_size):
168
190
  # a bug when model's vocab size > tokenizer.vocab_size
169
191
  vocab_size = tokenizer.vocab_size
@@ -245,20 +267,102 @@ def wrap_kernel_launcher(kernel):
245
267
 
246
268
 
247
269
  def is_multimodal_model(model):
248
- if isinstance(model, str):
249
- return "llava" in model or "yi-vl" in model
250
270
  from sglang.srt.model_config import ModelConfig
251
271
 
272
+ if isinstance(model, str):
273
+ model = model.lower()
274
+ return "llava" in model or "yi-vl" in model or "llava-next" in model
275
+
252
276
  if isinstance(model, ModelConfig):
253
277
  model_path = model.path.lower()
254
- return "llava" in model_path or "yi-vl" in model_path
255
- raise Exception("unrecognized type")
278
+ return "llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path
279
+
280
+ raise ValueError("unrecognized type")
281
+
282
+
283
+ def decode_video_base64(video_base64):
284
+ from PIL import Image
285
+
286
+ # Decode the base64 string
287
+ video_bytes = base64.b64decode(video_base64)
288
+
289
+ # Placeholder for the start indices of each PNG image
290
+ img_starts = []
291
+
292
+ frame_format = "PNG" # str(os.getenv('FRAME_FORMAT', "JPEG"))
293
+
294
+ assert frame_format in [
295
+ "PNG",
296
+ "JPEG",
297
+ ], "FRAME_FORMAT must be either 'PNG' or 'JPEG'"
298
+
299
+ if frame_format == "PNG":
300
+ # Find each PNG start signature to isolate images
301
+ i = 0
302
+ while i < len(video_bytes) - 7: # Adjusted for the length of the PNG signature
303
+ # Check if we found the start of a PNG file
304
+ if (
305
+ video_bytes[i] == 0x89
306
+ and video_bytes[i + 1] == 0x50
307
+ and video_bytes[i + 2] == 0x4E
308
+ and video_bytes[i + 3] == 0x47
309
+ and video_bytes[i + 4] == 0x0D
310
+ and video_bytes[i + 5] == 0x0A
311
+ and video_bytes[i + 6] == 0x1A
312
+ and video_bytes[i + 7] == 0x0A
313
+ ):
314
+ img_starts.append(i)
315
+ i += 8 # Skip the PNG signature
316
+ else:
317
+ i += 1
318
+ else:
319
+ # Find each JPEG start (0xFFD8) to isolate images
320
+ i = 0
321
+ while (
322
+ i < len(video_bytes) - 1
323
+ ): # Adjusted for the length of the JPEG SOI signature
324
+ # Check if we found the start of a JPEG file
325
+ if video_bytes[i] == 0xFF and video_bytes[i + 1] == 0xD8:
326
+ img_starts.append(i)
327
+ # Move to the next byte to continue searching for the next image start
328
+ i += 2
329
+ else:
330
+ i += 1
331
+
332
+ frames = []
333
+ for start_idx in img_starts:
334
+ # Assuming each image is back-to-back, the end of one image is the start of another
335
+ # The last image goes until the end of the byte string
336
+ end_idx = (
337
+ img_starts[img_starts.index(start_idx) + 1]
338
+ if img_starts.index(start_idx) + 1 < len(img_starts)
339
+ else len(video_bytes)
340
+ )
341
+ img_bytes = video_bytes[start_idx:end_idx]
342
+
343
+ # Convert bytes to a PIL Image
344
+ img = Image.open(BytesIO(img_bytes))
345
+
346
+ # Convert PIL Image to a NumPy array
347
+ frame = np.array(img)
348
+
349
+ # Append the frame to the list of frames
350
+ frames.append(frame)
351
+
352
+ # Ensure there's at least one frame to avoid errors with np.stack
353
+ if frames:
354
+ return np.stack(frames, axis=0), img.size
355
+ else:
356
+ return np.array([]), (
357
+ 0,
358
+ 0,
359
+ ) # Return an empty array and size tuple if no frames were found
256
360
 
257
361
 
258
362
  def load_image(image_file):
259
363
  from PIL import Image
260
364
 
261
- image = None
365
+ image = image_size = None
262
366
 
263
367
  if image_file.startswith("http://") or image_file.startswith("https://"):
264
368
  timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
@@ -269,10 +373,13 @@ def load_image(image_file):
269
373
  elif image_file.startswith("data:"):
270
374
  image_file = image_file.split(",")[1]
271
375
  image = Image.open(BytesIO(base64.b64decode(image_file)))
376
+ elif image_file.startswith("video:"):
377
+ image_file = image_file.replace("video:", "")
378
+ image, image_size = decode_video_base64(image_file)
272
379
  else:
273
380
  image = Image.open(BytesIO(base64.b64decode(image_file)))
274
381
 
275
- return image
382
+ return image, image_size
276
383
 
277
384
 
278
385
  def assert_pkg_version(pkg: str, min_version: str):
@@ -284,7 +391,9 @@ def assert_pkg_version(pkg: str, min_version: str):
284
391
  f"is less than the minimum required version {min_version}"
285
392
  )
286
393
  except PackageNotFoundError:
287
- raise Exception(f"{pkg} with minimum required version {min_version} is not installed")
394
+ raise Exception(
395
+ f"{pkg} with minimum required version {min_version} is not installed"
396
+ )
288
397
 
289
398
 
290
399
  API_KEY_HEADER_NAME = "X-API-Key"
@@ -314,4 +423,4 @@ IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
314
423
  def jsonify_pydantic_model(obj: BaseModel):
315
424
  if IS_PYDANTIC_1:
316
425
  return obj.json(ensure_ascii=False)
317
- return obj.model_dump_json()
426
+ return obj.model_dump_json()