ollamadiffuser 1.0.0__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ollamadiffuser/api/server.py +147 -2
- ollamadiffuser/cli/main.py +325 -25
- ollamadiffuser/core/inference/engine.py +180 -9
- ollamadiffuser/core/models/manager.py +136 -2
- ollamadiffuser/core/utils/controlnet_preprocessors.py +317 -0
- ollamadiffuser/core/utils/download_utils.py +209 -60
- ollamadiffuser/ui/templates/index.html +384 -7
- ollamadiffuser/ui/web.py +181 -100
- ollamadiffuser-1.1.1.dist-info/METADATA +470 -0
- {ollamadiffuser-1.0.0.dist-info → ollamadiffuser-1.1.1.dist-info}/RECORD +14 -13
- ollamadiffuser-1.0.0.dist-info/METADATA +0 -493
- {ollamadiffuser-1.0.0.dist-info → ollamadiffuser-1.1.1.dist-info}/WHEEL +0 -0
- {ollamadiffuser-1.0.0.dist-info → ollamadiffuser-1.1.1.dist-info}/entry_points.txt +0 -0
- {ollamadiffuser-1.0.0.dist-info → ollamadiffuser-1.1.1.dist-info}/licenses/LICENSE +0 -0
- {ollamadiffuser-1.0.0.dist-info → ollamadiffuser-1.1.1.dist-info}/top_level.txt +0 -0
ollamadiffuser/api/server.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from fastapi import FastAPI, HTTPException
|
|
1
|
+
from fastapi import FastAPI, HTTPException, File, UploadFile, Form
|
|
2
2
|
from fastapi.responses import Response
|
|
3
3
|
from fastapi.middleware.cors import CORSMiddleware
|
|
4
4
|
import uvicorn
|
|
@@ -6,6 +6,7 @@ import io
|
|
|
6
6
|
import logging
|
|
7
7
|
from typing import Dict, Any, Optional
|
|
8
8
|
from pydantic import BaseModel
|
|
9
|
+
from PIL import Image
|
|
9
10
|
|
|
10
11
|
from ..core.models.manager import model_manager
|
|
11
12
|
from ..core.config.settings import settings
|
|
@@ -22,6 +23,10 @@ class GenerateRequest(BaseModel):
|
|
|
22
23
|
guidance_scale: Optional[float] = None
|
|
23
24
|
width: int = 1024
|
|
24
25
|
height: int = 1024
|
|
26
|
+
control_image_path: Optional[str] = None # Path to control image file
|
|
27
|
+
controlnet_conditioning_scale: float = 1.0
|
|
28
|
+
control_guidance_start: float = 0.0
|
|
29
|
+
control_guidance_end: float = 1.0
|
|
25
30
|
|
|
26
31
|
class LoadModelRequest(BaseModel):
|
|
27
32
|
model_name: str
|
|
@@ -234,7 +239,11 @@ def create_app() -> FastAPI:
|
|
|
234
239
|
num_inference_steps=request.num_inference_steps,
|
|
235
240
|
guidance_scale=request.guidance_scale,
|
|
236
241
|
width=request.width,
|
|
237
|
-
height=request.height
|
|
242
|
+
height=request.height,
|
|
243
|
+
control_image=request.control_image_path,
|
|
244
|
+
controlnet_conditioning_scale=request.controlnet_conditioning_scale,
|
|
245
|
+
control_guidance_start=request.control_guidance_start,
|
|
246
|
+
control_guidance_end=request.control_guidance_end
|
|
238
247
|
)
|
|
239
248
|
|
|
240
249
|
# Convert PIL image to bytes
|
|
@@ -248,6 +257,142 @@ def create_app() -> FastAPI:
|
|
|
248
257
|
logger.error(f"Image generation failed: {e}")
|
|
249
258
|
raise HTTPException(status_code=500, detail=f"Image generation failed: {str(e)}")
|
|
250
259
|
|
|
260
|
+
@app.post("/api/generate/controlnet")
|
|
261
|
+
async def generate_image_with_controlnet(
|
|
262
|
+
prompt: str = Form(...),
|
|
263
|
+
negative_prompt: str = Form("low quality, bad anatomy, worst quality, low resolution"),
|
|
264
|
+
num_inference_steps: Optional[int] = Form(None),
|
|
265
|
+
guidance_scale: Optional[float] = Form(None),
|
|
266
|
+
width: int = Form(1024),
|
|
267
|
+
height: int = Form(1024),
|
|
268
|
+
controlnet_conditioning_scale: float = Form(1.0),
|
|
269
|
+
control_guidance_start: float = Form(0.0),
|
|
270
|
+
control_guidance_end: float = Form(1.0),
|
|
271
|
+
control_image: Optional[UploadFile] = File(None)
|
|
272
|
+
):
|
|
273
|
+
"""Generate image with ControlNet using uploaded control image"""
|
|
274
|
+
# Check if model is loaded
|
|
275
|
+
if not model_manager.is_model_loaded():
|
|
276
|
+
raise HTTPException(status_code=400, detail="No model loaded, please load a model first")
|
|
277
|
+
|
|
278
|
+
try:
|
|
279
|
+
# Get current loaded inference engine
|
|
280
|
+
engine = model_manager.loaded_model
|
|
281
|
+
|
|
282
|
+
# Check if this is a ControlNet model
|
|
283
|
+
if not engine.is_controlnet_pipeline:
|
|
284
|
+
raise HTTPException(status_code=400, detail="Current model is not a ControlNet model")
|
|
285
|
+
|
|
286
|
+
# Process control image if provided
|
|
287
|
+
control_image_pil = None
|
|
288
|
+
if control_image:
|
|
289
|
+
# Read uploaded image
|
|
290
|
+
image_data = await control_image.read()
|
|
291
|
+
control_image_pil = Image.open(io.BytesIO(image_data)).convert('RGB')
|
|
292
|
+
|
|
293
|
+
# Generate image
|
|
294
|
+
image = engine.generate_image(
|
|
295
|
+
prompt=prompt,
|
|
296
|
+
negative_prompt=negative_prompt,
|
|
297
|
+
num_inference_steps=num_inference_steps,
|
|
298
|
+
guidance_scale=guidance_scale,
|
|
299
|
+
width=width,
|
|
300
|
+
height=height,
|
|
301
|
+
control_image=control_image_pil,
|
|
302
|
+
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
|
303
|
+
control_guidance_start=control_guidance_start,
|
|
304
|
+
control_guidance_end=control_guidance_end
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
# Convert PIL image to bytes
|
|
308
|
+
img_byte_arr = io.BytesIO()
|
|
309
|
+
image.save(img_byte_arr, format='PNG')
|
|
310
|
+
img_byte_arr = img_byte_arr.getvalue()
|
|
311
|
+
|
|
312
|
+
return Response(content=img_byte_arr, media_type="image/png")
|
|
313
|
+
|
|
314
|
+
except Exception as e:
|
|
315
|
+
logger.error(f"ControlNet image generation failed: {e}")
|
|
316
|
+
raise HTTPException(status_code=500, detail=f"ControlNet image generation failed: {str(e)}")
|
|
317
|
+
|
|
318
|
+
@app.post("/api/controlnet/initialize")
|
|
319
|
+
async def initialize_controlnet():
|
|
320
|
+
"""Initialize ControlNet preprocessors"""
|
|
321
|
+
try:
|
|
322
|
+
from ..core.utils.controlnet_preprocessors import controlnet_preprocessor
|
|
323
|
+
|
|
324
|
+
logger.info("Explicitly initializing ControlNet preprocessors...")
|
|
325
|
+
success = controlnet_preprocessor.initialize(force=True)
|
|
326
|
+
|
|
327
|
+
return {
|
|
328
|
+
"success": success,
|
|
329
|
+
"initialized": controlnet_preprocessor.is_initialized(),
|
|
330
|
+
"available_types": controlnet_preprocessor.get_available_types(),
|
|
331
|
+
"message": "ControlNet preprocessors initialized successfully" if success else "Failed to initialize ControlNet preprocessors"
|
|
332
|
+
}
|
|
333
|
+
except Exception as e:
|
|
334
|
+
logger.error(f"Failed to initialize ControlNet preprocessors: {e}")
|
|
335
|
+
raise HTTPException(status_code=500, detail=f"Failed to initialize ControlNet preprocessors: {str(e)}")
|
|
336
|
+
|
|
337
|
+
@app.get("/api/controlnet/preprocessors")
|
|
338
|
+
async def get_controlnet_preprocessors():
|
|
339
|
+
"""Get available ControlNet preprocessors"""
|
|
340
|
+
try:
|
|
341
|
+
from ..core.utils.controlnet_preprocessors import controlnet_preprocessor
|
|
342
|
+
return {
|
|
343
|
+
"available_types": controlnet_preprocessor.get_available_types(),
|
|
344
|
+
"available": controlnet_preprocessor.is_available(),
|
|
345
|
+
"initialized": controlnet_preprocessor.is_initialized(),
|
|
346
|
+
"description": {
|
|
347
|
+
"canny": "Edge detection using Canny algorithm",
|
|
348
|
+
"depth": "Depth estimation for depth-based control",
|
|
349
|
+
"openpose": "Human pose detection for pose control",
|
|
350
|
+
"scribble": "Scribble/sketch detection for artistic control",
|
|
351
|
+
"hed": "Holistically-nested edge detection",
|
|
352
|
+
"mlsd": "Mobile line segment detection",
|
|
353
|
+
"normal": "Surface normal estimation",
|
|
354
|
+
"lineart": "Line art detection",
|
|
355
|
+
"lineart_anime": "Anime-style line art detection",
|
|
356
|
+
"shuffle": "Content shuffling for style transfer"
|
|
357
|
+
}
|
|
358
|
+
}
|
|
359
|
+
except Exception as e:
|
|
360
|
+
logger.error(f"Failed to get ControlNet preprocessors: {e}")
|
|
361
|
+
raise HTTPException(status_code=500, detail=f"Failed to get ControlNet preprocessors: {str(e)}")
|
|
362
|
+
|
|
363
|
+
@app.post("/api/controlnet/preprocess")
|
|
364
|
+
async def preprocess_control_image(
|
|
365
|
+
control_type: str = Form(...),
|
|
366
|
+
image: UploadFile = File(...)
|
|
367
|
+
):
|
|
368
|
+
"""Preprocess image for ControlNet"""
|
|
369
|
+
try:
|
|
370
|
+
from ..core.utils.controlnet_preprocessors import controlnet_preprocessor
|
|
371
|
+
|
|
372
|
+
# Initialize ControlNet preprocessors if needed
|
|
373
|
+
if not controlnet_preprocessor.is_initialized():
|
|
374
|
+
logger.info("Initializing ControlNet preprocessors for image preprocessing...")
|
|
375
|
+
if not controlnet_preprocessor.initialize():
|
|
376
|
+
raise HTTPException(status_code=500, detail="Failed to initialize ControlNet preprocessors")
|
|
377
|
+
|
|
378
|
+
# Read uploaded image
|
|
379
|
+
image_data = await image.read()
|
|
380
|
+
input_image = Image.open(io.BytesIO(image_data)).convert('RGB')
|
|
381
|
+
|
|
382
|
+
# Preprocess image
|
|
383
|
+
processed_image = controlnet_preprocessor.preprocess(input_image, control_type)
|
|
384
|
+
|
|
385
|
+
# Convert PIL image to bytes
|
|
386
|
+
img_byte_arr = io.BytesIO()
|
|
387
|
+
processed_image.save(img_byte_arr, format='PNG')
|
|
388
|
+
img_byte_arr = img_byte_arr.getvalue()
|
|
389
|
+
|
|
390
|
+
return Response(content=img_byte_arr, media_type="image/png")
|
|
391
|
+
|
|
392
|
+
except Exception as e:
|
|
393
|
+
logger.error(f"Image preprocessing failed: {e}")
|
|
394
|
+
raise HTTPException(status_code=500, detail=f"Image preprocessing failed: {str(e)}")
|
|
395
|
+
|
|
251
396
|
# Health check endpoints
|
|
252
397
|
@app.get("/api/health")
|
|
253
398
|
async def health_check():
|
ollamadiffuser/cli/main.py
CHANGED
|
@@ -5,8 +5,9 @@ import logging
|
|
|
5
5
|
from typing import Optional
|
|
6
6
|
from rich.console import Console
|
|
7
7
|
from rich.table import Table
|
|
8
|
-
from rich.progress import Progress, SpinnerColumn, TextColumn
|
|
8
|
+
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, DownloadColumn, TransferSpeedColumn, TimeRemainingColumn
|
|
9
9
|
from rich import print as rprint
|
|
10
|
+
import time
|
|
10
11
|
|
|
11
12
|
from ..core.models.manager import model_manager
|
|
12
13
|
from ..core.config.settings import settings
|
|
@@ -14,6 +15,44 @@ from ..api.server import run_server
|
|
|
14
15
|
|
|
15
16
|
console = Console()
|
|
16
17
|
|
|
18
|
+
class OllamaStyleProgress:
|
|
19
|
+
"""Enhanced progress tracker that mimics Ollama's progress display"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, console: Console):
|
|
22
|
+
self.console = console
|
|
23
|
+
self.last_message = ""
|
|
24
|
+
|
|
25
|
+
def update(self, message: str):
|
|
26
|
+
"""Update progress with a message"""
|
|
27
|
+
# Skip duplicate messages
|
|
28
|
+
if message == self.last_message:
|
|
29
|
+
return
|
|
30
|
+
|
|
31
|
+
self.last_message = message
|
|
32
|
+
|
|
33
|
+
# Handle different types of messages
|
|
34
|
+
if message.startswith("pulling ") and ":" in message and "%" in message:
|
|
35
|
+
# This is a file progress message from download_utils
|
|
36
|
+
# Format: "pulling e6a7edc1a4d7: 12% ▕██ ▏ 617 MB/5200 MB 44 MB/s 1m44s"
|
|
37
|
+
self.console.print(message)
|
|
38
|
+
elif message.startswith("pulling manifest"):
|
|
39
|
+
self.console.print(message)
|
|
40
|
+
elif message.startswith("📦 Repository:"):
|
|
41
|
+
# Repository info
|
|
42
|
+
self.console.print(f"[dim]{message}[/dim]")
|
|
43
|
+
elif message.startswith("📁 Found"):
|
|
44
|
+
# Existing files info
|
|
45
|
+
self.console.print(f"[dim]{message}[/dim]")
|
|
46
|
+
elif message.startswith("✅") and "download completed" in message:
|
|
47
|
+
self.console.print(f"[green]{message}[/green]")
|
|
48
|
+
elif message.startswith("❌"):
|
|
49
|
+
self.console.print(f"[red]{message}[/red]")
|
|
50
|
+
elif message.startswith("⚠️"):
|
|
51
|
+
self.console.print(f"[yellow]{message}[/yellow]")
|
|
52
|
+
else:
|
|
53
|
+
# For other messages, print with dimmed style
|
|
54
|
+
self.console.print(f"[dim]{message}[/dim]")
|
|
55
|
+
|
|
17
56
|
@click.group()
|
|
18
57
|
@click.option('--verbose', '-v', is_flag=True, help='Enable verbose output')
|
|
19
58
|
def cli(verbose):
|
|
@@ -30,24 +69,26 @@ def pull(model_name: str, force: bool):
|
|
|
30
69
|
"""Download model"""
|
|
31
70
|
rprint(f"[blue]Downloading model: {model_name}[/blue]")
|
|
32
71
|
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
"""Update progress display with download status"""
|
|
42
|
-
progress.update(task, description=message)
|
|
43
|
-
|
|
72
|
+
# Use the new Ollama-style progress tracker
|
|
73
|
+
progress_tracker = OllamaStyleProgress(console)
|
|
74
|
+
|
|
75
|
+
def progress_callback(message: str):
|
|
76
|
+
"""Enhanced progress callback with Ollama-style display"""
|
|
77
|
+
progress_tracker.update(message)
|
|
78
|
+
|
|
79
|
+
try:
|
|
44
80
|
if model_manager.pull_model(model_name, force=force, progress_callback=progress_callback):
|
|
45
|
-
|
|
81
|
+
progress_tracker.update("✅ download completed")
|
|
46
82
|
rprint(f"[green]Model {model_name} downloaded successfully![/green]")
|
|
47
83
|
else:
|
|
48
|
-
progress.update(task, description=f"❌ {model_name} download failed")
|
|
49
84
|
rprint(f"[red]Model {model_name} download failed![/red]")
|
|
50
85
|
sys.exit(1)
|
|
86
|
+
except KeyboardInterrupt:
|
|
87
|
+
rprint("\n[yellow]Download cancelled by user[/yellow]")
|
|
88
|
+
sys.exit(1)
|
|
89
|
+
except Exception as e:
|
|
90
|
+
rprint(f"[red]Download failed: {str(e)}[/red]")
|
|
91
|
+
sys.exit(1)
|
|
51
92
|
|
|
52
93
|
@cli.command()
|
|
53
94
|
@click.argument('model_name')
|
|
@@ -226,21 +267,280 @@ def check(model_name: str, list: bool):
|
|
|
226
267
|
rprint("[dim] ollamadiffuser check --list[/dim]")
|
|
227
268
|
return
|
|
228
269
|
|
|
229
|
-
#
|
|
270
|
+
# Check model download status directly
|
|
271
|
+
status = _check_download_status(model_name)
|
|
272
|
+
|
|
273
|
+
rprint("\n" + "="*60)
|
|
274
|
+
|
|
275
|
+
if status is True:
|
|
276
|
+
rprint(f"[green]🎉 {model_name} is ready to use![/green]")
|
|
277
|
+
rprint(f"\n[blue]💡 You can now run:[/blue]")
|
|
278
|
+
rprint(f" [cyan]ollamadiffuser run {model_name}[/cyan]")
|
|
279
|
+
elif status == "needs_config":
|
|
280
|
+
rprint(f"[yellow]⚠️ {model_name} files are complete but model needs configuration[/yellow]")
|
|
281
|
+
rprint(f"\n[blue]💡 Try reinstalling:[/blue]")
|
|
282
|
+
rprint(f" [cyan]ollamadiffuser pull {model_name} --force[/cyan]")
|
|
283
|
+
elif status == "downloading":
|
|
284
|
+
rprint(f"[yellow]🔄 {model_name} is currently downloading[/yellow]")
|
|
285
|
+
rprint(f"\n[blue]💡 Wait for download to complete or check progress[/blue]")
|
|
286
|
+
elif status == "incomplete":
|
|
287
|
+
rprint(f"[yellow]⚠️ Download is incomplete[/yellow]")
|
|
288
|
+
rprint(f"\n[blue]💡 Resume download with:[/blue]")
|
|
289
|
+
rprint(f" [cyan]ollamadiffuser pull {model_name}[/cyan]")
|
|
290
|
+
rprint(f"\n[blue]💡 Or force fresh download with:[/blue]")
|
|
291
|
+
rprint(f" [cyan]ollamadiffuser pull {model_name} --force[/cyan]")
|
|
292
|
+
else:
|
|
293
|
+
rprint(f"[red]❌ {model_name} is not downloaded[/red]")
|
|
294
|
+
rprint(f"\n[blue]💡 Download with:[/blue]")
|
|
295
|
+
rprint(f" [cyan]ollamadiffuser pull {model_name}[/cyan]")
|
|
296
|
+
|
|
297
|
+
_show_model_specific_help(model_name)
|
|
298
|
+
|
|
299
|
+
rprint(f"\n[dim]📚 For more help: ollamadiffuser --help[/dim]")
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def _check_download_status(model_name: str):
|
|
303
|
+
"""Check the current download status of any model"""
|
|
304
|
+
from ..core.utils.download_utils import check_download_integrity, get_repo_file_list, format_size
|
|
230
305
|
import subprocess
|
|
231
|
-
import sys
|
|
232
|
-
from pathlib import Path
|
|
233
306
|
|
|
234
|
-
|
|
235
|
-
|
|
307
|
+
rprint(f"[blue]🔍 Checking {model_name} download status...[/blue]\n")
|
|
308
|
+
|
|
309
|
+
# Check if model is in registry
|
|
310
|
+
if model_name not in model_manager.model_registry:
|
|
311
|
+
rprint(f"[red]❌ {model_name} not found in model registry[/red]")
|
|
312
|
+
available_models = model_manager.list_available_models()
|
|
313
|
+
rprint(f"[blue]📋 Available models: {', '.join(available_models)}[/blue]")
|
|
314
|
+
return False
|
|
315
|
+
|
|
316
|
+
model_info = model_manager.model_registry[model_name]
|
|
317
|
+
repo_id = model_info["repo_id"]
|
|
318
|
+
model_path = settings.get_model_path(model_name)
|
|
319
|
+
|
|
320
|
+
rprint(f"[cyan]📦 Model: {model_name}[/cyan]")
|
|
321
|
+
rprint(f"[cyan]🔗 Repository: {repo_id}[/cyan]")
|
|
322
|
+
rprint(f"[cyan]📁 Local path: {model_path}[/cyan]")
|
|
323
|
+
|
|
324
|
+
# Show model-specific info
|
|
325
|
+
license_info = model_info.get("license_info", {})
|
|
326
|
+
if license_info:
|
|
327
|
+
rprint(f"[yellow]📄 License: {license_info.get('type', 'Unknown')}[/yellow]")
|
|
328
|
+
rprint(f"[yellow]🔑 HF Token Required: {'Yes' if license_info.get('requires_agreement', False) else 'No'}[/yellow]")
|
|
329
|
+
rprint(f"[yellow]💼 Commercial Use: {'Allowed' if license_info.get('commercial_use', False) else 'Not Allowed'}[/yellow]")
|
|
330
|
+
|
|
331
|
+
# Show optimal parameters
|
|
332
|
+
params = model_info.get("parameters", {})
|
|
333
|
+
if params:
|
|
334
|
+
rprint(f"[green]⚡ Optimal Settings:[/green]")
|
|
335
|
+
rprint(f" Steps: {params.get('num_inference_steps', 'N/A')}")
|
|
336
|
+
rprint(f" Guidance: {params.get('guidance_scale', 'N/A')}")
|
|
337
|
+
if 'max_sequence_length' in params:
|
|
338
|
+
rprint(f" Max Seq Length: {params['max_sequence_length']}")
|
|
339
|
+
|
|
340
|
+
rprint()
|
|
341
|
+
|
|
342
|
+
# Check if directory exists
|
|
343
|
+
if not model_path.exists():
|
|
344
|
+
rprint("[yellow]📂 Status: Not downloaded[/yellow]")
|
|
345
|
+
return False
|
|
346
|
+
|
|
347
|
+
# Get repository file list
|
|
348
|
+
rprint("[blue]🌐 Getting repository information...[/blue]")
|
|
349
|
+
try:
|
|
350
|
+
file_sizes = get_repo_file_list(repo_id)
|
|
351
|
+
total_expected_size = sum(file_sizes.values())
|
|
352
|
+
total_files_expected = len(file_sizes)
|
|
353
|
+
|
|
354
|
+
rprint(f"[blue]📊 Expected: {total_files_expected} files, {format_size(total_expected_size)} total[/blue]")
|
|
355
|
+
except Exception as e:
|
|
356
|
+
rprint(f"[yellow]⚠️ Could not get repository info: {e}[/yellow]")
|
|
357
|
+
file_sizes = {}
|
|
358
|
+
total_expected_size = 0
|
|
359
|
+
total_files_expected = 0
|
|
360
|
+
|
|
361
|
+
# Check local files
|
|
362
|
+
local_files = []
|
|
363
|
+
local_size = 0
|
|
364
|
+
|
|
365
|
+
for file_path in model_path.rglob('*'):
|
|
366
|
+
if file_path.is_file():
|
|
367
|
+
rel_path = file_path.relative_to(model_path)
|
|
368
|
+
file_size = file_path.stat().st_size
|
|
369
|
+
local_files.append((str(rel_path), file_size))
|
|
370
|
+
local_size += file_size
|
|
371
|
+
|
|
372
|
+
rprint(f"[blue]💾 Downloaded: {len(local_files)} files, {format_size(local_size)} total[/blue]")
|
|
373
|
+
|
|
374
|
+
if total_expected_size > 0:
|
|
375
|
+
progress_percent = (local_size / total_expected_size) * 100
|
|
376
|
+
rprint(f"[blue]📈 Progress: {progress_percent:.1f}%[/blue]")
|
|
377
|
+
|
|
378
|
+
rprint()
|
|
379
|
+
|
|
380
|
+
# Check for missing files
|
|
381
|
+
if file_sizes:
|
|
382
|
+
# Check if we have size information from the API
|
|
383
|
+
has_size_info = any(size > 0 for size in file_sizes.values())
|
|
384
|
+
|
|
385
|
+
if has_size_info:
|
|
386
|
+
# Normal case: we have size information, do detailed comparison
|
|
387
|
+
missing_files = []
|
|
388
|
+
incomplete_files = []
|
|
389
|
+
|
|
390
|
+
for expected_file, expected_size in file_sizes.items():
|
|
391
|
+
local_file_path = model_path / expected_file
|
|
392
|
+
if not local_file_path.exists():
|
|
393
|
+
missing_files.append(expected_file)
|
|
394
|
+
elif expected_size > 0 and local_file_path.stat().st_size != expected_size:
|
|
395
|
+
local_size_actual = local_file_path.stat().st_size
|
|
396
|
+
incomplete_files.append((expected_file, local_size_actual, expected_size))
|
|
397
|
+
|
|
398
|
+
if missing_files:
|
|
399
|
+
rprint(f"[red]❌ Missing files ({len(missing_files)}):[/red]")
|
|
400
|
+
for missing_file in missing_files[:10]: # Show first 10
|
|
401
|
+
rprint(f" - {missing_file}")
|
|
402
|
+
if len(missing_files) > 10:
|
|
403
|
+
rprint(f" ... and {len(missing_files) - 10} more")
|
|
404
|
+
rprint()
|
|
405
|
+
|
|
406
|
+
if incomplete_files:
|
|
407
|
+
rprint(f"[yellow]⚠️ Incomplete files ({len(incomplete_files)}):[/yellow]")
|
|
408
|
+
for incomplete_file, actual_size, expected_size in incomplete_files[:5]:
|
|
409
|
+
rprint(f" - {incomplete_file}: {format_size(actual_size)}/{format_size(expected_size)}")
|
|
410
|
+
if len(incomplete_files) > 5:
|
|
411
|
+
rprint(f" ... and {len(incomplete_files) - 5} more")
|
|
412
|
+
rprint()
|
|
413
|
+
|
|
414
|
+
if not missing_files and not incomplete_files:
|
|
415
|
+
rprint("[green]✅ All files present and complete![/green]")
|
|
416
|
+
|
|
417
|
+
# Check integrity
|
|
418
|
+
rprint("[blue]🔍 Checking download integrity...[/blue]")
|
|
419
|
+
if check_download_integrity(str(model_path), repo_id):
|
|
420
|
+
rprint("[green]✅ Download integrity verified![/green]")
|
|
421
|
+
|
|
422
|
+
# Check if model is in configuration
|
|
423
|
+
if model_manager.is_model_installed(model_name):
|
|
424
|
+
rprint("[green]✅ Model is properly configured[/green]")
|
|
425
|
+
return True
|
|
426
|
+
else:
|
|
427
|
+
rprint("[yellow]⚠️ Model files complete but not in configuration[/yellow]")
|
|
428
|
+
return "needs_config"
|
|
429
|
+
else:
|
|
430
|
+
rprint("[red]❌ Download integrity check failed[/red]")
|
|
431
|
+
return False
|
|
432
|
+
else:
|
|
433
|
+
rprint("[yellow]⚠️ Download is incomplete[/yellow]")
|
|
434
|
+
return "incomplete"
|
|
435
|
+
else:
|
|
436
|
+
# No size information available from API (common with gated repos)
|
|
437
|
+
rprint("[blue]ℹ️ Repository API doesn't provide file sizes (common with gated models)[/blue]")
|
|
438
|
+
rprint("[blue]🔍 Checking essential model files instead...[/blue]")
|
|
439
|
+
|
|
440
|
+
# Check for essential model files
|
|
441
|
+
# Determine model type based on repo_id
|
|
442
|
+
is_controlnet = 'controlnet' in repo_id.lower()
|
|
443
|
+
|
|
444
|
+
if is_controlnet:
|
|
445
|
+
# ControlNet models have different essential files
|
|
446
|
+
essential_files = ['config.json']
|
|
447
|
+
essential_dirs = [] # ControlNet models don't have complex directory structure
|
|
448
|
+
else:
|
|
449
|
+
# Regular diffusion models
|
|
450
|
+
essential_files = ['model_index.json']
|
|
451
|
+
essential_dirs = ['transformer', 'text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'vae', 'scheduler']
|
|
452
|
+
|
|
453
|
+
missing_essential = []
|
|
454
|
+
for essential_file in essential_files:
|
|
455
|
+
if not (model_path / essential_file).exists():
|
|
456
|
+
missing_essential.append(essential_file)
|
|
457
|
+
|
|
458
|
+
existing_dirs = []
|
|
459
|
+
for essential_dir in essential_dirs:
|
|
460
|
+
if (model_path / essential_dir).exists():
|
|
461
|
+
existing_dirs.append(essential_dir)
|
|
462
|
+
|
|
463
|
+
if missing_essential:
|
|
464
|
+
rprint(f"[red]❌ Missing essential files: {', '.join(missing_essential)}[/red]")
|
|
465
|
+
return "incomplete"
|
|
466
|
+
|
|
467
|
+
if existing_dirs:
|
|
468
|
+
rprint(f"[green]✅ Found model components: {', '.join(existing_dirs)}[/green]")
|
|
469
|
+
|
|
470
|
+
# Check integrity
|
|
471
|
+
rprint("[blue]🔍 Checking download integrity...[/blue]")
|
|
472
|
+
if check_download_integrity(str(model_path), repo_id):
|
|
473
|
+
rprint("[green]✅ Download integrity verified![/green]")
|
|
474
|
+
|
|
475
|
+
# Check if model is in configuration
|
|
476
|
+
if model_manager.is_model_installed(model_name):
|
|
477
|
+
rprint("[green]✅ Model is properly configured and functional[/green]")
|
|
478
|
+
return True
|
|
479
|
+
else:
|
|
480
|
+
rprint("[yellow]⚠️ Model files complete but not in configuration[/yellow]")
|
|
481
|
+
return "needs_config"
|
|
482
|
+
else:
|
|
483
|
+
rprint("[red]❌ Download integrity check failed[/red]")
|
|
484
|
+
return False
|
|
485
|
+
|
|
486
|
+
# Check if download process is running
|
|
487
|
+
rprint("[blue]🔍 Checking for active download processes...[/blue]")
|
|
236
488
|
try:
|
|
237
|
-
result = subprocess.run([
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
489
|
+
result = subprocess.run(['ps', 'aux'], capture_output=True, text=True)
|
|
490
|
+
if f'ollamadiffuser pull {model_name}' in result.stdout:
|
|
491
|
+
rprint("[yellow]🔄 Download process is currently running[/yellow]")
|
|
492
|
+
return "downloading"
|
|
493
|
+
else:
|
|
494
|
+
rprint("[blue]💤 No active download process found[/blue]")
|
|
242
495
|
except Exception as e:
|
|
243
|
-
rprint(f"[
|
|
496
|
+
rprint(f"[yellow]⚠️ Could not check processes: {e}[/yellow]")
|
|
497
|
+
|
|
498
|
+
return "incomplete"
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
def _show_model_specific_help(model_name: str):
|
|
502
|
+
"""Show model-specific help and recommendations"""
|
|
503
|
+
model_info = model_manager.get_model_info(model_name)
|
|
504
|
+
if not model_info:
|
|
505
|
+
return
|
|
506
|
+
|
|
507
|
+
rprint(f"\n[bold blue]💡 {model_name} Specific Tips:[/bold blue]")
|
|
508
|
+
|
|
509
|
+
# License-specific help
|
|
510
|
+
license_info = model_info.get("license_info", {})
|
|
511
|
+
if license_info.get("requires_agreement", False):
|
|
512
|
+
rprint(f" [yellow]🔑 Requires HuggingFace token and license agreement[/yellow]")
|
|
513
|
+
rprint(f" [blue]📝 Visit: https://huggingface.co/{model_info['repo_id']}[/blue]")
|
|
514
|
+
rprint(f" [cyan]🔧 Set token: export HF_TOKEN=your_token_here[/cyan]")
|
|
515
|
+
else:
|
|
516
|
+
rprint(f" [green]✅ No HuggingFace token required![/green]")
|
|
517
|
+
|
|
518
|
+
# Model-specific optimizations
|
|
519
|
+
if "schnell" in model_name.lower():
|
|
520
|
+
rprint(f" [green]⚡ FLUX.1-schnell is 12x faster than FLUX.1-dev[/green]")
|
|
521
|
+
rprint(f" [green]🎯 Optimized for 4-step generation[/green]")
|
|
522
|
+
rprint(f" [green]💼 Commercial use allowed (Apache 2.0)[/green]")
|
|
523
|
+
elif "flux.1-dev" in model_name.lower():
|
|
524
|
+
rprint(f" [blue]🎨 Best quality FLUX model[/blue]")
|
|
525
|
+
rprint(f" [blue]🔬 Requires 50 steps for optimal results[/blue]")
|
|
526
|
+
rprint(f" [yellow]⚠️ Non-commercial license only[/yellow]")
|
|
527
|
+
elif "stable-diffusion-1.5" in model_name.lower():
|
|
528
|
+
rprint(f" [green]🚀 Great for learning and quick tests[/green]")
|
|
529
|
+
rprint(f" [green]💾 Smallest model, runs on most hardware[/green]")
|
|
530
|
+
elif "stable-diffusion-3.5" in model_name.lower():
|
|
531
|
+
rprint(f" [green]🏆 Excellent quality-to-speed ratio[/green]")
|
|
532
|
+
rprint(f" [green]🔄 Great LoRA ecosystem[/green]")
|
|
533
|
+
|
|
534
|
+
# Hardware recommendations
|
|
535
|
+
hw_req = model_info.get("hardware_requirements", {})
|
|
536
|
+
if hw_req:
|
|
537
|
+
min_vram = hw_req.get("min_vram_gb", 0)
|
|
538
|
+
if min_vram >= 12:
|
|
539
|
+
rprint(f" [yellow]🖥️ Requires high-end GPU (RTX 4070+ or M2 Pro+)[/yellow]")
|
|
540
|
+
elif min_vram >= 8:
|
|
541
|
+
rprint(f" [blue]🖥️ Requires mid-range GPU (RTX 3080+ or M1 Pro+)[/blue]")
|
|
542
|
+
else:
|
|
543
|
+
rprint(f" [green]🖥️ Runs on most modern GPUs[/green]")
|
|
244
544
|
|
|
245
545
|
@cli.command()
|
|
246
546
|
@click.argument('model_name')
|