xinference 0.11.3__py3-none-any.whl → 0.12.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +143 -6
- xinference/client/restful/restful_client.py +144 -5
- xinference/constants.py +5 -0
- xinference/core/cache_tracker.py +48 -28
- xinference/core/model.py +160 -19
- xinference/core/scheduler.py +446 -0
- xinference/core/supervisor.py +99 -24
- xinference/core/worker.py +68 -2
- xinference/deploy/cmdline.py +86 -2
- xinference/deploy/test/test_cmdline.py +19 -10
- xinference/isolation.py +9 -2
- xinference/model/audio/__init__.py +14 -1
- xinference/model/audio/chattts.py +84 -0
- xinference/model/audio/core.py +22 -4
- xinference/model/audio/custom.py +6 -4
- xinference/model/audio/model_spec.json +20 -0
- xinference/model/audio/model_spec_modelscope.json +20 -0
- xinference/model/llm/__init__.py +38 -2
- xinference/model/llm/llm_family.json +509 -1
- xinference/model/llm/llm_family.py +86 -1
- xinference/model/llm/llm_family_csghub.json +66 -0
- xinference/model/llm/llm_family_modelscope.json +411 -2
- xinference/model/llm/pytorch/chatglm.py +20 -13
- xinference/model/llm/pytorch/cogvlm2.py +76 -17
- xinference/model/llm/pytorch/core.py +141 -6
- xinference/model/llm/pytorch/glm4v.py +268 -0
- xinference/model/llm/pytorch/minicpmv25.py +232 -0
- xinference/model/llm/pytorch/qwen_vl.py +1 -1
- xinference/model/llm/pytorch/utils.py +405 -8
- xinference/model/llm/utils.py +14 -13
- xinference/model/llm/vllm/core.py +16 -4
- xinference/model/utils.py +8 -2
- xinference/thirdparty/ChatTTS/__init__.py +1 -0
- xinference/thirdparty/ChatTTS/core.py +200 -0
- xinference/thirdparty/ChatTTS/experimental/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/experimental/llm.py +40 -0
- xinference/thirdparty/ChatTTS/infer/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/infer/api.py +125 -0
- xinference/thirdparty/ChatTTS/model/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/model/dvae.py +155 -0
- xinference/thirdparty/ChatTTS/model/gpt.py +265 -0
- xinference/thirdparty/ChatTTS/utils/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/utils/gpu_utils.py +23 -0
- xinference/thirdparty/ChatTTS/utils/infer_utils.py +141 -0
- xinference/thirdparty/ChatTTS/utils/io_utils.py +14 -0
- xinference/types.py +3 -0
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.074e2b31.css +2 -0
- xinference/web/ui/build/static/css/main.074e2b31.css.map +1 -0
- xinference/web/ui/build/static/js/main.a58ff436.js +3 -0
- xinference/web/ui/build/static/js/main.a58ff436.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/10262a281dec3bc2b185f4385ceb6846626f52d41cb4d46c7c649e719f979d4d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/762a75a62daf3bec2cfc97ec8612798493fb34ef87087dcad6aad64ab7f14345.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/7f3bdb3a48fa00c046c8b185acd4da6f2e2940a20dbd77f9373d60de3fd6633e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f2f73bfdc13b12b02c8cbc4769b0b8e6367e9b6d8331c322d94318491a0b3653.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f51bf63ddaa7afd125ef2254a105789333eecc1c94fdf5157a9b88ef7ad0a5bd.json +1 -0
- {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/METADATA +26 -9
- {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/RECORD +65 -47
- xinference/web/ui/build/static/css/main.54bca460.css +0 -2
- xinference/web/ui/build/static/css/main.54bca460.css.map +0 -1
- xinference/web/ui/build/static/js/main.551aa479.js +0 -3
- xinference/web/ui/build/static/js/main.551aa479.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1e86938a0cdf706d21e99b21f5d868fa247c0c88b26807047e26dcdc4d9a9db3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1fa824d82b2af519de7700c594e50bde4bbca60d13bd3fabff576802e4070304.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/3e737bcdbcbc407ccd65b90e199ef0c3214b261e8e41dbf14d921384a717d9ee.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/46b6dd1f6d1109cd0e2455a0ea0be3e9bda1097cd4ebec9c4040070372671cfc.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/59ce49eae0f486af4c5034d4d2f9ca77c3ec3a32ecc560085caf5ef482b5f4c9.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/a6da6bc3d0d2191adebee87fb58ecebe82d071087bd2f7f3a9c7fdd2ada130f2.json +0 -1
- /xinference/web/ui/build/static/js/{main.551aa479.js.LICENSE.txt → main.a58ff436.js.LICENSE.txt} +0 -0
- {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/LICENSE +0 -0
- {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/WHEEL +0 -0
- {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/top_level.txt +0 -0
xinference/core/supervisor.py
CHANGED
|
@@ -982,34 +982,59 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
982
982
|
)
|
|
983
983
|
|
|
984
984
|
@log_async(logger=logger)
|
|
985
|
-
async def list_cached_models(
|
|
985
|
+
async def list_cached_models(
|
|
986
|
+
self, model_name: Optional[str] = None, worker_ip: Optional[str] = None
|
|
987
|
+
) -> List[Dict[str, Any]]:
|
|
988
|
+
target_ip_worker_ref = (
|
|
989
|
+
self._get_worker_ref_by_ip(worker_ip) if worker_ip is not None else None
|
|
990
|
+
)
|
|
991
|
+
if (
|
|
992
|
+
worker_ip is not None
|
|
993
|
+
and not self.is_local_deployment()
|
|
994
|
+
and target_ip_worker_ref is None
|
|
995
|
+
):
|
|
996
|
+
raise ValueError(f"Worker ip address {worker_ip} is not in the cluster.")
|
|
997
|
+
|
|
998
|
+
# search assigned worker and return
|
|
999
|
+
if target_ip_worker_ref:
|
|
1000
|
+
cached_models = await target_ip_worker_ref.list_cached_models(model_name)
|
|
1001
|
+
cached_models = sorted(cached_models, key=lambda x: x["model_name"])
|
|
1002
|
+
return cached_models
|
|
1003
|
+
|
|
1004
|
+
# search all worker
|
|
986
1005
|
cached_models = []
|
|
987
1006
|
for worker in self._worker_address_to_worker.values():
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
model_format = model_version.get("model_format", None)
|
|
992
|
-
model_size_in_billions = model_version.get(
|
|
993
|
-
"model_size_in_billions", None
|
|
994
|
-
)
|
|
995
|
-
quantizations = model_version.get("quantization", None)
|
|
996
|
-
actor_ip_address = model_version.get("actor_ip_address", None)
|
|
997
|
-
path = model_version.get("path", None)
|
|
998
|
-
real_path = model_version.get("real_path", None)
|
|
999
|
-
|
|
1000
|
-
cache_entry = {
|
|
1001
|
-
"model_name": model_name,
|
|
1002
|
-
"model_format": model_format,
|
|
1003
|
-
"model_size_in_billions": model_size_in_billions,
|
|
1004
|
-
"quantizations": quantizations,
|
|
1005
|
-
"path": path,
|
|
1006
|
-
"Actor IP Address": actor_ip_address,
|
|
1007
|
-
"real_path": real_path,
|
|
1008
|
-
}
|
|
1009
|
-
|
|
1010
|
-
cached_models.append(cache_entry)
|
|
1007
|
+
res = await worker.list_cached_models(model_name)
|
|
1008
|
+
cached_models.extend(res)
|
|
1009
|
+
cached_models = sorted(cached_models, key=lambda x: x["model_name"])
|
|
1011
1010
|
return cached_models
|
|
1012
1011
|
|
|
1012
|
+
@log_async(logger=logger)
|
|
1013
|
+
async def abort_request(self, model_uid: str, request_id: str) -> Dict:
|
|
1014
|
+
from .scheduler import AbortRequestMessage
|
|
1015
|
+
|
|
1016
|
+
res = {"msg": AbortRequestMessage.NO_OP.name}
|
|
1017
|
+
replica_info = self._model_uid_to_replica_info.get(model_uid, None)
|
|
1018
|
+
if not replica_info:
|
|
1019
|
+
return res
|
|
1020
|
+
replica_cnt = replica_info.replica
|
|
1021
|
+
|
|
1022
|
+
# Query all replicas
|
|
1023
|
+
for rep_mid in iter_replica_model_uid(model_uid, replica_cnt):
|
|
1024
|
+
worker_ref = self._replica_model_uid_to_worker.get(rep_mid, None)
|
|
1025
|
+
if worker_ref is None:
|
|
1026
|
+
continue
|
|
1027
|
+
model_ref = await worker_ref.get_model(model_uid=rep_mid)
|
|
1028
|
+
result_info = await model_ref.abort_request(request_id)
|
|
1029
|
+
res["msg"] = result_info
|
|
1030
|
+
if result_info == AbortRequestMessage.DONE.name:
|
|
1031
|
+
break
|
|
1032
|
+
elif result_info == AbortRequestMessage.NOT_FOUND.name:
|
|
1033
|
+
logger.debug(f"Request id: {request_id} not found for model {rep_mid}")
|
|
1034
|
+
else:
|
|
1035
|
+
logger.debug(f"No-op for model {rep_mid}")
|
|
1036
|
+
return res
|
|
1037
|
+
|
|
1013
1038
|
@log_async(logger=logger)
|
|
1014
1039
|
async def add_worker(self, worker_address: str):
|
|
1015
1040
|
from .worker import WorkerActor
|
|
@@ -1057,6 +1082,56 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
1057
1082
|
worker_status.update_time = time.time()
|
|
1058
1083
|
worker_status.status = status
|
|
1059
1084
|
|
|
1085
|
+
async def list_deletable_models(
|
|
1086
|
+
self, model_version: str, worker_ip: Optional[str] = None
|
|
1087
|
+
) -> List[str]:
|
|
1088
|
+
target_ip_worker_ref = (
|
|
1089
|
+
self._get_worker_ref_by_ip(worker_ip) if worker_ip is not None else None
|
|
1090
|
+
)
|
|
1091
|
+
if (
|
|
1092
|
+
worker_ip is not None
|
|
1093
|
+
and not self.is_local_deployment()
|
|
1094
|
+
and target_ip_worker_ref is None
|
|
1095
|
+
):
|
|
1096
|
+
raise ValueError(f"Worker ip address {worker_ip} is not in the cluster.")
|
|
1097
|
+
|
|
1098
|
+
ret = []
|
|
1099
|
+
if target_ip_worker_ref:
|
|
1100
|
+
ret = await target_ip_worker_ref.list_deletable_models(
|
|
1101
|
+
model_version=model_version,
|
|
1102
|
+
)
|
|
1103
|
+
return ret
|
|
1104
|
+
|
|
1105
|
+
for worker in self._worker_address_to_worker.values():
|
|
1106
|
+
path = await worker.list_deletable_models(model_version=model_version)
|
|
1107
|
+
ret.extend(path)
|
|
1108
|
+
return ret
|
|
1109
|
+
|
|
1110
|
+
async def confirm_and_remove_model(
|
|
1111
|
+
self, model_version: str, worker_ip: Optional[str] = None
|
|
1112
|
+
) -> bool:
|
|
1113
|
+
target_ip_worker_ref = (
|
|
1114
|
+
self._get_worker_ref_by_ip(worker_ip) if worker_ip is not None else None
|
|
1115
|
+
)
|
|
1116
|
+
if (
|
|
1117
|
+
worker_ip is not None
|
|
1118
|
+
and not self.is_local_deployment()
|
|
1119
|
+
and target_ip_worker_ref is None
|
|
1120
|
+
):
|
|
1121
|
+
raise ValueError(f"Worker ip address {worker_ip} is not in the cluster.")
|
|
1122
|
+
|
|
1123
|
+
if target_ip_worker_ref:
|
|
1124
|
+
ret = await target_ip_worker_ref.confirm_and_remove_model(
|
|
1125
|
+
model_version=model_version,
|
|
1126
|
+
)
|
|
1127
|
+
return ret
|
|
1128
|
+
ret = True
|
|
1129
|
+
for worker in self._worker_address_to_worker.values():
|
|
1130
|
+
ret = ret and await worker.confirm_and_remove_model(
|
|
1131
|
+
model_version=model_version,
|
|
1132
|
+
)
|
|
1133
|
+
return ret
|
|
1134
|
+
|
|
1060
1135
|
@staticmethod
|
|
1061
1136
|
def record_metrics(name, op, kwargs):
|
|
1062
1137
|
record_metrics(name, op, kwargs)
|
xinference/core/worker.py
CHANGED
|
@@ -16,6 +16,7 @@ import asyncio
|
|
|
16
16
|
import os
|
|
17
17
|
import platform
|
|
18
18
|
import queue
|
|
19
|
+
import shutil
|
|
19
20
|
import signal
|
|
20
21
|
import threading
|
|
21
22
|
import time
|
|
@@ -786,8 +787,73 @@ class WorkerActor(xo.StatelessActor):
|
|
|
786
787
|
except asyncio.CancelledError: # pragma: no cover
|
|
787
788
|
break
|
|
788
789
|
|
|
789
|
-
async def list_cached_models(
|
|
790
|
-
|
|
790
|
+
async def list_cached_models(
|
|
791
|
+
self, model_name: Optional[str] = None
|
|
792
|
+
) -> List[Dict[Any, Any]]:
|
|
793
|
+
lists = await self._cache_tracker_ref.list_cached_models(
|
|
794
|
+
self.address, model_name
|
|
795
|
+
)
|
|
796
|
+
cached_models = []
|
|
797
|
+
for list in lists:
|
|
798
|
+
cached_model = {
|
|
799
|
+
"model_name": list.get("model_name"),
|
|
800
|
+
"model_size_in_billions": list.get("model_size_in_billions"),
|
|
801
|
+
"model_format": list.get("model_format"),
|
|
802
|
+
"quantization": list.get("quantization"),
|
|
803
|
+
"model_version": list.get("model_version"),
|
|
804
|
+
}
|
|
805
|
+
path = list.get("model_file_location")
|
|
806
|
+
cached_model["path"] = path
|
|
807
|
+
# parsing soft links
|
|
808
|
+
if os.path.isdir(path):
|
|
809
|
+
files = os.listdir(path)
|
|
810
|
+
# dir has files
|
|
811
|
+
if files:
|
|
812
|
+
resolved_file = os.path.realpath(os.path.join(path, files[0]))
|
|
813
|
+
if resolved_file:
|
|
814
|
+
cached_model["real_path"] = os.path.dirname(resolved_file)
|
|
815
|
+
else:
|
|
816
|
+
cached_model["real_path"] = os.path.realpath(path)
|
|
817
|
+
cached_model["actor_ip_address"] = self.address
|
|
818
|
+
cached_models.append(cached_model)
|
|
819
|
+
return cached_models
|
|
820
|
+
|
|
821
|
+
async def list_deletable_models(self, model_version: str) -> List[str]:
|
|
822
|
+
paths = set()
|
|
823
|
+
path = await self._cache_tracker_ref.list_deletable_models(
|
|
824
|
+
model_version, self.address
|
|
825
|
+
)
|
|
826
|
+
if os.path.isfile(path):
|
|
827
|
+
path = os.path.dirname(path)
|
|
828
|
+
|
|
829
|
+
if os.path.isdir(path):
|
|
830
|
+
files = os.listdir(path)
|
|
831
|
+
paths.update([os.path.join(path, file) for file in files])
|
|
832
|
+
# search real path
|
|
833
|
+
if paths:
|
|
834
|
+
paths.update([os.path.realpath(path) for path in paths])
|
|
835
|
+
|
|
836
|
+
return list(paths)
|
|
837
|
+
|
|
838
|
+
async def confirm_and_remove_model(self, model_version: str) -> bool:
|
|
839
|
+
paths = await self.list_deletable_models(model_version)
|
|
840
|
+
for path in paths:
|
|
841
|
+
try:
|
|
842
|
+
if os.path.islink(path):
|
|
843
|
+
os.unlink(path)
|
|
844
|
+
elif os.path.isfile(path):
|
|
845
|
+
os.remove(path)
|
|
846
|
+
elif os.path.isdir(path):
|
|
847
|
+
shutil.rmtree(path)
|
|
848
|
+
else:
|
|
849
|
+
logger.debug(f"{path} is not a valid path.")
|
|
850
|
+
except Exception as e:
|
|
851
|
+
logger.error(f"Fail to delete {path} with error:{e}.")
|
|
852
|
+
return False
|
|
853
|
+
await self._cache_tracker_ref.confirm_and_remove_model(
|
|
854
|
+
model_version, self.address
|
|
855
|
+
)
|
|
856
|
+
return True
|
|
791
857
|
|
|
792
858
|
@staticmethod
|
|
793
859
|
def record_metrics(name, op, kwargs):
|
xinference/deploy/cmdline.py
CHANGED
|
@@ -577,6 +577,18 @@ def list_model_registrations(
|
|
|
577
577
|
type=str,
|
|
578
578
|
help="Xinference endpoint.",
|
|
579
579
|
)
|
|
580
|
+
@click.option(
|
|
581
|
+
"--model_name",
|
|
582
|
+
"-n",
|
|
583
|
+
type=str,
|
|
584
|
+
help="Provide the name of the models to be removed.",
|
|
585
|
+
)
|
|
586
|
+
@click.option(
|
|
587
|
+
"--worker-ip",
|
|
588
|
+
default=None,
|
|
589
|
+
type=str,
|
|
590
|
+
help="Specify which worker this model runs on by ip, for distributed situation.",
|
|
591
|
+
)
|
|
580
592
|
@click.option(
|
|
581
593
|
"--api-key",
|
|
582
594
|
"-ak",
|
|
@@ -587,6 +599,8 @@ def list_model_registrations(
|
|
|
587
599
|
def list_cached_models(
|
|
588
600
|
endpoint: Optional[str],
|
|
589
601
|
api_key: Optional[str],
|
|
602
|
+
model_name: Optional[str],
|
|
603
|
+
worker_ip: Optional[str],
|
|
590
604
|
):
|
|
591
605
|
from tabulate import tabulate
|
|
592
606
|
|
|
@@ -595,10 +609,13 @@ def list_cached_models(
|
|
|
595
609
|
if api_key is None:
|
|
596
610
|
client._set_token(get_stored_token(endpoint, client))
|
|
597
611
|
|
|
598
|
-
cached_models = client.list_cached_models()
|
|
612
|
+
cached_models = client.list_cached_models(model_name, worker_ip)
|
|
613
|
+
if not cached_models:
|
|
614
|
+
print("There are no cache files.")
|
|
615
|
+
return
|
|
616
|
+
headers = list(cached_models[0].keys())
|
|
599
617
|
|
|
600
618
|
print("cached_model: ")
|
|
601
|
-
headers = list(cached_models[0].keys())
|
|
602
619
|
table_data = []
|
|
603
620
|
for model in cached_models:
|
|
604
621
|
row_data = [
|
|
@@ -608,6 +625,73 @@ def list_cached_models(
|
|
|
608
625
|
print(tabulate(table_data, headers=headers, tablefmt="pretty"))
|
|
609
626
|
|
|
610
627
|
|
|
628
|
+
@cli.command("remove-cache", help="Remove selected cached models in Xinference.")
|
|
629
|
+
@click.option(
|
|
630
|
+
"--endpoint",
|
|
631
|
+
"-e",
|
|
632
|
+
type=str,
|
|
633
|
+
help="Xinference endpoint.",
|
|
634
|
+
)
|
|
635
|
+
@click.option(
|
|
636
|
+
"--model_version",
|
|
637
|
+
"-n",
|
|
638
|
+
type=str,
|
|
639
|
+
help="Provide the version of the models to be removed.",
|
|
640
|
+
)
|
|
641
|
+
@click.option(
|
|
642
|
+
"--worker-ip",
|
|
643
|
+
default=None,
|
|
644
|
+
type=str,
|
|
645
|
+
help="Specify which worker this model runs on by ip, for distributed situation.",
|
|
646
|
+
)
|
|
647
|
+
@click.option(
|
|
648
|
+
"--api-key",
|
|
649
|
+
"-ak",
|
|
650
|
+
default=None,
|
|
651
|
+
type=str,
|
|
652
|
+
help="Api-Key for access xinference api with authorization.",
|
|
653
|
+
)
|
|
654
|
+
@click.option("--check", is_flag=True, help="Confirm the deletion of the cache.")
|
|
655
|
+
def remove_cache(
|
|
656
|
+
endpoint: Optional[str],
|
|
657
|
+
model_version: str,
|
|
658
|
+
api_key: Optional[str],
|
|
659
|
+
check: bool,
|
|
660
|
+
worker_ip: Optional[str] = None,
|
|
661
|
+
):
|
|
662
|
+
endpoint = get_endpoint(endpoint)
|
|
663
|
+
client = RESTfulClient(base_url=endpoint, api_key=api_key)
|
|
664
|
+
if api_key is None:
|
|
665
|
+
client._set_token(get_stored_token(endpoint, client))
|
|
666
|
+
|
|
667
|
+
if not check:
|
|
668
|
+
response = client.list_deletable_models(
|
|
669
|
+
model_version=model_version, worker_ip=worker_ip
|
|
670
|
+
)
|
|
671
|
+
paths: List[str] = response.get("paths", [])
|
|
672
|
+
if not paths:
|
|
673
|
+
click.echo(f"There is no model version named {model_version}.")
|
|
674
|
+
return
|
|
675
|
+
click.echo(f"Model {model_version} cache directory to be deleted:")
|
|
676
|
+
for path in response.get("paths", []):
|
|
677
|
+
click.echo(f"{path}")
|
|
678
|
+
|
|
679
|
+
if click.confirm("Do you want to proceed with the deletion?", abort=True):
|
|
680
|
+
check = True
|
|
681
|
+
try:
|
|
682
|
+
result = client.confirm_and_remove_model(
|
|
683
|
+
model_version=model_version, worker_ip=worker_ip
|
|
684
|
+
)
|
|
685
|
+
if result:
|
|
686
|
+
click.echo(f"Cache directory {model_version} has been deleted.")
|
|
687
|
+
else:
|
|
688
|
+
click.echo(
|
|
689
|
+
f"Cache directory {model_version} fail to be deleted. Please check the log."
|
|
690
|
+
)
|
|
691
|
+
except Exception as e:
|
|
692
|
+
click.echo(f"An error occurred while deleting the cache: {e}")
|
|
693
|
+
|
|
694
|
+
|
|
611
695
|
@cli.command(
|
|
612
696
|
"launch",
|
|
613
697
|
help="Launch a model with the Xinference framework with the given parameters.",
|
|
@@ -26,6 +26,7 @@ from ..cmdline import (
|
|
|
26
26
|
model_list,
|
|
27
27
|
model_terminate,
|
|
28
28
|
register_model,
|
|
29
|
+
remove_cache,
|
|
29
30
|
unregister_model,
|
|
30
31
|
)
|
|
31
32
|
|
|
@@ -287,18 +288,26 @@ def test_list_cached_models(setup):
|
|
|
287
288
|
|
|
288
289
|
result = runner.invoke(
|
|
289
290
|
list_cached_models,
|
|
290
|
-
[
|
|
291
|
-
"--endpoint",
|
|
292
|
-
endpoint,
|
|
293
|
-
],
|
|
291
|
+
["--endpoint", endpoint, "--model_name", "orca"],
|
|
294
292
|
)
|
|
295
|
-
assert result.exit_code == 0
|
|
296
|
-
assert "cached_model: " in result.stdout
|
|
297
|
-
|
|
298
|
-
# check if the output is in tabular format
|
|
299
293
|
assert "model_name" in result.stdout
|
|
300
294
|
assert "model_format" in result.stdout
|
|
301
295
|
assert "model_size_in_billions" in result.stdout
|
|
302
|
-
assert "
|
|
296
|
+
assert "quantization" in result.stdout
|
|
297
|
+
assert "model_version" in result.stdout
|
|
303
298
|
assert "path" in result.stdout
|
|
304
|
-
assert "
|
|
299
|
+
assert "actor_ip_address" in result.stdout
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def test_remove_cache(setup):
|
|
303
|
+
endpoint, _ = setup
|
|
304
|
+
runner = CliRunner()
|
|
305
|
+
|
|
306
|
+
result = runner.invoke(
|
|
307
|
+
remove_cache,
|
|
308
|
+
["--endpoint", endpoint, "--model_version", "orca"],
|
|
309
|
+
input="y\n",
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
assert result.exit_code == 0
|
|
313
|
+
assert "Cache directory orca has been deleted."
|
xinference/isolation.py
CHANGED
|
@@ -19,13 +19,19 @@ from typing import Any, Coroutine
|
|
|
19
19
|
|
|
20
20
|
class Isolation:
|
|
21
21
|
# TODO: better move isolation to xoscar.
|
|
22
|
-
def __init__(
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
loop: asyncio.AbstractEventLoop,
|
|
25
|
+
threaded: bool = True,
|
|
26
|
+
daemon: bool = True,
|
|
27
|
+
):
|
|
23
28
|
self._loop = loop
|
|
24
29
|
self._threaded = threaded
|
|
25
30
|
|
|
26
31
|
self._stopped = None
|
|
27
32
|
self._thread = None
|
|
28
33
|
self._thread_ident = None
|
|
34
|
+
self._daemon = daemon
|
|
29
35
|
|
|
30
36
|
def _run(self):
|
|
31
37
|
asyncio.set_event_loop(self._loop)
|
|
@@ -35,7 +41,8 @@ class Isolation:
|
|
|
35
41
|
def start(self):
|
|
36
42
|
if self._threaded:
|
|
37
43
|
self._thread = thread = threading.Thread(target=self._run)
|
|
38
|
-
|
|
44
|
+
if self._daemon:
|
|
45
|
+
thread.daemon = True
|
|
39
46
|
thread.start()
|
|
40
47
|
self._thread_ident = thread.ident
|
|
41
48
|
|
|
@@ -32,6 +32,9 @@ from .custom import (
|
|
|
32
32
|
)
|
|
33
33
|
|
|
34
34
|
_model_spec_json = os.path.join(os.path.dirname(__file__), "model_spec.json")
|
|
35
|
+
_model_spec_modelscope_json = os.path.join(
|
|
36
|
+
os.path.dirname(__file__), "model_spec_modelscope.json"
|
|
37
|
+
)
|
|
35
38
|
BUILTIN_AUDIO_MODELS = dict(
|
|
36
39
|
(spec["model_name"], AudioModelFamilyV1(**spec))
|
|
37
40
|
for spec in json.load(codecs.open(_model_spec_json, "r", encoding="utf-8"))
|
|
@@ -39,8 +42,17 @@ BUILTIN_AUDIO_MODELS = dict(
|
|
|
39
42
|
for model_name, model_spec in BUILTIN_AUDIO_MODELS.items():
|
|
40
43
|
MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)
|
|
41
44
|
|
|
45
|
+
MODELSCOPE_AUDIO_MODELS = dict(
|
|
46
|
+
(spec["model_name"], AudioModelFamilyV1(**spec))
|
|
47
|
+
for spec in json.load(
|
|
48
|
+
codecs.open(_model_spec_modelscope_json, "r", encoding="utf-8")
|
|
49
|
+
)
|
|
50
|
+
)
|
|
51
|
+
for model_name, model_spec in MODELSCOPE_AUDIO_MODELS.items():
|
|
52
|
+
MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)
|
|
53
|
+
|
|
42
54
|
# register model description after recording model revision
|
|
43
|
-
for model_spec_info in [BUILTIN_AUDIO_MODELS]:
|
|
55
|
+
for model_spec_info in [BUILTIN_AUDIO_MODELS, MODELSCOPE_AUDIO_MODELS]:
|
|
44
56
|
for model_name, model_spec in model_spec_info.items():
|
|
45
57
|
if model_spec.model_name not in AUDIO_MODEL_DESCRIPTIONS:
|
|
46
58
|
AUDIO_MODEL_DESCRIPTIONS.update(generate_audio_description(model_spec))
|
|
@@ -64,3 +76,4 @@ for ud_audio in get_user_defined_audios():
|
|
|
64
76
|
AUDIO_MODEL_DESCRIPTIONS.update(generate_audio_description(ud_audio))
|
|
65
77
|
|
|
66
78
|
del _model_spec_json
|
|
79
|
+
del _model_spec_modelscope_json
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import logging
|
|
15
|
+
from io import BytesIO
|
|
16
|
+
from typing import TYPE_CHECKING, Optional
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from .core import AudioModelFamilyV1
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ChatTTSModel:
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
model_uid: str,
|
|
28
|
+
model_path: str,
|
|
29
|
+
model_spec: "AudioModelFamilyV1",
|
|
30
|
+
device: Optional[str] = None,
|
|
31
|
+
**kwargs,
|
|
32
|
+
):
|
|
33
|
+
self._model_uid = model_uid
|
|
34
|
+
self._model_path = model_path
|
|
35
|
+
self._model_spec = model_spec
|
|
36
|
+
self._device = device
|
|
37
|
+
self._model = None
|
|
38
|
+
self._kwargs = kwargs
|
|
39
|
+
|
|
40
|
+
def load(self):
|
|
41
|
+
import torch
|
|
42
|
+
|
|
43
|
+
from xinference.thirdparty import ChatTTS
|
|
44
|
+
|
|
45
|
+
torch._dynamo.config.cache_size_limit = 64
|
|
46
|
+
torch._dynamo.config.suppress_errors = True
|
|
47
|
+
torch.set_float32_matmul_precision("high")
|
|
48
|
+
self._model = ChatTTS.Chat()
|
|
49
|
+
self._model.load_models(
|
|
50
|
+
source="local", local_path=self._model_path, compile=True
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
def speech(
|
|
54
|
+
self, input: str, voice: str, response_format: str = "mp3", speed: float = 1.0
|
|
55
|
+
):
|
|
56
|
+
import numpy as np
|
|
57
|
+
import torch
|
|
58
|
+
import torchaudio
|
|
59
|
+
import xxhash
|
|
60
|
+
|
|
61
|
+
seed = xxhash.xxh32_intdigest(voice)
|
|
62
|
+
|
|
63
|
+
torch.manual_seed(seed)
|
|
64
|
+
np.random.seed(seed)
|
|
65
|
+
torch.cuda.manual_seed(seed)
|
|
66
|
+
torch.backends.cudnn.deterministic = True
|
|
67
|
+
torch.backends.cudnn.benchmark = False
|
|
68
|
+
|
|
69
|
+
assert self._model is not None
|
|
70
|
+
rnd_spk_emb = self._model.sample_random_speaker()
|
|
71
|
+
|
|
72
|
+
default = 5
|
|
73
|
+
infer_speed = int(default * speed)
|
|
74
|
+
params_infer_code = {"spk_emb": rnd_spk_emb, "prompt": f"[speed_{infer_speed}]"}
|
|
75
|
+
|
|
76
|
+
assert self._model is not None
|
|
77
|
+
wavs = self._model.infer([input], params_infer_code=params_infer_code)
|
|
78
|
+
|
|
79
|
+
# Save the generated audio
|
|
80
|
+
with BytesIO() as out:
|
|
81
|
+
torchaudio.save(
|
|
82
|
+
out, torch.from_numpy(wavs[0]), 24000, format=response_format
|
|
83
|
+
)
|
|
84
|
+
return out.getvalue()
|
xinference/model/audio/core.py
CHANGED
|
@@ -14,11 +14,12 @@
|
|
|
14
14
|
import logging
|
|
15
15
|
import os
|
|
16
16
|
from collections import defaultdict
|
|
17
|
-
from typing import Dict, List, Optional, Tuple
|
|
17
|
+
from typing import Dict, List, Optional, Tuple, Union
|
|
18
18
|
|
|
19
19
|
from ...constants import XINFERENCE_CACHE_DIR
|
|
20
20
|
from ..core import CacheableModelSpec, ModelDescription
|
|
21
21
|
from ..utils import valid_model_revision
|
|
22
|
+
from .chattts import ChatTTSModel
|
|
22
23
|
from .whisper import WhisperModel
|
|
23
24
|
|
|
24
25
|
MAX_ATTEMPTS = 3
|
|
@@ -94,13 +95,24 @@ def generate_audio_description(
|
|
|
94
95
|
|
|
95
96
|
|
|
96
97
|
def match_audio(model_name: str) -> AudioModelFamilyV1:
|
|
97
|
-
from
|
|
98
|
+
from ..utils import download_from_modelscope
|
|
99
|
+
from . import BUILTIN_AUDIO_MODELS, MODELSCOPE_AUDIO_MODELS
|
|
98
100
|
from .custom import get_user_defined_audios
|
|
99
101
|
|
|
100
102
|
for model_spec in get_user_defined_audios():
|
|
101
103
|
if model_spec.model_name == model_name:
|
|
102
104
|
return model_spec
|
|
103
105
|
|
|
106
|
+
if download_from_modelscope():
|
|
107
|
+
if model_name in MODELSCOPE_AUDIO_MODELS:
|
|
108
|
+
logger.debug(f"Audio model {model_name} found in ModelScope.")
|
|
109
|
+
return MODELSCOPE_AUDIO_MODELS[model_name]
|
|
110
|
+
else:
|
|
111
|
+
logger.debug(
|
|
112
|
+
f"Audio model {model_name} not found in ModelScope, "
|
|
113
|
+
f"now try to load it via builtin way."
|
|
114
|
+
)
|
|
115
|
+
|
|
104
116
|
if model_name in BUILTIN_AUDIO_MODELS:
|
|
105
117
|
return BUILTIN_AUDIO_MODELS[model_name]
|
|
106
118
|
else:
|
|
@@ -130,10 +142,16 @@ def get_cache_status(
|
|
|
130
142
|
|
|
131
143
|
def create_audio_model_instance(
|
|
132
144
|
subpool_addr: str, devices: List[str], model_uid: str, model_name: str, **kwargs
|
|
133
|
-
) -> Tuple[WhisperModel, AudioModelDescription]:
|
|
145
|
+
) -> Tuple[Union[WhisperModel, ChatTTSModel], AudioModelDescription]:
|
|
134
146
|
model_spec = match_audio(model_name)
|
|
135
147
|
model_path = cache(model_spec)
|
|
136
|
-
model
|
|
148
|
+
model: Union[WhisperModel, ChatTTSModel]
|
|
149
|
+
if model_spec.model_family == "whisper":
|
|
150
|
+
model = WhisperModel(model_uid, model_path, model_spec, **kwargs)
|
|
151
|
+
elif model_spec.model_family == "ChatTTS":
|
|
152
|
+
model = ChatTTSModel(model_uid, model_path, model_spec, **kwargs)
|
|
153
|
+
else:
|
|
154
|
+
raise Exception(f"Unsupported audio model family: {model_spec.model_family}")
|
|
137
155
|
model_description = AudioModelDescription(
|
|
138
156
|
subpool_addr, devices, model_spec, model_path=model_path
|
|
139
157
|
)
|
xinference/model/audio/custom.py
CHANGED
|
@@ -83,15 +83,17 @@ def get_user_defined_audios() -> List[CustomAudioModelFamilyV1]:
|
|
|
83
83
|
def register_audio(model_spec: CustomAudioModelFamilyV1, persist: bool):
|
|
84
84
|
from ...constants import XINFERENCE_MODEL_DIR
|
|
85
85
|
from ..utils import is_valid_model_name, is_valid_model_uri
|
|
86
|
-
from . import BUILTIN_AUDIO_MODELS
|
|
86
|
+
from . import BUILTIN_AUDIO_MODELS, MODELSCOPE_AUDIO_MODELS
|
|
87
87
|
|
|
88
88
|
if not is_valid_model_name(model_spec.model_name):
|
|
89
89
|
raise ValueError(f"Invalid model name {model_spec.model_name}.")
|
|
90
90
|
|
|
91
91
|
with UD_AUDIO_LOCK:
|
|
92
|
-
for model_name in
|
|
93
|
-
|
|
94
|
-
|
|
92
|
+
for model_name in (
|
|
93
|
+
list(BUILTIN_AUDIO_MODELS.keys())
|
|
94
|
+
+ list(MODELSCOPE_AUDIO_MODELS.keys())
|
|
95
|
+
+ [spec.model_name for spec in UD_AUDIOS]
|
|
96
|
+
):
|
|
95
97
|
if model_spec.model_name == model_name:
|
|
96
98
|
raise ValueError(
|
|
97
99
|
f"Model name conflicts with existing model {model_spec.model_name}"
|