sglang 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. sglang/__init__.py +5 -1
  2. sglang/api.py +8 -3
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +148 -12
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/global_config.py +11 -1
  8. sglang/lang/chat_template.py +9 -2
  9. sglang/lang/interpreter.py +161 -81
  10. sglang/lang/ir.py +29 -11
  11. sglang/lang/tracer.py +1 -1
  12. sglang/launch_server.py +1 -2
  13. sglang/launch_server_llavavid.py +31 -0
  14. sglang/srt/constrained/fsm_cache.py +3 -0
  15. sglang/srt/flush_cache.py +16 -0
  16. sglang/srt/hf_transformers_utils.py +83 -2
  17. sglang/srt/layers/extend_attention.py +17 -0
  18. sglang/srt/layers/fused_moe.py +485 -0
  19. sglang/srt/layers/logits_processor.py +12 -7
  20. sglang/srt/layers/radix_attention.py +10 -3
  21. sglang/srt/layers/token_attention.py +16 -1
  22. sglang/srt/managers/controller/dp_worker.py +110 -0
  23. sglang/srt/managers/controller/infer_batch.py +619 -0
  24. sglang/srt/managers/controller/manager_multi.py +191 -0
  25. sglang/srt/managers/controller/manager_single.py +97 -0
  26. sglang/srt/managers/controller/model_runner.py +462 -0
  27. sglang/srt/managers/controller/radix_cache.py +267 -0
  28. sglang/srt/managers/controller/schedule_heuristic.py +59 -0
  29. sglang/srt/managers/controller/tp_worker.py +791 -0
  30. sglang/srt/managers/detokenizer_manager.py +45 -45
  31. sglang/srt/managers/io_struct.py +26 -10
  32. sglang/srt/managers/router/infer_batch.py +130 -74
  33. sglang/srt/managers/router/manager.py +7 -9
  34. sglang/srt/managers/router/model_rpc.py +224 -135
  35. sglang/srt/managers/router/model_runner.py +94 -107
  36. sglang/srt/managers/router/radix_cache.py +54 -18
  37. sglang/srt/managers/router/scheduler.py +23 -34
  38. sglang/srt/managers/tokenizer_manager.py +183 -88
  39. sglang/srt/model_config.py +5 -2
  40. sglang/srt/models/commandr.py +15 -22
  41. sglang/srt/models/dbrx.py +22 -29
  42. sglang/srt/models/gemma.py +14 -24
  43. sglang/srt/models/grok.py +671 -0
  44. sglang/srt/models/llama2.py +24 -23
  45. sglang/srt/models/llava.py +85 -25
  46. sglang/srt/models/llavavid.py +298 -0
  47. sglang/srt/models/mixtral.py +254 -130
  48. sglang/srt/models/mixtral_quant.py +373 -0
  49. sglang/srt/models/qwen.py +28 -25
  50. sglang/srt/models/qwen2.py +17 -22
  51. sglang/srt/models/stablelm.py +21 -26
  52. sglang/srt/models/yivl.py +17 -25
  53. sglang/srt/openai_api_adapter.py +140 -95
  54. sglang/srt/openai_protocol.py +10 -1
  55. sglang/srt/server.py +101 -52
  56. sglang/srt/server_args.py +59 -11
  57. sglang/srt/utils.py +242 -75
  58. sglang/test/test_programs.py +44 -0
  59. sglang/test/test_utils.py +32 -1
  60. sglang/utils.py +95 -26
  61. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/METADATA +23 -13
  62. sglang-0.1.17.dist-info/RECORD +81 -0
  63. sglang/srt/backend_config.py +0 -13
  64. sglang/srt/models/dbrx_config.py +0 -281
  65. sglang/srt/weight_utils.py +0 -402
  66. sglang-0.1.15.dist-info/RECORD +0 -69
  67. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
  68. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
  69. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
sglang/srt/utils.py CHANGED
@@ -1,25 +1,31 @@
1
1
  """Common utilities."""
2
2
 
3
3
  import base64
4
+ import multiprocessing
5
+ import logging
4
6
  import os
5
7
  import random
6
8
  import socket
7
- import sys
8
9
  import time
9
- import traceback
10
10
  from importlib.metadata import PackageNotFoundError, version
11
11
  from io import BytesIO
