minicpmo-utils 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. cosyvoice/__init__.py +17 -0
  2. cosyvoice/bin/average_model.py +93 -0
  3. cosyvoice/bin/export_jit.py +103 -0
  4. cosyvoice/bin/export_onnx.py +120 -0
  5. cosyvoice/bin/inference_deprecated.py +126 -0
  6. cosyvoice/bin/train.py +195 -0
  7. cosyvoice/cli/__init__.py +0 -0
  8. cosyvoice/cli/cosyvoice.py +209 -0
  9. cosyvoice/cli/frontend.py +238 -0
  10. cosyvoice/cli/model.py +386 -0
  11. cosyvoice/dataset/__init__.py +0 -0
  12. cosyvoice/dataset/dataset.py +151 -0
  13. cosyvoice/dataset/processor.py +434 -0
  14. cosyvoice/flow/decoder.py +494 -0
  15. cosyvoice/flow/flow.py +281 -0
  16. cosyvoice/flow/flow_matching.py +227 -0
  17. cosyvoice/flow/length_regulator.py +70 -0
  18. cosyvoice/hifigan/discriminator.py +230 -0
  19. cosyvoice/hifigan/f0_predictor.py +58 -0
  20. cosyvoice/hifigan/generator.py +582 -0
  21. cosyvoice/hifigan/hifigan.py +67 -0
  22. cosyvoice/llm/llm.py +610 -0
  23. cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
  24. cosyvoice/tokenizer/tokenizer.py +279 -0
  25. cosyvoice/transformer/__init__.py +0 -0
  26. cosyvoice/transformer/activation.py +84 -0
  27. cosyvoice/transformer/attention.py +330 -0
  28. cosyvoice/transformer/convolution.py +145 -0
  29. cosyvoice/transformer/decoder.py +396 -0
  30. cosyvoice/transformer/decoder_layer.py +132 -0
  31. cosyvoice/transformer/embedding.py +302 -0
  32. cosyvoice/transformer/encoder.py +474 -0
  33. cosyvoice/transformer/encoder_layer.py +236 -0
  34. cosyvoice/transformer/label_smoothing_loss.py +96 -0
  35. cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  36. cosyvoice/transformer/subsampling.py +383 -0
  37. cosyvoice/transformer/upsample_encoder.py +320 -0
  38. cosyvoice/utils/__init__.py +0 -0
  39. cosyvoice/utils/class_utils.py +83 -0
  40. cosyvoice/utils/common.py +186 -0
  41. cosyvoice/utils/executor.py +176 -0
  42. cosyvoice/utils/file_utils.py +129 -0
  43. cosyvoice/utils/frontend_utils.py +136 -0
  44. cosyvoice/utils/losses.py +57 -0
  45. cosyvoice/utils/mask.py +265 -0
  46. cosyvoice/utils/scheduler.py +738 -0
  47. cosyvoice/utils/train_utils.py +367 -0
  48. cosyvoice/vllm/cosyvoice2.py +103 -0
  49. matcha/__init__.py +0 -0
  50. matcha/app.py +357 -0
  51. matcha/cli.py +418 -0
  52. matcha/hifigan/__init__.py +0 -0
  53. matcha/hifigan/config.py +28 -0
  54. matcha/hifigan/denoiser.py +64 -0
  55. matcha/hifigan/env.py +17 -0
  56. matcha/hifigan/meldataset.py +217 -0
  57. matcha/hifigan/models.py +368 -0
  58. matcha/hifigan/xutils.py +60 -0
  59. matcha/models/__init__.py +0 -0
  60. matcha/models/baselightningmodule.py +209 -0
  61. matcha/models/components/__init__.py +0 -0
  62. matcha/models/components/decoder.py +443 -0
  63. matcha/models/components/flow_matching.py +132 -0
  64. matcha/models/components/text_encoder.py +410 -0
  65. matcha/models/components/transformer.py +316 -0
  66. matcha/models/matcha_tts.py +239 -0
  67. matcha/onnx/__init__.py +0 -0
  68. matcha/onnx/export.py +181 -0
  69. matcha/onnx/infer.py +168 -0
  70. matcha/text/__init__.py +53 -0
  71. matcha/text/cleaners.py +116 -0
  72. matcha/text/numbers.py +71 -0
  73. matcha/text/symbols.py +17 -0
  74. matcha/train.py +122 -0
  75. matcha/utils/__init__.py +5 -0
  76. matcha/utils/audio.py +82 -0
  77. matcha/utils/generate_data_statistics.py +111 -0
  78. matcha/utils/instantiators.py +56 -0
  79. matcha/utils/logging_utils.py +53 -0
  80. matcha/utils/model.py +90 -0
  81. matcha/utils/monotonic_align/__init__.py +22 -0
  82. matcha/utils/monotonic_align/setup.py +7 -0
  83. matcha/utils/pylogger.py +21 -0
  84. matcha/utils/rich_utils.py +101 -0
  85. matcha/utils/utils.py +219 -0
  86. minicpmo/__init__.py +24 -0
  87. minicpmo/utils.py +636 -0
  88. minicpmo/version.py +2 -0
  89. minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
  90. minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
  91. minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
  92. minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
  93. s3tokenizer/__init__.py +153 -0
  94. s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
  95. s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
  96. s3tokenizer/assets/mel_filters.npz +0 -0
  97. s3tokenizer/cli.py +183 -0
  98. s3tokenizer/model.py +546 -0
  99. s3tokenizer/model_v2.py +605 -0
  100. s3tokenizer/utils.py +390 -0
  101. stepaudio2/__init__.py +40 -0
  102. stepaudio2/cosyvoice2/__init__.py +1 -0
  103. stepaudio2/cosyvoice2/flow/__init__.py +0 -0
  104. stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
  105. stepaudio2/cosyvoice2/flow/flow.py +230 -0
  106. stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
  107. stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
  108. stepaudio2/cosyvoice2/transformer/attention.py +328 -0
  109. stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
  110. stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
  111. stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
  112. stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
  113. stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
  114. stepaudio2/cosyvoice2/utils/__init__.py +1 -0
  115. stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
  116. stepaudio2/cosyvoice2/utils/common.py +101 -0
  117. stepaudio2/cosyvoice2/utils/mask.py +49 -0
  118. stepaudio2/flashcosyvoice/__init__.py +0 -0
  119. stepaudio2/flashcosyvoice/cli.py +424 -0
  120. stepaudio2/flashcosyvoice/config.py +80 -0
  121. stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
  122. stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
  123. stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
  124. stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
  125. stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
  126. stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
  127. stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
  128. stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
  129. stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
  130. stepaudio2/flashcosyvoice/modules/flow.py +198 -0
  131. stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
  132. stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
  133. stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
  134. stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
  135. stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
  136. stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
  137. stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
  138. stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
  139. stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
  140. stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
  141. stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
  142. stepaudio2/flashcosyvoice/utils/audio.py +77 -0
  143. stepaudio2/flashcosyvoice/utils/context.py +28 -0
  144. stepaudio2/flashcosyvoice/utils/loader.py +116 -0
  145. stepaudio2/flashcosyvoice/utils/memory.py +19 -0
  146. stepaudio2/stepaudio2.py +204 -0
  147. stepaudio2/token2wav.py +248 -0
  148. stepaudio2/utils.py +91 -0
