xinference 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (97) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +34 -15
  3. xinference/client/oscar/actor_client.py +4 -3
  4. xinference/client/restful/restful_client.py +40 -18
  5. xinference/core/supervisor.py +48 -9
  6. xinference/core/worker.py +13 -8
  7. xinference/deploy/cmdline.py +22 -9
  8. xinference/model/audio/__init__.py +40 -1
  9. xinference/model/audio/core.py +25 -45
  10. xinference/model/audio/custom.py +148 -0
  11. xinference/model/core.py +6 -9
  12. xinference/model/embedding/core.py +1 -2
  13. xinference/model/embedding/model_spec.json +24 -0
  14. xinference/model/embedding/model_spec_modelscope.json +24 -0
  15. xinference/model/image/core.py +12 -4
  16. xinference/model/image/stable_diffusion/core.py +8 -7
  17. xinference/model/llm/__init__.py +0 -6
  18. xinference/model/llm/core.py +9 -14
  19. xinference/model/llm/ggml/llamacpp.py +2 -10
  20. xinference/model/llm/llm_family.json +507 -7
  21. xinference/model/llm/llm_family.py +41 -4
  22. xinference/model/llm/llm_family_modelscope.json +260 -0
  23. xinference/model/llm/pytorch/baichuan.py +4 -3
  24. xinference/model/llm/pytorch/chatglm.py +5 -2
  25. xinference/model/llm/pytorch/core.py +37 -41
  26. xinference/model/llm/pytorch/falcon.py +6 -5
  27. xinference/model/llm/pytorch/internlm2.py +5 -2
  28. xinference/model/llm/pytorch/llama_2.py +6 -5
  29. xinference/model/llm/pytorch/qwen_vl.py +2 -0
  30. xinference/model/llm/pytorch/vicuna.py +4 -3
  31. xinference/model/llm/pytorch/yi_vl.py +4 -2
  32. xinference/model/llm/utils.py +42 -4
  33. xinference/model/llm/vllm/core.py +54 -6
  34. xinference/model/rerank/core.py +26 -12
  35. xinference/model/rerank/model_spec.json +24 -0
  36. xinference/model/rerank/model_spec_modelscope.json +25 -1
  37. xinference/model/utils.py +12 -1
  38. xinference/thirdparty/omnilmm/chat.py +1 -1
  39. xinference/types.py +70 -19
  40. xinference/utils.py +1 -0
  41. xinference/web/ui/build/asset-manifest.json +3 -3
  42. xinference/web/ui/build/index.html +1 -1
  43. xinference/web/ui/build/static/js/main.26fdbfbe.js +3 -0
  44. xinference/web/ui/build/static/js/main.26fdbfbe.js.map +1 -0
  45. xinference/web/ui/node_modules/.cache/babel-loader/15e2cf8cd8d0989719b6349428ff576f9009ff4c2dcc52378be0bd938e82495e.json +1 -0
  46. xinference/web/ui/node_modules/.cache/babel-loader/1870cd6f7054d04e049e363c0a85526584fe25519378609d2838e28d7492bbf1.json +1 -0
  47. xinference/web/ui/node_modules/.cache/babel-loader/1e86938a0cdf706d21e99b21f5d868fa247c0c88b26807047e26dcdc4d9a9db3.json +1 -0
  48. xinference/web/ui/node_modules/.cache/babel-loader/3c2f277c93c5f1638e08db38df0d0fb4e58d1c5571aea03241a5c04ff4094704.json +1 -0
  49. xinference/web/ui/node_modules/.cache/babel-loader/3fa1f69162f9c6dc0f6a6e21b64d49d6b8e6fa8dfa59a82cf829931c5f97d99f.json +1 -0
  50. xinference/web/ui/node_modules/.cache/babel-loader/44774c783428f952d8e2e4ad0998a9c5bc16a57cd9c68b7c5ff18aaa5a41d65c.json +1 -0
  51. xinference/web/ui/node_modules/.cache/babel-loader/5393569d846332075b93b55656716a34f50e0a8c970be789502d7e6c49755fd7.json +1 -0
  52. xinference/web/ui/node_modules/.cache/babel-loader/59ce49eae0f486af4c5034d4d2f9ca77c3ec3a32ecc560085caf5ef482b5f4c9.json +1 -0
  53. xinference/web/ui/node_modules/.cache/babel-loader/62e257ed9016471035fa1a7da57c9e2a4250974ed566b4d1295873d747c68eb2.json +1 -0
  54. xinference/web/ui/node_modules/.cache/babel-loader/63a4c48f0326d071c7772c46598215c006ae41fd3d4ff3577fe717de66ad6e89.json +1 -0
  55. xinference/web/ui/node_modules/.cache/babel-loader/b9cbcb6d77ba21b22c6950b6fb5b305d23c19cf747f99f7d48b6b046f8f7b1b0.json +1 -0
  56. xinference/web/ui/node_modules/.cache/babel-loader/d06a96a3c9c32e42689094aa3aaad41c8125894e956b8f84a70fadce6e3f65b3.json +1 -0
  57. xinference/web/ui/node_modules/.cache/babel-loader/de0299226173b0662b573f49e3992220f6611947073bd66ac079728a8bc8837d.json +1 -0
  58. xinference/web/ui/node_modules/.cache/babel-loader/e606671420d2937102c3c34b4b04056c11736408c1d3347b8cf42dfe61fb394b.json +1 -0
  59. xinference/web/ui/node_modules/.cache/babel-loader/e6eccc9aa641e7da833492e27846dc965f9750281420977dc84654ca6ed221e4.json +1 -0
  60. xinference/web/ui/node_modules/.cache/babel-loader/e9b52d171223bb59fb918316297a051cdfd42dd453e8260fd918e90bc0a4ebdf.json +1 -0
  61. xinference/web/ui/node_modules/.cache/babel-loader/f4d5d1a41892a754c1ee0237450d804b20612d1b657945b59e564161ea47aa7a.json +1 -0
  62. xinference/web/ui/node_modules/.cache/babel-loader/f9290c0738db50065492ceedc6a4af25083fe18399b7c44d942273349ad9e643.json +1 -0
  63. xinference/web/ui/node_modules/.cache/babel-loader/fad4cd70de36ef6e6d5f8fd74a10ded58d964a8a91ef7681693fbb8376552da7.json +1 -0
  64. xinference/web/ui/node_modules/.cache/babel-loader/feabb04b4aa507102da0a64398a40818e878fd1df9b75dda8461b3e1e7ff3f11.json +1 -0
  65. {xinference-0.10.0.dist-info → xinference-0.10.2.dist-info}/METADATA +13 -10
  66. {xinference-0.10.0.dist-info → xinference-0.10.2.dist-info}/RECORD +71 -74
  67. xinference/model/llm/ggml/ctransformers.py +0 -281
  68. xinference/model/llm/ggml/ctransformers_util.py +0 -161
  69. xinference/web/ui/build/static/js/main.98516614.js +0 -3
  70. xinference/web/ui/build/static/js/main.98516614.js.map +0 -1
  71. xinference/web/ui/node_modules/.cache/babel-loader/0bd70b1ecf307e2681318e864f4692305b6350c8683863007f4caf2f9ac33b6e.json +0 -1
  72. xinference/web/ui/node_modules/.cache/babel-loader/0db651c046ef908f45cde73af0dbea0a797d3e35bb57f4a0863b481502103a64.json +0 -1
  73. xinference/web/ui/node_modules/.cache/babel-loader/139969fd25258eb7decc9505f30b779089bba50c402bb5c663008477c7bff73b.json +0 -1
  74. xinference/web/ui/node_modules/.cache/babel-loader/18e5d5422e2464abf4a3e6d38164570e2e426e0a921e9a2628bbae81b18da353.json +0 -1
  75. xinference/web/ui/node_modules/.cache/babel-loader/3d93bd9a74a1ab0cec85af40f9baa5f6a8e7384b9e18c409b95a81a7b45bb7e2.json +0 -1
  76. xinference/web/ui/node_modules/.cache/babel-loader/3e055de705e397e1d413d7f429589b1a98dd78ef378b97f0cdb462c5f2487d5e.json +0 -1
  77. xinference/web/ui/node_modules/.cache/babel-loader/3f357ab57b8e7fade54c667f0e0ebf2787566f72bfdca0fea14e395b5c203753.json +0 -1
  78. xinference/web/ui/node_modules/.cache/babel-loader/4fd24800544873512b540544ae54601240a5bfefd9105ff647855c64f8ad828f.json +0 -1
  79. xinference/web/ui/node_modules/.cache/babel-loader/52aa27272b4b9968f62666262b47661cb1992336a2aff3b13994cc36877b3ec3.json +0 -1
  80. xinference/web/ui/node_modules/.cache/babel-loader/60c4b98d8ea7479fb0c94cfd19c8128f17bd7e27a1e73e6dd9adf6e9d88d18eb.json +0 -1
  81. xinference/web/ui/node_modules/.cache/babel-loader/7e094845f611802b024b57439cbf911038169d06cdf6c34a72a7277f35aa71a4.json +0 -1
  82. xinference/web/ui/node_modules/.cache/babel-loader/95c8cc049fadd23085d8623e1d43d70b614a4e52217676f186a417dca894aa09.json +0 -1
  83. xinference/web/ui/node_modules/.cache/babel-loader/98b7ef307f436affe13d75a4f265b27e828ccc2b10ffae6513abe2681bc11971.json +0 -1
  84. xinference/web/ui/node_modules/.cache/babel-loader/9d7c49815d97539207e5aab2fb967591b5fed7791218a0762539efc9491f36af.json +0 -1
  85. xinference/web/ui/node_modules/.cache/babel-loader/a8070ce4b780b4a044218536e158a9e7192a6c80ff593fdc126fee43f46296b5.json +0 -1
  86. xinference/web/ui/node_modules/.cache/babel-loader/b400cfc9db57fa6c70cd2bad055b73c5079fde0ed37974009d898083f6af8cd8.json +0 -1
  87. xinference/web/ui/node_modules/.cache/babel-loader/bd04667474fd9cac2983b03725c218908a6cc0ee9128a5953cd00d26d4877f60.json +0 -1
  88. xinference/web/ui/node_modules/.cache/babel-loader/c230a727b8f68f0e62616a75e14a3d33026dc4164f2e325a9a8072d733850edb.json +0 -1
  89. xinference/web/ui/node_modules/.cache/babel-loader/d0d0b591d9adaf42b83ad6633f8b7c118541a4b80ea957c303d3bf9b86fbad0a.json +0 -1
  90. xinference/web/ui/node_modules/.cache/babel-loader/d44a6eb6106e09082b691a315c9f6ce17fcfe25beb7547810e0d271ce3301cd2.json +0 -1
  91. xinference/web/ui/node_modules/.cache/babel-loader/e1d9b2ae4e1248658704bc6bfc5d6160dcd1a9e771ea4ae8c1fed0aaddeedd29.json +0 -1
  92. xinference/web/ui/node_modules/.cache/babel-loader/fe5db70859503a54cbe71f9637e5a314cda88b1f0eecb733b6e6f837697db1ef.json +0 -1
  93. /xinference/web/ui/build/static/js/{main.98516614.js.LICENSE.txt → main.26fdbfbe.js.LICENSE.txt} +0 -0
  94. {xinference-0.10.0.dist-info → xinference-0.10.2.dist-info}/LICENSE +0 -0
  95. {xinference-0.10.0.dist-info → xinference-0.10.2.dist-info}/WHEEL +0 -0
  96. {xinference-0.10.0.dist-info → xinference-0.10.2.dist-info}/entry_points.txt +0 -0
  97. {xinference-0.10.0.dist-info → xinference-0.10.2.dist-info}/top_level.txt +0 -0
