ollamadiffuser 1.2.3__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. ollamadiffuser/__init__.py +1 -1
  2. ollamadiffuser/api/server.py +312 -312
  3. ollamadiffuser/cli/config_commands.py +119 -0
  4. ollamadiffuser/cli/lora_commands.py +169 -0
  5. ollamadiffuser/cli/main.py +85 -1233
  6. ollamadiffuser/cli/model_commands.py +664 -0
  7. ollamadiffuser/cli/recommend_command.py +205 -0
  8. ollamadiffuser/cli/registry_commands.py +197 -0
  9. ollamadiffuser/core/config/model_registry.py +562 -11
  10. ollamadiffuser/core/config/settings.py +24 -2
  11. ollamadiffuser/core/inference/__init__.py +5 -0
  12. ollamadiffuser/core/inference/base.py +182 -0
  13. ollamadiffuser/core/inference/engine.py +204 -1405
  14. ollamadiffuser/core/inference/strategies/__init__.py +1 -0
  15. ollamadiffuser/core/inference/strategies/controlnet_strategy.py +170 -0
  16. ollamadiffuser/core/inference/strategies/flux_strategy.py +136 -0
  17. ollamadiffuser/core/inference/strategies/generic_strategy.py +164 -0
  18. ollamadiffuser/core/inference/strategies/gguf_strategy.py +113 -0
  19. ollamadiffuser/core/inference/strategies/hidream_strategy.py +104 -0
  20. ollamadiffuser/core/inference/strategies/sd15_strategy.py +134 -0
  21. ollamadiffuser/core/inference/strategies/sd3_strategy.py +80 -0
  22. ollamadiffuser/core/inference/strategies/sdxl_strategy.py +131 -0
  23. ollamadiffuser/core/inference/strategies/video_strategy.py +108 -0
  24. ollamadiffuser/mcp/__init__.py +0 -0
  25. ollamadiffuser/mcp/server.py +184 -0
  26. ollamadiffuser/ui/templates/index.html +62 -1
  27. ollamadiffuser/ui/web.py +116 -54
  28. {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.0.dist-info}/METADATA +321 -108
  29. ollamadiffuser-2.0.0.dist-info/RECORD +61 -0
  30. {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.0.dist-info}/WHEEL +1 -1
  31. {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.0.dist-info}/entry_points.txt +1 -0
  32. ollamadiffuser/core/models/registry.py +0 -384
  33. ollamadiffuser/ui/samples/.DS_Store +0 -0
  34. ollamadiffuser-1.2.3.dist-info/RECORD +0 -45
  35. {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.0.dist-info}/licenses/LICENSE +0 -0
  36. {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.0.dist-info}/top_level.txt +0 -0
@@ -1,1309 +1,161 @@
1
- #!/usr/bin/env python3
2
- import click
1
+ """OllamaDiffuser CLI - Main entry point"""
2
+
3
3
  import sys
4
4
  import logging
5
- from typing import Optional
6
- from rich.console import Console
7
- from rich.table import Table
8
- from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, DownloadColumn, TransferSpeedColumn, TimeRemainingColumn
5
+
6
+ import click
9
7
  from rich import print as rprint
10
- import time
11
8
 
12
9
  from .. import __version__, print_version
13
- from ..core.models.manager import model_manager
14
10
  from ..core.config.settings import settings
15
- from ..core.config.model_registry import model_registry
16
11
  from ..api.server import run_server
17
12
 
18
- console = Console()
19
-
20
- class OllamaStyleProgress:
21
- """Enhanced progress tracker that mimics Ollama's progress display"""
22
-
23
- def __init__(self, console: Console):
24
- self.console = console
25
- self.last_message = ""
26
-
27
- def update(self, message: str):
28
- """Update progress with a message"""
29
- # Skip duplicate messages
30
- if message == self.last_message:
31
- return
32
-
33
- self.last_message = message
34
-
35
- # Handle different types of messages
36
- if message.startswith("pulling ") and ":" in message and "%" in message:
37
- # This is a file progress message from download_utils
38
- # Format: "pulling e6a7edc1a4d7: 12% ▕██ ▏ 617 MB/5200 MB 44 MB/s 1m44s"
39
- self.console.print(message)
40
- elif message.startswith("pulling manifest"):
41
- self.console.print(message)
42
- elif message.startswith("📦 Repository:"):
43
- # Repository info
44
- self.console.print(f"[dim]{message}[/dim]")
45
- elif message.startswith("📁 Found"):
46
- # Existing files info
47
- self.console.print(f"[dim]{message}[/dim]")
48
- elif message.startswith("✅") and "download completed" in message:
49
- self.console.print(f"[green]{message}[/green]")
50
- elif message.startswith("❌"):
51
- self.console.print(f"[red]{message}[/red]")
52
- elif message.startswith("⚠️"):
53
- self.console.print(f"[yellow]{message}[/yellow]")
54
- else:
55
- # For other messages, print with dimmed style
56
- self.console.print(f"[dim]{message}[/dim]")
57
13
 
58
14
  @click.group(invoke_without_command=True)
59
- @click.option('--verbose', '-v', is_flag=True, help='Enable verbose output')
60
- @click.option('--version', '-V', is_flag=True, help='Show version and exit')
61
- @click.option('--mode', type=click.Choice(['cli', 'api', 'ui']), help='Running mode: cli (command line), api (API server), ui (Web interface)')
62
- @click.option('--host', default=None, help='Server host address (for api/ui modes)')
63
- @click.option('--port', type=int, default=None, help='Server port (for api/ui modes)')
15
+ @click.option("--verbose", "-v", is_flag=True, help="Enable verbose output")
16
+ @click.option("--version", "-V", is_flag=True, help="Show version and exit")
17
+ @click.option(
18
+ "--mode",
19
+ type=click.Choice(["cli", "api", "ui"]),
20
+ help="Running mode: cli, api (server), ui (web interface)",
21
+ )
22
+ @click.option("--host", default=None, help="Server host address (for api/ui modes)")
23
+ @click.option("--port", type=int, default=None, help="Server port (for api/ui modes)")
64
24
  @click.pass_context
65
25
  def cli(ctx, verbose, version, mode, host, port):
66
26
  """OllamaDiffuser - Image generation model management tool"""
67
27
  if version:
68
28
  print_version()
69
29
  sys.exit(0)
70
-
30
+
71
31
  if verbose:
72
32
  logging.basicConfig(level=logging.DEBUG)
73
33
  else:
74
34
  logging.basicConfig(level=logging.WARNING)
75
-
76
- # Handle mode-based execution
35
+
77
36
  if mode:
78
- if mode == 'api':
37
+ if mode == "api":
79
38
  rprint("[blue]Starting OllamaDiffuser API server...[/blue]")
80
39
  run_server(host=host, port=port)
81
40
  sys.exit(0)
82
- elif mode == 'ui':
41
+ elif mode == "ui":
83
42
  rprint("[blue]Starting OllamaDiffuser Web UI...[/blue]")
84
43
  import uvicorn
85
44
  from ..ui.web import create_ui_app
45
+
86
46
  app = create_ui_app()
87
47
  ui_host = host or settings.server.host
88
- ui_port = port or (settings.server.port + 1) # Web UI uses different port
48
+ ui_port = port or (settings.server.port + 1)
89
49
  uvicorn.run(app, host=ui_host, port=ui_port)
90
50
  sys.exit(0)
91
- elif mode == 'cli':
92
- # Continue with normal CLI processing
93
- pass
94
-
95
- # If no subcommand is provided and no mode/version flag, show help
51
+
96
52
  if ctx.invoked_subcommand is None and not version and not mode:
97
53
  rprint(ctx.get_help())
98
54
  sys.exit(0)
99
55
 
