lollms-client 0.15.2__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lollms-client might be problematic. Click here for more details.
- examples/generate_and_speak/generate_and_speak.py +251 -0
- examples/generate_game_sfx/generate_game_fx.py +240 -0
- examples/simple_text_gen_with_image_test.py +8 -8
- examples/text_2_image.py +0 -1
- examples/text_gen.py +1 -1
- lollms_client/__init__.py +1 -1
- lollms_client/llm_bindings/llamacpp/__init__.py +61 -11
- lollms_client/llm_bindings/lollms/__init__.py +31 -24
- lollms_client/llm_bindings/ollama/__init__.py +47 -27
- lollms_client/llm_bindings/openai/__init__.py +62 -35
- lollms_client/llm_bindings/openllm/__init__.py +4 -1
- lollms_client/llm_bindings/pythonllamacpp/__init__.py +3 -0
- lollms_client/llm_bindings/tensor_rt/__init__.py +4 -1
- lollms_client/llm_bindings/transformers/__init__.py +3 -0
- lollms_client/llm_bindings/vllm/__init__.py +4 -1
- lollms_client/lollms_core.py +65 -33
- lollms_client/lollms_llm_binding.py +76 -22
- lollms_client/lollms_stt_binding.py +3 -15
- lollms_client/lollms_tti_binding.py +5 -29
- lollms_client/lollms_ttm_binding.py +5 -28
- lollms_client/lollms_tts_binding.py +4 -28
- lollms_client/lollms_ttv_binding.py +4 -28
- lollms_client/lollms_utilities.py +5 -3
- lollms_client/stt_bindings/lollms/__init__.py +5 -4
- lollms_client/stt_bindings/whisper/__init__.py +304 -0
- lollms_client/stt_bindings/whispercpp/__init__.py +380 -0
- lollms_client/tti_bindings/lollms/__init__.py +4 -6
- lollms_client/ttm_bindings/audiocraft/__init__.py +281 -0
- lollms_client/ttm_bindings/bark/__init__.py +339 -0
- lollms_client/tts_bindings/bark/__init__.py +336 -0
- lollms_client/tts_bindings/piper_tts/__init__.py +343 -0
- lollms_client/tts_bindings/xtts/__init__.py +317 -0
- lollms_client-0.17.0.dist-info/METADATA +183 -0
- lollms_client-0.17.0.dist-info/RECORD +65 -0
- lollms_client-0.15.2.dist-info/METADATA +0 -192
- lollms_client-0.15.2.dist-info/RECORD +0 -56
- {lollms_client-0.15.2.dist-info → lollms_client-0.17.0.dist-info}/WHEEL +0 -0
- {lollms_client-0.15.2.dist-info → lollms_client-0.17.0.dist-info}/licenses/LICENSE +0 -0
- {lollms_client-0.15.2.dist-info → lollms_client-0.17.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
# lollms_client/ttm_bindings/audiocraft/__init__.py
|
|
2
|
+
import io
|
|
3
|
+
import os
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Optional, List, Union, Dict, Any
|
|
6
|
+
|
|
7
|
+
from ascii_colors import trace_exception, ASCIIColors
|
|
8
|
+
|
|
9
|
+
# --- Package Management and Conditional Imports ---
|
|
10
|
+
_audiocraft_installed_with_correct_torch = False
|
|
11
|
+
_audiocraft_installation_error = ""
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
import pipmaster as pm
|
|
15
|
+
import platform # For OS detection for torch index
|
|
16
|
+
|
|
17
|
+
# Determine initial device preference to guide torch installation
|
|
18
|
+
preferred_torch_device_for_install = "cpu" # Default assumption
|
|
19
|
+
|
|
20
|
+
# Tentatively set preference based on OS, assuming user might want GPU if available
|
|
21
|
+
if platform.system() == "Linux" or platform.system() == "Windows":
|
|
22
|
+
# On Linux/Windows, CUDA is the primary GPU acceleration for PyTorch.
|
|
23
|
+
# We will try to install a CUDA version of PyTorch.
|
|
24
|
+
preferred_torch_device_for_install = "cuda"
|
|
25
|
+
elif platform.system() == "Darwin":
|
|
26
|
+
# On macOS, MPS is the acceleration. Standard torch install usually handles this.
|
|
27
|
+
preferred_torch_device_for_install = "mps" # or keep cpu if mps detection is later
|
|
28
|
+
|
|
29
|
+
torch_pkgs = ["torch", "torchaudio","xformers"]
|
|
30
|
+
audiocraft_core_pkgs = ["audiocraft"]
|
|
31
|
+
other_deps = ["scipy", "numpy"]
|
|
32
|
+
|
|
33
|
+
torch_index_url = None
|
|
34
|
+
if preferred_torch_device_for_install == "cuda":
|
|
35
|
+
# Specify a common CUDA version index. Pip should resolve the correct torch version.
|
|
36
|
+
# As of late 2023/early 2024, cu118 or cu121 are common. Let's use cu121.
|
|
37
|
+
# Users with different CUDA setups might need to pre-install torch manually.
|
|
38
|
+
torch_index_url = "https://download.pytorch.org/whl/cu126"
|
|
39
|
+
ASCIIColors.info(f"Attempting to ensure PyTorch with CUDA support (target index: {torch_index_url})")
|
|
40
|
+
# Install torch and torchaudio first from the specific index
|
|
41
|
+
pm.ensure_packages(torch_pkgs, index_url=torch_index_url)
|
|
42
|
+
# Then install audiocraft and other dependencies; pip should use the already installed torch
|
|
43
|
+
pm.ensure_packages(audiocraft_core_pkgs + other_deps)
|
|
44
|
+
else:
|
|
45
|
+
# For CPU, MPS, or if no specific CUDA preference was determined for install
|
|
46
|
+
ASCIIColors.info("Ensuring PyTorch, AudioCraft, and dependencies using default PyPI index.")
|
|
47
|
+
pm.ensure_packages(torch_pkgs + audiocraft_core_pkgs + other_deps)
|
|
48
|
+
|
|
49
|
+
# Now, perform the actual imports
|
|
50
|
+
import torch, torchaudio
|
|
51
|
+
from audiocraft.models import MusicGen
|
|
52
|
+
from audiocraft.data.audio import audio_write # For saving to bytes
|
|
53
|
+
import numpy as np
|
|
54
|
+
import scipy.io.wavfile # For direct WAV manipulation if needed, though audio_write is preferred
|
|
55
|
+
|
|
56
|
+
_audiocraft_installed_with_correct_torch = True # If imports succeed after ensure_packages
|
|
57
|
+
except Exception as e:
|
|
58
|
+
_audiocraft_installation_error = str(e)
|
|
59
|
+
# Set placeholders if imports fail
|
|
60
|
+
MusicGen, torch, audio_write, np, scipy = None, None, None, None, None
|
|
61
|
+
# --- End Package Management ---
|
|
62
|
+
|
|
63
|
+
from lollms_client.lollms_ttm_binding import LollmsTTMBinding
|
|
64
|
+
|
|
65
|
+
BindingName = "AudioCraftTTMBinding"
|
|
66
|
+
|
|
67
|
+
# Common MusicGen model IDs from Hugging Face
|
|
68
|
+
DEFAULT_AUDIOCRAFT_MODELS = [
|
|
69
|
+
"facebook/musicgen-small",
|
|
70
|
+
"facebook/musicgen-medium",
|
|
71
|
+
"facebook/musicgen-melody", # Can be conditioned on a melody audio file too
|
|
72
|
+
"facebook/musicgen-large",
|
|
73
|
+
]
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class AudioCraftTTMBinding(LollmsTTMBinding):
|
|
77
|
+
def __init__(self,
|
|
78
|
+
model_name: str = "facebook/musicgen-small", # HF ID or local path
|
|
79
|
+
device: Optional[str] = None, # "cpu", "cuda", "mps", or None for auto
|
|
80
|
+
output_format: str = "wav", # 'wav', 'mp3' (mp3 needs ffmpeg via audiocraft)
|
|
81
|
+
# Catch LollmsTTMBinding standard args
|
|
82
|
+
host_address: Optional[str] = None, # Not used by local binding
|
|
83
|
+
service_key: Optional[str] = None, # Not used by local binding
|
|
84
|
+
verify_ssl_certificate: bool = True,# Not used by local binding
|
|
85
|
+
**kwargs): # Catch-all for future compatibility or specific audiocraft params
|
|
86
|
+
|
|
87
|
+
super().__init__(binding_name="audiocraft")
|
|
88
|
+
|
|
89
|
+
if not _audiocraft_installed_with_correct_torch:
|
|
90
|
+
raise ImportError(f"AudioCraft TTM binding dependencies not met. Please ensure 'audiocraft', 'torch', 'torchaudio', 'scipy', 'numpy' are installed. Error: {_audiocraft_installation_error}")
|
|
91
|
+
|
|
92
|
+
self.device = device
|
|
93
|
+
if self.device is None: # Auto-detect if not specified by user
|
|
94
|
+
if torch.cuda.is_available():
|
|
95
|
+
self.device = "cuda"
|
|
96
|
+
ASCIIColors.info("CUDA device detected by PyTorch.")
|
|
97
|
+
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): # For Apple Silicon
|
|
98
|
+
self.device = "mps"
|
|
99
|
+
ASCIIColors.info("MPS device detected by PyTorch for Apple Silicon.")
|
|
100
|
+
else:
|
|
101
|
+
self.device = "cpu"
|
|
102
|
+
ASCIIColors.info("No GPU (CUDA/MPS) detected by PyTorch, using CPU.")
|
|
103
|
+
elif self.device == "cuda" and not torch.cuda.is_available():
|
|
104
|
+
ASCIIColors.warning("CUDA device requested, but torch.cuda.is_available() is False. Falling back to CPU.")
|
|
105
|
+
self.device = "cpu"
|
|
106
|
+
elif self.device == "mps" and not (hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()):
|
|
107
|
+
ASCIIColors.warning("MPS device requested, but not available. Falling back to CPU.")
|
|
108
|
+
self.device = "cpu"
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
ASCIIColors.info(f"AudioCraftTTMBinding: Using device '{self.device}'.")
|
|
112
|
+
|
|
113
|
+
self.loaded_model_name = None
|
|
114
|
+
self.model: Optional[MusicGen] = None
|
|
115
|
+
self.output_format = output_format.lower()
|
|
116
|
+
if self.output_format not in ["wav", "mp3"]:
|
|
117
|
+
ASCIIColors.warning(f"Unsupported output_format '{self.output_format}'. Defaulting to 'wav'.")
|
|
118
|
+
self.output_format = "wav"
|
|
119
|
+
|
|
120
|
+
self._load_audiocraft_model(model_name)
|
|
121
|
+
|
|
122
|
+
def _load_audiocraft_model(self, model_name_to_load: str):
|
|
123
|
+
if self.model is not None and self.loaded_model_name == model_name_to_load:
|
|
124
|
+
ASCIIColors.info(f"AudioCraft model '{model_name_to_load}' already loaded.")
|
|
125
|
+
return
|
|
126
|
+
|
|
127
|
+
ASCIIColors.info(f"Loading AudioCraft (MusicGen) model: '{model_name_to_load}' on device '{self.device}'...")
|
|
128
|
+
try:
|
|
129
|
+
self.model = MusicGen.get_pretrained(model_name_to_load, device=self.device)
|
|
130
|
+
self.loaded_model_name = model_name_to_load
|
|
131
|
+
# self.model_name is part of LollmsBinding base, but audiocraft uses loaded_model_name for its own logic.
|
|
132
|
+
# We can assign it for consistency if needed by LollmsClient core, though it's not directly used by this binding's logic post-load.
|
|
133
|
+
# self.model_name = model_name_to_load
|
|
134
|
+
|
|
135
|
+
ASCIIColors.green(f"AudioCraft model '{model_name_to_load}' loaded successfully.")
|
|
136
|
+
except Exception as e:
|
|
137
|
+
self.model = None
|
|
138
|
+
self.loaded_model_name = None
|
|
139
|
+
ASCIIColors.error(f"Failed to load AudioCraft model '{model_name_to_load}': {e}")
|
|
140
|
+
trace_exception(e)
|
|
141
|
+
raise RuntimeError(f"Failed to load AudioCraft model '{model_name_to_load}'") from e
|
|
142
|
+
|
|
143
|
+
def generate_music(self,
|
|
144
|
+
prompt: str,
|
|
145
|
+
duration: int = 8,
|
|
146
|
+
temperature: float = 1.0,
|
|
147
|
+
top_k: int = 250,
|
|
148
|
+
top_p: float = 0.0,
|
|
149
|
+
cfg_coef: float = 3.0,
|
|
150
|
+
progress: bool = True,
|
|
151
|
+
**kwargs) -> bytes:
|
|
152
|
+
if self.model is None:
|
|
153
|
+
raise RuntimeError("AudioCraft model is not loaded. Cannot generate music.")
|
|
154
|
+
|
|
155
|
+
self.model.set_generation_params(
|
|
156
|
+
duration=duration,
|
|
157
|
+
temperature=temperature,
|
|
158
|
+
top_k=top_k,
|
|
159
|
+
top_p=top_p,
|
|
160
|
+
cfg_coef=cfg_coef,
|
|
161
|
+
**kwargs
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
ASCIIColors.info(f"Generating music for prompt: '{prompt[:50]}...' (Duration: {duration}s, Temp: {temperature}, TopK: {top_k}, TopP: {top_p}, CFG: {cfg_coef})")
|
|
165
|
+
try:
|
|
166
|
+
wav_tensor = self.model.generate(descriptions=[prompt], progress=progress)
|
|
167
|
+
|
|
168
|
+
if wav_tensor is None or wav_tensor.numel() == 0:
|
|
169
|
+
raise RuntimeError("MusicGen returned empty audio data.")
|
|
170
|
+
|
|
171
|
+
if wav_tensor.ndim == 3 and wav_tensor.shape[0] == 1:
|
|
172
|
+
wav_tensor_single = wav_tensor.squeeze(0)
|
|
173
|
+
elif wav_tensor.ndim == 2:
|
|
174
|
+
wav_tensor_single = wav_tensor
|
|
175
|
+
else:
|
|
176
|
+
raise ValueError(f"Unexpected tensor shape from MusicGen: {wav_tensor.shape}")
|
|
177
|
+
|
|
178
|
+
buffer = io.BytesIO()
|
|
179
|
+
dummy_filename = f"musicgen_output.{self.output_format}" # For audiocraft's format detection
|
|
180
|
+
|
|
181
|
+
# audio_write needs tensor on CPU
|
|
182
|
+
torchaudio.save(
|
|
183
|
+
buffer,
|
|
184
|
+
wav_tensor_single.cpu(),
|
|
185
|
+
self.model.sample_rate,
|
|
186
|
+
format="wav" # Explicitly WAV
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
audio_bytes = buffer.getvalue()
|
|
190
|
+
buffer.close()
|
|
191
|
+
|
|
192
|
+
ASCIIColors.green("Music generation successful.")
|
|
193
|
+
return audio_bytes
|
|
194
|
+
|
|
195
|
+
except Exception as e:
|
|
196
|
+
ASCIIColors.error(f"AudioCraft music generation failed: {e}")
|
|
197
|
+
trace_exception(e)
|
|
198
|
+
# Provide more specific feedback for common issues
|
|
199
|
+
if "out of memory" in str(e).lower() and self.device == "cuda":
|
|
200
|
+
ASCIIColors.yellow("CUDA out of memory. Consider using a smaller model (e.g., 'facebook/musicgen-small'), a shorter duration, or ensure your GPU has sufficient VRAM (medium models might need ~10-12GB, large ~16GB+).")
|
|
201
|
+
elif "ffmpeg" in str(e).lower() and self.output_format == "mp3":
|
|
202
|
+
ASCIIColors.yellow("An FFmpeg error occurred. Ensure FFmpeg is installed and accessible in your system's PATH if you are generating MP3s.")
|
|
203
|
+
raise RuntimeError(f"AudioCraft music generation error: {e}") from e
|
|
204
|
+
|
|
205
|
+
def list_models(self, **kwargs) -> List[str]:
|
|
206
|
+
return DEFAULT_AUDIOCRAFT_MODELS.copy()
|
|
207
|
+
|
|
208
|
+
def __del__(self):
|
|
209
|
+
if hasattr(self, 'model') and self.model is not None: # Check if model attribute exists
|
|
210
|
+
del self.model
|
|
211
|
+
self.model = None
|
|
212
|
+
if torch and hasattr(torch, 'cuda') and torch.cuda.is_available():
|
|
213
|
+
torch.cuda.empty_cache()
|
|
214
|
+
if hasattr(self, 'loaded_model_name') and self.loaded_model_name: # Check if loaded_model_name exists
|
|
215
|
+
ASCIIColors.info(f"AudioCraftTTMBinding for model '{self.loaded_model_name}' destroyed and resources released.")
|
|
216
|
+
else:
|
|
217
|
+
ASCIIColors.info("AudioCraftTTMBinding destroyed (no model was fully loaded or name not set).")
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
# --- Main Test Block (Example Usage) ---
|
|
221
|
+
if __name__ == '__main__':
|
|
222
|
+
if not _audiocraft_installed_with_correct_torch:
|
|
223
|
+
print(f"{ASCIIColors.RED}AudioCraft dependencies not met or import failed. Skipping tests. Error: {_audiocraft_installation_error}{ASCIIColors.RESET}")
|
|
224
|
+
exit()
|
|
225
|
+
|
|
226
|
+
ASCIIColors.yellow("--- AudioCraftTTMBinding Test ---")
|
|
227
|
+
test_model_id = "facebook/musicgen-small" # Smallest model for quicker testing
|
|
228
|
+
test_output_dir = Path("./test_audiocraft_output")
|
|
229
|
+
test_output_dir.mkdir(exist_ok=True)
|
|
230
|
+
ttm_binding = None
|
|
231
|
+
|
|
232
|
+
try:
|
|
233
|
+
ASCIIColors.cyan(f"\n--- Initializing AudioCraftTTMBinding (model: '{test_model_id}') ---")
|
|
234
|
+
# Explicitly set device to CPU for basic test if no GPU, or let it auto-detect
|
|
235
|
+
# device_for_test = "cpu" if not (torch and torch.cuda.is_available()) else None
|
|
236
|
+
ttm_binding = AudioCraftTTMBinding(model_name=test_model_id, output_format="wav") # device=device_for_test
|
|
237
|
+
|
|
238
|
+
ASCIIColors.cyan("\n--- Listing common MusicGen models ---")
|
|
239
|
+
models = ttm_binding.list_models()
|
|
240
|
+
print(f"Common MusicGen models: {models}")
|
|
241
|
+
|
|
242
|
+
test_prompt_1 = "A lo-fi hip hop beat with a chill piano melody and soft drums, perfect for studying."
|
|
243
|
+
test_prompt_2 = "Epic orchestral score for a fantasy battle scene, with choirs and horns."
|
|
244
|
+
|
|
245
|
+
prompts_to_test = [
|
|
246
|
+
("lofi_chill", test_prompt_1),
|
|
247
|
+
("epic_battle", test_prompt_2),
|
|
248
|
+
]
|
|
249
|
+
|
|
250
|
+
for name, prompt in prompts_to_test:
|
|
251
|
+
ASCIIColors.cyan(f"\n--- Generating music for: '{name}' (duration 3s) ---")
|
|
252
|
+
print(f"Prompt: {prompt}")
|
|
253
|
+
try:
|
|
254
|
+
music_bytes = ttm_binding.generate_music(prompt, duration=3, progress=True)
|
|
255
|
+
|
|
256
|
+
if music_bytes:
|
|
257
|
+
output_filename = f"test_{name}_{test_model_id.split('/')[-1]}.{ttm_binding.output_format}"
|
|
258
|
+
output_path = test_output_dir / output_filename
|
|
259
|
+
with open(output_path, "wb") as f:
|
|
260
|
+
f.write(music_bytes)
|
|
261
|
+
ASCIIColors.green(f"Music for '{name}' saved to: {output_path} ({len(music_bytes) / 1024:.2f} KB)")
|
|
262
|
+
else:
|
|
263
|
+
ASCIIColors.error(f"Music generation for '{name}' returned empty bytes.")
|
|
264
|
+
except Exception as e_gen:
|
|
265
|
+
ASCIIColors.error(f"Failed to generate music for '{name}': {e_gen}")
|
|
266
|
+
# Error details already printed by generate_music method
|
|
267
|
+
|
|
268
|
+
except ImportError as e_imp:
|
|
269
|
+
ASCIIColors.error(f"Import error during test setup: {e_imp}")
|
|
270
|
+
except RuntimeError as e_rt: # Catch runtime errors from init or generate
|
|
271
|
+
ASCIIColors.error(f"Runtime error during test: {e_rt}")
|
|
272
|
+
except Exception as e:
|
|
273
|
+
ASCIIColors.error(f"An unexpected error occurred during testing: {e}")
|
|
274
|
+
trace_exception(e)
|
|
275
|
+
finally:
|
|
276
|
+
if ttm_binding:
|
|
277
|
+
del ttm_binding
|
|
278
|
+
ASCIIColors.info(f"Test artifacts (if any) are in: {test_output_dir.resolve()}")
|
|
279
|
+
print(f"{ASCIIColors.YELLOW}Remember to check the audio files in '{test_output_dir.resolve()}'!{ASCIIColors.RESET}")
|
|
280
|
+
|
|
281
|
+
ASCIIColors.yellow("\n--- AudioCraftTTMBinding Test Finished ---")
|
|
@@ -0,0 +1,339 @@
|
|
|
1
|
+
# lollms_client/ttm_bindings/bark/__init__.py
|
|
2
|
+
import io
|
|
3
|
+
import os
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Optional, List, Union, Dict, Any
|
|
6
|
+
|
|
7
|
+
from ascii_colors import trace_exception, ASCIIColors
|
|
8
|
+
|
|
9
|
+
# --- Package Management and Conditional Imports ---
|
|
10
|
+
_bark_deps_installed_with_correct_torch = False
|
|
11
|
+
_bark_installation_error = ""
|
|
12
|
+
try:
|
|
13
|
+
import pipmaster as pm
|
|
14
|
+
import platform
|
|
15
|
+
|
|
16
|
+
preferred_torch_device_for_install = "cpu"
|
|
17
|
+
if platform.system() == "Linux" or platform.system() == "Windows":
|
|
18
|
+
preferred_torch_device_for_install = "cuda"
|
|
19
|
+
elif platform.system() == "Darwin":
|
|
20
|
+
preferred_torch_device_for_install = "mps"
|
|
21
|
+
|
|
22
|
+
torch_pkgs = ["torch"]
|
|
23
|
+
bark_core_pkgs = ["transformers", "accelerate", "sentencepiece"]
|
|
24
|
+
other_deps = ["scipy", "numpy"]
|
|
25
|
+
|
|
26
|
+
torch_index_url = None
|
|
27
|
+
if preferred_torch_device_for_install == "cuda":
|
|
28
|
+
torch_index_url = "https://download.pytorch.org/whl/cu126"
|
|
29
|
+
ASCIIColors.info(f"Attempting to ensure PyTorch with CUDA support (target index: {torch_index_url}) for Bark binding.")
|
|
30
|
+
pm.ensure_packages(torch_pkgs, index_url=torch_index_url)
|
|
31
|
+
pm.ensure_packages(bark_core_pkgs + other_deps)
|
|
32
|
+
else:
|
|
33
|
+
ASCIIColors.info("Ensuring PyTorch, Bark dependencies, and others using default PyPI index for Bark binding.")
|
|
34
|
+
pm.ensure_packages(torch_pkgs + bark_core_pkgs + other_deps)
|
|
35
|
+
|
|
36
|
+
import torch
|
|
37
|
+
from transformers import AutoProcessor, BarkModel, GenerationConfig
|
|
38
|
+
import scipy.io.wavfile
|
|
39
|
+
import numpy as np
|
|
40
|
+
|
|
41
|
+
_bark_deps_installed_with_correct_torch = True
|
|
42
|
+
except Exception as e:
|
|
43
|
+
_bark_installation_error = str(e)
|
|
44
|
+
AutoProcessor, BarkModel, GenerationConfig, torch, scipy, np = None, None, None, None, None, None
|
|
45
|
+
# --- End Package Management ---
|
|
46
|
+
|
|
47
|
+
from lollms_client.lollms_ttm_binding import LollmsTTMBinding
|
|
48
|
+
|
|
49
|
+
BindingName = "BarkTTMBinding"
|
|
50
|
+
|
|
51
|
+
DEFAULT_BARK_MODELS = [
|
|
52
|
+
"suno/bark",
|
|
53
|
+
"suno/bark-small",
|
|
54
|
+
]
|
|
55
|
+
|
|
56
|
+
BARK_VOICE_PRESETS_EXAMPLES = [
|
|
57
|
+
"v2/en_speaker_0", "v2/en_speaker_1", "v2/en_speaker_2", "v2/en_speaker_3",
|
|
58
|
+
"v2/en_speaker_4", "v2/en_speaker_5", "v2/en_speaker_6", "v2/en_speaker_7",
|
|
59
|
+
"v2/en_speaker_8", "v2/en_speaker_9",
|
|
60
|
+
"v2/de_speaker_0", "v2/es_speaker_0", "v2/fr_speaker_0",
|
|
61
|
+
]
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class BarkTTMBinding(LollmsTTMBinding):
|
|
65
|
+
def __init__(self,
|
|
66
|
+
model_name: str = "suno/bark-small",
|
|
67
|
+
device: Optional[str] = None,
|
|
68
|
+
default_voice_preset: Optional[str] = "v2/en_speaker_6",
|
|
69
|
+
enable_better_transformer: bool = True,
|
|
70
|
+
**kwargs):
|
|
71
|
+
|
|
72
|
+
super().__init__(binding_name="bark")
|
|
73
|
+
|
|
74
|
+
if not _bark_deps_installed_with_correct_torch:
|
|
75
|
+
raise ImportError(f"Bark TTM binding dependencies not met. Error: {_bark_installation_error}")
|
|
76
|
+
|
|
77
|
+
self.device = device
|
|
78
|
+
if self.device is None:
|
|
79
|
+
if torch.cuda.is_available(): self.device = "cuda"; ASCIIColors.info("CUDA device detected by PyTorch for Bark.")
|
|
80
|
+
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): self.device = "mps"; ASCIIColors.info("MPS device detected for Bark.")
|
|
81
|
+
else: self.device = "cpu"; ASCIIColors.info("No GPU (CUDA/MPS) by PyTorch, using CPU for Bark.")
|
|
82
|
+
elif self.device == "cuda" and not torch.cuda.is_available(): self.device = "cpu"; ASCIIColors.warning("CUDA req, not avail. CPU for Bark.")
|
|
83
|
+
elif self.device == "mps" and not (hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()): self.device = "cpu"; ASCIIColors.warning("MPS req, not avail. CPU for Bark.")
|
|
84
|
+
|
|
85
|
+
ASCIIColors.info(f"BarkTTMBinding: Using device '{self.device}'.")
|
|
86
|
+
|
|
87
|
+
self.loaded_model_name = None
|
|
88
|
+
self.model: Optional[BarkModel] = None
|
|
89
|
+
self.processor: Optional[AutoProcessor] = None
|
|
90
|
+
self.default_voice_preset = default_voice_preset
|
|
91
|
+
self.enable_better_transformer = enable_better_transformer
|
|
92
|
+
|
|
93
|
+
self.default_generation_params = {}
|
|
94
|
+
temp_gen_config = GenerationConfig()
|
|
95
|
+
for key, value in kwargs.items():
|
|
96
|
+
if hasattr(temp_gen_config, key):
|
|
97
|
+
self.default_generation_params[key] = value
|
|
98
|
+
|
|
99
|
+
self._load_bark_model(model_name)
|
|
100
|
+
|
|
101
|
+
def _load_bark_model(self, model_name_to_load: str):
|
|
102
|
+
if self.model is not None and self.loaded_model_name == model_name_to_load:
|
|
103
|
+
ASCIIColors.info(f"Bark model '{model_name_to_load}' already loaded.")
|
|
104
|
+
return
|
|
105
|
+
|
|
106
|
+
ASCIIColors.info(f"Loading Bark model: '{model_name_to_load}' on device '{self.device}'...")
|
|
107
|
+
try:
|
|
108
|
+
dtype_for_bark = torch.float16 if self.device == "cuda" else None
|
|
109
|
+
|
|
110
|
+
self.processor = AutoProcessor.from_pretrained(model_name_to_load)
|
|
111
|
+
self.model = BarkModel.from_pretrained(
|
|
112
|
+
model_name_to_load,
|
|
113
|
+
torch_dtype=dtype_for_bark,
|
|
114
|
+
low_cpu_mem_usage=True if self.device != "cpu" else False
|
|
115
|
+
).to(self.device)
|
|
116
|
+
|
|
117
|
+
if self.enable_better_transformer and self.device == "cuda":
|
|
118
|
+
try:
|
|
119
|
+
self.model = self.model.to_bettertransformer()
|
|
120
|
+
ASCIIColors.info("Applied BetterTransformer optimization to Bark model.")
|
|
121
|
+
except Exception as e_bt:
|
|
122
|
+
ASCIIColors.warning(f"Failed to apply BetterTransformer: {e_bt}. Proceeding without it.")
|
|
123
|
+
|
|
124
|
+
if "small" not in model_name_to_load and self.device=="cpu":
|
|
125
|
+
ASCIIColors.warning("Using full Bark model on CPU. Generation might be slow.")
|
|
126
|
+
elif self.device != "cpu" and "small" not in model_name_to_load:
|
|
127
|
+
if hasattr(self.model, "enable_model_cpu_offload"):
|
|
128
|
+
try: self.model.enable_model_cpu_offload(); ASCIIColors.info("Enabled model_cpu_offload for Bark.")
|
|
129
|
+
except Exception as e: ASCIIColors.warning(f"Could not enable model_cpu_offload: {e}")
|
|
130
|
+
elif hasattr(self.model, "enable_cpu_offload"):
|
|
131
|
+
try: self.model.enable_cpu_offload(); ASCIIColors.info("Enabled cpu_offload for Bark (older API).")
|
|
132
|
+
except Exception as e: ASCIIColors.warning(f"Could not enable cpu_offload (older API): {e}")
|
|
133
|
+
else: ASCIIColors.info("CPU offload not explicitly enabled.")
|
|
134
|
+
|
|
135
|
+
self.loaded_model_name = model_name_to_load
|
|
136
|
+
ASCIIColors.green(f"Bark model '{model_name_to_load}' loaded successfully.")
|
|
137
|
+
except Exception as e:
|
|
138
|
+
self.model, self.processor, self.loaded_model_name = None, None, None
|
|
139
|
+
ASCIIColors.error(f"Failed to load Bark model '{model_name_to_load}': {e}"); trace_exception(e)
|
|
140
|
+
raise RuntimeError(f"Failed to load Bark model '{model_name_to_load}'") from e
|
|
141
|
+
|
|
142
|
+
def generate_music(self,
|
|
143
|
+
prompt: str,
|
|
144
|
+
voice_preset: Optional[str] = None,
|
|
145
|
+
do_sample: Optional[bool] = None,
|
|
146
|
+
temperature: Optional[float] = None,
|
|
147
|
+
**kwargs) -> bytes:
|
|
148
|
+
if self.model is None or self.processor is None:
|
|
149
|
+
raise RuntimeError("Bark model or processor not loaded.")
|
|
150
|
+
|
|
151
|
+
effective_voice_preset = voice_preset if voice_preset is not None else self.default_voice_preset
|
|
152
|
+
|
|
153
|
+
ASCIIColors.info(f"Generating SFX/audio with Bark: '{prompt[:60]}...' (Preset: {effective_voice_preset})")
|
|
154
|
+
try:
|
|
155
|
+
# The processor correctly returns 'input_ids' and 'attention_mask'
|
|
156
|
+
inputs = self.processor(
|
|
157
|
+
text=[prompt], # Processor expects a list of texts
|
|
158
|
+
voice_preset=effective_voice_preset,
|
|
159
|
+
return_tensors="pt",
|
|
160
|
+
# Explicitly ask for padding if tokenizer supports it,
|
|
161
|
+
# though Bark's processor might handle this internally.
|
|
162
|
+
# padding=True, # Let processor decide best padding strategy
|
|
163
|
+
# truncation=True # Ensure inputs fit model context
|
|
164
|
+
)
|
|
165
|
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
166
|
+
|
|
167
|
+
# Ensure attention_mask is present
|
|
168
|
+
if 'attention_mask' not in inputs:
|
|
169
|
+
ASCIIColors.warning("Processor did not return attention_mask. Creating a default one (all ones). This might lead to suboptimal results if padding was intended.")
|
|
170
|
+
inputs['attention_mask'] = torch.ones_like(inputs['input_ids'])
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
if hasattr(self.model, 'generation_config') and self.model.generation_config is not None:
|
|
174
|
+
gen_config = GenerationConfig.from_dict(self.model.generation_config.to_dict())
|
|
175
|
+
else:
|
|
176
|
+
gen_config = GenerationConfig()
|
|
177
|
+
|
|
178
|
+
for key, value in self.default_generation_params.items():
|
|
179
|
+
if hasattr(gen_config, key): setattr(gen_config, key, value)
|
|
180
|
+
|
|
181
|
+
if do_sample is not None: gen_config.do_sample = do_sample
|
|
182
|
+
|
|
183
|
+
if temperature is not None:
|
|
184
|
+
if 'semantic_temperature' not in kwargs and hasattr(gen_config, 'semantic_temperature'): gen_config.semantic_temperature = temperature
|
|
185
|
+
if 'coarse_temperature' not in kwargs and hasattr(gen_config, 'coarse_temperature'): gen_config.coarse_temperature = temperature
|
|
186
|
+
if 'fine_temperature' not in kwargs and hasattr(gen_config, 'fine_temperature'): gen_config.fine_temperature = temperature
|
|
187
|
+
|
|
188
|
+
for key, value in kwargs.items():
|
|
189
|
+
if hasattr(gen_config, key): setattr(gen_config, key, value)
|
|
190
|
+
|
|
191
|
+
# Critical: Set pad_token_id in GenerationConfig.
|
|
192
|
+
# Bark uses specific token IDs for its different codebooks.
|
|
193
|
+
# The processor's tokenizer should have the correct pad_token_id if it's used for text inputs.
|
|
194
|
+
# For Bark, the semantic vocabulary has its own pad_token_id, often same as EOS.
|
|
195
|
+
# Let's try to get it from the model's semantic config or text config.
|
|
196
|
+
pad_token_id_to_set = None
|
|
197
|
+
if hasattr(self.model.config, 'semantic_config') and hasattr(self.model.config.semantic_config, 'pad_token_id'):
|
|
198
|
+
pad_token_id_to_set = self.model.config.semantic_config.pad_token_id
|
|
199
|
+
elif hasattr(self.model.config, 'text_config') and hasattr(self.model.config.text_config, 'pad_token_id'):
|
|
200
|
+
pad_token_id_to_set = self.model.config.text_config.pad_token_id
|
|
201
|
+
elif hasattr(self.processor, 'tokenizer') and self.processor.tokenizer and self.processor.tokenizer.pad_token_id is not None:
|
|
202
|
+
pad_token_id_to_set = self.processor.tokenizer.pad_token_id
|
|
203
|
+
|
|
204
|
+
if pad_token_id_to_set is not None:
|
|
205
|
+
gen_config.pad_token_id = pad_token_id_to_set
|
|
206
|
+
# Also set EOS token if it's distinct and meaningful for generation stopping
|
|
207
|
+
if hasattr(gen_config, 'eos_token_id') and gen_config.eos_token_id is None:
|
|
208
|
+
eos_id = None
|
|
209
|
+
if hasattr(self.model.config, 'semantic_config') and hasattr(self.model.config.semantic_config, 'eos_token_id'):
|
|
210
|
+
eos_id = self.model.config.semantic_config.eos_token_id
|
|
211
|
+
if eos_id is not None:
|
|
212
|
+
gen_config.eos_token_id = eos_id
|
|
213
|
+
|
|
214
|
+
else:
|
|
215
|
+
# This state is problematic for Bark if pad_token_id is truly needed and distinct from EOS
|
|
216
|
+
ASCIIColors.warning("Could not determine a specific pad_token_id from Bark's config for GenerationConfig. This might lead to issues.")
|
|
217
|
+
# If eos_token_id is also not set, generation might not stop correctly.
|
|
218
|
+
# Defaulting pad_token_id to eos_token_id if eos_token_id exists.
|
|
219
|
+
if gen_config.eos_token_id is not None:
|
|
220
|
+
gen_config.pad_token_id = gen_config.eos_token_id
|
|
221
|
+
ASCIIColors.info(f"Setting pad_token_id to eos_token_id ({gen_config.eos_token_id}) as a fallback.")
|
|
222
|
+
else:
|
|
223
|
+
# This is a last resort and might not be correct for Bark specifically
|
|
224
|
+
gen_config.pad_token_id = 0
|
|
225
|
+
ASCIIColors.warning("pad_token_id defaulted to 0 as a last resort.")
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
ASCIIColors.debug(f"Bark final generation_config: {gen_config.to_json_string()}")
|
|
229
|
+
|
|
230
|
+
with torch.no_grad():
|
|
231
|
+
output = self.model.generate(
|
|
232
|
+
input_ids=inputs['input_ids'], # Explicitly pass input_ids
|
|
233
|
+
attention_mask=inputs.get('attention_mask'), # Pass attention_mask if available
|
|
234
|
+
generation_config=gen_config
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
if isinstance(output, torch.Tensor): speech_output_tensor = output
|
|
238
|
+
elif isinstance(output, dict) and "audio_features" in output: speech_output_tensor = output["audio_features"]
|
|
239
|
+
elif isinstance(output, dict) and "waveform" in output: speech_output_tensor = output["waveform"] # Bark might return this key
|
|
240
|
+
else: raise TypeError(f"Unexpected output type from BarkModel.generate: {type(output)}. Content: {output}")
|
|
241
|
+
|
|
242
|
+
audio_array_np = speech_output_tensor.cpu().numpy().squeeze()
|
|
243
|
+
if audio_array_np.ndim == 0 or audio_array_np.size == 0:
|
|
244
|
+
raise RuntimeError("Bark model returned empty audio data.")
|
|
245
|
+
|
|
246
|
+
audio_int16 = (audio_array_np * 32767).astype(np.int16)
|
|
247
|
+
|
|
248
|
+
buffer = io.BytesIO()
|
|
249
|
+
sample_rate_to_use = int(self.model.generation_config.sample_rate if hasattr(self.model.generation_config, 'sample_rate') and self.model.generation_config.sample_rate else 24_000)
|
|
250
|
+
scipy.io.wavfile.write(buffer, rate=sample_rate_to_use, data=audio_int16)
|
|
251
|
+
audio_bytes = buffer.getvalue()
|
|
252
|
+
buffer.close()
|
|
253
|
+
|
|
254
|
+
ASCIIColors.green("Bark audio generation successful.")
|
|
255
|
+
return audio_bytes
|
|
256
|
+
except Exception as e:
|
|
257
|
+
ASCIIColors.error(f"Bark audio generation failed: {e}"); trace_exception(e)
|
|
258
|
+
if "out of memory" in str(e).lower() and self.device == "cuda":
|
|
259
|
+
ASCIIColors.yellow("CUDA out of memory. Consider using suno/bark-small or ensure GPU has sufficient VRAM.")
|
|
260
|
+
raise RuntimeError(f"Bark audio generation error: {e}") from e
|
|
261
|
+
|
|
262
|
+
def list_models(self, **kwargs) -> List[str]:
|
|
263
|
+
return DEFAULT_BARK_MODELS.copy()
|
|
264
|
+
|
|
265
|
+
def list_voice_presets(self) -> List[str]:
|
|
266
|
+
return BARK_VOICE_PRESETS_EXAMPLES.copy()
|
|
267
|
+
|
|
268
|
+
def __del__(self):
|
|
269
|
+
if hasattr(self, 'model') and self.model is not None:
|
|
270
|
+
del self.model; self.model = None
|
|
271
|
+
if hasattr(self, 'processor') and self.processor is not None:
|
|
272
|
+
del self.processor; self.processor = None
|
|
273
|
+
if torch and hasattr(torch, 'cuda') and torch.cuda.is_available():
|
|
274
|
+
torch.cuda.empty_cache()
|
|
275
|
+
loaded_name = getattr(self, 'loaded_model_name', None)
|
|
276
|
+
msg = f"BarkTTMBinding for model '{loaded_name}' destroyed." if loaded_name else "BarkTTMBinding destroyed."
|
|
277
|
+
ASCIIColors.info(msg)
|
|
278
|
+
|
|
279
|
+
# --- Main Test Block ---
|
|
280
|
+
if __name__ == '__main__':
|
|
281
|
+
if not _bark_deps_installed_with_correct_torch:
|
|
282
|
+
print(f"{ASCIIColors.RED}Bark TTM binding dependencies not met. Skipping tests. Error: {_bark_installation_error}{ASCIIColors.RESET}")
|
|
283
|
+
exit()
|
|
284
|
+
|
|
285
|
+
ASCIIColors.yellow("--- BarkTTMBinding Test ---")
|
|
286
|
+
test_model_id = "suno/bark-small"
|
|
287
|
+
test_output_dir = Path("./test_bark_sfx_output")
|
|
288
|
+
test_output_dir.mkdir(exist_ok=True)
|
|
289
|
+
ttm_binding = None
|
|
290
|
+
|
|
291
|
+
try:
|
|
292
|
+
ASCIIColors.cyan(f"\n--- Initializing BarkTTMBinding (model: '{test_model_id}') ---")
|
|
293
|
+
ttm_binding = BarkTTMBinding(model_name=test_model_id)
|
|
294
|
+
|
|
295
|
+
ASCIIColors.cyan("\n--- Listing common Bark models ---")
|
|
296
|
+
models = ttm_binding.list_models(); print(f"Common Bark models: {models}")
|
|
297
|
+
ASCIIColors.cyan("\n--- Listing example Bark voice presets ---")
|
|
298
|
+
presets = ttm_binding.list_voice_presets(); print(f"Example presets: {presets[:5]}...")
|
|
299
|
+
|
|
300
|
+
sfx_prompts_to_test = [
|
|
301
|
+
("laser_blast", "A short, sharp laser blast sound effect [SFX]"),
|
|
302
|
+
("footsteps_gravel", "Footsteps walking on gravel [footsteps]."),
|
|
303
|
+
("explosion_distant", "A distant explosion [boom] with a slight echo."),
|
|
304
|
+
("interface_click", "A clean, quick digital interface click sound. [click]"),
|
|
305
|
+
("creature_roar_short", "[roar] A short, guttural creature roar."),
|
|
306
|
+
("ambient_wind", "[wind] Gentle wind blowing through trees."),
|
|
307
|
+
("speech_hello", "Hello, this is a test of Bark's speech capabilities."),
|
|
308
|
+
]
|
|
309
|
+
|
|
310
|
+
for name, prompt in sfx_prompts_to_test:
|
|
311
|
+
ASCIIColors.cyan(f"\n--- Generating SFX/Audio for: '{name}' ---"); print(f"Prompt: {prompt}")
|
|
312
|
+
try:
|
|
313
|
+
call_kwargs = {}
|
|
314
|
+
if "speech" in name:
|
|
315
|
+
call_kwargs = {"semantic_temperature": 0.6, "coarse_temperature": 0.8, "fine_temperature": 0.5, "do_sample": True}
|
|
316
|
+
elif name == "laser_blast":
|
|
317
|
+
call_kwargs = {"semantic_temperature": 0.5, "coarse_temperature": 0.6, "fine_temperature": 0.4, "do_sample": True}
|
|
318
|
+
else: # For SFX, sometimes more deterministic sampling helps for consistency
|
|
319
|
+
call_kwargs = {"do_sample": True, "semantic_temperature": 0.7, "coarse_temperature": 0.7, "fine_temperature": 0.7}
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
sfx_bytes = ttm_binding.generate_music(prompt, voice_preset=None, **call_kwargs)
|
|
323
|
+
if sfx_bytes:
|
|
324
|
+
output_filename = f"sfx_{name}_{test_model_id.split('/')[-1]}.wav"
|
|
325
|
+
output_path = test_output_dir / output_filename
|
|
326
|
+
with open(output_path, "wb") as f: f.write(sfx_bytes)
|
|
327
|
+
ASCIIColors.green(f"SFX for '{name}' saved to: {output_path} ({len(sfx_bytes) / 1024:.2f} KB)")
|
|
328
|
+
else: ASCIIColors.error(f"SFX generation for '{name}' returned empty bytes.")
|
|
329
|
+
except Exception as e_gen: ASCIIColors.error(f"Failed to generate SFX for '{name}': {e_gen}")
|
|
330
|
+
|
|
331
|
+
except ImportError as e_imp: ASCIIColors.error(f"Import error: {e_imp}")
|
|
332
|
+
except RuntimeError as e_rt: ASCIIColors.error(f"Runtime error: {e_rt}")
|
|
333
|
+
except Exception as e: ASCIIColors.error(f"Unexpected error: {e}"); trace_exception(e)
|
|
334
|
+
finally:
|
|
335
|
+
if ttm_binding: del ttm_binding
|
|
336
|
+
ASCIIColors.info(f"Test SFX (if any) are in: {test_output_dir.resolve()}")
|
|
337
|
+
print(f"{ASCIIColors.YELLOW}Check the audio files in '{test_output_dir.resolve()}'!{ASCIIColors.RESET}")
|
|
338
|
+
|
|
339
|
+
ASCIIColors.yellow("\n--- BarkTTMBinding Test Finished ---")
|