xinference 0.12.0__py3-none-any.whl → 0.12.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +74 -6
- xinference/client/restful/restful_client.py +74 -5
- xinference/constants.py +1 -0
- xinference/core/cache_tracker.py +48 -28
- xinference/core/model.py +54 -42
- xinference/core/scheduler.py +34 -16
- xinference/core/supervisor.py +73 -24
- xinference/core/worker.py +68 -2
- xinference/deploy/cmdline.py +86 -2
- xinference/deploy/test/test_cmdline.py +19 -10
- xinference/model/audio/__init__.py +14 -1
- xinference/model/audio/core.py +12 -1
- xinference/model/audio/custom.py +6 -4
- xinference/model/audio/model_spec_modelscope.json +20 -0
- xinference/model/llm/__init__.py +34 -2
- xinference/model/llm/llm_family.json +2 -0
- xinference/model/llm/llm_family.py +86 -1
- xinference/model/llm/llm_family_csghub.json +66 -0
- xinference/model/llm/llm_family_modelscope.json +2 -0
- xinference/model/llm/pytorch/chatglm.py +18 -12
- xinference/model/llm/pytorch/core.py +92 -42
- xinference/model/llm/pytorch/glm4v.py +13 -3
- xinference/model/llm/pytorch/qwen_vl.py +1 -1
- xinference/model/llm/pytorch/utils.py +27 -14
- xinference/model/llm/utils.py +14 -13
- xinference/model/llm/vllm/core.py +10 -4
- xinference/model/utils.py +8 -2
- xinference/thirdparty/ChatTTS/experimental/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/experimental/llm.py +40 -0
- xinference/thirdparty/ChatTTS/infer/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/infer/api.py +125 -0
- xinference/thirdparty/ChatTTS/model/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/model/dvae.py +155 -0
- xinference/thirdparty/ChatTTS/model/gpt.py +265 -0
- xinference/thirdparty/ChatTTS/utils/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/utils/gpu_utils.py +23 -0
- xinference/thirdparty/ChatTTS/utils/infer_utils.py +141 -0
- xinference/thirdparty/ChatTTS/utils/io_utils.py +14 -0
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.074e2b31.css +2 -0
- xinference/web/ui/build/static/css/main.074e2b31.css.map +1 -0
- xinference/web/ui/build/static/js/main.a58ff436.js +3 -0
- xinference/web/ui/build/static/js/main.a58ff436.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/10262a281dec3bc2b185f4385ceb6846626f52d41cb4d46c7c649e719f979d4d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/762a75a62daf3bec2cfc97ec8612798493fb34ef87087dcad6aad64ab7f14345.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/7f3bdb3a48fa00c046c8b185acd4da6f2e2940a20dbd77f9373d60de3fd6633e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f2f73bfdc13b12b02c8cbc4769b0b8e6367e9b6d8331c322d94318491a0b3653.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f51bf63ddaa7afd125ef2254a105789333eecc1c94fdf5157a9b88ef7ad0a5bd.json +1 -0
- {xinference-0.12.0.dist-info → xinference-0.12.1.dist-info}/METADATA +1 -1
- {xinference-0.12.0.dist-info → xinference-0.12.1.dist-info}/RECORD +57 -45
- xinference/web/ui/build/static/css/main.54bca460.css +0 -2
- xinference/web/ui/build/static/css/main.54bca460.css.map +0 -1
- xinference/web/ui/build/static/js/main.551aa479.js +0 -3
- xinference/web/ui/build/static/js/main.551aa479.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1e86938a0cdf706d21e99b21f5d868fa247c0c88b26807047e26dcdc4d9a9db3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1fa824d82b2af519de7700c594e50bde4bbca60d13bd3fabff576802e4070304.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/3e737bcdbcbc407ccd65b90e199ef0c3214b261e8e41dbf14d921384a717d9ee.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/46b6dd1f6d1109cd0e2455a0ea0be3e9bda1097cd4ebec9c4040070372671cfc.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/59ce49eae0f486af4c5034d4d2f9ca77c3ec3a32ecc560085caf5ef482b5f4c9.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/a6da6bc3d0d2191adebee87fb58ecebe82d071087bd2f7f3a9c7fdd2ada130f2.json +0 -1
- /xinference/web/ui/build/static/js/{main.551aa479.js.LICENSE.txt → main.a58ff436.js.LICENSE.txt} +0 -0
- {xinference-0.12.0.dist-info → xinference-0.12.1.dist-info}/LICENSE +0 -0
- {xinference-0.12.0.dist-info → xinference-0.12.1.dist-info}/WHEEL +0 -0
- {xinference-0.12.0.dist-info → xinference-0.12.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.12.0.dist-info → xinference-0.12.1.dist-info}/top_level.txt +0 -0
xinference/core/worker.py
CHANGED
|
@@ -16,6 +16,7 @@ import asyncio
|
|
|
16
16
|
import os
|
|
17
17
|
import platform
|
|
18
18
|
import queue
|
|
19
|
+
import shutil
|
|
19
20
|
import signal
|
|
20
21
|
import threading
|
|
21
22
|
import time
|
|
@@ -786,8 +787,73 @@ class WorkerActor(xo.StatelessActor):
|
|
|
786
787
|
except asyncio.CancelledError: # pragma: no cover
|
|
787
788
|
break
|
|
788
789
|
|
|
789
|
-
async def list_cached_models(
|
|
790
|
-
|
|
790
|
+
async def list_cached_models(
|
|
791
|
+
self, model_name: Optional[str] = None
|
|
792
|
+
) -> List[Dict[Any, Any]]:
|
|
793
|
+
lists = await self._cache_tracker_ref.list_cached_models(
|
|
794
|
+
self.address, model_name
|
|
795
|
+
)
|
|
796
|
+
cached_models = []
|
|
797
|
+
for list in lists:
|
|
798
|
+
cached_model = {
|
|
799
|
+
"model_name": list.get("model_name"),
|
|
800
|
+
"model_size_in_billions": list.get("model_size_in_billions"),
|
|
801
|
+
"model_format": list.get("model_format"),
|
|
802
|
+
"quantization": list.get("quantization"),
|
|
803
|
+
"model_version": list.get("model_version"),
|
|
804
|
+
}
|
|
805
|
+
path = list.get("model_file_location")
|
|
806
|
+
cached_model["path"] = path
|
|
807
|
+
# parsing soft links
|
|
808
|
+
if os.path.isdir(path):
|
|
809
|
+
files = os.listdir(path)
|
|
810
|
+
# dir has files
|
|
811
|
+
if files:
|
|
812
|
+
resolved_file = os.path.realpath(os.path.join(path, files[0]))
|
|
813
|
+
if resolved_file:
|
|
814
|
+
cached_model["real_path"] = os.path.dirname(resolved_file)
|
|
815
|
+
else:
|
|
816
|
+
cached_model["real_path"] = os.path.realpath(path)
|
|
817
|
+
cached_model["actor_ip_address"] = self.address
|
|
818
|
+
cached_models.append(cached_model)
|
|
819
|
+
return cached_models
|
|
820
|
+
|
|
821
|
+
async def list_deletable_models(self, model_version: str) -> List[str]:
|
|
822
|
+
paths = set()
|
|
823
|
+
path = await self._cache_tracker_ref.list_deletable_models(
|
|
824
|
+
model_version, self.address
|
|
825
|
+
)
|
|
826
|
+
if os.path.isfile(path):
|
|
827
|
+
path = os.path.dirname(path)
|
|
828
|
+
|
|
829
|
+
if os.path.isdir(path):
|
|
830
|
+
files = os.listdir(path)
|
|
831
|
+
paths.update([os.path.join(path, file) for file in files])
|
|
832
|
+
# search real path
|
|
833
|
+
if paths:
|
|
834
|
+
paths.update([os.path.realpath(path) for path in paths])
|
|
835
|
+
|
|
836
|
+
return list(paths)
|
|
837
|
+
|
|
838
|
+
async def confirm_and_remove_model(self, model_version: str) -> bool:
|
|
839
|
+
paths = await self.list_deletable_models(model_version)
|
|
840
|
+
for path in paths:
|
|
841
|
+
try:
|
|
842
|
+
if os.path.islink(path):
|
|
843
|
+
os.unlink(path)
|
|
844
|
+
elif os.path.isfile(path):
|
|
845
|
+
os.remove(path)
|
|
846
|
+
elif os.path.isdir(path):
|
|
847
|
+
shutil.rmtree(path)
|
|
848
|
+
else:
|
|
849
|
+
logger.debug(f"{path} is not a valid path.")
|
|
850
|
+
except Exception as e:
|
|
851
|
+
logger.error(f"Fail to delete {path} with error:{e}.")
|
|
852
|
+
return False
|
|
853
|
+
await self._cache_tracker_ref.confirm_and_remove_model(
|
|
854
|
+
model_version, self.address
|
|
855
|
+
)
|
|
856
|
+
return True
|
|
791
857
|
|
|
792
858
|
@staticmethod
|
|
793
859
|
def record_metrics(name, op, kwargs):
|
xinference/deploy/cmdline.py
CHANGED
|
@@ -577,6 +577,18 @@ def list_model_registrations(
|
|
|
577
577
|
type=str,
|
|
578
578
|
help="Xinference endpoint.",
|
|
579
579
|
)
|
|
580
|
+
@click.option(
|
|
581
|
+
"--model_name",
|
|
582
|
+
"-n",
|
|
583
|
+
type=str,
|
|
584
|
+
help="Provide the name of the models to be removed.",
|
|
585
|
+
)
|
|
586
|
+
@click.option(
|
|
587
|
+
"--worker-ip",
|
|
588
|
+
default=None,
|
|
589
|
+
type=str,
|
|
590
|
+
help="Specify which worker this model runs on by ip, for distributed situation.",
|
|
591
|
+
)
|
|
580
592
|
@click.option(
|
|
581
593
|
"--api-key",
|
|
582
594
|
"-ak",
|
|
@@ -587,6 +599,8 @@ def list_model_registrations(
|
|
|
587
599
|
def list_cached_models(
|
|
588
600
|
endpoint: Optional[str],
|
|
589
601
|
api_key: Optional[str],
|
|
602
|
+
model_name: Optional[str],
|
|
603
|
+
worker_ip: Optional[str],
|
|
590
604
|
):
|
|
591
605
|
from tabulate import tabulate
|
|
592
606
|
|
|
@@ -595,10 +609,13 @@ def list_cached_models(
|
|
|
595
609
|
if api_key is None:
|
|
596
610
|
client._set_token(get_stored_token(endpoint, client))
|
|
597
611
|
|
|
598
|
-
cached_models = client.list_cached_models()
|
|
612
|
+
cached_models = client.list_cached_models(model_name, worker_ip)
|
|
613
|
+
if not cached_models:
|
|
614
|
+
print("There are no cache files.")
|
|
615
|
+
return
|
|
616
|
+
headers = list(cached_models[0].keys())
|
|
599
617
|
|
|
600
618
|
print("cached_model: ")
|
|
601
|
-
headers = list(cached_models[0].keys())
|
|
602
619
|
table_data = []
|
|
603
620
|
for model in cached_models:
|
|
604
621
|
row_data = [
|
|
@@ -608,6 +625,73 @@ def list_cached_models(
|
|
|
608
625
|
print(tabulate(table_data, headers=headers, tablefmt="pretty"))
|
|
609
626
|
|
|
610
627
|
|
|
628
|
+
@cli.command("remove-cache", help="Remove selected cached models in Xinference.")
|
|
629
|
+
@click.option(
|
|
630
|
+
"--endpoint",
|
|
631
|
+
"-e",
|
|
632
|
+
type=str,
|
|
633
|
+
help="Xinference endpoint.",
|
|
634
|
+
)
|
|
635
|
+
@click.option(
|
|
636
|
+
"--model_version",
|
|
637
|
+
"-n",
|
|
638
|
+
type=str,
|
|
639
|
+
help="Provide the version of the models to be removed.",
|
|
640
|
+
)
|
|
641
|
+
@click.option(
|
|
642
|
+
"--worker-ip",
|
|
643
|
+
default=None,
|
|
644
|
+
type=str,
|
|
645
|
+
help="Specify which worker this model runs on by ip, for distributed situation.",
|
|
646
|
+
)
|
|
647
|
+
@click.option(
|
|
648
|
+
"--api-key",
|
|
649
|
+
"-ak",
|
|
650
|
+
default=None,
|
|
651
|
+
type=str,
|
|
652
|
+
help="Api-Key for access xinference api with authorization.",
|
|
653
|
+
)
|
|
654
|
+
@click.option("--check", is_flag=True, help="Confirm the deletion of the cache.")
|
|
655
|
+
def remove_cache(
|
|
656
|
+
endpoint: Optional[str],
|
|
657
|
+
model_version: str,
|
|
658
|
+
api_key: Optional[str],
|
|
659
|
+
check: bool,
|
|
660
|
+
worker_ip: Optional[str] = None,
|
|
661
|
+
):
|
|
662
|
+
endpoint = get_endpoint(endpoint)
|
|
663
|
+
client = RESTfulClient(base_url=endpoint, api_key=api_key)
|
|
664
|
+
if api_key is None:
|
|
665
|
+
client._set_token(get_stored_token(endpoint, client))
|
|
666
|
+
|
|
667
|
+
if not check:
|
|
668
|
+
response = client.list_deletable_models(
|
|
669
|
+
model_version=model_version, worker_ip=worker_ip
|
|
670
|
+
)
|
|
671
|
+
paths: List[str] = response.get("paths", [])
|
|
672
|
+
if not paths:
|
|
673
|
+
click.echo(f"There is no model version named {model_version}.")
|
|
674
|
+
return
|
|
675
|
+
click.echo(f"Model {model_version} cache directory to be deleted:")
|
|
676
|
+
for path in response.get("paths", []):
|
|
677
|
+
click.echo(f"{path}")
|
|
678
|
+
|
|
679
|
+
if click.confirm("Do you want to proceed with the deletion?", abort=True):
|
|
680
|
+
check = True
|
|
681
|
+
try:
|
|
682
|
+
result = client.confirm_and_remove_model(
|
|
683
|
+
model_version=model_version, worker_ip=worker_ip
|
|
684
|
+
)
|
|
685
|
+
if result:
|
|
686
|
+
click.echo(f"Cache directory {model_version} has been deleted.")
|
|
687
|
+
else:
|
|
688
|
+
click.echo(
|
|
689
|
+
f"Cache directory {model_version} fail to be deleted. Please check the log."
|
|
690
|
+
)
|
|
691
|
+
except Exception as e:
|
|
692
|
+
click.echo(f"An error occurred while deleting the cache: {e}")
|
|
693
|
+
|
|
694
|
+
|
|
611
695
|
@cli.command(
|
|
612
696
|
"launch",
|
|
613
697
|
help="Launch a model with the Xinference framework with the given parameters.",
|
|
@@ -26,6 +26,7 @@ from ..cmdline import (
|
|
|
26
26
|
model_list,
|
|
27
27
|
model_terminate,
|
|
28
28
|
register_model,
|
|
29
|
+
remove_cache,
|
|
29
30
|
unregister_model,
|
|
30
31
|
)
|
|
31
32
|
|
|
@@ -287,18 +288,26 @@ def test_list_cached_models(setup):
|
|
|
287
288
|
|
|
288
289
|
result = runner.invoke(
|
|
289
290
|
list_cached_models,
|
|
290
|
-
[
|
|
291
|
-
"--endpoint",
|
|
292
|
-
endpoint,
|
|
293
|
-
],
|
|
291
|
+
["--endpoint", endpoint, "--model_name", "orca"],
|
|
294
292
|
)
|
|
295
|
-
assert result.exit_code == 0
|
|
296
|
-
assert "cached_model: " in result.stdout
|
|
297
|
-
|
|
298
|
-
# check if the output is in tabular format
|
|
299
293
|
assert "model_name" in result.stdout
|
|
300
294
|
assert "model_format" in result.stdout
|
|
301
295
|
assert "model_size_in_billions" in result.stdout
|
|
302
|
-
assert "
|
|
296
|
+
assert "quantization" in result.stdout
|
|
297
|
+
assert "model_version" in result.stdout
|
|
303
298
|
assert "path" in result.stdout
|
|
304
|
-
assert "
|
|
299
|
+
assert "actor_ip_address" in result.stdout
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def test_remove_cache(setup):
|
|
303
|
+
endpoint, _ = setup
|
|
304
|
+
runner = CliRunner()
|
|
305
|
+
|
|
306
|
+
result = runner.invoke(
|
|
307
|
+
remove_cache,
|
|
308
|
+
["--endpoint", endpoint, "--model_version", "orca"],
|
|
309
|
+
input="y\n",
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
assert result.exit_code == 0
|
|
313
|
+
assert "Cache directory orca has been deleted."
|
|
@@ -32,6 +32,9 @@ from .custom import (
|
|
|
32
32
|
)
|
|
33
33
|
|
|
34
34
|
_model_spec_json = os.path.join(os.path.dirname(__file__), "model_spec.json")
|
|
35
|
+
_model_spec_modelscope_json = os.path.join(
|
|
36
|
+
os.path.dirname(__file__), "model_spec_modelscope.json"
|
|
37
|
+
)
|
|
35
38
|
BUILTIN_AUDIO_MODELS = dict(
|
|
36
39
|
(spec["model_name"], AudioModelFamilyV1(**spec))
|
|
37
40
|
for spec in json.load(codecs.open(_model_spec_json, "r", encoding="utf-8"))
|
|
@@ -39,8 +42,17 @@ BUILTIN_AUDIO_MODELS = dict(
|
|
|
39
42
|
for model_name, model_spec in BUILTIN_AUDIO_MODELS.items():
|
|
40
43
|
MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)
|
|
41
44
|
|
|
45
|
+
MODELSCOPE_AUDIO_MODELS = dict(
|
|
46
|
+
(spec["model_name"], AudioModelFamilyV1(**spec))
|
|
47
|
+
for spec in json.load(
|
|
48
|
+
codecs.open(_model_spec_modelscope_json, "r", encoding="utf-8")
|
|
49
|
+
)
|
|
50
|
+
)
|
|
51
|
+
for model_name, model_spec in MODELSCOPE_AUDIO_MODELS.items():
|
|
52
|
+
MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)
|
|
53
|
+
|
|
42
54
|
# register model description after recording model revision
|
|
43
|
-
for model_spec_info in [BUILTIN_AUDIO_MODELS]:
|
|
55
|
+
for model_spec_info in [BUILTIN_AUDIO_MODELS, MODELSCOPE_AUDIO_MODELS]:
|
|
44
56
|
for model_name, model_spec in model_spec_info.items():
|
|
45
57
|
if model_spec.model_name not in AUDIO_MODEL_DESCRIPTIONS:
|
|
46
58
|
AUDIO_MODEL_DESCRIPTIONS.update(generate_audio_description(model_spec))
|
|
@@ -64,3 +76,4 @@ for ud_audio in get_user_defined_audios():
|
|
|
64
76
|
AUDIO_MODEL_DESCRIPTIONS.update(generate_audio_description(ud_audio))
|
|
65
77
|
|
|
66
78
|
del _model_spec_json
|
|
79
|
+
del _model_spec_modelscope_json
|
xinference/model/audio/core.py
CHANGED
|
@@ -95,13 +95,24 @@ def generate_audio_description(
|
|
|
95
95
|
|
|
96
96
|
|
|
97
97
|
def match_audio(model_name: str) -> AudioModelFamilyV1:
|
|
98
|
-
from
|
|
98
|
+
from ..utils import download_from_modelscope
|
|
99
|
+
from . import BUILTIN_AUDIO_MODELS, MODELSCOPE_AUDIO_MODELS
|
|
99
100
|
from .custom import get_user_defined_audios
|
|
100
101
|
|
|
101
102
|
for model_spec in get_user_defined_audios():
|
|
102
103
|
if model_spec.model_name == model_name:
|
|
103
104
|
return model_spec
|
|
104
105
|
|
|
106
|
+
if download_from_modelscope():
|
|
107
|
+
if model_name in MODELSCOPE_AUDIO_MODELS:
|
|
108
|
+
logger.debug(f"Audio model {model_name} found in ModelScope.")
|
|
109
|
+
return MODELSCOPE_AUDIO_MODELS[model_name]
|
|
110
|
+
else:
|
|
111
|
+
logger.debug(
|
|
112
|
+
f"Audio model {model_name} not found in ModelScope, "
|
|
113
|
+
f"now try to load it via builtin way."
|
|
114
|
+
)
|
|
115
|
+
|
|
105
116
|
if model_name in BUILTIN_AUDIO_MODELS:
|
|
106
117
|
return BUILTIN_AUDIO_MODELS[model_name]
|
|
107
118
|
else:
|
xinference/model/audio/custom.py
CHANGED
|
@@ -83,15 +83,17 @@ def get_user_defined_audios() -> List[CustomAudioModelFamilyV1]:
|
|
|
83
83
|
def register_audio(model_spec: CustomAudioModelFamilyV1, persist: bool):
|
|
84
84
|
from ...constants import XINFERENCE_MODEL_DIR
|
|
85
85
|
from ..utils import is_valid_model_name, is_valid_model_uri
|
|
86
|
-
from . import BUILTIN_AUDIO_MODELS
|
|
86
|
+
from . import BUILTIN_AUDIO_MODELS, MODELSCOPE_AUDIO_MODELS
|
|
87
87
|
|
|
88
88
|
if not is_valid_model_name(model_spec.model_name):
|
|
89
89
|
raise ValueError(f"Invalid model name {model_spec.model_name}.")
|
|
90
90
|
|
|
91
91
|
with UD_AUDIO_LOCK:
|
|
92
|
-
for model_name in
|
|
93
|
-
|
|
94
|
-
|
|
92
|
+
for model_name in (
|
|
93
|
+
list(BUILTIN_AUDIO_MODELS.keys())
|
|
94
|
+
+ list(MODELSCOPE_AUDIO_MODELS.keys())
|
|
95
|
+
+ [spec.model_name for spec in UD_AUDIOS]
|
|
96
|
+
):
|
|
95
97
|
if model_spec.model_name == model_name:
|
|
96
98
|
raise ValueError(
|
|
97
99
|
f"Model name conflicts with existing model {model_spec.model_name}"
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
[
|
|
2
|
+
{
|
|
3
|
+
"model_name": "whisper-large-v3",
|
|
4
|
+
"model_family": "whisper",
|
|
5
|
+
"model_hub": "modelscope",
|
|
6
|
+
"model_id": "AI-ModelScope/whisper-large-v3",
|
|
7
|
+
"model_revision": "master",
|
|
8
|
+
"ability": "audio-to-text",
|
|
9
|
+
"multilingual": true
|
|
10
|
+
},
|
|
11
|
+
{
|
|
12
|
+
"model_name": "ChatTTS",
|
|
13
|
+
"model_family": "ChatTTS",
|
|
14
|
+
"model_hub": "modelscope",
|
|
15
|
+
"model_id": "pzc163/chatTTS",
|
|
16
|
+
"model_revision": "master",
|
|
17
|
+
"ability": "text-to-audio",
|
|
18
|
+
"multilingual": true
|
|
19
|
+
}
|
|
20
|
+
]
|
xinference/model/llm/__init__.py
CHANGED
|
@@ -25,6 +25,7 @@ from .core import (
|
|
|
25
25
|
get_llm_model_descriptions,
|
|
26
26
|
)
|
|
27
27
|
from .llm_family import (
|
|
28
|
+
BUILTIN_CSGHUB_LLM_FAMILIES,
|
|
28
29
|
BUILTIN_LLM_FAMILIES,
|
|
29
30
|
BUILTIN_LLM_MODEL_CHAT_FAMILIES,
|
|
30
31
|
BUILTIN_LLM_MODEL_GENERATE_FAMILIES,
|
|
@@ -221,13 +222,44 @@ def _install():
|
|
|
221
222
|
if "tools" in model_spec.model_ability:
|
|
222
223
|
BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES.add(model_spec.model_name)
|
|
223
224
|
|
|
224
|
-
|
|
225
|
+
csghub_json_path = os.path.join(
|
|
226
|
+
os.path.dirname(os.path.abspath(__file__)), "llm_family_csghub.json"
|
|
227
|
+
)
|
|
228
|
+
for json_obj in json.load(codecs.open(csghub_json_path, "r", encoding="utf-8")):
|
|
229
|
+
model_spec = LLMFamilyV1.parse_obj(json_obj)
|
|
230
|
+
BUILTIN_CSGHUB_LLM_FAMILIES.append(model_spec)
|
|
231
|
+
|
|
232
|
+
# register prompt style, in case that we have something missed
|
|
233
|
+
# if duplicated with huggingface json, keep it as the huggingface style
|
|
234
|
+
if (
|
|
235
|
+
"chat" in model_spec.model_ability
|
|
236
|
+
and isinstance(model_spec.prompt_style, PromptStyleV1)
|
|
237
|
+
and model_spec.model_name not in BUILTIN_LLM_PROMPT_STYLE
|
|
238
|
+
):
|
|
239
|
+
BUILTIN_LLM_PROMPT_STYLE[model_spec.model_name] = model_spec.prompt_style
|
|
240
|
+
# register model family
|
|
241
|
+
if "chat" in model_spec.model_ability:
|
|
242
|
+
BUILTIN_LLM_MODEL_CHAT_FAMILIES.add(model_spec.model_name)
|
|
243
|
+
else:
|
|
244
|
+
BUILTIN_LLM_MODEL_GENERATE_FAMILIES.add(model_spec.model_name)
|
|
245
|
+
if "tools" in model_spec.model_ability:
|
|
246
|
+
BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES.add(model_spec.model_name)
|
|
247
|
+
|
|
248
|
+
for llm_specs in [
|
|
249
|
+
BUILTIN_LLM_FAMILIES,
|
|
250
|
+
BUILTIN_MODELSCOPE_LLM_FAMILIES,
|
|
251
|
+
BUILTIN_CSGHUB_LLM_FAMILIES,
|
|
252
|
+
]:
|
|
225
253
|
for llm_spec in llm_specs:
|
|
226
254
|
if llm_spec.model_name not in LLM_MODEL_DESCRIPTIONS:
|
|
227
255
|
LLM_MODEL_DESCRIPTIONS.update(generate_llm_description(llm_spec))
|
|
228
256
|
|
|
229
257
|
# traverse all families and add engine parameters corresponding to the model name
|
|
230
|
-
for families in [
|
|
258
|
+
for families in [
|
|
259
|
+
BUILTIN_LLM_FAMILIES,
|
|
260
|
+
BUILTIN_MODELSCOPE_LLM_FAMILIES,
|
|
261
|
+
BUILTIN_CSGHUB_LLM_FAMILIES,
|
|
262
|
+
]:
|
|
231
263
|
for family in families:
|
|
232
264
|
generate_engine_config_by_model_family(family)
|
|
233
265
|
|
|
@@ -32,10 +32,15 @@ from ..._compat import (
|
|
|
32
32
|
load_str_bytes,
|
|
33
33
|
validator,
|
|
34
34
|
)
|
|
35
|
-
from ...constants import
|
|
35
|
+
from ...constants import (
|
|
36
|
+
XINFERENCE_CACHE_DIR,
|
|
37
|
+
XINFERENCE_ENV_CSG_TOKEN,
|
|
38
|
+
XINFERENCE_MODEL_DIR,
|
|
39
|
+
)
|
|
36
40
|
from ..utils import (
|
|
37
41
|
IS_NEW_HUGGINGFACE_HUB,
|
|
38
42
|
create_symlink,
|
|
43
|
+
download_from_csghub,
|
|
39
44
|
download_from_modelscope,
|
|
40
45
|
is_valid_model_uri,
|
|
41
46
|
parse_uri,
|
|
@@ -232,6 +237,7 @@ LLAMA_CLASSES: List[Type[LLM]] = []
|
|
|
232
237
|
|
|
233
238
|
BUILTIN_LLM_FAMILIES: List["LLMFamilyV1"] = []
|
|
234
239
|
BUILTIN_MODELSCOPE_LLM_FAMILIES: List["LLMFamilyV1"] = []
|
|
240
|
+
BUILTIN_CSGHUB_LLM_FAMILIES: List["LLMFamilyV1"] = []
|
|
235
241
|
|
|
236
242
|
SGLANG_CLASSES: List[Type[LLM]] = []
|
|
237
243
|
TRANSFORMERS_CLASSES: List[Type[LLM]] = []
|
|
@@ -292,6 +298,9 @@ def cache(
|
|
|
292
298
|
elif llm_spec.model_hub == "modelscope":
|
|
293
299
|
logger.info(f"Caching from Modelscope: {llm_spec.model_id}")
|
|
294
300
|
return cache_from_modelscope(llm_family, llm_spec, quantization)
|
|
301
|
+
elif llm_spec.model_hub == "csghub":
|
|
302
|
+
logger.info(f"Caching from CSGHub: {llm_spec.model_id}")
|
|
303
|
+
return cache_from_csghub(llm_family, llm_spec, quantization)
|
|
295
304
|
else:
|
|
296
305
|
raise ValueError(f"Unknown model hub: {llm_spec.model_hub}")
|
|
297
306
|
|
|
@@ -566,6 +575,7 @@ def _skip_download(
|
|
|
566
575
|
"modelscope": _get_meta_path(
|
|
567
576
|
cache_dir, model_format, "modelscope", quantization
|
|
568
577
|
),
|
|
578
|
+
"csghub": _get_meta_path(cache_dir, model_format, "csghub", quantization),
|
|
569
579
|
}
|
|
570
580
|
if valid_model_revision(model_hub_to_meta_path[model_hub], model_revision):
|
|
571
581
|
logger.info(f"Cache {cache_dir} exists")
|
|
@@ -650,6 +660,75 @@ def _merge_cached_files(
|
|
|
650
660
|
logger.info(f"Merge complete.")
|
|
651
661
|
|
|
652
662
|
|
|
663
|
+
def cache_from_csghub(
|
|
664
|
+
llm_family: LLMFamilyV1,
|
|
665
|
+
llm_spec: "LLMSpecV1",
|
|
666
|
+
quantization: Optional[str] = None,
|
|
667
|
+
) -> str:
|
|
668
|
+
"""
|
|
669
|
+
Cache model from CSGHub. Return the cache directory.
|
|
670
|
+
"""
|
|
671
|
+
from pycsghub.file_download import file_download
|
|
672
|
+
from pycsghub.snapshot_download import snapshot_download
|
|
673
|
+
|
|
674
|
+
cache_dir = _get_cache_dir(llm_family, llm_spec)
|
|
675
|
+
|
|
676
|
+
if _skip_download(
|
|
677
|
+
cache_dir,
|
|
678
|
+
llm_spec.model_format,
|
|
679
|
+
llm_spec.model_hub,
|
|
680
|
+
llm_spec.model_revision,
|
|
681
|
+
quantization,
|
|
682
|
+
):
|
|
683
|
+
return cache_dir
|
|
684
|
+
|
|
685
|
+
if llm_spec.model_format in ["pytorch", "gptq", "awq"]:
|
|
686
|
+
download_dir = retry_download(
|
|
687
|
+
snapshot_download,
|
|
688
|
+
llm_family.model_name,
|
|
689
|
+
{
|
|
690
|
+
"model_size": llm_spec.model_size_in_billions,
|
|
691
|
+
"model_format": llm_spec.model_format,
|
|
692
|
+
},
|
|
693
|
+
llm_spec.model_id,
|
|
694
|
+
endpoint="https://hub-stg.opencsg.com",
|
|
695
|
+
token=os.environ.get(XINFERENCE_ENV_CSG_TOKEN),
|
|
696
|
+
)
|
|
697
|
+
create_symlink(download_dir, cache_dir)
|
|
698
|
+
|
|
699
|
+
elif llm_spec.model_format in ["ggmlv3", "ggufv2"]:
|
|
700
|
+
file_names, final_file_name, need_merge = _generate_model_file_names(
|
|
701
|
+
llm_spec, quantization
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
for filename in file_names:
|
|
705
|
+
download_path = retry_download(
|
|
706
|
+
file_download,
|
|
707
|
+
llm_family.model_name,
|
|
708
|
+
{
|
|
709
|
+
"model_size": llm_spec.model_size_in_billions,
|
|
710
|
+
"model_format": llm_spec.model_format,
|
|
711
|
+
},
|
|
712
|
+
llm_spec.model_id,
|
|
713
|
+
file_name=filename,
|
|
714
|
+
endpoint="https://hub-stg.opencsg.com",
|
|
715
|
+
token=os.environ.get(XINFERENCE_ENV_CSG_TOKEN),
|
|
716
|
+
)
|
|
717
|
+
symlink_local_file(download_path, cache_dir, filename)
|
|
718
|
+
|
|
719
|
+
if need_merge:
|
|
720
|
+
_merge_cached_files(cache_dir, file_names, final_file_name)
|
|
721
|
+
else:
|
|
722
|
+
raise ValueError(f"Unsupported format: {llm_spec.model_format}")
|
|
723
|
+
|
|
724
|
+
meta_path = _get_meta_path(
|
|
725
|
+
cache_dir, llm_spec.model_format, llm_spec.model_hub, quantization
|
|
726
|
+
)
|
|
727
|
+
_generate_meta_file(meta_path, llm_family, llm_spec, quantization)
|
|
728
|
+
|
|
729
|
+
return cache_dir
|
|
730
|
+
|
|
731
|
+
|
|
653
732
|
def cache_from_modelscope(
|
|
654
733
|
llm_family: LLMFamilyV1,
|
|
655
734
|
llm_spec: "LLMSpecV1",
|
|
@@ -931,6 +1010,12 @@ def match_llm(
|
|
|
931
1010
|
+ BUILTIN_LLM_FAMILIES
|
|
932
1011
|
+ user_defined_llm_families
|
|
933
1012
|
)
|
|
1013
|
+
elif download_from_csghub():
|
|
1014
|
+
all_families = (
|
|
1015
|
+
BUILTIN_CSGHUB_LLM_FAMILIES
|
|
1016
|
+
+ BUILTIN_LLM_FAMILIES
|
|
1017
|
+
+ user_defined_llm_families
|
|
1018
|
+
)
|
|
934
1019
|
else:
|
|
935
1020
|
all_families = BUILTIN_LLM_FAMILIES + user_defined_llm_families
|
|
936
1021
|
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
[
|
|
2
|
+
{
|
|
3
|
+
"version": 1,
|
|
4
|
+
"context_length": 32768,
|
|
5
|
+
"model_name": "qwen2-instruct",
|
|
6
|
+
"model_lang": [
|
|
7
|
+
"en",
|
|
8
|
+
"zh"
|
|
9
|
+
],
|
|
10
|
+
"model_ability": [
|
|
11
|
+
"chat",
|
|
12
|
+
"tools"
|
|
13
|
+
],
|
|
14
|
+
"model_description": "Qwen2 is the new series of Qwen large language models",
|
|
15
|
+
"model_specs": [
|
|
16
|
+
{
|
|
17
|
+
"model_format": "pytorch",
|
|
18
|
+
"model_size_in_billions": "0_5",
|
|
19
|
+
"quantizations": [
|
|
20
|
+
"4-bit",
|
|
21
|
+
"8-bit",
|
|
22
|
+
"none"
|
|
23
|
+
],
|
|
24
|
+
"model_id": "Qwen/Qwen2-0.5B-Instruct",
|
|
25
|
+
"model_hub": "csghub"
|
|
26
|
+
},
|
|
27
|
+
{
|
|
28
|
+
"model_format": "ggufv2",
|
|
29
|
+
"model_size_in_billions": "0_5",
|
|
30
|
+
"quantizations": [
|
|
31
|
+
"q2_k",
|
|
32
|
+
"q3_k_m",
|
|
33
|
+
"q4_0",
|
|
34
|
+
"q4_k_m",
|
|
35
|
+
"q5_0",
|
|
36
|
+
"q5_k_m",
|
|
37
|
+
"q6_k",
|
|
38
|
+
"q8_0",
|
|
39
|
+
"fp16"
|
|
40
|
+
],
|
|
41
|
+
"model_id": "qwen/Qwen2-0.5B-Instruct-GGUF",
|
|
42
|
+
"model_file_name_template": "qwen2-0_5b-instruct-{quantization}.gguf",
|
|
43
|
+
"model_hub": "csghub"
|
|
44
|
+
}
|
|
45
|
+
],
|
|
46
|
+
"prompt_style": {
|
|
47
|
+
"style_name": "QWEN",
|
|
48
|
+
"system_prompt": "You are a helpful assistant.",
|
|
49
|
+
"roles": [
|
|
50
|
+
"user",
|
|
51
|
+
"assistant"
|
|
52
|
+
],
|
|
53
|
+
"intra_message_sep": "\n",
|
|
54
|
+
"stop_token_ids": [
|
|
55
|
+
151643,
|
|
56
|
+
151644,
|
|
57
|
+
151645
|
|
58
|
+
],
|
|
59
|
+
"stop": [
|
|
60
|
+
"<|endoftext|>",
|
|
61
|
+
"<|im_start|>",
|
|
62
|
+
"<|im_end|>"
|
|
63
|
+
]
|
|
64
|
+
}
|
|
65
|
+
}
|
|
66
|
+
]
|
|
@@ -89,24 +89,30 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
89
89
|
return False
|
|
90
90
|
return True
|
|
91
91
|
|
|
92
|
-
|
|
93
|
-
def _handle_tools(generate_config) -> Optional[dict]:
|
|
92
|
+
def _handle_tools(self, generate_config) -> Optional[dict]:
|
|
94
93
|
"""Convert openai tools to ChatGLM tools."""
|
|
95
94
|
if generate_config is None:
|
|
96
95
|
return None
|
|
97
96
|
tools = generate_config.pop("tools", None)
|
|
98
97
|
if tools is None:
|
|
99
98
|
return None
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
99
|
+
if self.model_family.model_name == "glm4-chat":
|
|
100
|
+
return {
|
|
101
|
+
"role": "system",
|
|
102
|
+
"content": None,
|
|
103
|
+
"tools": tools,
|
|
104
|
+
}
|
|
105
|
+
else:
|
|
106
|
+
chatglm_tools = []
|
|
107
|
+
for elem in tools:
|
|
108
|
+
if elem.get("type") != "function" or "function" not in elem:
|
|
109
|
+
raise ValueError("ChatGLM tools only support function type.")
|
|
110
|
+
chatglm_tools.append(elem["function"])
|
|
111
|
+
return {
|
|
112
|
+
"role": "system",
|
|
113
|
+
"content": f"Answer the following questions as best as you can. You have access to the following tools:",
|
|
114
|
+
"tools": chatglm_tools,
|
|
115
|
+
}
|
|
110
116
|
|
|
111
117
|
def chat(
|
|
112
118
|
self,
|