xinference 1.0.1__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (170) hide show
  1. xinference/_compat.py +2 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +28 -6
  4. xinference/core/utils.py +10 -6
  5. xinference/deploy/cmdline.py +3 -1
  6. xinference/deploy/test/test_cmdline.py +56 -0
  7. xinference/isolation.py +24 -0
  8. xinference/model/audio/core.py +10 -0
  9. xinference/model/audio/cosyvoice.py +25 -3
  10. xinference/model/audio/f5tts.py +200 -0
  11. xinference/model/audio/f5tts_mlx.py +260 -0
  12. xinference/model/audio/fish_speech.py +36 -111
  13. xinference/model/audio/model_spec.json +27 -3
  14. xinference/model/audio/model_spec_modelscope.json +18 -0
  15. xinference/model/audio/utils.py +32 -0
  16. xinference/model/embedding/core.py +203 -142
  17. xinference/model/embedding/model_spec.json +7 -0
  18. xinference/model/embedding/model_spec_modelscope.json +8 -0
  19. xinference/model/image/core.py +69 -1
  20. xinference/model/image/model_spec.json +127 -4
  21. xinference/model/image/model_spec_modelscope.json +130 -4
  22. xinference/model/image/stable_diffusion/core.py +45 -13
  23. xinference/model/llm/__init__.py +2 -2
  24. xinference/model/llm/llm_family.json +219 -53
  25. xinference/model/llm/llm_family.py +15 -36
  26. xinference/model/llm/llm_family_modelscope.json +167 -20
  27. xinference/model/llm/mlx/core.py +287 -51
  28. xinference/model/llm/sglang/core.py +1 -0
  29. xinference/model/llm/transformers/chatglm.py +9 -5
  30. xinference/model/llm/transformers/core.py +1 -0
  31. xinference/model/llm/transformers/qwen2_vl.py +2 -0
  32. xinference/model/llm/transformers/utils.py +16 -8
  33. xinference/model/llm/utils.py +5 -1
  34. xinference/model/llm/vllm/core.py +16 -2
  35. xinference/thirdparty/cosyvoice/bin/average_model.py +92 -0
  36. xinference/thirdparty/cosyvoice/bin/export_jit.py +12 -2
  37. xinference/thirdparty/cosyvoice/bin/export_onnx.py +112 -0
  38. xinference/thirdparty/cosyvoice/bin/export_trt.sh +9 -0
  39. xinference/thirdparty/cosyvoice/bin/inference.py +5 -7
  40. xinference/thirdparty/cosyvoice/bin/train.py +42 -8
  41. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +96 -25
  42. xinference/thirdparty/cosyvoice/cli/frontend.py +77 -30
  43. xinference/thirdparty/cosyvoice/cli/model.py +330 -80
  44. xinference/thirdparty/cosyvoice/dataset/dataset.py +6 -2
  45. xinference/thirdparty/cosyvoice/dataset/processor.py +76 -14
  46. xinference/thirdparty/cosyvoice/flow/decoder.py +92 -13
  47. xinference/thirdparty/cosyvoice/flow/flow.py +99 -9
  48. xinference/thirdparty/cosyvoice/flow/flow_matching.py +110 -13
  49. xinference/thirdparty/cosyvoice/flow/length_regulator.py +5 -4
  50. xinference/thirdparty/cosyvoice/hifigan/discriminator.py +140 -0
  51. xinference/thirdparty/cosyvoice/hifigan/generator.py +58 -42
  52. xinference/thirdparty/cosyvoice/hifigan/hifigan.py +67 -0
  53. xinference/thirdparty/cosyvoice/llm/llm.py +139 -6
  54. xinference/thirdparty/cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  55. xinference/thirdparty/cosyvoice/tokenizer/tokenizer.py +279 -0
  56. xinference/thirdparty/cosyvoice/transformer/embedding.py +2 -2
  57. xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +7 -7
  58. xinference/thirdparty/cosyvoice/transformer/upsample_encoder.py +318 -0
  59. xinference/thirdparty/cosyvoice/utils/common.py +28 -1
  60. xinference/thirdparty/cosyvoice/utils/executor.py +69 -7
  61. xinference/thirdparty/cosyvoice/utils/file_utils.py +2 -12
  62. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +9 -5
  63. xinference/thirdparty/cosyvoice/utils/losses.py +20 -0
  64. xinference/thirdparty/cosyvoice/utils/scheduler.py +1 -2
  65. xinference/thirdparty/cosyvoice/utils/train_utils.py +101 -45
  66. xinference/thirdparty/f5_tts/api.py +166 -0
  67. xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
  68. xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
  69. xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
  70. xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
  71. xinference/thirdparty/f5_tts/eval/README.md +49 -0
  72. xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
  73. xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
  74. xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
  75. xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
  76. xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
  77. xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
  78. xinference/thirdparty/f5_tts/infer/README.md +191 -0
  79. xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
  80. xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
  81. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
  82. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
  83. xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
  84. xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
  85. xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
  86. xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
  87. xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
  88. xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
  89. xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
  90. xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
  91. xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
  92. xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
  93. xinference/thirdparty/f5_tts/model/__init__.py +10 -0
  94. xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
  95. xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
  96. xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
  97. xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
  98. xinference/thirdparty/f5_tts/model/cfm.py +285 -0
  99. xinference/thirdparty/f5_tts/model/dataset.py +319 -0
  100. xinference/thirdparty/f5_tts/model/modules.py +658 -0
  101. xinference/thirdparty/f5_tts/model/trainer.py +366 -0
  102. xinference/thirdparty/f5_tts/model/utils.py +185 -0
  103. xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
  104. xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
  105. xinference/thirdparty/f5_tts/socket_server.py +159 -0
  106. xinference/thirdparty/f5_tts/train/README.md +77 -0
  107. xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
  108. xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
  109. xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
  110. xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
  111. xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
  112. xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
  113. xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
  114. xinference/thirdparty/f5_tts/train/train.py +75 -0
  115. xinference/thirdparty/fish_speech/fish_speech/conversation.py +94 -83
  116. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +63 -20
  117. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +1 -26
  118. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +1 -1
  119. xinference/thirdparty/fish_speech/fish_speech/tokenizer.py +152 -0
  120. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -2
  121. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
  122. xinference/thirdparty/fish_speech/tools/{post_api.py → api_client.py} +7 -13
  123. xinference/thirdparty/fish_speech/tools/api_server.py +98 -0
  124. xinference/thirdparty/fish_speech/tools/download_models.py +5 -5
  125. xinference/thirdparty/fish_speech/tools/fish_e2e.py +2 -2
  126. xinference/thirdparty/fish_speech/tools/inference_engine/__init__.py +192 -0
  127. xinference/thirdparty/fish_speech/tools/inference_engine/reference_loader.py +125 -0
  128. xinference/thirdparty/fish_speech/tools/inference_engine/utils.py +39 -0
  129. xinference/thirdparty/fish_speech/tools/inference_engine/vq_manager.py +57 -0
  130. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +2 -2
  131. xinference/thirdparty/fish_speech/tools/llama/generate.py +117 -89
  132. xinference/thirdparty/fish_speech/tools/run_webui.py +104 -0
  133. xinference/thirdparty/fish_speech/tools/schema.py +11 -28
  134. xinference/thirdparty/fish_speech/tools/server/agent/__init__.py +57 -0
  135. xinference/thirdparty/fish_speech/tools/server/agent/generate.py +119 -0
  136. xinference/thirdparty/fish_speech/tools/server/agent/generation_utils.py +122 -0
  137. xinference/thirdparty/fish_speech/tools/server/agent/pre_generation_utils.py +72 -0
  138. xinference/thirdparty/fish_speech/tools/server/api_utils.py +75 -0
  139. xinference/thirdparty/fish_speech/tools/server/exception_handler.py +27 -0
  140. xinference/thirdparty/fish_speech/tools/server/inference.py +45 -0
  141. xinference/thirdparty/fish_speech/tools/server/model_manager.py +122 -0
  142. xinference/thirdparty/fish_speech/tools/server/model_utils.py +129 -0
  143. xinference/thirdparty/fish_speech/tools/server/views.py +246 -0
  144. xinference/thirdparty/fish_speech/tools/webui/__init__.py +173 -0
  145. xinference/thirdparty/fish_speech/tools/webui/inference.py +91 -0
  146. xinference/thirdparty/fish_speech/tools/webui/variables.py +14 -0
  147. xinference/thirdparty/matcha/utils/utils.py +2 -2
  148. xinference/web/ui/build/asset-manifest.json +3 -3
  149. xinference/web/ui/build/index.html +1 -1
  150. xinference/web/ui/build/static/js/{main.2f269bb3.js → main.4eb4ee80.js} +3 -3
  151. xinference/web/ui/build/static/js/main.4eb4ee80.js.map +1 -0
  152. xinference/web/ui/node_modules/.cache/babel-loader/8c5eeb02f772d02cbe8b89c05428d0dd41a97866f75f7dc1c2164a67f5a1cf98.json +1 -0
  153. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/METADATA +41 -17
  154. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/RECORD +160 -88
  155. xinference/thirdparty/cosyvoice/bin/export_trt.py +0 -8
  156. xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
  157. xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
  158. xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
  159. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  160. xinference/thirdparty/fish_speech/tools/api.py +0 -943
  161. xinference/thirdparty/fish_speech/tools/msgpack_api.py +0 -95
  162. xinference/thirdparty/fish_speech/tools/webui.py +0 -548
  163. xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
  164. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
  165. /xinference/thirdparty/{cosyvoice/bin → f5_tts}/__init__.py +0 -0
  166. /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.4eb4ee80.js.LICENSE.txt} +0 -0
  167. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/LICENSE +0 -0
  168. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/WHEEL +0 -0
  169. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/entry_points.txt +0 -0
  170. {xinference-1.0.1.dist-info → xinference-1.1.1.dist-info}/top_level.txt +0 -0