100
- @cli.command()
101
- @click.argument('model_name')
102
- @click.option('--force', '-f', is_flag=True, help='Force re-download')
103
- def pull(model_name: str, force: bool):
104
- """Download model"""
105
- rprint(f"[blue]Downloading model: {model_name}[/blue]")
106
-
107
- # Use the new Ollama-style progress tracker
108
- progress_tracker = OllamaStyleProgress(console)
109
-
110
- def progress_callback(message: str):
111
- """Enhanced progress callback with Ollama-style display"""
112
- progress_tracker.update(message)
113
-
114
- try:
115
- if model_manager.pull_model(model_name, force=force, progress_callback=progress_callback):
116
- progress_tracker.update("✅ download completed")
117
- rprint(f"[green]Model {model_name} downloaded successfully![/green]")
118
- else:
119
- rprint(f"[red]Model {model_name} download failed![/red]")
120
- sys.exit(1)
121
- except KeyboardInterrupt:
122
- rprint("\n[yellow]Download cancelled by user[/yellow]")
123
- sys.exit(1)
124
- except Exception as e:
125
- rprint(f"[red]Download failed: {str(e)}[/red]")
126
- sys.exit(1)
127
-
128
- @cli.command()
129
- @click.argument('model_name')
130
- @click.option('--host', '-h', default=None, help='Server host address')
131
- @click.option('--port', '-p', default=None, type=int, help='Server port')
132
- def run(model_name: str, host: Optional[str], port: Optional[int]):
133
- """Run model service"""
134
- rprint(f"[blue]Starting model service: {model_name}[/blue]")
135
-
136
- # Check if model is installed
137
- if not model_manager.is_model_installed(model_name):
138
- rprint(f"[red]Model {model_name} is not installed. Please run first: ollamadiffuser pull {model_name}[/red]")
139
- sys.exit(1)
140
-
141
- # Load model
142
- rprint("[yellow]Loading model...[/yellow]")
143
- if not model_manager.load_model(model_name):
144
- rprint(f"[red]Failed to load model {model_name}![/red]")
145
- sys.exit(1)
146
-
147
- rprint(f"[green]Model {model_name} loaded successfully![/green]")
148
-
149
- # Start server
150
- try:
151
- run_server(host=host, port=port)
152
- except KeyboardInterrupt:
153
- rprint("\n[yellow]Server stopped[/yellow]")
154
- model_manager.unload_model()
155
- # Clear the current model from settings when server stops
156
- settings.current_model = None
157
- settings.save_config()
158
56
 
159
- @cli.command()
160
- @click.option('--hardware', '-hw', is_flag=True, help='Show hardware requirements')
161
- def list(hardware: bool):
162
- """List installed models only"""
163
- installed_models = model_manager.list_installed_models()
164
- current_model = model_manager.get_current_model()
165
-
166
- if not installed_models:
167
- rprint("[yellow]No models installed[/yellow]")
168
- rprint("\n[dim]💡 Download models with: ollamadiffuser pull <model-name>[/dim]")
169
- rprint("[dim]💡 See all available models: ollamadiffuser registry list[/dim]")
170
- rprint("[dim]💡 See only available models: ollamadiffuser registry list --available-only[/dim]")
171
- return
172
-
173
- if hardware:
174
- # Show detailed hardware requirements
175
- for model_name in installed_models:
176
- info = model_manager.get_model_info(model_name)
177
- if not info:
178
- continue
179
-
180
- # Check installation status
181
- status = "✅ Installed"
182
- if model_name == current_model:
183
- status += " (current)"
184
- size = info.get('size', 'Unknown')
185
-
186
- # Create individual table for each model
187
- table = Table(title=f"[bold cyan]{model_name}[/bold cyan] - {status}")
188
- table.add_column("Property", style="yellow", no_wrap=True)
189
- table.add_column("Value", style="white")
190
-
191
- # Basic info
192
- table.add_row("Type", info.get('model_type', 'Unknown'))
193
- table.add_row("Size", size)
194
-
195
- # Hardware requirements
196
- hw_req = info.get('hardware_requirements', {})
197
- if hw_req:
198
- table.add_row("Min VRAM", f"{hw_req.get('min_vram_gb', 'Unknown')} GB")
199
- table.add_row("Recommended VRAM", f"{hw_req.get('recommended_vram_gb', 'Unknown')} GB")
200
- table.add_row("Min RAM", f"{hw_req.get('min_ram_gb', 'Unknown')} GB")
201
- table.add_row("Recommended RAM", f"{hw_req.get('recommended_ram_gb', 'Unknown')} GB")
202
- table.add_row("Disk Space", f"{hw_req.get('disk_space_gb', 'Unknown')} GB")
203
- table.add_row("Supported Devices", ", ".join(hw_req.get('supported_devices', [])))
204
- table.add_row("Performance Notes", hw_req.get('performance_notes', 'N/A'))
205
-
206
- console.print(table)
207
- console.print() # Add spacing between models
208
- else:
209
- # Show compact table
210
- table = Table(title="Installed Models")
211
- table.add_column("Model Name", style="cyan", no_wrap=True)
212
- table.add_column("Status", style="green")
213
- table.add_column("Size", style="blue")
214
- table.add_column("Type", style="magenta")
215
- table.add_column("Min VRAM", style="yellow")
216
-
217
- for model_name in installed_models:
218
- # Check installation status
219
- status = "✅ Installed"
220
- if model_name == current_model:
221
- status += " (current)"
222
-
223
- # Get model information
224
- info = model_manager.get_model_info(model_name)
225
- size = info.get('size', 'Unknown') if info else 'Unknown'
226
- model_type = info.get('model_type', 'Unknown') if info else 'Unknown'
227
-
228
- # Get hardware requirements
229
- hw_req = info.get('hardware_requirements', {}) if info else {}
230
- min_vram = f"{hw_req.get('min_vram_gb', '?')} GB" if hw_req else "Unknown"
231
-
232
- table.add_row(model_name, status, size, model_type, min_vram)
233
-
234
- console.print(table)
235
-
236
- # Get counts for summary
237
- available_models = model_registry.get_available_models()
238
- external_models = model_registry.get_external_api_models_only()
239
-
240
- console.print(f"\n[dim]💡 Installed: {len(installed_models)} models[/dim]")
241
- console.print(f"[dim]💡 Available for download: {len(available_models)} models[/dim]")
242
- if external_models:
243
- console.print(f"[dim]💡 External API models: {len(external_models)} models[/dim]")
244
- console.print("\n[dim]💡 Use --hardware flag to see detailed hardware requirements[/dim]")
245
- console.print("[dim]💡 See all models: ollamadiffuser registry list[/dim]")
246
- console.print("[dim]💡 See available models: ollamadiffuser registry list --available-only[/dim]")
247
-
248
- @cli.command()
249
- @click.argument('model_name')
250
- def show(model_name: str):
251
- """Show model detailed information"""
252
- info = model_manager.get_model_info(model_name)
253
-
254
- if info is None:
255
- rprint(f"[red]Model {model_name} does not exist[/red]")
256
- sys.exit(1)
257
-
258
- rprint(f"[bold cyan]Model Information: {model_name}[/bold cyan]")
259
- rprint(f"Type: {info.get('model_type', 'Unknown')}")
260
- rprint(f"Variant: {info.get('variant', 'Unknown')}")
261
- rprint(f"Installed: {'Yes' if info.get('installed', False) else 'No'}")
262
-
263
- if info.get('installed', False):
264
- rprint(f"Local Path: {info.get('local_path', 'Unknown')}")
265
- rprint(f"Size: {info.get('size', 'Unknown')}")
266
-
267
- # Hardware requirements
268
- if 'hardware_requirements' in info and info['hardware_requirements']:
269
- hw_req = info['hardware_requirements']
270
- rprint("\n[bold]Hardware Requirements:[/bold]")
271
- rprint(f" Min VRAM: {hw_req.get('min_vram_gb', 'Unknown')} GB")
272
- rprint(f" Recommended VRAM: {hw_req.get('recommended_vram_gb', 'Unknown')} GB")
273
- rprint(f" Min RAM: {hw_req.get('min_ram_gb', 'Unknown')} GB")
274
- rprint(f" Recommended RAM: {hw_req.get('recommended_ram_gb', 'Unknown')} GB")
275
- rprint(f" Disk Space: {hw_req.get('disk_space_gb', 'Unknown')} GB")
276
- rprint(f" Supported Devices: {', '.join(hw_req.get('supported_devices', []))}")
277
- if hw_req.get('performance_notes'):
278
- rprint(f" Performance Notes: {hw_req.get('performance_notes')}")
279
-
280
- if 'parameters' in info and info['parameters']:
281
- rprint("\n[bold]Default Parameters:[/bold]")
282
- for key, value in info['parameters'].items():
283
- rprint(f" {key}: {value}")
284
-
285
- if 'components' in info and info['components']:
286
- rprint("\n[bold]Components:[/bold]")
287
- for key, value in info['components'].items():
288
- rprint(f" {key}: {value}")
57
+ # --- Register model commands ---
58
+ from .model_commands import (
59
+ pull,
60
+ run,
61
+ list,
62
+ show,
63
+ check,
64
+ rm,
65
+ ps,
66
+ serve,
67
+ load,
68
+ unload,
69
+ stop,
70
+ )
289
71
 
