xinference 0.11.3__py3-none-any.whl → 0.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (75) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +143 -6
  3. xinference/client/restful/restful_client.py +144 -5
  4. xinference/constants.py +5 -0
  5. xinference/core/cache_tracker.py +48 -28
  6. xinference/core/model.py +160 -19
  7. xinference/core/scheduler.py +446 -0
  8. xinference/core/supervisor.py +99 -24
  9. xinference/core/worker.py +68 -2
  10. xinference/deploy/cmdline.py +86 -2
  11. xinference/deploy/test/test_cmdline.py +19 -10
  12. xinference/isolation.py +9 -2
  13. xinference/model/audio/__init__.py +14 -1
  14. xinference/model/audio/chattts.py +84 -0
  15. xinference/model/audio/core.py +22 -4
  16. xinference/model/audio/custom.py +6 -4
  17. xinference/model/audio/model_spec.json +20 -0
  18. xinference/model/audio/model_spec_modelscope.json +20 -0
  19. xinference/model/llm/__init__.py +38 -2
  20. xinference/model/llm/llm_family.json +509 -1
  21. xinference/model/llm/llm_family.py +86 -1
  22. xinference/model/llm/llm_family_csghub.json +66 -0
  23. xinference/model/llm/llm_family_modelscope.json +411 -2
  24. xinference/model/llm/pytorch/chatglm.py +20 -13
  25. xinference/model/llm/pytorch/cogvlm2.py +76 -17
  26. xinference/model/llm/pytorch/core.py +141 -6
  27. xinference/model/llm/pytorch/glm4v.py +268 -0
  28. xinference/model/llm/pytorch/minicpmv25.py +232 -0
  29. xinference/model/llm/pytorch/qwen_vl.py +1 -1
  30. xinference/model/llm/pytorch/utils.py +405 -8
  31. xinference/model/llm/utils.py +14 -13
  32. xinference/model/llm/vllm/core.py +16 -4
  33. xinference/model/utils.py +8 -2
  34. xinference/thirdparty/ChatTTS/__init__.py +1 -0
  35. xinference/thirdparty/ChatTTS/core.py +200 -0
  36. xinference/thirdparty/ChatTTS/experimental/__init__.py +0 -0
  37. xinference/thirdparty/ChatTTS/experimental/llm.py +40 -0
  38. xinference/thirdparty/ChatTTS/infer/__init__.py +0 -0
  39. xinference/thirdparty/ChatTTS/infer/api.py +125 -0
  40. xinference/thirdparty/ChatTTS/model/__init__.py +0 -0
  41. xinference/thirdparty/ChatTTS/model/dvae.py +155 -0
  42. xinference/thirdparty/ChatTTS/model/gpt.py +265 -0
  43. xinference/thirdparty/ChatTTS/utils/__init__.py +0 -0
  44. xinference/thirdparty/ChatTTS/utils/gpu_utils.py +23 -0
  45. xinference/thirdparty/ChatTTS/utils/infer_utils.py +141 -0
  46. xinference/thirdparty/ChatTTS/utils/io_utils.py +14 -0
  47. xinference/types.py +3 -0
  48. xinference/web/ui/build/asset-manifest.json +6 -6
  49. xinference/web/ui/build/index.html +1 -1
  50. xinference/web/ui/build/static/css/main.074e2b31.css +2 -0
  51. xinference/web/ui/build/static/css/main.074e2b31.css.map +1 -0
  52. xinference/web/ui/build/static/js/main.a58ff436.js +3 -0
  53. xinference/web/ui/build/static/js/main.a58ff436.js.map +1 -0
  54. xinference/web/ui/node_modules/.cache/babel-loader/10262a281dec3bc2b185f4385ceb6846626f52d41cb4d46c7c649e719f979d4d.json +1 -0
  55. xinference/web/ui/node_modules/.cache/babel-loader/762a75a62daf3bec2cfc97ec8612798493fb34ef87087dcad6aad64ab7f14345.json +1 -0
  56. xinference/web/ui/node_modules/.cache/babel-loader/7f3bdb3a48fa00c046c8b185acd4da6f2e2940a20dbd77f9373d60de3fd6633e.json +1 -0
  57. xinference/web/ui/node_modules/.cache/babel-loader/f2f73bfdc13b12b02c8cbc4769b0b8e6367e9b6d8331c322d94318491a0b3653.json +1 -0
  58. xinference/web/ui/node_modules/.cache/babel-loader/f51bf63ddaa7afd125ef2254a105789333eecc1c94fdf5157a9b88ef7ad0a5bd.json +1 -0
  59. {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/METADATA +26 -9
  60. {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/RECORD +65 -47
  61. xinference/web/ui/build/static/css/main.54bca460.css +0 -2
  62. xinference/web/ui/build/static/css/main.54bca460.css.map +0 -1
  63. xinference/web/ui/build/static/js/main.551aa479.js +0 -3
  64. xinference/web/ui/build/static/js/main.551aa479.js.map +0 -1
  65. xinference/web/ui/node_modules/.cache/babel-loader/1e86938a0cdf706d21e99b21f5d868fa247c0c88b26807047e26dcdc4d9a9db3.json +0 -1
  66. xinference/web/ui/node_modules/.cache/babel-loader/1fa824d82b2af519de7700c594e50bde4bbca60d13bd3fabff576802e4070304.json +0 -1
  67. xinference/web/ui/node_modules/.cache/babel-loader/3e737bcdbcbc407ccd65b90e199ef0c3214b261e8e41dbf14d921384a717d9ee.json +0 -1
  68. xinference/web/ui/node_modules/.cache/babel-loader/46b6dd1f6d1109cd0e2455a0ea0be3e9bda1097cd4ebec9c4040070372671cfc.json +0 -1
  69. xinference/web/ui/node_modules/.cache/babel-loader/59ce49eae0f486af4c5034d4d2f9ca77c3ec3a32ecc560085caf5ef482b5f4c9.json +0 -1
  70. xinference/web/ui/node_modules/.cache/babel-loader/a6da6bc3d0d2191adebee87fb58ecebe82d071087bd2f7f3a9c7fdd2ada130f2.json +0 -1
  71. /xinference/web/ui/build/static/js/{main.551aa479.js.LICENSE.txt → main.a58ff436.js.LICENSE.txt} +0 -0
  72. {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/LICENSE +0 -0
  73. {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/WHEEL +0 -0
  74. {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/entry_points.txt +0 -0
  75. {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/top_level.txt +0 -0
@@ -982,34 +982,59 @@ class SupervisorActor(xo.StatelessActor):
982
982
  )
983
983
 
984
984
  @log_async(logger=logger)
985
- async def list_cached_models(self) -> List[Dict[str, Any]]:
985
+ async def list_cached_models(
986
+ self, model_name: Optional[str] = None, worker_ip: Optional[str] = None
987
+ ) -> List[Dict[str, Any]]:
988
+ target_ip_worker_ref = (
989
+ self._get_worker_ref_by_ip(worker_ip) if worker_ip is not None else None
990
+ )
991
+ if (
992
+ worker_ip is not None
993
+ and not self.is_local_deployment()
994
+ and target_ip_worker_ref is None
995
+ ):
996
+ raise ValueError(f"Worker ip address {worker_ip} is not in the cluster.")
997
+
998
+ # search assigned worker and return
999
+ if target_ip_worker_ref:
1000
+ cached_models = await target_ip_worker_ref.list_cached_models(model_name)
1001
+ cached_models = sorted(cached_models, key=lambda x: x["model_name"])
1002
+ return cached_models
1003
+
1004
+ # search all worker
986
1005
  cached_models = []
987
1006
  for worker in self._worker_address_to_worker.values():
988
- ret = await worker.list_cached_models()
989
- for model_version in ret:
990
- model_name = model_version.get("model_name", None)
991
- model_format = model_version.get("model_format", None)
992
- model_size_in_billions = model_version.get(
993
- "model_size_in_billions", None
994
- )
995
- quantizations = model_version.get("quantization", None)
996
- actor_ip_address = model_version.get("actor_ip_address", None)
997
- path = model_version.get("path", None)
998
- real_path = model_version.get("real_path", None)
999
-
1000
- cache_entry = {
1001
- "model_name": model_name,
1002
- "model_format": model_format,
1003
- "model_size_in_billions": model_size_in_billions,
1004
- "quantizations": quantizations,
1005
- "path": path,
1006
- "Actor IP Address": actor_ip_address,
1007
- "real_path": real_path,
1008
- }
1009
-
1010
- cached_models.append(cache_entry)
1007
+ res = await worker.list_cached_models(model_name)
1008
+ cached_models.extend(res)
1009
+ cached_models = sorted(cached_models, key=lambda x: x["model_name"])
1011
1010
  return cached_models
1012
1011
 
1012
+ @log_async(logger=logger)
1013
+ async def abort_request(self, model_uid: str, request_id: str) -> Dict:
1014
+ from .scheduler import AbortRequestMessage
1015
+
1016
+ res = {"msg": AbortRequestMessage.NO_OP.name}
1017
+ replica_info = self._model_uid_to_replica_info.get(model_uid, None)
1018
+ if not replica_info:
1019
+ return res
1020
+ replica_cnt = replica_info.replica
1021
+
1022
+ # Query all replicas
1023
+ for rep_mid in iter_replica_model_uid(model_uid, replica_cnt):
1024
+ worker_ref = self._replica_model_uid_to_worker.get(rep_mid, None)
1025
+ if worker_ref is None:
1026
+ continue
1027
+ model_ref = await worker_ref.get_model(model_uid=rep_mid)
1028
+ result_info = await model_ref.abort_request(request_id)
1029
+ res["msg"] = result_info
1030
+ if result_info == AbortRequestMessage.DONE.name:
1031
+ break
1032
+ elif result_info == AbortRequestMessage.NOT_FOUND.name:
1033
+ logger.debug(f"Request id: {request_id} not found for model {rep_mid}")
1034
+ else:
1035
+ logger.debug(f"No-op for model {rep_mid}")
1036
+ return res
1037
+
1013
1038
  @log_async(logger=logger)
1014
1039
  async def add_worker(self, worker_address: str):
1015
1040
  from .worker import WorkerActor
@@ -1057,6 +1082,56 @@ class SupervisorActor(xo.StatelessActor):
1057
1082
  worker_status.update_time = time.time()
1058
1083
  worker_status.status = status
1059
1084
 
1085
+ async def list_deletable_models(
1086
+ self, model_version: str, worker_ip: Optional[str] = None
1087
+ ) -> List[str]:
1088
+ target_ip_worker_ref = (
1089
+ self._get_worker_ref_by_ip(worker_ip) if worker_ip is not None else None
1090
+ )
1091
+ if (
1092
+ worker_ip is not None
1093
+ and not self.is_local_deployment()
1094
+ and target_ip_worker_ref is None
1095
+ ):
1096
+ raise ValueError(f"Worker ip address {worker_ip} is not in the cluster.")
1097
+
1098
+ ret = []
1099
+ if target_ip_worker_ref:
1100
+ ret = await target_ip_worker_ref.list_deletable_models(
1101
+ model_version=model_version,
1102
+ )
1103
+ return ret
1104
+
1105
+ for worker in self._worker_address_to_worker.values():
1106
+ path = await worker.list_deletable_models(model_version=model_version)
1107
+ ret.extend(path)
1108
+ return ret
1109
+
1110
+ async def confirm_and_remove_model(
1111
+ self, model_version: str, worker_ip: Optional[str] = None
1112
+ ) -> bool:
1113
+ target_ip_worker_ref = (
1114
+ self._get_worker_ref_by_ip(worker_ip) if worker_ip is not None else None
1115
+ )
1116
+ if (
1117
+ worker_ip is not None
1118
+ and not self.is_local_deployment()
1119
+ and target_ip_worker_ref is None
1120
+ ):
1121
+ raise ValueError(f"Worker ip address {worker_ip} is not in the cluster.")
1122
+
1123
+ if target_ip_worker_ref:
1124
+ ret = await target_ip_worker_ref.confirm_and_remove_model(
1125
+ model_version=model_version,
1126
+ )
1127
+ return ret
1128
+ ret = True
1129
+ for worker in self._worker_address_to_worker.values():
1130
+ ret = ret and await worker.confirm_and_remove_model(
1131
+ model_version=model_version,
1132
+ )
1133
+ return ret
1134
+
1060
1135
  @staticmethod
1061
1136
  def record_metrics(name, op, kwargs):
1062
1137
  record_metrics(name, op, kwargs)
xinference/core/worker.py CHANGED
@@ -16,6 +16,7 @@ import asyncio
16
16
  import os
17
17
  import platform
18
18
  import queue
19
+ import shutil
19
20
  import signal
20
21
  import threading
21
22
  import time
@@ -786,8 +787,73 @@ class WorkerActor(xo.StatelessActor):
786
787
  except asyncio.CancelledError: # pragma: no cover
787
788
  break
788
789
 
789
- async def list_cached_models(self) -> List[Dict[Any, Any]]:
790
- return self._cache_tracker_ref.list_cached_models()
790
+ async def list_cached_models(
791
+ self, model_name: Optional[str] = None
792
+ ) -> List[Dict[Any, Any]]:
793
+ lists = await self._cache_tracker_ref.list_cached_models(
794
+ self.address, model_name
795
+ )
796
+ cached_models = []
797
+ for list in lists:
798
+ cached_model = {
799
+ "model_name": list.get("model_name"),
800
+ "model_size_in_billions": list.get("model_size_in_billions"),
801
+ "model_format": list.get("model_format"),
802
+ "quantization": list.get("quantization"),
803
+ "model_version": list.get("model_version"),
804
+ }
805
+ path = list.get("model_file_location")
806
+ cached_model["path"] = path
807
+ # parsing soft links
808
+ if os.path.isdir(path):
809
+ files = os.listdir(path)
810
+ # dir has files
811
+ if files:
812
+ resolved_file = os.path.realpath(os.path.join(path, files[0]))
813
+ if resolved_file:
814
+ cached_model["real_path"] = os.path.dirname(resolved_file)
815
+ else:
816
+ cached_model["real_path"] = os.path.realpath(path)
817
+ cached_model["actor_ip_address"] = self.address
818
+ cached_models.append(cached_model)
819
+ return cached_models
820
+
821
+ async def list_deletable_models(self, model_version: str) -> List[str]:
822
+ paths = set()
823
+ path = await self._cache_tracker_ref.list_deletable_models(
824
+ model_version, self.address
825
+ )
826
+ if os.path.isfile(path):
827
+ path = os.path.dirname(path)
828
+
829
+ if os.path.isdir(path):
830
+ files = os.listdir(path)
831
+ paths.update([os.path.join(path, file) for file in files])
832
+ # search real path
833
+ if paths:
834
+ paths.update([os.path.realpath(path) for path in paths])
835
+
836
+ return list(paths)
837
+
838
+ async def confirm_and_remove_model(self, model_version: str) -> bool:
839
+ paths = await self.list_deletable_models(model_version)
840
+ for path in paths:
841
+ try:
842
+ if os.path.islink(path):
843
+ os.unlink(path)
844
+ elif os.path.isfile(path):
845
+ os.remove(path)
846
+ elif os.path.isdir(path):
847
+ shutil.rmtree(path)
848
+ else:
849
+ logger.debug(f"{path} is not a valid path.")
850
+ except Exception as e:
851
+ logger.error(f"Fail to delete {path} with error:{e}.")
852
+ return False
853
+ await self._cache_tracker_ref.confirm_and_remove_model(
854
+ model_version, self.address
855
+ )
856
+ return True
791
857
 
792
858
  @staticmethod
793
859
  def record_metrics(name, op, kwargs):
@@ -577,6 +577,18 @@ def list_model_registrations(
577
577
  type=str,
578
578
  help="Xinference endpoint.",
579
579
  )
580
+ @click.option(
581
+ "--model_name",
582
+ "-n",
583
+ type=str,
584
+ help="Provide the name of the models to be removed.",
585
+ )
586
+ @click.option(
587
+ "--worker-ip",
588
+ default=None,
589
+ type=str,
590
+ help="Specify which worker this model runs on by ip, for distributed situation.",
591
+ )
580
592
  @click.option(
581
593
  "--api-key",
582
594
  "-ak",
@@ -587,6 +599,8 @@ def list_model_registrations(
587
599
  def list_cached_models(
588
600
  endpoint: Optional[str],
589
601
  api_key: Optional[str],
602
+ model_name: Optional[str],
603
+ worker_ip: Optional[str],
590
604
  ):
591
605
  from tabulate import tabulate
592
606
 
@@ -595,10 +609,13 @@ def list_cached_models(
595
609
  if api_key is None:
596
610
  client._set_token(get_stored_token(endpoint, client))
597
611
 
598
- cached_models = client.list_cached_models()
612
+ cached_models = client.list_cached_models(model_name, worker_ip)
613
+ if not cached_models:
614
+ print("There are no cache files.")
615
+ return
616
+ headers = list(cached_models[0].keys())
599
617
 
600
618
  print("cached_model: ")
601
- headers = list(cached_models[0].keys())
602
619
  table_data = []
603
620
  for model in cached_models:
604
621
  row_data = [
@@ -608,6 +625,73 @@ def list_cached_models(
608
625
  print(tabulate(table_data, headers=headers, tablefmt="pretty"))
609
626
 
610
627
 
628
+ @cli.command("remove-cache", help="Remove selected cached models in Xinference.")
629
+ @click.option(
630
+ "--endpoint",
631
+ "-e",
632
+ type=str,
633
+ help="Xinference endpoint.",
634
+ )
635
+ @click.option(
636
+ "--model_version",
637
+ "-n",
638
+ type=str,
639
+ help="Provide the version of the models to be removed.",
640
+ )
641
+ @click.option(
642
+ "--worker-ip",
643
+ default=None,
644
+ type=str,
645
+ help="Specify which worker this model runs on by ip, for distributed situation.",
646
+ )
647
+ @click.option(
648
+ "--api-key",
649
+ "-ak",
650
+ default=None,
651
+ type=str,
652
+ help="Api-Key for access xinference api with authorization.",
653
+ )
654
+ @click.option("--check", is_flag=True, help="Confirm the deletion of the cache.")
655
+ def remove_cache(
656
+ endpoint: Optional[str],
657
+ model_version: str,
658
+ api_key: Optional[str],
659
+ check: bool,
660
+ worker_ip: Optional[str] = None,
661
+ ):
662
+ endpoint = get_endpoint(endpoint)
663
+ client = RESTfulClient(base_url=endpoint, api_key=api_key)
664
+ if api_key is None:
665
+ client._set_token(get_stored_token(endpoint, client))
666
+
667
+ if not check:
668
+ response = client.list_deletable_models(
669
+ model_version=model_version, worker_ip=worker_ip
670
+ )
671
+ paths: List[str] = response.get("paths", [])
672
+ if not paths:
673
+ click.echo(f"There is no model version named {model_version}.")
674
+ return
675
+ click.echo(f"Model {model_version} cache directory to be deleted:")
676
+ for path in response.get("paths", []):
677
+ click.echo(f"{path}")
678
+
679
+ if click.confirm("Do you want to proceed with the deletion?", abort=True):
680
+ check = True
681
+ try:
682
+ result = client.confirm_and_remove_model(
683
+ model_version=model_version, worker_ip=worker_ip
684
+ )
685
+ if result:
686
+ click.echo(f"Cache directory {model_version} has been deleted.")
687
+ else:
688
+ click.echo(
689
+ f"Cache directory {model_version} fail to be deleted. Please check the log."
690
+ )
691
+ except Exception as e:
692
+ click.echo(f"An error occurred while deleting the cache: {e}")
693
+
694
+
611
695
  @cli.command(
612
696
  "launch",
613
697
  help="Launch a model with the Xinference framework with the given parameters.",
@@ -26,6 +26,7 @@ from ..cmdline import (
26
26
  model_list,
27
27
  model_terminate,
28
28
  register_model,
29
+ remove_cache,
29
30
  unregister_model,
30
31
  )
31
32
 
@@ -287,18 +288,26 @@ def test_list_cached_models(setup):
287
288
 
288
289
  result = runner.invoke(
289
290
  list_cached_models,
290
- [
291
- "--endpoint",
292
- endpoint,
293
- ],
291
+ ["--endpoint", endpoint, "--model_name", "orca"],
294
292
  )
295
- assert result.exit_code == 0
296
- assert "cached_model: " in result.stdout
297
-
298
- # check if the output is in tabular format
299
293
  assert "model_name" in result.stdout
300
294
  assert "model_format" in result.stdout
301
295
  assert "model_size_in_billions" in result.stdout
302
- assert "quantizations" in result.stdout
296
+ assert "quantization" in result.stdout
297
+ assert "model_version" in result.stdout
303
298
  assert "path" in result.stdout
304
- assert "Actor IP Address" in result.stdout
299
+ assert "actor_ip_address" in result.stdout
300
+
301
+
302
+ def test_remove_cache(setup):
303
+ endpoint, _ = setup
304
+ runner = CliRunner()
305
+
306
+ result = runner.invoke(
307
+ remove_cache,
308
+ ["--endpoint", endpoint, "--model_version", "orca"],
309
+ input="y\n",
310
+ )
311
+
312
+ assert result.exit_code == 0
313
+ assert "Cache directory orca has been deleted."
xinference/isolation.py CHANGED
@@ -19,13 +19,19 @@ from typing import Any, Coroutine
19
19
 
20
20
  class Isolation:
21
21
  # TODO: better move isolation to xoscar.
22
- def __init__(self, loop: asyncio.AbstractEventLoop, threaded: bool = True):
22
+ def __init__(
23
+ self,
24
+ loop: asyncio.AbstractEventLoop,
25
+ threaded: bool = True,
26
+ daemon: bool = True,
27
+ ):
23
28
  self._loop = loop
24
29
  self._threaded = threaded
25
30
 
26
31
  self._stopped = None
27
32
  self._thread = None
28
33
  self._thread_ident = None
34
+ self._daemon = daemon
29
35
 
30
36
  def _run(self):
31
37
  asyncio.set_event_loop(self._loop)
@@ -35,7 +41,8 @@ class Isolation:
35
41
  def start(self):
36
42
  if self._threaded:
37
43
  self._thread = thread = threading.Thread(target=self._run)
38
- thread.daemon = True
44
+ if self._daemon:
45
+ thread.daemon = True
39
46
  thread.start()
40
47
  self._thread_ident = thread.ident
41
48
 
@@ -32,6 +32,9 @@ from .custom import (
32
32
  )
33
33
 
34
34
  _model_spec_json = os.path.join(os.path.dirname(__file__), "model_spec.json")
35
+ _model_spec_modelscope_json = os.path.join(
36
+ os.path.dirname(__file__), "model_spec_modelscope.json"
37
+ )
35
38
  BUILTIN_AUDIO_MODELS = dict(
36
39
  (spec["model_name"], AudioModelFamilyV1(**spec))
37
40
  for spec in json.load(codecs.open(_model_spec_json, "r", encoding="utf-8"))
@@ -39,8 +42,17 @@ BUILTIN_AUDIO_MODELS = dict(
39
42
  for model_name, model_spec in BUILTIN_AUDIO_MODELS.items():
40
43
  MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)
41
44
 
45
+ MODELSCOPE_AUDIO_MODELS = dict(
46
+ (spec["model_name"], AudioModelFamilyV1(**spec))
47
+ for spec in json.load(
48
+ codecs.open(_model_spec_modelscope_json, "r", encoding="utf-8")
49
+ )
50
+ )
51
+ for model_name, model_spec in MODELSCOPE_AUDIO_MODELS.items():
52
+ MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)
53
+
42
54
  # register model description after recording model revision
43
- for model_spec_info in [BUILTIN_AUDIO_MODELS]:
55
+ for model_spec_info in [BUILTIN_AUDIO_MODELS, MODELSCOPE_AUDIO_MODELS]:
44
56
  for model_name, model_spec in model_spec_info.items():
45
57
  if model_spec.model_name not in AUDIO_MODEL_DESCRIPTIONS:
46
58
  AUDIO_MODEL_DESCRIPTIONS.update(generate_audio_description(model_spec))
@@ -64,3 +76,4 @@ for ud_audio in get_user_defined_audios():
64
76
  AUDIO_MODEL_DESCRIPTIONS.update(generate_audio_description(ud_audio))
65
77
 
66
78
  del _model_spec_json
79
+ del _model_spec_modelscope_json
@@ -0,0 +1,84 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ from io import BytesIO
16
+ from typing import TYPE_CHECKING, Optional
17
+
18
+ if TYPE_CHECKING:
19
+ from .core import AudioModelFamilyV1
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class ChatTTSModel:
25
+ def __init__(
26
+ self,
27
+ model_uid: str,
28
+ model_path: str,
29
+ model_spec: "AudioModelFamilyV1",
30
+ device: Optional[str] = None,
31
+ **kwargs,
32
+ ):
33
+ self._model_uid = model_uid
34
+ self._model_path = model_path
35
+ self._model_spec = model_spec
36
+ self._device = device
37
+ self._model = None
38
+ self._kwargs = kwargs
39
+
40
+ def load(self):
41
+ import torch
42
+
43
+ from xinference.thirdparty import ChatTTS
44
+
45
+ torch._dynamo.config.cache_size_limit = 64
46
+ torch._dynamo.config.suppress_errors = True
47
+ torch.set_float32_matmul_precision("high")
48
+ self._model = ChatTTS.Chat()
49
+ self._model.load_models(
50
+ source="local", local_path=self._model_path, compile=True
51
+ )
52
+
53
+ def speech(
54
+ self, input: str, voice: str, response_format: str = "mp3", speed: float = 1.0
55
+ ):
56
+ import numpy as np
57
+ import torch
58
+ import torchaudio
59
+ import xxhash
60
+
61
+ seed = xxhash.xxh32_intdigest(voice)
62
+
63
+ torch.manual_seed(seed)
64
+ np.random.seed(seed)
65
+ torch.cuda.manual_seed(seed)
66
+ torch.backends.cudnn.deterministic = True
67
+ torch.backends.cudnn.benchmark = False
68
+
69
+ assert self._model is not None
70
+ rnd_spk_emb = self._model.sample_random_speaker()
71
+
72
+ default = 5
73
+ infer_speed = int(default * speed)
74
+ params_infer_code = {"spk_emb": rnd_spk_emb, "prompt": f"[speed_{infer_speed}]"}
75
+
76
+ assert self._model is not None
77
+ wavs = self._model.infer([input], params_infer_code=params_infer_code)
78
+
79
+ # Save the generated audio
80
+ with BytesIO() as out:
81
+ torchaudio.save(
82
+ out, torch.from_numpy(wavs[0]), 24000, format=response_format
83
+ )
84
+ return out.getvalue()
@@ -14,11 +14,12 @@
14
14
  import logging
15
15
  import os
16
16
  from collections import defaultdict
17
- from typing import Dict, List, Optional, Tuple
17
+ from typing import Dict, List, Optional, Tuple, Union
18
18
 
19
19
  from ...constants import XINFERENCE_CACHE_DIR
20
20
  from ..core import CacheableModelSpec, ModelDescription
21
21
  from ..utils import valid_model_revision
22
+ from .chattts import ChatTTSModel
22
23
  from .whisper import WhisperModel
23
24
 
24
25
  MAX_ATTEMPTS = 3
@@ -94,13 +95,24 @@ def generate_audio_description(
94
95
 
95
96
 
96
97
  def match_audio(model_name: str) -> AudioModelFamilyV1:
97
- from . import BUILTIN_AUDIO_MODELS
98
+ from ..utils import download_from_modelscope
99
+ from . import BUILTIN_AUDIO_MODELS, MODELSCOPE_AUDIO_MODELS
98
100
  from .custom import get_user_defined_audios
99
101
 
100
102
  for model_spec in get_user_defined_audios():
101
103
  if model_spec.model_name == model_name:
102
104
  return model_spec
103
105
 
106
+ if download_from_modelscope():
107
+ if model_name in MODELSCOPE_AUDIO_MODELS:
108
+ logger.debug(f"Audio model {model_name} found in ModelScope.")
109
+ return MODELSCOPE_AUDIO_MODELS[model_name]
110
+ else:
111
+ logger.debug(
112
+ f"Audio model {model_name} not found in ModelScope, "
113
+ f"now try to load it via builtin way."
114
+ )
115
+
104
116
  if model_name in BUILTIN_AUDIO_MODELS:
105
117
  return BUILTIN_AUDIO_MODELS[model_name]
106
118
  else:
@@ -130,10 +142,16 @@ def get_cache_status(
130
142
 
131
143
  def create_audio_model_instance(
132
144
  subpool_addr: str, devices: List[str], model_uid: str, model_name: str, **kwargs
133
- ) -> Tuple[WhisperModel, AudioModelDescription]:
145
+ ) -> Tuple[Union[WhisperModel, ChatTTSModel], AudioModelDescription]:
134
146
  model_spec = match_audio(model_name)
135
147
  model_path = cache(model_spec)
136
- model = WhisperModel(model_uid, model_path, model_spec, **kwargs)
148
+ model: Union[WhisperModel, ChatTTSModel]
149
+ if model_spec.model_family == "whisper":
150
+ model = WhisperModel(model_uid, model_path, model_spec, **kwargs)
151
+ elif model_spec.model_family == "ChatTTS":
152
+ model = ChatTTSModel(model_uid, model_path, model_spec, **kwargs)
153
+ else:
154
+ raise Exception(f"Unsupported audio model family: {model_spec.model_family}")
137
155
  model_description = AudioModelDescription(
138
156
  subpool_addr, devices, model_spec, model_path=model_path
139
157
  )
@@ -83,15 +83,17 @@ def get_user_defined_audios() -> List[CustomAudioModelFamilyV1]:
83
83
  def register_audio(model_spec: CustomAudioModelFamilyV1, persist: bool):
84
84
  from ...constants import XINFERENCE_MODEL_DIR
85
85
  from ..utils import is_valid_model_name, is_valid_model_uri
86
- from . import BUILTIN_AUDIO_MODELS
86
+ from . import BUILTIN_AUDIO_MODELS, MODELSCOPE_AUDIO_MODELS
87
87
 
88
88
  if not is_valid_model_name(model_spec.model_name):
89
89
  raise ValueError(f"Invalid model name {model_spec.model_name}.")
90
90
 
91
91
  with UD_AUDIO_LOCK:
92
- for model_name in list(BUILTIN_AUDIO_MODELS.keys()) + [
93
- spec.model_name for spec in UD_AUDIOS
94
- ]:
92
+ for model_name in (
93
+ list(BUILTIN_AUDIO_MODELS.keys())
94
+ + list(MODELSCOPE_AUDIO_MODELS.keys())
95
+ + [spec.model_name for spec in UD_AUDIOS]
96
+ ):
95
97
  if model_spec.model_name == model_name:
96
98
  raise ValueError(
97
99
  f"Model name conflicts with existing model {model_spec.model_name}"