sglang 0.1.24__py3-none-any.whl → 0.1.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sglang/__init__.py CHANGED
@@ -1,5 +1,3 @@
1
- __version__ = "0.1.24"
2
-
3
1
  # SGL API Components
4
2
  from sglang.api import (
5
3
  Runtime,
@@ -32,6 +30,8 @@ from sglang.lang.backend.openai import OpenAI
32
30
  from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
33
31
  from sglang.lang.backend.vertexai import VertexAI
34
32
 
33
+ from .version import __version__
34
+
35
35
  # public APIs management
36
36
  __all__ = [
37
37
  "global_config",
@@ -15,6 +15,7 @@ from flashinfer import (
15
15
  BatchPrefillWithRaggedKVCacheWrapper,
16
16
  )
17
17
  from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
18
+ from torch.nn.parameter import Parameter
18
19
  from vllm.config import DeviceConfig, LoadConfig
19
20
  from vllm.config import ModelConfig as VllmModelConfig
20
21
  from vllm.distributed import (
@@ -22,6 +23,7 @@ from vllm.distributed import (
22
23
  init_distributed_environment,
23
24
  initialize_model_parallel,
24
25
  )
26
+ from vllm.model_executor.layers.linear import QKVParallelLinear
25
27
  from vllm.model_executor.models import ModelRegistry
26
28
 
27
29
  from sglang.global_config import global_config
@@ -38,6 +40,18 @@ from sglang.srt.utils import (
38
40
  logger = logging.getLogger("srt.model_runner")
39
41
 
40
42
 
43
+ def is_llama3_405b_fp8(model_config):
44
+ if (
45
+ model_config.hf_config.architectures[0] == "LlamaForCausalLM"
46
+ and model_config.hf_config.hidden_size == 16384
47
+ and model_config.hf_config.intermediate_size == 53248
48
+ and model_config.hf_config.num_hidden_layers == 126
49
+ and model_config.hf_config.quantization_config["quant_method"] == "fbgemm_fp8"
50
+ ):
51
+ return True
52
+ return False
53
+
54
+
41
55
  class ModelRunner:
42
56
  def __init__(
43
57
  self,
@@ -118,6 +132,9 @@ class ModelRunner:
118
132
  seed=42,
119
133
  skip_tokenizer_init=True,
120
134
  )
135
+ if is_llama3_405b_fp8(self.model_config):
136
+ self.model_config.hf_config.num_key_value_heads = 8
137
+ vllm_model_config.hf_config.num_key_value_heads = 8
121
138
  self.dtype = vllm_model_config.dtype
122
139
  if self.model_config.model_overide_args is not None:
123
140
  vllm_model_config.hf_config.update(self.model_config.model_overide_args)
@@ -370,5 +387,39 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
370
387
  return model_arch_name_to_cls[model_arch]
371
388
 
372
389
 
390
+ def get_original_weight(loaded_weight, head_dim):
391
+ n_kv_head = loaded_weight.shape[0] // (2 * head_dim)
392
+ dim = loaded_weight.shape[1]
393
+ for i in range(n_kv_head):
394
+ loaded_weight[i * head_dim : (i + 1) * head_dim, :] = loaded_weight[
395
+ 2 * i * head_dim : (2 * i + 1) * head_dim, :
396
+ ]
397
+ original_kv_weight = loaded_weight[: n_kv_head * head_dim, :]
398
+ assert original_kv_weight.shape == (n_kv_head * head_dim, dim)
399
+ return original_kv_weight
400
+
401
+
402
+ def get_weight_loader_srt(weight_loader):
403
+ def weight_loader_srt(
404
+ self,
405
+ param: Parameter,
406
+ loaded_weight: torch.Tensor,
407
+ loaded_shard_id: Optional[str] = None,
408
+ ):
409
+ if (
410
+ loaded_shard_id in ["k", "v"]
411
+ and loaded_weight.shape[0] == self.head_size * self.total_num_kv_heads * 2
412
+ ):
413
+ loaded_weight = get_original_weight(loaded_weight, self.head_size)
414
+
415
+ weight_loader(self, param, loaded_weight, loaded_shard_id)
416
+
417
+ return weight_loader_srt
418
+
419
+
373
420
  # Monkey patch model loader
374
421
  setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)
422
+ original_weight_loader = QKVParallelLinear.weight_loader
423
+ setattr(
424
+ QKVParallelLinear, "weight_loader", get_weight_loader_srt(original_weight_loader)
425
+ )
sglang/srt/server.py CHANGED
@@ -52,6 +52,7 @@ from sglang.srt.utils import (
52
52
  allocate_init_ports,
53
53
  assert_pkg_version,
54
54
  enable_show_time_cost,
55
+ maybe_set_triton_cache_manager,
55
56
  set_ulimit,
56
57
  )
57
58
  from sglang.utils import get_exception_traceback
@@ -201,6 +202,11 @@ def launch_server(
201
202
  "reinstall the latest version by following the instructions "
202
203
  "at https://docs.flashinfer.ai/installation.html.",
203
204
  )
205
+
206
+ if server_args.tp_size // server_args.dp_size > 1:
207
+ # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
208
+ maybe_set_triton_cache_manager()
209
+
204
210
  if server_args.chat_template:
205
211
  # TODO: replace this with huggingface transformers template
206
212
  load_chat_template_for_openai_api(server_args.chat_template)
sglang/srt/utils.py CHANGED
@@ -18,10 +18,15 @@ import psutil
18
18
  import requests
19
19
  import torch
20
20
  import torch.distributed as dist
21
- import triton
22
21
  from fastapi.responses import JSONResponse
23
22
  from packaging import version as pkg_version
24
23
  from starlette.middleware.base import BaseHTTPMiddleware
24
+ from triton.runtime.cache import (
25
+ FileCacheManager,
26
+ default_cache_dir,
27
+ default_dump_dir,
28
+ default_override_dir,
29
+ )
25
30
 
26
31
  logger = logging.getLogger(__name__)
27
32
 
@@ -460,6 +465,44 @@ def monkey_patch_vllm_all_gather(reverse: bool = False):
460
465
  setattr(GroupCoordinator, "all_gather", all_gather)
461
466
 
462
467
 
468
+ def maybe_set_triton_cache_manager() -> None:
469
+ """Set environment variable to tell Triton to use a
470
+ custom cache manager"""
471
+ cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None)
472
+ if cache_manger is None:
473
+ manager = "sglang.srt.utils:CustomCacheManager"
474
+ logger.info("Setting Triton cache manager to: %s", manager)
475
+ os.environ["TRITON_CACHE_MANAGER"] = manager
476
+
477
+
478
+ class CustomCacheManager(FileCacheManager):
479
+ # Adapted from: https://github.com/tdoublep/vllm/blob/3307522289fdfefe323b6c00d0db696651989a2f/vllm/triton_utils/custom_cache_manager.py
480
+ def __init__(self, key, override=False, dump=False):
481
+
482
+ self.key = key
483
+ self.lock_path = None
484
+ if dump:
485
+ self.cache_dir = default_dump_dir()
486
+ self.cache_dir = os.path.join(self.cache_dir, self.key)
487
+ self.lock_path = os.path.join(self.cache_dir, "lock")
488
+ os.makedirs(self.cache_dir, exist_ok=True)
489
+ elif override:
490
+ self.cache_dir = default_override_dir()
491
+ self.cache_dir = os.path.join(self.cache_dir, self.key)
492
+ else:
493
+ # create cache directory if it doesn't exist
494
+ self.cache_dir = (
495
+ os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir()
496
+ )
497
+ if self.cache_dir:
498
+ self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
499
+ self.cache_dir = os.path.join(self.cache_dir, self.key)
500
+ self.lock_path = os.path.join(self.cache_dir, "lock")
501
+ os.makedirs(self.cache_dir, exist_ok=True)
502
+ else:
503
+ raise RuntimeError("Could not create or locate cache dir")
504
+
505
+
463
506
  API_KEY_HEADER_NAME = "X-API-Key"
464
507
 
465
508
 
sglang/version.py ADDED
@@ -0,0 +1 @@
1
+ __version__ = "0.1.25"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: sglang
3
- Version: 0.1.24
3
+ Version: 0.1.25
4
4
  Summary: SGLang is yet another fast serving framework for large language models and vision language models.
5
5
  License: Apache License
6
6
  Version 2.0, January 2004
@@ -282,7 +282,7 @@ The core features include:
282
282
 
283
283
  ### Method 1: With pip
284
284
  ```
285
- pip install --upgrade pip setuptools wheel
285
+ pip install --upgrade pip
286
286
  pip install "sglang[all]"
287
287
 
288
288
  # Install FlashInfer CUDA kernels
@@ -405,7 +405,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
405
405
 
406
406
  ### Supported Models
407
407
 
408
- - Llama / Llama 2 / Llama 3
408
+ - Llama / Llama 2 / Llama 3 / Llama 3.1
409
409
  - Mistral / Mixtral
410
410
  - Gemma / Gemma 2
411
411
  - Qwen / Qwen 2 / Qwen 2 MoE
@@ -1,6 +1,5 @@
1
- sglang/__init__.py,sha256=nMs6lYeKcQpYArIaZLQ2VGNleY1dVvdBFaHyG7fpOsA,1141
1
+ sglang/__init__.py,sha256=UV7VlXhXrwi00Zg45iNB9KcnmrwLjdMtjMz06AiafY0,1151
2
2
  sglang/api.py,sha256=1JARbc1wNYF6tODdUpgmNgTyLOvMnxdTBctLvEwzGTY,5565
3
- sglang/bench.py,sha256=p34wnfMRdiedOUf9GKGZkkNxehmyTzK6Q1O20q_SGjY,21841
4
3
  sglang/bench_latency.py,sha256=UPy6WhrddMTDX7HqIeHNhCn5vF0YMOKxJlQRvhMC8zU,10552
5
4
  sglang/bench_serving.py,sha256=zKGgVX3S-ggUvOxvEM4AszzXRPRVU6NGNnBG5vAAvRY,34577
6
5
  sglang/check_env.py,sha256=CscuPMlf68dkgZf0m-FiLpUisNNDoihMck4qhLOeV1Q,4124
@@ -8,13 +7,7 @@ sglang/global_config.py,sha256=QG-ABVJksKK_llvUx7fSZcmK4GGCs-hBUVcM4LCr7Nw,1749
8
7
  sglang/launch_server.py,sha256=Gg8CwNlTCCfg1dF65ZT9ePLxOT9LKtY79GhIPG6PCrU,358
9
8
  sglang/launch_server_llavavid.py,sha256=40uaazMsavKuk6YXFa5v37kdUpFGuealgJJeph1g8gU,1025
10
9
  sglang/utils.py,sha256=arJuwOAEX445M2NL9SAOi6jBNu0-cfU04PLAr-hIH3U,8168
11
- sglang/backend/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- sglang/backend/anthropic.py,sha256=iJjXiDMZbtvX2XNG78MG9kM7SpZq9hmXVuzT_T18elw,2076
13
- sglang/backend/base_backend.py,sha256=APiMht4WYECLCOGRPCEUF6lX-an1vjVe2dWoMSgymWY,1831
14
- sglang/backend/litellm.py,sha256=ZqsEZXgxLge-Fh3SMr1XkVPU7z3FKntpRppNwd1a12s,2447
15
- sglang/backend/openai.py,sha256=Id4vDzfefG9R7AqJBMXqYmKHv2FMu0PBSYEGbK7Q510,14803
16
- sglang/backend/runtime_endpoint.py,sha256=PAdnQBj3yQNtgw8GH9F1ecGE7HhxGa2T7Tz_c--H2aE,9203
17
- sglang/backend/vertexai.py,sha256=98toR-L0OTi4dYHaSmmzJdlQ2qN_0lImoKZFlVgYLRE,4850
10
+ sglang/version.py,sha256=Ej7LsXg-6CASlaEHsZkUoLDpYEfHeFKdIeXMIM0esgA,23
18
11
  sglang/lang/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
12
  sglang/lang/chat_template.py,sha256=psIlhaDo70twgLrx5Lgln03metLEA3-FZuixeI0Y7Ao,13309
20
13
  sglang/lang/compiler.py,sha256=UiXUmPR9wBAPtnORrLcyQX8Uh0ZL0nKeV8ZgBozAJPw,7531
@@ -34,12 +27,10 @@ sglang/srt/hf_transformers_utils.py,sha256=94mOI93B2xOmXKqfJfEoGxqHgwwlWNbPHgsA4
34
27
  sglang/srt/memory_pool.py,sha256=FhJk5GtYortO3MJIsMMQ-o49agwDHVX1aEQH2LITq6c,3949
35
28
  sglang/srt/mm_utils.py,sha256=OptgAHDX-73Bk4jAdr2BOAJtiEXJNzPrMhaM-dy275c,8889
36
29
  sglang/srt/model_config.py,sha256=lZu1D-XLVMETHS6FBMoPn8Uowa9QFGe95d3SuWrr2q8,5282
37
- sglang/srt/openai_api_adapter.py,sha256=iw-FquXQeM2Z4nxOoYGFPjTkIdgA8rQkh_IcmJRy-R0,15143
38
- sglang/srt/openai_protocol.py,sha256=lGBhfxG6jmgUkMOh2NpBK9w9TUTRZKrsfHdW7XYhKKI,5700
39
30
  sglang/srt/sampling_params.py,sha256=OI11asr1Bd_E5soDjih614v4flgWxdMZU9HAF0aBafQ,3062
40
- sglang/srt/server.py,sha256=JC6rs8mkWg2mWwriwZvYEZyO514_HJFOUNda-pu8U_4,14369
31
+ sglang/srt/server.py,sha256=DXhcJt0V24a7yhydP1abPrK1qqV3qt7r8cyOMVOAI4M,14611
41
32
  sglang/srt/server_args.py,sha256=aF6L35mEB-FU3BL_ooKuCIcOXLhYLxA9-MjpaOTQRCo,13189
42
- sglang/srt/utils.py,sha256=ZB9WLlZ_GpKVpPJiETrYkqH10J8iWrN_4buxDnQoA88,18568
33
+ sglang/srt/utils.py,sha256=bUp3SLzbDms0dvuETaccDPAGRHOIGW5A61pqH62XiT0,20370
43
34
  sglang/srt/constrained/__init__.py,sha256=5LB3_mDTMW6wcRkFA5J2Rd5HPHHEKRyiELhe4gtlBYM,1472
44
35
  sglang/srt/constrained/base_cache.py,sha256=QQjmFEiT8jlOskJoZobhrDl2TKB-B4b1LPQo9JQCP_w,1405
45
36
  sglang/srt/constrained/fsm_cache.py,sha256=P4qNDHHxpKpTnYL_8V1R6OFXlUwbM6ZcBdzddpcBgb4,1135
@@ -57,11 +48,10 @@ sglang/srt/managers/detokenizer_manager.py,sha256=8rN2cdMr61LWy07lingEqLnNy0W5Re
57
48
  sglang/srt/managers/io_struct.py,sha256=Y6jW3p0cNg0jcrEQNki1H8MMEWxwWA4p6Y-xVgUVWaI,5404
58
49
  sglang/srt/managers/tokenizer_manager.py,sha256=SbivhFhZUR9HU9pLTe93MlYprAFAHzOU3KMBA2piQUk,19308
59
50
  sglang/srt/managers/controller/cuda_graph_runner.py,sha256=0aRqA1_34oJ557Zn8PjpJecex5bBWJdnCmBlcDVvYO0,8509
60
- sglang/srt/managers/controller/dp_worker.py,sha256=ES3-jyxGfHzpgVoXub_3qjVygwfWYWpfN4vuVWU23Gs,3675
61
51
  sglang/srt/managers/controller/infer_batch.py,sha256=SKwCwhnZ_CNlG0mVCEc4X0e4HNjJFke-c8zdWP3TzjQ,34186
62
52
  sglang/srt/managers/controller/manager_multi.py,sha256=DT8Y9RF5OyTxlrLEZYz4claNWir3UrVztdOZaVPiA6g,6077
63
53
  sglang/srt/managers/controller/manager_single.py,sha256=2xO_iWK6tWvc0B31nKbe2N3klxwQBJmPTnFhNjzhVSI,4566
64
- sglang/srt/managers/controller/model_runner.py,sha256=927tf6nJjLjEDgz2wCDj2kvpZ-E_rAVm8PVKFVfP4p8,13951
54
+ sglang/srt/managers/controller/model_runner.py,sha256=FwZ7FU7nhJsYhtoTNxYFc4e6oMEwSqOh8ohXOKtFPKc,15828
65
55
  sglang/srt/managers/controller/radix_cache.py,sha256=tx8LEQpqLxipw9UUVj4D1YQLMMDmWnjDYv8oDlOl-co,8210
66
56
  sglang/srt/managers/controller/schedule_heuristic.py,sha256=SQAGzPS3aB_TPj7rnPBhewwyR6W1sVwW4D3zG3JUY00,2714
67
57
  sglang/srt/managers/controller/tp_worker.py,sha256=yjz-Xzl0zEy4QSU-EYneZH5vi3oHtBuXTtYe4VuDp2g,30517
@@ -90,16 +80,13 @@ sglang/srt/models/qwen2_moe.py,sha256=oHNoo45myV5kitkls2GWVzuGt1Q4pRHN2nLlXEltFI
90
80
  sglang/srt/models/stablelm.py,sha256=Z_XCDSHY_QMz3lZwwkZdIZjEOizZjLYJU9GDi8o08qQ,10802
91
81
  sglang/srt/models/yivl.py,sha256=55KPrQ-dVplI0hh2WCSugjc1luE0J2UAafjZxu_7Xuc,4367
92
82
  sglang/srt/openai_api/adapter.py,sha256=eirFYVGIp5D-UrQLqW5dRJOQYKmzF9nmgCzFeUOb2z8,15737
93
- sglang/srt/openai_api/api_adapter.py,sha256=eirFYVGIp5D-UrQLqW5dRJOQYKmzF9nmgCzFeUOb2z8,15737
94
- sglang/srt/openai_api/openai_api_adapter.py,sha256=5pDaktIEteHxp3qN89U_U3ndd7N0FIfUZAM06YeziUY,15687
95
- sglang/srt/openai_api/openai_protocol.py,sha256=lGBhfxG6jmgUkMOh2NpBK9w9TUTRZKrsfHdW7XYhKKI,5700
96
83
  sglang/srt/openai_api/protocol.py,sha256=j7ifIR2SFQxTwaHAd9ksM096vfffcNltzTH4sg7H0RA,5739
97
84
  sglang/test/test_conversation.py,sha256=gF_AyOxQgpPQBPnA57-kq-M0p_zFu-rBDMFgAq655Rw,1596
98
85
  sglang/test/test_openai_protocol.py,sha256=DVx3r6hrb8oRqbo5AYIleldxbqMBTtb-gtORM6t_Y1c,1661
99
86
  sglang/test/test_programs.py,sha256=uefeHUFKT2NJESOujj-CsnPXdw1aQQN2TzUbPCHJjGs,13654
100
87
  sglang/test/test_utils.py,sha256=kD_fQe3WroZ9Kc3NBRKPiZOFJ_JD2uEE9XIvPp6AD9Y,11048
101
- sglang-0.1.24.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
102
- sglang-0.1.24.dist-info/METADATA,sha256=_HKFljParVedu-eht7OKKb_RpEkVcB-Wh_P_jRW3TJk,30933
103
- sglang-0.1.24.dist-info/WHEEL,sha256=Wyh-_nZ0DJYolHNn1_hMa4lM7uDedD_RGVwbmTjyItk,91
104
- sglang-0.1.24.dist-info/top_level.txt,sha256=yxhh3pYQkcnA7v3Bg889C2jZhvtJdEincysO7PEB09M,7
105
- sglang-0.1.24.dist-info/RECORD,,
88
+ sglang-0.1.25.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
89
+ sglang-0.1.25.dist-info/METADATA,sha256=Ifwh2YdZqQXMe2UCOklWFIGeM0KLkfLjBQHv98gS8Pw,30928
90
+ sglang-0.1.25.dist-info/WHEEL,sha256=Wyh-_nZ0DJYolHNn1_hMa4lM7uDedD_RGVwbmTjyItk,91
91
+ sglang-0.1.25.dist-info/top_level.txt,sha256=yxhh3pYQkcnA7v3Bg889C2jZhvtJdEincysO7PEB09M,7
92
+ sglang-0.1.25.dist-info/RECORD,,
File without changes
@@ -1,77 +0,0 @@
1
- from typing import List, Optional, Union
2
-
3
- import numpy as np
4
-
5
- from sglang.backend.base_backend import BaseBackend
6
- from sglang.lang.chat_template import get_chat_template
7
- from sglang.lang.interpreter import StreamExecutor
8
- from sglang.lang.ir import SglSamplingParams
9
-
10
- try:
11
- import anthropic
12
- except ImportError as e:
13
- anthropic = e
14
-
15
-
16
- class Anthropic(BaseBackend):
17
- def __init__(self, model_name, *args, **kwargs):
18
- super().__init__()
19
-
20
- if isinstance(anthropic, Exception):
21
- raise anthropic
22
-
23
- self.model_name = model_name
24
- self.chat_template = get_chat_template("claude")
25
- self.client = anthropic.Anthropic(*args, **kwargs)
26
-
27
- def get_chat_template(self):
28
- return self.chat_template
29
-
30
- def generate(
31
- self,
32
- s: StreamExecutor,
33
- sampling_params: SglSamplingParams,
34
- ):
35
- if s.messages_:
36
- messages = s.messages_
37
- else:
38
- messages = [{"role": "user", "content": s.text_}]
39
-
40
- if messages and messages[0]["role"] == "system":
41
- system = messages.pop(0)["content"]
42
- else:
43
- system = ""
44
-
45
- ret = self.client.messages.create(
46
- model=self.model_name,
47
- system=system,
48
- messages=messages,
49
- **sampling_params.to_anthropic_kwargs(),
50
- )
51
- comp = ret.content[0].text
52
-
53
- return comp, {}
54
-
55
- def generate_stream(
56
- self,
57
- s: StreamExecutor,
58
- sampling_params: SglSamplingParams,
59
- ):
60
- if s.messages_:
61
- messages = s.messages_
62
- else:
63
- messages = [{"role": "user", "content": s.text_}]
64
-
65
- if messages and messages[0]["role"] == "system":
66
- system = messages.pop(0)["content"]
67
- else:
68
- system = ""
69
-
70
- with self.client.messages.stream(
71
- model=self.model_name,
72
- system=system,
73
- messages=messages,
74
- **sampling_params.to_anthropic_kwargs(),
75
- ) as stream:
76
- for text in stream.text_stream:
77
- yield text, {}
@@ -1,80 +0,0 @@
1
- from typing import Callable, List, Optional, Union
2
-
3
- from sglang.lang.chat_template import get_chat_template
4
- from sglang.lang.interpreter import StreamExecutor
5
- from sglang.lang.ir import SglSamplingParams
6
-
7
-
8
- class BaseBackend:
9
- def __init__(self) -> None:
10
- self.support_concate_and_append = False
11
- self.chat_template = get_chat_template("default")
12
-
13
- def get_model_name(self):
14
- raise NotImplementedError()
15
-
16
- def get_chat_template(self):
17
- return self.chat_template
18
-
19
- def cache_prefix(self, prefix_str: str):
20
- pass
21
-
22
- def uncache_prefix(self, rid: str):
23
- pass
24
-
25
- def end_request(self, rid: Union[str, List[str]]):
26
- pass
27
-
28
- def begin_program(self, s: StreamExecutor):
29
- pass
30
-
31
- def end_program(self, s: Union[StreamExecutor, List[StreamExecutor]]):
32
- pass
33
-
34
- def commit_lazy_operations(self, s: StreamExecutor):
35
- pass
36
-
37
- def fork_program(
38
- self,
39
- src: StreamExecutor,
40
- dst: List[StreamExecutor],
41
- position_ids_offset: Optional[List[int]] = None,
42
- ):
43
- pass
44
-
45
- def fill_image(self, s: StreamExecutor):
46
- pass
47
-
48
- def generate(
49
- self,
50
- s: StreamExecutor,
51
- sampling_params: SglSamplingParams,
52
- ):
53
- raise NotImplementedError()
54
-
55
- def generate_stream(
56
- self,
57
- s: StreamExecutor,
58
- sampling_params: SglSamplingParams,
59
- ):
60
- raise NotImplementedError()
61
-
62
- def select(
63
- self,
64
- s: StreamExecutor,
65
- choices: List[str],
66
- temperature: float,
67
- ):
68
- raise NotImplementedError()
69
-
70
- def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
71
- raise NotImplementedError()
72
-
73
- def shutdown(self):
74
- pass
75
-
76
- def flush_cache(self):
77
- pass
78
-
79
- def get_server_args(self):
80
- pass
sglang/backend/litellm.py DELETED
@@ -1,90 +0,0 @@
1
- from typing import Mapping, Optional
2
-
3
- from sglang.backend.base_backend import BaseBackend
4
- from sglang.lang.chat_template import get_chat_template_by_model_path
5
- from sglang.lang.interpreter import StreamExecutor
6
- from sglang.lang.ir import SglSamplingParams
7
-
8
- try:
9
- import litellm
10
- except ImportError as e:
11
- litellm = e
12
- litellm.num_retries = 1
13
-
14
-
15
- class LiteLLM(BaseBackend):
16
- def __init__(
17
- self,
18
- model_name,
19
- chat_template=None,
20
- api_key=None,
21
- organization: Optional[str] = None,
22
- base_url: Optional[str] = None,
23
- timeout: Optional[float] = 600,
24
- max_retries: Optional[int] = litellm.num_retries,
25
- default_headers: Optional[Mapping[str, str]] = None,
26
- ):
27
- super().__init__()
28
-
29
- if isinstance(litellm, Exception):
30
- raise litellm
31
-
32
- self.model_name = model_name
33
-
34
- self.chat_template = chat_template or get_chat_template_by_model_path(
35
- model_name
36
- )
37
-
38
- self.client_params = {
39
- "api_key": api_key,
40
- "organization": organization,
41
- "base_url": base_url,
42
- "timeout": timeout,
43
- "max_retries": max_retries,
44
- "default_headers": default_headers,
45
- }
46
-
47
- def get_chat_template(self):
48
- return self.chat_template
49
-
50
- def generate(
51
- self,
52
- s: StreamExecutor,
53
- sampling_params: SglSamplingParams,
54
- ):
55
- if s.messages_:
56
- messages = s.messages_
57
- else:
58
- messages = [{"role": "user", "content": s.text_}]
59
-
60
- ret = litellm.completion(
61
- model=self.model_name,
62
- messages=messages,
63
- **self.client_params,
64
- **sampling_params.to_anthropic_kwargs(),
65
- )
66
- comp = ret.choices[0].message.content
67
-
68
- return comp, {}
69
-
70
- def generate_stream(
71
- self,
72
- s: StreamExecutor,
73
- sampling_params: SglSamplingParams,
74
- ):
75
- if s.messages_:
76
- messages = s.messages_
77
- else:
78
- messages = [{"role": "user", "content": s.text_}]
79
-
80
- ret = litellm.completion(
81
- model=self.model_name,
82
- messages=messages,
83
- stream=True,
84
- **self.client_params,
85
- **sampling_params.to_litellm_kwargs(),
86
- )
87
- for chunk in ret:
88
- text = chunk.choices[0].delta.content
89
- if text is not None:
90
- yield text, {}