xinference 0.8.1__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (95) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/oauth2/auth_service.py +132 -0
  3. xinference/api/restful_api.py +282 -78
  4. xinference/client/handlers.py +3 -0
  5. xinference/client/restful/restful_client.py +108 -75
  6. xinference/constants.py +14 -4
  7. xinference/core/cache_tracker.py +102 -0
  8. xinference/core/chat_interface.py +10 -4
  9. xinference/core/event.py +56 -0
  10. xinference/core/model.py +44 -0
  11. xinference/core/resource.py +19 -12
  12. xinference/core/status_guard.py +4 -0
  13. xinference/core/supervisor.py +278 -87
  14. xinference/core/utils.py +68 -3
  15. xinference/core/worker.py +98 -8
  16. xinference/deploy/cmdline.py +6 -3
  17. xinference/deploy/local.py +2 -2
  18. xinference/deploy/supervisor.py +2 -2
  19. xinference/model/audio/__init__.py +27 -0
  20. xinference/model/audio/core.py +161 -0
  21. xinference/model/audio/model_spec.json +79 -0
  22. xinference/model/audio/utils.py +18 -0
  23. xinference/model/audio/whisper.py +132 -0
  24. xinference/model/core.py +18 -13
  25. xinference/model/embedding/__init__.py +27 -2
  26. xinference/model/embedding/core.py +43 -3
  27. xinference/model/embedding/model_spec.json +24 -0
  28. xinference/model/embedding/model_spec_modelscope.json +24 -0
  29. xinference/model/embedding/utils.py +18 -0
  30. xinference/model/image/__init__.py +12 -1
  31. xinference/model/image/core.py +63 -9
  32. xinference/model/image/utils.py +26 -0
  33. xinference/model/llm/__init__.py +20 -1
  34. xinference/model/llm/core.py +43 -2
  35. xinference/model/llm/ggml/chatglm.py +15 -6
  36. xinference/model/llm/llm_family.json +197 -6
  37. xinference/model/llm/llm_family.py +9 -7
  38. xinference/model/llm/llm_family_modelscope.json +189 -4
  39. xinference/model/llm/pytorch/chatglm.py +3 -3
  40. xinference/model/llm/pytorch/core.py +4 -2
  41. xinference/model/{multimodal → llm/pytorch}/qwen_vl.py +10 -8
  42. xinference/model/llm/pytorch/utils.py +21 -9
  43. xinference/model/llm/pytorch/yi_vl.py +246 -0
  44. xinference/model/llm/utils.py +57 -4
  45. xinference/model/llm/vllm/core.py +5 -4
  46. xinference/model/rerank/__init__.py +25 -2
  47. xinference/model/rerank/core.py +51 -9
  48. xinference/model/rerank/model_spec.json +6 -0
  49. xinference/model/rerank/model_spec_modelscope.json +7 -0
  50. xinference/{api/oauth2/common.py → model/rerank/utils.py} +6 -2
  51. xinference/model/utils.py +5 -3
  52. xinference/thirdparty/__init__.py +0 -0
  53. xinference/thirdparty/llava/__init__.py +1 -0
  54. xinference/thirdparty/llava/conversation.py +205 -0
  55. xinference/thirdparty/llava/mm_utils.py +122 -0
  56. xinference/thirdparty/llava/model/__init__.py +1 -0
  57. xinference/thirdparty/llava/model/clip_encoder/__init__.py +0 -0
  58. xinference/thirdparty/llava/model/clip_encoder/builder.py +11 -0
  59. xinference/thirdparty/llava/model/clip_encoder/clip_encoder.py +86 -0
  60. xinference/thirdparty/llava/model/constants.py +6 -0
  61. xinference/thirdparty/llava/model/llava_arch.py +385 -0
  62. xinference/thirdparty/llava/model/llava_llama.py +163 -0
  63. xinference/thirdparty/llava/model/multimodal_projector/__init__.py +0 -0
  64. xinference/thirdparty/llava/model/multimodal_projector/builder.py +64 -0
  65. xinference/types.py +1 -1
  66. xinference/web/ui/build/asset-manifest.json +3 -3
  67. xinference/web/ui/build/index.html +1 -1
  68. xinference/web/ui/build/static/js/main.15822aeb.js +3 -0
  69. xinference/web/ui/build/static/js/main.15822aeb.js.map +1 -0
  70. xinference/web/ui/node_modules/.cache/babel-loader/139e5e4adf436923107d2b02994c7ff6dba2aac1989e9b6638984f0dfe782c4a.json +1 -0
  71. xinference/web/ui/node_modules/.cache/babel-loader/52aa27272b4b9968f62666262b47661cb1992336a2aff3b13994cc36877b3ec3.json +1 -0
  72. xinference/web/ui/node_modules/.cache/babel-loader/64accc515dc6cd584a2873796cd7da6f93de57f7e465eb5423cca9a2f3fe3eff.json +1 -0
  73. xinference/web/ui/node_modules/.cache/babel-loader/65ca3ba225b8c8dac907210545b51f2fcdb2591f0feeb7195f1c037f2bc956a0.json +1 -0
  74. xinference/web/ui/node_modules/.cache/babel-loader/b80db1012318b97c329c4e3e72454f7512fb107e57c444b437dbe4ba1a3faa5a.json +1 -0
  75. {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/METADATA +33 -23
  76. {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/RECORD +81 -64
  77. xinference/api/oauth2/core.py +0 -93
  78. xinference/model/multimodal/__init__.py +0 -52
  79. xinference/model/multimodal/core.py +0 -467
  80. xinference/model/multimodal/model_spec.json +0 -43
  81. xinference/model/multimodal/model_spec_modelscope.json +0 -45
  82. xinference/web/ui/build/static/js/main.b83095c2.js +0 -3
  83. xinference/web/ui/build/static/js/main.b83095c2.js.map +0 -1
  84. xinference/web/ui/node_modules/.cache/babel-loader/101923c539819f26ad11fbcbd6f6e56436b285efbb090dcc7dd648c6e924c4a8.json +0 -1
  85. xinference/web/ui/node_modules/.cache/babel-loader/4942da6bc03bf7373af068e22f916341aabc5b5df855d73c1d348c696724ce37.json +0 -1
  86. xinference/web/ui/node_modules/.cache/babel-loader/52a6136cb2dbbf9c51d461724d9b283ebe74a73fb19d5df7ba8e13c42bd7174d.json +0 -1
  87. xinference/web/ui/node_modules/.cache/babel-loader/71493aadd34d568fbe605cacaba220aa69bd09273251ee4ba27930f8d01fccd8.json +0 -1
  88. xinference/web/ui/node_modules/.cache/babel-loader/8b071db2a5a9ef68dc14d5f606540bd23d9785e365a11997c510656764d2dccf.json +0 -1
  89. xinference/web/ui/node_modules/.cache/babel-loader/a4d72d3b806ba061919115f0c513738726872e3c79cf258f007519d3f91d1a16.json +0 -1
  90. xinference/web/ui/node_modules/.cache/babel-loader/f037ffef5992af0892d6d991053c1dace364cd39a3f11f1a41f92776e8a59459.json +0 -1
  91. /xinference/web/ui/build/static/js/{main.b83095c2.js.LICENSE.txt → main.15822aeb.js.LICENSE.txt} +0 -0
  92. {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/LICENSE +0 -0
  93. {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/WHEEL +0 -0
  94. {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/entry_points.txt +0 -0
  95. {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/top_level.txt +0 -0
xinference/core/utils.py CHANGED
@@ -11,26 +11,35 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ import copy
14
15
  import logging
15
16
  import os
16
17
  import random
17
18
  import string
18
- from typing import Generator, Tuple
19
+ from typing import Dict, Generator, List, Tuple, Union
19
20
 
20
21
  import orjson
21
22
  from pydantic import BaseModel
23
+ from pynvml import nvmlDeviceGetCount, nvmlInit, nvmlShutdown
22
24
 
23
25
  logger = logging.getLogger(__name__)
24
26
 
25
27
 
26
- def log_async(logger):
28
+ def log_async(logger, args_formatter=None):
27
29
  import time
28
30
  from functools import wraps
29
31
 
30
32
  def decorator(func):
31
33
  @wraps(func)
32
34
  async def wrapped(*args, **kwargs):
33
- logger.debug(f"Enter {func.__name__}, args: {args}, kwargs: {kwargs}")
35
+ if args_formatter is not None:
36
+ formatted_args, formatted_kwargs = copy.copy(args), copy.copy(kwargs)
37
+ args_formatter(formatted_args, formatted_kwargs)
38
+ else:
39
+ formatted_args, formatted_kwargs = args, kwargs
40
+ logger.debug(
41
+ f"Enter {func.__name__}, args: {formatted_args}, kwargs: {formatted_kwargs}"
42
+ )
34
43
  start = time.time()
35
44
  ret = await func(*args, **kwargs)
36
45
  logger.debug(
@@ -125,3 +134,59 @@ def purge_dir(d):
125
134
  os.rmdir(subdir)
126
135
  except Exception:
127
136
  pass
137
+
138
+
139
+ def parse_model_version(model_version: str, model_type: str) -> Tuple:
140
+ results: List[str] = model_version.split("--")
141
+ if model_type == "LLM":
142
+ if len(results) != 4:
143
+ raise ValueError(
144
+ f"LLM model_version parses failed! model_version: {model_version}"
145
+ )
146
+ model_name = results[0]
147
+ size = results[1]
148
+ if not size.endswith("B"):
149
+ raise ValueError(f"Cannot parse model_size_in_billions: {size}")
150
+ size = size.rstrip("B")
151
+ size_in_billions: Union[int, str] = size if "_" in size else int(size)
152
+ model_format = results[2]
153
+ quantization = results[3]
154
+ return model_name, size_in_billions, model_format, quantization
155
+ elif model_type == "embedding":
156
+ assert len(results) > 0, "Embedding model_version parses failed!"
157
+ return (results[0],)
158
+ elif model_type == "rerank":
159
+ assert len(results) > 0, "Rerank model_version parses failed!"
160
+ return (results[0],)
161
+ elif model_type == "image":
162
+ assert 2 >= len(results) >= 1, "Image model_version parses failed!"
163
+ return tuple(results)
164
+ else:
165
+ raise ValueError(f"Not supported model_type: {model_type}")
166
+
167
+
168
+ def _get_nvidia_gpu_mem_info(gpu_id: int) -> Dict[str, float]:
169
+ from pynvml import nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo
170
+
171
+ handler = nvmlDeviceGetHandleByIndex(gpu_id)
172
+ mem_info = nvmlDeviceGetMemoryInfo(handler)
173
+ return {"total": mem_info.total, "used": mem_info.used, "free": mem_info.free}
174
+
175
+
176
+ def get_nvidia_gpu_info() -> Dict:
177
+ try:
178
+ nvmlInit()
179
+ device_count = nvmlDeviceGetCount()
180
+ res = {}
181
+ for i in range(device_count):
182
+ res[f"gpu-{i}"] = _get_nvidia_gpu_mem_info(i)
183
+ return res
184
+ except:
185
+ # TODO: add log here
186
+ # logger.debug(f"Cannot init nvml. Maybe due to lack of NVIDIA GPUs or incorrect installation of CUDA.")
187
+ return {}
188
+ finally:
189
+ try:
190
+ nvmlShutdown()
191
+ except:
192
+ pass
xinference/core/worker.py CHANGED
@@ -18,11 +18,13 @@ import platform
18
18
  import queue
19
19
  import signal
20
20
  import threading
21
+ import time
21
22
  from collections import defaultdict
22
23
  from logging import getLogger
23
24
  from typing import Any, Dict, List, Optional, Set, Tuple, Union
24
25
 
25
26
  import xoscar as xo
27
+ from async_timeout import timeout
26
28
  from xoscar import MainActorPoolType
27
29
 
28
30
  from ..constants import XINFERENCE_CACHE_DIR
@@ -30,6 +32,7 @@ from ..core import ModelActor
30
32
  from ..core.status_guard import LaunchStatus
31
33
  from ..model.core import ModelDescription, create_model_instance
32
34
  from ..utils import cuda_count
35
+ from .event import Event, EventCollectorActor, EventType
33
36
  from .metrics import launch_metrics_export_server, record_metrics
34
37
  from .resource import gather_node_info
35
38
  from .utils import log_async, log_sync, parse_replica_model_uid, purge_dir
@@ -125,6 +128,15 @@ class WorkerActor(xo.StatelessActor):
125
128
  model_uid,
126
129
  recover_count - 1,
127
130
  )
131
+ event_model_uid, _, __ = parse_replica_model_uid(model_uid)
132
+ await self._event_collector_ref.report_event(
133
+ event_model_uid,
134
+ Event(
135
+ event_type=EventType.WARNING,
136
+ event_ts=int(time.time()),
137
+ event_content="Recreate model",
138
+ ),
139
+ )
128
140
  self._model_uid_to_recover_count[model_uid] = (
129
141
  recover_count - 1
130
142
  )
@@ -141,6 +153,8 @@ class WorkerActor(xo.StatelessActor):
141
153
  return "worker"
142
154
 
143
155
  async def __post_create__(self):
156
+ from ..isolation import Isolation
157
+ from .cache_tracker import CacheTrackerActor
144
158
  from .status_guard import StatusGuardActor
145
159
  from .supervisor import SupervisorActor
146
160
 
@@ -149,24 +163,46 @@ class WorkerActor(xo.StatelessActor):
149
163
  ] = await xo.actor_ref(
150
164
  address=self._supervisor_address, uid=StatusGuardActor.uid()
151
165
  )
166
+ self._event_collector_ref: xo.ActorRefType[
167
+ EventCollectorActor
168
+ ] = await xo.actor_ref(
169
+ address=self._supervisor_address, uid=EventCollectorActor.uid()
170
+ )
171
+ self._cache_tracker_ref: xo.ActorRefType[
172
+ "CacheTrackerActor"
173
+ ] = await xo.actor_ref(
174
+ address=self._supervisor_address, uid=CacheTrackerActor.uid()
175
+ )
152
176
  self._supervisor_ref: xo.ActorRefType["SupervisorActor"] = await xo.actor_ref(
153
177
  address=self._supervisor_address, uid=SupervisorActor.uid()
154
178
  )
155
179
  await self._supervisor_ref.add_worker(self.address)
156
- self._upload_task = asyncio.create_task(self._periodical_report_status())
180
+ # Run _periodical_report_status() in a dedicated thread.
181
+ self._isolation = Isolation(asyncio.new_event_loop(), threaded=True)
182
+ self._isolation.start()
183
+ asyncio.run_coroutine_threadsafe(
184
+ self._periodical_report_status(), loop=self._isolation.loop
185
+ )
157
186
  logger.info(f"Xinference worker {self.address} started")
158
187
  logger.info("Purge cache directory: %s", XINFERENCE_CACHE_DIR)
159
188
  purge_dir(XINFERENCE_CACHE_DIR)
160
189
 
161
190
  from ..model.embedding import (
162
191
  CustomEmbeddingModelSpec,
192
+ get_embedding_model_descriptions,
163
193
  register_embedding,
164
194
  unregister_embedding,
165
195
  )
166
- from ..model.llm import register_llm, unregister_llm
167
- from ..model.llm.llm_family import CustomLLMFamilyV1
168
- from ..model.rerank.custom import (
196
+ from ..model.image import get_image_model_descriptions
197
+ from ..model.llm import (
198
+ CustomLLMFamilyV1,
199
+ get_llm_model_descriptions,
200
+ register_llm,
201
+ unregister_llm,
202
+ )
203
+ from ..model.rerank import (
169
204
  CustomRerankModelSpec,
205
+ get_rerank_model_descriptions,
170
206
  register_rerank,
171
207
  unregister_rerank,
172
208
  )
@@ -181,6 +217,16 @@ class WorkerActor(xo.StatelessActor):
181
217
  "rerank": (CustomRerankModelSpec, register_rerank, unregister_rerank),
182
218
  }
183
219
 
220
+ # record model version
221
+ model_version_infos: Dict[str, List[Dict]] = {}
222
+ model_version_infos.update(get_llm_model_descriptions())
223
+ model_version_infos.update(get_embedding_model_descriptions())
224
+ model_version_infos.update(get_rerank_model_descriptions())
225
+ model_version_infos.update(get_image_model_descriptions())
226
+ await self._cache_tracker_ref.record_model_version(
227
+ model_version_infos, self.address
228
+ )
229
+
184
230
  # Windows does not have signal handler
185
231
  if os.name != "nt":
186
232
 
@@ -194,7 +240,7 @@ class WorkerActor(xo.StatelessActor):
194
240
  )
195
241
 
196
242
  async def __pre_destroy__(self):
197
- self._upload_task.cancel()
243
+ self._isolation.stop()
198
244
 
199
245
  @staticmethod
200
246
  def get_devices_count():
@@ -407,13 +453,30 @@ class WorkerActor(xo.StatelessActor):
407
453
  return ["rerank"]
408
454
  elif model_type == "image":
409
455
  return ["text_to_image"]
410
- elif model_type == "multimodal":
411
- return ["multimodal"]
456
+ elif model_type == "audio":
457
+ return ["audio_to_text"]
412
458
  else:
413
459
  assert model_type == "LLM"
414
460
  assert isinstance(model, LLM)
415
461
  return model.model_family.model_ability # type: ignore
416
462
 
463
+ async def update_cache_status(
464
+ self, model_name: str, model_description: ModelDescription
465
+ ):
466
+ version_info = model_description.to_version_info()
467
+ if isinstance(version_info, list): # image model
468
+ model_path = version_info[0]["model_file_location"]
469
+ await self._cache_tracker_ref.update_cache_status(
470
+ self.address, model_name, None, model_path
471
+ )
472
+ else:
473
+ await self._cache_tracker_ref.update_cache_status(
474
+ self.address,
475
+ model_name,
476
+ version_info["model_version"],
477
+ version_info["model_file_location"],
478
+ )
479
+
417
480
  @log_async(logger=logger)
418
481
  async def launch_builtin_model(
419
482
  self,
@@ -427,6 +490,15 @@ class WorkerActor(xo.StatelessActor):
427
490
  request_limits: Optional[int] = None,
428
491
  **kwargs,
429
492
  ):
493
+ event_model_uid, _, __ = parse_replica_model_uid(model_uid)
494
+ await self._event_collector_ref.report_event(
495
+ event_model_uid,
496
+ Event(
497
+ event_type=EventType.INFO,
498
+ event_ts=int(time.time()),
499
+ event_content="Launch model",
500
+ ),
501
+ )
430
502
  launch_args = locals()
431
503
  launch_args.pop("self")
432
504
  launch_args.pop("kwargs")
@@ -464,6 +536,7 @@ class WorkerActor(xo.StatelessActor):
464
536
  is_local_deployment,
465
537
  **kwargs,
466
538
  )
539
+ await self.update_cache_status(model_name, model_description)
467
540
  model_ref = await xo.create_actor(
468
541
  ModelActor,
469
542
  address=subpool_address,
@@ -497,6 +570,15 @@ class WorkerActor(xo.StatelessActor):
497
570
 
498
571
  @log_async(logger=logger)
499
572
  async def terminate_model(self, model_uid: str):
573
+ event_model_uid, _, __ = parse_replica_model_uid(model_uid)
574
+ await self._event_collector_ref.report_event(
575
+ event_model_uid,
576
+ Event(
577
+ event_type=EventType.INFO,
578
+ event_ts=int(time.time()),
579
+ event_content="Terminate model",
580
+ ),
581
+ )
500
582
  origin_uid, _, _ = parse_replica_model_uid(model_uid)
501
583
  await self._status_guard_ref.update_instance_info(
502
584
  origin_uid, {"status": LaunchStatus.TERMINATING.name}
@@ -553,7 +635,15 @@ class WorkerActor(xo.StatelessActor):
553
635
  return model_desc.to_dict()
554
636
 
555
637
  async def report_status(self):
556
- status = await asyncio.to_thread(gather_node_info)
638
+ status = dict()
639
+ try:
640
+ # asyncio.timeout is only available in Python >= 3.11
641
+ async with timeout(2):
642
+ status = await asyncio.to_thread(gather_node_info)
643
+ except asyncio.CancelledError:
644
+ raise
645
+ except Exception:
646
+ logger.exception("Report status got error.")
557
647
  await self._supervisor_ref.report_worker_status(self.address, status)
558
648
 
559
649
  async def _periodical_report_status(self):
@@ -499,7 +499,7 @@ def list_model_registrations(
499
499
  tabulate(table, headers=["Type", "Name", "Family", "Is-built-in"]),
500
500
  file=sys.stderr,
501
501
  )
502
- elif model_type == "multimodal":
502
+ elif model_type == "audio":
503
503
  for registration in registrations:
504
504
  model_name = registration["model_name"]
505
505
  model_family = client.get_model_registration(model_type, model_name)
@@ -507,12 +507,15 @@ def list_model_registrations(
507
507
  [
508
508
  model_type,
509
509
  model_family["model_name"],
510
- model_family["model_lang"],
510
+ model_family["model_family"],
511
+ model_family["multilingual"],
511
512
  registration["is_builtin"],
512
513
  ]
513
514
  )
514
515
  print(
515
- tabulate(table, headers=["Type", "Name", "Language", "Is-built-in"]),
516
+ tabulate(
517
+ table, headers=["Type", "Name", "Family", "Multilingual", "Is-built-in"]
518
+ ),
516
519
  file=sys.stderr,
517
520
  )
518
521
  else:
@@ -23,7 +23,7 @@ import xoscar as xo
23
23
  from xoscar.utils import get_next_port
24
24
 
25
25
  from ..constants import (
26
- XINFERENCE_HEALTH_CHECK_ATTEMPTS,
26
+ XINFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD,
27
27
  XINFERENCE_HEALTH_CHECK_INTERVAL,
28
28
  )
29
29
  from ..core.supervisor import SupervisorActor
@@ -116,7 +116,7 @@ def main(
116
116
 
117
117
  if not health_check(
118
118
  address=supervisor_address,
119
- max_attempts=XINFERENCE_HEALTH_CHECK_ATTEMPTS,
119
+ max_attempts=XINFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD,
120
120
  sleep_interval=XINFERENCE_HEALTH_CHECK_INTERVAL,
121
121
  ):
122
122
  raise RuntimeError("Cluster is not available after multiple attempts")
@@ -23,7 +23,7 @@ import xoscar as xo
23
23
  from xoscar.utils import get_next_port
24
24
 
25
25
  from ..constants import (
26
- XINFERENCE_HEALTH_CHECK_ATTEMPTS,
26
+ XINFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD,
27
27
  XINFERENCE_HEALTH_CHECK_INTERVAL,
28
28
  )
29
29
  from ..core.supervisor import SupervisorActor
@@ -82,7 +82,7 @@ def main(
82
82
 
83
83
  if not health_check(
84
84
  address=supervisor_address,
85
- max_attempts=XINFERENCE_HEALTH_CHECK_ATTEMPTS,
85
+ max_attempts=XINFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD,
86
86
  sleep_interval=XINFERENCE_HEALTH_CHECK_INTERVAL,
87
87
  ):
88
88
  raise RuntimeError("Supervisor is not available after multiple attempts")
@@ -0,0 +1,27 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import codecs
16
+ import json
17
+ import os
18
+
19
+ from .core import AudioModelFamilyV1, generate_audio_description, get_cache_status
20
+
21
+ _model_spec_json = os.path.join(os.path.dirname(__file__), "model_spec.json")
22
+ BUILTIN_AUDIO_MODELS = dict(
23
+ (spec["model_name"], AudioModelFamilyV1(**spec))
24
+ for spec in json.load(codecs.open(_model_spec_json, "r", encoding="utf-8"))
25
+ )
26
+
27
+ del _model_spec_json
@@ -0,0 +1,161 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import os
16
+ from collections import defaultdict
17
+ from typing import Dict, List, Optional, Tuple
18
+
19
+ from pydantic import BaseModel
20
+
21
+ from ...constants import XINFERENCE_CACHE_DIR
22
+ from ..core import ModelDescription
23
+ from ..utils import valid_model_revision
24
+ from .whisper import WhisperModel
25
+
26
+ MAX_ATTEMPTS = 3
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class AudioModelFamilyV1(BaseModel):
32
+ model_family: str
33
+ model_name: str
34
+ model_id: str
35
+ model_revision: str
36
+ multilingual: bool
37
+
38
+
39
+ class AudioModelDescription(ModelDescription):
40
+ def __init__(
41
+ self,
42
+ address: Optional[str],
43
+ devices: Optional[List[str]],
44
+ model_spec: AudioModelFamilyV1,
45
+ model_path: Optional[str] = None,
46
+ ):
47
+ super().__init__(address, devices, model_path=model_path)
48
+ self._model_spec = model_spec
49
+
50
+ def to_dict(self):
51
+ return {
52
+ "model_type": "audio",
53
+ "address": self.address,
54
+ "accelerators": self.devices,
55
+ "model_name": self._model_spec.model_name,
56
+ "model_family": self._model_spec.model_family,
57
+ "model_revision": self._model_spec.model_revision,
58
+ }
59
+
60
+ def to_version_info(self):
61
+ from .utils import get_model_version
62
+
63
+ if self._model_path is None:
64
+ is_cached = get_cache_status(self._model_spec)
65
+ file_location = get_cache_dir(self._model_spec)
66
+ else:
67
+ is_cached = True
68
+ file_location = self._model_path
69
+
70
+ return {
71
+ "model_version": get_model_version(self._model_spec),
72
+ "model_file_location": file_location,
73
+ "cache_status": is_cached,
74
+ }
75
+
76
+
77
+ def generate_audio_description(
78
+ image_model: AudioModelFamilyV1,
79
+ ) -> Dict[str, List[Dict]]:
80
+ res = defaultdict(list)
81
+ res[image_model.model_name].extend(
82
+ AudioModelDescription(None, None, image_model).to_dict()
83
+ )
84
+ return res
85
+
86
+
87
+ def match_model(model_name: str) -> AudioModelFamilyV1:
88
+ from . import BUILTIN_AUDIO_MODELS
89
+
90
+ if model_name in BUILTIN_AUDIO_MODELS:
91
+ return BUILTIN_AUDIO_MODELS[model_name]
92
+ else:
93
+ raise ValueError(
94
+ f"Image model {model_name} not found, available"
95
+ f"model list: {BUILTIN_AUDIO_MODELS.keys()}"
96
+ )
97
+
98
+
99
+ def cache(model_spec: AudioModelFamilyV1):
100
+ # TODO: cache from uri
101
+ import huggingface_hub
102
+
103
+ cache_dir = get_cache_dir(model_spec)
104
+ if not os.path.exists(cache_dir):
105
+ os.makedirs(cache_dir, exist_ok=True)
106
+
107
+ meta_path = os.path.join(cache_dir, "__valid_download")
108
+ if valid_model_revision(meta_path, model_spec.model_revision):
109
+ return cache_dir
110
+
111
+ for current_attempt in range(1, MAX_ATTEMPTS + 1):
112
+ try:
113
+ huggingface_hub.snapshot_download(
114
+ model_spec.model_id,
115
+ revision=model_spec.model_revision,
116
+ local_dir=cache_dir,
117
+ local_dir_use_symlinks=True,
118
+ resume_download=True,
119
+ )
120
+ break
121
+ except huggingface_hub.utils.LocalEntryNotFoundError:
122
+ remaining_attempts = MAX_ATTEMPTS - current_attempt
123
+ logger.warning(
124
+ f"Attempt {current_attempt} failed. Remaining attempts: {remaining_attempts}"
125
+ )
126
+ else:
127
+ raise RuntimeError(
128
+ f"Failed to download model '{model_spec.model_name}' after {MAX_ATTEMPTS} attempts"
129
+ )
130
+
131
+ with open(meta_path, "w") as f:
132
+ import json
133
+
134
+ desc = AudioModelDescription(None, None, model_spec)
135
+ json.dump(desc.to_dict(), f)
136
+
137
+ return cache_dir
138
+
139
+
140
+ def get_cache_dir(model_spec: AudioModelFamilyV1):
141
+ return os.path.realpath(os.path.join(XINFERENCE_CACHE_DIR, model_spec.model_name))
142
+
143
+
144
+ def get_cache_status(
145
+ model_spec: AudioModelFamilyV1,
146
+ ) -> bool:
147
+ cache_dir = get_cache_dir(model_spec)
148
+ meta_path = os.path.join(cache_dir, "__valid_download")
149
+ return valid_model_revision(meta_path, model_spec.model_revision)
150
+
151
+
152
+ def create_audio_model_instance(
153
+ subpool_addr: str, devices: List[str], model_uid: str, model_name: str, **kwargs
154
+ ) -> Tuple[WhisperModel, AudioModelDescription]:
155
+ model_spec = match_model(model_name)
156
+ model_path = cache(model_spec)
157
+ model = WhisperModel(model_uid, model_path, model_spec, **kwargs)
158
+ model_description = AudioModelDescription(
159
+ subpool_addr, devices, model_spec, model_path=model_path
160
+ )
161
+ return model, model_description
@@ -0,0 +1,79 @@
1
+ [
2
+ {
3
+ "model_name": "whisper-tiny",
4
+ "model_family": "whisper",
5
+ "model_id": "openai/whisper-tiny",
6
+ "model_revision": "167c219b21f11ef214220b8fdb7536b8a88c2475",
7
+ "multilingual": true
8
+ },
9
+ {
10
+ "model_name": "whisper-tiny.en",
11
+ "model_family": "whisper",
12
+ "model_id": "openai/whisper-tiny.en",
13
+ "model_revision": "87c7102498dcde7456f24cfd30239ca606ed9063",
14
+ "multilingual": false
15
+ },
16
+ {
17
+ "model_name": "whisper-base",
18
+ "model_family": "whisper",
19
+ "model_id": "openai/whisper-base",
20
+ "model_revision": "8c1db9b51951100007a96a525d83a8ec81b3c237",
21
+ "multilingual": true
22
+ },
23
+ {
24
+ "model_name": "whisper-base.en",
25
+ "model_family": "whisper",
26
+ "model_id": "openai/whisper-base.en",
27
+ "model_revision": "911407f4214e0e1d82085af863093ec0b66f9cd6",
28
+ "multilingual": false
29
+ },
30
+ {
31
+ "model_name": "whisper-small",
32
+ "model_family": "whisper",
33
+ "model_id": "openai/whisper-small",
34
+ "model_revision": "998cb1a777c20db53d6033a61b977ed4c3792cac",
35
+ "multilingual": true
36
+ },
37
+ {
38
+ "model_name": "whisper-small.en",
39
+ "model_family": "whisper",
40
+ "model_id": "openai/whisper-small.en",
41
+ "model_revision": "e8727524f962ee844a7319d92be39ac1bd25655a",
42
+ "multilingual": false
43
+ },
44
+ {
45
+ "model_name": "whisper-medium",
46
+ "model_family": "whisper",
47
+ "model_id": "openai/whisper-medium",
48
+ "model_revision": "16688beb1294bedd0a6f5cd86fe7eec57bce41ed",
49
+ "multilingual": true
50
+ },
51
+ {
52
+ "model_name": "whisper-medium.en",
53
+ "model_family": "whisper",
54
+ "model_id": "openai/whisper-medium.en",
55
+ "model_revision": "2e98eb6279edf5095af0c8dedb36bdec0acd172b",
56
+ "multilingual": false
57
+ },
58
+ {
59
+ "model_name": "whisper-large-v3",
60
+ "model_family": "whisper",
61
+ "model_id": "openai/whisper-large-v3",
62
+ "model_revision": "6cdf07a7e3ec3806e5d55f787915b85d4cd020b1",
63
+ "multilingual": true
64
+ },
65
+ {
66
+ "model_name": "Belle-distilwhisper-large-v2-zh",
67
+ "model_family": "whisper",
68
+ "model_id": "BELLE-2/Belle-distilwhisper-large-v2-zh",
69
+ "model_revision": "ed25d13498fa5bac758b2fc479435b698532dfe8",
70
+ "multilingual": false
71
+ },
72
+ {
73
+ "model_name": "Belle-whisper-large-v2-zh",
74
+ "model_family": "whisper",
75
+ "model_id": "BELLE-2/Belle-whisper-large-v2-zh",
76
+ "model_revision": "ec5bd5d78598545b7585814edde86dac2002b5b9",
77
+ "multilingual": false
78
+ }
79
+ ]
@@ -0,0 +1,18 @@
1
+ # Copyright 2022-2024 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from .core import AudioModelFamilyV1
15
+
16
+
17
+ def get_model_version(audio_model: AudioModelFamilyV1) -> str:
18
+ return audio_model.model_name