xinference 1.2.0__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (124) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +4 -7
  3. xinference/client/handlers.py +3 -0
  4. xinference/core/chat_interface.py +6 -1
  5. xinference/core/model.py +2 -0
  6. xinference/core/scheduler.py +4 -7
  7. xinference/core/supervisor.py +114 -23
  8. xinference/core/worker.py +70 -4
  9. xinference/deploy/local.py +2 -1
  10. xinference/model/audio/core.py +11 -0
  11. xinference/model/audio/cosyvoice.py +16 -5
  12. xinference/model/audio/kokoro.py +139 -0
  13. xinference/model/audio/melotts.py +110 -0
  14. xinference/model/audio/model_spec.json +80 -0
  15. xinference/model/audio/model_spec_modelscope.json +18 -0
  16. xinference/model/audio/whisper.py +35 -10
  17. xinference/model/llm/llama_cpp/core.py +21 -14
  18. xinference/model/llm/llm_family.json +527 -1
  19. xinference/model/llm/llm_family.py +4 -1
  20. xinference/model/llm/llm_family_modelscope.json +495 -3
  21. xinference/model/llm/memory.py +1 -1
  22. xinference/model/llm/mlx/core.py +24 -6
  23. xinference/model/llm/transformers/core.py +9 -1
  24. xinference/model/llm/transformers/qwen2_audio.py +3 -1
  25. xinference/model/llm/transformers/qwen2_vl.py +20 -3
  26. xinference/model/llm/transformers/utils.py +22 -11
  27. xinference/model/llm/utils.py +115 -1
  28. xinference/model/llm/vllm/core.py +14 -4
  29. xinference/model/llm/vllm/xavier/block.py +3 -4
  30. xinference/model/llm/vllm/xavier/block_tracker.py +71 -58
  31. xinference/model/llm/vllm/xavier/collective.py +74 -0
  32. xinference/model/llm/vllm/xavier/collective_manager.py +147 -0
  33. xinference/model/llm/vllm/xavier/executor.py +18 -16
  34. xinference/model/llm/vllm/xavier/scheduler.py +79 -63
  35. xinference/model/llm/vllm/xavier/test/test_xavier.py +60 -35
  36. xinference/model/llm/vllm/xavier/transfer.py +53 -32
  37. xinference/thirdparty/cosyvoice/bin/spk2info.pt +0 -0
  38. xinference/thirdparty/melo/__init__.py +0 -0
  39. xinference/thirdparty/melo/api.py +135 -0
  40. xinference/thirdparty/melo/app.py +61 -0
  41. xinference/thirdparty/melo/attentions.py +459 -0
  42. xinference/thirdparty/melo/commons.py +160 -0
  43. xinference/thirdparty/melo/configs/config.json +94 -0
  44. xinference/thirdparty/melo/data/example/metadata.list +20 -0
  45. xinference/thirdparty/melo/data_utils.py +413 -0
  46. xinference/thirdparty/melo/download_utils.py +67 -0
  47. xinference/thirdparty/melo/infer.py +25 -0
  48. xinference/thirdparty/melo/init_downloads.py +14 -0
  49. xinference/thirdparty/melo/losses.py +58 -0
  50. xinference/thirdparty/melo/main.py +36 -0
  51. xinference/thirdparty/melo/mel_processing.py +174 -0
  52. xinference/thirdparty/melo/models.py +1030 -0
  53. xinference/thirdparty/melo/modules.py +598 -0
  54. xinference/thirdparty/melo/monotonic_align/__init__.py +16 -0
  55. xinference/thirdparty/melo/monotonic_align/core.py +46 -0
  56. xinference/thirdparty/melo/preprocess_text.py +135 -0
  57. xinference/thirdparty/melo/split_utils.py +174 -0
  58. xinference/thirdparty/melo/text/__init__.py +35 -0
  59. xinference/thirdparty/melo/text/chinese.py +199 -0
  60. xinference/thirdparty/melo/text/chinese_bert.py +107 -0
  61. xinference/thirdparty/melo/text/chinese_mix.py +253 -0
  62. xinference/thirdparty/melo/text/cleaner.py +36 -0
  63. xinference/thirdparty/melo/text/cleaner_multiling.py +110 -0
  64. xinference/thirdparty/melo/text/cmudict.rep +129530 -0
  65. xinference/thirdparty/melo/text/cmudict_cache.pickle +0 -0
  66. xinference/thirdparty/melo/text/english.py +284 -0
  67. xinference/thirdparty/melo/text/english_bert.py +39 -0
  68. xinference/thirdparty/melo/text/english_utils/__init__.py +0 -0
  69. xinference/thirdparty/melo/text/english_utils/abbreviations.py +35 -0
  70. xinference/thirdparty/melo/text/english_utils/number_norm.py +97 -0
  71. xinference/thirdparty/melo/text/english_utils/time_norm.py +47 -0
  72. xinference/thirdparty/melo/text/es_phonemizer/__init__.py +0 -0
  73. xinference/thirdparty/melo/text/es_phonemizer/base.py +140 -0
  74. xinference/thirdparty/melo/text/es_phonemizer/cleaner.py +109 -0
  75. xinference/thirdparty/melo/text/es_phonemizer/es_symbols.json +79 -0
  76. xinference/thirdparty/melo/text/es_phonemizer/es_symbols.txt +1 -0
  77. xinference/thirdparty/melo/text/es_phonemizer/es_symbols_v2.json +83 -0
  78. xinference/thirdparty/melo/text/es_phonemizer/es_to_ipa.py +12 -0
  79. xinference/thirdparty/melo/text/es_phonemizer/example_ipa.txt +400 -0
  80. xinference/thirdparty/melo/text/es_phonemizer/gruut_wrapper.py +253 -0
  81. xinference/thirdparty/melo/text/es_phonemizer/punctuation.py +174 -0
  82. xinference/thirdparty/melo/text/es_phonemizer/spanish_symbols.txt +1 -0
  83. xinference/thirdparty/melo/text/es_phonemizer/test.ipynb +124 -0
  84. xinference/thirdparty/melo/text/fr_phonemizer/__init__.py +0 -0
  85. xinference/thirdparty/melo/text/fr_phonemizer/base.py +140 -0
  86. xinference/thirdparty/melo/text/fr_phonemizer/cleaner.py +122 -0
  87. xinference/thirdparty/melo/text/fr_phonemizer/en_symbols.json +78 -0
  88. xinference/thirdparty/melo/text/fr_phonemizer/example_ipa.txt +1 -0
  89. xinference/thirdparty/melo/text/fr_phonemizer/fr_symbols.json +89 -0
  90. xinference/thirdparty/melo/text/fr_phonemizer/fr_to_ipa.py +30 -0
  91. xinference/thirdparty/melo/text/fr_phonemizer/french_abbreviations.py +48 -0
  92. xinference/thirdparty/melo/text/fr_phonemizer/french_symbols.txt +1 -0
  93. xinference/thirdparty/melo/text/fr_phonemizer/gruut_wrapper.py +258 -0
  94. xinference/thirdparty/melo/text/fr_phonemizer/punctuation.py +172 -0
  95. xinference/thirdparty/melo/text/french.py +94 -0
  96. xinference/thirdparty/melo/text/french_bert.py +39 -0
  97. xinference/thirdparty/melo/text/japanese.py +647 -0
  98. xinference/thirdparty/melo/text/japanese_bert.py +49 -0
  99. xinference/thirdparty/melo/text/ko_dictionary.py +44 -0
  100. xinference/thirdparty/melo/text/korean.py +192 -0
  101. xinference/thirdparty/melo/text/opencpop-strict.txt +429 -0
  102. xinference/thirdparty/melo/text/spanish.py +122 -0
  103. xinference/thirdparty/melo/text/spanish_bert.py +39 -0
  104. xinference/thirdparty/melo/text/symbols.py +290 -0
  105. xinference/thirdparty/melo/text/tone_sandhi.py +769 -0
  106. xinference/thirdparty/melo/train.py +635 -0
  107. xinference/thirdparty/melo/train.sh +19 -0
  108. xinference/thirdparty/melo/transforms.py +209 -0
  109. xinference/thirdparty/melo/utils.py +424 -0
  110. xinference/types.py +2 -0
  111. xinference/web/ui/build/asset-manifest.json +3 -3
  112. xinference/web/ui/build/index.html +1 -1
  113. xinference/web/ui/build/static/js/{main.1eb206d1.js → main.b0936c54.js} +3 -3
  114. xinference/web/ui/build/static/js/main.b0936c54.js.map +1 -0
  115. xinference/web/ui/node_modules/.cache/babel-loader/a3ff866acddf34917a7ee399e0e571a4dfd8ba66d5057db885f243e16a6eb17d.json +1 -0
  116. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/METADATA +37 -27
  117. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/RECORD +122 -45
  118. xinference/web/ui/build/static/js/main.1eb206d1.js.map +0 -1
  119. xinference/web/ui/node_modules/.cache/babel-loader/2213d49de260e1f67c888081b18f120f5225462b829ae57c9e05a05cec83689d.json +0 -1
  120. /xinference/web/ui/build/static/js/{main.1eb206d1.js.LICENSE.txt → main.b0936c54.js.LICENSE.txt} +0 -0
  121. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/LICENSE +0 -0
  122. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/WHEEL +0 -0
  123. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/entry_points.txt +0 -0
  124. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,139 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ from io import BytesIO
