xinference 0.12.0__py3-none-any.whl → 0.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (67) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +74 -6
  3. xinference/client/restful/restful_client.py +74 -5
  4. xinference/constants.py +1 -0
  5. xinference/core/cache_tracker.py +48 -28
  6. xinference/core/model.py +54 -42
  7. xinference/core/scheduler.py +34 -16
  8. xinference/core/supervisor.py +73 -24
  9. xinference/core/worker.py +68 -2
  10. xinference/deploy/cmdline.py +86 -2
  11. xinference/deploy/test/test_cmdline.py +19 -10
  12. xinference/model/audio/__init__.py +14 -1
  13. xinference/model/audio/core.py +12 -1
  14. xinference/model/audio/custom.py +6 -4
  15. xinference/model/audio/model_spec_modelscope.json +20 -0
  16. xinference/model/llm/__init__.py +34 -2
  17. xinference/model/llm/llm_family.json +2 -0
  18. xinference/model/llm/llm_family.py +86 -1
  19. xinference/model/llm/llm_family_csghub.json +66 -0
  20. xinference/model/llm/llm_family_modelscope.json +2 -0
  21. xinference/model/llm/pytorch/chatglm.py +18 -12
  22. xinference/model/llm/pytorch/core.py +92 -42
  23. xinference/model/llm/pytorch/glm4v.py +13 -3
  24. xinference/model/llm/pytorch/qwen_vl.py +1 -1
  25. xinference/model/llm/pytorch/utils.py +27 -14
  26. xinference/model/llm/utils.py +14 -13
  27. xinference/model/llm/vllm/core.py +10 -4
  28. xinference/model/utils.py +8 -2
  29. xinference/thirdparty/ChatTTS/experimental/__init__.py +0 -0
  30. xinference/thirdparty/ChatTTS/experimental/llm.py +40 -0
  31. xinference/thirdparty/ChatTTS/infer/__init__.py +0 -0
  32. xinference/thirdparty/ChatTTS/infer/api.py +125 -0
  33. xinference/thirdparty/ChatTTS/model/__init__.py +0 -0
  34. xinference/thirdparty/ChatTTS/model/dvae.py +155 -0
  35. xinference/thirdparty/ChatTTS/model/gpt.py +265 -0
  36. xinference/thirdparty/ChatTTS/utils/__init__.py +0 -0
  37. xinference/thirdparty/ChatTTS/utils/gpu_utils.py +23 -0
  38. xinference/thirdparty/ChatTTS/utils/infer_utils.py +141 -0
  39. xinference/thirdparty/ChatTTS/utils/io_utils.py +14 -0
  40. xinference/web/ui/build/asset-manifest.json +6 -6
  41. xinference/web/ui/build/index.html +1 -1
  42. xinference/web/ui/build/static/css/main.074e2b31.css +2 -0
  43. xinference/web/ui/build/static/css/main.074e2b31.css.map +1 -0
  44. xinference/web/ui/build/static/js/main.a58ff436.js +3 -0
  45. xinference/web/ui/build/static/js/main.a58ff436.js.map +1 -0
  46. xinference/web/ui/node_modules/.cache/babel-loader/10262a281dec3bc2b185f4385ceb6846626f52d41cb4d46c7c649e719f979d4d.json +1 -0
  47. xinference/web/ui/node_modules/.cache/babel-loader/762a75a62daf3bec2cfc97ec8612798493fb34ef87087dcad6aad64ab7f14345.json +1 -0
  48. xinference/web/ui/node_modules/.cache/babel-loader/7f3bdb3a48fa00c046c8b185acd4da6f2e2940a20dbd77f9373d60de3fd6633e.json +1 -0
  49. xinference/web/ui/node_modules/.cache/babel-loader/f2f73bfdc13b12b02c8cbc4769b0b8e6367e9b6d8331c322d94318491a0b3653.json +1 -0
  50. xinference/web/ui/node_modules/.cache/babel-loader/f51bf63ddaa7afd125ef2254a105789333eecc1c94fdf5157a9b88ef7ad0a5bd.json +1 -0
  51. {xinference-0.12.0.dist-info → xinference-0.12.1.dist-info}/METADATA +1 -1
  52. {xinference-0.12.0.dist-info → xinference-0.12.1.dist-info}/RECORD +57 -45
  53. xinference/web/ui/build/static/css/main.54bca460.css +0 -2
  54. xinference/web/ui/build/static/css/main.54bca460.css.map +0 -1
  55. xinference/web/ui/build/static/js/main.551aa479.js +0 -3
  56. xinference/web/ui/build/static/js/main.551aa479.js.map +0 -1
  57. xinference/web/ui/node_modules/.cache/babel-loader/1e86938a0cdf706d21e99b21f5d868fa247c0c88b26807047e26dcdc4d9a9db3.json +0 -1
  58. xinference/web/ui/node_modules/.cache/babel-loader/1fa824d82b2af519de7700c594e50bde4bbca60d13bd3fabff576802e4070304.json +0 -1
  59. xinference/web/ui/node_modules/.cache/babel-loader/3e737bcdbcbc407ccd65b90e199ef0c3214b261e8e41dbf14d921384a717d9ee.json +0 -1
  60. xinference/web/ui/node_modules/.cache/babel-loader/46b6dd1f6d1109cd0e2455a0ea0be3e9bda1097cd4ebec9c4040070372671cfc.json +0 -1
  61. xinference/web/ui/node_modules/.cache/babel-loader/59ce49eae0f486af4c5034d4d2f9ca77c3ec3a32ecc560085caf5ef482b5f4c9.json +0 -1
  62. xinference/web/ui/node_modules/.cache/babel-loader/a6da6bc3d0d2191adebee87fb58ecebe82d071087bd2f7f3a9c7fdd2ada130f2.json +0 -1
  63. /xinference/web/ui/build/static/js/{main.551aa479.js.LICENSE.txt → main.a58ff436.js.LICENSE.txt} +0 -0
  64. {xinference-0.12.0.dist-info → xinference-0.12.1.dist-info}/LICENSE +0 -0
  65. {xinference-0.12.0.dist-info → xinference-0.12.1.dist-info}/WHEEL +0 -0
  66. {xinference-0.12.0.dist-info → xinference-0.12.1.dist-info}/entry_points.txt +0 -0
  67. {xinference-0.12.0.dist-info → xinference-0.12.1.dist-info}/top_level.txt +0 -0
