ollamadiffuser 1.0.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ollamadiffuser/ui/web.py CHANGED
@@ -4,10 +4,16 @@ from fastapi.staticfiles import StaticFiles
4
4
  from fastapi.templating import Jinja2Templates
5
5
  import io
6
6
  import base64
7
+ import logging
8
+ import json
7
9
  from pathlib import Path
10
+ from PIL import Image
8
11
 
9
12
  from ..core.models.manager import model_manager
10
13
  from ..core.utils.lora_manager import lora_manager
14
+ from ..core.utils.controlnet_preprocessors import controlnet_preprocessor
15
+
16
+ logger = logging.getLogger(__name__)
11
17
 
12
18
  # Get templates directory
13
19
  templates_dir = Path(__file__).parent / "templates"
@@ -17,9 +23,18 @@ def create_ui_app() -> FastAPI:
17
23
  """Create Web UI application"""
18
24
  app = FastAPI(title="OllamaDiffuser Web UI")
19
25
 
20
- @app.get("/", response_class=HTMLResponse)
21
- async def home(request: Request):
22
- """Home page"""
26
+ # Mount static files for samples
27
+ samples_dir = Path(__file__).parent / "samples"
28
+ logger.info(f"Samples directory: {samples_dir}")
29
+ logger.info(f"Samples directory exists: {samples_dir.exists()}")
30
+ if samples_dir.exists():
31
+ logger.info(f"Mounting samples directory: {samples_dir}")
32
+ app.mount("/samples", StaticFiles(directory=str(samples_dir)), name="samples")
33
+ else:
34
+ logger.warning(f"Samples directory not found: {samples_dir}")
35
+
36
+ def get_template_context(request: Request):
37
+ """Get common template context"""
23
38
  models = model_manager.list_available_models()
24
39
  installed_models = model_manager.list_installed_models()
25
40
  current_model = model_manager.get_current_model()
@@ -29,17 +44,58 @@ def create_ui_app() -> FastAPI:
29
44
  installed_loras = lora_manager.list_installed_loras()
30
45
  current_lora = lora_manager.get_current_lora()
31
46
 
32
- # Don't auto-load model on startup - let user choose
47
+ # Check if current model is ControlNet
48
+ is_controlnet_model = False
49
+ controlnet_type = None
50
+ model_parameters = {}
51
+ if current_model and model_loaded:
52
+ engine = model_manager.loaded_model
53
+ if hasattr(engine, 'is_controlnet_pipeline'):
54
+ is_controlnet_model = engine.is_controlnet_pipeline
55
+ if is_controlnet_model:
56
+ # Get ControlNet type from model info
57
+ model_info = model_manager.get_model_info(current_model)
58
+ controlnet_type = model_info.get('controlnet_type', 'canny') if model_info else 'canny'
59
+
60
+ # Get model parameters for current model
61
+ model_info = model_manager.get_model_info(current_model)
62
+ if model_info and 'parameters' in model_info:
63
+ model_parameters = model_info['parameters']
64
+
65
+ # Get available ControlNet preprocessors (without initializing)
66
+ available_preprocessors = controlnet_preprocessor.get_available_types()
67
+
68
+ # Load sample metadata
69
+ sample_metadata = {}
70
+ metadata_file = samples_dir / "metadata.json"
71
+ if metadata_file.exists():
72
+ try:
73
+ with open(metadata_file, 'r') as f:
74
+ sample_metadata = json.load(f)
75
+ except Exception as e:
76
+ logger.warning(f"Failed to load sample metadata: {e}")
33
77
 
34
- return templates.TemplateResponse("index.html", {
78
+ return {
35
79
  "request": request,
36
80
  "models": models,
37
81
  "installed_models": installed_models,
38
82
  "current_model": current_model,
39
83
  "model_loaded": model_loaded,
40
84
  "installed_loras": installed_loras,
41
- "current_lora": current_lora
42
- })
85
+ "current_lora": current_lora,
86
+ "is_controlnet_model": is_controlnet_model,
87
+ "controlnet_type": controlnet_type,
88
+ "available_preprocessors": available_preprocessors,
89
+ "controlnet_available": controlnet_preprocessor.is_available(),
90
+ "controlnet_initialized": controlnet_preprocessor.is_initialized(),
91
+ "sample_metadata": sample_metadata,
92
+ "model_parameters": model_parameters
93
+ }
94
+
95
+ @app.get("/", response_class=HTMLResponse)
96
+ async def home(request: Request):
97
+ """Home page"""
98
+ return templates.TemplateResponse("index.html", get_template_context(request))
43
99
 
