ollamadiffuser 1.0.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- from fastapi import FastAPI, HTTPException
1
+ from fastapi import FastAPI, HTTPException, File, UploadFile, Form
2
2
  from fastapi.responses import Response
3
3
  from fastapi.middleware.cors import CORSMiddleware
4
4
  import uvicorn
@@ -6,6 +6,7 @@ import io
6
6
  import logging
7
7
  from typing import Dict, Any, Optional
8
8
  from pydantic import BaseModel
9
+ from PIL import Image
9
10
 
10
11
  from ..core.models.manager import model_manager
11
12
  from ..core.config.settings import settings
@@ -22,6 +23,10 @@ class GenerateRequest(BaseModel):
22
23
  guidance_scale: Optional[float] = None
23
24
  width: int = 1024
24
25
  height: int = 1024
26
+ control_image_path: Optional[str] = None # Path to control image file
27
+ controlnet_conditioning_scale: float = 1.0
28
+ control_guidance_start: float = 0.0
29
+ control_guidance_end: float = 1.0
25
30
 
26
31
  class LoadModelRequest(BaseModel):
27
32
  model_name: str
@@ -234,7 +239,11 @@ def create_app() -> FastAPI:
234
239
  num_inference_steps=request.num_inference_steps,
235
240
  guidance_scale=request.guidance_scale,
236
241
  width=request.width,
237
- height=request.height
242
+ height=request.height,
243
+ control_image=request.control_image_path,
244
+ controlnet_conditioning_scale=request.controlnet_conditioning_scale,
245
+ control_guidance_start=request.control_guidance_start,
246
+ control_guidance_end=request.control_guidance_end
238
247
  )
239
248
 
240
249
  # Convert PIL image to bytes
@@ -248,6 +257,142 @@ def create_app() -> FastAPI:
248
257
  logger.error(f"Image generation failed: {e}")
249
258
  raise HTTPException(status_code=500, detail=f"Image generation failed: {str(e)}")
250
259
 