12
12
  from typing import List, Optional
13
13
 
14
14
  import numpy as np
15
- import pydantic
15
+ import psutil
16
16
  import requests
17
+ import rpyc
17
18
  import torch
19
+ import triton
20
+ from rpyc.utils.server import ThreadedServer
18
21
  from fastapi.responses import JSONResponse
19
22
  from packaging import version as pkg_version
20
- from pydantic import BaseModel
21
23
  from starlette.middleware.base import BaseHTTPMiddleware
22
24
 
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
23
29
  show_time_cost = False
24
30
  time_infos = {}
25
31
 
@@ -90,37 +96,49 @@ def calculate_time(show=False, min_cost_ms=0.0):
90
96
  return wrapper
91
97
 
92
98
 
93
- def set_random_seed(seed: int) -> None:
94
- random.seed(seed)
99
+ def get_available_gpu_memory(gpu_id, distributed=False):
100
+ """
101
+ Get available memory for cuda:gpu_id device.
102
+ When distributed is True, the available memory is the minimum available memory of all GPUs.
103
+ """
104
+ num_gpus = torch.cuda.device_count()
105
+ assert gpu_id < num_gpus
95
106
 
96
- torch.manual_seed(seed)
97
- if torch.cuda.is_available():
98
- torch.cuda.manual_seed_all(seed)
107
+ if torch.cuda.current_device() != gpu_id:
108
+ print(
109
+ f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
110
+ "which may cause useless memory allocation for torch CUDA context.",
111
+ )
112
+
113
+ torch.cuda.empty_cache()
114
+ free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
99
115
 
116
+ if distributed:
117
+ tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
118
+ torch.device("cuda", gpu_id)
119
+ )
120
+ torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
121
+ free_gpu_memory = tensor.item()
100
122
 
101
- def alloc_usable_network_port(num, used_list=()):
102
- port_list = []
103
- for port in range(10000, 65536):
104
- if port in used_list:
105
- continue
123
+ return free_gpu_memory / (1 << 30)
106
124
 
107
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
108
- try:
109
- s.bind(("", port))
110
- port_list.append(port)
111
- except socket.error:
112
- pass
113
125
 
114
- if len(port_list) == num:
115
- return port_list
116
- return None
126
+ def set_random_seed(seed: int) -> None:
127
+ """Set the random seed for all libraries."""
128
+ random.seed(seed)
129
+ np.random.seed(seed)
130
+ torch.manual_seed(seed)
131
+ if torch.cuda.is_available():
132
+ torch.cuda.manual_seed_all(seed)
117
133
 
118
134
 
119
- def check_port(port):
135
+ def is_port_available(port):
136
+ """Return whether a port is available."""
120
137
  with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
121
138
  try:
122
139
  s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
123
140
  s.bind(("", port))
141
+ s.listen(1)
124
142
  return True
125
143
  except socket.error:
126
144
  return False
@@ -130,41 +148,34 @@ def allocate_init_ports(
130
148
  port: Optional[int] = None,
131
149
  additional_ports: Optional[List[int]] = None,
132
150
  tp_size: int = 1,
151
+ dp_size: int = 1,
133
152
  ):
134
- port = 30000 if port is None else port
135
- additional_ports = [] if additional_ports is None else additional_ports
136
- additional_ports = (
137
- [additional_ports] if isinstance(additional_ports, int) else additional_ports
138
- )
139
- # first check on server port
140
- if not check_port(port):
141
- new_port = alloc_usable_network_port(1, used_list=[port])[0]
142
- print(f"WARNING: Port {port} is not available. Use {new_port} instead.")
143
- port = new_port
144
-
145
- # then we check on additional ports
146
- additional_unique_ports = set(additional_ports) - {port}
147
- # filter out ports that are already in use
148
- can_use_ports = [port for port in additional_unique_ports if check_port(port)]
149
-
150
- num_specified_ports = len(can_use_ports)
151
- if num_specified_ports < 4 + tp_size:
152
- addtional_can_use_ports = alloc_usable_network_port(
153
- num=4 + tp_size - num_specified_ports, used_list=can_use_ports + [port]
154
- )
155
- can_use_ports.extend(addtional_can_use_ports)
153
+ """Allocate ports for all connections."""
154
+ if additional_ports:
155
+ ret_ports = [port] + additional_ports
156
+ else:
157
+ ret_ports = [port]
158
+
159
+ ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
160
+ cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000
156
161
 
