ollamadiffuser 1.2.3__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ollamadiffuser/__init__.py +1 -1
- ollamadiffuser/api/server.py +312 -312
- ollamadiffuser/cli/config_commands.py +119 -0
- ollamadiffuser/cli/lora_commands.py +169 -0
- ollamadiffuser/cli/main.py +85 -1233
- ollamadiffuser/cli/model_commands.py +664 -0
- ollamadiffuser/cli/recommend_command.py +205 -0
- ollamadiffuser/cli/registry_commands.py +197 -0
- ollamadiffuser/core/config/model_registry.py +562 -11
- ollamadiffuser/core/config/settings.py +24 -2
- ollamadiffuser/core/inference/__init__.py +5 -0
- ollamadiffuser/core/inference/base.py +182 -0
- ollamadiffuser/core/inference/engine.py +204 -1405
- ollamadiffuser/core/inference/strategies/__init__.py +1 -0
- ollamadiffuser/core/inference/strategies/controlnet_strategy.py +170 -0
- ollamadiffuser/core/inference/strategies/flux_strategy.py +136 -0
- ollamadiffuser/core/inference/strategies/generic_strategy.py +164 -0
- ollamadiffuser/core/inference/strategies/gguf_strategy.py +113 -0
- ollamadiffuser/core/inference/strategies/hidream_strategy.py +104 -0
- ollamadiffuser/core/inference/strategies/sd15_strategy.py +134 -0
- ollamadiffuser/core/inference/strategies/sd3_strategy.py +80 -0
- ollamadiffuser/core/inference/strategies/sdxl_strategy.py +131 -0
- ollamadiffuser/core/inference/strategies/video_strategy.py +108 -0
- ollamadiffuser/mcp/__init__.py +0 -0
- ollamadiffuser/mcp/server.py +184 -0
- ollamadiffuser/ui/templates/index.html +62 -1
- ollamadiffuser/ui/web.py +116 -54
- {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.0.dist-info}/METADATA +321 -108
- ollamadiffuser-2.0.0.dist-info/RECORD +61 -0
- {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.0.dist-info}/WHEEL +1 -1
- {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.0.dist-info}/entry_points.txt +1 -0
- ollamadiffuser/core/models/registry.py +0 -384
- ollamadiffuser/ui/samples/.DS_Store +0 -0
- ollamadiffuser-1.2.3.dist-info/RECORD +0 -45
- {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.0.dist-info}/licenses/LICENSE +0 -0
- {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.0.dist-info}/top_level.txt +0 -0
ollamadiffuser/cli/main.py
CHANGED
|
@@ -1,1309 +1,161 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
"""OllamaDiffuser CLI - Main entry point"""
|
|
2
|
+
|
|
3
3
|
import sys
|
|
4
4
|
import logging
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
from rich.table import Table
|
|
8
|
-
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, DownloadColumn, TransferSpeedColumn, TimeRemainingColumn
|
|
5
|
+
|
|
6
|
+
import click
|
|
9
7
|
from rich import print as rprint
|
|
10
|
-
import time
|
|
11
8
|
|
|
12
9
|
from .. import __version__, print_version
|
|
13
|
-
from ..core.models.manager import model_manager
|
|
14
10
|
from ..core.config.settings import settings
|
|
15
|
-
from ..core.config.model_registry import model_registry
|
|
16
11
|
from ..api.server import run_server
|
|
17
12
|
|
|
18
|
-
console = Console()
|
|
19
|
-
|
|
20
|
-
class OllamaStyleProgress:
|
|
21
|
-
"""Enhanced progress tracker that mimics Ollama's progress display"""
|
|
22
|
-
|
|
23
|
-
def __init__(self, console: Console):
|
|
24
|
-
self.console = console
|
|
25
|
-
self.last_message = ""
|
|
26
|
-
|
|
27
|
-
def update(self, message: str):
|
|
28
|
-
"""Update progress with a message"""
|
|
29
|
-
# Skip duplicate messages
|
|
30
|
-
if message == self.last_message:
|
|
31
|
-
return
|
|
32
|
-
|
|
33
|
-
self.last_message = message
|
|
34
|
-
|
|
35
|
-
# Handle different types of messages
|
|
36
|
-
if message.startswith("pulling ") and ":" in message and "%" in message:
|
|
37
|
-
# This is a file progress message from download_utils
|
|
38
|
-
# Format: "pulling e6a7edc1a4d7: 12% ▕██ ▏ 617 MB/5200 MB 44 MB/s 1m44s"
|
|
39
|
-
self.console.print(message)
|
|
40
|
-
elif message.startswith("pulling manifest"):
|
|
41
|
-
self.console.print(message)
|
|
42
|
-
elif message.startswith("📦 Repository:"):
|
|
43
|
-
# Repository info
|
|
44
|
-
self.console.print(f"[dim]{message}[/dim]")
|
|
45
|
-
elif message.startswith("📁 Found"):
|
|
46
|
-
# Existing files info
|
|
47
|
-
self.console.print(f"[dim]{message}[/dim]")
|
|
48
|
-
elif message.startswith("✅") and "download completed" in message:
|
|
49
|
-
self.console.print(f"[green]{message}[/green]")
|
|
50
|
-
elif message.startswith("❌"):
|
|
51
|
-
self.console.print(f"[red]{message}[/red]")
|
|
52
|
-
elif message.startswith("⚠️"):
|
|
53
|
-
self.console.print(f"[yellow]{message}[/yellow]")
|
|
54
|
-
else:
|
|
55
|
-
# For other messages, print with dimmed style
|
|
56
|
-
self.console.print(f"[dim]{message}[/dim]")
|
|
57
13
|
|
|
58
14
|
@click.group(invoke_without_command=True)
|
|
59
|
-
@click.option(
|
|
60
|
-
@click.option(
|
|
61
|
-
@click.option(
|
|
62
|
-
|
|
63
|
-
|
|
15
|
+
@click.option("--verbose", "-v", is_flag=True, help="Enable verbose output")
|
|
16
|
+
@click.option("--version", "-V", is_flag=True, help="Show version and exit")
|
|
17
|
+
@click.option(
|
|
18
|
+
"--mode",
|
|
19
|
+
type=click.Choice(["cli", "api", "ui"]),
|
|
20
|
+
help="Running mode: cli, api (server), ui (web interface)",
|
|
21
|
+
)
|
|
22
|
+
@click.option("--host", default=None, help="Server host address (for api/ui modes)")
|
|
23
|
+
@click.option("--port", type=int, default=None, help="Server port (for api/ui modes)")
|
|
64
24
|
@click.pass_context
|
|
65
25
|
def cli(ctx, verbose, version, mode, host, port):
|
|
66
26
|
"""OllamaDiffuser - Image generation model management tool"""
|
|
67
27
|
if version:
|
|
68
28
|
print_version()
|
|
69
29
|
sys.exit(0)
|
|
70
|
-
|
|
30
|
+
|
|
71
31
|
if verbose:
|
|
72
32
|
logging.basicConfig(level=logging.DEBUG)
|
|
73
33
|
else:
|
|
74
34
|
logging.basicConfig(level=logging.WARNING)
|
|
75
|
-
|
|
76
|
-
# Handle mode-based execution
|
|
35
|
+
|
|
77
36
|
if mode:
|
|
78
|
-
if mode ==
|
|
37
|
+
if mode == "api":
|
|
79
38
|
rprint("[blue]Starting OllamaDiffuser API server...[/blue]")
|
|
80
39
|
run_server(host=host, port=port)
|
|
81
40
|
sys.exit(0)
|
|
82
|
-
elif mode ==
|
|
41
|
+
elif mode == "ui":
|
|
83
42
|
rprint("[blue]Starting OllamaDiffuser Web UI...[/blue]")
|
|
84
43
|
import uvicorn
|
|
85
44
|
from ..ui.web import create_ui_app
|
|
45
|
+
|
|
86
46
|
app = create_ui_app()
|
|
87
47
|
ui_host = host or settings.server.host
|
|
88
|
-
ui_port = port or (settings.server.port + 1)
|
|
48
|
+
ui_port = port or (settings.server.port + 1)
|
|
89
49
|
uvicorn.run(app, host=ui_host, port=ui_port)
|
|
90
50
|
sys.exit(0)
|
|
91
|
-
|
|
92
|
-
# Continue with normal CLI processing
|
|
93
|
-
pass
|
|
94
|
-
|
|
95
|
-
# If no subcommand is provided and no mode/version flag, show help
|
|
51
|
+
|
|
96
52
|
if ctx.invoked_subcommand is None and not version and not mode:
|
|
97
53
|
rprint(ctx.get_help())
|
|
98
54
|
sys.exit(0)
|
|
99
55
|
|
|
100
|
-
@cli.command()
|
|
101
|
-
@click.argument('model_name')
|
|
102
|
-
@click.option('--force', '-f', is_flag=True, help='Force re-download')
|
|
103
|
-
def pull(model_name: str, force: bool):
|
|
104
|
-
"""Download model"""
|
|
105
|
-
rprint(f"[blue]Downloading model: {model_name}[/blue]")
|
|
106
|
-
|
|
107
|
-
# Use the new Ollama-style progress tracker
|
|
108
|
-
progress_tracker = OllamaStyleProgress(console)
|
|
109
|
-
|
|
110
|
-
def progress_callback(message: str):
|
|
111
|
-
"""Enhanced progress callback with Ollama-style display"""
|
|
112
|
-
progress_tracker.update(message)
|
|
113
|
-
|
|
114
|
-
try:
|
|
115
|
-
if model_manager.pull_model(model_name, force=force, progress_callback=progress_callback):
|
|
116
|
-
progress_tracker.update("✅ download completed")
|
|
117
|
-
rprint(f"[green]Model {model_name} downloaded successfully![/green]")
|
|
118
|
-
else:
|
|
119
|
-
rprint(f"[red]Model {model_name} download failed![/red]")
|
|
120
|
-
sys.exit(1)
|
|
121
|
-
except KeyboardInterrupt:
|
|
122
|
-
rprint("\n[yellow]Download cancelled by user[/yellow]")
|
|
123
|
-
sys.exit(1)
|
|
124
|
-
except Exception as e:
|
|
125
|
-
rprint(f"[red]Download failed: {str(e)}[/red]")
|
|
126
|
-
sys.exit(1)
|
|
127
|
-
|
|
128
|
-
@cli.command()
|
|
129
|
-
@click.argument('model_name')
|
|
130
|
-
@click.option('--host', '-h', default=None, help='Server host address')
|
|
131
|
-
@click.option('--port', '-p', default=None, type=int, help='Server port')
|
|
132
|
-
def run(model_name: str, host: Optional[str], port: Optional[int]):
|
|
133
|
-
"""Run model service"""
|
|
134
|
-
rprint(f"[blue]Starting model service: {model_name}[/blue]")
|
|
135
|
-
|
|
136
|
-
# Check if model is installed
|
|
137
|
-
if not model_manager.is_model_installed(model_name):
|
|
138
|
-
rprint(f"[red]Model {model_name} is not installed. Please run first: ollamadiffuser pull {model_name}[/red]")
|
|
139
|
-
sys.exit(1)
|
|
140
|
-
|
|
141
|
-
# Load model
|
|
142
|
-
rprint("[yellow]Loading model...[/yellow]")
|
|
143
|
-
if not model_manager.load_model(model_name):
|
|
144
|
-
rprint(f"[red]Failed to load model {model_name}![/red]")
|
|
145
|
-
sys.exit(1)
|
|
146
|
-
|
|
147
|
-
rprint(f"[green]Model {model_name} loaded successfully![/green]")
|
|
148
|
-
|
|
149
|
-
# Start server
|
|
150
|
-
try:
|
|
151
|
-
run_server(host=host, port=port)
|
|
152
|
-
except KeyboardInterrupt:
|
|
153
|
-
rprint("\n[yellow]Server stopped[/yellow]")
|
|
154
|
-
model_manager.unload_model()
|
|
155
|
-
# Clear the current model from settings when server stops
|
|
156
|
-
settings.current_model = None
|
|
157
|
-
settings.save_config()
|
|
158
56
|
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
if hardware:
|
|
174
|
-
# Show detailed hardware requirements
|
|
175
|
-
for model_name in installed_models:
|
|
176
|
-
info = model_manager.get_model_info(model_name)
|
|
177
|
-
if not info:
|
|
178
|
-
continue
|
|
179
|
-
|
|
180
|
-
# Check installation status
|
|
181
|
-
status = "✅ Installed"
|
|
182
|
-
if model_name == current_model:
|
|
183
|
-
status += " (current)"
|
|
184
|
-
size = info.get('size', 'Unknown')
|
|
185
|
-
|
|
186
|
-
# Create individual table for each model
|
|
187
|
-
table = Table(title=f"[bold cyan]{model_name}[/bold cyan] - {status}")
|
|
188
|
-
table.add_column("Property", style="yellow", no_wrap=True)
|
|
189
|
-
table.add_column("Value", style="white")
|
|
190
|
-
|
|
191
|
-
# Basic info
|
|
192
|
-
table.add_row("Type", info.get('model_type', 'Unknown'))
|
|
193
|
-
table.add_row("Size", size)
|
|
194
|
-
|
|
195
|
-
# Hardware requirements
|
|
196
|
-
hw_req = info.get('hardware_requirements', {})
|
|
197
|
-
if hw_req:
|
|
198
|
-
table.add_row("Min VRAM", f"{hw_req.get('min_vram_gb', 'Unknown')} GB")
|
|
199
|
-
table.add_row("Recommended VRAM", f"{hw_req.get('recommended_vram_gb', 'Unknown')} GB")
|
|
200
|
-
table.add_row("Min RAM", f"{hw_req.get('min_ram_gb', 'Unknown')} GB")
|
|
201
|
-
table.add_row("Recommended RAM", f"{hw_req.get('recommended_ram_gb', 'Unknown')} GB")
|
|
202
|
-
table.add_row("Disk Space", f"{hw_req.get('disk_space_gb', 'Unknown')} GB")
|
|
203
|
-
table.add_row("Supported Devices", ", ".join(hw_req.get('supported_devices', [])))
|
|
204
|
-
table.add_row("Performance Notes", hw_req.get('performance_notes', 'N/A'))
|
|
205
|
-
|
|
206
|
-
console.print(table)
|
|
207
|
-
console.print() # Add spacing between models
|
|
208
|
-
else:
|
|
209
|
-
# Show compact table
|
|
210
|
-
table = Table(title="Installed Models")
|
|
211
|
-
table.add_column("Model Name", style="cyan", no_wrap=True)
|
|
212
|
-
table.add_column("Status", style="green")
|
|
213
|
-
table.add_column("Size", style="blue")
|
|
214
|
-
table.add_column("Type", style="magenta")
|
|
215
|
-
table.add_column("Min VRAM", style="yellow")
|
|
216
|
-
|
|
217
|
-
for model_name in installed_models:
|
|
218
|
-
# Check installation status
|
|
219
|
-
status = "✅ Installed"
|
|
220
|
-
if model_name == current_model:
|
|
221
|
-
status += " (current)"
|
|
222
|
-
|
|
223
|
-
# Get model information
|
|
224
|
-
info = model_manager.get_model_info(model_name)
|
|
225
|
-
size = info.get('size', 'Unknown') if info else 'Unknown'
|
|
226
|
-
model_type = info.get('model_type', 'Unknown') if info else 'Unknown'
|
|
227
|
-
|
|
228
|
-
# Get hardware requirements
|
|
229
|
-
hw_req = info.get('hardware_requirements', {}) if info else {}
|
|
230
|
-
min_vram = f"{hw_req.get('min_vram_gb', '?')} GB" if hw_req else "Unknown"
|
|
231
|
-
|
|
232
|
-
table.add_row(model_name, status, size, model_type, min_vram)
|
|
233
|
-
|
|
234
|
-
console.print(table)
|
|
235
|
-
|
|
236
|
-
# Get counts for summary
|
|
237
|
-
available_models = model_registry.get_available_models()
|
|
238
|
-
external_models = model_registry.get_external_api_models_only()
|
|
239
|
-
|
|
240
|
-
console.print(f"\n[dim]💡 Installed: {len(installed_models)} models[/dim]")
|
|
241
|
-
console.print(f"[dim]💡 Available for download: {len(available_models)} models[/dim]")
|
|
242
|
-
if external_models:
|
|
243
|
-
console.print(f"[dim]💡 External API models: {len(external_models)} models[/dim]")
|
|
244
|
-
console.print("\n[dim]💡 Use --hardware flag to see detailed hardware requirements[/dim]")
|
|
245
|
-
console.print("[dim]💡 See all models: ollamadiffuser registry list[/dim]")
|
|
246
|
-
console.print("[dim]💡 See available models: ollamadiffuser registry list --available-only[/dim]")
|
|
247
|
-
|
|
248
|
-
@cli.command()
|
|
249
|
-
@click.argument('model_name')
|
|
250
|
-
def show(model_name: str):
|
|
251
|
-
"""Show model detailed information"""
|
|
252
|
-
info = model_manager.get_model_info(model_name)
|
|
253
|
-
|
|
254
|
-
if info is None:
|
|
255
|
-
rprint(f"[red]Model {model_name} does not exist[/red]")
|
|
256
|
-
sys.exit(1)
|
|
257
|
-
|
|
258
|
-
rprint(f"[bold cyan]Model Information: {model_name}[/bold cyan]")
|
|
259
|
-
rprint(f"Type: {info.get('model_type', 'Unknown')}")
|
|
260
|
-
rprint(f"Variant: {info.get('variant', 'Unknown')}")
|
|
261
|
-
rprint(f"Installed: {'Yes' if info.get('installed', False) else 'No'}")
|
|
262
|
-
|
|
263
|
-
if info.get('installed', False):
|
|
264
|
-
rprint(f"Local Path: {info.get('local_path', 'Unknown')}")
|
|
265
|
-
rprint(f"Size: {info.get('size', 'Unknown')}")
|
|
266
|
-
|
|
267
|
-
# Hardware requirements
|
|
268
|
-
if 'hardware_requirements' in info and info['hardware_requirements']:
|
|
269
|
-
hw_req = info['hardware_requirements']
|
|
270
|
-
rprint("\n[bold]Hardware Requirements:[/bold]")
|
|
271
|
-
rprint(f" Min VRAM: {hw_req.get('min_vram_gb', 'Unknown')} GB")
|
|
272
|
-
rprint(f" Recommended VRAM: {hw_req.get('recommended_vram_gb', 'Unknown')} GB")
|
|
273
|
-
rprint(f" Min RAM: {hw_req.get('min_ram_gb', 'Unknown')} GB")
|
|
274
|
-
rprint(f" Recommended RAM: {hw_req.get('recommended_ram_gb', 'Unknown')} GB")
|
|
275
|
-
rprint(f" Disk Space: {hw_req.get('disk_space_gb', 'Unknown')} GB")
|
|
276
|
-
rprint(f" Supported Devices: {', '.join(hw_req.get('supported_devices', []))}")
|
|
277
|
-
if hw_req.get('performance_notes'):
|
|
278
|
-
rprint(f" Performance Notes: {hw_req.get('performance_notes')}")
|
|
279
|
-
|
|
280
|
-
if 'parameters' in info and info['parameters']:
|
|
281
|
-
rprint("\n[bold]Default Parameters:[/bold]")
|
|
282
|
-
for key, value in info['parameters'].items():
|
|
283
|
-
rprint(f" {key}: {value}")
|
|
284
|
-
|
|
285
|
-
if 'components' in info and info['components']:
|
|
286
|
-
rprint("\n[bold]Components:[/bold]")
|
|
287
|
-
for key, value in info['components'].items():
|
|
288
|
-
rprint(f" {key}: {value}")
|
|
57
|
+
# --- Register model commands ---
|
|
58
|
+
from .model_commands import (
|
|
59
|
+
pull,
|
|
60
|
+
run,
|
|
61
|
+
list,
|
|
62
|
+
show,
|
|
63
|
+
check,
|
|
64
|
+
rm,
|
|
65
|
+
ps,
|
|
66
|
+
serve,
|
|
67
|
+
load,
|
|
68
|
+
unload,
|
|
69
|
+
stop,
|
|
70
|
+
)
|
|
289
71
|
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
license_type = model_info.get("license_info", {}).get("type", "Unknown")
|
|
302
|
-
rprint(f" {model:<30} {status:<15} ({license_type})")
|
|
303
|
-
return
|
|
304
|
-
|
|
305
|
-
if not model_name:
|
|
306
|
-
rprint("[bold red]❌ Please specify a model name or use --list[/bold red]")
|
|
307
|
-
rprint("[dim]Usage: ollamadiffuser check MODEL_NAME[/dim]")
|
|
308
|
-
rprint("[dim] ollamadiffuser check --list[/dim]")
|
|
309
|
-
return
|
|
310
|
-
|
|
311
|
-
# Check model download status directly
|
|
312
|
-
status = _check_download_status(model_name)
|
|
313
|
-
|
|
314
|
-
rprint("\n" + "="*60)
|
|
315
|
-
|
|
316
|
-
if status is True:
|
|
317
|
-
rprint(f"[green]🎉 {model_name} is ready to use![/green]")
|
|
318
|
-
rprint(f"\n[blue]💡 You can now run:[/blue]")
|
|
319
|
-
rprint(f" [cyan]ollamadiffuser run {model_name}[/cyan]")
|
|
320
|
-
elif status == "needs_config":
|
|
321
|
-
rprint(f"[yellow]⚠️ {model_name} files are complete but model needs configuration[/yellow]")
|
|
322
|
-
rprint(f"\n[blue]💡 Try reinstalling:[/blue]")
|
|
323
|
-
rprint(f" [cyan]ollamadiffuser pull {model_name} --force[/cyan]")
|
|
324
|
-
elif status == "downloading":
|
|
325
|
-
rprint(f"[yellow]🔄 {model_name} is currently downloading[/yellow]")
|
|
326
|
-
rprint(f"\n[blue]💡 Wait for download to complete or check progress[/blue]")
|
|
327
|
-
elif status == "incomplete":
|
|
328
|
-
rprint(f"[yellow]⚠️ Download is incomplete[/yellow]")
|
|
329
|
-
rprint(f"\n[blue]💡 Resume download with:[/blue]")
|
|
330
|
-
rprint(f" [cyan]ollamadiffuser pull {model_name}[/cyan]")
|
|
331
|
-
rprint(f"\n[blue]💡 Or force fresh download with:[/blue]")
|
|
332
|
-
rprint(f" [cyan]ollamadiffuser pull {model_name} --force[/cyan]")
|
|
333
|
-
else:
|
|
334
|
-
rprint(f"[red]❌ {model_name} is not downloaded[/red]")
|
|
335
|
-
rprint(f"\n[blue]💡 Download with:[/blue]")
|
|
336
|
-
rprint(f" [cyan]ollamadiffuser pull {model_name}[/cyan]")
|
|
337
|
-
|
|
338
|
-
_show_model_specific_help(model_name)
|
|
339
|
-
|
|
340
|
-
rprint(f"\n[dim]📚 For more help: ollamadiffuser --help[/dim]")
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
def _check_download_status(model_name: str):
|
|
344
|
-
"""Check the current download status of any model"""
|
|
345
|
-
from ..core.utils.download_utils import check_download_integrity, get_repo_file_list, format_size
|
|
346
|
-
import subprocess
|
|
347
|
-
|
|
348
|
-
rprint(f"[blue]🔍 Checking {model_name} download status...[/blue]\n")
|
|
349
|
-
|
|
350
|
-
# Check if model is in registry
|
|
351
|
-
if model_name not in model_manager.model_registry:
|
|
352
|
-
rprint(f"[red]❌ {model_name} not found in model registry[/red]")
|
|
353
|
-
available_models = model_manager.list_available_models()
|
|
354
|
-
rprint(f"[blue]📋 Available models: {', '.join(available_models)}[/blue]")
|
|
355
|
-
return False
|
|
356
|
-
|
|
357
|
-
model_info = model_manager.model_registry[model_name]
|
|
358
|
-
repo_id = model_info["repo_id"]
|
|
359
|
-
model_path = settings.get_model_path(model_name)
|
|
360
|
-
|
|
361
|
-
rprint(f"[cyan]📦 Model: {model_name}[/cyan]")
|
|
362
|
-
rprint(f"[cyan]🔗 Repository: {repo_id}[/cyan]")
|
|
363
|
-
rprint(f"[cyan]📁 Local path: {model_path}[/cyan]")
|
|
364
|
-
|
|
365
|
-
# Show model-specific info
|
|
366
|
-
license_info = model_info.get("license_info", {})
|
|
367
|
-
if license_info:
|
|
368
|
-
rprint(f"[yellow]📄 License: {license_info.get('type', 'Unknown')}[/yellow]")
|
|
369
|
-
rprint(f"[yellow]🔑 HF Token Required: {'Yes' if license_info.get('requires_agreement', False) else 'No'}[/yellow]")
|
|
370
|
-
rprint(f"[yellow]💼 Commercial Use: {'Allowed' if license_info.get('commercial_use', False) else 'Not Allowed'}[/yellow]")
|
|
371
|
-
|
|
372
|
-
# Show optimal parameters
|
|
373
|
-
params = model_info.get("parameters", {})
|
|
374
|
-
if params:
|
|
375
|
-
rprint(f"[green]⚡ Optimal Settings:[/green]")
|
|
376
|
-
rprint(f" Steps: {params.get('num_inference_steps', 'N/A')}")
|
|
377
|
-
rprint(f" Guidance: {params.get('guidance_scale', 'N/A')}")
|
|
378
|
-
if 'max_sequence_length' in params:
|
|
379
|
-
rprint(f" Max Seq Length: {params['max_sequence_length']}")
|
|
380
|
-
|
|
381
|
-
rprint()
|
|
382
|
-
|
|
383
|
-
# Check if directory exists
|
|
384
|
-
if not model_path.exists():
|
|
385
|
-
rprint("[yellow]📂 Status: Not downloaded[/yellow]")
|
|
386
|
-
return False
|
|
387
|
-
|
|
388
|
-
# Get repository file list
|
|
389
|
-
rprint("[blue]🌐 Getting repository information...[/blue]")
|
|
390
|
-
try:
|
|
391
|
-
file_sizes = get_repo_file_list(repo_id)
|
|
392
|
-
total_expected_size = sum(file_sizes.values())
|
|
393
|
-
total_files_expected = len(file_sizes)
|
|
394
|
-
|
|
395
|
-
rprint(f"[blue]📊 Expected: {total_files_expected} files, {format_size(total_expected_size)} total[/blue]")
|
|
396
|
-
except Exception as e:
|
|
397
|
-
rprint(f"[yellow]⚠️ Could not get repository info: {e}[/yellow]")
|
|
398
|
-
file_sizes = {}
|
|
399
|
-
total_expected_size = 0
|
|
400
|
-
total_files_expected = 0
|
|
401
|
-
|
|
402
|
-
# Check local files
|
|
403
|
-
local_files = []
|
|
404
|
-
local_size = 0
|
|
405
|
-
|
|
406
|
-
for file_path in model_path.rglob('*'):
|
|
407
|
-
if file_path.is_file():
|
|
408
|
-
rel_path = file_path.relative_to(model_path)
|
|
409
|
-
file_size = file_path.stat().st_size
|
|
410
|
-
local_files.append((str(rel_path), file_size))
|
|
411
|
-
local_size += file_size
|
|
412
|
-
|
|
413
|
-
rprint(f"[blue]💾 Downloaded: {len(local_files)} files, {format_size(local_size)} total[/blue]")
|
|
414
|
-
|
|
415
|
-
if total_expected_size > 0:
|
|
416
|
-
progress_percent = (local_size / total_expected_size) * 100
|
|
417
|
-
rprint(f"[blue]📈 Progress: {progress_percent:.1f}%[/blue]")
|
|
418
|
-
|
|
419
|
-
rprint()
|
|
420
|
-
|
|
421
|
-
# Check for missing files
|
|
422
|
-
if file_sizes:
|
|
423
|
-
# Check if we have size information from the API
|
|
424
|
-
has_size_info = any(size > 0 for size in file_sizes.values())
|
|
425
|
-
|
|
426
|
-
if has_size_info:
|
|
427
|
-
# Normal case: we have size information, do detailed comparison
|
|
428
|
-
missing_files = []
|
|
429
|
-
incomplete_files = []
|
|
430
|
-
|
|
431
|
-
for expected_file, expected_size in file_sizes.items():
|
|
432
|
-
local_file_path = model_path / expected_file
|
|
433
|
-
if not local_file_path.exists():
|
|
434
|
-
missing_files.append(expected_file)
|
|
435
|
-
elif expected_size > 0 and local_file_path.stat().st_size != expected_size:
|
|
436
|
-
local_size_actual = local_file_path.stat().st_size
|
|
437
|
-
incomplete_files.append((expected_file, local_size_actual, expected_size))
|
|
438
|
-
|
|
439
|
-
if missing_files:
|
|
440
|
-
rprint(f"[red]❌ Missing files ({len(missing_files)}):[/red]")
|
|
441
|
-
for missing_file in missing_files[:10]: # Show first 10
|
|
442
|
-
rprint(f" - {missing_file}")
|
|
443
|
-
if len(missing_files) > 10:
|
|
444
|
-
rprint(f" ... and {len(missing_files) - 10} more")
|
|
445
|
-
rprint()
|
|
446
|
-
|
|
447
|
-
if incomplete_files:
|
|
448
|
-
rprint(f"[yellow]⚠️ Incomplete files ({len(incomplete_files)}):[/yellow]")
|
|
449
|
-
for incomplete_file, actual_size, expected_size in incomplete_files[:5]:
|
|
450
|
-
rprint(f" - {incomplete_file}: {format_size(actual_size)}/{format_size(expected_size)}")
|
|
451
|
-
if len(incomplete_files) > 5:
|
|
452
|
-
rprint(f" ... and {len(incomplete_files) - 5} more")
|
|
453
|
-
rprint()
|
|
454
|
-
|
|
455
|
-
if not missing_files and not incomplete_files:
|
|
456
|
-
rprint("[green]✅ All files present and complete![/green]")
|
|
457
|
-
|
|
458
|
-
# Check integrity
|
|
459
|
-
rprint("[blue]🔍 Checking download integrity...[/blue]")
|
|
460
|
-
if check_download_integrity(str(model_path), repo_id):
|
|
461
|
-
rprint("[green]✅ Download integrity verified![/green]")
|
|
462
|
-
|
|
463
|
-
# Check if model is in configuration
|
|
464
|
-
if model_manager.is_model_installed(model_name):
|
|
465
|
-
rprint("[green]✅ Model is properly configured[/green]")
|
|
466
|
-
return True
|
|
467
|
-
else:
|
|
468
|
-
rprint("[yellow]⚠️ Model files complete but not in configuration[/yellow]")
|
|
469
|
-
return "needs_config"
|
|
470
|
-
else:
|
|
471
|
-
rprint("[red]❌ Download integrity check failed[/red]")
|
|
472
|
-
return False
|
|
473
|
-
else:
|
|
474
|
-
rprint("[yellow]⚠️ Download is incomplete[/yellow]")
|
|
475
|
-
return "incomplete"
|
|
476
|
-
else:
|
|
477
|
-
# No size information available from API (common with gated repos)
|
|
478
|
-
rprint("[blue]ℹ️ Repository API doesn't provide file sizes (common with gated models)[/blue]")
|
|
479
|
-
rprint("[blue]🔍 Checking essential model files instead...[/blue]")
|
|
480
|
-
|
|
481
|
-
# Check for essential model files
|
|
482
|
-
# Determine model type based on repo_id
|
|
483
|
-
is_controlnet = 'controlnet' in repo_id.lower()
|
|
484
|
-
|
|
485
|
-
if is_controlnet:
|
|
486
|
-
# ControlNet models have different essential files
|
|
487
|
-
essential_files = ['config.json']
|
|
488
|
-
essential_dirs = [] # ControlNet models don't have complex directory structure
|
|
489
|
-
else:
|
|
490
|
-
# Regular diffusion models
|
|
491
|
-
essential_files = ['model_index.json']
|
|
492
|
-
essential_dirs = ['transformer', 'text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'vae', 'scheduler']
|
|
493
|
-
|
|
494
|
-
missing_essential = []
|
|
495
|
-
for essential_file in essential_files:
|
|
496
|
-
if not (model_path / essential_file).exists():
|
|
497
|
-
missing_essential.append(essential_file)
|
|
498
|
-
|
|
499
|
-
existing_dirs = []
|
|
500
|
-
for essential_dir in essential_dirs:
|
|
501
|
-
if (model_path / essential_dir).exists():
|
|
502
|
-
existing_dirs.append(essential_dir)
|
|
503
|
-
|
|
504
|
-
if missing_essential:
|
|
505
|
-
rprint(f"[red]❌ Missing essential files: {', '.join(missing_essential)}[/red]")
|
|
506
|
-
return "incomplete"
|
|
507
|
-
|
|
508
|
-
if existing_dirs:
|
|
509
|
-
rprint(f"[green]✅ Found model components: {', '.join(existing_dirs)}[/green]")
|
|
510
|
-
|
|
511
|
-
# Check integrity
|
|
512
|
-
rprint("[blue]🔍 Checking download integrity...[/blue]")
|
|
513
|
-
if check_download_integrity(str(model_path), repo_id):
|
|
514
|
-
rprint("[green]✅ Download integrity verified![/green]")
|
|
515
|
-
|
|
516
|
-
# Check if model is in configuration
|
|
517
|
-
if model_manager.is_model_installed(model_name):
|
|
518
|
-
rprint("[green]✅ Model is properly configured and functional[/green]")
|
|
519
|
-
return True
|
|
520
|
-
else:
|
|
521
|
-
rprint("[yellow]⚠️ Model files complete but not in configuration[/yellow]")
|
|
522
|
-
return "needs_config"
|
|
523
|
-
else:
|
|
524
|
-
rprint("[red]❌ Download integrity check failed[/red]")
|
|
525
|
-
return False
|
|
526
|
-
|
|
527
|
-
# Check if download process is running
|
|
528
|
-
rprint("[blue]🔍 Checking for active download processes...[/blue]")
|
|
529
|
-
try:
|
|
530
|
-
result = subprocess.run(['ps', 'aux'], capture_output=True, text=True)
|
|
531
|
-
if f'ollamadiffuser pull {model_name}' in result.stdout:
|
|
532
|
-
rprint("[yellow]🔄 Download process is currently running[/yellow]")
|
|
533
|
-
return "downloading"
|
|
534
|
-
else:
|
|
535
|
-
rprint("[blue]💤 No active download process found[/blue]")
|
|
536
|
-
except Exception as e:
|
|
537
|
-
rprint(f"[yellow]⚠️ Could not check processes: {e}[/yellow]")
|
|
538
|
-
|
|
539
|
-
return "incomplete"
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
def _show_model_specific_help(model_name: str):
|
|
543
|
-
"""Show model-specific help and recommendations"""
|
|
544
|
-
model_info = model_manager.get_model_info(model_name)
|
|
545
|
-
if not model_info:
|
|
546
|
-
return
|
|
547
|
-
|
|
548
|
-
rprint(f"\n[bold blue]💡 {model_name} Specific Tips:[/bold blue]")
|
|
549
|
-
|
|
550
|
-
# License-specific help
|
|
551
|
-
license_info = model_info.get("license_info", {})
|
|
552
|
-
if license_info.get("requires_agreement", False):
|
|
553
|
-
rprint(f" [yellow]🔑 Requires HuggingFace token and license agreement[/yellow]")
|
|
554
|
-
rprint(f" [blue]📝 Visit: https://huggingface.co/{model_info['repo_id']}[/blue]")
|
|
555
|
-
rprint(f" [cyan]🔧 Set token: export HF_TOKEN=your_token_here[/cyan]")
|
|
556
|
-
else:
|
|
557
|
-
rprint(f" [green]✅ No HuggingFace token required![/green]")
|
|
558
|
-
|
|
559
|
-
# Model-specific optimizations
|
|
560
|
-
if "schnell" in model_name.lower():
|
|
561
|
-
rprint(f" [green]⚡ FLUX.1-schnell is 12x faster than FLUX.1-dev[/green]")
|
|
562
|
-
rprint(f" [green]🎯 Optimized for 4-step generation[/green]")
|
|
563
|
-
rprint(f" [green]💼 Commercial use allowed (Apache 2.0)[/green]")
|
|
564
|
-
elif "flux.1-dev" in model_name.lower():
|
|
565
|
-
rprint(f" [blue]🎨 Best quality FLUX model[/blue]")
|
|
566
|
-
rprint(f" [blue]🔬 Requires 50 steps for optimal results[/blue]")
|
|
567
|
-
rprint(f" [yellow]⚠️ Non-commercial license only[/yellow]")
|
|
568
|
-
elif "stable-diffusion-1.5" in model_name.lower():
|
|
569
|
-
rprint(f" [green]🚀 Great for learning and quick tests[/green]")
|
|
570
|
-
rprint(f" [green]💾 Smallest model, runs on most hardware[/green]")
|
|
571
|
-
elif "stable-diffusion-3.5" in model_name.lower():
|
|
572
|
-
rprint(f" [green]🏆 Excellent quality-to-speed ratio[/green]")
|
|
573
|
-
rprint(f" [green]🔄 Great LoRA ecosystem[/green]")
|
|
574
|
-
|
|
575
|
-
# Hardware recommendations
|
|
576
|
-
hw_req = model_info.get("hardware_requirements", {})
|
|
577
|
-
if hw_req:
|
|
578
|
-
min_vram = hw_req.get("min_vram_gb", 0)
|
|
579
|
-
if min_vram >= 12:
|
|
580
|
-
rprint(f" [yellow]🖥️ Requires high-end GPU (RTX 4070+ or M2 Pro+)[/yellow]")
|
|
581
|
-
elif min_vram >= 8:
|
|
582
|
-
rprint(f" [blue]🖥️ Requires mid-range GPU (RTX 3080+ or M1 Pro+)[/blue]")
|
|
583
|
-
else:
|
|
584
|
-
rprint(f" [green]🖥️ Runs on most modern GPUs[/green]")
|
|
585
|
-
|
|
586
|
-
@cli.command()
|
|
587
|
-
@click.argument('model_name')
|
|
588
|
-
@click.confirmation_option(prompt='Are you sure you want to delete this model?')
|
|
589
|
-
def rm(model_name: str):
|
|
590
|
-
"""Remove model"""
|
|
591
|
-
if model_manager.remove_model(model_name):
|
|
592
|
-
rprint(f"[green]Model {model_name} removed successfully![/green]")
|
|
593
|
-
else:
|
|
594
|
-
rprint(f"[red]Failed to remove model {model_name}![/red]")
|
|
595
|
-
sys.exit(1)
|
|
596
|
-
|
|
597
|
-
@cli.command()
|
|
598
|
-
def ps():
|
|
599
|
-
"""Show currently running model"""
|
|
600
|
-
current_model = model_manager.get_current_model()
|
|
601
|
-
server_running = model_manager.is_server_running()
|
|
602
|
-
|
|
603
|
-
if current_model:
|
|
604
|
-
rprint(f"[green]Current model: {current_model}[/green]")
|
|
605
|
-
|
|
606
|
-
# Check server status
|
|
607
|
-
if server_running:
|
|
608
|
-
rprint(f"[green]Server status: Running on {settings.server.host}:{settings.server.port}[/green]")
|
|
609
|
-
|
|
610
|
-
# Try to get model info from the running server
|
|
611
|
-
try:
|
|
612
|
-
import requests
|
|
613
|
-
response = requests.get(f"http://{settings.server.host}:{settings.server.port}/api/models/running", timeout=2)
|
|
614
|
-
if response.status_code == 200:
|
|
615
|
-
data = response.json()
|
|
616
|
-
if data.get('loaded'):
|
|
617
|
-
info = data.get('info', {})
|
|
618
|
-
rprint(f"Device: {info.get('device', 'Unknown')}")
|
|
619
|
-
rprint(f"Type: {info.get('type', 'Unknown')}")
|
|
620
|
-
rprint(f"Variant: {info.get('variant', 'Unknown')}")
|
|
621
|
-
else:
|
|
622
|
-
rprint("[yellow]Model loaded but not active in server[/yellow]")
|
|
623
|
-
except:
|
|
624
|
-
pass
|
|
625
|
-
else:
|
|
626
|
-
rprint("[yellow]Server status: Not running[/yellow]")
|
|
627
|
-
rprint("[dim]Model is set as current but server is not active[/dim]")
|
|
628
|
-
|
|
629
|
-
# Show model info from local config
|
|
630
|
-
model_info = model_manager.get_model_info(current_model)
|
|
631
|
-
if model_info:
|
|
632
|
-
rprint(f"Model type: {model_info.get('model_type', 'Unknown')}")
|
|
633
|
-
if model_info.get('installed'):
|
|
634
|
-
rprint(f"Size: {model_info.get('size', 'Unknown')}")
|
|
635
|
-
else:
|
|
636
|
-
if server_running:
|
|
637
|
-
rprint("[yellow]Server is running but no model is loaded[/yellow]")
|
|
638
|
-
rprint(f"[green]Server status: Running on {settings.server.host}:{settings.server.port}[/green]")
|
|
639
|
-
else:
|
|
640
|
-
rprint("[yellow]No model is currently running[/yellow]")
|
|
641
|
-
rprint("[dim]Use 'ollamadiffuser run <model>' to start a model[/dim]")
|
|
642
|
-
|
|
643
|
-
@cli.command()
|
|
644
|
-
@click.option('--host', '-h', default=None, help='Server host address')
|
|
645
|
-
@click.option('--port', '-p', default=None, type=int, help='Server port')
|
|
646
|
-
def serve(host: Optional[str], port: Optional[int]):
|
|
647
|
-
"""Start API server (without loading model)"""
|
|
648
|
-
rprint("[blue]Starting OllamaDiffuser API server...[/blue]")
|
|
649
|
-
|
|
650
|
-
try:
|
|
651
|
-
run_server(host=host, port=port)
|
|
652
|
-
except KeyboardInterrupt:
|
|
653
|
-
rprint("\n[yellow]Server stopped[/yellow]")
|
|
654
|
-
|
|
655
|
-
@cli.command()
|
|
656
|
-
@click.argument('model_name')
|
|
657
|
-
def load(model_name: str):
|
|
658
|
-
"""Load model into memory"""
|
|
659
|
-
rprint(f"[blue]Loading model: {model_name}[/blue]")
|
|
660
|
-
|
|
661
|
-
if model_manager.load_model(model_name):
|
|
662
|
-
rprint(f"[green]Model {model_name} loaded successfully![/green]")
|
|
663
|
-
else:
|
|
664
|
-
rprint(f"[red]Failed to load model {model_name}![/red]")
|
|
665
|
-
sys.exit(1)
|
|
72
|
+
cli.add_command(pull)
|
|
73
|
+
cli.add_command(run)
|
|
74
|
+
cli.add_command(list)
|
|
75
|
+
cli.add_command(show)
|
|
76
|
+
cli.add_command(check)
|
|
77
|
+
cli.add_command(rm)
|
|
78
|
+
cli.add_command(ps)
|
|
79
|
+
cli.add_command(serve)
|
|
80
|
+
cli.add_command(load)
|
|
81
|
+
cli.add_command(unload)
|
|
82
|
+
cli.add_command(stop)
|
|
666
83
|
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
"""Unload current model"""
|
|
670
|
-
if model_manager.is_model_loaded():
|
|
671
|
-
current_model = model_manager.get_current_model()
|
|
672
|
-
model_manager.unload_model()
|
|
673
|
-
rprint(f"[green]Model {current_model} unloaded[/green]")
|
|
674
|
-
else:
|
|
675
|
-
rprint("[yellow]No model to unload[/yellow]")
|
|
84
|
+
# --- Register LoRA commands ---
|
|
85
|
+
from .lora_commands import lora
|
|
676
86
|
|
|
677
|
-
|
|
678
|
-
def stop():
|
|
679
|
-
"""Stop running server"""
|
|
680
|
-
if not model_manager.is_server_running():
|
|
681
|
-
rprint("[yellow]No server is currently running[/yellow]")
|
|
682
|
-
return
|
|
683
|
-
|
|
684
|
-
try:
|
|
685
|
-
import requests
|
|
686
|
-
import signal
|
|
687
|
-
import psutil
|
|
688
|
-
|
|
689
|
-
host = settings.server.host
|
|
690
|
-
port = settings.server.port
|
|
691
|
-
|
|
692
|
-
# Try graceful shutdown via API first
|
|
693
|
-
try:
|
|
694
|
-
response = requests.post(f"http://{host}:{port}/api/shutdown", timeout=5)
|
|
695
|
-
if response.status_code == 200:
|
|
696
|
-
rprint("[green]Server stopped gracefully[/green]")
|
|
697
|
-
return
|
|
698
|
-
except:
|
|
699
|
-
pass
|
|
700
|
-
|
|
701
|
-
# Fallback: Find and terminate the process
|
|
702
|
-
for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
|
|
703
|
-
try:
|
|
704
|
-
cmdline = proc.info['cmdline']
|
|
705
|
-
if cmdline and any('uvicorn' in arg for arg in cmdline) and any(str(port) in arg for arg in cmdline):
|
|
706
|
-
proc.terminate()
|
|
707
|
-
proc.wait(timeout=10)
|
|
708
|
-
rprint(f"[green]Server process (PID: {proc.info['pid']}) stopped[/green]")
|
|
709
|
-
return
|
|
710
|
-
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.TimeoutExpired):
|
|
711
|
-
continue
|
|
712
|
-
|
|
713
|
-
rprint("[red]Could not find or stop the server process[/red]")
|
|
714
|
-
|
|
715
|
-
except ImportError:
|
|
716
|
-
rprint("[red]psutil package required for stop command. Install with: pip install psutil[/red]")
|
|
717
|
-
except Exception as e:
|
|
718
|
-
rprint(f"[red]Failed to stop server: {e}[/red]")
|
|
87
|
+
cli.add_command(lora)
|
|
719
88
|
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
"""LoRA (Low-Rank Adaptation) management commands"""
|
|
723
|
-
pass
|
|
89
|
+
# --- Register registry commands ---
|
|
90
|
+
from .registry_commands import registry
|
|
724
91
|
|
|
725
|
-
|
|
726
|
-
@click.argument('repo_id')
|
|
727
|
-
@click.option('--weight-name', '-w', help='Specific weight file name (e.g., lora.safetensors)')
|
|
728
|
-
@click.option('--alias', '-a', help='Local alias name for the LoRA')
|
|
729
|
-
def pull(repo_id: str, weight_name: Optional[str], alias: Optional[str]):
|
|
730
|
-
"""Download LoRA weights from Hugging Face Hub"""
|
|
731
|
-
from ..core.utils.lora_manager import lora_manager
|
|
732
|
-
|
|
733
|
-
rprint(f"[blue]Downloading LoRA: {repo_id}[/blue]")
|
|
734
|
-
|
|
735
|
-
with Progress(
|
|
736
|
-
SpinnerColumn(),
|
|
737
|
-
TextColumn("[progress.description]{task.description}"),
|
|
738
|
-
console=console
|
|
739
|
-
) as progress:
|
|
740
|
-
task = progress.add_task(f"Downloading LoRA...", total=None)
|
|
741
|
-
|
|
742
|
-
def progress_callback(message: str):
|
|
743
|
-
progress.update(task, description=message)
|
|
744
|
-
|
|
745
|
-
if lora_manager.pull_lora(repo_id, weight_name=weight_name, alias=alias, progress_callback=progress_callback):
|
|
746
|
-
progress.update(task, description=f"✅ LoRA download completed")
|
|
747
|
-
rprint(f"[green]LoRA {repo_id} downloaded successfully![/green]")
|
|
748
|
-
else:
|
|
749
|
-
progress.update(task, description=f"❌ LoRA download failed")
|
|
750
|
-
rprint(f"[red]LoRA {repo_id} download failed![/red]")
|
|
751
|
-
sys.exit(1)
|
|
92
|
+
cli.add_command(registry)
|
|
752
93
|
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
@click.option('--scale', '-s', default=1.0, type=float, help='LoRA scale/strength (default: 1.0)')
|
|
756
|
-
def load(lora_name: str, scale: float):
|
|
757
|
-
"""Load LoRA weights into the current model"""
|
|
758
|
-
from ..core.utils.lora_manager import lora_manager
|
|
759
|
-
|
|
760
|
-
rprint(f"[blue]Loading LoRA: {lora_name} (scale: {scale})[/blue]")
|
|
761
|
-
|
|
762
|
-
if lora_manager.load_lora(lora_name, scale=scale):
|
|
763
|
-
rprint(f"[green]LoRA {lora_name} loaded successfully![/green]")
|
|
764
|
-
else:
|
|
765
|
-
rprint(f"[red]Failed to load LoRA {lora_name}![/red]")
|
|
766
|
-
sys.exit(1)
|
|
94
|
+
# --- Register config commands ---
|
|
95
|
+
from .config_commands import config
|
|
767
96
|
|
|
768
|
-
|
|
769
|
-
def unload():
|
|
770
|
-
"""Unload current LoRA weights"""
|
|
771
|
-
from ..core.utils.lora_manager import lora_manager
|
|
772
|
-
|
|
773
|
-
rprint("[blue]Unloading LoRA weights...[/blue]")
|
|
774
|
-
|
|
775
|
-
if lora_manager.unload_lora():
|
|
776
|
-
rprint("[green]LoRA weights unloaded successfully![/green]")
|
|
777
|
-
else:
|
|
778
|
-
rprint("[red]Failed to unload LoRA weights![/red]")
|
|
779
|
-
sys.exit(1)
|
|
97
|
+
cli.add_command(config)
|
|
780
98
|
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
@click.confirmation_option(prompt='Are you sure you want to delete this LoRA?')
|
|
784
|
-
def rm(lora_name: str):
|
|
785
|
-
"""Remove LoRA weights"""
|
|
786
|
-
from ..core.utils.lora_manager import lora_manager
|
|
787
|
-
|
|
788
|
-
rprint(f"[blue]Removing LoRA: {lora_name}[/blue]")
|
|
789
|
-
|
|
790
|
-
if lora_manager.remove_lora(lora_name):
|
|
791
|
-
rprint(f"[green]LoRA {lora_name} removed successfully![/green]")
|
|
792
|
-
else:
|
|
793
|
-
rprint(f"[red]Failed to remove LoRA {lora_name}![/red]")
|
|
794
|
-
sys.exit(1)
|
|
99
|
+
# --- Register recommend command ---
|
|
100
|
+
from .recommend_command import recommend
|
|
795
101
|
|
|
796
|
-
|
|
797
|
-
def ps():
|
|
798
|
-
"""Show currently loaded LoRA status"""
|
|
799
|
-
from ..core.utils.lora_manager import lora_manager
|
|
800
|
-
|
|
801
|
-
# Check if server is running
|
|
802
|
-
server_running = lora_manager._is_server_running()
|
|
803
|
-
current_lora = lora_manager.get_current_lora()
|
|
804
|
-
|
|
805
|
-
if server_running:
|
|
806
|
-
rprint(f"[green]Server status: Running on {settings.server.host}:{settings.server.port}[/green]")
|
|
807
|
-
|
|
808
|
-
# Try to get LoRA status from the running server
|
|
809
|
-
try:
|
|
810
|
-
import requests
|
|
811
|
-
response = requests.get(f"http://{settings.server.host}:{settings.server.port}/api/models/running", timeout=2)
|
|
812
|
-
if response.status_code == 200:
|
|
813
|
-
data = response.json()
|
|
814
|
-
if data.get('loaded'):
|
|
815
|
-
model_info = data.get('info', {})
|
|
816
|
-
rprint(f"Model: {data.get('model', 'Unknown')}")
|
|
817
|
-
rprint(f"Device: {model_info.get('device', 'Unknown')}")
|
|
818
|
-
rprint(f"Type: {model_info.get('type', 'Unknown')}")
|
|
819
|
-
else:
|
|
820
|
-
rprint("[yellow]No model loaded in server[/yellow]")
|
|
821
|
-
return
|
|
822
|
-
except Exception as e:
|
|
823
|
-
rprint(f"[red]Failed to get server status: {e}[/red]")
|
|
824
|
-
return
|
|
825
|
-
else:
|
|
826
|
-
# Check local model manager
|
|
827
|
-
if model_manager.is_model_loaded():
|
|
828
|
-
current_model = model_manager.get_current_model()
|
|
829
|
-
rprint(f"[green]Model loaded locally: {current_model}[/green]")
|
|
830
|
-
else:
|
|
831
|
-
rprint("[yellow]No server running and no local model loaded[/yellow]")
|
|
832
|
-
rprint("[dim]Use 'ollamadiffuser run <model>' to start a model[/dim]")
|
|
833
|
-
return
|
|
834
|
-
|
|
835
|
-
# Show LoRA status
|
|
836
|
-
lora_status_shown = False
|
|
837
|
-
lora_loaded_on_server = False
|
|
838
|
-
|
|
839
|
-
# Try to get LoRA status from server if running
|
|
840
|
-
if server_running:
|
|
841
|
-
try:
|
|
842
|
-
import requests
|
|
843
|
-
response = requests.get(f"http://{settings.server.host}:{settings.server.port}/api/lora/status", timeout=2)
|
|
844
|
-
if response.status_code == 200:
|
|
845
|
-
lora_data = response.json()
|
|
846
|
-
if lora_data.get('loaded'):
|
|
847
|
-
lora_info = lora_data.get('info', {})
|
|
848
|
-
rprint(f"\n[bold green]🔄 LoRA Status: LOADED (via server)[/bold green]")
|
|
849
|
-
rprint(f"Adapter: {lora_info.get('adapter_name', 'Unknown')}")
|
|
850
|
-
if 'scale' in lora_info:
|
|
851
|
-
rprint(f"Scale: {lora_info.get('scale', 'Unknown')}")
|
|
852
|
-
if 'adapters' in lora_info:
|
|
853
|
-
rprint(f"Active Adapters: {', '.join(lora_info.get('adapters', []))}")
|
|
854
|
-
lora_status_shown = True
|
|
855
|
-
lora_loaded_on_server = True
|
|
856
|
-
else:
|
|
857
|
-
rprint(f"\n[dim]💾 LoRA Status: No LoRA loaded (server)[/dim]")
|
|
858
|
-
lora_status_shown = True
|
|
859
|
-
except Exception as e:
|
|
860
|
-
rprint(f"\n[yellow]⚠️ Failed to get LoRA status from server: {e}[/yellow]")
|
|
861
|
-
|
|
862
|
-
# Fallback to local LoRA manager state
|
|
863
|
-
if not lora_status_shown:
|
|
864
|
-
if current_lora:
|
|
865
|
-
lora_info = lora_manager.get_lora_info(current_lora)
|
|
866
|
-
if lora_info:
|
|
867
|
-
rprint(f"\n[bold green]🔄 LoRA Status: LOADED (local)[/bold green]")
|
|
868
|
-
rprint(f"Name: {current_lora}")
|
|
869
|
-
rprint(f"Repository: {lora_info.get('repo_id', 'Unknown')}")
|
|
870
|
-
rprint(f"Weight File: {lora_info.get('weight_name', 'Unknown')}")
|
|
871
|
-
rprint(f"Size: {lora_info.get('size', 'Unknown')}")
|
|
872
|
-
rprint(f"Local Path: {lora_info.get('path', 'Unknown')}")
|
|
873
|
-
else:
|
|
874
|
-
rprint(f"\n[yellow]⚠️ LoRA {current_lora} is set as current but info not found[/yellow]")
|
|
875
|
-
else:
|
|
876
|
-
rprint(f"\n[dim]💾 LoRA Status: No LoRA loaded[/dim]")
|
|
877
|
-
|
|
878
|
-
if not lora_loaded_on_server:
|
|
879
|
-
rprint("[dim]Use 'ollamadiffuser lora load <lora_name>' to load a LoRA[/dim]")
|
|
102
|
+
cli.add_command(recommend)
|
|
880
103
|
|
|
881
|
-
|
|
882
|
-
def list():
|
|
883
|
-
"""List available and installed LoRA weights"""
|
|
884
|
-
from ..core.utils.lora_manager import lora_manager
|
|
885
|
-
|
|
886
|
-
installed_loras = lora_manager.list_installed_loras()
|
|
887
|
-
current_lora = lora_manager.get_current_lora()
|
|
888
|
-
|
|
889
|
-
if not installed_loras:
|
|
890
|
-
rprint("[yellow]No LoRA weights installed.[/yellow]")
|
|
891
|
-
rprint("\n[dim]💡 Use 'ollamadiffuser lora pull <repo_id>' to download LoRA weights[/dim]")
|
|
892
|
-
return
|
|
893
|
-
|
|
894
|
-
table = Table(title="Installed LoRA Weights")
|
|
895
|
-
table.add_column("Name", style="cyan", no_wrap=True)
|
|
896
|
-
table.add_column("Repository", style="blue")
|
|
897
|
-
table.add_column("Status", style="green")
|
|
898
|
-
table.add_column("Size", style="yellow")
|
|
899
|
-
|
|
900
|
-
for lora_name, lora_info in installed_loras.items():
|
|
901
|
-
status = "🔄 Loaded" if lora_name == current_lora else "💾 Available"
|
|
902
|
-
size = lora_info.get('size', 'Unknown')
|
|
903
|
-
repo_id = lora_info.get('repo_id', 'Unknown')
|
|
904
|
-
|
|
905
|
-
table.add_row(lora_name, repo_id, status, size)
|
|
906
|
-
|
|
907
|
-
console.print(table)
|
|
104
|
+
# --- Utility commands ---
|
|
908
105
|
|
|
909
|
-
@lora.command()
|
|
910
|
-
@click.argument('lora_name')
|
|
911
|
-
def show(lora_name: str):
|
|
912
|
-
"""Show detailed LoRA information"""
|
|
913
|
-
from ..core.utils.lora_manager import lora_manager
|
|
914
|
-
|
|
915
|
-
lora_info = lora_manager.get_lora_info(lora_name)
|
|
916
|
-
|
|
917
|
-
if not lora_info:
|
|
918
|
-
rprint(f"[red]LoRA {lora_name} not found.[/red]")
|
|
919
|
-
sys.exit(1)
|
|
920
|
-
|
|
921
|
-
rprint(f"[bold cyan]LoRA Information: {lora_name}[/bold cyan]")
|
|
922
|
-
rprint(f"Repository: {lora_info.get('repo_id', 'Unknown')}")
|
|
923
|
-
rprint(f"Weight File: {lora_info.get('weight_name', 'Unknown')}")
|
|
924
|
-
rprint(f"Local Path: {lora_info.get('path', 'Unknown')}")
|
|
925
|
-
rprint(f"Size: {lora_info.get('size', 'Unknown')}")
|
|
926
|
-
rprint(f"Downloaded: {lora_info.get('downloaded_at', 'Unknown')}")
|
|
927
|
-
|
|
928
|
-
if lora_info.get('description'):
|
|
929
|
-
rprint(f"Description: {lora_info.get('description')}")
|
|
930
106
|
|
|
931
107
|
@cli.command()
|
|
932
108
|
def version():
|
|
933
109
|
"""Show version information"""
|
|
934
110
|
print_version()
|
|
935
|
-
rprint("\n[bold]Features:[/bold]")
|
|
936
|
-
rprint("• 🚀 Fast Startup with lazy loading architecture")
|
|
937
|
-
rprint("• 🎛️ ControlNet Support with 10+ control types")
|
|
938
|
-
rprint("• 🔄 LoRA Integration with dynamic loading")
|
|
939
|
-
rprint("• 🌐 Multiple Interfaces: CLI, Python API, Web UI, REST API")
|
|
940
|
-
rprint("• 📦 Easy model management and switching")
|
|
941
|
-
rprint("• ⚡ Performance optimized with GPU acceleration")
|
|
942
|
-
|
|
943
111
|
rprint("\n[bold]Supported Models:[/bold]")
|
|
944
|
-
rprint("
|
|
945
|
-
rprint("
|
|
946
|
-
rprint("
|
|
947
|
-
rprint("
|
|
948
|
-
rprint("• Stable Diffusion 1.5")
|
|
949
|
-
rprint("• ControlNet models for SD15 and SDXL")
|
|
950
|
-
|
|
112
|
+
rprint(" FLUX.1-schnell, FLUX.1-dev, SD 3.5 Medium, SDXL, SD 1.5")
|
|
113
|
+
rprint(" ControlNet (SD15 + SDXL), AnimateDiff, HiDream, GGUF variants")
|
|
114
|
+
rprint("\n[bold]Features:[/bold]")
|
|
115
|
+
rprint(" img2img, inpainting, LoRA, ControlNet, async API, Web UI")
|
|
951
116
|
rprint("\n[dim]For help: ollamadiffuser --help[/dim]")
|
|
952
|
-
rprint("[dim]For diagnostics: ollamadiffuser doctor[/dim]")
|
|
953
|
-
rprint("[dim]For ControlNet samples: ollamadiffuser create-samples[/dim]")
|
|
954
117
|
|
|
955
|
-
|
|
118
|
+
|
|
119
|
+
@cli.command(name="verify-deps")
|
|
956
120
|
def verify_deps_cmd():
|
|
957
121
|
"""Verify and install missing dependencies"""
|
|
958
122
|
from .commands import verify_deps
|
|
123
|
+
|
|
959
124
|
ctx = click.Context(verify_deps)
|
|
960
125
|
ctx.invoke(verify_deps)
|
|
961
126
|
|
|
127
|
+
|
|
962
128
|
@cli.command()
|
|
963
129
|
def doctor():
|
|
964
130
|
"""Run comprehensive system diagnostics"""
|
|
965
131
|
from .commands import doctor
|
|
132
|
+
|
|
966
133
|
ctx = click.Context(doctor)
|
|
967
134
|
ctx.invoke(doctor)
|
|
968
135
|
|
|
969
|
-
@cli.command(name='create-samples')
|
|
970
|
-
@click.option('--force', is_flag=True, help='Force recreation of all samples even if they exist')
|
|
971
|
-
def create_samples_cmd(force):
|
|
972
|
-
"""Create ControlNet sample images for the Web UI"""
|
|
973
|
-
from .commands import create_samples
|
|
974
|
-
ctx = click.Context(create_samples)
|
|
975
|
-
ctx.invoke(create_samples, force=force)
|
|
976
|
-
|
|
977
|
-
@cli.group(hidden=True)
|
|
978
|
-
def registry():
|
|
979
|
-
"""Manage model registry (internal command)"""
|
|
980
|
-
pass
|
|
981
136
|
|
|
982
|
-
@
|
|
983
|
-
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
def list(format: str, installed_only: bool, available_only: bool, external_only: bool):
|
|
988
|
-
"""List models in the registry with installation status"""
|
|
989
|
-
|
|
990
|
-
# Get different model categories
|
|
991
|
-
if installed_only:
|
|
992
|
-
models = model_registry.get_installed_models()
|
|
993
|
-
title = "Installed Models"
|
|
994
|
-
elif available_only:
|
|
995
|
-
models = model_registry.get_available_models()
|
|
996
|
-
title = "Available Models (Not Installed)"
|
|
997
|
-
elif external_only:
|
|
998
|
-
models = model_registry.get_external_api_models_only()
|
|
999
|
-
title = "External API Models"
|
|
1000
|
-
else:
|
|
1001
|
-
models = model_registry.get_all_models()
|
|
1002
|
-
title = "All Models (Installed + Available)"
|
|
1003
|
-
|
|
1004
|
-
installed_model_names = set(model_registry.get_installed_models().keys())
|
|
1005
|
-
local_model_names = set(model_registry.get_local_models_only().keys())
|
|
1006
|
-
external_model_names = set(model_registry.get_external_api_models_only().keys())
|
|
1007
|
-
current_model = model_manager.get_current_model()
|
|
1008
|
-
|
|
1009
|
-
if not models:
|
|
1010
|
-
rprint(f"[yellow]No models found in category: {title}[/yellow]")
|
|
1011
|
-
return
|
|
1012
|
-
|
|
1013
|
-
if format == 'table':
|
|
1014
|
-
table = Table(title=title)
|
|
1015
|
-
table.add_column("Model Name", style="cyan", no_wrap=True)
|
|
1016
|
-
table.add_column("Type", style="yellow")
|
|
1017
|
-
table.add_column("Repository", style="blue")
|
|
1018
|
-
table.add_column("Status", style="green")
|
|
1019
|
-
table.add_column("Source", style="magenta")
|
|
1020
|
-
|
|
1021
|
-
for model_name, model_info in models.items():
|
|
1022
|
-
# Check installation status
|
|
1023
|
-
if model_name in installed_model_names:
|
|
1024
|
-
status = "✅ Installed"
|
|
1025
|
-
if model_name == current_model:
|
|
1026
|
-
status += " (current)"
|
|
1027
|
-
else:
|
|
1028
|
-
status = "⬇️ Available"
|
|
1029
|
-
|
|
1030
|
-
# Determine source
|
|
1031
|
-
if model_name in local_model_names and model_name in external_model_names:
|
|
1032
|
-
source = "Local + External"
|
|
1033
|
-
elif model_name in local_model_names:
|
|
1034
|
-
source = "Local"
|
|
1035
|
-
elif model_name in external_model_names:
|
|
1036
|
-
source = "External API"
|
|
1037
|
-
else:
|
|
1038
|
-
source = "Unknown"
|
|
1039
|
-
|
|
1040
|
-
table.add_row(
|
|
1041
|
-
model_name,
|
|
1042
|
-
model_info.get('model_type', 'Unknown'),
|
|
1043
|
-
model_info.get('repo_id', 'Unknown'),
|
|
1044
|
-
status,
|
|
1045
|
-
source
|
|
1046
|
-
)
|
|
1047
|
-
|
|
1048
|
-
console.print(table)
|
|
1049
|
-
|
|
1050
|
-
# Show summary
|
|
1051
|
-
if not (installed_only or available_only or external_only):
|
|
1052
|
-
total_count = len(models)
|
|
1053
|
-
installed_count = len(installed_model_names)
|
|
1054
|
-
available_count = total_count - installed_count
|
|
1055
|
-
local_count = len(local_model_names)
|
|
1056
|
-
external_count = len(external_model_names)
|
|
1057
|
-
|
|
1058
|
-
console.print(f"\n[dim]Summary:[/dim]")
|
|
1059
|
-
console.print(f"[dim] • Total: {total_count} models[/dim]")
|
|
1060
|
-
console.print(f"[dim] • Installed: {installed_count} models[/dim]")
|
|
1061
|
-
console.print(f"[dim] • Available: {available_count} models[/dim]")
|
|
1062
|
-
console.print(f"[dim] • Local registry: {local_count} models[/dim]")
|
|
1063
|
-
console.print(f"[dim] • External API: {external_count} models[/dim]")
|
|
1064
|
-
|
|
1065
|
-
elif format == 'json':
|
|
1066
|
-
import json
|
|
1067
|
-
print(json.dumps(models, indent=2, ensure_ascii=False))
|
|
1068
|
-
|
|
1069
|
-
elif format == 'yaml':
|
|
1070
|
-
import yaml
|
|
1071
|
-
print(yaml.dump(models, default_flow_style=False, allow_unicode=True))
|
|
1072
|
-
|
|
1073
|
-
@registry.command()
|
|
1074
|
-
@click.argument('model_name')
|
|
1075
|
-
@click.argument('repo_id')
|
|
1076
|
-
@click.argument('model_type')
|
|
1077
|
-
@click.option('--variant', help='Model variant (e.g., fp16, bf16)')
|
|
1078
|
-
@click.option('--license-type', help='License type')
|
|
1079
|
-
@click.option('--commercial-use', type=bool, help='Whether commercial use is allowed')
|
|
1080
|
-
@click.option('--save', is_flag=True, help='Save to user configuration file')
|
|
1081
|
-
def add(model_name: str, repo_id: str, model_type: str, variant: Optional[str],
|
|
1082
|
-
license_type: Optional[str], commercial_use: Optional[bool], save: bool):
|
|
1083
|
-
"""Add a new model to the registry"""
|
|
1084
|
-
|
|
1085
|
-
model_config = {
|
|
1086
|
-
"repo_id": repo_id,
|
|
1087
|
-
"model_type": model_type
|
|
1088
|
-
}
|
|
1089
|
-
|
|
1090
|
-
if variant:
|
|
1091
|
-
model_config["variant"] = variant
|
|
1092
|
-
|
|
1093
|
-
if license_type or commercial_use is not None:
|
|
1094
|
-
license_info = {}
|
|
1095
|
-
if license_type:
|
|
1096
|
-
license_info["type"] = license_type
|
|
1097
|
-
if commercial_use is not None:
|
|
1098
|
-
license_info["commercial_use"] = commercial_use
|
|
1099
|
-
model_config["license_info"] = license_info
|
|
1100
|
-
|
|
1101
|
-
if model_registry.add_model(model_name, model_config):
|
|
1102
|
-
rprint(f"[green]Model '{model_name}' added to registry successfully![/green]")
|
|
1103
|
-
|
|
1104
|
-
if save:
|
|
1105
|
-
try:
|
|
1106
|
-
# Load existing user models and add the new one
|
|
1107
|
-
user_models = {}
|
|
1108
|
-
config_path = settings.config_dir / "models.json"
|
|
1109
|
-
if config_path.exists():
|
|
1110
|
-
import json
|
|
1111
|
-
with open(config_path, 'r') as f:
|
|
1112
|
-
data = json.load(f)
|
|
1113
|
-
user_models = data.get('models', {})
|
|
1114
|
-
|
|
1115
|
-
user_models[model_name] = model_config
|
|
1116
|
-
model_registry.save_user_config(user_models, config_path)
|
|
1117
|
-
rprint(f"[green]Model configuration saved to {config_path}[/green]")
|
|
1118
|
-
except Exception as e:
|
|
1119
|
-
rprint(f"[red]Failed to save configuration: {e}[/red]")
|
|
1120
|
-
else:
|
|
1121
|
-
rprint(f"[red]Failed to add model '{model_name}' to registry![/red]")
|
|
1122
|
-
sys.exit(1)
|
|
137
|
+
@cli.command(name="mcp")
|
|
138
|
+
def mcp_cmd():
|
|
139
|
+
"""Start the MCP (Model Context Protocol) server for AI assistant integration."""
|
|
140
|
+
try:
|
|
141
|
+
from ..mcp.server import main as mcp_main
|
|
1123
142
|
|
|
1124
|
-
|
|
1125
|
-
|
|
1126
|
-
|
|
1127
|
-
|
|
1128
|
-
"""Remove a model from the registry"""
|
|
1129
|
-
|
|
1130
|
-
if model_registry.remove_model(model_name):
|
|
1131
|
-
rprint(f"[green]Model '{model_name}' removed from registry![/green]")
|
|
1132
|
-
|
|
1133
|
-
if from_file:
|
|
1134
|
-
try:
|
|
1135
|
-
config_path = settings.config_dir / "models.json"
|
|
1136
|
-
if config_path.exists():
|
|
1137
|
-
import json
|
|
1138
|
-
with open(config_path, 'r') as f:
|
|
1139
|
-
data = json.load(f)
|
|
1140
|
-
|
|
1141
|
-
user_models = data.get('models', {})
|
|
1142
|
-
if model_name in user_models:
|
|
1143
|
-
del user_models[model_name]
|
|
1144
|
-
model_registry.save_user_config(user_models, config_path)
|
|
1145
|
-
rprint(f"[green]Model removed from configuration file[/green]")
|
|
1146
|
-
else:
|
|
1147
|
-
rprint(f"[yellow]Model not found in configuration file[/yellow]")
|
|
1148
|
-
else:
|
|
1149
|
-
rprint(f"[yellow]No user configuration file found[/yellow]")
|
|
1150
|
-
except Exception as e:
|
|
1151
|
-
rprint(f"[red]Failed to update configuration file: {e}[/red]")
|
|
1152
|
-
else:
|
|
1153
|
-
rprint(f"[red]Model '{model_name}' not found in registry![/red]")
|
|
143
|
+
mcp_main()
|
|
144
|
+
except ImportError:
|
|
145
|
+
rprint("[red]MCP package not installed. Install with:[/red]")
|
|
146
|
+
rprint("[yellow] pip install 'ollamadiffuser[mcp]'[/yellow]")
|
|
1154
147
|
sys.exit(1)
|
|
1155
148
|
|
|
1156
|
-
@registry.command()
|
|
1157
|
-
def reload():
|
|
1158
|
-
"""Reload the model registry from configuration files"""
|
|
1159
|
-
try:
|
|
1160
|
-
model_registry.reload()
|
|
1161
|
-
rprint("[green]Model registry reloaded successfully![/green]")
|
|
1162
|
-
|
|
1163
|
-
# Show summary
|
|
1164
|
-
models = model_registry.get_all_models()
|
|
1165
|
-
external_registries = model_registry.get_external_registries()
|
|
1166
|
-
|
|
1167
|
-
rprint(f"[dim]Total models: {len(models)}[/dim]")
|
|
1168
|
-
if external_registries:
|
|
1169
|
-
rprint(f"[dim]External registries: {len(external_registries)}[/dim]")
|
|
1170
|
-
for registry_path in external_registries:
|
|
1171
|
-
rprint(f"[dim] • {registry_path}[/dim]")
|
|
1172
|
-
else:
|
|
1173
|
-
rprint("[dim]No external registries loaded[/dim]")
|
|
1174
|
-
|
|
1175
|
-
except Exception as e:
|
|
1176
|
-
rprint(f"[red]Failed to reload registry: {e}[/red]")
|
|
1177
|
-
sys.exit(1)
|
|
1178
149
|
|
|
1179
|
-
@
|
|
1180
|
-
@click.
|
|
1181
|
-
def
|
|
1182
|
-
"""
|
|
1183
|
-
|
|
1184
|
-
from pathlib import Path
|
|
1185
|
-
import json
|
|
1186
|
-
import yaml
|
|
1187
|
-
|
|
1188
|
-
config_path = Path(config_file)
|
|
1189
|
-
|
|
1190
|
-
with open(config_path, 'r', encoding='utf-8') as f:
|
|
1191
|
-
if config_path.suffix.lower() == '.json':
|
|
1192
|
-
data = json.load(f)
|
|
1193
|
-
elif config_path.suffix.lower() in ['.yaml', '.yml']:
|
|
1194
|
-
data = yaml.safe_load(f)
|
|
1195
|
-
else:
|
|
1196
|
-
rprint(f"[red]Unsupported file format: {config_path.suffix}[/red]")
|
|
1197
|
-
sys.exit(1)
|
|
1198
|
-
|
|
1199
|
-
if 'models' not in data:
|
|
1200
|
-
rprint("[red]Configuration file must contain a 'models' section[/red]")
|
|
1201
|
-
sys.exit(1)
|
|
1202
|
-
|
|
1203
|
-
imported_count = 0
|
|
1204
|
-
for model_name, model_config in data['models'].items():
|
|
1205
|
-
if model_registry.add_model(model_name, model_config):
|
|
1206
|
-
imported_count += 1
|
|
1207
|
-
rprint(f"[green]✓ Imported: {model_name}[/green]")
|
|
1208
|
-
else:
|
|
1209
|
-
rprint(f"[red]✗ Failed to import: {model_name}[/red]")
|
|
1210
|
-
|
|
1211
|
-
rprint(f"[green]Successfully imported {imported_count} models[/green]")
|
|
1212
|
-
|
|
1213
|
-
except Exception as e:
|
|
1214
|
-
rprint(f"[red]Failed to import configuration: {e}[/red]")
|
|
1215
|
-
sys.exit(1)
|
|
150
|
+
@cli.command(name="create-samples")
|
|
151
|
+
@click.option("--force", is_flag=True, help="Force recreation of samples")
|
|
152
|
+
def create_samples_cmd(force):
|
|
153
|
+
"""Create ControlNet sample images for the Web UI"""
|
|
154
|
+
from .commands import create_samples
|
|
1216
155
|
|
|
1217
|
-
|
|
1218
|
-
|
|
1219
|
-
@click.option('--format', '-f', type=click.Choice(['json', 'yaml']), default='json', help='Output format')
|
|
1220
|
-
@click.option('--user-only', is_flag=True, help='Export only user-defined models')
|
|
1221
|
-
def export(output: Optional[str], format: str, user_only: bool):
|
|
1222
|
-
"""Export model registry to a configuration file"""
|
|
1223
|
-
try:
|
|
1224
|
-
from pathlib import Path
|
|
1225
|
-
import json
|
|
1226
|
-
import yaml
|
|
1227
|
-
|
|
1228
|
-
if user_only:
|
|
1229
|
-
# Only export models from external registries
|
|
1230
|
-
models = {}
|
|
1231
|
-
external_registries = model_registry.get_external_registries()
|
|
1232
|
-
if external_registries:
|
|
1233
|
-
rprint(f"[yellow]User-only export not fully supported yet. Exporting all models.[/yellow]")
|
|
1234
|
-
|
|
1235
|
-
models = model_registry.get_all_models()
|
|
1236
|
-
|
|
1237
|
-
config_data = {"models": models}
|
|
1238
|
-
|
|
1239
|
-
if output:
|
|
1240
|
-
output_path = Path(output)
|
|
1241
|
-
else:
|
|
1242
|
-
if format == 'json':
|
|
1243
|
-
output_path = Path('models.json')
|
|
1244
|
-
else:
|
|
1245
|
-
output_path = Path('models.yaml')
|
|
1246
|
-
|
|
1247
|
-
with open(output_path, 'w', encoding='utf-8') as f:
|
|
1248
|
-
if format == 'json':
|
|
1249
|
-
json.dump(config_data, f, indent=2, ensure_ascii=False)
|
|
1250
|
-
else:
|
|
1251
|
-
yaml.safe_dump(config_data, f, default_flow_style=False, allow_unicode=True)
|
|
1252
|
-
|
|
1253
|
-
rprint(f"[green]Model registry exported to {output_path}[/green]")
|
|
1254
|
-
rprint(f"[dim]Exported {len(models)} models[/dim]")
|
|
1255
|
-
|
|
1256
|
-
except Exception as e:
|
|
1257
|
-
rprint(f"[red]Failed to export registry: {e}[/red]")
|
|
1258
|
-
sys.exit(1)
|
|
156
|
+
ctx = click.Context(create_samples)
|
|
157
|
+
ctx.invoke(create_samples, force=force)
|
|
1259
158
|
|
|
1260
|
-
@registry.command('check-gguf')
|
|
1261
|
-
def check_gguf():
|
|
1262
|
-
"""Check GGUF support status"""
|
|
1263
|
-
from ..core.models.gguf_loader import GGUF_AVAILABLE
|
|
1264
|
-
|
|
1265
|
-
if GGUF_AVAILABLE:
|
|
1266
|
-
rprint("✅ [green]GGUF Support Available[/green]")
|
|
1267
|
-
|
|
1268
|
-
# Show GGUF models
|
|
1269
|
-
models = model_registry.get_all_models()
|
|
1270
|
-
gguf_models = {name: info for name, info in models.items()
|
|
1271
|
-
if model_manager.is_gguf_model(name)}
|
|
1272
|
-
|
|
1273
|
-
if gguf_models:
|
|
1274
|
-
rprint(f"\n🔥 Found {len(gguf_models)} GGUF models:")
|
|
1275
|
-
|
|
1276
|
-
table = Table()
|
|
1277
|
-
table.add_column("Model", style="cyan")
|
|
1278
|
-
table.add_column("Variant", style="yellow")
|
|
1279
|
-
table.add_column("VRAM", style="green")
|
|
1280
|
-
table.add_column("Size", style="blue")
|
|
1281
|
-
table.add_column("Installed", style="red")
|
|
1282
|
-
|
|
1283
|
-
for name, info in gguf_models.items():
|
|
1284
|
-
hw_req = info.get('hardware_requirements', {})
|
|
1285
|
-
installed = "✅" if model_manager.is_model_installed(name) else "❌"
|
|
1286
|
-
|
|
1287
|
-
table.add_row(
|
|
1288
|
-
name,
|
|
1289
|
-
info.get('variant', 'unknown'),
|
|
1290
|
-
f"{hw_req.get('min_vram_gb', '?')}GB",
|
|
1291
|
-
f"{hw_req.get('disk_space_gb', '?')}GB",
|
|
1292
|
-
installed
|
|
1293
|
-
)
|
|
1294
|
-
|
|
1295
|
-
console.print(table)
|
|
1296
|
-
|
|
1297
|
-
rprint("\n📋 [blue]Usage:[/blue]")
|
|
1298
|
-
rprint(" ollamadiffuser pull <model-name> # Download GGUF model")
|
|
1299
|
-
rprint(" ollamadiffuser load <model-name> # Load GGUF model")
|
|
1300
|
-
rprint("\n💡 [yellow]Tip:[/yellow] Start with flux.1-dev-gguf-q4ks for best balance")
|
|
1301
|
-
else:
|
|
1302
|
-
rprint("ℹ️ No GGUF models found in registry")
|
|
1303
|
-
else:
|
|
1304
|
-
rprint("❌ [red]GGUF Support Not Available[/red]")
|
|
1305
|
-
rprint("📦 Install with: [yellow]pip install llama-cpp-python gguf[/yellow]")
|
|
1306
|
-
rprint("🔧 Or install all dependencies: [yellow]pip install -r requirements.txt[/yellow]")
|
|
1307
159
|
|
|
1308
|
-
if __name__ ==
|
|
1309
|
-
cli()
|
|
160
|
+
if __name__ == "__main__":
|
|
161
|
+
cli()
|