@@ -15,8 +15,10 @@
15
15
  # Modified from ESPnet(https://github.com/espnet/espnet)
16
16
  """Unility functions for Transformer."""
17
17
 
18
+ import random
18
19
  from typing import List
19
20
 
21
+ import numpy as np
20
22
  import torch
21
23
 
22
24
  IGNORE_ID = -1
@@ -102,6 +104,7 @@ def init_weights(m, mean=0.0, std=0.01):
102
104
  if classname.find("Conv") != -1:
103
105
  m.weight.data.normal_(mean, std)
104
106
 
107
+
105
108
  # Repetition Aware Sampling in VALL-E 2
106
109
  def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1):
107
110
  top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
@@ -110,6 +113,7 @@ def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25,
110
113
  top_ids = random_sampling(weighted_scores, decoded_tokens, sampling)
111
114
  return top_ids
112
115
 
116
+
113
117
  def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
114
118
  prob, indices = [], []
115
119
  cum_prob = 0.0
@@ -127,13 +131,36 @@ def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
127
131
  top_ids = indices[prob.multinomial(1, replacement=True)]
128
132
  return top_ids
129
133
 
134
+
130
135
  def random_sampling(weighted_scores, decoded_tokens, sampling):
131
136
  top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
132
137
  return top_ids