157
- additional_ports = can_use_ports[: 4 + tp_size]
158
- return port, additional_ports
162
+ # HTTP + Tokenizer + Controller + Detokenizer + dp_size * (nccl + tp_size)
163
+ num_ports_needed = 4 + dp_size * (1 + tp_size)
164
+ while len(ret_ports) < num_ports_needed:
165
+ if cur_port not in ret_ports and is_port_available(cur_port):
166
+ ret_ports.append(cur_port)
167
+ cur_port += 1
159
168
 
169
+ if port is not None and ret_ports[0] != port:
170
+ logger.warn(
171
+ f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
172
+ )
160
173
 
161
- def get_exception_traceback():
162
- etype, value, tb = sys.exc_info()
163
- err_str = "".join(traceback.format_exception(etype, value, tb))
164
- return err_str
174
+ return ret_ports[0], ret_ports[1:num_ports_needed]
165
175
 
166
176
 
167
177
  def get_int_token_logit_bias(tokenizer, vocab_size):
178
+ """Get the logit bias for integer-only tokens."""
168
179
  # a bug when model's vocab size > tokenizer.vocab_size
169
180
  vocab_size = tokenizer.vocab_size
170
181
  logit_bias = np.zeros(vocab_size, dtype=np.float32)
@@ -178,14 +189,11 @@ def get_int_token_logit_bias(tokenizer, vocab_size):
178
189
 
179
190
  def wrap_kernel_launcher(kernel):
180
191
  """A faster launcher for triton kernels."""
181
- import torch.distributed as dist
182
-
183
- if dist.is_initialized():
184
- rank = dist.get_rank()
185
- else:
186
- rank = 0
192
+ if int(triton.__version__.split(".")[0]) >= 3:
193
+ return None
187
194
 
188
- kernels = kernel.cache[rank].values()
195
+ gpu_id = torch.cuda.current_device()
196
+ kernels = kernel.cache[gpu_id].values()
189
197
  kernel = next(iter(kernels))
190
198
 
191
199
  # Different trition versions use different low-level names
@@ -245,20 +253,104 @@ def wrap_kernel_launcher(kernel):
245
253
 
246
254
 
247
255
  def is_multimodal_model(model):
248
- if isinstance(model, str):
249
- return "llava" in model or "yi-vl" in model
250
256
  from sglang.srt.model_config import ModelConfig
251
257
 
258
+ if isinstance(model, str):
259
+ model = model.lower()
260
+ return "llava" in model or "yi-vl" in model or "llava-next" in model
261
+
252
262
  if isinstance(model, ModelConfig):
253
263
  model_path = model.path.lower()
