xinference 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (104) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +50 -1
  3. xinference/client/restful/restful_client.py +82 -2
  4. xinference/constants.py +3 -0
  5. xinference/core/chat_interface.py +297 -83
  6. xinference/core/model.py +1 -0
  7. xinference/core/progress_tracker.py +16 -8
  8. xinference/core/supervisor.py +45 -1
  9. xinference/core/worker.py +262 -37
  10. xinference/deploy/cmdline.py +33 -1
  11. xinference/model/audio/core.py +11 -1
  12. xinference/model/audio/megatts.py +105 -0
  13. xinference/model/audio/model_spec.json +24 -1
  14. xinference/model/audio/model_spec_modelscope.json +26 -1
  15. xinference/model/core.py +14 -0
  16. xinference/model/embedding/core.py +6 -1
  17. xinference/model/flexible/core.py +6 -1
  18. xinference/model/image/core.py +6 -1
  19. xinference/model/image/model_spec.json +17 -1
  20. xinference/model/image/model_spec_modelscope.json +17 -1
  21. xinference/model/llm/__init__.py +0 -4
  22. xinference/model/llm/core.py +4 -0
  23. xinference/model/llm/llama_cpp/core.py +40 -16
  24. xinference/model/llm/llm_family.json +413 -84
  25. xinference/model/llm/llm_family.py +24 -1
  26. xinference/model/llm/llm_family_modelscope.json +447 -0
  27. xinference/model/llm/mlx/core.py +16 -2
  28. xinference/model/llm/transformers/__init__.py +14 -0
  29. xinference/model/llm/transformers/core.py +30 -6
  30. xinference/model/llm/transformers/gemma3.py +17 -2
  31. xinference/model/llm/transformers/intern_vl.py +28 -18
  32. xinference/model/llm/transformers/minicpmv26.py +21 -2
  33. xinference/model/llm/transformers/qwen-omni.py +308 -0
  34. xinference/model/llm/transformers/qwen2_audio.py +1 -1
  35. xinference/model/llm/transformers/qwen2_vl.py +20 -4
  36. xinference/model/llm/utils.py +11 -1
  37. xinference/model/llm/vllm/core.py +35 -0
  38. xinference/model/llm/vllm/distributed_executor.py +8 -2
  39. xinference/model/rerank/core.py +6 -1
  40. xinference/model/utils.py +118 -1
  41. xinference/model/video/core.py +6 -1
  42. xinference/thirdparty/megatts3/__init__.py +0 -0
  43. xinference/thirdparty/megatts3/tts/frontend_function.py +175 -0
  44. xinference/thirdparty/megatts3/tts/gradio_api.py +93 -0
  45. xinference/thirdparty/megatts3/tts/infer_cli.py +277 -0
  46. xinference/thirdparty/megatts3/tts/modules/aligner/whisper_small.py +318 -0
  47. xinference/thirdparty/megatts3/tts/modules/ar_dur/ar_dur_predictor.py +362 -0
  48. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/layers.py +64 -0
  49. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/nar_tts_modules.py +73 -0
  50. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rel_transformer.py +403 -0
  51. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/rot_transformer.py +649 -0
  52. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/seq_utils.py +342 -0
  53. xinference/thirdparty/megatts3/tts/modules/ar_dur/commons/transformer.py +767 -0
  54. xinference/thirdparty/megatts3/tts/modules/llm_dit/cfm.py +309 -0
  55. xinference/thirdparty/megatts3/tts/modules/llm_dit/dit.py +180 -0
  56. xinference/thirdparty/megatts3/tts/modules/llm_dit/time_embedding.py +44 -0
  57. xinference/thirdparty/megatts3/tts/modules/llm_dit/transformer.py +230 -0
  58. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/diag_gaussian.py +67 -0
  59. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/hifigan_modules.py +283 -0
  60. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/seanet_encoder.py +38 -0
  61. xinference/thirdparty/megatts3/tts/modules/wavvae/decoder/wavvae_v3.py +60 -0
  62. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/conv.py +154 -0
  63. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/lstm.py +51 -0
  64. xinference/thirdparty/megatts3/tts/modules/wavvae/encoder/common_modules/seanet.py +126 -0
  65. xinference/thirdparty/megatts3/tts/utils/audio_utils/align.py +36 -0
  66. xinference/thirdparty/megatts3/tts/utils/audio_utils/io.py +95 -0
  67. xinference/thirdparty/megatts3/tts/utils/audio_utils/plot.py +90 -0
  68. xinference/thirdparty/megatts3/tts/utils/commons/ckpt_utils.py +171 -0
  69. xinference/thirdparty/megatts3/tts/utils/commons/hparams.py +215 -0
  70. xinference/thirdparty/megatts3/tts/utils/text_utils/dict.json +1 -0
  71. xinference/thirdparty/megatts3/tts/utils/text_utils/ph_tone_convert.py +94 -0
  72. xinference/thirdparty/megatts3/tts/utils/text_utils/split_text.py +90 -0
  73. xinference/thirdparty/megatts3/tts/utils/text_utils/text_encoder.py +280 -0
  74. xinference/types.py +10 -0
  75. xinference/utils.py +54 -0
  76. xinference/web/ui/build/asset-manifest.json +6 -6
  77. xinference/web/ui/build/index.html +1 -1
  78. xinference/web/ui/build/static/css/main.0f6523be.css +2 -0
  79. xinference/web/ui/build/static/css/main.0f6523be.css.map +1 -0
  80. xinference/web/ui/build/static/js/main.58bd483c.js +3 -0
  81. xinference/web/ui/build/static/js/main.58bd483c.js.map +1 -0
  82. xinference/web/ui/node_modules/.cache/babel-loader/3bff8cbe9141f937f4d98879a9771b0f48e0e4e0dbee8e647adbfe23859e7048.json +1 -0
  83. xinference/web/ui/node_modules/.cache/babel-loader/4500b1a622a031011f0a291701e306b87e08cbc749c50e285103536b85b6a914.json +1 -0
  84. xinference/web/ui/node_modules/.cache/babel-loader/51709f5d3e53bcf19e613662ef9b91fb9174942c5518987a248348dd4e1e0e02.json +1 -0
  85. xinference/web/ui/node_modules/.cache/babel-loader/69081049f0c7447544b7cfd73dd13d8846c02fe5febe4d81587e95c89a412d5b.json +1 -0
  86. xinference/web/ui/node_modules/.cache/babel-loader/b8551e9775a01b28ae674125c688febe763732ea969ae344512e64ea01bf632e.json +1 -0
  87. xinference/web/ui/node_modules/.cache/babel-loader/bf2b211b0d1b6465eff512d64c869d748f803c5651a7c24e48de6ea3484a7bfe.json +1 -0
  88. xinference/web/ui/src/locales/en.json +2 -1
  89. xinference/web/ui/src/locales/zh.json +2 -1
  90. {xinference-1.4.1.dist-info → xinference-1.5.0.dist-info}/METADATA +127 -114
  91. {xinference-1.4.1.dist-info → xinference-1.5.0.dist-info}/RECORD +96 -60
  92. {xinference-1.4.1.dist-info → xinference-1.5.0.dist-info}/WHEEL +1 -1
  93. xinference/web/ui/build/static/css/main.b494ae7e.css +0 -2
  94. xinference/web/ui/build/static/css/main.b494ae7e.css.map +0 -1
  95. xinference/web/ui/build/static/js/main.5ca4eea1.js +0 -3
  96. xinference/web/ui/build/static/js/main.5ca4eea1.js.map +0 -1
  97. xinference/web/ui/node_modules/.cache/babel-loader/0f0967acaec5df1d45b80010949c258d64297ebbb0f44b8bb3afcbd45c6f0ec4.json +0 -1
  98. xinference/web/ui/node_modules/.cache/babel-loader/27bcada3ee8f89d21184b359f022fc965f350ffaca52c9814c29f1fc37121173.json +0 -1
  99. xinference/web/ui/node_modules/.cache/babel-loader/68249645124f37d01eef83b1d897e751f895bea919b6fb466f907c1f87cebc84.json +0 -1
  100. xinference/web/ui/node_modules/.cache/babel-loader/e547bbb18abb4a474b675a8d5782d25617566bea0af8caa9b836ce5649e2250a.json +0 -1
  101. /xinference/web/ui/build/static/js/{main.5ca4eea1.js.LICENSE.txt → main.58bd483c.js.LICENSE.txt} +0 -0
  102. {xinference-1.4.1.dist-info → xinference-1.5.0.dist-info}/entry_points.txt +0 -0
  103. {xinference-1.4.1.dist-info → xinference-1.5.0.dist-info/licenses}/LICENSE +0 -0
  104. {xinference-1.4.1.dist-info → xinference-1.5.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,277 @@
1
+ # Copyright 2025 ByteDance and/or its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import os
17
+ import argparse
18
+ import librosa
19
+ import numpy as np
20
+ import torch
21
+
22
+ from tn.chinese.normalizer import Normalizer as ZhNormalizer
23
+ from tn.english.normalizer import Normalizer as EnNormalizer
24
+ from langdetect import detect as classify_language
25
+ from pydub import AudioSegment
26
+ import pyloudnorm as pyln
27
+
28
+ from tts.modules.ar_dur.commons.nar_tts_modules import LengthRegulator
29
+ from tts.frontend_function import g2p, align, make_dur_prompt, dur_pred, prepare_inputs_for_dit
30
+ from tts.utils.audio_utils.io import save_wav, to_wav_bytes, convert_to_wav_bytes, combine_audio_segments
31
+ from tts.utils.commons.ckpt_utils import load_ckpt
32
+ from tts.utils.commons.hparams import set_hparams, hparams
33
+ from tts.utils.text_utils.text_encoder import TokenTextEncoder
34
+ from tts.utils.text_utils.split_text import chunk_text_chinese, chunk_text_english
35
+ from tts.utils.commons.hparams import hparams, set_hparams
36
+
37
+
38
+ if "TOKENIZERS_PARALLELISM" not in os.environ:
39
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
40
+
41
+ def convert_to_wav(wav_path):
42
+ # Check if the file exists
43
+ if not os.path.exists(wav_path):
44
+ print(f"The file '{wav_path}' does not exist.")
45
+ return
46
+
47
+ # Check if the file already has a .wav extension
48
+ if not wav_path.endswith(".wav"):
49
+ # Define the output path with a .wav extension
50
+ out_path = os.path.splitext(wav_path)[0] + ".wav"
51
+
52
+ # Load the audio file using pydub and convert it to WAV
53
+ audio = AudioSegment.from_file(wav_path)
54
+ audio.export(out_path, format="wav")
55
+
56
+ print(f"Converted '{wav_path}' to '{out_path}'")
57
+
58
+
59
+ def cut_wav(wav_path, max_len=28):
60
+ audio = AudioSegment.from_file(wav_path)
61
+ audio = audio[:int(max_len * 1000)]
62
+ audio.export(wav_path, format="wav")
63
+
64
+ class MegaTTS3DiTInfer():
65
+ def __init__(
66
+ self,
67
+ device=None,
68
+ ckpt_root='./checkpoints',
69
+ dit_exp_name='diffusion_transformer',
70
+ frontend_exp_name='aligner_lm',
71
+ wavvae_exp_name='wavvae',
72
+ dur_ckpt_path='duration_lm',
73
+ g2p_exp_name='g2p',
74
+ precision=torch.float16,
75
+ **kwargs
76
+ ):
77
+ self.sr = 24000
78
+ self.fm = 8
79
+ if device is None:
80
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
81
+ self.device = device
82
+ self.precision = precision
83
+
84
+ # build models
85
+ self.dit_exp_name = os.path.join(ckpt_root, dit_exp_name)
86
+ self.frontend_exp_name = os.path.join(ckpt_root, frontend_exp_name)
87
+ self.wavvae_exp_name = os.path.join(ckpt_root, wavvae_exp_name)
88
+ self.dur_exp_name = os.path.join(ckpt_root, dur_ckpt_path)
89
+ self.g2p_exp_name = os.path.join(ckpt_root, g2p_exp_name)
90
+ self.build_model(self.device)
91
+
92
+ # init text normalizer
93
+ self.zh_normalizer = ZhNormalizer(overwrite_cache=False, remove_erhua=False, remove_interjections=False)
94
+ self.en_normalizer = EnNormalizer(overwrite_cache=False)
95
+ # loudness meter
96
+ self.loudness_meter = pyln.Meter(self.sr)
97
+
98
+ def build_model(self, device):
99
+ set_hparams(exp_name=self.dit_exp_name, print_hparams=False)
100
+
101
+ ''' Load Dict '''
102
+ current_dir = os.path.dirname(os.path.abspath(__file__))
103
+ ling_dict = json.load(open(f"{current_dir}/utils/text_utils/dict.json", encoding='utf-8-sig'))
104
+ self.ling_dict = {k: TokenTextEncoder(None, vocab_list=ling_dict[k], replace_oov='<UNK>') for k in ['phone', 'tone']}
105
+ self.token_encoder = token_encoder = self.ling_dict['phone']
106
+ ph_dict_size = len(token_encoder)
107
+
108
+ ''' Load Duration LM '''
109
+ from tts.modules.ar_dur.ar_dur_predictor import ARDurPredictor
110
+ hp_dur_model = self.hp_dur_model = set_hparams(f'{self.dur_exp_name}/config.yaml', global_hparams=False)
111
+ hp_dur_model['frames_multiple'] = hparams['frames_multiple']
112
+ self.dur_model = ARDurPredictor(
113
+ hp_dur_model, hp_dur_model['dur_txt_hs'], hp_dur_model['dur_model_hidden_size'],
114
+ hp_dur_model['dur_model_layers'], ph_dict_size,
115
+ hp_dur_model['dur_code_size'],
116
+ use_rot_embed=hp_dur_model.get('use_rot_embed', False))
117
+ self.length_regulator = LengthRegulator()
118
+ load_ckpt(self.dur_model, f'{self.dur_exp_name}', 'dur_model')
119
+ self.dur_model.eval()
120
+ self.dur_model.to(device)
121
+
122
+ ''' Load Diffusion Transformer '''
123
+ from tts.modules.llm_dit.dit import Diffusion
124
+ self.dit = Diffusion()
125
+ load_ckpt(self.dit, f'{self.dit_exp_name}', 'dit', strict=False)
126
+ self.dit.eval()
127
+ self.dit.to(device)
128
+ self.cfg_mask_token_phone = 302 - 1
129
+ self.cfg_mask_token_tone = 32 - 1
130
+
131
+ ''' Load Frontend LM '''
132
+ from tts.modules.aligner.whisper_small import Whisper
133
+ self.aligner_lm = Whisper()
134
+ load_ckpt(self.aligner_lm, f'{self.frontend_exp_name}', 'model')
135
+ self.aligner_lm.eval()
136
+ self.aligner_lm.to(device)
137
+ self.kv_cache = None
138
+ self.hooks = None
139
+
140
+ ''' Load G2P LM'''
141
+ from transformers import AutoTokenizer, AutoModelForCausalLM
142
+ g2p_tokenizer = AutoTokenizer.from_pretrained(self.g2p_exp_name, padding_side="right")
143
+ g2p_tokenizer.padding_side = "right"
144
+ self.g2p_model = AutoModelForCausalLM.from_pretrained(self.g2p_exp_name).eval().to(device)
145
+ self.g2p_tokenizer = g2p_tokenizer
146
+ self.speech_start_idx = g2p_tokenizer.encode('<Reserved_TTS_0>')[0]
147
+
148
+ ''' Wav VAE '''
149
+ self.hp_wavvae = hp_wavvae = set_hparams(f'{self.wavvae_exp_name}/config.yaml', global_hparams=False)
150
+ from tts.modules.wavvae.decoder.wavvae_v3 import WavVAE_V3
151
+ self.wavvae = WavVAE_V3(hparams=hp_wavvae)
152
+ if os.path.exists(f'{self.wavvae_exp_name}/model_only_last.ckpt'):
153
+ load_ckpt(self.wavvae, f'{self.wavvae_exp_name}/model_only_last.ckpt', 'model_gen', strict=True)
154
+ self.has_vae_encoder = True
155
+ else:
156
+ load_ckpt(self.wavvae, f'{self.wavvae_exp_name}/decoder.ckpt', 'model_gen', strict=False)
157
+ self.has_vae_encoder = False
158
+ self.wavvae.eval()
159
+ self.wavvae.to(device)
160
+ self.vae_stride = hp_wavvae.get('vae_stride', 4)
161
+ self.hop_size = hp_wavvae.get('hop_size', 4)
162
+
163
+ def preprocess(self, audio_bytes, latent_file=None, topk_dur=1, **kwargs):
164
+ wav_bytes = convert_to_wav_bytes(audio_bytes)
165
+
166
+ ''' Load wav '''
167
+ wav, _ = librosa.core.load(wav_bytes, sr=self.sr)
168
+ # Pad wav if necessary
169
+ ws = hparams['win_size']
170
+ if len(wav) % ws < ws - 1:
171
+ wav = np.pad(wav, (0, ws - 1 - (len(wav) % ws)), mode='constant', constant_values=0.0).astype(np.float32)
172
+ wav = np.pad(wav, (0, 12000), mode='constant', constant_values=0.0).astype(np.float32)
173
+ self.loudness_prompt = self.loudness_meter.integrated_loudness(wav.astype(float))
174
+
175
+ ''' obtain alignments with aligner_lm '''
176
+ ph_ref, tone_ref, mel2ph_ref = align(self, wav)
177
+
178
+ with torch.inference_mode():
179
+ ''' Forward WaveVAE to obtain: prompt latent '''
180
+ if self.has_vae_encoder:
181
+ wav = torch.FloatTensor(wav)[None].to(self.device)
182
+ vae_latent = self.wavvae.encode_latent(wav)
183
+ vae_latent = vae_latent[:, :mel2ph_ref.size(1)//4]
184
+ else:
185
+ assert latent_file is not None, "Please provide latent_file in WaveVAE decoder-only mode"
186
+ vae_latent = torch.from_numpy(np.load(latent_file)).to(self.device)
187
+ vae_latent = vae_latent[:, :mel2ph_ref.size(1)//4]
188
+
189
+ ''' Duration Prompting '''
190
+ self.dur_model.hparams["infer_top_k"] = topk_dur if topk_dur > 1 else None
191
+ incremental_state_dur_prompt, ctx_dur_tokens = make_dur_prompt(self, mel2ph_ref, ph_ref, tone_ref)
192
+
193
+ return {
194
+ 'ph_ref': ph_ref,
195
+ 'tone_ref': tone_ref,
196
+ 'mel2ph_ref': mel2ph_ref,
197
+ 'vae_latent': vae_latent,
198
+ 'incremental_state_dur_prompt': incremental_state_dur_prompt,
199
+ 'ctx_dur_tokens': ctx_dur_tokens,
200
+ }
201
+
202
+ def forward(self, resource_context, input_text, time_step, p_w, t_w, dur_disturb=0.1, dur_alpha=1.0, **kwargs):
203
+ device = self.device
204
+
205
+ ph_ref = resource_context['ph_ref'].to(device)
206
+ tone_ref = resource_context['tone_ref'].to(device)
207
+ mel2ph_ref = resource_context['mel2ph_ref'].to(device)
208
+ vae_latent = resource_context['vae_latent'].to(device)
209
+ ctx_dur_tokens = resource_context['ctx_dur_tokens'].to(device)
210
+ incremental_state_dur_prompt = resource_context['incremental_state_dur_prompt']
211
+
212
+ with torch.inference_mode():
213
+ ''' Generating '''
214
+ wav_pred_ = []
215
+ language_type = classify_language(input_text)
216
+ if language_type == 'en':
217
+ input_text = self.en_normalizer.normalize(input_text)
218
+ text_segs = chunk_text_english(input_text, max_chars=130)
219
+ else:
220
+ input_text = self.zh_normalizer.normalize(input_text)
221
+ text_segs = chunk_text_chinese(input_text, limit=60)
222
+
223
+ for seg_i, text in enumerate(text_segs):
224
+ ''' G2P '''
225
+ ph_pred, tone_pred = g2p(self, text)
226
+
227
+ ''' Duration Prediction '''
228
+ mel2ph_pred = dur_pred(self, ctx_dur_tokens, incremental_state_dur_prompt, ph_pred, tone_pred, seg_i, dur_disturb, dur_alpha, is_first=seg_i==0, is_final=seg_i==len(text_segs)-1)
229
+
230
+ inputs = prepare_inputs_for_dit(self, mel2ph_ref, mel2ph_pred, ph_ref, tone_ref, ph_pred, tone_pred, vae_latent)
231
+ # Speech dit inference
232
+ with torch.cuda.amp.autocast(dtype=self.precision, enabled=True):
233
+ x = self.dit.inference(inputs, timesteps=time_step, seq_cfg_w=[p_w, t_w]).float()
234
+
235
+ # WavVAE decode
236
+ x[:, :vae_latent.size(1)] = vae_latent
237
+ wav_pred = self.wavvae.decode(x)[0,0].to(torch.float32)
238
+
239
+ ''' Post-processing '''
240
+ # Trim prompt wav
241
+ wav_pred = wav_pred[vae_latent.size(1)*self.vae_stride*self.hop_size:].cpu().numpy()
242
+ # Norm generated wav to prompt wav's level
243
+ meter = pyln.Meter(self.sr) # create BS.1770 meter
244
+ loudness_pred = self.loudness_meter.integrated_loudness(wav_pred.astype(float))
245
+ wav_pred = pyln.normalize.loudness(wav_pred, loudness_pred, self.loudness_prompt)
246
+ if np.abs(wav_pred).max() >= 1:
247
+ wav_pred = wav_pred / np.abs(wav_pred).max() * 0.95
248
+
249
+ # Apply hamming window
250
+ wav_pred_.append(wav_pred)
251
+
252
+ return combine_audio_segments(wav_pred_, sr=self.sr).astype(float)
253
+
254
+
255
+ if __name__ == '__main__':
256
+ parser = argparse.ArgumentParser()
257
+ parser.add_argument('--input_wav', type=str)
258
+ parser.add_argument('--input_text', type=str)
259
+ parser.add_argument('--output_dir', type=str)
260
+ parser.add_argument('--time_step', type=int, default=32, help='Inference steps of Diffusion Transformer')
261
+ parser.add_argument('--p_w', type=float, default=1.6, help='Intelligibility Weight')
262
+ parser.add_argument('--t_w', type=float, default=2.5, help='Similarity Weight')
263
+ args = parser.parse_args()
264
+ wav_path, input_text, out_path, time_step, p_w, t_w = args.input_wav, args.input_text, args.output_dir, args.time_step, args.p_w, args.t_w
265
+
266
+ infer_ins = MegaTTS3DiTInfer()
267
+
268
+ with open(wav_path, 'rb') as file:
269
+ file_content = file.read()
270
+
271
+ print(f"| Start processing {wav_path}+{input_text}")
272
+ resource_context = infer_ins.preprocess(file_content, latent_file=wav_path.replace('.wav', '.npy'))
273
+ wav_bytes = infer_ins.forward(resource_context, input_text, time_step=time_step, p_w=p_w, t_w=t_w)
274
+
275
+ print(f"| Saving results to {out_path}/[P]{input_text[:20]}.wav")
276
+ os.makedirs(out_path, exist_ok=True)
277
+ save_wav(wav_bytes, f'{out_path}/[P]{input_text[:20]}.wav')
@@ -0,0 +1,318 @@
1
+ # MIT License
2
+
3
+ # Copyright (c) 2022 OpenAI
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # Copyright (c) [2022] [OpenAI]
24
+ # Copyright (c) [2025] [Ziyue Jiang]
25
+ # SPDX-License-Identifier: MIT
26
+ # This file has been modified by Ziyue Jiang on 2025/03/19
27
+ # Original file was released under MIT, with the full license text # available at https://github.com/openai/whisper/blob/v20240930/LICENSE.
28
+ # This modified file is released under the same license.
29
+
30
+ from contextlib import contextmanager
31
+ from typing import Dict, Iterable, Optional, Tuple
32
+
33
+ import numpy as np
34
+ import torch
35
+ import torch.nn.functional as F
36
+ from torch import Tensor, nn
37
+
38
+ from torch.nn.functional import scaled_dot_product_attention
39
+ SDPA_AVAILABLE = True
40
+
41
+
42
+ class LayerNorm(nn.LayerNorm):
43
+ def forward(self, x: Tensor) -> Tensor:
44
+ return super().forward(x.float()).type(x.dtype)
45
+
46
+
47
+ class Linear(nn.Linear):
48
+ def forward(self, x: Tensor) -> Tensor:
49
+ return F.linear(
50
+ x,
51
+ self.weight.to(x.dtype),
52
+ None if self.bias is None else self.bias.to(x.dtype),
53
+ )
54
+
55
+
56
+ class Conv1d(nn.Conv1d):
57
+ def _conv_forward(
58
+ self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
59
+ ) -> Tensor:
60
+ return super()._conv_forward(
61
+ x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
62
+ )
63
+
64
+
65
+ def sinusoids(length, channels, max_timescale=10000):
66
+ """Returns sinusoids for positional embedding"""
67
+ assert channels % 2 == 0
68
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
69
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
70
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
71
+ return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
72
+
73
+
74
+ @contextmanager
75
+ def disable_sdpa():
76
+ prev_state = MultiHeadAttention.use_sdpa
77
+ try:
78
+ MultiHeadAttention.use_sdpa = False
79
+ yield
80
+ finally:
81
+ MultiHeadAttention.use_sdpa = prev_state
82
+
83
+
84
+ class MultiHeadAttention(nn.Module):
85
+ use_sdpa = True
86
+
87
+ def __init__(self, n_state: int, n_head: int):
88
+ super().__init__()
89
+ self.n_head = n_head
90
+ self.query = Linear(n_state, n_state)
91
+ self.key = Linear(n_state, n_state, bias=False)
92
+ self.value = Linear(n_state, n_state)
93
+ self.out = Linear(n_state, n_state)
94
+
95
+ def forward(
96
+ self,
97
+ x: Tensor,
98
+ xa: Optional[Tensor] = None,
99
+ mask: Optional[Tensor] = None,
100
+ kv_cache: Optional[dict] = None,
101
+ casual: Optional[bool] = None
102
+ ):
103
+ q = self.query(x)
104
+
105
+ if kv_cache is None or xa is None or self.key not in kv_cache:
106
+ # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
107
+ # otherwise, perform key/value projections for self- or cross-attention as usual.
108
+ k = self.key(x if xa is None else xa)
109
+ v = self.value(x if xa is None else xa)
110
+ else:
111
+ # for cross-attention, calculate keys and values once and reuse in subsequent calls.
112
+ k = kv_cache[self.key]
113
+ v = kv_cache[self.value]
114
+
115
+ wv = self.qkv_attention(q, k, v, mask, casual)
116
+ return self.out(wv)
117
+
118
+ def qkv_attention(
119
+ self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, casual: Optional[bool] = None
120
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
121
+ n_batch, n_ctx, n_state = q.shape
122
+ scale = (n_state // self.n_head) ** -0.25
123
+ q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
124
+ k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
125
+ v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
126
+
127
+ a = scaled_dot_product_attention(
128
+ q, k, v, is_causal=casual and n_ctx > 1, attn_mask=mask[:, None, None, :] if mask is not None else None
129
+ )
130
+ out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
131
+ return out
132
+
133
+
134
+ class ResidualAttentionBlock(nn.Module):
135
+ def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
136
+ super().__init__()
137
+
138
+ self.attn = MultiHeadAttention(n_state, n_head)
139
+ self.attn_ln = LayerNorm(n_state)
140
+
141
+ self.cross_attn = (
142
+ MultiHeadAttention(n_state, n_head) if cross_attention else None
143
+ )
144
+ self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
145
+
146
+ n_mlp = n_state * 4
147
+ self.mlp = nn.Sequential(
148
+ Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
149
+ )
150
+ self.mlp_ln = LayerNorm(n_state)
151
+
152
+ def forward(
153
+ self,
154
+ x: Tensor,
155
+ xa: Optional[Tensor] = None,
156
+ mask: Optional[Tensor] = None,
157
+ kv_cache: Optional[dict] = None,
158
+ casual: Optional[bool] = None,
159
+ ):
160
+ x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache, casual=casual)
161
+ if self.cross_attn:
162
+ # TODO: Cross attention mask
163
+ x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache, casual=False)
164
+ x = x + self.mlp(self.mlp_ln(x))
165
+ return x
166
+
167
+
168
+ class AudioEncoder(nn.Module):
169
+ def __init__(
170
+ self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
171
+ ):
172
+ super().__init__()
173
+ self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
174
+ self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
175
+ self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
176
+
177
+ self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
178
+ [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
179
+ )
180
+ self.ln_post = LayerNorm(n_state)
181
+
182
+ def forward(self, x: Tensor, attn_mask: Tensor):
183
+ """
184
+ x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
185
+ the mel spectrogram of the audio
186
+ """
187
+ x = F.gelu(self.conv1(x))
188
+ x = F.gelu(self.conv2(x))
189
+ x = x.permute(0, 2, 1)
190
+
191
+ # assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
192
+ x = (x + self.positional_embedding[:x.size(1)]).to(x.dtype)
193
+
194
+ for block in self.blocks:
195
+ x = block(x, mask=attn_mask, casual=False)
196
+
197
+ x = self.ln_post(x)
198
+ return x
199
+
200
+
201
+ class TextDecoder(nn.Module):
202
+ def __init__(
203
+ self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
204
+ ):
205
+ super().__init__()
206
+
207
+ self.token_embedding = nn.Embedding(n_vocab, n_state)
208
+ self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
209
+
210
+ self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
211
+ [
212
+ ResidualAttentionBlock(n_state, n_head, cross_attention=True)
213
+ for _ in range(n_layer)
214
+ ]
215
+ )
216
+ self.ln = LayerNorm(n_state)
217
+
218
+ self.out_proj = nn.Linear(n_state, n_vocab)
219
+
220
+ def forward(self, x: Tensor, attn_mask: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
221
+ """
222
+ x : torch.LongTensor, shape = (batch_size, <= n_ctx)
223
+ the text tokens
224
+ xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
225
+ the encoded audio features to be attended on
226
+ """
227
+ offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
228
+ x = (
229
+ self.token_embedding(x)
230
+ + self.positional_embedding[offset : offset + x.shape[-1]]
231
+ )
232
+ x = x.to(xa.dtype)
233
+
234
+ for block in self.blocks:
235
+ x = block(x, xa, mask=attn_mask, kv_cache=kv_cache, casual=True)
236
+
237
+ x = self.ln(x)
238
+ # logits = (
239
+ # x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
240
+ # ).float()
241
+ logits = self.out_proj(x)
242
+
243
+ return logits
244
+
245
+
246
+ class Whisper(nn.Module):
247
+ def __init__(self):
248
+ super().__init__()
249
+ self.n_vocab = 6800
250
+ self.n_text_layer = 6
251
+ self.n_text_head = 8
252
+ self.n_text_ctx = 2048
253
+
254
+ self.encoder = AudioEncoder(
255
+ n_mels=80, n_ctx=3000, n_state=512, n_head=8, n_layer=6,
256
+ )
257
+ self.decoder = TextDecoder(
258
+ n_vocab=6800, n_ctx=2048, n_state=512, n_head=8, n_layer=6,
259
+ )
260
+
261
+ def embed_audio(self, mel: torch.Tensor):
262
+ return self.encoder(mel, None)
263
+
264
+ def logits(self, tokens, audio_features, kv_cache=None):
265
+ return self.decoder(tokens, None, audio_features, kv_cache=kv_cache)
266
+
267
+ def forward(
268
+ self, mel, mel_len, token, token_len
269
+ ) -> Dict[str, torch.Tensor]:
270
+ attn_mask_enc = self.sequence_mask(mel_len//2, device=mel.device) > 0
271
+ attn_mask_dec = self.sequence_mask(token_len, device=mel.device) > 0
272
+ return self.decoder(token, attn_mask_dec, self.encoder(mel, attn_mask_enc))
273
+
274
+ @property
275
+ def device(self):
276
+ return next(self.parameters()).device
277
+
278
+ def install_kv_cache_hooks(self, cache: Optional[dict] = None):
279
+ """
280
+ The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
281
+ tensors calculated for the previous positions. This method returns a dictionary that stores
282
+ all caches, and the necessary hooks for the key and value projection modules that save the
283
+ intermediate tensors to be reused during later calculations.
284
+
285
+ Returns
286
+ -------
287
+ cache : Dict[nn.Module, torch.Tensor]
288
+ A dictionary object mapping the key/value projection modules to its cache
289
+ hooks : List[RemovableHandle]
290
+ List of PyTorch RemovableHandle objects to stop the hooks to be called
291
+ """
292
+ cache = {**cache} if cache is not None else {}
293
+ hooks = []
294
+
295
+ def save_to_cache(module, _, output):
296
+ if module not in cache or output.shape[1] > self.n_text_ctx:
297
+ # save as-is, for the first token or cross attention
298
+ cache[module] = output
299
+ else:
300
+ cache[module] = torch.cat([cache[module], output], dim=1).detach()
301
+ return cache[module]
302
+
303
+ def install_hooks(layer: nn.Module):
304
+ if isinstance(layer, MultiHeadAttention):
305
+ hooks.append(layer.key.register_forward_hook(save_to_cache))
306
+ hooks.append(layer.value.register_forward_hook(save_to_cache))
307
+
308
+ self.decoder.apply(install_hooks)
309
+ return cache, hooks
310
+
311
+ def sequence_mask(self, seq_lens, max_len=None, device='cpu'):
312
+ b = seq_lens.shape[0]
313
+ if max_len is None:
314
+ max_len = seq_lens.max()
315
+ mask = torch.arange(max_len).unsqueeze(0).to(device) # [1, t]
316
+ mask = mask < (seq_lens.unsqueeze(1)) # [1, t] + [b, 1] = [b, t]
317
+ mask = mask.float()
318
+ return mask