290
- @cli.command()
291
- @click.argument('model_name', required=False)
292
- @click.option('--list', '-l', is_flag=True, help='List all available models')
293
- def check(model_name: str, list: bool):
294
- """Check model download status and integrity"""
295
- if list:
296
- rprint("[bold blue]📋 Available Models:[/bold blue]")
297
- available_models = model_manager.list_available_models()
298
- for model in available_models:
299
- model_info = model_manager.get_model_info(model)
300
- status = "✅ Installed" if model_manager.is_model_installed(model) else "⬇️ Available"
301
- license_type = model_info.get("license_info", {}).get("type", "Unknown")
302
- rprint(f" {model:<30} {status:<15} ({license_type})")
303
- return
304
-
305
- if not model_name:
306
- rprint("[bold red]❌ Please specify a model name or use --list[/bold red]")
307
- rprint("[dim]Usage: ollamadiffuser check MODEL_NAME[/dim]")
308
- rprint("[dim] ollamadiffuser check --list[/dim]")
309
- return
310
-
311
- # Check model download status directly
312
- status = _check_download_status(model_name)
313
-
314
- rprint("\n" + "="*60)
315
-
316
- if status is True:
317
- rprint(f"[green]🎉 {model_name} is ready to use![/green]")
318
- rprint(f"\n[blue]💡 You can now run:[/blue]")
319
- rprint(f" [cyan]ollamadiffuser run {model_name}[/cyan]")
320
- elif status == "needs_config":
321
- rprint(f"[yellow]⚠️ {model_name} files are complete but model needs configuration[/yellow]")
322
- rprint(f"\n[blue]💡 Try reinstalling:[/blue]")
323
- rprint(f" [cyan]ollamadiffuser pull {model_name} --force[/cyan]")
324
- elif status == "downloading":
325
- rprint(f"[yellow]🔄 {model_name} is currently downloading[/yellow]")
326
- rprint(f"\n[blue]💡 Wait for download to complete or check progress[/blue]")
327
- elif status == "incomplete":
328
- rprint(f"[yellow]⚠️ Download is incomplete[/yellow]")
329
- rprint(f"\n[blue]💡 Resume download with:[/blue]")
330
- rprint(f" [cyan]ollamadiffuser pull {model_name}[/cyan]")
331
- rprint(f"\n[blue]💡 Or force fresh download with:[/blue]")
332
- rprint(f" [cyan]ollamadiffuser pull {model_name} --force[/cyan]")
333
- else:
334
- rprint(f"[red]❌ {model_name} is not downloaded[/red]")
335
- rprint(f"\n[blue]💡 Download with:[/blue]")
336
- rprint(f" [cyan]ollamadiffuser pull {model_name}[/cyan]")
337
-
338
- _show_model_specific_help(model_name)
339
-
340
- rprint(f"\n[dim]📚 For more help: ollamadiffuser --help[/dim]")
341
-
342
-
343
- def _check_download_status(model_name: str):
344
- """Check the current download status of any model"""
345
- from ..core.utils.download_utils import check_download_integrity, get_repo_file_list, format_size
346
- import subprocess
347
-
348
- rprint(f"[blue]🔍 Checking {model_name} download status...[/blue]\n")
349
-
350
- # Check if model is in registry
351
- if model_name not in model_manager.model_registry:
352
- rprint(f"[red]❌ {model_name} not found in model registry[/red]")
353
- available_models = model_manager.list_available_models()
354
- rprint(f"[blue]📋 Available models: {', '.join(available_models)}[/blue]")
355
- return False
356
-
357
- model_info = model_manager.model_registry[model_name]
358
- repo_id = model_info["repo_id"]
359
- model_path = settings.get_model_path(model_name)
360
-
361
- rprint(f"[cyan]📦 Model: {model_name}[/cyan]")
362
- rprint(f"[cyan]🔗 Repository: {repo_id}[/cyan]")
363
- rprint(f"[cyan]📁 Local path: {model_path}[/cyan]")
364
-
365
- # Show model-specific info
366
- license_info = model_info.get("license_info", {})
367
- if license_info:
368
- rprint(f"[yellow]📄 License: {license_info.get('type', 'Unknown')}[/yellow]")
369
- rprint(f"[yellow]🔑 HF Token Required: {'Yes' if license_info.get('requires_agreement', False) else 'No'}[/yellow]")
370
- rprint(f"[yellow]💼 Commercial Use: {'Allowed' if license_info.get('commercial_use', False) else 'Not Allowed'}[/yellow]")
371
-
372
- # Show optimal parameters
373
- params = model_info.get("parameters", {})
374
- if params:
375
- rprint(f"[green]⚡ Optimal Settings:[/green]")
376
- rprint(f" Steps: {params.get('num_inference_steps', 'N/A')}")
377
- rprint(f" Guidance: {params.get('guidance_scale', 'N/A')}")
378
- if 'max_sequence_length' in params:
379
- rprint(f" Max Seq Length: {params['max_sequence_length']}")
380
-
381
- rprint()
382
-
383
- # Check if directory exists
384
- if not model_path.exists():
385
- rprint("[yellow]📂 Status: Not downloaded[/yellow]")
386
- return False
387
-
388
- # Get repository file list
389
- rprint("[blue]🌐 Getting repository information...[/blue]")
390
- try:
391
- file_sizes = get_repo_file_list(repo_id)
392
- total_expected_size = sum(file_sizes.values())
393
- total_files_expected = len(file_sizes)
394
-
395
- rprint(f"[blue]📊 Expected: {total_files_expected} files, {format_size(total_expected_size)} total[/blue]")
396
- except Exception as e:
397
- rprint(f"[yellow]⚠️ Could not get repository info: {e}[/yellow]")
398
- file_sizes = {}
399
- total_expected_size = 0
400
- total_files_expected = 0
401
-
402
- # Check local files
403
- local_files = []
404
- local_size = 0
405
-
406
- for file_path in model_path.rglob('*'):
407
- if file_path.is_file():
408
- rel_path = file_path.relative_to(model_path)
409
- file_size = file_path.stat().st_size
410
- local_files.append((str(rel_path), file_size))
411
- local_size += file_size
412
-
413
- rprint(f"[blue]💾 Downloaded: {len(local_files)} files, {format_size(local_size)} total[/blue]")
414
-
415
- if total_expected_size > 0:
416
- progress_percent = (local_size / total_expected_size) * 100
417
- rprint(f"[blue]📈 Progress: {progress_percent:.1f}%[/blue]")
418
-
419
- rprint()
420
-
421
- # Check for missing files
422
- if file_sizes:
423
- # Check if we have size information from the API
424
- has_size_info = any(size > 0 for size in file_sizes.values())
425
-
426
- if has_size_info:
427
- # Normal case: we have size information, do detailed comparison
428
- missing_files = []
429
- incomplete_files = []
430
-
431
- for expected_file, expected_size in file_sizes.items():
432
- local_file_path = model_path / expected_file
433
- if not local_file_path.exists():
434
- missing_files.append(expected_file)
435
- elif expected_size > 0 and local_file_path.stat().st_size != expected_size:
436
- local_size_actual = local_file_path.stat().st_size
437
- incomplete_files.append((expected_file, local_size_actual, expected_size))
438
-
439
- if missing_files:
440
- rprint(f"[red]❌ Missing files ({len(missing_files)}):[/red]")
441
- for missing_file in missing_files[:10]: # Show first 10
442
- rprint(f" - {missing_file}")
443
- if len(missing_files) > 10:
444
- rprint(f" ... and {len(missing_files) - 10} more")
445
- rprint()
446
-
447
- if incomplete_files:
448
- rprint(f"[yellow]⚠️ Incomplete files ({len(incomplete_files)}):[/yellow]")
449
- for incomplete_file, actual_size, expected_size in incomplete_files[:5]:
450
- rprint(f" - {incomplete_file}: {format_size(actual_size)}/{format_size(expected_size)}")
451
- if len(incomplete_files) > 5:
452
- rprint(f" ... and {len(incomplete_files) - 5} more")
453
- rprint()
454
-
455
- if not missing_files and not incomplete_files:
456
- rprint("[green]✅ All files present and complete![/green]")
457
-
458
- # Check integrity
459
- rprint("[blue]🔍 Checking download integrity...[/blue]")
460
- if check_download_integrity(str(model_path), repo_id):
461
- rprint("[green]✅ Download integrity verified![/green]")
462
-
463
- # Check if model is in configuration
464
- if model_manager.is_model_installed(model_name):
465
- rprint("[green]✅ Model is properly configured[/green]")
466
- return True
467
- else:
468
- rprint("[yellow]⚠️ Model files complete but not in configuration[/yellow]")
469
- return "needs_config"
470
- else:
471
- rprint("[red]❌ Download integrity check failed[/red]")
472
- return False
473
- else:
474
- rprint("[yellow]⚠️ Download is incomplete[/yellow]")
475
- return "incomplete"
476
- else:
477
- # No size information available from API (common with gated repos)
478
- rprint("[blue]ℹ️ Repository API doesn't provide file sizes (common with gated models)[/blue]")
479
- rprint("[blue]🔍 Checking essential model files instead...[/blue]")
480
-
481
- # Check for essential model files
482
- # Determine model type based on repo_id
483
- is_controlnet = 'controlnet' in repo_id.lower()
484
-
485
- if is_controlnet:
486
- # ControlNet models have different essential files
487
- essential_files = ['config.json']
488
- essential_dirs = [] # ControlNet models don't have complex directory structure
489
- else:
490
- # Regular diffusion models
491
- essential_files = ['model_index.json']
492
- essential_dirs = ['transformer', 'text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'vae', 'scheduler']
493
-
494
- missing_essential = []
495
- for essential_file in essential_files:
496
- if not (model_path / essential_file).exists():
497
- missing_essential.append(essential_file)
498
-
499
- existing_dirs = []
500
- for essential_dir in essential_dirs:
501
- if (model_path / essential_dir).exists():
502
- existing_dirs.append(essential_dir)
503
-
504
- if missing_essential:
505
- rprint(f"[red]❌ Missing essential files: {', '.join(missing_essential)}[/red]")
506
- return "incomplete"
507
-
508
- if existing_dirs:
509
- rprint(f"[green]✅ Found model components: {', '.join(existing_dirs)}[/green]")
510
-
511
- # Check integrity
512
- rprint("[blue]🔍 Checking download integrity...[/blue]")
513
- if check_download_integrity(str(model_path), repo_id):
514
- rprint("[green]✅ Download integrity verified![/green]")
515
-
516
- # Check if model is in configuration
517
- if model_manager.is_model_installed(model_name):
518
- rprint("[green]✅ Model is properly configured and functional[/green]")
519
- return True
520
- else:
521
- rprint("[yellow]⚠️ Model files complete but not in configuration[/yellow]")
522
- return "needs_config"
523
- else:
524
- rprint("[red]❌ Download integrity check failed[/red]")
525
- return False
526
-
527
- # Check if download process is running
528
- rprint("[blue]🔍 Checking for active download processes...[/blue]")
529
- try:
530
- result = subprocess.run(['ps', 'aux'], capture_output=True, text=True)
531
- if f'ollamadiffuser pull {model_name}' in result.stdout:
532
- rprint("[yellow]🔄 Download process is currently running[/yellow]")
533
- return "downloading"
534
- else:
535
- rprint("[blue]💤 No active download process found[/blue]")
536
- except Exception as e:
537
- rprint(f"[yellow]⚠️ Could not check processes: {e}[/yellow]")
538
-
539
- return "incomplete"
540
-
541
-
542
- def _show_model_specific_help(model_name: str):
543
- """Show model-specific help and recommendations"""
544
- model_info = model_manager.get_model_info(model_name)
545
- if not model_info:
546
- return
547
-
548
- rprint(f"\n[bold blue]💡 {model_name} Specific Tips:[/bold blue]")
549
-
550
- # License-specific help
551
- license_info = model_info.get("license_info", {})
552
- if license_info.get("requires_agreement", False):
553
- rprint(f" [yellow]🔑 Requires HuggingFace token and license agreement[/yellow]")
554
- rprint(f" [blue]📝 Visit: https://huggingface.co/{model_info['repo_id']}[/blue]")
555
- rprint(f" [cyan]🔧 Set token: export HF_TOKEN=your_token_here[/cyan]")
556
- else:
557
- rprint(f" [green]✅ No HuggingFace token required![/green]")
558
-
559
- # Model-specific optimizations
560
- if "schnell" in model_name.lower():
561
- rprint(f" [green]⚡ FLUX.1-schnell is 12x faster than FLUX.1-dev[/green]")
562
- rprint(f" [green]🎯 Optimized for 4-step generation[/green]")
563
- rprint(f" [green]💼 Commercial use allowed (Apache 2.0)[/green]")
564
- elif "flux.1-dev" in model_name.lower():
565
- rprint(f" [blue]🎨 Best quality FLUX model[/blue]")
566
- rprint(f" [blue]🔬 Requires 50 steps for optimal results[/blue]")
567
- rprint(f" [yellow]⚠️ Non-commercial license only[/yellow]")
568
- elif "stable-diffusion-1.5" in model_name.lower():
569
- rprint(f" [green]🚀 Great for learning and quick tests[/green]")
570
- rprint(f" [green]💾 Smallest model, runs on most hardware[/green]")
571
- elif "stable-diffusion-3.5" in model_name.lower():
572
- rprint(f" [green]🏆 Excellent quality-to-speed ratio[/green]")
573
- rprint(f" [green]🔄 Great LoRA ecosystem[/green]")
574
-
575
- # Hardware recommendations
576
- hw_req = model_info.get("hardware_requirements", {})
577
- if hw_req:
578
- min_vram = hw_req.get("min_vram_gb", 0)
579
- if min_vram >= 12:
580
- rprint(f" [yellow]🖥️ Requires high-end GPU (RTX 4070+ or M2 Pro+)[/yellow]")
581
- elif min_vram >= 8:
582
- rprint(f" [blue]🖥️ Requires mid-range GPU (RTX 3080+ or M1 Pro+)[/blue]")
583
- else:
584
- rprint(f" [green]🖥️ Runs on most modern GPUs[/green]")
585
-
586
- @cli.command()
587
- @click.argument('model_name')
588
- @click.confirmation_option(prompt='Are you sure you want to delete this model?')
589
- def rm(model_name: str):
590
- """Remove model"""
591
- if model_manager.remove_model(model_name):
592
- rprint(f"[green]Model {model_name} removed successfully![/green]")
593
- else:
594
- rprint(f"[red]Failed to remove model {model_name}![/red]")
595
- sys.exit(1)
596
-
597
- @cli.command()
598
- def ps():
599
- """Show currently running model"""
600
- current_model = model_manager.get_current_model()
601
- server_running = model_manager.is_server_running()
602
-
603
- if current_model:
604
- rprint(f"[green]Current model: {current_model}[/green]")
605
-
606
- # Check server status
607
- if server_running:
608
- rprint(f"[green]Server status: Running on {settings.server.host}:{settings.server.port}[/green]")
609
-
610
- # Try to get model info from the running server
611
- try:
612
- import requests
613
- response = requests.get(f"http://{settings.server.host}:{settings.server.port}/api/models/running", timeout=2)
614
- if response.status_code == 200:
615
- data = response.json()
616
- if data.get('loaded'):
617
- info = data.get('info', {})
618
- rprint(f"Device: {info.get('device', 'Unknown')}")
619
- rprint(f"Type: {info.get('type', 'Unknown')}")
620
- rprint(f"Variant: {info.get('variant', 'Unknown')}")
621
- else:
622
- rprint("[yellow]Model loaded but not active in server[/yellow]")
623
- except:
624
- pass
625
- else:
626
- rprint("[yellow]Server status: Not running[/yellow]")
627
- rprint("[dim]Model is set as current but server is not active[/dim]")
628
-
629
- # Show model info from local config
630
- model_info = model_manager.get_model_info(current_model)
631
- if model_info:
632
- rprint(f"Model type: {model_info.get('model_type', 'Unknown')}")
633
- if model_info.get('installed'):
634
- rprint(f"Size: {model_info.get('size', 'Unknown')}")
635
- else:
636
- if server_running:
637
- rprint("[yellow]Server is running but no model is loaded[/yellow]")
638
- rprint(f"[green]Server status: Running on {settings.server.host}:{settings.server.port}[/green]")
639
- else:
640
- rprint("[yellow]No model is currently running[/yellow]")
641
- rprint("[dim]Use 'ollamadiffuser run <model>' to start a model[/dim]")
642
-
643
- @cli.command()
644
- @click.option('--host', '-h', default=None, help='Server host address')
645
- @click.option('--port', '-p', default=None, type=int, help='Server port')
646
- def serve(host: Optional[str], port: Optional[int]):
647
- """Start API server (without loading model)"""
648
- rprint("[blue]Starting OllamaDiffuser API server...[/blue]")
649
-
650
- try:
651
- run_server(host=host, port=port)
652
- except KeyboardInterrupt:
653
- rprint("\n[yellow]Server stopped[/yellow]")
654
-
655
- @cli.command()
656
- @click.argument('model_name')
657
- def load(model_name: str):
658
- """Load model into memory"""
659
- rprint(f"[blue]Loading model: {model_name}[/blue]")
660
-
661
- if model_manager.load_model(model_name):
662
- rprint(f"[green]Model {model_name} loaded successfully![/green]")
663
- else:
664
- rprint(f"[red]Failed to load model {model_name}![/red]")
665
- sys.exit(1)
72
+ cli.add_command(pull)
73
+ cli.add_command(run)
74
+ cli.add_command(list)
75
+ cli.add_command(show)
76
+ cli.add_command(check)
77
+ cli.add_command(rm)
78
+ cli.add_command(ps)
79
+ cli.add_command(serve)
80
+ cli.add_command(load)
81
+ cli.add_command(unload)
82
+ cli.add_command(stop)
666
83
 
667
- @cli.command()
668
- def unload():
669
- """Unload current model"""
670
- if model_manager.is_model_loaded():
671
- current_model = model_manager.get_current_model()
672
- model_manager.unload_model()
673
- rprint(f"[green]Model {current_model} unloaded[/green]")
674
- else:
675
- rprint("[yellow]No model to unload[/yellow]")
84
+ # --- Register LoRA commands ---
85
+ from .lora_commands import lora
676
86
 
677
- @cli.command()
678
- def stop():
679
- """Stop running server"""
680
- if not model_manager.is_server_running():
681
- rprint("[yellow]No server is currently running[/yellow]")
682
- return
683
-
684
- try:
685
- import requests
686
- import signal
687
- import psutil
688
-
689
- host = settings.server.host
690
- port = settings.server.port
691
-
692
- # Try graceful shutdown via API first
693
- try:
694
- response = requests.post(f"http://{host}:{port}/api/shutdown", timeout=5)
695
- if response.status_code == 200:
696
- rprint("[green]Server stopped gracefully[/green]")
697
- return
698
- except:
699
- pass
700
-
701
- # Fallback: Find and terminate the process
702
- for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
703
- try:
704
- cmdline = proc.info['cmdline']
705
- if cmdline and any('uvicorn' in arg for arg in cmdline) and any(str(port) in arg for arg in cmdline):
706
- proc.terminate()
707
- proc.wait(timeout=10)
708
- rprint(f"[green]Server process (PID: {proc.info['pid']}) stopped[/green]")
709
- return
710
- except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.TimeoutExpired):
711
- continue
712
-
713
- rprint("[red]Could not find or stop the server process[/red]")
714
-
715
- except ImportError:
716
- rprint("[red]psutil package required for stop command. Install with: pip install psutil[/red]")
717
- except Exception as e:
718
- rprint(f"[red]Failed to stop server: {e}[/red]")
87
+ cli.add_command(lora)
719
88
 