16
+ from typing import TYPE_CHECKING, Optional
17
+
18
+ import numpy as np
19
+
20
+ from ...device_utils import get_available_device, is_device_available
21
+
22
+ if TYPE_CHECKING:
23
+ from .core import AudioModelFamilyV1
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class KokoroModel:
29
+ # The available voices, should keep sync with https://huggingface.co/hexgrad/Kokoro-82M/tree/main/voices
30
+ VOICES = [
31
+ "af_alloy",
32
+ "af_aoede",
33
+ "af_bella",
34
+ "af_jessica",
35
+ "af_kore",
36
+ "af_nicole",
37
+ "af_nova",
38
+ "af_river",
39
+ "af_sarah",
40
+ "af_sky",
41
+ "am_adam",
42
+ "am_echo",
43
+ "am_eric",
44
+ "am_fenrir",
45
+ "am_liam",
46
+ "am_michael",
47
+ "am_onyx",
48
+ "am_puck",
49
+ "bf_alice",
50
+ "bf_emma",
51
+ "bf_isabella",
52
+ "bf_lily",
53
+ "bm_daniel",
54
+ "bm_fable",
55
+ "bm_george",
56
+ "bm_lewis",
57
+ ]
58
+
59
+ def __init__(
60
+ self,
61
+ model_uid: str,
62
+ model_path: str,
63
+ model_spec: "AudioModelFamilyV1",
64
+ device: Optional[str] = None,
65
+ **kwargs,
66
+ ):
67
+ self._model_uid = model_uid
68
+ self._model_path = model_path
69
+ self._model_spec = model_spec
70
+ self._device = device
71
+ self._model = None
72
+ self._kwargs = kwargs
73
+
74
+ @property
75
+ def model_ability(self):
76
+ return self._model_spec.model_ability
77
+
78
+ def load(self):
79
+ if self._device is None:
80
+ self._device = get_available_device()
81
+ else:
82
+ if not is_device_available(self._device):
83
+ raise ValueError(f"Device {self._device} is not available!")
84
+
85
+ import os
86
+
87
+ from kokoro import KModel, KPipeline
88
+
89
+ config_path = os.path.join(self._model_path, "config.json")
90
+ model_path = os.path.join(self._model_path, "kokoro-v1_0.pth")
91
+ # LANG_CODES = dict(
92
+ # a='American English',
93
+ # b='British English',
94
+ # )
95
+ lang_code = self._kwargs.get("lang_code", "a")
96
+ self._model = KPipeline(
97
+ lang_code=lang_code,
98
+ model=KModel(config=config_path, model=model_path),
99
+ device=self._device,
100
+ )
101
+
102
+ def speech(
103
+ self,
104
+ input: str,
105
+ voice: str,
106
+ response_format: str = "mp3",
107
+ speed: float = 1.0,
108
+ stream: bool = False,
109
+ **kwargs,
110
+ ):
111
+ import soundfile
112
+
113
+ if stream:
114
+ raise Exception("Kokoro does not support stream mode.")
115
+ assert self._model is not None
116
+ if not voice:
117
+ voice = next(iter(self.VOICES))
118
+ logger.info("Auto select speaker: %s", voice)
119
+ elif not voice.endswith(".pt") and voice not in self.VOICES:
120
+ raise ValueError(
121
+ f"Invalid voice: {voice}, available speakers: {self.VOICES}"
122
+ )
123
+ else:
124
+ logger.info("Using custom voice pt: %s", voice)
125
+ logger.info("Speech kwargs: %s", kwargs)
126
+ generator = self._model(text=input, voice=voice, speed=speed, **kwargs)
127
+ results = list(generator)
128
+ audio = np.concatenate([r[2] for r in results])
129
+ # Save the generated audio
130
+ with BytesIO() as out:
131
+ with soundfile.SoundFile(
132
+ out,
133
+ "w",
134
+ 24000,
135
+ 1,
136
+ format=response_format.upper(),
137
+ ) as f:
138
+ f.write(audio)
139
+ return out.getvalue()
@@ -0,0 +1,110 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ from io import BytesIO
16
+ from typing import TYPE_CHECKING, Optional
17
+
18
+ from ...device_utils import get_available_device, is_device_available
19
+
20
+ if TYPE_CHECKING:
21
+ from .core import AudioModelFamilyV1
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class MeloTTSModel:
27
+ def __init__(
28
+ self,
29
+ model_uid: str,
30
+ model_path: str,
31
+ model_spec: "AudioModelFamilyV1",
32
+ device: Optional[str] = None,
33
+ **kwargs,
34
+ ):
35
+ self._model_uid = model_uid
36
+ self._model_path = model_path
37
+ self._model_spec = model_spec
38
+ self._device = device
39
+ self._model = None
40
+ self._kwargs = kwargs
41
+
42
+ @property
43
+ def model_ability(self):
44
+ return self._model_spec.model_ability
45
+
46
+ def load(self):
47
+ if self._device is None:
48
+ self._device = get_available_device()
49
+ else:
50
+ if not is_device_available(self._device):
51
+ raise ValueError(f"Device {self._device} is not available!")
52
+
53
+ import os
54
+ import sys
55
+
56
+ import nltk
57
+
58
+ # English language requires download averaged_perceptron_tagger_eng
59
+ nltk.download("averaged_perceptron_tagger_eng")
60
+
61
+ # The yaml config loaded from model has hard-coded the import paths. please refer to: load_hyperpyyaml
62
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../thirdparty"))
63
+
64
+ from melo.api import TTS
65
+
66
+ config_path = os.path.join(self._model_path, "config.json")
67
+ ckpt_path = os.path.join(self._model_path, "checkpoint.pth")
68
+ self._model = TTS(
69
+ language=self._model_spec.language,
70
+ device=self._device,
71
+ config_path=config_path,
72
+ ckpt_path=ckpt_path,
73
+ )
74
+
75
+ def speech(
76
+ self,
77
+ input: str,
78
+ voice: str,
79
+ response_format: str = "mp3",
80
+ speed: float = 1.0,
81
+ stream: bool = False,
82
+ **kwargs,
83
+ ):
84
+ import soundfile
85
+
86
+ if stream:
87
+ raise Exception("MeloTTS does not support stream mode.")
88
+ assert self._model is not None
89
+ speaker_ids = self._model.hps.data.spk2id
90
+ if not voice:
91
+ voice = next(iter(speaker_ids.keys()))
92
+ logger.info("Auto select speaker: %s", voice)
93
+ elif voice not in speaker_ids:
94
+ raise ValueError(
95
+ f"Invalid voice: {voice}, available speakers: {speaker_ids}"
96
+ )
97
+ audio = self._model.tts_to_file(
98
+ text=input, speaker_id=speaker_ids[voice], speed=speed, **kwargs
99
+ )
100
+ # Save the generated audio
101
+ with BytesIO() as out:
102
+ with soundfile.SoundFile(
103
+ out,
104
+ "w",
105
+ self._model.hps.data.sampling_rate,
106
+ 1,
107
+ format=response_format.upper(),
108
+ ) as f:
109
+ f.write(audio)
110
+ return out.getvalue()
@@ -266,5 +266,85 @@
266
266
  "model_revision": "7642bb232e3fcacf92c51c786edebb8624da6b93",
