xinference 0.14.3__py3-none-any.whl → 0.14.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (70) hide show
  1. xinference/_version.py +3 -3
  2. xinference/core/worker.py +18 -9
  3. xinference/model/audio/chattts.py +4 -3
  4. xinference/model/audio/cosyvoice.py +4 -3
  5. xinference/model/audio/custom.py +4 -5
  6. xinference/model/embedding/core.py +2 -0
  7. xinference/model/embedding/custom.py +4 -5
  8. xinference/model/flexible/core.py +5 -1
  9. xinference/model/image/custom.py +4 -5
  10. xinference/model/image/stable_diffusion/core.py +21 -6
  11. xinference/model/llm/llm_family.py +5 -6
  12. xinference/model/llm/sglang/core.py +7 -1
  13. xinference/model/llm/transformers/core.py +2 -0
  14. xinference/model/llm/utils.py +3 -0
  15. xinference/model/llm/vllm/core.py +0 -33
  16. xinference/model/rerank/custom.py +4 -5
  17. xinference/model/utils.py +41 -1
  18. xinference/model/video/core.py +3 -1
  19. xinference/model/video/diffusers.py +41 -38
  20. xinference/model/video/model_spec.json +24 -1
  21. xinference/model/video/model_spec_modelscope.json +25 -1
  22. xinference/thirdparty/fish_speech/tools/api.py +1 -1
  23. xinference/thirdparty/matcha/__init__.py +0 -0
  24. xinference/thirdparty/matcha/app.py +357 -0
  25. xinference/thirdparty/matcha/cli.py +419 -0
  26. xinference/thirdparty/matcha/data/__init__.py +0 -0
  27. xinference/thirdparty/matcha/data/components/__init__.py +0 -0
  28. xinference/thirdparty/matcha/data/text_mel_datamodule.py +274 -0
  29. xinference/thirdparty/matcha/hifigan/__init__.py +0 -0
  30. xinference/thirdparty/matcha/hifigan/config.py +28 -0
  31. xinference/thirdparty/matcha/hifigan/denoiser.py +64 -0
  32. xinference/thirdparty/matcha/hifigan/env.py +17 -0
  33. xinference/thirdparty/matcha/hifigan/meldataset.py +217 -0
  34. xinference/thirdparty/matcha/hifigan/models.py +368 -0
  35. xinference/thirdparty/matcha/hifigan/xutils.py +60 -0
  36. xinference/thirdparty/matcha/models/__init__.py +0 -0
  37. xinference/thirdparty/matcha/models/baselightningmodule.py +210 -0
  38. xinference/thirdparty/matcha/models/components/__init__.py +0 -0
  39. xinference/thirdparty/matcha/models/components/decoder.py +443 -0
  40. xinference/thirdparty/matcha/models/components/flow_matching.py +132 -0
  41. xinference/thirdparty/matcha/models/components/text_encoder.py +410 -0
  42. xinference/thirdparty/matcha/models/components/transformer.py +316 -0
  43. xinference/thirdparty/matcha/models/matcha_tts.py +244 -0
  44. xinference/thirdparty/matcha/onnx/__init__.py +0 -0
  45. xinference/thirdparty/matcha/onnx/export.py +181 -0
  46. xinference/thirdparty/matcha/onnx/infer.py +168 -0
  47. xinference/thirdparty/matcha/text/__init__.py +53 -0
  48. xinference/thirdparty/matcha/text/cleaners.py +121 -0
  49. xinference/thirdparty/matcha/text/numbers.py +71 -0
  50. xinference/thirdparty/matcha/text/symbols.py +17 -0
  51. xinference/thirdparty/matcha/train.py +122 -0
  52. xinference/thirdparty/matcha/utils/__init__.py +5 -0
  53. xinference/thirdparty/matcha/utils/audio.py +82 -0
  54. xinference/thirdparty/matcha/utils/generate_data_statistics.py +112 -0
  55. xinference/thirdparty/matcha/utils/get_durations_from_trained_model.py +195 -0
  56. xinference/thirdparty/matcha/utils/instantiators.py +56 -0
  57. xinference/thirdparty/matcha/utils/logging_utils.py +53 -0
  58. xinference/thirdparty/matcha/utils/model.py +90 -0
  59. xinference/thirdparty/matcha/utils/monotonic_align/__init__.py +22 -0
  60. xinference/thirdparty/matcha/utils/monotonic_align/core.pyx +47 -0
  61. xinference/thirdparty/matcha/utils/monotonic_align/setup.py +7 -0
  62. xinference/thirdparty/matcha/utils/pylogger.py +21 -0
  63. xinference/thirdparty/matcha/utils/rich_utils.py +101 -0
  64. xinference/thirdparty/matcha/utils/utils.py +259 -0
  65. {xinference-0.14.3.dist-info → xinference-0.14.4.dist-info}/METADATA +20 -12
  66. {xinference-0.14.3.dist-info → xinference-0.14.4.dist-info}/RECORD +70 -28
  67. {xinference-0.14.3.dist-info → xinference-0.14.4.dist-info}/LICENSE +0 -0
  68. {xinference-0.14.3.dist-info → xinference-0.14.4.dist-info}/WHEEL +0 -0
  69. {xinference-0.14.3.dist-info → xinference-0.14.4.dist-info}/entry_points.txt +0 -0
  70. {xinference-0.14.3.dist-info → xinference-0.14.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,60 @@
1
+ """ from https://github.com/jik876/hifi-gan """
2
+
3
+ import glob
4
+ import os
5
+
6
+ import matplotlib
7
+ import torch
8
+ from torch.nn.utils import weight_norm
9
+
10
+ matplotlib.use("Agg")
11
+ import matplotlib.pylab as plt
12
+
13
+
14
+ def plot_spectrogram(spectrogram):
15
+ fig, ax = plt.subplots(figsize=(10, 2))
16
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
17
+ plt.colorbar(im, ax=ax)
18
+
19
+ fig.canvas.draw()
20
+ plt.close()
21
+
22
+ return fig
23
+
24
+
25
+ def init_weights(m, mean=0.0, std=0.01):
26
+ classname = m.__class__.__name__
27
+ if classname.find("Conv") != -1:
28
+ m.weight.data.normal_(mean, std)
29
+
30
+
31
+ def apply_weight_norm(m):
32
+ classname = m.__class__.__name__
33
+ if classname.find("Conv") != -1:
34
+ weight_norm(m)
35
+
36
+
37
+ def get_padding(kernel_size, dilation=1):
38
+ return int((kernel_size * dilation - dilation) / 2)
39
+
40
+
41
+ def load_checkpoint(filepath, device):
42
+ assert os.path.isfile(filepath)
43
+ print(f"Loading '{filepath}'")
44
+ checkpoint_dict = torch.load(filepath, map_location=device)
45
+ print("Complete.")
46
+ return checkpoint_dict
47
+
48
+
49
+ def save_checkpoint(filepath, obj):
50
+ print(f"Saving checkpoint to {filepath}")
51
+ torch.save(obj, filepath)
52
+ print("Complete.")
53
+
54
+
55
+ def scan_checkpoint(cp_dir, prefix):
56
+ pattern = os.path.join(cp_dir, prefix + "????????")
57
+ cp_list = glob.glob(pattern)
58
+ if len(cp_list) == 0:
59
+ return None
60
+ return sorted(cp_list)[-1]
File without changes
@@ -0,0 +1,210 @@
1
+ """
2
+ This is a base lightning module that can be used to train a model.
3
+ The benefit of this abstraction is that all the logic outside of model definition can be reused for different models.
4
+ """
5
+ import inspect
6
+ from abc import ABC
7
+ from typing import Any, Dict
8
+
9
+ import torch
10
+ from lightning import LightningModule
11
+ from lightning.pytorch.utilities import grad_norm
12
+
13
+ from matcha import utils
14
+ from matcha.utils.utils import plot_tensor
15
+
16
+ log = utils.get_pylogger(__name__)
17
+
18
+
19
+ class BaseLightningClass(LightningModule, ABC):
20
+ def update_data_statistics(self, data_statistics):
21
+ if data_statistics is None:
22
+ data_statistics = {
23
+ "mel_mean": 0.0,
24
+ "mel_std": 1.0,
25
+ }
26
+
27
+ self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"]))
28
+ self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"]))
29
+
30
+ def configure_optimizers(self) -> Any:
31
+ optimizer = self.hparams.optimizer(params=self.parameters())
32
+ if self.hparams.scheduler not in (None, {}):
33
+ scheduler_args = {}
34
+ # Manage last epoch for exponential schedulers
35
+ if "last_epoch" in inspect.signature(self.hparams.scheduler.scheduler).parameters:
36
+ if hasattr(self, "ckpt_loaded_epoch"):
37
+ current_epoch = self.ckpt_loaded_epoch - 1
38
+ else:
39
+ current_epoch = -1
40
+
41
+ scheduler_args.update({"optimizer": optimizer})
42
+ scheduler = self.hparams.scheduler.scheduler(**scheduler_args)
43
+ scheduler.last_epoch = current_epoch
44
+ return {
45
+ "optimizer": optimizer,
46
+ "lr_scheduler": {
47
+ "scheduler": scheduler,
48
+ "interval": self.hparams.scheduler.lightning_args.interval,
49
+ "frequency": self.hparams.scheduler.lightning_args.frequency,
50
+ "name": "learning_rate",
51
+ },
52
+ }
53
+
54
+ return {"optimizer": optimizer}
55
+
56
+ def get_losses(self, batch):
57
+ x, x_lengths = batch["x"], batch["x_lengths"]
58
+ y, y_lengths = batch["y"], batch["y_lengths"]
59
+ spks = batch["spks"]
60
+
61
+ dur_loss, prior_loss, diff_loss, *_ = self(
62
+ x=x,
63
+ x_lengths=x_lengths,
64
+ y=y,
65
+ y_lengths=y_lengths,
66
+ spks=spks,
67
+ out_size=self.out_size,
68
+ durations=batch["durations"],
69
+ )
70
+ return {
71
+ "dur_loss": dur_loss,
72
+ "prior_loss": prior_loss,
73
+ "diff_loss": diff_loss,
74
+ }
75
+
76
+ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
77
+ self.ckpt_loaded_epoch = checkpoint["epoch"] # pylint: disable=attribute-defined-outside-init
78
+
79
+ def training_step(self, batch: Any, batch_idx: int):
80
+ loss_dict = self.get_losses(batch)
81
+ self.log(
82
+ "step",
83
+ float(self.global_step),
84
+ on_step=True,
85
+ prog_bar=True,
86
+ logger=True,
87
+ sync_dist=True,
88
+ )
89
+
90
+ self.log(
91
+ "sub_loss/train_dur_loss",
92
+ loss_dict["dur_loss"],
93
+ on_step=True,
94
+ on_epoch=True,
95
+ logger=True,
96
+ sync_dist=True,
97
+ )
98
+ self.log(
99
+ "sub_loss/train_prior_loss",
100
+ loss_dict["prior_loss"],
101
+ on_step=True,
102
+ on_epoch=True,
103
+ logger=True,
104
+ sync_dist=True,
105
+ )
106
+ self.log(
107
+ "sub_loss/train_diff_loss",
108
+ loss_dict["diff_loss"],
109
+ on_step=True,
110
+ on_epoch=True,
111
+ logger=True,
112
+ sync_dist=True,
113
+ )
114
+
115
+ total_loss = sum(loss_dict.values())
116
+ self.log(
117
+ "loss/train",
118
+ total_loss,
119
+ on_step=True,
120
+ on_epoch=True,
121
+ logger=True,
122
+ prog_bar=True,
123
+ sync_dist=True,
124
+ )
125
+
126
+ return {"loss": total_loss, "log": loss_dict}
127
+
128
+ def validation_step(self, batch: Any, batch_idx: int):
129
+ loss_dict = self.get_losses(batch)
130
+ self.log(
131
+ "sub_loss/val_dur_loss",
132
+ loss_dict["dur_loss"],
133
+ on_step=True,
134
+ on_epoch=True,
135
+ logger=True,
136
+ sync_dist=True,
137
+ )
138
+ self.log(
139
+ "sub_loss/val_prior_loss",
140
+ loss_dict["prior_loss"],
141
+ on_step=True,
142
+ on_epoch=True,
143
+ logger=True,
144
+ sync_dist=True,
145
+ )
146
+ self.log(
147
+ "sub_loss/val_diff_loss",
148
+ loss_dict["diff_loss"],
149
+ on_step=True,
150
+ on_epoch=True,
151
+ logger=True,
152
+ sync_dist=True,
153
+ )
154
+
155
+ total_loss = sum(loss_dict.values())
156
+ self.log(
157
+ "loss/val",
158
+ total_loss,
159
+ on_step=True,
160
+ on_epoch=True,
161
+ logger=True,
162
+ prog_bar=True,
163
+ sync_dist=True,
164
+ )
165
+
166
+ return total_loss
167
+
168
+ def on_validation_end(self) -> None:
169
+ if self.trainer.is_global_zero:
170
+ one_batch = next(iter(self.trainer.val_dataloaders))
171
+ if self.current_epoch == 0:
172
+ log.debug("Plotting original samples")
173
+ for i in range(2):
174
+ y = one_batch["y"][i].unsqueeze(0).to(self.device)
175
+ self.logger.experiment.add_image(
176
+ f"original/{i}",
177
+ plot_tensor(y.squeeze().cpu()),
178
+ self.current_epoch,
179
+ dataformats="HWC",
180
+ )
181
+
182
+ log.debug("Synthesising...")
183
+ for i in range(2):
184
+ x = one_batch["x"][i].unsqueeze(0).to(self.device)
185
+ x_lengths = one_batch["x_lengths"][i].unsqueeze(0).to(self.device)
186
+ spks = one_batch["spks"][i].unsqueeze(0).to(self.device) if one_batch["spks"] is not None else None
187
+ output = self.synthesise(x[:, :x_lengths], x_lengths, n_timesteps=10, spks=spks)
188
+ y_enc, y_dec = output["encoder_outputs"], output["decoder_outputs"]
189
+ attn = output["attn"]
190
+ self.logger.experiment.add_image(
191
+ f"generated_enc/{i}",
192
+ plot_tensor(y_enc.squeeze().cpu()),
193
+ self.current_epoch,
194
+ dataformats="HWC",
195
+ )
196
+ self.logger.experiment.add_image(
197
+ f"generated_dec/{i}",
198
+ plot_tensor(y_dec.squeeze().cpu()),
199
+ self.current_epoch,
200
+ dataformats="HWC",
201
+ )
202
+ self.logger.experiment.add_image(
203
+ f"alignment/{i}",
204
+ plot_tensor(attn.squeeze().cpu()),
205
+ self.current_epoch,
206
+ dataformats="HWC",
207
+ )
208
+
209
+ def on_before_optimizer_step(self, optimizer):
210
+ self.log_dict({f"grad_norm/{k}": v for k, v in grad_norm(self, norm_type=2).items()})