ollamadiffuser 1.2.3__py3-none-any.whl → 2.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. ollamadiffuser/__init__.py +1 -1
  2. ollamadiffuser/api/server.py +312 -312
  3. ollamadiffuser/cli/config_commands.py +119 -0
  4. ollamadiffuser/cli/lora_commands.py +169 -0
  5. ollamadiffuser/cli/main.py +85 -1233
  6. ollamadiffuser/cli/model_commands.py +664 -0
  7. ollamadiffuser/cli/recommend_command.py +205 -0
  8. ollamadiffuser/cli/registry_commands.py +197 -0
  9. ollamadiffuser/core/config/model_registry.py +562 -11
  10. ollamadiffuser/core/config/settings.py +24 -2
  11. ollamadiffuser/core/inference/__init__.py +5 -0
  12. ollamadiffuser/core/inference/base.py +182 -0
  13. ollamadiffuser/core/inference/engine.py +204 -1405
  14. ollamadiffuser/core/inference/strategies/__init__.py +1 -0
  15. ollamadiffuser/core/inference/strategies/controlnet_strategy.py +170 -0
  16. ollamadiffuser/core/inference/strategies/flux_strategy.py +136 -0
  17. ollamadiffuser/core/inference/strategies/generic_strategy.py +164 -0
  18. ollamadiffuser/core/inference/strategies/gguf_strategy.py +113 -0
  19. ollamadiffuser/core/inference/strategies/hidream_strategy.py +104 -0
  20. ollamadiffuser/core/inference/strategies/sd15_strategy.py +134 -0
  21. ollamadiffuser/core/inference/strategies/sd3_strategy.py +80 -0
  22. ollamadiffuser/core/inference/strategies/sdxl_strategy.py +131 -0
  23. ollamadiffuser/core/inference/strategies/video_strategy.py +108 -0
  24. ollamadiffuser/mcp/__init__.py +0 -0
  25. ollamadiffuser/mcp/server.py +184 -0
  26. ollamadiffuser/ui/templates/index.html +62 -1
  27. ollamadiffuser/ui/web.py +116 -54
  28. {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.1.dist-info}/METADATA +317 -108
  29. ollamadiffuser-2.0.1.dist-info/RECORD +61 -0
  30. {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.1.dist-info}/WHEEL +1 -1
  31. {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.1.dist-info}/entry_points.txt +1 -0
  32. ollamadiffuser/core/models/registry.py +0 -384
  33. ollamadiffuser/ui/samples/.DS_Store +0 -0
  34. ollamadiffuser-1.2.3.dist-info/RECORD +0 -45
  35. {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.1.dist-info}/licenses/LICENSE +0 -0
  36. {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.1.dist-info}/top_level.txt +0 -0
@@ -1,450 +1,450 @@
1
- from fastapi import FastAPI, HTTPException, File, UploadFile, Form
2
- from fastapi.responses import Response
3
- from fastapi.middleware.cors import CORSMiddleware
4
- import uvicorn
1
+ """OllamaDiffuser REST API Server"""
2
+
3
+ import asyncio
4
+ import base64
5
5
  import io
6
6
  import logging
7
- from typing import Dict, Any, Optional
8
- from pydantic import BaseModel
7
+ from typing import Any, Dict, Optional
8
+
9
+ import uvicorn
10
+ from fastapi import FastAPI, File, Form, HTTPException, UploadFile
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+ from fastapi.responses import Response
9
13
  from PIL import Image
14
+ from pydantic import BaseModel
10
15
 
11
- from ..core.models.manager import model_manager
12
16
  from ..core.config.settings import settings
17
+ from ..core.models.manager import model_manager
13
18
 
14
- # Setup logging
15
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
19
+ logging.basicConfig(
20
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
21
+ )
16
22
  logger = logging.getLogger(__name__)
17
23
 
18
- # API request models
24
+
25
+ # --- Request models ---
26
+
27
+
19
28
  class GenerateRequest(BaseModel):
20
29
  prompt: str
21
30
  negative_prompt: str = "low quality, bad anatomy, worst quality, low resolution"
22
31
  num_inference_steps: Optional[int] = None
