minicpmo-utils 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. cosyvoice/__init__.py +17 -0
  2. cosyvoice/bin/average_model.py +93 -0
  3. cosyvoice/bin/export_jit.py +103 -0
  4. cosyvoice/bin/export_onnx.py +120 -0
  5. cosyvoice/bin/inference_deprecated.py +126 -0
  6. cosyvoice/bin/train.py +195 -0
  7. cosyvoice/cli/__init__.py +0 -0
  8. cosyvoice/cli/cosyvoice.py +209 -0
  9. cosyvoice/cli/frontend.py +238 -0
  10. cosyvoice/cli/model.py +386 -0
  11. cosyvoice/dataset/__init__.py +0 -0
  12. cosyvoice/dataset/dataset.py +151 -0
  13. cosyvoice/dataset/processor.py +434 -0
  14. cosyvoice/flow/decoder.py +494 -0
  15. cosyvoice/flow/flow.py +281 -0
  16. cosyvoice/flow/flow_matching.py +227 -0
  17. cosyvoice/flow/length_regulator.py +70 -0
  18. cosyvoice/hifigan/discriminator.py +230 -0
  19. cosyvoice/hifigan/f0_predictor.py +58 -0
  20. cosyvoice/hifigan/generator.py +582 -0
  21. cosyvoice/hifigan/hifigan.py +67 -0
  22. cosyvoice/llm/llm.py +610 -0
  23. cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  24. cosyvoice/tokenizer/tokenizer.py +279 -0
  25. cosyvoice/transformer/__init__.py +0 -0
  26. cosyvoice/transformer/activation.py +84 -0
  27. cosyvoice/transformer/attention.py +330 -0
  28. cosyvoice/transformer/convolution.py +145 -0
  29. cosyvoice/transformer/decoder.py +396 -0
  30. cosyvoice/transformer/decoder_layer.py +132 -0
  31. cosyvoice/transformer/embedding.py +302 -0
  32. cosyvoice/transformer/encoder.py +474 -0
  33. cosyvoice/transformer/encoder_layer.py +236 -0
  34. cosyvoice/transformer/label_smoothing_loss.py +96 -0
  35. cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  36. cosyvoice/transformer/subsampling.py +383 -0
  37. cosyvoice/transformer/upsample_encoder.py +320 -0
  38. cosyvoice/utils/__init__.py +0 -0
  39. cosyvoice/utils/class_utils.py +83 -0
  40. cosyvoice/utils/common.py +186 -0
  41. cosyvoice/utils/executor.py +176 -0
  42. cosyvoice/utils/file_utils.py +129 -0
  43. cosyvoice/utils/frontend_utils.py +136 -0
  44. cosyvoice/utils/losses.py +57 -0
  45. cosyvoice/utils/mask.py +265 -0
  46. cosyvoice/utils/scheduler.py +738 -0
  47. cosyvoice/utils/train_utils.py +367 -0
  48. cosyvoice/vllm/cosyvoice2.py +103 -0
  49. matcha/__init__.py +0 -0
  50. matcha/app.py +357 -0
  51. matcha/cli.py +418 -0
  52. matcha/hifigan/__init__.py +0 -0
  53. matcha/hifigan/config.py +28 -0
  54. matcha/hifigan/denoiser.py +64 -0
  55. matcha/hifigan/env.py +17 -0
  56. matcha/hifigan/meldataset.py +217 -0
  57. matcha/hifigan/models.py +368 -0
  58. matcha/hifigan/xutils.py +60 -0
  59. matcha/models/__init__.py +0 -0
  60. matcha/models/baselightningmodule.py +209 -0
  61. matcha/models/components/__init__.py +0 -0
  62. matcha/models/components/decoder.py +443 -0
  63. matcha/models/components/flow_matching.py +132 -0
  64. matcha/models/components/text_encoder.py +410 -0
  65. matcha/models/components/transformer.py +316 -0
  66. matcha/models/matcha_tts.py +239 -0
  67. matcha/onnx/__init__.py +0 -0
  68. matcha/onnx/export.py +181 -0
  69. matcha/onnx/infer.py +168 -0
  70. matcha/text/__init__.py +53 -0
  71. matcha/text/cleaners.py +116 -0
  72. matcha/text/numbers.py +71 -0
  73. matcha/text/symbols.py +17 -0
  74. matcha/train.py +122 -0
  75. matcha/utils/__init__.py +5 -0
  76. matcha/utils/audio.py +82 -0
  77. matcha/utils/generate_data_statistics.py +111 -0
  78. matcha/utils/instantiators.py +56 -0
  79. matcha/utils/logging_utils.py +53 -0
  80. matcha/utils/model.py +90 -0
  81. matcha/utils/monotonic_align/__init__.py +22 -0
  82. matcha/utils/monotonic_align/setup.py +7 -0
  83. matcha/utils/pylogger.py +21 -0
  84. matcha/utils/rich_utils.py +101 -0
  85. matcha/utils/utils.py +219 -0
  86. minicpmo/__init__.py +24 -0
  87. minicpmo/utils.py +636 -0
  88. minicpmo/version.py +2 -0
  89. minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
  90. minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
  91. minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
  92. minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
  93. s3tokenizer/__init__.py +153 -0
  94. s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
  95. s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
  96. s3tokenizer/assets/mel_filters.npz +0 -0
  97. s3tokenizer/cli.py +183 -0
  98. s3tokenizer/model.py +546 -0
  99. s3tokenizer/model_v2.py +605 -0
  100. s3tokenizer/utils.py +390 -0
  101. stepaudio2/__init__.py +40 -0
  102. stepaudio2/cosyvoice2/__init__.py +1 -0
  103. stepaudio2/cosyvoice2/flow/__init__.py +0 -0
  104. stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
  105. stepaudio2/cosyvoice2/flow/flow.py +230 -0
  106. stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
  107. stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
  108. stepaudio2/cosyvoice2/transformer/attention.py +328 -0
  109. stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
  110. stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
  111. stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
  112. stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
  113. stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
  114. stepaudio2/cosyvoice2/utils/__init__.py +1 -0
  115. stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
  116. stepaudio2/cosyvoice2/utils/common.py +101 -0
  117. stepaudio2/cosyvoice2/utils/mask.py +49 -0
  118. stepaudio2/flashcosyvoice/__init__.py +0 -0
  119. stepaudio2/flashcosyvoice/cli.py +424 -0
  120. stepaudio2/flashcosyvoice/config.py +80 -0
  121. stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
  122. stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
  123. stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
  124. stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
  125. stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
  126. stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
  127. stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
  128. stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
  129. stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
  130. stepaudio2/flashcosyvoice/modules/flow.py +198 -0
  131. stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
  132. stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
  133. stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
  134. stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
  135. stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
  136. stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
  137. stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
  138. stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
  139. stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
  140. stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
  141. stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
  142. stepaudio2/flashcosyvoice/utils/audio.py +77 -0
  143. stepaudio2/flashcosyvoice/utils/context.py +28 -0
  144. stepaudio2/flashcosyvoice/utils/loader.py +116 -0
  145. stepaudio2/flashcosyvoice/utils/memory.py +19 -0
  146. stepaudio2/stepaudio2.py +204 -0
  147. stepaudio2/token2wav.py +248 -0
  148. stepaudio2/utils.py +91 -0