xinference/core/worker.py CHANGED
@@ -16,6 +16,7 @@ import asyncio
16
16
  import os
17
17
  import platform
18
18
  import queue
19
+ import shutil
19
20
  import signal
20
21
  import threading
21
22
  import time
@@ -786,8 +787,73 @@ class WorkerActor(xo.StatelessActor):
786
787
  except asyncio.CancelledError: # pragma: no cover
787
788
  break
788
789
 
789
- async def list_cached_models(self) -> List[Dict[Any, Any]]:
790
- return self._cache_tracker_ref.list_cached_models()
790
+ async def list_cached_models(
791
+ self, model_name: Optional[str] = None
792
+ ) -> List[Dict[Any, Any]]:
793
+ lists = await self._cache_tracker_ref.list_cached_models(
794
+ self.address, model_name
795
+ )
796
+ cached_models = []
797
+ for list in lists:
798
+ cached_model = {
799
+ "model_name": list.get("model_name"),
800
+ "model_size_in_billions": list.get("model_size_in_billions"),
801
+ "model_format": list.get("model_format"),
802
+ "quantization": list.get("quantization"),
803
+ "model_version": list.get("model_version"),
804
+ }
805
+ path = list.get("model_file_location")
806
+ cached_model["path"] = path
807
+ # parsing soft links
808
+ if os.path.isdir(path):
809
+ files = os.listdir(path)
810
+ # dir has files
811
+ if files:
812
+ resolved_file = os.path.realpath(os.path.join(path, files[0]))
813
+ if resolved_file:
814
+ cached_model["real_path"] = os.path.dirname(resolved_file)
815
+ else:
816
+ cached_model["real_path"] = os.path.realpath(path)
817
+ cached_model["actor_ip_address"] = self.address
818
+ cached_models.append(cached_model)
819
+ return cached_models
820
+
821
+ async def list_deletable_models(self, model_version: str) -> List[str]:
822
+ paths = set()
823
+ path = await self._cache_tracker_ref.list_deletable_models(
824
+ model_version, self.address
825
+ )
826
+ if os.path.isfile(path):
827
+ path = os.path.dirname(path)
828
+
829
+ if os.path.isdir(path):
830
+ files = os.listdir(path)
831
+ paths.update([os.path.join(path, file) for file in files])
832
+ # search real path
833
+ if paths:
834
+ paths.update([os.path.realpath(path) for path in paths])
835
+
836
+ return list(paths)
837
+
838
+ async def confirm_and_remove_model(self, model_version: str) -> bool:
839
+ paths = await self.list_deletable_models(model_version)
840
+ for path in paths:
841
+ try:
842
+ if os.path.islink(path):
843
+ os.unlink(path)
844
+ elif os.path.isfile(path):
845
+ os.remove(path)
846
+ elif os.path.isdir(path):
847
+ shutil.rmtree(path)
848
+ else:
849
+ logger.debug(f"{path} is not a valid path.")
850
+ except Exception as e:
851
+ logger.error(f"Fail to delete {path} with error:{e}.")
852
+ return False
853
+ await self._cache_tracker_ref.confirm_and_remove_model(
854
+ model_version, self.address
855
+ )
856
+ return True
791
857
 
