lollms-client 1.6.2__py3-none-any.whl → 1.6.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lollms-client might be problematic. Click here for more details.

Files changed (41) hide show
  1. lollms_client/__init__.py +1 -1
  2. lollms_client/llm_bindings/azure_openai/__init__.py +2 -2
  3. lollms_client/llm_bindings/claude/__init__.py +2 -2
  4. lollms_client/llm_bindings/gemini/__init__.py +2 -2
  5. lollms_client/llm_bindings/grok/__init__.py +2 -2
  6. lollms_client/llm_bindings/groq/__init__.py +2 -2
  7. lollms_client/llm_bindings/hugging_face_inference_api/__init__.py +2 -2
  8. lollms_client/llm_bindings/litellm/__init__.py +1 -1
  9. lollms_client/llm_bindings/llamacpp/__init__.py +2 -2
  10. lollms_client/llm_bindings/lollms/__init__.py +1 -1
  11. lollms_client/llm_bindings/lollms_webui/__init__.py +1 -1
  12. lollms_client/llm_bindings/mistral/__init__.py +2 -2
  13. lollms_client/llm_bindings/novita_ai/__init__.py +2 -2
  14. lollms_client/llm_bindings/ollama/__init__.py +7 -4
  15. lollms_client/llm_bindings/open_router/__init__.py +2 -2
  16. lollms_client/llm_bindings/openai/__init__.py +1 -1
  17. lollms_client/llm_bindings/openllm/__init__.py +2 -2
  18. lollms_client/llm_bindings/openwebui/__init__.py +1 -1
  19. lollms_client/llm_bindings/perplexity/__init__.py +2 -2
  20. lollms_client/llm_bindings/pythonllamacpp/__init__.py +3 -3
  21. lollms_client/llm_bindings/tensor_rt/__init__.py +1 -1
  22. lollms_client/llm_bindings/transformers/__init__.py +4 -4
  23. lollms_client/llm_bindings/vllm/__init__.py +1 -1
  24. lollms_client/lollms_core.py +4 -1442
  25. lollms_client/lollms_llm_binding.py +1 -1
  26. lollms_client/lollms_tti_binding.py +1 -1
  27. lollms_client/tti_bindings/diffusers/__init__.py +276 -856
  28. lollms_client/tti_bindings/diffusers/server/main.py +730 -0
  29. lollms_client/tti_bindings/gemini/__init__.py +1 -1
  30. lollms_client/tti_bindings/leonardo_ai/__init__.py +1 -1
  31. lollms_client/tti_bindings/novita_ai/__init__.py +1 -1
  32. lollms_client/tti_bindings/stability_ai/__init__.py +1 -1
  33. lollms_client/tts_bindings/lollms/__init__.py +6 -1
  34. lollms_client/tts_bindings/piper_tts/__init__.py +1 -1
  35. lollms_client/tts_bindings/xtts/__init__.py +20 -14
  36. lollms_client/tts_bindings/xtts/server/main.py +17 -4
  37. {lollms_client-1.6.2.dist-info → lollms_client-1.6.4.dist-info}/METADATA +2 -2
  38. {lollms_client-1.6.2.dist-info → lollms_client-1.6.4.dist-info}/RECORD +41 -40
  39. {lollms_client-1.6.2.dist-info → lollms_client-1.6.4.dist-info}/WHEEL +0 -0
  40. {lollms_client-1.6.2.dist-info → lollms_client-1.6.4.dist-info}/licenses/LICENSE +0 -0
  41. {lollms_client-1.6.2.dist-info → lollms_client-1.6.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,730 @@
1
+ import os
2
+ import importlib
3
+ from io import BytesIO
4
+ from typing import Optional, List, Dict, Any, Union, Tuple
5
+ from pathlib import Path
6
+ import base64
7
+ import threading
8
+ import queue
9
+ from concurrent.futures import Future
10
+ import time
11
+ import hashlib
12
+ import requests
13
+ from tqdm import tqdm
14
+ import json
15
+ import shutil
16
+ import numpy as np
17
+ import gc
18
+ import argparse
19
+ import uvicorn
20
+ from fastapi import FastAPI, APIRouter, HTTPException, UploadFile, Form
21
+ from fastapi.responses import Response
22
+ from pydantic import BaseModel, Field
23
+ import sys
24
+
25
+ # Add binding root to sys.path to ensure local modules can be imported if structured that way.
26
+ binding_root = Path(__file__).resolve().parent.parent
27
+ sys.path.insert(0, str(binding_root))
28
+
29
+ # --- Dependency Check and Imports ---
30
+ try:
31
+ import torch
32
+ from diffusers import (
33
+ AutoPipelineForText2Image, AutoPipelineForImage2Image, AutoPipelineForInpainting,
34
+ DiffusionPipeline, StableDiffusionPipeline, QwenImageEditPipeline, QwenImageEditPlusPipeline
35
+ )
36
+ from diffusers.utils import load_image
37
+ from PIL import Image
38
+ from ascii_colors import trace_exception, ASCIIColors
39
+ DIFFUSERS_AVAILABLE = True
40
+ except ImportError as e:
41
+ print(f"FATAL: A required package is missing from the server's venv: {e}.")
42
+ DIFFUSERS_AVAILABLE = False
43
+ # Define dummy classes to allow server to start and report error via API
44
+ class Dummy: pass
45
+ torch = Dummy()
46
+ torch.cuda = Dummy()
47
+ torch.cuda.is_available = lambda: False
48
+ torch.backends = Dummy()
49
+ torch.backends.mps = Dummy()
50
+ torch.backends.mps.is_available = lambda: False
51
+ AutoPipelineForText2Image = AutoPipelineForImage2Image = AutoPipelineForInpainting = DiffusionPipeline = StableDiffusionPipeline = QwenImageEditPipeline = QwenImageEditPlusPipeline = Image = load_image = ASCIIColors = trace_exception = Dummy
52
+
53
+ # --- Server Setup ---
54
+ app = FastAPI(title="Diffusers TTI Server")
55
+ router = APIRouter()
56
+ MODELS_PATH = Path("./models") # Default, will be overridden by command-line arg
57
+
58
+ # --- START: Core Logic (Complete and Unabridged) ---
59
+ CIVITAI_MODELS = {
60
+ "realistic-vision-v6": {
61
+ "display_name": "Realistic Vision V6.0", "url": "https://civitai.com/api/download/models/501240?type=Model&format=SafeTensor&size=pruned&fp=fp16",
62
+ "filename": "realisticVisionV60_v60B1.safetensors", "description": "Photorealistic SD1.5 checkpoint.", "owned_by": "civitai"
63
+ },
64
+ "absolute-reality": {
65
+ "display_name": "Absolute Reality", "url": "https://civitai.com/api/download/models/132760?type=Model&format=SafeTensor&size=pruned&fp=fp16",
66
+ "filename": "absolutereality_v181.safetensors", "description": "General realistic SD1.5.", "owned_by": "civitai"
67
+ },
68
+ "dreamshaper-8": {
69
+ "display_name": "DreamShaper 8", "url": "https://civitai.com/api/download/models/128713",
70
+ "filename": "dreamshaper_8.safetensors", "description": "Versatile SD1.5 style model.", "owned_by": "civitai"
71
+ },
72
+ "juggernaut-xl": {
73
+ "display_name": "Juggernaut XL", "url": "https://civitai.com/api/download/models/133005",
74
+ "filename": "juggernautXL_version6Rundiffusion.safetensors", "description": "Artistic SDXL.", "owned_by": "civitai"
75
+ },
76
+ "lyriel-v1.6": {
77
+ "display_name": "Lyriel v1.6", "url": "https://civitai.com/api/download/models/72396?type=Model&format=SafeTensor&size=full&fp=fp16",
78
+ "filename": "lyriel_v16.safetensors", "description": "Fantasy/stylized SD1.5.", "owned_by": "civitai"
79
+ },
80
+ "ui_icons": {
81
+ "display_name": "UI Icons", "url": "https://civitai.com/api/download/models/367044?type=Model&format=SafeTensor&size=full&fp=fp16",
82
+ "filename": "uiIcons_v10.safetensors", "description": "A model for generating UI icons.", "owned_by": "civitai"
83
+ },
84
+ "meinamix": {
85
+ "display_name": "MeinaMix", "url": "https://civitai.com/api/download/models/948574?type=Model&format=SafeTensor&size=pruned&fp=fp16",
86
+ "filename": "meinamix_meinaV11.safetensors", "description": "Anime/illustration SD1.5.", "owned_by": "civitai"
87
+ },
88
+ "rpg-v5": {
89
+ "display_name": "RPG v5", "url": "https://civitai.com/api/download/models/124626?type=Model&format=SafeTensor&size=pruned&fp=fp16",
90
+ "filename": "rpg_v5.safetensors", "description": "RPG assets SD1.5.", "owned_by": "civitai"
91
+ },
92
+ "pixel-art-xl": {
93
+ "display_name": "Pixel Art XL", "url": "https://civitai.com/api/download/models/135931?type=Model&format=SafeTensor",
94
+ "filename": "pixelartxl_v11.safetensors", "description": "Pixel art SDXL.", "owned_by": "civitai"
95
+ },
96
+ "lowpoly-world": {
97
+ "display_name": "Lowpoly World", "url": "https://civitai.com/api/download/models/146502?type=Model&format=SafeTensor",
98
+ "filename": "LowpolySDXL.safetensors", "description": "Lowpoly style SD1.5.", "owned_by": "civitai"
99
+ },
100
+ "toonyou": {
101
+ "display_name": "ToonYou", "url": "https://civitai.com/api/download/models/125771?type=Model&format=SafeTensor&size=pruned&fp=fp16",
102
+ "filename": "toonyou_beta6.safetensors", "description": "Cartoon/Disney SD1.5.", "owned_by": "civitai"
103
+ },
104
+ "papercut": {
105
+ "display_name": "Papercut", "url": "https://civitai.com/api/download/models/133503?type=Model&format=SafeTensor",
106
+ "filename": "papercut.safetensors", "description": "Paper cutout SD1.5.", "owned_by": "civitai"
107
+ },
108
+ "fantassifiedIcons": {
109
+ "display_name": "Fantassified Icons", "url": "https://civitai.com/api/download/models/67584?type=Model&format=SafeTensor&size=pruned&fp=fp16",
110
+ "filename": "fantassifiedIcons_fantassifiedIconsV20.safetensors", "description": "Flat, modern Icons.", "owned_by": "civitai"
111
+ },
112
+ "game_icon_institute": {
113
+ "display_name": "Game icon institute", "url": "https://civitai.com/api/download/models/158776?type=Model&format=SafeTensor&size=full&fp=fp16",
114
+ "filename": "gameIconInstituteV10_v10.safetensors", "description": "Flat, modern game Icons.", "owned_by": "civitai"
115
+ },
116
+ "M4RV3LS_DUNGEONS": {
117
+ "display_name": "M4RV3LS & DUNGEONS", "url": "https://civitai.com/api/download/models/139417?type=Model&format=SafeTensor&size=pruned&fp=fp16",
118
+ "filename": "M4RV3LSDUNGEONSNEWV40COMICS_mD40.safetensors", "description": "comics.", "owned_by": "civitai"
119
+ },
120
+ }
121
+
122
+ TORCH_DTYPE_MAP_STR_TO_OBJ = {
123
+ "float16": getattr(torch, 'float16', 'float16'), "bfloat16": getattr(torch, 'bfloat16', 'bfloat16'),
124
+ "float32": getattr(torch, 'float32', 'float32'), "auto": "auto"
125
+ }
126
+
127
+ SCHEDULER_MAPPING = {
128
+ "default": None, "ddim": "DDIMScheduler", "ddpm": "DDPMScheduler", "deis_multistep": "DEISMultistepScheduler",
129
+ "dpm_multistep": "DPMSolverMultistepScheduler", "dpm_multistep_karras": "DPMSolverMultistepScheduler", "dpm_single": "DPMSolverSinglestepScheduler",
130
+ "dpm_adaptive": "DPMSolverPlusPlusScheduler", "dpm++_2m": "DPMSolverMultistepScheduler", "dpm++_2m_karras": "DPMSolverMultistepScheduler",
131
+ "dpm++_2s_ancestral": "DPMSolverAncestralDiscreteScheduler", "dpm++_2s_ancestral_karras": "DPMSolverAncestralDiscreteScheduler", "dpm++_sde": "DPMSolverSDEScheduler",
132
+ "dpm++_sde_karras": "DPMSolverSDEScheduler", "euler_ancestral_discrete": "EulerAncestralDiscreteScheduler", "euler_discrete": "EulerDiscreteScheduler",
133
+ "heun_discrete": "HeunDiscreteScheduler", "heun_karras": "HeunDiscreteScheduler", "lms_discrete": "LMSDiscreteScheduler",
134
+ "lms_karras": "LMSDiscreteScheduler", "pndm": "PNDMScheduler", "unipc_multistep": "UniPCMultistepScheduler",
135
+ "dpm++_2m_sde": "DPMSolverMultistepScheduler", "dpm++_2m_sde_karras": "DPMSolverMultistepScheduler", "dpm2": "KDPM2DiscreteScheduler",
136
+ "dpm2_karras": "KDPM2DiscreteScheduler", "dpm2_a": "KDPM2AncestralDiscreteScheduler", "dpm2_a_karras": "KDPM2AncestralDiscreteScheduler",
137
+ "euler": "EulerDiscreteScheduler", "euler_a": "EulerAncestralDiscreteScheduler", "heun": "HeunDiscreteScheduler", "lms": "LMSDiscreteScheduler"
138
+ }
139
+ SCHEDULER_USES_KARRAS_SIGMAS = [
140
+ "dpm_multistep_karras","dpm++_2m_karras","dpm++_2s_ancestral_karras", "dpm++_sde_karras","heun_karras","lms_karras",
141
+ "dpm++_2m_sde_karras","dpm2_karras","dpm2_a_karras"
142
+ ]
143
+
144
+ class ModelManager:
145
+ def __init__(self, config: Dict[str, Any], models_path: Path, registry: 'PipelineRegistry'):
146
+ self.config = config
147
+ self.models_path = models_path
148
+ self.registry = registry
149
+ self.pipeline: Optional[DiffusionPipeline] = None
150
+ self.current_task: Optional[str] = None
151
+ self.ref_count = 0
152
+ self.lock = threading.Lock()
153
+ self.queue = queue.Queue()
154
+ self.is_loaded = False
155
+ self.last_used_time = time.time()
156
+ self._stop_event = threading.Event()
157
+ self.worker_thread = threading.Thread(target=self._generation_worker, daemon=True)
158
+ self.worker_thread.start()
159
+ self._stop_monitor_event = threading.Event()
160
+ self._unload_monitor_thread = None
161
+ self._start_unload_monitor()
162
+
163
+ def acquire(self):
164
+ with self.lock:
165
+ self.ref_count += 1
166
+ return self
167
+
168
+ def release(self):
169
+ with self.lock:
170
+ self.ref_count -= 1
171
+ return self.ref_count
172
+
173
+ def stop(self):
174
+ self._stop_event.set()
175
+ if self._unload_monitor_thread:
176
+ self._stop_monitor_event.set()
177
+ self._unload_monitor_thread.join(timeout=2)
178
+ self.queue.put(None)
179
+ self.worker_thread.join(timeout=5)
180
+
181
+ def _start_unload_monitor(self):
182
+ unload_after = self.config.get("unload_inactive_model_after", 0)
183
+ if unload_after > 0 and self._unload_monitor_thread is None:
184
+ self._stop_monitor_event.clear()
185
+ self._unload_monitor_thread = threading.Thread(target=self._unload_monitor, daemon=True)
186
+ self._unload_monitor_thread.start()
187
+
188
+ def _unload_monitor(self):
189
+ unload_after = self.config.get("unload_inactive_model_after", 0)
190
+ if unload_after <= 0:
191
+ return
192
+ ASCIIColors.info(f"Starting inactivity monitor for '{self.config['model_name']}' (timeout: {unload_after}s).")
193
+ while not self._stop_monitor_event.wait(timeout=5.0):
194
+ with self.lock:
195
+ if not self.is_loaded:
196
+ continue
197
+ if time.time() - self.last_used_time > unload_after:
198
+ ASCIIColors.info(f"Model '{self.config['model_name']}' has been inactive. Unloading.")
199
+ self._unload_pipeline()
200
+
201
+ def _resolve_model_path(self, model_name: str) -> Union[str, Path]:
202
+ path_obj = Path(model_name)
203
+ if path_obj.is_absolute() and path_obj.exists():
204
+ return model_name
205
+ if model_name in CIVITAI_MODELS:
206
+ filename = CIVITAI_MODELS[model_name]["filename"]
207
+ local_path = self.models_path / filename
208
+ if not local_path.exists():
209
+ self._download_civitai_model(model_name)
210
+ return local_path
211
+ local_path = self.models_path / model_name
212
+ if local_path.exists():
213
+ return local_path
214
+ return model_name
215
+
216
+ def _download_civitai_model(self, model_key: str):
217
+ model_info = CIVITAI_MODELS[model_key]
218
+ url = model_info["url"]
219
+ filename = model_info["filename"]
220
+ dest_path = self.models_path / filename
221
+ temp_path = dest_path.with_suffix(".temp")
222
+ ASCIIColors.cyan(f"Downloading '{filename}' from Civitai...")
223
+ try:
224
+ with requests.get(url, stream=True) as r:
225
+ r.raise_for_status()
226
+ total_size = int(r.headers.get('content-length', 0))
227
+ with open(temp_path, 'wb') as f, tqdm(total=total_size, unit='iB', unit_scale=True, desc=f"Downloading {filename}") as bar:
228
+ for chunk in r.iter_content(chunk_size=8192):
229
+ f.write(chunk)
230
+ bar.update(len(chunk))
231
+ shutil.move(temp_path, dest_path)
232
+ ASCIIColors.green(f"Model '{filename}' downloaded successfully.")
233
+ except Exception as e:
234
+ if temp_path.exists():
235
+ temp_path.unlink()
236
+ raise Exception(f"Failed to download model {filename}: {e}")
237
+
238
+ def _set_scheduler(self):
239
+ if not self.pipeline:
240
+ return
241
+ if "Qwen" in self.config.get("model_name", ""):
242
+ ASCIIColors.info("Qwen model detected, skipping custom scheduler setup.")
243
+ return
244
+ scheduler_name_key = self.config["scheduler_name"].lower()
245
+ if scheduler_name_key == "default":
246
+ return
247
+ scheduler_class_name = SCHEDULER_MAPPING.get(scheduler_name_key)
248
+ if scheduler_class_name:
249
+ try:
250
+ SchedulerClass = getattr(importlib.import_module("diffusers.schedulers"), scheduler_class_name)
251
+ scheduler_config = self.pipeline.scheduler.config
252
+ scheduler_config["use_karras_sigmas"] = scheduler_name_key in SCHEDULER_USES_KARRAS_SIGMAS
253
+ self.pipeline.scheduler = SchedulerClass.from_config(scheduler_config)
254
+ ASCIIColors.info(f"Switched scheduler to {scheduler_class_name}")
255
+ except Exception as e:
256
+ ASCIIColors.warning(f"Could not switch scheduler to {scheduler_name_key}: {e}. Using current default.")
257
+
258
+ def _execute_load_pipeline(self, task: str, model_path: Union[str, Path], torch_dtype: Any):
259
+ model_name = self.config.get("model_name", "")
260
+ try:
261
+ load_args = {}
262
+ if self.config.get("hf_cache_path"):
263
+ load_args["cache_dir"] = str(self.config["hf_cache_path"])
264
+ if str(model_path).endswith(".safetensors"):
265
+ if task == "text2image":
266
+ try:
267
+ self.pipeline = AutoPipelineForText2Image.from_single_file(model_path, torch_dtype=torch_dtype, cache_dir=load_args.get("cache_dir"))
268
+ except AttributeError:
269
+ self.pipeline = StableDiffusionPipeline.from_single_file(model_path, torch_dtype=torch_dtype, cache_dir=load_args.get("cache_dir"))
270
+ elif task == "image2image":
271
+ self.pipeline = AutoPipelineForImage2Image.from_single_file(model_path, torch_dtype=torch_dtype, cache_dir=load_args.get("cache_dir"))
272
+ elif task == "inpainting":
273
+ self.pipeline = AutoPipelineForInpainting.from_single_file(model_path, torch_dtype=torch_dtype, cache_dir=load_args.get("cache_dir"))
274
+ else:
275
+ common_args = {
276
+ "torch_dtype": torch_dtype,
277
+ "use_safetensors": self.config["use_safetensors"],
278
+ "token": self.config["hf_token"],
279
+ "local_files_only": self.config["local_files_only"]
280
+ }
281
+ if self.config["hf_variant"]:
282
+ common_args["variant"] = self.config["hf_variant"]
283
+ if not self.config["safety_checker_on"]:
284
+ common_args["safety_checker"] = None
285
+ if self.config.get("hf_cache_path"):
286
+ common_args["cache_dir"] = str(self.config["hf_cache_path"])
287
+
288
+ if "Qwen-Image-Edit-2509" in str(model_path):
289
+ self.pipeline = QwenImageEditPlusPipeline.from_pretrained(model_path, **common_args)
290
+ elif "Qwen-Image-Edit" in str(model_path):
291
+ self.pipeline = QwenImageEditPipeline.from_pretrained(model_path, **common_args)
292
+ elif "Qwen/Qwen-Image" in str(model_path):
293
+ self.pipeline = DiffusionPipeline.from_pretrained(model_path, **common_args)
294
+ elif task == "text2image":
295
+ self.pipeline = AutoPipelineForText2Image.from_pretrained(model_path, **common_args)
296
+ elif task == "image2image":
297
+ self.pipeline = AutoPipelineForImage2Image.from_pretrained(model_path, **common_args)
298
+ elif task == "inpainting":
299
+ self.pipeline = AutoPipelineForInpainting.from_pretrained(model_path, **common_args)
300
+ except Exception as e:
301
+ error_str = str(e).lower()
302
+ if "401" in error_str or "gated" in error_str or "authorization" in error_str:
303
+ msg = (
304
+ f"AUTHENTICATION FAILED for model '{model_name}'. "
305
+ "Please ensure you accepted the model license and provided a valid HF token."
306
+ )
307
+ raise RuntimeError(msg)
308
+ raise e
309
+ self._set_scheduler()
310
+ self.pipeline.to(self.config["device"])
311
+ if self.config["enable_xformers"]:
312
+ try:
313
+ self.pipeline.enable_xformers_memory_efficient_attention()
314
+ except Exception as e:
315
+ ASCIIColors.warning(f"Could not enable xFormers: {e}.")
316
+ if self.config["enable_cpu_offload"] and self.config["device"] != "cpu":
317
+ self.pipeline.enable_model_cpu_offload()
318
+ elif self.config["enable_sequential_cpu_offload"] and self.config["device"] != "cpu":
319
+ self.pipeline.enable_sequential_cpu_offload()
320
+ self.is_loaded = True
321
+ self.current_task = task
322
+ self.last_used_time = time.time()
323
+ ASCIIColors.green(f"Model '{model_name}' loaded successfully on '{self.config['device']}' for task '{task}'.")
324
+
325
+ def _load_pipeline_for_task(self, task: str):
326
+ if self.pipeline and self.current_task == task:
327
+ return
328
+ if self.pipeline:
329
+ self._unload_pipeline()
330
+
331
+ model_name = self.config.get("model_name", "")
332
+ if not model_name:
333
+ raise ValueError("Model name cannot be empty for loading.")
334
+
335
+ ASCIIColors.info(f"Loading Diffusers model: {model_name} for task: {task}")
336
+ model_path = self._resolve_model_path(model_name)
337
+ torch_dtype = TORCH_DTYPE_MAP_STR_TO_OBJ.get(self.config["torch_dtype_str"].lower())
338
+
339
+ try:
340
+ self._execute_load_pipeline(task, model_path, torch_dtype)
341
+ return
342
+ except Exception as e:
343
+ is_oom = "out of memory" in str(e).lower()
344
+ if not is_oom or not hasattr(self, 'registry'):
345
+ raise e
346
+
347
+ ASCIIColors.warning(f"Failed to load '{model_name}' due to OOM. Attempting to unload other models to free VRAM.")
348
+
349
+ candidates_to_unload = [
350
+ m for m in self.registry.get_all_managers()
351
+ if m is not self and m.is_loaded
352
+ ]
353
+ candidates_to_unload.sort(key=lambda m: m.last_used_time)
354
+
355
+ if not candidates_to_unload:
356
+ ASCIIColors.error("OOM error, but no other models are available to unload.")
357
+ raise e
358
+
359
+ for victim in candidates_to_unload:
360
+ ASCIIColors.info(f"Unloading '{victim.config['model_name']}' (last used: {time.ctime(victim.last_used_time)}) to free VRAM.")
361
+ victim._unload_pipeline()
362
+
363
+ try:
364
+ ASCIIColors.info(f"Retrying to load '{model_name}'...")
365
+ self._execute_load_pipeline(task, model_path, torch_dtype)
366
+ ASCIIColors.green(f"Successfully loaded '{model_name}' after freeing VRAM.")
367
+ return
368
+ except Exception as retry_e:
369
+ is_oom_retry = "out of memory" in str(retry_e).lower()
370
+ if not is_oom_retry:
371
+ raise retry_e
372
+
373
+ ASCIIColors.error(f"Could not load '{model_name}' even after unloading all other models.")
374
+ raise e
375
+
376
+ def _unload_pipeline(self):
377
+ if self.pipeline:
378
+ model_name = self.config.get('model_name', 'Unknown')
379
+ del self.pipeline
380
+ self.pipeline = None
381
+ gc.collect()
382
+ if torch and torch.cuda.is_available():
383
+ torch.cuda.empty_cache()
384
+ self.is_loaded = False
385
+ self.current_task = None
386
+ ASCIIColors.info(f"Model '{model_name}' unloaded and VRAM cleared.")
387
+
388
+ def _generation_worker(self):
389
+ while not self._stop_event.is_set():
390
+ try:
391
+ job = self.queue.get(timeout=1)
392
+ if job is None:
393
+ break
394
+ future, task, pipeline_args = job
395
+ output = None
396
+ try:
397
+ with self.lock:
398
+ self.last_used_time = time.time()
399
+ if not self.is_loaded or self.current_task != task:
400
+ self._load_pipeline_for_task(task)
401
+ with torch.no_grad():
402
+ output = self.pipeline(**pipeline_args)
403
+ pil = output.images[0]
404
+ buf = BytesIO()
405
+ pil.save(buf, format="PNG")
406
+ future.set_result(buf.getvalue())
407
+ except Exception as e:
408
+ trace_exception(e)
409
+ future.set_exception(e)
410
+ finally:
411
+ self.queue.task_done()
412
+ # Aggressive cleanup
413
+ if output is not None:
414
+ del output
415
+ gc.collect()
416
+ if torch.cuda.is_available():
417
+ torch.cuda.empty_cache()
418
+ except queue.Empty:
419
+ continue
420
+
421
+ class PipelineRegistry:
422
+ _instance = None
423
+ _lock = threading.Lock()
424
+ def __new__(cls, *args, **kwargs):
425
+ with cls._lock:
426
+ if cls._instance is None:
427
+ cls._instance = super().__new__(cls)
428
+ cls._instance._managers = {}
429
+ cls._instance._registry_lock = threading.Lock()
430
+ return cls._instance
431
+ @staticmethod
432
+ def _get_critical_keys():
433
+ return [
434
+ "model_name","device","torch_dtype_str","use_safetensors",
435
+ "safety_checker_on","hf_variant","enable_cpu_offload",
436
+ "enable_sequential_cpu_offload","enable_xformers",
437
+ "local_files_only","hf_cache_path","unload_inactive_model_after"
438
+ ]
439
+ def _get_config_key(self, config: Dict[str, Any]) -> str:
440
+ key_data = tuple(sorted((k, config.get(k)) for k in self._get_critical_keys()))
441
+ return hashlib.sha256(str(key_data).encode('utf-8')).hexdigest()
442
+ def get_manager(self, config: Dict[str, Any], models_path: Path) -> ModelManager:
443
+ key = self._get_config_key(config)
444
+ with self._registry_lock:
445
+ if key not in self._managers:
446
+ self._managers[key] = ModelManager(config.copy(), models_path, self)
447
+ return self._managers[key].acquire()
448
+ def release_manager(self, config: Dict[str, Any]):
449
+ key = self._get_config_key(config)
450
+ with self._registry_lock:
451
+ if key in self._managers:
452
+ manager = self._managers[key]
453
+ ref_count = manager.release()
454
+ if ref_count == 0:
455
+ ASCIIColors.info(f"Reference count for model '{config.get('model_name')}' is zero. Cleaning up manager.")
456
+ manager.stop()
457
+ with manager.lock:
458
+ manager._unload_pipeline()
459
+ del self._managers[key]
460
+ def get_active_managers(self) -> List[ModelManager]:
461
+ with self._registry_lock:
462
+ return [m for m in self._managers.values() if m.is_loaded]
463
+ def get_all_managers(self) -> List[ModelManager]:
464
+ with self._registry_lock:
465
+ return list(self._managers.values())
466
+
467
+ class ServerState:
468
+ def __init__(self, models_path: Path):
469
+ self.models_path = models_path
470
+ self.models_path.mkdir(parents=True, exist_ok=True)
471
+ self.config_path = self.models_path.parent / "diffusers_server_config.json"
472
+ self.registry = PipelineRegistry()
473
+ self.manager: Optional[ModelManager] = None
474
+ self.config = {}
475
+ self.load_config() # This will set self.config
476
+ self._resolve_device_and_dtype()
477
+
478
+ # Eagerly acquire manager at startup if a model is configured
479
+ if self.config.get("model_name"):
480
+ try:
481
+ ASCIIColors.info(f"Acquiring initial model manager for '{self.config['model_name']}' on startup.")
482
+ self.manager = self.registry.get_manager(self.config, self.models_path)
483
+ except Exception as e:
484
+ ASCIIColors.error(f"Failed to acquire model manager on startup: {e}")
485
+ self.manager = None # Ensure manager is None on failure
486
+
487
+ def get_default_config(self) -> Dict[str, Any]:
488
+ return {
489
+ "model_name": "", "device": "auto", "torch_dtype_str": "auto", "use_safetensors": True,
490
+ "scheduler_name": "default", "safety_checker_on": True, "num_inference_steps": 25,
491
+ "guidance_scale": 7.0, "width": 512, "height": 512, "seed": -1,
492
+ "enable_cpu_offload": False, "enable_sequential_cpu_offload": False, "enable_xformers": False,
493
+ "hf_variant": None, "hf_token": None, "hf_cache_path": None, "local_files_only": False,
494
+ "unload_inactive_model_after": 0
495
+ }
496
+
497
+ def save_config(self):
498
+ """Saves the current configuration to a JSON file."""
499
+ try:
500
+ with open(self.config_path, 'w') as f:
501
+ json.dump(self.config, f, indent=4)
502
+ ASCIIColors.info(f"Server config saved to {self.config_path}")
503
+ except Exception as e:
504
+ ASCIIColors.error(f"Failed to save server config: {e}")
505
+
506
+ def load_config(self):
507
+ """Loads configuration from JSON file, falling back to defaults."""
508
+ default_config = self.get_default_config()
509
+ if self.config_path.exists():
510
+ try:
511
+ with open(self.config_path, 'r') as f:
512
+ loaded_config = json.load(f)
513
+ # Merge loaded config into defaults to ensure all keys are present
514
+ default_config.update(loaded_config)
515
+ self.config = default_config
516
+ ASCIIColors.info(f"Loaded server configuration from {self.config_path}")
517
+ except (json.JSONDecodeError, IOError) as e:
518
+ ASCIIColors.warning(f"Could not load config file, using defaults. Error: {e}")
519
+ self.config = default_config
520
+ else:
521
+ self.config = default_config
522
+ # Save back to ensure file exists and is up-to-date with all keys
523
+ self.save_config()
524
+
525
+ def _resolve_device_and_dtype(self):
526
+ if self.config.get("device", "auto").lower() == "auto":
527
+ self.config["device"] = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
528
+
529
+ # Prioritize bfloat16 for Qwen models on supported hardware, as it's more stable
530
+ if "Qwen" in self.config.get("model_name", "") and self.config["device"] == "cuda":
531
+ if hasattr(torch.cuda, 'is_bf16_supported') and torch.cuda.is_bf16_supported():
532
+ self.config["torch_dtype_str"] = "bfloat16"
533
+ ASCIIColors.info("Qwen model detected on compatible hardware. Forcing dtype to bfloat16 for stability.")
534
+ return
535
+
536
+ if self.config["torch_dtype_str"].lower() == "auto":
537
+ self.config["torch_dtype_str"] = "float16" if self.config["device"] != "cpu" else "float32"
538
+
539
+ def update_settings(self, new_settings: Dict[str, Any]):
540
+ """Updates settings, swaps the manager if critical settings change, and saves the config."""
541
+ if 'model' in new_settings and 'model_name' not in new_settings:
542
+ new_settings['model_name'] = new_settings.pop('model')
543
+
544
+ # Safeguard: If a model is already configured and the new settings don't specify one,
545
+ # keep the old one. This prevents a misconfigured client from wiping a valid server state.
546
+ if self.config.get("model_name") and not new_settings.get("model_name"):
547
+ ASCIIColors.info("Incoming settings have no model_name. Preserving existing model.")
548
+ new_settings["model_name"] = self.config["model_name"]
549
+
550
+ # Release old manager if it exists
551
+ if self.manager:
552
+ self.registry.release_manager(self.manager.config)
553
+ self.manager = None
554
+
555
+ # Update the config in memory
556
+ self.config.update(new_settings)
557
+ ASCIIColors.info(f"Server config updated. Current model_name: {self.config.get('model_name')}")
558
+
559
+ self._resolve_device_and_dtype()
560
+
561
+ # Acquire new manager with the updated config
562
+ if self.config.get("model_name"):
563
+ ASCIIColors.info("Acquiring model manager with updated configuration...")
564
+ self.manager = self.registry.get_manager(self.config, self.models_path)
565
+ else:
566
+ ASCIIColors.warning("No model_name in config after update, manager not acquired.")
567
+
568
+ self.save_config() # Persist the new state
569
+ return True
570
+
571
+ def get_active_manager(self) -> ModelManager:
572
+ if self.manager:
573
+ return self.manager
574
+ raise HTTPException(status_code=400, detail="No model is configured or manager is not active. Please set a model using the /set_settings endpoint.")
575
+
576
+ state: Optional[ServerState] = None
577
+
578
+ # --- Pydantic Models for API ---
579
+ class T2IRequest(BaseModel):
580
+ prompt: str
581
+ negative_prompt: str = ""
582
+ params: Dict[str, Any] = Field(default_factory=dict)
583
+
584
+ class EditRequestPayload(BaseModel):
585
+ prompt: str
586
+ image_paths: List[str] = Field(default_factory=list)
587
+ params: Dict[str, Any] = Field(default_factory=dict)
588
+
589
+ # --- API Endpoints ---
590
+ @router.post("/generate_image")
591
+ async def generate_image(request: T2IRequest):
592
+ try:
593
+ manager = state.get_active_manager()
594
+ params = request.params
595
+ seed = int(params.get("seed", state.config.get("seed", -1)))
596
+ generator = None
597
+ if seed != -1:
598
+ generator = torch.Generator(device=state.config["device"]).manual_seed(seed)
599
+
600
+ pipeline_args = {
601
+ "prompt": request.prompt, "negative_prompt": request.negative_prompt,
602
+ "width": int(params.get("width", state.config.get("width", 512))),
603
+ "height": int(params.get("height", state.config.get("height", 512))),
604
+ "num_inference_steps": int(params.get("num_inference_steps", state.config.get("num_inference_steps", 25))),
605
+ "guidance_scale": float(params.get("guidance_scale", state.config.get("guidance_scale", 7.0))),
606
+ "generator": generator
607
+ }
608
+
609
+ future = Future()
610
+ manager.queue.put((future,"text2image", pipeline_args))
611
+ result_bytes = future.result()
612
+ return Response(content=result_bytes, media_type="image/png")
613
+ except Exception as e:
614
+ trace_exception(e)
615
+ raise HTTPException(status_code=500, detail=str(e))
616
+
617
+ @router.post("/edit_image")
618
+ async def edit_image(json_payload: str = Form(...), files: List[UploadFile] = []):
619
+ try:
620
+ data = EditRequestPayload.parse_raw(json_payload)
621
+ manager = state.get_active_manager()
622
+
623
+ pil_images = []
624
+ for file in files:
625
+ contents = await file.read()
626
+ pil_images.append(Image.open(BytesIO(contents)).convert("RGB"))
627
+
628
+ for path in data.image_paths:
629
+ pil_images.append(load_image(path).convert("RGB"))
630
+
631
+ if not pil_images:
632
+ raise HTTPException(status_code=400, detail="No images provided for editing.")
633
+
634
+ task = "inpainting" if data.params.get("mask") else "image2image"
635
+
636
+ pipeline_args = {
637
+ "prompt": data.prompt,
638
+ "image": pil_images[0], # Simple i2i for now
639
+ "strength": float(data.params.get("strength", 0.8)),
640
+ # Add other params like mask etc.
641
+ }
642
+
643
+ future = Future()
644
+ manager.queue.put((future, task, pipeline_args))
645
+ result_bytes = future.result()
646
+ return Response(content=result_bytes, media_type="image/png")
647
+ except Exception as e:
648
+ trace_exception(e)
649
+ raise HTTPException(status_code=500, detail=str(e))
650
+
651
+ @router.get("/list_models")
652
+ def list_models_endpoint():
653
+ civitai = [{'model_name': key, 'display_name': info['display_name'], 'description': info['description'], 'owned_by': info['owned_by']} for key, info in CIVITAI_MODELS.items()]
654
+ local = [{'model_name': f.name, 'display_name': f.stem, 'description': 'Local safetensors file.', 'owned_by': 'local_user'} for f in state.models_path.glob("*.safetensors")]
655
+ return civitai + local
656
+
657
+ @router.get("/list_local_models")
658
+ def list_local_models_endpoint():
659
+ return sorted([f.name for f in state.models_path.glob("*.safetensors")])
660
+
661
+ @router.get("/list_available_models")
662
+ def list_available_models_endpoint():
663
+ discoverable = [m['model_name'] for m in list_models_endpoint()]
664
+ return sorted(list(set(discoverable)))
665
+
666
+ @router.get("/get_settings")
667
+ def get_settings_endpoint():
668
+ settings_list = []
669
+ # Add options for dropdowns
670
+ available_models = list_available_models_endpoint()
671
+ schedulers = list(SCHEDULER_MAPPING.keys())
672
+ config_to_display = state.config or state.get_default_config()
673
+ for name, value in config_to_display.items():
674
+ setting = {"name": name, "type": str(type(value).__name__), "value": value}
675
+ if name == "model_name": setting["options"] = available_models
676
+ if name == "scheduler_name": setting["options"] = schedulers
677
+ settings_list.append(setting)
678
+ return settings_list
679
+
680
+ @router.post("/set_settings")
681
+ def set_settings_endpoint(settings: Dict[str, Any]):
682
+ try:
683
+ success = state.update_settings(settings)
684
+ return {"success": success}
685
+ except Exception as e:
686
+ trace_exception(e)
687
+ raise HTTPException(status_code=500, detail=str(e))
688
+
689
+ @router.get("/status")
690
+ def status_endpoint():
691
+ return {"status": "running", "diffusers_available": DIFFUSERS_AVAILABLE, "model_loaded": state.manager.is_loaded if state.manager else False}
692
+
693
+ @router.post("/unload_model")
694
+ def unload_model_endpoint():
695
+ if state.manager:
696
+ state.manager._unload_pipeline()
697
+ state.registry.release_manager(state.manager.config)
698
+ state.manager = None
699
+ return {"status": "unloaded"}
700
+
701
+ @router.get("/ps")
702
+ def ps_endpoint():
703
+ managers = state.registry.get_all_managers()
704
+ return [{
705
+ "model_name": m.config.get("model_name"), "is_loaded": m.is_loaded,
706
+ "task": m.current_task, "device": m.config.get("device"), "ref_count": m.ref_count,
707
+ "queue_size": m.queue.qsize(), "last_used": time.ctime(m.last_used_time)
708
+ } for m in managers]
709
+
710
+ app.include_router(router)
711
+
712
+ if __name__ == "__main__":
713
+ parser = argparse.ArgumentParser(description="Diffusers TTI Server")
714
+ parser.add_argument("--host", type=str, default="localhost", help="Host to bind to.")
715
+ parser.add_argument("--port", type=int, default=9630, help="Port to bind to.")
716
+ parser.add_argument("--models-path", type=str, required=True, help="Path to the models directory.")
717
+ args = parser.parse_args()
718
+
719
+ MODELS_PATH = Path(args.models_path)
720
+ state = ServerState(MODELS_PATH)
721
+
722
+ ASCIIColors.cyan(f"--- Diffusers TTI Server ---")
723
+ ASCIIColors.green(f"Starting server on http://{args.host}:{args.port}")
724
+ ASCIIColors.green(f"Serving models from: {MODELS_PATH.resolve()}")
725
+ if not DIFFUSERS_AVAILABLE:
726
+ ASCIIColors.error("Diffusers or its dependencies are not installed correctly in the server's environment!")
727
+ else:
728
+ ASCIIColors.info(f"Detected device: {state.config['device']}")
729
+
730
+ uvicorn.run(app, host=args.host, port=args.port, reload=False)