@@ -16,9 +16,8 @@ import os
16
16
  from collections import defaultdict
17
17
  from typing import Dict, List, Optional, Tuple
18
18
 
19
- from ..._compat import BaseModel
20
19
  from ...constants import XINFERENCE_CACHE_DIR
21
- from ..core import ModelDescription
20
+ from ..core import CacheableModelSpec, ModelDescription
22
21
  from ..utils import valid_model_revision
23
22
  from .whisper import WhisperModel
24
23
 
@@ -26,8 +25,19 @@ MAX_ATTEMPTS = 3
26
25
 
27
26
  logger = logging.getLogger(__name__)
28
27
 
28
+ # Used for check whether the model is cached.
29
+ # Init when registering all the builtin models.
30
+ MODEL_NAME_TO_REVISION: Dict[str, List[str]] = defaultdict(list)
31
+ AUDIO_MODEL_DESCRIPTIONS: Dict[str, List[Dict]] = defaultdict(list)
29
32
 
30
- class AudioModelFamilyV1(BaseModel):
33
+
34
+ def get_audio_model_descriptions():
35
+ import copy
36
+
37
+ return copy.deepcopy(AUDIO_MODEL_DESCRIPTIONS)
38
+
39
+
40
+ class AudioModelFamilyV1(CacheableModelSpec):
31
41
  model_family: str