267
267
  "model_ability": "text-to-audio",
268
268
  "multilingual": true
269
+ },
270
+ {
271
+ "model_name": "MeloTTS-English",
272
+ "model_family": "MeloTTS",
273
+ "model_id": "myshell-ai/MeloTTS-English",
274
+ "model_revision": "bb4fb7346d566d277ba8c8c7dbfdf6786139b8ef",
275
+ "model_ability": "text-to-audio",
276
+ "multilingual": false,
277
+ "language": "EN"
278
+ },
279
+ {
280
+ "model_name": "MeloTTS-English-v2",
281
+ "model_family": "MeloTTS",
282
+ "model_id": "myshell-ai/MeloTTS-English-v2",
283
+ "model_revision": "a53e3509c4ee4ff16d79272feb2474ff864e18f3",
284
+ "model_ability": "text-to-audio",
285
+ "multilingual": false,
286
+ "language": "EN"
287
+ },
288
+ {
289
+ "model_name": "MeloTTS-English-v3",
290
+ "model_family": "MeloTTS",
291
+ "model_id": "myshell-ai/MeloTTS-English-v3",
292
+ "model_revision": "f7c4a35392c0e9be24a755f1edb4c3f63040f759",
293
+ "model_ability": "text-to-audio",
294
+ "multilingual": false,
295
+ "language": "EN"
296
+ },
297
+ {
298
+ "model_name": "MeloTTS-French",
299
+ "model_family": "MeloTTS",
300
+ "model_id": "myshell-ai/MeloTTS-French",
301
+ "model_revision": "1e9bf590262392d8bffb679b0a3b0c16b0f9fdaf",
302
+ "model_ability": "text-to-audio",
303
+ "multilingual": false,
304
+ "language": "FR"
305
+ },
306
+ {
307
+ "model_name": "MeloTTS-Japanese",
308
+ "model_family": "MeloTTS",
309
+ "model_id": "myshell-ai/MeloTTS-Japanese",
310
+ "model_revision": "367f8795464b531b4e97c1515bddfc1243e60891",
311
+ "model_ability": "text-to-audio",
312
+ "multilingual": false,
313
+ "language": "JP"
314
+ },
315
+ {
316
+ "model_name": "MeloTTS-Spanish",
317
+ "model_family": "MeloTTS",
318
+ "model_id": "myshell-ai/MeloTTS-Spanish",
319
+ "model_revision": "dbb5496df39d11a66c1d5f5a9ca357c3c9fb95fb",
320
+ "model_ability": "text-to-audio",
321
+ "multilingual": false,
322
+ "language": "ES"
323
+ },
324
+ {
325
+ "model_name": "MeloTTS-Chinese",
326
+ "model_family": "MeloTTS",
327
+ "model_id": "myshell-ai/MeloTTS-Chinese",
328
+ "model_revision": "af5d207a364ea4208c6f589c89f57f88414bdd16",
329
+ "model_ability": "text-to-audio",
330
+ "multilingual": false,
331
+ "language": "ZH"
332
+ },
333
+ {
334
+ "model_name": "MeloTTS-Korean",
335
+ "model_family": "MeloTTS",
336
+ "model_id": "myshell-ai/MeloTTS-Korean",
337
+ "model_revision": "0207e5adfc90129a51b6b03d89be6d84360ed323",
338
+ "model_ability": "text-to-audio",
339
+ "multilingual": false,
340
+ "language": "KR"
341
+ },
342
+ {
343
+ "model_name": "Kokoro-82M",
344
+ "model_family": "Kokoro",
345
+ "model_id": "hexgrad/Kokoro-82M",
346
+ "model_revision": "7a29fcdf8e997bac6d6f5f6f0c2f0b92912f6102",
347
+ "model_ability": "text-to-audio",
348
+ "multilingual": true
269
349
  }