133
138
 
139
+
134
140
  def fade_in_out(fade_in_mel, fade_out_mel, window):
135
141
  device = fade_in_mel.device
136
142
  fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
137
143
  mel_overlap_len = int(window.shape[0] / 2)
138
- fade_in_mel[:, :, :mel_overlap_len] = fade_in_mel[:, :, :mel_overlap_len] * window[:mel_overlap_len] + fade_out_mel[:, :, -mel_overlap_len:] * window[mel_overlap_len:]
144
+ if fade_in_mel.device == torch.device('cpu'):
145
+ fade_in_mel = fade_in_mel.clone()
146
+ fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
147
+ fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
139
148
  return fade_in_mel.to(device)
149
+
150
+
151
+ def set_all_random_seed(seed):
152
+ random.seed(seed)
153
+ np.random.seed(seed)
154
+ torch.manual_seed(seed)
155
+ torch.cuda.manual_seed_all(seed)
156
+
157
+
158
+ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
159
+ assert mask.dtype == torch.bool
160
+ assert dtype in [torch.float32, torch.bfloat16, torch.float16]
161
+ mask = mask.to(dtype)
162
+ # attention mask bias
163
+ # NOTE(Mddct): torch.finfo jit issues
164
+ # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
165
+ mask = (1.0 - mask) * torch.finfo(dtype).min
166
+ return mask
@@ -25,13 +25,14 @@ from cosyvoice.utils.train_utils import update_parameter_and_lr, log_per_step, l
25
25
 
26
26
  class Executor:
27
27
 
28
- def __init__(self):
28
+ def __init__(self, gan: bool = False):
29
+ self.gan = gan
29
30
  self.step = 0
30
31
  self.epoch = 0
31
32
  self.rank = int(os.environ.get('RANK', 0))
32
33
  self.device = torch.device('cuda:{}'.format(self.rank))
33
34
 
34
- def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join):
35
+ def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join):
35
36
  ''' Train one epoch
36
37
  '''
37
38
 
@@ -64,13 +65,72 @@ class Executor:
64
65
  context = nullcontext
65
66
 
66
67
  with context():
67
- info_dict = batch_forward(model, batch_dict, info_dict)
68
- info_dict = batch_backward(model, info_dict)
68
+ info_dict = batch_forward(model, batch_dict, scaler, info_dict)
69
+ info_dict = batch_backward(model, scaler, info_dict)
69
70
 
70
- info_dict = update_parameter_and_lr(model, optimizer, scheduler, info_dict)
71
+ info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
71
72
  log_per_step(writer, info_dict)
72
73
  # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