44
100
  @app.post("/generate")
45
101
  async def generate_image_ui(
@@ -49,11 +105,16 @@ def create_ui_app() -> FastAPI:
49
105
  num_inference_steps: int = Form(28),
50
106
  guidance_scale: float = Form(3.5),
51
107
  width: int = Form(1024),
52
- height: int = Form(1024)
108
+ height: int = Form(1024),
109
+ control_image: UploadFile = File(None),
110
+ controlnet_conditioning_scale: float = Form(1.0),
111
+ control_guidance_start: float = Form(0.0),
112
+ control_guidance_end: float = Form(1.0)
53
113
  ):
54
114
  """Generate image (Web UI)"""
55
115
  error_message = None
56
116
  image_b64 = None
117
+ control_image_b64 = None
57
118
 
58
119
  try:
59
120
  # Check if model is actually loaded in memory
@@ -67,6 +128,26 @@ def create_ui_app() -> FastAPI:
67
128
  if engine is None:
68
129
  error_message = "Model engine is not available. Please reload the model."
69
130
  else:
131
+ # Process control image if provided
132
+ control_image_pil = None
133
+ if control_image and control_image.filename:
134
+ # Initialize ControlNet preprocessors if needed
135
+ if not controlnet_preprocessor.is_initialized():
136
+ logger.info("Initializing ControlNet preprocessors for image processing...")
137
+ if not controlnet_preprocessor.initialize():
138
+ error_message = "Failed to initialize ControlNet preprocessors. Please check your installation."
139
+
140
+ if not error_message:
141
+ # Read uploaded image
142
+ image_data = await control_image.read()
143
+ control_image_pil = Image.open(io.BytesIO(image_data)).convert('RGB')
144
+
145
+ # Convert control image to base64 for display
146
+ img_buffer = io.BytesIO()
147
+ control_image_pil.save(img_buffer, format='PNG')
148
+ img_buffer.seek(0)
149
+ control_image_b64 = base64.b64encode(img_buffer.getvalue()).decode()
150
+
70
151
  # Generate image
71
152
  image = engine.generate_image(
72
153
  prompt=prompt,
@@ -74,7 +155,11 @@ def create_ui_app() -> FastAPI:
74
155
  num_inference_steps=num_inference_steps,
75
156
  guidance_scale=guidance_scale,
76
157
  width=width,
77
- height=height
158
+ height=height,
159
+ control_image=control_image_pil,
160
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
161
+ control_guidance_start=control_guidance_start,
162
+ control_guidance_end=control_guidance_end
78
163
  )
79
164
 
80
165
  # Convert to base64
@@ -87,29 +172,55 @@ def create_ui_app() -> FastAPI:
87
172
  error_message = f"Image generation failed: {str(e)}"
88
173
 
89
174
  # Return result page