23
- steps: Optional[int] = None # Alias for num_inference_steps for convenience
32
+ steps: Optional[int] = None
24
33
  guidance_scale: Optional[float] = None
25
- cfg_scale: Optional[float] = None # Alias for guidance_scale for convenience
34
+ cfg_scale: Optional[float] = None
26
35
  width: int = 1024
27
36
  height: int = 1024
28
- control_image_path: Optional[str] = None # Path to control image file
29
- controlnet_conditioning_scale: float = 1.0
30
- control_guidance_start: float = 0.0
31
- control_guidance_end: float = 1.0
37
+ seed: Optional[int] = None
38
+ response_format: Optional[str] = None # "b64_json" for JSON with base64, None for raw PNG
39
+
40
+
41
+ class Img2ImgRequest(BaseModel):
42
+ prompt: str
43
+ negative_prompt: str = "low quality, bad anatomy, worst quality, low resolution"
44
+ num_inference_steps: Optional[int] = None
45
+ guidance_scale: Optional[float] = None
46
+ seed: Optional[int] = None
47
+ strength: float = 0.75
48
+
32
49
 
33
50
  class LoadModelRequest(BaseModel):
34
51
  model_name: str
35
52
 
53
+
36
54
  class LoadLoRARequest(BaseModel):
37
55
  lora_name: str
38
56
  repo_id: str
39
57
  weight_name: Optional[str] = None
40
58
  scale: float = 1.0
41
59
 
60
+
61
+ # --- Helpers ---
62
+
63
+
64
+ def _image_to_response(image: Image.Image) -> Response:
65
+ """Convert PIL Image to PNG Response."""
66
+ buf = io.BytesIO()
67
+ image.save(buf, format="PNG")
68
+ return Response(content=buf.getvalue(), media_type="image/png")
69
+
70
+
71
+ def _image_to_json_response(image: Image.Image) -> Dict[str, Any]:
72
+ """Convert PIL Image to JSON response with base64-encoded PNG."""
73
+ buf = io.BytesIO()
74
+ image.save(buf, format="PNG")
75
+ b64_data = base64.b64encode(buf.getvalue()).decode("utf-8")
76
+ return {
77
+ "image": b64_data,
78
+ "format": "png",
79
+ "width": image.width,
80
+ "height": image.height,
81
+ }
82
+
83
+
84
+ def _get_engine():
85
+ """Get the currently loaded inference engine or raise 400."""
86
+ if not model_manager.is_model_loaded():
87
+ raise HTTPException(status_code=400, detail="No model loaded")
88
+ return model_manager.loaded_model
89
+
90
+
91
+ # --- App factory ---
92
+
93
+
42
94
  def create_app() -> FastAPI:
43
95
  """Create FastAPI application"""
44
96
  app = FastAPI(
45
97
  title="OllamaDiffuser API",
46
98
  description="Image generation model management and inference API",
47
- version="1.0.0"
99
+ version="2.0.0",
48
100
  )
49
-
50
- # Add CORS middleware
101
+
51
102
  if settings.server.enable_cors:
52
103
  app.add_middleware(
53
104
  CORSMiddleware,
54
105
  allow_origins=["*"],
55
- allow_credentials=True,
56
106
  allow_methods=["*"],
57
107
  allow_headers=["*"],
58
108
  )
59
-
60
- # Root endpoint
109
+
110
+ # --- Root ---
111
+
61
112
  @app.get("/")
62
113
  async def root():
63
- """Root endpoint with API information"""
64
114
  return {
65
115
  "name": "OllamaDiffuser API",
66
- "version": "1.0.0",
67
- "description": "Image generation model management and inference API",
116
+ "version": "2.0.0",
68
117
  "status": "running",
69
118
  "endpoints": {
70
- "documentation": "/docs",
71
- "openapi_schema": "/openapi.json",
72
- "health_check": "/api/health",
119
+ "docs": "/docs",
120
+ "health": "/api/health",
73
121
  "models": "/api/models",
74
- "generate": "/api/generate"
122
+ "generate": "/api/generate",
123
+ "img2img": "/api/generate/img2img",
124
+ "inpaint": "/api/generate/inpaint",
75
125
  },
76
- "usage": {
77
- "web_ui": "Use 'ollamadiffuser --mode ui' to start the web interface",
78
- "cli": "Use 'ollamadiffuser --help' for command line options",
79
- "api_docs": "Visit /docs for interactive API documentation"
80
- }
81
126
  }