260
+ @app.post("/api/generate/controlnet")
261
+ async def generate_image_with_controlnet(
262
+ prompt: str = Form(...),
263
+ negative_prompt: str = Form("low quality, bad anatomy, worst quality, low resolution"),
264
+ num_inference_steps: Optional[int] = Form(None),
265
+ guidance_scale: Optional[float] = Form(None),
266
+ width: int = Form(1024),
267
+ height: int = Form(1024),
268
+ controlnet_conditioning_scale: float = Form(1.0),
269
+ control_guidance_start: float = Form(0.0),
270
+ control_guidance_end: float = Form(1.0),
271
+ control_image: Optional[UploadFile] = File(None)
272
+ ):
273
+ """Generate image with ControlNet using uploaded control image"""
274
+ # Check if model is loaded
275
+ if not model_manager.is_model_loaded():
276
+ raise HTTPException(status_code=400, detail="No model loaded, please load a model first")
277
+
278
+ try:
279
+ # Get current loaded inference engine
280
+ engine = model_manager.loaded_model
281
+
282
+ # Check if this is a ControlNet model
283
+ if not engine.is_controlnet_pipeline:
284
+ raise HTTPException(status_code=400, detail="Current model is not a ControlNet model")
285
+
286
+ # Process control image if provided
287
+ control_image_pil = None
288
+ if control_image:
289
+ # Read uploaded image
290
+ image_data = await control_image.read()
291
+ control_image_pil = Image.open(io.BytesIO(image_data)).convert('RGB')
292
+
293
+ # Generate image
294
+ image = engine.generate_image(
295
+ prompt=prompt,
296
+ negative_prompt=negative_prompt,
297
+ num_inference_steps=num_inference_steps,
298
+ guidance_scale=guidance_scale,
299
+ width=width,
300
+ height=height,
301
+ control_image=control_image_pil,
302
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
303
+ control_guidance_start=control_guidance_start,
304
+ control_guidance_end=control_guidance_end
305
+ )
306
+
307
+ # Convert PIL image to bytes
308
+ img_byte_arr = io.BytesIO()
309
+ image.save(img_byte_arr, format='PNG')
310
+ img_byte_arr = img_byte_arr.getvalue()
311
+
312
+ return Response(content=img_byte_arr, media_type="image/png")
313
+
314
+ except Exception as e:
315
+ logger.error(f"ControlNet image generation failed: {e}")
316
+ raise HTTPException(status_code=500, detail=f"ControlNet image generation failed: {str(e)}")
317
+
318
+ @app.post("/api/controlnet/initialize")
319
+ async def initialize_controlnet():
320
+ """Initialize ControlNet preprocessors"""
321
+ try:
322
+ from ..core.utils.controlnet_preprocessors import controlnet_preprocessor
323
+
324
+ logger.info("Explicitly initializing ControlNet preprocessors...")
325
+ success = controlnet_preprocessor.initialize(force=True)
326
+
327
+ return {
328
+ "success": success,
329
+ "initialized": controlnet_preprocessor.is_initialized(),
330
+ "available_types": controlnet_preprocessor.get_available_types(),
331
+ "message": "ControlNet preprocessors initialized successfully" if success else "Failed to initialize ControlNet preprocessors"
332
+ }
333
+ except Exception as e:
334
+ logger.error(f"Failed to initialize ControlNet preprocessors: {e}")
335
+ raise HTTPException(status_code=500, detail=f"Failed to initialize ControlNet preprocessors: {str(e)}")
336
+
337
+ @app.get("/api/controlnet/preprocessors")
338
+ async def get_controlnet_preprocessors():
339
+ """Get available ControlNet preprocessors"""
340
+ try:
341
+ from ..core.utils.controlnet_preprocessors import controlnet_preprocessor
342
+ return {
343
+ "available_types": controlnet_preprocessor.get_available_types(),
344
+ "available": controlnet_preprocessor.is_available(),
345
+ "initialized": controlnet_preprocessor.is_initialized(),
346
+ "description": {
347
+ "canny": "Edge detection using Canny algorithm",
348
+ "depth": "Depth estimation for depth-based control",
349
+ "openpose": "Human pose detection for pose control",
350
+ "scribble": "Scribble/sketch detection for artistic control",
351
+ "hed": "Holistically-nested edge detection",
352
+ "mlsd": "Mobile line segment detection",
353
+ "normal": "Surface normal estimation",
354
+ "lineart": "Line art detection",
355
+ "lineart_anime": "Anime-style line art detection",
356
+ "shuffle": "Content shuffling for style transfer"
357
+ }
358
+ }
359
+ except Exception as e:
360
+ logger.error(f"Failed to get ControlNet preprocessors: {e}")
361
+ raise HTTPException(status_code=500, detail=f"Failed to get ControlNet preprocessors: {str(e)}")
362
+
363
+ @app.post("/api/controlnet/preprocess")
364
+ async def preprocess_control_image(
365
+ control_type: str = Form(...),
366
+ image: UploadFile = File(...)
367
+ ):
368
+ """Preprocess image for ControlNet"""
369
+ try:
370
+ from ..core.utils.controlnet_preprocessors import controlnet_preprocessor
371
+
372
+ # Initialize ControlNet preprocessors if needed
373
+ if not controlnet_preprocessor.is_initialized():
374
+ logger.info("Initializing ControlNet preprocessors for image preprocessing...")
375
+ if not controlnet_preprocessor.initialize():
376
+ raise HTTPException(status_code=500, detail="Failed to initialize ControlNet preprocessors")
377
+
378
+ # Read uploaded image
379
+ image_data = await image.read()
380
+ input_image = Image.open(io.BytesIO(image_data)).convert('RGB')
381
+
382
+ # Preprocess image
383
+ processed_image = controlnet_preprocessor.preprocess(input_image, control_type)
384
+
385
+ # Convert PIL image to bytes
386
+ img_byte_arr = io.BytesIO()
387
+ processed_image.save(img_byte_arr, format='PNG')
388
+ img_byte_arr = img_byte_arr.getvalue()
389
+
390
+ return Response(content=img_byte_arr, media_type="image/png")
391
+
392
+ except Exception as e:
393
+ logger.error(f"Image preprocessing failed: {e}")
394
+ raise HTTPException(status_code=500, detail=f"Image preprocessing failed: {str(e)}")
395
+
251
396
  # Health check endpoints
