ollamadiffuser 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ollamadiffuser/__init__.py +0 -0
- ollamadiffuser/__main__.py +50 -0
- ollamadiffuser/api/__init__.py +0 -0
- ollamadiffuser/api/server.py +297 -0
- ollamadiffuser/cli/__init__.py +0 -0
- ollamadiffuser/cli/main.py +597 -0
- ollamadiffuser/core/__init__.py +0 -0
- ollamadiffuser/core/config/__init__.py +0 -0
- ollamadiffuser/core/config/settings.py +137 -0
- ollamadiffuser/core/inference/__init__.py +0 -0
- ollamadiffuser/core/inference/engine.py +926 -0
- ollamadiffuser/core/models/__init__.py +0 -0
- ollamadiffuser/core/models/manager.py +436 -0
- ollamadiffuser/core/utils/__init__.py +3 -0
- ollamadiffuser/core/utils/download_utils.py +356 -0
- ollamadiffuser/core/utils/lora_manager.py +390 -0
- ollamadiffuser/ui/__init__.py +0 -0
- ollamadiffuser/ui/templates/index.html +496 -0
- ollamadiffuser/ui/web.py +278 -0
- ollamadiffuser/utils/__init__.py +0 -0
- ollamadiffuser-1.0.0.dist-info/METADATA +493 -0
- ollamadiffuser-1.0.0.dist-info/RECORD +26 -0
- ollamadiffuser-1.0.0.dist-info/WHEEL +5 -0
- ollamadiffuser-1.0.0.dist-info/entry_points.txt +2 -0
- ollamadiffuser-1.0.0.dist-info/licenses/LICENSE +21 -0
- ollamadiffuser-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,597 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
import click
|
|
3
|
+
import sys
|
|
4
|
+
import logging
|
|
5
|
+
from typing import Optional
|
|
6
|
+
from rich.console import Console
|
|
7
|
+
from rich.table import Table
|
|
8
|
+
from rich.progress import Progress, SpinnerColumn, TextColumn
|
|
9
|
+
from rich import print as rprint
|
|
10
|
+
|
|
11
|
+
from ..core.models.manager import model_manager
|
|
12
|
+
from ..core.config.settings import settings
|
|
13
|
+
from ..api.server import run_server
|
|
14
|
+
|
|
15
|
+
console = Console()
|
|
16
|
+
|
|
17
|
+
@click.group()
|
|
18
|
+
@click.option('--verbose', '-v', is_flag=True, help='Enable verbose output')
|
|
19
|
+
def cli(verbose):
|
|
20
|
+
"""OllamaDiffuser - Image generation model management tool"""
|
|
21
|
+
if verbose:
|
|
22
|
+
logging.basicConfig(level=logging.DEBUG)
|
|
23
|
+
else:
|
|
24
|
+
logging.basicConfig(level=logging.WARNING)
|
|
25
|
+
|
|
26
|
+
@cli.command()
|
|
27
|
+
@click.argument('model_name')
|
|
28
|
+
@click.option('--force', '-f', is_flag=True, help='Force re-download')
|
|
29
|
+
def pull(model_name: str, force: bool):
|
|
30
|
+
"""Download model"""
|
|
31
|
+
rprint(f"[blue]Downloading model: {model_name}[/blue]")
|
|
32
|
+
|
|
33
|
+
with Progress(
|
|
34
|
+
SpinnerColumn(),
|
|
35
|
+
TextColumn("[progress.description]{task.description}"),
|
|
36
|
+
console=console
|
|
37
|
+
) as progress:
|
|
38
|
+
task = progress.add_task(f"Downloading {model_name}...", total=None)
|
|
39
|
+
|
|
40
|
+
def progress_callback(message: str):
|
|
41
|
+
"""Update progress display with download status"""
|
|
42
|
+
progress.update(task, description=message)
|
|
43
|
+
|
|
44
|
+
if model_manager.pull_model(model_name, force=force, progress_callback=progress_callback):
|
|
45
|
+
progress.update(task, description=f"✅ {model_name} download completed")
|
|
46
|
+
rprint(f"[green]Model {model_name} downloaded successfully![/green]")
|
|
47
|
+
else:
|
|
48
|
+
progress.update(task, description=f"❌ {model_name} download failed")
|
|
49
|
+
rprint(f"[red]Model {model_name} download failed![/red]")
|
|
50
|
+
sys.exit(1)
|
|
51
|
+
|
|
52
|
+
@cli.command()
|
|
53
|
+
@click.argument('model_name')
|
|
54
|
+
@click.option('--host', '-h', default=None, help='Server host address')
|
|
55
|
+
@click.option('--port', '-p', default=None, type=int, help='Server port')
|
|
56
|
+
def run(model_name: str, host: Optional[str], port: Optional[int]):
|
|
57
|
+
"""Run model service"""
|
|
58
|
+
rprint(f"[blue]Starting model service: {model_name}[/blue]")
|
|
59
|
+
|
|
60
|
+
# Check if model is installed
|
|
61
|
+
if not model_manager.is_model_installed(model_name):
|
|
62
|
+
rprint(f"[red]Model {model_name} is not installed. Please run first: ollamadiffuser pull {model_name}[/red]")
|
|
63
|
+
sys.exit(1)
|
|
64
|
+
|
|
65
|
+
# Load model
|
|
66
|
+
rprint("[yellow]Loading model...[/yellow]")
|
|
67
|
+
if not model_manager.load_model(model_name):
|
|
68
|
+
rprint(f"[red]Failed to load model {model_name}![/red]")
|
|
69
|
+
sys.exit(1)
|
|
70
|
+
|
|
71
|
+
rprint(f"[green]Model {model_name} loaded successfully![/green]")
|
|
72
|
+
|
|
73
|
+
# Start server
|
|
74
|
+
try:
|
|
75
|
+
run_server(host=host, port=port)
|
|
76
|
+
except KeyboardInterrupt:
|
|
77
|
+
rprint("\n[yellow]Server stopped[/yellow]")
|
|
78
|
+
model_manager.unload_model()
|
|
79
|
+
# Clear the current model from settings when server stops
|
|
80
|
+
settings.current_model = None
|
|
81
|
+
settings.save_config()
|
|
82
|
+
|
|
83
|
+
@cli.command()
|
|
84
|
+
@click.option('--hardware', '-hw', is_flag=True, help='Show hardware requirements')
|
|
85
|
+
def list(hardware: bool):
|
|
86
|
+
"""List all models"""
|
|
87
|
+
available_models = model_manager.list_available_models()
|
|
88
|
+
installed_models = model_manager.list_installed_models()
|
|
89
|
+
current_model = model_manager.get_current_model()
|
|
90
|
+
|
|
91
|
+
if hardware:
|
|
92
|
+
# Show detailed hardware requirements
|
|
93
|
+
for model_name in available_models:
|
|
94
|
+
info = model_manager.get_model_info(model_name)
|
|
95
|
+
if not info:
|
|
96
|
+
continue
|
|
97
|
+
|
|
98
|
+
# Check installation status
|
|
99
|
+
if model_name in installed_models:
|
|
100
|
+
status = "✅ Installed"
|
|
101
|
+
if model_name == current_model:
|
|
102
|
+
status += " (current)"
|
|
103
|
+
size = info.get('size', 'Unknown')
|
|
104
|
+
else:
|
|
105
|
+
status = "⬇️ Available"
|
|
106
|
+
size = "-"
|
|
107
|
+
|
|
108
|
+
# Create individual table for each model
|
|
109
|
+
table = Table(title=f"[bold cyan]{model_name}[/bold cyan] - {status}")
|
|
110
|
+
table.add_column("Property", style="yellow", no_wrap=True)
|
|
111
|
+
table.add_column("Value", style="white")
|
|
112
|
+
|
|
113
|
+
# Basic info
|
|
114
|
+
table.add_row("Type", info.get('model_type', 'Unknown'))
|
|
115
|
+
table.add_row("Size", size)
|
|
116
|
+
|
|
117
|
+
# Hardware requirements
|
|
118
|
+
hw_req = info.get('hardware_requirements', {})
|
|
119
|
+
if hw_req:
|
|
120
|
+
table.add_row("Min VRAM", f"{hw_req.get('min_vram_gb', 'Unknown')} GB")
|
|
121
|
+
table.add_row("Recommended VRAM", f"{hw_req.get('recommended_vram_gb', 'Unknown')} GB")
|
|
122
|
+
table.add_row("Min RAM", f"{hw_req.get('min_ram_gb', 'Unknown')} GB")
|
|
123
|
+
table.add_row("Recommended RAM", f"{hw_req.get('recommended_ram_gb', 'Unknown')} GB")
|
|
124
|
+
table.add_row("Disk Space", f"{hw_req.get('disk_space_gb', 'Unknown')} GB")
|
|
125
|
+
table.add_row("Supported Devices", ", ".join(hw_req.get('supported_devices', [])))
|
|
126
|
+
table.add_row("Performance Notes", hw_req.get('performance_notes', 'N/A'))
|
|
127
|
+
|
|
128
|
+
console.print(table)
|
|
129
|
+
console.print() # Add spacing between models
|
|
130
|
+
else:
|
|
131
|
+
# Show compact table
|
|
132
|
+
table = Table(title="OllamaDiffuser Model List")
|
|
133
|
+
table.add_column("Model Name", style="cyan", no_wrap=True)
|
|
134
|
+
table.add_column("Status", style="green")
|
|
135
|
+
table.add_column("Size", style="blue")
|
|
136
|
+
table.add_column("Type", style="magenta")
|
|
137
|
+
table.add_column("Min VRAM", style="yellow")
|
|
138
|
+
|
|
139
|
+
for model_name in available_models:
|
|
140
|
+
# Check installation status
|
|
141
|
+
if model_name in installed_models:
|
|
142
|
+
status = "✅ Installed"
|
|
143
|
+
if model_name == current_model:
|
|
144
|
+
status += " (current)"
|
|
145
|
+
|
|
146
|
+
# Get model information
|
|
147
|
+
info = model_manager.get_model_info(model_name)
|
|
148
|
+
size = info.get('size', 'Unknown') if info else 'Unknown'
|
|
149
|
+
model_type = info.get('model_type', 'Unknown') if info else 'Unknown'
|
|
150
|
+
else:
|
|
151
|
+
status = "⬇️ Available"
|
|
152
|
+
size = "-"
|
|
153
|
+
# Get type from registry
|
|
154
|
+
info = model_manager.get_model_info(model_name)
|
|
155
|
+
model_type = info.get('model_type', 'Unknown') if info else 'Unknown'
|
|
156
|
+
|
|
157
|
+
# Get hardware requirements
|
|
158
|
+
hw_req = info.get('hardware_requirements', {}) if info else {}
|
|
159
|
+
min_vram = f"{hw_req.get('min_vram_gb', '?')} GB" if hw_req else "Unknown"
|
|
160
|
+
|
|
161
|
+
table.add_row(model_name, status, size, model_type, min_vram)
|
|
162
|
+
|
|
163
|
+
console.print(table)
|
|
164
|
+
console.print("\n[dim]💡 Use --hardware flag to see detailed hardware requirements[/dim]")
|
|
165
|
+
|
|
166
|
+
@cli.command()
|
|
167
|
+
@click.argument('model_name')
|
|
168
|
+
def show(model_name: str):
|
|
169
|
+
"""Show model detailed information"""
|
|
170
|
+
info = model_manager.get_model_info(model_name)
|
|
171
|
+
|
|
172
|
+
if info is None:
|
|
173
|
+
rprint(f"[red]Model {model_name} does not exist[/red]")
|
|
174
|
+
sys.exit(1)
|
|
175
|
+
|
|
176
|
+
rprint(f"[bold cyan]Model Information: {model_name}[/bold cyan]")
|
|
177
|
+
rprint(f"Type: {info.get('model_type', 'Unknown')}")
|
|
178
|
+
rprint(f"Variant: {info.get('variant', 'Unknown')}")
|
|
179
|
+
rprint(f"Installed: {'Yes' if info.get('installed', False) else 'No'}")
|
|
180
|
+
|
|
181
|
+
if info.get('installed', False):
|
|
182
|
+
rprint(f"Local Path: {info.get('local_path', 'Unknown')}")
|
|
183
|
+
rprint(f"Size: {info.get('size', 'Unknown')}")
|
|
184
|
+
|
|
185
|
+
# Hardware requirements
|
|
186
|
+
if 'hardware_requirements' in info and info['hardware_requirements']:
|
|
187
|
+
hw_req = info['hardware_requirements']
|
|
188
|
+
rprint("\n[bold]Hardware Requirements:[/bold]")
|
|
189
|
+
rprint(f" Min VRAM: {hw_req.get('min_vram_gb', 'Unknown')} GB")
|
|
190
|
+
rprint(f" Recommended VRAM: {hw_req.get('recommended_vram_gb', 'Unknown')} GB")
|
|
191
|
+
rprint(f" Min RAM: {hw_req.get('min_ram_gb', 'Unknown')} GB")
|
|
192
|
+
rprint(f" Recommended RAM: {hw_req.get('recommended_ram_gb', 'Unknown')} GB")
|
|
193
|
+
rprint(f" Disk Space: {hw_req.get('disk_space_gb', 'Unknown')} GB")
|
|
194
|
+
rprint(f" Supported Devices: {', '.join(hw_req.get('supported_devices', []))}")
|
|
195
|
+
if hw_req.get('performance_notes'):
|
|
196
|
+
rprint(f" Performance Notes: {hw_req.get('performance_notes')}")
|
|
197
|
+
|
|
198
|
+
if 'parameters' in info and info['parameters']:
|
|
199
|
+
rprint("\n[bold]Default Parameters:[/bold]")
|
|
200
|
+
for key, value in info['parameters'].items():
|
|
201
|
+
rprint(f" {key}: {value}")
|
|
202
|
+
|
|
203
|
+
if 'components' in info and info['components']:
|
|
204
|
+
rprint("\n[bold]Components:[/bold]")
|
|
205
|
+
for key, value in info['components'].items():
|
|
206
|
+
rprint(f" {key}: {value}")
|
|
207
|
+
|
|
208
|
+
@cli.command()
|
|
209
|
+
@click.argument('model_name', required=False)
|
|
210
|
+
@click.option('--list', '-l', is_flag=True, help='List all available models')
|
|
211
|
+
def check(model_name: str, list: bool):
|
|
212
|
+
"""Check model download status and integrity"""
|
|
213
|
+
if list:
|
|
214
|
+
rprint("[bold blue]📋 Available Models:[/bold blue]")
|
|
215
|
+
available_models = model_manager.list_available_models()
|
|
216
|
+
for model in available_models:
|
|
217
|
+
model_info = model_manager.get_model_info(model)
|
|
218
|
+
status = "✅ Installed" if model_manager.is_model_installed(model) else "⬇️ Available"
|
|
219
|
+
license_type = model_info.get("license_info", {}).get("type", "Unknown")
|
|
220
|
+
rprint(f" {model:<30} {status:<15} ({license_type})")
|
|
221
|
+
return
|
|
222
|
+
|
|
223
|
+
if not model_name:
|
|
224
|
+
rprint("[bold red]❌ Please specify a model name or use --list[/bold red]")
|
|
225
|
+
rprint("[dim]Usage: ollamadiffuser check MODEL_NAME[/dim]")
|
|
226
|
+
rprint("[dim] ollamadiffuser check --list[/dim]")
|
|
227
|
+
return
|
|
228
|
+
|
|
229
|
+
# Import and run the check function
|
|
230
|
+
import subprocess
|
|
231
|
+
import sys
|
|
232
|
+
from pathlib import Path
|
|
233
|
+
|
|
234
|
+
# Run the check script
|
|
235
|
+
script_path = Path(__file__).parent.parent.parent / "examples" / "check_model_download.py"
|
|
236
|
+
try:
|
|
237
|
+
result = subprocess.run([sys.executable, str(script_path), model_name],
|
|
238
|
+
capture_output=True, text=True)
|
|
239
|
+
rprint(result.stdout)
|
|
240
|
+
if result.stderr:
|
|
241
|
+
rprint(f"[red]{result.stderr}[/red]")
|
|
242
|
+
except Exception as e:
|
|
243
|
+
rprint(f"[red]❌ Error running check: {e}[/red]")
|
|
244
|
+
|
|
245
|
+
@cli.command()
|
|
246
|
+
@click.argument('model_name')
|
|
247
|
+
@click.confirmation_option(prompt='Are you sure you want to delete this model?')
|
|
248
|
+
def rm(model_name: str):
|
|
249
|
+
"""Remove model"""
|
|
250
|
+
if model_manager.remove_model(model_name):
|
|
251
|
+
rprint(f"[green]Model {model_name} removed successfully![/green]")
|
|
252
|
+
else:
|
|
253
|
+
rprint(f"[red]Failed to remove model {model_name}![/red]")
|
|
254
|
+
sys.exit(1)
|
|
255
|
+
|
|
256
|
+
@cli.command()
|
|
257
|
+
def ps():
|
|
258
|
+
"""Show currently running model"""
|
|
259
|
+
current_model = model_manager.get_current_model()
|
|
260
|
+
server_running = model_manager.is_server_running()
|
|
261
|
+
|
|
262
|
+
if current_model:
|
|
263
|
+
rprint(f"[green]Current model: {current_model}[/green]")
|
|
264
|
+
|
|
265
|
+
# Check server status
|
|
266
|
+
if server_running:
|
|
267
|
+
rprint(f"[green]Server status: Running on {settings.server.host}:{settings.server.port}[/green]")
|
|
268
|
+
|
|
269
|
+
# Try to get model info from the running server
|
|
270
|
+
try:
|
|
271
|
+
import requests
|
|
272
|
+
response = requests.get(f"http://{settings.server.host}:{settings.server.port}/api/models/running", timeout=2)
|
|
273
|
+
if response.status_code == 200:
|
|
274
|
+
data = response.json()
|
|
275
|
+
if data.get('loaded'):
|
|
276
|
+
info = data.get('info', {})
|
|
277
|
+
rprint(f"Device: {info.get('device', 'Unknown')}")
|
|
278
|
+
rprint(f"Type: {info.get('type', 'Unknown')}")
|
|
279
|
+
rprint(f"Variant: {info.get('variant', 'Unknown')}")
|
|
280
|
+
else:
|
|
281
|
+
rprint("[yellow]Model loaded but not active in server[/yellow]")
|
|
282
|
+
except:
|
|
283
|
+
pass
|
|
284
|
+
else:
|
|
285
|
+
rprint("[yellow]Server status: Not running[/yellow]")
|
|
286
|
+
rprint("[dim]Model is set as current but server is not active[/dim]")
|
|
287
|
+
|
|
288
|
+
# Show model info from local config
|
|
289
|
+
model_info = model_manager.get_model_info(current_model)
|
|
290
|
+
if model_info:
|
|
291
|
+
rprint(f"Model type: {model_info.get('model_type', 'Unknown')}")
|
|
292
|
+
if model_info.get('installed'):
|
|
293
|
+
rprint(f"Size: {model_info.get('size', 'Unknown')}")
|
|
294
|
+
else:
|
|
295
|
+
if server_running:
|
|
296
|
+
rprint("[yellow]Server is running but no model is loaded[/yellow]")
|
|
297
|
+
rprint(f"[green]Server status: Running on {settings.server.host}:{settings.server.port}[/green]")
|
|
298
|
+
else:
|
|
299
|
+
rprint("[yellow]No model is currently running[/yellow]")
|
|
300
|
+
rprint("[dim]Use 'ollamadiffuser run <model>' to start a model[/dim]")
|
|
301
|
+
|
|
302
|
+
@cli.command()
|
|
303
|
+
@click.option('--host', '-h', default=None, help='Server host address')
|
|
304
|
+
@click.option('--port', '-p', default=None, type=int, help='Server port')
|
|
305
|
+
def serve(host: Optional[str], port: Optional[int]):
|
|
306
|
+
"""Start API server (without loading model)"""
|
|
307
|
+
rprint("[blue]Starting OllamaDiffuser API server...[/blue]")
|
|
308
|
+
|
|
309
|
+
try:
|
|
310
|
+
run_server(host=host, port=port)
|
|
311
|
+
except KeyboardInterrupt:
|
|
312
|
+
rprint("\n[yellow]Server stopped[/yellow]")
|
|
313
|
+
|
|
314
|
+
@cli.command()
|
|
315
|
+
@click.argument('model_name')
|
|
316
|
+
def load(model_name: str):
|
|
317
|
+
"""Load model into memory"""
|
|
318
|
+
rprint(f"[blue]Loading model: {model_name}[/blue]")
|
|
319
|
+
|
|
320
|
+
if model_manager.load_model(model_name):
|
|
321
|
+
rprint(f"[green]Model {model_name} loaded successfully![/green]")
|
|
322
|
+
else:
|
|
323
|
+
rprint(f"[red]Failed to load model {model_name}![/red]")
|
|
324
|
+
sys.exit(1)
|
|
325
|
+
|
|
326
|
+
@cli.command()
|
|
327
|
+
def unload():
|
|
328
|
+
"""Unload current model"""
|
|
329
|
+
if model_manager.is_model_loaded():
|
|
330
|
+
current_model = model_manager.get_current_model()
|
|
331
|
+
model_manager.unload_model()
|
|
332
|
+
rprint(f"[green]Model {current_model} unloaded[/green]")
|
|
333
|
+
else:
|
|
334
|
+
rprint("[yellow]No model to unload[/yellow]")
|
|
335
|
+
|
|
336
|
+
@cli.command()
|
|
337
|
+
def stop():
|
|
338
|
+
"""Stop running server"""
|
|
339
|
+
if not model_manager.is_server_running():
|
|
340
|
+
rprint("[yellow]No server is currently running[/yellow]")
|
|
341
|
+
return
|
|
342
|
+
|
|
343
|
+
try:
|
|
344
|
+
import requests
|
|
345
|
+
import signal
|
|
346
|
+
import psutil
|
|
347
|
+
|
|
348
|
+
host = settings.server.host
|
|
349
|
+
port = settings.server.port
|
|
350
|
+
|
|
351
|
+
# Try graceful shutdown via API first
|
|
352
|
+
try:
|
|
353
|
+
response = requests.post(f"http://{host}:{port}/api/shutdown", timeout=5)
|
|
354
|
+
if response.status_code == 200:
|
|
355
|
+
rprint("[green]Server stopped gracefully[/green]")
|
|
356
|
+
return
|
|
357
|
+
except:
|
|
358
|
+
pass
|
|
359
|
+
|
|
360
|
+
# Fallback: Find and terminate the process
|
|
361
|
+
for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
|
|
362
|
+
try:
|
|
363
|
+
cmdline = proc.info['cmdline']
|
|
364
|
+
if cmdline and any('uvicorn' in arg for arg in cmdline) and any(str(port) in arg for arg in cmdline):
|
|
365
|
+
proc.terminate()
|
|
366
|
+
proc.wait(timeout=10)
|
|
367
|
+
rprint(f"[green]Server process (PID: {proc.info['pid']}) stopped[/green]")
|
|
368
|
+
return
|
|
369
|
+
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.TimeoutExpired):
|
|
370
|
+
continue
|
|
371
|
+
|
|
372
|
+
rprint("[red]Could not find or stop the server process[/red]")
|
|
373
|
+
|
|
374
|
+
except ImportError:
|
|
375
|
+
rprint("[red]psutil package required for stop command. Install with: pip install psutil[/red]")
|
|
376
|
+
except Exception as e:
|
|
377
|
+
rprint(f"[red]Failed to stop server: {e}[/red]")
|
|
378
|
+
|
|
379
|
+
@cli.group()
|
|
380
|
+
def lora():
|
|
381
|
+
"""LoRA (Low-Rank Adaptation) management commands"""
|
|
382
|
+
pass
|
|
383
|
+
|
|
384
|
+
@lora.command()
|
|
385
|
+
@click.argument('repo_id')
|
|
386
|
+
@click.option('--weight-name', '-w', help='Specific weight file name (e.g., lora.safetensors)')
|
|
387
|
+
@click.option('--alias', '-a', help='Local alias name for the LoRA')
|
|
388
|
+
def pull(repo_id: str, weight_name: Optional[str], alias: Optional[str]):
|
|
389
|
+
"""Download LoRA weights from Hugging Face Hub"""
|
|
390
|
+
from ..core.utils.lora_manager import lora_manager
|
|
391
|
+
|
|
392
|
+
rprint(f"[blue]Downloading LoRA: {repo_id}[/blue]")
|
|
393
|
+
|
|
394
|
+
with Progress(
|
|
395
|
+
SpinnerColumn(),
|
|
396
|
+
TextColumn("[progress.description]{task.description}"),
|
|
397
|
+
console=console
|
|
398
|
+
) as progress:
|
|
399
|
+
task = progress.add_task(f"Downloading LoRA...", total=None)
|
|
400
|
+
|
|
401
|
+
def progress_callback(message: str):
|
|
402
|
+
progress.update(task, description=message)
|
|
403
|
+
|
|
404
|
+
if lora_manager.pull_lora(repo_id, weight_name=weight_name, alias=alias, progress_callback=progress_callback):
|
|
405
|
+
progress.update(task, description=f"✅ LoRA download completed")
|
|
406
|
+
rprint(f"[green]LoRA {repo_id} downloaded successfully![/green]")
|
|
407
|
+
else:
|
|
408
|
+
progress.update(task, description=f"❌ LoRA download failed")
|
|
409
|
+
rprint(f"[red]LoRA {repo_id} download failed![/red]")
|
|
410
|
+
sys.exit(1)
|
|
411
|
+
|
|
412
|
+
@lora.command()
|
|
413
|
+
@click.argument('lora_name')
|
|
414
|
+
@click.option('--scale', '-s', default=1.0, type=float, help='LoRA scale/strength (default: 1.0)')
|
|
415
|
+
def load(lora_name: str, scale: float):
|
|
416
|
+
"""Load LoRA weights into the current model"""
|
|
417
|
+
from ..core.utils.lora_manager import lora_manager
|
|
418
|
+
|
|
419
|
+
rprint(f"[blue]Loading LoRA: {lora_name} (scale: {scale})[/blue]")
|
|
420
|
+
|
|
421
|
+
if lora_manager.load_lora(lora_name, scale=scale):
|
|
422
|
+
rprint(f"[green]LoRA {lora_name} loaded successfully![/green]")
|
|
423
|
+
else:
|
|
424
|
+
rprint(f"[red]Failed to load LoRA {lora_name}![/red]")
|
|
425
|
+
sys.exit(1)
|
|
426
|
+
|
|
427
|
+
@lora.command()
|
|
428
|
+
def unload():
|
|
429
|
+
"""Unload current LoRA weights"""
|
|
430
|
+
from ..core.utils.lora_manager import lora_manager
|
|
431
|
+
|
|
432
|
+
rprint("[blue]Unloading LoRA weights...[/blue]")
|
|
433
|
+
|
|
434
|
+
if lora_manager.unload_lora():
|
|
435
|
+
rprint("[green]LoRA weights unloaded successfully![/green]")
|
|
436
|
+
else:
|
|
437
|
+
rprint("[red]Failed to unload LoRA weights![/red]")
|
|
438
|
+
sys.exit(1)
|
|
439
|
+
|
|
440
|
+
@lora.command()
|
|
441
|
+
@click.argument('lora_name')
|
|
442
|
+
@click.confirmation_option(prompt='Are you sure you want to delete this LoRA?')
|
|
443
|
+
def rm(lora_name: str):
|
|
444
|
+
"""Remove LoRA weights"""
|
|
445
|
+
from ..core.utils.lora_manager import lora_manager
|
|
446
|
+
|
|
447
|
+
rprint(f"[blue]Removing LoRA: {lora_name}[/blue]")
|
|
448
|
+
|
|
449
|
+
if lora_manager.remove_lora(lora_name):
|
|
450
|
+
rprint(f"[green]LoRA {lora_name} removed successfully![/green]")
|
|
451
|
+
else:
|
|
452
|
+
rprint(f"[red]Failed to remove LoRA {lora_name}![/red]")
|
|
453
|
+
sys.exit(1)
|
|
454
|
+
|
|
455
|
+
@lora.command()
|
|
456
|
+
def ps():
|
|
457
|
+
"""Show currently loaded LoRA status"""
|
|
458
|
+
from ..core.utils.lora_manager import lora_manager
|
|
459
|
+
|
|
460
|
+
# Check if server is running
|
|
461
|
+
server_running = lora_manager._is_server_running()
|
|
462
|
+
current_lora = lora_manager.get_current_lora()
|
|
463
|
+
|
|
464
|
+
if server_running:
|
|
465
|
+
rprint(f"[green]Server status: Running on {settings.server.host}:{settings.server.port}[/green]")
|
|
466
|
+
|
|
467
|
+
# Try to get LoRA status from the running server
|
|
468
|
+
try:
|
|
469
|
+
import requests
|
|
470
|
+
response = requests.get(f"http://{settings.server.host}:{settings.server.port}/api/models/running", timeout=2)
|
|
471
|
+
if response.status_code == 200:
|
|
472
|
+
data = response.json()
|
|
473
|
+
if data.get('loaded'):
|
|
474
|
+
model_info = data.get('info', {})
|
|
475
|
+
rprint(f"Model: {data.get('model', 'Unknown')}")
|
|
476
|
+
rprint(f"Device: {model_info.get('device', 'Unknown')}")
|
|
477
|
+
rprint(f"Type: {model_info.get('type', 'Unknown')}")
|
|
478
|
+
else:
|
|
479
|
+
rprint("[yellow]No model loaded in server[/yellow]")
|
|
480
|
+
return
|
|
481
|
+
except Exception as e:
|
|
482
|
+
rprint(f"[red]Failed to get server status: {e}[/red]")
|
|
483
|
+
return
|
|
484
|
+
else:
|
|
485
|
+
# Check local model manager
|
|
486
|
+
if model_manager.is_model_loaded():
|
|
487
|
+
current_model = model_manager.get_current_model()
|
|
488
|
+
rprint(f"[green]Model loaded locally: {current_model}[/green]")
|
|
489
|
+
else:
|
|
490
|
+
rprint("[yellow]No server running and no local model loaded[/yellow]")
|
|
491
|
+
rprint("[dim]Use 'ollamadiffuser run <model>' to start a model[/dim]")
|
|
492
|
+
return
|
|
493
|
+
|
|
494
|
+
# Show LoRA status
|
|
495
|
+
lora_status_shown = False
|
|
496
|
+
lora_loaded_on_server = False
|
|
497
|
+
|
|
498
|
+
# Try to get LoRA status from server if running
|
|
499
|
+
if server_running:
|
|
500
|
+
try:
|
|
501
|
+
import requests
|
|
502
|
+
response = requests.get(f"http://{settings.server.host}:{settings.server.port}/api/lora/status", timeout=2)
|
|
503
|
+
if response.status_code == 200:
|
|
504
|
+
lora_data = response.json()
|
|
505
|
+
if lora_data.get('loaded'):
|
|
506
|
+
lora_info = lora_data.get('info', {})
|
|
507
|
+
rprint(f"\n[bold green]🔄 LoRA Status: LOADED (via server)[/bold green]")
|
|
508
|
+
rprint(f"Adapter: {lora_info.get('adapter_name', 'Unknown')}")
|
|
509
|
+
if 'scale' in lora_info:
|
|
510
|
+
rprint(f"Scale: {lora_info.get('scale', 'Unknown')}")
|
|
511
|
+
if 'adapters' in lora_info:
|
|
512
|
+
rprint(f"Active Adapters: {', '.join(lora_info.get('adapters', []))}")
|
|
513
|
+
lora_status_shown = True
|
|
514
|
+
lora_loaded_on_server = True
|
|
515
|
+
else:
|
|
516
|
+
rprint(f"\n[dim]💾 LoRA Status: No LoRA loaded (server)[/dim]")
|
|
517
|
+
lora_status_shown = True
|
|
518
|
+
except Exception as e:
|
|
519
|
+
rprint(f"\n[yellow]⚠️ Failed to get LoRA status from server: {e}[/yellow]")
|
|
520
|
+
|
|
521
|
+
# Fallback to local LoRA manager state
|
|
522
|
+
if not lora_status_shown:
|
|
523
|
+
if current_lora:
|
|
524
|
+
lora_info = lora_manager.get_lora_info(current_lora)
|
|
525
|
+
if lora_info:
|
|
526
|
+
rprint(f"\n[bold green]🔄 LoRA Status: LOADED (local)[/bold green]")
|
|
527
|
+
rprint(f"Name: {current_lora}")
|
|
528
|
+
rprint(f"Repository: {lora_info.get('repo_id', 'Unknown')}")
|
|
529
|
+
rprint(f"Weight File: {lora_info.get('weight_name', 'Unknown')}")
|
|
530
|
+
rprint(f"Size: {lora_info.get('size', 'Unknown')}")
|
|
531
|
+
rprint(f"Local Path: {lora_info.get('path', 'Unknown')}")
|
|
532
|
+
else:
|
|
533
|
+
rprint(f"\n[yellow]⚠️ LoRA {current_lora} is set as current but info not found[/yellow]")
|
|
534
|
+
else:
|
|
535
|
+
rprint(f"\n[dim]💾 LoRA Status: No LoRA loaded[/dim]")
|
|
536
|
+
|
|
537
|
+
if not lora_loaded_on_server:
|
|
538
|
+
rprint("[dim]Use 'ollamadiffuser lora load <lora_name>' to load a LoRA[/dim]")
|
|
539
|
+
|
|
540
|
+
@lora.command()
|
|
541
|
+
def list():
|
|
542
|
+
"""List available and installed LoRA weights"""
|
|
543
|
+
from ..core.utils.lora_manager import lora_manager
|
|
544
|
+
|
|
545
|
+
installed_loras = lora_manager.list_installed_loras()
|
|
546
|
+
current_lora = lora_manager.get_current_lora()
|
|
547
|
+
|
|
548
|
+
if not installed_loras:
|
|
549
|
+
rprint("[yellow]No LoRA weights installed.[/yellow]")
|
|
550
|
+
rprint("\n[dim]💡 Use 'ollamadiffuser lora pull <repo_id>' to download LoRA weights[/dim]")
|
|
551
|
+
return
|
|
552
|
+
|
|
553
|
+
table = Table(title="Installed LoRA Weights")
|
|
554
|
+
table.add_column("Name", style="cyan", no_wrap=True)
|
|
555
|
+
table.add_column("Repository", style="blue")
|
|
556
|
+
table.add_column("Status", style="green")
|
|
557
|
+
table.add_column("Size", style="yellow")
|
|
558
|
+
|
|
559
|
+
for lora_name, lora_info in installed_loras.items():
|
|
560
|
+
status = "🔄 Loaded" if lora_name == current_lora else "💾 Available"
|
|
561
|
+
size = lora_info.get('size', 'Unknown')
|
|
562
|
+
repo_id = lora_info.get('repo_id', 'Unknown')
|
|
563
|
+
|
|
564
|
+
table.add_row(lora_name, repo_id, status, size)
|
|
565
|
+
|
|
566
|
+
console.print(table)
|
|
567
|
+
|
|
568
|
+
@lora.command()
|
|
569
|
+
@click.argument('lora_name')
|
|
570
|
+
def show(lora_name: str):
|
|
571
|
+
"""Show detailed LoRA information"""
|
|
572
|
+
from ..core.utils.lora_manager import lora_manager
|
|
573
|
+
|
|
574
|
+
lora_info = lora_manager.get_lora_info(lora_name)
|
|
575
|
+
|
|
576
|
+
if not lora_info:
|
|
577
|
+
rprint(f"[red]LoRA {lora_name} not found.[/red]")
|
|
578
|
+
sys.exit(1)
|
|
579
|
+
|
|
580
|
+
rprint(f"[bold cyan]LoRA Information: {lora_name}[/bold cyan]")
|
|
581
|
+
rprint(f"Repository: {lora_info.get('repo_id', 'Unknown')}")
|
|
582
|
+
rprint(f"Weight File: {lora_info.get('weight_name', 'Unknown')}")
|
|
583
|
+
rprint(f"Local Path: {lora_info.get('path', 'Unknown')}")
|
|
584
|
+
rprint(f"Size: {lora_info.get('size', 'Unknown')}")
|
|
585
|
+
rprint(f"Downloaded: {lora_info.get('downloaded_at', 'Unknown')}")
|
|
586
|
+
|
|
587
|
+
if lora_info.get('description'):
|
|
588
|
+
rprint(f"Description: {lora_info.get('description')}")
|
|
589
|
+
|
|
590
|
+
@cli.command()
|
|
591
|
+
def version():
|
|
592
|
+
"""Show version information"""
|
|
593
|
+
rprint("[bold cyan]OllamaDiffuser v1.0.0[/bold cyan]")
|
|
594
|
+
rprint("Image generation model management tool")
|
|
595
|
+
|
|
596
|
+
if __name__ == '__main__':
|
|
597
|
+
cli()
|
|
File without changes
|
|
File without changes
|