254
- return "llava" in model_path or "yi-vl" in model_path
255
- raise Exception("unrecognized type")
264
+ return (
265
+ "llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path
266
+ )
267
+
268
+ raise ValueError("unrecognized type")
269
+
270
+
271
+ def decode_video_base64(video_base64):
272
+ from PIL import Image
273
+
274
+ # Decode the base64 string
275
+ video_bytes = base64.b64decode(video_base64)
276
+
277
+ # Placeholder for the start indices of each PNG image
278
+ img_starts = []
279
+
280
+ frame_format = "PNG" # str(os.getenv('FRAME_FORMAT', "JPEG"))
281
+
282
+ assert frame_format in [
283
+ "PNG",
284
+ "JPEG",
285
+ ], "FRAME_FORMAT must be either 'PNG' or 'JPEG'"
286
+
287
+ if frame_format == "PNG":
288
+ # Find each PNG start signature to isolate images
289
+ i = 0
290
+ while i < len(video_bytes) - 7: # Adjusted for the length of the PNG signature
291
+ # Check if we found the start of a PNG file
292
+ if (
293
+ video_bytes[i] == 0x89
294
+ and video_bytes[i + 1] == 0x50
295
+ and video_bytes[i + 2] == 0x4E
296
+ and video_bytes[i + 3] == 0x47
297
+ and video_bytes[i + 4] == 0x0D
298
+ and video_bytes[i + 5] == 0x0A
299
+ and video_bytes[i + 6] == 0x1A
300
+ and video_bytes[i + 7] == 0x0A
301
+ ):
302
+ img_starts.append(i)
303
+ i += 8 # Skip the PNG signature
304
+ else:
305
+ i += 1
306
+ else:
307
+ # Find each JPEG start (0xFFD8) to isolate images
308
+ i = 0
309
+ while (
310
+ i < len(video_bytes) - 1
311
+ ): # Adjusted for the length of the JPEG SOI signature
312
+ # Check if we found the start of a JPEG file
313
+ if video_bytes[i] == 0xFF and video_bytes[i + 1] == 0xD8:
314
+ img_starts.append(i)
315
+ # Move to the next byte to continue searching for the next image start
316
+ i += 2
317
+ else:
318
+ i += 1
319
+
320
+ frames = []
321
+ for start_idx in img_starts:
322
+ # Assuming each image is back-to-back, the end of one image is the start of another
323
+ # The last image goes until the end of the byte string
324
+ end_idx = (
325
+ img_starts[img_starts.index(start_idx) + 1]
326
+ if img_starts.index(start_idx) + 1 < len(img_starts)
327
+ else len(video_bytes)
328
+ )
329
+ img_bytes = video_bytes[start_idx:end_idx]
330
+
331
+ # Convert bytes to a PIL Image
332
+ img = Image.open(BytesIO(img_bytes))
333
+
334
+ # Convert PIL Image to a NumPy array
335
+ frame = np.array(img)
336
+
337
+ # Append the frame to the list of frames
338
+ frames.append(frame)
339
+
340
+ # Ensure there's at least one frame to avoid errors with np.stack
341
+ if frames:
342
+ return np.stack(frames, axis=0), img.size
343
+ else:
344
+ return np.array([]), (
345
+ 0,
346
+ 0,
347
+ ) # Return an empty array and size tuple if no frames were found
256
348
 
257
349
 
258
350
  def load_image(image_file):
259
351
  from PIL import Image
260
352
 
261
- image = None
353
+ image = image_size = None
262
354
 
263
355
  if image_file.startswith("http://") or image_file.startswith("https://"):
264
356
  timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
@@ -269,10 +361,71 @@ def load_image(image_file):
269
361
  elif image_file.startswith("data:"):
270
362
  image_file = image_file.split(",")[1]
271
363
  image = Image.open(BytesIO(base64.b64decode(image_file)))
364
+ elif image_file.startswith("video:"):
365
+ image_file = image_file.replace("video:", "")
366
+ image, image_size = decode_video_base64(image_file)
272
367
  else:
273
368
  image = Image.open(BytesIO(base64.b64decode(image_file)))
274
369
 
275
- return image
370
+ return image, image_size
371
+
372
+
373
+ def init_rpyc_service(service: rpyc.Service, port: int):
374
+ t = ThreadedServer(
375
+ service=service,
376
+ port=port,
377
+ protocol_config={
378
+ "allow_public_attrs": True,
379
+ "allow_pickle": True,
380
+ "sync_request_timeout": 3600
381
+ },
382
+ )
383
+ t.logger.setLevel(logging.WARN)
384
+ t.start()
385
+
386
+
387
+ def connect_to_rpyc_service(port, host="localhost"):
388
+ time.sleep(1)
389
+
390
+ repeat_count = 0
391
+ while repeat_count < 20:
392
+ try:
393
+ con = rpyc.connect(
394
+ host,
395
+ port,
396
+ config={
397
+ "allow_public_attrs": True,
398
+ "allow_pickle": True,
399
+ "sync_request_timeout": 3600
400
+ },
401
+ )
402
+ break
403
+ except ConnectionRefusedError:
404
+ time.sleep(1)
405
+ repeat_count += 1
406
+ if repeat_count == 20:
407
+ raise RuntimeError("init rpc env error!")
408
+
409
+ return con.root
410
+
411
+
412
+ def start_rpyc_process(service: rpyc.Service, port: int):
413
+ # Return the proxy and the process
414
+ proc = multiprocessing.Process(target=init_rpyc_service, args=(service, port))
415
+ proc.start()
416
+ proxy = connect_to_rpyc_service(port)
417
+ assert proc.is_alive()
418
+ return proxy, proc
419
+
420
+
421
+ def suppress_other_loggers():
422
+ from vllm.logger import logger as vllm_default_logger
423
+
424
+ vllm_default_logger.setLevel(logging.WARN)
425
+ logging.getLogger("vllm.config").setLevel(logging.ERROR)
426
+ logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(logging.WARN)
427
+ logging.getLogger("vllm.selector").setLevel(logging.WARN)
428
+ logging.getLogger("vllm.utils").setLevel(logging.WARN)
276
429
 
