minicpmo-utils 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cosyvoice/__init__.py +17 -0
- cosyvoice/bin/average_model.py +93 -0
- cosyvoice/bin/export_jit.py +103 -0
- cosyvoice/bin/export_onnx.py +120 -0
- cosyvoice/bin/inference_deprecated.py +126 -0
- cosyvoice/bin/train.py +195 -0
- cosyvoice/cli/__init__.py +0 -0
- cosyvoice/cli/cosyvoice.py +209 -0
- cosyvoice/cli/frontend.py +238 -0
- cosyvoice/cli/model.py +386 -0
- cosyvoice/dataset/__init__.py +0 -0
- cosyvoice/dataset/dataset.py +151 -0
- cosyvoice/dataset/processor.py +434 -0
- cosyvoice/flow/decoder.py +494 -0
- cosyvoice/flow/flow.py +281 -0
- cosyvoice/flow/flow_matching.py +227 -0
- cosyvoice/flow/length_regulator.py +70 -0
- cosyvoice/hifigan/discriminator.py +230 -0
- cosyvoice/hifigan/f0_predictor.py +58 -0
- cosyvoice/hifigan/generator.py +582 -0
- cosyvoice/hifigan/hifigan.py +67 -0
- cosyvoice/llm/llm.py +610 -0
- cosyvoice/tokenizer/assets/multilingual_zh_ja_yue_char_del.tiktoken +58836 -0
- cosyvoice/tokenizer/tokenizer.py +279 -0
- cosyvoice/transformer/__init__.py +0 -0
- cosyvoice/transformer/activation.py +84 -0
- cosyvoice/transformer/attention.py +330 -0
- cosyvoice/transformer/convolution.py +145 -0
- cosyvoice/transformer/decoder.py +396 -0
- cosyvoice/transformer/decoder_layer.py +132 -0
- cosyvoice/transformer/embedding.py +302 -0
- cosyvoice/transformer/encoder.py +474 -0
- cosyvoice/transformer/encoder_layer.py +236 -0
- cosyvoice/transformer/label_smoothing_loss.py +96 -0
- cosyvoice/transformer/positionwise_feed_forward.py +115 -0
- cosyvoice/transformer/subsampling.py +383 -0
- cosyvoice/transformer/upsample_encoder.py +320 -0
- cosyvoice/utils/__init__.py +0 -0
- cosyvoice/utils/class_utils.py +83 -0
- cosyvoice/utils/common.py +186 -0
- cosyvoice/utils/executor.py +176 -0
- cosyvoice/utils/file_utils.py +129 -0
- cosyvoice/utils/frontend_utils.py +136 -0
- cosyvoice/utils/losses.py +57 -0
- cosyvoice/utils/mask.py +265 -0
- cosyvoice/utils/scheduler.py +738 -0
- cosyvoice/utils/train_utils.py +367 -0
- cosyvoice/vllm/cosyvoice2.py +103 -0
- matcha/__init__.py +0 -0
- matcha/app.py +357 -0
- matcha/cli.py +418 -0
- matcha/hifigan/__init__.py +0 -0
- matcha/hifigan/config.py +28 -0
- matcha/hifigan/denoiser.py +64 -0
- matcha/hifigan/env.py +17 -0
- matcha/hifigan/meldataset.py +217 -0
- matcha/hifigan/models.py +368 -0
- matcha/hifigan/xutils.py +60 -0
- matcha/models/__init__.py +0 -0
- matcha/models/baselightningmodule.py +209 -0
- matcha/models/components/__init__.py +0 -0
- matcha/models/components/decoder.py +443 -0
- matcha/models/components/flow_matching.py +132 -0
- matcha/models/components/text_encoder.py +410 -0
- matcha/models/components/transformer.py +316 -0
- matcha/models/matcha_tts.py +239 -0
- matcha/onnx/__init__.py +0 -0
- matcha/onnx/export.py +181 -0
- matcha/onnx/infer.py +168 -0
- matcha/text/__init__.py +53 -0
- matcha/text/cleaners.py +116 -0
- matcha/text/numbers.py +71 -0
- matcha/text/symbols.py +17 -0
- matcha/train.py +122 -0
- matcha/utils/__init__.py +5 -0
- matcha/utils/audio.py +82 -0
- matcha/utils/generate_data_statistics.py +111 -0
- matcha/utils/instantiators.py +56 -0
- matcha/utils/logging_utils.py +53 -0
- matcha/utils/model.py +90 -0
- matcha/utils/monotonic_align/__init__.py +22 -0
- matcha/utils/monotonic_align/setup.py +7 -0
- matcha/utils/pylogger.py +21 -0
- matcha/utils/rich_utils.py +101 -0
- matcha/utils/utils.py +219 -0
- minicpmo/__init__.py +24 -0
- minicpmo/utils.py +636 -0
- minicpmo/version.py +2 -0
- minicpmo_utils-0.1.0.dist-info/METADATA +72 -0
- minicpmo_utils-0.1.0.dist-info/RECORD +148 -0
- minicpmo_utils-0.1.0.dist-info/WHEEL +5 -0
- minicpmo_utils-0.1.0.dist-info/top_level.txt +5 -0
- s3tokenizer/__init__.py +153 -0
- s3tokenizer/assets/BAC009S0764W0121.wav +0 -0
- s3tokenizer/assets/BAC009S0764W0122.wav +0 -0
- s3tokenizer/assets/mel_filters.npz +0 -0
- s3tokenizer/cli.py +183 -0
- s3tokenizer/model.py +546 -0
- s3tokenizer/model_v2.py +605 -0
- s3tokenizer/utils.py +390 -0
- stepaudio2/__init__.py +40 -0
- stepaudio2/cosyvoice2/__init__.py +1 -0
- stepaudio2/cosyvoice2/flow/__init__.py +0 -0
- stepaudio2/cosyvoice2/flow/decoder_dit.py +585 -0
- stepaudio2/cosyvoice2/flow/flow.py +230 -0
- stepaudio2/cosyvoice2/flow/flow_matching.py +205 -0
- stepaudio2/cosyvoice2/transformer/__init__.py +0 -0
- stepaudio2/cosyvoice2/transformer/attention.py +328 -0
- stepaudio2/cosyvoice2/transformer/embedding.py +119 -0
- stepaudio2/cosyvoice2/transformer/encoder_layer.py +163 -0
- stepaudio2/cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
- stepaudio2/cosyvoice2/transformer/subsampling.py +79 -0
- stepaudio2/cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
- stepaudio2/cosyvoice2/utils/__init__.py +1 -0
- stepaudio2/cosyvoice2/utils/class_utils.py +41 -0
- stepaudio2/cosyvoice2/utils/common.py +101 -0
- stepaudio2/cosyvoice2/utils/mask.py +49 -0
- stepaudio2/flashcosyvoice/__init__.py +0 -0
- stepaudio2/flashcosyvoice/cli.py +424 -0
- stepaudio2/flashcosyvoice/config.py +80 -0
- stepaudio2/flashcosyvoice/cosyvoice2.py +160 -0
- stepaudio2/flashcosyvoice/cosyvoice3.py +1 -0
- stepaudio2/flashcosyvoice/engine/__init__.py +0 -0
- stepaudio2/flashcosyvoice/engine/block_manager.py +114 -0
- stepaudio2/flashcosyvoice/engine/llm_engine.py +125 -0
- stepaudio2/flashcosyvoice/engine/model_runner.py +310 -0
- stepaudio2/flashcosyvoice/engine/scheduler.py +77 -0
- stepaudio2/flashcosyvoice/engine/sequence.py +90 -0
- stepaudio2/flashcosyvoice/modules/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/flow.py +198 -0
- stepaudio2/flashcosyvoice/modules/flow_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/flow_components/estimator.py +974 -0
- stepaudio2/flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
- stepaudio2/flashcosyvoice/modules/hifigan.py +249 -0
- stepaudio2/flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/hifigan_components/layers.py +433 -0
- stepaudio2/flashcosyvoice/modules/qwen2.py +92 -0
- stepaudio2/flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
- stepaudio2/flashcosyvoice/modules/qwen2_components/layers.py +616 -0
- stepaudio2/flashcosyvoice/modules/sampler.py +231 -0
- stepaudio2/flashcosyvoice/utils/__init__.py +0 -0
- stepaudio2/flashcosyvoice/utils/audio.py +77 -0
- stepaudio2/flashcosyvoice/utils/context.py +28 -0
- stepaudio2/flashcosyvoice/utils/loader.py +116 -0
- stepaudio2/flashcosyvoice/utils/memory.py +19 -0
- stepaudio2/stepaudio2.py +204 -0
- stepaudio2/token2wav.py +248 -0
- stepaudio2/utils.py +91 -0
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Sampler(nn.Module):
|
|
6
|
+
"""
|
|
7
|
+
Optimized sampler implementation using vectorized operations instead of loops, significantly improving performance
|
|
8
|
+
|
|
9
|
+
Performance optimizations:
|
|
10
|
+
1. Using batch processing instead of sequence loops, reducing Python loop overhead
|
|
11
|
+
2. Using PyTorch's vectorized operations (like torch.sort, torch.gather) for parallel computation
|
|
12
|
+
3. Using mask operations to apply top-k filtering at once, avoiding per-sequence processing
|
|
13
|
+
"""
|
|
14
|
+
def __init__(self):
|
|
15
|
+
super().__init__()
|
|
16
|
+
|
|
17
|
+
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor, top_k: int = None):
|
|
18
|
+
"""
|
|
19
|
+
Perform sampling operation using vectorized method for top-k filtering
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
logits: Logits tensor with shape [batch_size, vocab_size]
|
|
23
|
+
temperatures: Temperature parameters with shape [batch_size]
|
|
24
|
+
top_k: Top-k value for filtering (uniform across all sequences)
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
Sampled token IDs
|
|
28
|
+
"""
|
|
29
|
+
logits = logits.to(torch.float)
|
|
30
|
+
greedy_tokens = logits.argmax(dim=-1) # Greedy decoding result, used when temperature=0
|
|
31
|
+
logits.div_(temperatures.unsqueeze(dim=1)) # Apply temperature scaling
|
|
32
|
+
|
|
33
|
+
# Apply uniform top-k filtering if top_k is provided
|
|
34
|
+
if top_k is not None and top_k > 0:
|
|
35
|
+
vocab_size = logits.size(-1)
|
|
36
|
+
|
|
37
|
+
# Create a mask to store which positions should be kept
|
|
38
|
+
mask = torch.zeros_like(logits, dtype=torch.bool)
|
|
39
|
+
|
|
40
|
+
# Batch sorting for all sequences at once
|
|
41
|
+
sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True)
|
|
42
|
+
|
|
43
|
+
# Get threshold for each sequence (the k-th largest value)
|
|
44
|
+
k_value = min(top_k, vocab_size) # Ensure k doesn't exceed vocab size
|
|
45
|
+
thresholds = sorted_logits[:, k_value-1:k_value] # Shape [batch_size, 1]
|
|
46
|
+
thresholds = thresholds.expand(-1, vocab_size) # Expand to match logits shape
|
|
47
|
+
|
|
48
|
+
# Create mask: only keep logits greater than or equal to threshold
|
|
49
|
+
mask = logits >= thresholds
|
|
50
|
+
|
|
51
|
+
# Apply mask: set logits not in top-k to negative infinity
|
|
52
|
+
logits = torch.where(mask, logits, torch.tensor(float('-inf'), device=logits.device))
|
|
53
|
+
|
|
54
|
+
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
|
55
|
+
# logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
|
56
|
+
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
|
|
57
|
+
return torch.where(temperatures == 0, greedy_tokens, sample_tokens)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class RasSampler(nn.Module):
|
|
61
|
+
"""
|
|
62
|
+
Optimized Repetition Aware Sampling implementation
|
|
63
|
+
|
|
64
|
+
Performance optimizations:
|
|
65
|
+
1. Using vectorized nucleus sampling instead of loop implementation, improving sampling efficiency
|
|
66
|
+
2. Using tensor operations to calculate repetition rate, reducing Python loop overhead
|
|
67
|
+
3. Optimizing EOS handling logic, reducing unnecessary resampling
|
|
68
|
+
4. Using PyTorch's vectorized operations for parallel computation
|
|
69
|
+
5. Batch processing for all sequences, dramatically improving throughput
|
|
70
|
+
6. Robust handling for sequences of any length, including empty sequences
|
|
71
|
+
"""
|
|
72
|
+
def __init__(self):
|
|
73
|
+
super().__init__()
|
|
74
|
+
|
|
75
|
+
def forward(self, logits: torch.Tensor, decoded_tokens_list: list,
|
|
76
|
+
win_size: int = 10, tau_r: float = 0.1,
|
|
77
|
+
top_p: float = 0.8, top_k: int = 25,
|
|
78
|
+
eos_token: int = 6561, min_tokens: list[int] = None):
|
|
79
|
+
"""
|
|
80
|
+
Execute repetition-aware sampling using optimized vectorized operations with batch processing
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
logits: Input logits with shape [batch_size, vocab_size]
|
|
84
|
+
decoded_tokens_list: List of decoded tokens, each element is a token list for a batch
|
|
85
|
+
win_size: Window size for repetition detection (uniform across all batch items)
|
|
86
|
+
tau_r: Repetition threshold (uniform across all batch items)
|
|
87
|
+
top_p: Nucleus sampling probability threshold (uniform across all batch items)
|
|
88
|
+
top_k: Nucleus sampling top-k threshold (uniform across all batch items)
|
|
89
|
+
eos_token: End of sequence token ID (uniform across all batch items)
|
|
90
|
+
min_tokens: List of minimum tokens to generate before allowing EOS, one per batch item
|
|
91
|
+
Returns:
|
|
92
|
+
Selected token IDs
|
|
93
|
+
"""
|
|
94
|
+
batch_size = logits.size(0)
|
|
95
|
+
device = logits.device
|
|
96
|
+
result = torch.zeros(batch_size, dtype=torch.long, device=device)
|
|
97
|
+
|
|
98
|
+
# Set default values if not provided
|
|
99
|
+
if min_tokens is None:
|
|
100
|
+
min_tokens = [2] * batch_size
|
|
101
|
+
|
|
102
|
+
# Ensure min_tokens list has the correct length
|
|
103
|
+
assert len(min_tokens) == batch_size, f"min_tokens length {len(min_tokens)} != batch_size {batch_size}"
|
|
104
|
+
|
|
105
|
+
# Force continue decode first token
|
|
106
|
+
for i in range(batch_size):
|
|
107
|
+
if i < len(decoded_tokens_list) and len(decoded_tokens_list[i]) == 0:
|
|
108
|
+
logits[i, eos_token] = -float('inf')
|
|
109
|
+
|
|
110
|
+
# 1. First, perform nucleus sampling for all sequences
|
|
111
|
+
probs = torch.softmax(logits, dim=-1)
|
|
112
|
+
|
|
113
|
+
# Use vectorized nucleus sampling for all sequences
|
|
114
|
+
# This can be done in batch since top_p and top_k are uniform
|
|
115
|
+
sorted_probs, sorted_indices = probs.sort(dim=-1, descending=True)
|
|
116
|
+
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
|
117
|
+
|
|
118
|
+
# Create masks for top-p and top-k filtering
|
|
119
|
+
top_p_mask = cumulative_probs <= top_p
|
|
120
|
+
|
|
121
|
+
# Create top-k mask (first top_k positions are True)
|
|
122
|
+
top_k_mask = torch.zeros_like(top_p_mask)
|
|
123
|
+
top_k_mask[:, :top_k] = True
|
|
124
|
+
|
|
125
|
+
# Combine masks
|
|
126
|
+
mask = top_p_mask & top_k_mask
|
|
127
|
+
|
|
128
|
+
# Ensure at least one token is selected per sequence
|
|
129
|
+
first_token_mask = torch.zeros_like(mask)
|
|
130
|
+
first_token_mask[:, 0] = True
|
|
131
|
+
mask = mask | first_token_mask
|
|
132
|
+
|
|
133
|
+
# Sample from the filtered distribution
|
|
134
|
+
sample_probs = torch.where(mask, sorted_probs, torch.zeros_like(sorted_probs))
|
|
135
|
+
sample_probs = sample_probs / sample_probs.sum(dim=-1, keepdim=True)
|
|
136
|
+
|
|
137
|
+
# Sample indices from the filtered distribution
|
|
138
|
+
sampled_indices = torch.multinomial(sample_probs, 1).squeeze(-1)
|
|
139
|
+
top_ids = torch.gather(sorted_indices, -1, sampled_indices.unsqueeze(-1)).squeeze(-1)
|
|
140
|
+
|
|
141
|
+
# 2. Check for repetitions and apply random sampling if needed
|
|
142
|
+
# Extract recent tokens for each sequence, handling empty or short sequences
|
|
143
|
+
recent_tokens_list = []
|
|
144
|
+
for i in range(batch_size):
|
|
145
|
+
# Handle index out of range or empty tokens
|
|
146
|
+
if i < len(decoded_tokens_list):
|
|
147
|
+
tokens = decoded_tokens_list[i]
|
|
148
|
+
if len(tokens) > 0:
|
|
149
|
+
start_idx = max(0, len(tokens) - win_size)
|
|
150
|
+
recent_tokens_list.append(tokens[start_idx:])
|
|
151
|
+
else:
|
|
152
|
+
recent_tokens_list.append([]) # Empty list for empty tokens
|
|
153
|
+
else:
|
|
154
|
+
recent_tokens_list.append([]) # Empty list for missing batch items
|
|
155
|
+
|
|
156
|
+
# Check if we have any tokens to process for repetition detection
|
|
157
|
+
if any(len(tokens) > 0 for tokens in recent_tokens_list):
|
|
158
|
+
# Convert to padded tensor for batch processing
|
|
159
|
+
max_recent_len = max(len(tokens) for tokens in recent_tokens_list)
|
|
160
|
+
if max_recent_len > 0: # Only proceed if we have tokens
|
|
161
|
+
recent_tokens_tensor = torch.zeros((batch_size, max_recent_len), dtype=torch.long, device=device) - 1
|
|
162
|
+
for i, tokens in enumerate(recent_tokens_list):
|
|
163
|
+
if len(tokens) > 0:
|
|
164
|
+
recent_tokens_tensor[i, -len(tokens):] = torch.tensor(tokens, device=device)
|
|
165
|
+
|
|
166
|
+
# Create a mask for valid positions and to avoid division by zero
|
|
167
|
+
valid_positions_mask = torch.zeros_like(recent_tokens_tensor, dtype=torch.bool)
|
|
168
|
+
for i, tokens in enumerate(recent_tokens_list):
|
|
169
|
+
if len(tokens) > 0:
|
|
170
|
+
valid_positions_mask[i, -len(tokens):] = True
|
|
171
|
+
|
|
172
|
+
# Check repetition rates
|
|
173
|
+
repetition_counts = torch.zeros(batch_size, device=device)
|
|
174
|
+
for i in range(batch_size):
|
|
175
|
+
if len(recent_tokens_list[i]) > 0:
|
|
176
|
+
repetition_counts[i] = (recent_tokens_tensor[i] == top_ids[i]).sum()
|
|
177
|
+
|
|
178
|
+
# Calculate repetition rates, avoiding division by zero
|
|
179
|
+
recent_lengths = torch.tensor([max(1, len(tokens)) for tokens in recent_tokens_list], device=device)
|
|
180
|
+
repetition_rates = repetition_counts / recent_lengths
|
|
181
|
+
|
|
182
|
+
# Identify sequences needing random sampling
|
|
183
|
+
need_random = repetition_rates >= tau_r
|
|
184
|
+
|
|
185
|
+
# Apply random sampling where needed
|
|
186
|
+
if need_random.any():
|
|
187
|
+
random_indices = torch.multinomial(probs[need_random], 1).squeeze(-1)
|
|
188
|
+
top_ids[need_random] = random_indices
|
|
189
|
+
|
|
190
|
+
# 3. Handle EOS tokens
|
|
191
|
+
# Create mask for sequences that should ignore EOS tokens
|
|
192
|
+
ignore_eos_mask = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
|
193
|
+
for i in range(batch_size):
|
|
194
|
+
if i < len(decoded_tokens_list):
|
|
195
|
+
ignore_eos_mask[i] = len(decoded_tokens_list[i]) < min_tokens[i]
|
|
196
|
+
else:
|
|
197
|
+
ignore_eos_mask[i] = True # Default to ignoring EOS for missing sequences
|
|
198
|
+
|
|
199
|
+
is_eos_mask = top_ids == eos_token
|
|
200
|
+
need_resample = ignore_eos_mask & is_eos_mask
|
|
201
|
+
|
|
202
|
+
# Resample for sequences that need it
|
|
203
|
+
if need_resample.any():
|
|
204
|
+
max_trials = 100
|
|
205
|
+
for attempt in range(max_trials):
|
|
206
|
+
# Break if no more resampling needed
|
|
207
|
+
if not need_resample.any():
|
|
208
|
+
break
|
|
209
|
+
|
|
210
|
+
# Sample new tokens for sequences that need resampling
|
|
211
|
+
new_samples = torch.multinomial(probs[need_resample], 1).squeeze(-1)
|
|
212
|
+
|
|
213
|
+
# Update top_ids with new samples
|
|
214
|
+
top_ids[need_resample] = new_samples
|
|
215
|
+
|
|
216
|
+
# Update which sequences still need resampling
|
|
217
|
+
is_eos_mask = top_ids == eos_token
|
|
218
|
+
need_resample = ignore_eos_mask & is_eos_mask
|
|
219
|
+
|
|
220
|
+
# If still have EOS tokens that should be ignored, force them to be non-EOS
|
|
221
|
+
if need_resample.any():
|
|
222
|
+
# Force to a non-EOS token (e.g., the second most likely token)
|
|
223
|
+
for i in range(batch_size):
|
|
224
|
+
if need_resample[i]:
|
|
225
|
+
# Get second most likely token (or first if only one token)
|
|
226
|
+
second_best_idx = 1 if sorted_indices.size(1) > 1 else 0
|
|
227
|
+
top_ids[i] = sorted_indices[i, second_best_idx]
|
|
228
|
+
|
|
229
|
+
result = top_ids
|
|
230
|
+
|
|
231
|
+
return result
|
|
File without changes
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
from librosa.filters import mel as librosa_mel_fn
|
|
4
|
+
from scipy.io.wavfile import read
|
|
5
|
+
|
|
6
|
+
MAX_WAV_VALUE = 32768.0
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def load_wav(full_path):
|
|
10
|
+
sampling_rate, data = read(full_path)
|
|
11
|
+
return data, sampling_rate
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
|
15
|
+
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def dynamic_range_decompression(x, C=1):
|
|
19
|
+
return np.exp(x) / C
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
|
23
|
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def dynamic_range_decompression_torch(x, C=1):
|
|
27
|
+
return torch.exp(x) / C
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def spectral_normalize_torch(magnitudes):
|
|
31
|
+
output = dynamic_range_compression_torch(magnitudes)
|
|
32
|
+
return output
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def spectral_de_normalize_torch(magnitudes):
|
|
36
|
+
output = dynamic_range_decompression_torch(magnitudes)
|
|
37
|
+
return output
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
mel_basis = {}
|
|
41
|
+
hann_window = {}
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def mel_spectrogram(y, n_fft=1920, num_mels=80, sampling_rate=24000, hop_size=480,
|
|
45
|
+
win_size=1920, fmin=0, fmax=8000, center=False):
|
|
46
|
+
global mel_basis, hann_window # pylint: disable=global-statement
|
|
47
|
+
if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
|
|
48
|
+
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
|
49
|
+
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
|
50
|
+
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
|
51
|
+
|
|
52
|
+
y = torch.nn.functional.pad(
|
|
53
|
+
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
|
|
54
|
+
)
|
|
55
|
+
y = y.squeeze(1)
|
|
56
|
+
|
|
57
|
+
spec = torch.view_as_real(
|
|
58
|
+
torch.stft(
|
|
59
|
+
y,
|
|
60
|
+
n_fft,
|
|
61
|
+
hop_length=hop_size,
|
|
62
|
+
win_length=win_size,
|
|
63
|
+
window=hann_window[str(y.device)],
|
|
64
|
+
center=center,
|
|
65
|
+
pad_mode="reflect",
|
|
66
|
+
normalized=False,
|
|
67
|
+
onesided=True,
|
|
68
|
+
return_complex=True,
|
|
69
|
+
)
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
|
73
|
+
|
|
74
|
+
spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
|
|
75
|
+
spec = spectral_normalize_torch(spec)
|
|
76
|
+
|
|
77
|
+
return spec
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class Context:
|
|
8
|
+
is_prefill: bool = False
|
|
9
|
+
cu_seqlens_q: torch.Tensor | None = None
|
|
10
|
+
cu_seqlens_k: torch.Tensor | None = None
|
|
11
|
+
max_seqlen_q: int = 0
|
|
12
|
+
max_seqlen_k: int = 0
|
|
13
|
+
slot_mapping: torch.Tensor | None = None
|
|
14
|
+
context_lens: torch.Tensor | None = None
|
|
15
|
+
block_tables: torch.Tensor | None = None
|
|
16
|
+
|
|
17
|
+
_CONTEXT = Context()
|
|
18
|
+
|
|
19
|
+
def get_context():
|
|
20
|
+
return _CONTEXT
|
|
21
|
+
|
|
22
|
+
def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None):
|
|
23
|
+
global _CONTEXT
|
|
24
|
+
_CONTEXT = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)
|
|
25
|
+
|
|
26
|
+
def reset_context():
|
|
27
|
+
global _CONTEXT
|
|
28
|
+
_CONTEXT = Context()
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from glob import glob
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from safetensors import safe_open
|
|
6
|
+
from torch import nn
|
|
7
|
+
|
|
8
|
+
from stepaudio2.flashcosyvoice.config import CosyVoice2LLMConfig
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def default_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor):
|
|
12
|
+
param.data.copy_(loaded_weight)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def load_text_llm(model: nn.Module, path: str):
|
|
16
|
+
packed_modules_mapping = getattr(model, "packed_modules_mapping", {})
|
|
17
|
+
for file in glob(os.path.join(path, "*.safetensors")):
|
|
18
|
+
with safe_open(file, "pt", "cpu") as f:
|
|
19
|
+
for weight_name in f.keys():
|
|
20
|
+
for k in packed_modules_mapping:
|
|
21
|
+
if k in weight_name:
|
|
22
|
+
v, shard_id = packed_modules_mapping[k]
|
|
23
|
+
param_name = weight_name.replace(k, v)
|
|
24
|
+
param = model.get_parameter(param_name)
|
|
25
|
+
weight_loader = param.weight_loader
|
|
26
|
+
weight_loader(param, f.get_tensor(weight_name), shard_id)
|
|
27
|
+
break
|
|
28
|
+
else:
|
|
29
|
+
param = model.get_parameter(weight_name)
|
|
30
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
31
|
+
weight_loader(param, f.get_tensor(weight_name))
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def load_speech_llm(model: nn.Module, path: str, hf_config: CosyVoice2LLMConfig):
|
|
35
|
+
packed_modules_mapping = getattr(model, "packed_modules_mapping", {})
|
|
36
|
+
|
|
37
|
+
# NOTE(xcsong): 1. load speech embedding + sos/taskid embedding + lm head
|
|
38
|
+
embedding_weights = {}
|
|
39
|
+
tmp_weights = torch.load(f"{path}/llm.pt", map_location="cpu", weights_only=True)
|
|
40
|
+
missed, missed_names = 0, []
|
|
41
|
+
for k, v in tmp_weights.items():
|
|
42
|
+
if k == "speech_embedding.weight": # torch.Size([6564, 896])
|
|
43
|
+
speech_embedding_size = hf_config.speech_vocab_size # 6562
|
|
44
|
+
# NOTE(xcsong): padding to 6592 for vllm tensor parallel
|
|
45
|
+
if speech_embedding_size != v.shape[0]: # [6564, 896] -> [6562, 896]
|
|
46
|
+
assert speech_embedding_size <= v.shape[0], f"speech_embedding_size should be less than or equal to {v.shape[0]}, but got {speech_embedding_size}"
|
|
47
|
+
v = v[:speech_embedding_size, :]
|
|
48
|
+
embedding_weights["speech_embedding.weight"] = v
|
|
49
|
+
elif k == "llm_embedding.weight": # torch.Size([2, 896]), eos and task_id
|
|
50
|
+
assert v.shape[0] == 2, f"llm_embedding.weight should be of shape [2, 896], but got {v.shape}"
|
|
51
|
+
embedding_weights["llm_embedding.weight"] = v
|
|
52
|
+
elif k == "llm.model.model.embed_tokens.weight": # torch.Size([151936, 896])
|
|
53
|
+
embedding_weights["model.embed_tokens.weight"] = v
|
|
54
|
+
elif k == "llm_decoder.weight": # torch.Size([6564, 896])
|
|
55
|
+
lm_head_size = hf_config.speech_vocab_size # 6562
|
|
56
|
+
if lm_head_size != v.shape[0]: # [6564, 896] -> [6562, 896]
|
|
57
|
+
assert lm_head_size <= v.shape[0], f"lm_head_size should be less than or equal to {v.shape[0]}, but got {lm_head_size}"
|
|
58
|
+
v = v[:lm_head_size, :]
|
|
59
|
+
param = model.get_parameter("lm_head.weight")
|
|
60
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
61
|
+
weight_loader(param, v)
|
|
62
|
+
elif k == "llm_decoder.bias": # torch.Size([6564])
|
|
63
|
+
lm_head_size = hf_config.speech_vocab_size # 6562
|
|
64
|
+
if lm_head_size != v.shape[0]: # [6564] -> [6562]
|
|
65
|
+
assert lm_head_size <= v.shape[0], f"lm_head_size should be less than or equal to {v.shape[0]}, but got {lm_head_size}"
|
|
66
|
+
v = v[:lm_head_size]
|
|
67
|
+
param = model.get_parameter("lm_head.bias")
|
|
68
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
69
|
+
weight_loader(param, v)
|
|
70
|
+
elif "llm.model." in k:
|
|
71
|
+
weight_name = k.replace("llm.model.", "")
|
|
72
|
+
for kk in packed_modules_mapping:
|
|
73
|
+
if kk in weight_name:
|
|
74
|
+
vv, shard_id = packed_modules_mapping[kk]
|
|
75
|
+
param_name = weight_name.replace(kk, vv)
|
|
76
|
+
try:
|
|
77
|
+
param = model.get_parameter(param_name)
|
|
78
|
+
weight_loader = param.weight_loader
|
|
79
|
+
weight_loader(param, v, shard_id)
|
|
80
|
+
break
|
|
81
|
+
except Exception as e:
|
|
82
|
+
print(e)
|
|
83
|
+
print(f"skip parameter (1): {weight_name}")
|
|
84
|
+
continue
|
|
85
|
+
else:
|
|
86
|
+
try:
|
|
87
|
+
param = model.get_parameter(weight_name)
|
|
88
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
89
|
+
weight_loader(param, v)
|
|
90
|
+
except Exception as e:
|
|
91
|
+
print(e)
|
|
92
|
+
print(f"skip parameter (2): {weight_name}")
|
|
93
|
+
continue
|
|
94
|
+
else:
|
|
95
|
+
missed += 1
|
|
96
|
+
missed_names.append(weight_name)
|
|
97
|
+
continue
|
|
98
|
+
print(f"missed {missed} parameters: {missed_names}")
|
|
99
|
+
|
|
100
|
+
# NOTE(xcsong): 2. merge text embedding, sos/taskid embedding, and speech embedding
|
|
101
|
+
text_embedding_weight = embedding_weights["model.embed_tokens.weight"].cpu() # [151936, 896]
|
|
102
|
+
sos_taskid_embedding_weight = embedding_weights["llm_embedding.weight"].cpu() # [2, 896]
|
|
103
|
+
speech_embedding_weight = embedding_weights["speech_embedding.weight"].cpu() # [6562, 896]
|
|
104
|
+
final_embedding_weight = torch.cat([speech_embedding_weight, sos_taskid_embedding_weight, text_embedding_weight], dim=0) # [158500, 896]
|
|
105
|
+
param = model.get_parameter("model.embed_tokens.weight")
|
|
106
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
107
|
+
weight_loader(param, final_embedding_weight)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def load_model(model: nn.Module, path: str, hf_config: CosyVoice2LLMConfig | None = None):
|
|
111
|
+
if model.model_type == "speech_llm":
|
|
112
|
+
load_speech_llm(model, path, hf_config)
|
|
113
|
+
elif model.model_type == "text_llm":
|
|
114
|
+
load_text_llm(model, path)
|
|
115
|
+
else:
|
|
116
|
+
raise ValueError(f"Unsupported model type: {model.model_type}")
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from pynvml import * # noqa
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def get_gpu_memory():
|
|
8
|
+
torch.cuda.synchronize()
|
|
9
|
+
nvmlInit()
|
|
10
|
+
visible_device = list(map(int, os.getenv("CUDA_VISIBLE_DEVICES", "0,1,2,3,4,5,6,7").split(',')))
|
|
11
|
+
cuda_device_idx = torch.cuda.current_device()
|
|
12
|
+
cuda_device_idx = visible_device[cuda_device_idx]
|
|
13
|
+
handle = nvmlDeviceGetHandleByIndex(cuda_device_idx)
|
|
14
|
+
mem_info = nvmlDeviceGetMemoryInfo(handle)
|
|
15
|
+
total_memory = mem_info.total
|
|
16
|
+
used_memory = mem_info.used
|
|
17
|
+
free_memory = mem_info.free
|
|
18
|
+
nvmlShutdown()
|
|
19
|
+
return total_memory, used_memory, free_memory
|
stepaudio2/stepaudio2.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
|
3
|
+
|
|
4
|
+
from stepaudio2.utils import compute_token_num, load_audio, log_mel_spectrogram, padding_mels
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class StepAudio2Base:
|
|
8
|
+
|
|
9
|
+
def __init__(self, model_path: str):
|
|
10
|
+
self.llm_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side="right")
|
|
11
|
+
self.llm = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16).cuda()
|
|
12
|
+
self.eos_token_id = self.llm_tokenizer.eos_token_id
|
|
13
|
+
|
|
14
|
+
def __call__(self, messages: list, **kwargs):
|
|
15
|
+
messages, mels = self.apply_chat_template(messages)
|
|
16
|
+
|
|
17
|
+
# Tokenize prompts
|
|
18
|
+
prompt_ids = []
|
|
19
|
+
for msg in messages:
|
|
20
|
+
if isinstance(msg, str):
|
|
21
|
+
prompt_ids.append(self.llm_tokenizer(text=msg, return_tensors="pt", padding=True)["input_ids"])
|
|
22
|
+
elif isinstance(msg, list):
|
|
23
|
+
prompt_ids.append(torch.tensor([msg], dtype=torch.int32))
|
|
24
|
+
else:
|
|
25
|
+
raise ValueError(f"Unsupported content type: {type(msg)}")
|
|
26
|
+
prompt_ids = torch.cat(prompt_ids, dim=-1).cuda()
|
|
27
|
+
attention_mask = torch.ones_like(prompt_ids)
|
|
28
|
+
|
|
29
|
+
#mels = None if len(mels) == 0 else torch.stack(mels).cuda()
|
|
30
|
+
#mel_lengths = None if mels is None else torch.tensor([mel.shape[1] - 2 for mel in mels], dtype=torch.int32, device='cuda')
|
|
31
|
+
if len(mels)==0:
|
|
32
|
+
mels = None
|
|
33
|
+
mel_lengths = None
|
|
34
|
+
else:
|
|
35
|
+
mels, mel_lengths = padding_mels(mels)
|
|
36
|
+
mels = mels.cuda()
|
|
37
|
+
mel_lengths = mel_lengths.cuda()
|
|
38
|
+
|
|
39
|
+
generate_inputs = {
|
|
40
|
+
"input_ids": prompt_ids,
|
|
41
|
+
"wavs": mels,
|
|
42
|
+
"wav_lens": mel_lengths,
|
|
43
|
+
"attention_mask":attention_mask
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
generation_config = dict(max_new_tokens=2048,
|
|
47
|
+
pad_token_id=self.llm_tokenizer.pad_token_id,
|
|
48
|
+
eos_token_id=self.eos_token_id,
|
|
49
|
+
)
|
|
50
|
+
generation_config.update(kwargs)
|
|
51
|
+
generation_config = GenerationConfig(**generation_config)
|
|
52
|
+
|
|
53
|
+
outputs = self.llm.generate(**generate_inputs, generation_config=generation_config, tokenizer=self.llm_tokenizer)
|
|
54
|
+
output_token_ids = outputs[0, prompt_ids.shape[-1] : -1].tolist()
|
|
55
|
+
output_text_tokens = [i for i in output_token_ids if i < 151688]
|
|
56
|
+
output_audio_tokens = [i - 151696 for i in output_token_ids if i > 151695]
|
|
57
|
+
output_text = self.llm_tokenizer.decode(output_text_tokens)
|
|
58
|
+
return output_token_ids, output_text, output_audio_tokens
|
|
59
|
+
|
|
60
|
+
def apply_chat_template(self, messages: list):
|
|
61
|
+
results = []
|
|
62
|
+
mels = []
|
|
63
|
+
for msg in messages:
|
|
64
|
+
content = msg
|
|
65
|
+
if isinstance(content, str):
|
|
66
|
+
text_with_audio = content
|
|
67
|
+
results.append(text_with_audio)
|
|
68
|
+
elif isinstance(content, dict):
|
|
69
|
+
if content["type"] == "text":
|
|
70
|
+
results.append(f"{content['text']}")
|
|
71
|
+
elif content["type"] == "audio":
|
|
72
|
+
audio = load_audio(content['audio'])
|
|
73
|
+
for i in range(0, audio.shape[0], 16000 * 25):
|
|
74
|
+
mel = log_mel_spectrogram(audio[i:i+16000*25], n_mels=128, padding=479)
|
|
75
|
+
mels.append(mel)
|
|
76
|
+
audio_tokens = "<audio_patch>" * compute_token_num(mel.shape[1])
|
|
77
|
+
results.append(f"<audio_start>{audio_tokens}<audio_end>")
|
|
78
|
+
elif content["type"] == "token":
|
|
79
|
+
results.append(content["token"])
|
|
80
|
+
else:
|
|
81
|
+
raise ValueError(f"Unsupported content type: {type(content)}")
|
|
82
|
+
# print(results)
|
|
83
|
+
return results, mels
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class StepAudio2(StepAudio2Base):
|
|
87
|
+
|
|
88
|
+
def __init__(self, model_path: str):
|
|
89
|
+
super().__init__(model_path)
|
|
90
|
+
self.llm_tokenizer.eos_token = "<|EOT|>"
|
|
91
|
+
self.llm.config.eos_token_id = self.llm_tokenizer.convert_tokens_to_ids("<|EOT|>")
|
|
92
|
+
self.eos_token_id = self.llm_tokenizer.convert_tokens_to_ids("<|EOT|>")
|
|
93
|
+
|
|
94
|
+
def apply_chat_template(self, messages: list):
|
|
95
|
+
results = []
|
|
96
|
+
mels = []
|
|
97
|
+
for msg in messages:
|
|
98
|
+
role = msg["role"]
|
|
99
|
+
content = msg["content"]
|
|
100
|
+
if role == "user":
|
|
101
|
+
role = "human"
|
|
102
|
+
if isinstance(content, str):
|
|
103
|
+
text_with_audio = f"<|BOT|>{role}\n{content}"
|
|
104
|
+
text_with_audio += '<|EOT|>' if msg.get('eot', True) else ''
|
|
105
|
+
results.append(text_with_audio)
|
|
106
|
+
elif isinstance(content, list):
|
|
107
|
+
results.append(f"<|BOT|>{role}\n")
|
|
108
|
+
for item in content:
|
|
109
|
+
if item["type"] == "text":
|
|
110
|
+
results.append(f"{item['text']}")
|
|
111
|
+
elif item["type"] == "audio":
|
|
112
|
+
audio = load_audio(item['audio'])
|
|
113
|
+
for i in range(0, audio.shape[0], 16000 * 25):
|
|
114
|
+
mel = log_mel_spectrogram(audio[i:i+16000*25], n_mels=128, padding=479)
|
|
115
|
+
mels.append(mel)
|
|
116
|
+
audio_tokens = "<audio_patch>" * compute_token_num(mel.shape[1])
|
|
117
|
+
results.append(f"<audio_start>{audio_tokens}<audio_end>")
|
|
118
|
+
elif item["type"] == "token":
|
|
119
|
+
results.append(item["token"])
|
|
120
|
+
if msg.get('eot', True):
|
|
121
|
+
results.append('<|EOT|>')
|
|
122
|
+
elif content is None:
|
|
123
|
+
results.append(f"<|BOT|>{role}\n")
|
|
124
|
+
else:
|
|
125
|
+
raise ValueError(f"Unsupported content type: {type(content)}")
|
|
126
|
+
# print(results)
|
|
127
|
+
return results, mels
|
|
128
|
+
|
|
129
|
+
if __name__ == '__main__':
|
|
130
|
+
from stepaudio2.token2wav import Token2wav
|
|
131
|
+
|
|
132
|
+
model = StepAudio2('Step-Audio-2-mini')
|
|
133
|
+
token2wav = Token2wav('Step-Audio-2-mini/token2wav')
|
|
134
|
+
|
|
135
|
+
# Text-to-text conversation
|
|
136
|
+
print()
|
|
137
|
+
messages = [
|
|
138
|
+
{"role": "system", "content": "You are a helpful assistant."},
|
|
139
|
+
{"role": "human", "content": "Give me a brief introduction to the Great Wall."},
|
|
140
|
+
{"role": "assistant", "content": None}
|
|
141
|
+
]
|
|
142
|
+
tokens, text, _ = model(messages, max_new_tokens=256, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True)
|
|
143
|
+
print(text)
|
|
144
|
+
|
|
145
|
+
# Text-to-speech conversation
|
|
146
|
+
print()
|
|
147
|
+
messages = [
|
|
148
|
+
{"role": "system", "content": "You are a helpful assistant."},
|
|
149
|
+
{"role": "human", "content": "Give me a brief introduction to the Great Wall."},
|
|
150
|
+
{"role": "assistant", "content": "<tts_start>", "eot": False}, # Insert <tts_start> for speech response
|
|
151
|
+
]
|
|
152
|
+
tokens, text, audio = model(messages, max_new_tokens=4096, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True)
|
|
153
|
+
print(text)
|
|
154
|
+
print(tokens)
|
|
155
|
+
audio = token2wav(audio, prompt_wav='assets/default_male.wav')
|
|
156
|
+
with open('output-male.wav', 'wb') as f:
|
|
157
|
+
f.write(audio)
|
|
158
|
+
|
|
159
|
+
# Speech-to-text conversation
|
|
160
|
+
print()
|
|
161
|
+
messages = [
|
|
162
|
+
{"role": "system", "content": "You are a helpful assistant."},
|
|
163
|
+
{"role": "human", "content": [{"type": "audio", "audio": "assets/give_me_a_brief_introduction_to_the_great_wall.wav"}]},
|
|
164
|
+
{"role": "assistant", "content": None}
|
|
165
|
+
]
|
|
166
|
+
tokens, text, _ = model(messages, max_new_tokens=256, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True)
|
|
167
|
+
print(text)
|
|
168
|
+
|
|
169
|
+
# Speech-to-speech conversation
|
|
170
|
+
print()
|
|
171
|
+
messages = [
|
|
172
|
+
{"role": "system", "content": "You are a helpful assistant."},
|
|
173
|
+
{"role": "human", "content": [{"type": "audio", "audio": "assets/give_me_a_brief_introduction_to_the_great_wall.wav"}]},
|
|
174
|
+
{"role": "assistant", "content": "<tts_start>", "eot": False}, # Insert <tts_start> for speech response
|
|
175
|
+
]
|
|
176
|
+
tokens, text, audio = model(messages, max_new_tokens=4096, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True)
|
|
177
|
+
print(text)
|
|
178
|
+
print(tokens)
|
|
179
|
+
audio = token2wav(audio, prompt_wav='assets/default_female.wav')
|
|
180
|
+
with open('output-female.wav', 'wb') as f:
|
|
181
|
+
f.write(audio)
|
|
182
|
+
|
|
183
|
+
# Multi-turn conversation
|
|
184
|
+
print()
|
|
185
|
+
messages.pop(-1)
|
|
186
|
+
messages += [
|
|
187
|
+
{"role": "assistant", "content": [{"type": "text", "text": "<tts_start>"},
|
|
188
|
+
{"type": "token", "token": tokens}]},
|
|
189
|
+
{"role": "human", "content": "Now write a 4-line poem about it."},
|
|
190
|
+
{"role": "assistant", "content": None}
|
|
191
|
+
]
|
|
192
|
+
tokens, text, audio = model(messages, max_new_tokens=256, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True)
|
|
193
|
+
print(text)
|
|
194
|
+
|
|
195
|
+
# Multi-modal inputs
|
|
196
|
+
print()
|
|
197
|
+
messages = [
|
|
198
|
+
{"role": "system", "content": "You are a helpful assistant."},
|
|
199
|
+
{"role": "human", "content": [{"type": "text", "text": "Translate the speech into Chinese."},
|
|
200
|
+
{"type": "audio", "audio": "assets/give_me_a_brief_introduction_to_the_great_wall.wav"}]},
|
|
201
|
+
{"role": "assistant", "content": None}
|
|
202
|
+
]
|
|
203
|
+
tokens, text, audio = model(messages, max_new_tokens=256, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True)
|
|
204
|
+
print(text)
|