73
- if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and (batch_idx + 1) % info_dict["accum_grad"] == 0:
74
+ if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
75
+ (batch_idx + 1) % info_dict["accum_grad"] == 0:
76
+ dist.barrier()
77
+ self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
78
+ model.train()
79
+ if (batch_idx + 1) % info_dict["accum_grad"] == 0:
80
+ self.step += 1
81
+ dist.barrier()
82
+ self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
83
+
84
+ def train_one_epoc_gan(self, model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
85
+ writer, info_dict, scaler, group_join):
86
+ ''' Train one epoch
87
+ '''
88
+
89
+ lr = optimizer.param_groups[0]['lr']
90
+ logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank))
91
+ logging.info('using accumulate grad, new batch size is {} times'
92
+ ' larger than before'.format(info_dict['accum_grad']))
93
+ # A context manager to be used in conjunction with an instance of
94
+ # torch.nn.parallel.DistributedDataParallel to be able to train
95
+ # with uneven inputs across participating processes.
96
+ model.train()
97
+ model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
98
+ with model_context():
99
+ for batch_idx, batch_dict in enumerate(train_data_loader):
100
+ info_dict["tag"] = "TRAIN"
101
+ info_dict["step"] = self.step
102
+ info_dict["epoch"] = self.epoch
103
+ info_dict["batch_idx"] = batch_idx
104
+ if cosyvoice_join(group_join, info_dict):
105
+ break
106
+
107
+ # Disable gradient synchronizations across DDP processes.
108
+ # Within this context, gradients will be accumulated on module
109
+ # variables, which will later be synchronized.
110
+ if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0:
111
+ context = model.no_sync
112
+ # Used for single gpu training and DDP gradient synchronization
113
+ # processes.
114
+ else:
115
+ context = nullcontext
116
+
117
+ with context():
118
+ batch_dict['turn'] = 'discriminator'
119
+ info_dict = batch_forward(model, batch_dict, scaler, info_dict)
120
+ info_dict = batch_backward(model, scaler, info_dict)
121
+ info_dict = update_parameter_and_lr(model, optimizer_d, scheduler_d, scaler, info_dict)
122
+ optimizer.zero_grad()
123
+ log_per_step(writer, info_dict)
124
+ with context():
125
+ batch_dict['turn'] = 'generator'
126
+ info_dict = batch_forward(model, batch_dict, scaler, info_dict)
127
+ info_dict = batch_backward(model, scaler, info_dict)
128
+ info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
129
+ optimizer_d.zero_grad()
130
+ log_per_step(writer, info_dict)
131
+ # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
132
+ if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
133
+ (batch_idx + 1) % info_dict["accum_grad"] == 0:
74
134
  dist.barrier()
75
135
  self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
76
136
  model.train()
@@ -95,7 +155,9 @@ class Executor:
95
155
  num_utts = len(batch_dict["utts"])
96
156
  total_num_utts += num_utts
97
157
 
98
- info_dict = batch_forward(model, batch_dict, info_dict)
158
+ if self.gan is True:
159
+ batch_dict['turn'] = 'generator'
160
+ info_dict = batch_forward(model, batch_dict, None, info_dict)
99
161
 
100
162
  for k, v in info_dict['loss_dict'].items():
101
163
  if k not in total_loss_dict:
@@ -28,6 +28,7 @@ def read_lists(list_file):
28
28
  lists.append(line.strip())
29
29
  return lists
30
30
 
31
+
31
32
  def read_json_lists(list_file):
32
33
  lists = read_lists(list_file)
33
34
  results = {}
@@ -36,6 +37,7 @@ def read_json_lists(list_file):
36
37
  results.update(json.load(fin))
37
38
  return results
38
39
 
40
+
39
41
  def load_wav(wav, target_sr):
40
42
  speech, sample_rate = torchaudio.load(wav)
41
43
  speech = speech.mean(dim=0, keepdim=True)
@@ -43,15 +45,3 @@ def load_wav(wav, target_sr):
43
45
  assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
44
46
  speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
45
47
  return speech
46
-
47
- def speed_change(waveform, sample_rate, speed_factor: str):
48
- effects = [
49
- ["tempo", speed_factor], # speed_factor
50
- ["rate", f"{sample_rate}"]
51
- ]
52
- augmented_waveform, new_sample_rate = torchaudio.sox_effects.apply_effects_tensor(
53
- waveform,
54
- sample_rate,
55
- effects
56
- )
57
- return augmented_waveform, new_sample_rate
@@ -15,6 +15,7 @@
15
15
  import re
16
16
  chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
17
17
 
18
+
18
19
  # whether contain chinese character
19
20
  def contains_chinese(text):
20
21
  return bool(chinese_char_pattern.search(text))
@@ -79,6 +80,13 @@ def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=
79
80
  pounc = ['.', '?', '!', ';', ':']
80
81
  if comma_split:
81
82
  pounc.extend([',', ','])
83
+
84
+ if text[-1] not in pounc:
85
+ if lang == "zh":
86
+ text += "。"
87
+ else:
88
+ text += "."
89
+
82
90
  st = 0
