minicpmo-utils 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cosyvoice/__init__.py +17 -0
- cosyvoice/bin/average_model.py +93 -0
- cosyvoice/bin/export_jit.py +103 -0
- cosyvoice/bin/export_onnx.py +120 -0
- cosyvoice/bin/inference_deprecated.py +126 -0
- cosyvoice/bin/train.py +195 -0
- cosyvoice/cli/__init__.py +0 -0
- cosyvoice/cli/cosyvoice.py +209 -0
- cosyvoice/cli/frontend.py +238 -0
- cosyvoice/cli/model.py +386 -0
- cosyvoice/dataset/__init__.py +0 -0
- cosyvoice/dataset/dataset.py +151 -0
- cosyvoice/dataset/processor.py +434 -0
- cosyvoice/flow/decoder.py +494 -0
- cosyvoice/flow/flow.py +281 -0
- cosyvoice/flow/flow_matching.py +227 -0
- cosyvoice/flow/length_regulator.py +70 -0
- cosyvoice/hifigan/discriminator.py +230 -0
- cosyvoice/hifigan/f0_predictor.py +58 -0
- cosyvoice/hifigan/generator.py +582 -0
- cosyvoice/hifigan/hifigan.py +67 -0
- cosyvoice/llm/llm.py +610 -0
- cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- cosyvoice/tokenizer/tokenizer.py +279 -0
- cosyvoice/transformer/__init__.py +0 -0
- cosyvoice/transformer/activation.py +84 -0
- cosyvoice/transformer/attention.py +330 -0
- cosyvoice/transformer/convolution.py +145 -0
- cosyvoice/transformer/decoder.py +396 -0
- cosyvoice/transformer/decoder_layer.py +132 -0
- cosyvoice/transformer/embedding.py +302 -0
- cosyvoice/transformer/encoder.py +474 -0
- cosyvoice/transformer/encoder_layer.py +236 -0
- cosyvoice/transformer/label_smoothing_loss.py +96 -0
- cosyvoice/transformer/positionwise_feed_forward.py +115 -0
- cosyvoice/transformer/subsampling.py +383 -0
- cosyvoice/transformer/upsample_encoder.py +320 -0
- cosyvoice/utils/__init__.py +0 -0
- cosyvoice/utils/class_utils.py +83 -0
- cosyvoice/utils/common.py +186 -0
- cosyvoice/utils/executor.py +176 -0
- cosyvoice/utils/file_utils.py +129 -0
- cosyvoice/utils/frontend_utils.py +136 -0
- cosyvoice/utils/losses.py +57 -0
- cosyvoice/utils/mask.py +265 -0
- cosyvoice/utils/scheduler.py +738 -0
- cosyvoice/utils/train_utils.py +367 -0
- cosyvoice/vllm/cosyvoice2.py +103 -0
- matcha/__init__.py +0 -0
- matcha/app.py +357 -0
- matcha/cli.py +418 -0
- matcha/hifigan/__init__.py +0 -0
- matcha/hifigan/config.py +28 -0
- matcha/hifigan/denoiser.py +64 -0
- matcha/hifigan/env.py +17 -0
- matcha/hifigan/meldataset.py +217 -0
- matcha/hifigan/models.py +368 -0
- matcha/hifigan/xutils.py +60 -0
- matcha/models/__init__.py +0 -0
- matcha/models/baselightningmodule.py +209 -0
- matcha/models/components/__init__.py +0 -0
- matcha/models/components/decoder.py +443 -0
- matcha/models/components/flow_matching.py +132 -0
- matcha/models/components/text_encoder.py +410 -0
- matcha/models/components/transformer.py +316 -0
- matcha/models/matcha_tts.py +239 -0
- matcha/onnx/__init__.py +0 -0
- matcha/onnx/export.py +181 -0
- matcha/onnx/infer.py +168 -0
- matcha/text/__init__.py +53 -0
- matcha/text/cleaners.py +116 -0
- matcha/text/numbers.py +71 -0
- matcha/text/symbols.py +17 -0
- matcha/train.py +122 -0
- matcha/utils/__init__.py +5 -0
- matcha/utils/audio.py +82 -0
- matcha/utils/generate_data_statistics.py +111 -0
- matcha/utils/instantiators.py +56 -0
- matcha/utils/logging_utils.py +53 -0
- matcha/utils/model.py +90 -0
- matcha/utils/monotonic_align/__init__.py +22 -0
- matcha/utils/monotonic_align/setup.py +7 -0
- matcha/utils/pylogger.py +21 -0
- matcha/utils/rich_utils.py +101 -0
- matcha/utils/utils.py +219 -0
- minicpmo/__init__.py +24 -0
- minicpmo/utils.py +636 -0
- minicpmo/version.py +2 -0
- minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
- minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
- minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
- minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
- s3tokenizer/__init__.py +153 -0
- s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
- s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
- s3tokenizer/assets/mel_filters.npz +0 -0
- s3tokenizer/cli.py +183 -0
- s3tokenizer/model.py +546 -0
- s3tokenizer/model_v2.py +605 -0
- s3tokenizer/utils.py +390 -0
- stepaudio2/__init__.py +40 -0
- stepaudio2/cosyvoice2/__init__.py +1 -0
- stepaudio2/cosyvoice2/flow/__init__.py +0 -0
- stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
- stepaudio2/cosyvoice2/flow/flow.py +230 -0
- stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
- stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
- stepaudio2/cosyvoice2/transformer/attention.py +328 -0
- stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
- stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
- stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
- stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
- stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
- stepaudio2/cosyvoice2/utils/__init__.py +1 -0
- stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
- stepaudio2/cosyvoice2/utils/common.py +101 -0
- stepaudio2/cosyvoice2/utils/mask.py +49 -0
- stepaudio2/flashcosyvoice/__init__.py +0 -0
- stepaudio2/flashcosyvoice/cli.py +424 -0
- stepaudio2/flashcosyvoice/config.py +80 -0
- stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
- stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
- stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
- stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
- stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
- stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
- stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
- stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
- stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/flow.py +198 -0
- stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
- stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
- stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
- stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
- stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
- stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
- stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
- stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
- stepaudio2/flashcosyvoice/utils/audio.py +77 -0
- stepaudio2/flashcosyvoice/utils/context.py +28 -0
- stepaudio2/flashcosyvoice/utils/loader.py +116 -0
- stepaudio2/flashcosyvoice/utils/memory.py +19 -0
- stepaudio2/stepaudio2.py +204 -0
- stepaudio2/token2wav.py +248 -0
- stepaudio2/utils.py +91 -0
|
@@ -0,0 +1,310 @@
|
|
|
1
|
+
import pickle
|
|
2
|
+
from multiprocessing.shared_memory import SharedMemory
|
|
3
|
+
from multiprocessing.synchronize import Event
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.distributed as dist
|
|
7
|
+
|
|
8
|
+
from stepaudio2.flashcosyvoice.config import Config
|
|
9
|
+
from stepaudio2.flashcosyvoice.engine.sequence import Sequence
|
|
10
|
+
from stepaudio2.flashcosyvoice.modules.qwen2 import Qwen2ForCausalLM
|
|
11
|
+
from stepaudio2.flashcosyvoice.modules.sampler import RasSampler, Sampler
|
|
12
|
+
from stepaudio2.flashcosyvoice.utils.context import (get_context, reset_context,
|
|
13
|
+
set_context)
|
|
14
|
+
from stepaudio2.flashcosyvoice.utils.loader import load_model
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ModelRunner:
|
|
18
|
+
|
|
19
|
+
def __init__(self, config: Config, rank: int, event: Event | list[Event]):
|
|
20
|
+
self.config = config
|
|
21
|
+
hf_config = config.hf_config
|
|
22
|
+
self.block_size = config.kvcache_block_size
|
|
23
|
+
self.enforce_eager = config.enforce_eager
|
|
24
|
+
self.world_size = config.tensor_parallel_size
|
|
25
|
+
self.rank = rank
|
|
26
|
+
self.event = event
|
|
27
|
+
|
|
28
|
+
# TODO(xcsong): support tp > 1
|
|
29
|
+
if self.world_size > 1:
|
|
30
|
+
dist.init_process_group("nccl", "tcp://localhost:2333", world_size=self.world_size, rank=rank)
|
|
31
|
+
torch.cuda.set_device(rank)
|
|
32
|
+
default_dtype = torch.get_default_dtype()
|
|
33
|
+
torch.set_default_dtype(hf_config.torch_dtype)
|
|
34
|
+
torch.set_default_device("cuda")
|
|
35
|
+
self.model = Qwen2ForCausalLM(hf_config)
|
|
36
|
+
load_model(self.model, config.model, hf_config)
|
|
37
|
+
self.sampler = Sampler()
|
|
38
|
+
self.ras_sampler = RasSampler()
|
|
39
|
+
self.warmup_model()
|
|
40
|
+
self.allocate_kv_cache()
|
|
41
|
+
if not self.enforce_eager:
|
|
42
|
+
self.capture_cudagraph()
|
|
43
|
+
torch.set_default_device("cpu")
|
|
44
|
+
torch.set_default_dtype(default_dtype)
|
|
45
|
+
|
|
46
|
+
if self.world_size > 1:
|
|
47
|
+
if rank == 0:
|
|
48
|
+
self.shm = SharedMemory(name="flashcosyvoice", create=True, size=2**20)
|
|
49
|
+
dist.barrier()
|
|
50
|
+
else:
|
|
51
|
+
dist.barrier()
|
|
52
|
+
self.shm = SharedMemory(name="flashcosyvoice")
|
|
53
|
+
self.loop()
|
|
54
|
+
|
|
55
|
+
def exit(self):
|
|
56
|
+
if self.world_size > 1:
|
|
57
|
+
self.shm.close()
|
|
58
|
+
dist.barrier()
|
|
59
|
+
if self.rank == 0:
|
|
60
|
+
self.shm.unlink()
|
|
61
|
+
if not self.enforce_eager:
|
|
62
|
+
del self.graphs, self.graph_pool
|
|
63
|
+
torch.cuda.synchronize()
|
|
64
|
+
if self.world_size > 1:
|
|
65
|
+
dist.destroy_process_group()
|
|
66
|
+
|
|
67
|
+
def loop(self):
|
|
68
|
+
while True:
|
|
69
|
+
method_name, args = self.read_shm()
|
|
70
|
+
self.call(method_name, *args)
|
|
71
|
+
if method_name == "exit":
|
|
72
|
+
break
|
|
73
|
+
|
|
74
|
+
def read_shm(self):
|
|
75
|
+
assert self.world_size > 1 and self.rank
|
|
76
|
+
self.event.wait()
|
|
77
|
+
n = int.from_bytes(self.shm.buf[0:4], "little")
|
|
78
|
+
method_name, *args = pickle.loads(self.shm.buf[4:n + 4])
|
|
79
|
+
self.event.clear()
|
|
80
|
+
return method_name, args
|
|
81
|
+
|
|
82
|
+
def write_shm(self, method_name, *args):
|
|
83
|
+
assert self.world_size > 1 and not self.rank
|
|
84
|
+
data = pickle.dumps([method_name, *args])
|
|
85
|
+
n = len(data)
|
|
86
|
+
self.shm.buf[0:4] = n.to_bytes(4, "little")
|
|
87
|
+
self.shm.buf[4:n + 4] = data
|
|
88
|
+
for event in self.event:
|
|
89
|
+
event.set()
|
|
90
|
+
|
|
91
|
+
def call(self, method_name, *args):
|
|
92
|
+
if self.world_size > 1 and self.rank == 0:
|
|
93
|
+
self.write_shm(method_name, *args)
|
|
94
|
+
method = getattr(self, method_name, None)
|
|
95
|
+
return method(*args)
|
|
96
|
+
|
|
97
|
+
def warmup_model(self):
|
|
98
|
+
torch.cuda.empty_cache()
|
|
99
|
+
torch.cuda.reset_peak_memory_stats()
|
|
100
|
+
max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len
|
|
101
|
+
num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs)
|
|
102
|
+
seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)]
|
|
103
|
+
self.run(seqs, True)
|
|
104
|
+
torch.cuda.empty_cache()
|
|
105
|
+
|
|
106
|
+
def allocate_kv_cache(self):
|
|
107
|
+
config = self.config
|
|
108
|
+
hf_config = config.hf_config
|
|
109
|
+
free, total = torch.cuda.mem_get_info()
|
|
110
|
+
used = total - free
|
|
111
|
+
peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
|
|
112
|
+
current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
|
|
113
|
+
num_kv_heads = hf_config.num_key_value_heads // self.world_size
|
|
114
|
+
head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
|
|
115
|
+
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize
|
|
116
|
+
config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
|
|
117
|
+
assert config.num_kvcache_blocks > 0, "try to **increase** gpu_memory_utilization"
|
|
118
|
+
self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim)
|
|
119
|
+
layer_id = 0
|
|
120
|
+
for module in self.model.modules():
|
|
121
|
+
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
|
|
122
|
+
module.k_cache = self.kv_cache[0, layer_id]
|
|
123
|
+
module.v_cache = self.kv_cache[1, layer_id]
|
|
124
|
+
layer_id += 1
|
|
125
|
+
|
|
126
|
+
def prepare_block_tables(self, seqs: list[Sequence]):
|
|
127
|
+
max_len = max(len(seq.block_table) for seq in seqs)
|
|
128
|
+
block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs]
|
|
129
|
+
block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
|
130
|
+
return block_tables
|
|
131
|
+
|
|
132
|
+
def prepare_prefill(self, seqs: list[Sequence]):
|
|
133
|
+
input_ids = []
|
|
134
|
+
positions = []
|
|
135
|
+
cu_seqlens_q = [0]
|
|
136
|
+
cu_seqlens_k = [0]
|
|
137
|
+
max_seqlen_q = 0
|
|
138
|
+
max_seqlen_k = 0
|
|
139
|
+
slot_mapping = []
|
|
140
|
+
block_tables = None
|
|
141
|
+
for seq in seqs:
|
|
142
|
+
seqlen = len(seq)
|
|
143
|
+
input_ids.extend(seq[seq.num_cached_tokens:])
|
|
144
|
+
positions.extend(list(range(seq.num_cached_tokens, seqlen)))
|
|
145
|
+
seqlen_q = seqlen - seq.num_cached_tokens
|
|
146
|
+
seqlen_k = seqlen
|
|
147
|
+
cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
|
|
148
|
+
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
|
|
149
|
+
max_seqlen_q = max(seqlen_q, max_seqlen_q)
|
|
150
|
+
max_seqlen_k = max(seqlen_k, max_seqlen_k)
|
|
151
|
+
if not seq.block_table:
|
|
152
|
+
continue
|
|
153
|
+
for i in range(seq.num_cached_blocks, seq.num_blocks):
|
|
154
|
+
start = seq.block_table[i] * self.block_size
|
|
155
|
+
if i != seq.num_blocks - 1:
|
|
156
|
+
end = start + self.block_size
|
|
157
|
+
else:
|
|
158
|
+
end = start + seq.last_block_num_tokens
|
|
159
|
+
slot_mapping.extend(list(range(start, end)))
|
|
160
|
+
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
|
|
161
|
+
block_tables = self.prepare_block_tables(seqs)
|
|
162
|
+
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
|
163
|
+
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
|
164
|
+
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
|
165
|
+
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
|
166
|
+
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
|
167
|
+
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
|
|
168
|
+
return input_ids, positions
|
|
169
|
+
|
|
170
|
+
def prepare_decode(self, seqs: list[Sequence]):
|
|
171
|
+
input_ids = []
|
|
172
|
+
positions = []
|
|
173
|
+
slot_mapping = []
|
|
174
|
+
context_lens = []
|
|
175
|
+
for seq in seqs:
|
|
176
|
+
input_ids.append(seq.last_token)
|
|
177
|
+
positions.append(len(seq))
|
|
178
|
+
context_lens.append(len(seq))
|
|
179
|
+
slot_mapping.append(seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1)
|
|
180
|
+
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
|
181
|
+
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
|
182
|
+
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
|
183
|
+
context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
|
184
|
+
block_tables = self.prepare_block_tables(seqs)
|
|
185
|
+
set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
|
|
186
|
+
return input_ids, positions
|
|
187
|
+
|
|
188
|
+
def prepare_sample(self, seqs: list[Sequence]):
|
|
189
|
+
temperatures = []
|
|
190
|
+
top_ks = []
|
|
191
|
+
win_sizes = []
|
|
192
|
+
tau_rs = []
|
|
193
|
+
top_ps = []
|
|
194
|
+
min_tokens_list = []
|
|
195
|
+
use_ras_list = []
|
|
196
|
+
|
|
197
|
+
for seq in seqs:
|
|
198
|
+
temperatures.append(seq.temperature)
|
|
199
|
+
top_ks.append(seq.top_k)
|
|
200
|
+
win_sizes.append(seq.win_size)
|
|
201
|
+
tau_rs.append(seq.tau_r)
|
|
202
|
+
top_ps.append(seq.top_p)
|
|
203
|
+
min_tokens_list.append(seq.min_tokens)
|
|
204
|
+
use_ras_list.append(seq.use_ras)
|
|
205
|
+
|
|
206
|
+
temperatures_tensor = torch.tensor(temperatures, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
|
|
207
|
+
# check all items equal
|
|
208
|
+
assert all(item == temperatures[0] for item in temperatures)
|
|
209
|
+
assert all(item == top_ks[0] for item in top_ks)
|
|
210
|
+
assert all(item == win_sizes[0] for item in win_sizes)
|
|
211
|
+
assert all(item == tau_rs[0] for item in tau_rs)
|
|
212
|
+
assert all(item == top_ps[0] for item in top_ps)
|
|
213
|
+
assert all(item == use_ras_list[0] for item in use_ras_list)
|
|
214
|
+
|
|
215
|
+
return {
|
|
216
|
+
'temperatures': temperatures_tensor,
|
|
217
|
+
'top_k': top_ks[0],
|
|
218
|
+
'win_size': win_sizes[0],
|
|
219
|
+
'tau_r': tau_rs[0],
|
|
220
|
+
'top_p': top_ps[0],
|
|
221
|
+
'eos_token': self.config.eos,
|
|
222
|
+
'min_tokens': min_tokens_list,
|
|
223
|
+
'use_ras': use_ras_list[0]
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
@torch.inference_mode()
|
|
227
|
+
def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool):
|
|
228
|
+
if is_prefill or self.enforce_eager or input_ids.size(0) > 512:
|
|
229
|
+
return self.model.compute_logits(self.model(input_ids, positions))
|
|
230
|
+
else:
|
|
231
|
+
bs = input_ids.size(0)
|
|
232
|
+
context = get_context()
|
|
233
|
+
graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
|
|
234
|
+
graph_vars = self.graph_vars
|
|
235
|
+
for k, v in graph_vars.items():
|
|
236
|
+
if k != "outputs":
|
|
237
|
+
v.zero_()
|
|
238
|
+
graph_vars["input_ids"][:bs] = input_ids
|
|
239
|
+
graph_vars["positions"][:bs] = positions
|
|
240
|
+
graph_vars["slot_mapping"][:bs] = context.slot_mapping
|
|
241
|
+
graph_vars["context_lens"][:bs] = context.context_lens
|
|
242
|
+
graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables
|
|
243
|
+
graph.replay()
|
|
244
|
+
return self.model.compute_logits(graph_vars["outputs"][:bs])
|
|
245
|
+
|
|
246
|
+
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
|
|
247
|
+
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
|
|
248
|
+
if self.rank == 0 or self.world_size == 1:
|
|
249
|
+
sample_params = self.prepare_sample(seqs)
|
|
250
|
+
logits = self.run_model(input_ids, positions, is_prefill)
|
|
251
|
+
|
|
252
|
+
if sample_params['use_ras']:
|
|
253
|
+
# Prepare decoded tokens list for RasSampler
|
|
254
|
+
decoded_tokens_list = [seq.completion_token_ids for seq in seqs]
|
|
255
|
+
# Pass all parameters as lists to RasSampler
|
|
256
|
+
token_ids = self.ras_sampler(
|
|
257
|
+
logits,
|
|
258
|
+
decoded_tokens_list,
|
|
259
|
+
win_size=sample_params['win_size'],
|
|
260
|
+
tau_r=sample_params['tau_r'],
|
|
261
|
+
top_p=sample_params['top_p'],
|
|
262
|
+
top_k=sample_params['top_k'],
|
|
263
|
+
eos_token=sample_params['eos_token'],
|
|
264
|
+
min_tokens=sample_params['min_tokens']
|
|
265
|
+
).tolist()
|
|
266
|
+
else:
|
|
267
|
+
# Use the default sampler with list form of top_ks
|
|
268
|
+
token_ids = self.sampler(logits, sample_params['temperatures'], sample_params['top_k']).tolist()
|
|
269
|
+
else:
|
|
270
|
+
logits = self.run_model(input_ids, positions, is_prefill)
|
|
271
|
+
token_ids = None
|
|
272
|
+
reset_context()
|
|
273
|
+
return token_ids
|
|
274
|
+
|
|
275
|
+
@torch.inference_mode()
|
|
276
|
+
def capture_cudagraph(self):
|
|
277
|
+
config = self.config
|
|
278
|
+
hf_config = config.hf_config
|
|
279
|
+
max_bs = min(self.config.max_num_seqs, 512)
|
|
280
|
+
max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size
|
|
281
|
+
input_ids = torch.zeros(max_bs, dtype=torch.int64)
|
|
282
|
+
positions = torch.zeros(max_bs, dtype=torch.int64)
|
|
283
|
+
slot_mapping = torch.zeros(max_bs, dtype=torch.int32)
|
|
284
|
+
context_lens = torch.zeros(max_bs, dtype=torch.int32)
|
|
285
|
+
block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)
|
|
286
|
+
outputs = torch.zeros(max_bs, hf_config.hidden_size)
|
|
287
|
+
self.graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16))
|
|
288
|
+
self.graphs = {}
|
|
289
|
+
self.graph_pool = None
|
|
290
|
+
|
|
291
|
+
for bs in reversed(self.graph_bs):
|
|
292
|
+
graph = torch.cuda.CUDAGraph()
|
|
293
|
+
set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])
|
|
294
|
+
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup
|
|
295
|
+
with torch.cuda.graph(graph, self.graph_pool):
|
|
296
|
+
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture
|
|
297
|
+
if self.graph_pool is None:
|
|
298
|
+
self.graph_pool = graph.pool()
|
|
299
|
+
self.graphs[bs] = graph
|
|
300
|
+
torch.cuda.synchronize()
|
|
301
|
+
reset_context()
|
|
302
|
+
|
|
303
|
+
self.graph_vars = dict(
|
|
304
|
+
input_ids=input_ids,
|
|
305
|
+
positions=positions,
|
|
306
|
+
slot_mapping=slot_mapping,
|
|
307
|
+
context_lens=context_lens,
|
|
308
|
+
block_tables=block_tables,
|
|
309
|
+
outputs=outputs,
|
|
310
|
+
)
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
|
|
3
|
+
from stepaudio2.flashcosyvoice.config import Config
|
|
4
|
+
from stepaudio2.flashcosyvoice.engine.block_manager import BlockManager
|
|
5
|
+
from stepaudio2.flashcosyvoice.engine.sequence import Sequence, SequenceStatus
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Scheduler:
|
|
9
|
+
|
|
10
|
+
def __init__(self, config: Config):
|
|
11
|
+
self.max_num_seqs = config.max_num_seqs
|
|
12
|
+
self.max_num_batched_tokens = config.max_num_batched_tokens
|
|
13
|
+
self.eos = config.eos
|
|
14
|
+
self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size)
|
|
15
|
+
self.waiting: deque[Sequence] = deque()
|
|
16
|
+
self.running: deque[Sequence] = deque()
|
|
17
|
+
|
|
18
|
+
def is_finished(self):
|
|
19
|
+
return not self.waiting and not self.running
|
|
20
|
+
|
|
21
|
+
def add(self, seq: Sequence):
|
|
22
|
+
self.waiting.append(seq)
|
|
23
|
+
|
|
24
|
+
def schedule(self) -> tuple[list[Sequence], bool]:
|
|
25
|
+
# prefill
|
|
26
|
+
scheduled_seqs = []
|
|
27
|
+
num_seqs = 0
|
|
28
|
+
num_batched_tokens = 0
|
|
29
|
+
while self.waiting and num_seqs < self.max_num_seqs:
|
|
30
|
+
seq = self.waiting[0]
|
|
31
|
+
if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq):
|
|
32
|
+
break
|
|
33
|
+
num_seqs += 1
|
|
34
|
+
self.block_manager.allocate(seq)
|
|
35
|
+
num_batched_tokens += len(seq) - seq.num_cached_tokens
|
|
36
|
+
seq.status = SequenceStatus.RUNNING
|
|
37
|
+
self.waiting.popleft()
|
|
38
|
+
self.running.append(seq)
|
|
39
|
+
scheduled_seqs.append(seq)
|
|
40
|
+
if scheduled_seqs:
|
|
41
|
+
return scheduled_seqs, True
|
|
42
|
+
|
|
43
|
+
# decode
|
|
44
|
+
while self.running and num_seqs < self.max_num_seqs:
|
|
45
|
+
seq = self.running.popleft()
|
|
46
|
+
while not self.block_manager.can_append(seq):
|
|
47
|
+
if self.running:
|
|
48
|
+
self.preempt(self.running.pop())
|
|
49
|
+
else:
|
|
50
|
+
self.preempt(seq)
|
|
51
|
+
break
|
|
52
|
+
else:
|
|
53
|
+
num_seqs += 1
|
|
54
|
+
self.block_manager.may_append(seq)
|
|
55
|
+
scheduled_seqs.append(seq)
|
|
56
|
+
assert scheduled_seqs
|
|
57
|
+
self.running.extendleft(reversed(scheduled_seqs))
|
|
58
|
+
return scheduled_seqs, False
|
|
59
|
+
|
|
60
|
+
def preempt(self, seq: Sequence):
|
|
61
|
+
seq.status = SequenceStatus.WAITING
|
|
62
|
+
self.block_manager.deallocate(seq)
|
|
63
|
+
self.waiting.appendleft(seq)
|
|
64
|
+
|
|
65
|
+
def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
|
|
66
|
+
for seq, token_id in zip(seqs, token_ids):
|
|
67
|
+
seq.append_token(token_id)
|
|
68
|
+
# Check if the sequence has reached the maximum number of tokens
|
|
69
|
+
reached_max_tokens = seq.num_completion_tokens == seq.max_tokens
|
|
70
|
+
# Check if the sequence has reached EOS and has generated enough tokens (satisfying min_tokens requirements)
|
|
71
|
+
eos_with_min_tokens = (not seq.ignore_eos and token_id == self.eos and
|
|
72
|
+
seq.num_completion_tokens >= seq.min_tokens)
|
|
73
|
+
|
|
74
|
+
if reached_max_tokens or eos_with_min_tokens:
|
|
75
|
+
seq.status = SequenceStatus.FINISHED
|
|
76
|
+
self.block_manager.deallocate(seq)
|
|
77
|
+
self.running.remove(seq)
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
from copy import copy
|
|
2
|
+
from enum import Enum, auto
|
|
3
|
+
from itertools import count
|
|
4
|
+
|
|
5
|
+
from stepaudio2.flashcosyvoice.config import SamplingParams
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SequenceStatus(Enum):
|
|
9
|
+
WAITING = auto()
|
|
10
|
+
RUNNING = auto()
|
|
11
|
+
FINISHED = auto()
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Sequence:
|
|
15
|
+
block_size = 256
|
|
16
|
+
counter = count()
|
|
17
|
+
|
|
18
|
+
def __init__(self, token_ids: list[int], sampling_params = SamplingParams()):
|
|
19
|
+
self.seq_id = next(Sequence.counter)
|
|
20
|
+
self.status = SequenceStatus.WAITING
|
|
21
|
+
self.token_ids = copy(token_ids)
|
|
22
|
+
self.last_token = token_ids[-1]
|
|
23
|
+
self.num_tokens = len(self.token_ids)
|
|
24
|
+
self.num_prompt_tokens = len(token_ids)
|
|
25
|
+
self.num_cached_tokens = 0
|
|
26
|
+
self.block_table = []
|
|
27
|
+
self.temperature = sampling_params.temperature
|
|
28
|
+
self.min_tokens = sampling_params.min_tokens
|
|
29
|
+
self.max_tokens = sampling_params.max_tokens
|
|
30
|
+
self.ignore_eos = sampling_params.ignore_eos
|
|
31
|
+
self.top_k = sampling_params.top_k
|
|
32
|
+
# RasSampler parameters
|
|
33
|
+
self.use_ras = sampling_params.use_ras
|
|
34
|
+
self.win_size = sampling_params.win_size
|
|
35
|
+
self.tau_r = sampling_params.tau_r
|
|
36
|
+
self.top_p = sampling_params.top_p
|
|
37
|
+
|
|
38
|
+
def __len__(self):
|
|
39
|
+
return self.num_tokens
|
|
40
|
+
|
|
41
|
+
def __getitem__(self, key):
|
|
42
|
+
return self.token_ids[key]
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def is_finished(self):
|
|
46
|
+
return self.status == SequenceStatus.FINISHED
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def num_completion_tokens(self):
|
|
50
|
+
return self.num_tokens - self.num_prompt_tokens
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def prompt_token_ids(self):
|
|
54
|
+
return self.token_ids[:self.num_prompt_tokens]
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def completion_token_ids(self):
|
|
58
|
+
return self.token_ids[self.num_prompt_tokens:]
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def num_cached_blocks(self):
|
|
62
|
+
return self.num_cached_tokens // self.block_size
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def num_blocks(self):
|
|
66
|
+
return (self.num_tokens + self.block_size - 1) // self.block_size
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def last_block_num_tokens(self):
|
|
70
|
+
return self.num_tokens - (self.num_blocks - 1) * self.block_size
|
|
71
|
+
|
|
72
|
+
def block(self, i):
|
|
73
|
+
assert 0 <= i < self.num_blocks
|
|
74
|
+
return self.token_ids[i*self.block_size: (i+1)*self.block_size]
|
|
75
|
+
|
|
76
|
+
def append_token(self, token_id: int):
|
|
77
|
+
self.token_ids.append(token_id)
|
|
78
|
+
self.last_token = token_id
|
|
79
|
+
self.num_tokens += 1
|
|
80
|
+
|
|
81
|
+
def __getstate__(self):
|
|
82
|
+
return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table,
|
|
83
|
+
self.token_ids if self.num_completion_tokens == 0 else self.last_token)
|
|
84
|
+
|
|
85
|
+
def __setstate__(self, state):
|
|
86
|
+
self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table = state[:-1]
|
|
87
|
+
if self.num_completion_tokens == 0:
|
|
88
|
+
self.token_ids = state[-1]
|
|
89
|
+
else:
|
|
90
|
+
self.last_token = state[-1]
|
|
File without changes
|
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
|
|
7
|
+
from stepaudio2.flashcosyvoice.modules.flow_components.estimator import \
|
|
8
|
+
CausalConditionalDecoder
|
|
9
|
+
from stepaudio2.flashcosyvoice.modules.flow_components.upsample_encoder import (
|
|
10
|
+
UpsampleConformerEncoder, make_pad_mask)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# TODO(xcsong): make it configurable
|
|
14
|
+
@dataclass
|
|
15
|
+
class CfmParams:
|
|
16
|
+
sigma_min: float = 1e-6
|
|
17
|
+
solver: str = "euler"
|
|
18
|
+
t_scheduler: str = "cosine"
|
|
19
|
+
training_cfg_rate: float = 0.2
|
|
20
|
+
inference_cfg_rate: float = 0.7
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class CausalConditionalCFM(torch.nn.Module):
|
|
24
|
+
def __init__(self, in_channels=320, cfm_params=CfmParams(), n_spks=1, spk_emb_dim=80, estimator: torch.nn.Module = None):
|
|
25
|
+
super().__init__()
|
|
26
|
+
self.n_feats = in_channels
|
|
27
|
+
self.n_spks = n_spks
|
|
28
|
+
self.spk_emb_dim = spk_emb_dim
|
|
29
|
+
self.solver = cfm_params.solver
|
|
30
|
+
if hasattr(cfm_params, "sigma_min"):
|
|
31
|
+
self.sigma_min = cfm_params.sigma_min
|
|
32
|
+
else:
|
|
33
|
+
self.sigma_min = 1e-4
|
|
34
|
+
self.t_scheduler = cfm_params.t_scheduler
|
|
35
|
+
self.training_cfg_rate = cfm_params.training_cfg_rate
|
|
36
|
+
self.inference_cfg_rate = cfm_params.inference_cfg_rate
|
|
37
|
+
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
|
|
38
|
+
# Just change the architecture of the estimator here
|
|
39
|
+
self.estimator = CausalConditionalDecoder() if estimator is None else estimator
|
|
40
|
+
|
|
41
|
+
@torch.inference_mode()
|
|
42
|
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, streaming=False):
|
|
43
|
+
"""Forward diffusion
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
mu (torch.Tensor): output of encoder
|
|
47
|
+
shape: (batch_size, n_feats, mel_timesteps)
|
|
48
|
+
mask (torch.Tensor): output_mask
|
|
49
|
+
shape: (batch_size, 1, mel_timesteps)
|
|
50
|
+
n_timesteps (int): number of diffusion steps
|
|
51
|
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
|
52
|
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
|
53
|
+
shape: (batch_size, spk_emb_dim)
|
|
54
|
+
cond: Not used but kept for future purposes
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
sample: generated mel-spectrogram
|
|
58
|
+
shape: (batch_size, n_feats, mel_timesteps)
|
|
59
|
+
"""
|
|
60
|
+
z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
|
|
61
|
+
# fix prompt and overlap part mu and z
|
|
62
|
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
|
63
|
+
if self.t_scheduler == 'cosine':
|
|
64
|
+
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
|
65
|
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, streaming=streaming), None
|
|
66
|
+
|
|
67
|
+
def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False):
|
|
68
|
+
"""
|
|
69
|
+
Fixed euler solver for ODEs.
|
|
70
|
+
Args:
|
|
71
|
+
x (torch.Tensor): random noise
|
|
72
|
+
t_span (torch.Tensor): n_timesteps interpolated
|
|
73
|
+
shape: (n_timesteps + 1,)
|
|
74
|
+
mu (torch.Tensor): output of encoder
|
|
75
|
+
shape: (batch_size, n_feats, mel_timesteps)
|
|
76
|
+
mask (torch.Tensor): output_mask
|
|
77
|
+
shape: (batch_size, 1, mel_timesteps)
|
|
78
|
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
|
79
|
+
shape: (batch_size, spk_emb_dim)
|
|
80
|
+
cond: Not used but kept for future purposes
|
|
81
|
+
"""
|
|
82
|
+
batch_size = x.size(0)
|
|
83
|
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
|
84
|
+
|
|
85
|
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
|
86
|
+
# Or in future might add like a return_all_steps flag
|
|
87
|
+
sol = []
|
|
88
|
+
|
|
89
|
+
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
|
|
90
|
+
# Create tensors with double batch size for CFG (conditional + unconditional)
|
|
91
|
+
x_in = torch.zeros([batch_size * 2, x.size(1), x.size(2)], device=x.device, dtype=x.dtype)
|
|
92
|
+
mask_in = torch.zeros([batch_size * 2, mask.size(1), mask.size(2)], device=x.device, dtype=x.dtype)
|
|
93
|
+
mu_in = torch.zeros([batch_size * 2, mu.size(1), mu.size(2)], device=x.device, dtype=x.dtype)
|
|
94
|
+
t_in = torch.zeros([batch_size * 2], device=x.device, dtype=x.dtype)
|
|
95
|
+
spks_in = torch.zeros([batch_size * 2, spks.size(1)], device=x.device, dtype=x.dtype)
|
|
96
|
+
cond_in = torch.zeros([batch_size * 2, cond.size(1), cond.size(2)], device=x.device, dtype=x.dtype)
|
|
97
|
+
|
|
98
|
+
for step in range(1, len(t_span)):
|
|
99
|
+
# Classifier-Free Guidance inference introduced in VoiceBox
|
|
100
|
+
# Copy conditional and unconditional input
|
|
101
|
+
x_in[:batch_size] = x
|
|
102
|
+
x_in[batch_size:] = x
|
|
103
|
+
mask_in[:batch_size] = mask
|
|
104
|
+
mask_in[batch_size:] = mask
|
|
105
|
+
mu_in[:batch_size] = mu
|
|
106
|
+
# Unconditional part remains 0
|
|
107
|
+
t_in.fill_(t)
|
|
108
|
+
spks_in[:batch_size] = spks
|
|
109
|
+
cond_in[:batch_size] = cond
|
|
110
|
+
|
|
111
|
+
dphi_dt = self.estimator(
|
|
112
|
+
x_in, mask_in,
|
|
113
|
+
mu_in, t_in,
|
|
114
|
+
spks_in,
|
|
115
|
+
cond_in,
|
|
116
|
+
streaming
|
|
117
|
+
)
|
|
118
|
+
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [batch_size, batch_size], dim=0)
|
|
119
|
+
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
|
120
|
+
x = x + dt * dphi_dt
|
|
121
|
+
t = t + dt
|
|
122
|
+
sol.append(x)
|
|
123
|
+
if step < len(t_span) - 1:
|
|
124
|
+
dt = t_span[step + 1] - t
|
|
125
|
+
|
|
126
|
+
return sol[-1].float()
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|
130
|
+
def __init__(
|
|
131
|
+
self,
|
|
132
|
+
input_size: int = 512,
|
|
133
|
+
output_size: int = 80,
|
|
134
|
+
spk_embed_dim: int = 192,
|
|
135
|
+
output_type: str = "mel",
|
|
136
|
+
vocab_size: int = 6561,
|
|
137
|
+
input_frame_rate: int = 25,
|
|
138
|
+
token_mel_ratio: int = 2,
|
|
139
|
+
pre_lookahead_len: int = 3,
|
|
140
|
+
encoder: torch.nn.Module = None,
|
|
141
|
+
decoder: torch.nn.Module = None,
|
|
142
|
+
):
|
|
143
|
+
super().__init__()
|
|
144
|
+
self.input_size = input_size
|
|
145
|
+
self.output_size = output_size
|
|
146
|
+
self.vocab_size = vocab_size
|
|
147
|
+
self.output_type = output_type
|
|
148
|
+
self.input_frame_rate = input_frame_rate
|
|
149
|
+
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
|
150
|
+
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
|
151
|
+
self.encoder = UpsampleConformerEncoder() if encoder is None else encoder
|
|
152
|
+
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
|
|
153
|
+
self.decoder = CausalConditionalCFM() if decoder is None else decoder
|
|
154
|
+
self.token_mel_ratio = token_mel_ratio
|
|
155
|
+
self.pre_lookahead_len = pre_lookahead_len
|
|
156
|
+
|
|
157
|
+
@torch.inference_mode()
|
|
158
|
+
def forward(self,
|
|
159
|
+
token,
|
|
160
|
+
token_len,
|
|
161
|
+
prompt_feat,
|
|
162
|
+
prompt_feat_len,
|
|
163
|
+
embedding,
|
|
164
|
+
streaming,
|
|
165
|
+
finalize):
|
|
166
|
+
# xvec projection
|
|
167
|
+
embedding = F.normalize(embedding, dim=1)
|
|
168
|
+
embedding = self.spk_embed_affine_layer(embedding)
|
|
169
|
+
|
|
170
|
+
# concat text and prompt_text
|
|
171
|
+
mask = (~make_pad_mask(token_len, max_len=token.shape[1])).unsqueeze(-1).to(embedding)
|
|
172
|
+
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
|
173
|
+
|
|
174
|
+
# text encode
|
|
175
|
+
if finalize is True:
|
|
176
|
+
h, h_lengths = self.encoder(token, token_len, streaming=streaming)
|
|
177
|
+
else:
|
|
178
|
+
token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
|
|
179
|
+
h, h_lengths = self.encoder(token, token_len, context=context, streaming=streaming)
|
|
180
|
+
h = self.encoder_proj(h)
|
|
181
|
+
|
|
182
|
+
# get conditions
|
|
183
|
+
conds = torch.zeros_like(h, device=token.device)
|
|
184
|
+
for i, j in enumerate(prompt_feat_len):
|
|
185
|
+
conds[i, :j] = prompt_feat[i, :j]
|
|
186
|
+
conds = conds.transpose(1, 2)
|
|
187
|
+
|
|
188
|
+
h_lengths = h_lengths.sum(dim=-1).squeeze(dim=1)
|
|
189
|
+
mask = (~make_pad_mask(h_lengths, max_len=h.shape[1])).to(h)
|
|
190
|
+
feat, _ = self.decoder(
|
|
191
|
+
mu=h.transpose(1, 2).contiguous(),
|
|
192
|
+
mask=mask.unsqueeze(1),
|
|
193
|
+
spks=embedding,
|
|
194
|
+
cond=conds,
|
|
195
|
+
n_timesteps=10,
|
|
196
|
+
streaming=streaming
|
|
197
|
+
) # [B, num_mels, T]
|
|
198
|
+
return feat.float(), h_lengths
|
|
File without changes
|