xinference 1.0.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (87) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +5 -5
  3. xinference/core/model.py +6 -1
  4. xinference/deploy/cmdline.py +3 -1
  5. xinference/deploy/test/test_cmdline.py +56 -0
  6. xinference/isolation.py +24 -0
  7. xinference/model/audio/core.py +5 -0
  8. xinference/model/audio/f5tts.py +195 -0
  9. xinference/model/audio/fish_speech.py +2 -1
  10. xinference/model/audio/model_spec.json +8 -0
  11. xinference/model/audio/model_spec_modelscope.json +9 -0
  12. xinference/model/embedding/core.py +203 -142
  13. xinference/model/embedding/model_spec.json +7 -0
  14. xinference/model/embedding/model_spec_modelscope.json +8 -0
  15. xinference/model/llm/__init__.py +2 -2
  16. xinference/model/llm/llm_family.json +172 -53
  17. xinference/model/llm/llm_family_modelscope.json +118 -20
  18. xinference/model/llm/mlx/core.py +230 -49
  19. xinference/model/llm/sglang/core.py +1 -0
  20. xinference/model/llm/transformers/chatglm.py +9 -5
  21. xinference/model/llm/transformers/utils.py +16 -8
  22. xinference/model/llm/utils.py +4 -1
  23. xinference/model/llm/vllm/core.py +5 -0
  24. xinference/thirdparty/f5_tts/__init__.py +0 -0
  25. xinference/thirdparty/f5_tts/api.py +166 -0
  26. xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
  27. xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
  28. xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
  29. xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
  30. xinference/thirdparty/f5_tts/eval/README.md +49 -0
  31. xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
  32. xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
  33. xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
  34. xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
  35. xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
  36. xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
  37. xinference/thirdparty/f5_tts/infer/README.md +191 -0
  38. xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
  39. xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
  40. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
  41. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
  42. xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
  43. xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
  44. xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
  45. xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
  46. xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
  47. xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
  48. xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
  49. xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
  50. xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
  51. xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
  52. xinference/thirdparty/f5_tts/model/__init__.py +10 -0
  53. xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
  54. xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
  55. xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
  56. xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
  57. xinference/thirdparty/f5_tts/model/cfm.py +285 -0
  58. xinference/thirdparty/f5_tts/model/dataset.py +319 -0
  59. xinference/thirdparty/f5_tts/model/modules.py +658 -0
  60. xinference/thirdparty/f5_tts/model/trainer.py +366 -0
  61. xinference/thirdparty/f5_tts/model/utils.py +185 -0
  62. xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
  63. xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
  64. xinference/thirdparty/f5_tts/socket_server.py +159 -0
  65. xinference/thirdparty/f5_tts/train/README.md +77 -0
  66. xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
  67. xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
  68. xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
  69. xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
  70. xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
  71. xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
  72. xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
  73. xinference/thirdparty/f5_tts/train/train.py +75 -0
  74. xinference/web/ui/build/asset-manifest.json +3 -3
  75. xinference/web/ui/build/index.html +1 -1
  76. xinference/web/ui/build/static/js/{main.2f269bb3.js → main.4eb4ee80.js} +3 -3
  77. xinference/web/ui/build/static/js/main.4eb4ee80.js.map +1 -0
  78. xinference/web/ui/node_modules/.cache/babel-loader/8c5eeb02f772d02cbe8b89c05428d0dd41a97866f75f7dc1c2164a67f5a1cf98.json +1 -0
  79. {xinference-1.0.1.dist-info → xinference-1.1.0.dist-info}/METADATA +33 -14
  80. {xinference-1.0.1.dist-info → xinference-1.1.0.dist-info}/RECORD +85 -34
  81. xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
  82. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
  83. /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.4eb4ee80.js.LICENSE.txt} +0 -0
  84. {xinference-1.0.1.dist-info → xinference-1.1.0.dist-info}/LICENSE +0 -0
  85. {xinference-1.0.1.dist-info → xinference-1.1.0.dist-info}/WHEEL +0 -0
  86. {xinference-1.0.1.dist-info → xinference-1.1.0.dist-info}/entry_points.txt +0 -0
  87. {xinference-1.0.1.dist-info → xinference-1.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,193 @@
1
+ import os
2
+
3
+ os.environ["PYTOCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torchaudio
8
+
9
+ from f5_tts.infer.utils_infer import load_checkpoint, load_vocoder, save_spectrogram
10
+ from f5_tts.model import CFM, DiT, UNetT
11
+ from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer
12
+
13
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
14
+
15
+
16
+ # --------------------- Dataset Settings -------------------- #
17
+
18
+ target_sample_rate = 24000
19
+ n_mel_channels = 100
20
+ hop_length = 256
21
+ win_length = 1024
22
+ n_fft = 1024
23
+ mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
24
+ target_rms = 0.1
25
+
26
+ tokenizer = "pinyin"
27
+ dataset_name = "Emilia_ZH_EN"
28
+
29
+
30
+ # ---------------------- infer setting ---------------------- #
31
+
32
+ seed = None # int | None
33
+
34
+ exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
35
+ ckpt_step = 1200000
36
+
37
+ nfe_step = 32 # 16, 32
38
+ cfg_strength = 2.0
39
+ ode_method = "euler" # euler | midpoint
40
+ sway_sampling_coef = -1.0
41
+ speed = 1.0
42
+
43
+ if exp_name == "F5TTS_Base":
44
+ model_cls = DiT
45
+ model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
46
+
47
+ elif exp_name == "E2TTS_Base":
48
+ model_cls = UNetT
49
+ model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
50
+
51
+ ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.safetensors"
52
+ output_dir = "tests"
53
+
54
+ # [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
55
+ # pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git
56
+ # [write the origin_text into a file, e.g. tests/test_edit.txt]
57
+ # ctc-forced-aligner --audio_path "src/f5_tts/infer/examples/basic/basic_ref_en.wav" --text_path "tests/test_edit.txt" --language "zho" --romanize --split_size "char"
58
+ # [result will be saved at same path of audio file]
59
+ # [--language "zho" for Chinese, "eng" for English]
60
+ # [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"]
61
+
62
+ audio_to_edit = "src/f5_tts/infer/examples/basic/basic_ref_en.wav"
63
+ origin_text = "Some call me nature, others call me mother nature."
64
+ target_text = "Some call me optimist, others call me realist."
65
+ parts_to_edit = [
66
+ [1.42, 2.44],
67
+ [4.04, 4.9],
68
+ ] # stard_ends of "nature" & "mother nature", in seconds
69
+ fix_duration = [
70
+ 1.2,
71
+ 1,
72
+ ] # fix duration for "optimist" & "realist", in seconds
73
+
74
+ # audio_to_edit = "src/f5_tts/infer/examples/basic/basic_ref_zh.wav"
75
+ # origin_text = "对,这就是我,万人敬仰的太乙真人。"
76
+ # target_text = "对,那就是你,万人敬仰的太白金星。"
77
+ # parts_to_edit = [[0.84, 1.4], [1.92, 2.4], [4.26, 6.26], ]
78
+ # fix_duration = None # use origin text duration
79
+
80
+
81
+ # -------------------------------------------------#
82
+
83
+ use_ema = True
84
+
85
+ if not os.path.exists(output_dir):
86
+ os.makedirs(output_dir)
87
+
88
+ # Vocoder model
89
+ local = False
90
+ if mel_spec_type == "vocos":
91
+ vocoder_local_path = "../checkpoints/charactr/vocos-mel-24khz"
92
+ elif mel_spec_type == "bigvgan":
93
+ vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
94
+ vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=local, local_path=vocoder_local_path)
95
+
96
+ # Tokenizer
97
+ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
98
+
99
+ # Model
100
+ model = CFM(
101
+ transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
102
+ mel_spec_kwargs=dict(
103
+ n_fft=n_fft,
104
+ hop_length=hop_length,
105
+ win_length=win_length,
106
+ n_mel_channels=n_mel_channels,
107
+ target_sample_rate=target_sample_rate,
108
+ mel_spec_type=mel_spec_type,
109
+ ),
110
+ odeint_kwargs=dict(
111
+ method=ode_method,
112
+ ),
113
+ vocab_char_map=vocab_char_map,
114
+ ).to(device)
115
+
116
+ dtype = torch.float32 if mel_spec_type == "bigvgan" else None
117
+ model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
118
+
119
+ # Audio
120
+ audio, sr = torchaudio.load(audio_to_edit)
121
+ if audio.shape[0] > 1:
122
+ audio = torch.mean(audio, dim=0, keepdim=True)
123
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
124
+ if rms < target_rms:
125
+ audio = audio * target_rms / rms
126
+ if sr != target_sample_rate:
127
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
128
+ audio = resampler(audio)
129
+ offset = 0
130
+ audio_ = torch.zeros(1, 0)
131
+ edit_mask = torch.zeros(1, 0, dtype=torch.bool)
132
+ for part in parts_to_edit:
133
+ start, end = part
134
+ part_dur = end - start if fix_duration is None else fix_duration.pop(0)
135
+ part_dur = part_dur * target_sample_rate
136
+ start = start * target_sample_rate
137
+ audio_ = torch.cat((audio_, audio[:, round(offset) : round(start)], torch.zeros(1, round(part_dur))), dim=-1)
138
+ edit_mask = torch.cat(
139
+ (
140
+ edit_mask,
141
+ torch.ones(1, round((start - offset) / hop_length), dtype=torch.bool),
142
+ torch.zeros(1, round(part_dur / hop_length), dtype=torch.bool),
143
+ ),
144
+ dim=-1,
145
+ )
146
+ offset = end * target_sample_rate
147
+ # audio = torch.cat((audio_, audio[:, round(offset):]), dim = -1)
148
+ edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value=True)
149
+ audio = audio.to(device)
150
+ edit_mask = edit_mask.to(device)
151
+
152
+ # Text
153
+ text_list = [target_text]
154
+ if tokenizer == "pinyin":
155
+ final_text_list = convert_char_to_pinyin(text_list)
156
+ else:
157
+ final_text_list = [text_list]
158
+ print(f"text : {text_list}")
159
+ print(f"pinyin: {final_text_list}")
160
+
161
+ # Duration
162
+ ref_audio_len = 0
163
+ duration = audio.shape[-1] // hop_length
164
+
165
+ # Inference
166
+ with torch.inference_mode():
167
+ generated, trajectory = model.sample(
168
+ cond=audio,
169
+ text=final_text_list,
170
+ duration=duration,
171
+ steps=nfe_step,
172
+ cfg_strength=cfg_strength,
173
+ sway_sampling_coef=sway_sampling_coef,
174
+ seed=seed,
175
+ edit_mask=edit_mask,
176
+ )
177
+ print(f"Generated mel: {generated.shape}")
178
+
179
+ # Final result
180
+ generated = generated.to(torch.float32)
181
+ generated = generated[:, ref_audio_len:, :]
182
+ gen_mel_spec = generated.permute(0, 2, 1)
183
+ if mel_spec_type == "vocos":
184
+ generated_wave = vocoder.decode(gen_mel_spec).cpu()
185
+ elif mel_spec_type == "bigvgan":
186
+ generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
187
+
188
+ if rms < target_rms:
189
+ generated_wave = generated_wave * rms / target_rms
190
+
191
+ save_spectrogram(gen_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
192
+ torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave, target_sample_rate)
193
+ print(f"Generated wav: {generated_wave.shape}")