252
397
  @app.get("/api/health")
253
398
  async def health_check():
@@ -5,8 +5,9 @@ import logging
5
5
  from typing import Optional
6
6
  from rich.console import Console
7
7
  from rich.table import Table
8
- from rich.progress import Progress, SpinnerColumn, TextColumn
8
+ from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, DownloadColumn, TransferSpeedColumn, TimeRemainingColumn
9
9
  from rich import print as rprint
10
+ import time
10
11
 
11
12
  from ..core.models.manager import model_manager
12
13
  from ..core.config.settings import settings
@@ -14,6 +15,44 @@ from ..api.server import run_server
14
15
 
15
16
  console = Console()
16
17
 
18
+ class OllamaStyleProgress:
19
+ """Enhanced progress tracker that mimics Ollama's progress display"""
20
+
21
+ def __init__(self, console: Console):
22
+ self.console = console
23
+ self.last_message = ""
24
+
25
+ def update(self, message: str):
26
+ """Update progress with a message"""
27
+ # Skip duplicate messages
28
+ if message == self.last_message:
29
+ return
30
+
31
+ self.last_message = message
32
+
33
+ # Handle different types of messages
34
+ if message.startswith("pulling ") and ":" in message and "%" in message:
35
+ # This is a file progress message from download_utils
36
+ # Format: "pulling e6a7edc1a4d7: 12% ▕██ ▏ 617 MB/5200 MB 44 MB/s 1m44s"
37
+ self.console.print(message)
38
+ elif message.startswith("pulling manifest"):
39
+ self.console.print(message)
40
+ elif message.startswith("📦 Repository:"):
41
+ # Repository info
42
+ self.console.print(f"[dim]{message}[/dim]")
43
+ elif message.startswith("📁 Found"):
44
+ # Existing files info
45
+ self.console.print(f"[dim]{message}[/dim]")
46
+ elif message.startswith("✅") and "download completed" in message:
47
+ self.console.print(f"[green]{message}[/green]")
48
+ elif message.startswith("❌"):
49
+ self.console.print(f"[red]{message}[/red]")
50
+ elif message.startswith("⚠️"):
51
+ self.console.print(f"[yellow]{message}[/yellow]")
52
+ else:
53
+ # For other messages, print with dimmed style
54
+ self.console.print(f"[dim]{message}[/dim]")
55
+
17
56
  @click.group()
18
57
  @click.option('--verbose', '-v', is_flag=True, help='Enable verbose output')
19
58
  def cli(verbose):
@@ -30,24 +69,26 @@ def pull(model_name: str, force: bool):
30
69
  """Download model"""
31
70
  rprint(f"[blue]Downloading model: {model_name}[/blue]")
32
71
 
33
- with Progress(
34
- SpinnerColumn(),
35
- TextColumn("[progress.description]{task.description}"),
36
- console=console
37
- ) as progress:
38
- task = progress.add_task(f"Downloading {model_name}...", total=None)
39
-
40
- def progress_callback(message: str):
41
- """Update progress display with download status"""
42
- progress.update(task, description=message)
43
-
72
+ # Use the new Ollama-style progress tracker
73
+ progress_tracker = OllamaStyleProgress(console)
74
+
75
+ def progress_callback(message: str):
76
+ """Enhanced progress callback with Ollama-style display"""
77
+ progress_tracker.update(message)
78
+
79
+ try:
44
80
  if model_manager.pull_model(model_name, force=force, progress_callback=progress_callback):
