lollms-client 0.15.1__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lollms-client might be problematic. Click here for more details.

@@ -0,0 +1,281 @@
1
+ # lollms_client/ttm_bindings/audiocraft/__init__.py
2
+ import io
3
+ import os
4
+ from pathlib import Path
5
+ from typing import Optional, List, Union, Dict, Any
6
+
7
+ from ascii_colors import trace_exception, ASCIIColors
8
+
9
+ # --- Package Management and Conditional Imports ---
10
+ _audiocraft_installed_with_correct_torch = False
11
+ _audiocraft_installation_error = ""
12
+
13
+ try:
14
+ import pipmaster as pm
15
+ import platform # For OS detection for torch index
16
+
17
+ # Determine initial device preference to guide torch installation
18
+ preferred_torch_device_for_install = "cpu" # Default assumption
19
+
20
+ # Tentatively set preference based on OS, assuming user might want GPU if available
21
+ if platform.system() == "Linux" or platform.system() == "Windows":
22
+ # On Linux/Windows, CUDA is the primary GPU acceleration for PyTorch.
23
+ # We will try to install a CUDA version of PyTorch.
24
+ preferred_torch_device_for_install = "cuda"
25
+ elif platform.system() == "Darwin":
26
+ # On macOS, MPS is the acceleration. Standard torch install usually handles this.
27
+ preferred_torch_device_for_install = "mps" # or keep cpu if mps detection is later
28
+
29
+ torch_pkgs = ["torch", "torchaudio","xformers"]
30
+ audiocraft_core_pkgs = ["audiocraft"]
31
+ other_deps = ["scipy", "numpy"]
32
+
33
+ torch_index_url = None
34
+ if preferred_torch_device_for_install == "cuda":
35
+ # Specify a common CUDA version index. Pip should resolve the correct torch version.
36
+ # As of late 2023/early 2024, cu118 or cu121 are common. Let's use cu121.
37
+ # Users with different CUDA setups might need to pre-install torch manually.
38
+ torch_index_url = "https://download.pytorch.org/whl/cu126"
39
+ ASCIIColors.info(f"Attempting to ensure PyTorch with CUDA support (target index: {torch_index_url})")
40
+ # Install torch and torchaudio first from the specific index
41
+ pm.ensure_packages(torch_pkgs, index_url=torch_index_url)
42
+ # Then install audiocraft and other dependencies; pip should use the already installed torch
43
+ pm.ensure_packages(audiocraft_core_pkgs + other_deps)
44
+ else:
45
+ # For CPU, MPS, or if no specific CUDA preference was determined for install
46
+ ASCIIColors.info("Ensuring PyTorch, AudioCraft, and dependencies using default PyPI index.")
47
+ pm.ensure_packages(torch_pkgs + audiocraft_core_pkgs + other_deps)
48
+
49
+ # Now, perform the actual imports
50
+ import torch, torchaudio
51
+ from audiocraft.models import MusicGen
52
+ from audiocraft.data.audio import audio_write # For saving to bytes
53
+ import numpy as np
54
+ import scipy.io.wavfile # For direct WAV manipulation if needed, though audio_write is preferred
55
+
56
+ _audiocraft_installed_with_correct_torch = True # If imports succeed after ensure_packages
57
+ except Exception as e:
58
+ _audiocraft_installation_error = str(e)
59
+ # Set placeholders if imports fail
60
+ MusicGen, torch, audio_write, np, scipy = None, None, None, None, None
61
+ # --- End Package Management ---
62
+
63
+ from lollms_client.lollms_ttm_binding import LollmsTTMBinding
64
+
65
+ BindingName = "AudioCraftTTMBinding"
66
+
67
+ # Common MusicGen model IDs from Hugging Face
68
+ DEFAULT_AUDIOCRAFT_MODELS = [
69
+ "facebook/musicgen-small",
70
+ "facebook/musicgen-medium",
71
+ "facebook/musicgen-melody", # Can be conditioned on a melody audio file too
72
+ "facebook/musicgen-large",
73
+ ]
74
+
75
+
76
+ class AudioCraftTTMBinding(LollmsTTMBinding):
77
+ def __init__(self,
78
+ model_name: str = "facebook/musicgen-small", # HF ID or local path
79
+ device: Optional[str] = None, # "cpu", "cuda", "mps", or None for auto
80
+ output_format: str = "wav", # 'wav', 'mp3' (mp3 needs ffmpeg via audiocraft)
81
+ # Catch LollmsTTMBinding standard args
82
+ host_address: Optional[str] = None, # Not used by local binding
83
+ service_key: Optional[str] = None, # Not used by local binding
84
+ verify_ssl_certificate: bool = True,# Not used by local binding
85
+ **kwargs): # Catch-all for future compatibility or specific audiocraft params
86
+
87
+ super().__init__(binding_name="audiocraft")
88
+
89
+ if not _audiocraft_installed_with_correct_torch:
90
+ raise ImportError(f"AudioCraft TTM binding dependencies not met. Please ensure 'audiocraft', 'torch', 'torchaudio', 'scipy', 'numpy' are installed. Error: {_audiocraft_installation_error}")
91
+
92
+ self.device = device
93
+ if self.device is None: # Auto-detect if not specified by user
94
+ if torch.cuda.is_available():
95
+ self.device = "cuda"
96
+ ASCIIColors.info("CUDA device detected by PyTorch.")
97
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): # For Apple Silicon
98
+ self.device = "mps"
99
+ ASCIIColors.info("MPS device detected by PyTorch for Apple Silicon.")
100
+ else:
101
+ self.device = "cpu"
102
+ ASCIIColors.info("No GPU (CUDA/MPS) detected by PyTorch, using CPU.")
103
+ elif self.device == "cuda" and not torch.cuda.is_available():
104
+ ASCIIColors.warning("CUDA device requested, but torch.cuda.is_available() is False. Falling back to CPU.")
105
+ self.device = "cpu"
106
+ elif self.device == "mps" and not (hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()):
107
+ ASCIIColors.warning("MPS device requested, but not available. Falling back to CPU.")
108
+ self.device = "cpu"
109
+
110
+
111
+ ASCIIColors.info(f"AudioCraftTTMBinding: Using device '{self.device}'.")
112
+
113
+ self.loaded_model_name = None
114
+ self.model: Optional[MusicGen] = None
115
+ self.output_format = output_format.lower()
116
+ if self.output_format not in ["wav", "mp3"]:
117
+ ASCIIColors.warning(f"Unsupported output_format '{self.output_format}'. Defaulting to 'wav'.")
118
+ self.output_format = "wav"
119
+
120
+ self._load_audiocraft_model(model_name)
121
+
122
+ def _load_audiocraft_model(self, model_name_to_load: str):
123
+ if self.model is not None and self.loaded_model_name == model_name_to_load:
124
+ ASCIIColors.info(f"AudioCraft model '{model_name_to_load}' already loaded.")
125
+ return
126
+
127
+ ASCIIColors.info(f"Loading AudioCraft (MusicGen) model: '{model_name_to_load}' on device '{self.device}'...")
128
+ try:
129
+ self.model = MusicGen.get_pretrained(model_name_to_load, device=self.device)
130
+ self.loaded_model_name = model_name_to_load
131
+ # self.model_name is part of LollmsBinding base, but audiocraft uses loaded_model_name for its own logic.
132
+ # We can assign it for consistency if needed by LollmsClient core, though it's not directly used by this binding's logic post-load.
133
+ # self.model_name = model_name_to_load
134
+
135
+ ASCIIColors.green(f"AudioCraft model '{model_name_to_load}' loaded successfully.")
136
+ except Exception as e:
137
+ self.model = None
138
+ self.loaded_model_name = None
139
+ ASCIIColors.error(f"Failed to load AudioCraft model '{model_name_to_load}': {e}")
140
+ trace_exception(e)
141
+ raise RuntimeError(f"Failed to load AudioCraft model '{model_name_to_load}'") from e
142
+
143
+ def generate_music(self,
144
+ prompt: str,
145
+ duration: int = 8,
146
+ temperature: float = 1.0,
147
+ top_k: int = 250,
148
+ top_p: float = 0.0,
149
+ cfg_coef: float = 3.0,
150
+ progress: bool = True,
151
+ **kwargs) -> bytes:
152
+ if self.model is None:
153
+ raise RuntimeError("AudioCraft model is not loaded. Cannot generate music.")
154
+
155
+ self.model.set_generation_params(
156
+ duration=duration,
157
+ temperature=temperature,
158
+ top_k=top_k,
159
+ top_p=top_p,
160
+ cfg_coef=cfg_coef,
161
+ **kwargs
162
+ )
163
+
164
+ ASCIIColors.info(f"Generating music for prompt: '{prompt[:50]}...' (Duration: {duration}s, Temp: {temperature}, TopK: {top_k}, TopP: {top_p}, CFG: {cfg_coef})")
165
+ try:
166
+ wav_tensor = self.model.generate(descriptions=[prompt], progress=progress)
167
+
168
+ if wav_tensor is None or wav_tensor.numel() == 0:
169
+ raise RuntimeError("MusicGen returned empty audio data.")
170
+
171
+ if wav_tensor.ndim == 3 and wav_tensor.shape[0] == 1:
172
+ wav_tensor_single = wav_tensor.squeeze(0)
173
+ elif wav_tensor.ndim == 2:
174
+ wav_tensor_single = wav_tensor
175
+ else:
176
+ raise ValueError(f"Unexpected tensor shape from MusicGen: {wav_tensor.shape}")
177
+
178
+ buffer = io.BytesIO()
179
+ dummy_filename = f"musicgen_output.{self.output_format}" # For audiocraft's format detection
180
+
181
+ # audio_write needs tensor on CPU
182
+ torchaudio.save(
183
+ buffer,
184
+ wav_tensor_single.cpu(),
185
+ self.model.sample_rate,
186
+ format="wav" # Explicitly WAV
187
+ )
188
+
189
+ audio_bytes = buffer.getvalue()
190
+ buffer.close()
191
+
192
+ ASCIIColors.green("Music generation successful.")
193
+ return audio_bytes
194
+
195
+ except Exception as e:
196
+ ASCIIColors.error(f"AudioCraft music generation failed: {e}")
197
+ trace_exception(e)
198
+ # Provide more specific feedback for common issues
199
+ if "out of memory" in str(e).lower() and self.device == "cuda":
200
+ ASCIIColors.yellow("CUDA out of memory. Consider using a smaller model (e.g., 'facebook/musicgen-small'), a shorter duration, or ensure your GPU has sufficient VRAM (medium models might need ~10-12GB, large ~16GB+).")
201
+ elif "ffmpeg" in str(e).lower() and self.output_format == "mp3":
202
+ ASCIIColors.yellow("An FFmpeg error occurred. Ensure FFmpeg is installed and accessible in your system's PATH if you are generating MP3s.")
203
+ raise RuntimeError(f"AudioCraft music generation error: {e}") from e
204
+
205
+ def list_models(self, **kwargs) -> List[str]:
206
+ return DEFAULT_AUDIOCRAFT_MODELS.copy()
207
+
208
+ def __del__(self):
209
+ if hasattr(self, 'model') and self.model is not None: # Check if model attribute exists
210
+ del self.model
211
+ self.model = None
212
+ if torch and hasattr(torch, 'cuda') and torch.cuda.is_available():
213
+ torch.cuda.empty_cache()
214
+ if hasattr(self, 'loaded_model_name') and self.loaded_model_name: # Check if loaded_model_name exists
215
+ ASCIIColors.info(f"AudioCraftTTMBinding for model '{self.loaded_model_name}' destroyed and resources released.")
216
+ else:
217
+ ASCIIColors.info("AudioCraftTTMBinding destroyed (no model was fully loaded or name not set).")
218
+
219
+
220
+ # --- Main Test Block (Example Usage) ---
221
+ if __name__ == '__main__':
222
+ if not _audiocraft_installed_with_correct_torch:
223
+ print(f"{ASCIIColors.RED}AudioCraft dependencies not met or import failed. Skipping tests. Error: {_audiocraft_installation_error}{ASCIIColors.RESET}")
224
+ exit()
225
+
226
+ ASCIIColors.yellow("--- AudioCraftTTMBinding Test ---")
227
+ test_model_id = "facebook/musicgen-small" # Smallest model for quicker testing
228
+ test_output_dir = Path("./test_audiocraft_output")
229
+ test_output_dir.mkdir(exist_ok=True)
230
+ ttm_binding = None
231
+
232
+ try:
233
+ ASCIIColors.cyan(f"\n--- Initializing AudioCraftTTMBinding (model: '{test_model_id}') ---")
234
+ # Explicitly set device to CPU for basic test if no GPU, or let it auto-detect
235
+ # device_for_test = "cpu" if not (torch and torch.cuda.is_available()) else None
236
+ ttm_binding = AudioCraftTTMBinding(model_name=test_model_id, output_format="wav") # device=device_for_test
237
+
238
+ ASCIIColors.cyan("\n--- Listing common MusicGen models ---")
239
+ models = ttm_binding.list_models()
240
+ print(f"Common MusicGen models: {models}")
241
+
242
+ test_prompt_1 = "A lo-fi hip hop beat with a chill piano melody and soft drums, perfect for studying."
243
+ test_prompt_2 = "Epic orchestral score for a fantasy battle scene, with choirs and horns."
244
+
245
+ prompts_to_test = [
246
+ ("lofi_chill", test_prompt_1),
247
+ ("epic_battle", test_prompt_2),
248
+ ]
249
+
250
+ for name, prompt in prompts_to_test:
251
+ ASCIIColors.cyan(f"\n--- Generating music for: '{name}' (duration 3s) ---")
252
+ print(f"Prompt: {prompt}")
253
+ try:
254
+ music_bytes = ttm_binding.generate_music(prompt, duration=3, progress=True)
255
+
256
+ if music_bytes:
257
+ output_filename = f"test_{name}_{test_model_id.split('/')[-1]}.{ttm_binding.output_format}"
258
+ output_path = test_output_dir / output_filename
259
+ with open(output_path, "wb") as f:
260
+ f.write(music_bytes)
261
+ ASCIIColors.green(f"Music for '{name}' saved to: {output_path} ({len(music_bytes) / 1024:.2f} KB)")
262
+ else:
263
+ ASCIIColors.error(f"Music generation for '{name}' returned empty bytes.")
264
+ except Exception as e_gen:
265
+ ASCIIColors.error(f"Failed to generate music for '{name}': {e_gen}")
266
+ # Error details already printed by generate_music method
267
+
268
+ except ImportError as e_imp:
269
+ ASCIIColors.error(f"Import error during test setup: {e_imp}")
270
+ except RuntimeError as e_rt: # Catch runtime errors from init or generate
271
+ ASCIIColors.error(f"Runtime error during test: {e_rt}")
272
+ except Exception as e:
273
+ ASCIIColors.error(f"An unexpected error occurred during testing: {e}")
274
+ trace_exception(e)
275
+ finally:
276
+ if ttm_binding:
277
+ del ttm_binding
278
+ ASCIIColors.info(f"Test artifacts (if any) are in: {test_output_dir.resolve()}")
279
+ print(f"{ASCIIColors.YELLOW}Remember to check the audio files in '{test_output_dir.resolve()}'!{ASCIIColors.RESET}")
280
+
281
+ ASCIIColors.yellow("\n--- AudioCraftTTMBinding Test Finished ---")
@@ -0,0 +1,339 @@
1
+ # lollms_client/ttm_bindings/bark/__init__.py
2
+ import io
3
+ import os
4
+ from pathlib import Path
5
+ from typing import Optional, List, Union, Dict, Any
6
+
7
+ from ascii_colors import trace_exception, ASCIIColors
8
+
9
+ # --- Package Management and Conditional Imports ---
10
+ _bark_deps_installed_with_correct_torch = False
11
+ _bark_installation_error = ""
12
+ try:
13
+ import pipmaster as pm
14
+ import platform
15
+
16
+ preferred_torch_device_for_install = "cpu"
17
+ if platform.system() == "Linux" or platform.system() == "Windows":
18
+ preferred_torch_device_for_install = "cuda"
19
+ elif platform.system() == "Darwin":
20
+ preferred_torch_device_for_install = "mps"
21
+
22
+ torch_pkgs = ["torch"]
23
+ bark_core_pkgs = ["transformers", "accelerate", "sentencepiece"]
24
+ other_deps = ["scipy", "numpy"]
25
+
26
+ torch_index_url = None
27
+ if preferred_torch_device_for_install == "cuda":
28
+ torch_index_url = "https://download.pytorch.org/whl/cu126"
29
+ ASCIIColors.info(f"Attempting to ensure PyTorch with CUDA support (target index: {torch_index_url}) for Bark binding.")
30
+ pm.ensure_packages(torch_pkgs, index_url=torch_index_url)
31
+ pm.ensure_packages(bark_core_pkgs + other_deps)
32
+ else:
33
+ ASCIIColors.info("Ensuring PyTorch, Bark dependencies, and others using default PyPI index for Bark binding.")
34
+ pm.ensure_packages(torch_pkgs + bark_core_pkgs + other_deps)
35
+
36
+ import torch
37
+ from transformers import AutoProcessor, BarkModel, GenerationConfig
38
+ import scipy.io.wavfile
39
+ import numpy as np
40
+
41
+ _bark_deps_installed_with_correct_torch = True
42
+ except Exception as e:
43
+ _bark_installation_error = str(e)
44
+ AutoProcessor, BarkModel, GenerationConfig, torch, scipy, np = None, None, None, None, None, None
45
+ # --- End Package Management ---
46
+
47
+ from lollms_client.lollms_ttm_binding import LollmsTTMBinding
48
+
49
+ BindingName = "BarkTTMBinding"
50
+
51
+ DEFAULT_BARK_MODELS = [
52
+ "suno/bark",
53
+ "suno/bark-small",
54
+ ]
55
+
56
+ BARK_VOICE_PRESETS_EXAMPLES = [
57
+ "v2/en_speaker_0", "v2/en_speaker_1", "v2/en_speaker_2", "v2/en_speaker_3",
58
+ "v2/en_speaker_4", "v2/en_speaker_5", "v2/en_speaker_6", "v2/en_speaker_7",
59
+ "v2/en_speaker_8", "v2/en_speaker_9",
60
+ "v2/de_speaker_0", "v2/es_speaker_0", "v2/fr_speaker_0",
61
+ ]
62
+
63
+
64
+ class BarkTTMBinding(LollmsTTMBinding):
65
+ def __init__(self,
66
+ model_name: str = "suno/bark-small",
67
+ device: Optional[str] = None,
68
+ default_voice_preset: Optional[str] = "v2/en_speaker_6",
69
+ enable_better_transformer: bool = True,
70
+ **kwargs):
71
+
72
+ super().__init__(binding_name="bark")
73
+
74
+ if not _bark_deps_installed_with_correct_torch:
75
+ raise ImportError(f"Bark TTM binding dependencies not met. Error: {_bark_installation_error}")
76
+
77
+ self.device = device
78
+ if self.device is None:
79
+ if torch.cuda.is_available(): self.device = "cuda"; ASCIIColors.info("CUDA device detected by PyTorch for Bark.")
80
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): self.device = "mps"; ASCIIColors.info("MPS device detected for Bark.")
81
+ else: self.device = "cpu"; ASCIIColors.info("No GPU (CUDA/MPS) by PyTorch, using CPU for Bark.")
82
+ elif self.device == "cuda" and not torch.cuda.is_available(): self.device = "cpu"; ASCIIColors.warning("CUDA req, not avail. CPU for Bark.")
83
+ elif self.device == "mps" and not (hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()): self.device = "cpu"; ASCIIColors.warning("MPS req, not avail. CPU for Bark.")
84
+
85
+ ASCIIColors.info(f"BarkTTMBinding: Using device '{self.device}'.")
86
+
87
+ self.loaded_model_name = None
88
+ self.model: Optional[BarkModel] = None
89
+ self.processor: Optional[AutoProcessor] = None
90
+ self.default_voice_preset = default_voice_preset
91
+ self.enable_better_transformer = enable_better_transformer
92
+
93
+ self.default_generation_params = {}
94
+ temp_gen_config = GenerationConfig()
95
+ for key, value in kwargs.items():
96
+ if hasattr(temp_gen_config, key):
97
+ self.default_generation_params[key] = value
98
+
99
+ self._load_bark_model(model_name)
100
+
101
+ def _load_bark_model(self, model_name_to_load: str):
102
+ if self.model is not None and self.loaded_model_name == model_name_to_load:
103
+ ASCIIColors.info(f"Bark model '{model_name_to_load}' already loaded.")
104
+ return
105
+
106
+ ASCIIColors.info(f"Loading Bark model: '{model_name_to_load}' on device '{self.device}'...")
107
+ try:
108
+ dtype_for_bark = torch.float16 if self.device == "cuda" else None
109
+
110
+ self.processor = AutoProcessor.from_pretrained(model_name_to_load)
111
+ self.model = BarkModel.from_pretrained(
112
+ model_name_to_load,
113
+ torch_dtype=dtype_for_bark,
114
+ low_cpu_mem_usage=True if self.device != "cpu" else False
115
+ ).to(self.device)
116
+
117
+ if self.enable_better_transformer and self.device == "cuda":
118
+ try:
119
+ self.model = self.model.to_bettertransformer()
120
+ ASCIIColors.info("Applied BetterTransformer optimization to Bark model.")
121
+ except Exception as e_bt:
122
+ ASCIIColors.warning(f"Failed to apply BetterTransformer: {e_bt}. Proceeding without it.")
123
+
124
+ if "small" not in model_name_to_load and self.device=="cpu":
125
+ ASCIIColors.warning("Using full Bark model on CPU. Generation might be slow.")
126
+ elif self.device != "cpu" and "small" not in model_name_to_load:
127
+ if hasattr(self.model, "enable_model_cpu_offload"):
128
+ try: self.model.enable_model_cpu_offload(); ASCIIColors.info("Enabled model_cpu_offload for Bark.")
129
+ except Exception as e: ASCIIColors.warning(f"Could not enable model_cpu_offload: {e}")
130
+ elif hasattr(self.model, "enable_cpu_offload"):
131
+ try: self.model.enable_cpu_offload(); ASCIIColors.info("Enabled cpu_offload for Bark (older API).")
132
+ except Exception as e: ASCIIColors.warning(f"Could not enable cpu_offload (older API): {e}")
133
+ else: ASCIIColors.info("CPU offload not explicitly enabled.")
134
+
135
+ self.loaded_model_name = model_name_to_load
136
+ ASCIIColors.green(f"Bark model '{model_name_to_load}' loaded successfully.")
137
+ except Exception as e:
138
+ self.model, self.processor, self.loaded_model_name = None, None, None
139
+ ASCIIColors.error(f"Failed to load Bark model '{model_name_to_load}': {e}"); trace_exception(e)
140
+ raise RuntimeError(f"Failed to load Bark model '{model_name_to_load}'") from e
141
+
142
+ def generate_music(self,
143
+ prompt: str,
144
+ voice_preset: Optional[str] = None,
145
+ do_sample: Optional[bool] = None,
146
+ temperature: Optional[float] = None,
147
+ **kwargs) -> bytes:
148
+ if self.model is None or self.processor is None:
149
+ raise RuntimeError("Bark model or processor not loaded.")
150
+
151
+ effective_voice_preset = voice_preset if voice_preset is not None else self.default_voice_preset
152
+
153
+ ASCIIColors.info(f"Generating SFX/audio with Bark: '{prompt[:60]}...' (Preset: {effective_voice_preset})")
154
+ try:
155
+ # The processor correctly returns 'input_ids' and 'attention_mask'
156
+ inputs = self.processor(
157
+ text=[prompt], # Processor expects a list of texts
158
+ voice_preset=effective_voice_preset,
159
+ return_tensors="pt",
160
+ # Explicitly ask for padding if tokenizer supports it,
161
+ # though Bark's processor might handle this internally.
162
+ # padding=True, # Let processor decide best padding strategy
163
+ # truncation=True # Ensure inputs fit model context
164
+ )
165
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
166
+
167
+ # Ensure attention_mask is present
168
+ if 'attention_mask' not in inputs:
169
+ ASCIIColors.warning("Processor did not return attention_mask. Creating a default one (all ones). This might lead to suboptimal results if padding was intended.")
170
+ inputs['attention_mask'] = torch.ones_like(inputs['input_ids'])
171
+
172
+
173
+ if hasattr(self.model, 'generation_config') and self.model.generation_config is not None:
174
+ gen_config = GenerationConfig.from_dict(self.model.generation_config.to_dict())
175
+ else:
176
+ gen_config = GenerationConfig()
177
+
178
+ for key, value in self.default_generation_params.items():
179
+ if hasattr(gen_config, key): setattr(gen_config, key, value)
180
+
181
+ if do_sample is not None: gen_config.do_sample = do_sample
182
+
183
+ if temperature is not None:
184
+ if 'semantic_temperature' not in kwargs and hasattr(gen_config, 'semantic_temperature'): gen_config.semantic_temperature = temperature
185
+ if 'coarse_temperature' not in kwargs and hasattr(gen_config, 'coarse_temperature'): gen_config.coarse_temperature = temperature
186
+ if 'fine_temperature' not in kwargs and hasattr(gen_config, 'fine_temperature'): gen_config.fine_temperature = temperature
187
+
188
+ for key, value in kwargs.items():
189
+ if hasattr(gen_config, key): setattr(gen_config, key, value)
190
+
191
+ # Critical: Set pad_token_id in GenerationConfig.
192
+ # Bark uses specific token IDs for its different codebooks.
193
+ # The processor's tokenizer should have the correct pad_token_id if it's used for text inputs.
194
+ # For Bark, the semantic vocabulary has its own pad_token_id, often same as EOS.
195
+ # Let's try to get it from the model's semantic config or text config.
196
+ pad_token_id_to_set = None
197
+ if hasattr(self.model.config, 'semantic_config') and hasattr(self.model.config.semantic_config, 'pad_token_id'):
198
+ pad_token_id_to_set = self.model.config.semantic_config.pad_token_id
199
+ elif hasattr(self.model.config, 'text_config') and hasattr(self.model.config.text_config, 'pad_token_id'):
200
+ pad_token_id_to_set = self.model.config.text_config.pad_token_id
201
+ elif hasattr(self.processor, 'tokenizer') and self.processor.tokenizer and self.processor.tokenizer.pad_token_id is not None:
202
+ pad_token_id_to_set = self.processor.tokenizer.pad_token_id
203
+
204
+ if pad_token_id_to_set is not None:
205
+ gen_config.pad_token_id = pad_token_id_to_set
206
+ # Also set EOS token if it's distinct and meaningful for generation stopping
207
+ if hasattr(gen_config, 'eos_token_id') and gen_config.eos_token_id is None:
208
+ eos_id = None
209
+ if hasattr(self.model.config, 'semantic_config') and hasattr(self.model.config.semantic_config, 'eos_token_id'):
210
+ eos_id = self.model.config.semantic_config.eos_token_id
211
+ if eos_id is not None:
212
+ gen_config.eos_token_id = eos_id
213
+
214
+ else:
215
+ # This state is problematic for Bark if pad_token_id is truly needed and distinct from EOS
216
+ ASCIIColors.warning("Could not determine a specific pad_token_id from Bark's config for GenerationConfig. This might lead to issues.")
217
+ # If eos_token_id is also not set, generation might not stop correctly.
218
+ # Defaulting pad_token_id to eos_token_id if eos_token_id exists.
219
+ if gen_config.eos_token_id is not None:
220
+ gen_config.pad_token_id = gen_config.eos_token_id
221
+ ASCIIColors.info(f"Setting pad_token_id to eos_token_id ({gen_config.eos_token_id}) as a fallback.")
222
+ else:
223
+ # This is a last resort and might not be correct for Bark specifically
224
+ gen_config.pad_token_id = 0
225
+ ASCIIColors.warning("pad_token_id defaulted to 0 as a last resort.")
226
+
227
+
228
+ ASCIIColors.debug(f"Bark final generation_config: {gen_config.to_json_string()}")
229
+
230
+ with torch.no_grad():
231
+ output = self.model.generate(
232
+ input_ids=inputs['input_ids'], # Explicitly pass input_ids
233
+ attention_mask=inputs.get('attention_mask'), # Pass attention_mask if available
234
+ generation_config=gen_config
235
+ )
236
+
237
+ if isinstance(output, torch.Tensor): speech_output_tensor = output
238
+ elif isinstance(output, dict) and "audio_features" in output: speech_output_tensor = output["audio_features"]
239
+ elif isinstance(output, dict) and "waveform" in output: speech_output_tensor = output["waveform"] # Bark might return this key
240
+ else: raise TypeError(f"Unexpected output type from BarkModel.generate: {type(output)}. Content: {output}")
241
+
242
+ audio_array_np = speech_output_tensor.cpu().numpy().squeeze()
243
+ if audio_array_np.ndim == 0 or audio_array_np.size == 0:
244
+ raise RuntimeError("Bark model returned empty audio data.")
245
+
246
+ audio_int16 = (audio_array_np * 32767).astype(np.int16)
247
+
248
+ buffer = io.BytesIO()
249
+ sample_rate_to_use = int(self.model.generation_config.sample_rate if hasattr(self.model.generation_config, 'sample_rate') and self.model.generation_config.sample_rate else 24_000)
250
+ scipy.io.wavfile.write(buffer, rate=sample_rate_to_use, data=audio_int16)
251
+ audio_bytes = buffer.getvalue()
252
+ buffer.close()
253
+
254
+ ASCIIColors.green("Bark audio generation successful.")
255
+ return audio_bytes
256
+ except Exception as e:
257
+ ASCIIColors.error(f"Bark audio generation failed: {e}"); trace_exception(e)
258
+ if "out of memory" in str(e).lower() and self.device == "cuda":
259
+ ASCIIColors.yellow("CUDA out of memory. Consider using suno/bark-small or ensure GPU has sufficient VRAM.")
260
+ raise RuntimeError(f"Bark audio generation error: {e}") from e
261
+
262
+ def list_models(self, **kwargs) -> List[str]:
263
+ return DEFAULT_BARK_MODELS.copy()
264
+
265
+ def list_voice_presets(self) -> List[str]:
266
+ return BARK_VOICE_PRESETS_EXAMPLES.copy()
267
+
268
+ def __del__(self):
269
+ if hasattr(self, 'model') and self.model is not None:
270
+ del self.model; self.model = None
271
+ if hasattr(self, 'processor') and self.processor is not None:
272
+ del self.processor; self.processor = None
273
+ if torch and hasattr(torch, 'cuda') and torch.cuda.is_available():
274
+ torch.cuda.empty_cache()
275
+ loaded_name = getattr(self, 'loaded_model_name', None)
276
+ msg = f"BarkTTMBinding for model '{loaded_name}' destroyed." if loaded_name else "BarkTTMBinding destroyed."
277
+ ASCIIColors.info(msg)
278
+
279
+ # --- Main Test Block ---
280
+ if __name__ == '__main__':
281
+ if not _bark_deps_installed_with_correct_torch:
282
+ print(f"{ASCIIColors.RED}Bark TTM binding dependencies not met. Skipping tests. Error: {_bark_installation_error}{ASCIIColors.RESET}")
283
+ exit()
284
+
285
+ ASCIIColors.yellow("--- BarkTTMBinding Test ---")
286
+ test_model_id = "suno/bark-small"
287
+ test_output_dir = Path("./test_bark_sfx_output")
288
+ test_output_dir.mkdir(exist_ok=True)
289
+ ttm_binding = None
290
+
291
+ try:
292
+ ASCIIColors.cyan(f"\n--- Initializing BarkTTMBinding (model: '{test_model_id}') ---")
293
+ ttm_binding = BarkTTMBinding(model_name=test_model_id)
294
+
295
+ ASCIIColors.cyan("\n--- Listing common Bark models ---")
296
+ models = ttm_binding.list_models(); print(f"Common Bark models: {models}")
297
+ ASCIIColors.cyan("\n--- Listing example Bark voice presets ---")
298
+ presets = ttm_binding.list_voice_presets(); print(f"Example presets: {presets[:5]}...")
299
+
300
+ sfx_prompts_to_test = [
301
+ ("laser_blast", "A short, sharp laser blast sound effect [SFX]"),
302
+ ("footsteps_gravel", "Footsteps walking on gravel [footsteps]."),
303
+ ("explosion_distant", "A distant explosion [boom] with a slight echo."),
304
+ ("interface_click", "A clean, quick digital interface click sound. [click]"),
305
+ ("creature_roar_short", "[roar] A short, guttural creature roar."),
306
+ ("ambient_wind", "[wind] Gentle wind blowing through trees."),
307
+ ("speech_hello", "Hello, this is a test of Bark's speech capabilities."),
308
+ ]
309
+
310
+ for name, prompt in sfx_prompts_to_test:
311
+ ASCIIColors.cyan(f"\n--- Generating SFX/Audio for: '{name}' ---"); print(f"Prompt: {prompt}")
312
+ try:
313
+ call_kwargs = {}
314
+ if "speech" in name:
315
+ call_kwargs = {"semantic_temperature": 0.6, "coarse_temperature": 0.8, "fine_temperature": 0.5, "do_sample": True}
316
+ elif name == "laser_blast":
317
+ call_kwargs = {"semantic_temperature": 0.5, "coarse_temperature": 0.6, "fine_temperature": 0.4, "do_sample": True}
318
+ else: # For SFX, sometimes more deterministic sampling helps for consistency
319
+ call_kwargs = {"do_sample": True, "semantic_temperature": 0.7, "coarse_temperature": 0.7, "fine_temperature": 0.7}
320
+
321
+
322
+ sfx_bytes = ttm_binding.generate_music(prompt, voice_preset=None, **call_kwargs)
323
+ if sfx_bytes:
324
+ output_filename = f"sfx_{name}_{test_model_id.split('/')[-1]}.wav"
325
+ output_path = test_output_dir / output_filename
326
+ with open(output_path, "wb") as f: f.write(sfx_bytes)
327
+ ASCIIColors.green(f"SFX for '{name}' saved to: {output_path} ({len(sfx_bytes) / 1024:.2f} KB)")
328
+ else: ASCIIColors.error(f"SFX generation for '{name}' returned empty bytes.")
329
+ except Exception as e_gen: ASCIIColors.error(f"Failed to generate SFX for '{name}': {e_gen}")
330
+
331
+ except ImportError as e_imp: ASCIIColors.error(f"Import error: {e_imp}")
332
+ except RuntimeError as e_rt: ASCIIColors.error(f"Runtime error: {e_rt}")
333
+ except Exception as e: ASCIIColors.error(f"Unexpected error: {e}"); trace_exception(e)
334
+ finally:
335
+ if ttm_binding: del ttm_binding
336
+ ASCIIColors.info(f"Test SFX (if any) are in: {test_output_dir.resolve()}")
337
+ print(f"{ASCIIColors.YELLOW}Check the audio files in '{test_output_dir.resolve()}'!{ASCIIColors.RESET}")
338
+
339
+ ASCIIColors.yellow("\n--- BarkTTMBinding Test Finished ---")