270
350
  ]
@@ -17,6 +17,15 @@
17
17
  "model_ability": "audio-to-text",
18
18
  "multilingual": true
19
19
  },
20
+ {
21
+ "model_name": "Belle-whisper-large-v3-zh",
22
+ "model_family": "whisper",
23
+ "model_hub": "modelscope",
24
+ "model_id": "Xorbits/Belle-whisper-large-v3-zh",
25
+ "model_revision": "master",
26
+ "model_ability": "audio-to-text",
27
+ "multilingual": false
28
+ },
20
29
  {
21
30
  "model_name": "SenseVoiceSmall",
22
31
  "model_family": "funasr",
@@ -91,5 +100,14 @@
91
100
  "model_revision": "master",
92
101
  "model_ability": "text-to-audio",
93
102
  "multilingual": true
103
+ },
104
+ {
105
+ "model_name": "Kokoro-82M",
106
+ "model_family": "Kokoro",
107
+ "model_hub": "modelscope",
108
+ "model_id": "AI-ModelScope/Kokoro-82M",
109
+ "model_revision": "master",
110
+ "model_ability": "text-to-audio",
111
+ "multilingual": true
94
112
  }
95
113
  ]
@@ -13,9 +13,12 @@
13
13
  # limitations under the License.
14
14
  import logging