720
- @cli.group()
721
- def lora():
722
- """LoRA (Low-Rank Adaptation) management commands"""
723
- pass
89
+ # --- Register registry commands ---
90
+ from .registry_commands import registry
724
91
 
725
- @lora.command()
726
- @click.argument('repo_id')
727
- @click.option('--weight-name', '-w', help='Specific weight file name (e.g., lora.safetensors)')
728
- @click.option('--alias', '-a', help='Local alias name for the LoRA')
729
- def pull(repo_id: str, weight_name: Optional[str], alias: Optional[str]):
730
- """Download LoRA weights from Hugging Face Hub"""
731
- from ..core.utils.lora_manager import lora_manager
732
-
733
- rprint(f"[blue]Downloading LoRA: {repo_id}[/blue]")
734
-
735
- with Progress(
736
- SpinnerColumn(),
737
- TextColumn("[progress.description]{task.description}"),
738
- console=console
739
- ) as progress:
740
- task = progress.add_task(f"Downloading LoRA...", total=None)
741
-
742
- def progress_callback(message: str):
743
- progress.update(task, description=message)
744
-
745
- if lora_manager.pull_lora(repo_id, weight_name=weight_name, alias=alias, progress_callback=progress_callback):
746
- progress.update(task, description=f"✅ LoRA download completed")
747
- rprint(f"[green]LoRA {repo_id} downloaded successfully![/green]")
748
- else:
749
- progress.update(task, description=f"❌ LoRA download failed")
750
- rprint(f"[red]LoRA {repo_id} download failed![/red]")
751
- sys.exit(1)
92
+ cli.add_command(registry)
752
93
 