32
42
  model_name: str
33
43
  model_id: str
@@ -77,63 +87,33 @@ def generate_audio_description(
77
87
  image_model: AudioModelFamilyV1,
78
88
  ) -> Dict[str, List[Dict]]:
79
89
  res = defaultdict(list)
80
- res[image_model.model_name].extend(
81
- AudioModelDescription(None, None, image_model).to_dict()
90
+ res[image_model.model_name].append(
91
+ AudioModelDescription(None, None, image_model).to_version_info()
82
92
  )
83
93
  return res
84
94
 
85
95
 
86
- def match_model(model_name: str) -> AudioModelFamilyV1:
96
+ def match_audio(model_name: str) -> AudioModelFamilyV1:
87
97
  from . import BUILTIN_AUDIO_MODELS
98
+ from .custom import get_user_defined_audios
99
+
100
+ for model_spec in get_user_defined_audios():
101
+ if model_spec.model_name == model_name:
102
+ return model_spec
88
103
 
89
104
  if model_name in BUILTIN_AUDIO_MODELS:
90
105
  return BUILTIN_AUDIO_MODELS[model_name]
91
106
  else:
92
107
  raise ValueError(
93
- f"Image model {model_name} not found, available"
108
+ f"Audio model {model_name} not found, available"
94
109
  f"model list: {BUILTIN_AUDIO_MODELS.keys()}"
95
110
  )
96
111
 
97
112
 
98
113
  def cache(model_spec: AudioModelFamilyV1):
99
- # TODO: cache from uri
100
- import huggingface_hub
101
-
102
- cache_dir = get_cache_dir(model_spec)
103
- if not os.path.exists(cache_dir):
104
- os.makedirs(cache_dir, exist_ok=True)
105
-
106
- meta_path = os.path.join(cache_dir, "__valid_download")
107
- if valid_model_revision(meta_path, model_spec.model_revision):
108
- return cache_dir
109
-
110
- for current_attempt in range(1, MAX_ATTEMPTS + 1):
111
- try:
112
- huggingface_hub.snapshot_download(
113
- model_spec.model_id,
114
- revision=model_spec.model_revision,
115
- local_dir=cache_dir,
116
- local_dir_use_symlinks=True,
117
- resume_download=True,
118
- )
119
- break
120
- except huggingface_hub.utils.LocalEntryNotFoundError:
121
- remaining_attempts = MAX_ATTEMPTS - current_attempt
122
- logger.warning(
123
- f"Attempt {current_attempt} failed. Remaining attempts: {remaining_attempts}"
124
- )
125
- else:
126
- raise RuntimeError(
127
- f"Failed to download model '{model_spec.model_name}' after {MAX_ATTEMPTS} attempts"
128
- )
129
-
130
- with open(meta_path, "w") as f:
131
- import json
132
-
133
- desc = AudioModelDescription(None, None, model_spec)
134
- json.dump(desc.to_dict(), f)
114
+ from ..utils import cache
135
115
 
136
- return cache_dir
116
+ return cache(model_spec, AudioModelDescription)
137
117
 
138
118
 
139
119
  def get_cache_dir(model_spec: AudioModelFamilyV1):
@@ -151,7 +131,7 @@ def get_cache_status(
151
131
  def create_audio_model_instance(
152
132
  subpool_addr: str, devices: List[str], model_uid: str, model_name: str, **kwargs
153
133
  ) -> Tuple[WhisperModel, AudioModelDescription]:
154
- model_spec = match_model(model_name)
134
+ model_spec = match_audio(model_name)
155
135
  model_path = cache(model_spec)
156
136
  model = WhisperModel(model_uid, model_path, model_spec, **kwargs)
157
137
  model_description = AudioModelDescription(
@@ -0,0 +1,148 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import os
17
+ from threading import Lock
18
+ from typing import Any, List, Optional
19
+
20
+ from ..._compat import (
21
+ ROOT_KEY,
22
+ ErrorWrapper,
23
+ Protocol,
24
+ StrBytes,
25
+ ValidationError,
26
+ load_str_bytes,
27
+ )
28
+ from ...constants import XINFERENCE_CACHE_DIR, XINFERENCE_MODEL_DIR
29
+ from .core import AudioModelFamilyV1
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+ UD_AUDIO_LOCK = Lock()
34
+
35
+
36
+ class CustomAudioModelFamilyV1(AudioModelFamilyV1):
37
+ model_id: Optional[str] # type: ignore
38
+ model_revision: Optional[str] # type: ignore
39
+ model_uri: Optional[str]
40
+
41
+ @classmethod
42
+ def parse_raw(
43
+ cls: Any,
44
+ b: StrBytes,
45
+ *,
46
+ content_type: Optional[str] = None,
47
+ encoding: str = "utf8",
48
+ proto: Protocol = None,
49
+ allow_pickle: bool = False,
50
+ ) -> AudioModelFamilyV1:
51
+ # See source code of BaseModel.parse_raw
52
+ try:
53
+ obj = load_str_bytes(
54
+ b,
55
+ proto=proto,
56
+ content_type=content_type,
57
+ encoding=encoding,
58
+ allow_pickle=allow_pickle,
59
+ json_loads=cls.__config__.json_loads,
60
+ )
61
+ except (ValueError, TypeError, UnicodeDecodeError) as e:
62
+ raise ValidationError([ErrorWrapper(e, loc=ROOT_KEY)], cls)
63
+
64
+ audio_spec: AudioModelFamilyV1 = cls.parse_obj(obj)
65
+
66
+ # check model_family
67
+ if audio_spec.model_family is None:
68
+ raise ValueError(
69
+ f"You must specify `model_family` when registering custom Audio models."
70
+ )
71
+ assert isinstance(audio_spec.model_family, str)
72
+ return audio_spec
73
+
74
+
75
+ UD_AUDIOS: List[CustomAudioModelFamilyV1] = []
76
+
77
+
78
+ def get_user_defined_audios() -> List[CustomAudioModelFamilyV1]:
79
+ with UD_AUDIO_LOCK:
80
+ return UD_AUDIOS.copy()
81
+
82
+
83
+ def register_audio(model_spec: CustomAudioModelFamilyV1, persist: bool):
84
+ from ...constants import XINFERENCE_MODEL_DIR
85
+ from ..utils import is_valid_model_name, is_valid_model_uri
86
+ from . import BUILTIN_AUDIO_MODELS
87
+
88
+ if not is_valid_model_name(model_spec.model_name):
89
+ raise ValueError(f"Invalid model name {model_spec.model_name}.")
90
+
91
+ with UD_AUDIO_LOCK:
92
+ for model_name in list(BUILTIN_AUDIO_MODELS.keys()) + [
93
+ spec.model_name for spec in UD_AUDIOS
94
+ ]:
95
+ if model_spec.model_name == model_name:
96
+ raise ValueError(
97
+ f"Model name conflicts with existing model {model_spec.model_name}"
98
+ )
99
+
100
+ UD_AUDIOS.append(model_spec)
101
+
102
+ if persist:
103
+ # We only validate model URL when persist is True.
104
+ model_uri = model_spec.model_uri
105
+ if model_uri and not is_valid_model_uri(model_uri):
106
+ raise ValueError(f"Invalid model URI {model_uri}.")
107
+
108
+ persist_path = os.path.join(
109
+ XINFERENCE_MODEL_DIR, "audio", f"{model_spec.model_name}.json"
110
+ )
111
+ os.makedirs(os.path.dirname(persist_path), exist_ok=True)
112
+ with open(persist_path, mode="w") as fd:
113
+ fd.write(model_spec.json())
114
+
115
+
116
+ def unregister_audio(model_name: str, raise_error: bool = True):
117
+ with UD_AUDIO_LOCK:
118
+ model_spec = None
119
+ for i, f in enumerate(UD_AUDIOS):
120
+ if f.model_name == model_name:
121
+ model_spec = f
122
+ break
123
+ if model_spec:
124
+ UD_AUDIOS.remove(model_spec)
125
+
126
+ persist_path = os.path.join(
127
+ XINFERENCE_MODEL_DIR, "audio", f"{model_spec.model_name}.json"
128
+ )
129
+ if os.path.exists(persist_path):
130
+ os.remove(persist_path)
131
+
132
+ cache_dir = os.path.join(XINFERENCE_CACHE_DIR, model_spec.model_name)
133
+ if os.path.exists(cache_dir):
134
+ logger.warning(
135
+ f"Remove the cache of user-defined model {model_spec.model_name}. "
136
+ f"Cache directory: {cache_dir}"
137
+ )
138
+ if os.path.isdir(cache_dir):
139
+ os.rmdir(cache_dir)
140
+ else:
141
+ logger.warning(
142
+ f"Cache directory is not a soft link, please remove it manually."
143
+ )
144
+ else:
145
+ if raise_error:
146
+ raise ValueError(f"Model {model_name} not found")
147
+ else:
148
+ logger.warning(f"Custom audio model {model_name} not found")
xinference/model/core.py CHANGED
@@ -13,9 +13,10 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from abc import ABC, abstractmethod
16
- from typing import Any, Dict, List, Optional, Tuple
16
+ from typing import Any, List, Optional, Tuple, Union
17
17
 
18
18
  from .._compat import BaseModel
19
+ from ..types import PeftModelConfig
19
20
 
20
21
 
21
22
  class ModelDescription(ABC):
@@ -50,11 +51,9 @@ def create_model_instance(
50
51
  model_type: str,
51
52
  model_name: str,
52
53
  model_format: Optional[str] = None,
53
- model_size_in_billions: Optional[int] = None,
54
+ model_size_in_billions: Optional[Union[int, str]] = None,
54
55
  quantization: Optional[str] = None,
55
- peft_model_path: Optional[str] = None,
56
- image_lora_load_kwargs: Optional[Dict] = None,
57
- image_lora_fuse_kwargs: Optional[Dict] = None,
56
+ peft_model_config: Optional[PeftModelConfig] = None,
58
57
  is_local_deployment: bool = False,
59
58
  **kwargs,
60
59
  ) -> Tuple[Any, ModelDescription]:
@@ -73,7 +72,7 @@ def create_model_instance(
73
72
  model_format,
74
73
  model_size_in_billions,
75
74
  quantization,
76
- peft_model_path,
75
+ peft_model_config,
77
76
  is_local_deployment,
78
77
  **kwargs,
79
78
  )
@@ -90,9 +89,7 @@ def create_model_instance(
90
89
  devices,
91
90
  model_uid,
92
91
  model_name,
93
- lora_model_path=peft_model_path,
94
- lora_load_kwargs=image_lora_load_kwargs,
95
- lora_fuse_kwargs=image_lora_fuse_kwargs,
92
+ peft_model_config,
96
93
  **kwargs,
97
94
  )
98
95
  elif model_type == "rerank":
@@ -136,7 +136,7 @@ class EmbeddingModel:
136
136
  def create_embedding(self, sentences: Union[str, List[str]], **kwargs):
137
137
  from sentence_transformers import SentenceTransformer
138
138
 
139
- normalize_embeddings = kwargs.pop("normalize_embeddings", True)
139
+ kwargs.setdefault("normalize_embeddings", True)
140
140
 
141
141
  # copied from sentence-transformers, and modify it to return tokens num
142
142
  @no_type_check
@@ -272,7 +272,6 @@ class EmbeddingModel:
272
272
  self._model,
273
273
  sentences,
274
274
  convert_to_numpy=False,
275
- normalize_embeddings=normalize_embeddings,
276
275
  **kwargs,
277
276
  )
278
277
  if isinstance(sentences, str):
@@ -206,5 +206,29 @@
206
206
  "language": ["zh", "en"],
207
207
  "model_id": "maidalun1020/bce-embedding-base_v1",
208
208
  "model_revision": "236d9024fc1b4046f03848723f934521a66a9323"
209
+ },
210
+ {
211
+ "model_name": "m3e-small",
212
+ "dimensions": 512,
213
+ "max_tokens": 512,
214
+ "language": ["zh", "en"],
215
+ "model_id": "moka-ai/m3e-small",
216
+ "model_revision": "44c696631b2a8c200220aaaad5f987f096e986df"
217
+ },
218
+ {
219
+ "model_name": "m3e-base",
220
+ "dimensions": 768,
221
+ "max_tokens": 512,
222
+ "language": ["zh", "en"],
223
+ "model_id": "moka-ai/m3e-base",
224
+ "model_revision": "764b537a0e50e5c7d64db883f2d2e051cbe3c64c"
225
+ },
226
+ {
227
+ "model_name": "m3e-large",
228
+ "dimensions": 1024,
229
+ "max_tokens": 512,
230
+ "language": ["zh", "en"],
231
+ "model_id": "moka-ai/m3e-large",
232
+ "model_revision": "12900375086c37ba5d83d1e417b21dc7d1d1f388"
209
233
  }
210
234
  ]