82
-
83
- # Model management endpoints
127
+
128
+ # --- Health ---
129
+
130
+ @app.get("/api/health")
131
+ async def health_check():
132
+ return {
133
+ "status": "healthy",
134
+ "model_loaded": model_manager.is_model_loaded(),
135
+ "current_model": model_manager.get_current_model(),
136
+ }
137
+
138
+ # --- Model management ---
139
+
84
140
  @app.get("/api/models")
85
141
  async def list_models():
86
- """List all models"""
87
142
  return {
88
143
  "available": model_manager.list_available_models(),
89
144
  "installed": model_manager.list_installed_models(),
90
- "current": model_manager.get_current_model()
145
+ "current": model_manager.get_current_model(),
91
146
  }
92
-
147
+
93
148
  @app.get("/api/models/running")
94
149
  async def get_running_model():
95
- """Get currently running model"""
96
150
  if model_manager.is_model_loaded():
97
151
  engine = model_manager.loaded_model
98
152
  return {
99
153
  "model": model_manager.get_current_model(),
100
154
  "info": engine.get_model_info(),
101
- "loaded": True
155
+ "loaded": True,
102
156
  }
103
- else:
104
- return {"loaded": False}
105
-
157
+ return {"loaded": False}
158
+
106
159
  @app.get("/api/models/{model_name}")
107
160
  async def get_model_info(model_name: str):
108
- """Get model detailed information"""
109
161
  info = model_manager.get_model_info(model_name)
110
162
  if info is None:
111
- raise HTTPException(status_code=404, detail="Model does not exist")
163
+ raise HTTPException(status_code=404, detail="Model not found")
112
164
  return info
113
-
165
+
114
166
  @app.post("/api/models/pull")
115
167
  async def pull_model(model_name: str):
116
- """Download model"""
117
- if model_manager.pull_model(model_name):
168
+ success = await asyncio.to_thread(model_manager.pull_model, model_name)
169
+ if success:
118
170
  return {"message": f"Model {model_name} downloaded successfully"}
119
- else:
120
- raise HTTPException(status_code=400, detail=f"Failed to download model {model_name}")
121
-
171
+ raise HTTPException(
172
+ status_code=400, detail=f"Failed to download model {model_name}"
173
+ )
174
+
122
175
  @app.post("/api/models/load")
123
176
  async def load_model(request: LoadModelRequest):
124
- """Load model"""
125
- if model_manager.load_model(request.model_name):
177
+ success = await asyncio.to_thread(model_manager.load_model, request.model_name)
178
+ if success:
126
179
  return {"message": f"Model {request.model_name} loaded successfully"}
127
- else:
128
- raise HTTPException(status_code=400, detail=f"Failed to load model {request.model_name}")
129
-
180
+ raise HTTPException(
181
+ status_code=400, detail=f"Failed to load model {request.model_name}"
182
+ )
183
+
130
184
  @app.post("/api/models/unload")
131
185
  async def unload_model():
132
- """Unload current model"""
133
186
  model_manager.unload_model()
134
187
  return {"message": "Model unloaded"}
135
-
188
+
136
189
  @app.delete("/api/models/{model_name}")
137
190
  async def remove_model(model_name: str):
138
- """Remove model"""
139
191
  if model_manager.remove_model(model_name):