753
- @lora.command()
754
- @click.argument('lora_name')
755
- @click.option('--scale', '-s', default=1.0, type=float, help='LoRA scale/strength (default: 1.0)')
756
- def load(lora_name: str, scale: float):
757
- """Load LoRA weights into the current model"""
758
- from ..core.utils.lora_manager import lora_manager
759
-
760
- rprint(f"[blue]Loading LoRA: {lora_name} (scale: {scale})[/blue]")
761
-
762
- if lora_manager.load_lora(lora_name, scale=scale):
763
- rprint(f"[green]LoRA {lora_name} loaded successfully![/green]")
764
- else:
765
- rprint(f"[red]Failed to load LoRA {lora_name}![/red]")
766
- sys.exit(1)
94
+ # --- Register config commands ---
95
+ from .config_commands import config
767
96
 
768
- @lora.command()
769
- def unload():
770
- """Unload current LoRA weights"""
771
- from ..core.utils.lora_manager import lora_manager
772
-
773
- rprint("[blue]Unloading LoRA weights...[/blue]")
774
-
775
- if lora_manager.unload_lora():
776
- rprint("[green]LoRA weights unloaded successfully![/green]")
777
- else:
778
- rprint("[red]Failed to unload LoRA weights![/red]")
779
- sys.exit(1)
97
+ cli.add_command(config)
780
98
 