15
15
  import os
16
+ import typing
16
17
  from glob import glob
17
18
  from typing import TYPE_CHECKING, Dict, List, Optional, Union
18
19
 
20
+ from typing_extensions import TypedDict
21
+
19
22
  from ...device_utils import (
20
23
  get_available_device,
21
24
  get_device_preferred_dtype,
@@ -28,6 +31,13 @@ if TYPE_CHECKING:
28
31
  logger = logging.getLogger(__name__)
29
32
 
30
33
 
34
+ class WhisperModelConfig(TypedDict, total=False):
35
+ chunk_length_s: Optional[float]
36
+ stride_length_s: Optional[float]
37
+ return_timestamps: Optional[bool]
38
+ batch_size: Optional[int]
39
+
40
+
31
41
  class WhisperModel:
32
42
  def __init__(
33
43
  self,
@@ -35,6 +45,7 @@ class WhisperModel:
35
45
  model_path: str,
36
46
  model_spec: "AudioModelFamilyV1",
37
47
  device: Optional[str] = None,
48
+ max_new_tokens: Optional[int] = 128,
38
49
  **kwargs,
39
50
  ):
40
51
  self._model_uid = model_uid
@@ -42,7 +53,21 @@ class WhisperModel:
42
53
  self._model_spec = model_spec
43
54
  self._device = device
44
55
  self._model = None
45
- self._kwargs = kwargs
56
+ self._max_new_tokens = max_new_tokens
57
+ self._model_config: WhisperModelConfig = self._sanitize_model_config(
58
+ typing.cast(WhisperModelConfig, kwargs)
59
+ )
60
+
61
+ def _sanitize_model_config(
62
+ self, model_config: Optional[WhisperModelConfig]
63
+ ) -> WhisperModelConfig:
64
+ if model_config is None:
65
+ model_config = WhisperModelConfig()
66
+ model_config.setdefault("chunk_length_s", 30)
67
+ model_config.setdefault("stride_length_s", None)
68
+ model_config.setdefault("return_timestamps", False)
69
+ model_config.setdefault("batch_size", 16)
70
+ return model_config
46
71
 
47
72
  @property
48
73
  def model_ability(self):
@@ -75,10 +100,10 @@ class WhisperModel:
75
100
  model=model,
76
101
  tokenizer=processor.tokenizer,
77
102
  feature_extractor=processor.feature_extractor,
78
- max_new_tokens=128,
79
- chunk_length_s=30,
80
- batch_size=16,
81
- return_timestamps=False,
103
+ chunk_length_s=self._model_config.get("chunk_length_s"),
104
+ stride_length_s=self._model_config.get("stride_length_s"),
105
+ return_timestamps=self._model_config.get("return_timestamps"),
106
+ batch_size=self._model_config.get("batch_size"),
82
107
  torch_dtype=torch_dtype,
83
108
  device=self._device,
84
109
  )