90
- models = model_manager.list_available_models()
91
- installed_models = model_manager.list_installed_models()
92
- current_model = model_manager.get_current_model()
93
- installed_loras = lora_manager.list_installed_loras()
94
- current_lora = lora_manager.get_current_lora()
95
-
96
- return templates.TemplateResponse("index.html", {
97
- "request": request,
98
- "models": models,
99
- "installed_models": installed_models,
100
- "current_model": current_model,
101
- "model_loaded": model_manager.is_model_loaded(),
102
- "installed_loras": installed_loras,
103
- "current_lora": current_lora,
175
+ context = get_template_context(request)
176
+ context.update({
104
177
  "prompt": prompt,
105
178
  "negative_prompt": negative_prompt,
106
179
  "num_inference_steps": num_inference_steps,
107
180
  "guidance_scale": guidance_scale,
108
181
  "width": width,
109
182
  "height": height,
183
+ "controlnet_conditioning_scale": controlnet_conditioning_scale,
184
+ "control_guidance_start": control_guidance_start,
185
+ "control_guidance_end": control_guidance_end,
110
186
  "image_b64": image_b64,
187
+ "control_image_b64": control_image_b64,
111
188
  "error_message": error_message
112
189
  })
190
+
191
+ return templates.TemplateResponse("index.html", context)
192
+
193
+ @app.post("/preprocess_control_image")
194
+ async def preprocess_control_image_ui(
195
+ request: Request,
196
+ control_type: str = Form(...),
197
+ image: UploadFile = File(...)
198
+ ):
199
+ """Preprocess control image (Web UI)"""
200
+ try:
201
+ # Initialize ControlNet preprocessors if needed
202
+ if not controlnet_preprocessor.is_initialized():
203
+ logger.info("Initializing ControlNet preprocessors for image preprocessing...")
204
+ if not controlnet_preprocessor.initialize():
205
+ return {"error": "Failed to initialize ControlNet preprocessors. Please check your installation."}
206
+
207
+ # Read uploaded image
208
+ image_data = await image.read()
209
+ input_image = Image.open(io.BytesIO(image_data)).convert('RGB')
210
+
211
+ # Preprocess image
212
+ processed_image = controlnet_preprocessor.preprocess(input_image, control_type)
213
+
214
+ # Convert to base64
215
+ img_buffer = io.BytesIO()
216
+ processed_image.save(img_buffer, format='PNG')
217
+ img_buffer.seek(0)
218
+
219
+ return StreamingResponse(io.BytesIO(img_buffer.getvalue()), media_type="image/png")
220
+
221
+ except Exception as e:
222
+ # Return error as JSON
223
+ return {"error": f"Image preprocessing failed: {str(e)}"}
113
224
 
114
225
  @app.post("/load_model")
115
226
  async def load_model_ui(request: Request, model_name: str = Form(...)):
@@ -125,24 +236,14 @@ def create_ui_app() -> FastAPI:
125
236
  except Exception as e:
126
237
  error_message = f"Error loading model: {str(e)}"
127
238
 
128
- # Redirect back to home page
129
- models = model_manager.list_available_models()
130
- installed_models = model_manager.list_installed_models()
131
- current_model = model_manager.get_current_model()
132
- installed_loras = lora_manager.list_installed_loras()
133
- current_lora = lora_manager.get_current_lora()
134
-
135
- return templates.TemplateResponse("index.html", {
136
- "request": request,
137
- "models": models,
138
- "installed_models": installed_models,
139
- "current_model": current_model,
140
- "model_loaded": model_manager.is_model_loaded(),
141
- "installed_loras": installed_loras,
142
- "current_lora": current_lora,
239
+ # Return result page
240
+ context = get_template_context(request)
241
+ context.update({
143
242
  "success_message": f"Model {model_name} loaded successfully!" if success else None,
144
243
  "error_message": error_message
145
244
  })
245
+
246
+ return templates.TemplateResponse("index.html", context)
146
247
 
147
248
  @app.post("/unload_model")
148
249
  async def unload_model_ui(request: Request):
@@ -151,28 +252,19 @@ def create_ui_app() -> FastAPI:
151
252
  current_model = model_manager.get_current_model()
152
253
  model_manager.unload_model()
153
254
  success_message = f"Model {current_model} unloaded successfully!" if current_model else "Model unloaded!"
255
+ error_message = None
154
256
  except Exception as e:
155
257
  success_message = None
156
258
  error_message = f"Error unloading model: {str(e)}"
157
259
 
158
- # Redirect back to home page
159
- models = model_manager.list_available_models()
160
- installed_models = model_manager.list_installed_models()
161
- current_model = model_manager.get_current_model()
162
- installed_loras = lora_manager.list_installed_loras()
163
- current_lora = lora_manager.get_current_lora()
164
-
165
- return templates.TemplateResponse("index.html", {
166
- "request": request,
167
- "models": models,
168
- "installed_models": installed_models,
169
- "current_model": current_model,
170
- "model_loaded": model_manager.is_model_loaded(),
171
- "installed_loras": installed_loras,
172
- "current_lora": current_lora,
260
+ # Return result page
261
+ context = get_template_context(request)
262
+ context.update({
173
263
  "success_message": success_message,
174
- "error_message": error_message if 'error_message' in locals() else None
264
+ "error_message": error_message
175
265
  })
266
+
267
+ return templates.TemplateResponse("index.html", context)
176
268
 
177
269
  @app.post("/load_lora")
178
270
  async def load_lora_ui(request: Request, lora_name: str = Form(...), scale: float = Form(1.0)):
@@ -188,24 +280,14 @@ def create_ui_app() -> FastAPI:
188
280
  except Exception as e:
189
281
  error_message = f"Error loading LoRA: {str(e)}"
190
282
 
191
- # Redirect back to home page
192
- models = model_manager.list_available_models()
193
- installed_models = model_manager.list_installed_models()
194
- current_model = model_manager.get_current_model()
195
- installed_loras = lora_manager.list_installed_loras()
196
- current_lora = lora_manager.get_current_lora()
197
-
198
- return templates.TemplateResponse("index.html", {
199
- "request": request,
200
- "models": models,
201
- "installed_models": installed_models,
202
- "current_model": current_model,
203
- "model_loaded": model_manager.is_model_loaded(),
204
- "installed_loras": installed_loras,
205
- "current_lora": current_lora,
283
+ # Return result page
284
+ context = get_template_context(request)
285
+ context.update({
206
286
  "success_message": f"LoRA {lora_name} loaded successfully with scale {scale}!" if success else None,
207
287
  "error_message": error_message
208
288
  })
289
+
290
+ return templates.TemplateResponse("index.html", context)
209
291
 
210
292
  @app.post("/unload_lora")
211
293
  async def unload_lora_ui(request: Request):
@@ -214,28 +296,19 @@ def create_ui_app() -> FastAPI:
214
296
  current_lora_name = lora_manager.get_current_lora()
215
297
  lora_manager.unload_lora()
216
298
  success_message = f"LoRA {current_lora_name} unloaded successfully!" if current_lora_name else "LoRA unloaded!"
299
+ error_message = None
217
300
  except Exception as e:
218
301
  success_message = None
219
302
  error_message = f"Error unloading LoRA: {str(e)}"
220
303
 
221
- # Redirect back to home page
222
- models = model_manager.list_available_models()
223
- installed_models = model_manager.list_installed_models()
224
- current_model = model_manager.get_current_model()
225
- installed_loras = lora_manager.list_installed_loras()
226
- current_lora = lora_manager.get_current_lora()
227
-
228
- return templates.TemplateResponse("index.html", {
229
- "request": request,
230
- "models": models,
231
- "installed_models": installed_models,
232
- "current_model": current_model,
233
- "model_loaded": model_manager.is_model_loaded(),
234
- "installed_loras": installed_loras,
235
- "current_lora": current_lora,
304
+ # Return result page
305
+ context = get_template_context(request)
306
+ context.update({
236
307
  "success_message": success_message,
237
- "error_message": error_message if 'error_message' in locals() else None
308
+ "error_message": error_message
238
309
  })
310
+
311
+ return templates.TemplateResponse("index.html", context)
239
312
 
240
313
  @app.post("/pull_lora")
241
314
  async def pull_lora_ui(request: Request, repo_id: str = Form(...), weight_name: str = Form(""), alias: str = Form("")):
@@ -256,23 +329,31 @@ def create_ui_app() -> FastAPI:
256
329
  except Exception as e:
257
330
  error_message = f"Error downloading LoRA: {str(e)}"
258
331
 
259
- # Redirect back to home page
260
- models = model_manager.list_available_models()
261
- installed_models = model_manager.list_installed_models()
262
- current_model = model_manager.get_current_model()
263
- installed_loras = lora_manager.list_installed_loras()
264
- current_lora = lora_manager.get_current_lora()
265
-
266
- return templates.TemplateResponse("index.html", {
267
- "request": request,
268
- "models": models,
269
- "installed_models": installed_models,
270
- "current_model": current_model,
271
- "model_loaded": model_manager.is_model_loaded(),
272
- "installed_loras": installed_loras,
273
- "current_lora": current_lora,
332
+ # Return result page
333
+ context = get_template_context(request)
334
+ context.update({
274
335
  "success_message": f"LoRA {final_name if success else repo_id} downloaded successfully!" if success else None,
275
336
  "error_message": error_message
276
337
  })
338
+
339
+ return templates.TemplateResponse("index.html", context)
340
+
341
+ @app.post("/api/controlnet/initialize")
342
+ async def initialize_controlnet_api():
343
+ """Initialize ControlNet preprocessors (API endpoint)"""
344
+ try:
345
+ success = controlnet_preprocessor.initialize()
346
+ return {
347
+ "success": success,
348
+ "initialized": controlnet_preprocessor.is_initialized(),
349
+ "message": "ControlNet preprocessors initialized successfully!" if success else "Failed to initialize ControlNet preprocessors"
350
+ }
351
+ except Exception as e:
352
+ logger.error(f"Error initializing ControlNet: {e}")
353
+ return {
354
+ "success": False,
355
+ "initialized": False,
356
+ "message": f"Error initializing ControlNet: {str(e)}"
357
+ }
277
358
 
278
359
  return app