781
- @lora.command()
782
- @click.argument('lora_name')
783
- @click.confirmation_option(prompt='Are you sure you want to delete this LoRA?')
784
- def rm(lora_name: str):
785
- """Remove LoRA weights"""
786
- from ..core.utils.lora_manager import lora_manager
787
-
788
- rprint(f"[blue]Removing LoRA: {lora_name}[/blue]")
789
-
790
- if lora_manager.remove_lora(lora_name):
791
- rprint(f"[green]LoRA {lora_name} removed successfully![/green]")
792
- else:
793
- rprint(f"[red]Failed to remove LoRA {lora_name}![/red]")
794
- sys.exit(1)
99
+ # --- Register recommend command ---
100
+ from .recommend_command import recommend
795
101
 
796
- @lora.command()
797
- def ps():
798
- """Show currently loaded LoRA status"""
799
- from ..core.utils.lora_manager import lora_manager
800
-
801
- # Check if server is running
802
- server_running = lora_manager._is_server_running()
803
- current_lora = lora_manager.get_current_lora()
804
-
805
- if server_running:
806
- rprint(f"[green]Server status: Running on {settings.server.host}:{settings.server.port}[/green]")
807
-
808
- # Try to get LoRA status from the running server
809
- try:
810
- import requests
811
- response = requests.get(f"http://{settings.server.host}:{settings.server.port}/api/models/running", timeout=2)
812
- if response.status_code == 200:
813
- data = response.json()
814
- if data.get('loaded'):
815
- model_info = data.get('info', {})
816
- rprint(f"Model: {data.get('model', 'Unknown')}")
817
- rprint(f"Device: {model_info.get('device', 'Unknown')}")
818
- rprint(f"Type: {model_info.get('type', 'Unknown')}")
819
- else:
820
- rprint("[yellow]No model loaded in server[/yellow]")
821
- return
822
- except Exception as e:
823
- rprint(f"[red]Failed to get server status: {e}[/red]")
824
- return
825
- else:
826
- # Check local model manager
827
- if model_manager.is_model_loaded():
828
- current_model = model_manager.get_current_model()
829
- rprint(f"[green]Model loaded locally: {current_model}[/green]")
830
- else:
831
- rprint("[yellow]No server running and no local model loaded[/yellow]")
832
- rprint("[dim]Use 'ollamadiffuser run <model>' to start a model[/dim]")
833
- return
834
-
835
- # Show LoRA status
836
- lora_status_shown = False
837
- lora_loaded_on_server = False
838
-
839
- # Try to get LoRA status from server if running
840
- if server_running:
841
- try:
842
- import requests
843
- response = requests.get(f"http://{settings.server.host}:{settings.server.port}/api/lora/status", timeout=2)
844
- if response.status_code == 200:
845
- lora_data = response.json()
846
- if lora_data.get('loaded'):
847
- lora_info = lora_data.get('info', {})
848
- rprint(f"\n[bold green]🔄 LoRA Status: LOADED (via server)[/bold green]")
849
- rprint(f"Adapter: {lora_info.get('adapter_name', 'Unknown')}")
850
- if 'scale' in lora_info:
851
- rprint(f"Scale: {lora_info.get('scale', 'Unknown')}")
852
- if 'adapters' in lora_info:
853
- rprint(f"Active Adapters: {', '.join(lora_info.get('adapters', []))}")
854
- lora_status_shown = True
855
- lora_loaded_on_server = True
856
- else:
857
- rprint(f"\n[dim]💾 LoRA Status: No LoRA loaded (server)[/dim]")
858
- lora_status_shown = True
859
- except Exception as e:
860
- rprint(f"\n[yellow]⚠️ Failed to get LoRA status from server: {e}[/yellow]")
861
-
862
- # Fallback to local LoRA manager state
863
- if not lora_status_shown:
864
- if current_lora:
865
- lora_info = lora_manager.get_lora_info(current_lora)
866
- if lora_info:
867
- rprint(f"\n[bold green]🔄 LoRA Status: LOADED (local)[/bold green]")
868
- rprint(f"Name: {current_lora}")
869
- rprint(f"Repository: {lora_info.get('repo_id', 'Unknown')}")
870
- rprint(f"Weight File: {lora_info.get('weight_name', 'Unknown')}")
871
- rprint(f"Size: {lora_info.get('size', 'Unknown')}")
872
- rprint(f"Local Path: {lora_info.get('path', 'Unknown')}")
873
- else:
874
- rprint(f"\n[yellow]⚠️ LoRA {current_lora} is set as current but info not found[/yellow]")
875
- else:
876
- rprint(f"\n[dim]💾 LoRA Status: No LoRA loaded[/dim]")
877
-
878
- if not lora_loaded_on_server:
879
- rprint("[dim]Use 'ollamadiffuser lora load <lora_name>' to load a LoRA[/dim]")
102
+ cli.add_command(recommend)
880
103
 
881
- @lora.command()
882
- def list():
883
- """List available and installed LoRA weights"""
884
- from ..core.utils.lora_manager import lora_manager
885
-
886
- installed_loras = lora_manager.list_installed_loras()
887
- current_lora = lora_manager.get_current_lora()
888
-
889
- if not installed_loras:
890
- rprint("[yellow]No LoRA weights installed.[/yellow]")
891
- rprint("\n[dim]💡 Use 'ollamadiffuser lora pull <repo_id>' to download LoRA weights[/dim]")
892
- return
893
-
894
- table = Table(title="Installed LoRA Weights")
895
- table.add_column("Name", style="cyan", no_wrap=True)
896
- table.add_column("Repository", style="blue")
897
- table.add_column("Status", style="green")
898
- table.add_column("Size", style="yellow")
899
-
900
- for lora_name, lora_info in installed_loras.items():
901
- status = "🔄 Loaded" if lora_name == current_lora else "💾 Available"
902
- size = lora_info.get('size', 'Unknown')
903
- repo_id = lora_info.get('repo_id', 'Unknown')
904
-
905
- table.add_row(lora_name, repo_id, status, size)
906
-
907
- console.print(table)
104
+ # --- Utility commands ---
908
105
 
909
- @lora.command()
910
- @click.argument('lora_name')
911
- def show(lora_name: str):
912
- """Show detailed LoRA information"""
913
- from ..core.utils.lora_manager import lora_manager
914
-
915
- lora_info = lora_manager.get_lora_info(lora_name)
916
-
917
- if not lora_info:
918
- rprint(f"[red]LoRA {lora_name} not found.[/red]")
919
- sys.exit(1)
920
-
921
- rprint(f"[bold cyan]LoRA Information: {lora_name}[/bold cyan]")
922
- rprint(f"Repository: {lora_info.get('repo_id', 'Unknown')}")
923
- rprint(f"Weight File: {lora_info.get('weight_name', 'Unknown')}")
924
- rprint(f"Local Path: {lora_info.get('path', 'Unknown')}")
925
- rprint(f"Size: {lora_info.get('size', 'Unknown')}")
926
- rprint(f"Downloaded: {lora_info.get('downloaded_at', 'Unknown')}")
927
-
928
- if lora_info.get('description'):
929
- rprint(f"Description: {lora_info.get('description')}")
930
106
 
931
107
  @cli.command()
932
108
  def version():
933
109
  """Show version information"""
934
110
  print_version()
935
- rprint("\n[bold]Features:[/bold]")
936
- rprint("• 🚀 Fast Startup with lazy loading architecture")
937
- rprint("• 🎛️ ControlNet Support with 10+ control types")
938
- rprint("• 🔄 LoRA Integration with dynamic loading")
939
- rprint("• 🌐 Multiple Interfaces: CLI, Python API, Web UI, REST API")
940
- rprint("• 📦 Easy model management and switching")
941
- rprint("• ⚡ Performance optimized with GPU acceleration")
942
-
943
111
  rprint("\n[bold]Supported Models:[/bold]")
944
- rprint("FLUX.1-schnell (Apache 2.0, Commercial OK, 4-step generation)")
945
- rprint(" FLUX.1-dev (Non-commercial, High quality, 50-step generation)")
946
- rprint("• Stable Diffusion 3.5 Medium")
947
- rprint(" Stable Diffusion XL Base")
948
- rprint("• Stable Diffusion 1.5")
949
- rprint("• ControlNet models for SD15 and SDXL")
950
-
112
+ rprint(" FLUX.1-schnell, FLUX.1-dev, SD 3.5 Medium, SDXL, SD 1.5")
113
+ rprint(" ControlNet (SD15 + SDXL), AnimateDiff, HiDream, GGUF variants")
114
+ rprint("\n[bold]Features:[/bold]")
115
+ rprint(" img2img, inpainting, LoRA, ControlNet, async API, Web UI")
951
116
  rprint("\n[dim]For help: ollamadiffuser --help[/dim]")
