xinference 0.12.0__py3-none-any.whl → 0.12.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (85) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +108 -14
  3. xinference/client/restful/restful_client.py +78 -5
  4. xinference/constants.py +1 -0
  5. xinference/core/cache_tracker.py +48 -28
  6. xinference/core/event.py +5 -6
  7. xinference/core/model.py +59 -42
  8. xinference/core/scheduler.py +46 -18
  9. xinference/core/supervisor.py +73 -24
  10. xinference/core/worker.py +68 -2
  11. xinference/deploy/cmdline.py +86 -2
  12. xinference/deploy/test/test_cmdline.py +19 -10
  13. xinference/model/audio/__init__.py +14 -1
  14. xinference/model/audio/core.py +12 -1
  15. xinference/model/audio/custom.py +6 -4
  16. xinference/model/audio/model_spec_modelscope.json +20 -0
  17. xinference/model/llm/__init__.py +34 -2
  18. xinference/model/llm/llm_family.json +8 -2
  19. xinference/model/llm/llm_family.py +86 -1
  20. xinference/model/llm/llm_family_csghub.json +66 -0
  21. xinference/model/llm/llm_family_modelscope.json +8 -2
  22. xinference/model/llm/pytorch/chatglm.py +41 -12
  23. xinference/model/llm/pytorch/core.py +128 -88
  24. xinference/model/llm/pytorch/glm4v.py +24 -3
  25. xinference/model/llm/pytorch/internlm2.py +15 -0
  26. xinference/model/llm/pytorch/qwen_vl.py +1 -1
  27. xinference/model/llm/pytorch/utils.py +69 -189
  28. xinference/model/llm/utils.py +27 -14
  29. xinference/model/llm/vllm/core.py +10 -4
  30. xinference/model/rerank/core.py +35 -6
  31. xinference/model/utils.py +8 -2
  32. xinference/thirdparty/ChatTTS/experimental/__init__.py +0 -0
  33. xinference/thirdparty/ChatTTS/experimental/llm.py +40 -0
  34. xinference/thirdparty/ChatTTS/infer/__init__.py +0 -0
  35. xinference/thirdparty/ChatTTS/infer/api.py +125 -0
  36. xinference/thirdparty/ChatTTS/model/__init__.py +0 -0
  37. xinference/thirdparty/ChatTTS/model/dvae.py +155 -0
  38. xinference/thirdparty/ChatTTS/model/gpt.py +265 -0
  39. xinference/thirdparty/ChatTTS/utils/__init__.py +0 -0
  40. xinference/thirdparty/ChatTTS/utils/gpu_utils.py +23 -0
  41. xinference/thirdparty/ChatTTS/utils/infer_utils.py +141 -0
  42. xinference/thirdparty/ChatTTS/utils/io_utils.py +14 -0
  43. xinference/types.py +28 -0
  44. xinference/web/ui/build/asset-manifest.json +6 -6
  45. xinference/web/ui/build/index.html +1 -1
  46. xinference/web/ui/build/static/css/main.4bafd904.css +2 -0
  47. xinference/web/ui/build/static/css/main.4bafd904.css.map +1 -0
  48. xinference/web/ui/build/static/js/main.b80d9c08.js +3 -0
  49. xinference/web/ui/build/static/js/main.b80d9c08.js.map +1 -0
  50. xinference/web/ui/node_modules/.cache/babel-loader/0c2fb5375667931c4a331c99e0d87dc145e8f327cea3f44d6e56f54c7c1d4020.json +1 -0
  51. xinference/web/ui/node_modules/.cache/babel-loader/131091b25d26b17cdca187d7542a21475c211138d900cf667682260e76ef9463.json +1 -0
  52. xinference/web/ui/node_modules/.cache/babel-loader/16537795de12c61903b6110c241f62a7855b2d0fc1e7c3d1faa347267f3a6893.json +1 -0
  53. xinference/web/ui/node_modules/.cache/babel-loader/17b8f071491402d70b146532358b1a612226e5dc7b3e8755a1322d27b4680cee.json +1 -0
  54. xinference/web/ui/node_modules/.cache/babel-loader/395409bd005e19d48b437c48d88e5126c7865ba9631fe98535333c952e383dc5.json +1 -0
  55. xinference/web/ui/node_modules/.cache/babel-loader/3da7d55e87882a4af923e187b1351160e34ca102f589086439c15131a227fb6e.json +1 -0
  56. xinference/web/ui/node_modules/.cache/babel-loader/43991bb67c3136863e6fb37f796466b12eb547a1465408cc77820fddafb3bed3.json +1 -0
  57. xinference/web/ui/node_modules/.cache/babel-loader/72bcecc71c5267250edeb89608859d449b586f13ff9923a5e70e7172976ec403.json +1 -0
  58. xinference/web/ui/node_modules/.cache/babel-loader/{15e2cf8cd8d0989719b6349428ff576f9009ff4c2dcc52378be0bd938e82495e.json → 935efd2867664c58230378fdf2ff1ea85e58d853b7214014e20dfbca8dab7b05.json} +1 -1
  59. xinference/web/ui/node_modules/.cache/babel-loader/a7109d4425e3d94ca2726fc7020fd33bf5030afd4c9cf4bf71e21776cd70646a.json +1 -0
  60. xinference/web/ui/node_modules/.cache/babel-loader/c2abe75f04ad82fba68f35ed9cbe2e287762c876684fddccccfa73f739489b65.json +1 -0
  61. xinference/web/ui/node_modules/.cache/babel-loader/f28b83886159d83b84f099b05d607a822dca4dd7f2d8aa6d56fe08bab0b5b086.json +1 -0
  62. xinference/web/ui/node_modules/.cache/babel-loader/f51bf63ddaa7afd125ef2254a105789333eecc1c94fdf5157a9b88ef7ad0a5bd.json +1 -0
  63. {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/METADATA +1 -1
  64. {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/RECORD +69 -56
  65. xinference/web/ui/build/static/css/main.54bca460.css +0 -2
  66. xinference/web/ui/build/static/css/main.54bca460.css.map +0 -1
  67. xinference/web/ui/build/static/js/main.551aa479.js +0 -3
  68. xinference/web/ui/build/static/js/main.551aa479.js.map +0 -1
  69. xinference/web/ui/node_modules/.cache/babel-loader/1e86938a0cdf706d21e99b21f5d868fa247c0c88b26807047e26dcdc4d9a9db3.json +0 -1
  70. xinference/web/ui/node_modules/.cache/babel-loader/1fa824d82b2af519de7700c594e50bde4bbca60d13bd3fabff576802e4070304.json +0 -1
  71. xinference/web/ui/node_modules/.cache/babel-loader/2c63e940b945fd5817157e08a42b889b30d668ea4c91332f48ef2b1b9d26f520.json +0 -1
  72. xinference/web/ui/node_modules/.cache/babel-loader/3c2f277c93c5f1638e08db38df0d0fb4e58d1c5571aea03241a5c04ff4094704.json +0 -1
  73. xinference/web/ui/node_modules/.cache/babel-loader/3e737bcdbcbc407ccd65b90e199ef0c3214b261e8e41dbf14d921384a717d9ee.json +0 -1
  74. xinference/web/ui/node_modules/.cache/babel-loader/4135fe8745434cbce6438d1ebfa47422e0c77d884db4edc75c8bf32ea1d50621.json +0 -1
  75. xinference/web/ui/node_modules/.cache/babel-loader/46b6dd1f6d1109cd0e2455a0ea0be3e9bda1097cd4ebec9c4040070372671cfc.json +0 -1
  76. xinference/web/ui/node_modules/.cache/babel-loader/4de0a71074f9cbe1e7862750dcdd08cbc1bae7d9d9849a78b1783ca670017b3c.json +0 -1
  77. xinference/web/ui/node_modules/.cache/babel-loader/59ce49eae0f486af4c5034d4d2f9ca77c3ec3a32ecc560085caf5ef482b5f4c9.json +0 -1
  78. xinference/web/ui/node_modules/.cache/babel-loader/9cfd33238ca43e5bf9fc7e442690e8cc6027c73553db36de87e3597ed524ee4b.json +0 -1
  79. xinference/web/ui/node_modules/.cache/babel-loader/a6da6bc3d0d2191adebee87fb58ecebe82d071087bd2f7f3a9c7fdd2ada130f2.json +0 -1
  80. xinference/web/ui/node_modules/.cache/babel-loader/e6eccc9aa641e7da833492e27846dc965f9750281420977dc84654ca6ed221e4.json +0 -1
  81. /xinference/web/ui/build/static/js/{main.551aa479.js.LICENSE.txt → main.b80d9c08.js.LICENSE.txt} +0 -0
  82. {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/LICENSE +0 -0
  83. {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/WHEEL +0 -0
  84. {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/entry_points.txt +0 -0
  85. {xinference-0.12.0.dist-info → xinference-0.12.2.dist-info}/top_level.txt +0 -0
@@ -15,9 +15,10 @@
15
15
  import asyncio
16
16
  import functools
17
17
  import logging
18
+ import uuid
18
19
  from collections import deque
19
20
  from enum import Enum
20
- from typing import List, Optional, Set
21
+ from typing import List, Optional, Set, Tuple
21
22
 
22
23
  import xoscar as xo
23
24
 
@@ -50,9 +51,10 @@ class InferenceRequest:
50
51
  self._new_tokens = []
51
52
  # kv_cache used in decode phase
52
53
  self._kv_cache = None
53
- # use passed args from `chat` interface
54
+ # use passed args from upstream interface
54
55
  self._inference_args = args
55
- # use passed kwargs from `chat` interface, basically not used for now
56
+ # use passed kwargs from upstream interface, currently for getting raw generate config from upstream,
57
+ # which is useful for some special models
56
58
  self._inference_kwargs = kwargs
57
59
  # should this request be stopped
58
60
  self._stopped = False
@@ -63,6 +65,10 @@ class InferenceRequest:
63
65
  self._aborted = False
64
66
  # sanitized generate config
65
67
  self._sanitized_generate_config = None
68
+ # Chunk id for results. In stream mode, all the chunk ids should be same.
69
+ self._stream_chunk_id = str(uuid.uuid4())
70
+ # For calculate attention mask if needed
71
+ self.padding_len = 0
66
72
  # Use in stream mode
67
73
  self.last_output_length = 0
68
74
  # inference results,
@@ -81,19 +87,26 @@ class InferenceRequest:
81
87
  self._check_args()
82
88
 
83
89
  def _check_args(self):
84
- assert len(self._inference_args) == 3
85
- # system prompt
86
- assert self._inference_args[0] is None or isinstance(
87
- self._inference_args[0], str
88
- )
89
- # chat history
90
- assert self._inference_args[1] is None or isinstance(
91
- self._inference_args[1], list
92
- )
93
- # generate config
94
- assert self._inference_args[2] is None or isinstance(
95
- self._inference_args[2], dict
96
- )
90
+ # chat
91
+ if len(self._inference_args) == 3:
92
+ # system prompt
93
+ assert self._inference_args[0] is None or isinstance(
94
+ self._inference_args[0], str
95
+ )
96
+ # chat history
97
+ assert self._inference_args[1] is None or isinstance(
98
+ self._inference_args[1], list
99
+ )
100
+ # generate config
101
+ assert self._inference_args[2] is None or isinstance(
102
+ self._inference_args[2], dict
103
+ )
104
+ else: # generate
105
+ assert len(self._inference_args) == 1
106
+ # generate config
107
+ assert self._inference_args[0] is None or isinstance(
108
+ self._inference_args[0], dict
109
+ )
97
110
 
98
111
  @property
99
112
  def prompt(self):
@@ -148,7 +161,11 @@ class InferenceRequest:
148
161
 
149
162
  @property
150
163
  def generate_config(self):
151
- return self._inference_args[2]
164
+ return (
165
+ self._inference_args[2]
166
+ if len(self._inference_args) == 3
167
+ else self._inference_args[0]
168
+ )
152
169
 
153
170
  @property
154
171
  def sanitized_generate_config(self):
@@ -158,6 +175,10 @@ class InferenceRequest:
158
175
  def sanitized_generate_config(self, value: dict):
159
176
  self._sanitized_generate_config = value
160
177
 
178
+ @property
179
+ def inference_kwargs(self):
180
+ return self._inference_kwargs
181
+
161
182
  @property
162
183
  def stopped(self):
163
184
  return self._stopped
@@ -174,6 +195,10 @@ class InferenceRequest:
174
195
  def finish_reason(self, value: Optional[str]):
175
196
  self._finish_reason = value
176
197
 
198
+ @property
199
+ def chunk_id(self):
200
+ return self._stream_chunk_id
201
+
177
202
  @property
178
203
  def stream(self) -> bool:
179
204
  return (
@@ -213,7 +238,9 @@ class InferenceRequest:
213
238
  )
214
239
 
215
240
  @functools.lru_cache
216
- def get_generate_configs(self, eos_token_id: int):
241
+ def get_generate_configs(
242
+ self, eos_token_id: int, builtin_stop_token_ids: Optional[Tuple[int]] = None
243
+ ):
217
244
  from ..types import max_tokens_field
218
245
 
219
246
  max_new_tokens = int(
@@ -227,6 +254,7 @@ class InferenceRequest:
227
254
  )
228
255
  stop_token_ids = set(stop_token_ids)
229
256
  stop_token_ids.add(eos_token_id)
257
+ stop_token_ids.update(builtin_stop_token_ids or [])
230
258
  temperature = float(self.sanitized_generate_config.get("temperature", 1.0))
231
259
  repetition_penalty = float(
232
260
  self.sanitized_generate_config.get("repetition_penalty", 1.0)
@@ -982,32 +982,31 @@ class SupervisorActor(xo.StatelessActor):
982
982
  )
983
983
 
984
984
  @log_async(logger=logger)
985
- async def list_cached_models(self) -> List[Dict[str, Any]]:
985
+ async def list_cached_models(
986
+ self, model_name: Optional[str] = None, worker_ip: Optional[str] = None
987
+ ) -> List[Dict[str, Any]]:
988
+ target_ip_worker_ref = (
989
+ self._get_worker_ref_by_ip(worker_ip) if worker_ip is not None else None
990
+ )
991
+ if (
992
+ worker_ip is not None
993
+ and not self.is_local_deployment()
994
+ and target_ip_worker_ref is None
995
+ ):
996
+ raise ValueError(f"Worker ip address {worker_ip} is not in the cluster.")
997
+
998
+ # search assigned worker and return
999
+ if target_ip_worker_ref:
1000
+ cached_models = await target_ip_worker_ref.list_cached_models(model_name)
1001
+ cached_models = sorted(cached_models, key=lambda x: x["model_name"])
1002
+ return cached_models
1003
+
1004
+ # search all worker
986
1005
  cached_models = []
987
1006
  for worker in self._worker_address_to_worker.values():
988
- ret = await worker.list_cached_models()
989
- for model_version in ret:
990
- model_name = model_version.get("model_name", None)
991
- model_format = model_version.get("model_format", None)
992
- model_size_in_billions = model_version.get(
993
- "model_size_in_billions", None
994
- )
995
- quantizations = model_version.get("quantization", None)
996
- actor_ip_address = model_version.get("actor_ip_address", None)
997
- path = model_version.get("path", None)
998
- real_path = model_version.get("real_path", None)
999
-
1000
- cache_entry = {
1001
- "model_name": model_name,
1002
- "model_format": model_format,
1003
- "model_size_in_billions": model_size_in_billions,
1004
- "quantizations": quantizations,
1005
- "path": path,
1006
- "Actor IP Address": actor_ip_address,
1007
- "real_path": real_path,
1008
- }
1009
-
1010
- cached_models.append(cache_entry)
1007
+ res = await worker.list_cached_models(model_name)
1008
+ cached_models.extend(res)
1009
+ cached_models = sorted(cached_models, key=lambda x: x["model_name"])
1011
1010
  return cached_models
1012
1011
 
1013
1012
  @log_async(logger=logger)
@@ -1083,6 +1082,56 @@ class SupervisorActor(xo.StatelessActor):
1083
1082
  worker_status.update_time = time.time()
1084
1083
  worker_status.status = status
1085
1084
 
1085
+ async def list_deletable_models(
1086
+ self, model_version: str, worker_ip: Optional[str] = None
1087
+ ) -> List[str]:
1088
+ target_ip_worker_ref = (
1089
+ self._get_worker_ref_by_ip(worker_ip) if worker_ip is not None else None
1090
+ )
1091
+ if (
1092
+ worker_ip is not None
1093
+ and not self.is_local_deployment()
1094
+ and target_ip_worker_ref is None
1095
+ ):
1096
+ raise ValueError(f"Worker ip address {worker_ip} is not in the cluster.")
1097
+
1098
+ ret = []
1099
+ if target_ip_worker_ref:
1100
+ ret = await target_ip_worker_ref.list_deletable_models(
1101
+ model_version=model_version,
1102
+ )
1103
+ return ret
1104
+
1105
+ for worker in self._worker_address_to_worker.values():
1106
+ path = await worker.list_deletable_models(model_version=model_version)
1107
+ ret.extend(path)
1108
+ return ret
1109
+
1110
+ async def confirm_and_remove_model(
1111
+ self, model_version: str, worker_ip: Optional[str] = None
1112
+ ) -> bool:
1113
+ target_ip_worker_ref = (
1114
+ self._get_worker_ref_by_ip(worker_ip) if worker_ip is not None else None
1115
+ )
1116
+ if (
1117
+ worker_ip is not None
1118
+ and not self.is_local_deployment()
1119
+ and target_ip_worker_ref is None
1120
+ ):
1121
+ raise ValueError(f"Worker ip address {worker_ip} is not in the cluster.")
1122
+
1123
+ if target_ip_worker_ref:
1124
+ ret = await target_ip_worker_ref.confirm_and_remove_model(
1125
+ model_version=model_version,
1126
+ )
1127
+ return ret
1128
+ ret = True
1129
+ for worker in self._worker_address_to_worker.values():
1130
+ ret = ret and await worker.confirm_and_remove_model(
1131
+ model_version=model_version,
1132
+ )
1133
+ return ret
1134
+
1086
1135
  @staticmethod
1087
1136
  def record_metrics(name, op, kwargs):
1088
1137
  record_metrics(name, op, kwargs)
xinference/core/worker.py CHANGED
@@ -16,6 +16,7 @@ import asyncio
16
16
  import os
17
17
  import platform
18
18
  import queue
19
+ import shutil
19
20
  import signal
20
21
  import threading
21
22
  import time
@@ -786,8 +787,73 @@ class WorkerActor(xo.StatelessActor):
786
787
  except asyncio.CancelledError: # pragma: no cover
787
788
  break
788
789
 
789
- async def list_cached_models(self) -> List[Dict[Any, Any]]:
790
- return self._cache_tracker_ref.list_cached_models()
790
+ async def list_cached_models(
791
+ self, model_name: Optional[str] = None
792
+ ) -> List[Dict[Any, Any]]:
793
+ lists = await self._cache_tracker_ref.list_cached_models(
794
+ self.address, model_name
795
+ )
796
+ cached_models = []
797
+ for list in lists:
798
+ cached_model = {
799
+ "model_name": list.get("model_name"),
800
+ "model_size_in_billions": list.get("model_size_in_billions"),
801
+ "model_format": list.get("model_format"),
802
+ "quantization": list.get("quantization"),
803
+ "model_version": list.get("model_version"),
804
+ }
805
+ path = list.get("model_file_location")
806
+ cached_model["path"] = path
807
+ # parsing soft links
808
+ if os.path.isdir(path):
809
+ files = os.listdir(path)
810
+ # dir has files
811
+ if files:
812
+ resolved_file = os.path.realpath(os.path.join(path, files[0]))
813
+ if resolved_file:
814
+ cached_model["real_path"] = os.path.dirname(resolved_file)
815
+ else:
816
+ cached_model["real_path"] = os.path.realpath(path)
817
+ cached_model["actor_ip_address"] = self.address
818
+ cached_models.append(cached_model)
819
+ return cached_models
820
+
821
+ async def list_deletable_models(self, model_version: str) -> List[str]:
822
+ paths = set()
823
+ path = await self._cache_tracker_ref.list_deletable_models(
824
+ model_version, self.address
825
+ )
826
+ if os.path.isfile(path):
827
+ path = os.path.dirname(path)
828
+
829
+ if os.path.isdir(path):
830
+ files = os.listdir(path)
831
+ paths.update([os.path.join(path, file) for file in files])
832
+ # search real path
833
+ if paths:
834
+ paths.update([os.path.realpath(path) for path in paths])
835
+
836
+ return list(paths)
837
+
838
+ async def confirm_and_remove_model(self, model_version: str) -> bool:
839
+ paths = await self.list_deletable_models(model_version)
840
+ for path in paths:
841
+ try:
842
+ if os.path.islink(path):
843
+ os.unlink(path)
844
+ elif os.path.isfile(path):
845
+ os.remove(path)
846
+ elif os.path.isdir(path):
847
+ shutil.rmtree(path)
848
+ else:
849
+ logger.debug(f"{path} is not a valid path.")
850
+ except Exception as e:
851
+ logger.error(f"Fail to delete {path} with error:{e}.")
852
+ return False
853
+ await self._cache_tracker_ref.confirm_and_remove_model(
854
+ model_version, self.address
855
+ )
856
+ return True
791
857
 
792
858
  @staticmethod
793
859
  def record_metrics(name, op, kwargs):
@@ -577,6 +577,18 @@ def list_model_registrations(
577
577
  type=str,
578
578
  help="Xinference endpoint.",
579
579
  )
580
+ @click.option(
581
+ "--model_name",
582
+ "-n",
583
+ type=str,
584
+ help="Provide the name of the models to be removed.",
585
+ )
586
+ @click.option(
587
+ "--worker-ip",
588
+ default=None,
589
+ type=str,
590
+ help="Specify which worker this model runs on by ip, for distributed situation.",
591
+ )
580
592
  @click.option(
581
593
  "--api-key",
582
594
  "-ak",
@@ -587,6 +599,8 @@ def list_model_registrations(
587
599
  def list_cached_models(
588
600
  endpoint: Optional[str],
589
601
  api_key: Optional[str],
602
+ model_name: Optional[str],
603
+ worker_ip: Optional[str],
590
604
  ):
591
605
  from tabulate import tabulate
592
606
 
@@ -595,10 +609,13 @@ def list_cached_models(
595
609
  if api_key is None:
596
610
  client._set_token(get_stored_token(endpoint, client))
597
611
 
598
- cached_models = client.list_cached_models()
612
+ cached_models = client.list_cached_models(model_name, worker_ip)
613
+ if not cached_models:
614
+ print("There are no cache files.")
615
+ return
616
+ headers = list(cached_models[0].keys())
599
617
 
600
618
  print("cached_model: ")
601
- headers = list(cached_models[0].keys())
602
619
  table_data = []
603
620
  for model in cached_models:
604
621
  row_data = [
@@ -608,6 +625,73 @@ def list_cached_models(
608
625
  print(tabulate(table_data, headers=headers, tablefmt="pretty"))
609
626
 
610
627
 
628
+ @cli.command("remove-cache", help="Remove selected cached models in Xinference.")
629
+ @click.option(
630
+ "--endpoint",
631
+ "-e",
632
+ type=str,
633
+ help="Xinference endpoint.",
634
+ )
635
+ @click.option(
636
+ "--model_version",
637
+ "-n",
638
+ type=str,
639
+ help="Provide the version of the models to be removed.",
640
+ )
641
+ @click.option(
642
+ "--worker-ip",
643
+ default=None,
644
+ type=str,
645
+ help="Specify which worker this model runs on by ip, for distributed situation.",
646
+ )
647
+ @click.option(
648
+ "--api-key",
649
+ "-ak",
650
+ default=None,
651
+ type=str,
652
+ help="Api-Key for access xinference api with authorization.",
653
+ )
654
+ @click.option("--check", is_flag=True, help="Confirm the deletion of the cache.")
655
+ def remove_cache(
656
+ endpoint: Optional[str],
657
+ model_version: str,
658
+ api_key: Optional[str],
659
+ check: bool,
660
+ worker_ip: Optional[str] = None,
661
+ ):
662
+ endpoint = get_endpoint(endpoint)
663
+ client = RESTfulClient(base_url=endpoint, api_key=api_key)
664
+ if api_key is None:
665
+ client._set_token(get_stored_token(endpoint, client))
666
+
667
+ if not check:
668
+ response = client.list_deletable_models(
669
+ model_version=model_version, worker_ip=worker_ip
670
+ )
671
+ paths: List[str] = response.get("paths", [])
672
+ if not paths:
673
+ click.echo(f"There is no model version named {model_version}.")
674
+ return
675
+ click.echo(f"Model {model_version} cache directory to be deleted:")
676
+ for path in response.get("paths", []):
677
+ click.echo(f"{path}")
678
+
679
+ if click.confirm("Do you want to proceed with the deletion?", abort=True):
680
+ check = True
681
+ try:
682
+ result = client.confirm_and_remove_model(
683
+ model_version=model_version, worker_ip=worker_ip
684
+ )
685
+ if result:
686
+ click.echo(f"Cache directory {model_version} has been deleted.")
687
+ else:
688
+ click.echo(
689
+ f"Cache directory {model_version} fail to be deleted. Please check the log."
690
+ )
691
+ except Exception as e:
692
+ click.echo(f"An error occurred while deleting the cache: {e}")
693
+
694
+
611
695
  @cli.command(
612
696
  "launch",
613
697
  help="Launch a model with the Xinference framework with the given parameters.",
@@ -26,6 +26,7 @@ from ..cmdline import (
26
26
  model_list,
27
27
  model_terminate,
28
28
  register_model,
29
+ remove_cache,
29
30
  unregister_model,
30
31
  )
31
32
 
@@ -287,18 +288,26 @@ def test_list_cached_models(setup):
287
288
 
288
289
  result = runner.invoke(
289
290
  list_cached_models,
290
- [
291
- "--endpoint",
292
- endpoint,
293
- ],
291
+ ["--endpoint", endpoint, "--model_name", "orca"],
294
292
  )
295
- assert result.exit_code == 0
296
- assert "cached_model: " in result.stdout
297
-
298
- # check if the output is in tabular format
299
293
  assert "model_name" in result.stdout
300
294
  assert "model_format" in result.stdout
301
295
  assert "model_size_in_billions" in result.stdout
302
- assert "quantizations" in result.stdout
296
+ assert "quantization" in result.stdout
297
+ assert "model_version" in result.stdout
303
298
  assert "path" in result.stdout
304
- assert "Actor IP Address" in result.stdout
299
+ assert "actor_ip_address" in result.stdout
300
+
301
+
302
+ def test_remove_cache(setup):
303
+ endpoint, _ = setup
304
+ runner = CliRunner()
305
+
306
+ result = runner.invoke(
307
+ remove_cache,
308
+ ["--endpoint", endpoint, "--model_version", "orca"],
309
+ input="y\n",
310
+ )
311
+
312
+ assert result.exit_code == 0
313
+ assert "Cache directory orca has been deleted."
@@ -32,6 +32,9 @@ from .custom import (
32
32
  )
33
33
 
34
34
  _model_spec_json = os.path.join(os.path.dirname(__file__), "model_spec.json")
35
+ _model_spec_modelscope_json = os.path.join(
36
+ os.path.dirname(__file__), "model_spec_modelscope.json"
37
+ )
35
38
  BUILTIN_AUDIO_MODELS = dict(
36
39
  (spec["model_name"], AudioModelFamilyV1(**spec))
37
40
  for spec in json.load(codecs.open(_model_spec_json, "r", encoding="utf-8"))
@@ -39,8 +42,17 @@ BUILTIN_AUDIO_MODELS = dict(
39
42
  for model_name, model_spec in BUILTIN_AUDIO_MODELS.items():
40
43
  MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)
41
44
 
45
+ MODELSCOPE_AUDIO_MODELS = dict(
46
+ (spec["model_name"], AudioModelFamilyV1(**spec))
47
+ for spec in json.load(
48
+ codecs.open(_model_spec_modelscope_json, "r", encoding="utf-8")
49
+ )
50
+ )
51
+ for model_name, model_spec in MODELSCOPE_AUDIO_MODELS.items():
52
+ MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)
53
+
42
54
  # register model description after recording model revision
43
- for model_spec_info in [BUILTIN_AUDIO_MODELS]:
55
+ for model_spec_info in [BUILTIN_AUDIO_MODELS, MODELSCOPE_AUDIO_MODELS]:
44
56
  for model_name, model_spec in model_spec_info.items():
45
57
  if model_spec.model_name not in AUDIO_MODEL_DESCRIPTIONS:
46
58
  AUDIO_MODEL_DESCRIPTIONS.update(generate_audio_description(model_spec))
@@ -64,3 +76,4 @@ for ud_audio in get_user_defined_audios():
64
76
  AUDIO_MODEL_DESCRIPTIONS.update(generate_audio_description(ud_audio))
65
77
 
66
78
  del _model_spec_json
79
+ del _model_spec_modelscope_json
@@ -95,13 +95,24 @@ def generate_audio_description(
95
95
 
96
96
 
97
97
  def match_audio(model_name: str) -> AudioModelFamilyV1:
98
- from . import BUILTIN_AUDIO_MODELS
98
+ from ..utils import download_from_modelscope
99
+ from . import BUILTIN_AUDIO_MODELS, MODELSCOPE_AUDIO_MODELS
99
100
  from .custom import get_user_defined_audios
100
101
 
101
102
  for model_spec in get_user_defined_audios():
102
103
  if model_spec.model_name == model_name:
103
104
  return model_spec
104
105
 
106
+ if download_from_modelscope():
107
+ if model_name in MODELSCOPE_AUDIO_MODELS:
108
+ logger.debug(f"Audio model {model_name} found in ModelScope.")
109
+ return MODELSCOPE_AUDIO_MODELS[model_name]
110
+ else:
111
+ logger.debug(
112
+ f"Audio model {model_name} not found in ModelScope, "
113
+ f"now try to load it via builtin way."
114
+ )
115
+
105
116
  if model_name in BUILTIN_AUDIO_MODELS:
106
117
  return BUILTIN_AUDIO_MODELS[model_name]
107
118
  else:
@@ -83,15 +83,17 @@ def get_user_defined_audios() -> List[CustomAudioModelFamilyV1]:
83
83
  def register_audio(model_spec: CustomAudioModelFamilyV1, persist: bool):
84
84
  from ...constants import XINFERENCE_MODEL_DIR
85
85
  from ..utils import is_valid_model_name, is_valid_model_uri
86
- from . import BUILTIN_AUDIO_MODELS
86
+ from . import BUILTIN_AUDIO_MODELS, MODELSCOPE_AUDIO_MODELS
87
87
 
88
88
  if not is_valid_model_name(model_spec.model_name):
89
89
  raise ValueError(f"Invalid model name {model_spec.model_name}.")
90
90
 
91
91
  with UD_AUDIO_LOCK:
92
- for model_name in list(BUILTIN_AUDIO_MODELS.keys()) + [
93
- spec.model_name for spec in UD_AUDIOS
94
- ]:
92
+ for model_name in (
93
+ list(BUILTIN_AUDIO_MODELS.keys())
94
+ + list(MODELSCOPE_AUDIO_MODELS.keys())
95
+ + [spec.model_name for spec in UD_AUDIOS]
96
+ ):
95
97
  if model_spec.model_name == model_name:
96
98
  raise ValueError(
97
99
  f"Model name conflicts with existing model {model_spec.model_name}"
@@ -0,0 +1,20 @@
1
+ [
2
+ {
3
+ "model_name": "whisper-large-v3",
4
+ "model_family": "whisper",
5
+ "model_hub": "modelscope",
6
+ "model_id": "AI-ModelScope/whisper-large-v3",
7
+ "model_revision": "master",
8
+ "ability": "audio-to-text",
9
+ "multilingual": true
10
+ },
11
+ {
12
+ "model_name": "ChatTTS",
13
+ "model_family": "ChatTTS",
14
+ "model_hub": "modelscope",
15
+ "model_id": "pzc163/chatTTS",
16
+ "model_revision": "master",
17
+ "ability": "text-to-audio",
18
+ "multilingual": true
19
+ }
20
+ ]
@@ -25,6 +25,7 @@ from .core import (
25
25
  get_llm_model_descriptions,
26
26
  )
27
27
  from .llm_family import (
28
+ BUILTIN_CSGHUB_LLM_FAMILIES,
28
29
  BUILTIN_LLM_FAMILIES,
29
30
  BUILTIN_LLM_MODEL_CHAT_FAMILIES,
30
31
  BUILTIN_LLM_MODEL_GENERATE_FAMILIES,
@@ -221,13 +222,44 @@ def _install():
221
222
  if "tools" in model_spec.model_ability:
222
223
  BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES.add(model_spec.model_name)
223
224
 
224
- for llm_specs in [BUILTIN_LLM_FAMILIES, BUILTIN_MODELSCOPE_LLM_FAMILIES]:
225
+ csghub_json_path = os.path.join(
226
+ os.path.dirname(os.path.abspath(__file__)), "llm_family_csghub.json"
227
+ )
228
+ for json_obj in json.load(codecs.open(csghub_json_path, "r", encoding="utf-8")):
229
+ model_spec = LLMFamilyV1.parse_obj(json_obj)
230
+ BUILTIN_CSGHUB_LLM_FAMILIES.append(model_spec)
231
+
232
+ # register prompt style, in case that we have something missed
233
+ # if duplicated with huggingface json, keep it as the huggingface style
234
+ if (
235
+ "chat" in model_spec.model_ability
236
+ and isinstance(model_spec.prompt_style, PromptStyleV1)
237
+ and model_spec.model_name not in BUILTIN_LLM_PROMPT_STYLE
238
+ ):
239
+ BUILTIN_LLM_PROMPT_STYLE[model_spec.model_name] = model_spec.prompt_style
240
+ # register model family
241
+ if "chat" in model_spec.model_ability:
242
+ BUILTIN_LLM_MODEL_CHAT_FAMILIES.add(model_spec.model_name)
243
+ else:
244
+ BUILTIN_LLM_MODEL_GENERATE_FAMILIES.add(model_spec.model_name)
245
+ if "tools" in model_spec.model_ability:
246
+ BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES.add(model_spec.model_name)
247
+
248
+ for llm_specs in [
249
+ BUILTIN_LLM_FAMILIES,
250
+ BUILTIN_MODELSCOPE_LLM_FAMILIES,
251
+ BUILTIN_CSGHUB_LLM_FAMILIES,
252
+ ]:
225
253
  for llm_spec in llm_specs:
226
254
  if llm_spec.model_name not in LLM_MODEL_DESCRIPTIONS:
227
255
  LLM_MODEL_DESCRIPTIONS.update(generate_llm_description(llm_spec))
228
256
 
229
257
  # traverse all families and add engine parameters corresponding to the model name
230
- for families in [BUILTIN_LLM_FAMILIES, BUILTIN_MODELSCOPE_LLM_FAMILIES]:
258
+ for families in [
259
+ BUILTIN_LLM_FAMILIES,
260
+ BUILTIN_MODELSCOPE_LLM_FAMILIES,
261
+ BUILTIN_CSGHUB_LLM_FAMILIES,
262
+ ]:
231
263
  for family in families:
232
264
  generate_engine_config_by_model_family(family)
233
265
 
@@ -939,6 +939,8 @@
939
939
  "model_format": "pytorch",
940
940
  "model_size_in_billions": 9,
941
941
  "quantizations": [
942
+ "4-bit",
943
+ "8-bit",
942
944
  "none"
943
945
  ],
944
946
  "model_id": "THUDM/glm-4v-9b",
@@ -2288,7 +2290,8 @@
2288
2290
  "zh"
2289
2291
  ],
2290
2292
  "model_ability": [
2291
- "chat"
2293
+ "chat",
2294
+ "tools"
2292
2295
  ],
2293
2296
  "model_description": "Qwen1.5-MoE is a transformer-based MoE decoder-only language model pretrained on a large amount of data.",
2294
2297
  "model_specs": [
@@ -2593,7 +2596,8 @@
2593
2596
  "zh"
2594
2597
  ],
2595
2598
  "model_ability": [
2596
- "chat"
2599
+ "chat",
2600
+ "tools"
2597
2601
  ],
2598
2602
  "model_description": "Qwen2 is the new series of Qwen large language models. ",
2599
2603
  "model_specs": [
@@ -5673,9 +5677,11 @@
5673
5677
  ],
5674
5678
  "intra_message_sep": "<|im_end|>",
5675
5679
  "stop_token_ids": [
5680
+ 2,
5676
5681
  92542
5677
5682
  ],
5678
5683
  "stop": [
5684
+ "</s>",
5679
5685
  "<|im_end|>"
5680
5686
  ]
5681
5687
  }