sglang 0.1.16__py3-none-any.whl → 0.1.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. sglang/__init__.py +3 -1
  2. sglang/api.py +7 -7
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +158 -11
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/bench_latency.py +299 -0
  8. sglang/global_config.py +12 -2
  9. sglang/lang/compiler.py +2 -2
  10. sglang/lang/interpreter.py +114 -67
  11. sglang/lang/ir.py +28 -3
  12. sglang/launch_server.py +4 -1
  13. sglang/launch_server_llavavid.py +2 -1
  14. sglang/srt/constrained/__init__.py +13 -6
  15. sglang/srt/constrained/fsm_cache.py +8 -2
  16. sglang/srt/constrained/jump_forward.py +113 -25
  17. sglang/srt/conversation.py +2 -0
  18. sglang/srt/flush_cache.py +3 -1
  19. sglang/srt/hf_transformers_utils.py +130 -1
  20. sglang/srt/layers/extend_attention.py +17 -0
  21. sglang/srt/layers/fused_moe.py +582 -0
  22. sglang/srt/layers/logits_processor.py +65 -32
  23. sglang/srt/layers/radix_attention.py +41 -7
  24. sglang/srt/layers/token_attention.py +16 -1
  25. sglang/srt/managers/controller/dp_worker.py +113 -0
  26. sglang/srt/managers/{router → controller}/infer_batch.py +242 -100
  27. sglang/srt/managers/controller/manager_multi.py +191 -0
  28. sglang/srt/managers/{router/manager.py → controller/manager_single.py} +34 -14
  29. sglang/srt/managers/{router → controller}/model_runner.py +262 -158
  30. sglang/srt/managers/{router → controller}/radix_cache.py +11 -1
  31. sglang/srt/managers/{router/scheduler.py → controller/schedule_heuristic.py} +9 -7
  32. sglang/srt/managers/{router/model_rpc.py → controller/tp_worker.py} +298 -267
  33. sglang/srt/managers/detokenizer_manager.py +42 -46
  34. sglang/srt/managers/io_struct.py +22 -12
  35. sglang/srt/managers/tokenizer_manager.py +151 -87
  36. sglang/srt/model_config.py +83 -5
  37. sglang/srt/models/chatglm.py +399 -0
  38. sglang/srt/models/commandr.py +10 -13
  39. sglang/srt/models/dbrx.py +9 -15
  40. sglang/srt/models/gemma.py +12 -15
  41. sglang/srt/models/grok.py +738 -0
  42. sglang/srt/models/llama2.py +26 -15
  43. sglang/srt/models/llama_classification.py +104 -0
  44. sglang/srt/models/llava.py +86 -19
  45. sglang/srt/models/llavavid.py +11 -20
  46. sglang/srt/models/mixtral.py +282 -103
  47. sglang/srt/models/mixtral_quant.py +372 -0
  48. sglang/srt/models/qwen.py +9 -13
  49. sglang/srt/models/qwen2.py +11 -13
  50. sglang/srt/models/stablelm.py +9 -15
  51. sglang/srt/models/yivl.py +17 -22
  52. sglang/srt/openai_api_adapter.py +150 -95
  53. sglang/srt/openai_protocol.py +11 -2
  54. sglang/srt/server.py +124 -48
  55. sglang/srt/server_args.py +128 -48
  56. sglang/srt/utils.py +234 -67
  57. sglang/test/test_programs.py +65 -3
  58. sglang/test/test_utils.py +32 -1
  59. sglang/utils.py +23 -4
  60. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/METADATA +40 -27
  61. sglang-0.1.18.dist-info/RECORD +78 -0
  62. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
  63. sglang/srt/backend_config.py +0 -13
  64. sglang/srt/models/dbrx_config.py +0 -281
  65. sglang/srt/weight_utils.py +0 -417
  66. sglang-0.1.16.dist-info/RECORD +0 -72
  67. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
  68. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
sglang/srt/utils.py CHANGED
@@ -1,24 +1,31 @@
1
1
  """Common utilities."""
2
2
 
3
3
  import base64