952
- rprint("[dim]For diagnostics: ollamadiffuser doctor[/dim]")
953
- rprint("[dim]For ControlNet samples: ollamadiffuser create-samples[/dim]")
954
117
 
955
- @cli.command(name='verify-deps')
118
+
119
+ @cli.command(name="verify-deps")
956
120
  def verify_deps_cmd():
957
121
  """Verify and install missing dependencies"""
958
122
  from .commands import verify_deps
123
+
959
124
  ctx = click.Context(verify_deps)
960
125
  ctx.invoke(verify_deps)
961
126
 
127
+
962
128
  @cli.command()
963
129
  def doctor():
964
130
  """Run comprehensive system diagnostics"""
965
131
  from .commands import doctor
132
+
966
133
  ctx = click.Context(doctor)
967
134
  ctx.invoke(doctor)
968
135
 
969
- @cli.command(name='create-samples')
970
- @click.option('--force', is_flag=True, help='Force recreation of all samples even if they exist')
971
- def create_samples_cmd(force):
972
- """Create ControlNet sample images for the Web UI"""
973
- from .commands import create_samples
974
- ctx = click.Context(create_samples)
975
- ctx.invoke(create_samples, force=force)
976
-
977
- @cli.group(hidden=True)
978
- def registry():
979
- """Manage model registry (internal command)"""
980
- pass
981
136
 
982
- @registry.command()
983
- @click.option('--format', '-f', type=click.Choice(['table', 'json', 'yaml']), default='table', help='Output format')
984
- @click.option('--installed-only', is_flag=True, help='Show only installed models')
985
- @click.option('--available-only', is_flag=True, help='Show only available (not installed) models')
986
- @click.option('--external-only', is_flag=True, help='Show only externally defined models')
987
- def list(format: str, installed_only: bool, available_only: bool, external_only: bool):
988
- """List models in the registry with installation status"""
989
-
990
- # Get different model categories
991
- if installed_only:
992
- models = model_registry.get_installed_models()
993
- title = "Installed Models"
994
- elif available_only:
995
- models = model_registry.get_available_models()
996
- title = "Available Models (Not Installed)"
997
- elif external_only:
998
- models = model_registry.get_external_api_models_only()
999
- title = "External API Models"
1000
- else:
1001
- models = model_registry.get_all_models()
1002
- title = "All Models (Installed + Available)"
1003
-
1004
- installed_model_names = set(model_registry.get_installed_models().keys())
1005
- local_model_names = set(model_registry.get_local_models_only().keys())
1006
- external_model_names = set(model_registry.get_external_api_models_only().keys())
1007
- current_model = model_manager.get_current_model()
1008
-
1009
- if not models:
1010
- rprint(f"[yellow]No models found in category: {title}[/yellow]")
1011
- return
1012
-
1013
- if format == 'table':
1014
- table = Table(title=title)
1015
- table.add_column("Model Name", style="cyan", no_wrap=True)
1016
- table.add_column("Type", style="yellow")
1017
- table.add_column("Repository", style="blue")
1018
- table.add_column("Status", style="green")
1019
- table.add_column("Source", style="magenta")
1020
-
1021
- for model_name, model_info in models.items():
1022
- # Check installation status
1023
- if model_name in installed_model_names:
1024
- status = "✅ Installed"
1025
- if model_name == current_model:
1026
- status += " (current)"
1027
- else:
1028
- status = "⬇️ Available"
1029
-
1030
- # Determine source
1031
- if model_name in local_model_names and model_name in external_model_names:
1032
- source = "Local + External"
1033
- elif model_name in local_model_names:
1034
- source = "Local"
1035
- elif model_name in external_model_names:
1036
- source = "External API"
1037
- else:
1038
- source = "Unknown"
1039
-
1040
- table.add_row(
1041
- model_name,
1042
- model_info.get('model_type', 'Unknown'),
1043
- model_info.get('repo_id', 'Unknown'),
1044
- status,
1045
- source
1046
- )
1047
-
1048
- console.print(table)
1049
-
1050
- # Show summary
1051
- if not (installed_only or available_only or external_only):
1052
- total_count = len(models)
1053
- installed_count = len(installed_model_names)
1054
- available_count = total_count - installed_count
1055
- local_count = len(local_model_names)
1056
- external_count = len(external_model_names)
1057
-
1058
- console.print(f"\n[dim]Summary:[/dim]")
1059
- console.print(f"[dim] • Total: {total_count} models[/dim]")
1060
- console.print(f"[dim] • Installed: {installed_count} models[/dim]")
1061
- console.print(f"[dim] • Available: {available_count} models[/dim]")
1062
- console.print(f"[dim] • Local registry: {local_count} models[/dim]")
1063
- console.print(f"[dim] • External API: {external_count} models[/dim]")
1064
-
1065
- elif format == 'json':
1066
- import json
1067
- print(json.dumps(models, indent=2, ensure_ascii=False))
1068
-
1069
- elif format == 'yaml':
1070
- import yaml
1071
- print(yaml.dump(models, default_flow_style=False, allow_unicode=True))
1072
-
1073
- @registry.command()
1074
- @click.argument('model_name')
1075
- @click.argument('repo_id')
1076
- @click.argument('model_type')
1077
- @click.option('--variant', help='Model variant (e.g., fp16, bf16)')
1078
- @click.option('--license-type', help='License type')
1079
- @click.option('--commercial-use', type=bool, help='Whether commercial use is allowed')
1080
- @click.option('--save', is_flag=True, help='Save to user configuration file')
1081
- def add(model_name: str, repo_id: str, model_type: str, variant: Optional[str],
1082
- license_type: Optional[str], commercial_use: Optional[bool], save: bool):
1083
- """Add a new model to the registry"""
1084
-
1085
- model_config = {
1086
- "repo_id": repo_id,
1087
- "model_type": model_type
1088
- }
1089
-
1090
- if variant:
1091
- model_config["variant"] = variant
1092
-
1093
- if license_type or commercial_use is not None:
1094
- license_info = {}
1095
- if license_type:
1096
- license_info["type"] = license_type
1097
- if commercial_use is not None:
1098
- license_info["commercial_use"] = commercial_use
1099
- model_config["license_info"] = license_info
1100
-
1101
- if model_registry.add_model(model_name, model_config):
1102
- rprint(f"[green]Model '{model_name}' added to registry successfully![/green]")
1103
-
1104
- if save:
1105
- try:
1106
- # Load existing user models and add the new one
1107
- user_models = {}
1108
- config_path = settings.config_dir / "models.json"
1109
- if config_path.exists():
1110
- import json
1111
- with open(config_path, 'r') as f:
1112
- data = json.load(f)
1113
- user_models = data.get('models', {})
1114
-
1115
- user_models[model_name] = model_config
1116
- model_registry.save_user_config(user_models, config_path)
1117
- rprint(f"[green]Model configuration saved to {config_path}[/green]")
1118
- except Exception as e:
1119
- rprint(f"[red]Failed to save configuration: {e}[/red]")
1120
- else:
1121
- rprint(f"[red]Failed to add model '{model_name}' to registry![/red]")
1122
- sys.exit(1)
137
+ @cli.command(name="mcp")
138
+ def mcp_cmd():
139
+ """Start the MCP (Model Context Protocol) server for AI assistant integration."""
140
+ try:
141
+ from ..mcp.server import main as mcp_main
1123
142
 