@@ -208,5 +208,29 @@
208
208
  "language": ["zh", "en"],
209
209
  "model_id": "maidalun/bce-embedding-base_v1",
210
210
  "model_hub": "modelscope"
211
+ },
212
+ {
213
+ "model_name": "m3e-small",
214
+ "dimensions": 512,
215
+ "max_tokens": 512,
216
+ "language": ["zh", "en"],
217
+ "model_id": "AI-ModelScope/m3e-small",
218
+ "model_hub": "modelscope"
219
+ },
220
+ {
221
+ "model_name": "m3e-base",
222
+ "dimensions": 768,
223
+ "max_tokens": 512,
224
+ "language": ["zh", "en"],
225
+ "model_id": "AI-ModelScope/m3e-base",
226
+ "model_hub": "modelscope"
227
+ },
228
+ {
229
+ "model_name": "m3e-large",
230
+ "dimensions": 1024,
231
+ "max_tokens": 512,
232
+ "language": ["zh", "en"],
233
+ "model_id": "AI-ModelScope/m3e-large",
234
+ "model_hub": "modelscope"
211
235
  }
212
236
  ]
@@ -18,6 +18,7 @@ from collections import defaultdict
18
18
  from typing import Dict, List, Optional, Tuple
19
19
 
20
20
  from ...constants import XINFERENCE_CACHE_DIR
