ollamadiffuser 1.2.3__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ollamadiffuser/__init__.py +1 -1
- ollamadiffuser/api/server.py +312 -312
- ollamadiffuser/cli/config_commands.py +119 -0
- ollamadiffuser/cli/lora_commands.py +169 -0
- ollamadiffuser/cli/main.py +85 -1233
- ollamadiffuser/cli/model_commands.py +664 -0
- ollamadiffuser/cli/recommend_command.py +205 -0
- ollamadiffuser/cli/registry_commands.py +197 -0
- ollamadiffuser/core/config/model_registry.py +562 -11
- ollamadiffuser/core/config/settings.py +24 -2
- ollamadiffuser/core/inference/__init__.py +5 -0
- ollamadiffuser/core/inference/base.py +182 -0
- ollamadiffuser/core/inference/engine.py +204 -1405
- ollamadiffuser/core/inference/strategies/__init__.py +1 -0
- ollamadiffuser/core/inference/strategies/controlnet_strategy.py +170 -0
- ollamadiffuser/core/inference/strategies/flux_strategy.py +136 -0
- ollamadiffuser/core/inference/strategies/generic_strategy.py +164 -0
- ollamadiffuser/core/inference/strategies/gguf_strategy.py +113 -0
- ollamadiffuser/core/inference/strategies/hidream_strategy.py +104 -0
- ollamadiffuser/core/inference/strategies/sd15_strategy.py +134 -0
- ollamadiffuser/core/inference/strategies/sd3_strategy.py +80 -0
- ollamadiffuser/core/inference/strategies/sdxl_strategy.py +131 -0
- ollamadiffuser/core/inference/strategies/video_strategy.py +108 -0
- ollamadiffuser/mcp/__init__.py +0 -0
- ollamadiffuser/mcp/server.py +184 -0
- ollamadiffuser/ui/templates/index.html +62 -1
- ollamadiffuser/ui/web.py +116 -54
- {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.0.dist-info}/METADATA +321 -108
- ollamadiffuser-2.0.0.dist-info/RECORD +61 -0
- {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.0.dist-info}/WHEEL +1 -1
- {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.0.dist-info}/entry_points.txt +1 -0
- ollamadiffuser/core/models/registry.py +0 -384
- ollamadiffuser/ui/samples/.DS_Store +0 -0
- ollamadiffuser-1.2.3.dist-info/RECORD +0 -45
- {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.0.dist-info}/licenses/LICENSE +0 -0
- {ollamadiffuser-1.2.3.dist-info → ollamadiffuser-2.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,664 @@
|
|
|
1
|
+
"""Model-related CLI commands for OllamaDiffuser."""
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
import subprocess
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import click
|
|
8
|
+
from rich.console import Console
|
|
9
|
+
from rich.table import Table
|
|
10
|
+
from rich import print as rprint
|
|
11
|
+
|
|
12
|
+
from ..core.models.manager import model_manager
|
|
13
|
+
from ..core.config.settings import settings
|
|
14
|
+
from ..core.config.model_registry import model_registry
|
|
15
|
+
from ..api.server import run_server
|
|
16
|
+
|
|
17
|
+
console = Console()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class OllamaStyleProgress:
|
|
21
|
+
"""Enhanced progress tracker that mimics Ollama's progress display"""
|
|
22
|
+
|
|
23
|
+
def __init__(self, console: Console):
|
|
24
|
+
self.console = console
|
|
25
|
+
self.last_message = ""
|
|
26
|
+
|
|
27
|
+
def update(self, message: str):
|
|
28
|
+
"""Update progress with a message"""
|
|
29
|
+
# Skip duplicate messages
|
|
30
|
+
if message == self.last_message:
|
|
31
|
+
return
|
|
32
|
+
|
|
33
|
+
self.last_message = message
|
|
34
|
+
|
|
35
|
+
# Handle different types of messages
|
|
36
|
+
if message.startswith("pulling ") and ":" in message and "%" in message:
|
|
37
|
+
# This is a file progress message from download_utils
|
|
38
|
+
# Format: "pulling e6a7edc1a4d7: 12% ▕██ ▏ 617 MB/5200 MB 44 MB/s 1m44s"
|
|
39
|
+
self.console.print(message)
|
|
40
|
+
elif message.startswith("pulling manifest"):
|
|
41
|
+
self.console.print(message)
|
|
42
|
+
elif message.startswith("📦 Repository:"):
|
|
43
|
+
# Repository info
|
|
44
|
+
self.console.print(f"[dim]{message}[/dim]")
|
|
45
|
+
elif message.startswith("📁 Found"):
|
|
46
|
+
# Existing files info
|
|
47
|
+
self.console.print(f"[dim]{message}[/dim]")
|
|
48
|
+
elif message.startswith("✅") and "download completed" in message:
|
|
49
|
+
self.console.print(f"[green]{message}[/green]")
|
|
50
|
+
elif message.startswith("❌"):
|
|
51
|
+
self.console.print(f"[red]{message}[/red]")
|
|
52
|
+
elif message.startswith("⚠️"):
|
|
53
|
+
self.console.print(f"[yellow]{message}[/yellow]")
|
|
54
|
+
else:
|
|
55
|
+
# For other messages, print with dimmed style
|
|
56
|
+
self.console.print(f"[dim]{message}[/dim]")
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@click.command()
|
|
60
|
+
@click.argument('model_name')
|
|
61
|
+
@click.option('--force', '-f', is_flag=True, help='Force re-download')
|
|
62
|
+
def pull(model_name: str, force: bool):
|
|
63
|
+
"""Download model"""
|
|
64
|
+
rprint(f"[blue]Downloading model: {model_name}[/blue]")
|
|
65
|
+
|
|
66
|
+
# Use the new Ollama-style progress tracker
|
|
67
|
+
progress_tracker = OllamaStyleProgress(console)
|
|
68
|
+
|
|
69
|
+
def progress_callback(message: str):
|
|
70
|
+
"""Enhanced progress callback with Ollama-style display"""
|
|
71
|
+
progress_tracker.update(message)
|
|
72
|
+
|
|
73
|
+
try:
|
|
74
|
+
if model_manager.pull_model(model_name, force=force, progress_callback=progress_callback):
|
|
75
|
+
progress_tracker.update("✅ download completed")
|
|
76
|
+
rprint(f"[green]Model {model_name} downloaded successfully![/green]")
|
|
77
|
+
else:
|
|
78
|
+
rprint(f"[red]Model {model_name} download failed![/red]")
|
|
79
|
+
sys.exit(1)
|
|
80
|
+
except KeyboardInterrupt:
|
|
81
|
+
rprint("\n[yellow]Download cancelled by user[/yellow]")
|
|
82
|
+
sys.exit(1)
|
|
83
|
+
except Exception as e:
|
|
84
|
+
rprint(f"[red]Download failed: {str(e)}[/red]")
|
|
85
|
+
sys.exit(1)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@click.command()
|
|
89
|
+
@click.argument('model_name')
|
|
90
|
+
@click.option('--host', '-h', default=None, help='Server host address')
|
|
91
|
+
@click.option('--port', '-p', default=None, type=int, help='Server port')
|
|
92
|
+
def run(model_name: str, host: Optional[str], port: Optional[int]):
|
|
93
|
+
"""Run model service"""
|
|
94
|
+
rprint(f"[blue]Starting model service: {model_name}[/blue]")
|
|
95
|
+
|
|
96
|
+
# Check if model is installed
|
|
97
|
+
if not model_manager.is_model_installed(model_name):
|
|
98
|
+
rprint(f"[red]Model {model_name} is not installed. Please run first: ollamadiffuser pull {model_name}[/red]")
|
|
99
|
+
sys.exit(1)
|
|
100
|
+
|
|
101
|
+
# Load model
|
|
102
|
+
rprint("[yellow]Loading model...[/yellow]")
|
|
103
|
+
if not model_manager.load_model(model_name):
|
|
104
|
+
rprint(f"[red]Failed to load model {model_name}![/red]")
|
|
105
|
+
sys.exit(1)
|
|
106
|
+
|
|
107
|
+
rprint(f"[green]Model {model_name} loaded successfully![/green]")
|
|
108
|
+
|
|
109
|
+
# Start server
|
|
110
|
+
try:
|
|
111
|
+
run_server(host=host, port=port)
|
|
112
|
+
except KeyboardInterrupt:
|
|
113
|
+
rprint("\n[yellow]Server stopped[/yellow]")
|
|
114
|
+
model_manager.unload_model()
|
|
115
|
+
# Clear the current model from settings when server stops
|
|
116
|
+
settings.current_model = None
|
|
117
|
+
settings.save_config()
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@click.command()
|
|
121
|
+
@click.option('--hardware', '-hw', is_flag=True, help='Show hardware requirements')
|
|
122
|
+
def list(hardware: bool):
|
|
123
|
+
"""List installed models only"""
|
|
124
|
+
installed_models = model_manager.list_installed_models()
|
|
125
|
+
current_model = model_manager.get_current_model()
|
|
126
|
+
|
|
127
|
+
if not installed_models:
|
|
128
|
+
rprint("[yellow]No models installed[/yellow]")
|
|
129
|
+
rprint("\n[dim]💡 Download models with: ollamadiffuser pull <model-name>[/dim]")
|
|
130
|
+
rprint("[dim]💡 See all available models: ollamadiffuser registry list[/dim]")
|
|
131
|
+
rprint("[dim]💡 See only available models: ollamadiffuser registry list --available-only[/dim]")
|
|
132
|
+
return
|
|
133
|
+
|
|
134
|
+
if hardware:
|
|
135
|
+
# Show detailed hardware requirements
|
|
136
|
+
for model_name in installed_models:
|
|
137
|
+
info = model_manager.get_model_info(model_name)
|
|
138
|
+
if not info:
|
|
139
|
+
continue
|
|
140
|
+
|
|
141
|
+
# Check installation status
|
|
142
|
+
status = "✅ Installed"
|
|
143
|
+
if model_name == current_model:
|
|
144
|
+
status += " (current)"
|
|
145
|
+
size = info.get('size', 'Unknown')
|
|
146
|
+
|
|
147
|
+
# Create individual table for each model
|
|
148
|
+
table = Table(title=f"[bold cyan]{model_name}[/bold cyan] - {status}")
|
|
149
|
+
table.add_column("Property", style="yellow", no_wrap=True)
|
|
150
|
+
table.add_column("Value", style="white")
|
|
151
|
+
|
|
152
|
+
# Basic info
|
|
153
|
+
table.add_row("Type", info.get('model_type', 'Unknown'))
|
|
154
|
+
table.add_row("Size", size)
|
|
155
|
+
|
|
156
|
+
# Hardware requirements
|
|
157
|
+
hw_req = info.get('hardware_requirements', {})
|
|
158
|
+
if hw_req:
|
|
159
|
+
table.add_row("Min VRAM", f"{hw_req.get('min_vram_gb', 'Unknown')} GB")
|
|
160
|
+
table.add_row("Recommended VRAM", f"{hw_req.get('recommended_vram_gb', 'Unknown')} GB")
|
|
161
|
+
table.add_row("Min RAM", f"{hw_req.get('min_ram_gb', 'Unknown')} GB")
|
|
162
|
+
table.add_row("Recommended RAM", f"{hw_req.get('recommended_ram_gb', 'Unknown')} GB")
|
|
163
|
+
table.add_row("Disk Space", f"{hw_req.get('disk_space_gb', 'Unknown')} GB")
|
|
164
|
+
table.add_row("Supported Devices", ", ".join(hw_req.get('supported_devices', [])))
|
|
165
|
+
table.add_row("Performance Notes", hw_req.get('performance_notes', 'N/A'))
|
|
166
|
+
|
|
167
|
+
console.print(table)
|
|
168
|
+
console.print() # Add spacing between models
|
|
169
|
+
else:
|
|
170
|
+
# Show compact table
|
|
171
|
+
table = Table(title="Installed Models")
|
|
172
|
+
table.add_column("Model Name", style="cyan", no_wrap=True)
|
|
173
|
+
table.add_column("Status", style="green")
|
|
174
|
+
table.add_column("Size", style="blue")
|
|
175
|
+
table.add_column("Type", style="magenta")
|
|
176
|
+
table.add_column("Min VRAM", style="yellow")
|
|
177
|
+
|
|
178
|
+
for model_name in installed_models:
|
|
179
|
+
# Check installation status
|
|
180
|
+
status = "✅ Installed"
|
|
181
|
+
if model_name == current_model:
|
|
182
|
+
status += " (current)"
|
|
183
|
+
|
|
184
|
+
# Get model information
|
|
185
|
+
info = model_manager.get_model_info(model_name)
|
|
186
|
+
size = info.get('size', 'Unknown') if info else 'Unknown'
|
|
187
|
+
model_type = info.get('model_type', 'Unknown') if info else 'Unknown'
|
|
188
|
+
|
|
189
|
+
# Get hardware requirements
|
|
190
|
+
hw_req = info.get('hardware_requirements', {}) if info else {}
|
|
191
|
+
min_vram = f"{hw_req.get('min_vram_gb', '?')} GB" if hw_req else "Unknown"
|
|
192
|
+
|
|
193
|
+
table.add_row(model_name, status, size, model_type, min_vram)
|
|
194
|
+
|
|
195
|
+
console.print(table)
|
|
196
|
+
|
|
197
|
+
# Get counts for summary
|
|
198
|
+
available_models = model_registry.get_available_models()
|
|
199
|
+
external_models = model_registry.get_external_api_models_only()
|
|
200
|
+
|
|
201
|
+
console.print(f"\n[dim]💡 Installed: {len(installed_models)} models[/dim]")
|
|
202
|
+
console.print(f"[dim]💡 Available for download: {len(available_models)} models[/dim]")
|
|
203
|
+
if external_models:
|
|
204
|
+
console.print(f"[dim]💡 External API models: {len(external_models)} models[/dim]")
|
|
205
|
+
console.print("\n[dim]💡 Use --hardware flag to see detailed hardware requirements[/dim]")
|
|
206
|
+
console.print("[dim]💡 See all models: ollamadiffuser registry list[/dim]")
|
|
207
|
+
console.print("[dim]💡 See available models: ollamadiffuser registry list --available-only[/dim]")
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
@click.command()
|
|
211
|
+
@click.argument('model_name')
|
|
212
|
+
def show(model_name: str):
|
|
213
|
+
"""Show model detailed information"""
|
|
214
|
+
info = model_manager.get_model_info(model_name)
|
|
215
|
+
|
|
216
|
+
if info is None:
|
|
217
|
+
rprint(f"[red]Model {model_name} does not exist[/red]")
|
|
218
|
+
sys.exit(1)
|
|
219
|
+
|
|
220
|
+
rprint(f"[bold cyan]Model Information: {model_name}[/bold cyan]")
|
|
221
|
+
rprint(f"Type: {info.get('model_type', 'Unknown')}")
|
|
222
|
+
rprint(f"Variant: {info.get('variant', 'Unknown')}")
|
|
223
|
+
rprint(f"Installed: {'Yes' if info.get('installed', False) else 'No'}")
|
|
224
|
+
|
|
225
|
+
if info.get('installed', False):
|
|
226
|
+
rprint(f"Local Path: {info.get('local_path', 'Unknown')}")
|
|
227
|
+
rprint(f"Size: {info.get('size', 'Unknown')}")
|
|
228
|
+
|
|
229
|
+
# Hardware requirements
|
|
230
|
+
if 'hardware_requirements' in info and info['hardware_requirements']:
|
|
231
|
+
hw_req = info['hardware_requirements']
|
|
232
|
+
rprint("\n[bold]Hardware Requirements:[/bold]")
|
|
233
|
+
rprint(f" Min VRAM: {hw_req.get('min_vram_gb', 'Unknown')} GB")
|
|
234
|
+
rprint(f" Recommended VRAM: {hw_req.get('recommended_vram_gb', 'Unknown')} GB")
|
|
235
|
+
rprint(f" Min RAM: {hw_req.get('min_ram_gb', 'Unknown')} GB")
|
|
236
|
+
rprint(f" Recommended RAM: {hw_req.get('recommended_ram_gb', 'Unknown')} GB")
|
|
237
|
+
rprint(f" Disk Space: {hw_req.get('disk_space_gb', 'Unknown')} GB")
|
|
238
|
+
rprint(f" Supported Devices: {', '.join(hw_req.get('supported_devices', []))}")
|
|
239
|
+
if hw_req.get('performance_notes'):
|
|
240
|
+
rprint(f" Performance Notes: {hw_req.get('performance_notes')}")
|
|
241
|
+
|
|
242
|
+
if 'parameters' in info and info['parameters']:
|
|
243
|
+
rprint("\n[bold]Default Parameters:[/bold]")
|
|
244
|
+
for key, value in info['parameters'].items():
|
|
245
|
+
rprint(f" {key}: {value}")
|
|
246
|
+
|
|
247
|
+
if 'components' in info and info['components']:
|
|
248
|
+
rprint("\n[bold]Components:[/bold]")
|
|
249
|
+
for key, value in info['components'].items():
|
|
250
|
+
rprint(f" {key}: {value}")
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
@click.command()
|
|
254
|
+
@click.argument('model_name', required=False)
|
|
255
|
+
@click.option('--list', '-l', is_flag=True, help='List all available models')
|
|
256
|
+
def check(model_name: str, list: bool):
|
|
257
|
+
"""Check model download status and integrity"""
|
|
258
|
+
if list:
|
|
259
|
+
rprint("[bold blue]📋 Available Models:[/bold blue]")
|
|
260
|
+
available_models = model_manager.list_available_models()
|
|
261
|
+
for model in available_models:
|
|
262
|
+
model_info = model_manager.get_model_info(model)
|
|
263
|
+
status = "✅ Installed" if model_manager.is_model_installed(model) else "⬇️ Available"
|
|
264
|
+
license_type = model_info.get("license_info", {}).get("type", "Unknown")
|
|
265
|
+
rprint(f" {model:<30} {status:<15} ({license_type})")
|
|
266
|
+
return
|
|
267
|
+
|
|
268
|
+
if not model_name:
|
|
269
|
+
rprint("[bold red]❌ Please specify a model name or use --list[/bold red]")
|
|
270
|
+
rprint("[dim]Usage: ollamadiffuser check MODEL_NAME[/dim]")
|
|
271
|
+
rprint("[dim] ollamadiffuser check --list[/dim]")
|
|
272
|
+
return
|
|
273
|
+
|
|
274
|
+
# Check model download status directly
|
|
275
|
+
status = _check_download_status(model_name)
|
|
276
|
+
|
|
277
|
+
rprint("\n" + "="*60)
|
|
278
|
+
|
|
279
|
+
if status is True:
|
|
280
|
+
rprint(f"[green]🎉 {model_name} is ready to use![/green]")
|
|
281
|
+
rprint(f"\n[blue]💡 You can now run:[/blue]")
|
|
282
|
+
rprint(f" [cyan]ollamadiffuser run {model_name}[/cyan]")
|
|
283
|
+
elif status == "needs_config":
|
|
284
|
+
rprint(f"[yellow]⚠️ {model_name} files are complete but model needs configuration[/yellow]")
|
|
285
|
+
rprint(f"\n[blue]💡 Try reinstalling:[/blue]")
|
|
286
|
+
rprint(f" [cyan]ollamadiffuser pull {model_name} --force[/cyan]")
|
|
287
|
+
elif status == "downloading":
|
|
288
|
+
rprint(f"[yellow]🔄 {model_name} is currently downloading[/yellow]")
|
|
289
|
+
rprint(f"\n[blue]💡 Wait for download to complete or check progress[/blue]")
|
|
290
|
+
elif status == "incomplete":
|
|
291
|
+
rprint(f"[yellow]⚠️ Download is incomplete[/yellow]")
|
|
292
|
+
rprint(f"\n[blue]💡 Resume download with:[/blue]")
|
|
293
|
+
rprint(f" [cyan]ollamadiffuser pull {model_name}[/cyan]")
|
|
294
|
+
rprint(f"\n[blue]💡 Or force fresh download with:[/blue]")
|
|
295
|
+
rprint(f" [cyan]ollamadiffuser pull {model_name} --force[/cyan]")
|
|
296
|
+
else:
|
|
297
|
+
rprint(f"[red]❌ {model_name} is not downloaded[/red]")
|
|
298
|
+
rprint(f"\n[blue]💡 Download with:[/blue]")
|
|
299
|
+
rprint(f" [cyan]ollamadiffuser pull {model_name}[/cyan]")
|
|
300
|
+
|
|
301
|
+
_show_model_specific_help(model_name)
|
|
302
|
+
|
|
303
|
+
rprint(f"\n[dim]📚 For more help: ollamadiffuser --help[/dim]")
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def _check_download_status(model_name: str):
|
|
307
|
+
"""Check the current download status of any model"""
|
|
308
|
+
from ..core.utils.download_utils import check_download_integrity, get_repo_file_list, format_size
|
|
309
|
+
import subprocess
|
|
310
|
+
|
|
311
|
+
rprint(f"[blue]🔍 Checking {model_name} download status...[/blue]\n")
|
|
312
|
+
|
|
313
|
+
# Check if model is in registry
|
|
314
|
+
if model_name not in model_manager.model_registry:
|
|
315
|
+
rprint(f"[red]❌ {model_name} not found in model registry[/red]")
|
|
316
|
+
available_models = model_manager.list_available_models()
|
|
317
|
+
rprint(f"[blue]📋 Available models: {', '.join(available_models)}[/blue]")
|
|
318
|
+
return False
|
|
319
|
+
|
|
320
|
+
model_info = model_manager.model_registry[model_name]
|
|
321
|
+
repo_id = model_info["repo_id"]
|
|
322
|
+
model_path = settings.get_model_path(model_name)
|
|
323
|
+
|
|
324
|
+
rprint(f"[cyan]📦 Model: {model_name}[/cyan]")
|
|
325
|
+
rprint(f"[cyan]🔗 Repository: {repo_id}[/cyan]")
|
|
326
|
+
rprint(f"[cyan]📁 Local path: {model_path}[/cyan]")
|
|
327
|
+
|
|
328
|
+
# Show model-specific info
|
|
329
|
+
license_info = model_info.get("license_info", {})
|
|
330
|
+
if license_info:
|
|
331
|
+
rprint(f"[yellow]📄 License: {license_info.get('type', 'Unknown')}[/yellow]")
|
|
332
|
+
rprint(f"[yellow]🔑 HF Token Required: {'Yes' if license_info.get('requires_agreement', False) else 'No'}[/yellow]")
|
|
333
|
+
rprint(f"[yellow]💼 Commercial Use: {'Allowed' if license_info.get('commercial_use', False) else 'Not Allowed'}[/yellow]")
|
|
334
|
+
|
|
335
|
+
# Show optimal parameters
|
|
336
|
+
params = model_info.get("parameters", {})
|
|
337
|
+
if params:
|
|
338
|
+
rprint(f"[green]⚡ Optimal Settings:[/green]")
|
|
339
|
+
rprint(f" Steps: {params.get('num_inference_steps', 'N/A')}")
|
|
340
|
+
rprint(f" Guidance: {params.get('guidance_scale', 'N/A')}")
|
|
341
|
+
if 'max_sequence_length' in params:
|
|
342
|
+
rprint(f" Max Seq Length: {params['max_sequence_length']}")
|
|
343
|
+
|
|
344
|
+
rprint()
|
|
345
|
+
|
|
346
|
+
# Check if directory exists
|
|
347
|
+
if not model_path.exists():
|
|
348
|
+
rprint("[yellow]📂 Status: Not downloaded[/yellow]")
|
|
349
|
+
return False
|
|
350
|
+
|
|
351
|
+
# Get repository file list
|
|
352
|
+
rprint("[blue]🌐 Getting repository information...[/blue]")
|
|
353
|
+
try:
|
|
354
|
+
file_sizes = get_repo_file_list(repo_id)
|
|
355
|
+
total_expected_size = sum(file_sizes.values())
|
|
356
|
+
total_files_expected = len(file_sizes)
|
|
357
|
+
|
|
358
|
+
rprint(f"[blue]📊 Expected: {total_files_expected} files, {format_size(total_expected_size)} total[/blue]")
|
|
359
|
+
except Exception as e:
|
|
360
|
+
rprint(f"[yellow]⚠️ Could not get repository info: {e}[/yellow]")
|
|
361
|
+
file_sizes = {}
|
|
362
|
+
total_expected_size = 0
|
|
363
|
+
total_files_expected = 0
|
|
364
|
+
|
|
365
|
+
# Check local files
|
|
366
|
+
local_files = []
|
|
367
|
+
local_size = 0
|
|
368
|
+
|
|
369
|
+
for file_path in model_path.rglob('*'):
|
|
370
|
+
if file_path.is_file():
|
|
371
|
+
rel_path = file_path.relative_to(model_path)
|
|
372
|
+
file_size = file_path.stat().st_size
|
|
373
|
+
local_files.append((str(rel_path), file_size))
|
|
374
|
+
local_size += file_size
|
|
375
|
+
|
|
376
|
+
rprint(f"[blue]💾 Downloaded: {len(local_files)} files, {format_size(local_size)} total[/blue]")
|
|
377
|
+
|
|
378
|
+
if total_expected_size > 0:
|
|
379
|
+
progress_percent = (local_size / total_expected_size) * 100
|
|
380
|
+
rprint(f"[blue]📈 Progress: {progress_percent:.1f}%[/blue]")
|
|
381
|
+
|
|
382
|
+
rprint()
|
|
383
|
+
|
|
384
|
+
# Check for missing files
|
|
385
|
+
if file_sizes:
|
|
386
|
+
# Check if we have size information from the API
|
|
387
|
+
has_size_info = any(size > 0 for size in file_sizes.values())
|
|
388
|
+
|
|
389
|
+
if has_size_info:
|
|
390
|
+
# Normal case: we have size information, do detailed comparison
|
|
391
|
+
missing_files = []
|
|
392
|
+
incomplete_files = []
|
|
393
|
+
|
|
394
|
+
for expected_file, expected_size in file_sizes.items():
|
|
395
|
+
local_file_path = model_path / expected_file
|
|
396
|
+
if not local_file_path.exists():
|
|
397
|
+
missing_files.append(expected_file)
|
|
398
|
+
elif expected_size > 0 and local_file_path.stat().st_size != expected_size:
|
|
399
|
+
local_size_actual = local_file_path.stat().st_size
|
|
400
|
+
incomplete_files.append((expected_file, local_size_actual, expected_size))
|
|
401
|
+
|
|
402
|
+
if missing_files:
|
|
403
|
+
rprint(f"[red]❌ Missing files ({len(missing_files)}):[/red]")
|
|
404
|
+
for missing_file in missing_files[:10]: # Show first 10
|
|
405
|
+
rprint(f" - {missing_file}")
|
|
406
|
+
if len(missing_files) > 10:
|
|
407
|
+
rprint(f" ... and {len(missing_files) - 10} more")
|
|
408
|
+
rprint()
|
|
409
|
+
|
|
410
|
+
if incomplete_files:
|
|
411
|
+
rprint(f"[yellow]⚠️ Incomplete files ({len(incomplete_files)}):[/yellow]")
|
|
412
|
+
for incomplete_file, actual_size, expected_size in incomplete_files[:5]:
|
|
413
|
+
rprint(f" - {incomplete_file}: {format_size(actual_size)}/{format_size(expected_size)}")
|
|
414
|
+
if len(incomplete_files) > 5:
|
|
415
|
+
rprint(f" ... and {len(incomplete_files) - 5} more")
|
|
416
|
+
rprint()
|
|
417
|
+
|
|
418
|
+
if not missing_files and not incomplete_files:
|
|
419
|
+
rprint("[green]✅ All files present and complete![/green]")
|
|
420
|
+
|
|
421
|
+
# Check integrity
|
|
422
|
+
rprint("[blue]🔍 Checking download integrity...[/blue]")
|
|
423
|
+
if check_download_integrity(str(model_path), repo_id):
|
|
424
|
+
rprint("[green]✅ Download integrity verified![/green]")
|
|
425
|
+
|
|
426
|
+
# Check if model is in configuration
|
|
427
|
+
if model_manager.is_model_installed(model_name):
|
|
428
|
+
rprint("[green]✅ Model is properly configured[/green]")
|
|
429
|
+
return True
|
|
430
|
+
else:
|
|
431
|
+
rprint("[yellow]⚠️ Model files complete but not in configuration[/yellow]")
|
|
432
|
+
return "needs_config"
|
|
433
|
+
else:
|
|
434
|
+
rprint("[red]❌ Download integrity check failed[/red]")
|
|
435
|
+
return False
|
|
436
|
+
else:
|
|
437
|
+
rprint("[yellow]⚠️ Download is incomplete[/yellow]")
|
|
438
|
+
return "incomplete"
|
|
439
|
+
else:
|
|
440
|
+
# No size information available from API (common with gated repos)
|
|
441
|
+
rprint("[blue]ℹ️ Repository API doesn't provide file sizes (common with gated models)[/blue]")
|
|
442
|
+
rprint("[blue]🔍 Checking essential model files instead...[/blue]")
|
|
443
|
+
|
|
444
|
+
# Check for essential model files
|
|
445
|
+
# Determine model type based on repo_id
|
|
446
|
+
is_controlnet = 'controlnet' in repo_id.lower()
|
|
447
|
+
|
|
448
|
+
if is_controlnet:
|
|
449
|
+
# ControlNet models have different essential files
|
|
450
|
+
essential_files = ['config.json']
|
|
451
|
+
essential_dirs = [] # ControlNet models don't have complex directory structure
|
|
452
|
+
else:
|
|
453
|
+
# Regular diffusion models
|
|
454
|
+
essential_files = ['model_index.json']
|
|
455
|
+
essential_dirs = ['transformer', 'text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'vae', 'scheduler']
|
|
456
|
+
|
|
457
|
+
missing_essential = []
|
|
458
|
+
for essential_file in essential_files:
|
|
459
|
+
if not (model_path / essential_file).exists():
|
|
460
|
+
missing_essential.append(essential_file)
|
|
461
|
+
|
|
462
|
+
existing_dirs = []
|
|
463
|
+
for essential_dir in essential_dirs:
|
|
464
|
+
if (model_path / essential_dir).exists():
|
|
465
|
+
existing_dirs.append(essential_dir)
|
|
466
|
+
|
|
467
|
+
if missing_essential:
|
|
468
|
+
rprint(f"[red]❌ Missing essential files: {', '.join(missing_essential)}[/red]")
|
|
469
|
+
return "incomplete"
|
|
470
|
+
|
|
471
|
+
if existing_dirs:
|
|
472
|
+
rprint(f"[green]✅ Found model components: {', '.join(existing_dirs)}[/green]")
|
|
473
|
+
|
|
474
|
+
# Check integrity
|
|
475
|
+
rprint("[blue]🔍 Checking download integrity...[/blue]")
|
|
476
|
+
if check_download_integrity(str(model_path), repo_id):
|
|
477
|
+
rprint("[green]✅ Download integrity verified![/green]")
|
|
478
|
+
|
|
479
|
+
# Check if model is in configuration
|
|
480
|
+
if model_manager.is_model_installed(model_name):
|
|
481
|
+
rprint("[green]✅ Model is properly configured and functional[/green]")
|
|
482
|
+
return True
|
|
483
|
+
else:
|
|
484
|
+
rprint("[yellow]⚠️ Model files complete but not in configuration[/yellow]")
|
|
485
|
+
return "needs_config"
|
|
486
|
+
else:
|
|
487
|
+
rprint("[red]❌ Download integrity check failed[/red]")
|
|
488
|
+
return False
|
|
489
|
+
|
|
490
|
+
# Check if download process is running
|
|
491
|
+
rprint("[blue]🔍 Checking for active download processes...[/blue]")
|
|
492
|
+
try:
|
|
493
|
+
result = subprocess.run(['ps', 'aux'], capture_output=True, text=True)
|
|
494
|
+
if f'ollamadiffuser pull {model_name}' in result.stdout:
|
|
495
|
+
rprint("[yellow]🔄 Download process is currently running[/yellow]")
|
|
496
|
+
return "downloading"
|
|
497
|
+
else:
|
|
498
|
+
rprint("[blue]💤 No active download process found[/blue]")
|
|
499
|
+
except Exception as e:
|
|
500
|
+
rprint(f"[yellow]⚠️ Could not check processes: {e}[/yellow]")
|
|
501
|
+
|
|
502
|
+
return "incomplete"
|
|
503
|
+
|
|
504
|
+
|
|
505
|
+
def _show_model_specific_help(model_name: str):
|
|
506
|
+
"""Show model-specific help and recommendations"""
|
|
507
|
+
model_info = model_manager.get_model_info(model_name)
|
|
508
|
+
if not model_info:
|
|
509
|
+
return
|
|
510
|
+
|
|
511
|
+
rprint(f"\n[bold blue]💡 {model_name} Specific Tips:[/bold blue]")
|
|
512
|
+
|
|
513
|
+
# License-specific help
|
|
514
|
+
license_info = model_info.get("license_info", {})
|
|
515
|
+
if license_info.get("requires_agreement", False):
|
|
516
|
+
rprint(f" [yellow]🔑 Requires HuggingFace token and license agreement[/yellow]")
|
|
517
|
+
rprint(f" [blue]📝 Visit: https://huggingface.co/{model_info['repo_id']}[/blue]")
|
|
518
|
+
rprint(f" [cyan]🔧 Set token: export HF_TOKEN=your_token_here[/cyan]")
|
|
519
|
+
else:
|
|
520
|
+
rprint(f" [green]✅ No HuggingFace token required![/green]")
|
|
521
|
+
|
|
522
|
+
# Model-specific optimizations
|
|
523
|
+
if "schnell" in model_name.lower():
|
|
524
|
+
rprint(f" [green]⚡ FLUX.1-schnell is 12x faster than FLUX.1-dev[/green]")
|
|
525
|
+
rprint(f" [green]🎯 Optimized for 4-step generation[/green]")
|
|
526
|
+
rprint(f" [green]💼 Commercial use allowed (Apache 2.0)[/green]")
|
|
527
|
+
elif "flux.1-dev" in model_name.lower():
|
|
528
|
+
rprint(f" [blue]🎨 Best quality FLUX model[/blue]")
|
|
529
|
+
rprint(f" [blue]🔬 Requires 50 steps for optimal results[/blue]")
|
|
530
|
+
rprint(f" [yellow]⚠️ Non-commercial license only[/yellow]")
|
|
531
|
+
elif "stable-diffusion-1.5" in model_name.lower():
|
|
532
|
+
rprint(f" [green]🚀 Great for learning and quick tests[/green]")
|
|
533
|
+
rprint(f" [green]💾 Smallest model, runs on most hardware[/green]")
|
|
534
|
+
elif "stable-diffusion-3.5" in model_name.lower():
|
|
535
|
+
rprint(f" [green]🏆 Excellent quality-to-speed ratio[/green]")
|
|
536
|
+
rprint(f" [green]🔄 Great LoRA ecosystem[/green]")
|
|
537
|
+
|
|
538
|
+
# Hardware recommendations
|
|
539
|
+
hw_req = model_info.get("hardware_requirements", {})
|
|
540
|
+
if hw_req:
|
|
541
|
+
min_vram = hw_req.get("min_vram_gb", 0)
|
|
542
|
+
if min_vram >= 12:
|
|
543
|
+
rprint(f" [yellow]🖥️ Requires high-end GPU (RTX 4070+ or M2 Pro+)[/yellow]")
|
|
544
|
+
elif min_vram >= 8:
|
|
545
|
+
rprint(f" [blue]🖥️ Requires mid-range GPU (RTX 3080+ or M1 Pro+)[/blue]")
|
|
546
|
+
else:
|
|
547
|
+
rprint(f" [green]🖥️ Runs on most modern GPUs[/green]")
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
@click.command()
|
|
551
|
+
@click.argument('model_name')
|
|
552
|
+
@click.confirmation_option(prompt='Are you sure you want to delete this model?')
|
|
553
|
+
def rm(model_name: str):
|
|
554
|
+
"""Remove model"""
|
|
555
|
+
if model_manager.remove_model(model_name):
|
|
556
|
+
rprint(f"[green]Model {model_name} removed successfully![/green]")
|
|
557
|
+
else:
|
|
558
|
+
rprint(f"[red]Failed to remove model {model_name}![/red]")
|
|
559
|
+
sys.exit(1)
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
@click.command()
|
|
563
|
+
def ps():
|
|
564
|
+
"""Show currently running model"""
|
|
565
|
+
current_model = model_manager.get_current_model()
|
|
566
|
+
server_running = model_manager.is_server_running()
|
|
567
|
+
|
|
568
|
+
if current_model:
|
|
569
|
+
rprint(f"[green]Current model: {current_model}[/green]")
|
|
570
|
+
|
|
571
|
+
# Check server status
|
|
572
|
+
if server_running:
|
|
573
|
+
rprint(f"[green]Server status: Running on {settings.server.host}:{settings.server.port}[/green]")
|
|
574
|
+
|
|
575
|
+
# Try to get model info from the running server
|
|
576
|
+
try:
|
|
577
|
+
import requests
|
|
578
|
+
response = requests.get(f"http://{settings.server.host}:{settings.server.port}/api/models/running", timeout=2)
|
|
579
|
+
if response.status_code == 200:
|
|
580
|
+
data = response.json()
|
|
581
|
+
if data.get('loaded'):
|
|
582
|
+
info = data.get('info', {})
|
|
583
|
+
rprint(f"Device: {info.get('device', 'Unknown')}")
|
|
584
|
+
rprint(f"Type: {info.get('type', 'Unknown')}")
|
|
585
|
+
rprint(f"Variant: {info.get('variant', 'Unknown')}")
|
|
586
|
+
else:
|
|
587
|
+
rprint("[yellow]Model loaded but not active in server[/yellow]")
|
|
588
|
+
except:
|
|
589
|
+
pass
|
|
590
|
+
else:
|
|
591
|
+
rprint("[yellow]Server status: Not running[/yellow]")
|
|
592
|
+
rprint("[dim]Model is set as current but server is not active[/dim]")
|
|
593
|
+
|
|
594
|
+
# Show model info from local config
|
|
595
|
+
model_info = model_manager.get_model_info(current_model)
|
|
596
|
+
if model_info:
|
|
597
|
+
rprint(f"Model type: {model_info.get('model_type', 'Unknown')}")
|
|
598
|
+
if model_info.get('installed'):
|
|
599
|
+
rprint(f"Size: {model_info.get('size', 'Unknown')}")
|
|
600
|
+
else:
|
|
601
|
+
if server_running:
|
|
602
|
+
rprint("[yellow]Server is running but no model is loaded[/yellow]")
|
|
603
|
+
rprint(f"[green]Server status: Running on {settings.server.host}:{settings.server.port}[/green]")
|
|
604
|
+
else:
|
|
605
|
+
rprint("[yellow]No model is currently running[/yellow]")
|
|
606
|
+
rprint("[dim]Use 'ollamadiffuser run <model>' to start a model[/dim]")
|
|
607
|
+
|
|
608
|
+
|
|
609
|
+
@click.command()
|
|
610
|
+
@click.argument('model_name')
|
|
611
|
+
def load(model_name: str):
|
|
612
|
+
"""Load model into memory"""
|
|
613
|
+
rprint(f"[blue]Loading model: {model_name}[/blue]")
|
|
614
|
+
|
|
615
|
+
if model_manager.load_model(model_name):
|
|
616
|
+
rprint(f"[green]Model {model_name} loaded successfully![/green]")
|
|
617
|
+
else:
|
|
618
|
+
rprint(f"[red]Failed to load model {model_name}![/red]")
|
|
619
|
+
sys.exit(1)
|
|
620
|
+
|
|
621
|
+
|
|
622
|
+
@click.command()
|
|
623
|
+
def unload():
|
|
624
|
+
"""Unload current model"""
|
|
625
|
+
if model_manager.is_model_loaded():
|
|
626
|
+
current_model = model_manager.get_current_model()
|
|
627
|
+
model_manager.unload_model()
|
|
628
|
+
rprint(f"[green]Model {current_model} unloaded[/green]")
|
|
629
|
+
else:
|
|
630
|
+
rprint("[yellow]No model to unload[/yellow]")
|
|
631
|
+
|
|
632
|
+
|
|
633
|
+
@click.command()
|
|
634
|
+
@click.option('--host', '-h', default=None, help='Server host address')
|
|
635
|
+
@click.option('--port', '-p', default=None, type=int, help='Server port')
|
|
636
|
+
def serve(host: Optional[str], port: Optional[int]):
|
|
637
|
+
"""Start the API server without loading a model"""
|
|
638
|
+
rprint("[blue]Starting OllamaDiffuser API server...[/blue]")
|
|
639
|
+
try:
|
|
640
|
+
run_server(host=host, port=port)
|
|
641
|
+
except KeyboardInterrupt:
|
|
642
|
+
rprint("\n[yellow]Server stopped[/yellow]")
|
|
643
|
+
|
|
644
|
+
|
|
645
|
+
@click.command()
|
|
646
|
+
def stop():
|
|
647
|
+
"""Stop the running server"""
|
|
648
|
+
import requests
|
|
649
|
+
|
|
650
|
+
server_host = settings.server.host
|
|
651
|
+
server_port = settings.server.port
|
|
652
|
+
|
|
653
|
+
try:
|
|
654
|
+
response = requests.post(
|
|
655
|
+
f"http://{server_host}:{server_port}/api/shutdown", timeout=5
|
|
656
|
+
)
|
|
657
|
+
if response.status_code == 200:
|
|
658
|
+
rprint("[green]Server shutdown initiated[/green]")
|
|
659
|
+
else:
|
|
660
|
+
rprint(f"[red]Failed to stop server: {response.status_code}[/red]")
|
|
661
|
+
except requests.ConnectionError:
|
|
662
|
+
rprint("[yellow]No server running or already stopped[/yellow]")
|
|
663
|
+
except Exception as e:
|
|
664
|
+
rprint(f"[red]Error stopping server: {e}[/red]")
|