@@ -0,0 +1,310 @@
1
+ import pickle
2
+ from multiprocessing.shared_memory import SharedMemory
3
+ from multiprocessing.synchronize import Event
4
+
5
+ import torch
6
+ import torch.distributed as dist
7
+
8
+ from stepaudio2.flashcosyvoice.config import Config
9
+ from stepaudio2.flashcosyvoice.engine.sequence import Sequence
10
+ from stepaudio2.flashcosyvoice.modules.qwen2 import Qwen2ForCausalLM
11
+ from stepaudio2.flashcosyvoice.modules.sampler import RasSampler, Sampler
12
+ from stepaudio2.flashcosyvoice.utils.context import (get_context, reset_context,
13
+ set_context)
14
+ from stepaudio2.flashcosyvoice.utils.loader import load_model
15
+
16
+
17
+ class ModelRunner:
18
+
19
+ def __init__(self, config: Config, rank: int, event: Event | list[Event]):
20
+ self.config = config
21
+ hf_config = config.hf_config
22
+ self.block_size = config.kvcache_block_size
23
+ self.enforce_eager = config.enforce_eager
24
+ self.world_size = config.tensor_parallel_size
25
+ self.rank = rank
26
+ self.event = event
27
+
28
+ # TODO(xcsong): support tp > 1
29
+ if self.world_size > 1:
30
+ dist.init_process_group("nccl", "tcp://localhost:2333", world_size=self.world_size, rank=rank)
31
+ torch.cuda.set_device(rank)
32
+ default_dtype = torch.get_default_dtype()
33
+ torch.set_default_dtype(hf_config.torch_dtype)
34
+ torch.set_default_device("cuda")
35
+ self.model = Qwen2ForCausalLM(hf_config)
36
+ load_model(self.model, config.model, hf_config)
37
+ self.sampler = Sampler()
38
+ self.ras_sampler = RasSampler()
39
+ self.warmup_model()
40
+ self.allocate_kv_cache()
41
+ if not self.enforce_eager:
42
+ self.capture_cudagraph()
43
+ torch.set_default_device("cpu")
44
+ torch.set_default_dtype(default_dtype)
45
+
46
+ if self.world_size > 1:
47
+ if rank == 0:
48
+ self.shm = SharedMemory(name="flashcosyvoice", create=True, size=2**20)
49
+ dist.barrier()
50
+ else:
51
+ dist.barrier()
52
+ self.shm = SharedMemory(name="flashcosyvoice")
53
+ self.loop()
54
+
55
+ def exit(self):
56
+ if self.world_size > 1:
57
+ self.shm.close()
58
+ dist.barrier()
59
+ if self.rank == 0:
60
+ self.shm.unlink()
61
+ if not self.enforce_eager:
62
+ del self.graphs, self.graph_pool
63
+ torch.cuda.synchronize()
64
+ if self.world_size > 1:
65
+ dist.destroy_process_group()
66
+
67
+ def loop(self):
68
+ while True:
69
+ method_name, args = self.read_shm()
70
+ self.call(method_name, *args)
71
+ if method_name == "exit":
72
+ break
73
+
74
+ def read_shm(self):
75
+ assert self.world_size > 1 and self.rank
76
+ self.event.wait()
77
+ n = int.from_bytes(self.shm.buf[0:4], "little")
78
+ method_name, *args = pickle.loads(self.shm.buf[4:n + 4])
79
+ self.event.clear()
80
+ return method_name, args
81
+
82
+ def write_shm(self, method_name, *args):
83
+ assert self.world_size > 1 and not self.rank
84
+ data = pickle.dumps([method_name, *args])
85
+ n = len(data)
86
+ self.shm.buf[0:4] = n.to_bytes(4, "little")
87
+ self.shm.buf[4:n + 4] = data
88
+ for event in self.event:
89
+ event.set()
90
+
91
+ def call(self, method_name, *args):
92
+ if self.world_size > 1 and self.rank == 0:
93
+ self.write_shm(method_name, *args)
94
+ method = getattr(self, method_name, None)
95
+ return method(*args)
96
+
97
+ def warmup_model(self):
98
+ torch.cuda.empty_cache()
99
+ torch.cuda.reset_peak_memory_stats()
100
+ max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len
101
+ num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs)
102
+ seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)]
103
+ self.run(seqs, True)
104
+ torch.cuda.empty_cache()
105
+
106
+ def allocate_kv_cache(self):
107
+ config = self.config
108
+ hf_config = config.hf_config
109
+ free, total = torch.cuda.mem_get_info()
110
+ used = total - free
111
+ peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
112
+ current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
113
+ num_kv_heads = hf_config.num_key_value_heads // self.world_size
114
+ head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
115
+ block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize
116
+ config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
117
+ assert config.num_kvcache_blocks > 0, "try to **increase** gpu_memory_utilization"
118
+ self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim)
119
+ layer_id = 0
120
+ for module in self.model.modules():
121
+ if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
122
+ module.k_cache = self.kv_cache[0, layer_id]
123
+ module.v_cache = self.kv_cache[1, layer_id]
124
+ layer_id += 1
125
+
126
+ def prepare_block_tables(self, seqs: list[Sequence]):
127
+ max_len = max(len(seq.block_table) for seq in seqs)
128
+ block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs]
129
+ block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
130
+ return block_tables
131
+
132
+ def prepare_prefill(self, seqs: list[Sequence]):
133
+ input_ids = []
134
+ positions = []
135
+ cu_seqlens_q = [0]
136
+ cu_seqlens_k = [0]
137
+ max_seqlen_q = 0
138
+ max_seqlen_k = 0
139
+ slot_mapping = []
140
+ block_tables = None
141
+ for seq in seqs:
142
+ seqlen = len(seq)
143
+ input_ids.extend(seq[seq.num_cached_tokens:])
144
+ positions.extend(list(range(seq.num_cached_tokens, seqlen)))
145
+ seqlen_q = seqlen - seq.num_cached_tokens
146
+ seqlen_k = seqlen
147
+ cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
148
+ cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
149
+ max_seqlen_q = max(seqlen_q, max_seqlen_q)
150
+ max_seqlen_k = max(seqlen_k, max_seqlen_k)
151
+ if not seq.block_table:
152
+ continue
153
+ for i in range(seq.num_cached_blocks, seq.num_blocks):
154
+ start = seq.block_table[i] * self.block_size
155
+ if i != seq.num_blocks - 1:
156
+ end = start + self.block_size
157
+ else:
158
+ end = start + seq.last_block_num_tokens
159
+ slot_mapping.extend(list(range(start, end)))
160
+ if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
161
+ block_tables = self.prepare_block_tables(seqs)
162
+ input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
163
+ positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
164
+ cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
165
+ cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
166
+ slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
167
+ set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
168
+ return input_ids, positions
169
+
170
+ def prepare_decode(self, seqs: list[Sequence]):
171
+ input_ids = []
172
+ positions = []
173
+ slot_mapping = []
174
+ context_lens = []
175
+ for seq in seqs:
176
+ input_ids.append(seq.last_token)
177
+ positions.append(len(seq))
178
+ context_lens.append(len(seq))
179
+ slot_mapping.append(seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1)
180
+ input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
181
+ positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
182
+ slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
183
+ context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
184
+ block_tables = self.prepare_block_tables(seqs)
185
+ set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
186
+ return input_ids, positions
187
+
188
+ def prepare_sample(self, seqs: list[Sequence]):
189
+ temperatures = []
190
+ top_ks = []
191
+ win_sizes = []
192
+ tau_rs = []
193
+ top_ps = []
194
+ min_tokens_list = []
195
+ use_ras_list = []
196
+
197
+ for seq in seqs:
198
+ temperatures.append(seq.temperature)
199
+ top_ks.append(seq.top_k)
200
+ win_sizes.append(seq.win_size)
201
+ tau_rs.append(seq.tau_r)
202
+ top_ps.append(seq.top_p)
203
+ min_tokens_list.append(seq.min_tokens)
204
+ use_ras_list.append(seq.use_ras)
205
+
206
+ temperatures_tensor = torch.tensor(temperatures, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
207
+ # check all items equal
208
+ assert all(item == temperatures[0] for item in temperatures)
209
+ assert all(item == top_ks[0] for item in top_ks)
210
+ assert all(item == win_sizes[0] for item in win_sizes)
211
+ assert all(item == tau_rs[0] for item in tau_rs)
212
+ assert all(item == top_ps[0] for item in top_ps)
213
+ assert all(item == use_ras_list[0] for item in use_ras_list)
214
+
215
+ return {
216
+ 'temperatures': temperatures_tensor,
217
+ 'top_k': top_ks[0],
218
+ 'win_size': win_sizes[0],
219
+ 'tau_r': tau_rs[0],
220
+ 'top_p': top_ps[0],
221
+ 'eos_token': self.config.eos,
222
+ 'min_tokens': min_tokens_list,
223
+ 'use_ras': use_ras_list[0]
224
+ }
225
+
226
+ @torch.inference_mode()
227
+ def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool):
228
+ if is_prefill or self.enforce_eager or input_ids.size(0) > 512:
229
+ return self.model.compute_logits(self.model(input_ids, positions))
230
+ else:
231
+ bs = input_ids.size(0)
232
+ context = get_context()
233
+ graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
234
+ graph_vars = self.graph_vars
235
+ for k, v in graph_vars.items():
236
+ if k != "outputs":
237
+ v.zero_()
238
+ graph_vars["input_ids"][:bs] = input_ids
239
+ graph_vars["positions"][:bs] = positions
240
+ graph_vars["slot_mapping"][:bs] = context.slot_mapping
241
+ graph_vars["context_lens"][:bs] = context.context_lens
242
+ graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables
243
+ graph.replay()
244
+ return self.model.compute_logits(graph_vars["outputs"][:bs])
245
+
246
+ def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
247
+ input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
248
+ if self.rank == 0 or self.world_size == 1:
249
+ sample_params = self.prepare_sample(seqs)
250
+ logits = self.run_model(input_ids, positions, is_prefill)
251
+
252
+ if sample_params['use_ras']:
253
+ # Prepare decoded tokens list for RasSampler
254
+ decoded_tokens_list = [seq.completion_token_ids for seq in seqs]
255
+ # Pass all parameters as lists to RasSampler
256
+ token_ids = self.ras_sampler(
257
+ logits,
258
+ decoded_tokens_list,
259
+ win_size=sample_params['win_size'],
260
+ tau_r=sample_params['tau_r'],
261
+ top_p=sample_params['top_p'],
262
+ top_k=sample_params['top_k'],
263
+ eos_token=sample_params['eos_token'],
264
+ min_tokens=sample_params['min_tokens']
265
+ ).tolist()
266
+ else:
267
+ # Use the default sampler with list form of top_ks
268
+ token_ids = self.sampler(logits, sample_params['temperatures'], sample_params['top_k']).tolist()
269
+ else:
270
+ logits = self.run_model(input_ids, positions, is_prefill)
271
+ token_ids = None
272
+ reset_context()
273
+ return token_ids
274
+
275
+ @torch.inference_mode()
276
+ def capture_cudagraph(self):
277
+ config = self.config
278
+ hf_config = config.hf_config
279
+ max_bs = min(self.config.max_num_seqs, 512)
280
+ max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size
281
+ input_ids = torch.zeros(max_bs, dtype=torch.int64)
282
+ positions = torch.zeros(max_bs, dtype=torch.int64)
283
+ slot_mapping = torch.zeros(max_bs, dtype=torch.int32)
284
+ context_lens = torch.zeros(max_bs, dtype=torch.int32)
285
+ block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)
286
+ outputs = torch.zeros(max_bs, hf_config.hidden_size)
287
+ self.graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16))
288
+ self.graphs = {}
289
+ self.graph_pool = None
290
+
291
+ for bs in reversed(self.graph_bs):
292
+ graph = torch.cuda.CUDAGraph()
293
+ set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])
294
+ outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup
295
+ with torch.cuda.graph(graph, self.graph_pool):
296
+ outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture
297
+ if self.graph_pool is None:
298
+ self.graph_pool = graph.pool()
299
+ self.graphs[bs] = graph
300
+ torch.cuda.synchronize()
301
+ reset_context()
302
+
303
+ self.graph_vars = dict(
304
+ input_ids=input_ids,
305
+ positions=positions,
306
+ slot_mapping=slot_mapping,
307
+ context_lens=context_lens,
308
+ block_tables=block_tables,
309
+ outputs=outputs,
310
+ )
@@ -0,0 +1,77 @@
1
+ from collections import deque
2
+
3
+ from stepaudio2.flashcosyvoice.config import Config
4
+ from stepaudio2.flashcosyvoice.engine.block_manager import BlockManager
5
+ from stepaudio2.flashcosyvoice.engine.sequence import Sequence, SequenceStatus
6
+
7
+
8
+ class Scheduler:
9
+
10
+ def __init__(self, config: Config):
11
+ self.max_num_seqs = config.max_num_seqs
12
+ self.max_num_batched_tokens = config.max_num_batched_tokens
13
+ self.eos = config.eos
14
+ self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size)
15
+ self.waiting: deque[Sequence] = deque()
16
+ self.running: deque[Sequence] = deque()
17
+
18
+ def is_finished(self):
19
+ return not self.waiting and not self.running
20
+
21
+ def add(self, seq: Sequence):
22
+ self.waiting.append(seq)
23
+
24
+ def schedule(self) -> tuple[list[Sequence], bool]:
25
+ # prefill
26
+ scheduled_seqs = []
27
+ num_seqs = 0
28
+ num_batched_tokens = 0
29
+ while self.waiting and num_seqs < self.max_num_seqs:
30
+ seq = self.waiting[0]
31
+ if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq):
32
+ break
33
+ num_seqs += 1
34
+ self.block_manager.allocate(seq)
35
+ num_batched_tokens += len(seq) - seq.num_cached_tokens
36
+ seq.status = SequenceStatus.RUNNING
37
+ self.waiting.popleft()
38
+ self.running.append(seq)
39
+ scheduled_seqs.append(seq)
40
+ if scheduled_seqs:
41
+ return scheduled_seqs, True
42
+
43
+ # decode
44
+ while self.running and num_seqs < self.max_num_seqs:
45
+ seq = self.running.popleft()
46
+ while not self.block_manager.can_append(seq):
47
+ if self.running:
48
+ self.preempt(self.running.pop())
49
+ else:
50
+ self.preempt(seq)
51
+ break
52
+ else:
53
+ num_seqs += 1
54
+ self.block_manager.may_append(seq)
55
+ scheduled_seqs.append(seq)
56
+ assert scheduled_seqs
57
+ self.running.extendleft(reversed(scheduled_seqs))
58
+ return scheduled_seqs, False
59
+
60
+ def preempt(self, seq: Sequence):
61
+ seq.status = SequenceStatus.WAITING
62
+ self.block_manager.deallocate(seq)
63
+ self.waiting.appendleft(seq)
64
+
65
+ def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
66
+ for seq, token_id in zip(seqs, token_ids):
67
+ seq.append_token(token_id)
68
+ # Check if the sequence has reached the maximum number of tokens
69
+ reached_max_tokens = seq.num_completion_tokens == seq.max_tokens
70
+ # Check if the sequence has reached EOS and has generated enough tokens (satisfying min_tokens requirements)
71
+ eos_with_min_tokens = (not seq.ignore_eos and token_id == self.eos and
72
+ seq.num_completion_tokens >= seq.min_tokens)
73
+
74
+ if reached_max_tokens or eos_with_min_tokens:
75
+ seq.status = SequenceStatus.FINISHED
76
+ self.block_manager.deallocate(seq)
77
+ self.running.remove(seq)
@@ -0,0 +1,90 @@
1
+ from copy import copy
2
+ from enum import Enum, auto
3
+ from itertools import count
4
+
5
+ from stepaudio2.flashcosyvoice.config import SamplingParams
6
+
7
+
8
+ class SequenceStatus(Enum):
9
+ WAITING = auto()
10
+ RUNNING = auto()
11
+ FINISHED = auto()
12
+
13
+
14
+ class Sequence:
15
+ block_size = 256
16
+ counter = count()
17
+
18
+ def __init__(self, token_ids: list[int], sampling_params = SamplingParams()):
19
+ self.seq_id = next(Sequence.counter)
20
+ self.status = SequenceStatus.WAITING
21
+ self.token_ids = copy(token_ids)
22
+ self.last_token = token_ids[-1]
23
+ self.num_tokens = len(self.token_ids)
24
+ self.num_prompt_tokens = len(token_ids)
25
+ self.num_cached_tokens = 0
26
+ self.block_table = []
27
+ self.temperature = sampling_params.temperature
28
+ self.min_tokens = sampling_params.min_tokens
29
+ self.max_tokens = sampling_params.max_tokens
30
+ self.ignore_eos = sampling_params.ignore_eos
31
+ self.top_k = sampling_params.top_k
32
+ # RasSampler parameters
33
+ self.use_ras = sampling_params.use_ras
34
+ self.win_size = sampling_params.win_size
35
+ self.tau_r = sampling_params.tau_r
36
+ self.top_p = sampling_params.top_p
37
+
38
+ def __len__(self):
39
+ return self.num_tokens
40
+
41
+ def __getitem__(self, key):
42
+ return self.token_ids[key]
43
+
44
+ @property
45
+ def is_finished(self):
46
+ return self.status == SequenceStatus.FINISHED
47
+
48
+ @property
49
+ def num_completion_tokens(self):
50
+ return self.num_tokens - self.num_prompt_tokens
51
+
52
+ @property
53
+ def prompt_token_ids(self):
54
+ return self.token_ids[:self.num_prompt_tokens]
55
+
56
+ @property
57
+ def completion_token_ids(self):
58
+ return self.token_ids[self.num_prompt_tokens:]
59
+
60
+ @property
61
+ def num_cached_blocks(self):
62
+ return self.num_cached_tokens // self.block_size
63
+
64
+ @property
65
+ def num_blocks(self):
66
+ return (self.num_tokens + self.block_size - 1) // self.block_size
67
+
68
+ @property
69
+ def last_block_num_tokens(self):
70
+ return self.num_tokens - (self.num_blocks - 1) * self.block_size
71
+
72
+ def block(self, i):
73
+ assert 0 <= i < self.num_blocks
74
+ return self.token_ids[i*self.block_size: (i+1)*self.block_size]
75
+
76
+ def append_token(self, token_id: int):
77
+ self.token_ids.append(token_id)
78
+ self.last_token = token_id
79
+ self.num_tokens += 1
80
+
81
+ def __getstate__(self):
82
+ return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table,
83
+ self.token_ids if self.num_completion_tokens == 0 else self.last_token)
84
+
85
+ def __setstate__(self, state):
86
+ self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table = state[:-1]
87
+ if self.num_completion_tokens == 0:
88
+ self.token_ids = state[-1]
89
+ else:
90
+ self.last_token = state[-1]
File without changes
@@ -0,0 +1,198 @@
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from stepaudio2.flashcosyvoice.modules.flow_components.estimator import \
8
+ CausalConditionalDecoder
9
+ from stepaudio2.flashcosyvoice.modules.flow_components.upsample_encoder import (
10
+ UpsampleConformerEncoder, make_pad_mask)
11
+
12
+
13
+ # TODO(xcsong): make it configurable
14
+ @dataclass
15
+ class CfmParams:
16
+ sigma_min: float = 1e-6
17
+ solver: str = "euler"
18
+ t_scheduler: str = "cosine"
19
+ training_cfg_rate: float = 0.2
20
+ inference_cfg_rate: float = 0.7
21
+
22
+
23
+ class CausalConditionalCFM(torch.nn.Module):
24
+ def __init__(self, in_channels=320, cfm_params=CfmParams(), n_spks=1, spk_emb_dim=80, estimator: torch.nn.Module = None):
25
+ super().__init__()
26
+ self.n_feats = in_channels
27
+ self.n_spks = n_spks
28
+ self.spk_emb_dim = spk_emb_dim
29
+ self.solver = cfm_params.solver
30
+ if hasattr(cfm_params, "sigma_min"):
31
+ self.sigma_min = cfm_params.sigma_min
32
+ else:
33
+ self.sigma_min = 1e-4
34
+ self.t_scheduler = cfm_params.t_scheduler
35
+ self.training_cfg_rate = cfm_params.training_cfg_rate
36
+ self.inference_cfg_rate = cfm_params.inference_cfg_rate
37
+ in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
38
+ # Just change the architecture of the estimator here
39
+ self.estimator = CausalConditionalDecoder() if estimator is None else estimator
40
+
41
+ @torch.inference_mode()
42
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, streaming=False):
43
+ """Forward diffusion
44
+
45
+ Args:
46
+ mu (torch.Tensor): output of encoder
47
+ shape: (batch_size, n_feats, mel_timesteps)
48
+ mask (torch.Tensor): output_mask
49
+ shape: (batch_size, 1, mel_timesteps)
50
+ n_timesteps (int): number of diffusion steps
51
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
52
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
53
+ shape: (batch_size, spk_emb_dim)
54
+ cond: Not used but kept for future purposes
55
+
56
+ Returns:
57
+ sample: generated mel-spectrogram
58
+ shape: (batch_size, n_feats, mel_timesteps)
59
+ """
60
+ z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
61
+ # fix prompt and overlap part mu and z
62
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
63
+ if self.t_scheduler == 'cosine':
64
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
65
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, streaming=streaming), None
66
+
67
+ def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False):
68
+ """
69
+ Fixed euler solver for ODEs.
70
+ Args:
71
+ x (torch.Tensor): random noise
72
+ t_span (torch.Tensor): n_timesteps interpolated
73
+ shape: (n_timesteps + 1,)
74
+ mu (torch.Tensor): output of encoder
75
+ shape: (batch_size, n_feats, mel_timesteps)
76
+ mask (torch.Tensor): output_mask
77
+ shape: (batch_size, 1, mel_timesteps)
78
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
79
+ shape: (batch_size, spk_emb_dim)
80
+ cond: Not used but kept for future purposes
81
+ """
82
+ batch_size = x.size(0)
83
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
84
+
85
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
86
+ # Or in future might add like a return_all_steps flag
87
+ sol = []
88
+
89
+ # Do not use concat, it may cause memory format changed and trt infer with wrong results!
90
+ # Create tensors with double batch size for CFG (conditional + unconditional)
91
+ x_in = torch.zeros([batch_size * 2, x.size(1), x.size(2)], device=x.device, dtype=x.dtype)
92
+ mask_in = torch.zeros([batch_size * 2, mask.size(1), mask.size(2)], device=x.device, dtype=x.dtype)
93
+ mu_in = torch.zeros([batch_size * 2, mu.size(1), mu.size(2)], device=x.device, dtype=x.dtype)
94
+ t_in = torch.zeros([batch_size * 2], device=x.device, dtype=x.dtype)
95
+ spks_in = torch.zeros([batch_size * 2, spks.size(1)], device=x.device, dtype=x.dtype)
96
+ cond_in = torch.zeros([batch_size * 2, cond.size(1), cond.size(2)], device=x.device, dtype=x.dtype)
97
+
98
+ for step in range(1, len(t_span)):
99
+ # Classifier-Free Guidance inference introduced in VoiceBox
100
+ # Copy conditional and unconditional input
101
+ x_in[:batch_size] = x
102
+ x_in[batch_size:] = x
103
+ mask_in[:batch_size] = mask
104
+ mask_in[batch_size:] = mask
105
+ mu_in[:batch_size] = mu
106
+ # Unconditional part remains 0
107
+ t_in.fill_(t)
108
+ spks_in[:batch_size] = spks
109
+ cond_in[:batch_size] = cond
110
+
111
+ dphi_dt = self.estimator(
112
+ x_in, mask_in,
113
+ mu_in, t_in,
114
+ spks_in,
115
+ cond_in,
116
+ streaming
117
+ )
118
+ dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [batch_size, batch_size], dim=0)
119
+ dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
120
+ x = x + dt * dphi_dt
121
+ t = t + dt
122
+ sol.append(x)
123
+ if step < len(t_span) - 1:
124
+ dt = t_span[step + 1] - t
125
+
126
+ return sol[-1].float()
127
+
128
+
129
+ class CausalMaskedDiffWithXvec(torch.nn.Module):
130
+ def __init__(
131
+ self,
132
+ input_size: int = 512,
133
+ output_size: int = 80,
134
+ spk_embed_dim: int = 192,
135
+ output_type: str = "mel",
136
+ vocab_size: int = 6561,
137
+ input_frame_rate: int = 25,
138
+ token_mel_ratio: int = 2,
139
+ pre_lookahead_len: int = 3,
140
+ encoder: torch.nn.Module = None,
141
+ decoder: torch.nn.Module = None,
142
+ ):
143
+ super().__init__()
144
+ self.input_size = input_size
145
+ self.output_size = output_size
146
+ self.vocab_size = vocab_size
147
+ self.output_type = output_type
148
+ self.input_frame_rate = input_frame_rate
149
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
150
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
151
+ self.encoder = UpsampleConformerEncoder() if encoder is None else encoder
152
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
153
+ self.decoder = CausalConditionalCFM() if decoder is None else decoder
154
+ self.token_mel_ratio = token_mel_ratio
155
+ self.pre_lookahead_len = pre_lookahead_len
156
+
157
+ @torch.inference_mode()
158
+ def forward(self,
159
+ token,
160
+ token_len,
161
+ prompt_feat,
162
+ prompt_feat_len,
163
+ embedding,
164
+ streaming,
165
+ finalize):
166
+ # xvec projection
167
+ embedding = F.normalize(embedding, dim=1)
168
+ embedding = self.spk_embed_affine_layer(embedding)
169
+
170
+ # concat text and prompt_text
171
+ mask = (~make_pad_mask(token_len, max_len=token.shape[1])).unsqueeze(-1).to(embedding)
172
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
173
+
174
+ # text encode
175
+ if finalize is True:
176
+ h, h_lengths = self.encoder(token, token_len, streaming=streaming)
177
+ else:
178
+ token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
179
+ h, h_lengths = self.encoder(token, token_len, context=context, streaming=streaming)
180
+ h = self.encoder_proj(h)
181
+
182
+ # get conditions
183
+ conds = torch.zeros_like(h, device=token.device)
184
+ for i, j in enumerate(prompt_feat_len):
185
+ conds[i, :j] = prompt_feat[i, :j]
186
+ conds = conds.transpose(1, 2)
187
+
188
+ h_lengths = h_lengths.sum(dim=-1).squeeze(dim=1)
189
+ mask = (~make_pad_mask(h_lengths, max_len=h.shape[1])).to(h)
190
+ feat, _ = self.decoder(
191
+ mu=h.transpose(1, 2).contiguous(),
192
+ mask=mask.unsqueeze(1),
193
+ spks=embedding,
194
+ cond=conds,
195
+ n_timesteps=10,
196
+ streaming=streaming
197
+ ) # [B, num_mels, T]
198
+ return feat.float(), h_lengths