xinference 0.8.1__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (95) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/oauth2/auth_service.py +132 -0
  3. xinference/api/restful_api.py +282 -78
  4. xinference/client/handlers.py +3 -0
  5. xinference/client/restful/restful_client.py +108 -75
  6. xinference/constants.py +14 -4
  7. xinference/core/cache_tracker.py +102 -0
  8. xinference/core/chat_interface.py +10 -4
  9. xinference/core/event.py +56 -0
  10. xinference/core/model.py +44 -0
  11. xinference/core/resource.py +19 -12
  12. xinference/core/status_guard.py +4 -0
  13. xinference/core/supervisor.py +278 -87
  14. xinference/core/utils.py +68 -3
  15. xinference/core/worker.py +98 -8
  16. xinference/deploy/cmdline.py +6 -3
  17. xinference/deploy/local.py +2 -2
  18. xinference/deploy/supervisor.py +2 -2
  19. xinference/model/audio/__init__.py +27 -0
  20. xinference/model/audio/core.py +161 -0
  21. xinference/model/audio/model_spec.json +79 -0
  22. xinference/model/audio/utils.py +18 -0
  23. xinference/model/audio/whisper.py +132 -0
  24. xinference/model/core.py +18 -13
  25. xinference/model/embedding/__init__.py +27 -2
  26. xinference/model/embedding/core.py +43 -3
  27. xinference/model/embedding/model_spec.json +24 -0
  28. xinference/model/embedding/model_spec_modelscope.json +24 -0
  29. xinference/model/embedding/utils.py +18 -0
  30. xinference/model/image/__init__.py +12 -1
  31. xinference/model/image/core.py +63 -9
  32. xinference/model/image/utils.py +26 -0
  33. xinference/model/llm/__init__.py +20 -1
  34. xinference/model/llm/core.py +43 -2
  35. xinference/model/llm/ggml/chatglm.py +15 -6
  36. xinference/model/llm/llm_family.json +197 -6
  37. xinference/model/llm/llm_family.py +9 -7
  38. xinference/model/llm/llm_family_modelscope.json +189 -4
  39. xinference/model/llm/pytorch/chatglm.py +3 -3
  40. xinference/model/llm/pytorch/core.py +4 -2
  41. xinference/model/{multimodal → llm/pytorch}/qwen_vl.py +10 -8
  42. xinference/model/llm/pytorch/utils.py +21 -9
  43. xinference/model/llm/pytorch/yi_vl.py +246 -0
  44. xinference/model/llm/utils.py +57 -4
  45. xinference/model/llm/vllm/core.py +5 -4
  46. xinference/model/rerank/__init__.py +25 -2
  47. xinference/model/rerank/core.py +51 -9
  48. xinference/model/rerank/model_spec.json +6 -0
  49. xinference/model/rerank/model_spec_modelscope.json +7 -0
  50. xinference/{api/oauth2/common.py → model/rerank/utils.py} +6 -2
  51. xinference/model/utils.py +5 -3
  52. xinference/thirdparty/__init__.py +0 -0
  53. xinference/thirdparty/llava/__init__.py +1 -0
  54. xinference/thirdparty/llava/conversation.py +205 -0
  55. xinference/thirdparty/llava/mm_utils.py +122 -0
  56. xinference/thirdparty/llava/model/__init__.py +1 -0
  57. xinference/thirdparty/llava/model/clip_encoder/__init__.py +0 -0
  58. xinference/thirdparty/llava/model/clip_encoder/builder.py +11 -0
  59. xinference/thirdparty/llava/model/clip_encoder/clip_encoder.py +86 -0
  60. xinference/thirdparty/llava/model/constants.py +6 -0
  61. xinference/thirdparty/llava/model/llava_arch.py +385 -0
  62. xinference/thirdparty/llava/model/llava_llama.py +163 -0
  63. xinference/thirdparty/llava/model/multimodal_projector/__init__.py +0 -0
  64. xinference/thirdparty/llava/model/multimodal_projector/builder.py +64 -0
  65. xinference/types.py +1 -1
  66. xinference/web/ui/build/asset-manifest.json +3 -3
  67. xinference/web/ui/build/index.html +1 -1
  68. xinference/web/ui/build/static/js/main.15822aeb.js +3 -0
  69. xinference/web/ui/build/static/js/main.15822aeb.js.map +1 -0
  70. xinference/web/ui/node_modules/.cache/babel-loader/139e5e4adf436923107d2b02994c7ff6dba2aac1989e9b6638984f0dfe782c4a.json +1 -0
  71. xinference/web/ui/node_modules/.cache/babel-loader/52aa27272b4b9968f62666262b47661cb1992336a2aff3b13994cc36877b3ec3.json +1 -0
  72. xinference/web/ui/node_modules/.cache/babel-loader/64accc515dc6cd584a2873796cd7da6f93de57f7e465eb5423cca9a2f3fe3eff.json +1 -0
  73. xinference/web/ui/node_modules/.cache/babel-loader/65ca3ba225b8c8dac907210545b51f2fcdb2591f0feeb7195f1c037f2bc956a0.json +1 -0
  74. xinference/web/ui/node_modules/.cache/babel-loader/b80db1012318b97c329c4e3e72454f7512fb107e57c444b437dbe4ba1a3faa5a.json +1 -0
  75. {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/METADATA +33 -23
  76. {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/RECORD +81 -64
  77. xinference/api/oauth2/core.py +0 -93
  78. xinference/model/multimodal/__init__.py +0 -52
  79. xinference/model/multimodal/core.py +0 -467
  80. xinference/model/multimodal/model_spec.json +0 -43
  81. xinference/model/multimodal/model_spec_modelscope.json +0 -45
  82. xinference/web/ui/build/static/js/main.b83095c2.js +0 -3
  83. xinference/web/ui/build/static/js/main.b83095c2.js.map +0 -1
  84. xinference/web/ui/node_modules/.cache/babel-loader/101923c539819f26ad11fbcbd6f6e56436b285efbb090dcc7dd648c6e924c4a8.json +0 -1
  85. xinference/web/ui/node_modules/.cache/babel-loader/4942da6bc03bf7373af068e22f916341aabc5b5df855d73c1d348c696724ce37.json +0 -1
  86. xinference/web/ui/node_modules/.cache/babel-loader/52a6136cb2dbbf9c51d461724d9b283ebe74a73fb19d5df7ba8e13c42bd7174d.json +0 -1
  87. xinference/web/ui/node_modules/.cache/babel-loader/71493aadd34d568fbe605cacaba220aa69bd09273251ee4ba27930f8d01fccd8.json +0 -1
  88. xinference/web/ui/node_modules/.cache/babel-loader/8b071db2a5a9ef68dc14d5f606540bd23d9785e365a11997c510656764d2dccf.json +0 -1
  89. xinference/web/ui/node_modules/.cache/babel-loader/a4d72d3b806ba061919115f0c513738726872e3c79cf258f007519d3f91d1a16.json +0 -1
  90. xinference/web/ui/node_modules/.cache/babel-loader/f037ffef5992af0892d6d991053c1dace364cd39a3f11f1a41f92776e8a59459.json +0 -1
  91. /xinference/web/ui/build/static/js/{main.b83095c2.js.LICENSE.txt → main.15822aeb.js.LICENSE.txt} +0 -0
  92. {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/LICENSE +0 -0
  93. {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/WHEEL +0 -0
  94. {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/entry_points.txt +0 -0
  95. {xinference-0.8.1.dist-info → xinference-0.8.3.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,6 @@
1
+ from .restful.restful_client import ( # noqa: F401
2
+ RESTfulAudioModelHandle as AudioModelHandle,
3
+ )
1
4
  from .restful.restful_client import ( # noqa: F401
2
5
  RESTfulChatglmCppChatModelHandle as ChatglmCppChatModelHandle,
3
6
  )
@@ -400,19 +400,17 @@ class RESTfulChatModelHandle(RESTfulGenerateModelHandle):
400
400
  return response_data
401
401
 
402
402
 
403
- class RESTfulMultimodalModelHandle(RESTfulModelHandle):
403
+ class RESTfulChatglmCppChatModelHandle(RESTfulModelHandle):
404
404
  def chat(
405
405
  self,
406
- prompt: Any,
406
+ prompt: str,
407
407
  system_prompt: Optional[str] = None,
408
408
  chat_history: Optional[List["ChatCompletionMessage"]] = None,
409
409
  tools: Optional[List[Dict]] = None,
410
- generate_config: Optional[
411
- Union["LlamaCppGenerateConfig", "PytorchGenerateConfig"]
412
- ] = None,
410
+ generate_config: Optional["ChatglmCppGenerateConfig"] = None,
413
411
  ) -> Union["ChatCompletion", Iterator["ChatCompletionChunk"]]:
414
412
  """
415
- Given a list of messages comprising a conversation, the model will return a response via RESTful APIs.
413
+ Given a list of messages comprising a conversation, the ChatGLM model will return a response via RESTful APIs.
416
414
 
417
415
  Parameters
418
416
  ----------
@@ -424,10 +422,8 @@ class RESTfulMultimodalModelHandle(RESTfulModelHandle):
424
422
  A list of messages comprising the conversation so far.
425
423
  tools: Optional[List[Dict]]
426
424
  A tool list.
427
- generate_config: Optional[Union["LlamaCppGenerateConfig", "PytorchGenerateConfig"]]
428
- Additional configuration for the chat generation.
429
- "LlamaCppGenerateConfig" -> configuration for ggml model
430
- "PytorchGenerateConfig" -> configuration for pytorch model
425
+ generate_config: Optional["ChatglmCppGenerateConfig"]
426
+ Additional configuration for ChatGLM chat generation.
431
427
 
432
428
  Returns
433
429
  -------
@@ -451,7 +447,6 @@ class RESTfulMultimodalModelHandle(RESTfulModelHandle):
451
447
  if chat_history and chat_history[0]["role"] == "system":
452
448
  if system_prompt is not None:
453
449
  chat_history[0]["content"] = system_prompt
454
-
455
450
  else:
456
451
  if system_prompt is not None:
457
452
  chat_history.insert(0, {"role": "system", "content": system_prompt})
@@ -463,8 +458,7 @@ class RESTfulMultimodalModelHandle(RESTfulModelHandle):
463
458
  "messages": chat_history,
464
459
  }
465
460
  if tools is not None:
466
- raise RuntimeError("Multimodal does not support function call.")
467
-
461
+ request_body["tools"] = tools
468
462
  if generate_config is not None:
469
463
  for key, value in generate_config.items():
470
464
  request_body[key] = value
@@ -486,67 +480,51 @@ class RESTfulMultimodalModelHandle(RESTfulModelHandle):
486
480
  return response_data
487
481
 
488
482
 
489
- class RESTfulChatglmCppChatModelHandle(RESTfulModelHandle):
490
- def chat(
483
+ class RESTfulChatglmCppGenerateModelHandle(RESTfulChatglmCppChatModelHandle):
484
+ def generate(
491
485
  self,
492
486
  prompt: str,
493
- chat_history: Optional[List["ChatCompletionMessage"]] = None,
494
- tools: Optional[List[Dict]] = None,
495
487
  generate_config: Optional["ChatglmCppGenerateConfig"] = None,
496
- ) -> Union["ChatCompletion", Iterator["ChatCompletionChunk"]]:
488
+ ) -> Union["Completion", Iterator["CompletionChunk"]]:
497
489
  """
498
- Given a list of messages comprising a conversation, the ChatGLM model will return a response via RESTful APIs.
490
+ Given a prompt, the ChatGLM model will generate a response via RESTful APIs.
499
491
 
500
492
  Parameters
501
493
  ----------
502
494
  prompt: str
503
495
  The user's input.
504
- chat_history: Optional[List["ChatCompletionMessage"]]
505
- A list of messages comprising the conversation so far.
506
- tools: Optional[List[Dict]]
507
- A tool list.
508
496
  generate_config: Optional["ChatglmCppGenerateConfig"]
509
497
  Additional configuration for ChatGLM chat generation.
510
498
 
511
499
  Returns
512
500
  -------
513
- Union["ChatCompletion", Iterator["ChatCompletionChunk"]]
501
+ Union["Completion", Iterator["CompletionChunk"]]
514
502
  Stream is a parameter in generate_config.
515
- When stream is set to True, the function will return Iterator["ChatCompletionChunk"].
516
- When stream is set to False, the function will return "ChatCompletion".
503
+ When stream is set to True, the function will return Iterator["CompletionChunk"].
504
+ When stream is set to False, the function will return "Completion".
517
505
 
518
506
  Raises
519
507
  ------
520
508
  RuntimeError
521
- Report the failure to generate the chat from the server. Detailed information provided in error message.
509
+ Report the failure to generate the content from the server. Detailed information provided in error message.
522
510
 
523
511
  """
524
512
 
525
- url = f"{self._base_url}/v1/chat/completions"
526
-
527
- if chat_history is None:
528
- chat_history = []
529
-
530
- chat_history.append({"role": "user", "content": prompt})
513
+ url = f"{self._base_url}/v1/completions"
531
514
 
532
- request_body: Dict[str, Any] = {
533
- "model": self._model_uid,
534
- "messages": chat_history,
535
- }
536
- if tools is not None:
537
- request_body["tools"] = tools
515
+ request_body: Dict[str, Any] = {"model": self._model_uid, "prompt": prompt}
538
516
  if generate_config is not None:
539
517
  for key, value in generate_config.items():
540
518
  request_body[key] = value
541
519
 
542
520
  stream = bool(generate_config and generate_config.get("stream"))
521
+
543
522
  response = requests.post(
544
523
  url, json=request_body, stream=stream, headers=self.auth_headers
545
524
  )
546
-
547
525
  if response.status_code != 200:
548
526
  raise RuntimeError(
549
- f"Failed to generate chat completion, detail: {_get_error_string(response)}"
527
+ f"Failed to generate completion, detail: {response.json()['detail']}"
550
528
  )
551
529
 
552
530
  if stream:
@@ -556,56 +534,111 @@ class RESTfulChatglmCppChatModelHandle(RESTfulModelHandle):
556
534
  return response_data
557
535
 
558
536
 
559
- class RESTfulChatglmCppGenerateModelHandle(RESTfulChatglmCppChatModelHandle):
560
- def generate(
537
+ class RESTfulAudioModelHandle(RESTfulModelHandle):
538
+ def transcriptions(
561
539
  self,
562
- prompt: str,
563
- generate_config: Optional["ChatglmCppGenerateConfig"] = None,
564
- ) -> Union["Completion", Iterator["CompletionChunk"]]:
540
+ audio: bytes,
541
+ language: Optional[str] = None,
542
+ prompt: Optional[str] = None,
543
+ response_format: Optional[str] = "json",
544
+ temperature: Optional[float] = 0,
545
+ ):
565
546
  """
566
- Given a prompt, the ChatGLM model will generate a response via RESTful APIs.
547
+ Transcribes audio into the input language.
567
548
 
568
549
  Parameters
569
550
  ----------
570
- prompt: str
571
- The user's input.
572
- generate_config: Optional["ChatglmCppGenerateConfig"]
573
- Additional configuration for ChatGLM chat generation.
551
+
552
+ audio: bytes
553
+ The audio file object (not file name) to transcribe, in one of these formats: flac, mp3, mp4, mpeg,
554
+ mpga, m4a, ogg, wav, or webm.
555
+ language: Optional[str]
556
+ The language of the input audio. Supplying the input language in ISO-639-1
557
+ (https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes) format will improve accuracy and latency.
558
+ prompt: Optional[str]
559
+ An optional text to guide the model's style or continue a previous audio segment.
560
+ The prompt should match the audio language.
561
+ response_format: Optional[str], defaults to json
562
+ The format of the transcript output, in one of these options: json, text, srt, verbose_json, or vtt.
563
+ temperature: Optional[float], defaults to 0
564
+ The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random,
565
+ while lower values like 0.2 will make it more focused and deterministic.
566
+ If set to 0, the model will use log probability to automatically increase the temperature
567
+ until certain thresholds are hit.
574
568
 
575
569
  Returns
576
570
  -------
577
- Union["Completion", Iterator["CompletionChunk"]]
578
- Stream is a parameter in generate_config.
579
- When stream is set to True, the function will return Iterator["CompletionChunk"].
580
- When stream is set to False, the function will return "Completion".
571
+ The transcribed text.
572
+ """
573
+ url = f"{self._base_url}/v1/audio/transcriptions"
574
+ params = {
575
+ "model": self._model_uid,
576
+ "language": language,
577
+ "prompt": prompt,
578
+ "response_format": response_format,
579
+ "temperature": temperature,
580
+ }
581
+ files: List[Any] = []
582
+ for key, value in params.items():
583
+ files.append((key, (None, value)))
584
+ files.append(("file", ("file", audio, "application/octet-stream")))
585
+ response = requests.post(url, files=files, headers=self.auth_headers)
586
+ if response.status_code != 200:
587
+ raise RuntimeError(
588
+ f"Failed to transcribe the audio, detail: {_get_error_string(response)}"
589
+ )
581
590
 
582
- Raises
583
- ------
584
- RuntimeError
585
- Report the failure to generate the content from the server. Detailed information provided in error message.
591
+ response_data = response.json()
592
+ return response_data
586
593
 
594
+ def translations(
595
+ self,
596
+ audio: bytes,
597
+ prompt: Optional[str] = None,
598
+ response_format: Optional[str] = "json",
599
+ temperature: Optional[float] = 0,
600
+ ):
587
601
  """
602
+ Translates audio into English.
588
603
 
589
- url = f"{self._base_url}/v1/completions"
590
-
591
- request_body: Dict[str, Any] = {"model": self._model_uid, "prompt": prompt}
592
- if generate_config is not None:
593
- for key, value in generate_config.items():
594
- request_body[key] = value
604
+ Parameters
605
+ ----------
595
606
 
596
- stream = bool(generate_config and generate_config.get("stream"))
607
+ audio: bytes
608
+ The audio file object (not file name) to transcribe, in one of these formats: flac, mp3, mp4, mpeg,
609
+ mpga, m4a, ogg, wav, or webm.
610
+ prompt: Optional[str]
611
+ An optional text to guide the model's style or continue a previous audio segment.
612
+ The prompt should match the audio language.
613
+ response_format: Optional[str], defaults to json
614
+ The format of the transcript output, in one of these options: json, text, srt, verbose_json, or vtt.
615
+ temperature: Optional[float], defaults to 0
616
+ The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random,
617
+ while lower values like 0.2 will make it more focused and deterministic.
618
+ If set to 0, the model will use log probability to automatically increase the temperature
619
+ until certain thresholds are hit.
597
620
 
598
- response = requests.post(
599
- url, json=request_body, stream=stream, headers=self.auth_headers
600
- )
621
+ Returns
622
+ -------
623
+ The translated text.
624
+ """
625
+ url = f"{self._base_url}/v1/audio/translations"
626
+ params = {
627
+ "model": self._model_uid,
628
+ "prompt": prompt,
629
+ "response_format": response_format,
630
+ "temperature": temperature,
631
+ }
632
+ files: List[Any] = []
633
+ for key, value in params.items():
634
+ files.append((key, (None, value)))
635
+ files.append(("file", ("file", audio, "application/octet-stream")))
636
+ response = requests.post(url, files=files, headers=self.auth_headers)
601
637
  if response.status_code != 200:
602
638
  raise RuntimeError(
603
- f"Failed to generate completion, detail: {response.json()['detail']}"
639
+ f"Failed to translate the audio, detail: {_get_error_string(response)}"
604
640
  )
605
641
 
606
- if stream:
607
- return streaming_response_iterator(response.iter_lines())
608
-
609
642
  response_data = response.json()
610
643
  return response_data
611
644
 
@@ -889,8 +922,8 @@ class Client:
889
922
  return RESTfulRerankModelHandle(
890
923
  model_uid, self.base_url, auth_headers=self._headers
891
924
  )
892
- elif desc["model_type"] == "multimodal":
893
- return RESTfulMultimodalModelHandle(
925
+ elif desc["model_type"] == "audio":
926
+ return RESTfulAudioModelHandle(
894
927
  model_uid, self.base_url, auth_headers=self._headers
895
928
  )
896
929
  else:
xinference/constants.py CHANGED
@@ -18,8 +18,12 @@ from pathlib import Path
18
18
  XINFERENCE_ENV_ENDPOINT = "XINFERENCE_ENDPOINT"
19
19
  XINFERENCE_ENV_MODEL_SRC = "XINFERENCE_MODEL_SRC"
20
20
  XINFERENCE_ENV_HOME_PATH = "XINFERENCE_HOME"
21
- XINFERENCE_ENV_HEALTH_CHECK_ATTEMPTS = "XINFERENCE_HEALTH_CHECK_ATTEMPTS"
21
+ XINFERENCE_ENV_HEALTH_CHECK_FAILURE_THRESHOLD = (
22
+ "XINFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD"
23
+ )
22
24
  XINFERENCE_ENV_HEALTH_CHECK_INTERVAL = "XINFERENCE_HEALTH_CHECK_INTERVAL"
25
+ XINFERENCE_ENV_HEALTH_CHECK_TIMEOUT = "XINFERENCE_HEALTH_CHECK_TIMEOUT"
26
+ XINFERENCE_ENV_DISABLE_HEALTH_CHECK = "XINFERENCE_DISABLE_HEALTH_CHECK"
23
27
  XINFERENCE_ENV_DISABLE_VLLM = "XINFERENCE_DISABLE_VLLM"
24
28
 
25
29
 
@@ -47,10 +51,16 @@ XINFERENCE_DEFAULT_ENDPOINT_PORT = 9997
47
51
  XINFERENCE_DEFAULT_LOG_FILE_NAME = "xinference.log"
48
52
  XINFERENCE_LOG_MAX_BYTES = 100 * 1024 * 1024
49
53
  XINFERENCE_LOG_BACKUP_COUNT = 30
50
- XINFERENCE_HEALTH_CHECK_ATTEMPTS = int(
51
- os.environ.get(XINFERENCE_ENV_HEALTH_CHECK_ATTEMPTS, 3)
54
+ XINFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD = int(
55
+ os.environ.get(XINFERENCE_ENV_HEALTH_CHECK_FAILURE_THRESHOLD, 5)
52
56
  )
53
57
  XINFERENCE_HEALTH_CHECK_INTERVAL = int(
54
- os.environ.get(XINFERENCE_ENV_HEALTH_CHECK_INTERVAL, 3)
58
+ os.environ.get(XINFERENCE_ENV_HEALTH_CHECK_INTERVAL, 5)
59
+ )
60
+ XINFERENCE_HEALTH_CHECK_TIMEOUT = int(
61
+ os.environ.get(XINFERENCE_ENV_HEALTH_CHECK_TIMEOUT, 10)
62
+ )
63
+ XINFERENCE_DISABLE_HEALTH_CHECK = bool(
64
+ int(os.environ.get(XINFERENCE_ENV_DISABLE_HEALTH_CHECK, 0))
55
65
  )
56
66
  XINFERENCE_DISABLE_VLLM = bool(int(os.environ.get(XINFERENCE_ENV_DISABLE_VLLM, 0)))
@@ -0,0 +1,102 @@
1
+ # Copyright 2022-2024 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from logging import getLogger
15
+ from typing import Dict, List, Optional
16
+
17
+ import xoscar as xo
18
+
19
+ logger = getLogger(__name__)
20
+
21
+
22
+ class CacheTrackerActor(xo.Actor):
23
+ def __init__(self):
24
+ super().__init__()
25
+ self._model_name_to_version_info: Dict[str, List[Dict]] = {}
26
+
27
+ @classmethod
28
+ def uid(cls) -> str:
29
+ return "cache_tracker"
30
+
31
+ @staticmethod
32
+ def _map_address_to_file_location(
33
+ model_version: Dict[str, List[Dict]], address: str
34
+ ):
35
+ for model_name, model_versions in model_version.items():
36
+ for info_dict in model_versions:
37
+ info_dict["model_file_location"] = (
38
+ {address: info_dict["model_file_location"]}
39
+ if info_dict["cache_status"]
40
+ else None
41
+ )
42
+
43
+ @staticmethod
44
+ def _update_file_location(data: Dict, origin_version_info: Dict):
45
+ if origin_version_info["model_file_location"] is None:
46
+ origin_version_info["model_file_location"] = data
47
+ else:
48
+ assert isinstance(origin_version_info["model_file_location"], dict)
49
+ origin_version_info["model_file_location"].update(data)
50
+
51
+ def record_model_version(self, version_info: Dict[str, List[Dict]], address: str):
52
+ self._map_address_to_file_location(version_info, address)
53
+ for model_name, model_versions in version_info.items():
54
+ if model_name not in self._model_name_to_version_info:
55
+ self._model_name_to_version_info[model_name] = model_versions
56
+ else:
57
+ assert len(model_versions) == len(
58
+ self._model_name_to_version_info[model_name]
59
+ ), "Model version info inconsistency between supervisor and worker"
60
+ for version, origin_version in zip(
61
+ model_versions, self._model_name_to_version_info[model_name]
62
+ ):
63
+ if (
64
+ version["cache_status"]
65
+ and version["model_file_location"] is not None
66
+ ):
67
+ origin_version["cache_status"] = True
68
+ self._update_file_location(
69
+ version["model_file_location"], origin_version
70
+ )
71
+
72
+ def update_cache_status(
73
+ self,
74
+ address: str,
75
+ model_name: str,
76
+ model_version: Optional[str],
77
+ model_path: str,
78
+ ):
79
+ if model_name not in self._model_name_to_version_info:
80
+ logger.warning(f"Not record version info for {model_name} for now.")
81
+ else:
82
+ for version_info in self._model_name_to_version_info[model_name]:
83
+ if model_version is None: # image model
84
+ self._update_file_location({address: model_path}, version_info)
85
+ version_info["cache_status"] = True
86
+ else:
87
+ if version_info["model_version"] == model_version:
88
+ self._update_file_location({address: model_path}, version_info)
89
+ version_info["cache_status"] = True
90
+
91
+ def unregister_model_version(self, model_name: str):
92
+ self._model_name_to_version_info.pop(model_name, None)
93
+
94
+ def get_model_versions(self, model_name: str) -> List[Dict]:
95
+ if model_name not in self._model_name_to_version_info:
96
+ logger.warning(f"Not record version info for model_name: {model_name}")
97
+ return []
98
+ else:
99
+ return self._model_name_to_version_info[model_name]
100
+
101
+ def get_model_version_count(self, model_name: str) -> int:
102
+ return len(self.get_model_versions(model_name))
@@ -27,7 +27,6 @@ from ..client.restful.restful_client import (
27
27
  RESTfulChatglmCppChatModelHandle,
28
28
  RESTfulChatModelHandle,
29
29
  RESTfulGenerateModelHandle,
30
- RESTfulMultimodalModelHandle,
31
30
  )
32
31
  from ..types import ChatCompletionMessage
33
32
 
@@ -66,7 +65,7 @@ class GradioInterface:
66
65
  )
67
66
 
68
67
  def build(self) -> "gr.Blocks":
69
- if self.model_type == "multimodal":
68
+ if "vision" in self.model_ability:
70
69
  interface = self.build_chat_vl_interface()
71
70
  elif "chat" in self.model_ability:
72
71
  interface = self.build_chat_interface()
@@ -99,9 +98,16 @@ class GradioInterface:
99
98
  return flat_list
100
99
 
101
100
  def to_chat(lst: List[str]) -> List[ChatCompletionMessage]:
101
+ from ..model.llm import BUILTIN_LLM_PROMPT_STYLE
102
+
102
103
  res = []
104
+ prompt_style = BUILTIN_LLM_PROMPT_STYLE.get(self.model_name)
105
+ if prompt_style is None:
106
+ roles = ["assistant", "user"]
107
+ else:
108
+ roles = prompt_style.roles
103
109
  for i in range(len(lst)):
104
- role = "assistant" if i % 2 == 1 else "user"
110
+ role = roles[0] if i % 2 == 1 else roles[1]
105
111
  res.append(ChatCompletionMessage(role=role, content=lst[i]))
106
112
  return res
107
113
 
@@ -191,7 +197,7 @@ class GradioInterface:
191
197
  client = RESTfulClient(self.endpoint)
192
198
  client._set_token(self._access_token)
193
199
  model = client.get_model(self.model_uid)
194
- assert isinstance(model, RESTfulMultimodalModelHandle)
200
+ assert isinstance(model, RESTfulChatModelHandle)
195
201
 
196
202
  prompt = history[-1]
197
203
  assert prompt["role"] == "user"
@@ -0,0 +1,56 @@
1
+ # Copyright 2022-2024 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import queue
16
+ from collections import defaultdict
17
+ from enum import Enum
18
+ from typing import Dict, List, TypedDict
19
+
20
+ import xoscar as xo
21
+
22
+ MAX_EVENT_COUNT_PER_MODEL = 100
23
+
24
+
25
+ class EventType(Enum):
26
+ INFO = 1
27
+ WARNING = 2
28
+ ERROR = 3
29
+
30
+
31
+ class Event(TypedDict):
32
+ event_type: EventType
33
+ event_ts: int
34
+ event_content: str
35
+
36
+
37
+ class EventCollectorActor(xo.StatelessActor):
38
+ def __init__(self):
39
+ super().__init__()
40
+ self._model_uid_to_events: Dict[str, queue.Queue] = defaultdict(
41
+ lambda: queue.Queue(maxsize=MAX_EVENT_COUNT_PER_MODEL)
42
+ )
43
+
44
+ @classmethod
45
+ def uid(cls) -> str:
46
+ return "event_collector"
47
+
48
+ def get_model_events(self, model_uid: str) -> List[Dict]:
49
+ event_queue = self._model_uid_to_events.get(model_uid)
50
+ if event_queue is None:
51
+ return []
52
+ else:
53
+ return [dict(e, event_type=e["event_type"].name) for e in event_queue.queue]
54
+
55
+ def report_event(self, model_uid: str, event: Event):
56
+ self._model_uid_to_events[model_uid].put(event)
xinference/core/model.py CHANGED
@@ -426,6 +426,50 @@ class ModelActor(xo.StatelessActor):
426
426
  )
427
427
  raise AttributeError(f"Model {self._model.model_spec} is not for reranking.")
428
428
 
429
+ @log_async(logger=logger, args_formatter=lambda _, kwargs: kwargs.pop("audio"))
430
+ @request_limit
431
+ async def transcriptions(
432
+ self,
433
+ audio: bytes,
434
+ language: Optional[str] = None,
435
+ prompt: Optional[str] = None,
436
+ response_format: str = "json",
437
+ temperature: float = 0,
438
+ ):
439
+ if hasattr(self._model, "transcriptions"):
440
+ return await self._call_wrapper(
441
+ self._model.transcriptions,
442
+ audio,
443
+ language,
444
+ prompt,
445
+ response_format,
446
+ temperature,
447
+ )
448
+ raise AttributeError(
449
+ f"Model {self._model.model_spec} is not for creating transcriptions."
450
+ )
451
+
452
+ @log_async(logger=logger, args_formatter=lambda _, kwargs: kwargs.pop("audio"))
453
+ @request_limit
454
+ async def translations(
455
+ self,
456
+ audio: bytes,
457
+ prompt: Optional[str] = None,
458
+ response_format: str = "json",
459
+ temperature: float = 0,
460
+ ):
461
+ if hasattr(self._model, "translations"):
462
+ return await self._call_wrapper(
463
+ self._model.translations,
464
+ audio,
465
+ prompt,
466
+ response_format,
467
+ temperature,
468
+ )
469
+ raise AttributeError(
470
+ f"Model {self._model.model_spec} is not for creating translations."
471
+ )
472
+
429
473
  @log_async(logger=logger)
430
474
  @request_limit
431
475
  async def text_to_image(
@@ -13,10 +13,12 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from dataclasses import dataclass
16
- from typing import Dict
16
+ from typing import Dict, Union
17
17
 
18
18
  import psutil
19
19
 
20
+ from .utils import get_nvidia_gpu_info
21
+
20
22
 
21
23
  @dataclass
22
24
  class ResourceStatus:
@@ -26,7 +28,14 @@ class ResourceStatus:
26
28
  memory_total: float
27
29
 
28
30
 
29
- def gather_node_info() -> Dict[str, ResourceStatus]:
31
+ @dataclass
32
+ class GPUStatus:
33
+ mem_total: float
34
+ mem_free: float
35
+ mem_used: float
36
+
37
+
38
+ def gather_node_info() -> Dict[str, Union[ResourceStatus, GPUStatus]]:
30
39
  node_resource = dict()
31
40
  mem_info = psutil.virtual_memory()
32
41
  node_resource["cpu"] = ResourceStatus(
@@ -35,13 +44,11 @@ def gather_node_info() -> Dict[str, ResourceStatus]:
35
44
  memory_available=mem_info.available,
36
45
  memory_total=mem_info.total,
37
46
  )
38
- # TODO: record GPU stats
39
- # for idx, gpu_card_stat in enumerate(resource.cuda_card_stats()):
40
- # node_resource[f"gpu-{idx}"] = ResourceStatus(
41
- # available=gpu_card_stat.gpu_usage / 100.0,
42
- # total=1,
43
- # memory_available=gpu_card_stat.fb_mem_info.available,
44
- # memory_total=gpu_card_stat.fb_mem_info.total,
45
- # )
46
-
47
- return node_resource
47
+ for gpu_idx, gpu_info in get_nvidia_gpu_info().items():
48
+ node_resource[gpu_idx] = GPUStatus( # type: ignore
49
+ mem_total=gpu_info["total"],
50
+ mem_used=gpu_info["used"],
51
+ mem_free=gpu_info["free"],
52
+ )
53
+
54
+ return node_resource # type: ignore
@@ -33,6 +33,7 @@ class LaunchStatus(Enum):
33
33
  class InstanceInfo(BaseModel):
34
34
  model_name: str
35
35
  model_uid: str
36
+ model_version: Optional[str]
36
37
  model_ability: List[str]
37
38
  replica: int
38
39
  status: str
@@ -82,5 +83,8 @@ class StatusGuardActor(xo.StatelessActor):
82
83
  else self._drop_terminated_info(all_infos)
83
84
  )
84
85
 
86
+ def get_instance_count(self, model_name: str) -> int:
87
+ return len(self.get_instance_info(model_name=model_name))
88
+
85
89
  def update_instance_info(self, model_uid: str, info: Dict):
86
90
  self._model_uid_to_info[model_uid].update(**info)