1124
- @registry.command()
1125
- @click.argument('model_name')
1126
- @click.option('--from-file', is_flag=True, help='Also remove from user configuration file')
1127
- def remove(model_name: str, from_file: bool):
1128
- """Remove a model from the registry"""
1129
-
1130
- if model_registry.remove_model(model_name):
1131
- rprint(f"[green]Model '{model_name}' removed from registry![/green]")
1132
-
1133
- if from_file:
1134
- try:
1135
- config_path = settings.config_dir / "models.json"
1136
- if config_path.exists():
1137
- import json
1138
- with open(config_path, 'r') as f:
1139
- data = json.load(f)
1140
-
1141
- user_models = data.get('models', {})
1142
- if model_name in user_models:
1143
- del user_models[model_name]
1144
- model_registry.save_user_config(user_models, config_path)
1145
- rprint(f"[green]Model removed from configuration file[/green]")
1146
- else:
1147
- rprint(f"[yellow]Model not found in configuration file[/yellow]")
1148
- else:
1149
- rprint(f"[yellow]No user configuration file found[/yellow]")
1150
- except Exception as e:
1151
- rprint(f"[red]Failed to update configuration file: {e}[/red]")
1152
- else:
1153
- rprint(f"[red]Model '{model_name}' not found in registry![/red]")
143
+ mcp_main()
144
+ except ImportError:
145
+ rprint("[red]MCP package not installed. Install with:[/red]")
146
+ rprint("[yellow] pip install 'ollamadiffuser[mcp]'[/yellow]")
1154
147
  sys.exit(1)
1155
148
 
1156
- @registry.command()
1157
- def reload():
1158
- """Reload the model registry from configuration files"""
1159
- try:
1160
- model_registry.reload()
1161
- rprint("[green]Model registry reloaded successfully![/green]")
1162
-
1163
- # Show summary
1164
- models = model_registry.get_all_models()
1165
- external_registries = model_registry.get_external_registries()
1166
-
1167
- rprint(f"[dim]Total models: {len(models)}[/dim]")
1168
- if external_registries:
1169
- rprint(f"[dim]External registries: {len(external_registries)}[/dim]")
1170
- for registry_path in external_registries:
1171
- rprint(f"[dim] • {registry_path}[/dim]")
1172
- else:
1173
- rprint("[dim]No external registries loaded[/dim]")
1174
-
1175
- except Exception as e:
1176
- rprint(f"[red]Failed to reload registry: {e}[/red]")
1177
- sys.exit(1)
1178
149
 
1179
- @registry.command()
1180
- @click.argument('config_file', type=click.Path(exists=True))
1181
- def import_config(config_file: str):
1182
- """Import models from a configuration file"""
1183
- try:
1184
- from pathlib import Path
1185
- import json
1186
- import yaml
1187
-
1188
- config_path = Path(config_file)
1189
-
1190
- with open(config_path, 'r', encoding='utf-8') as f:
1191
- if config_path.suffix.lower() == '.json':
1192
- data = json.load(f)
1193
- elif config_path.suffix.lower() in ['.yaml', '.yml']:
1194
- data = yaml.safe_load(f)
1195
- else:
1196
- rprint(f"[red]Unsupported file format: {config_path.suffix}[/red]")
1197
- sys.exit(1)
1198
-
1199
- if 'models' not in data:
1200
- rprint("[red]Configuration file must contain a 'models' section[/red]")
1201
- sys.exit(1)
1202
-
1203
- imported_count = 0
1204
- for model_name, model_config in data['models'].items():
1205
- if model_registry.add_model(model_name, model_config):
1206
- imported_count += 1
1207
- rprint(f"[green]✓ Imported: {model_name}[/green]")
1208
- else:
1209
- rprint(f"[red]✗ Failed to import: {model_name}[/red]")
1210
-
1211
- rprint(f"[green]Successfully imported {imported_count} models[/green]")
1212
-
1213
- except Exception as e:
1214
- rprint(f"[red]Failed to import configuration: {e}[/red]")
1215
- sys.exit(1)
150
+ @cli.command(name="create-samples")
151
+ @click.option("--force", is_flag=True, help="Force recreation of samples")
152
+ def create_samples_cmd(force):
153
+ """Create ControlNet sample images for the Web UI"""
154
+ from .commands import create_samples
1216
155
 
1217
- @registry.command()
1218
- @click.option('--output', '-o', help='Output file path')
1219
- @click.option('--format', '-f', type=click.Choice(['json', 'yaml']), default='json', help='Output format')
1220
- @click.option('--user-only', is_flag=True, help='Export only user-defined models')
1221
- def export(output: Optional[str], format: str, user_only: bool):
1222
- """Export model registry to a configuration file"""
1223
- try:
1224
- from pathlib import Path
1225
- import json
1226
- import yaml
1227
-
1228
- if user_only:
1229
- # Only export models from external registries
1230
- models = {}
1231
- external_registries = model_registry.get_external_registries()
1232
- if external_registries:
1233
- rprint(f"[yellow]User-only export not fully supported yet. Exporting all models.[/yellow]")
1234
-
1235
- models = model_registry.get_all_models()
1236
-
1237
- config_data = {"models": models}
1238
-
1239
- if output:
1240
- output_path = Path(output)
1241
- else:
1242
- if format == 'json':
1243
- output_path = Path('models.json')
1244
- else:
1245
- output_path = Path('models.yaml')
1246
-
1247
- with open(output_path, 'w', encoding='utf-8') as f:
1248
- if format == 'json':
1249
- json.dump(config_data, f, indent=2, ensure_ascii=False)
1250
- else:
1251
- yaml.safe_dump(config_data, f, default_flow_style=False, allow_unicode=True)
1252
-
1253
- rprint(f"[green]Model registry exported to {output_path}[/green]")
1254
- rprint(f"[dim]Exported {len(models)} models[/dim]")
1255
-
1256
- except Exception as e:
1257
- rprint(f"[red]Failed to export registry: {e}[/red]")
1258
- sys.exit(1)
156
+ ctx = click.Context(create_samples)
157
+ ctx.invoke(create_samples, force=force)
1259
158
 
1260
- @registry.command('check-gguf')
1261
- def check_gguf():
1262
- """Check GGUF support status"""
1263
- from ..core.models.gguf_loader import GGUF_AVAILABLE
1264
-
1265
- if GGUF_AVAILABLE:
1266
- rprint("✅ [green]GGUF Support Available[/green]")
1267
-
1268
- # Show GGUF models
1269
- models = model_registry.get_all_models()
1270
- gguf_models = {name: info for name, info in models.items()
1271
- if model_manager.is_gguf_model(name)}
1272
-
1273
- if gguf_models:
1274
- rprint(f"\n🔥 Found {len(gguf_models)} GGUF models:")
1275
-
1276
- table = Table()
1277
- table.add_column("Model", style="cyan")
1278
- table.add_column("Variant", style="yellow")
1279
- table.add_column("VRAM", style="green")
1280
- table.add_column("Size", style="blue")
1281
- table.add_column("Installed", style="red")
1282
-
1283
- for name, info in gguf_models.items():
1284
- hw_req = info.get('hardware_requirements', {})
1285
- installed = "✅" if model_manager.is_model_installed(name) else "❌"
1286
-
1287
- table.add_row(
1288
- name,
1289
- info.get('variant', 'unknown'),
1290
- f"{hw_req.get('min_vram_gb', '?')}GB",
1291
- f"{hw_req.get('disk_space_gb', '?')}GB",
1292
- installed
1293
- )
1294
-
1295
- console.print(table)
1296
-
1297
- rprint("\n📋 [blue]Usage:[/blue]")
1298
- rprint(" ollamadiffuser pull <model-name> # Download GGUF model")
1299
- rprint(" ollamadiffuser load <model-name> # Load GGUF model")
1300
- rprint("\n💡 [yellow]Tip:[/yellow] Start with flux.1-dev-gguf-q4ks for best balance")
1301
- else:
1302
- rprint("ℹ️ No GGUF models found in registry")
1303
- else:
1304
- rprint("❌ [red]GGUF Support Not Available[/red]")
1305
- rprint("📦 Install with: [yellow]pip install llama-cpp-python gguf[/yellow]")
1306
- rprint("🔧 Or install all dependencies: [yellow]pip install -r requirements.txt[/yellow]")
1307
159
 
1308
- if __name__ == '__main__':
1309
- cli()
160
+ if __name__ == "__main__":
161
+ cli()