83
91
  utts = []
84
92
  for i, c in enumerate(text):
@@ -91,11 +99,7 @@ def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=
91
99
  st = i + 2
92
100
  else:
93
101
  st = i + 1
94
- if len(utts) == 0:
95
- if lang == "zh":
96
- utts.append(text + '。')
97
- else:
98
- utts.append(text + '.')
102
+
99
103
  final_utts = []
100
104
  cur_utt = ""
101
105
  for utt in utts:
@@ -0,0 +1,20 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def tpr_loss(disc_real_outputs, disc_generated_outputs, tau):
6
+ loss = 0
7
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
8
+ m_DG = torch.median((dr - dg))
9
+ L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG])
10
+ loss += tau - F.relu(tau - L_rel)
11
+ return loss
12
+
13
+
14
+ def mel_loss(real_speech, generated_speech, mel_transforms):
15
+ loss = 0
16
+ for transform in mel_transforms:
17
+ mel_r = transform(real_speech)
18
+ mel_g = transform(generated_speech)
19
+ loss += F.l1_loss(mel_g, mel_r)
20
+ return loss
@@ -567,8 +567,7 @@ class NoamAnnealing(_LRScheduler):
567
567
  min_lr=0.0,
568
568
  last_epoch=-1):
569
569
  self._normalize = d_model**(-0.5)
570
- assert not (warmup_steps is not None
571
- and warmup_ratio is not None), \
570
+ assert not (warmup_steps is not None and warmup_ratio is not None), \
572
571
  "Either use particular number of step or ratio"
573
572
  assert warmup_ratio is None or max_steps is not None, \
574
573
  "If there is a ratio, there should be a total steps"
@@ -14,7 +14,6 @@
14
14
  # See the License for the specific language governing permissions and
15
15
  # limitations under the License.
16
16
 
17
- from contextlib import nullcontext
18
17
  import logging
19
18
  import os
20
19
  import torch
@@ -51,9 +50,10 @@ def init_distributed(args):
51
50
  return world_size, local_rank, rank
52
51
 
53
52
 
54
- def init_dataset_and_dataloader(args, configs):
55
- train_dataset = Dataset(args.train_data, data_pipeline=configs['data_pipeline'], mode='train', shuffle=True, partition=True)
56
- cv_dataset = Dataset(args.cv_data, data_pipeline=configs['data_pipeline'], mode='train', shuffle=False, partition=False)
53
+ def init_dataset_and_dataloader(args, configs, gan):
54
+ data_pipeline = configs['data_pipeline_gan'] if gan is True else configs['data_pipeline']
55
+ train_dataset = Dataset(args.train_data, data_pipeline=data_pipeline, mode='train', gan=gan, shuffle=True, partition=True)
56
+ cv_dataset = Dataset(args.cv_data, data_pipeline=data_pipeline, mode='train', gan=gan, shuffle=False, partition=False)
57
57
 
58
58
  # do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts
59
59
  train_data_loader = DataLoader(train_dataset,
@@ -69,7 +69,6 @@ def init_dataset_and_dataloader(args, configs):
69
69
  return train_dataset, cv_dataset, train_data_loader, cv_data_loader
70
70
 
71
71
 
72
-
73
72
  def check_modify_and_save_config(args, configs):
74
73
  if args.train_engine == "torch_ddp":
75
74
  configs['train_conf']["dtype"] = 'fp32'
@@ -84,7 +83,8 @@ def check_modify_and_save_config(args, configs):
84
83
  configs['train_conf']["dtype"] = "fp32"
85
84
  assert ds_configs["train_micro_batch_size_per_gpu"] == 1
86
85
  # if use deepspeed, override ddp config
87
- configs['train_conf']['save_per_step'] = int(configs['train_conf']['save_per_step'] * configs['train_conf']['accum_grad'] / ds_configs["gradient_accumulation_steps"])
86
+ configs['train_conf']['save_per_step'] = int(configs['train_conf']['save_per_step'] *
87
+ configs['train_conf']['accum_grad'] / ds_configs["gradient_accumulation_steps"])
88
88
  configs['train_conf']['accum_grad'] = ds_configs["gradient_accumulation_steps"]
89
89
  configs['train_conf']['grad_clip'] = ds_configs["gradient_clipping"]
90
90
  configs['train_conf']['log_interval'] = ds_configs["steps_per_print"]
@@ -108,38 +108,80 @@ def wrap_cuda_model(args, model):
108
108
  return model
109
109
 
110
110
 
111
- def init_optimizer_and_scheduler(args, configs, model):
112
- if configs['train_conf']['optim'] == 'adam':
113
- optimizer = optim.Adam(model.parameters(), **configs['train_conf']['optim_conf'])
114
- elif configs['train_conf']['optim'] == 'adamw':
115
- optimizer = optim.AdamW(model.parameters(), **configs['train_conf']['optim_conf'])
116
- else:
117
- raise ValueError("unknown optimizer: " + configs['train_conf'])
118
-
119
- if configs['train_conf']['scheduler'] == 'warmuplr':
120
- scheduler_type = WarmupLR
121
- scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf'])
122
- elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
123
- scheduler_type = NoamHoldAnnealing
124
- scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
125
- elif configs['train_conf']['scheduler'] == 'constantlr':
126
- scheduler_type = ConstantLR
127
- scheduler = ConstantLR(optimizer)
128
- else:
129
- raise ValueError("unknown scheduler: " + configs['train_conf'])
111
+ def init_optimizer_and_scheduler(args, configs, model, gan):
112
+ if gan is False:
113
+ if configs['train_conf']['optim'] == 'adam':
114
+ optimizer = optim.Adam(model.parameters(), **configs['train_conf']['optim_conf'])
115
+ elif configs['train_conf']['optim'] == 'adamw':
116
+ optimizer = optim.AdamW(model.parameters(), **configs['train_conf']['optim_conf'])
117
+ else:
118
+ raise ValueError("unknown optimizer: " + configs['train_conf'])
119
+
120
+ if configs['train_conf']['scheduler'] == 'warmuplr':
121
+ scheduler_type = WarmupLR
122
+ scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf'])
123
+ elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
124
+ scheduler_type = NoamHoldAnnealing
125
+ scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
126
+ elif configs['train_conf']['scheduler'] == 'constantlr':
127
+ scheduler_type = ConstantLR
128
+ scheduler = ConstantLR(optimizer)
129
+ else:
130
+ raise ValueError("unknown scheduler: " + configs['train_conf'])
131
+
132
+ # use deepspeed optimizer for speedup
133
+ if args.train_engine == "deepspeed":
134
+ def scheduler(opt):
135
+ return scheduler_type(opt, **configs['train_conf']['scheduler_conf'])
136
+ model, optimizer, _, scheduler = deepspeed.initialize(
137
+ args=args,
138
+ model=model,
139
+ optimizer=None,
140
+ lr_scheduler=scheduler,
141
+ model_parameters=model.parameters())
142
+
143
+ optimizer_d, scheduler_d = None, None
130
144
 
131
- # use deepspeed optimizer for speedup
132
- if args.train_engine == "deepspeed":
133
- def scheduler(opt):
134
- return scheduler_type(opt, **configs['train_conf']['scheduler_conf'])
135
- model, optimizer, _, scheduler = deepspeed.initialize(
136
- args=args,
137
- model=model,
138
- optimizer=None,
139
- lr_scheduler=scheduler,
140
- model_parameters=model.parameters())
145
+ else:
146
+ # currently we wrap generator and discriminator in one model, so we cannot use deepspeed
147
+ if configs['train_conf']['optim'] == 'adam':
148
+ optimizer = optim.Adam(model.module.generator.parameters(), **configs['train_conf']['optim_conf'])
149
+ elif configs['train_conf']['optim'] == 'adamw':
150
+ optimizer = optim.AdamW(model.module.generator.parameters(), **configs['train_conf']['optim_conf'])
151
+ else:
152
+ raise ValueError("unknown optimizer: " + configs['train_conf'])
153
+
154
+ if configs['train_conf']['scheduler'] == 'warmuplr':
155
+ scheduler_type = WarmupLR
156
+ scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf'])
157
+ elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
158
+ scheduler_type = NoamHoldAnnealing
159
+ scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
160
+ elif configs['train_conf']['scheduler'] == 'constantlr':
161
+ scheduler_type = ConstantLR
162
+ scheduler = ConstantLR(optimizer)
163
+ else:
164
+ raise ValueError("unknown scheduler: " + configs['train_conf'])
141
165
 
142
- return model, optimizer, scheduler
166
+ if configs['train_conf']['optim_d'] == 'adam':
167
+ optimizer_d = optim.Adam(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf'])
168
+ elif configs['train_conf']['optim_d'] == 'adamw':
169
+ optimizer_d = optim.AdamW(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf'])
170
+ else:
171
+ raise ValueError("unknown optimizer: " + configs['train_conf'])
172
+
173
+ if configs['train_conf']['scheduler_d'] == 'warmuplr':
174
+ scheduler_type = WarmupLR
175
+ scheduler_d = WarmupLR(optimizer_d, **configs['train_conf']['scheduler_conf'])
176
+ elif configs['train_conf']['scheduler_d'] == 'NoamHoldAnnealing':
177
+ scheduler_type = NoamHoldAnnealing
178
+ scheduler_d = NoamHoldAnnealing(optimizer_d, **configs['train_conf']['scheduler_conf'])
179
+ elif configs['train_conf']['scheduler'] == 'constantlr':
180
+ scheduler_type = ConstantLR
181
+ scheduler_d = ConstantLR(optimizer_d)
182
+ else:
183
+ raise ValueError("unknown scheduler: " + configs['train_conf'])
184
+ return model, optimizer, scheduler, optimizer_d, scheduler_d
143
185
 
144
186
 
145
187
  def init_summarywriter(args):
@@ -157,7 +199,7 @@ def save_model(model, model_name, info_dict):
157
199
 
158
200
  if info_dict["train_engine"] == "torch_ddp":
159
201
  if rank == 0:
160
- torch.save(model.module.state_dict(), save_model_path)
202
+ torch.save({**model.module.state_dict(), 'epoch': info_dict['epoch'], 'step': info_dict['step']}, save_model_path)
161
203
  else:
162
204
  with torch.no_grad():
163
205
  model.save_checkpoint(save_dir=model_dir,
@@ -193,7 +235,7 @@ def cosyvoice_join(group_join, info_dict):
193
235
  return False
194
236
 
195
237
 
196
- def batch_forward(model, batch, info_dict):
238
+ def batch_forward(model, batch, scaler, info_dict):
197
239
  device = int(os.environ.get('LOCAL_RANK', 0))
198
240
 
199
241
  dtype = info_dict["dtype"]
@@ -205,7 +247,7 @@ def batch_forward(model, batch, info_dict):
205
247
  dtype = torch.float32
206
248
 
207
249
  if info_dict['train_engine'] == 'torch_ddp':
208
- autocast = nullcontext()
250
+ autocast = torch.cuda.amp.autocast(enabled=scaler is not None)
209
251
  else:
210
252
  autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False)
211
253
 
@@ -214,27 +256,41 @@ def batch_forward(model, batch, info_dict):
214
256
  return info_dict
215
257
 
216
258
 
217
- def batch_backward(model, info_dict):
259
+ def batch_backward(model, scaler, info_dict):
218
260
  if info_dict["train_engine"] == "deepspeed":
219
261
  scaled_loss = model.backward(info_dict['loss_dict']['loss'])
220
262
  else:
221
263
  scaled_loss = info_dict['loss_dict']['loss'] / info_dict['accum_grad']
222
- scaled_loss.backward()
264
+ if scaler is not None:
265
+ scaler.scale(scaled_loss).backward()
266
+ else:
267
+ scaled_loss.backward()
223
268
 
224
269
  info_dict['loss_dict']['loss'] = scaled_loss
225
270
  return info_dict
226
271
 
227
272
 
228
- def update_parameter_and_lr(model, optimizer, scheduler, info_dict):
273
+ def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict):
229
274
  grad_norm = 0.0
230
275
  if info_dict['train_engine'] == "deepspeed":
231
276
  info_dict["is_gradient_accumulation_boundary"] = model.is_gradient_accumulation_boundary()
232
277
  model.step()
233
278
  grad_norm = model.get_global_grad_norm()
234
279
  elif (info_dict['batch_idx'] + 1) % info_dict["accum_grad"] == 0:
235
- grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
236
- if torch.isfinite(grad_norm):
237
- optimizer.step()
280
+ # Use mixed precision training
281
+ if scaler is not None:
282
+ scaler.unscale_(optimizer)
283
+ grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
284
+ # We don't check grad here since that if the gradient
285
+ # has inf/nan values, scaler.step will skip
286
+ # optimizer.step().
287
+ if torch.isfinite(grad_norm):
288
+ scaler.step(optimizer)
289
+ scaler.update()
290
+ else:
291
+ grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
292
+ if torch.isfinite(grad_norm):
293
+ optimizer.step()
238
294
  optimizer.zero_grad()
239
295
  scheduler.step()
240
296
  info_dict["lr"] = optimizer.param_groups[0]['lr']
@@ -0,0 +1,166 @@
1
+ import random
2
+ import sys
3
+ from importlib.resources import files
4
+
5
+ import soundfile as sf
6
+ import tqdm
7
+ from cached_path import cached_path
8
+
9
+ from f5_tts.infer.utils_infer import (
10
+ hop_length,
11
+ infer_process,
12
+ load_model,
13
+ load_vocoder,
14
+ preprocess_ref_audio_text,
15
+ remove_silence_for_generated_wav,
16
+ save_spectrogram,
17
+ transcribe,
18
+ target_sample_rate,
19
+ )
20
+ from f5_tts.model import DiT, UNetT
21
+ from f5_tts.model.utils import seed_everything
22
+
23
+
24
+ class F5TTS:
25
+ def __init__(
26
+ self,
27
+ model_type="F5-TTS",
28
+ ckpt_file="",
29
+ vocab_file="",
30
+ ode_method="euler",
31
+ use_ema=True,
32
+ vocoder_name="vocos",
33
+ local_path=None,
34
+ device=None,
35
+ hf_cache_dir=None,
36
+ ):
37
+ # Initialize parameters
38
+ self.final_wave = None
39
+ self.target_sample_rate = target_sample_rate
40
+ self.hop_length = hop_length
41
+ self.seed = -1
42
+ self.mel_spec_type = vocoder_name
43
+
44
+ # Set device
45
+ if device is not None:
46
+ self.device = device
47
+ else:
48
+ import torch
49
+
50
+ self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
51
+
52
+ # Load models
53
+ self.load_vocoder_model(vocoder_name, local_path=local_path, hf_cache_dir=hf_cache_dir)
54
+ self.load_ema_model(
55
+ model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema, hf_cache_dir=hf_cache_dir
56
+ )
57
+
58
+ def load_vocoder_model(self, vocoder_name, local_path=None, hf_cache_dir=None):
59
+ self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device, hf_cache_dir)
60
+
61
+ def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, hf_cache_dir=None):
62
+ if model_type == "F5-TTS":
63
+ if not ckpt_file:
64
+ if mel_spec_type == "vocos":
65
+ ckpt_file = str(
66
+ cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir)
67
+ )
68
+ elif mel_spec_type == "bigvgan":
69
+ ckpt_file = str(
70
+ cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=hf_cache_dir)
71
+ )
72
+ model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
73
+ model_cls = DiT
74
+ elif model_type == "E2-TTS":
75
+ if not ckpt_file:
76
+ ckpt_file = str(
77
+ cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir)
78
+ )
79
+ model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
80
+ model_cls = UNetT
81
+ else:
82
+ raise ValueError(f"Unknown model type: {model_type}")
83
+
84
+ self.ema_model = load_model(
85
+ model_cls, model_cfg, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, self.device
86
+ )
87
+
88
+ def transcribe(self, ref_audio, language=None):
89
+ return transcribe(ref_audio, language)
90
+
91
+ def export_wav(self, wav, file_wave, remove_silence=False):
92
+ sf.write(file_wave, wav, self.target_sample_rate)
93
+
94
+ if remove_silence:
95
+ remove_silence_for_generated_wav(file_wave)
96
+
97
+ def export_spectrogram(self, spect, file_spect):
98
+ save_spectrogram(spect, file_spect)
99
+
100
+ def infer(
101
+ self,
102
+ ref_file,
103
+ ref_text,
104
+ gen_text,
105
+ show_info=print,
106
+ progress=tqdm,
107
+ target_rms=0.1,
108
+ cross_fade_duration=0.15,
109
+ sway_sampling_coef=-1,
110
+ cfg_strength=2,
111
+ nfe_step=32,
112
+ speed=1.0,
113
+ fix_duration=None,
114
+ remove_silence=False,
115
+ file_wave=None,
116
+ file_spect=None,
117
+ seed=-1,
118
+ ):
119
+ if seed == -1:
120
+ seed = random.randint(0, sys.maxsize)
121
+ seed_everything(seed)
122
+ self.seed = seed
123
+
124
+ ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device)
125
+
126
+ wav, sr, spect = infer_process(
127
+ ref_file,
128
+ ref_text,
129
+ gen_text,
130
+ self.ema_model,
131
+ self.vocoder,
132
+ self.mel_spec_type,
133
+ show_info=show_info,
134
+ progress=progress,
135
+ target_rms=target_rms,
136
+ cross_fade_duration=cross_fade_duration,
137
+ nfe_step=nfe_step,
138
+ cfg_strength=cfg_strength,
139
+ sway_sampling_coef=sway_sampling_coef,
140
+ speed=speed,
141
+ fix_duration=fix_duration,
142
+ device=self.device,
143
+ )
144
+
145
+ if file_wave is not None:
146
+ self.export_wav(wav, file_wave, remove_silence)
147
+
148
+ if file_spect is not None:
149
+ self.export_spectrogram(spect, file_spect)
150
+
151
+ return wav, sr, spect
152
+
153
+
154
+ if __name__ == "__main__":
155
+ f5tts = F5TTS()
156
+
157
+ wav, sr, spect = f5tts.infer(
158
+ ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
159
+ ref_text="some call me nature, others call me mother nature.",
160
+ gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
161
+ file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
162
+ file_spect=str(files("f5_tts").joinpath("../../tests/api_out.png")),
163
+ seed=-1, # random seed = -1
164
+ )
165
+
166
+ print("seed :", f5tts.seed)