792
858
  @staticmethod
793
859
  def record_metrics(name, op, kwargs):
@@ -577,6 +577,18 @@ def list_model_registrations(
577
577
  type=str,
578
578
  help="Xinference endpoint.",
579
579
  )
580
+ @click.option(
581
+ "--model_name",
582
+ "-n",
583
+ type=str,
584
+ help="Provide the name of the models to be removed.",
585
+ )
586
+ @click.option(
587
+ "--worker-ip",
588
+ default=None,
589
+ type=str,
590
+ help="Specify which worker this model runs on by ip, for distributed situation.",
591
+ )
580
592
  @click.option(
581
593
  "--api-key",
582
594
  "-ak",
@@ -587,6 +599,8 @@ def list_model_registrations(
587
599
  def list_cached_models(
588
600
  endpoint: Optional[str],
589
601
  api_key: Optional[str],
602
+ model_name: Optional[str],
603
+ worker_ip: Optional[str],
590
604
  ):
591
605
  from tabulate import tabulate
592
606
 
@@ -595,10 +609,13 @@ def list_cached_models(
595
609
  if api_key is None:
596
610
  client._set_token(get_stored_token(endpoint, client))
597
611
 
598
- cached_models = client.list_cached_models()
612
+ cached_models = client.list_cached_models(model_name, worker_ip)
613
+ if not cached_models:
614
+ print("There are no cache files.")
615
+ return
616
+ headers = list(cached_models[0].keys())
599
617
 
600
618
  print("cached_model: ")
601
- headers = list(cached_models[0].keys())
602
619
  table_data = []
603
620
  for model in cached_models:
604
621
  row_data = [
@@ -608,6 +625,73 @@ def list_cached_models(
608
625
  print(tabulate(table_data, headers=headers, tablefmt="pretty"))
609
626
 
610
627
 
628
+ @cli.command("remove-cache", help="Remove selected cached models in Xinference.")
629
+ @click.option(
630
+ "--endpoint",
631
+ "-e",
632
+ type=str,
633
+ help="Xinference endpoint.",
634
+ )
635
+ @click.option(
636
+ "--model_version",
637
+ "-n",
638
+ type=str,
639
+ help="Provide the version of the models to be removed.",
640
+ )
641
+ @click.option(
642
+ "--worker-ip",
643
+ default=None,
644
+ type=str,
645
+ help="Specify which worker this model runs on by ip, for distributed situation.",
646
+ )
647
+ @click.option(
648
+ "--api-key",
649
+ "-ak",
650
+ default=None,
651
+ type=str,
652
+ help="Api-Key for access xinference api with authorization.",
653
+ )
654
+ @click.option("--check", is_flag=True, help="Confirm the deletion of the cache.")
655
+ def remove_cache(
656
+ endpoint: Optional[str],
657
+ model_version: str,
658
+ api_key: Optional[str],
659
+ check: bool,
660
+ worker_ip: Optional[str] = None,
661
+ ):
662
+ endpoint = get_endpoint(endpoint)
663
+ client = RESTfulClient(base_url=endpoint, api_key=api_key)
664
+ if api_key is None:
665
+ client._set_token(get_stored_token(endpoint, client))
666
+
667
+ if not check:
668
+ response = client.list_deletable_models(
669
+ model_version=model_version, worker_ip=worker_ip
670
+ )
671
+ paths: List[str] = response.get("paths", [])
672
+ if not paths:
673
+ click.echo(f"There is no model version named {model_version}.")
674
+ return
675
+ click.echo(f"Model {model_version} cache directory to be deleted:")
676
+ for path in response.get("paths", []):
677
+ click.echo(f"{path}")
678
+
679
+ if click.confirm("Do you want to proceed with the deletion?", abort=True):
680
+ check = True
681
+ try:
682
+ result = client.confirm_and_remove_model(
683
+ model_version=model_version, worker_ip=worker_ip
684
+ )
685
+ if result:
686
+ click.echo(f"Cache directory {model_version} has been deleted.")
687
+ else:
688
+ click.echo(
689
+ f"Cache directory {model_version} fail to be deleted. Please check the log."
690
+ )
691
+ except Exception as e:
692
+ click.echo(f"An error occurred while deleting the cache: {e}")
693
+
694
+
611
695
  @cli.command(
612
696
  "launch",
613
697
  help="Launch a model with the Xinference framework with the given parameters.",
@@ -26,6 +26,7 @@ from ..cmdline import (
26
26
  model_list,
27
27
  model_terminate,
28
28
  register_model,
29
+ remove_cache,
29
30
  unregister_model,
30
31
  )
31
32
 
@@ -287,18 +288,26 @@ def test_list_cached_models(setup):
287
288
 
288
289
  result = runner.invoke(
289
290
  list_cached_models,
290
- [
291
- "--endpoint",
292
- endpoint,
293
- ],
291
+ ["--endpoint", endpoint, "--model_name", "orca"],
294
292
  )
295
- assert result.exit_code == 0
296
- assert "cached_model: " in result.stdout
297
-
298
- # check if the output is in tabular format
299
293
  assert "model_name" in result.stdout
300
294
  assert "model_format" in result.stdout
301
295
  assert "model_size_in_billions" in result.stdout
302
- assert "quantizations" in result.stdout
296
+ assert "quantization" in result.stdout
297
+ assert "model_version" in result.stdout
303
298
  assert "path" in result.stdout
304
- assert "Actor IP Address" in result.stdout
299
+ assert "actor_ip_address" in result.stdout
300
+
301
+
302
+ def test_remove_cache(setup):
303
+ endpoint, _ = setup
304
+ runner = CliRunner()
305
+
306
+ result = runner.invoke(
307
+ remove_cache,
308
+ ["--endpoint", endpoint, "--model_version", "orca"],
309
+ input="y\n",
310
+ )
311
+
312
+ assert result.exit_code == 0
313
+ assert "Cache directory orca has been deleted."
@@ -32,6 +32,9 @@ from .custom import (
32
32
  )
33
33
 
34
34
  _model_spec_json = os.path.join(os.path.dirname(__file__), "model_spec.json")
35
+ _model_spec_modelscope_json = os.path.join(
36
+ os.path.dirname(__file__), "model_spec_modelscope.json"
37
+ )
35
38
  BUILTIN_AUDIO_MODELS = dict(
36
39
  (spec["model_name"], AudioModelFamilyV1(**spec))
37
40
  for spec in json.load(codecs.open(_model_spec_json, "r", encoding="utf-8"))
@@ -39,8 +42,17 @@ BUILTIN_AUDIO_MODELS = dict(
39
42
  for model_name, model_spec in BUILTIN_AUDIO_MODELS.items():
40
43
  MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)
41
44
 
45
+ MODELSCOPE_AUDIO_MODELS = dict(
46
+ (spec["model_name"], AudioModelFamilyV1(**spec))
47
+ for spec in json.load(
48
+ codecs.open(_model_spec_modelscope_json, "r", encoding="utf-8")
49
+ )
50
+ )
51
+ for model_name, model_spec in MODELSCOPE_AUDIO_MODELS.items():
52
+ MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)
53
+
42
54
  # register model description after recording model revision
43
- for model_spec_info in [BUILTIN_AUDIO_MODELS]:
55
+ for model_spec_info in [BUILTIN_AUDIO_MODELS, MODELSCOPE_AUDIO_MODELS]:
44
56
  for model_name, model_spec in model_spec_info.items():
45
57
  if model_spec.model_name not in AUDIO_MODEL_DESCRIPTIONS:
46
58
  AUDIO_MODEL_DESCRIPTIONS.update(generate_audio_description(model_spec))
@@ -64,3 +76,4 @@ for ud_audio in get_user_defined_audios():
64
76
  AUDIO_MODEL_DESCRIPTIONS.update(generate_audio_description(ud_audio))
65
77
 
66
78
  del _model_spec_json
79
+ del _model_spec_modelscope_json
@@ -95,13 +95,24 @@ def generate_audio_description(
95
95
 
96
96
 
97
97
  def match_audio(model_name: str) -> AudioModelFamilyV1:
98
- from . import BUILTIN_AUDIO_MODELS
98
+ from ..utils import download_from_modelscope
99
+ from . import BUILTIN_AUDIO_MODELS, MODELSCOPE_AUDIO_MODELS
99
100
  from .custom import get_user_defined_audios
100
101
 
101
102
  for model_spec in get_user_defined_audios():
102
103
  if model_spec.model_name == model_name:
103
104
  return model_spec
104
105
 
106
+ if download_from_modelscope():
107
+ if model_name in MODELSCOPE_AUDIO_MODELS:
108
+ logger.debug(f"Audio model {model_name} found in ModelScope.")
109
+ return MODELSCOPE_AUDIO_MODELS[model_name]
110
+ else:
111
+ logger.debug(
112
+ f"Audio model {model_name} not found in ModelScope, "
113
+ f"now try to load it via builtin way."
114
+ )
115
+
105
116
  if model_name in BUILTIN_AUDIO_MODELS:
106
117
  return BUILTIN_AUDIO_MODELS[model_name]
107
118
  else:
@@ -83,15 +83,17 @@ def get_user_defined_audios() -> List[CustomAudioModelFamilyV1]:
83
83
  def register_audio(model_spec: CustomAudioModelFamilyV1, persist: bool):
84
84
  from ...constants import XINFERENCE_MODEL_DIR
85
85
  from ..utils import is_valid_model_name, is_valid_model_uri
86
- from . import BUILTIN_AUDIO_MODELS
86
+ from . import BUILTIN_AUDIO_MODELS, MODELSCOPE_AUDIO_MODELS
87
87
 
88
88
  if not is_valid_model_name(model_spec.model_name):
89
89
  raise ValueError(f"Invalid model name {model_spec.model_name}.")
90
90
 
91
91
  with UD_AUDIO_LOCK:
92
- for model_name in list(BUILTIN_AUDIO_MODELS.keys()) + [
93
- spec.model_name for spec in UD_AUDIOS
94
- ]:
92
+ for model_name in (
93
+ list(BUILTIN_AUDIO_MODELS.keys())
94
+ + list(MODELSCOPE_AUDIO_MODELS.keys())
95
+ + [spec.model_name for spec in UD_AUDIOS]
96
+ ):
95
97
  if model_spec.model_name == model_name:
96
98
  raise ValueError(
97
99
  f"Model name conflicts with existing model {model_spec.model_name}"
@@ -0,0 +1,20 @@
1
+ [
2
+ {
3
+ "model_name": "whisper-large-v3",
4
+ "model_family": "whisper",
5
+ "model_hub": "modelscope",
6
+ "model_id": "AI-ModelScope/whisper-large-v3",
7
+ "model_revision": "master",
8
+ "ability": "audio-to-text",
9
+ "multilingual": true
10
+ },
11
+ {
12
+ "model_name": "ChatTTS",
13
+ "model_family": "ChatTTS",
14
+ "model_hub": "modelscope",
15
+ "model_id": "pzc163/chatTTS",
16
+ "model_revision": "master",
17
+ "ability": "text-to-audio",
18
+ "multilingual": true
19
+ }
20
+ ]
@@ -25,6 +25,7 @@ from .core import (
25
25
  get_llm_model_descriptions,
26
26
  )
27
27
  from .llm_family import (
28
+ BUILTIN_CSGHUB_LLM_FAMILIES,
28
29
  BUILTIN_LLM_FAMILIES,
29
30
  BUILTIN_LLM_MODEL_CHAT_FAMILIES,
30
31
  BUILTIN_LLM_MODEL_GENERATE_FAMILIES,
@@ -221,13 +222,44 @@ def _install():
221
222
  if "tools" in model_spec.model_ability:
222
223
  BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES.add(model_spec.model_name)
223
224
 
224
- for llm_specs in [BUILTIN_LLM_FAMILIES, BUILTIN_MODELSCOPE_LLM_FAMILIES]:
225
+ csghub_json_path = os.path.join(
226
+ os.path.dirname(os.path.abspath(__file__)), "llm_family_csghub.json"
227
+ )
228
+ for json_obj in json.load(codecs.open(csghub_json_path, "r", encoding="utf-8")):
229
+ model_spec = LLMFamilyV1.parse_obj(json_obj)
230
+ BUILTIN_CSGHUB_LLM_FAMILIES.append(model_spec)
231
+
232
+ # register prompt style, in case that we have something missed
233
+ # if duplicated with huggingface json, keep it as the huggingface style
234
+ if (
235
+ "chat" in model_spec.model_ability
236
+ and isinstance(model_spec.prompt_style, PromptStyleV1)
237
+ and model_spec.model_name not in BUILTIN_LLM_PROMPT_STYLE
238
+ ):
239
+ BUILTIN_LLM_PROMPT_STYLE[model_spec.model_name] = model_spec.prompt_style
240
+ # register model family
241
+ if "chat" in model_spec.model_ability:
242
+ BUILTIN_LLM_MODEL_CHAT_FAMILIES.add(model_spec.model_name)
243
+ else:
244
+ BUILTIN_LLM_MODEL_GENERATE_FAMILIES.add(model_spec.model_name)
245
+ if "tools" in model_spec.model_ability:
246
+ BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES.add(model_spec.model_name)
247
+
248
+ for llm_specs in [
249
+ BUILTIN_LLM_FAMILIES,
250
+ BUILTIN_MODELSCOPE_LLM_FAMILIES,
251
+ BUILTIN_CSGHUB_LLM_FAMILIES,
252
+ ]:
225
253
  for llm_spec in llm_specs:
226
254
  if llm_spec.model_name not in LLM_MODEL_DESCRIPTIONS:
227
255
  LLM_MODEL_DESCRIPTIONS.update(generate_llm_description(llm_spec))
228
256
 
229
257
  # traverse all families and add engine parameters corresponding to the model name
230
- for families in [BUILTIN_LLM_FAMILIES, BUILTIN_MODELSCOPE_LLM_FAMILIES]:
258
+ for families in [
259
+ BUILTIN_LLM_FAMILIES,
260
+ BUILTIN_MODELSCOPE_LLM_FAMILIES,
261
+ BUILTIN_CSGHUB_LLM_FAMILIES,
262
+ ]:
231
263
  for family in families:
232
264
  generate_engine_config_by_model_family(family)
233
265
 
@@ -939,6 +939,8 @@
939
939
  "model_format": "pytorch",
940
940
  "model_size_in_billions": 9,
941
941
  "quantizations": [
942
+ "4-bit",
943
+ "8-bit",
942
944
  "none"
943
945
  ],
944
946
  "model_id": "THUDM/glm-4v-9b",
@@ -32,10 +32,15 @@ from ..._compat import (
32
32
  load_str_bytes,
33
33
  validator,
34
34
  )
35
- from ...constants import XINFERENCE_CACHE_DIR, XINFERENCE_MODEL_DIR
35
+ from ...constants import (
36
+ XINFERENCE_CACHE_DIR,
37
+ XINFERENCE_ENV_CSG_TOKEN,
38
+ XINFERENCE_MODEL_DIR,
39
+ )
36
40
  from ..utils import (
37
41
  IS_NEW_HUGGINGFACE_HUB,
38
42
  create_symlink,
43
+ download_from_csghub,
39
44
  download_from_modelscope,
40
45
  is_valid_model_uri,
41
46
  parse_uri,
@@ -232,6 +237,7 @@ LLAMA_CLASSES: List[Type[LLM]] = []
232
237
 
233
238
  BUILTIN_LLM_FAMILIES: List["LLMFamilyV1"] = []
234
239
  BUILTIN_MODELSCOPE_LLM_FAMILIES: List["LLMFamilyV1"] = []
240
+ BUILTIN_CSGHUB_LLM_FAMILIES: List["LLMFamilyV1"] = []
235
241
 
236
242
  SGLANG_CLASSES: List[Type[LLM]] = []
237
243
  TRANSFORMERS_CLASSES: List[Type[LLM]] = []
@@ -292,6 +298,9 @@ def cache(
292
298
  elif llm_spec.model_hub == "modelscope":
293
299
  logger.info(f"Caching from Modelscope: {llm_spec.model_id}")
294
300
  return cache_from_modelscope(llm_family, llm_spec, quantization)
301
+ elif llm_spec.model_hub == "csghub":
302
+ logger.info(f"Caching from CSGHub: {llm_spec.model_id}")
303
+ return cache_from_csghub(llm_family, llm_spec, quantization)
295
304
  else:
296
305
  raise ValueError(f"Unknown model hub: {llm_spec.model_hub}")
297
306
 
@@ -566,6 +575,7 @@ def _skip_download(
566
575
  "modelscope": _get_meta_path(
567
576
  cache_dir, model_format, "modelscope", quantization
568
577
  ),
578
+ "csghub": _get_meta_path(cache_dir, model_format, "csghub", quantization),
569
579
  }
570
580
  if valid_model_revision(model_hub_to_meta_path[model_hub], model_revision):
571
581
  logger.info(f"Cache {cache_dir} exists")
@@ -650,6 +660,75 @@ def _merge_cached_files(
650
660
  logger.info(f"Merge complete.")
651
661
 
652
662
 
663
+ def cache_from_csghub(
664
+ llm_family: LLMFamilyV1,
665
+ llm_spec: "LLMSpecV1",
666
+ quantization: Optional[str] = None,
667
+ ) -> str:
668
+ """
669
+ Cache model from CSGHub. Return the cache directory.
670
+ """
671
+ from pycsghub.file_download import file_download
672
+ from pycsghub.snapshot_download import snapshot_download
673
+
674
+ cache_dir = _get_cache_dir(llm_family, llm_spec)
675
+
676
+ if _skip_download(
677
+ cache_dir,
678
+ llm_spec.model_format,
679
+ llm_spec.model_hub,
680
+ llm_spec.model_revision,
681
+ quantization,
682
+ ):
683
+ return cache_dir
684
+
685
+ if llm_spec.model_format in ["pytorch", "gptq", "awq"]:
686
+ download_dir = retry_download(
687
+ snapshot_download,
688
+ llm_family.model_name,
689
+ {
690
+ "model_size": llm_spec.model_size_in_billions,
691
+ "model_format": llm_spec.model_format,
692
+ },
693
+ llm_spec.model_id,
694
+ endpoint="https://hub-stg.opencsg.com",
695
+ token=os.environ.get(XINFERENCE_ENV_CSG_TOKEN),
696
+ )
697
+ create_symlink(download_dir, cache_dir)
698
+
699
+ elif llm_spec.model_format in ["ggmlv3", "ggufv2"]:
700
+ file_names, final_file_name, need_merge = _generate_model_file_names(
701
+ llm_spec, quantization
702
+ )
703
+
704
+ for filename in file_names:
705
+ download_path = retry_download(
706
+ file_download,
707
+ llm_family.model_name,
708
+ {
709
+ "model_size": llm_spec.model_size_in_billions,
710
+ "model_format": llm_spec.model_format,
711
+ },
712
+ llm_spec.model_id,
713
+ file_name=filename,
714
+ endpoint="https://hub-stg.opencsg.com",
715
+ token=os.environ.get(XINFERENCE_ENV_CSG_TOKEN),
716
+ )
717
+ symlink_local_file(download_path, cache_dir, filename)
718
+
719
+ if need_merge:
720
+ _merge_cached_files(cache_dir, file_names, final_file_name)
721
+ else:
722
+ raise ValueError(f"Unsupported format: {llm_spec.model_format}")
723
+
724
+ meta_path = _get_meta_path(
725
+ cache_dir, llm_spec.model_format, llm_spec.model_hub, quantization
726
+ )
727
+ _generate_meta_file(meta_path, llm_family, llm_spec, quantization)
728
+
729
+ return cache_dir
730
+
731
+
653
732
  def cache_from_modelscope(
654
733
  llm_family: LLMFamilyV1,
655
734
  llm_spec: "LLMSpecV1",
@@ -931,6 +1010,12 @@ def match_llm(
931
1010
  + BUILTIN_LLM_FAMILIES
932
1011
  + user_defined_llm_families
933
1012
  )
1013
+ elif download_from_csghub():
1014
+ all_families = (
1015
+ BUILTIN_CSGHUB_LLM_FAMILIES
1016
+ + BUILTIN_LLM_FAMILIES
1017
+ + user_defined_llm_families
1018
+ )
934
1019
  else:
935
1020
  all_families = BUILTIN_LLM_FAMILIES + user_defined_llm_families
936
1021
 
@@ -0,0 +1,66 @@
1
+ [
2
+ {
3
+ "version": 1,
4
+ "context_length": 32768,
5
+ "model_name": "qwen2-instruct",
6
+ "model_lang": [
7
+ "en",
8
+ "zh"
9
+ ],
10
+ "model_ability": [
11
+ "chat",
12
+ "tools"
13
+ ],
14
+ "model_description": "Qwen2 is the new series of Qwen large language models",
15
+ "model_specs": [
16
+ {
17
+ "model_format": "pytorch",
18
+ "model_size_in_billions": "0_5",
19
+ "quantizations": [
20
+ "4-bit",
21
+ "8-bit",
22
+ "none"
23
+ ],
24
+ "model_id": "Qwen/Qwen2-0.5B-Instruct",
25
+ "model_hub": "csghub"
26
+ },
27
+ {
28
+ "model_format": "ggufv2",
29
+ "model_size_in_billions": "0_5",
30
+ "quantizations": [
31
+ "q2_k",
32
+ "q3_k_m",
33
+ "q4_0",
34
+ "q4_k_m",
35
+ "q5_0",
36
+ "q5_k_m",
37
+ "q6_k",
38
+ "q8_0",
39
+ "fp16"
40
+ ],
41
+ "model_id": "qwen/Qwen2-0.5B-Instruct-GGUF",
42
+ "model_file_name_template": "qwen2-0_5b-instruct-{quantization}.gguf",
43
+ "model_hub": "csghub"
44
+ }
45
+ ],
46
+ "prompt_style": {
47
+ "style_name": "QWEN",
48
+ "system_prompt": "You are a helpful assistant.",
49
+ "roles": [
50
+ "user",
51
+ "assistant"
52
+ ],
53
+ "intra_message_sep": "\n",
54
+ "stop_token_ids": [
55
+ 151643,
56
+ 151644,
57
+ 151645
58
+ ],
59
+ "stop": [
60
+ "<|endoftext|>",
61
+ "<|im_start|>",
62
+ "<|im_end|>"
63
+ ]
64
+ }
65
+ }
66
+ ]
@@ -632,6 +632,8 @@
632
632
  "model_format": "pytorch",
633
633
  "model_size_in_billions": 9,
634
634
  "quantizations": [
635
+ "4-bit",
636
+ "8-bit",
635
637
  "none"
636
638
  ],
637
639
  "model_hub": "modelscope",
@@ -89,24 +89,30 @@ class ChatglmPytorchChatModel(PytorchChatModel):
89
89
  return False
90
90
  return True
91
91
 
92
- @staticmethod
93
- def _handle_tools(generate_config) -> Optional[dict]:
92
+ def _handle_tools(self, generate_config) -> Optional[dict]:
94
93
  """Convert openai tools to ChatGLM tools."""
95
94
  if generate_config is None:
96
95
  return None
97
96
  tools = generate_config.pop("tools", None)
98
97
  if tools is None:
99
98
  return None
100
- chatglm_tools = []
101
- for elem in tools:
102
- if elem.get("type") != "function" or "function" not in elem:
103
- raise ValueError("ChatGLM tools only support function type.")
104
- chatglm_tools.append(elem["function"])
105
- return {
106
- "role": "system",
107
- "content": f"Answer the following questions as best as you can. You have access to the following tools:",
108
- "tools": chatglm_tools,
109
- }
99
+ if self.model_family.model_name == "glm4-chat":
100
+ return {
101
+ "role": "system",
102
+ "content": None,
103
+ "tools": tools,
104
+ }
105
+ else:
106
+ chatglm_tools = []
107
+ for elem in tools:
108
+ if elem.get("type") != "function" or "function" not in elem:
109
+ raise ValueError("ChatGLM tools only support function type.")
110
+ chatglm_tools.append(elem["function"])
111
+ return {
112
+ "role": "system",
113
+ "content": f"Answer the following questions as best as you can. You have access to the following tools:",
114
+ "tools": chatglm_tools,
115
+ }
110
116
 
111
117
  def chat(
112
118
  self,