45
- progress.update(task, description=f"✅ {model_name} download completed")
81
+ progress_tracker.update("✅ download completed")
46
82
  rprint(f"[green]Model {model_name} downloaded successfully![/green]")
47
83
  else:
48
- progress.update(task, description=f"❌ {model_name} download failed")
49
84
  rprint(f"[red]Model {model_name} download failed![/red]")
50
85
  sys.exit(1)
86
+ except KeyboardInterrupt:
87
+ rprint("\n[yellow]Download cancelled by user[/yellow]")
88
+ sys.exit(1)
89
+ except Exception as e:
90
+ rprint(f"[red]Download failed: {str(e)}[/red]")
91
+ sys.exit(1)
51
92
 
52
93
  @cli.command()
53
94
  @click.argument('model_name')
@@ -226,21 +267,280 @@ def check(model_name: str, list: bool):
226
267
  rprint("[dim] ollamadiffuser check --list[/dim]")
227
268
  return
228
269
 
229
- # Import and run the check function
270
+ # Check model download status directly
271
+ status = _check_download_status(model_name)
272
+
273
+ rprint("\n" + "="*60)
274
+
275
+ if status is True:
276
+ rprint(f"[green]🎉 {model_name} is ready to use![/green]")
277
+ rprint(f"\n[blue]💡 You can now run:[/blue]")
278
+ rprint(f" [cyan]ollamadiffuser run {model_name}[/cyan]")
279
+ elif status == "needs_config":
280
+ rprint(f"[yellow]⚠️ {model_name} files are complete but model needs configuration[/yellow]")
281
+ rprint(f"\n[blue]💡 Try reinstalling:[/blue]")
282
+ rprint(f" [cyan]ollamadiffuser pull {model_name} --force[/cyan]")
283
+ elif status == "downloading":
284
+ rprint(f"[yellow]🔄 {model_name} is currently downloading[/yellow]")
285
+ rprint(f"\n[blue]💡 Wait for download to complete or check progress[/blue]")
286
+ elif status == "incomplete":
287
+ rprint(f"[yellow]⚠️ Download is incomplete[/yellow]")
288
+ rprint(f"\n[blue]💡 Resume download with:[/blue]")
289
+ rprint(f" [cyan]ollamadiffuser pull {model_name}[/cyan]")
290
+ rprint(f"\n[blue]💡 Or force fresh download with:[/blue]")
291
+ rprint(f" [cyan]ollamadiffuser pull {model_name} --force[/cyan]")
292
+ else:
293
+ rprint(f"[red]❌ {model_name} is not downloaded[/red]")
294
+ rprint(f"\n[blue]💡 Download with:[/blue]")
295
+ rprint(f" [cyan]ollamadiffuser pull {model_name}[/cyan]")
296
+
297
+ _show_model_specific_help(model_name)
298
+
299
+ rprint(f"\n[dim]📚 For more help: ollamadiffuser --help[/dim]")
300
+
301
+
302
+ def _check_download_status(model_name: str):
303
+ """Check the current download status of any model"""
304
+ from ..core.utils.download_utils import check_download_integrity, get_repo_file_list, format_size
230
305
  import subprocess
231
- import sys
232
- from pathlib import Path
233
306
 
