xinference 0.10.3__py3-none-any.whl → 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (101) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/oauth2/auth_service.py +1 -1
  3. xinference/api/restful_api.py +53 -61
  4. xinference/client/restful/restful_client.py +52 -57
  5. xinference/conftest.py +1 -1
  6. xinference/core/cache_tracker.py +1 -1
  7. xinference/core/chat_interface.py +10 -4
  8. xinference/core/event.py +1 -1
  9. xinference/core/model.py +17 -6
  10. xinference/core/status_guard.py +1 -1
  11. xinference/core/supervisor.py +58 -72
  12. xinference/core/worker.py +68 -101
  13. xinference/deploy/cmdline.py +166 -1
  14. xinference/deploy/test/test_cmdline.py +2 -0
  15. xinference/deploy/utils.py +1 -1
  16. xinference/device_utils.py +29 -3
  17. xinference/fields.py +7 -1
  18. xinference/model/audio/whisper.py +88 -12
  19. xinference/model/core.py +2 -2
  20. xinference/model/image/__init__.py +29 -0
  21. xinference/model/image/core.py +6 -0
  22. xinference/model/image/custom.py +109 -0
  23. xinference/model/llm/__init__.py +92 -32
  24. xinference/model/llm/core.py +57 -102
  25. xinference/model/llm/ggml/chatglm.py +98 -13
  26. xinference/model/llm/ggml/llamacpp.py +49 -2
  27. xinference/model/llm/ggml/tools/convert_ggml_to_gguf.py +2 -2
  28. xinference/model/llm/llm_family.json +438 -7
  29. xinference/model/llm/llm_family.py +45 -41
  30. xinference/model/llm/llm_family_modelscope.json +258 -5
  31. xinference/model/llm/pytorch/chatglm.py +48 -0
  32. xinference/model/llm/pytorch/core.py +23 -6
  33. xinference/model/llm/pytorch/deepseek_vl.py +115 -33
  34. xinference/model/llm/pytorch/internlm2.py +32 -1
  35. xinference/model/llm/pytorch/qwen_vl.py +94 -12
  36. xinference/model/llm/pytorch/utils.py +38 -1
  37. xinference/model/llm/pytorch/yi_vl.py +96 -51
  38. xinference/model/llm/sglang/core.py +31 -9
  39. xinference/model/llm/utils.py +54 -20
  40. xinference/model/llm/vllm/core.py +101 -7
  41. xinference/thirdparty/omnilmm/chat.py +2 -1
  42. xinference/thirdparty/omnilmm/model/omnilmm.py +2 -1
  43. xinference/types.py +11 -0
  44. xinference/web/ui/build/asset-manifest.json +6 -3
  45. xinference/web/ui/build/index.html +1 -1
  46. xinference/web/ui/build/static/css/main.54bca460.css +2 -0
  47. xinference/web/ui/build/static/css/main.54bca460.css.map +1 -0
  48. xinference/web/ui/build/static/js/main.551aa479.js +3 -0
  49. xinference/web/ui/build/static/js/{main.26fdbfbe.js.LICENSE.txt → main.551aa479.js.LICENSE.txt} +7 -0
  50. xinference/web/ui/build/static/js/main.551aa479.js.map +1 -0
  51. xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +1 -0
  52. xinference/web/ui/node_modules/.cache/babel-loader/1fa824d82b2af519de7700c594e50bde4bbca60d13bd3fabff576802e4070304.json +1 -0
  53. xinference/web/ui/node_modules/.cache/babel-loader/23caf6f1e52c43e983ca3bfd4189f41dbd645fa78f2dfdcd7f6b69bc41678665.json +1 -0
  54. xinference/web/ui/node_modules/.cache/babel-loader/29dda700ab913cf7f2cfabe450ddabfb283e96adfa3ec9d315b2fa6c63cd375c.json +1 -0
  55. xinference/web/ui/node_modules/.cache/babel-loader/2c63e940b945fd5817157e08a42b889b30d668ea4c91332f48ef2b1b9d26f520.json +1 -0
  56. xinference/web/ui/node_modules/.cache/babel-loader/4135fe8745434cbce6438d1ebfa47422e0c77d884db4edc75c8bf32ea1d50621.json +1 -0
  57. xinference/web/ui/node_modules/.cache/babel-loader/46b6dd1f6d1109cd0e2455a0ea0be3e9bda1097cd4ebec9c4040070372671cfc.json +1 -0
  58. xinference/web/ui/node_modules/.cache/babel-loader/4de0a71074f9cbe1e7862750dcdd08cbc1bae7d9d9849a78b1783ca670017b3c.json +1 -0
  59. xinference/web/ui/node_modules/.cache/babel-loader/53f6c0c0afb51265cd8fb940daeb65523501879ac2a8c03a1ead22b9793c5041.json +1 -0
  60. xinference/web/ui/node_modules/.cache/babel-loader/8ccbb839002bc5bc03e0a0e7612362bf92f6ae64f87e094f8682d6a6fe4619bb.json +1 -0
  61. xinference/web/ui/node_modules/.cache/babel-loader/97ed30d6e22cf76f0733651e2c18364689a01665d0b5fe811c1b7ca3eb713c82.json +1 -0
  62. xinference/web/ui/node_modules/.cache/babel-loader/9c0c70f1838913aaa792a0d2260f17f90fd177b95698ed46b7bc3050eb712c1c.json +1 -0
  63. xinference/web/ui/node_modules/.cache/babel-loader/9cfd33238ca43e5bf9fc7e442690e8cc6027c73553db36de87e3597ed524ee4b.json +1 -0
  64. xinference/web/ui/node_modules/.cache/babel-loader/a6da6bc3d0d2191adebee87fb58ecebe82d071087bd2f7f3a9c7fdd2ada130f2.json +1 -0
  65. xinference/web/ui/node_modules/.cache/babel-loader/ada71518a429f821a9b1dea38bc951447f03c8db509887e0980b893acac938f3.json +1 -0
  66. xinference/web/ui/node_modules/.cache/babel-loader/b6c9558d28b5972bb8b2691c5a76a2c8814a815eb3443126da9f49f7d6a0c118.json +1 -0
  67. xinference/web/ui/node_modules/.cache/babel-loader/bb0f721c084a4d85c09201c984f02ee8437d3b6c5c38a57cb4a101f653daef1b.json +1 -0
  68. xinference/web/ui/node_modules/.package-lock.json +33 -0
  69. xinference/web/ui/node_modules/clipboard/.babelrc.json +11 -0
  70. xinference/web/ui/node_modules/clipboard/.eslintrc.json +24 -0
  71. xinference/web/ui/node_modules/clipboard/.prettierrc.json +9 -0
  72. xinference/web/ui/node_modules/clipboard/bower.json +18 -0
  73. xinference/web/ui/node_modules/clipboard/composer.json +25 -0
  74. xinference/web/ui/node_modules/clipboard/package.json +63 -0
  75. xinference/web/ui/node_modules/delegate/package.json +31 -0
  76. xinference/web/ui/node_modules/good-listener/bower.json +11 -0
  77. xinference/web/ui/node_modules/good-listener/package.json +35 -0
  78. xinference/web/ui/node_modules/select/bower.json +13 -0
  79. xinference/web/ui/node_modules/select/package.json +29 -0
  80. xinference/web/ui/node_modules/tiny-emitter/package.json +53 -0
  81. xinference/web/ui/package-lock.json +34 -0
  82. xinference/web/ui/package.json +1 -0
  83. {xinference-0.10.3.dist-info → xinference-0.11.1.dist-info}/METADATA +13 -12
  84. {xinference-0.10.3.dist-info → xinference-0.11.1.dist-info}/RECORD +88 -67
  85. xinference/client/oscar/__init__.py +0 -13
  86. xinference/client/oscar/actor_client.py +0 -611
  87. xinference/model/llm/pytorch/spec_decoding_utils.py +0 -531
  88. xinference/model/llm/pytorch/spec_model.py +0 -186
  89. xinference/web/ui/build/static/js/main.26fdbfbe.js +0 -3
  90. xinference/web/ui/build/static/js/main.26fdbfbe.js.map +0 -1
  91. xinference/web/ui/node_modules/.cache/babel-loader/1870cd6f7054d04e049e363c0a85526584fe25519378609d2838e28d7492bbf1.json +0 -1
  92. xinference/web/ui/node_modules/.cache/babel-loader/5393569d846332075b93b55656716a34f50e0a8c970be789502d7e6c49755fd7.json +0 -1
  93. xinference/web/ui/node_modules/.cache/babel-loader/63a4c48f0326d071c7772c46598215c006ae41fd3d4ff3577fe717de66ad6e89.json +0 -1
  94. xinference/web/ui/node_modules/.cache/babel-loader/de0299226173b0662b573f49e3992220f6611947073bd66ac079728a8bc8837d.json +0 -1
  95. xinference/web/ui/node_modules/.cache/babel-loader/e9b52d171223bb59fb918316297a051cdfd42dd453e8260fd918e90bc0a4ebdf.json +0 -1
  96. xinference/web/ui/node_modules/.cache/babel-loader/f4d5d1a41892a754c1ee0237450d804b20612d1b657945b59e564161ea47aa7a.json +0 -1
  97. xinference/web/ui/node_modules/.cache/babel-loader/fad4cd70de36ef6e6d5f8fd74a10ded58d964a8a91ef7681693fbb8376552da7.json +0 -1
  98. {xinference-0.10.3.dist-info → xinference-0.11.1.dist-info}/LICENSE +0 -0
  99. {xinference-0.10.3.dist-info → xinference-0.11.1.dist-info}/WHEEL +0 -0
  100. {xinference-0.10.3.dist-info → xinference-0.11.1.dist-info}/entry_points.txt +0 -0
  101. {xinference-0.10.3.dist-info → xinference-0.11.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,109 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import os
17
+ from threading import Lock
18
+ from typing import List, Optional
19
+
20
+ from ...constants import XINFERENCE_CACHE_DIR, XINFERENCE_MODEL_DIR
21
+ from .core import ImageModelFamilyV1
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ UD_IMAGE_LOCK = Lock()
26
+
27
+
28
+ class CustomImageModelFamilyV1(ImageModelFamilyV1):
29
+ model_id: Optional[str] # type: ignore
30
+ model_revision: Optional[str] # type: ignore
31
+ model_uri: Optional[str]
32
+ controlnet: Optional[List["CustomImageModelFamilyV1"]]
33
+
34
+
35
+ UD_IMAGES: List[CustomImageModelFamilyV1] = []
36
+
37
+
38
+ def get_user_defined_images() -> List[ImageModelFamilyV1]:
39
+ with UD_IMAGE_LOCK:
40
+ return UD_IMAGES.copy()
41
+
42
+
43
+ def register_image(model_spec: CustomImageModelFamilyV1, persist: bool):
44
+ from ..utils import is_valid_model_name, is_valid_model_uri
45
+ from . import BUILTIN_IMAGE_MODELS, MODELSCOPE_IMAGE_MODELS
46
+
47
+ if not is_valid_model_name(model_spec.model_name):
48
+ raise ValueError(f"Invalid model name {model_spec.model_name}.")
49
+
50
+ with UD_IMAGE_LOCK:
51
+ for model_name in (
52
+ list(BUILTIN_IMAGE_MODELS.keys())
53
+ + list(MODELSCOPE_IMAGE_MODELS.keys())
54
+ + [spec.model_name for spec in UD_IMAGES]
55
+ ):
56
+ if model_spec.model_name == model_name:
57
+ raise ValueError(
58
+ f"Model name conflicts with existing model {model_spec.model_name}"
59
+ )
60
+ UD_IMAGES.append(model_spec)
61
+
62
+ if persist:
63
+ # We only validate model URL when persist is True.
64
+ model_uri = model_spec.model_uri
65
+ if model_uri and not is_valid_model_uri(model_uri):
66
+ raise ValueError(f"Invalid model URI {model_uri}")
67
+
68
+ persist_path = os.path.join(
69
+ XINFERENCE_MODEL_DIR, "image", f"{model_spec.model_id}.json"
70
+ )
71
+ os.makedirs(os.path.dirname(persist_path), exist_ok=True)
72
+ with open(persist_path, "w") as f:
73
+ f.write(model_spec.json())
74
+
75
+
76
+ def unregister_image(model_name: str, raise_error: bool = True):
77
+ with UD_IMAGE_LOCK:
78
+ model_spec = None
79
+ for i, f in enumerate(UD_IMAGES):
80
+ if f.model_name == model_name:
81
+ model_spec = f
82
+ break
83
+ if model_spec:
84
+ UD_IMAGES.remove(model_spec)
85
+
86
+ persist_path = os.path.join(
87
+ XINFERENCE_MODEL_DIR, "image", f"{model_spec.model_id}.json"
88
+ )
89
+
90
+ if os.path.exists(persist_path):
91
+ os.remove(persist_path)
92
+
93
+ cache_dir = os.path.join(XINFERENCE_CACHE_DIR, model_spec.model_name)
94
+ if os.path.exists(cache_dir):
95
+ logger.warning(
96
+ f"Remove the cache of user-defined model {model_spec.model_name}. "
97
+ f"Cache directory: {cache_dir}"
98
+ )
99
+ if os.path.islink(cache_dir):
100
+ os.remove(cache_dir)
101
+ else:
102
+ logger.warning(
103
+ f"Cache directory is not a soft link, please remove it manually."
104
+ )
105
+ else:
106
+ if raise_error:
107
+ raise ValueError(f"Model {model_name} not found.")
108
+ else:
109
+ logger.warning(f"Custom image model {model_name} not found.")
@@ -15,6 +15,7 @@
15
15
  import codecs
16
16
  import json
17
17
  import os
18
+ import warnings
18
19
 
19
20
  from .core import (
20
21
  LLM,
@@ -30,8 +31,12 @@ from .llm_family import (
30
31
  BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES,
31
32
  BUILTIN_LLM_PROMPT_STYLE,
32
33
  BUILTIN_MODELSCOPE_LLM_FAMILIES,
33
- LLM_CLASSES,
34
- PEFT_SUPPORTED_CLASSES,
34
+ LLAMA_CLASSES,
35
+ LLM_ENGINES,
36
+ SGLANG_CLASSES,
37
+ SUPPORTED_ENGINES,
38
+ TRANSFORMERS_CLASSES,
39
+ VLLM_CLASSES,
35
40
  CustomLLMFamilyV1,
36
41
  GgmlLLMSpecV1,
37
42
  LLMFamilyV1,
@@ -41,12 +46,68 @@ from .llm_family import (
41
46
  get_cache_status,
42
47
  get_user_defined_llm_families,
43
48
  match_llm,
44
- match_llm_cls,
45
49
  register_llm,
46
50
  unregister_llm,
47
51
  )
48
52
 
49
53
 
54
+ def check_format_with_engine(model_format, engine):
55
+ # only llama-cpp-python support and only support ggufv2 and ggmlv3
56
+ if model_format in ["ggufv2", "ggmlv3"] and engine != "llama.cpp":
57
+ return False
58
+ if model_format not in ["ggufv2", "ggmlv3"] and engine == "llama.cpp":
59
+ return False
60
+ return True
61
+
62
+
63
+ def generate_engine_config_by_model_family(model_family):
64
+ model_name = model_family.model_name
65
+ specs = model_family.model_specs
66
+ engines = LLM_ENGINES.get(model_name, {}) # structure for engine query
67
+ for spec in specs:
68
+ model_format = spec.model_format
69
+ model_size_in_billions = spec.model_size_in_billions
70
+ quantizations = spec.quantizations
71
+ for quantization in quantizations:
72
+ # traverse all supported engines to match the name, format, size in billions and quatization of model
73
+ for engine in SUPPORTED_ENGINES:
74
+ if not check_format_with_engine(
75
+ model_format, engine
76
+ ): # match the format of model with engine
77
+ continue
78
+ CLASSES = SUPPORTED_ENGINES[engine]
79
+ for cls in CLASSES:
80
+ if cls.match(model_family, spec, quantization):
81
+ engine_params = engines.get(engine, [])
82
+ already_exists = False
83
+ # if the name, format and size in billions of model already exists in the structure, add the new quantization
84
+ for param in engine_params:
85
+ if (
86
+ model_name == param["model_name"]
87
+ and model_format == param["model_format"]
88
+ and model_size_in_billions
89
+ == param["model_size_in_billions"]
90
+ ):
91
+ if quantization not in param["quantizations"]:
92
+ param["quantizations"].append(quantization)
93
+ already_exists = True
94
+ break
95
+ # successfully match the params for the first time, add to the structure
96
+ if not already_exists:
97
+ engine_params.append(
98
+ {
99
+ "model_name": model_name,
100
+ "model_format": model_format,
101
+ "model_size_in_billions": model_size_in_billions,
102
+ "quantizations": [quantization],
103
+ "llm_class": cls,
104
+ }
105
+ )
106
+ engines[engine] = engine_params
107
+ break
108
+ LLM_ENGINES[model_name] = engines
109
+
110
+
50
111
  def _install():
51
112
  from .ggml.chatglm import ChatglmCppChatModel
52
113
  from .ggml.llamacpp import LlamaCppChatModel, LlamaCppModel
@@ -57,28 +118,31 @@ def _install():
57
118
  from .pytorch.falcon import FalconPytorchChatModel, FalconPytorchModel
58
119
  from .pytorch.internlm2 import Internlm2PytorchChatModel
59
120
  from .pytorch.llama_2 import LlamaPytorchChatModel, LlamaPytorchModel
60
- from .pytorch.omnilmm import OmniLMMModel
61
121
  from .pytorch.qwen_vl import QwenVLChatModel
62
122
  from .pytorch.vicuna import VicunaPytorchChatModel
63
123
  from .pytorch.yi_vl import YiVLChatModel
64
124
  from .sglang.core import SGLANGChatModel, SGLANGModel
65
125
  from .vllm.core import VLLMChatModel, VLLMModel
66
126
 
127
+ try:
128
+ from .pytorch.omnilmm import OmniLMMModel
129
+ except ImportError as e:
130
+ # For quite old transformers version,
131
+ # import will generate error
132
+ OmniLMMModel = None
133
+ warnings.warn(f"Cannot import OmniLLMModel due to reason: {e}")
134
+
67
135
  # register llm classes.
68
- LLM_CLASSES.extend(
136
+ LLAMA_CLASSES.extend(
69
137
  [
138
+ ChatglmCppChatModel,
70
139
  LlamaCppChatModel,
71
140
  LlamaCppModel,
72
141
  ]
73
142
  )
74
- LLM_CLASSES.extend(
75
- [
76
- ChatglmCppChatModel,
77
- ]
78
- )
79
- LLM_CLASSES.extend([SGLANGModel, SGLANGChatModel])
80
- LLM_CLASSES.extend([VLLMModel, VLLMChatModel])
81
- LLM_CLASSES.extend(
143
+ SGLANG_CLASSES.extend([SGLANGModel, SGLANGChatModel])
144
+ VLLM_CLASSES.extend([VLLMModel, VLLMChatModel])
145
+ TRANSFORMERS_CLASSES.extend(
82
146
  [
83
147
  BaichuanPytorchChatModel,
84
148
  VicunaPytorchChatModel,
@@ -90,28 +154,19 @@ def _install():
90
154
  FalconPytorchModel,
91
155
  Internlm2PytorchChatModel,
92
156
  QwenVLChatModel,
93
- OmniLMMModel,
94
157
  YiVLChatModel,
95
158
  DeepSeekVLChatModel,
96
159
  PytorchModel,
97
160
  ]
98
161
  )
99
- PEFT_SUPPORTED_CLASSES.extend(
100
- [
101
- BaichuanPytorchChatModel,
102
- VicunaPytorchChatModel,
103
- FalconPytorchChatModel,
104
- ChatglmPytorchChatModel,
105
- LlamaPytorchModel,
106
- LlamaPytorchChatModel,
107
- PytorchChatModel,
108
- FalconPytorchModel,
109
- Internlm2PytorchChatModel,
110
- QwenVLChatModel,
111
- YiVLChatModel,
112
- PytorchModel,
113
- ]
114
- )
162
+ if OmniLMMModel: # type: ignore
163
+ TRANSFORMERS_CLASSES.append(OmniLMMModel)
164
+
165
+ # support 4 engines for now
166
+ SUPPORTED_ENGINES["vLLM"] = VLLM_CLASSES
167
+ SUPPORTED_ENGINES["SGLang"] = SGLANG_CLASSES
168
+ SUPPORTED_ENGINES["Transformers"] = TRANSFORMERS_CLASSES
169
+ SUPPORTED_ENGINES["llama.cpp"] = LLAMA_CLASSES
115
170
 
116
171
  json_path = os.path.join(
117
172
  os.path.dirname(os.path.abspath(__file__)), "llm_family.json"
@@ -132,7 +187,7 @@ def _install():
132
187
  BUILTIN_LLM_MODEL_CHAT_FAMILIES.add(model_spec.model_name)
133
188
  else:
134
189
  BUILTIN_LLM_MODEL_GENERATE_FAMILIES.add(model_spec.model_name)
135
- if "tool_call" in model_spec.model_ability:
190
+ if "tools" in model_spec.model_ability:
136
191
  BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES.add(model_spec.model_name)
137
192
 
138
193
  modelscope_json_path = os.path.join(
@@ -155,7 +210,7 @@ def _install():
155
210
  BUILTIN_LLM_MODEL_CHAT_FAMILIES.add(model_spec.model_name)
156
211
  else:
157
212
  BUILTIN_LLM_MODEL_GENERATE_FAMILIES.add(model_spec.model_name)
158
- if "tool_call" in model_spec.model_ability:
213
+ if "tools" in model_spec.model_ability:
159
214
  BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES.add(model_spec.model_name)
160
215
 
161
216
  for llm_specs in [BUILTIN_LLM_FAMILIES, BUILTIN_MODELSCOPE_LLM_FAMILIES]:
@@ -163,6 +218,11 @@ def _install():
163
218
  if llm_spec.model_name not in LLM_MODEL_DESCRIPTIONS:
164
219
  LLM_MODEL_DESCRIPTIONS.update(generate_llm_description(llm_spec))
165
220
 
221
+ # traverse all families and add engine parameters corresponding to the model name
222
+ for families in [BUILTIN_LLM_FAMILIES, BUILTIN_MODELSCOPE_LLM_FAMILIES]:
223
+ for family in families:
224
+ generate_engine_config_by_model_family(family)
225
+
166
226
  from ...constants import XINFERENCE_MODEL_DIR
167
227
 
168
228
  user_defined_llm_dir = os.path.join(XINFERENCE_MODEL_DIR, "llm")
@@ -13,11 +13,13 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import abc
16
+ import inspect
16
17
  import logging
17
18
  import os
18
19
  import platform
19
20
  from abc import abstractmethod
20
21
  from collections import defaultdict
22
+ from functools import lru_cache
21
23
  from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
22
24
 
23
25
  from ...core.utils import parse_replica_model_uid
@@ -62,16 +64,6 @@ class LLM(abc.ABC):
62
64
  if kwargs:
63
65
  raise ValueError(f"Unrecognized keyword arguments: {kwargs}")
64
66
 
65
- @staticmethod
66
- def handle_model_size(model_size_in_billions: Union[str, int]) -> Union[int, float]:
67
- if isinstance(model_size_in_billions, str):
68
- if "_" in model_size_in_billions:
69
- ms = model_size_in_billions.replace("_", ".")
70
- return float(ms)
71
- else:
72
- raise ValueError("Invalid format for `model_size_in_billions`")
73
- return model_size_in_billions
74
-
75
67
  @staticmethod
76
68
  def _is_darwin_and_apple_silicon():
77
69
  return platform.system() == "Darwin" and platform.processor() == "arm"
@@ -81,12 +73,30 @@ class LLM(abc.ABC):
81
73
  return platform.system() == "Linux"
82
74
 
83
75
  @staticmethod
76
+ @lru_cache
84
77
  def _has_cuda_device():
85
- from ...utils import cuda_count
86
-
87
- return cuda_count() > 0
78
+ """
79
+ Use pynvml to impl this interface.
80
+ DO NOT USE torch to impl this, which will lead to some unexpected errors.
81
+ """
82
+ from pynvml import nvmlDeviceGetCount, nvmlInit, nvmlShutdown
83
+
84
+ device_count = 0
85
+ try:
86
+ nvmlInit()
87
+ device_count = nvmlDeviceGetCount()
88
+ except:
89
+ pass
90
+ finally:
91
+ try:
92
+ nvmlShutdown()
93
+ except:
94
+ pass
95
+
96
+ return device_count > 0
88
97
 
89
98
  @staticmethod
99
+ @lru_cache
90
100
  def _get_cuda_count():
91
101
  from ...utils import cuda_count
92
102
 
@@ -178,47 +188,60 @@ def create_llm_model_instance(
178
188
  devices: List[str],
179
189
  model_uid: str,
180
190
  model_name: str,
191
+ model_engine: Optional[str],
181
192
  model_format: Optional[str] = None,
182
193
  model_size_in_billions: Optional[Union[int, str]] = None,
183
194
  quantization: Optional[str] = None,
184
195
  peft_model_config: Optional[PeftModelConfig] = None,
185
- is_local_deployment: bool = False,
186
196
  **kwargs,
187
197
  ) -> Tuple[LLM, LLMDescription]:
188
- from . import match_llm, match_llm_cls
189
- from .llm_family import cache
198
+ from .llm_family import cache, check_engine_by_spec_parameters, match_llm
190
199
 
200
+ if model_engine is None:
201
+ raise ValueError("model_engine is required for LLM model")
191
202
  match_result = match_llm(
192
- model_name,
193
- model_format,
194
- model_size_in_billions,
195
- quantization,
196
- is_local_deployment,
203
+ model_name, model_format, model_size_in_billions, quantization
197
204
  )
205
+
198
206
  if not match_result:
199
207
  raise ValueError(
200
208
  f"Model not found, name: {model_name}, format: {model_format},"
201
209
  f" size: {model_size_in_billions}, quantization: {quantization}"
202
210
  )
203
211
  llm_family, llm_spec, quantization = match_result
204
-
205
212
  assert quantization is not None
206
- save_path = cache(llm_family, llm_spec, quantization)
207
-
208
- peft_model = peft_model_config.peft_model if peft_model_config else None
209
213
 
210
- llm_cls = match_llm_cls(llm_family, llm_spec, quantization, peft_model=peft_model)
211
- if not llm_cls:
212
- raise ValueError(
213
- f"Model not supported, name: {model_name}, format: {model_format},"
214
- f" size: {model_size_in_billions}, quantization: {quantization}"
215
- )
214
+ llm_cls = check_engine_by_spec_parameters(
215
+ model_engine,
216
+ llm_family.model_name,
217
+ llm_spec.model_format,
218
+ llm_spec.model_size_in_billions,
219
+ quantization,
220
+ )
216
221
  logger.debug(f"Launching {model_uid} with {llm_cls.__name__}")
217
222
 
223
+ save_path = cache(llm_family, llm_spec, quantization)
224
+
225
+ peft_model = peft_model_config.peft_model if peft_model_config else None
218
226
  if peft_model is not None:
219
- model = llm_cls(
220
- model_uid, llm_family, llm_spec, quantization, save_path, kwargs, peft_model
221
- )
227
+ if "peft_model" in inspect.signature(llm_cls.__init__).parameters:
228
+ model = llm_cls(
229
+ model_uid,
230
+ llm_family,
231
+ llm_spec,
232
+ quantization,
233
+ save_path,
234
+ kwargs,
235
+ peft_model,
236
+ )
237
+ else:
238
+ logger.warning(
239
+ f"Model not supported with lora, name: {model_name}, format: {model_format}, engine: {model_engine}. "
240
+ f"Load this without lora."
241
+ )
242
+ model = llm_cls(
243
+ model_uid, llm_family, llm_spec, quantization, save_path, kwargs
244
+ )
222
245
  else:
223
246
  model = llm_cls(
224
247
  model_uid, llm_family, llm_spec, quantization, save_path, kwargs
@@ -226,71 +249,3 @@ def create_llm_model_instance(
226
249
  return model, LLMDescription(
227
250
  subpool_addr, devices, llm_family, llm_spec, quantization
228
251
  )
229
-
230
-
231
- def create_speculative_llm_model_instance(
232
- subpool_addr: str,
233
- devices: List[str],
234
- model_uid: str,
235
- model_name: str,
236
- model_size_in_billions: Optional[Union[int, str]],
237
- quantization: Optional[str],
238
- draft_model_name: str,
239
- draft_model_size_in_billions: Optional[int],
240
- draft_quantization: Optional[str],
241
- is_local_deployment: bool = False,
242
- ) -> Tuple[LLM, LLMDescription]:
243
- from . import match_llm
244
- from .llm_family import cache
245
-
246
- match_result = match_llm(
247
- model_name,
248
- "pytorch",
249
- model_size_in_billions,
250
- quantization,
251
- is_local_deployment,
252
- )
253
-
254
- if not match_result:
255
- raise ValueError(
256
- f"Model not found, name: {model_name}, format: pytorch,"
257
- f" size: {model_size_in_billions}, quantization: {quantization}"
258
- )
259
- llm_family, llm_spec, quantization = match_result
260
- assert quantization is not None
261
- save_path = cache(llm_family, llm_spec, quantization)
262
-
263
- draft_match_result = match_llm(
264
- draft_model_name,
265
- "pytorch",
266
- draft_model_size_in_billions,
267
- draft_quantization,
268
- is_local_deployment,
269
- )
270
-
271
- if not draft_match_result:
272
- raise ValueError(
273
- f"Model not found, name: {draft_model_name}, format: pytorch,"
274
- f" size: {draft_model_size_in_billions}, quantization: {draft_quantization}"
275
- )
276
- draft_llm_family, draft_llm_spec, draft_quantization = draft_match_result
277
- assert draft_quantization is not None
278
- draft_save_path = cache(draft_llm_family, draft_llm_spec, draft_quantization)
279
-
280
- from .pytorch.spec_model import SpeculativeModel
281
-
282
- model = SpeculativeModel(
283
- model_uid,
284
- model_family=llm_family,
285
- model_spec=llm_spec,
286
- quantization=quantization,
287
- model_path=save_path,
288
- draft_model_family=draft_llm_family,
289
- draft_model_spec=draft_llm_spec,
290
- draft_quantization=draft_quantization,
291
- draft_model_path=draft_save_path,
292
- )
293
-
294
- return model, LLMDescription(
295
- subpool_addr, devices, llm_family, llm_spec, quantization
296
- )