xinference 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (94) hide show
  1. xinference/_compat.py +22 -2
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +91 -6
  4. xinference/client/restful/restful_client.py +39 -0
  5. xinference/core/model.py +41 -13
  6. xinference/deploy/cmdline.py +3 -1
  7. xinference/deploy/test/test_cmdline.py +56 -0
  8. xinference/isolation.py +24 -0
  9. xinference/model/audio/__init__.py +12 -0
  10. xinference/model/audio/core.py +26 -4
  11. xinference/model/audio/f5tts.py +195 -0
  12. xinference/model/audio/fish_speech.py +71 -35
  13. xinference/model/audio/model_spec.json +88 -0
  14. xinference/model/audio/model_spec_modelscope.json +9 -0
  15. xinference/model/audio/whisper_mlx.py +208 -0
  16. xinference/model/embedding/core.py +322 -6
  17. xinference/model/embedding/model_spec.json +8 -1
  18. xinference/model/embedding/model_spec_modelscope.json +9 -1
  19. xinference/model/llm/__init__.py +4 -2
  20. xinference/model/llm/llm_family.json +479 -53
  21. xinference/model/llm/llm_family_modelscope.json +423 -17
  22. xinference/model/llm/mlx/core.py +230 -50
  23. xinference/model/llm/sglang/core.py +2 -0
  24. xinference/model/llm/transformers/chatglm.py +9 -5
  25. xinference/model/llm/transformers/core.py +1 -0
  26. xinference/model/llm/transformers/glm_edge_v.py +230 -0
  27. xinference/model/llm/transformers/utils.py +16 -8
  28. xinference/model/llm/utils.py +23 -1
  29. xinference/model/llm/vllm/core.py +89 -2
  30. xinference/thirdparty/f5_tts/__init__.py +0 -0
  31. xinference/thirdparty/f5_tts/api.py +166 -0
  32. xinference/thirdparty/f5_tts/configs/E2TTS_Base_train.yaml +44 -0
  33. xinference/thirdparty/f5_tts/configs/E2TTS_Small_train.yaml +44 -0
  34. xinference/thirdparty/f5_tts/configs/F5TTS_Base_train.yaml +46 -0
  35. xinference/thirdparty/f5_tts/configs/F5TTS_Small_train.yaml +46 -0
  36. xinference/thirdparty/f5_tts/eval/README.md +49 -0
  37. xinference/thirdparty/f5_tts/eval/ecapa_tdnn.py +330 -0
  38. xinference/thirdparty/f5_tts/eval/eval_infer_batch.py +207 -0
  39. xinference/thirdparty/f5_tts/eval/eval_infer_batch.sh +13 -0
  40. xinference/thirdparty/f5_tts/eval/eval_librispeech_test_clean.py +84 -0
  41. xinference/thirdparty/f5_tts/eval/eval_seedtts_testset.py +84 -0
  42. xinference/thirdparty/f5_tts/eval/utils_eval.py +405 -0
  43. xinference/thirdparty/f5_tts/infer/README.md +191 -0
  44. xinference/thirdparty/f5_tts/infer/SHARED.md +74 -0
  45. xinference/thirdparty/f5_tts/infer/examples/basic/basic.toml +11 -0
  46. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
  47. xinference/thirdparty/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
  48. xinference/thirdparty/f5_tts/infer/examples/multi/country.flac +0 -0
  49. xinference/thirdparty/f5_tts/infer/examples/multi/main.flac +0 -0
  50. xinference/thirdparty/f5_tts/infer/examples/multi/story.toml +19 -0
  51. xinference/thirdparty/f5_tts/infer/examples/multi/story.txt +1 -0
  52. xinference/thirdparty/f5_tts/infer/examples/multi/town.flac +0 -0
  53. xinference/thirdparty/f5_tts/infer/examples/vocab.txt +2545 -0
  54. xinference/thirdparty/f5_tts/infer/infer_cli.py +226 -0
  55. xinference/thirdparty/f5_tts/infer/infer_gradio.py +851 -0
  56. xinference/thirdparty/f5_tts/infer/speech_edit.py +193 -0
  57. xinference/thirdparty/f5_tts/infer/utils_infer.py +538 -0
  58. xinference/thirdparty/f5_tts/model/__init__.py +10 -0
  59. xinference/thirdparty/f5_tts/model/backbones/README.md +20 -0
  60. xinference/thirdparty/f5_tts/model/backbones/dit.py +163 -0
  61. xinference/thirdparty/f5_tts/model/backbones/mmdit.py +146 -0
  62. xinference/thirdparty/f5_tts/model/backbones/unett.py +219 -0
  63. xinference/thirdparty/f5_tts/model/cfm.py +285 -0
  64. xinference/thirdparty/f5_tts/model/dataset.py +319 -0
  65. xinference/thirdparty/f5_tts/model/modules.py +658 -0
  66. xinference/thirdparty/f5_tts/model/trainer.py +366 -0
  67. xinference/thirdparty/f5_tts/model/utils.py +185 -0
  68. xinference/thirdparty/f5_tts/scripts/count_max_epoch.py +33 -0
  69. xinference/thirdparty/f5_tts/scripts/count_params_gflops.py +39 -0
  70. xinference/thirdparty/f5_tts/socket_server.py +159 -0
  71. xinference/thirdparty/f5_tts/train/README.md +77 -0
  72. xinference/thirdparty/f5_tts/train/datasets/prepare_csv_wavs.py +139 -0
  73. xinference/thirdparty/f5_tts/train/datasets/prepare_emilia.py +230 -0
  74. xinference/thirdparty/f5_tts/train/datasets/prepare_libritts.py +92 -0
  75. xinference/thirdparty/f5_tts/train/datasets/prepare_ljspeech.py +65 -0
  76. xinference/thirdparty/f5_tts/train/datasets/prepare_wenetspeech4tts.py +125 -0
  77. xinference/thirdparty/f5_tts/train/finetune_cli.py +174 -0
  78. xinference/thirdparty/f5_tts/train/finetune_gradio.py +1846 -0
  79. xinference/thirdparty/f5_tts/train/train.py +75 -0
  80. xinference/types.py +2 -1
  81. xinference/web/ui/build/asset-manifest.json +3 -3
  82. xinference/web/ui/build/index.html +1 -1
  83. xinference/web/ui/build/static/js/{main.2f269bb3.js → main.4eb4ee80.js} +3 -3
  84. xinference/web/ui/build/static/js/main.4eb4ee80.js.map +1 -0
  85. xinference/web/ui/node_modules/.cache/babel-loader/8c5eeb02f772d02cbe8b89c05428d0dd41a97866f75f7dc1c2164a67f5a1cf98.json +1 -0
  86. {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/METADATA +39 -18
  87. {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/RECORD +92 -39
  88. {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/WHEEL +1 -1
  89. xinference/web/ui/build/static/js/main.2f269bb3.js.map +0 -1
  90. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +0 -1
  91. /xinference/web/ui/build/static/js/{main.2f269bb3.js.LICENSE.txt → main.4eb4ee80.js.LICENSE.txt} +0 -0
  92. {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/LICENSE +0 -0
  93. {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/entry_points.txt +0 -0
  94. {xinference-1.0.0.dist-info → xinference-1.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,163 @@
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import torch
13
+ from torch import nn
14
+ import torch.nn.functional as F
15
+
16
+ from x_transformers.x_transformers import RotaryEmbedding
17
+
18
+ from f5_tts.model.modules import (
19
+ TimestepEmbedding,
20
+ ConvNeXtV2Block,
21
+ ConvPositionEmbedding,
22
+ DiTBlock,
23
+ AdaLayerNormZero_Final,
24
+ precompute_freqs_cis,
25
+ get_pos_embed_indices,
26
+ )
27
+
28
+
29
+ # Text embedding
30
+
31
+
32
+ class TextEmbedding(nn.Module):
33
+ def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
34
+ super().__init__()
35
+ self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
36
+
37
+ if conv_layers > 0:
38
+ self.extra_modeling = True
39
+ self.precompute_max_pos = 4096 # ~44s of 24khz audio
40
+ self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
41
+ self.text_blocks = nn.Sequential(
42
+ *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
43
+ )
44
+ else:
45
+ self.extra_modeling = False
46
+
47
+ def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
48
+ text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
49
+ text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
50
+ batch, text_len = text.shape[0], text.shape[1]
51
+ text = F.pad(text, (0, seq_len - text_len), value=0)
52
+
53
+ if drop_text: # cfg for text
54
+ text = torch.zeros_like(text)
55
+
56
+ text = self.text_embed(text) # b n -> b n d
57
+
58
+ # possible extra modeling
59
+ if self.extra_modeling:
60
+ # sinus pos emb
61
+ batch_start = torch.zeros((batch,), dtype=torch.long)
62
+ pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
63
+ text_pos_embed = self.freqs_cis[pos_idx]
64
+ text = text + text_pos_embed
65
+
66
+ # convnextv2 blocks
67
+ text = self.text_blocks(text)
68
+
69
+ return text
70
+
71
+
72
+ # noised input audio and context mixing embedding
73
+
74
+
75
+ class InputEmbedding(nn.Module):
76
+ def __init__(self, mel_dim, text_dim, out_dim):
77
+ super().__init__()
78
+ self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
79
+ self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
80
+
81
+ def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
82
+ if drop_audio_cond: # cfg for cond audio
83
+ cond = torch.zeros_like(cond)
84
+
85
+ x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
86
+ x = self.conv_pos_embed(x) + x
87
+ return x
88
+
89
+
90
+ # Transformer backbone using DiT blocks
91
+
92
+
93
+ class DiT(nn.Module):
94
+ def __init__(
95
+ self,
96
+ *,
97
+ dim,
98
+ depth=8,
99
+ heads=8,
100
+ dim_head=64,
101
+ dropout=0.1,
102
+ ff_mult=4,
103
+ mel_dim=100,
104
+ text_num_embeds=256,
105
+ text_dim=None,
106
+ conv_layers=0,
107
+ long_skip_connection=False,
108
+ ):
109
+ super().__init__()
110
+
111
+ self.time_embed = TimestepEmbedding(dim)
112
+ if text_dim is None:
113
+ text_dim = mel_dim
114
+ self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
115
+ self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
116
+
117
+ self.rotary_embed = RotaryEmbedding(dim_head)
118
+
119
+ self.dim = dim
120
+ self.depth = depth
121
+
122
+ self.transformer_blocks = nn.ModuleList(
123
+ [DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
124
+ )
125
+ self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
126
+
127
+ self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
128
+ self.proj_out = nn.Linear(dim, mel_dim)
129
+
130
+ def forward(
131
+ self,
132
+ x: float["b n d"], # nosied input audio # noqa: F722
133
+ cond: float["b n d"], # masked cond audio # noqa: F722
134
+ text: int["b nt"], # text # noqa: F722
135
+ time: float["b"] | float[""], # time step # noqa: F821 F722
136
+ drop_audio_cond, # cfg for cond audio
137
+ drop_text, # cfg for text
138
+ mask: bool["b n"] | None = None, # noqa: F722
139
+ ):
140
+ batch, seq_len = x.shape[0], x.shape[1]
141
+ if time.ndim == 0:
142
+ time = time.repeat(batch)
143
+
144
+ # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
145
+ t = self.time_embed(time)
146
+ text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
147
+ x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
148
+
149
+ rope = self.rotary_embed.forward_from_seq_len(seq_len)
150
+
151
+ if self.long_skip_connection is not None:
152
+ residual = x
153
+
154
+ for block in self.transformer_blocks:
155
+ x = block(x, t, mask=mask, rope=rope)
156
+
157
+ if self.long_skip_connection is not None:
158
+ x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
159
+
160
+ x = self.norm_out(x, t)
161
+ output = self.proj_out(x)
162
+
163
+ return output
@@ -0,0 +1,146 @@
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import torch
13
+ from torch import nn
14
+
15
+ from x_transformers.x_transformers import RotaryEmbedding
16
+
17
+ from f5_tts.model.modules import (
18
+ TimestepEmbedding,
19
+ ConvPositionEmbedding,
20
+ MMDiTBlock,
21
+ AdaLayerNormZero_Final,
22
+ precompute_freqs_cis,
23
+ get_pos_embed_indices,
24
+ )
25
+
26
+
27
+ # text embedding
28
+
29
+
30
+ class TextEmbedding(nn.Module):
31
+ def __init__(self, out_dim, text_num_embeds):
32
+ super().__init__()
33
+ self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
34
+
35
+ self.precompute_max_pos = 1024
36
+ self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
37
+
38
+ def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722
39
+ text = text + 1
40
+ if drop_text:
41
+ text = torch.zeros_like(text)
42
+ text = self.text_embed(text)
43
+
44
+ # sinus pos emb
45
+ batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
46
+ batch_text_len = text.shape[1]
47
+ pos_idx = get_pos_embed_indices(batch_start, batch_text_len, max_pos=self.precompute_max_pos)
48
+ text_pos_embed = self.freqs_cis[pos_idx]
49
+
50
+ text = text + text_pos_embed
51
+
52
+ return text
53
+
54
+
55
+ # noised input & masked cond audio embedding
56
+
57
+
58
+ class AudioEmbedding(nn.Module):
59
+ def __init__(self, in_dim, out_dim):
60
+ super().__init__()
61
+ self.linear = nn.Linear(2 * in_dim, out_dim)
62
+ self.conv_pos_embed = ConvPositionEmbedding(out_dim)
63
+
64
+ def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False): # noqa: F722
65
+ if drop_audio_cond:
66
+ cond = torch.zeros_like(cond)
67
+ x = torch.cat((x, cond), dim=-1)
68
+ x = self.linear(x)
69
+ x = self.conv_pos_embed(x) + x
70
+ return x
71
+
72
+
73
+ # Transformer backbone using MM-DiT blocks
74
+
75
+
76
+ class MMDiT(nn.Module):
77
+ def __init__(
78
+ self,
79
+ *,
80
+ dim,
81
+ depth=8,
82
+ heads=8,
83
+ dim_head=64,
84
+ dropout=0.1,
85
+ ff_mult=4,
86
+ text_num_embeds=256,
87
+ mel_dim=100,
88
+ ):
89
+ super().__init__()
90
+
91
+ self.time_embed = TimestepEmbedding(dim)
92
+ self.text_embed = TextEmbedding(dim, text_num_embeds)
93
+ self.audio_embed = AudioEmbedding(mel_dim, dim)
94
+
95
+ self.rotary_embed = RotaryEmbedding(dim_head)
96
+
97
+ self.dim = dim
98
+ self.depth = depth
99
+
100
+ self.transformer_blocks = nn.ModuleList(
101
+ [
102
+ MMDiTBlock(
103
+ dim=dim,
104
+ heads=heads,
105
+ dim_head=dim_head,
106
+ dropout=dropout,
107
+ ff_mult=ff_mult,
108
+ context_pre_only=i == depth - 1,
109
+ )
110
+ for i in range(depth)
111
+ ]
112
+ )
113
+ self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
114
+ self.proj_out = nn.Linear(dim, mel_dim)
115
+
116
+ def forward(
117
+ self,
118
+ x: float["b n d"], # nosied input audio # noqa: F722
119
+ cond: float["b n d"], # masked cond audio # noqa: F722
120
+ text: int["b nt"], # text # noqa: F722
121
+ time: float["b"] | float[""], # time step # noqa: F821 F722
122
+ drop_audio_cond, # cfg for cond audio
123
+ drop_text, # cfg for text
124
+ mask: bool["b n"] | None = None, # noqa: F722
125
+ ):
126
+ batch = x.shape[0]
127
+ if time.ndim == 0:
128
+ time = time.repeat(batch)
129
+
130
+ # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
131
+ t = self.time_embed(time)
132
+ c = self.text_embed(text, drop_text=drop_text)
133
+ x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
134
+
135
+ seq_len = x.shape[1]
136
+ text_len = text.shape[1]
137
+ rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
138
+ rope_text = self.rotary_embed.forward_from_seq_len(text_len)
139
+
140
+ for block in self.transformer_blocks:
141
+ c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text)
142
+
143
+ x = self.norm_out(x, t)
144
+ output = self.proj_out(x)
145
+
146
+ return output
@@ -0,0 +1,219 @@
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+ from typing import Literal
12
+
13
+ import torch
14
+ from torch import nn
15
+ import torch.nn.functional as F
16
+
17
+ from x_transformers import RMSNorm
18
+ from x_transformers.x_transformers import RotaryEmbedding
19
+
20
+ from f5_tts.model.modules import (
21
+ TimestepEmbedding,
22
+ ConvNeXtV2Block,
23
+ ConvPositionEmbedding,
24
+ Attention,
25
+ AttnProcessor,
26
+ FeedForward,
27
+ precompute_freqs_cis,
28
+ get_pos_embed_indices,
29
+ )
30
+
31
+
32
+ # Text embedding
33
+
34
+
35
+ class TextEmbedding(nn.Module):
36
+ def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
37
+ super().__init__()
38
+ self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
39
+
40
+ if conv_layers > 0:
41
+ self.extra_modeling = True
42
+ self.precompute_max_pos = 4096 # ~44s of 24khz audio
43
+ self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
44
+ self.text_blocks = nn.Sequential(
45
+ *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
46
+ )
47
+ else:
48
+ self.extra_modeling = False
49
+
50
+ def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
51
+ text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
52
+ text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
53
+ batch, text_len = text.shape[0], text.shape[1]
54
+ text = F.pad(text, (0, seq_len - text_len), value=0)
55
+
56
+ if drop_text: # cfg for text
57
+ text = torch.zeros_like(text)
58
+
59
+ text = self.text_embed(text) # b n -> b n d
60
+
61
+ # possible extra modeling
62
+ if self.extra_modeling:
63
+ # sinus pos emb
64
+ batch_start = torch.zeros((batch,), dtype=torch.long)
65
+ pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
66
+ text_pos_embed = self.freqs_cis[pos_idx]
67
+ text = text + text_pos_embed
68
+
69
+ # convnextv2 blocks
70
+ text = self.text_blocks(text)
71
+
72
+ return text
73
+
74
+
75
+ # noised input audio and context mixing embedding
76
+
77
+
78
+ class InputEmbedding(nn.Module):
79
+ def __init__(self, mel_dim, text_dim, out_dim):
80
+ super().__init__()
81
+ self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
82
+ self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
83
+
84
+ def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
85
+ if drop_audio_cond: # cfg for cond audio
86
+ cond = torch.zeros_like(cond)
87
+
88
+ x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
89
+ x = self.conv_pos_embed(x) + x
90
+ return x
91
+
92
+
93
+ # Flat UNet Transformer backbone
94
+
95
+
96
+ class UNetT(nn.Module):
97
+ def __init__(
98
+ self,
99
+ *,
100
+ dim,
101
+ depth=8,
102
+ heads=8,
103
+ dim_head=64,
104
+ dropout=0.1,
105
+ ff_mult=4,
106
+ mel_dim=100,
107
+ text_num_embeds=256,
108
+ text_dim=None,
109
+ conv_layers=0,
110
+ skip_connect_type: Literal["add", "concat", "none"] = "concat",
111
+ ):
112
+ super().__init__()
113
+ assert depth % 2 == 0, "UNet-Transformer's depth should be even."
114
+
115
+ self.time_embed = TimestepEmbedding(dim)
116
+ if text_dim is None:
117
+ text_dim = mel_dim
118
+ self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
119
+ self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
120
+
121
+ self.rotary_embed = RotaryEmbedding(dim_head)
122
+
123
+ # transformer layers & skip connections
124
+
125
+ self.dim = dim
126
+ self.skip_connect_type = skip_connect_type
127
+ needs_skip_proj = skip_connect_type == "concat"
128
+
129
+ self.depth = depth
130
+ self.layers = nn.ModuleList([])
131
+
132
+ for idx in range(depth):
133
+ is_later_half = idx >= (depth // 2)
134
+
135
+ attn_norm = RMSNorm(dim)
136
+ attn = Attention(
137
+ processor=AttnProcessor(),
138
+ dim=dim,
139
+ heads=heads,
140
+ dim_head=dim_head,
141
+ dropout=dropout,
142
+ )
143
+
144
+ ff_norm = RMSNorm(dim)
145
+ ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
146
+
147
+ skip_proj = nn.Linear(dim * 2, dim, bias=False) if needs_skip_proj and is_later_half else None
148
+
149
+ self.layers.append(
150
+ nn.ModuleList(
151
+ [
152
+ skip_proj,
153
+ attn_norm,
154
+ attn,
155
+ ff_norm,
156
+ ff,
157
+ ]
158
+ )
159
+ )
160
+
161
+ self.norm_out = RMSNorm(dim)
162
+ self.proj_out = nn.Linear(dim, mel_dim)
163
+
164
+ def forward(
165
+ self,
166
+ x: float["b n d"], # nosied input audio # noqa: F722
167
+ cond: float["b n d"], # masked cond audio # noqa: F722
168
+ text: int["b nt"], # text # noqa: F722
169
+ time: float["b"] | float[""], # time step # noqa: F821 F722
170
+ drop_audio_cond, # cfg for cond audio
171
+ drop_text, # cfg for text
172
+ mask: bool["b n"] | None = None, # noqa: F722
173
+ ):
174
+ batch, seq_len = x.shape[0], x.shape[1]
175
+ if time.ndim == 0:
176
+ time = time.repeat(batch)
177
+
178
+ # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
179
+ t = self.time_embed(time)
180
+ text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
181
+ x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
182
+
183
+ # postfix time t to input x, [b n d] -> [b n+1 d]
184
+ x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x
185
+ if mask is not None:
186
+ mask = F.pad(mask, (1, 0), value=1)
187
+
188
+ rope = self.rotary_embed.forward_from_seq_len(seq_len + 1)
189
+
190
+ # flat unet transformer
191
+ skip_connect_type = self.skip_connect_type
192
+ skips = []
193
+ for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(self.layers):
194
+ layer = idx + 1
195
+
196
+ # skip connection logic
197
+ is_first_half = layer <= (self.depth // 2)
198
+ is_later_half = not is_first_half
199
+
200
+ if is_first_half:
201
+ skips.append(x)
202
+
203
+ if is_later_half:
204
+ skip = skips.pop()
205
+ if skip_connect_type == "concat":
206
+ x = torch.cat((x, skip), dim=-1)
207
+ x = maybe_skip_proj(x)
208
+ elif skip_connect_type == "add":
209
+ x = x + skip
210
+
211
+ # attention and feedforward blocks
212
+ x = attn(attn_norm(x), rope=rope, mask=mask) + x
213
+ x = ff(ff_norm(x)) + x
214
+
215
+ assert len(skips) == 0
216
+
217
+ x = self.norm_out(x)[:, 1:, :] # unpack t from x
218
+
219
+ return self.proj_out(x)