lollms-client 1.4.1__py3-none-any.whl → 1.7.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lollms_client/__init__.py +1 -1
- lollms_client/llm_bindings/azure_openai/__init__.py +2 -2
- lollms_client/llm_bindings/claude/__init__.py +125 -34
- lollms_client/llm_bindings/gemini/__init__.py +261 -159
- lollms_client/llm_bindings/grok/__init__.py +52 -14
- lollms_client/llm_bindings/groq/__init__.py +2 -2
- lollms_client/llm_bindings/hugging_face_inference_api/__init__.py +2 -2
- lollms_client/llm_bindings/litellm/__init__.py +1 -1
- lollms_client/llm_bindings/llamacpp/__init__.py +18 -11
- lollms_client/llm_bindings/lollms/__init__.py +151 -32
- lollms_client/llm_bindings/lollms_webui/__init__.py +1 -1
- lollms_client/llm_bindings/mistral/__init__.py +2 -2
- lollms_client/llm_bindings/novita_ai/__init__.py +439 -0
- lollms_client/llm_bindings/ollama/__init__.py +309 -93
- lollms_client/llm_bindings/open_router/__init__.py +2 -2
- lollms_client/llm_bindings/openai/__init__.py +148 -29
- lollms_client/llm_bindings/openllm/__init__.py +362 -506
- lollms_client/llm_bindings/openwebui/__init__.py +465 -0
- lollms_client/llm_bindings/perplexity/__init__.py +326 -0
- lollms_client/llm_bindings/pythonllamacpp/__init__.py +3 -3
- lollms_client/llm_bindings/tensor_rt/__init__.py +1 -1
- lollms_client/llm_bindings/transformers/__init__.py +428 -632
- lollms_client/llm_bindings/vllm/__init__.py +1 -1
- lollms_client/lollms_agentic.py +4 -2
- lollms_client/lollms_base_binding.py +61 -0
- lollms_client/lollms_core.py +516 -1890
- lollms_client/lollms_discussion.py +55 -18
- lollms_client/lollms_llm_binding.py +112 -261
- lollms_client/lollms_mcp_binding.py +34 -75
- lollms_client/lollms_personality.py +5 -2
- lollms_client/lollms_stt_binding.py +85 -52
- lollms_client/lollms_tti_binding.py +23 -37
- lollms_client/lollms_ttm_binding.py +24 -42
- lollms_client/lollms_tts_binding.py +28 -17
- lollms_client/lollms_ttv_binding.py +24 -42
- lollms_client/lollms_types.py +4 -2
- lollms_client/stt_bindings/whisper/__init__.py +108 -23
- lollms_client/stt_bindings/whispercpp/__init__.py +7 -1
- lollms_client/tti_bindings/diffusers/__init__.py +418 -810
- lollms_client/tti_bindings/diffusers/server/main.py +1051 -0
- lollms_client/tti_bindings/gemini/__init__.py +182 -239
- lollms_client/tti_bindings/leonardo_ai/__init__.py +127 -0
- lollms_client/tti_bindings/lollms/__init__.py +4 -1
- lollms_client/tti_bindings/novita_ai/__init__.py +105 -0
- lollms_client/tti_bindings/openai/__init__.py +10 -11
- lollms_client/tti_bindings/stability_ai/__init__.py +178 -0
- lollms_client/ttm_bindings/audiocraft/__init__.py +7 -12
- lollms_client/ttm_bindings/beatoven_ai/__init__.py +129 -0
- lollms_client/ttm_bindings/lollms/__init__.py +4 -17
- lollms_client/ttm_bindings/replicate/__init__.py +115 -0
- lollms_client/ttm_bindings/stability_ai/__init__.py +117 -0
- lollms_client/ttm_bindings/topmediai/__init__.py +96 -0
- lollms_client/tts_bindings/bark/__init__.py +7 -10
- lollms_client/tts_bindings/lollms/__init__.py +6 -1
- lollms_client/tts_bindings/piper_tts/__init__.py +8 -11
- lollms_client/tts_bindings/xtts/__init__.py +157 -74
- lollms_client/tts_bindings/xtts/server/main.py +241 -280
- {lollms_client-1.4.1.dist-info → lollms_client-1.7.10.dist-info}/METADATA +316 -6
- lollms_client-1.7.10.dist-info/RECORD +89 -0
- lollms_client/ttm_bindings/bark/__init__.py +0 -339
- lollms_client-1.4.1.dist-info/RECORD +0 -78
- {lollms_client-1.4.1.dist-info → lollms_client-1.7.10.dist-info}/WHEEL +0 -0
- {lollms_client-1.4.1.dist-info → lollms_client-1.7.10.dist-info}/licenses/LICENSE +0 -0
- {lollms_client-1.4.1.dist-info → lollms_client-1.7.10.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import requests
|
|
3
|
+
import base64
|
|
4
|
+
from io import BytesIO
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Optional, List, Dict, Any, Union
|
|
7
|
+
|
|
8
|
+
from lollms_client.lollms_tti_binding import LollmsTTIBinding
|
|
9
|
+
from ascii_colors import trace_exception, ASCIIColors
|
|
10
|
+
import pipmaster as pm
|
|
11
|
+
|
|
12
|
+
pm.ensure_packages(["requests"])
|
|
13
|
+
|
|
14
|
+
BindingName = "NovitaAITTIBinding"
|
|
15
|
+
|
|
16
|
+
# Sourced from https://docs.novita.ai/image-generation/models
|
|
17
|
+
NOVITA_AI_MODELS = [
|
|
18
|
+
{"model_name": "sd_xl_base_1.0.safetensors", "display_name": "Stable Diffusion XL 1.0", "description": "Official SDXL 1.0 Base model."},
|
|
19
|
+
{"model_name": "dreamshaper_xl_1_0.safetensors", "display_name": "DreamShaper XL 1.0", "description": "Versatile artistic SDXL model."},
|
|
20
|
+
{"model_name": "juggernaut_xl_v9_rundiffusion.safetensors", "display_name": "Juggernaut XL v9", "description": "High-quality realistic and cinematic model."},
|
|
21
|
+
{"model_name": "realistic_vision_v5.1.safetensors", "display_name": "Realistic Vision v5.1", "description": "Popular photorealistic SD1.5 model."},
|
|
22
|
+
{"model_name": "absolutereality_v1.8.1.safetensors", "display_name": "Absolute Reality v1.8.1", "description": "General-purpose realistic SD1.5 model."},
|
|
23
|
+
{"model_name": "meinamix_meina_v11.safetensors", "display_name": "MeinaMix v11", "description": "High-quality anime illustration model."},
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
class NovitaAITTIBinding(LollmsTTIBinding):
|
|
27
|
+
"""Novita.ai TTI binding for LoLLMS"""
|
|
28
|
+
|
|
29
|
+
def __init__(self, **kwargs):
|
|
30
|
+
# Prioritize 'model_name' but accept 'model' as an alias from config files.
|
|
31
|
+
if 'model' in kwargs and 'model_name' not in kwargs:
|
|
32
|
+
kwargs['model_name'] = kwargs.pop('model')
|
|
33
|
+
super().__init__(binding_name=BindingName, config=kwargs)
|
|
34
|
+
self.config = kwargs
|
|
35
|
+
self.api_key = self.config.get("api_key") or os.environ.get("NOVITA_API_KEY")
|
|
36
|
+
if not self.api_key:
|
|
37
|
+
raise ValueError("Novita.ai API key is required.")
|
|
38
|
+
self.model_name = self.config.get("model_name", "juggernaut_xl_v9_rundiffusion.safetensors")
|
|
39
|
+
self.base_url = "https://api.novita.ai/v3"
|
|
40
|
+
self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
|
41
|
+
|
|
42
|
+
def list_models(self) -> list:
|
|
43
|
+
return NOVITA_AI_MODELS
|
|
44
|
+
|
|
45
|
+
def generate_image(self, prompt: str, negative_prompt: str = "", width: int = 1024, height: int = 1024, **kwargs) -> bytes:
|
|
46
|
+
url = f"{self.base_url}/text2img"
|
|
47
|
+
payload = {
|
|
48
|
+
"model_name": self.model_name,
|
|
49
|
+
"prompt": prompt,
|
|
50
|
+
"negative_prompt": negative_prompt,
|
|
51
|
+
"width": width,
|
|
52
|
+
"height": height,
|
|
53
|
+
"sampler_name": "DPM++ 2M Karras",
|
|
54
|
+
"cfg_scale": kwargs.get("guidance_scale", 7.0),
|
|
55
|
+
"steps": kwargs.get("num_inference_steps", 25),
|
|
56
|
+
"seed": kwargs.get("seed", -1),
|
|
57
|
+
"n_iter": 1,
|
|
58
|
+
"batch_size": 1
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
try:
|
|
62
|
+
ASCIIColors.info(f"Requesting image from Novita.ai ({self.model_name})...")
|
|
63
|
+
response = requests.post(url, json=payload, headers=self.headers)
|
|
64
|
+
response.raise_for_status()
|
|
65
|
+
data = response.json()
|
|
66
|
+
if "images" not in data or not data["images"]:
|
|
67
|
+
raise Exception(f"API returned no images. Response: {data}")
|
|
68
|
+
|
|
69
|
+
b64_image = data["images"][0]["image_base64"]
|
|
70
|
+
return base64.b64decode(b64_image)
|
|
71
|
+
|
|
72
|
+
except Exception as e:
|
|
73
|
+
trace_exception(e)
|
|
74
|
+
try:
|
|
75
|
+
error_msg = response.json()
|
|
76
|
+
raise Exception(f"Novita.ai API error: {error_msg}")
|
|
77
|
+
except:
|
|
78
|
+
raise Exception(f"Novita.ai API request failed: {e}")
|
|
79
|
+
|
|
80
|
+
def edit_image(self, **kwargs) -> bytes:
|
|
81
|
+
ASCIIColors.warning("Novita.ai edit_image (inpainting/img2img) is not yet implemented in this binding.")
|
|
82
|
+
raise NotImplementedError("This binding does not yet support image editing.")
|
|
83
|
+
|
|
84
|
+
if __name__ == '__main__':
|
|
85
|
+
ASCIIColors.magenta("--- Novita.ai TTI Binding Test ---")
|
|
86
|
+
if "NOVITA_API_KEY" not in os.environ:
|
|
87
|
+
ASCIIColors.error("NOVITA_API_KEY environment variable not set. Cannot run test.")
|
|
88
|
+
exit(1)
|
|
89
|
+
|
|
90
|
+
try:
|
|
91
|
+
binding = NovitaAITTIBinding()
|
|
92
|
+
|
|
93
|
+
ASCIIColors.cyan("\n--- Test: Text-to-Image ---")
|
|
94
|
+
prompt = "A cute capybara wearing a top hat, sitting in a field of flowers, painterly style"
|
|
95
|
+
img_bytes = binding.generate_image(prompt, width=1024, height=1024, num_inference_steps=30)
|
|
96
|
+
|
|
97
|
+
assert len(img_bytes) > 1000
|
|
98
|
+
output_path = Path(__file__).parent / "tmp_novita_t2i.png"
|
|
99
|
+
with open(output_path, "wb") as f:
|
|
100
|
+
f.write(img_bytes)
|
|
101
|
+
ASCIIColors.green(f"Text-to-Image generation OK. Image saved to {output_path}")
|
|
102
|
+
|
|
103
|
+
except Exception as e:
|
|
104
|
+
trace_exception(e)
|
|
105
|
+
ASCIIColors.error(f"Novita.ai binding test failed: {e}")
|
|
@@ -6,7 +6,7 @@ from io import BytesIO
|
|
|
6
6
|
from ascii_colors import trace_exception
|
|
7
7
|
from openai import OpenAI
|
|
8
8
|
from lollms_client.lollms_tti_binding import LollmsTTIBinding
|
|
9
|
-
|
|
9
|
+
import os
|
|
10
10
|
BindingName = "OpenAITTIBinding"
|
|
11
11
|
|
|
12
12
|
|
|
@@ -50,19 +50,18 @@ class OpenAITTIBinding(LollmsTTIBinding):
|
|
|
50
50
|
|
|
51
51
|
def __init__(
|
|
52
52
|
self,
|
|
53
|
-
model: str = "gpt-image-1",
|
|
54
|
-
api_key: Optional[str] = None,
|
|
55
|
-
size: str = "1024x1024",
|
|
56
|
-
n: int = 1,
|
|
57
|
-
quality: str = "standard",
|
|
58
53
|
**kwargs,
|
|
59
54
|
):
|
|
60
|
-
|
|
55
|
+
# Prioritize 'model_name' but accept 'model' as an alias from config files.
|
|
56
|
+
if 'model' in kwargs and 'model_name' not in kwargs:
|
|
57
|
+
kwargs['model_name'] = kwargs.pop('model')
|
|
58
|
+
super().__init__(binding_name=BindingName, config=kwargs)
|
|
59
|
+
self.client = OpenAI(api_key=kwargs.get("api_key" or os.environ.get("OPENAI_API_KEY")))
|
|
61
60
|
self.global_params = {
|
|
62
|
-
"model":
|
|
63
|
-
"size": size,
|
|
64
|
-
"n": n,
|
|
65
|
-
"quality": quality,
|
|
61
|
+
"model": kwargs.get("model_name") or "gpt-image-1",
|
|
62
|
+
"size": kwargs.get("size", "1024x1024"),
|
|
63
|
+
"n": kwargs.get("n", 1),
|
|
64
|
+
"quality": kwargs.get("quality", "standard"),
|
|
66
65
|
}
|
|
67
66
|
|
|
68
67
|
def _resolve_param(self, name: str, kwargs: Dict[str, Any], default: Any) -> Any:
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import requests
|
|
3
|
+
import base64
|
|
4
|
+
from io import BytesIO
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Optional, List, Dict, Any, Union
|
|
7
|
+
|
|
8
|
+
from lollms_client.lollms_tti_binding import LollmsTTIBinding
|
|
9
|
+
from ascii_colors import trace_exception, ASCIIColors
|
|
10
|
+
import pipmaster as pm
|
|
11
|
+
|
|
12
|
+
pm.ensure_packages(["requests", "Pillow"])
|
|
13
|
+
|
|
14
|
+
from PIL import Image
|
|
15
|
+
|
|
16
|
+
BindingName = "StabilityAITTIBinding"
|
|
17
|
+
|
|
18
|
+
# Sourced from https://platform.stability.ai/docs/getting-started/models
|
|
19
|
+
STABILITY_AI_MODELS = [
|
|
20
|
+
# SD3
|
|
21
|
+
{"model_name": "stable-diffusion-3-medium", "display_name": "Stable Diffusion 3 Medium", "description": "Most advanced text-to-image model.", "owned_by": "Stability AI"},
|
|
22
|
+
{"model_name": "stable-diffusion-3-large", "display_name": "Stable Diffusion 3 Large", "description": "Most advanced model with higher quality.", "owned_by": "Stability AI"},
|
|
23
|
+
{"model_name": "stable-diffusion-3-large-turbo", "display_name": "Stable Diffusion 3 Large Turbo", "description": "Fast, high-quality generation.", "owned_by": "Stability AI"},
|
|
24
|
+
# SDXL
|
|
25
|
+
{"model_name": "stable-diffusion-xl-1024-v1-0", "display_name": "Stable Diffusion XL 1.0", "description": "High-quality 1024x1024 generation.", "owned_by": "Stability AI"},
|
|
26
|
+
{"model_name": "stable-diffusion-xl-beta-v2-2-2", "display_name": "SDXL Beta", "description": "Legacy anime-focused SDXL model.", "owned_by": "Stability AI"},
|
|
27
|
+
# SD 1.x & 2.x
|
|
28
|
+
{"model_name": "stable-diffusion-v1-6", "display_name": "Stable Diffusion 1.6", "description": "Improved version of SD 1.5.", "owned_by": "Stability AI"},
|
|
29
|
+
{"model_name": "stable-diffusion-2-1", "display_name": "Stable Diffusion 2.1", "description": "768x768 native resolution model.", "owned_by": "Stability AI"},
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
class StabilityAITTIBinding(LollmsTTIBinding):
|
|
33
|
+
"""Stability AI TTI binding for LoLLMS"""
|
|
34
|
+
|
|
35
|
+
def __init__(self, **kwargs):
|
|
36
|
+
# Prioritize 'model_name' but accept 'model' as an alias from config files.
|
|
37
|
+
if 'model' in kwargs and 'model_name' not in kwargs:
|
|
38
|
+
kwargs['model_name'] = kwargs.pop('model')
|
|
39
|
+
super().__init__(binding_name=BindingName, config=kwargs)
|
|
40
|
+
self.api_key = self.config.get("api_key") or os.environ.get("STABILITY_API_KEY")
|
|
41
|
+
if not self.api_key:
|
|
42
|
+
raise ValueError("Stability AI API key is required. Please set it in the configuration or as STABILITY_API_KEY environment variable.")
|
|
43
|
+
self.model_name = self.config.get("model_name", "stable-diffusion-3-medium")
|
|
44
|
+
|
|
45
|
+
def list_models(self) -> list:
|
|
46
|
+
return STABILITY_AI_MODELS
|
|
47
|
+
|
|
48
|
+
def _get_api_url(self, task: str) -> str:
|
|
49
|
+
base_url = "https://api.stability.ai/v2beta/stable-image"
|
|
50
|
+
# SD3 models use a different endpoint structure
|
|
51
|
+
if "stable-diffusion-3" in self.model_name:
|
|
52
|
+
return f"{base_url}/generate/sd3"
|
|
53
|
+
|
|
54
|
+
task_map = {
|
|
55
|
+
"text2image": "generate/core",
|
|
56
|
+
"image2image": "edit/image-to-image",
|
|
57
|
+
"inpainting": "edit/in-painting",
|
|
58
|
+
"upscale": "edit/upscale"
|
|
59
|
+
}
|
|
60
|
+
if task not in task_map:
|
|
61
|
+
raise ValueError(f"Unsupported task for this model family: {task}")
|
|
62
|
+
return f"{base_url}/{task_map[task]}"
|
|
63
|
+
|
|
64
|
+
def _decode_image_input(self, item: Union[str, Path, bytes]) -> Image.Image:
|
|
65
|
+
if isinstance(item, bytes):
|
|
66
|
+
return Image.open(BytesIO(item))
|
|
67
|
+
s = str(item).strip()
|
|
68
|
+
if s.startswith("data:image/") and ";base64," in s:
|
|
69
|
+
b64 = s.split(";base64,")[-1]
|
|
70
|
+
return Image.open(BytesIO(base64.b64decode(b64)))
|
|
71
|
+
try:
|
|
72
|
+
p = Path(s)
|
|
73
|
+
if p.exists():
|
|
74
|
+
return Image.open(p)
|
|
75
|
+
except:
|
|
76
|
+
pass
|
|
77
|
+
if s.startswith("http"):
|
|
78
|
+
response = requests.get(s, stream=True)
|
|
79
|
+
response.raise_for_status()
|
|
80
|
+
return Image.open(response.raw)
|
|
81
|
+
# Fallback for raw base64
|
|
82
|
+
return Image.open(BytesIO(base64.b64decode(s)))
|
|
83
|
+
|
|
84
|
+
def generate_image(self, prompt: str, negative_prompt: str = "", width: int = 1024, height: int = 1024, **kwargs) -> bytes:
|
|
85
|
+
url = self._get_api_url("text2image")
|
|
86
|
+
|
|
87
|
+
data = {
|
|
88
|
+
"prompt": prompt,
|
|
89
|
+
"negative_prompt": negative_prompt,
|
|
90
|
+
"output_format": "png",
|
|
91
|
+
"seed": kwargs.get("seed", 0)
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
# SD3 uses aspect_ratio, older models use width/height
|
|
95
|
+
if "stable-diffusion-3" in self.model_name:
|
|
96
|
+
data["aspect_ratio"] = f"{width}:{height}"
|
|
97
|
+
data["model"] = self.model_name
|
|
98
|
+
else:
|
|
99
|
+
data["width"] = width
|
|
100
|
+
data["height"] = height
|
|
101
|
+
data["style_preset"] = kwargs.get("style_preset", "photographic")
|
|
102
|
+
|
|
103
|
+
headers = {"authorization": f"Bearer {self.api_key}", "accept": "image/*"}
|
|
104
|
+
|
|
105
|
+
try:
|
|
106
|
+
ASCIIColors.info(f"Requesting image from Stability AI ({self.model_name})...")
|
|
107
|
+
response = requests.post(url, headers=headers, files={"none": ''}, data=data)
|
|
108
|
+
response.raise_for_status()
|
|
109
|
+
return response.content
|
|
110
|
+
except Exception as e:
|
|
111
|
+
trace_exception(e)
|
|
112
|
+
try:
|
|
113
|
+
error_msg = response.json()
|
|
114
|
+
raise Exception(f"Stability AI API error: {error_msg}")
|
|
115
|
+
except:
|
|
116
|
+
raise Exception(f"Stability AI API request failed: {e}")
|
|
117
|
+
|
|
118
|
+
def edit_image(self, images: Union[str, List[str]], prompt: str, negative_prompt: Optional[str] = "", mask: Optional[str] = None, **kwargs) -> bytes:
|
|
119
|
+
init_image_bytes = BytesIO()
|
|
120
|
+
init_image = self._decode_image_input(images[0] if isinstance(images, list) else images)
|
|
121
|
+
init_image.save(init_image_bytes, format="PNG")
|
|
122
|
+
|
|
123
|
+
task = "inpainting" if mask else "image2image"
|
|
124
|
+
url = self._get_api_url(task)
|
|
125
|
+
|
|
126
|
+
files = {"image": init_image_bytes.getvalue()}
|
|
127
|
+
data = {
|
|
128
|
+
"prompt": prompt,
|
|
129
|
+
"negative_prompt": negative_prompt or "",
|
|
130
|
+
"output_format": "png",
|
|
131
|
+
"seed": kwargs.get("seed", 0)
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
if task == "inpainting":
|
|
135
|
+
mask_image_bytes = BytesIO()
|
|
136
|
+
mask_image = self._decode_image_input(mask)
|
|
137
|
+
mask_image.save(mask_image_bytes, format="PNG")
|
|
138
|
+
files["mask"] = mask_image_bytes.getvalue()
|
|
139
|
+
else: # image2image
|
|
140
|
+
data["strength"] = kwargs.get("strength", 0.6) # mode IMAGE_STRENGTH
|
|
141
|
+
|
|
142
|
+
headers = {"authorization": f"Bearer {self.api_key}", "accept": "image/*"}
|
|
143
|
+
|
|
144
|
+
try:
|
|
145
|
+
ASCIIColors.info(f"Requesting image edit from Stability AI ({self.model_name})...")
|
|
146
|
+
response = requests.post(url, headers=headers, files=files, data=data)
|
|
147
|
+
response.raise_for_status()
|
|
148
|
+
return response.content
|
|
149
|
+
except Exception as e:
|
|
150
|
+
trace_exception(e)
|
|
151
|
+
try:
|
|
152
|
+
error_msg = response.json()
|
|
153
|
+
raise Exception(f"Stability AI API error: {error_msg}")
|
|
154
|
+
except:
|
|
155
|
+
raise Exception(f"Stability AI API request failed: {e}")
|
|
156
|
+
|
|
157
|
+
if __name__ == '__main__':
|
|
158
|
+
ASCIIColors.magenta("--- Stability AI TTI Binding Test ---")
|
|
159
|
+
if "STABILITY_API_KEY" not in os.environ:
|
|
160
|
+
ASCIIColors.error("STABILITY_API_KEY environment variable not set. Cannot run test.")
|
|
161
|
+
exit(1)
|
|
162
|
+
|
|
163
|
+
try:
|
|
164
|
+
binding = StabilityAITTIBinding(model_name="stable-diffusion-3-medium")
|
|
165
|
+
|
|
166
|
+
ASCIIColors.cyan("\n--- Test: Text-to-Image ---")
|
|
167
|
+
prompt = "a cinematic photo of a robot drinking coffee in a Parisian cafe"
|
|
168
|
+
img_bytes = binding.generate_image(prompt, width=1024, height=1024)
|
|
169
|
+
|
|
170
|
+
assert len(img_bytes) > 1000, "Generated image bytes are too small."
|
|
171
|
+
output_path = Path(__file__).parent / "tmp_stability_t2i.png"
|
|
172
|
+
with open(output_path, "wb") as f:
|
|
173
|
+
f.write(img_bytes)
|
|
174
|
+
ASCIIColors.green(f"Text-to-Image generation OK. Image saved to {output_path}")
|
|
175
|
+
|
|
176
|
+
except Exception as e:
|
|
177
|
+
trace_exception(e)
|
|
178
|
+
ASCIIColors.error(f"Stability AI binding test failed: {e}")
|
|
@@ -75,21 +75,16 @@ DEFAULT_AUDIOCRAFT_MODELS = [
|
|
|
75
75
|
|
|
76
76
|
class AudioCraftTTMBinding(LollmsTTMBinding):
|
|
77
77
|
def __init__(self,
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
service_key: Optional[str] = None, # Not used by local binding
|
|
84
|
-
verify_ssl_certificate: bool = True,# Not used by local binding
|
|
85
|
-
**kwargs): # Catch-all for future compatibility or specific audiocraft params
|
|
86
|
-
|
|
87
|
-
super().__init__(binding_name="audiocraft")
|
|
78
|
+
**kwargs):
|
|
79
|
+
# Prioritize 'model_name' but accept 'model' as an alias from config files.
|
|
80
|
+
if 'model' in kwargs and 'model_name' not in kwargs:
|
|
81
|
+
kwargs['model_name'] = kwargs.pop('model')
|
|
82
|
+
super().__init__(binding_name=BindingName, config=kwargs)
|
|
88
83
|
|
|
89
84
|
if not _audiocraft_installed_with_correct_torch:
|
|
90
85
|
raise ImportError(f"AudioCraft TTM binding dependencies not met. Please ensure 'audiocraft', 'torch', 'torchaudio', 'scipy', 'numpy' are installed. Error: {_audiocraft_installation_error}")
|
|
91
86
|
|
|
92
|
-
self.device = device
|
|
87
|
+
self.device = kwargs.get("device") # "cuda", "mps", "cpu", or None for auto-detect
|
|
93
88
|
if self.device is None: # Auto-detect if not specified by user
|
|
94
89
|
if torch.cuda.is_available():
|
|
95
90
|
self.device = "cuda"
|
|
@@ -117,7 +112,7 @@ class AudioCraftTTMBinding(LollmsTTMBinding):
|
|
|
117
112
|
ASCIIColors.warning(f"Unsupported output_format '{self.output_format}'. Defaulting to 'wav'.")
|
|
118
113
|
self.output_format = "wav"
|
|
119
114
|
|
|
120
|
-
self._load_audiocraft_model(model_name)
|
|
115
|
+
self._load_audiocraft_model(kwargs.get("model_name") or "facebook/musicgen-small")
|
|
121
116
|
|
|
122
117
|
def _load_audiocraft_model(self, model_name_to_load: str):
|
|
123
118
|
if self.model is not None and self.loaded_model_name == model_name_to_load:
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import requests
|
|
3
|
+
import time
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Optional, List, Dict, Any
|
|
6
|
+
|
|
7
|
+
from lollms_client.lollms_ttm_binding import LollmsTTMBinding
|
|
8
|
+
from ascii_colors import trace_exception, ASCIIColors
|
|
9
|
+
import pipmaster as pm
|
|
10
|
+
|
|
11
|
+
# Ensure required packages are installed
|
|
12
|
+
pm.ensure_packages(["requests"])
|
|
13
|
+
|
|
14
|
+
BindingName = "BeatovenAITTMBinding"
|
|
15
|
+
|
|
16
|
+
class BeatovenAITTMBinding(LollmsTTMBinding):
|
|
17
|
+
"""A Text-to-Music binding for the Beatoven.ai API."""
|
|
18
|
+
|
|
19
|
+
def __init__(self,
|
|
20
|
+
**kwargs):
|
|
21
|
+
# Prioritize 'model_name' but accept 'model' as an alias from config files.
|
|
22
|
+
if 'model' in kwargs and 'model_name' not in kwargs:
|
|
23
|
+
kwargs['model_name'] = kwargs.pop('model')
|
|
24
|
+
super().__init__(binding_name=BindingName, config=kwargs)
|
|
25
|
+
self.api_key = self.config.get("api_key") or os.environ.get("BEATOVEN_API_KEY")
|
|
26
|
+
if not self.api_key:
|
|
27
|
+
raise ValueError("Beatoven.ai API key is required. Please set it in config or as BEATOVEN_API_KEY env var.")
|
|
28
|
+
self.base_url = "https://api.beatoven.ai/api/v1"
|
|
29
|
+
self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
|
30
|
+
|
|
31
|
+
def list_models(self, **kwargs) -> List[str]:
|
|
32
|
+
# Beatoven.ai does not expose different models via the API.
|
|
33
|
+
# Customization is done via genre, mood, and tempo.
|
|
34
|
+
return ["default"]
|
|
35
|
+
|
|
36
|
+
def _poll_for_completion(self, task_id: str) -> Dict[str, Any]:
|
|
37
|
+
"""Polls the tasks endpoint until the composition is complete."""
|
|
38
|
+
poll_url = f"{self.base_url}/tasks/{task_id}"
|
|
39
|
+
while True:
|
|
40
|
+
try:
|
|
41
|
+
response = requests.get(poll_url, headers=self.headers)
|
|
42
|
+
response.raise_for_status()
|
|
43
|
+
data = response.json()
|
|
44
|
+
status = data.get("status")
|
|
45
|
+
|
|
46
|
+
if status == "success":
|
|
47
|
+
ASCIIColors.green("Composition task successful.")
|
|
48
|
+
return data
|
|
49
|
+
elif status == "failed":
|
|
50
|
+
error_info = data.get("error", "Unknown error.")
|
|
51
|
+
raise Exception(f"Beatoven.ai task failed: {error_info}")
|
|
52
|
+
else:
|
|
53
|
+
ASCIIColors.info(f"Task status is '{status}'. Waiting...")
|
|
54
|
+
time.sleep(5)
|
|
55
|
+
except requests.exceptions.HTTPError as e:
|
|
56
|
+
raise Exception(f"Failed to poll task status: {e.response.text}")
|
|
57
|
+
|
|
58
|
+
def generate_music(self, prompt: str, **kwargs) -> bytes:
|
|
59
|
+
"""
|
|
60
|
+
Generates music by creating a track, waiting for composition, and downloading the result.
|
|
61
|
+
"""
|
|
62
|
+
# Step 1: Create a track
|
|
63
|
+
create_track_url = f"{self.base_url}/tracks"
|
|
64
|
+
payload = {
|
|
65
|
+
"title": prompt[:100], # Use prompt as title, truncated
|
|
66
|
+
"duration_in_seconds": kwargs.get("duration", 30),
|
|
67
|
+
"genre": kwargs.get("genre", "Cinematic"),
|
|
68
|
+
"tempo": kwargs.get("tempo", "medium"),
|
|
69
|
+
"prompt": prompt
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
try:
|
|
73
|
+
ASCIIColors.info("Submitting music track request to Beatoven.ai...")
|
|
74
|
+
create_response = requests.post(create_track_url, json=payload, headers=self.headers)
|
|
75
|
+
create_response.raise_for_status()
|
|
76
|
+
task_id = create_response.json().get("task_id")
|
|
77
|
+
ASCIIColors.info(f"Track creation submitted. Task ID: {task_id}")
|
|
78
|
+
|
|
79
|
+
# Step 2: Poll for task completion
|
|
80
|
+
task_result = self._poll_for_completion(task_id)
|
|
81
|
+
track_id = task_result.get("track_id")
|
|
82
|
+
if not track_id:
|
|
83
|
+
raise Exception("Task completed but did not return a track_id.")
|
|
84
|
+
|
|
85
|
+
# Step 3: Get track details to find the audio URL
|
|
86
|
+
track_url = f"{self.base_url}/tracks/{track_id}"
|
|
87
|
+
track_response = requests.get(track_url, headers=self.headers)
|
|
88
|
+
track_response.raise_for_status()
|
|
89
|
+
|
|
90
|
+
audio_url = track_response.json().get("renders", {}).get("wav")
|
|
91
|
+
if not audio_url:
|
|
92
|
+
raise Exception("Could not find WAV render URL in the completed track details.")
|
|
93
|
+
|
|
94
|
+
# Step 4: Download the audio file
|
|
95
|
+
ASCIIColors.info(f"Downloading generated audio from {audio_url}")
|
|
96
|
+
audio_response = requests.get(audio_url)
|
|
97
|
+
audio_response.raise_for_status()
|
|
98
|
+
|
|
99
|
+
return audio_response.content
|
|
100
|
+
|
|
101
|
+
except requests.exceptions.HTTPError as e:
|
|
102
|
+
error_details = e.response.json()
|
|
103
|
+
raise Exception(f"Beatoven.ai API HTTP Error: {error_details}") from e
|
|
104
|
+
except Exception as e:
|
|
105
|
+
trace_exception(e)
|
|
106
|
+
raise
|
|
107
|
+
|
|
108
|
+
if __name__ == '__main__':
|
|
109
|
+
ASCIIColors.magenta("--- Beatoven.ai TTM Binding Test ---")
|
|
110
|
+
if "BEATOVEN_API_KEY" not in os.environ:
|
|
111
|
+
ASCIIColors.error("BEATOVEN_API_KEY environment variable not set. Cannot run test.")
|
|
112
|
+
exit(1)
|
|
113
|
+
|
|
114
|
+
try:
|
|
115
|
+
binding = BeatovenAITTMBinding()
|
|
116
|
+
|
|
117
|
+
ASCIIColors.cyan("\n--- Test: Music Generation ---")
|
|
118
|
+
prompt = "A mysterious and suspenseful cinematic track with soft piano and eerie strings, building tension."
|
|
119
|
+
music_bytes = binding.generate_music(prompt, duration=45, genre="Cinematic", tempo="slow")
|
|
120
|
+
|
|
121
|
+
assert len(music_bytes) > 1000, "Generated music bytes are too small."
|
|
122
|
+
output_path = Path(__file__).parent / "tmp_beatoven_music.wav"
|
|
123
|
+
with open(output_path, "wb") as f:
|
|
124
|
+
f.write(music_bytes)
|
|
125
|
+
ASCIIColors.green(f"Music generation OK. Audio saved to {output_path}")
|
|
126
|
+
|
|
127
|
+
except Exception as e:
|
|
128
|
+
trace_exception(e)
|
|
129
|
+
ASCIIColors.error(f"Beatoven.ai TTM binding test failed: {e}")
|
|
@@ -11,23 +11,10 @@ class LollmsTTMBinding_Impl(LollmsTTMBinding):
|
|
|
11
11
|
"""Concrete implementation of the LollmsTTMBinding for the standard LOLLMS server (Placeholder)."""
|
|
12
12
|
|
|
13
13
|
def __init__(self,
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
"""
|
|
19
|
-
Initialize the LOLLMS TTM binding.
|
|
20
|
-
|
|
21
|
-
Args:
|
|
22
|
-
host_address (Optional[str]): Host address for the LOLLMS service.
|
|
23
|
-
model_name (Optional[str]): Default TTM model identifier.
|
|
24
|
-
service_key (Optional[str]): Authentication key.
|
|
25
|
-
verify_ssl_certificate (bool): Whether to verify SSL certificates.
|
|
26
|
-
"""
|
|
27
|
-
super().__init__(host_address=host_address,
|
|
28
|
-
model_name=model_name,
|
|
29
|
-
service_key=service_key,
|
|
30
|
-
verify_ssl_certificate=verify_ssl_certificate)
|
|
14
|
+
**kwargs):
|
|
15
|
+
# Prioritize 'model_name' but accept 'model' as an alias from config files.
|
|
16
|
+
if 'model' in kwargs and 'model_name' not in kwargs:
|
|
17
|
+
kwargs['model_name'] = kwargs.pop('model')
|
|
31
18
|
ASCIIColors.warning("LOLLMS TTM binding is not yet fully implemented in the client.")
|
|
32
19
|
ASCIIColors.warning("Please ensure your LOLLMS server has a TTM service running.")
|
|
33
20
|
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import requests
|
|
3
|
+
import base64
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Optional, List, Dict, Any
|
|
6
|
+
|
|
7
|
+
from lollms_client.lollms_ttm_binding import LollmsTTMBinding
|
|
8
|
+
from ascii_colors import trace_exception, ASCIIColors
|
|
9
|
+
import pipmaster as pm
|
|
10
|
+
|
|
11
|
+
# Ensure required packages are installed
|
|
12
|
+
pm.ensure_packages(["requests"])
|
|
13
|
+
|
|
14
|
+
BindingName = "StabilityAITTMBinding"
|
|
15
|
+
|
|
16
|
+
# Models available via the Stability AI Audio API
|
|
17
|
+
# Sourced from: https://platform.stability.ai/docs/api-reference#tag/Generate/paths/~1v1~1generation~1stable-audio-2.0~1text-to-audio/post
|
|
18
|
+
STABILITY_AI_MODELS = [
|
|
19
|
+
{"model_name": "stable-audio-2.0", "display_name": "Stable Audio 2.0", "description": "High-quality, full-track music generation up to 3 minutes."},
|
|
20
|
+
{"model_name": "stable-audio-1.0", "display_name": "Stable Audio 1.0", "description": "Original model, best for short clips and sound effects."},
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
class StabilityAITTMBinding(LollmsTTMBinding):
|
|
24
|
+
"""A Text-to-Music binding for Stability AI's Stable Audio API."""
|
|
25
|
+
|
|
26
|
+
def __init__(self,
|
|
27
|
+
**kwargs):
|
|
28
|
+
# Prioritize 'model_name' but accept 'model' as an alias from config files.
|
|
29
|
+
if 'model' in kwargs and 'model_name' not in kwargs:
|
|
30
|
+
kwargs['model_name'] = kwargs.pop('model')
|
|
31
|
+
self.api_key = self.config.get("api_key") or os.environ.get("STABILITY_API_KEY")
|
|
32
|
+
if not self.api_key:
|
|
33
|
+
raise ValueError("Stability AI API key is required. Please set it in the configuration or as STABILITY_API_KEY environment variable.")
|
|
34
|
+
self.model_name = self.config.get("model_name", "stable-audio-2.0")
|
|
35
|
+
|
|
36
|
+
def list_models(self, **kwargs) -> List[Dict[str, str]]:
|
|
37
|
+
return STABILITY_AI_MODELS
|
|
38
|
+
|
|
39
|
+
def generate_music(self, prompt: str, **kwargs) -> bytes:
|
|
40
|
+
"""
|
|
41
|
+
Generates music using the Stable Audio API.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
prompt (str): The text prompt describing the desired music.
|
|
45
|
+
duration (int): The duration of the audio in seconds. Defaults to 29.
|
|
46
|
+
**kwargs: Additional parameters for the API.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
bytes: The generated audio data in WAV format.
|
|
50
|
+
"""
|
|
51
|
+
url = f"https://api.stability.ai/v1/generation/{self.model_name}/text-to-audio"
|
|
52
|
+
headers = {
|
|
53
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
54
|
+
"Accept": "audio/wav",
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
# Get duration, with a default of 29 seconds as it's a common value
|
|
58
|
+
duration = kwargs.get("duration", 29)
|
|
59
|
+
|
|
60
|
+
payload = {
|
|
61
|
+
"text_prompts[0][text]": prompt,
|
|
62
|
+
"text_prompts[0][weight]": 1.0,
|
|
63
|
+
"seed": kwargs.get("seed", 0), # 0 for random in API
|
|
64
|
+
"steps": kwargs.get("steps", 100),
|
|
65
|
+
"cfg_scale": kwargs.get("cfg_scale", 7.0),
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
# Handle different parameter names for duration
|
|
69
|
+
if self.model_name == "stable-audio-2.0":
|
|
70
|
+
payload["duration"] = duration
|
|
71
|
+
else: # stable-audio-1.0
|
|
72
|
+
payload["sample_length"] = duration * 44100 # v1 uses sample length
|
|
73
|
+
|
|
74
|
+
try:
|
|
75
|
+
ASCIIColors.info(f"Requesting music from Stability AI ({self.model_name})...")
|
|
76
|
+
response = requests.post(url, headers=headers, data=payload)
|
|
77
|
+
response.raise_for_status()
|
|
78
|
+
|
|
79
|
+
ASCIIColors.green("Successfully generated music from Stability AI.")
|
|
80
|
+
return response.content
|
|
81
|
+
except requests.exceptions.HTTPError as e:
|
|
82
|
+
try:
|
|
83
|
+
error_details = e.response.json()
|
|
84
|
+
error_message = error_details.get("message", e.response.text)
|
|
85
|
+
except:
|
|
86
|
+
error_message = e.response.text
|
|
87
|
+
ASCIIColors.error(f"HTTP Error from Stability AI: {e.response.status_code} - {error_message}")
|
|
88
|
+
raise Exception(f"Stability AI API Error: {error_message}") from e
|
|
89
|
+
except Exception as e:
|
|
90
|
+
trace_exception(e)
|
|
91
|
+
raise Exception(f"An unexpected error occurred: {e}")
|
|
92
|
+
|
|
93
|
+
if __name__ == '__main__':
|
|
94
|
+
ASCIIColors.magenta("--- Stability AI TTM Binding Test ---")
|
|
95
|
+
if "STABILITY_API_KEY" not in os.environ:
|
|
96
|
+
ASCIIColors.error("STABILITY_API_KEY environment variable not set. Cannot run test.")
|
|
97
|
+
exit(1)
|
|
98
|
+
|
|
99
|
+
try:
|
|
100
|
+
# Test with default settings
|
|
101
|
+
binding = StabilityAITTMBinding()
|
|
102
|
+
|
|
103
|
+
ASCIIColors.cyan("\n--- Test: Music Generation ---")
|
|
104
|
+
prompt = "80s synthwave, retro futuristic, driving beat, cinematic"
|
|
105
|
+
music_bytes = binding.generate_music(prompt, duration=10)
|
|
106
|
+
|
|
107
|
+
assert len(music_bytes) > 1000, "Generated music bytes are too small."
|
|
108
|
+
output_path = Path(__file__).parent / "tmp_stability_music.wav"
|
|
109
|
+
with open(output_path, "wb") as f:
|
|
110
|
+
f.write(music_bytes)
|
|
111
|
+
ASCIIColors.green(f"Music generation OK. Audio saved to {output_path}")
|
|
112
|
+
|
|
113
|
+
except Exception as e:
|
|
114
|
+
trace_exception(e)
|
|
115
|
+
ASCIIColors.error(f"Stability AI TTM binding test failed: {e}")
|