lollms-client 1.5.6__py3-none-any.whl → 1.7.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. lollms_client/__init__.py +1 -1
  2. lollms_client/llm_bindings/azure_openai/__init__.py +2 -2
  3. lollms_client/llm_bindings/claude/__init__.py +125 -35
  4. lollms_client/llm_bindings/gemini/__init__.py +261 -159
  5. lollms_client/llm_bindings/grok/__init__.py +52 -15
  6. lollms_client/llm_bindings/groq/__init__.py +2 -2
  7. lollms_client/llm_bindings/hugging_face_inference_api/__init__.py +2 -2
  8. lollms_client/llm_bindings/litellm/__init__.py +1 -1
  9. lollms_client/llm_bindings/llama_cpp_server/__init__.py +605 -0
  10. lollms_client/llm_bindings/llamacpp/__init__.py +18 -11
  11. lollms_client/llm_bindings/lollms/__init__.py +76 -21
  12. lollms_client/llm_bindings/lollms_webui/__init__.py +1 -1
  13. lollms_client/llm_bindings/mistral/__init__.py +2 -2
  14. lollms_client/llm_bindings/novita_ai/__init__.py +142 -6
  15. lollms_client/llm_bindings/ollama/__init__.py +345 -89
  16. lollms_client/llm_bindings/open_router/__init__.py +2 -2
  17. lollms_client/llm_bindings/openai/__init__.py +81 -20
  18. lollms_client/llm_bindings/openllm/__init__.py +362 -506
  19. lollms_client/llm_bindings/openwebui/__init__.py +333 -171
  20. lollms_client/llm_bindings/perplexity/__init__.py +2 -2
  21. lollms_client/llm_bindings/pythonllamacpp/__init__.py +3 -3
  22. lollms_client/llm_bindings/tensor_rt/__init__.py +1 -1
  23. lollms_client/llm_bindings/transformers/__init__.py +428 -632
  24. lollms_client/llm_bindings/vllm/__init__.py +1 -1
  25. lollms_client/lollms_agentic.py +4 -2
  26. lollms_client/lollms_base_binding.py +61 -0
  27. lollms_client/lollms_core.py +512 -1890
  28. lollms_client/lollms_discussion.py +65 -39
  29. lollms_client/lollms_llm_binding.py +126 -261
  30. lollms_client/lollms_mcp_binding.py +49 -77
  31. lollms_client/lollms_stt_binding.py +99 -52
  32. lollms_client/lollms_tti_binding.py +38 -38
  33. lollms_client/lollms_ttm_binding.py +38 -42
  34. lollms_client/lollms_tts_binding.py +43 -18
  35. lollms_client/lollms_ttv_binding.py +38 -42
  36. lollms_client/lollms_types.py +4 -2
  37. lollms_client/stt_bindings/whisper/__init__.py +108 -23
  38. lollms_client/stt_bindings/whispercpp/__init__.py +7 -1
  39. lollms_client/tti_bindings/diffusers/__init__.py +464 -803
  40. lollms_client/tti_bindings/diffusers/server/main.py +1062 -0
  41. lollms_client/tti_bindings/gemini/__init__.py +182 -239
  42. lollms_client/tti_bindings/leonardo_ai/__init__.py +6 -3
  43. lollms_client/tti_bindings/lollms/__init__.py +4 -1
  44. lollms_client/tti_bindings/novita_ai/__init__.py +5 -2
  45. lollms_client/tti_bindings/openai/__init__.py +10 -11
  46. lollms_client/tti_bindings/stability_ai/__init__.py +5 -3
  47. lollms_client/ttm_bindings/audiocraft/__init__.py +7 -12
  48. lollms_client/ttm_bindings/beatoven_ai/__init__.py +7 -3
  49. lollms_client/ttm_bindings/lollms/__init__.py +4 -17
  50. lollms_client/ttm_bindings/replicate/__init__.py +7 -4
  51. lollms_client/ttm_bindings/stability_ai/__init__.py +7 -4
  52. lollms_client/ttm_bindings/topmediai/__init__.py +6 -3
  53. lollms_client/tts_bindings/bark/__init__.py +7 -10
  54. lollms_client/tts_bindings/lollms/__init__.py +6 -1
  55. lollms_client/tts_bindings/piper_tts/__init__.py +8 -11
  56. lollms_client/tts_bindings/xtts/__init__.py +157 -74
  57. lollms_client/tts_bindings/xtts/server/main.py +241 -280
  58. {lollms_client-1.5.6.dist-info → lollms_client-1.7.13.dist-info}/METADATA +113 -5
  59. lollms_client-1.7.13.dist-info/RECORD +90 -0
  60. lollms_client-1.5.6.dist-info/RECORD +0 -87
  61. {lollms_client-1.5.6.dist-info → lollms_client-1.7.13.dist-info}/WHEEL +0 -0
  62. {lollms_client-1.5.6.dist-info → lollms_client-1.7.13.dist-info}/licenses/LICENSE +0 -0
  63. {lollms_client-1.5.6.dist-info → lollms_client-1.7.13.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1062 @@
1
+ import os
2
+ import importlib
3
+ from io import BytesIO
4
+ from typing import Optional, List, Dict, Any, Union, Tuple
5
+ from pathlib import Path
6
+ import base64
7
+ import threading
8
+ import queue
9
+ from concurrent.futures import Future
10
+ import time
11
+ import hashlib
12
+ import requests
13
+ from tqdm import tqdm
14
+ import json
15
+ import shutil
16
+ import numpy as np
17
+ import gc
18
+ import argparse
19
+ import uvicorn
20
+ from fastapi import FastAPI, APIRouter, HTTPException, UploadFile, Form
21
+ from fastapi import Request, Response
22
+ from fastapi.responses import Response
23
+ from pydantic import BaseModel, Field
24
+ import sys
25
+ import platform
26
+ import inspect
27
+
28
+ class PullModelRequest(BaseModel):
29
+ hf_id: Optional[str] = Field(default=None, description="Hugging Face repo id or URL, e.g. 'stabilityai/sdxl-turbo'")
30
+ safetensors_url: Optional[str] = Field(default=None, description="Direct URL to a .safetensors file")
31
+ local_name: Optional[str] = Field(default=None, description="Optional name/folder under models/")
32
+
33
+ # Add binding root to sys.path to ensure local modules can be imported if structured that way.
34
+ binding_root = Path(__file__).resolve().parent.parent
35
+ sys.path.insert(0, str(binding_root))
36
+
37
+ # --- Dependency Check and Imports ---
38
+ try:
39
+ import torch
40
+ from diffusers import (
41
+ AutoPipelineForText2Image, AutoPipelineForImage2Image, AutoPipelineForInpainting,
42
+ DiffusionPipeline, StableDiffusionPipeline, QwenImageEditPipeline, QwenImageEditPlusPipeline
43
+ )
44
+ from diffusers.utils import load_image
45
+ from PIL import Image
46
+ from ascii_colors import trace_exception, ASCIIColors
47
+ DIFFUSERS_AVAILABLE = True
48
+ except ImportError as e:
49
+ print(f"FATAL: A required package is missing from the server's venv: {e}.")
50
+ DIFFUSERS_AVAILABLE = False
51
+ # Define dummy classes to allow server to start and report error via API
52
+ class Dummy: pass
53
+ torch = Dummy()
54
+ torch.cuda = Dummy()
55
+ torch.cuda.is_available = lambda: False
56
+ torch.backends = Dummy()
57
+ torch.backends.mps = Dummy()
58
+ torch.backends.mps.is_available = lambda: False
59
+ AutoPipelineForText2Image = AutoPipelineForImage2Image = AutoPipelineForInpainting = DiffusionPipeline = StableDiffusionPipeline = QwenImageEditPipeline = QwenImageEditPlusPipeline = Image = load_image = ASCIIColors = trace_exception = Dummy
60
+
61
+ # --- Server Setup ---
62
+ app = FastAPI(title="Diffusers TTI Server")
63
+ router = APIRouter()
64
+ MODELS_PATH = Path("./models")
65
+
66
+ # --- START: Core Logic (Complete and Unabridged) ---
67
+ CIVITAI_MODELS = {
68
+ "DreamShaper-8": {
69
+ "display_name": "DreamShaper 8", "url": "https://civitai.com/api/download/models/128713",
70
+ "filename": "dreamshaper_8.safetensors", "description": "Versatile SD1.5 style model.", "owned_by": "civitai"
71
+ },
72
+ "Juggernaut-xl": {
73
+ "display_name": "Juggernaut XL", "url": "https://civitai.com/api/download/models/133005",
74
+ "filename": "juggernautXL_version6Rundiffusion.safetensors", "description": "Artistic SDXL.", "owned_by": "civitai"
75
+ },
76
+ }
77
+
78
+ HF_PUBLIC_MODELS = {
79
+ "General Purpose & SDXL": [
80
+ {"model_name": "stabilityai/stable-diffusion-xl-base-1.0", "display_name": "Stable Diffusion XL 1.0", "desc": "Official 1024x1024 text-to-image model from Stability AI."},
81
+ {"model_name": "stabilityai/sdxl-turbo", "display_name": "SDXL Turbo", "desc": "A fast, real-time text-to-image model based on SDXL."},
82
+ {"model_name": "kandinsky-community/kandinsky-3", "display_name": "Kandinsky 3", "desc": "A powerful multilingual model with strong prompt understanding and aesthetic quality."},
83
+ {"model_name": "playgroundai/playground-v2.5-1024px-aesthetic", "display_name": "Playground v2.5", "desc": "A high-quality model focused on aesthetic outputs."},
84
+ ],
85
+ "Photorealistic": [
86
+ {"model_name": "emilianJR/epiCRealism", "display_name": "epiCRealism", "desc": "A popular community model for generating photorealistic images."},
87
+ {"model_name": "SG161222/Realistic_Vision_V5.1_noVAE", "display_name": "Realistic Vision 5.1", "desc": "One of the most popular realistic models, great for portraits and scenes."},
88
+ {"model_name": "Photon-v1", "display_name": "Photon", "desc": "A model known for high-quality, realistic images with good lighting and detail."},
89
+ ],
90
+ "Anime & Illustration": [
91
+ {"model_name": "hakurei/waifu-diffusion", "display_name": "Waifu Diffusion 1.4", "desc": "A widely-used model for generating high-quality anime-style images."},
92
+ {"model_name": "gsdf/Counterfeit-V3.0", "display_name": "Counterfeit V3.0", "desc": "A strong model for illustrative and 2.5D anime styles."},
93
+ {"model_name": "cagliostrolab/animagine-xl-3.0", "display_name": "Animagine XL 3.0", "desc": "A state-of-the-art anime model on the SDXL architecture."},
94
+ ],
95
+ "Artistic & Stylized": [
96
+ {"model_name": "wavymulder/Analog-Diffusion", "display_name": "Analog Diffusion", "desc": "Creates images with a vintage, analog film aesthetic."},
97
+ {"model_name": "dreamlike-art/dreamlike-photoreal-2.0", "display_name": "Dreamlike Photoreal 2.0", "desc": "Produces stunning, artistic, and photorealistic images."},
98
+ ],
99
+ "Image Editing Tools": [
100
+ {"model_name": "stabilityai/stable-diffusion-xl-refiner-1.0", "display_name": "SDXL Refiner 1.0", "desc": "A dedicated refiner model to improve details in SDXL generations."},
101
+ {"model_name": "timbrooks/instruct-pix2pix", "display_name": "Instruct-Pix2Pix", "desc": "The original instruction-based image editing model (SD 1.5)."},
102
+ {"model_name": "kandinsky-community/kandinsky-2-2-instruct-pix2pix", "display_name": "Kandinsky 2.2 Instruct", "desc": "An instruction-based model with strong prompt adherence, based on Kandinsky 2.2."},
103
+ {"model_name": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", "display_name": "SDXL Inpainting", "desc": "A dedicated inpainting model based on SDXL 1.0 for filling in masked areas."},
104
+ {"model_name": "Qwen/Qwen-Image-Edit", "display_name": "Qwen Image Edit", "desc": "An instruction-based model for various image editing tasks. (Review License)."},
105
+ {"model_name": "Qwen/Qwen-Image-Edit-2509", "display_name": "Qwen Image Edit Plus", "desc": "Advanced multi-image editing and fusion. (Review License)."},
106
+ ],
107
+ "Legacy & Base Models": [
108
+ {"model_name": "runwayml/stable-diffusion-v1-5", "display_name": "Stable Diffusion 1.5", "desc": "The classic and versatile SD1.5 base model."},
109
+ {"model_name": "stabilityai/stable-diffusion-2-1", "display_name": "Stable Diffusion 2.1", "desc": "The 768x768 base model from the SD2.x series."},
110
+ ]
111
+ }
112
+
113
+ HF_GATED_MODELS = {
114
+ "Next-Generation (Gated Access Required)": [
115
+ {"model_name": "stabilityai/stable-diffusion-3-medium-diffusers", "display_name": "Stable Diffusion 3 Medium", "desc": "State-of-the-art model with advanced prompt understanding. Requires free registration."},
116
+ {"model_name": "black-forest-labs/FLUX.1-schnell", "display_name": "FLUX.1 Schnell", "desc": "A powerful and extremely fast next-generation model. Requires access request."},
117
+ {"model_name": "black-forest-labs/FLUX.1-dev", "display_name": "FLUX.1 Dev", "desc": "The larger developer version of the FLUX.1 model. Requires access request."},
118
+ ]
119
+ }
120
+
121
+
122
+ TORCH_DTYPE_MAP_STR_TO_OBJ = {
123
+ "float16": getattr(torch, 'float16', 'float16'), "bfloat16": getattr(torch, 'bfloat16', 'bfloat16'),
124
+ "float32": getattr(torch, 'float32', 'float32'), "auto": "auto"
125
+ }
126
+
127
+ SCHEDULER_MAPPING = {
128
+ "default": None, "ddim": "DDIMScheduler", "ddpm": "DDPMScheduler", "deis_multistep": "DEISMultistepScheduler",
129
+ "dpm_multistep": "DPMSolverMultistepScheduler", "dpm_multistep_karras": "DPMSolverMultistepScheduler", "dpm_single": "DPMSolverSinglestepScheduler",
130
+ "dpm_adaptive": "DPMSolverPlusPlusScheduler", "dpm++_2m": "DPMSolverMultistepScheduler", "dpm++_2m_karras": "DPMSolverMultistepScheduler",
131
+ "dpm++_2s_ancestral": "DPMSolverAncestralDiscreteScheduler", "dpm++_2s_ancestral_karras": "DPMSolverAncestralDiscreteScheduler", "dpm++_sde": "DPMSolverSDEScheduler",
132
+ "dpm++_sde_karras": "DPMSolverSDEScheduler", "euler_ancestral_discrete": "EulerAncestralDiscreteScheduler", "euler_discrete": "EulerDiscreteScheduler",
133
+ "heun_discrete": "HeunDiscreteScheduler", "heun_karras": "HeunDiscreteScheduler", "lms_discrete": "LMSDiscreteScheduler",
134
+ "lms_karras": "LMSDiscreteScheduler", "pndm": "PNDMScheduler", "unipc_multistep": "UniPCMultistepScheduler",
135
+ "dpm++_2m_sde": "DPMSolverMultistepScheduler", "dpm++_2m_sde_karras": "DPMSolverMultistepScheduler", "dpm2": "KDPM2DiscreteScheduler",
136
+ "dpm2_karras": "KDPM2DiscreteScheduler", "dpm2_a": "KDPM2AncestralDiscreteScheduler", "dpm2_a_karras": "KDPM2AncestralDiscreteScheduler",
137
+ "euler": "EulerDiscreteScheduler", "euler_a": "EulerAncestralDiscreteScheduler", "heun": "HeunDiscreteScheduler", "lms": "LMSDiscreteScheduler"
138
+ }
139
+
140
+ SCHEDULER_USES_KARRAS_SIGMAS = [
141
+ "dpm_multistep_karras","dpm++_2m_karras","dpm++_2s_ancestral_karras", "dpm++_sde_karras","heun_karras","lms_karras",
142
+ "dpm++_2m_sde_karras","dpm2_karras","dpm2_a_karras"
143
+ ]
144
+
145
+
146
+ class ModelManager:
147
+ def __init__(self, config: Dict[str, Any], models_path: Path, registry: 'PipelineRegistry'):
148
+ self.config = config
149
+ self.models_path = models_path
150
+ self.registry = registry
151
+ self.pipeline: Optional[DiffusionPipeline] = None
152
+ self.current_task: Optional[str] = None
153
+ self.ref_count = 0
154
+ self.lock = threading.Lock()
155
+ self.queue = queue.Queue()
156
+ self.is_loaded = False
157
+ self.last_used_time = time.time()
158
+ self._stop_event = threading.Event()
159
+ self.worker_thread = threading.Thread(target=self._generation_worker, daemon=True)
160
+ self.worker_thread.start()
161
+ self._stop_monitor_event = threading.Event()
162
+ self._unload_monitor_thread = None
163
+ self._start_unload_monitor()
164
+ self.supported_args: Optional[set] = None
165
+
166
+ def acquire(self):
167
+ with self.lock:
168
+ self.ref_count += 1
169
+ return self
170
+
171
+ def release(self):
172
+ with self.lock:
173
+ self.ref_count -= 1
174
+ return self.ref_count
175
+
176
+ def stop(self):
177
+ self._stop_event.set()
178
+ if self._unload_monitor_thread:
179
+ self._stop_monitor_event.set()
180
+ self._unload_monitor_thread.join(timeout=2)
181
+ self.queue.put(None)
182
+ self.worker_thread.join(timeout=5)
183
+
184
+ def _start_unload_monitor(self):
185
+ unload_after = self.config.get("unload_inactive_model_after", 0)
186
+ if unload_after > 0 and self._unload_monitor_thread is None:
187
+ self._stop_monitor_event.clear()
188
+ self._unload_monitor_thread = threading.Thread(target=self._unload_monitor, daemon=True)
189
+ self._unload_monitor_thread.start()
190
+
191
+ def _unload_monitor(self):
192
+ unload_after = self.config.get("unload_inactive_model_after", 0)
193
+ if unload_after <= 0:
194
+ return
195
+ ASCIIColors.info(f"Starting inactivity monitor for '{self.config['model_name']}' (timeout: {unload_after}s).")
196
+ while not self._stop_monitor_event.wait(timeout=5.0):
197
+ with self.lock:
198
+ if not self.is_loaded:
199
+ continue
200
+ if time.time() - self.last_used_time > unload_after:
201
+ ASCIIColors.info(f"Model '{self.config['model_name']}' has been inactive. Unloading.")
202
+ self._unload_pipeline()
203
+
204
+ def _resolve_model_path(self, model_name: str) -> Union[str, Path]:
205
+ path_obj = Path(model_name)
206
+ if path_obj.is_absolute() and path_obj.exists():
207
+ return model_name
208
+ if model_name in CIVITAI_MODELS:
209
+ filename = CIVITAI_MODELS[model_name]["filename"]
210
+ local_path = self.models_path / filename
211
+ if not local_path.exists():
212
+ self._download_civitai_model(model_name)
213
+ return local_path
214
+
215
+ # Search in extra models path
216
+ if state.extra_models_path and state.extra_models_path.exists():
217
+ found_paths = list(state.extra_models_path.rglob(model_name))
218
+ if found_paths:
219
+ ASCIIColors.info(f"Found model in extra path: {found_paths[0]}")
220
+ return found_paths[0]
221
+
222
+ # Search in primary models path
223
+ found_paths = list(self.models_path.rglob(model_name))
224
+ if found_paths:
225
+ ASCIIColors.info(f"Found model in primary path: {found_paths[0]}")
226
+ return found_paths[0]
227
+
228
+ # Fallback for HF hub models that are folders, not single files.
229
+ local_path = self.models_path / model_name
230
+ if local_path.exists():
231
+ return local_path
232
+
233
+ return model_name
234
+
235
+ def _download_civitai_model(self, model_key: str):
236
+ model_info = CIVITAI_MODELS[model_key]
237
+ url = model_info["url"]
238
+ filename = model_info["filename"]
239
+ dest_path = self.models_path / filename
240
+ temp_path = dest_path.with_suffix(".temp")
241
+ ASCIIColors.cyan(f"Downloading '{filename}' from Civitai... to {dest_path}")
242
+ try:
243
+ with requests.get(url, stream=True) as r:
244
+ r.raise_for_status()
245
+ total_size = int(r.headers.get('content-length', 0))
246
+ with open(temp_path, 'wb') as f, tqdm(total=total_size, unit='iB', unit_scale=True, desc=f"Downloading {filename}") as bar:
247
+ for chunk in r.iter_content(chunk_size=8192):
248
+ f.write(chunk)
249
+ bar.update(len(chunk))
250
+ shutil.move(temp_path, dest_path)
251
+ ASCIIColors.green(f"Model '{filename}' downloaded successfully.")
252
+ except Exception as e:
253
+ if temp_path.exists():
254
+ temp_path.unlink()
255
+ raise Exception(f"Failed to download model {filename}: {e}")
256
+
257
+ def _set_scheduler(self):
258
+ if not self.pipeline:
259
+ return
260
+ if "Qwen" in self.config.get("model_name", "") or "FLUX" in self.config.get("model_name", ""):
261
+ ASCIIColors.info("Special model detected, skipping custom scheduler setup.")
262
+ return
263
+ scheduler_name_key = self.config["scheduler_name"].lower()
264
+ if scheduler_name_key == "default":
265
+ return
266
+ scheduler_class_name = SCHEDULER_MAPPING.get(scheduler_name_key)
267
+ if scheduler_class_name:
268
+ try:
269
+ SchedulerClass = getattr(importlib.import_module("diffusers.schedulers"), scheduler_class_name)
270
+ scheduler_config = self.pipeline.scheduler.config
271
+ scheduler_config["use_karras_sigmas"] = scheduler_name_key in SCHEDULER_USES_KARRAS_SIGMAS
272
+ self.pipeline.scheduler = SchedulerClass.from_config(scheduler_config)
273
+ ASCIIColors.info(f"Switched scheduler to {scheduler_class_name}")
274
+ except Exception as e:
275
+ ASCIIColors.warning(f"Could not switch scheduler to {scheduler_name_key}: {e}. Using current default.")
276
+
277
+ def _execute_load_pipeline(self, task: str, model_path: Union[str, Path], torch_dtype: Any):
278
+ if platform.system() == "Windows":
279
+ os.environ["HF_HUB_ENABLE_SYMLINKS"] = "0"
280
+
281
+ model_name_from_config = self.config.get("model_name", "")
282
+ use_device_map = False
283
+
284
+ try:
285
+ load_params = {}
286
+ if self.config.get("hf_cache_path"):
287
+ load_params["cache_dir"] = str(self.config["hf_cache_path"])
288
+ load_params["torch_dtype"] = torch_dtype
289
+
290
+ is_qwen_model = "Qwen" in model_name_from_config
291
+ is_flux_model = "FLUX" in model_name_from_config
292
+
293
+ if is_qwen_model or is_flux_model:
294
+ ASCIIColors.info(f"Special model '{model_name_from_config}' detected. Using dedicated pipeline loader.")
295
+ load_params.update({
296
+ "use_safetensors": self.config["use_safetensors"],
297
+ "token": self.config["hf_token"],
298
+ "local_files_only": self.config["local_files_only"]
299
+ })
300
+ if self.config["hf_variant"]:
301
+ load_params["variant"] = self.config["hf_variant"]
302
+ if not self.config["safety_checker_on"]:
303
+ load_params["safety_checker"] = None
304
+
305
+ should_offload = self.config["enable_cpu_offload"] or self.config["enable_sequential_cpu_offload"]
306
+ if should_offload:
307
+ ASCIIColors.info(f"Offload enabled. Forcing device_map='auto' for {model_name_from_config}.")
308
+ use_device_map = True
309
+ load_params["device_map"] = "auto"
310
+
311
+ if is_flux_model:
312
+ self.pipeline = AutoPipelineForText2Image.from_pretrained(model_name_from_config, **load_params)
313
+ elif "Qwen-Image-Edit-2509" in model_name_from_config:
314
+ self.pipeline = QwenImageEditPlusPipeline.from_pretrained(model_name_from_config, **load_params)
315
+ elif "Qwen-Image-Edit" in model_name_from_config:
316
+ self.pipeline = QwenImageEditPipeline.from_pretrained(model_name_from_config, **load_params)
317
+ elif "Qwen/Qwen-Image" in model_name_from_config:
318
+ self.pipeline = DiffusionPipeline.from_pretrained(model_name_from_config, **load_params)
319
+
320
+ else:
321
+ is_safetensors_file = str(model_path).endswith(".safetensors")
322
+ if is_safetensors_file:
323
+ ASCIIColors.info(f"Loading standard model from local .safetensors file: {model_path}")
324
+ try:
325
+ self.pipeline = AutoPipelineForText2Image.from_single_file(model_path, **load_params)
326
+ except Exception as e:
327
+ ASCIIColors.warning(f"Failed to load with AutoPipeline, falling back to StableDiffusionPipeline: {e}")
328
+ self.pipeline = StableDiffusionPipeline.from_single_file(model_path, **load_params)
329
+ else:
330
+ ASCIIColors.info(f"Loading standard model from Hub: {model_path}")
331
+ load_params.update({
332
+ "use_safetensors": self.config["use_safetensors"],
333
+ "token": self.config["hf_token"],
334
+ "local_files_only": self.config["local_files_only"]
335
+ })
336
+ if self.config["hf_variant"]:
337
+ load_params["variant"] = self.config["hf_variant"]
338
+ if not self.config["safety_checker_on"]:
339
+ load_params["safety_checker"] = None
340
+
341
+ is_large_model = "stable-diffusion-3" in str(model_path)
342
+ should_offload = self.config["enable_cpu_offload"] or self.config["enable_sequential_cpu_offload"]
343
+ if is_large_model and should_offload:
344
+ ASCIIColors.info(f"Large model '{model_path}' detected with offload enabled. Using device_map='auto'.")
345
+ use_device_map = True
346
+ load_params["device_map"] = "auto"
347
+
348
+ if task == "text2image":
349
+ self.pipeline = AutoPipelineForText2Image.from_pretrained(model_path, **load_params)
350
+ elif task == "image2image":
351
+ self.pipeline = AutoPipelineForImage2Image.from_pretrained(model_path, **load_params)
352
+ elif task == "inpainting":
353
+ self.pipeline = AutoPipelineForInpainting.from_pretrained(model_path, **load_params)
354
+
355
+ except Exception as e:
356
+ error_str = str(e).lower()
357
+ if "401" in error_str or "gated" in error_str or "authorization" in error_str:
358
+ msg = (f"AUTHENTICATION FAILED for model '{model_name_from_config}'. Please ensure you accepted the model license and provided a valid HF token.")
359
+ raise RuntimeError(msg)
360
+ raise e
361
+
362
+ self._set_scheduler()
363
+
364
+ if not use_device_map:
365
+ self.pipeline.to(self.config["device"])
366
+ if self.config["enable_xformers"]:
367
+ try:
368
+ self.pipeline.enable_xformers_memory_efficient_attention()
369
+ except Exception as e:
370
+ ASCIIColors.warning(f"Could not enable xFormers: {e}.")
371
+
372
+ if self.config["enable_cpu_offload"] and self.config["device"] != "cpu":
373
+ self.pipeline.enable_model_cpu_offload()
374
+ elif self.config["enable_sequential_cpu_offload"] and self.config["device"] != "cpu":
375
+ self.pipeline.enable_sequential_cpu_offload()
376
+ else:
377
+ ASCIIColors.info("Device map handled device placement. Skipping manual pipeline.to() and offload calls.")
378
+
379
+ if self.pipeline:
380
+ sig = inspect.signature(self.pipeline.__call__)
381
+ self.supported_args = {p.name for p in sig.parameters.values()}
382
+ ASCIIColors.info(f"Pipeline supported arguments detected: {self.supported_args}")
383
+
384
+ self.is_loaded = True
385
+ self.current_task = task
386
+ self.last_used_time = time.time()
387
+ ASCIIColors.green(f"Model '{model_name_from_config}' loaded successfully using '{'device_map' if use_device_map else 'standard'}' mode for task '{task}'.")
388
+
389
+ def _load_pipeline_for_task(self, task: str):
390
+ if self.pipeline and self.current_task == task:
391
+ return
392
+ if self.pipeline:
393
+ self._unload_pipeline()
394
+
395
+ model_name = self.config.get("model_name", "")
396
+ if not model_name:
397
+ raise ValueError("Model name cannot be empty for loading.")
398
+
399
+ ASCIIColors.info(f"Loading Diffusers model: {model_name} for task: {task}")
400
+ model_path = self._resolve_model_path(model_name)
401
+ torch_dtype = TORCH_DTYPE_MAP_STR_TO_OBJ.get(self.config["torch_dtype_str"].lower())
402
+
403
+ try:
404
+ self._execute_load_pipeline(task, model_path, torch_dtype)
405
+ return
406
+ except Exception as e:
407
+ is_oom = "out of memory" in str(e).lower()
408
+ if not is_oom or not hasattr(self, 'registry'):
409
+ raise e
410
+
411
+ ASCIIColors.warning(f"Failed to load '{model_name}' due to OOM. Attempting to unload other models to free VRAM.")
412
+
413
+ candidates_to_unload = [m for m in self.registry.get_all_managers() if m is not self and m.is_loaded]
414
+ candidates_to_unload.sort(key=lambda m: m.last_used_time)
415
+
416
+ if not candidates_to_unload:
417
+ ASCIIColors.error("OOM error, but no other models are available to unload.")
418
+ raise Exception("OOM error, but no other models are available to unload.")
419
+
420
+ for victim in candidates_to_unload:
421
+ ASCIIColors.info(f"Unloading '{victim.config['model_name']}' (last used: {time.ctime(victim.last_used_time)}) to free VRAM.")
422
+ victim._unload_pipeline()
423
+
424
+ try:
425
+ ASCIIColors.info(f"Retrying to load '{model_name}'...")
426
+ self._execute_load_pipeline(task, model_path, torch_dtype)
427
+ ASCIIColors.green(f"Successfully loaded '{model_name}' after freeing VRAM.")
428
+ return
429
+ except Exception as retry_e:
430
+ is_oom_retry = "out of memory" in str(retry_e).lower()
431
+ if not is_oom_retry:
432
+ raise retry_e
433
+
434
+ ASCIIColors.error(f"Could not load '{model_name}' even after unloading all other models.")
435
+ raise e
436
+
437
+ def _unload_pipeline(self):
438
+ if self.pipeline:
439
+ model_name = self.config.get('model_name', 'Unknown')
440
+ del self.pipeline
441
+ self.pipeline = None
442
+ self.supported_args = None
443
+ gc.collect()
444
+ if torch and torch.cuda.is_available():
445
+ torch.cuda.empty_cache()
446
+ self.is_loaded = False
447
+ self.current_task = None
448
+ ASCIIColors.info(f"Model '{model_name}' unloaded and VRAM cleared.")
449
+
450
+ def _generation_worker(self):
451
+ while not self._stop_event.is_set():
452
+ try:
453
+ job = self.queue.get(timeout=1)
454
+ if job is None:
455
+ break
456
+ future, task, pipeline_args = job
457
+ output = None
458
+ try:
459
+ with self.lock:
460
+ self.last_used_time = time.time()
461
+ if not self.is_loaded or self.current_task != task:
462
+ self._load_pipeline_for_task(task)
463
+
464
+ if self.supported_args:
465
+ filtered_args = {k: v for k, v in pipeline_args.items() if k in self.supported_args}
466
+ else:
467
+ ASCIIColors.warning("Supported argument set not found. Using unfiltered arguments.")
468
+ filtered_args = pipeline_args
469
+
470
+ with torch.no_grad():
471
+ output = self.pipeline(**filtered_args)
472
+
473
+ pil = output.images[0]
474
+ buf = BytesIO()
475
+ pil.save(buf, format="PNG")
476
+ future.set_result(buf.getvalue())
477
+ except Exception as e:
478
+ trace_exception(e)
479
+ future.set_exception(e)
480
+ finally:
481
+ self.queue.task_done()
482
+ if output is not None:
483
+ del output
484
+ gc.collect()
485
+ if torch.cuda.is_available():
486
+ torch.cuda.empty_cache()
487
+ except queue.Empty:
488
+ continue
489
+
490
+ class PipelineRegistry:
491
+ _instance = None
492
+ _lock = threading.Lock()
493
+ def __new__(cls, *args, **kwargs):
494
+ with cls._lock:
495
+ if cls._instance is None:
496
+ cls._instance = super().__new__(cls)
497
+ cls._instance._managers = {}
498
+ cls._instance._registry_lock = threading.Lock()
499
+ return cls._instance
500
+ @staticmethod
501
+ def _get_critical_keys():
502
+ return [
503
+ "model_name","device","torch_dtype_str","use_safetensors",
504
+ "safety_checker_on","hf_variant","enable_cpu_offload",
505
+ "enable_sequential_cpu_offload","enable_xformers",
506
+ "local_files_only","hf_cache_path","unload_inactive_model_after"
507
+ ]
508
+ def _get_config_key(self, config: Dict[str, Any]) -> str:
509
+ key_data = tuple(sorted((k, config.get(k)) for k in self._get_critical_keys()))
510
+ return hashlib.sha256(str(key_data).encode('utf-8')).hexdigest()
511
+ def get_manager(self, config: Dict[str, Any], models_path: Path) -> ModelManager:
512
+ key = self._get_config_key(config)
513
+ with self._registry_lock:
514
+ if key not in self._managers:
515
+ self._managers[key] = ModelManager(config.copy(), models_path, self)
516
+ return self._managers[key].acquire()
517
+ def release_manager(self, config: Dict[str, Any]):
518
+ key = self._get_config_key(config)
519
+ with self._registry_lock:
520
+ if key in self._managers:
521
+ manager = self._managers[key]
522
+ ref_count = manager.release()
523
+ if ref_count == 0:
524
+ ASCIIColors.info(f"Reference count for model '{config.get('model_name')}' is zero. Cleaning up manager.")
525
+ manager.stop()
526
+ with manager.lock:
527
+ manager._unload_pipeline()
528
+ del self._managers[key]
529
+ def get_active_managers(self) -> List[ModelManager]:
530
+ with self._registry_lock:
531
+ return [m for m in self._managers.values() if m.is_loaded]
532
+ def get_all_managers(self) -> List[ModelManager]:
533
+ with self._registry_lock:
534
+ return list(self._managers.values())
535
+
536
+ class ServerState:
537
+ def __init__(self, models_path: Path, extra_models_path: Optional[Path] = None):
538
+ self.models_path = models_path
539
+ self.extra_models_path = extra_models_path
540
+ self.models_path.mkdir(parents=True, exist_ok=True)
541
+ if self.extra_models_path:
542
+ self.extra_models_path.mkdir(parents=True, exist_ok=True)
543
+ self.config_path = self.models_path.parent / "diffusers_server_config.json"
544
+ self.registry = PipelineRegistry()
545
+ self.manager: Optional[ModelManager] = None
546
+ self.config = {}
547
+ self.load_config()
548
+ self._resolve_device_and_dtype()
549
+ if self.config.get("model_name"):
550
+ try:
551
+ ASCIIColors.info(f"Acquiring initial model manager for '{self.config['model_name']}' on startup.")
552
+ self.manager = self.registry.get_manager(self.config, self.models_path)
553
+ except Exception as e:
554
+ ASCIIColors.error(f"Failed to acquire model manager on startup: {e}")
555
+ self.manager = None
556
+
557
+ def get_default_config(self) -> Dict[str, Any]:
558
+ return {
559
+ "model_name": "", "device": "auto", "torch_dtype_str": "auto", "use_safetensors": True,
560
+ "scheduler_name": "default", "safety_checker_on": True, "num_inference_steps": 25,
561
+ "guidance_scale": 7.0, "width": 1024, "height": 1024, "seed": -1,
562
+ "enable_cpu_offload": False, "enable_sequential_cpu_offload": False, "enable_xformers": False,
563
+ "hf_variant": None, "hf_token": None, "hf_cache_path": None, "local_files_only": False,
564
+ "unload_inactive_model_after": 0
565
+ }
566
+
567
+ def save_config(self):
568
+ try:
569
+ with open(self.config_path, 'w') as f:
570
+ json.dump(self.config, f, indent=4)
571
+ ASCIIColors.info(f"Server config saved to {self.config_path}")
572
+ except Exception as e:
573
+ ASCIIColors.error(f"Failed to save server config: {e}")
574
+
575
+ def load_config(self):
576
+ default_config = self.get_default_config()
577
+ if self.config_path.exists():
578
+ try:
579
+ with open(self.config_path, 'r') as f:
580
+ loaded_config = json.load(f)
581
+ default_config.update(loaded_config)
582
+ self.config = default_config
583
+ ASCIIColors.info(f"Loaded server configuration from {self.config_path}")
584
+ except (json.JSONDecodeError, IOError) as e:
585
+ ASCIIColors.warning(f"Could not load config file, using defaults. Error: {e}")
586
+ self.config = default_config
587
+ else:
588
+ self.config = default_config
589
+ self.save_config()
590
+
591
+ def _resolve_device_and_dtype(self):
592
+ if self.config.get("device", "auto").lower() == "auto":
593
+ self.config["device"] = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
594
+
595
+ if ("Qwen" in self.config.get("model_name", "") or "FLUX" in self.config.get("model_name", "")) and self.config["device"] == "cuda":
596
+ if hasattr(torch.cuda, 'is_bf16_supported') and torch.cuda.is_bf16_supported():
597
+ self.config["torch_dtype_str"] = "bfloat16"
598
+ ASCIIColors.info("Special model detected on compatible hardware. Forcing dtype to bfloat16 for stability.")
599
+ return
600
+
601
+ if self.config["torch_dtype_str"].lower() == "auto":
602
+ self.config["torch_dtype_str"] = "float16" if self.config["device"] != "cpu" else "float32"
603
+
604
+ def update_settings(self, new_settings: Dict[str, Any]):
605
+ if 'model' in new_settings and 'model_name' not in new_settings:
606
+ new_settings['model_name'] = new_settings.pop('model')
607
+
608
+ if self.config.get("model_name") and not new_settings.get("model_name"):
609
+ ASCIIColors.info("Incoming settings have no model_name. Preserving existing model.")
610
+ new_settings["model_name"] = self.config["model_name"]
611
+
612
+ if self.manager:
613
+ self.registry.release_manager(self.manager.config)
614
+ self.manager = None
615
+
616
+ self.config.update(new_settings)
617
+ ASCIIColors.info(f"Server config updated. Current model_name: {self.config.get('model_name')}")
618
+
619
+ self._resolve_device_and_dtype()
620
+
621
+ if self.config.get("model_name"):
622
+ ASCIIColors.info("Acquiring model manager with updated configuration...")
623
+ self.manager = self.registry.get_manager(self.config, self.models_path)
624
+ else:
625
+ ASCIIColors.warning("No model_name in config after update, manager not acquired.")
626
+
627
+ self.save_config()
628
+ return True
629
+
630
+ def get_active_manager(self) -> ModelManager:
631
+ if self.manager:
632
+ return self.manager
633
+ raise HTTPException(status_code=400, detail="No model is configured or manager is not active. Please set a model using the /set_settings endpoint.")
634
+
635
+ state: Optional[ServerState] = None
636
+
637
+ # --- Pydantic Models for API ---
638
+ class T2IRequest(BaseModel):
639
+ prompt: str
640
+ negative_prompt: str = ""
641
+ params: Dict[str, Any] = Field(default_factory=dict)
642
+
643
+ class EditRequestPayload(BaseModel):
644
+ prompt: str
645
+ image_paths: List[str] = Field(default_factory=list)
646
+ params: Dict[str, Any] = Field(default_factory=dict)
647
+
648
+ class EditRequestJSON(BaseModel):
649
+ prompt: str
650
+ images_b64: List[str] = Field(description="A list of Base64 encoded image strings.")
651
+ params: Dict[str, Any] = Field(default_factory=dict)
652
+ def get_sanitized_request_for_logging(request_data: Any) -> Dict[str, Any]:
653
+ """
654
+ Takes a request object (Pydantic model or dict) and returns a 'safe' dictionary
655
+ for logging, with long base64 strings replaced by placeholders.
656
+ """
657
+ import copy
658
+
659
+ try:
660
+ if hasattr(request_data, 'model_dump'):
661
+ data = request_data.model_dump()
662
+ elif isinstance(request_data, dict):
663
+ data = copy.deepcopy(request_data)
664
+ else:
665
+ return {"error": "Unsupported data type for sanitization"}
666
+
667
+ # Sanitize the main list of images
668
+ if 'images_b64' in data and isinstance(data['images_b64'], list):
669
+ count = len(data['images_b64'])
670
+ data['images_b64'] = f"[<{count} base64 image(s) truncated>]"
671
+
672
+ # Sanitize a potential mask in the 'params' dictionary
673
+ if 'params' in data and isinstance(data.get('params'), dict):
674
+ if 'mask_image' in data['params'] and isinstance(data['params']['mask_image'], str):
675
+ original_len = len(data['params']['mask_image'])
676
+ data['params']['mask_image'] = f"[<base64 mask truncated, len={original_len}>]"
677
+
678
+ return data
679
+ except Exception:
680
+ return {"error": "Failed to sanitize request data."}
681
+
682
+ # --- API Endpoints ---
683
+ @router.post("/generate_image")
684
+ async def generate_image(request: T2IRequest):
685
+ manager = None
686
+ temp_config = None
687
+ try:
688
+ # Determine which model manager to use for this specific request
689
+ if "model_name" in request.params and request.params["model_name"]:
690
+ temp_config = state.config.copy()
691
+ temp_config["model_name"] = request.params.pop("model_name") # Remove from params to avoid being passed to pipeline
692
+ manager = state.registry.get_manager(temp_config, state.models_path)
693
+ ASCIIColors.info(f"Using per-request model: {temp_config['model_name']}")
694
+ else:
695
+ manager = state.get_active_manager()
696
+ ASCIIColors.info(f"Using session-configured model: {manager.config.get('model_name')}")
697
+
698
+ # Start with the manager's config (base settings)
699
+ pipeline_args = manager.config.copy()
700
+ # Override with per-request parameters
701
+ pipeline_args.update(request.params)
702
+
703
+ # Add prompts and ensure types for specific args
704
+ pipeline_args["prompt"] = request.prompt
705
+ pipeline_args["negative_prompt"] = request.negative_prompt
706
+ width = pipeline_args.get("width", 1024)
707
+ height = pipeline_args.get("height", 1024)
708
+ num_inference_steps = pipeline_args.get("num_inference_steps", 25)
709
+ seed = pipeline_args.get("seed", -1)
710
+ guidance_scale = pipeline_args.get("guidance_scale", 7.0)
711
+ pipeline_args["width"] = int(width if width else 1024)
712
+ pipeline_args["height"] = int(height if height else 1024)
713
+ pipeline_args["num_inference_steps"] = int(num_inference_steps if num_inference_steps else 25)
714
+ pipeline_args["guidance_scale"] = float(guidance_scale if guidance_scale else 7.0)
715
+
716
+ seed = int(seed if seed is not None else -1)
717
+ pipeline_args["generator"] = None
718
+ if seed != -1:
719
+ pipeline_args["generator"] = torch.Generator(device=manager.config["device"]).manual_seed(seed)
720
+
721
+ model_name = manager.config.get("model_name", "")
722
+ task = "text2image"
723
+
724
+ if "Qwen-Image-Edit" in model_name:
725
+ rng_seed = seed if seed != -1 else None
726
+ rng = np.random.default_rng(seed=rng_seed)
727
+ random_pixels = rng.integers(0, 256, size=(pipeline_args["height"], pipeline_args["width"], 3), dtype=np.uint8)
728
+ placeholder_image = Image.fromarray(random_pixels, 'RGB')
729
+ pipeline_args["image"] = placeholder_image
730
+ pipeline_args["strength"] = float(pipeline_args.get("strength", 1.0))
731
+ task = "image2image"
732
+
733
+ log_args = {k: v for k, v in pipeline_args.items() if k not in ['generator', 'image']}
734
+ if pipeline_args.get("generator"): log_args['generator'] = f"<torch.Generator(seed={seed})>"
735
+ if pipeline_args.get("image"): log_args['image'] = "<PIL Image object>"
736
+
737
+ ASCIIColors.cyan("--- Generating Image with Settings ---")
738
+ try:
739
+ print(json.dumps(log_args, indent=2, default=str))
740
+ except Exception as e:
741
+ ASCIIColors.warning(f"Could not print all settings: {e}")
742
+ print(log_args)
743
+ ASCIIColors.cyan("------------------------------------")
744
+
745
+ future = Future()
746
+ manager.queue.put((future, task, pipeline_args))
747
+ result_bytes = future.result()
748
+ return Response(content=result_bytes, media_type="image/png")
749
+ except Exception as e:
750
+ trace_exception(e)
751
+ raise HTTPException(status_code=500, detail=str(e))
752
+ finally:
753
+ if temp_config and manager:
754
+ state.registry.release_manager(temp_config)
755
+ ASCIIColors.info(f"Released per-request model: {temp_config['model_name']}")
756
+
757
+
758
+ @router.post("/edit_image")
759
+ async def edit_image(request: EditRequestJSON):
760
+ manager = None
761
+ temp_config = None
762
+ ASCIIColors.info(f"Received /edit_image request with {len(request.images_b64)} image(s).")
763
+ ASCIIColors.info(request.params)
764
+ try:
765
+ if "model_name" in request.params and request.params["model_name"]:
766
+ temp_config = state.config.copy()
767
+ temp_config["model_name"] = request.params.pop("model_name")
768
+ manager = state.registry.get_manager(temp_config, state.models_path)
769
+ ASCIIColors.info(f"Using per-request model: {temp_config['model_name']}")
770
+ else:
771
+ manager = state.get_active_manager()
772
+ ASCIIColors.info(f"Using session-configured model: {manager.config.get('model_name')}")
773
+
774
+ # Start with manager's config, then override with request params
775
+ pipeline_args = manager.config.copy()
776
+ pipeline_args.update(request.params)
777
+
778
+ pipeline_args["prompt"] = request.prompt
779
+ model_name = manager.config.get("model_name", "")
780
+
781
+ pil_images = []
782
+ for b64_string in request.images_b64:
783
+ b64_data = b64_string.split(";base64,")[1] if ";base64," in b64_string else b64_string
784
+ image_bytes = base64.b64decode(b64_data)
785
+ pil_images.append(Image.open(BytesIO(image_bytes)).convert("RGB"))
786
+
787
+ if not pil_images: raise HTTPException(status_code=400, detail="No valid images provided.")
788
+
789
+ seed = int(pipeline_args.get("seed", -1))
790
+ pipeline_args["generator"] = None
791
+ if seed != -1: pipeline_args["generator"] = torch.Generator(device=manager.config["device"]).manual_seed(seed)
792
+
793
+ if "mask_image" in pipeline_args and pipeline_args["mask_image"]:
794
+ b64_mask = pipeline_args["mask_image"]
795
+ b64_data = b64_mask.split(";base64,")[1] if ";base64," in b64_mask else b64_mask
796
+ mask_bytes = base64.b64decode(b64_data)
797
+ pipeline_args["mask_image"] = Image.open(BytesIO(mask_bytes)).convert("L")
798
+
799
+ task = "inpainting" if "mask_image" in pipeline_args and pipeline_args["mask_image"] else "image2image"
800
+
801
+ if "Qwen-Image-Edit-2509" in model_name:
802
+ task = "image2image"
803
+ pipeline_args.update({"true_cfg_scale": 4.0, "guidance_scale": 1.0, "num_inference_steps": 40, "negative_prompt": " "})
804
+ edit_mode = pipeline_args.get("edit_mode", "fusion")
805
+ if edit_mode == "fusion": pipeline_args["image"] = pil_images
806
+ else:
807
+ pipeline_args.update({"image": pil_images[0]})
808
+
809
+ log_args = {k: v for k, v in pipeline_args.items() if k not in ['generator', 'image', 'mask_image']}
810
+ if pipeline_args.get("generator"): log_args['generator'] = f"<torch.Generator(seed={seed})>"
811
+ if 'image' in pipeline_args: log_args['image'] = f"[<{len(pil_images)} PIL Image(s)>]"
812
+ if 'mask_image' in pipeline_args and pipeline_args['mask_image']: log_args['mask_image'] = "<PIL Mask Image>"
813
+
814
+ ASCIIColors.cyan("--- Editing Image with Settings ---")
815
+ try:
816
+ print(json.dumps(log_args, indent=2, default=str))
817
+ except Exception as e:
818
+ ASCIIColors.warning(f"Could not print all settings: {e}")
819
+ print(log_args)
820
+ ASCIIColors.cyan("---------------------------------")
821
+
822
+ future = Future(); manager.queue.put((future, task, pipeline_args))
823
+ return Response(content=future.result(), media_type="image/png")
824
+ except Exception as e:
825
+ sanitized_payload = get_sanitized_request_for_logging(request)
826
+ ASCIIColors.error(f"Exception in /edit_image. Sanitized Payload: {json.dumps(sanitized_payload, indent=2)}")
827
+ trace_exception(e)
828
+ raise HTTPException(status_code=500, detail=str(e))
829
+ finally:
830
+ if temp_config and manager:
831
+ state.registry.release_manager(temp_config)
832
+ ASCIIColors.info(f"Released per-request model: {temp_config['model_name']}")
833
+
834
+ @router.post("/pull_model")
835
+ def pull_model_endpoint(payload: PullModelRequest):
836
+ if not payload.hf_id and not payload.safetensors_url:
837
+ raise HTTPException(status_code=400, detail="Provide either 'hf_id' or 'safetensors_url'.")
838
+
839
+ # 1) Pull Hugging Face model into a folder
840
+ if payload.hf_id:
841
+ model_id = payload.hf_id.strip()
842
+ folder_name = payload.local_name or model_id.replace("/", "__")
843
+ dest_dir = state.models_path / folder_name
844
+ dest_dir.mkdir(parents=True, exist_ok=True)
845
+
846
+ try:
847
+ ASCIIColors.cyan(f"Pulling HF model '{model_id}' into {dest_dir}")
848
+ # Reuse config options for HF access
849
+ load_params: Dict[str, Any] = {}
850
+ if state.config.get("hf_cache_path"):
851
+ load_params["cache_dir"] = str(state.config["hf_cache_path"])
852
+ if state.config.get("hf_token"):
853
+ load_params["token"] = state.config["hf_token"]
854
+ # For a pull, we want to actually download:
855
+ load_params["local_files_only"] = False
856
+
857
+ # Use DiffusionPipeline (or AutoPipelineForText2Image) to download, then save_pretrained
858
+ pipe = DiffusionPipeline.from_pretrained(model_id, **load_params)
859
+ pipe.save_pretrained(dest_dir)
860
+ del pipe
861
+ gc.collect()
862
+ if torch.cuda.is_available():
863
+ torch.cuda.empty_cache()
864
+ ASCIIColors.green(f"Model '{model_id}' pulled to {dest_dir}")
865
+ return {"status": "ok", "model_name": folder_name}
866
+ except Exception as e:
867
+ trace_exception(e)
868
+ raise HTTPException(status_code=500, detail=f"Failed to pull HF model: {e}")
869
+
870
+ # 2) Pull raw .safetensors from URL
871
+ if payload.safetensors_url:
872
+ url = payload.safetensors_url.strip()
873
+ default_name = url.split("/")[-1] or "model.safetensors"
874
+ if not default_name.endswith(".safetensors"):
875
+ default_name += ".safetensors"
876
+ filename = payload.local_name or default_name
877
+
878
+ dest_path = state.models_path / filename
879
+ temp_path = dest_path.with_suffix(".temp")
880
+
881
+ ASCIIColors.cyan(f"Downloading safetensors from {url} to {dest_path}")
882
+ try:
883
+ with requests.get(url, stream=True) as r:
884
+ r.raise_for_status()
885
+ total_size = int(r.headers.get("content-length", 0))
886
+ with open(temp_path, "wb") as f, tqdm(total=total_size, unit="iB", unit_scale=True, desc=f"Downloading {filename}") as bar:
887
+ for chunk in r.iter_content(chunk_size=8192):
888
+ if not chunk:
889
+ continue
890
+ f.write(chunk)
891
+ bar.update(len(chunk))
892
+ shutil.move(temp_path, dest_path)
893
+ ASCIIColors.green(f"Safetensors file downloaded to {dest_path}")
894
+ return {"status": "ok", "model_name": filename}
895
+ except Exception as e:
896
+ if temp_path.exists():
897
+ temp_path.unlink()
898
+ trace_exception(e)
899
+ raise HTTPException(status_code=500, detail=f"Failed to download safetensors: {e}")
900
+
901
+
902
+ @router.get("/list_local_models")
903
+ def list_local_models_endpoint():
904
+ local_models = set()
905
+ models_root = Path(args.models_path)
906
+ extra_root = Path(args.extra_models_path) if args.extra_models_path else None
907
+
908
+ def scan_root(root: Path):
909
+ if not root or not root.exists():
910
+ return
911
+
912
+ # 1. Diffusers folders (Recursive)
913
+ for model_index in root.rglob("model_index.json"):
914
+ # For listing just the name, we probably want the folder name or relative path
915
+ # Keeping it simple: folder name.
916
+ local_models.add(model_index.parent.name)
917
+
918
+ # 2. Safetensors files (Recursive)
919
+ for safepath in root.rglob("*.safetensors"):
920
+ if (safepath.parent / "model_index.json").exists():
921
+ continue
922
+ local_models.add(safepath.name)
923
+
924
+ scan_root(models_root)
925
+ scan_root(extra_root)
926
+
927
+ return sorted(list(local_models))
928
+
929
+ @app.get("/list_models")
930
+ def list_models() -> list[dict]:
931
+ models_root = Path(args.models_path)
932
+ extra_root = Path(args.extra_models_path) if args.extra_models_path else None
933
+ result = []
934
+ seen_paths = set()
935
+
936
+ def scan_root(root: Path):
937
+ if not root or not root.exists():
938
+ return
939
+
940
+ # 1. Diffusers folders (Recursive)
941
+ # We look for model_index.json
942
+ for model_index in root.rglob("model_index.json"):
943
+ folder = model_index.parent
944
+ resolved_path = str(folder.resolve())
945
+ if resolved_path in seen_paths:
946
+ continue
947
+ seen_paths.add(resolved_path)
948
+
949
+ result.append({
950
+ "model_name": resolved_path,
951
+ "display_name": folder.name,
952
+ "description": "Local Diffusers pipeline"
953
+ })
954
+
955
+ # 2. Safetensors files (Recursive)
956
+ for safepath in root.rglob("*.safetensors"):
957
+ # Skip if part of a diffusers folder
958
+ if (safepath.parent / "model_index.json").exists():
959
+ continue
960
+
961
+ resolved_path = str(safepath.resolve())
962
+ if resolved_path in seen_paths:
963
+ continue
964
+ seen_paths.add(resolved_path)
965
+
966
+ result.append({
967
+ "model_name": resolved_path,
968
+ "display_name": safepath.stem,
969
+ "description": "Local .safetensors checkpoint"
970
+ })
971
+
972
+ scan_root(models_root)
973
+ scan_root(extra_root)
974
+ return result
975
+
976
+
977
+
978
+
979
+ @router.get("/list_available_models")
980
+ def list_available_models_endpoint():
981
+ # Use list_models() to get all available models (dicts) then extract names
982
+ models_dicts = list_models()
983
+ discoverable = [m['model_name'] for m in models_dicts]
984
+ return sorted(list(set(discoverable)))
985
+
986
+ @router.get("/get_settings")
987
+ def get_settings_endpoint():
988
+ settings_list = []
989
+ available_models = list_available_models_endpoint()
990
+ schedulers = list(SCHEDULER_MAPPING.keys())
991
+ config_to_display = state.config or state.get_default_config()
992
+ for name, value in config_to_display.items():
993
+ setting = {"name": name, "type": str(type(value).__name__), "value": value}
994
+ if name == "model_name": setting["options"] = available_models
995
+ if name == "scheduler_name": setting["options"] = schedulers
996
+ settings_list.append(setting)
997
+ return settings_list
998
+
999
+ @router.post("/set_settings")
1000
+ def set_settings_endpoint(settings: Dict[str, Any]):
1001
+ try:
1002
+ success = state.update_settings(settings)
1003
+ return {"success": success}
1004
+ except Exception as e:
1005
+ trace_exception(e)
1006
+ raise HTTPException(status_code=500, detail=str(e))
1007
+
1008
+ @router.get("/status")
1009
+ def status_endpoint():
1010
+ return {"status": "running", "diffusers_available": DIFFUSERS_AVAILABLE, "model_loaded": state.manager.is_loaded if state.manager else False}
1011
+
1012
+ @router.post("/unload_model")
1013
+ def unload_model_endpoint():
1014
+ if state.manager:
1015
+ state.manager._unload_pipeline()
1016
+ state.registry.release_manager(state.manager.config)
1017
+ state.manager = None
1018
+ return {"status": "unloaded"}
1019
+
1020
+ @router.get("/ps")
1021
+ def ps_endpoint():
1022
+ managers = state.registry.get_all_managers()
1023
+ return [{
1024
+ "model_name": m.config.get("model_name"), "is_loaded": m.is_loaded,
1025
+ "task": m.current_task, "device": m.config.get("device"), "ref_count": m.ref_count,
1026
+ "queue_size": m.queue.qsize(), "last_used": time.ctime(m.last_used_time)
1027
+ } for m in managers]
1028
+
1029
+ app.include_router(router)
1030
+
1031
+ if __name__ == "__main__":
1032
+ parser = argparse.ArgumentParser(description="Diffusers TTI Server")
1033
+ parser.add_argument("--host", type=str, default="localhost", help="Host to bind to.")
1034
+ parser.add_argument("--port", type=int, default=9630, help="Port to bind to.")
1035
+ parser.add_argument("--models-path", type=str, required=True, help="Path to the models directory.")
1036
+ parser.add_argument("--extra-models-path", type=str, default=None, help="Path to an extra models directory.")
1037
+ parser.add_argument(
1038
+ "--hf-token",
1039
+ type=str,
1040
+ default=None,
1041
+ help="Optional Hugging Face access token used to download private or gated repos."
1042
+ )
1043
+
1044
+ args = parser.parse_args()
1045
+
1046
+ MODELS_PATH = Path(args.models_path)
1047
+ EXTRA_MODELS_PATH = Path(args.extra_models_path) if args.extra_models_path else None
1048
+ state = ServerState(MODELS_PATH, EXTRA_MODELS_PATH)
1049
+ if args.hf_token:
1050
+ state.config["hf_token"] = args.hf_token
1051
+ ASCIIColors.info("Hugging Face token received via CLI and stored in server config.")
1052
+ ASCIIColors.cyan(f"--- Diffusers TTI Server ---")
1053
+ ASCIIColors.green(f"Starting server on http://{args.host}:{args.port}")
1054
+ ASCIIColors.green(f"Serving models from: {MODELS_PATH.resolve()}")
1055
+ if EXTRA_MODELS_PATH:
1056
+ ASCIIColors.green(f"Serving extra models from: {EXTRA_MODELS_PATH.resolve()}")
1057
+ if not DIFFUSERS_AVAILABLE:
1058
+ ASCIIColors.error("Diffusers or its dependencies are not installed correctly in the server's environment!")
1059
+ else:
1060
+ ASCIIColors.info(f"Detected device: {state.config['device']}")
1061
+
1062
+ uvicorn.run(app, host=args.host, port=args.port, reload=False)