277
430
 
278
431
  def assert_pkg_version(pkg: str, min_version: str):
@@ -284,7 +437,30 @@ def assert_pkg_version(pkg: str, min_version: str):
284
437
  f"is less than the minimum required version {min_version}"
285
438
  )
286
439
  except PackageNotFoundError:
287
- raise Exception(f"{pkg} with minimum required version {min_version} is not installed")
440
+ raise Exception(
441
+ f"{pkg} with minimum required version {min_version} is not installed"
442
+ )
443
+
444
+
445
+ def kill_parent_process():
446
+ """Kill the parent process and all children of the parent process."""
447
+ current_process = psutil.Process()
448
+ parent_process = current_process.parent()
449
+ children = current_process.children(recursive=True)
450
+ for child in children:
451
+ if child.pid != current_process.pid:
452
+ os.kill(child.pid, 9)
453
+ os.kill(parent_process.pid, 9)
454
+
455
+
456
+ def monkey_patch_vllm_p2p_access_check():
457
+ """
458
+ Monkey patch the slow p2p access check in vllm.
459
+ NOTE: We assume the p2p access is always allowed, which can be wrong for some setups.
460
+ """
461
+ import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt
462
+
463
+ setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
288
464
 
289
465
 
290
466
  API_KEY_HEADER_NAME = "X-API-Key"
@@ -306,12 +482,3 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
306
482
  response = await call_next(request)
307
483
  return response
308
484
 
309
-
310
- # FIXME: Remove this once we drop support for pydantic 1.x
311
- IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
312
-
313
-
314
- def jsonify_pydantic_model(obj: BaseModel):
315
- if IS_PYDANTIC_1:
316
- return obj.json(ensure_ascii=False)
317
- return obj.model_dump_json()
@@ -304,6 +304,7 @@ def test_image_qa():
304
304
  temperature=0,
305
305
  max_new_tokens=64,
306
306
  )
307
+
307
308
  assert (
308
309
  "taxi" in state.messages()[-1]["content"]
309
310
  or "car" in state.messages()[-1]["content"]
@@ -349,3 +350,46 @@ def test_regex():
349
350
  state = regex_gen.run()
350
351
  answer = state["answer"]
351
352
  assert re.match(regex, answer)
353
+
354
+
355
+ def test_completion_speculative():
356
+ @sgl.function(num_api_spec_tokens=64)
357
+ def gen_character_spec(s):
358
+ s += "Construct a character within the following format:\n"
359
+ s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
360
+ s += "\nPlease generate new Name, Birthday and Job.\n"
361
+ s += "Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n")
362
+ s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
363
+
364
+
365
+ @sgl.function
366
+ def gen_character_no_spec(s):
367
+ s += "Construct a character within the following format:\n"
368
+ s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
369
+ s += "\nPlease generate new Name, Birthday and Job.\n"
370
+ s += "Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n")
371
+ s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
372
+
373
+ token_usage = sgl.global_config.default_backend.token_usage
374
+
375
+ token_usage.reset()
376
+ gen_character_spec().sync()
377
+ usage_with_spec = token_usage.prompt_tokens
378
+
379
+ token_usage.reset()
380
+ gen_character_no_spec().sync()
381
+ usage_with_no_spec = token_usage.prompt_tokens
382
+
383
+ assert usage_with_spec < usage_with_no_spec, f"{usage_with_spec} vs {usage_with_no_spec}"
384
+
385
+
386
+ def test_chat_completion_speculative():
387
+ @sgl.function(num_api_spec_tokens=256)
388
+ def gen_character_spec(s):
389
+ s += sgl.system("You are a helpful assistant.")
390
+ s += sgl.user("Construct a character within the following format:")
391
+ s += sgl.assistant("Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n")
392
+ s += sgl.user("Please generate new Name, Birthday and Job.\n")
393
+ s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
394
+
395
+ gen_character_spec().sync()
sglang/test/test_utils.py CHANGED
@@ -9,7 +9,7 @@ import requests
9
9
  from sglang.backend.openai import OpenAI