@@ -185,13 +210,13 @@ class WhisperModel:
185
210
  logger.warning(
186
211
  "Prompt for whisper transcriptions will be ignored: %s", prompt
187
212
  )
213
+ generate_kwargs = {"max_new_tokens": self._max_new_tokens, "task": "transcribe"}
214
+ if language is not None:
215
+ generate_kwargs["language"] = language
216
+
188
217
  return self._call_model(
189
218
  audio=audio,
190
- generate_kwargs=(
191
- {"language": language, "task": "transcribe"}
192
- if language is not None
193
- else {"task": "transcribe"}
194
- ),
219
+ generate_kwargs=generate_kwargs,
195
220
  response_format=response_format,
196
221
  temperature=temperature,
197
222
  timestamp_granularities=timestamp_granularities,
@@ -28,7 +28,7 @@ from ....types import (
28
28
  )
29
29
  from ..core import LLM
30
30
  from ..llm_family import LLMFamilyV1, LLMSpecV1
31
- from ..utils import QWEN_TOOL_CALL_FAMILY, ChatModelMixin
31
+ from ..utils import DEEPSEEK_TOOL_CALL_FAMILY, QWEN_TOOL_CALL_FAMILY, ChatModelMixin
32
32
 
33
33
  logger = logging.getLogger(__name__)
34
34
 
@@ -123,18 +123,22 @@ class LlamaCppModel(LLM):
123
123
 
124
124
  raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
125
125
 
126
- # handle legacy cache.
127
- model_path = os.path.realpath(
128
- os.path.join(
129
- self.model_path,
130
- self.model_spec.model_file_name_template.format(
131
- quantization=self.quantization
132
- ),
126
+ if os.path.isfile(self.model_path):
127
+ # mostly passed from --model_path
128
+ model_path = os.path.realpath(self.model_path)
129
+ else:
130
+ # handle legacy cache.
131
+ model_path = os.path.realpath(
132
+ os.path.join(
133
+ self.model_path,
134
+ self.model_spec.model_file_name_template.format(
135
+ quantization=self.quantization
136
+ ),
137
+ )
133
138
  )
134
- )
135
- legacy_model_file_path = os.path.join(self.model_path, "model.bin")
136
- if os.path.exists(legacy_model_file_path):
137
- model_path = legacy_model_file_path
139
+ legacy_model_file_path = os.path.join(self.model_path, "model.bin")
140
+ if os.path.exists(legacy_model_file_path):
141
+ model_path = legacy_model_file_path
138
142
 
139
143
  try:
140
144
  self._llm = Llama(
@@ -272,8 +276,11 @@ class LlamaCppChatModel(LlamaCppModel, ChatModelMixin):
272
276
  model_family = self.model_family.model_family or self.model_family.model_name
273
277
  tools = generate_config.pop("tools", []) if generate_config else None
274
278
  full_context_kwargs = {}
275
- if tools and model_family in QWEN_TOOL_CALL_FAMILY:
276
- full_context_kwargs["tools"] = tools
279
+ if tools:
280
+ if model_family in QWEN_TOOL_CALL_FAMILY:
281
+ full_context_kwargs["tools"] = tools
282
+ elif model_family in DEEPSEEK_TOOL_CALL_FAMILY:
283
+ self._tools_to_messages_for_deepseek(messages, tools)
277
284
  assert self.model_family.chat_template is not None
278
285
  full_prompt = self.get_full_context(
279
286
  messages, self.model_family.chat_template, **full_context_kwargs