lollms-client 1.6.2__py3-none-any.whl → 1.6.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lollms-client might be problematic. Click here for more details.

Files changed (41) hide show
  1. lollms_client/__init__.py +1 -1
  2. lollms_client/llm_bindings/azure_openai/__init__.py +2 -2
  3. lollms_client/llm_bindings/claude/__init__.py +2 -2
  4. lollms_client/llm_bindings/gemini/__init__.py +2 -2
  5. lollms_client/llm_bindings/grok/__init__.py +2 -2
  6. lollms_client/llm_bindings/groq/__init__.py +2 -2
  7. lollms_client/llm_bindings/hugging_face_inference_api/__init__.py +2 -2
  8. lollms_client/llm_bindings/litellm/__init__.py +1 -1
  9. lollms_client/llm_bindings/llamacpp/__init__.py +2 -2
  10. lollms_client/llm_bindings/lollms/__init__.py +1 -1
  11. lollms_client/llm_bindings/lollms_webui/__init__.py +1 -1
  12. lollms_client/llm_bindings/mistral/__init__.py +2 -2
  13. lollms_client/llm_bindings/novita_ai/__init__.py +2 -2
  14. lollms_client/llm_bindings/ollama/__init__.py +7 -4
  15. lollms_client/llm_bindings/open_router/__init__.py +2 -2
  16. lollms_client/llm_bindings/openai/__init__.py +1 -1
  17. lollms_client/llm_bindings/openllm/__init__.py +2 -2
  18. lollms_client/llm_bindings/openwebui/__init__.py +1 -1
  19. lollms_client/llm_bindings/perplexity/__init__.py +2 -2
  20. lollms_client/llm_bindings/pythonllamacpp/__init__.py +3 -3
  21. lollms_client/llm_bindings/tensor_rt/__init__.py +1 -1
  22. lollms_client/llm_bindings/transformers/__init__.py +4 -4
  23. lollms_client/llm_bindings/vllm/__init__.py +1 -1
  24. lollms_client/lollms_core.py +7 -1443
  25. lollms_client/lollms_llm_binding.py +1 -1
  26. lollms_client/lollms_tti_binding.py +1 -1
  27. lollms_client/tti_bindings/diffusers/__init__.py +320 -853
  28. lollms_client/tti_bindings/diffusers/server/main.py +882 -0
  29. lollms_client/tti_bindings/gemini/__init__.py +179 -239
  30. lollms_client/tti_bindings/leonardo_ai/__init__.py +1 -1
  31. lollms_client/tti_bindings/novita_ai/__init__.py +1 -1
  32. lollms_client/tti_bindings/stability_ai/__init__.py +1 -1
  33. lollms_client/tts_bindings/lollms/__init__.py +6 -1
  34. lollms_client/tts_bindings/piper_tts/__init__.py +1 -1
  35. lollms_client/tts_bindings/xtts/__init__.py +20 -14
  36. lollms_client/tts_bindings/xtts/server/main.py +17 -4
  37. {lollms_client-1.6.2.dist-info → lollms_client-1.6.5.dist-info}/METADATA +2 -2
  38. {lollms_client-1.6.2.dist-info → lollms_client-1.6.5.dist-info}/RECORD +41 -40
  39. {lollms_client-1.6.2.dist-info → lollms_client-1.6.5.dist-info}/WHEEL +0 -0
  40. {lollms_client-1.6.2.dist-info → lollms_client-1.6.5.dist-info}/licenses/LICENSE +0 -0
  41. {lollms_client-1.6.2.dist-info → lollms_client-1.6.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,882 @@
1
+ import os
2
+ import importlib
3
+ from io import BytesIO
4
+ from typing import Optional, List, Dict, Any, Union, Tuple
5
+ from pathlib import Path
6
+ import base64
7
+ import threading
8
+ import queue
9
+ from concurrent.futures import Future
10
+ import time
11
+ import hashlib
12
+ import requests
13
+ from tqdm import tqdm
14
+ import json
15
+ import shutil
16
+ import numpy as np
17
+ import gc
18
+ import argparse
19
+ import uvicorn
20
+ from fastapi import FastAPI, APIRouter, HTTPException, UploadFile, Form
21
+ from fastapi import Request, Response
22
+ from fastapi.responses import Response
23
+ from pydantic import BaseModel, Field
24
+ import sys
25
+ import platform
26
+ import inspect
27
+
28
+ # Add binding root to sys.path to ensure local modules can be imported if structured that way.
29
+ binding_root = Path(__file__).resolve().parent.parent
30
+ sys.path.insert(0, str(binding_root))
31
+
32
+ # --- Dependency Check and Imports ---
33
+ try:
34
+ import torch
35
+ from diffusers import (
36
+ AutoPipelineForText2Image, AutoPipelineForImage2Image, AutoPipelineForInpainting,
37
+ DiffusionPipeline, StableDiffusionPipeline, QwenImageEditPipeline, QwenImageEditPlusPipeline
38
+ )
39
+ from diffusers.utils import load_image
40
+ from PIL import Image
41
+ from ascii_colors import trace_exception, ASCIIColors
42
+ DIFFUSERS_AVAILABLE = True
43
+ except ImportError as e:
44
+ print(f"FATAL: A required package is missing from the server's venv: {e}.")
45
+ DIFFUSERS_AVAILABLE = False
46
+ # Define dummy classes to allow server to start and report error via API
47
+ class Dummy: pass
48
+ torch = Dummy()
49
+ torch.cuda = Dummy()
50
+ torch.cuda.is_available = lambda: False
51
+ torch.backends = Dummy()
52
+ torch.backends.mps = Dummy()
53
+ torch.backends.mps.is_available = lambda: False
54
+ AutoPipelineForText2Image = AutoPipelineForImage2Image = AutoPipelineForInpainting = DiffusionPipeline = StableDiffusionPipeline = QwenImageEditPipeline = QwenImageEditPlusPipeline = Image = load_image = ASCIIColors = trace_exception = Dummy
55
+
56
+ # --- Server Setup ---
57
+ app = FastAPI(title="Diffusers TTI Server")
58
+ router = APIRouter()
59
+ MODELS_PATH = Path("./models")
60
+
61
+ # --- START: Core Logic (Complete and Unabridged) ---
62
+ CIVITAI_MODELS = {
63
+ "realistic-vision-v6": {
64
+ "display_name": "Realistic Vision V6.0", "url": "https://civitai.com/api/download/models/501240?type=Model&format=SafeTensor&size=pruned&fp=fp16",
65
+ "filename": "realisticVisionV60_v60B1.safetensors", "description": "Photorealistic SD1.5 checkpoint.", "owned_by": "civitai"
66
+ },
67
+ "absolute-reality": {
68
+ "display_name": "Absolute Reality", "url": "https://civitai.com/api/download/models/132760?type=Model&format=SafeTensor&size=pruned&fp=fp16",
69
+ "filename": "absolutereality_v181.safetensors", "description": "General realistic SD1.5.", "owned_by": "civitai"
70
+ },
71
+ "dreamshaper-8": {
72
+ "display_name": "DreamShaper 8", "url": "https://civitai.com/api/download/models/128713",
73
+ "filename": "dreamshaper_8.safetensors", "description": "Versatile SD1.5 style model.", "owned_by": "civitai"
74
+ },
75
+ "juggernaut-xl": {
76
+ "display_name": "Juggernaut XL", "url": "https://civitai.com/api/download/models/133005",
77
+ "filename": "juggernautXL_version6Rundiffusion.safetensors", "description": "Artistic SDXL.", "owned_by": "civitai"
78
+ },
79
+ "lyriel-v1.6": {
80
+ "display_name": "Lyriel v1.6", "url": "https://civitai.com/api/download/models/72396?type=Model&format=SafeTensor&size=full&fp=fp16",
81
+ "filename": "lyriel_v16.safetensors", "description": "Fantasy/stylized SD1.5.", "owned_by": "civitai"
82
+ },
83
+ "ui_icons": {
84
+ "display_name": "UI Icons", "url": "https://civitai.com/api/download/models/367044?type=Model&format=SafeTensor&size=full&fp=fp16",
85
+ "filename": "uiIcons_v10.safetensors", "description": "A model for generating UI icons.", "owned_by": "civitai"
86
+ },
87
+ "meinamix": {
88
+ "display_name": "MeinaMix", "url": "https://civitai.com/api/download/models/948574?type=Model&format=SafeTensor&size=pruned&fp=fp16",
89
+ "filename": "meinamix_meinaV11.safetensors", "description": "Anime/illustration SD1.5.", "owned_by": "civitai"
90
+ },
91
+ "rpg-v5": {
92
+ "display_name": "RPG v5", "url": "https://civitai.com/api/download/models/124626?type=Model&format=SafeTensor&size=pruned&fp=fp16",
93
+ "filename": "rpg_v5.safetensors", "description": "RPG assets SD1.5.", "owned_by": "civitai"
94
+ },
95
+ "pixel-art-xl": {
96
+ "display_name": "Pixel Art XL", "url": "https://civitai.com/api/download/models/135931?type=Model&format=SafeTensor",
97
+ "filename": "pixelartxl_v11.safetensors", "description": "Pixel art SDXL.", "owned_by": "civitai"
98
+ },
99
+ "lowpoly-world": {
100
+ "display_name": "Lowpoly World", "url": "https://civitai.com/api/download/models/146502?type=Model&format=SafeTensor",
101
+ "filename": "LowpolySDXL.safetensors", "description": "Lowpoly style SD1.5.", "owned_by": "civitai"
102
+ },
103
+ "toonyou": {
104
+ "display_name": "ToonYou", "url": "https://civitai.com/api/download/models/125771?type=Model&format=SafeTensor&size=pruned&fp=fp16",
105
+ "filename": "toonyou_beta6.safetensors", "description": "Cartoon/Disney SD1.5.", "owned_by": "civitai"
106
+ },
107
+ "papercut": {
108
+ "display_name": "Papercut", "url": "https://civitai.com/api/download/models/133503?type=Model&format=SafeTensor",
109
+ "filename": "papercut.safetensors", "description": "Paper cutout SD1.5.", "owned_by": "civitai"
110
+ },
111
+ "fantassifiedIcons": {
112
+ "display_name": "Fantassified Icons", "url": "https://civitai.com/api/download/models/67584?type=Model&format=SafeTensor&size=pruned&fp=fp16",
113
+ "filename": "fantassifiedIcons_fantassifiedIconsV20.safetensors", "description": "Flat, modern Icons.", "owned_by": "civitai"
114
+ },
115
+ "game_icon_institute": {
116
+ "display_name": "Game icon institute", "url": "https://civitai.com/api/download/models/158776?type=Model&format=SafeTensor&size=full&fp=fp16",
117
+ "filename": "gameIconInstituteV10_v10.safetensors", "description": "Flat, modern game Icons.", "owned_by": "civitai"
118
+ },
119
+ "M4RV3LS_DUNGEONS": {
120
+ "display_name": "M4RV3LS & DUNGEONS", "url": "https://civitai.com/api/download/models/139417?type=Model&format=SafeTensor&size=pruned&fp=fp16",
121
+ "filename": "M4RV3LSDUNGEONSNEWV40COMICS_mD40.safetensors", "description": "comics.", "owned_by": "civitai"
122
+ },
123
+ }
124
+
125
+ HF_DEFAULT_MODELS = [
126
+ {"family": "FLUX", "model_name": "black-forest-labs/FLUX.1-schnell", "display_name": "FLUX.1 Schnell", "desc": "A fast and powerful next-gen T2I model."},
127
+ {"family": "FLUX", "model_name": "black-forest-labs/FLUX.1-dev", "display_name": "FLUX.1 Dev", "desc": "The larger, developer version of the FLUX.1 model."},
128
+ {"family": "SDXL", "model_name": "stabilityai/stable-diffusion-xl-base-1.0", "display_name": "SDXL Base 1.0", "desc": "Text2Image 1024 native."},
129
+ {"family": "SDXL", "model_name": "stabilityai/stable-diffusion-xl-refiner-1.0", "display_name": "SDXL Refiner 1.0", "desc": "Refiner for SDXL."},
130
+ {"family": "SD 1.x", "model_name": "runwayml/stable-diffusion-v1-5", "display_name": "Stable Diffusion 1.5", "desc": "Classic SD1.5."},
131
+ {"family": "SD 2.x", "model_name": "stabilityai/stable-diffusion-2-1", "display_name": "Stable Diffusion 2.1", "desc": "SD2.1 base."},
132
+ {"family": "SD3", "model_name": "stabilityai/stable-diffusion-3-medium-diffusers", "display_name": "Stable Diffusion 3 Medium", "desc": "SD3 medium."},
133
+ {"family": "Qwen", "model_name": "Qwen/Qwen-Image", "display_name": "Qwen Image", "desc": "Dedicated image generation."},
134
+ {"family": "Specialized", "model_name": "playgroundai/playground-v2.5-1024px-aesthetic", "display_name": "Playground v2.5", "desc": "High aesthetic 1024."},
135
+ {"family": "Editors", "model_name": "Qwen/Qwen-Image-Edit", "display_name": "Qwen Image Edit", "desc": "Dedicated image editing."},
136
+ {"family": "Editors", "model_name": "Qwen/Qwen-Image-Edit-2509", "display_name": "Qwen Image Edit Plus (Multi-Image)", "desc": "Advanced multi-image editing, fusion, and pose transfer."}
137
+ ]
138
+
139
+
140
+ TORCH_DTYPE_MAP_STR_TO_OBJ = {
141
+ "float16": getattr(torch, 'float16', 'float16'), "bfloat16": getattr(torch, 'bfloat16', 'bfloat16'),
142
+ "float32": getattr(torch, 'float32', 'float32'), "auto": "auto"
143
+ }
144
+
145
+ SCHEDULER_MAPPING = {
146
+ "default": None, "ddim": "DDIMScheduler", "ddpm": "DDPMScheduler", "deis_multistep": "DEISMultistepScheduler",
147
+ "dpm_multistep": "DPMSolverMultistepScheduler", "dpm_multistep_karras": "DPMSolverMultistepScheduler", "dpm_single": "DPMSolverSinglestepScheduler",
148
+ "dpm_adaptive": "DPMSolverPlusPlusScheduler", "dpm++_2m": "DPMSolverMultistepScheduler", "dpm++_2m_karras": "DPMSolverMultistepScheduler",
149
+ "dpm++_2s_ancestral": "DPMSolverAncestralDiscreteScheduler", "dpm++_2s_ancestral_karras": "DPMSolverAncestralDiscreteScheduler", "dpm++_sde": "DPMSolverSDEScheduler",
150
+ "dpm++_sde_karras": "DPMSolverSDEScheduler", "euler_ancestral_discrete": "EulerAncestralDiscreteScheduler", "euler_discrete": "EulerDiscreteScheduler",
151
+ "heun_discrete": "HeunDiscreteScheduler", "heun_karras": "HeunDiscreteScheduler", "lms_discrete": "LMSDiscreteScheduler",
152
+ "lms_karras": "LMSDiscreteScheduler", "pndm": "PNDMScheduler", "unipc_multistep": "UniPCMultistepScheduler",
153
+ "dpm++_2m_sde": "DPMSolverMultistepScheduler", "dpm++_2m_sde_karras": "DPMSolverMultistepScheduler", "dpm2": "KDPM2DiscreteScheduler",
154
+ "dpm2_karras": "KDPM2DiscreteScheduler", "dpm2_a": "KDPM2AncestralDiscreteScheduler", "dpm2_a_karras": "KDPM2AncestralDiscreteScheduler",
155
+ "euler": "EulerDiscreteScheduler", "euler_a": "EulerAncestralDiscreteScheduler", "heun": "HeunDiscreteScheduler", "lms": "LMSDiscreteScheduler"
156
+ }
157
+ SCHEDULER_USES_KARRAS_SIGMAS = [
158
+ "dpm_multistep_karras","dpm++_2m_karras","dpm++_2s_ancestral_karras", "dpm++_sde_karras","heun_karras","lms_karras",
159
+ "dpm++_2m_sde_karras","dpm2_karras","dpm2_a_karras"
160
+ ]
161
+
162
+
163
+ class ModelManager:
164
+ def __init__(self, config: Dict[str, Any], models_path: Path, registry: 'PipelineRegistry'):
165
+ self.config = config
166
+ self.models_path = models_path
167
+ self.registry = registry
168
+ self.pipeline: Optional[DiffusionPipeline] = None
169
+ self.current_task: Optional[str] = None
170
+ self.ref_count = 0
171
+ self.lock = threading.Lock()
172
+ self.queue = queue.Queue()
173
+ self.is_loaded = False
174
+ self.last_used_time = time.time()
175
+ self._stop_event = threading.Event()
176
+ self.worker_thread = threading.Thread(target=self._generation_worker, daemon=True)
177
+ self.worker_thread.start()
178
+ self._stop_monitor_event = threading.Event()
179
+ self._unload_monitor_thread = None
180
+ self._start_unload_monitor()
181
+ self.supported_args: Optional[set] = None
182
+
183
+ def acquire(self):
184
+ with self.lock:
185
+ self.ref_count += 1
186
+ return self
187
+
188
+ def release(self):
189
+ with self.lock:
190
+ self.ref_count -= 1
191
+ return self.ref_count
192
+
193
+ def stop(self):
194
+ self._stop_event.set()
195
+ if self._unload_monitor_thread:
196
+ self._stop_monitor_event.set()
197
+ self._unload_monitor_thread.join(timeout=2)
198
+ self.queue.put(None)
199
+ self.worker_thread.join(timeout=5)
200
+
201
+ def _start_unload_monitor(self):
202
+ unload_after = self.config.get("unload_inactive_model_after", 0)
203
+ if unload_after > 0 and self._unload_monitor_thread is None:
204
+ self._stop_monitor_event.clear()
205
+ self._unload_monitor_thread = threading.Thread(target=self._unload_monitor, daemon=True)
206
+ self._unload_monitor_thread.start()
207
+
208
+ def _unload_monitor(self):
209
+ unload_after = self.config.get("unload_inactive_model_after", 0)
210
+ if unload_after <= 0:
211
+ return
212
+ ASCIIColors.info(f"Starting inactivity monitor for '{self.config['model_name']}' (timeout: {unload_after}s).")
213
+ while not self._stop_monitor_event.wait(timeout=5.0):
214
+ with self.lock:
215
+ if not self.is_loaded:
216
+ continue
217
+ if time.time() - self.last_used_time > unload_after:
218
+ ASCIIColors.info(f"Model '{self.config['model_name']}' has been inactive. Unloading.")
219
+ self._unload_pipeline()
220
+
221
+ def _resolve_model_path(self, model_name: str) -> Union[str, Path]:
222
+ path_obj = Path(model_name)
223
+ if path_obj.is_absolute() and path_obj.exists():
224
+ return model_name
225
+ if model_name in CIVITAI_MODELS:
226
+ filename = CIVITAI_MODELS[model_name]["filename"]
227
+ local_path = self.models_path / filename
228
+ if not local_path.exists():
229
+ self._download_civitai_model(model_name)
230
+ return local_path
231
+ local_path = self.models_path / model_name
232
+ if local_path.exists():
233
+ return local_path
234
+ return model_name
235
+
236
+ def _download_civitai_model(self, model_key: str):
237
+ model_info = CIVITAI_MODELS[model_key]
238
+ url = model_info["url"]
239
+ filename = model_info["filename"]
240
+ dest_path = self.models_path / filename
241
+ temp_path = dest_path.with_suffix(".temp")
242
+ ASCIIColors.cyan(f"Downloading '{filename}' from Civitai... to {dest_path}")
243
+ try:
244
+ with requests.get(url, stream=True) as r:
245
+ r.raise_for_status()
246
+ total_size = int(r.headers.get('content-length', 0))
247
+ with open(temp_path, 'wb') as f, tqdm(total=total_size, unit='iB', unit_scale=True, desc=f"Downloading {filename}") as bar:
248
+ for chunk in r.iter_content(chunk_size=8192):
249
+ f.write(chunk)
250
+ bar.update(len(chunk))
251
+ shutil.move(temp_path, dest_path)
252
+ ASCIIColors.green(f"Model '{filename}' downloaded successfully.")
253
+ except Exception as e:
254
+ if temp_path.exists():
255
+ temp_path.unlink()
256
+ raise Exception(f"Failed to download model {filename}: {e}")
257
+
258
+ def _set_scheduler(self):
259
+ if not self.pipeline:
260
+ return
261
+ if "Qwen" in self.config.get("model_name", "") or "FLUX" in self.config.get("model_name", ""):
262
+ ASCIIColors.info("Special model detected, skipping custom scheduler setup.")
263
+ return
264
+ scheduler_name_key = self.config["scheduler_name"].lower()
265
+ if scheduler_name_key == "default":
266
+ return
267
+ scheduler_class_name = SCHEDULER_MAPPING.get(scheduler_name_key)
268
+ if scheduler_class_name:
269
+ try:
270
+ SchedulerClass = getattr(importlib.import_module("diffusers.schedulers"), scheduler_class_name)
271
+ scheduler_config = self.pipeline.scheduler.config
272
+ scheduler_config["use_karras_sigmas"] = scheduler_name_key in SCHEDULER_USES_KARRAS_SIGMAS
273
+ self.pipeline.scheduler = SchedulerClass.from_config(scheduler_config)
274
+ ASCIIColors.info(f"Switched scheduler to {scheduler_class_name}")
275
+ except Exception as e:
276
+ ASCIIColors.warning(f"Could not switch scheduler to {scheduler_name_key}: {e}. Using current default.")
277
+
278
+ def _execute_load_pipeline(self, task: str, model_path: Union[str, Path], torch_dtype: Any):
279
+ if platform.system() == "Windows":
280
+ os.environ["HF_HUB_ENABLE_SYMLINKS"] = "0"
281
+
282
+ model_name_from_config = self.config.get("model_name", "")
283
+ use_device_map = False
284
+
285
+ try:
286
+ load_params = {}
287
+ if self.config.get("hf_cache_path"):
288
+ load_params["cache_dir"] = str(self.config["hf_cache_path"])
289
+ load_params["torch_dtype"] = torch_dtype
290
+
291
+ is_qwen_model = "Qwen" in model_name_from_config
292
+ is_flux_model = "FLUX" in model_name_from_config
293
+
294
+ if is_qwen_model or is_flux_model:
295
+ ASCIIColors.info(f"Special model '{model_name_from_config}' detected. Using dedicated pipeline loader.")
296
+ load_params.update({
297
+ "use_safetensors": self.config["use_safetensors"],
298
+ "token": self.config["hf_token"],
299
+ "local_files_only": self.config["local_files_only"]
300
+ })
301
+ if self.config["hf_variant"]:
302
+ load_params["variant"] = self.config["hf_variant"]
303
+ if not self.config["safety_checker_on"]:
304
+ load_params["safety_checker"] = None
305
+
306
+ should_offload = self.config["enable_cpu_offload"] or self.config["enable_sequential_cpu_offload"]
307
+ if should_offload:
308
+ ASCIIColors.info(f"Offload enabled. Forcing device_map='auto' for {model_name_from_config}.")
309
+ use_device_map = True
310
+ load_params["device_map"] = "auto"
311
+
312
+ if is_flux_model:
313
+ self.pipeline = AutoPipelineForText2Image.from_pretrained(model_name_from_config, **load_params)
314
+ elif "Qwen-Image-Edit-2509" in model_name_from_config:
315
+ self.pipeline = QwenImageEditPlusPipeline.from_pretrained(model_name_from_config, **load_params)
316
+ elif "Qwen-Image-Edit" in model_name_from_config:
317
+ self.pipeline = QwenImageEditPipeline.from_pretrained(model_name_from_config, **load_params)
318
+ elif "Qwen/Qwen-Image" in model_name_from_config:
319
+ self.pipeline = DiffusionPipeline.from_pretrained(model_name_from_config, **load_params)
320
+
321
+ else:
322
+ is_safetensors_file = str(model_path).endswith(".safetensors")
323
+ if is_safetensors_file:
324
+ ASCIIColors.info(f"Loading standard model from local .safetensors file: {model_path}")
325
+ try:
326
+ self.pipeline = AutoPipelineForText2Image.from_single_file(model_path, **load_params)
327
+ except Exception as e:
328
+ ASCIIColors.warning(f"Failed to load with AutoPipeline, falling back to StableDiffusionPipeline: {e}")
329
+ self.pipeline = StableDiffusionPipeline.from_single_file(model_path, **load_params)
330
+ else:
331
+ ASCIIColors.info(f"Loading standard model from Hub: {model_path}")
332
+ load_params.update({
333
+ "use_safetensors": self.config["use_safetensors"],
334
+ "token": self.config["hf_token"],
335
+ "local_files_only": self.config["local_files_only"]
336
+ })
337
+ if self.config["hf_variant"]:
338
+ load_params["variant"] = self.config["hf_variant"]
339
+ if not self.config["safety_checker_on"]:
340
+ load_params["safety_checker"] = None
341
+
342
+ is_large_model = "stable-diffusion-3" in str(model_path)
343
+ should_offload = self.config["enable_cpu_offload"] or self.config["enable_sequential_cpu_offload"]
344
+ if is_large_model and should_offload:
345
+ ASCIIColors.info(f"Large model '{model_path}' detected with offload enabled. Using device_map='auto'.")
346
+ use_device_map = True
347
+ load_params["device_map"] = "auto"
348
+
349
+ if task == "text2image":
350
+ self.pipeline = AutoPipelineForText2Image.from_pretrained(model_path, **load_params)
351
+ elif task == "image2image":
352
+ self.pipeline = AutoPipelineForImage2Image.from_pretrained(model_path, **load_params)
353
+ elif task == "inpainting":
354
+ self.pipeline = AutoPipelineForInpainting.from_pretrained(model_path, **load_params)
355
+
356
+ except Exception as e:
357
+ error_str = str(e).lower()
358
+ if "401" in error_str or "gated" in error_str or "authorization" in error_str:
359
+ msg = (f"AUTHENTICATION FAILED for model '{model_name_from_config}'. Please ensure you accepted the model license and provided a valid HF token.")
360
+ raise RuntimeError(msg)
361
+ raise e
362
+
363
+ self._set_scheduler()
364
+
365
+ if not use_device_map:
366
+ self.pipeline.to(self.config["device"])
367
+ if self.config["enable_xformers"]:
368
+ try:
369
+ self.pipeline.enable_xformers_memory_efficient_attention()
370
+ except Exception as e:
371
+ ASCIIColors.warning(f"Could not enable xFormers: {e}.")
372
+
373
+ if self.config["enable_cpu_offload"] and self.config["device"] != "cpu":
374
+ self.pipeline.enable_model_cpu_offload()
375
+ elif self.config["enable_sequential_cpu_offload"] and self.config["device"] != "cpu":
376
+ self.pipeline.enable_sequential_cpu_offload()
377
+ else:
378
+ ASCIIColors.info("Device map handled device placement. Skipping manual pipeline.to() and offload calls.")
379
+
380
+ if self.pipeline:
381
+ sig = inspect.signature(self.pipeline.__call__)
382
+ self.supported_args = {p.name for p in sig.parameters.values()}
383
+ ASCIIColors.info(f"Pipeline supported arguments detected: {self.supported_args}")
384
+
385
+ self.is_loaded = True
386
+ self.current_task = task
387
+ self.last_used_time = time.time()
388
+ ASCIIColors.green(f"Model '{model_name_from_config}' loaded successfully using '{'device_map' if use_device_map else 'standard'}' mode for task '{task}'.")
389
+
390
+ def _load_pipeline_for_task(self, task: str):
391
+ if self.pipeline and self.current_task == task:
392
+ return
393
+ if self.pipeline:
394
+ self._unload_pipeline()
395
+
396
+ model_name = self.config.get("model_name", "")
397
+ if not model_name:
398
+ raise ValueError("Model name cannot be empty for loading.")
399
+
400
+ ASCIIColors.info(f"Loading Diffusers model: {model_name} for task: {task}")
401
+ model_path = self._resolve_model_path(model_name)
402
+ torch_dtype = TORCH_DTYPE_MAP_STR_TO_OBJ.get(self.config["torch_dtype_str"].lower())
403
+
404
+ try:
405
+ self._execute_load_pipeline(task, model_path, torch_dtype)
406
+ return
407
+ except Exception as e:
408
+ is_oom = "out of memory" in str(e).lower()
409
+ if not is_oom or not hasattr(self, 'registry'):
410
+ raise e
411
+
412
+ ASCIIColors.warning(f"Failed to load '{model_name}' due to OOM. Attempting to unload other models to free VRAM.")
413
+
414
+ candidates_to_unload = [m for m in self.registry.get_all_managers() if m is not self and m.is_loaded]
415
+ candidates_to_unload.sort(key=lambda m: m.last_used_time)
416
+
417
+ if not candidates_to_unload:
418
+ ASCIIColors.error("OOM error, but no other models are available to unload.")
419
+ raise e
420
+
421
+ for victim in candidates_to_unload:
422
+ ASCIIColors.info(f"Unloading '{victim.config['model_name']}' (last used: {time.ctime(victim.last_used_time)}) to free VRAM.")
423
+ victim._unload_pipeline()
424
+
425
+ try:
426
+ ASCIIColors.info(f"Retrying to load '{model_name}'...")
427
+ self._execute_load_pipeline(task, model_path, torch_dtype)
428
+ ASCIIColors.green(f"Successfully loaded '{model_name}' after freeing VRAM.")
429
+ return
430
+ except Exception as retry_e:
431
+ is_oom_retry = "out of memory" in str(retry_e).lower()
432
+ if not is_oom_retry:
433
+ raise retry_e
434
+
435
+ ASCIIColors.error(f"Could not load '{model_name}' even after unloading all other models.")
436
+ raise e
437
+
438
+ def _unload_pipeline(self):
439
+ if self.pipeline:
440
+ model_name = self.config.get('model_name', 'Unknown')
441
+ del self.pipeline
442
+ self.pipeline = None
443
+ self.supported_args = None
444
+ gc.collect()
445
+ if torch and torch.cuda.is_available():
446
+ torch.cuda.empty_cache()
447
+ self.is_loaded = False
448
+ self.current_task = None
449
+ ASCIIColors.info(f"Model '{model_name}' unloaded and VRAM cleared.")
450
+
451
+ def _generation_worker(self):
452
+ while not self._stop_event.is_set():
453
+ try:
454
+ job = self.queue.get(timeout=1)
455
+ if job is None:
456
+ break
457
+ future, task, pipeline_args = job
458
+ output = None
459
+ try:
460
+ with self.lock:
461
+ self.last_used_time = time.time()
462
+ if not self.is_loaded or self.current_task != task:
463
+ self._load_pipeline_for_task(task)
464
+
465
+ if self.supported_args:
466
+ filtered_args = {k: v for k, v in pipeline_args.items() if k in self.supported_args}
467
+ else:
468
+ ASCIIColors.warning("Supported argument set not found. Using unfiltered arguments.")
469
+ filtered_args = pipeline_args
470
+
471
+ with torch.no_grad():
472
+ output = self.pipeline(**filtered_args)
473
+
474
+ pil = output.images[0]
475
+ buf = BytesIO()
476
+ pil.save(buf, format="PNG")
477
+ future.set_result(buf.getvalue())
478
+ except Exception as e:
479
+ trace_exception(e)
480
+ future.set_exception(e)
481
+ finally:
482
+ self.queue.task_done()
483
+ if output is not None:
484
+ del output
485
+ gc.collect()
486
+ if torch.cuda.is_available():
487
+ torch.cuda.empty_cache()
488
+ except queue.Empty:
489
+ continue
490
+
491
+ class PipelineRegistry:
492
+ _instance = None
493
+ _lock = threading.Lock()
494
+ def __new__(cls, *args, **kwargs):
495
+ with cls._lock:
496
+ if cls._instance is None:
497
+ cls._instance = super().__new__(cls)
498
+ cls._instance._managers = {}
499
+ cls._instance._registry_lock = threading.Lock()
500
+ return cls._instance
501
+ @staticmethod
502
+ def _get_critical_keys():
503
+ return [
504
+ "model_name","device","torch_dtype_str","use_safetensors",
505
+ "safety_checker_on","hf_variant","enable_cpu_offload",
506
+ "enable_sequential_cpu_offload","enable_xformers",
507
+ "local_files_only","hf_cache_path","unload_inactive_model_after"
508
+ ]
509
+ def _get_config_key(self, config: Dict[str, Any]) -> str:
510
+ key_data = tuple(sorted((k, config.get(k)) for k in self._get_critical_keys()))
511
+ return hashlib.sha256(str(key_data).encode('utf-8')).hexdigest()
512
+ def get_manager(self, config: Dict[str, Any], models_path: Path) -> ModelManager:
513
+ key = self._get_config_key(config)
514
+ with self._registry_lock:
515
+ if key not in self._managers:
516
+ self._managers[key] = ModelManager(config.copy(), models_path, self)
517
+ return self._managers[key].acquire()
518
+ def release_manager(self, config: Dict[str, Any]):
519
+ key = self._get_config_key(config)
520
+ with self._registry_lock:
521
+ if key in self._managers:
522
+ manager = self._managers[key]
523
+ ref_count = manager.release()
524
+ if ref_count == 0:
525
+ ASCIIColors.info(f"Reference count for model '{config.get('model_name')}' is zero. Cleaning up manager.")
526
+ manager.stop()
527
+ with manager.lock:
528
+ manager._unload_pipeline()
529
+ del self._managers[key]
530
+ def get_active_managers(self) -> List[ModelManager]:
531
+ with self._registry_lock:
532
+ return [m for m in self._managers.values() if m.is_loaded]
533
+ def get_all_managers(self) -> List[ModelManager]:
534
+ with self._registry_lock:
535
+ return list(self._managers.values())
536
+
537
+ class ServerState:
538
+ def __init__(self, models_path: Path):
539
+ self.models_path = models_path
540
+ self.models_path.mkdir(parents=True, exist_ok=True)
541
+ self.config_path = self.models_path.parent / "diffusers_server_config.json"
542
+ self.registry = PipelineRegistry()
543
+ self.manager: Optional[ModelManager] = None
544
+ self.config = {}
545
+ self.load_config()
546
+ self._resolve_device_and_dtype()
547
+ if self.config.get("model_name"):
548
+ try:
549
+ ASCIIColors.info(f"Acquiring initial model manager for '{self.config['model_name']}' on startup.")
550
+ self.manager = self.registry.get_manager(self.config, self.models_path)
551
+ except Exception as e:
552
+ ASCIIColors.error(f"Failed to acquire model manager on startup: {e}")
553
+ self.manager = None
554
+
555
+ def get_default_config(self) -> Dict[str, Any]:
556
+ return {
557
+ "model_name": "", "device": "auto", "torch_dtype_str": "auto", "use_safetensors": True,
558
+ "scheduler_name": "default", "safety_checker_on": True, "num_inference_steps": 25,
559
+ "guidance_scale": 7.0, "width": 512, "height": 512, "seed": -1,
560
+ "enable_cpu_offload": False, "enable_sequential_cpu_offload": False, "enable_xformers": False,
561
+ "hf_variant": None, "hf_token": None, "hf_cache_path": None, "local_files_only": False,
562
+ "unload_inactive_model_after": 0
563
+ }
564
+
565
+ def save_config(self):
566
+ try:
567
+ with open(self.config_path, 'w') as f:
568
+ json.dump(self.config, f, indent=4)
569
+ ASCIIColors.info(f"Server config saved to {self.config_path}")
570
+ except Exception as e:
571
+ ASCIIColors.error(f"Failed to save server config: {e}")
572
+
573
+ def load_config(self):
574
+ default_config = self.get_default_config()
575
+ if self.config_path.exists():
576
+ try:
577
+ with open(self.config_path, 'r') as f:
578
+ loaded_config = json.load(f)
579
+ default_config.update(loaded_config)
580
+ self.config = default_config
581
+ ASCIIColors.info(f"Loaded server configuration from {self.config_path}")
582
+ except (json.JSONDecodeError, IOError) as e:
583
+ ASCIIColors.warning(f"Could not load config file, using defaults. Error: {e}")
584
+ self.config = default_config
585
+ else:
586
+ self.config = default_config
587
+ self.save_config()
588
+
589
+ def _resolve_device_and_dtype(self):
590
+ if self.config.get("device", "auto").lower() == "auto":
591
+ self.config["device"] = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
592
+
593
+ if ("Qwen" in self.config.get("model_name", "") or "FLUX" in self.config.get("model_name", "")) and self.config["device"] == "cuda":
594
+ if hasattr(torch.cuda, 'is_bf16_supported') and torch.cuda.is_bf16_supported():
595
+ self.config["torch_dtype_str"] = "bfloat16"
596
+ ASCIIColors.info("Special model detected on compatible hardware. Forcing dtype to bfloat16 for stability.")
597
+ return
598
+
599
+ if self.config["torch_dtype_str"].lower() == "auto":
600
+ self.config["torch_dtype_str"] = "float16" if self.config["device"] != "cpu" else "float32"
601
+
602
+ def update_settings(self, new_settings: Dict[str, Any]):
603
+ if 'model' in new_settings and 'model_name' not in new_settings:
604
+ new_settings['model_name'] = new_settings.pop('model')
605
+
606
+ if self.config.get("model_name") and not new_settings.get("model_name"):
607
+ ASCIIColors.info("Incoming settings have no model_name. Preserving existing model.")
608
+ new_settings["model_name"] = self.config["model_name"]
609
+
610
+ if self.manager:
611
+ self.registry.release_manager(self.manager.config)
612
+ self.manager = None
613
+
614
+ self.config.update(new_settings)
615
+ ASCIIColors.info(f"Server config updated. Current model_name: {self.config.get('model_name')}")
616
+
617
+ self._resolve_device_and_dtype()
618
+
619
+ if self.config.get("model_name"):
620
+ ASCIIColors.info("Acquiring model manager with updated configuration...")
621
+ self.manager = self.registry.get_manager(self.config, self.models_path)
622
+ else:
623
+ ASCIIColors.warning("No model_name in config after update, manager not acquired.")
624
+
625
+ self.save_config()
626
+ return True
627
+
628
+ def get_active_manager(self) -> ModelManager:
629
+ if self.manager:
630
+ return self.manager
631
+ raise HTTPException(status_code=400, detail="No model is configured or manager is not active. Please set a model using the /set_settings endpoint.")
632
+
633
+ state: Optional[ServerState] = None
634
+
635
+ # --- Pydantic Models for API ---
636
+ class T2IRequest(BaseModel):
637
+ prompt: str
638
+ negative_prompt: str = ""
639
+ params: Dict[str, Any] = Field(default_factory=dict)
640
+
641
+ class EditRequestPayload(BaseModel):
642
+ prompt: str
643
+ image_paths: List[str] = Field(default_factory=list)
644
+ params: Dict[str, Any] = Field(default_factory=dict)
645
+
646
+ class EditRequestJSON(BaseModel):
647
+ prompt: str
648
+ images_b64: List[str] = Field(description="A list of Base64 encoded image strings.")
649
+ params: Dict[str, Any] = Field(default_factory=dict)
650
+ def get_sanitized_request_for_logging(request_data: Any) -> Dict[str, Any]:
651
+ """
652
+ Takes a request object (Pydantic model or dict) and returns a 'safe' dictionary
653
+ for logging, with long base64 strings replaced by placeholders.
654
+ """
655
+ import copy
656
+
657
+ try:
658
+ if hasattr(request_data, 'model_dump'):
659
+ data = request_data.model_dump()
660
+ elif isinstance(request_data, dict):
661
+ data = copy.deepcopy(request_data)
662
+ else:
663
+ return {"error": "Unsupported data type for sanitization"}
664
+
665
+ # Sanitize the main list of images
666
+ if 'images_b64' in data and isinstance(data['images_b64'], list):
667
+ count = len(data['images_b64'])
668
+ data['images_b64'] = f"[<{count} base64 image(s) truncated>]"
669
+
670
+ # Sanitize a potential mask in the 'params' dictionary
671
+ if 'params' in data and isinstance(data.get('params'), dict):
672
+ if 'mask_image' in data['params'] and isinstance(data['params']['mask_image'], str):
673
+ original_len = len(data['params']['mask_image'])
674
+ data['params']['mask_image'] = f"[<base64 mask truncated, len={original_len}>]"
675
+
676
+ return data
677
+ except Exception:
678
+ return {"error": "Failed to sanitize request data."}
679
+
680
+ # --- API Endpoints ---
681
+ @router.post("/generate_image")
682
+ async def generate_image(request: T2IRequest):
683
+ manager = None
684
+ temp_config = None
685
+ try:
686
+ params = request.params
687
+
688
+ # Determine which model manager to use for this specific request
689
+ if "model_name" in params and params["model_name"]:
690
+ temp_config = state.config.copy()
691
+ temp_config["model_name"] = params.pop("model_name") # Remove from params to avoid being passed to pipeline
692
+ manager = state.registry.get_manager(temp_config, state.models_path)
693
+ ASCIIColors.info(f"Using per-request model: {temp_config['model_name']}")
694
+ else:
695
+ manager = state.get_active_manager()
696
+ ASCIIColors.info(f"Using session-configured model: {manager.config.get('model_name')}")
697
+
698
+ seed = int(params.get("seed", manager.config.get("seed", -1)))
699
+ generator = None
700
+ if seed != -1:
701
+ generator = torch.Generator(device=manager.config["device"]).manual_seed(seed)
702
+
703
+ width = int(params.get("width", manager.config.get("width", 512)))
704
+ height = int(params.get("height", manager.config.get("height", 512)))
705
+
706
+ pipeline_args = {
707
+ "prompt": request.prompt,
708
+ "negative_prompt": request.negative_prompt,
709
+ "width": width,
710
+ "height": height,
711
+ "num_inference_steps": int(params.get("num_inference_steps", manager.config.get("num_inference_steps", 25))),
712
+ "guidance_scale": float(params.get("guidance_scale", manager.config.get("guidance_scale", 7.0))),
713
+ "generator": generator
714
+ }
715
+ pipeline_args.update(params)
716
+
717
+ model_name = manager.config.get("model_name", "")
718
+ task = "text2image"
719
+
720
+ if "Qwen-Image-Edit" in model_name:
721
+ rng_seed = seed if seed != -1 else None
722
+ rng = np.random.default_rng(seed=rng_seed)
723
+ random_pixels = rng.integers(0, 256, size=(height, width, 3), dtype=np.uint8)
724
+ placeholder_image = Image.fromarray(random_pixels, 'RGB')
725
+ pipeline_args["image"] = placeholder_image
726
+ pipeline_args["strength"] = float(params.get("strength", 1.0))
727
+ task = "image2image"
728
+
729
+ future = Future()
730
+ manager.queue.put((future, task, pipeline_args))
731
+ result_bytes = future.result()
732
+ return Response(content=result_bytes, media_type="image/png")
733
+ except Exception as e:
734
+ trace_exception(e)
735
+ raise HTTPException(status_code=500, detail=str(e))
736
+ finally:
737
+ if temp_config and manager:
738
+ state.registry.release_manager(temp_config)
739
+ ASCIIColors.info(f"Released per-request model: {temp_config['model_name']}")
740
+
741
+
742
+ @router.post("/edit_image")
743
+ async def edit_image(request: EditRequestJSON):
744
+ manager = None
745
+ temp_config = None
746
+ try:
747
+ params = request.params
748
+
749
+ if "model_name" in params and params["model_name"]:
750
+ temp_config = state.config.copy()
751
+ temp_config["model_name"] = params.pop("model_name")
752
+ manager = state.registry.get_manager(temp_config, state.models_path)
753
+ ASCIIColors.info(f"Using per-request model: {temp_config['model_name']}")
754
+ else:
755
+ manager = state.get_active_manager()
756
+ ASCIIColors.info(f"Using session-configured model: {manager.config.get('model_name')}")
757
+
758
+ model_name = manager.config.get("model_name", "")
759
+
760
+ pil_images = []
761
+ for b64_string in request.images_b64:
762
+ b64_data = b64_string.split(";base64,")[1] if ";base64," in b64_string else b64_string
763
+ image_bytes = base64.b64decode(b64_data)
764
+ pil_images.append(Image.open(BytesIO(image_bytes)).convert("RGB"))
765
+
766
+ if not pil_images: raise HTTPException(status_code=400, detail="No valid images provided.")
767
+
768
+ pipeline_args = {"prompt": request.prompt}
769
+ seed = int(params.get("seed", -1))
770
+ if seed != -1: pipeline_args["generator"] = torch.Generator(device=manager.config["device"]).manual_seed(seed)
771
+
772
+ if "mask_image" in params and params["mask_image"]:
773
+ b64_mask = params["mask_image"]
774
+ b64_data = b64_mask.split(";base64,")[1] if ";base64," in b64_mask else b64_mask
775
+ mask_bytes = base64.b64decode(b64_data)
776
+ pipeline_args["mask_image"] = Image.open(BytesIO(mask_bytes)).convert("L")
777
+
778
+ task = "inpainting" if "mask_image" in pipeline_args else "image2image"
779
+
780
+ if "Qwen-Image-Edit-2509" in model_name:
781
+ task = "image2image"
782
+ pipeline_args.update({"true_cfg_scale": 4.0, "guidance_scale": 1.0, "num_inference_steps": 40, "negative_prompt": " "})
783
+ edit_mode = params.get("edit_mode", "fusion")
784
+ if edit_mode == "fusion": pipeline_args["image"] = pil_images
785
+ else:
786
+ pipeline_args.update({"image": pil_images[0], "strength": 0.8, "guidance_scale": 7.5, "num_inference_steps": 25})
787
+
788
+ pipeline_args.update(params)
789
+
790
+ future = Future(); manager.queue.put((future, task, pipeline_args))
791
+ return Response(content=future.result(), media_type="image/png")
792
+ except Exception as e:
793
+ sanitized_payload = get_sanitized_request_for_logging(request)
794
+ ASCIIColors.error(f"Exception in /edit_image. Sanitized Payload: {json.dumps(sanitized_payload, indent=2)}")
795
+ trace_exception(e)
796
+ raise HTTPException(status_code=500, detail=str(e))
797
+ finally:
798
+ if temp_config and manager:
799
+ state.registry.release_manager(temp_config)
800
+ ASCIIColors.info(f"Released per-request model: {temp_config['model_name']}")
801
+
802
+
803
+ @router.get("/list_models")
804
+ def list_models_endpoint():
805
+ civitai = [{'model_name': key, 'display_name': info['display_name'], 'description': info['description'], 'owned_by': info['owned_by']} for key, info in CIVITAI_MODELS.items()]
806
+ local = [{'model_name': f.name, 'display_name': f.stem, 'description': 'Local safetensors file.', 'owned_by': 'local_user'} for f in state.models_path.glob("*.safetensors")]
807
+ huggingface = [{'model_name': m['model_name'], 'display_name': m['display_name'], 'description': m['desc'], 'owned_by': 'huggingface'} for m in HF_DEFAULT_MODELS]
808
+ return huggingface + civitai + local
809
+
810
+ @router.get("/list_local_models")
811
+ def list_local_models_endpoint():
812
+ return sorted([f.name for f in state.models_path.glob("*.safetensors")])
813
+
814
+ @router.get("/list_available_models")
815
+ def list_available_models_endpoint():
816
+ discoverable = [m['model_name'] for m in list_models_endpoint()]
817
+ return sorted(list(set(discoverable)))
818
+
819
+ @router.get("/get_settings")
820
+ def get_settings_endpoint():
821
+ settings_list = []
822
+ available_models = list_available_models_endpoint()
823
+ schedulers = list(SCHEDULER_MAPPING.keys())
824
+ config_to_display = state.config or state.get_default_config()
825
+ for name, value in config_to_display.items():
826
+ setting = {"name": name, "type": str(type(value).__name__), "value": value}
827
+ if name == "model_name": setting["options"] = available_models
828
+ if name == "scheduler_name": setting["options"] = schedulers
829
+ settings_list.append(setting)
830
+ return settings_list
831
+
832
+ @router.post("/set_settings")
833
+ def set_settings_endpoint(settings: Dict[str, Any]):
834
+ try:
835
+ success = state.update_settings(settings)
836
+ return {"success": success}
837
+ except Exception as e:
838
+ trace_exception(e)
839
+ raise HTTPException(status_code=500, detail=str(e))
840
+
841
+ @router.get("/status")
842
+ def status_endpoint():
843
+ return {"status": "running", "diffusers_available": DIFFUSERS_AVAILABLE, "model_loaded": state.manager.is_loaded if state.manager else False}
844
+
845
+ @router.post("/unload_model")
846
+ def unload_model_endpoint():
847
+ if state.manager:
848
+ state.manager._unload_pipeline()
849
+ state.registry.release_manager(state.manager.config)
850
+ state.manager = None
851
+ return {"status": "unloaded"}
852
+
853
+ @router.get("/ps")
854
+ def ps_endpoint():
855
+ managers = state.registry.get_all_managers()
856
+ return [{
857
+ "model_name": m.config.get("model_name"), "is_loaded": m.is_loaded,
858
+ "task": m.current_task, "device": m.config.get("device"), "ref_count": m.ref_count,
859
+ "queue_size": m.queue.qsize(), "last_used": time.ctime(m.last_used_time)
860
+ } for m in managers]
861
+
862
+ app.include_router(router)
863
+
864
+ if __name__ == "__main__":
865
+ parser = argparse.ArgumentParser(description="Diffusers TTI Server")
866
+ parser.add_argument("--host", type=str, default="localhost", help="Host to bind to.")
867
+ parser.add_argument("--port", type=int, default=9630, help="Port to bind to.")
868
+ parser.add_argument("--models-path", type=str, required=True, help="Path to the models directory.")
869
+ args = parser.parse_args()
870
+
871
+ MODELS_PATH = Path(args.models_path)
872
+ state = ServerState(MODELS_PATH)
873
+
874
+ ASCIIColors.cyan(f"--- Diffusers TTI Server ---")
875
+ ASCIIColors.green(f"Starting server on http://{args.host}:{args.port}")
876
+ ASCIIColors.green(f"Serving models from: {MODELS_PATH.resolve()}")
877
+ if not DIFFUSERS_AVAILABLE:
878
+ ASCIIColors.error("Diffusers or its dependencies are not installed correctly in the server's environment!")
879
+ else:
880
+ ASCIIColors.info(f"Detected device: {state.config['device']}")
881
+
882
+ uvicorn.run(app, host=args.host, port=args.port, reload=False)