234
- # Run the check script
235
- script_path = Path(__file__).parent.parent.parent / "examples" / "check_model_download.py"
307
+ rprint(f"[blue]🔍 Checking {model_name} download status...[/blue]\n")
308
+
309
+ # Check if model is in registry
310
+ if model_name not in model_manager.model_registry:
311
+ rprint(f"[red]❌ {model_name} not found in model registry[/red]")
312
+ available_models = model_manager.list_available_models()
313
+ rprint(f"[blue]📋 Available models: {', '.join(available_models)}[/blue]")
314
+ return False
315
+
316
+ model_info = model_manager.model_registry[model_name]
317
+ repo_id = model_info["repo_id"]
318
+ model_path = settings.get_model_path(model_name)
319
+
320
+ rprint(f"[cyan]📦 Model: {model_name}[/cyan]")
321
+ rprint(f"[cyan]🔗 Repository: {repo_id}[/cyan]")
322
+ rprint(f"[cyan]📁 Local path: {model_path}[/cyan]")
323
+
324
+ # Show model-specific info
325
+ license_info = model_info.get("license_info", {})
326
+ if license_info:
327
+ rprint(f"[yellow]📄 License: {license_info.get('type', 'Unknown')}[/yellow]")
328
+ rprint(f"[yellow]🔑 HF Token Required: {'Yes' if license_info.get('requires_agreement', False) else 'No'}[/yellow]")
329
+ rprint(f"[yellow]💼 Commercial Use: {'Allowed' if license_info.get('commercial_use', False) else 'Not Allowed'}[/yellow]")
330
+
331
+ # Show optimal parameters
332
+ params = model_info.get("parameters", {})
333
+ if params:
334
+ rprint(f"[green]⚡ Optimal Settings:[/green]")
335
+ rprint(f" Steps: {params.get('num_inference_steps', 'N/A')}")
336
+ rprint(f" Guidance: {params.get('guidance_scale', 'N/A')}")
337
+ if 'max_sequence_length' in params:
338
+ rprint(f" Max Seq Length: {params['max_sequence_length']}")
339
+
340
+ rprint()
341
+
342
+ # Check if directory exists
343
+ if not model_path.exists():
344
+ rprint("[yellow]📂 Status: Not downloaded[/yellow]")
345
+ return False
346
+
347
+ # Get repository file list
348
+ rprint("[blue]🌐 Getting repository information...[/blue]")
349
+ try:
350
+ file_sizes = get_repo_file_list(repo_id)
351
+ total_expected_size = sum(file_sizes.values())
352
+ total_files_expected = len(file_sizes)
353
+
354
+ rprint(f"[blue]📊 Expected: {total_files_expected} files, {format_size(total_expected_size)} total[/blue]")
355
+ except Exception as e:
356
+ rprint(f"[yellow]⚠️ Could not get repository info: {e}[/yellow]")
357
+ file_sizes = {}
358
+ total_expected_size = 0
359
+ total_files_expected = 0
360
+
361
+ # Check local files
362
+ local_files = []
363
+ local_size = 0
364
+
365
+ for file_path in model_path.rglob('*'):
366
+ if file_path.is_file():
367
+ rel_path = file_path.relative_to(model_path)
368
+ file_size = file_path.stat().st_size
369
+ local_files.append((str(rel_path), file_size))
370
+ local_size += file_size
371
+
372
+ rprint(f"[blue]💾 Downloaded: {len(local_files)} files, {format_size(local_size)} total[/blue]")
373
+
374
+ if total_expected_size > 0:
375
+ progress_percent = (local_size / total_expected_size) * 100
376
+ rprint(f"[blue]📈 Progress: {progress_percent:.1f}%[/blue]")
377
+
378
+ rprint()
379
+
380
+ # Check for missing files
381
+ if file_sizes:
382
+ # Check if we have size information from the API
383
+ has_size_info = any(size > 0 for size in file_sizes.values())
384
+
385
+ if has_size_info:
386
+ # Normal case: we have size information, do detailed comparison
387
+ missing_files = []
388
+ incomplete_files = []
389
+
390
+ for expected_file, expected_size in file_sizes.items():
391
+ local_file_path = model_path / expected_file
392
+ if not local_file_path.exists():
393
+ missing_files.append(expected_file)
394
+ elif expected_size > 0 and local_file_path.stat().st_size != expected_size:
395
+ local_size_actual = local_file_path.stat().st_size
396
+ incomplete_files.append((expected_file, local_size_actual, expected_size))
397
+
398
+ if missing_files:
399
+ rprint(f"[red]❌ Missing files ({len(missing_files)}):[/red]")
400
+ for missing_file in missing_files[:10]: # Show first 10
401
+ rprint(f" - {missing_file}")
402
+ if len(missing_files) > 10:
403
+ rprint(f" ... and {len(missing_files) - 10} more")
404
+ rprint()
405
+
406
+ if incomplete_files:
407
+ rprint(f"[yellow]⚠️ Incomplete files ({len(incomplete_files)}):[/yellow]")
408
+ for incomplete_file, actual_size, expected_size in incomplete_files[:5]:
409
+ rprint(f" - {incomplete_file}: {format_size(actual_size)}/{format_size(expected_size)}")
410
+ if len(incomplete_files) > 5:
411
+ rprint(f" ... and {len(incomplete_files) - 5} more")
412
+ rprint()
413
+
414
+ if not missing_files and not incomplete_files:
415
+ rprint("[green]✅ All files present and complete![/green]")
416
+
417
+ # Check integrity
418
+ rprint("[blue]🔍 Checking download integrity...[/blue]")
419
+ if check_download_integrity(str(model_path), repo_id):
420
+ rprint("[green]✅ Download integrity verified![/green]")
421
+
422
+ # Check if model is in configuration
423
+ if model_manager.is_model_installed(model_name):
424
+ rprint("[green]✅ Model is properly configured[/green]")
425
+ return True
426
+ else:
427
+ rprint("[yellow]⚠️ Model files complete but not in configuration[/yellow]")
428
+ return "needs_config"
429
+ else:
430
+ rprint("[red]❌ Download integrity check failed[/red]")
431
+ return False
432
+ else:
433
+ rprint("[yellow]⚠️ Download is incomplete[/yellow]")
434
+ return "incomplete"
435
+ else:
436
+ # No size information available from API (common with gated repos)
437
+ rprint("[blue]ℹ️ Repository API doesn't provide file sizes (common with gated models)[/blue]")
438
+ rprint("[blue]🔍 Checking essential model files instead...[/blue]")
439
+
440
+ # Check for essential model files
441
+ # Determine model type based on repo_id
442
+ is_controlnet = 'controlnet' in repo_id.lower()
443
+
444
+ if is_controlnet:
445
+ # ControlNet models have different essential files
446
+ essential_files = ['config.json']
447
+ essential_dirs = [] # ControlNet models don't have complex directory structure
448
+ else:
449
+ # Regular diffusion models
450
+ essential_files = ['model_index.json']
451
+ essential_dirs = ['transformer', 'text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'vae', 'scheduler']
452
+
453
+ missing_essential = []
454
+ for essential_file in essential_files:
455
+ if not (model_path / essential_file).exists():
456
+ missing_essential.append(essential_file)
457
+
458
+ existing_dirs = []
459
+ for essential_dir in essential_dirs:
460
+ if (model_path / essential_dir).exists():
461
+ existing_dirs.append(essential_dir)
462
+
463
+ if missing_essential:
464
+ rprint(f"[red]❌ Missing essential files: {', '.join(missing_essential)}[/red]")
465
+ return "incomplete"
466
+
467
+ if existing_dirs:
468
+ rprint(f"[green]✅ Found model components: {', '.join(existing_dirs)}[/green]")
469
+
470
+ # Check integrity
471
+ rprint("[blue]🔍 Checking download integrity...[/blue]")
472
+ if check_download_integrity(str(model_path), repo_id):
473
+ rprint("[green]✅ Download integrity verified![/green]")
474
+
475
+ # Check if model is in configuration
476
+ if model_manager.is_model_installed(model_name):
477
+ rprint("[green]✅ Model is properly configured and functional[/green]")
478
+ return True
479
+ else:
480
+ rprint("[yellow]⚠️ Model files complete but not in configuration[/yellow]")
481
+ return "needs_config"
482
+ else:
483
+ rprint("[red]❌ Download integrity check failed[/red]")
484
+ return False
485
+
486
+ # Check if download process is running
487
+ rprint("[blue]🔍 Checking for active download processes...[/blue]")
236
488
  try:
237
- result = subprocess.run([sys.executable, str(script_path), model_name],
238
- capture_output=True, text=True)
239
- rprint(result.stdout)
240
- if result.stderr:
241
- rprint(f"[red]{result.stderr}[/red]")
489
+ result = subprocess.run(['ps', 'aux'], capture_output=True, text=True)
490
+ if f'ollamadiffuser pull {model_name}' in result.stdout:
491
+ rprint("[yellow]🔄 Download process is currently running[/yellow]")
492
+ return "downloading"
493
+ else:
494
+ rprint("[blue]💤 No active download process found[/blue]")
242
495
  except Exception as e:
243
- rprint(f"[red] Error running check: {e}[/red]")
496
+ rprint(f"[yellow]⚠️ Could not check processes: {e}[/yellow]")
497
+
498
+ return "incomplete"
499
+
500
+
501
+ def _show_model_specific_help(model_name: str):
502
+ """Show model-specific help and recommendations"""
503
+ model_info = model_manager.get_model_info(model_name)
504
+ if not model_info:
505
+ return
506
+
507
+ rprint(f"\n[bold blue]💡 {model_name} Specific Tips:[/bold blue]")
508
+
509
+ # License-specific help
510
+ license_info = model_info.get("license_info", {})
511
+ if license_info.get("requires_agreement", False):
512
+ rprint(f" [yellow]🔑 Requires HuggingFace token and license agreement[/yellow]")
513
+ rprint(f" [blue]📝 Visit: https://huggingface.co/{model_info['repo_id']}[/blue]")
514
+ rprint(f" [cyan]🔧 Set token: export HF_TOKEN=your_token_here[/cyan]")
515
+ else:
516
+ rprint(f" [green]✅ No HuggingFace token required![/green]")
517
+
518
+ # Model-specific optimizations
519
+ if "schnell" in model_name.lower():
520
+ rprint(f" [green]⚡ FLUX.1-schnell is 12x faster than FLUX.1-dev[/green]")
521
+ rprint(f" [green]🎯 Optimized for 4-step generation[/green]")
522
+ rprint(f" [green]💼 Commercial use allowed (Apache 2.0)[/green]")
523
+ elif "flux.1-dev" in model_name.lower():
524
+ rprint(f" [blue]🎨 Best quality FLUX model[/blue]")
525
+ rprint(f" [blue]🔬 Requires 50 steps for optimal results[/blue]")
526
+ rprint(f" [yellow]⚠️ Non-commercial license only[/yellow]")
527
+ elif "stable-diffusion-1.5" in model_name.lower():
528
+ rprint(f" [green]🚀 Great for learning and quick tests[/green]")
529
+ rprint(f" [green]💾 Smallest model, runs on most hardware[/green]")
530
+ elif "stable-diffusion-3.5" in model_name.lower():
531
+ rprint(f" [green]🏆 Excellent quality-to-speed ratio[/green]")
532
+ rprint(f" [green]🔄 Great LoRA ecosystem[/green]")
533
+
534
+ # Hardware recommendations
535
+ hw_req = model_info.get("hardware_requirements", {})
536
+ if hw_req:
537
+ min_vram = hw_req.get("min_vram_gb", 0)
538
+ if min_vram >= 12:
539
+ rprint(f" [yellow]🖥️ Requires high-end GPU (RTX 4070+ or M2 Pro+)[/yellow]")
540
+ elif min_vram >= 8:
541
+ rprint(f" [blue]🖥️ Requires mid-range GPU (RTX 3080+ or M1 Pro+)[/blue]")
542
+ else:
543
+ rprint(f" [green]🖥️ Runs on most modern GPUs[/green]")
244
544
 
245
545
  @cli.command()
246
546
  @click.argument('model_name')