xinference 1.2.0__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (124) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +4 -7
  3. xinference/client/handlers.py +3 -0
  4. xinference/core/chat_interface.py +6 -1
  5. xinference/core/model.py +2 -0
  6. xinference/core/scheduler.py +4 -7
  7. xinference/core/supervisor.py +114 -23
  8. xinference/core/worker.py +70 -4
  9. xinference/deploy/local.py +2 -1
  10. xinference/model/audio/core.py +11 -0
  11. xinference/model/audio/cosyvoice.py +16 -5
  12. xinference/model/audio/kokoro.py +139 -0
  13. xinference/model/audio/melotts.py +110 -0
  14. xinference/model/audio/model_spec.json +80 -0
  15. xinference/model/audio/model_spec_modelscope.json +18 -0
  16. xinference/model/audio/whisper.py +35 -10
  17. xinference/model/llm/llama_cpp/core.py +21 -14
  18. xinference/model/llm/llm_family.json +527 -1
  19. xinference/model/llm/llm_family.py +4 -1
  20. xinference/model/llm/llm_family_modelscope.json +495 -3
  21. xinference/model/llm/memory.py +1 -1
  22. xinference/model/llm/mlx/core.py +24 -6
  23. xinference/model/llm/transformers/core.py +9 -1
  24. xinference/model/llm/transformers/qwen2_audio.py +3 -1
  25. xinference/model/llm/transformers/qwen2_vl.py +20 -3
  26. xinference/model/llm/transformers/utils.py +22 -11
  27. xinference/model/llm/utils.py +115 -1
  28. xinference/model/llm/vllm/core.py +14 -4
  29. xinference/model/llm/vllm/xavier/block.py +3 -4
  30. xinference/model/llm/vllm/xavier/block_tracker.py +71 -58
  31. xinference/model/llm/vllm/xavier/collective.py +74 -0
  32. xinference/model/llm/vllm/xavier/collective_manager.py +147 -0
  33. xinference/model/llm/vllm/xavier/executor.py +18 -16
  34. xinference/model/llm/vllm/xavier/scheduler.py +79 -63
  35. xinference/model/llm/vllm/xavier/test/test_xavier.py +60 -35
  36. xinference/model/llm/vllm/xavier/transfer.py +53 -32
  37. xinference/thirdparty/cosyvoice/bin/spk2info.pt +0 -0
  38. xinference/thirdparty/melo/__init__.py +0 -0
  39. xinference/thirdparty/melo/api.py +135 -0
  40. xinference/thirdparty/melo/app.py +61 -0
  41. xinference/thirdparty/melo/attentions.py +459 -0
  42. xinference/thirdparty/melo/commons.py +160 -0
  43. xinference/thirdparty/melo/configs/config.json +94 -0
  44. xinference/thirdparty/melo/data/example/metadata.list +20 -0
  45. xinference/thirdparty/melo/data_utils.py +413 -0
  46. xinference/thirdparty/melo/download_utils.py +67 -0
  47. xinference/thirdparty/melo/infer.py +25 -0
  48. xinference/thirdparty/melo/init_downloads.py +14 -0
  49. xinference/thirdparty/melo/losses.py +58 -0
  50. xinference/thirdparty/melo/main.py +36 -0
  51. xinference/thirdparty/melo/mel_processing.py +174 -0
  52. xinference/thirdparty/melo/models.py +1030 -0
  53. xinference/thirdparty/melo/modules.py +598 -0
  54. xinference/thirdparty/melo/monotonic_align/__init__.py +16 -0
  55. xinference/thirdparty/melo/monotonic_align/core.py +46 -0
  56. xinference/thirdparty/melo/preprocess_text.py +135 -0
  57. xinference/thirdparty/melo/split_utils.py +174 -0
  58. xinference/thirdparty/melo/text/__init__.py +35 -0
  59. xinference/thirdparty/melo/text/chinese.py +199 -0
  60. xinference/thirdparty/melo/text/chinese_bert.py +107 -0
  61. xinference/thirdparty/melo/text/chinese_mix.py +253 -0
  62. xinference/thirdparty/melo/text/cleaner.py +36 -0
  63. xinference/thirdparty/melo/text/cleaner_multiling.py +110 -0
  64. xinference/thirdparty/melo/text/cmudict.rep +129530 -0
  65. xinference/thirdparty/melo/text/cmudict_cache.pickle +0 -0
  66. xinference/thirdparty/melo/text/english.py +284 -0
  67. xinference/thirdparty/melo/text/english_bert.py +39 -0
  68. xinference/thirdparty/melo/text/english_utils/__init__.py +0 -0
  69. xinference/thirdparty/melo/text/english_utils/abbreviations.py +35 -0
  70. xinference/thirdparty/melo/text/english_utils/number_norm.py +97 -0
  71. xinference/thirdparty/melo/text/english_utils/time_norm.py +47 -0
  72. xinference/thirdparty/melo/text/es_phonemizer/__init__.py +0 -0
  73. xinference/thirdparty/melo/text/es_phonemizer/base.py +140 -0
  74. xinference/thirdparty/melo/text/es_phonemizer/cleaner.py +109 -0
  75. xinference/thirdparty/melo/text/es_phonemizer/es_symbols.json +79 -0
  76. xinference/thirdparty/melo/text/es_phonemizer/es_symbols.txt +1 -0
  77. xinference/thirdparty/melo/text/es_phonemizer/es_symbols_v2.json +83 -0
  78. xinference/thirdparty/melo/text/es_phonemizer/es_to_ipa.py +12 -0
  79. xinference/thirdparty/melo/text/es_phonemizer/example_ipa.txt +400 -0
  80. xinference/thirdparty/melo/text/es_phonemizer/gruut_wrapper.py +253 -0
  81. xinference/thirdparty/melo/text/es_phonemizer/punctuation.py +174 -0
  82. xinference/thirdparty/melo/text/es_phonemizer/spanish_symbols.txt +1 -0
  83. xinference/thirdparty/melo/text/es_phonemizer/test.ipynb +124 -0
  84. xinference/thirdparty/melo/text/fr_phonemizer/__init__.py +0 -0
  85. xinference/thirdparty/melo/text/fr_phonemizer/base.py +140 -0
  86. xinference/thirdparty/melo/text/fr_phonemizer/cleaner.py +122 -0
  87. xinference/thirdparty/melo/text/fr_phonemizer/en_symbols.json +78 -0
  88. xinference/thirdparty/melo/text/fr_phonemizer/example_ipa.txt +1 -0
  89. xinference/thirdparty/melo/text/fr_phonemizer/fr_symbols.json +89 -0
  90. xinference/thirdparty/melo/text/fr_phonemizer/fr_to_ipa.py +30 -0
  91. xinference/thirdparty/melo/text/fr_phonemizer/french_abbreviations.py +48 -0
  92. xinference/thirdparty/melo/text/fr_phonemizer/french_symbols.txt +1 -0
  93. xinference/thirdparty/melo/text/fr_phonemizer/gruut_wrapper.py +258 -0
  94. xinference/thirdparty/melo/text/fr_phonemizer/punctuation.py +172 -0
  95. xinference/thirdparty/melo/text/french.py +94 -0
  96. xinference/thirdparty/melo/text/french_bert.py +39 -0
  97. xinference/thirdparty/melo/text/japanese.py +647 -0
  98. xinference/thirdparty/melo/text/japanese_bert.py +49 -0
  99. xinference/thirdparty/melo/text/ko_dictionary.py +44 -0
  100. xinference/thirdparty/melo/text/korean.py +192 -0
  101. xinference/thirdparty/melo/text/opencpop-strict.txt +429 -0
  102. xinference/thirdparty/melo/text/spanish.py +122 -0
  103. xinference/thirdparty/melo/text/spanish_bert.py +39 -0
  104. xinference/thirdparty/melo/text/symbols.py +290 -0
  105. xinference/thirdparty/melo/text/tone_sandhi.py +769 -0
  106. xinference/thirdparty/melo/train.py +635 -0
  107. xinference/thirdparty/melo/train.sh +19 -0
  108. xinference/thirdparty/melo/transforms.py +209 -0
  109. xinference/thirdparty/melo/utils.py +424 -0
  110. xinference/types.py +2 -0
  111. xinference/web/ui/build/asset-manifest.json +3 -3
  112. xinference/web/ui/build/index.html +1 -1
  113. xinference/web/ui/build/static/js/{main.1eb206d1.js → main.b0936c54.js} +3 -3
  114. xinference/web/ui/build/static/js/main.b0936c54.js.map +1 -0
  115. xinference/web/ui/node_modules/.cache/babel-loader/a3ff866acddf34917a7ee399e0e571a4dfd8ba66d5057db885f243e16a6eb17d.json +1 -0
  116. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/METADATA +37 -27
  117. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/RECORD +122 -45
  118. xinference/web/ui/build/static/js/main.1eb206d1.js.map +0 -1
  119. xinference/web/ui/node_modules/.cache/babel-loader/2213d49de260e1f67c888081b18f120f5225462b829ae57c9e05a05cec83689d.json +0 -1
  120. /xinference/web/ui/build/static/js/{main.1eb206d1.js.LICENSE.txt → main.b0936c54.js.LICENSE.txt} +0 -0
  121. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/LICENSE +0 -0
  122. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/WHEEL +0 -0
  123. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/entry_points.txt +0 -0
  124. {xinference-1.2.0.dist-info → xinference-1.2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,635 @@
1
+ # flake8: noqa: E402
2
+
3
+ import os
4
+ import torch
5
+ from torch.nn import functional as F
6
+ from torch.utils.data import DataLoader
7
+ from torch.utils.tensorboard import SummaryWriter
8
+ import torch.distributed as dist
9
+ from torch.nn.parallel import DistributedDataParallel as DDP
10
+ from torch.cuda.amp import autocast, GradScaler
11
+ from tqdm import tqdm
12
+ import logging
13
+
14
+ logging.getLogger("numba").setLevel(logging.WARNING)
15
+ import commons
16
+ import utils
17
+ from data_utils import (
18
+ TextAudioSpeakerLoader,
19
+ TextAudioSpeakerCollate,
20
+ DistributedBucketSampler,
21
+ )
22
+ from models import (
23
+ SynthesizerTrn,
24
+ MultiPeriodDiscriminator,
25
+ DurationDiscriminator,
26
+ )
27
+ from losses import generator_loss, discriminator_loss, feature_loss, kl_loss
28
+ from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
29
+ from text.symbols import symbols
30
+ from melo.download_utils import load_pretrain_model
31
+
32
+ torch.backends.cuda.matmul.allow_tf32 = True
33
+ torch.backends.cudnn.allow_tf32 = (
34
+ True # If encontered training problem,please try to disable TF32.
35
+ )
36
+ torch.set_float32_matmul_precision("medium")
37
+
38
+
39
+ torch.backends.cudnn.benchmark = True
40
+ torch.backends.cuda.sdp_kernel("flash")
41
+ torch.backends.cuda.enable_flash_sdp(True)
42
+ # torch.backends.cuda.enable_mem_efficient_sdp(
43
+ # True
44
+ # ) # Not available if torch version is lower than 2.0
45
+ torch.backends.cuda.enable_math_sdp(True)
46
+ global_step = 0
47
+
48
+
49
+ def run():
50
+ hps = utils.get_hparams()
51
+ local_rank = int(os.environ["LOCAL_RANK"])
52
+ dist.init_process_group(
53
+ backend="gloo",
54
+ init_method="env://", # Due to some training problem,we proposed to use gloo instead of nccl.
55
+ rank=local_rank,
56
+ ) # Use torchrun instead of mp.spawn
57
+ rank = dist.get_rank()
58
+ n_gpus = dist.get_world_size()
59
+
60
+ torch.manual_seed(hps.train.seed)
61
+ torch.cuda.set_device(rank)
62
+ global global_step
63
+ if rank == 0:
64
+ logger = utils.get_logger(hps.model_dir)
65
+ logger.info(hps)
66
+ utils.check_git_hash(hps.model_dir)
67
+ writer = SummaryWriter(log_dir=hps.model_dir)
68
+ writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
69
+ train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps.data)
70
+ train_sampler = DistributedBucketSampler(
71
+ train_dataset,
72
+ hps.train.batch_size,
73
+ [32, 300, 400, 500, 600, 700, 800, 900, 1000],
74
+ num_replicas=n_gpus,
75
+ rank=rank,
76
+ shuffle=True,
77
+ )
78
+ collate_fn = TextAudioSpeakerCollate()
79
+ train_loader = DataLoader(
80
+ train_dataset,
81
+ num_workers=16,
82
+ shuffle=False,
83
+ pin_memory=True,
84
+ collate_fn=collate_fn,
85
+ batch_sampler=train_sampler,
86
+ persistent_workers=True,
87
+ prefetch_factor=4,
88
+ ) # DataLoader config could be adjusted.
89
+ if rank == 0:
90
+ eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data)
91
+ eval_loader = DataLoader(
92
+ eval_dataset,
93
+ num_workers=0,
94
+ shuffle=False,
95
+ batch_size=1,
96
+ pin_memory=True,
97
+ drop_last=False,
98
+ collate_fn=collate_fn,
99
+ )
100
+ if (
101
+ "use_noise_scaled_mas" in hps.model.keys()
102
+ and hps.model.use_noise_scaled_mas is True
103
+ ):
104
+ print("Using noise scaled MAS for VITS2")
105
+ mas_noise_scale_initial = 0.01
106
+ noise_scale_delta = 2e-6
107
+ else:
108
+ print("Using normal MAS for VITS1")
109
+ mas_noise_scale_initial = 0.0
110
+ noise_scale_delta = 0.0
111
+ if (
112
+ "use_duration_discriminator" in hps.model.keys()
113
+ and hps.model.use_duration_discriminator is True
114
+ ):
115
+ print("Using duration discriminator for VITS2")
116
+ net_dur_disc = DurationDiscriminator(
117
+ hps.model.hidden_channels,
118
+ hps.model.hidden_channels,
119
+ 3,
120
+ 0.1,
121
+ gin_channels=hps.model.gin_channels if hps.data.n_speakers != 0 else 0,
122
+ ).cuda(rank)
123
+ if (
124
+ "use_spk_conditioned_encoder" in hps.model.keys()
125
+ and hps.model.use_spk_conditioned_encoder is True
126
+ ):
127
+ if hps.data.n_speakers == 0:
128
+ raise ValueError(
129
+ "n_speakers must be > 0 when using spk conditioned encoder to train multi-speaker model"
130
+ )
131
+ else:
132
+ print("Using normal encoder for VITS1")
133
+
134
+ net_g = SynthesizerTrn(
135
+ len(symbols),
136
+ hps.data.filter_length // 2 + 1,
137
+ hps.train.segment_size // hps.data.hop_length,
138
+ n_speakers=hps.data.n_speakers,
139
+ mas_noise_scale_initial=mas_noise_scale_initial,
140
+ noise_scale_delta=noise_scale_delta,
141
+ **hps.model,
142
+ ).cuda(rank)
143
+
144
+ net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
145
+ optim_g = torch.optim.AdamW(
146
+ filter(lambda p: p.requires_grad, net_g.parameters()),
147
+ hps.train.learning_rate,
148
+ betas=hps.train.betas,
149
+ eps=hps.train.eps,
150
+ )
151
+ optim_d = torch.optim.AdamW(
152
+ net_d.parameters(),
153
+ hps.train.learning_rate,
154
+ betas=hps.train.betas,
155
+ eps=hps.train.eps,
156
+ )
157
+ if net_dur_disc is not None:
158
+ optim_dur_disc = torch.optim.AdamW(
159
+ net_dur_disc.parameters(),
160
+ hps.train.learning_rate,
161
+ betas=hps.train.betas,
162
+ eps=hps.train.eps,
163
+ )
164
+ else:
165
+ optim_dur_disc = None
166
+ net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
167
+ net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
168
+
169
+ pretrain_G, pretrain_D, pretrain_dur = load_pretrain_model()
170
+ hps.pretrain_G = hps.pretrain_G or pretrain_G
171
+ hps.pretrain_D = hps.pretrain_D or pretrain_D
172
+ hps.pretrain_dur = hps.pretrain_dur or pretrain_dur
173
+
174
+ if hps.pretrain_G:
175
+ utils.load_checkpoint(
176
+ hps.pretrain_G,
177
+ net_g,
178
+ None,
179
+ skip_optimizer=True
180
+ )
181
+ if hps.pretrain_D:
182
+ utils.load_checkpoint(
183
+ hps.pretrain_D,
184
+ net_d,
185
+ None,
186
+ skip_optimizer=True
187
+ )
188
+
189
+
190
+ if net_dur_disc is not None:
191
+ net_dur_disc = DDP(net_dur_disc, device_ids=[rank], find_unused_parameters=True)
192
+ if hps.pretrain_dur:
193
+ utils.load_checkpoint(
194
+ hps.pretrain_dur,
195
+ net_dur_disc,
196
+ None,
197
+ skip_optimizer=True
198
+ )
199
+
200
+ try:
201
+ if net_dur_disc is not None:
202
+ _, _, dur_resume_lr, epoch_str = utils.load_checkpoint(
203
+ utils.latest_checkpoint_path(hps.model_dir, "DUR_*.pth"),
204
+ net_dur_disc,
205
+ optim_dur_disc,
206
+ skip_optimizer=hps.train.skip_optimizer
207
+ if "skip_optimizer" in hps.train
208
+ else True,
209
+ )
210
+ _, optim_g, g_resume_lr, epoch_str = utils.load_checkpoint(
211
+ utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"),
212
+ net_g,
213
+ optim_g,
214
+ skip_optimizer=hps.train.skip_optimizer
215
+ if "skip_optimizer" in hps.train
216
+ else True,
217
+ )
218
+ _, optim_d, d_resume_lr, epoch_str = utils.load_checkpoint(
219
+ utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"),
220
+ net_d,
221
+ optim_d,
222
+ skip_optimizer=hps.train.skip_optimizer
223
+ if "skip_optimizer" in hps.train
224
+ else True,
225
+ )
226
+ if not optim_g.param_groups[0].get("initial_lr"):
227
+ optim_g.param_groups[0]["initial_lr"] = g_resume_lr
228
+ if not optim_d.param_groups[0].get("initial_lr"):
229
+ optim_d.param_groups[0]["initial_lr"] = d_resume_lr
230
+ if not optim_dur_disc.param_groups[0].get("initial_lr"):
231
+ optim_dur_disc.param_groups[0]["initial_lr"] = dur_resume_lr
232
+
233
+ epoch_str = max(epoch_str, 1)
234
+ global_step = (epoch_str - 1) * len(train_loader)
235
+ except Exception as e:
236
+ print(e)
237
+ epoch_str = 1
238
+ global_step = 0
239
+
240
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
241
+ optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
242
+ )
243
+ scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
244
+ optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
245
+ )
246
+ if net_dur_disc is not None:
247
+ scheduler_dur_disc = torch.optim.lr_scheduler.ExponentialLR(
248
+ optim_dur_disc, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
249
+ )
250
+ else:
251
+ scheduler_dur_disc = None
252
+ scaler = GradScaler(enabled=hps.train.fp16_run)
253
+
254
+ for epoch in range(epoch_str, hps.train.epochs + 1):
255
+ try:
256
+ if rank == 0:
257
+ train_and_evaluate(
258
+ rank,
259
+ epoch,
260
+ hps,
261
+ [net_g, net_d, net_dur_disc],
262
+ [optim_g, optim_d, optim_dur_disc],
263
+ [scheduler_g, scheduler_d, scheduler_dur_disc],
264
+ scaler,
265
+ [train_loader, eval_loader],
266
+ logger,
267
+ [writer, writer_eval],
268
+ )
269
+ else:
270
+ train_and_evaluate(
271
+ rank,
272
+ epoch,
273
+ hps,
274
+ [net_g, net_d, net_dur_disc],
275
+ [optim_g, optim_d, optim_dur_disc],
276
+ [scheduler_g, scheduler_d, scheduler_dur_disc],
277
+ scaler,
278
+ [train_loader, None],
279
+ None,
280
+ None,
281
+ )
282
+ except Exception as e:
283
+ print(e)
284
+ torch.cuda.empty_cache()
285
+ scheduler_g.step()
286
+ scheduler_d.step()
287
+ if net_dur_disc is not None:
288
+ scheduler_dur_disc.step()
289
+
290
+
291
+ def train_and_evaluate(
292
+ rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
293
+ ):
294
+ net_g, net_d, net_dur_disc = nets
295
+ optim_g, optim_d, optim_dur_disc = optims
296
+ scheduler_g, scheduler_d, scheduler_dur_disc = schedulers
297
+ train_loader, eval_loader = loaders
298
+ if writers is not None:
299
+ writer, writer_eval = writers
300
+
301
+ train_loader.batch_sampler.set_epoch(epoch)
302
+ global global_step
303
+
304
+ net_g.train()
305
+ net_d.train()
306
+ if net_dur_disc is not None:
307
+ net_dur_disc.train()
308
+ for batch_idx, (
309
+ x,
310
+ x_lengths,
311
+ spec,
312
+ spec_lengths,
313
+ y,
314
+ y_lengths,
315
+ speakers,
316
+ tone,
317
+ language,
318
+ bert,
319
+ ja_bert,
320
+ ) in enumerate(tqdm(train_loader)):
321
+ if net_g.module.use_noise_scaled_mas:
322
+ current_mas_noise_scale = (
323
+ net_g.module.mas_noise_scale_initial
324
+ - net_g.module.noise_scale_delta * global_step
325
+ )
326
+ net_g.module.current_mas_noise_scale = max(current_mas_noise_scale, 0.0)
327
+ x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda(
328
+ rank, non_blocking=True
329
+ )
330
+ spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
331
+ rank, non_blocking=True
332
+ )
333
+ y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(
334
+ rank, non_blocking=True
335
+ )
336
+ speakers = speakers.cuda(rank, non_blocking=True)
337
+ tone = tone.cuda(rank, non_blocking=True)
338
+ language = language.cuda(rank, non_blocking=True)
339
+ bert = bert.cuda(rank, non_blocking=True)
340
+ ja_bert = ja_bert.cuda(rank, non_blocking=True)
341
+
342
+ with autocast(enabled=hps.train.fp16_run):
343
+ (
344
+ y_hat,
345
+ l_length,
346
+ attn,
347
+ ids_slice,
348
+ x_mask,
349
+ z_mask,
350
+ (z, z_p, m_p, logs_p, m_q, logs_q),
351
+ (hidden_x, logw, logw_),
352
+ ) = net_g(
353
+ x,
354
+ x_lengths,
355
+ spec,
356
+ spec_lengths,
357
+ speakers,
358
+ tone,
359
+ language,
360
+ bert,
361
+ ja_bert,
362
+ )
363
+ mel = spec_to_mel_torch(
364
+ spec,
365
+ hps.data.filter_length,
366
+ hps.data.n_mel_channels,
367
+ hps.data.sampling_rate,
368
+ hps.data.mel_fmin,
369
+ hps.data.mel_fmax,
370
+ )
371
+ y_mel = commons.slice_segments(
372
+ mel, ids_slice, hps.train.segment_size // hps.data.hop_length
373
+ )
374
+ y_hat_mel = mel_spectrogram_torch(
375
+ y_hat.squeeze(1),
376
+ hps.data.filter_length,
377
+ hps.data.n_mel_channels,
378
+ hps.data.sampling_rate,
379
+ hps.data.hop_length,
380
+ hps.data.win_length,
381
+ hps.data.mel_fmin,
382
+ hps.data.mel_fmax,
383
+ )
384
+
385
+ y = commons.slice_segments(
386
+ y, ids_slice * hps.data.hop_length, hps.train.segment_size
387
+ ) # slice
388
+
389
+ # Discriminator
390
+ y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
391
+ with autocast(enabled=False):
392
+ loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
393
+ y_d_hat_r, y_d_hat_g
394
+ )
395
+ loss_disc_all = loss_disc
396
+ if net_dur_disc is not None:
397
+ y_dur_hat_r, y_dur_hat_g = net_dur_disc(
398
+ hidden_x.detach(), x_mask.detach(), logw.detach(), logw_.detach()
399
+ )
400
+ with autocast(enabled=False):
401
+ # TODO: I think need to mean using the mask, but for now, just mean all
402
+ (
403
+ loss_dur_disc,
404
+ losses_dur_disc_r,
405
+ losses_dur_disc_g,
406
+ ) = discriminator_loss(y_dur_hat_r, y_dur_hat_g)
407
+ loss_dur_disc_all = loss_dur_disc
408
+ optim_dur_disc.zero_grad()
409
+ scaler.scale(loss_dur_disc_all).backward()
410
+ scaler.unscale_(optim_dur_disc)
411
+ commons.clip_grad_value_(net_dur_disc.parameters(), None)
412
+ scaler.step(optim_dur_disc)
413
+
414
+ optim_d.zero_grad()
415
+ scaler.scale(loss_disc_all).backward()
416
+ scaler.unscale_(optim_d)
417
+ grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
418
+ scaler.step(optim_d)
419
+
420
+ with autocast(enabled=hps.train.fp16_run):
421
+ # Generator
422
+ y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
423
+ if net_dur_disc is not None:
424
+ y_dur_hat_r, y_dur_hat_g = net_dur_disc(hidden_x, x_mask, logw, logw_)
425
+ with autocast(enabled=False):
426
+ loss_dur = torch.sum(l_length.float())
427
+ loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
428
+ loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
429
+
430
+ loss_fm = feature_loss(fmap_r, fmap_g)
431
+ loss_gen, losses_gen = generator_loss(y_d_hat_g)
432
+ loss_gen_all = loss_gen + loss_fm + loss_mel + loss_dur + loss_kl
433
+ if net_dur_disc is not None:
434
+ loss_dur_gen, losses_dur_gen = generator_loss(y_dur_hat_g)
435
+ loss_gen_all += loss_dur_gen
436
+ optim_g.zero_grad()
437
+ scaler.scale(loss_gen_all).backward()
438
+ scaler.unscale_(optim_g)
439
+ grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
440
+ scaler.step(optim_g)
441
+ scaler.update()
442
+
443
+ if rank == 0:
444
+ if global_step % hps.train.log_interval == 0:
445
+ lr = optim_g.param_groups[0]["lr"]
446
+ losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_dur, loss_kl]
447
+ logger.info(
448
+ "Train Epoch: {} [{:.0f}%]".format(
449
+ epoch, 100.0 * batch_idx / len(train_loader)
450
+ )
451
+ )
452
+ logger.info([x.item() for x in losses] + [global_step, lr])
453
+
454
+ scalar_dict = {
455
+ "loss/g/total": loss_gen_all,
456
+ "loss/d/total": loss_disc_all,
457
+ "learning_rate": lr,
458
+ "grad_norm_d": grad_norm_d,
459
+ "grad_norm_g": grad_norm_g,
460
+ }
461
+ scalar_dict.update(
462
+ {
463
+ "loss/g/fm": loss_fm,
464
+ "loss/g/mel": loss_mel,
465
+ "loss/g/dur": loss_dur,
466
+ "loss/g/kl": loss_kl,
467
+ }
468
+ )
469
+ scalar_dict.update(
470
+ {"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}
471
+ )
472
+ scalar_dict.update(
473
+ {"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}
474
+ )
475
+ scalar_dict.update(
476
+ {"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}
477
+ )
478
+
479
+ image_dict = {
480
+ "slice/mel_org": utils.plot_spectrogram_to_numpy(
481
+ y_mel[0].data.cpu().numpy()
482
+ ),
483
+ "slice/mel_gen": utils.plot_spectrogram_to_numpy(
484
+ y_hat_mel[0].data.cpu().numpy()
485
+ ),
486
+ "all/mel": utils.plot_spectrogram_to_numpy(
487
+ mel[0].data.cpu().numpy()
488
+ ),
489
+ "all/attn": utils.plot_alignment_to_numpy(
490
+ attn[0, 0].data.cpu().numpy()
491
+ ),
492
+ }
493
+ utils.summarize(
494
+ writer=writer,
495
+ global_step=global_step,
496
+ images=image_dict,
497
+ scalars=scalar_dict,
498
+ )
499
+
500
+ if global_step % hps.train.eval_interval == 0:
501
+ evaluate(hps, net_g, eval_loader, writer_eval)
502
+ utils.save_checkpoint(
503
+ net_g,
504
+ optim_g,
505
+ hps.train.learning_rate,
506
+ epoch,
507
+ os.path.join(hps.model_dir, "G_{}.pth".format(global_step)),
508
+ )
509
+ utils.save_checkpoint(
510
+ net_d,
511
+ optim_d,
512
+ hps.train.learning_rate,
513
+ epoch,
514
+ os.path.join(hps.model_dir, "D_{}.pth".format(global_step)),
515
+ )
516
+ if net_dur_disc is not None:
517
+ utils.save_checkpoint(
518
+ net_dur_disc,
519
+ optim_dur_disc,
520
+ hps.train.learning_rate,
521
+ epoch,
522
+ os.path.join(hps.model_dir, "DUR_{}.pth".format(global_step)),
523
+ )
524
+ keep_ckpts = getattr(hps.train, "keep_ckpts", 5)
525
+ if keep_ckpts > 0:
526
+ utils.clean_checkpoints(
527
+ path_to_models=hps.model_dir,
528
+ n_ckpts_to_keep=keep_ckpts,
529
+ sort_by_time=True,
530
+ )
531
+
532
+ global_step += 1
533
+
534
+ if rank == 0:
535
+ logger.info("====> Epoch: {}".format(epoch))
536
+ torch.cuda.empty_cache()
537
+
538
+
539
+ def evaluate(hps, generator, eval_loader, writer_eval):
540
+ generator.eval()
541
+ image_dict = {}
542
+ audio_dict = {}
543
+ print("Evaluating ...")
544
+ with torch.no_grad():
545
+ for batch_idx, (
546
+ x,
547
+ x_lengths,
548
+ spec,
549
+ spec_lengths,
550
+ y,
551
+ y_lengths,
552
+ speakers,
553
+ tone,
554
+ language,
555
+ bert,
556
+ ja_bert,
557
+ ) in enumerate(eval_loader):
558
+ x, x_lengths = x.cuda(), x_lengths.cuda()
559
+ spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
560
+ y, y_lengths = y.cuda(), y_lengths.cuda()
561
+ speakers = speakers.cuda()
562
+ bert = bert.cuda()
563
+ ja_bert = ja_bert.cuda()
564
+ tone = tone.cuda()
565
+ language = language.cuda()
566
+ for use_sdp in [True, False]:
567
+ y_hat, attn, mask, *_ = generator.module.infer(
568
+ x,
569
+ x_lengths,
570
+ speakers,
571
+ tone,
572
+ language,
573
+ bert,
574
+ ja_bert,
575
+ y=spec,
576
+ max_len=1000,
577
+ sdp_ratio=0.0 if not use_sdp else 1.0,
578
+ )
579
+ y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
580
+
581
+ mel = spec_to_mel_torch(
582
+ spec,
583
+ hps.data.filter_length,
584
+ hps.data.n_mel_channels,
585
+ hps.data.sampling_rate,
586
+ hps.data.mel_fmin,
587
+ hps.data.mel_fmax,
588
+ )
589
+ y_hat_mel = mel_spectrogram_torch(
590
+ y_hat.squeeze(1).float(),
591
+ hps.data.filter_length,
592
+ hps.data.n_mel_channels,
593
+ hps.data.sampling_rate,
594
+ hps.data.hop_length,
595
+ hps.data.win_length,
596
+ hps.data.mel_fmin,
597
+ hps.data.mel_fmax,
598
+ )
599
+ image_dict.update(
600
+ {
601
+ f"gen/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(
602
+ y_hat_mel[0].cpu().numpy()
603
+ )
604
+ }
605
+ )
606
+ audio_dict.update(
607
+ {
608
+ f"gen/audio_{batch_idx}_{use_sdp}": y_hat[
609
+ 0, :, : y_hat_lengths[0]
610
+ ]
611
+ }
612
+ )
613
+ image_dict.update(
614
+ {
615
+ f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(
616
+ mel[0].cpu().numpy()
617
+ )
618
+ }
619
+ )
620
+ audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]})
621
+
622
+ utils.summarize(
623
+ writer=writer_eval,
624
+ global_step=global_step,
625
+ images=image_dict,
626
+ audios=audio_dict,
627
+ audio_sampling_rate=hps.data.sampling_rate,
628
+ )
629
+ generator.train()
630
+ print('Evauate done')
631
+ torch.cuda.empty_cache()
632
+
633
+
634
+ if __name__ == "__main__":
635
+ run()
@@ -0,0 +1,19 @@
1
+ CONFIG=$1
2
+ GPUS=$2
3
+ MODEL_NAME=$(basename "$(dirname $CONFIG)")
4
+
5
+ PORT=10902
6
+
7
+ while : # auto-resume: the code sometimes crash due to bug of gloo on some gpus
8
+ do
9
+ torchrun --nproc_per_node=$GPUS \
10
+ --master_port=$PORT \
11
+ train.py --c $CONFIG --model $MODEL_NAME
12
+
13
+ for PID in $(ps -aux | grep $CONFIG | grep python | awk '{print $2}')
14
+ do
15
+ echo $PID
16
+ kill -9 $PID
17
+ done
18
+ sleep 30
19
+ done