140
- return {"message": f"Model {model_name} removed successfully"}
141
- else:
142
- raise HTTPException(status_code=400, detail=f"Failed to remove model {model_name}")
143
-
144
- # LoRA management endpoints
145
- @app.post("/api/lora/load")
146
- async def load_lora(request: LoadLoRARequest):
147
- """Load LoRA weights into current model"""
148
- # Check if model is loaded
149
- if not model_manager.is_model_loaded():
150
- raise HTTPException(status_code=400, detail="No model loaded, please load a model first")
151
-
152
- try:
153
- # Get current loaded inference engine
154
- engine = model_manager.loaded_model
155
-
156
- # Load LoRA weights
157
- success = engine.load_lora_runtime(
158
- repo_id=request.repo_id,
159
- weight_name=request.weight_name,
160
- scale=request.scale
161
- )
162
-
163
- if success:
164
- return {"message": f"LoRA {request.lora_name} loaded successfully with scale {request.scale}"}
165
- else:
166
- raise HTTPException(status_code=400, detail=f"Failed to load LoRA {request.lora_name}")
167
-
168
- except Exception as e:
169
- logger.error(f"LoRA loading failed: {e}")
170
- raise HTTPException(status_code=500, detail=f"LoRA loading failed: {str(e)}")
171
-
172
- @app.post("/api/lora/unload")
173
- async def unload_lora():
174
- """Unload current LoRA weights"""
175
- # Check if model is loaded
176
- if not model_manager.is_model_loaded():
177
- raise HTTPException(status_code=400, detail="No model loaded")
178
-
179
- try:
180
- # Get current loaded inference engine
181
- engine = model_manager.loaded_model
182
-
183
- # Unload LoRA weights
184
- success = engine.unload_lora()
185
-
186
- if success:
187
- return {"message": "LoRA weights unloaded successfully"}
188
- else:
189
- raise HTTPException(status_code=400, detail="Failed to unload LoRA weights")
190
-
191
- except Exception as e:
192
- logger.error(f"LoRA unloading failed: {e}")
193
- raise HTTPException(status_code=500, detail=f"LoRA unloading failed: {str(e)}")
194
-
195
- @app.get("/api/lora/status")
196
- async def get_lora_status():
197
- """Get current LoRA status"""
198
- # Check if model is loaded
199
- if not model_manager.is_model_loaded():
200
- return {"loaded": False, "message": "No model loaded"}
201
-
202
- try:
203
- # Get current loaded inference engine
204
- engine = model_manager.loaded_model
205
-
206
- # Check tracked LoRA state
207
- if hasattr(engine, 'current_lora') and engine.current_lora:
208
- lora_info = engine.current_lora.copy()
209
- return {
210
- "loaded": True,
211
- "info": lora_info,
212
- "message": "LoRA loaded"
213
- }
214
- else:
215
- return {
216
- "loaded": False,
217
- "info": None,
218
- "message": "No LoRA loaded"
219
- }
220
-
221
- except Exception as e:
222
- logger.error(f"Failed to get LoRA status: {e}")
223
- raise HTTPException(status_code=500, detail=f"Failed to get LoRA status: {str(e)}")
224
-
225
- # Image generation endpoints
192
+ return {"message": f"Model {model_name} removed"}
193
+ raise HTTPException(
194
+ status_code=400, detail=f"Failed to remove model {model_name}"
195
+ )
196
+
197
+ # --- Image generation (async via thread pool) ---
198
+
226
199
  @app.post("/api/generate")
227
200
  async def generate_image(request: GenerateRequest):
228
- """Generate image"""
229
- # Check if model is loaded
230
- if not model_manager.is_model_loaded():
231
- raise HTTPException(status_code=400, detail="No model loaded, please load a model first")
232
-
201
+ """Generate image from text prompt"""
202
+ engine = _get_engine()
203
+
204
+ steps = request.steps if request.steps is not None else request.num_inference_steps
205
+ guidance = (
206
+ request.cfg_scale
207
+ if request.cfg_scale is not None
208
+ else request.guidance_scale
209
+ )
210
+
233
211
  try:
234
- # Get current loaded inference engine
235
- engine = model_manager.loaded_model
236
-
237
- # Handle parameter aliasing - prioritize shorter names for convenience
238
- steps = request.steps if request.steps is not None else request.num_inference_steps
239
- guidance = request.cfg_scale if request.cfg_scale is not None else request.guidance_scale
240
-
241
- # Generate image
242
- image = engine.generate_image(
212
+ image = await asyncio.to_thread(
213
+ engine.generate_image,
243
214
  prompt=request.prompt,
244
215
  negative_prompt=request.negative_prompt,
245
216
  num_inference_steps=steps,
246
- steps=steps, # Pass both for GGUF compatibility
247
217
  guidance_scale=guidance,
248
- cfg_scale=guidance, # Pass both for GGUF compatibility
249
218
  width=request.width,
250
219
  height=request.height,
251
- control_image=request.control_image_path,
252
- controlnet_conditioning_scale=request.controlnet_conditioning_scale,
253
- control_guidance_start=request.control_guidance_start,
254
- control_guidance_end=request.control_guidance_end
220
+ seed=request.seed,
221
+ )
222
+ if request.response_format == "b64_json":
223
+ return _image_to_json_response(image)
224
+ return _image_to_response(image)
225
+ except Exception as e:
226
+ logger.error(f"Generation failed: {e}", exc_info=True)
227
+ raise HTTPException(status_code=500, detail="Image generation failed")
228
+
229
+ @app.post("/api/generate/img2img")
230
+ async def generate_img2img(
231
+ prompt: str = Form(...),
232
+ negative_prompt: str = Form("low quality, bad anatomy, worst quality"),
233
+ num_inference_steps: Optional[int] = Form(None),
234
+ guidance_scale: Optional[float] = Form(None),
235
+ seed: Optional[int] = Form(None),
236
+ strength: float = Form(0.75),
237
+ image: UploadFile = File(...),
238
+ ):
239
+ """Image-to-image generation"""
240
+ engine = _get_engine()
241
+
242
+ image_data = await image.read()
243
+ input_image = Image.open(io.BytesIO(image_data)).convert("RGB")
244
+
245
+ try:
246
+ result = await asyncio.to_thread(
247
+ engine.generate_image,
248
+ prompt=prompt,
249
+ negative_prompt=negative_prompt,
250
+ num_inference_steps=num_inference_steps,
251
+ guidance_scale=guidance_scale,
252
+ width=input_image.width,
253
+ height=input_image.height,
254
+ seed=seed,
255
+ image=input_image,
256
+ strength=strength,
255
257
  )
256
-
257
- # Convert PIL image to bytes
258
- img_byte_arr = io.BytesIO()
259
- image.save(img_byte_arr, format='PNG')
260
- img_byte_arr = img_byte_arr.getvalue()
261
-
262
- return Response(content=img_byte_arr, media_type="image/png")
263
-
258
+ return _image_to_response(result)
264
259
  except Exception as e:
265
- logger.error(f"Image generation failed: {e}")
266
- raise HTTPException(status_code=500, detail=f"Image generation failed: {str(e)}")
267
-
260
+ logger.error(f"img2img failed: {e}", exc_info=True)
261
+ raise HTTPException(status_code=500, detail="Image-to-image generation failed")
262
+
263
+ @app.post("/api/generate/inpaint")
264
+ async def generate_inpaint(
265
+ prompt: str = Form(...),
266
+ negative_prompt: str = Form("low quality, bad anatomy, worst quality"),
267
+ num_inference_steps: Optional[int] = Form(None),
268
+ guidance_scale: Optional[float] = Form(None),
269
+ seed: Optional[int] = Form(None),
270
+ strength: float = Form(0.75),
271
+ image: UploadFile = File(...),
272
+ mask: UploadFile = File(...),
273
+ ):
274
+ """Inpainting generation"""
275
+ engine = _get_engine()
276
+
277
+ image_data = await image.read()
278
+ mask_data = await mask.read()
279
+ input_image = Image.open(io.BytesIO(image_data)).convert("RGB")
280
+ mask_image = Image.open(io.BytesIO(mask_data)).convert("RGB")
281
+
282
+ try:
283
+ result = await asyncio.to_thread(
284
+ engine.generate_image,
285
+ prompt=prompt,
286
+ negative_prompt=negative_prompt,
287
+ num_inference_steps=num_inference_steps,
288
+ guidance_scale=guidance_scale,
289
+ width=input_image.width,
290
+ height=input_image.height,
291
+ seed=seed,
292
+ image=input_image,
293
+ mask_image=mask_image,
294
+ strength=strength,
295
+ )
296
+ return _image_to_response(result)
297
+ except Exception as e:
298
+ logger.error(f"Inpainting failed: {e}", exc_info=True)
299
+ raise HTTPException(status_code=500, detail="Inpainting generation failed")
300
+
268
301
  @app.post("/api/generate/controlnet")
269
- async def generate_image_with_controlnet(
302
+ async def generate_controlnet(
270
303
  prompt: str = Form(...),
271
- negative_prompt: str = Form("low quality, bad anatomy, worst quality, low resolution"),
304
+ negative_prompt: str = Form("low quality, bad anatomy, worst quality"),
272
305
  num_inference_steps: Optional[int] = Form(None),
273
306
  guidance_scale: Optional[float] = Form(None),
274
307
  width: int = Form(1024),
275
308
  height: int = Form(1024),
309
+ seed: Optional[int] = Form(None),
276
310
  controlnet_conditioning_scale: float = Form(1.0),
277
311
  control_guidance_start: float = Form(0.0),
278
312
  control_guidance_end: float = Form(1.0),
279
- control_image: Optional[UploadFile] = File(None)
313
+ control_image: Optional[UploadFile] = File(None),
280
314
  ):
281
- """Generate image with ControlNet using uploaded control image"""
282
- # Check if model is loaded
283
- if not model_manager.is_model_loaded():
284
- raise HTTPException(status_code=400, detail="No model loaded, please load a model first")
285
-
315
+ """Generate image with ControlNet"""
316
+ engine = _get_engine()
317
+
318
+ if not engine.is_controlnet_pipeline:
319
+ raise HTTPException(
320
+ status_code=400, detail="Current model is not a ControlNet model"
321
+ )
322
+
323
+ control_pil = None
324
+ if control_image:
325
+ data = await control_image.read()
326
+ control_pil = Image.open(io.BytesIO(data)).convert("RGB")
327
+
286
328
  try:
287
- # Get current loaded inference engine
288
- engine = model_manager.loaded_model
289
-
290
- # Check if this is a ControlNet model
291
- if not engine.is_controlnet_pipeline:
292
- raise HTTPException(status_code=400, detail="Current model is not a ControlNet model")
293
-
294
- # Process control image if provided
295
- control_image_pil = None
296
- if control_image:
297
- # Read uploaded image
298
- image_data = await control_image.read()
299
- control_image_pil = Image.open(io.BytesIO(image_data)).convert('RGB')
300
-
301
- # Generate image
302
- image = engine.generate_image(
329
+ result = await asyncio.to_thread(
330
+ engine.generate_image,
303
331
  prompt=prompt,
304
332
  negative_prompt=negative_prompt,
305
333
  num_inference_steps=num_inference_steps,
306
334
  guidance_scale=guidance_scale,
307
335
  width=width,
308
336
  height=height,
309
- control_image=control_image_pil,
337
+ seed=seed,
338
+ control_image=control_pil,
310
339
  controlnet_conditioning_scale=controlnet_conditioning_scale,
311
340
  control_guidance_start=control_guidance_start,
312
- control_guidance_end=control_guidance_end
341
+ control_guidance_end=control_guidance_end,
313
342
  )
314
-
315
- # Convert PIL image to bytes
316
- img_byte_arr = io.BytesIO()
317
- image.save(img_byte_arr, format='PNG')
318
- img_byte_arr = img_byte_arr.getvalue()
319
-
320
- return Response(content=img_byte_arr, media_type="image/png")
321
-
343
+ return _image_to_response(result)
322
344
  except Exception as e:
323
- logger.error(f"ControlNet image generation failed: {e}")
324
- raise HTTPException(status_code=500, detail=f"ControlNet image generation failed: {str(e)}")
325
-
326
- @app.post("/api/controlnet/initialize")
327
- async def initialize_controlnet():
328
- """Initialize ControlNet preprocessors"""
345
+ logger.error(f"ControlNet generation failed: {e}", exc_info=True)
346
+ raise HTTPException(status_code=500, detail="ControlNet generation failed")
347
+
348
+ # --- LoRA management ---
349
+
350
+ @app.post("/api/lora/load")
351
+ async def load_lora(request: LoadLoRARequest):
352
+ engine = _get_engine()
329
353
  try:
330
- from ..core.utils.controlnet_preprocessors import controlnet_preprocessor
331
-
332
- logger.info("Explicitly initializing ControlNet preprocessors...")
333
- success = controlnet_preprocessor.initialize(force=True)
334
-
335
- return {
336
- "success": success,
337
- "initialized": controlnet_preprocessor.is_initialized(),
338
- "available_types": controlnet_preprocessor.get_available_types(),
339
- "message": "ControlNet preprocessors initialized successfully" if success else "Failed to initialize ControlNet preprocessors"
340
- }
354
+ success = engine.load_lora_runtime(
355
+ repo_id=request.repo_id,
356
+ weight_name=request.weight_name,
357
+ scale=request.scale,
358
+ )
359
+ if success:
360
+ return {"message": f"LoRA {request.lora_name} loaded (scale={request.scale})"}
361
+ raise HTTPException(status_code=400, detail="Failed to load LoRA")
341
362
  except Exception as e:
342
- logger.error(f"Failed to initialize ControlNet preprocessors: {e}")
343
- raise HTTPException(status_code=500, detail=f"Failed to initialize ControlNet preprocessors: {str(e)}")
344
-
363
+ logger.error(f"LoRA load failed: {e}", exc_info=True)
364
+ raise HTTPException(status_code=500, detail="Failed to load LoRA")
365
+
366
+ @app.post("/api/lora/unload")
367
+ async def unload_lora():
368
+ engine = _get_engine()
369
+ if engine.unload_lora():
370
+ return {"message": "LoRA unloaded"}
371
+ raise HTTPException(status_code=400, detail="Failed to unload LoRA")
372
+
373
+ @app.get("/api/lora/status")
374
+ async def get_lora_status():
375
+ if not model_manager.is_model_loaded():
376
+ return {"loaded": False, "message": "No model loaded"}
377
+ engine = model_manager.loaded_model
378
+ if hasattr(engine, "current_lora") and engine.current_lora:
379
+ return {"loaded": True, "info": engine.current_lora}
380
+ return {"loaded": False, "info": None}
381
+
382
+ # --- ControlNet preprocessors ---
383
+
384
+ @app.post("/api/controlnet/initialize")
385
+ async def initialize_controlnet():
386
+ from ..core.utils.controlnet_preprocessors import controlnet_preprocessor
387
+
388
+ success = controlnet_preprocessor.initialize(force=True)
389
+ return {
390
+ "success": success,
391
+ "initialized": controlnet_preprocessor.is_initialized(),
392
+ "available_types": controlnet_preprocessor.get_available_types(),
393
+ }
394
+
345
395
  @app.get("/api/controlnet/preprocessors")
346
396
  async def get_controlnet_preprocessors():
347
- """Get available ControlNet preprocessors"""
348
- try:
349
- from ..core.utils.controlnet_preprocessors import controlnet_preprocessor
350
- return {
351
- "available_types": controlnet_preprocessor.get_available_types(),
352
- "available": controlnet_preprocessor.is_available(),
353
- "initialized": controlnet_preprocessor.is_initialized(),
354
- "description": {
355
- "canny": "Edge detection using Canny algorithm",
356
- "depth": "Depth estimation for depth-based control",
357
- "openpose": "Human pose detection for pose control",
358
- "scribble": "Scribble/sketch detection for artistic control",
359
- "hed": "Holistically-nested edge detection",
360
- "mlsd": "Mobile line segment detection",
361
- "normal": "Surface normal estimation",
362
- "lineart": "Line art detection",
363
- "lineart_anime": "Anime-style line art detection",
364
- "shuffle": "Content shuffling for style transfer"
365
- }
366
- }
367
- except Exception as e:
368
- logger.error(f"Failed to get ControlNet preprocessors: {e}")
369
- raise HTTPException(status_code=500, detail=f"Failed to get ControlNet preprocessors: {str(e)}")
370
-
397
+ from ..core.utils.controlnet_preprocessors import controlnet_preprocessor
398
+
399
+ return {
400
+ "available_types": controlnet_preprocessor.get_available_types(),
401
+ "available": controlnet_preprocessor.is_available(),
402
+ "initialized": controlnet_preprocessor.is_initialized(),
403
+ }
404
+
371
405
  @app.post("/api/controlnet/preprocess")
372
406
  async def preprocess_control_image(
373
407
  control_type: str = Form(...),
374
- image: UploadFile = File(...)
408
+ image: UploadFile = File(...),
375
409
  ):
376
- """Preprocess image for ControlNet"""
377
- try:
378
- from ..core.utils.controlnet_preprocessors import controlnet_preprocessor
379
-
380
- # Initialize ControlNet preprocessors if needed
381
- if not controlnet_preprocessor.is_initialized():
382
- logger.info("Initializing ControlNet preprocessors for image preprocessing...")
383
- if not controlnet_preprocessor.initialize():
384
- raise HTTPException(status_code=500, detail="Failed to initialize ControlNet preprocessors")
385
-
386
- # Read uploaded image
387
- image_data = await image.read()
388
- input_image = Image.open(io.BytesIO(image_data)).convert('RGB')
389
-
390
- # Preprocess image
391
- processed_image = controlnet_preprocessor.preprocess(input_image, control_type)
392
-
393
- # Convert PIL image to bytes
394
- img_byte_arr = io.BytesIO()
395
- processed_image.save(img_byte_arr, format='PNG')
396
- img_byte_arr = img_byte_arr.getvalue()
397
-
398
- return Response(content=img_byte_arr, media_type="image/png")
399
-
400
- except Exception as e:
401
- logger.error(f"Image preprocessing failed: {e}")
402
- raise HTTPException(status_code=500, detail=f"Image preprocessing failed: {str(e)}")
403
-
404
- # Health check endpoints
405
- @app.get("/api/health")
406
- async def health_check():
407
- """Health check"""
408
- return {
409
- "status": "healthy",
410
- "model_loaded": model_manager.is_model_loaded(),
411
- "current_model": model_manager.get_current_model()
412
- }
413
-
414
- # Server management endpoints
410
+ from ..core.utils.controlnet_preprocessors import controlnet_preprocessor
411
+
412
+ if not controlnet_preprocessor.is_initialized():
413
+ if not controlnet_preprocessor.initialize():
414
+ raise HTTPException(status_code=500, detail="Failed to init preprocessors")
415
+
416
+ data = await image.read()
417
+ input_image = Image.open(io.BytesIO(data)).convert("RGB")
418
+ processed = controlnet_preprocessor.preprocess(input_image, control_type)
419
+ return _image_to_response(processed)
420
+
421
+ # --- Server management ---
422
+
415
423
  @app.post("/api/shutdown")
416
424
  async def shutdown_server():
417
- """Gracefully shutdown the server"""
418
425
  import os
419
426
  import signal
420
- import asyncio
421
-
422
- # Unload model first
427
+
423
428
  model_manager.unload_model()
424
-
425
- # Schedule server shutdown
429
+
426
430
  def shutdown():
427
431
  os.kill(os.getpid(), signal.SIGTERM)
428
-
429
- # Delay shutdown to allow response to be sent
432
+
430
433
  asyncio.get_event_loop().call_later(0.5, shutdown)
431
-
432
434
  return {"message": "Server shutting down..."}
433
-
435
+
434
436
  return app
435
437
 
438
+
436
439
  def run_server(host: str = None, port: int = None):
437
- """Start server"""
438
- # Use default values from configuration
440
+ """Start the API server"""
439
441
  host = host or settings.server.host
440
442
  port = port or settings.server.port
441
-
442
- # Create FastAPI application
443
+
443
444
  app = create_app()
444
-
445
- # Run server with uvicorn
446
445
  logger.info(f"Starting server: http://{host}:{port}")
447
446
  uvicorn.run(app, host=host, port=port, log_level="info")
448
447
 
448
+
449
449
  if __name__ == "__main__":
450
- run_server()
450
+ run_server()