ollamadiffuser 1.2.3__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ollamadiffuser/__init__.py +1 -1
- ollamadiffuser/api/server.py +312 -312
- ollamadiffuser/cli/config_commands.py +119 -0
- ollamadiffuser/cli/lora_commands.py +169 -0
- ollamadiffuser/cli/main.py +85 -1233
- ollamadiffuser/cli/model_commands.py +664 -0
- ollamadiffuser/cli/recommend_command.py +205 -0
- ollamadiffuser/cli/registry_commands.py +197 -0
- ollamadiffuser/core/config/model_registry.py +562 -11
- ollamadiffuser/core/config/settings.py +24 -2
- ollamadiffuser/core/inference/__init__.py +5 -0
- ollamadiffuser/core/inference/base.py +182 -0
- ollamadiffuser/core/inference/engine.py +204 -1405
- ollamadiffuser/core/inference/strategies/__init__.py +1 -0
- ollamadiffuser/core/inference/strategies/controlnet_strategy.py +170 -0
- ollamadiffuser/core/inference/strategies/flux_strategy.py +136 -0
- ollamadiffuser/core/inference/strategies/generic_strategy.py +164 -0
- ollamadiffuser/core/inference/strategies/gguf_strategy.py +113 -0
- ollamadiffuser/core/inference/strategies/hidream_strategy.py +104 -0
- ollamadiffuser/core/inference/strategies/sd15_strategy.py +134 -0
- ollamadiffuser/core/inference/strategies/sd3_strategy.py +80 -0
- ollamadiffuser/core/inference/strategies/sdxl_strategy.py +131 -0
- ollamadiffuser/core/inference/strategies/video_strategy.py +108 -0
- ollamadiffuser/mcp/__init__.py +0 -0
- ollamadiffuser/mcp/server.py +184 -0
- ollamadiffuser/ui/templates/index.html +62 -1
- ollamadiffuser/ui/web.py +116 -54
- {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.0.dist-info}/METADATA +321 -108
- ollamadiffuser-2.0.0.dist-info/RECORD +61 -0
- {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.0.dist-info}/WHEEL +1 -1
- {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.0.dist-info}/entry_points.txt +1 -0
- ollamadiffuser/core/models/registry.py +0 -384
- ollamadiffuser/ui/samples/.DS_Store +0 -0
- ollamadiffuser-1.2.3.dist-info/RECORD +0 -45
- {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.0.dist-info}/licenses/LICENSE +0 -0
- {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.0.dist-info}/top_level.txt +0 -0
ollamadiffuser/api/server.py
CHANGED
|
@@ -1,450 +1,450 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
import
|
|
1
|
+
"""OllamaDiffuser REST API Server"""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import base64
|
|
5
5
|
import io
|
|
6
6
|
import logging
|
|
7
|
-
from typing import
|
|
8
|
-
|
|
7
|
+
from typing import Any, Dict, Optional
|
|
8
|
+
|
|
9
|
+
import uvicorn
|
|
10
|
+
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
|
11
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
12
|
+
from fastapi.responses import Response
|
|
9
13
|
from PIL import Image
|
|
14
|
+
from pydantic import BaseModel
|
|
10
15
|
|
|
11
|
-
from ..core.models.manager import model_manager
|
|
12
16
|
from ..core.config.settings import settings
|
|
17
|
+
from ..core.models.manager import model_manager
|
|
13
18
|
|
|
14
|
-
|
|
15
|
-
|
|
19
|
+
logging.basicConfig(
|
|
20
|
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
|
21
|
+
)
|
|
16
22
|
logger = logging.getLogger(__name__)
|
|
17
23
|
|
|
18
|
-
|
|
24
|
+
|
|
25
|
+
# --- Request models ---
|
|
26
|
+
|
|
27
|
+
|
|
19
28
|
class GenerateRequest(BaseModel):
|
|
20
29
|
prompt: str
|
|
21
30
|
negative_prompt: str = "low quality, bad anatomy, worst quality, low resolution"
|
|
22
31
|
num_inference_steps: Optional[int] = None
|
|
23
|
-
steps: Optional[int] = None
|
|
32
|
+
steps: Optional[int] = None
|
|
24
33
|
guidance_scale: Optional[float] = None
|
|
25
|
-
cfg_scale: Optional[float] = None
|
|
34
|
+
cfg_scale: Optional[float] = None
|
|
26
35
|
width: int = 1024
|
|
27
36
|
height: int = 1024
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
37
|
+
seed: Optional[int] = None
|
|
38
|
+
response_format: Optional[str] = None # "b64_json" for JSON with base64, None for raw PNG
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class Img2ImgRequest(BaseModel):
|
|
42
|
+
prompt: str
|
|
43
|
+
negative_prompt: str = "low quality, bad anatomy, worst quality, low resolution"
|
|
44
|
+
num_inference_steps: Optional[int] = None
|
|
45
|
+
guidance_scale: Optional[float] = None
|
|
46
|
+
seed: Optional[int] = None
|
|
47
|
+
strength: float = 0.75
|
|
48
|
+
|
|
32
49
|
|
|
33
50
|
class LoadModelRequest(BaseModel):
|
|
34
51
|
model_name: str
|
|
35
52
|
|
|
53
|
+
|
|
36
54
|
class LoadLoRARequest(BaseModel):
|
|
37
55
|
lora_name: str
|
|
38
56
|
repo_id: str
|
|
39
57
|
weight_name: Optional[str] = None
|
|
40
58
|
scale: float = 1.0
|
|
41
59
|
|
|
60
|
+
|
|
61
|
+
# --- Helpers ---
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _image_to_response(image: Image.Image) -> Response:
|
|
65
|
+
"""Convert PIL Image to PNG Response."""
|
|
66
|
+
buf = io.BytesIO()
|
|
67
|
+
image.save(buf, format="PNG")
|
|
68
|
+
return Response(content=buf.getvalue(), media_type="image/png")
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _image_to_json_response(image: Image.Image) -> Dict[str, Any]:
|
|
72
|
+
"""Convert PIL Image to JSON response with base64-encoded PNG."""
|
|
73
|
+
buf = io.BytesIO()
|
|
74
|
+
image.save(buf, format="PNG")
|
|
75
|
+
b64_data = base64.b64encode(buf.getvalue()).decode("utf-8")
|
|
76
|
+
return {
|
|
77
|
+
"image": b64_data,
|
|
78
|
+
"format": "png",
|
|
79
|
+
"width": image.width,
|
|
80
|
+
"height": image.height,
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _get_engine():
|
|
85
|
+
"""Get the currently loaded inference engine or raise 400."""
|
|
86
|
+
if not model_manager.is_model_loaded():
|
|
87
|
+
raise HTTPException(status_code=400, detail="No model loaded")
|
|
88
|
+
return model_manager.loaded_model
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
# --- App factory ---
|
|
92
|
+
|
|
93
|
+
|
|
42
94
|
def create_app() -> FastAPI:
|
|
43
95
|
"""Create FastAPI application"""
|
|
44
96
|
app = FastAPI(
|
|
45
97
|
title="OllamaDiffuser API",
|
|
46
98
|
description="Image generation model management and inference API",
|
|
47
|
-
version="
|
|
99
|
+
version="2.0.0",
|
|
48
100
|
)
|
|
49
|
-
|
|
50
|
-
# Add CORS middleware
|
|
101
|
+
|
|
51
102
|
if settings.server.enable_cors:
|
|
52
103
|
app.add_middleware(
|
|
53
104
|
CORSMiddleware,
|
|
54
105
|
allow_origins=["*"],
|
|
55
|
-
allow_credentials=True,
|
|
56
106
|
allow_methods=["*"],
|
|
57
107
|
allow_headers=["*"],
|
|
58
108
|
)
|
|
59
|
-
|
|
60
|
-
# Root
|
|
109
|
+
|
|
110
|
+
# --- Root ---
|
|
111
|
+
|
|
61
112
|
@app.get("/")
|
|
62
113
|
async def root():
|
|
63
|
-
"""Root endpoint with API information"""
|
|
64
114
|
return {
|
|
65
115
|
"name": "OllamaDiffuser API",
|
|
66
|
-
"version": "
|
|
67
|
-
"description": "Image generation model management and inference API",
|
|
116
|
+
"version": "2.0.0",
|
|
68
117
|
"status": "running",
|
|
69
118
|
"endpoints": {
|
|
70
|
-
"
|
|
71
|
-
"
|
|
72
|
-
"health_check": "/api/health",
|
|
119
|
+
"docs": "/docs",
|
|
120
|
+
"health": "/api/health",
|
|
73
121
|
"models": "/api/models",
|
|
74
|
-
"generate": "/api/generate"
|
|
122
|
+
"generate": "/api/generate",
|
|
123
|
+
"img2img": "/api/generate/img2img",
|
|
124
|
+
"inpaint": "/api/generate/inpaint",
|
|
75
125
|
},
|
|
76
|
-
"usage": {
|
|
77
|
-
"web_ui": "Use 'ollamadiffuser --mode ui' to start the web interface",
|
|
78
|
-
"cli": "Use 'ollamadiffuser --help' for command line options",
|
|
79
|
-
"api_docs": "Visit /docs for interactive API documentation"
|
|
80
|
-
}
|
|
81
126
|
}
|
|
82
|
-
|
|
83
|
-
#
|
|
127
|
+
|
|
128
|
+
# --- Health ---
|
|
129
|
+
|
|
130
|
+
@app.get("/api/health")
|
|
131
|
+
async def health_check():
|
|
132
|
+
return {
|
|
133
|
+
"status": "healthy",
|
|
134
|
+
"model_loaded": model_manager.is_model_loaded(),
|
|
135
|
+
"current_model": model_manager.get_current_model(),
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
# --- Model management ---
|
|
139
|
+
|
|
84
140
|
@app.get("/api/models")
|
|
85
141
|
async def list_models():
|
|
86
|
-
"""List all models"""
|
|
87
142
|
return {
|
|
88
143
|
"available": model_manager.list_available_models(),
|
|
89
144
|
"installed": model_manager.list_installed_models(),
|
|
90
|
-
"current": model_manager.get_current_model()
|
|
145
|
+
"current": model_manager.get_current_model(),
|
|
91
146
|
}
|
|
92
|
-
|
|
147
|
+
|
|
93
148
|
@app.get("/api/models/running")
|
|
94
149
|
async def get_running_model():
|
|
95
|
-
"""Get currently running model"""
|
|
96
150
|
if model_manager.is_model_loaded():
|
|
97
151
|
engine = model_manager.loaded_model
|
|
98
152
|
return {
|
|
99
153
|
"model": model_manager.get_current_model(),
|
|
100
154
|
"info": engine.get_model_info(),
|
|
101
|
-
"loaded": True
|
|
155
|
+
"loaded": True,
|
|
102
156
|
}
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
157
|
+
return {"loaded": False}
|
|
158
|
+
|
|
106
159
|
@app.get("/api/models/{model_name}")
|
|
107
160
|
async def get_model_info(model_name: str):
|
|
108
|
-
"""Get model detailed information"""
|
|
109
161
|
info = model_manager.get_model_info(model_name)
|
|
110
162
|
if info is None:
|
|
111
|
-
raise HTTPException(status_code=404, detail="Model
|
|
163
|
+
raise HTTPException(status_code=404, detail="Model not found")
|
|
112
164
|
return info
|
|
113
|
-
|
|
165
|
+
|
|
114
166
|
@app.post("/api/models/pull")
|
|
115
167
|
async def pull_model(model_name: str):
|
|
116
|
-
|
|
117
|
-
if
|
|
168
|
+
success = await asyncio.to_thread(model_manager.pull_model, model_name)
|
|
169
|
+
if success:
|
|
118
170
|
return {"message": f"Model {model_name} downloaded successfully"}
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
171
|
+
raise HTTPException(
|
|
172
|
+
status_code=400, detail=f"Failed to download model {model_name}"
|
|
173
|
+
)
|
|
174
|
+
|
|
122
175
|
@app.post("/api/models/load")
|
|
123
176
|
async def load_model(request: LoadModelRequest):
|
|
124
|
-
|
|
125
|
-
if
|
|
177
|
+
success = await asyncio.to_thread(model_manager.load_model, request.model_name)
|
|
178
|
+
if success:
|
|
126
179
|
return {"message": f"Model {request.model_name} loaded successfully"}
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
180
|
+
raise HTTPException(
|
|
181
|
+
status_code=400, detail=f"Failed to load model {request.model_name}"
|
|
182
|
+
)
|
|
183
|
+
|
|
130
184
|
@app.post("/api/models/unload")
|
|
131
185
|
async def unload_model():
|
|
132
|
-
"""Unload current model"""
|
|
133
186
|
model_manager.unload_model()
|
|
134
187
|
return {"message": "Model unloaded"}
|
|
135
|
-
|
|
188
|
+
|
|
136
189
|
@app.delete("/api/models/{model_name}")
|
|
137
190
|
async def remove_model(model_name: str):
|
|
138
|
-
"""Remove model"""
|
|
139
191
|
if model_manager.remove_model(model_name):
|
|
140
|
-
return {"message": f"Model {model_name} removed
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
"""Load LoRA weights into current model"""
|
|
148
|
-
# Check if model is loaded
|
|
149
|
-
if not model_manager.is_model_loaded():
|
|
150
|
-
raise HTTPException(status_code=400, detail="No model loaded, please load a model first")
|
|
151
|
-
|
|
152
|
-
try:
|
|
153
|
-
# Get current loaded inference engine
|
|
154
|
-
engine = model_manager.loaded_model
|
|
155
|
-
|
|
156
|
-
# Load LoRA weights
|
|
157
|
-
success = engine.load_lora_runtime(
|
|
158
|
-
repo_id=request.repo_id,
|
|
159
|
-
weight_name=request.weight_name,
|
|
160
|
-
scale=request.scale
|
|
161
|
-
)
|
|
162
|
-
|
|
163
|
-
if success:
|
|
164
|
-
return {"message": f"LoRA {request.lora_name} loaded successfully with scale {request.scale}"}
|
|
165
|
-
else:
|
|
166
|
-
raise HTTPException(status_code=400, detail=f"Failed to load LoRA {request.lora_name}")
|
|
167
|
-
|
|
168
|
-
except Exception as e:
|
|
169
|
-
logger.error(f"LoRA loading failed: {e}")
|
|
170
|
-
raise HTTPException(status_code=500, detail=f"LoRA loading failed: {str(e)}")
|
|
171
|
-
|
|
172
|
-
@app.post("/api/lora/unload")
|
|
173
|
-
async def unload_lora():
|
|
174
|
-
"""Unload current LoRA weights"""
|
|
175
|
-
# Check if model is loaded
|
|
176
|
-
if not model_manager.is_model_loaded():
|
|
177
|
-
raise HTTPException(status_code=400, detail="No model loaded")
|
|
178
|
-
|
|
179
|
-
try:
|
|
180
|
-
# Get current loaded inference engine
|
|
181
|
-
engine = model_manager.loaded_model
|
|
182
|
-
|
|
183
|
-
# Unload LoRA weights
|
|
184
|
-
success = engine.unload_lora()
|
|
185
|
-
|
|
186
|
-
if success:
|
|
187
|
-
return {"message": "LoRA weights unloaded successfully"}
|
|
188
|
-
else:
|
|
189
|
-
raise HTTPException(status_code=400, detail="Failed to unload LoRA weights")
|
|
190
|
-
|
|
191
|
-
except Exception as e:
|
|
192
|
-
logger.error(f"LoRA unloading failed: {e}")
|
|
193
|
-
raise HTTPException(status_code=500, detail=f"LoRA unloading failed: {str(e)}")
|
|
194
|
-
|
|
195
|
-
@app.get("/api/lora/status")
|
|
196
|
-
async def get_lora_status():
|
|
197
|
-
"""Get current LoRA status"""
|
|
198
|
-
# Check if model is loaded
|
|
199
|
-
if not model_manager.is_model_loaded():
|
|
200
|
-
return {"loaded": False, "message": "No model loaded"}
|
|
201
|
-
|
|
202
|
-
try:
|
|
203
|
-
# Get current loaded inference engine
|
|
204
|
-
engine = model_manager.loaded_model
|
|
205
|
-
|
|
206
|
-
# Check tracked LoRA state
|
|
207
|
-
if hasattr(engine, 'current_lora') and engine.current_lora:
|
|
208
|
-
lora_info = engine.current_lora.copy()
|
|
209
|
-
return {
|
|
210
|
-
"loaded": True,
|
|
211
|
-
"info": lora_info,
|
|
212
|
-
"message": "LoRA loaded"
|
|
213
|
-
}
|
|
214
|
-
else:
|
|
215
|
-
return {
|
|
216
|
-
"loaded": False,
|
|
217
|
-
"info": None,
|
|
218
|
-
"message": "No LoRA loaded"
|
|
219
|
-
}
|
|
220
|
-
|
|
221
|
-
except Exception as e:
|
|
222
|
-
logger.error(f"Failed to get LoRA status: {e}")
|
|
223
|
-
raise HTTPException(status_code=500, detail=f"Failed to get LoRA status: {str(e)}")
|
|
224
|
-
|
|
225
|
-
# Image generation endpoints
|
|
192
|
+
return {"message": f"Model {model_name} removed"}
|
|
193
|
+
raise HTTPException(
|
|
194
|
+
status_code=400, detail=f"Failed to remove model {model_name}"
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
# --- Image generation (async via thread pool) ---
|
|
198
|
+
|
|
226
199
|
@app.post("/api/generate")
|
|
227
200
|
async def generate_image(request: GenerateRequest):
|
|
228
|
-
"""Generate image"""
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
201
|
+
"""Generate image from text prompt"""
|
|
202
|
+
engine = _get_engine()
|
|
203
|
+
|
|
204
|
+
steps = request.steps if request.steps is not None else request.num_inference_steps
|
|
205
|
+
guidance = (
|
|
206
|
+
request.cfg_scale
|
|
207
|
+
if request.cfg_scale is not None
|
|
208
|
+
else request.guidance_scale
|
|
209
|
+
)
|
|
210
|
+
|
|
233
211
|
try:
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
# Handle parameter aliasing - prioritize shorter names for convenience
|
|
238
|
-
steps = request.steps if request.steps is not None else request.num_inference_steps
|
|
239
|
-
guidance = request.cfg_scale if request.cfg_scale is not None else request.guidance_scale
|
|
240
|
-
|
|
241
|
-
# Generate image
|
|
242
|
-
image = engine.generate_image(
|
|
212
|
+
image = await asyncio.to_thread(
|
|
213
|
+
engine.generate_image,
|
|
243
214
|
prompt=request.prompt,
|
|
244
215
|
negative_prompt=request.negative_prompt,
|
|
245
216
|
num_inference_steps=steps,
|
|
246
|
-
steps=steps, # Pass both for GGUF compatibility
|
|
247
217
|
guidance_scale=guidance,
|
|
248
|
-
cfg_scale=guidance, # Pass both for GGUF compatibility
|
|
249
218
|
width=request.width,
|
|
250
219
|
height=request.height,
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
220
|
+
seed=request.seed,
|
|
221
|
+
)
|
|
222
|
+
if request.response_format == "b64_json":
|
|
223
|
+
return _image_to_json_response(image)
|
|
224
|
+
return _image_to_response(image)
|
|
225
|
+
except Exception as e:
|
|
226
|
+
logger.error(f"Generation failed: {e}", exc_info=True)
|
|
227
|
+
raise HTTPException(status_code=500, detail="Image generation failed")
|
|
228
|
+
|
|
229
|
+
@app.post("/api/generate/img2img")
|
|
230
|
+
async def generate_img2img(
|
|
231
|
+
prompt: str = Form(...),
|
|
232
|
+
negative_prompt: str = Form("low quality, bad anatomy, worst quality"),
|
|
233
|
+
num_inference_steps: Optional[int] = Form(None),
|
|
234
|
+
guidance_scale: Optional[float] = Form(None),
|
|
235
|
+
seed: Optional[int] = Form(None),
|
|
236
|
+
strength: float = Form(0.75),
|
|
237
|
+
image: UploadFile = File(...),
|
|
238
|
+
):
|
|
239
|
+
"""Image-to-image generation"""
|
|
240
|
+
engine = _get_engine()
|
|
241
|
+
|
|
242
|
+
image_data = await image.read()
|
|
243
|
+
input_image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
|
244
|
+
|
|
245
|
+
try:
|
|
246
|
+
result = await asyncio.to_thread(
|
|
247
|
+
engine.generate_image,
|
|
248
|
+
prompt=prompt,
|
|
249
|
+
negative_prompt=negative_prompt,
|
|
250
|
+
num_inference_steps=num_inference_steps,
|
|
251
|
+
guidance_scale=guidance_scale,
|
|
252
|
+
width=input_image.width,
|
|
253
|
+
height=input_image.height,
|
|
254
|
+
seed=seed,
|
|
255
|
+
image=input_image,
|
|
256
|
+
strength=strength,
|
|
255
257
|
)
|
|
256
|
-
|
|
257
|
-
# Convert PIL image to bytes
|
|
258
|
-
img_byte_arr = io.BytesIO()
|
|
259
|
-
image.save(img_byte_arr, format='PNG')
|
|
260
|
-
img_byte_arr = img_byte_arr.getvalue()
|
|
261
|
-
|
|
262
|
-
return Response(content=img_byte_arr, media_type="image/png")
|
|
263
|
-
|
|
258
|
+
return _image_to_response(result)
|
|
264
259
|
except Exception as e:
|
|
265
|
-
logger.error(f"
|
|
266
|
-
raise HTTPException(status_code=500, detail=
|
|
267
|
-
|
|
260
|
+
logger.error(f"img2img failed: {e}", exc_info=True)
|
|
261
|
+
raise HTTPException(status_code=500, detail="Image-to-image generation failed")
|
|
262
|
+
|
|
263
|
+
@app.post("/api/generate/inpaint")
|
|
264
|
+
async def generate_inpaint(
|
|
265
|
+
prompt: str = Form(...),
|
|
266
|
+
negative_prompt: str = Form("low quality, bad anatomy, worst quality"),
|
|
267
|
+
num_inference_steps: Optional[int] = Form(None),
|
|
268
|
+
guidance_scale: Optional[float] = Form(None),
|
|
269
|
+
seed: Optional[int] = Form(None),
|
|
270
|
+
strength: float = Form(0.75),
|
|
271
|
+
image: UploadFile = File(...),
|
|
272
|
+
mask: UploadFile = File(...),
|
|
273
|
+
):
|
|
274
|
+
"""Inpainting generation"""
|
|
275
|
+
engine = _get_engine()
|
|
276
|
+
|
|
277
|
+
image_data = await image.read()
|
|
278
|
+
mask_data = await mask.read()
|
|
279
|
+
input_image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
|
280
|
+
mask_image = Image.open(io.BytesIO(mask_data)).convert("RGB")
|
|
281
|
+
|
|
282
|
+
try:
|
|
283
|
+
result = await asyncio.to_thread(
|
|
284
|
+
engine.generate_image,
|
|
285
|
+
prompt=prompt,
|
|
286
|
+
negative_prompt=negative_prompt,
|
|
287
|
+
num_inference_steps=num_inference_steps,
|
|
288
|
+
guidance_scale=guidance_scale,
|
|
289
|
+
width=input_image.width,
|
|
290
|
+
height=input_image.height,
|
|
291
|
+
seed=seed,
|
|
292
|
+
image=input_image,
|
|
293
|
+
mask_image=mask_image,
|
|
294
|
+
strength=strength,
|
|
295
|
+
)
|
|
296
|
+
return _image_to_response(result)
|
|
297
|
+
except Exception as e:
|
|
298
|
+
logger.error(f"Inpainting failed: {e}", exc_info=True)
|
|
299
|
+
raise HTTPException(status_code=500, detail="Inpainting generation failed")
|
|
300
|
+
|
|
268
301
|
@app.post("/api/generate/controlnet")
|
|
269
|
-
async def
|
|
302
|
+
async def generate_controlnet(
|
|
270
303
|
prompt: str = Form(...),
|
|
271
|
-
negative_prompt: str = Form("low quality, bad anatomy, worst quality
|
|
304
|
+
negative_prompt: str = Form("low quality, bad anatomy, worst quality"),
|
|
272
305
|
num_inference_steps: Optional[int] = Form(None),
|
|
273
306
|
guidance_scale: Optional[float] = Form(None),
|
|
274
307
|
width: int = Form(1024),
|
|
275
308
|
height: int = Form(1024),
|
|
309
|
+
seed: Optional[int] = Form(None),
|
|
276
310
|
controlnet_conditioning_scale: float = Form(1.0),
|
|
277
311
|
control_guidance_start: float = Form(0.0),
|
|
278
312
|
control_guidance_end: float = Form(1.0),
|
|
279
|
-
control_image: Optional[UploadFile] = File(None)
|
|
313
|
+
control_image: Optional[UploadFile] = File(None),
|
|
280
314
|
):
|
|
281
|
-
"""Generate image with ControlNet
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
315
|
+
"""Generate image with ControlNet"""
|
|
316
|
+
engine = _get_engine()
|
|
317
|
+
|
|
318
|
+
if not engine.is_controlnet_pipeline:
|
|
319
|
+
raise HTTPException(
|
|
320
|
+
status_code=400, detail="Current model is not a ControlNet model"
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
control_pil = None
|
|
324
|
+
if control_image:
|
|
325
|
+
data = await control_image.read()
|
|
326
|
+
control_pil = Image.open(io.BytesIO(data)).convert("RGB")
|
|
327
|
+
|
|
286
328
|
try:
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
# Check if this is a ControlNet model
|
|
291
|
-
if not engine.is_controlnet_pipeline:
|
|
292
|
-
raise HTTPException(status_code=400, detail="Current model is not a ControlNet model")
|
|
293
|
-
|
|
294
|
-
# Process control image if provided
|
|
295
|
-
control_image_pil = None
|
|
296
|
-
if control_image:
|
|
297
|
-
# Read uploaded image
|
|
298
|
-
image_data = await control_image.read()
|
|
299
|
-
control_image_pil = Image.open(io.BytesIO(image_data)).convert('RGB')
|
|
300
|
-
|
|
301
|
-
# Generate image
|
|
302
|
-
image = engine.generate_image(
|
|
329
|
+
result = await asyncio.to_thread(
|
|
330
|
+
engine.generate_image,
|
|
303
331
|
prompt=prompt,
|
|
304
332
|
negative_prompt=negative_prompt,
|
|
305
333
|
num_inference_steps=num_inference_steps,
|
|
306
334
|
guidance_scale=guidance_scale,
|
|
307
335
|
width=width,
|
|
308
336
|
height=height,
|
|
309
|
-
|
|
337
|
+
seed=seed,
|
|
338
|
+
control_image=control_pil,
|
|
310
339
|
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
|
311
340
|
control_guidance_start=control_guidance_start,
|
|
312
|
-
control_guidance_end=control_guidance_end
|
|
341
|
+
control_guidance_end=control_guidance_end,
|
|
313
342
|
)
|
|
314
|
-
|
|
315
|
-
# Convert PIL image to bytes
|
|
316
|
-
img_byte_arr = io.BytesIO()
|
|
317
|
-
image.save(img_byte_arr, format='PNG')
|
|
318
|
-
img_byte_arr = img_byte_arr.getvalue()
|
|
319
|
-
|
|
320
|
-
return Response(content=img_byte_arr, media_type="image/png")
|
|
321
|
-
|
|
343
|
+
return _image_to_response(result)
|
|
322
344
|
except Exception as e:
|
|
323
|
-
logger.error(f"ControlNet
|
|
324
|
-
raise HTTPException(status_code=500, detail=
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
345
|
+
logger.error(f"ControlNet generation failed: {e}", exc_info=True)
|
|
346
|
+
raise HTTPException(status_code=500, detail="ControlNet generation failed")
|
|
347
|
+
|
|
348
|
+
# --- LoRA management ---
|
|
349
|
+
|
|
350
|
+
@app.post("/api/lora/load")
|
|
351
|
+
async def load_lora(request: LoadLoRARequest):
|
|
352
|
+
engine = _get_engine()
|
|
329
353
|
try:
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
"
|
|
337
|
-
|
|
338
|
-
"available_types": controlnet_preprocessor.get_available_types(),
|
|
339
|
-
"message": "ControlNet preprocessors initialized successfully" if success else "Failed to initialize ControlNet preprocessors"
|
|
340
|
-
}
|
|
354
|
+
success = engine.load_lora_runtime(
|
|
355
|
+
repo_id=request.repo_id,
|
|
356
|
+
weight_name=request.weight_name,
|
|
357
|
+
scale=request.scale,
|
|
358
|
+
)
|
|
359
|
+
if success:
|
|
360
|
+
return {"message": f"LoRA {request.lora_name} loaded (scale={request.scale})"}
|
|
361
|
+
raise HTTPException(status_code=400, detail="Failed to load LoRA")
|
|
341
362
|
except Exception as e:
|
|
342
|
-
logger.error(f"
|
|
343
|
-
raise HTTPException(status_code=500, detail=
|
|
344
|
-
|
|
363
|
+
logger.error(f"LoRA load failed: {e}", exc_info=True)
|
|
364
|
+
raise HTTPException(status_code=500, detail="Failed to load LoRA")
|
|
365
|
+
|
|
366
|
+
@app.post("/api/lora/unload")
|
|
367
|
+
async def unload_lora():
|
|
368
|
+
engine = _get_engine()
|
|
369
|
+
if engine.unload_lora():
|
|
370
|
+
return {"message": "LoRA unloaded"}
|
|
371
|
+
raise HTTPException(status_code=400, detail="Failed to unload LoRA")
|
|
372
|
+
|
|
373
|
+
@app.get("/api/lora/status")
|
|
374
|
+
async def get_lora_status():
|
|
375
|
+
if not model_manager.is_model_loaded():
|
|
376
|
+
return {"loaded": False, "message": "No model loaded"}
|
|
377
|
+
engine = model_manager.loaded_model
|
|
378
|
+
if hasattr(engine, "current_lora") and engine.current_lora:
|
|
379
|
+
return {"loaded": True, "info": engine.current_lora}
|
|
380
|
+
return {"loaded": False, "info": None}
|
|
381
|
+
|
|
382
|
+
# --- ControlNet preprocessors ---
|
|
383
|
+
|
|
384
|
+
@app.post("/api/controlnet/initialize")
|
|
385
|
+
async def initialize_controlnet():
|
|
386
|
+
from ..core.utils.controlnet_preprocessors import controlnet_preprocessor
|
|
387
|
+
|
|
388
|
+
success = controlnet_preprocessor.initialize(force=True)
|
|
389
|
+
return {
|
|
390
|
+
"success": success,
|
|
391
|
+
"initialized": controlnet_preprocessor.is_initialized(),
|
|
392
|
+
"available_types": controlnet_preprocessor.get_available_types(),
|
|
393
|
+
}
|
|
394
|
+
|
|
345
395
|
@app.get("/api/controlnet/preprocessors")
|
|
346
396
|
async def get_controlnet_preprocessors():
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
"canny": "Edge detection using Canny algorithm",
|
|
356
|
-
"depth": "Depth estimation for depth-based control",
|
|
357
|
-
"openpose": "Human pose detection for pose control",
|
|
358
|
-
"scribble": "Scribble/sketch detection for artistic control",
|
|
359
|
-
"hed": "Holistically-nested edge detection",
|
|
360
|
-
"mlsd": "Mobile line segment detection",
|
|
361
|
-
"normal": "Surface normal estimation",
|
|
362
|
-
"lineart": "Line art detection",
|
|
363
|
-
"lineart_anime": "Anime-style line art detection",
|
|
364
|
-
"shuffle": "Content shuffling for style transfer"
|
|
365
|
-
}
|
|
366
|
-
}
|
|
367
|
-
except Exception as e:
|
|
368
|
-
logger.error(f"Failed to get ControlNet preprocessors: {e}")
|
|
369
|
-
raise HTTPException(status_code=500, detail=f"Failed to get ControlNet preprocessors: {str(e)}")
|
|
370
|
-
|
|
397
|
+
from ..core.utils.controlnet_preprocessors import controlnet_preprocessor
|
|
398
|
+
|
|
399
|
+
return {
|
|
400
|
+
"available_types": controlnet_preprocessor.get_available_types(),
|
|
401
|
+
"available": controlnet_preprocessor.is_available(),
|
|
402
|
+
"initialized": controlnet_preprocessor.is_initialized(),
|
|
403
|
+
}
|
|
404
|
+
|
|
371
405
|
@app.post("/api/controlnet/preprocess")
|
|
372
406
|
async def preprocess_control_image(
|
|
373
407
|
control_type: str = Form(...),
|
|
374
|
-
image: UploadFile = File(...)
|
|
408
|
+
image: UploadFile = File(...),
|
|
375
409
|
):
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
# Preprocess image
|
|
391
|
-
processed_image = controlnet_preprocessor.preprocess(input_image, control_type)
|
|
392
|
-
|
|
393
|
-
# Convert PIL image to bytes
|
|
394
|
-
img_byte_arr = io.BytesIO()
|
|
395
|
-
processed_image.save(img_byte_arr, format='PNG')
|
|
396
|
-
img_byte_arr = img_byte_arr.getvalue()
|
|
397
|
-
|
|
398
|
-
return Response(content=img_byte_arr, media_type="image/png")
|
|
399
|
-
|
|
400
|
-
except Exception as e:
|
|
401
|
-
logger.error(f"Image preprocessing failed: {e}")
|
|
402
|
-
raise HTTPException(status_code=500, detail=f"Image preprocessing failed: {str(e)}")
|
|
403
|
-
|
|
404
|
-
# Health check endpoints
|
|
405
|
-
@app.get("/api/health")
|
|
406
|
-
async def health_check():
|
|
407
|
-
"""Health check"""
|
|
408
|
-
return {
|
|
409
|
-
"status": "healthy",
|
|
410
|
-
"model_loaded": model_manager.is_model_loaded(),
|
|
411
|
-
"current_model": model_manager.get_current_model()
|
|
412
|
-
}
|
|
413
|
-
|
|
414
|
-
# Server management endpoints
|
|
410
|
+
from ..core.utils.controlnet_preprocessors import controlnet_preprocessor
|
|
411
|
+
|
|
412
|
+
if not controlnet_preprocessor.is_initialized():
|
|
413
|
+
if not controlnet_preprocessor.initialize():
|
|
414
|
+
raise HTTPException(status_code=500, detail="Failed to init preprocessors")
|
|
415
|
+
|
|
416
|
+
data = await image.read()
|
|
417
|
+
input_image = Image.open(io.BytesIO(data)).convert("RGB")
|
|
418
|
+
processed = controlnet_preprocessor.preprocess(input_image, control_type)
|
|
419
|
+
return _image_to_response(processed)
|
|
420
|
+
|
|
421
|
+
# --- Server management ---
|
|
422
|
+
|
|
415
423
|
@app.post("/api/shutdown")
|
|
416
424
|
async def shutdown_server():
|
|
417
|
-
"""Gracefully shutdown the server"""
|
|
418
425
|
import os
|
|
419
426
|
import signal
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
# Unload model first
|
|
427
|
+
|
|
423
428
|
model_manager.unload_model()
|
|
424
|
-
|
|
425
|
-
# Schedule server shutdown
|
|
429
|
+
|
|
426
430
|
def shutdown():
|
|
427
431
|
os.kill(os.getpid(), signal.SIGTERM)
|
|
428
|
-
|
|
429
|
-
# Delay shutdown to allow response to be sent
|
|
432
|
+
|
|
430
433
|
asyncio.get_event_loop().call_later(0.5, shutdown)
|
|
431
|
-
|
|
432
434
|
return {"message": "Server shutting down..."}
|
|
433
|
-
|
|
435
|
+
|
|
434
436
|
return app
|
|
435
437
|
|
|
438
|
+
|
|
436
439
|
def run_server(host: str = None, port: int = None):
|
|
437
|
-
"""Start server"""
|
|
438
|
-
# Use default values from configuration
|
|
440
|
+
"""Start the API server"""
|
|
439
441
|
host = host or settings.server.host
|
|
440
442
|
port = port or settings.server.port
|
|
441
|
-
|
|
442
|
-
# Create FastAPI application
|
|
443
|
+
|
|
443
444
|
app = create_app()
|
|
444
|
-
|
|
445
|
-
# Run server with uvicorn
|
|
446
445
|
logger.info(f"Starting server: http://{host}:{port}")
|
|
447
446
|
uvicorn.run(app, host=host, port=port, log_level="info")
|
|
448
447
|
|
|
448
|
+
|
|
449
449
|
if __name__ == "__main__":
|
|
450
|
-
run_server()
|
|
450
|
+
run_server()
|