10
10
  from sglang.backend.runtime_endpoint import RuntimeEndpoint
11
11
  from sglang.global_config import global_config
12
- from sglang.srt.utils import get_exception_traceback
12
+ from sglang.utils import get_exception_traceback
13
13
 
14
14
 
15
15
  def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
@@ -88,6 +88,33 @@ def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None):
88
88
  return pred
89
89
 
90
90
 
91
+ def call_generate_ginfer(prompt, temperature, max_tokens, stop=None, url=None):
92
+ import grpc
93
+ from ginfer import sampler_pb2, sampler_pb2_grpc
94
+
95
+ sampler_channel = grpc.insecure_channel(url.replace("http://", ""))
96
+ sampler = sampler_pb2_grpc.SamplerStub(sampler_channel)
97
+
98
+ if stop is None:
99
+ stop_strings = None
100
+ else:
101
+ stop_strings = [stop]
102
+
103
+ sample_request = sampler_pb2.SampleTextRequest(
104
+ prompt=prompt,
105
+ settings=sampler_pb2.SampleSettings(
106
+ max_len=max_tokens,
107
+ rng_seed=0,
108
+ temperature=max(temperature, 1e-7),
109
+ nucleus_p=1,
110
+ stop_strings=stop_strings,
111
+ ),
112
+ )
113
+ stream = sampler.SampleText(sample_request)
114
+ response = "".join([x.text for x in stream])
115
+ return response
116
+
117
+
91
118
  def call_generate_guidance(
92
119
  prompt, temperature, max_tokens, stop=None, n=1, regex=None, model=None
93
120
  ):
@@ -228,6 +255,7 @@ def add_common_other_args_and_parse(parser):
228
255
  "vllm",
229
256
  "outlines",
230
257
  "lightllm",
258
+ "ginfer",
231
259
  "guidance",
232
260
  "lmql",
233
261
  "srt-raw",
@@ -248,6 +276,7 @@ def add_common_other_args_and_parse(parser):
248
276
  "lightllm": 22000,
249
277
  "lmql": 23000,
250
278
  "srt-raw": 30000,
279
+ "ginfer": 9988,
251
280
  }
252
281
  args.port = default_port.get(args.backend, None)
253
282
  return args
@@ -283,6 +312,8 @@ def _get_call_generate(args):
283
312
  return partial(call_generate_vllm, url=f"{args.host}:{args.port}/generate")
284
313
  elif args.backend == "srt-raw":
285
314
  return partial(call_generate_srt_raw, url=f"{args.host}:{args.port}/generate")
315
+ elif args.backend == "ginfer":
316
+ return partial(call_generate_ginfer, url=f"{args.host}:{args.port}")
286
317
  elif args.backend == "outlines":
287
318
  return partial(call_generate_outlines, url=f"{args.host}:{args.port}/generate")
288
319
  elif args.backend == "guidance":
sglang/utils.py CHANGED
@@ -2,40 +2,27 @@
2
2
 
3
3
  import base64
4
4
  import json
5
+ import logging
6
+ import signal
7
+ import sys
5
8
  import threading
9
+ import traceback
6
10
  import urllib.request
11
+ from concurrent.futures import ThreadPoolExecutor
7
12
  from io import BytesIO
8
13
  from json import dumps
9
14
 
15
+ import numpy as np
10
16
  import requests
11
17
 
12
18
 
13
- def get_available_gpu_memory(gpu_id, distributed=True):
14
- """
15
- Get available memory for cuda:gpu_id device.
16
- When distributed is True, the available memory is the minimum available memory of all GPUs.
17
- """
18
- import torch
19
+ logger = logging.getLogger(__name__)
19
20
 
20
- num_gpus = torch.cuda.device_count()
21
- assert gpu_id < num_gpus
22
21
 
23
- if torch.cuda.current_device() != gpu_id:
24
- print(
25
- f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
26
- "which may cause useless memory allocation for torch CUDA context.",
27
- )
28
-
29
- free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
30
-
31
- if distributed:
32
- tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
33
- torch.device("cuda", gpu_id)
34
- )
35
- torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
36
- free_gpu_memory = tensor.item()
37
-
38
- return free_gpu_memory / (1 << 30)
22
+ def get_exception_traceback():
23
+ etype, value, tb = sys.exc_info()
24
+ err_str = "".join(traceback.format_exception(etype, value, tb))
25
+ return err_str
39
26
 