@@ -0,0 +1,231 @@
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class Sampler(nn.Module):
6
+ """
7
+ Optimized sampler implementation using vectorized operations instead of loops, significantly improving performance
8
+
9
+ Performance optimizations:
10
+ 1. Using batch processing instead of sequence loops, reducing Python loop overhead
11
+ 2. Using PyTorch's vectorized operations (like torch.sort, torch.gather) for parallel computation
12
+ 3. Using mask operations to apply top-k filtering at once, avoiding per-sequence processing
13
+ """
14
+ def __init__(self):
15
+ super().__init__()
16
+
17
+ def forward(self, logits: torch.Tensor, temperatures: torch.Tensor, top_k: int = None):
18
+ """
19
+ Perform sampling operation using vectorized method for top-k filtering
20
+
21
+ Args:
22
+ logits: Logits tensor with shape [batch_size, vocab_size]
23
+ temperatures: Temperature parameters with shape [batch_size]
24
+ top_k: Top-k value for filtering (uniform across all sequences)
25
+
26
+ Returns:
27
+ Sampled token IDs
28
+ """
29
+ logits = logits.to(torch.float)
30
+ greedy_tokens = logits.argmax(dim=-1) # Greedy decoding result, used when temperature=0
31
+ logits.div_(temperatures.unsqueeze(dim=1)) # Apply temperature scaling
32
+
33
+ # Apply uniform top-k filtering if top_k is provided
34
+ if top_k is not None and top_k > 0:
35
+ vocab_size = logits.size(-1)
36
+
37
+ # Create a mask to store which positions should be kept
38
+ mask = torch.zeros_like(logits, dtype=torch.bool)
39
+
40
+ # Batch sorting for all sequences at once
41
+ sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True)
42
+
43
+ # Get threshold for each sequence (the k-th largest value)
44
+ k_value = min(top_k, vocab_size) # Ensure k doesn't exceed vocab size
45
+ thresholds = sorted_logits[:, k_value-1:k_value] # Shape [batch_size, 1]
46
+ thresholds = thresholds.expand(-1, vocab_size) # Expand to match logits shape
47
+
48
+ # Create mask: only keep logits greater than or equal to threshold
49
+ mask = logits >= thresholds
50
+
51
+ # Apply mask: set logits not in top-k to negative infinity
52
+ logits = torch.where(mask, logits, torch.tensor(float('-inf'), device=logits.device))
53
+
54
+ probs = torch.softmax(logits, dim=-1, dtype=torch.float)
55
+ # logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
56
+ sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
57
+ return torch.where(temperatures == 0, greedy_tokens, sample_tokens)
58
+
59
+
60
+ class RasSampler(nn.Module):
61
+ """
62
+ Optimized Repetition Aware Sampling implementation
63
+
64
+ Performance optimizations:
65
+ 1. Using vectorized nucleus sampling instead of loop implementation, improving sampling efficiency
66
+ 2. Using tensor operations to calculate repetition rate, reducing Python loop overhead
67
+ 3. Optimizing EOS handling logic, reducing unnecessary resampling
68
+ 4. Using PyTorch's vectorized operations for parallel computation
69
+ 5. Batch processing for all sequences, dramatically improving throughput
70
+ 6. Robust handling for sequences of any length, including empty sequences
71
+ """
72
+ def __init__(self):
73
+ super().__init__()
74
+
75
+ def forward(self, logits: torch.Tensor, decoded_tokens_list: list,
76
+ win_size: int = 10, tau_r: float = 0.1,
77
+ top_p: float = 0.8, top_k: int = 25,
78
+ eos_token: int = 6561, min_tokens: list[int] = None):
79
+ """
80
+ Execute repetition-aware sampling using optimized vectorized operations with batch processing
81
+
82
+ Args:
83
+ logits: Input logits with shape [batch_size, vocab_size]
84
+ decoded_tokens_list: List of decoded tokens, each element is a token list for a batch
85
+ win_size: Window size for repetition detection (uniform across all batch items)
86
+ tau_r: Repetition threshold (uniform across all batch items)
87
+ top_p: Nucleus sampling probability threshold (uniform across all batch items)
88
+ top_k: Nucleus sampling top-k threshold (uniform across all batch items)
89
+ eos_token: End of sequence token ID (uniform across all batch items)
90
+ min_tokens: List of minimum tokens to generate before allowing EOS, one per batch item
91
+ Returns:
92
+ Selected token IDs
93
+ """
94
+ batch_size = logits.size(0)
95
+ device = logits.device
96
+ result = torch.zeros(batch_size, dtype=torch.long, device=device)
97
+
98
+ # Set default values if not provided
99
+ if min_tokens is None:
100
+ min_tokens = [2] * batch_size
101
+
102
+ # Ensure min_tokens list has the correct length
103
+ assert len(min_tokens) == batch_size, f"min_tokens length {len(min_tokens)} != batch_size {batch_size}"
104
+
105
+ # Force continue decode first token
106
+ for i in range(batch_size):
107
+ if i < len(decoded_tokens_list) and len(decoded_tokens_list[i]) == 0:
108
+ logits[i, eos_token] = -float('inf')
109
+
110
+ # 1. First, perform nucleus sampling for all sequences
111
+ probs = torch.softmax(logits, dim=-1)
112
+
113
+ # Use vectorized nucleus sampling for all sequences
114
+ # This can be done in batch since top_p and top_k are uniform
115
+ sorted_probs, sorted_indices = probs.sort(dim=-1, descending=True)
116
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
117
+
118
+ # Create masks for top-p and top-k filtering
119
+ top_p_mask = cumulative_probs <= top_p
120
+
121
+ # Create top-k mask (first top_k positions are True)
122
+ top_k_mask = torch.zeros_like(top_p_mask)
123
+ top_k_mask[:, :top_k] = True
124
+
125
+ # Combine masks
126
+ mask = top_p_mask & top_k_mask
127
+
128
+ # Ensure at least one token is selected per sequence
129
+ first_token_mask = torch.zeros_like(mask)
130
+ first_token_mask[:, 0] = True
131
+ mask = mask | first_token_mask
132
+
133
+ # Sample from the filtered distribution
134
+ sample_probs = torch.where(mask, sorted_probs, torch.zeros_like(sorted_probs))
135
+ sample_probs = sample_probs / sample_probs.sum(dim=-1, keepdim=True)
136
+
137
+ # Sample indices from the filtered distribution
138
+ sampled_indices = torch.multinomial(sample_probs, 1).squeeze(-1)
139
+ top_ids = torch.gather(sorted_indices, -1, sampled_indices.unsqueeze(-1)).squeeze(-1)
140
+
141
+ # 2. Check for repetitions and apply random sampling if needed
142
+ # Extract recent tokens for each sequence, handling empty or short sequences
143
+ recent_tokens_list = []
144
+ for i in range(batch_size):
145
+ # Handle index out of range or empty tokens
146
+ if i < len(decoded_tokens_list):
147
+ tokens = decoded_tokens_list[i]
148
+ if len(tokens) > 0:
149
+ start_idx = max(0, len(tokens) - win_size)
150
+ recent_tokens_list.append(tokens[start_idx:])
151
+ else:
152
+ recent_tokens_list.append([]) # Empty list for empty tokens
153
+ else:
154
+ recent_tokens_list.append([]) # Empty list for missing batch items
155
+
156
+ # Check if we have any tokens to process for repetition detection
157
+ if any(len(tokens) > 0 for tokens in recent_tokens_list):
158
+ # Convert to padded tensor for batch processing
159
+ max_recent_len = max(len(tokens) for tokens in recent_tokens_list)
160
+ if max_recent_len > 0: # Only proceed if we have tokens
161
+ recent_tokens_tensor = torch.zeros((batch_size, max_recent_len), dtype=torch.long, device=device) - 1
162
+ for i, tokens in enumerate(recent_tokens_list):
163
+ if len(tokens) > 0:
164
+ recent_tokens_tensor[i, -len(tokens):] = torch.tensor(tokens, device=device)
165
+
166
+ # Create a mask for valid positions and to avoid division by zero
167
+ valid_positions_mask = torch.zeros_like(recent_tokens_tensor, dtype=torch.bool)
168
+ for i, tokens in enumerate(recent_tokens_list):
169
+ if len(tokens) > 0:
170
+ valid_positions_mask[i, -len(tokens):] = True
171
+
172
+ # Check repetition rates
173
+ repetition_counts = torch.zeros(batch_size, device=device)
174
+ for i in range(batch_size):
175
+ if len(recent_tokens_list[i]) > 0:
176
+ repetition_counts[i] = (recent_tokens_tensor[i] == top_ids[i]).sum()
177
+
178
+ # Calculate repetition rates, avoiding division by zero
179
+ recent_lengths = torch.tensor([max(1, len(tokens)) for tokens in recent_tokens_list], device=device)
180
+ repetition_rates = repetition_counts / recent_lengths
181
+
182
+ # Identify sequences needing random sampling
183
+ need_random = repetition_rates >= tau_r
184
+
185
+ # Apply random sampling where needed
186
+ if need_random.any():
187
+ random_indices = torch.multinomial(probs[need_random], 1).squeeze(-1)
188
+ top_ids[need_random] = random_indices
189
+
190
+ # 3. Handle EOS tokens
191
+ # Create mask for sequences that should ignore EOS tokens
192
+ ignore_eos_mask = torch.zeros(batch_size, dtype=torch.bool, device=device)
193
+ for i in range(batch_size):
194
+ if i < len(decoded_tokens_list):
195
+ ignore_eos_mask[i] = len(decoded_tokens_list[i]) < min_tokens[i]
196
+ else:
197
+ ignore_eos_mask[i] = True # Default to ignoring EOS for missing sequences
198
+
199
+ is_eos_mask = top_ids == eos_token
200
+ need_resample = ignore_eos_mask & is_eos_mask
201
+
202
+ # Resample for sequences that need it
203
+ if need_resample.any():
204
+ max_trials = 100
205
+ for attempt in range(max_trials):
206
+ # Break if no more resampling needed
207
+ if not need_resample.any():
208
+ break
209
+
210
+ # Sample new tokens for sequences that need resampling
211
+ new_samples = torch.multinomial(probs[need_resample], 1).squeeze(-1)
212
+
213
+ # Update top_ids with new samples
214
+ top_ids[need_resample] = new_samples
215
+
216
+ # Update which sequences still need resampling
217
+ is_eos_mask = top_ids == eos_token
218
+ need_resample = ignore_eos_mask & is_eos_mask
219
+
220
+ # If still have EOS tokens that should be ignored, force them to be non-EOS
221
+ if need_resample.any():
222
+ # Force to a non-EOS token (e.g., the second most likely token)
223
+ for i in range(batch_size):
224
+ if need_resample[i]:
225
+ # Get second most likely token (or first if only one token)
226
+ second_best_idx = 1 if sorted_indices.size(1) > 1 else 0
227
+ top_ids[i] = sorted_indices[i, second_best_idx]
228
+
229
+ result = top_ids
230
+
231
+ return result
File without changes
@@ -0,0 +1,77 @@
1
+ import numpy as np
2
+ import torch
3
+ from librosa.filters import mel as librosa_mel_fn
4
+ from scipy.io.wavfile import read
5
+
6
+ MAX_WAV_VALUE = 32768.0
7
+
8
+
9
+ def load_wav(full_path):
10
+ sampling_rate, data = read(full_path)
11
+ return data, sampling_rate
12
+
13
+
14
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
15
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
16
+
17
+
18
+ def dynamic_range_decompression(x, C=1):
19
+ return np.exp(x) / C
20
+
21
+
22
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
23
+ return torch.log(torch.clamp(x, min=clip_val) * C)
24
+
25
+
26
+ def dynamic_range_decompression_torch(x, C=1):
27
+ return torch.exp(x) / C
28
+
29
+
30
+ def spectral_normalize_torch(magnitudes):
31
+ output = dynamic_range_compression_torch(magnitudes)
32
+ return output
33
+
34
+
35
+ def spectral_de_normalize_torch(magnitudes):
36
+ output = dynamic_range_decompression_torch(magnitudes)
37
+ return output
38
+
39
+
40
+ mel_basis = {}
41
+ hann_window = {}
42
+
43
+
44
+ def mel_spectrogram(y, n_fft=1920, num_mels=80, sampling_rate=24000, hop_size=480,
45
+ win_size=1920, fmin=0, fmax=8000, center=False):
46
+ global mel_basis, hann_window # pylint: disable=global-statement
47
+ if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
48
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
49
+ mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
50
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
51
+
52
+ y = torch.nn.functional.pad(
53
+ y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
54
+ )
55
+ y = y.squeeze(1)
56
+
57
+ spec = torch.view_as_real(
58
+ torch.stft(
59
+ y,
60
+ n_fft,
61
+ hop_length=hop_size,
62
+ win_length=win_size,
63
+ window=hann_window[str(y.device)],
64
+ center=center,
65
+ pad_mode="reflect",
66
+ normalized=False,
67
+ onesided=True,
68
+ return_complex=True,
69
+ )
70
+ )
71
+
72
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
73
+
74
+ spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
75
+ spec = spectral_normalize_torch(spec)
76
+
77
+ return spec
@@ -0,0 +1,28 @@
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+
5
+
6
+ @dataclass
7
+ class Context:
8
+ is_prefill: bool = False
9
+ cu_seqlens_q: torch.Tensor | None = None
10
+ cu_seqlens_k: torch.Tensor | None = None
11
+ max_seqlen_q: int = 0
12
+ max_seqlen_k: int = 0
13
+ slot_mapping: torch.Tensor | None = None
14
+ context_lens: torch.Tensor | None = None
15
+ block_tables: torch.Tensor | None = None
16
+
17
+ _CONTEXT = Context()
18
+
19
+ def get_context():
20
+ return _CONTEXT
21
+
22
+ def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None):
23
+ global _CONTEXT
24
+ _CONTEXT = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)
25
+
26
+ def reset_context():
27
+ global _CONTEXT
28
+ _CONTEXT = Context()
@@ -0,0 +1,116 @@
1
+ import os
2
+ from glob import glob
3
+
4
+ import torch
5
+ from safetensors import safe_open
6
+ from torch import nn
7
+
8
+ from stepaudio2.flashcosyvoice.config import CosyVoice2LLMConfig
9
+
10
+
11
+ def default_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor):
12
+ param.data.copy_(loaded_weight)
13
+
14
+
15
+ def load_text_llm(model: nn.Module, path: str):
16
+ packed_modules_mapping = getattr(model, "packed_modules_mapping", {})
17
+ for file in glob(os.path.join(path, "*.safetensors")):
18
+ with safe_open(file, "pt", "cpu") as f:
19
+ for weight_name in f.keys():
20
+ for k in packed_modules_mapping:
21
+ if k in weight_name:
22
+ v, shard_id = packed_modules_mapping[k]
23
+ param_name = weight_name.replace(k, v)
24
+ param = model.get_parameter(param_name)
25
+ weight_loader = param.weight_loader
26
+ weight_loader(param, f.get_tensor(weight_name), shard_id)
27
+ break
28
+ else:
29
+ param = model.get_parameter(weight_name)
30
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
31
+ weight_loader(param, f.get_tensor(weight_name))
32
+
33
+
34
+ def load_speech_llm(model: nn.Module, path: str, hf_config: CosyVoice2LLMConfig):
35
+ packed_modules_mapping = getattr(model, "packed_modules_mapping", {})
36
+
37
+ # NOTE(xcsong): 1. load speech embedding + sos/taskid embedding + lm head
38
+ embedding_weights = {}
39
+ tmp_weights = torch.load(f"{path}/llm.pt", map_location="cpu", weights_only=True)
40
+ missed, missed_names = 0, []
41
+ for k, v in tmp_weights.items():
42
+ if k == "speech_embedding.weight": # torch.Size([6564, 896])
43
+ speech_embedding_size = hf_config.speech_vocab_size # 6562
44
+ # NOTE(xcsong): padding to 6592 for vllm tensor parallel
45
+ if speech_embedding_size != v.shape[0]: # [6564, 896] -> [6562, 896]
46
+ assert speech_embedding_size <= v.shape[0], f"speech_embedding_size should be less than or equal to {v.shape[0]}, but got {speech_embedding_size}"
47
+ v = v[:speech_embedding_size, :]
48
+ embedding_weights["speech_embedding.weight"] = v
49
+ elif k == "llm_embedding.weight": # torch.Size([2, 896]), eos and task_id
50
+ assert v.shape[0] == 2, f"llm_embedding.weight should be of shape [2, 896], but got {v.shape}"
51
+ embedding_weights["llm_embedding.weight"] = v
52
+ elif k == "llm.model.model.embed_tokens.weight": # torch.Size([151936, 896])
53
+ embedding_weights["model.embed_tokens.weight"] = v
54
+ elif k == "llm_decoder.weight": # torch.Size([6564, 896])
55
+ lm_head_size = hf_config.speech_vocab_size # 6562
56
+ if lm_head_size != v.shape[0]: # [6564, 896] -> [6562, 896]
57
+ assert lm_head_size <= v.shape[0], f"lm_head_size should be less than or equal to {v.shape[0]}, but got {lm_head_size}"
58
+ v = v[:lm_head_size, :]
59
+ param = model.get_parameter("lm_head.weight")
60
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
61
+ weight_loader(param, v)
62
+ elif k == "llm_decoder.bias": # torch.Size([6564])
63
+ lm_head_size = hf_config.speech_vocab_size # 6562
64
+ if lm_head_size != v.shape[0]: # [6564] -> [6562]
65
+ assert lm_head_size <= v.shape[0], f"lm_head_size should be less than or equal to {v.shape[0]}, but got {lm_head_size}"
66
+ v = v[:lm_head_size]
67
+ param = model.get_parameter("lm_head.bias")
68
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
69
+ weight_loader(param, v)
70
+ elif "llm.model." in k:
71
+ weight_name = k.replace("llm.model.", "")
72
+ for kk in packed_modules_mapping:
73
+ if kk in weight_name:
74
+ vv, shard_id = packed_modules_mapping[kk]
75
+ param_name = weight_name.replace(kk, vv)
76
+ try:
77
+ param = model.get_parameter(param_name)
78
+ weight_loader = param.weight_loader
79
+ weight_loader(param, v, shard_id)
80
+ break
81
+ except Exception as e:
82
+ print(e)
83
+ print(f"skip parameter (1): {weight_name}")
84
+ continue
85
+ else:
86
+ try:
87
+ param = model.get_parameter(weight_name)
88
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
89
+ weight_loader(param, v)
90
+ except Exception as e:
91
+ print(e)
92
+ print(f"skip parameter (2): {weight_name}")
93
+ continue
94
+ else:
95
+ missed += 1
96
+ missed_names.append(weight_name)
97
+ continue
98
+ print(f"missed {missed} parameters: {missed_names}")
99
+
100
+ # NOTE(xcsong): 2. merge text embedding, sos/taskid embedding, and speech embedding
101
+ text_embedding_weight = embedding_weights["model.embed_tokens.weight"].cpu() # [151936, 896]
102
+ sos_taskid_embedding_weight = embedding_weights["llm_embedding.weight"].cpu() # [2, 896]
103
+ speech_embedding_weight = embedding_weights["speech_embedding.weight"].cpu() # [6562, 896]
104
+ final_embedding_weight = torch.cat([speech_embedding_weight, sos_taskid_embedding_weight, text_embedding_weight], dim=0) # [158500, 896]
105
+ param = model.get_parameter("model.embed_tokens.weight")
106
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
107
+ weight_loader(param, final_embedding_weight)
108
+
109
+
110
+ def load_model(model: nn.Module, path: str, hf_config: CosyVoice2LLMConfig | None = None):
111
+ if model.model_type == "speech_llm":
112
+ load_speech_llm(model, path, hf_config)
113
+ elif model.model_type == "text_llm":
114
+ load_text_llm(model, path)
115
+ else:
116
+ raise ValueError(f"Unsupported model type: {model.model_type}")
@@ -0,0 +1,19 @@
1
+ import os
2
+
3
+ import torch
4
+ from pynvml import * # noqa
5
+
6
+
7
+ def get_gpu_memory():
8
+ torch.cuda.synchronize()
9
+ nvmlInit()
10
+ visible_device = list(map(int, os.getenv("CUDA_VISIBLE_DEVICES", "0,1,2,3,4,5,6,7").split(',')))
11
+ cuda_device_idx = torch.cuda.current_device()
12
+ cuda_device_idx = visible_device[cuda_device_idx]
13
+ handle = nvmlDeviceGetHandleByIndex(cuda_device_idx)
14
+ mem_info = nvmlDeviceGetMemoryInfo(handle)
15
+ total_memory = mem_info.total
16
+ used_memory = mem_info.used
17
+ free_memory = mem_info.free
18
+ nvmlShutdown()
19
+ return total_memory, used_memory, free_memory
@@ -0,0 +1,204 @@
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
3
+
4
+ from stepaudio2.utils import compute_token_num, load_audio, log_mel_spectrogram, padding_mels
5
+
6
+
7
+ class StepAudio2Base:
8
+
9
+ def __init__(self, model_path: str):
10
+ self.llm_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side="right")
11
+ self.llm = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16).cuda()
12
+ self.eos_token_id = self.llm_tokenizer.eos_token_id
13
+
14
+ def __call__(self, messages: list, **kwargs):
15
+ messages, mels = self.apply_chat_template(messages)
16
+
17
+ # Tokenize prompts
18
+ prompt_ids = []
19
+ for msg in messages:
20
+ if isinstance(msg, str):
21
+ prompt_ids.append(self.llm_tokenizer(text=msg, return_tensors="pt", padding=True)["input_ids"])
22
+ elif isinstance(msg, list):
23
+ prompt_ids.append(torch.tensor([msg], dtype=torch.int32))
24
+ else:
25
+ raise ValueError(f"Unsupported content type: {type(msg)}")
26
+ prompt_ids = torch.cat(prompt_ids, dim=-1).cuda()
27
+ attention_mask = torch.ones_like(prompt_ids)
28
+
29
+ #mels = None if len(mels) == 0 else torch.stack(mels).cuda()
30
+ #mel_lengths = None if mels is None else torch.tensor([mel.shape[1] - 2 for mel in mels], dtype=torch.int32, device='cuda')
31
+ if len(mels)==0:
32
+ mels = None
33
+ mel_lengths = None
34
+ else:
35
+ mels, mel_lengths = padding_mels(mels)
36
+ mels = mels.cuda()
37
+ mel_lengths = mel_lengths.cuda()
38
+
39
+ generate_inputs = {
40
+ "input_ids": prompt_ids,
41
+ "wavs": mels,
42
+ "wav_lens": mel_lengths,
43
+ "attention_mask":attention_mask
44
+ }
45
+
46
+ generation_config = dict(max_new_tokens=2048,
47
+ pad_token_id=self.llm_tokenizer.pad_token_id,
48
+ eos_token_id=self.eos_token_id,
49
+ )
50
+ generation_config.update(kwargs)
51
+ generation_config = GenerationConfig(**generation_config)
52
+
53
+ outputs = self.llm.generate(**generate_inputs, generation_config=generation_config, tokenizer=self.llm_tokenizer)
54
+ output_token_ids = outputs[0, prompt_ids.shape[-1] : -1].tolist()
55
+ output_text_tokens = [i for i in output_token_ids if i < 151688]
56
+ output_audio_tokens = [i - 151696 for i in output_token_ids if i > 151695]
57
+ output_text = self.llm_tokenizer.decode(output_text_tokens)
58
+ return output_token_ids, output_text, output_audio_tokens
59
+
60
+ def apply_chat_template(self, messages: list):
61
+ results = []
62
+ mels = []
63
+ for msg in messages:
64
+ content = msg
65
+ if isinstance(content, str):
66
+ text_with_audio = content
67
+ results.append(text_with_audio)
68
+ elif isinstance(content, dict):
69
+ if content["type"] == "text":
70
+ results.append(f"{content['text']}")
71
+ elif content["type"] == "audio":
72
+ audio = load_audio(content['audio'])
73
+ for i in range(0, audio.shape[0], 16000 * 25):
74
+ mel = log_mel_spectrogram(audio[i:i+16000*25], n_mels=128, padding=479)
75
+ mels.append(mel)
76
+ audio_tokens = "<audio_patch>" * compute_token_num(mel.shape[1])
77
+ results.append(f"<audio_start>{audio_tokens}<audio_end>")
78
+ elif content["type"] == "token":
79
+ results.append(content["token"])
80
+ else:
81
+ raise ValueError(f"Unsupported content type: {type(content)}")
82
+ # print(results)
83
+ return results, mels
84
+
85
+
86
+ class StepAudio2(StepAudio2Base):
87
+
88
+ def __init__(self, model_path: str):
89
+ super().__init__(model_path)
90
+ self.llm_tokenizer.eos_token = "<|EOT|>"
91
+ self.llm.config.eos_token_id = self.llm_tokenizer.convert_tokens_to_ids("<|EOT|>")
92
+ self.eos_token_id = self.llm_tokenizer.convert_tokens_to_ids("<|EOT|>")
93
+
94
+ def apply_chat_template(self, messages: list):
95
+ results = []
96
+ mels = []
97
+ for msg in messages:
98
+ role = msg["role"]
99
+ content = msg["content"]
100
+ if role == "user":
101
+ role = "human"
102
+ if isinstance(content, str):
103
+ text_with_audio = f"<|BOT|>{role}\n{content}"
104
+ text_with_audio += '<|EOT|>' if msg.get('eot', True) else ''
105
+ results.append(text_with_audio)
106
+ elif isinstance(content, list):
107
+ results.append(f"<|BOT|>{role}\n")
108
+ for item in content:
109
+ if item["type"] == "text":
110
+ results.append(f"{item['text']}")
111
+ elif item["type"] == "audio":
112
+ audio = load_audio(item['audio'])
113
+ for i in range(0, audio.shape[0], 16000 * 25):
114
+ mel = log_mel_spectrogram(audio[i:i+16000*25], n_mels=128, padding=479)
115
+ mels.append(mel)
116
+ audio_tokens = "<audio_patch>" * compute_token_num(mel.shape[1])
117
+ results.append(f"<audio_start>{audio_tokens}<audio_end>")
118
+ elif item["type"] == "token":
119
+ results.append(item["token"])
120
+ if msg.get('eot', True):
121
+ results.append('<|EOT|>')
122
+ elif content is None:
123
+ results.append(f"<|BOT|>{role}\n")
124
+ else:
125
+ raise ValueError(f"Unsupported content type: {type(content)}")
126
+ # print(results)
127
+ return results, mels
128
+
129
+ if __name__ == '__main__':
130
+ from stepaudio2.token2wav import Token2wav
131
+
132
+ model = StepAudio2('Step-Audio-2-mini')
133
+ token2wav = Token2wav('Step-Audio-2-mini/token2wav')
134
+
135
+ # Text-to-text conversation
136
+ print()
137
+ messages = [
138
+ {"role": "system", "content": "You are a helpful assistant."},
139
+ {"role": "human", "content": "Give me a brief introduction to the Great Wall."},
140
+ {"role": "assistant", "content": None}
141
+ ]
142
+ tokens, text, _ = model(messages, max_new_tokens=256, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True)
143
+ print(text)
144
+
145
+ # Text-to-speech conversation
146
+ print()
147
+ messages = [
148
+ {"role": "system", "content": "You are a helpful assistant."},
149
+ {"role": "human", "content": "Give me a brief introduction to the Great Wall."},
150
+ {"role": "assistant", "content": "<tts_start>", "eot": False}, # Insert <tts_start> for speech response
151
+ ]
152
+ tokens, text, audio = model(messages, max_new_tokens=4096, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True)
153
+ print(text)
154
+ print(tokens)
155
+ audio = token2wav(audio, prompt_wav='assets/default_male.wav')
156
+ with open('output-male.wav', 'wb') as f:
157
+ f.write(audio)
158
+
159
+ # Speech-to-text conversation
160
+ print()
161
+ messages = [
162
+ {"role": "system", "content": "You are a helpful assistant."},
163
+ {"role": "human", "content": [{"type": "audio", "audio": "assets/give_me_a_brief_introduction_to_the_great_wall.wav"}]},
164
+ {"role": "assistant", "content": None}
165
+ ]
166
+ tokens, text, _ = model(messages, max_new_tokens=256, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True)
167
+ print(text)
168
+
169
+ # Speech-to-speech conversation
170
+ print()
171
+ messages = [
172
+ {"role": "system", "content": "You are a helpful assistant."},
173
+ {"role": "human", "content": [{"type": "audio", "audio": "assets/give_me_a_brief_introduction_to_the_great_wall.wav"}]},
174
+ {"role": "assistant", "content": "<tts_start>", "eot": False}, # Insert <tts_start> for speech response
175
+ ]
176
+ tokens, text, audio = model(messages, max_new_tokens=4096, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True)
177
+ print(text)
178
+ print(tokens)
179
+ audio = token2wav(audio, prompt_wav='assets/default_female.wav')
180
+ with open('output-female.wav', 'wb') as f:
181
+ f.write(audio)
182
+
183
+ # Multi-turn conversation
184
+ print()
185
+ messages.pop(-1)
186
+ messages += [
187
+ {"role": "assistant", "content": [{"type": "text", "text": "<tts_start>"},
188
+ {"type": "token", "token": tokens}]},
189
+ {"role": "human", "content": "Now write a 4-line poem about it."},
190
+ {"role": "assistant", "content": None}
191
+ ]
192
+ tokens, text, audio = model(messages, max_new_tokens=256, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True)
193
+ print(text)
194
+
195
+ # Multi-modal inputs
196
+ print()
197
+ messages = [
198
+ {"role": "system", "content": "You are a helpful assistant."},
199
+ {"role": "human", "content": [{"type": "text", "text": "Translate the speech into Chinese."},
200
+ {"type": "audio", "audio": "assets/give_me_a_brief_introduction_to_the_great_wall.wav"}]},
201
+ {"role": "assistant", "content": None}
202
+ ]
203
+ tokens, text, audio = model(messages, max_new_tokens=256, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True)
204
+ print(text)