4
+ import fcntl
5
+ import logging
6
+ import multiprocessing
4
7
  import os
5
8
  import random
6
9
  import socket
10
+ import struct
7
11
  import time
8
12
  from importlib.metadata import PackageNotFoundError, version
9
13
  from io import BytesIO
10
14
  from typing import List, Optional
11
15
 
12
16
  import numpy as np
13
- import pydantic
17
+ import psutil
14
18
  import requests
19
+ import rpyc
15
20
  import torch
21
+ import triton
16
22
  from fastapi.responses import JSONResponse
17
23
  from packaging import version as pkg_version
18
- from pydantic import BaseModel
24
+ from rpyc.utils.server import ThreadedServer
19
25
  from starlette.middleware.base import BaseHTTPMiddleware
20
26
 
21
- from sglang.utils import get_exception_traceback
27
+ logger = logging.getLogger(__name__)
28
+
22
29
 
23
30
  show_time_cost = False
24
31
  time_infos = {}
@@ -90,7 +97,7 @@ def calculate_time(show=False, min_cost_ms=0.0):
90
97
  return wrapper
91
98
 
92
99
 
93
- def get_available_gpu_memory(gpu_id, distributed=True):
100
+ def get_available_gpu_memory(gpu_id, distributed=False):
94
101
  """
95
102
  Get available memory for cuda:gpu_id device.
96
103
  When distributed is True, the available memory is the minimum available memory of all GPUs.
@@ -104,6 +111,7 @@ def get_available_gpu_memory(gpu_id, distributed=True):
104
111
  "which may cause useless memory allocation for torch CUDA context.",
105
112
  )
106
113
 
114
+ torch.cuda.empty_cache()
107
115
  free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
108
116
 
109
117
  if distributed:
@@ -117,38 +125,21 @@ def get_available_gpu_memory(gpu_id, distributed=True):
117
125
 
118
126
 
119
127
  def set_random_seed(seed: int) -> None:
128
+ """Set the random seed for all libraries."""
120
129
  random.seed(seed)
121
-
130
+ np.random.seed(seed)
122
131
  torch.manual_seed(seed)
123
132
  if torch.cuda.is_available():
124
133
  torch.cuda.manual_seed_all(seed)
125
134
 
126
135
 
127
- def alloc_usable_network_port(num, used_list=()):
128
- port_list = []
129
- for port in range(10000, 65536):
130
- if port in used_list:
131
- continue
132
-
133
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
134
- s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
135
- try:
136
- s.bind(("", port))
137
- s.listen(1) # Attempt to listen on the port
138
- port_list.append(port)
139
- except socket.error:
140
- pass # If any error occurs, this port is not usable
141
-
142
- if len(port_list) == num:
143
- return port_list
144
- return None
145
-
146
-
147
- def check_port(port):
136
+ def is_port_available(port):
137
+ """Return whether a port is available."""
148
138
  with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
149
139
  try:
150
140
  s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
151
141
  s.bind(("", port))
142
+ s.listen(1)
152
143
  return True
153
144
  except socket.error:
154
145
  return False
@@ -158,35 +149,34 @@ def allocate_init_ports(
158
149
  port: Optional[int] = None,
159
150
  additional_ports: Optional[List[int]] = None,
160
151
  tp_size: int = 1,
152
+ dp_size: int = 1,
161
153
  ):
162
- port = 30000 if port is None else port
163
- additional_ports = [] if additional_ports is None else additional_ports
164
- additional_ports = (
165
- [additional_ports] if isinstance(additional_ports, int) else additional_ports
166
- )
167
- # first check on server port
168
- if not check_port(port):
169
- new_port = alloc_usable_network_port(1, used_list=[port])[0]
170
- print(f"WARNING: Port {port} is not available. Use {new_port} instead.")
171
- port = new_port
172
-
173
- # then we check on additional ports
174
- additional_unique_ports = set(additional_ports) - {port}
175
- # filter out ports that are already in use
176
- can_use_ports = [port for port in additional_unique_ports if check_port(port)]
177
-
178
- num_specified_ports = len(can_use_ports)
179
- if num_specified_ports < 4 + tp_size:
180
- addtional_can_use_ports = alloc_usable_network_port(
181
- num=4 + tp_size - num_specified_ports, used_list=can_use_ports + [port]
154
+ """Allocate ports for all connections."""
155
+ if additional_ports:
156
+ ret_ports = [port] + additional_ports
157
+ else:
158
+ ret_ports = [port]
159
+
160
+ ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
161
+ cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000
162
+
163
+ # HTTP + Tokenizer + Controller + Detokenizer + dp_size * (nccl + tp_size)
164
+ num_ports_needed = 4 + dp_size * (1 + tp_size)
165
+ while len(ret_ports) < num_ports_needed:
166
+ if cur_port not in ret_ports and is_port_available(cur_port):
167
+ ret_ports.append(cur_port)
168
+ cur_port += 1
169
+
170
+ if port is not None and ret_ports[0] != port:
171
+ logger.warn(
172
+ f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
182
173
  )
183
- can_use_ports.extend(addtional_can_use_ports)
184
174
 
185
- additional_ports = can_use_ports[: 4 + tp_size]
186
- return port, additional_ports
175
+ return ret_ports[0], ret_ports[1:num_ports_needed]
187
176
 
188
177
 
189
178
  def get_int_token_logit_bias(tokenizer, vocab_size):
179
+ """Get the logit bias for integer-only tokens."""
190
180
  # a bug when model's vocab size > tokenizer.vocab_size
191
181
  vocab_size = tokenizer.vocab_size
192
182
  logit_bias = np.zeros(vocab_size, dtype=np.float32)
@@ -200,14 +190,11 @@ def get_int_token_logit_bias(tokenizer, vocab_size):
200
190
 
201
191
  def wrap_kernel_launcher(kernel):
202
192
  """A faster launcher for triton kernels."""
203
- import torch.distributed as dist
204
-
205
- if dist.is_initialized():
206
- rank = dist.get_rank()
207
- else:
208
- rank = 0
193
+ if int(triton.__version__.split(".")[0]) >= 3:
194
+ return None
209
195
 
210
- kernels = kernel.cache[rank].values()
196
+ gpu_id = torch.cuda.current_device()
197
+ kernels = kernel.cache[gpu_id].values()
211
198
  kernel = next(iter(kernels))
212
199
 
213
200
  # Different trition versions use different low-level names
@@ -275,7 +262,9 @@ def is_multimodal_model(model):
275
262
 
276
263
  if isinstance(model, ModelConfig):
277
264
  model_path = model.path.lower()
278
- return "llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path
265
+ return (
266
+ "llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path
267
+ )
279
268
 
280
269
  raise ValueError("unrecognized type")
281
270
 
@@ -382,20 +371,145 @@ def load_image(image_file):
382
371
  return image, image_size
383
372
 
384
373
 
385
- def assert_pkg_version(pkg: str, min_version: str):
374
+ def connect_rpyc_service(host, port):
375
+ repeat_count = 0
376
+ while repeat_count < 20:
377
+ try:
378
+ con = rpyc.connect(
379
+ host,
380
+ port,
381
+ config={
382
+ "allow_public_attrs": True,
383
+ "allow_pickle": True,
384
+ "sync_request_timeout": 3600,
385
+ },
386
+ )
387
+ break
388
+ except ConnectionRefusedError as e:
389
+ time.sleep(1)
390
+ repeat_count += 1
391
+ if repeat_count == 20:
392
+ raise RuntimeError(f"Connect rpyc error: {e}")
393
+
394
+ return con.root
395
+
396
+
397
+ def start_rpyc_service(service: rpyc.Service, port: int):
398
+ t = ThreadedServer(
399
+ service=service,
400
+ port=port,
401
+ protocol_config={
402
+ "allow_public_attrs": True,
403
+ "allow_pickle": True,
404
+ "sync_request_timeout": 3600,
405
+ },
406
+ )
407
+ t.logger.setLevel(logging.WARN)
408
+ t.start()
409
+
410
+
411
+ def start_rpyc_service_process(service: rpyc.Service, port: int):
412
+ proc = multiprocessing.Process(target=start_rpyc_service, args=(service, port))
413
+ proc.start()
414
+ return proc
415
+
416
+
417
+ def suppress_other_loggers():
418
+ from vllm.logger import logger as vllm_default_logger
419
+
420
+ vllm_default_logger.setLevel(logging.WARN)
421
+ logging.getLogger("vllm.config").setLevel(logging.ERROR)
422
+ logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(
423
+ logging.WARN
424
+ )
425
+ logging.getLogger("vllm.selector").setLevel(logging.WARN)
426
+ logging.getLogger("vllm.utils").setLevel(logging.WARN)
427
+
428
+
429
+ def assert_pkg_version(pkg: str, min_version: str, message: str):
386
430
  try:
387
431
  installed_version = version(pkg)
388
432
  if pkg_version.parse(installed_version) < pkg_version.parse(min_version):
389
433
  raise Exception(
390
- f"{pkg} is installed with version {installed_version} which "
391
- f"is less than the minimum required version {min_version}"
434
+ f"{pkg} is installed with version {installed_version}, which "
435
+ f"is less than the minimum required version {min_version}. " +
436
+ message
392
437
  )
393
438
  except PackageNotFoundError:
394
439
  raise Exception(
395
- f"{pkg} with minimum required version {min_version} is not installed"
440
+ f"{pkg} with minimum required version {min_version} is not installed. " +
441
+ message
396
442
  )
397
443
 
398
444
 
445
+ def kill_parent_process():
446
+ """Kill the parent process and all children of the parent process."""
447
+ current_process = psutil.Process()
448
+ parent_process = current_process.parent()
449
+ children = current_process.children(recursive=True)
450
+ for child in children:
451
+ if child.pid != current_process.pid:
452
+ os.kill(child.pid, 9)
453
+ os.kill(parent_process.pid, 9)
454
+
455
+
456
+ def monkey_patch_vllm_p2p_access_check(gpu_id: int):
457
+ """
458
+ Monkey patch the slow p2p access check in vllm.
459
+ NOTE: We assume the p2p access is always allowed, which can be wrong for some setups.
460
+ """
461
+
462
+ # TODO: need a better check than just dev str name match
463
+ # compat: skip RTX 40 series as they do not have P2P feature and even checking for them may cause errors
464
+ device_name = torch.cuda.get_device_name(gpu_id)
465
+ if "RTX 40" not in device_name:
466
+ import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt
467
+
468
+ setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
469
+
470
+
471
+ def monkey_patch_vllm_dummy_weight_loader():
472
+ """
473
+ Monkey patch the dummy weight loader in vllm to call process_weights_after_loading.
474
+ """
475
+
476
+ from vllm.model_executor.model_loader.loader import (
477
+ ModelConfig, DeviceConfig, LoRAConfig, VisionLanguageConfig,
478
+ ParallelConfig, SchedulerConfig, CacheConfig, nn,
479
+ set_default_torch_dtype, _initialize_model, initialize_dummy_weights,
480
+ DummyModelLoader
481
+ )
482
+
483
+ def load_model(self, *, model_config: ModelConfig,
484
+ device_config: DeviceConfig,
485
+ lora_config: Optional[LoRAConfig],
486
+ vision_language_config: Optional[VisionLanguageConfig],
487
+ parallel_config: ParallelConfig,
488
+ scheduler_config: SchedulerConfig,
489
+ cache_config: CacheConfig) -> nn.Module:
490
+ with set_default_torch_dtype(model_config.dtype):
491
+ with torch.device(device_config.device):
492
+ model = _initialize_model(model_config, self.load_config,
493
+ lora_config, vision_language_config,
494
+ cache_config)
495
+
496
+ for _, module in model.named_modules():
497
+ quant_method = getattr(module, "quant_method", None)
498
+ if quant_method is not None:
499
+ quant_method.process_weights_after_loading(module)
500
+ # FIXME: Remove this after Mixtral is updated
501
+ # to use quant_method.
502
+ if hasattr(module, "process_weights_after_loading"):
503
+ module.process_weights_after_loading()
504
+
505
+ # NOTE(woosuk): For accurate performance evaluation, we assign
506
+ # random values to the weights.
507
+ initialize_dummy_weights(model)
508
+ return model.eval()
509
+
510
+ setattr(DummyModelLoader, "load_model", load_model)
511
+
512
+
399
513
  API_KEY_HEADER_NAME = "X-API-Key"
400
514
 
401
515
 
@@ -416,11 +530,64 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
416
530
  return response
417
531
 
418
532
 
419
- # FIXME: Remove this once we drop support for pydantic 1.x
420
- IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
533
+ def get_ip_address(ifname):
534
+ """
535
+ Get the IP address of a network interface.
536
+
537
+ :param ifname: Name of the network interface (e.g., 'eth0')
538
+ :return: IP address of the network interface
539
+ """
540
+ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
541
+ ip_address = fcntl.ioctl(
542
+ s.fileno(),
543
+ 0x8915, # SIOCGIFADDR
544
+ struct.pack('256s', bytes(ifname[:15], 'utf-8'))
545
+ )[20:24]
546
+ return socket.inet_ntoa(ip_address)
547
+
548
+
549
+ def send_addrs_to_rank_0(model_port_args, server_args):
550
+ assert server_args.node_rank != 0 and server_args.dp_size == 1
551
+ import torch.distributed as dist
552
+
553
+ ifname = os.environ.get("SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0"))
554
+ ip_addr = get_ip_address(ifname)
555
+
556
+ num_tp_ports = server_args.tp_size // server_args.nnodes
557
+ model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
558
+ ip_addr = [int(x) for x in ip_addr.split(".")]
559
+ addrs_tensor = torch.tensor(ip_addr + model_port_args.model_tp_ports, dtype=torch.int)
560
+
561
+ init_method = f"tcp://{server_args.nccl_init_addr}"
562
+ dist.init_process_group(backend="gloo", init_method=init_method, rank=server_args.node_rank, world_size=server_args.nnodes)
563
+ dist.send(addrs_tensor, dst=0)
564
+ print(f"Node {server_args.node_rank} sent: ip_address {ip_addr} and ports {model_port_args.model_tp_ports}")
565
+
566
+ dist.barrier()
567
+ dist.destroy_process_group()
568
+
569
+
570
+ def receive_addrs(model_port_args, server_args):
571
+ assert server_args.node_rank == 0 and server_args.dp_size == 1
572
+ import torch.distributed as dist
573
+
574
+ ifname = os.environ.get("SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0"))
575
+ ip_addr = get_ip_address(ifname)
576
+
577
+ num_tp_ports = server_args.tp_size // server_args.nnodes
578
+ model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
579
+
580
+ init_method = f"tcp://{server_args.nccl_init_addr}"
581
+ dist.init_process_group(backend="gloo", init_method=init_method, rank=server_args.node_rank, world_size=server_args.nnodes)
421
582
 
583
+ for src_rank in range(1, server_args.nnodes):
584
+ tensor = torch.zeros(4 + num_tp_ports, dtype=torch.int)
585
+ dist.recv(tensor, src=src_rank)
586
+ ip = ".".join([str(x) for x in tensor[:4].tolist()])
587
+ ports = tensor[4:].tolist()
588
+ model_port_args.model_tp_ips[num_tp_ports * src_rank: num_tp_ports * (src_rank + 1)] = [ip] * num_tp_ports
589
+ model_port_args.model_tp_ports[num_tp_ports * src_rank: num_tp_ports * (src_rank + 1)] = ports
590
+ print(f"Node 0 received from rank {src_rank}: {tensor.tolist()}")
422
591
 
423
- def jsonify_pydantic_model(obj: BaseModel):
424
- if IS_PYDANTIC_1:
425
- return obj.json(ensure_ascii=False)
426
- return obj.model_dump_json()
592
+ dist.barrier()
593
+ dist.destroy_process_group()
@@ -1,6 +1,4 @@
1
- """
2
- This file contains the SGL programs used for unit testing.
3
- """
1
+ """This file contains the SGL programs used for unit testing."""
4
2
 
5
3
  import json
6
4
  import re
@@ -304,6 +302,7 @@ def test_image_qa():
304
302
  temperature=0,
305
303
  max_new_tokens=64,
306
304
  )
305
+
307
306
  assert (
308
307
  "taxi" in state.messages()[-1]["content"]
309
308
  or "car" in state.messages()[-1]["content"]
@@ -349,3 +348,66 @@ def test_regex():
349
348
  state = regex_gen.run()
350
349
  answer = state["answer"]
351
350
  assert re.match(regex, answer)
351
+
352
+
353
+ def test_completion_speculative():
354
+ @sgl.function(num_api_spec_tokens=64)
355
+ def gen_character_spec(s):
356
+ s += "Construct a character within the following format:\n"
357
+ s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
358
+ s += "\nPlease generate new Name, Birthday and Job.\n"
359
+ s += (
360
+ "Name:"
361
+ + sgl.gen("name", stop="\n")
362
+ + "\nBirthday:"
363
+ + sgl.gen("birthday", stop="\n")
364
+ )
365
+ s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
366
+
367
+ @sgl.function
368
+ def gen_character_no_spec(s):
369
+ s += "Construct a character within the following format:\n"
370
+ s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
371
+ s += "\nPlease generate new Name, Birthday and Job.\n"
372
+ s += (
373
+ "Name:"
374
+ + sgl.gen("name", stop="\n")
375
+ + "\nBirthday:"
376
+ + sgl.gen("birthday", stop="\n")
377
+ )
378
+ s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
379
+
380
+ token_usage = sgl.global_config.default_backend.token_usage
381
+
382
+ token_usage.reset()
383
+ gen_character_spec().sync()
384
+ usage_with_spec = token_usage.prompt_tokens
385
+
386
+ token_usage.reset()
387
+ gen_character_no_spec().sync()
388
+ usage_with_no_spec = token_usage.prompt_tokens
389
+
390
+ assert (
391
+ usage_with_spec < usage_with_no_spec
392
+ ), f"{usage_with_spec} vs {usage_with_no_spec}"
393
+
394
+
395
+ def test_chat_completion_speculative():
396
+ @sgl.function(num_api_spec_tokens=256)
397
+ def gen_character_spec(s):
398
+ s += sgl.system("You are a helpful assistant.")
399
+ s += sgl.user("Construct a character within the following format:")
400
+ s += sgl.assistant(
401
+ "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
402
+ )
403
+ s += sgl.user("Please generate new Name, Birthday and Job.\n")
404
+ s += sgl.assistant(
405
+ "Name:"
406
+ + sgl.gen("name", stop="\n")
407
+ + "\nBirthday:"
408
+ + sgl.gen("birthday", stop="\n")
409
+ + "\nJob:"
410
+ + sgl.gen("job", stop="\n")
411
+ )
412
+
413
+ gen_character_spec().sync()
sglang/test/test_utils.py CHANGED
@@ -9,7 +9,7 @@ import requests
9
9
  from sglang.backend.openai import OpenAI
10
10
  from sglang.backend.runtime_endpoint import RuntimeEndpoint
11
11
  from sglang.global_config import global_config
12
- from sglang.srt.utils import get_exception_traceback
12
+ from sglang.utils import get_exception_traceback
13
13
 
14
14
 
15
15
  def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
@@ -88,6 +88,33 @@ def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None):
88
88
  return pred
89
89
 
90
90
 
91
+ def call_generate_ginfer(prompt, temperature, max_tokens, stop=None, url=None):
92
+ import grpc
93
+ from ginfer import sampler_pb2, sampler_pb2_grpc
94
+
95
+ sampler_channel = grpc.insecure_channel(url.replace("http://", ""))
96
+ sampler = sampler_pb2_grpc.SamplerStub(sampler_channel)
97
+
98
+ if stop is None:
99
+ stop_strings = None
100
+ else:
101
+ stop_strings = [stop]
102
+
103
+ sample_request = sampler_pb2.SampleTextRequest(
104
+ prompt=prompt,
105
+ settings=sampler_pb2.SampleSettings(
106
+ max_len=max_tokens,
107
+ rng_seed=0,
108
+ temperature=max(temperature, 1e-7),
109
+ nucleus_p=1,
110
+ stop_strings=stop_strings,
111
+ ),
112
+ )
113
+ stream = sampler.SampleText(sample_request)
114
+ response = "".join([x.text for x in stream])
115
+ return response
116
+
117
+
91
118
  def call_generate_guidance(
92
119
  prompt, temperature, max_tokens, stop=None, n=1, regex=None, model=None
93
120
  ):
@@ -228,6 +255,7 @@ def add_common_other_args_and_parse(parser):
228
255
  "vllm",
229
256
  "outlines",
230
257
  "lightllm",
258
+ "ginfer",
231
259
  "guidance",
232
260
  "lmql",
233
261
  "srt-raw",
@@ -248,6 +276,7 @@ def add_common_other_args_and_parse(parser):
248
276
  "lightllm": 22000,
249
277
  "lmql": 23000,
250
278
  "srt-raw": 30000,
279
+ "ginfer": 9988,
251
280
  }
252
281
  args.port = default_port.get(args.backend, None)
253
282
  return args
@@ -283,6 +312,8 @@ def _get_call_generate(args):
283
312
  return partial(call_generate_vllm, url=f"{args.host}:{args.port}/generate")
284
313
  elif args.backend == "srt-raw":
285
314
  return partial(call_generate_srt_raw, url=f"{args.host}:{args.port}/generate")
315
+ elif args.backend == "ginfer":
316
+ return partial(call_generate_ginfer, url=f"{args.host}:{args.port}")
286
317
  elif args.backend == "outlines":
287
318
  return partial(call_generate_outlines, url=f"{args.host}:{args.port}/generate")
288
319
  elif args.backend == "guidance":
sglang/utils.py CHANGED
@@ -2,7 +2,8 @@
2
2
 
3
3
  import base64
4
4
  import json
5
- import os
5
+ import logging
6
+ import signal
6
7
  import sys
7
8
  import threading
8
9
  import traceback
@@ -14,6 +15,8 @@ from json import dumps
14
15
  import numpy as np
15
16
  import requests
16
17
 
18
+ logger = logging.getLogger(__name__)
19
+
17
20
 
18
21
  def get_exception_traceback():
19
22
  etype, value, tb = sys.exc_info()
@@ -93,8 +96,12 @@ def http_request(
93
96
  data = None
94
97
  else:
95
98
  data = bytes(dumps(json), encoding="utf-8")
96
- resp = urllib.request.urlopen(req, data=data, cafile=verify)
97
- return HttpResponse(resp)
99
+
100
+ try:
101
+ resp = urllib.request.urlopen(req, data=data, cafile=verify)
102
+ return HttpResponse(resp)
103
+ except urllib.error.HTTPError as e:
104
+ return HttpResponse(e)
98
105
 
99
106
 
100
107
  def encode_image_base64(image_path):
@@ -137,7 +144,8 @@ def encode_frame(frame):
137
144
 
138
145
 
139
146
  def encode_video_base64(video_path, num_frames=16):
140
- import cv2
147
+ import cv2 # pip install opencv-python-headless
148
+
141
149
  cap = cv2.VideoCapture(video_path)
142
150
  if not cap.isOpened():
143
151
  raise IOError(f"Could not open video file:{video_path}")
@@ -242,3 +250,14 @@ def run_with_timeout(func, args=(), kwargs=None, timeout=None):
242
250
  raise RuntimeError()
243
251
 
244
252
  return ret_value[0]
253
+
254
+
255
+ def graceful_registry(sub_module_name):
256
+ def graceful_shutdown(signum, frame):
257
+ logger.info(
258
+ f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown..."
259
+ )
260
+ if signum == signal.SIGTERM:
261
+ logger.info(f"{sub_module_name} recive sigterm")
262
+
263
+ signal.signal(signal.SIGTERM, graceful_shutdown)