40
27
 
41
28
  def is_same_type(values):
@@ -110,8 +97,12 @@ def http_request(
110
97
  data = None
111
98
  else:
112
99
  data = bytes(dumps(json), encoding="utf-8")
113
- resp = urllib.request.urlopen(req, data=data, cafile=verify)
114
- return HttpResponse(resp)
100
+
101
+ try:
102
+ resp = urllib.request.urlopen(req, data=data, cafile=verify)
103
+ return HttpResponse(resp)
104
+ except urllib.error.HTTPError as e:
105
+ return HttpResponse(e)
115
106
 
116
107
 
117
108
  def encode_image_base64(image_path):
@@ -130,6 +121,75 @@ def encode_image_base64(image_path):
130
121
  return base64.b64encode(buffered.getvalue()).decode("utf-8")
131
122
 
132
123
 
124
+ def encode_frame(frame):
125
+ import cv2 # pip install opencv-python-headless
126
+ from PIL import Image
127
+
128
+ # Convert the frame to RGB (OpenCV uses BGR by default)
129
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
130
+
131
+ # Convert the frame to PIL Image to easily convert to bytes
132
+ im_pil = Image.fromarray(frame)
133
+
134
+ # Convert to bytes
135
+ buffered = BytesIO()
136
+
137
+ # frame_format = str(os.getenv('FRAME_FORMAT', "JPEG"))
138
+
139
+ im_pil.save(buffered, format="PNG")
140
+
141
+ frame_bytes = buffered.getvalue()
142
+
143
+ # Return the bytes of the frame
144
+ return frame_bytes
145
+
146
+
147
+ def encode_video_base64(video_path, num_frames=16):
148
+ import cv2 # pip install opencv-python-headless
149
+
150
+ cap = cv2.VideoCapture(video_path)
151
+ if not cap.isOpened():
152
+ raise IOError(f"Could not open video file:{video_path}")
153
+
154
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
155
+ print(f"target_frames: {num_frames}")
156
+
157
+ frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
158
+
159
+ frames = []
160
+ for i in range(total_frames):
161
+ ret, frame = cap.read()
162
+ if ret:
163
+ frames.append(frame)
164
+ else:
165
+ # Handle the case where the frame could not be read
166
+ # print(f"Warning: Could not read frame at index {i}.")
167
+ pass
168
+
169
+ cap.release()
170
+
171
+ # Safely select frames based on frame_indices, avoiding IndexError
172
+ frames = [frames[i] for i in frame_indices if i < len(frames)]
173
+
174
+ # If there are not enough frames, duplicate the last frame until we reach the target
175
+ while len(frames) < num_frames:
176
+ frames.append(frames[-1])
177
+
178
+ # Use ThreadPoolExecutor to process and encode frames in parallel
179
+ with ThreadPoolExecutor() as executor:
180
+ encoded_frames = list(executor.map(encode_frame, frames))
181
+
182
+ # encoded_frames = list(map(encode_frame, frames))
183
+
184
+ # Concatenate all frames bytes
185
+ video_bytes = b"".join(encoded_frames)
186
+
187
+ # Encode the concatenated bytes to base64
188
+ video_base64 = "video:" + base64.b64encode(video_bytes).decode("utf-8")
189
+
190
+ return video_base64
191
+
192
+
133
193
  def _is_chinese_char(cp):
134
194
  """Checks whether CP is the codepoint of a CJK character."""
135
195
  # This defines a "chinese character" as anything in the CJK Unicode block:
@@ -191,3 +251,12 @@ def run_with_timeout(func, args=(), kwargs=None, timeout=None):
191
251
  raise RuntimeError()
192
252
 
193
253
  return ret_value[0]
254
+
255
+
256
+ def graceful_registry(sub_module_name):
257
+ def graceful_shutdown(signum, frame):
258
+ logger.info(f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown...")
259
+ if signum == signal.SIGTERM:
260
+ logger.info(f"{sub_module_name} recive sigterm")
261
+
262
+ signal.signal(signal.SIGTERM, graceful_shutdown)