lollms-client 1.4.1__py3-none-any.whl → 1.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lollms-client might be problematic. Click here for more details.
- lollms_client/__init__.py +1 -1
- lollms_client/llm_bindings/novita_ai/__init__.py +303 -0
- lollms_client/llm_bindings/perplexity/__init__.py +326 -0
- lollms_client/lollms_discussion.py +11 -1
- lollms_client/tti_bindings/leonardo_ai/__init__.py +124 -0
- lollms_client/tti_bindings/novita_ai/__init__.py +102 -0
- lollms_client/tti_bindings/stability_ai/__init__.py +176 -0
- lollms_client/ttm_bindings/beatoven_ai/__init__.py +125 -0
- lollms_client/ttm_bindings/replicate/__init__.py +112 -0
- lollms_client/ttm_bindings/stability_ai/__init__.py +114 -0
- lollms_client/ttm_bindings/topmediai/__init__.py +93 -0
- {lollms_client-1.4.1.dist-info → lollms_client-1.4.6.dist-info}/METADATA +204 -2
- {lollms_client-1.4.1.dist-info → lollms_client-1.4.6.dist-info}/RECORD +16 -8
- lollms_client/ttm_bindings/bark/__init__.py +0 -339
- {lollms_client-1.4.1.dist-info → lollms_client-1.4.6.dist-info}/WHEEL +0 -0
- {lollms_client-1.4.1.dist-info → lollms_client-1.4.6.dist-info}/licenses/LICENSE +0 -0
- {lollms_client-1.4.1.dist-info → lollms_client-1.4.6.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import requests
|
|
3
|
+
import time
|
|
4
|
+
import base64
|
|
5
|
+
from io import BytesIO
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Optional, List, Dict, Any, Union
|
|
8
|
+
|
|
9
|
+
from lollms_client.lollms_tti_binding import LollmsTTIBinding
|
|
10
|
+
from ascii_colors import trace_exception, ASCIIColors
|
|
11
|
+
import pipmaster as pm
|
|
12
|
+
|
|
13
|
+
pm.ensure_packages(["requests", "Pillow"])
|
|
14
|
+
from PIL import Image
|
|
15
|
+
|
|
16
|
+
BindingName = "LeonardoAITTIBinding"
|
|
17
|
+
|
|
18
|
+
# Sourced from https://docs.leonardo.ai/docs/models
|
|
19
|
+
LEONARDO_AI_MODELS = [
|
|
20
|
+
{"model_name": "ac4f3991-8a40-42cd-b174-14a8e33738e4", "display_name": "Leonardo Phoenix", "description": "Fast, high-quality photorealism."},
|
|
21
|
+
{"model_name": "1e65d070-22c9-4aed-a5be-ce58a1b65b38", "display_name": "Leonardo Diffusion XL", "description": "The flagship general-purpose SDXL model."},
|
|
22
|
+
{"model_name": "b24e16ff-06e3-43eb-a255-db4322b0f345", "display_name": "AlbedoBase XL", "description": "Versatile model for photorealism and artistic styles."},
|
|
23
|
+
{"model_name": "6bef9f1b-29cb-40c7-b9df-32b51c1f67d3", "display_name": "Absolute Reality v1.6", "description": "Classic photorealistic model."},
|
|
24
|
+
{"model_name": "f3296a34-a868-4665-8b2f-f4313f8c8533", "display_name": "RPG v5", "description": "Specialized in RPG characters and assets."},
|
|
25
|
+
{"model_name": "2067ae58-a02e-4318-9742-2b55b2a4c813", "display_name": "DreamShaper v7", "description": "Popular versatile artistic model."},
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
class LeonardoAITTIBinding(LollmsTTIBinding):
|
|
29
|
+
"""Leonardo.ai TTI binding for LoLLMS"""
|
|
30
|
+
|
|
31
|
+
def __init__(self, **kwargs):
|
|
32
|
+
super().__init__(binding_name=BindingName)
|
|
33
|
+
self.config = kwargs
|
|
34
|
+
self.api_key = self.config.get("api_key") or os.environ.get("LEONARDO_API_KEY")
|
|
35
|
+
if not self.api_key:
|
|
36
|
+
raise ValueError("Leonardo.ai API key is required.")
|
|
37
|
+
self.model_name = self.config.get("model_name", "ac4f3991-8a40-42cd-b174-14a8e33738e4")
|
|
38
|
+
self.base_url = "https://cloud.leonardo.ai/api/rest/v1"
|
|
39
|
+
self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
|
40
|
+
|
|
41
|
+
def listModels(self) -> list:
|
|
42
|
+
# You could also fetch this dynamically from /models endpoint
|
|
43
|
+
return LEONARDO_AI_MODELS
|
|
44
|
+
|
|
45
|
+
def _wait_for_generation(self, generation_id: str) -> List[bytes]:
|
|
46
|
+
while True:
|
|
47
|
+
url = f"{self.base_url}/generations/{generation_id}"
|
|
48
|
+
response = requests.get(url, headers=self.headers)
|
|
49
|
+
response.raise_for_status()
|
|
50
|
+
data = response.json().get("generations_by_pk", {})
|
|
51
|
+
status = data.get("status")
|
|
52
|
+
|
|
53
|
+
if status == "COMPLETE":
|
|
54
|
+
ASCIIColors.green("Generation complete.")
|
|
55
|
+
images_data = []
|
|
56
|
+
for img in data.get("generated_images", []):
|
|
57
|
+
img_url = img.get("url")
|
|
58
|
+
if img_url:
|
|
59
|
+
img_response = requests.get(img_url)
|
|
60
|
+
img_response.raise_for_status()
|
|
61
|
+
images_data.append(img_response.content)
|
|
62
|
+
return images_data
|
|
63
|
+
elif status == "FAILED":
|
|
64
|
+
raise Exception("Leonardo.ai generation failed.")
|
|
65
|
+
else:
|
|
66
|
+
ASCIIColors.info(f"Generation status: {status}. Waiting...")
|
|
67
|
+
time.sleep(3)
|
|
68
|
+
|
|
69
|
+
def generate_image(self, prompt: str, negative_prompt: str = "", width: int = 1024, height: int = 1024, **kwargs) -> bytes:
|
|
70
|
+
url = f"{self.base_url}/generations"
|
|
71
|
+
payload = {
|
|
72
|
+
"prompt": prompt,
|
|
73
|
+
"negative_prompt": negative_prompt,
|
|
74
|
+
"modelId": self.model_name,
|
|
75
|
+
"width": width,
|
|
76
|
+
"height": height,
|
|
77
|
+
"num_images": 1,
|
|
78
|
+
"guidance_scale": kwargs.get("guidance_scale", 7),
|
|
79
|
+
"seed": kwargs.get("seed"),
|
|
80
|
+
"sd_version": "SDXL" # Most models are SDXL based
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
try:
|
|
84
|
+
ASCIIColors.info(f"Submitting generation job to Leonardo.ai ({self.model_name})...")
|
|
85
|
+
response = requests.post(url, json=payload, headers=self.headers)
|
|
86
|
+
response.raise_for_status()
|
|
87
|
+
generation_id = response.json()["sdGenerationJob"]["generationId"]
|
|
88
|
+
ASCIIColors.info(f"Job submitted with ID: {generation_id}")
|
|
89
|
+
images = self._wait_for_generation(generation_id)
|
|
90
|
+
return images[0]
|
|
91
|
+
except Exception as e:
|
|
92
|
+
trace_exception(e)
|
|
93
|
+
try:
|
|
94
|
+
error_msg = response.json()
|
|
95
|
+
raise Exception(f"Leonardo.ai API error: {error_msg}")
|
|
96
|
+
except:
|
|
97
|
+
raise Exception(f"Leonardo.ai API request failed: {e}")
|
|
98
|
+
|
|
99
|
+
def edit_image(self, **kwargs) -> bytes:
|
|
100
|
+
ASCIIColors.warning("Leonardo.ai edit_image (inpainting/img2img) is not yet implemented in this binding.")
|
|
101
|
+
raise NotImplementedError("This binding does not yet support image editing.")
|
|
102
|
+
|
|
103
|
+
if __name__ == '__main__':
|
|
104
|
+
ASCIIColors.magenta("--- Leonardo.ai TTI Binding Test ---")
|
|
105
|
+
if "LEONARDO_API_KEY" not in os.environ:
|
|
106
|
+
ASCIIColors.error("LEONARDO_API_KEY environment variable not set. Cannot run test.")
|
|
107
|
+
exit(1)
|
|
108
|
+
|
|
109
|
+
try:
|
|
110
|
+
binding = LeonardoAITTIBinding()
|
|
111
|
+
|
|
112
|
+
ASCIIColors.cyan("\n--- Test: Text-to-Image ---")
|
|
113
|
+
prompt = "A majestic lion wearing a crown, hyperrealistic, 8k"
|
|
114
|
+
img_bytes = binding.generate_image(prompt, width=1024, height=1024)
|
|
115
|
+
|
|
116
|
+
assert len(img_bytes) > 1000
|
|
117
|
+
output_path = Path(__file__).parent / "tmp_leonardo_t2i.png"
|
|
118
|
+
with open(output_path, "wb") as f:
|
|
119
|
+
f.write(img_bytes)
|
|
120
|
+
ASCIIColors.green(f"Text-to-Image generation OK. Image saved to {output_path}")
|
|
121
|
+
|
|
122
|
+
except Exception as e:
|
|
123
|
+
trace_exception(e)
|
|
124
|
+
ASCIIColors.error(f"Leonardo.ai binding test failed: {e}")
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import requests
|
|
3
|
+
import base64
|
|
4
|
+
from io import BytesIO
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Optional, List, Dict, Any, Union
|
|
7
|
+
|
|
8
|
+
from lollms_client.lollms_tti_binding import LollmsTTIBinding
|
|
9
|
+
from ascii_colors import trace_exception, ASCIIColors
|
|
10
|
+
import pipmaster as pm
|
|
11
|
+
|
|
12
|
+
pm.ensure_packages(["requests"])
|
|
13
|
+
|
|
14
|
+
BindingName = "NovitaAITTIBinding"
|
|
15
|
+
|
|
16
|
+
# Sourced from https://docs.novita.ai/image-generation/models
|
|
17
|
+
NOVITA_AI_MODELS = [
|
|
18
|
+
{"model_name": "sd_xl_base_1.0.safetensors", "display_name": "Stable Diffusion XL 1.0", "description": "Official SDXL 1.0 Base model."},
|
|
19
|
+
{"model_name": "dreamshaper_xl_1_0.safetensors", "display_name": "DreamShaper XL 1.0", "description": "Versatile artistic SDXL model."},
|
|
20
|
+
{"model_name": "juggernaut_xl_v9_rundiffusion.safetensors", "display_name": "Juggernaut XL v9", "description": "High-quality realistic and cinematic model."},
|
|
21
|
+
{"model_name": "realistic_vision_v5.1.safetensors", "display_name": "Realistic Vision v5.1", "description": "Popular photorealistic SD1.5 model."},
|
|
22
|
+
{"model_name": "absolutereality_v1.8.1.safetensors", "display_name": "Absolute Reality v1.8.1", "description": "General-purpose realistic SD1.5 model."},
|
|
23
|
+
{"model_name": "meinamix_meina_v11.safetensors", "display_name": "MeinaMix v11", "description": "High-quality anime illustration model."},
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
class NovitaAITTIBinding(LollmsTTIBinding):
|
|
27
|
+
"""Novita.ai TTI binding for LoLLMS"""
|
|
28
|
+
|
|
29
|
+
def __init__(self, **kwargs):
|
|
30
|
+
super().__init__(binding_name=BindingName)
|
|
31
|
+
self.config = kwargs
|
|
32
|
+
self.api_key = self.config.get("api_key") or os.environ.get("NOVITA_API_KEY")
|
|
33
|
+
if not self.api_key:
|
|
34
|
+
raise ValueError("Novita.ai API key is required.")
|
|
35
|
+
self.model_name = self.config.get("model_name", "juggernaut_xl_v9_rundiffusion.safetensors")
|
|
36
|
+
self.base_url = "https://api.novita.ai/v3"
|
|
37
|
+
self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
|
38
|
+
|
|
39
|
+
def listModels(self) -> list:
|
|
40
|
+
return NOVITA_AI_MODELS
|
|
41
|
+
|
|
42
|
+
def generate_image(self, prompt: str, negative_prompt: str = "", width: int = 1024, height: int = 1024, **kwargs) -> bytes:
|
|
43
|
+
url = f"{self.base_url}/text2img"
|
|
44
|
+
payload = {
|
|
45
|
+
"model_name": self.model_name,
|
|
46
|
+
"prompt": prompt,
|
|
47
|
+
"negative_prompt": negative_prompt,
|
|
48
|
+
"width": width,
|
|
49
|
+
"height": height,
|
|
50
|
+
"sampler_name": "DPM++ 2M Karras",
|
|
51
|
+
"cfg_scale": kwargs.get("guidance_scale", 7.0),
|
|
52
|
+
"steps": kwargs.get("num_inference_steps", 25),
|
|
53
|
+
"seed": kwargs.get("seed", -1),
|
|
54
|
+
"n_iter": 1,
|
|
55
|
+
"batch_size": 1
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
try:
|
|
59
|
+
ASCIIColors.info(f"Requesting image from Novita.ai ({self.model_name})...")
|
|
60
|
+
response = requests.post(url, json=payload, headers=self.headers)
|
|
61
|
+
response.raise_for_status()
|
|
62
|
+
data = response.json()
|
|
63
|
+
if "images" not in data or not data["images"]:
|
|
64
|
+
raise Exception(f"API returned no images. Response: {data}")
|
|
65
|
+
|
|
66
|
+
b64_image = data["images"][0]["image_base64"]
|
|
67
|
+
return base64.b64decode(b64_image)
|
|
68
|
+
|
|
69
|
+
except Exception as e:
|
|
70
|
+
trace_exception(e)
|
|
71
|
+
try:
|
|
72
|
+
error_msg = response.json()
|
|
73
|
+
raise Exception(f"Novita.ai API error: {error_msg}")
|
|
74
|
+
except:
|
|
75
|
+
raise Exception(f"Novita.ai API request failed: {e}")
|
|
76
|
+
|
|
77
|
+
def edit_image(self, **kwargs) -> bytes:
|
|
78
|
+
ASCIIColors.warning("Novita.ai edit_image (inpainting/img2img) is not yet implemented in this binding.")
|
|
79
|
+
raise NotImplementedError("This binding does not yet support image editing.")
|
|
80
|
+
|
|
81
|
+
if __name__ == '__main__':
|
|
82
|
+
ASCIIColors.magenta("--- Novita.ai TTI Binding Test ---")
|
|
83
|
+
if "NOVITA_API_KEY" not in os.environ:
|
|
84
|
+
ASCIIColors.error("NOVITA_API_KEY environment variable not set. Cannot run test.")
|
|
85
|
+
exit(1)
|
|
86
|
+
|
|
87
|
+
try:
|
|
88
|
+
binding = NovitaAITTIBinding()
|
|
89
|
+
|
|
90
|
+
ASCIIColors.cyan("\n--- Test: Text-to-Image ---")
|
|
91
|
+
prompt = "A cute capybara wearing a top hat, sitting in a field of flowers, painterly style"
|
|
92
|
+
img_bytes = binding.generate_image(prompt, width=1024, height=1024, num_inference_steps=30)
|
|
93
|
+
|
|
94
|
+
assert len(img_bytes) > 1000
|
|
95
|
+
output_path = Path(__file__).parent / "tmp_novita_t2i.png"
|
|
96
|
+
with open(output_path, "wb") as f:
|
|
97
|
+
f.write(img_bytes)
|
|
98
|
+
ASCIIColors.green(f"Text-to-Image generation OK. Image saved to {output_path}")
|
|
99
|
+
|
|
100
|
+
except Exception as e:
|
|
101
|
+
trace_exception(e)
|
|
102
|
+
ASCIIColors.error(f"Novita.ai binding test failed: {e}")
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import requests
|
|
3
|
+
import base64
|
|
4
|
+
from io import BytesIO
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Optional, List, Dict, Any, Union
|
|
7
|
+
|
|
8
|
+
from lollms_client.lollms_tti_binding import LollmsTTIBinding
|
|
9
|
+
from ascii_colors import trace_exception, ASCIIColors
|
|
10
|
+
import pipmaster as pm
|
|
11
|
+
|
|
12
|
+
pm.ensure_packages(["requests", "Pillow"])
|
|
13
|
+
|
|
14
|
+
from PIL import Image
|
|
15
|
+
|
|
16
|
+
BindingName = "StabilityAITTIBinding"
|
|
17
|
+
|
|
18
|
+
# Sourced from https://platform.stability.ai/docs/getting-started/models
|
|
19
|
+
STABILITY_AI_MODELS = [
|
|
20
|
+
# SD3
|
|
21
|
+
{"model_name": "stable-diffusion-3-medium", "display_name": "Stable Diffusion 3 Medium", "description": "Most advanced text-to-image model.", "owned_by": "Stability AI"},
|
|
22
|
+
{"model_name": "stable-diffusion-3-large", "display_name": "Stable Diffusion 3 Large", "description": "Most advanced model with higher quality.", "owned_by": "Stability AI"},
|
|
23
|
+
{"model_name": "stable-diffusion-3-large-turbo", "display_name": "Stable Diffusion 3 Large Turbo", "description": "Fast, high-quality generation.", "owned_by": "Stability AI"},
|
|
24
|
+
# SDXL
|
|
25
|
+
{"model_name": "stable-diffusion-xl-1024-v1-0", "display_name": "Stable Diffusion XL 1.0", "description": "High-quality 1024x1024 generation.", "owned_by": "Stability AI"},
|
|
26
|
+
{"model_name": "stable-diffusion-xl-beta-v2-2-2", "display_name": "SDXL Beta", "description": "Legacy anime-focused SDXL model.", "owned_by": "Stability AI"},
|
|
27
|
+
# SD 1.x & 2.x
|
|
28
|
+
{"model_name": "stable-diffusion-v1-6", "display_name": "Stable Diffusion 1.6", "description": "Improved version of SD 1.5.", "owned_by": "Stability AI"},
|
|
29
|
+
{"model_name": "stable-diffusion-2-1", "display_name": "Stable Diffusion 2.1", "description": "768x768 native resolution model.", "owned_by": "Stability AI"},
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
class StabilityAITTIBinding(LollmsTTIBinding):
|
|
33
|
+
"""Stability AI TTI binding for LoLLMS"""
|
|
34
|
+
|
|
35
|
+
def __init__(self, **kwargs):
|
|
36
|
+
super().__init__(binding_name=BindingName)
|
|
37
|
+
self.config = kwargs
|
|
38
|
+
self.api_key = self.config.get("api_key") or os.environ.get("STABILITY_API_KEY")
|
|
39
|
+
if not self.api_key:
|
|
40
|
+
raise ValueError("Stability AI API key is required. Please set it in the configuration or as STABILITY_API_KEY environment variable.")
|
|
41
|
+
self.model_name = self.config.get("model_name", "stable-diffusion-3-medium")
|
|
42
|
+
|
|
43
|
+
def listModels(self) -> list:
|
|
44
|
+
return STABILITY_AI_MODELS
|
|
45
|
+
|
|
46
|
+
def _get_api_url(self, task: str) -> str:
|
|
47
|
+
base_url = "https://api.stability.ai/v2beta/stable-image"
|
|
48
|
+
# SD3 models use a different endpoint structure
|
|
49
|
+
if "stable-diffusion-3" in self.model_name:
|
|
50
|
+
return f"{base_url}/generate/sd3"
|
|
51
|
+
|
|
52
|
+
task_map = {
|
|
53
|
+
"text2image": "generate/core",
|
|
54
|
+
"image2image": "edit/image-to-image",
|
|
55
|
+
"inpainting": "edit/in-painting",
|
|
56
|
+
"upscale": "edit/upscale"
|
|
57
|
+
}
|
|
58
|
+
if task not in task_map:
|
|
59
|
+
raise ValueError(f"Unsupported task for this model family: {task}")
|
|
60
|
+
return f"{base_url}/{task_map[task]}"
|
|
61
|
+
|
|
62
|
+
def _decode_image_input(self, item: Union[str, Path, bytes]) -> Image.Image:
|
|
63
|
+
if isinstance(item, bytes):
|
|
64
|
+
return Image.open(BytesIO(item))
|
|
65
|
+
s = str(item).strip()
|
|
66
|
+
if s.startswith("data:image/") and ";base64," in s:
|
|
67
|
+
b64 = s.split(";base64,")[-1]
|
|
68
|
+
return Image.open(BytesIO(base64.b64decode(b64)))
|
|
69
|
+
try:
|
|
70
|
+
p = Path(s)
|
|
71
|
+
if p.exists():
|
|
72
|
+
return Image.open(p)
|
|
73
|
+
except:
|
|
74
|
+
pass
|
|
75
|
+
if s.startswith("http"):
|
|
76
|
+
response = requests.get(s, stream=True)
|
|
77
|
+
response.raise_for_status()
|
|
78
|
+
return Image.open(response.raw)
|
|
79
|
+
# Fallback for raw base64
|
|
80
|
+
return Image.open(BytesIO(base64.b64decode(s)))
|
|
81
|
+
|
|
82
|
+
def generate_image(self, prompt: str, negative_prompt: str = "", width: int = 1024, height: int = 1024, **kwargs) -> bytes:
|
|
83
|
+
url = self._get_api_url("text2image")
|
|
84
|
+
|
|
85
|
+
data = {
|
|
86
|
+
"prompt": prompt,
|
|
87
|
+
"negative_prompt": negative_prompt,
|
|
88
|
+
"output_format": "png",
|
|
89
|
+
"seed": kwargs.get("seed", 0)
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
# SD3 uses aspect_ratio, older models use width/height
|
|
93
|
+
if "stable-diffusion-3" in self.model_name:
|
|
94
|
+
data["aspect_ratio"] = f"{width}:{height}"
|
|
95
|
+
data["model"] = self.model_name
|
|
96
|
+
else:
|
|
97
|
+
data["width"] = width
|
|
98
|
+
data["height"] = height
|
|
99
|
+
data["style_preset"] = kwargs.get("style_preset", "photographic")
|
|
100
|
+
|
|
101
|
+
headers = {"authorization": f"Bearer {self.api_key}", "accept": "image/*"}
|
|
102
|
+
|
|
103
|
+
try:
|
|
104
|
+
ASCIIColors.info(f"Requesting image from Stability AI ({self.model_name})...")
|
|
105
|
+
response = requests.post(url, headers=headers, files={"none": ''}, data=data)
|
|
106
|
+
response.raise_for_status()
|
|
107
|
+
return response.content
|
|
108
|
+
except Exception as e:
|
|
109
|
+
trace_exception(e)
|
|
110
|
+
try:
|
|
111
|
+
error_msg = response.json()
|
|
112
|
+
raise Exception(f"Stability AI API error: {error_msg}")
|
|
113
|
+
except:
|
|
114
|
+
raise Exception(f"Stability AI API request failed: {e}")
|
|
115
|
+
|
|
116
|
+
def edit_image(self, images: Union[str, List[str]], prompt: str, negative_prompt: Optional[str] = "", mask: Optional[str] = None, **kwargs) -> bytes:
|
|
117
|
+
init_image_bytes = BytesIO()
|
|
118
|
+
init_image = self._decode_image_input(images[0] if isinstance(images, list) else images)
|
|
119
|
+
init_image.save(init_image_bytes, format="PNG")
|
|
120
|
+
|
|
121
|
+
task = "inpainting" if mask else "image2image"
|
|
122
|
+
url = self._get_api_url(task)
|
|
123
|
+
|
|
124
|
+
files = {"image": init_image_bytes.getvalue()}
|
|
125
|
+
data = {
|
|
126
|
+
"prompt": prompt,
|
|
127
|
+
"negative_prompt": negative_prompt or "",
|
|
128
|
+
"output_format": "png",
|
|
129
|
+
"seed": kwargs.get("seed", 0)
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
if task == "inpainting":
|
|
133
|
+
mask_image_bytes = BytesIO()
|
|
134
|
+
mask_image = self._decode_image_input(mask)
|
|
135
|
+
mask_image.save(mask_image_bytes, format="PNG")
|
|
136
|
+
files["mask"] = mask_image_bytes.getvalue()
|
|
137
|
+
else: # image2image
|
|
138
|
+
data["strength"] = kwargs.get("strength", 0.6) # mode IMAGE_STRENGTH
|
|
139
|
+
|
|
140
|
+
headers = {"authorization": f"Bearer {self.api_key}", "accept": "image/*"}
|
|
141
|
+
|
|
142
|
+
try:
|
|
143
|
+
ASCIIColors.info(f"Requesting image edit from Stability AI ({self.model_name})...")
|
|
144
|
+
response = requests.post(url, headers=headers, files=files, data=data)
|
|
145
|
+
response.raise_for_status()
|
|
146
|
+
return response.content
|
|
147
|
+
except Exception as e:
|
|
148
|
+
trace_exception(e)
|
|
149
|
+
try:
|
|
150
|
+
error_msg = response.json()
|
|
151
|
+
raise Exception(f"Stability AI API error: {error_msg}")
|
|
152
|
+
except:
|
|
153
|
+
raise Exception(f"Stability AI API request failed: {e}")
|
|
154
|
+
|
|
155
|
+
if __name__ == '__main__':
|
|
156
|
+
ASCIIColors.magenta("--- Stability AI TTI Binding Test ---")
|
|
157
|
+
if "STABILITY_API_KEY" not in os.environ:
|
|
158
|
+
ASCIIColors.error("STABILITY_API_KEY environment variable not set. Cannot run test.")
|
|
159
|
+
exit(1)
|
|
160
|
+
|
|
161
|
+
try:
|
|
162
|
+
binding = StabilityAITTIBinding(model_name="stable-diffusion-3-medium")
|
|
163
|
+
|
|
164
|
+
ASCIIColors.cyan("\n--- Test: Text-to-Image ---")
|
|
165
|
+
prompt = "a cinematic photo of a robot drinking coffee in a Parisian cafe"
|
|
166
|
+
img_bytes = binding.generate_image(prompt, width=1024, height=1024)
|
|
167
|
+
|
|
168
|
+
assert len(img_bytes) > 1000, "Generated image bytes are too small."
|
|
169
|
+
output_path = Path(__file__).parent / "tmp_stability_t2i.png"
|
|
170
|
+
with open(output_path, "wb") as f:
|
|
171
|
+
f.write(img_bytes)
|
|
172
|
+
ASCIIColors.green(f"Text-to-Image generation OK. Image saved to {output_path}")
|
|
173
|
+
|
|
174
|
+
except Exception as e:
|
|
175
|
+
trace_exception(e)
|
|
176
|
+
ASCIIColors.error(f"Stability AI binding test failed: {e}")
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import requests
|
|
3
|
+
import time
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Optional, List, Dict, Any
|
|
6
|
+
|
|
7
|
+
from lollms_client.lollms_ttm_binding import LollmsTTMBinding
|
|
8
|
+
from ascii_colors import trace_exception, ASCIIColors
|
|
9
|
+
import pipmaster as pm
|
|
10
|
+
|
|
11
|
+
# Ensure required packages are installed
|
|
12
|
+
pm.ensure_packages(["requests"])
|
|
13
|
+
|
|
14
|
+
BindingName = "BeatovenAITTMBinding"
|
|
15
|
+
|
|
16
|
+
class BeatovenAITTMBinding(LollmsTTMBinding):
|
|
17
|
+
"""A Text-to-Music binding for the Beatoven.ai API."""
|
|
18
|
+
|
|
19
|
+
def __init__(self, **kwargs):
|
|
20
|
+
super().__init__(binding_name=BindingName, **kwargs)
|
|
21
|
+
self.api_key = self.settings.get("api_key") or os.environ.get("BEATOVEN_API_KEY")
|
|
22
|
+
if not self.api_key:
|
|
23
|
+
raise ValueError("Beatoven.ai API key is required. Please set it in config or as BEATOVEN_API_KEY env var.")
|
|
24
|
+
self.base_url = "https://api.beatoven.ai/api/v1"
|
|
25
|
+
self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
|
26
|
+
|
|
27
|
+
def list_models(self, **kwargs) -> List[str]:
|
|
28
|
+
# Beatoven.ai does not expose different models via the API.
|
|
29
|
+
# Customization is done via genre, mood, and tempo.
|
|
30
|
+
return ["default"]
|
|
31
|
+
|
|
32
|
+
def _poll_for_completion(self, task_id: str) -> Dict[str, Any]:
|
|
33
|
+
"""Polls the tasks endpoint until the composition is complete."""
|
|
34
|
+
poll_url = f"{self.base_url}/tasks/{task_id}"
|
|
35
|
+
while True:
|
|
36
|
+
try:
|
|
37
|
+
response = requests.get(poll_url, headers=self.headers)
|
|
38
|
+
response.raise_for_status()
|
|
39
|
+
data = response.json()
|
|
40
|
+
status = data.get("status")
|
|
41
|
+
|
|
42
|
+
if status == "success":
|
|
43
|
+
ASCIIColors.green("Composition task successful.")
|
|
44
|
+
return data
|
|
45
|
+
elif status == "failed":
|
|
46
|
+
error_info = data.get("error", "Unknown error.")
|
|
47
|
+
raise Exception(f"Beatoven.ai task failed: {error_info}")
|
|
48
|
+
else:
|
|
49
|
+
ASCIIColors.info(f"Task status is '{status}'. Waiting...")
|
|
50
|
+
time.sleep(5)
|
|
51
|
+
except requests.exceptions.HTTPError as e:
|
|
52
|
+
raise Exception(f"Failed to poll task status: {e.response.text}")
|
|
53
|
+
|
|
54
|
+
def generate_music(self, prompt: str, **kwargs) -> bytes:
|
|
55
|
+
"""
|
|
56
|
+
Generates music by creating a track, waiting for composition, and downloading the result.
|
|
57
|
+
"""
|
|
58
|
+
# Step 1: Create a track
|
|
59
|
+
create_track_url = f"{self.base_url}/tracks"
|
|
60
|
+
payload = {
|
|
61
|
+
"title": prompt[:100], # Use prompt as title, truncated
|
|
62
|
+
"duration_in_seconds": kwargs.get("duration", 30),
|
|
63
|
+
"genre": kwargs.get("genre", "Cinematic"),
|
|
64
|
+
"tempo": kwargs.get("tempo", "medium"),
|
|
65
|
+
"prompt": prompt
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
try:
|
|
69
|
+
ASCIIColors.info("Submitting music track request to Beatoven.ai...")
|
|
70
|
+
create_response = requests.post(create_track_url, json=payload, headers=self.headers)
|
|
71
|
+
create_response.raise_for_status()
|
|
72
|
+
task_id = create_response.json().get("task_id")
|
|
73
|
+
ASCIIColors.info(f"Track creation submitted. Task ID: {task_id}")
|
|
74
|
+
|
|
75
|
+
# Step 2: Poll for task completion
|
|
76
|
+
task_result = self._poll_for_completion(task_id)
|
|
77
|
+
track_id = task_result.get("track_id")
|
|
78
|
+
if not track_id:
|
|
79
|
+
raise Exception("Task completed but did not return a track_id.")
|
|
80
|
+
|
|
81
|
+
# Step 3: Get track details to find the audio URL
|
|
82
|
+
track_url = f"{self.base_url}/tracks/{track_id}"
|
|
83
|
+
track_response = requests.get(track_url, headers=self.headers)
|
|
84
|
+
track_response.raise_for_status()
|
|
85
|
+
|
|
86
|
+
audio_url = track_response.json().get("renders", {}).get("wav")
|
|
87
|
+
if not audio_url:
|
|
88
|
+
raise Exception("Could not find WAV render URL in the completed track details.")
|
|
89
|
+
|
|
90
|
+
# Step 4: Download the audio file
|
|
91
|
+
ASCIIColors.info(f"Downloading generated audio from {audio_url}")
|
|
92
|
+
audio_response = requests.get(audio_url)
|
|
93
|
+
audio_response.raise_for_status()
|
|
94
|
+
|
|
95
|
+
return audio_response.content
|
|
96
|
+
|
|
97
|
+
except requests.exceptions.HTTPError as e:
|
|
98
|
+
error_details = e.response.json()
|
|
99
|
+
raise Exception(f"Beatoven.ai API HTTP Error: {error_details}") from e
|
|
100
|
+
except Exception as e:
|
|
101
|
+
trace_exception(e)
|
|
102
|
+
raise
|
|
103
|
+
|
|
104
|
+
if __name__ == '__main__':
|
|
105
|
+
ASCIIColors.magenta("--- Beatoven.ai TTM Binding Test ---")
|
|
106
|
+
if "BEATOVEN_API_KEY" not in os.environ:
|
|
107
|
+
ASCIIColors.error("BEATOVEN_API_KEY environment variable not set. Cannot run test.")
|
|
108
|
+
exit(1)
|
|
109
|
+
|
|
110
|
+
try:
|
|
111
|
+
binding = BeatovenAITTMBinding()
|
|
112
|
+
|
|
113
|
+
ASCIIColors.cyan("\n--- Test: Music Generation ---")
|
|
114
|
+
prompt = "A mysterious and suspenseful cinematic track with soft piano and eerie strings, building tension."
|
|
115
|
+
music_bytes = binding.generate_music(prompt, duration=45, genre="Cinematic", tempo="slow")
|
|
116
|
+
|
|
117
|
+
assert len(music_bytes) > 1000, "Generated music bytes are too small."
|
|
118
|
+
output_path = Path(__file__).parent / "tmp_beatoven_music.wav"
|
|
119
|
+
with open(output_path, "wb") as f:
|
|
120
|
+
f.write(music_bytes)
|
|
121
|
+
ASCIIColors.green(f"Music generation OK. Audio saved to {output_path}")
|
|
122
|
+
|
|
123
|
+
except Exception as e:
|
|
124
|
+
trace_exception(e)
|
|
125
|
+
ASCIIColors.error(f"Beatoven.ai TTM binding test failed: {e}")
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import requests
|
|
3
|
+
import base64
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Optional, List, Dict, Any
|
|
6
|
+
|
|
7
|
+
from lollms_client.lollms_ttm_binding import LollmsTTMBinding
|
|
8
|
+
from ascii_colors import trace_exception, ASCIIColors
|
|
9
|
+
import pipmaster as pm
|
|
10
|
+
|
|
11
|
+
# Ensure required packages are installed
|
|
12
|
+
pm.ensure_packages(["requests"])
|
|
13
|
+
|
|
14
|
+
BindingName = "StabilityAITTMBinding"
|
|
15
|
+
|
|
16
|
+
# Models available via the Stability AI Audio API
|
|
17
|
+
# Sourced from: https://platform.stability.ai/docs/api-reference#tag/Generate/paths/~1v1~1generation~1stable-audio-2.0~1text-to-audio/post
|
|
18
|
+
STABILITY_AI_MODELS = [
|
|
19
|
+
{"model_name": "stable-audio-2.0", "display_name": "Stable Audio 2.0", "description": "High-quality, full-track music generation up to 3 minutes."},
|
|
20
|
+
{"model_name": "stable-audio-1.0", "display_name": "Stable Audio 1.0", "description": "Original model, best for short clips and sound effects."},
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
class StabilityAITTMBinding(LollmsTTMBinding):
|
|
24
|
+
"""A Text-to-Music binding for Stability AI's Stable Audio API."""
|
|
25
|
+
|
|
26
|
+
def __init__(self, **kwargs):
|
|
27
|
+
super().__init__(binding_name=BindingName, **kwargs)
|
|
28
|
+
self.api_key = self.settings.get("api_key") or os.environ.get("STABILITY_API_KEY")
|
|
29
|
+
if not self.api_key:
|
|
30
|
+
raise ValueError("Stability AI API key is required. Please set it in the configuration or as STABILITY_API_KEY environment variable.")
|
|
31
|
+
self.model_name = self.settings.get("model_name", "stable-audio-2.0")
|
|
32
|
+
|
|
33
|
+
def list_models(self, **kwargs) -> List[Dict[str, str]]:
|
|
34
|
+
return STABILITY_AI_MODELS
|
|
35
|
+
|
|
36
|
+
def generate_music(self, prompt: str, **kwargs) -> bytes:
|
|
37
|
+
"""
|
|
38
|
+
Generates music using the Stable Audio API.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
prompt (str): The text prompt describing the desired music.
|
|
42
|
+
duration (int): The duration of the audio in seconds. Defaults to 29.
|
|
43
|
+
**kwargs: Additional parameters for the API.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
bytes: The generated audio data in WAV format.
|
|
47
|
+
"""
|
|
48
|
+
url = f"https://api.stability.ai/v1/generation/{self.model_name}/text-to-audio"
|
|
49
|
+
headers = {
|
|
50
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
51
|
+
"Accept": "audio/wav",
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
# Get duration, with a default of 29 seconds as it's a common value
|
|
55
|
+
duration = kwargs.get("duration", 29)
|
|
56
|
+
|
|
57
|
+
payload = {
|
|
58
|
+
"text_prompts[0][text]": prompt,
|
|
59
|
+
"text_prompts[0][weight]": 1.0,
|
|
60
|
+
"seed": kwargs.get("seed", 0), # 0 for random in API
|
|
61
|
+
"steps": kwargs.get("steps", 100),
|
|
62
|
+
"cfg_scale": kwargs.get("cfg_scale", 7.0),
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
# Handle different parameter names for duration
|
|
66
|
+
if self.model_name == "stable-audio-2.0":
|
|
67
|
+
payload["duration"] = duration
|
|
68
|
+
else: # stable-audio-1.0
|
|
69
|
+
payload["sample_length"] = duration * 44100 # v1 uses sample length
|
|
70
|
+
|
|
71
|
+
try:
|
|
72
|
+
ASCIIColors.info(f"Requesting music from Stability AI ({self.model_name})...")
|
|
73
|
+
response = requests.post(url, headers=headers, data=payload)
|
|
74
|
+
response.raise_for_status()
|
|
75
|
+
|
|
76
|
+
ASCIIColors.green("Successfully generated music from Stability AI.")
|
|
77
|
+
return response.content
|
|
78
|
+
except requests.exceptions.HTTPError as e:
|
|
79
|
+
try:
|
|
80
|
+
error_details = e.response.json()
|
|
81
|
+
error_message = error_details.get("message", e.response.text)
|
|
82
|
+
except:
|
|
83
|
+
error_message = e.response.text
|
|
84
|
+
ASCIIColors.error(f"HTTP Error from Stability AI: {e.response.status_code} - {error_message}")
|
|
85
|
+
raise Exception(f"Stability AI API Error: {error_message}") from e
|
|
86
|
+
except Exception as e:
|
|
87
|
+
trace_exception(e)
|
|
88
|
+
raise Exception(f"An unexpected error occurred: {e}")
|
|
89
|
+
|
|
90
|
+
if __name__ == '__main__':
|
|
91
|
+
ASCIIColors.magenta("--- Stability AI TTM Binding Test ---")
|
|
92
|
+
if "STABILITY_API_KEY" not in os.environ:
|
|
93
|
+
ASCIIColors.error("STABILITY_API_KEY environment variable not set. Cannot run test.")
|
|
94
|
+
exit(1)
|
|
95
|
+
|
|
96
|
+
try:
|
|
97
|
+
# Test with default settings
|
|
98
|
+
binding = StabilityAITTMBinding()
|
|
99
|
+
|
|
100
|
+
ASCIIColors.cyan("\n--- Test: Music Generation ---")
|
|
101
|
+
prompt = "80s synthwave, retro futuristic, driving beat, cinematic"
|
|
102
|
+
music_bytes = binding.generate_music(prompt, duration=10)
|
|
103
|
+
|
|
104
|
+
assert len(music_bytes) > 1000, "Generated music bytes are too small."
|
|
105
|
+
output_path = Path(__file__).parent / "tmp_stability_music.wav"
|
|
106
|
+
with open(output_path, "wb") as f:
|
|
107
|
+
f.write(music_bytes)
|
|
108
|
+
ASCIIColors.green(f"Music generation OK. Audio saved to {output_path}")
|
|
109
|
+
|
|
110
|
+
except Exception as e:
|
|
111
|
+
trace_exception(e)
|
|
112
|
+
ASCIIColors.error(f"Stability AI TTM binding test failed: {e}")
|