21
+ from ...types import PeftModelConfig
21
22
  from ..core import CacheableModelSpec, ModelDescription
22
23
  from ..utils import valid_model_revision
23
24
  from .stable_diffusion.core import DiffusionModel
@@ -175,9 +176,7 @@ def create_image_model_instance(
175
176
  devices: List[str],
176
177
  model_uid: str,
177
178
  model_name: str,
178
- lora_model_path: Optional[str] = None,
179
- lora_load_kwargs: Optional[Dict] = None,
180
- lora_fuse_kwargs: Optional[Dict] = None,
179
+ peft_model_config: Optional[PeftModelConfig] = None,
181
180
  **kwargs,
182
181
  ) -> Tuple[DiffusionModel, ImageModelDescription]:
183
182
  model_spec = match_diffusion(model_name)
@@ -210,10 +209,19 @@ def create_image_model_instance(
210
209
  else:
211
210
  kwargs["controlnet"] = controlnet_model_paths
212
211
  model_path = cache(model_spec)
212
+ if peft_model_config is not None:
213
+ lora_model = peft_model_config.peft_model
214
+ lora_load_kwargs = peft_model_config.image_lora_load_kwargs
215
+ lora_fuse_kwargs = peft_model_config.image_lora_fuse_kwargs
216
+ else:
217
+ lora_model = None
218
+ lora_load_kwargs = None
219
+ lora_fuse_kwargs = None
220
+
213
221
  model = DiffusionModel(
214
222
  model_uid,
215
223
  model_path,
216
- lora_model_path=lora_model_path,
224
+ lora_model_paths=lora_model,
217
225
  lora_load_kwargs=lora_load_kwargs,
218
226
  lora_fuse_kwargs=lora_fuse_kwargs,
219
227
  **kwargs,
@@ -25,7 +25,7 @@ from typing import Dict, List, Optional, Union
25
25
 
26
26
  from ....constants import XINFERENCE_IMAGE_DIR
27
27
  from ....device_utils import move_model_to_available_device
28
- from ....types import Image, ImageList
28
+ from ....types import Image, ImageList, LoRA
29
29
 
30
30
  logger = logging.getLogger(__name__)
31
31
 
@@ -36,7 +36,7 @@ class DiffusionModel:
36
36
  model_uid: str,
37
37
  model_path: str,
38
38
  device: Optional[str] = None,
39
- lora_model_path: Optional[str] = None,
39
+ lora_model: Optional[List[LoRA]] = None,
40
40
  lora_load_kwargs: Optional[Dict] = None,
41
41
  lora_fuse_kwargs: Optional[Dict] = None,
42
42
  **kwargs,
@@ -45,20 +45,21 @@ class DiffusionModel:
45
45
  self._model_path = model_path
46
46
  self._device = device
47
47
  self._model = None
48
- self._lora_model_path = lora_model_path
48
+ self._lora_model = lora_model
49
49
  self._lora_load_kwargs = lora_load_kwargs or {}
50
50
  self._lora_fuse_kwargs = lora_fuse_kwargs or {}
51
51
  self._kwargs = kwargs
52
52
 
53
53
  def _apply_lora(self):
54
- if self._lora_model_path is not None:
54
+ if self._lora_model is not None:
55
55
  logger.info(
56
56
  f"Loading the LoRA with load kwargs: {self._lora_load_kwargs}, fuse kwargs: {self._lora_fuse_kwargs}."
57
57
  )
58
58
  assert self._model is not None
59
- self._model.load_lora_weights(
60
- self._lora_model_path, **self._lora_load_kwargs
61
- )
59
+ for lora_model in self._lora_model:
60
+ self._model.load_lora_weights(
61
+ lora_model.local_path, **self._lora_load_kwargs
62
+ )
62
63
  self._model.fuse_lora(**self._lora_fuse_kwargs)
63
64
  logger.info(f"Successfully loaded the LoRA for model {self._model_uid}.")
64
65
 
@@ -49,7 +49,6 @@ from .llm_family import (
49
49
 
50
50
  def _install():
51
51
  from .ggml.chatglm import ChatglmCppChatModel
52
- from .ggml.ctransformers import CtransformersModel
53
52
  from .ggml.llamacpp import LlamaCppChatModel, LlamaCppModel
54
53
  from .pytorch.baichuan import BaichuanPytorchChatModel
55
54
  from .pytorch.chatglm import ChatglmPytorchChatModel
@@ -77,11 +76,6 @@ def _install():
77
76
  ChatglmCppChatModel,
78
77
  ]
79
78
  )
80
- LLM_CLASSES.extend(
81
- [
82
- CtransformersModel,
83
- ]
84
- )
85
79
  LLM_CLASSES.extend([SGLANGModel, SGLANGChatModel])
86
80
  LLM_CLASSES.extend([VLLMModel, VLLMChatModel])
87
81
  LLM_CLASSES.extend(
@@ -21,6 +21,7 @@ from collections import defaultdict
21
21
  from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
22
22
 
23
23
  from ...core.utils import parse_replica_model_uid
24
+ from ...types import PeftModelConfig
24
25
  from ..core import ModelDescription
25
26
 
26
27
  if TYPE_CHECKING:
@@ -178,9 +179,9 @@ def create_llm_model_instance(
178
179
  model_uid: str,
179
180
  model_name: str,
180
181
  model_format: Optional[str] = None,
181
- model_size_in_billions: Optional[int] = None,
182
+ model_size_in_billions: Optional[Union[int, str]] = None,
182
183
  quantization: Optional[str] = None,
183
- peft_model_path: Optional[str] = None,
184
+ peft_model_config: Optional[PeftModelConfig] = None,
184
185
  is_local_deployment: bool = False,
185
186
  **kwargs,
186
187
  ) -> Tuple[LLM, LLMDescription]:
@@ -204,9 +205,9 @@ def create_llm_model_instance(
204
205
  assert quantization is not None
205
206
  save_path = cache(llm_family, llm_spec, quantization)
206
207
 
207
- llm_cls = match_llm_cls(
208
- llm_family, llm_spec, quantization, peft_model_path=peft_model_path
209
- )
208
+ peft_model = peft_model_config.peft_model if peft_model_config else None
209
+
210
+ llm_cls = match_llm_cls(llm_family, llm_spec, quantization, peft_model=peft_model)
210
211
  if not llm_cls:
211
212
  raise ValueError(
212
213
  f"Model not supported, name: {model_name}, format: {model_format},"
@@ -214,15 +215,9 @@ def create_llm_model_instance(
214
215
  )
215
216
  logger.debug(f"Launching {model_uid} with {llm_cls.__name__}")
216
217
 
217
- if peft_model_path is not None:
218
+ if peft_model is not None:
218
219
  model = llm_cls(
219
- model_uid,
220
- llm_family,
221
- llm_spec,
222
- quantization,
223
- save_path,
224
- kwargs,
225
- peft_model_path,
220
+ model_uid, llm_family, llm_spec, quantization, save_path, kwargs, peft_model
226
221
  )
227
222
  else:
228
223
  model = llm_cls(
@@ -238,7 +233,7 @@ def create_speculative_llm_model_instance(
238
233
  devices: List[str],
239
234
  model_uid: str,
240
235
  model_name: str,
241
- model_size_in_billions: Optional[int],
236
+ model_size_in_billions: Optional[Union[int, str]],
242
237
  quantization: Optional[str],
243
238
  draft_model_name: str,
244
239
  draft_model_size_in_billions: Optional[int],
@@ -30,7 +30,6 @@ from ....types import (
30
30
  from ..core import LLM
31
31
  from ..llm_family import LLMFamilyV1, LLMSpecV1
32
32
  from ..utils import ChatModelMixin
33
- from .ctransformers import CTRANSFORMERS_SUPPORTED_MODEL
34
33
 
35
34
  logger = logging.getLogger(__name__)
36
35
 
@@ -182,11 +181,7 @@ class LlamaCppModel(LLM):
182
181
  ) -> bool:
183
182
  if llm_spec.model_format not in ["ggmlv3", "ggufv2"]:
184
183
  return False
185
- if (
186
- "chatglm" in llm_family.model_name
187
- or "qwen" in llm_family.model_name
188
- or llm_family.model_name in CTRANSFORMERS_SUPPORTED_MODEL
189
- ):
184
+ if "chatglm" in llm_family.model_name or "qwen" in llm_family.model_name:
190
185
  return False
191
186
  if "generate" not in llm_family.model_ability:
192
187
  return False
@@ -250,10 +245,7 @@ class LlamaCppChatModel(LlamaCppModel, ChatModelMixin):
250
245
  ) -> bool:
251
246
  if llm_spec.model_format not in ["ggmlv3", "ggufv2"]:
252
247
  return False
253
- if (
254
- "chatglm" in llm_family.model_name
255
- or llm_family.model_name in CTRANSFORMERS_SUPPORTED_MODEL
256
- ):
248
+ if "chatglm" in llm_family.model_name:
257
249
  return False
258
250
  if "chat" not in llm_family.model_ability:
259
251
  return False