tensors 0.1.3__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tensors/__init__.py +26 -0
- tensors/api.py +288 -0
- tensors/cli.py +395 -0
- tensors/config.py +166 -0
- tensors/display.py +331 -0
- tensors/safetensor.py +95 -0
- {tensors-0.1.3.dist-info → tensors-0.1.4.dist-info}/METADATA +1 -1
- tensors-0.1.4.dist-info/RECORD +10 -0
- tensors-0.1.3.dist-info/RECORD +0 -5
- tensors.py +0 -1071
- {tensors-0.1.3.dist-info → tensors-0.1.4.dist-info}/WHEEL +0 -0
- {tensors-0.1.3.dist-info → tensors-0.1.4.dist-info}/entry_points.txt +0 -0
tensors.py
DELETED
|
@@ -1,1071 +0,0 @@
|
|
|
1
|
-
#!/usr/bin/env python3
|
|
2
|
-
"""
|
|
3
|
-
tsr: Read safetensor metadata, search and download CivitAI models.
|
|
4
|
-
"""
|
|
5
|
-
|
|
6
|
-
from __future__ import annotations
|
|
7
|
-
|
|
8
|
-
import hashlib
|
|
9
|
-
import json
|
|
10
|
-
import os
|
|
11
|
-
import re
|
|
12
|
-
import struct
|
|
13
|
-
import sys
|
|
14
|
-
import tomllib
|
|
15
|
-
from enum import Enum
|
|
16
|
-
from pathlib import Path
|
|
17
|
-
from typing import Annotated, Any
|
|
18
|
-
|
|
19
|
-
import httpx
|
|
20
|
-
import typer
|
|
21
|
-
from rich.console import Console
|
|
22
|
-
from rich.progress import (
|
|
23
|
-
BarColumn,
|
|
24
|
-
DownloadColumn,
|
|
25
|
-
Progress,
|
|
26
|
-
SpinnerColumn,
|
|
27
|
-
TaskProgressColumn,
|
|
28
|
-
TextColumn,
|
|
29
|
-
TimeRemainingColumn,
|
|
30
|
-
TransferSpeedColumn,
|
|
31
|
-
)
|
|
32
|
-
from rich.table import Table
|
|
33
|
-
|
|
34
|
-
# ============================================================================
|
|
35
|
-
# App and Console Setup
|
|
36
|
-
# ============================================================================
|
|
37
|
-
|
|
38
|
-
app = typer.Typer(
|
|
39
|
-
name="tsr",
|
|
40
|
-
help="Read safetensor metadata, search and download CivitAI models.",
|
|
41
|
-
no_args_is_help=True,
|
|
42
|
-
)
|
|
43
|
-
console = Console()
|
|
44
|
-
|
|
45
|
-
# ============================================================================
|
|
46
|
-
# Configuration
|
|
47
|
-
# ============================================================================
|
|
48
|
-
|
|
49
|
-
# XDG Base Directory spec
|
|
50
|
-
# Config: ~/.config/tensors/config.toml
|
|
51
|
-
# Data: ~/.local/share/tensors/models/, ~/.local/share/tensors/metadata/
|
|
52
|
-
CONFIG_DIR = Path(os.environ.get("XDG_CONFIG_HOME", Path.home() / ".config")) / "tensors"
|
|
53
|
-
CONFIG_FILE = CONFIG_DIR / "config.toml"
|
|
54
|
-
|
|
55
|
-
DATA_DIR = Path(os.environ.get("XDG_DATA_HOME", Path.home() / ".local" / "share")) / "tensors"
|
|
56
|
-
MODELS_DIR = DATA_DIR / "models"
|
|
57
|
-
METADATA_DIR = DATA_DIR / "metadata"
|
|
58
|
-
|
|
59
|
-
# Legacy config for migration
|
|
60
|
-
LEGACY_RC_FILE = Path.home() / ".sftrc"
|
|
61
|
-
|
|
62
|
-
# Default download paths by model type
|
|
63
|
-
DEFAULT_PATHS: dict[str, Path] = {
|
|
64
|
-
"Checkpoint": MODELS_DIR / "checkpoints",
|
|
65
|
-
"LORA": MODELS_DIR / "loras",
|
|
66
|
-
"LoCon": MODELS_DIR / "loras",
|
|
67
|
-
}
|
|
68
|
-
|
|
69
|
-
CIVITAI_API_BASE = "https://civitai.com/api/v1"
|
|
70
|
-
CIVITAI_DOWNLOAD_BASE = "https://civitai.com/api/download/models"
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
# ============================================================================
|
|
74
|
-
# Enums for CLI
|
|
75
|
-
# ============================================================================
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
class ModelType(str, Enum):
|
|
79
|
-
"""CivitAI model types."""
|
|
80
|
-
|
|
81
|
-
checkpoint = "checkpoint"
|
|
82
|
-
lora = "lora"
|
|
83
|
-
embedding = "embedding"
|
|
84
|
-
vae = "vae"
|
|
85
|
-
controlnet = "controlnet"
|
|
86
|
-
locon = "locon"
|
|
87
|
-
|
|
88
|
-
def to_api(self) -> str:
|
|
89
|
-
"""Convert to CivitAI API value."""
|
|
90
|
-
mapping = {
|
|
91
|
-
"checkpoint": "Checkpoint",
|
|
92
|
-
"lora": "LORA",
|
|
93
|
-
"embedding": "TextualInversion",
|
|
94
|
-
"vae": "VAE",
|
|
95
|
-
"controlnet": "Controlnet",
|
|
96
|
-
"locon": "LoCon",
|
|
97
|
-
}
|
|
98
|
-
return mapping[self.value]
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
class BaseModel(str, Enum):
|
|
102
|
-
"""Common base models."""
|
|
103
|
-
|
|
104
|
-
sd15 = "sd15"
|
|
105
|
-
sdxl = "sdxl"
|
|
106
|
-
pony = "pony"
|
|
107
|
-
flux = "flux"
|
|
108
|
-
illustrious = "illustrious"
|
|
109
|
-
|
|
110
|
-
def to_api(self) -> str:
|
|
111
|
-
"""Convert to CivitAI API value."""
|
|
112
|
-
mapping = {
|
|
113
|
-
"sd15": "SD 1.5",
|
|
114
|
-
"sdxl": "SDXL 1.0",
|
|
115
|
-
"pony": "Pony",
|
|
116
|
-
"flux": "Flux.1 D",
|
|
117
|
-
"illustrious": "Illustrious",
|
|
118
|
-
}
|
|
119
|
-
return mapping[self.value]
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
class SortOrder(str, Enum):
|
|
123
|
-
"""Sort options for search."""
|
|
124
|
-
|
|
125
|
-
downloads = "downloads"
|
|
126
|
-
rating = "rating"
|
|
127
|
-
newest = "newest"
|
|
128
|
-
|
|
129
|
-
def to_api(self) -> str:
|
|
130
|
-
"""Convert to CivitAI API value."""
|
|
131
|
-
mapping = {
|
|
132
|
-
"downloads": "Most Downloaded",
|
|
133
|
-
"rating": "Highest Rated",
|
|
134
|
-
"newest": "Newest",
|
|
135
|
-
}
|
|
136
|
-
return mapping[self.value]
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
# ============================================================================
|
|
140
|
-
# Config Functions
|
|
141
|
-
# ============================================================================
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
def load_config() -> dict[str, Any]:
|
|
145
|
-
"""Load configuration from TOML config file."""
|
|
146
|
-
if CONFIG_FILE.exists():
|
|
147
|
-
with CONFIG_FILE.open("rb") as f:
|
|
148
|
-
return tomllib.load(f)
|
|
149
|
-
return {}
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
def save_config(config: dict[str, Any]) -> None:
|
|
153
|
-
"""Save configuration to TOML config file."""
|
|
154
|
-
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
|
155
|
-
|
|
156
|
-
lines: list[str] = []
|
|
157
|
-
for key, value in config.items():
|
|
158
|
-
if isinstance(value, dict):
|
|
159
|
-
lines.append(f"[{key}]")
|
|
160
|
-
for k, v in value.items():
|
|
161
|
-
if isinstance(v, str):
|
|
162
|
-
lines.append(f'{k} = "{v}"')
|
|
163
|
-
else:
|
|
164
|
-
lines.append(f"{k} = {v}")
|
|
165
|
-
lines.append("")
|
|
166
|
-
elif isinstance(value, str):
|
|
167
|
-
lines.append(f'{key} = "{value}"')
|
|
168
|
-
else:
|
|
169
|
-
lines.append(f"{key} = {value}")
|
|
170
|
-
|
|
171
|
-
CONFIG_FILE.write_text("\n".join(lines) + "\n")
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
def load_api_key() -> str | None:
|
|
175
|
-
"""Load API key from config file or CIVITAI_API_KEY env var."""
|
|
176
|
-
# Check environment variable first
|
|
177
|
-
env_key = os.environ.get("CIVITAI_API_KEY")
|
|
178
|
-
if env_key:
|
|
179
|
-
return env_key
|
|
180
|
-
|
|
181
|
-
# Check TOML config file
|
|
182
|
-
config = load_config()
|
|
183
|
-
api_section = config.get("api", {})
|
|
184
|
-
if isinstance(api_section, dict):
|
|
185
|
-
key = api_section.get("civitai_key")
|
|
186
|
-
if key:
|
|
187
|
-
return str(key)
|
|
188
|
-
|
|
189
|
-
# Fall back to legacy RC file for migration
|
|
190
|
-
if LEGACY_RC_FILE.exists():
|
|
191
|
-
content = LEGACY_RC_FILE.read_text().strip()
|
|
192
|
-
if content:
|
|
193
|
-
return content
|
|
194
|
-
return None
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
def get_default_output_path(model_type: str | None) -> Path | None:
|
|
198
|
-
"""Get default output path based on model type."""
|
|
199
|
-
if model_type and model_type in DEFAULT_PATHS:
|
|
200
|
-
return DEFAULT_PATHS[model_type]
|
|
201
|
-
return None
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
# ============================================================================
|
|
205
|
-
# Safetensor Functions
|
|
206
|
-
# ============================================================================
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
def read_safetensor_metadata(file_path: Path) -> dict[str, Any]:
|
|
210
|
-
"""Read metadata from a safetensor file header."""
|
|
211
|
-
with file_path.open("rb") as f:
|
|
212
|
-
# First 8 bytes are the header size (little-endian u64)
|
|
213
|
-
header_size_bytes = f.read(8)
|
|
214
|
-
if len(header_size_bytes) < 8:
|
|
215
|
-
raise ValueError("Invalid safetensor file: too short")
|
|
216
|
-
|
|
217
|
-
header_size = struct.unpack("<Q", header_size_bytes)[0]
|
|
218
|
-
|
|
219
|
-
if header_size > 100_000_000: # 100MB sanity check
|
|
220
|
-
raise ValueError(f"Invalid header size: {header_size}")
|
|
221
|
-
|
|
222
|
-
header_bytes = f.read(header_size)
|
|
223
|
-
if len(header_bytes) < header_size:
|
|
224
|
-
raise ValueError("Invalid safetensor file: header truncated")
|
|
225
|
-
|
|
226
|
-
header: dict[str, Any] = json.loads(header_bytes.decode("utf-8"))
|
|
227
|
-
|
|
228
|
-
# Extract __metadata__ if present
|
|
229
|
-
metadata: dict[str, Any] = header.get("__metadata__", {})
|
|
230
|
-
|
|
231
|
-
# Count tensors (keys that aren't __metadata__)
|
|
232
|
-
tensor_count = sum(1 for k in header if k != "__metadata__")
|
|
233
|
-
|
|
234
|
-
return {
|
|
235
|
-
"metadata": metadata,
|
|
236
|
-
"tensor_count": tensor_count,
|
|
237
|
-
"header_size": header_size,
|
|
238
|
-
}
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
def compute_sha256(file_path: Path) -> str:
|
|
242
|
-
"""Compute SHA256 hash of a file with progress display."""
|
|
243
|
-
file_size = file_path.stat().st_size
|
|
244
|
-
sha256 = hashlib.sha256()
|
|
245
|
-
chunk_size = 1024 * 1024 * 8 # 8MB chunks
|
|
246
|
-
|
|
247
|
-
with Progress(
|
|
248
|
-
SpinnerColumn(),
|
|
249
|
-
TextColumn("[progress.description]{task.description}"),
|
|
250
|
-
BarColumn(),
|
|
251
|
-
TaskProgressColumn(),
|
|
252
|
-
DownloadColumn(),
|
|
253
|
-
TransferSpeedColumn(),
|
|
254
|
-
TimeRemainingColumn(),
|
|
255
|
-
console=console,
|
|
256
|
-
) as progress:
|
|
257
|
-
task = progress.add_task(f"[cyan]Hashing {file_path.name}...", total=file_size)
|
|
258
|
-
|
|
259
|
-
with file_path.open("rb") as f:
|
|
260
|
-
while chunk := f.read(chunk_size):
|
|
261
|
-
sha256.update(chunk)
|
|
262
|
-
progress.update(task, advance=len(chunk))
|
|
263
|
-
|
|
264
|
-
return sha256.hexdigest().upper()
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
def get_base_name(file_path: Path) -> str:
|
|
268
|
-
"""Get base filename without .safetensors extension."""
|
|
269
|
-
name = file_path.name
|
|
270
|
-
for ext in (".safetensors", ".sft"):
|
|
271
|
-
if name.lower().endswith(ext):
|
|
272
|
-
return name[: -len(ext)]
|
|
273
|
-
return file_path.stem
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
# ============================================================================
|
|
277
|
-
# CivitAI API Functions
|
|
278
|
-
# ============================================================================
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
def _get_headers(api_key: str | None) -> dict[str, str]:
|
|
282
|
-
"""Get headers for CivitAI API requests."""
|
|
283
|
-
headers: dict[str, str] = {}
|
|
284
|
-
if api_key:
|
|
285
|
-
headers["Authorization"] = f"Bearer {api_key}"
|
|
286
|
-
return headers
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
def fetch_civitai_model_version(
|
|
290
|
-
version_id: int, api_key: str | None = None
|
|
291
|
-
) -> dict[str, Any] | None:
|
|
292
|
-
"""Fetch model version information from CivitAI by version ID."""
|
|
293
|
-
url = f"{CIVITAI_API_BASE}/model-versions/{version_id}"
|
|
294
|
-
|
|
295
|
-
try:
|
|
296
|
-
response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
|
|
297
|
-
if response.status_code == 404:
|
|
298
|
-
return None
|
|
299
|
-
response.raise_for_status()
|
|
300
|
-
result: dict[str, Any] = response.json()
|
|
301
|
-
return result
|
|
302
|
-
except httpx.HTTPStatusError as e:
|
|
303
|
-
console.print(f"[red]API error: {e.response.status_code}[/red]")
|
|
304
|
-
return None
|
|
305
|
-
except httpx.RequestError as e:
|
|
306
|
-
console.print(f"[red]Request error: {e}[/red]")
|
|
307
|
-
return None
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
def fetch_civitai_model(model_id: int, api_key: str | None = None) -> dict[str, Any] | None:
|
|
311
|
-
"""Fetch model information from CivitAI by model ID."""
|
|
312
|
-
url = f"{CIVITAI_API_BASE}/models/{model_id}"
|
|
313
|
-
|
|
314
|
-
with Progress(
|
|
315
|
-
SpinnerColumn(),
|
|
316
|
-
TextColumn("[progress.description]{task.description}"),
|
|
317
|
-
console=console,
|
|
318
|
-
transient=True,
|
|
319
|
-
) as progress:
|
|
320
|
-
progress.add_task("[cyan]Fetching model from CivitAI...", total=None)
|
|
321
|
-
|
|
322
|
-
try:
|
|
323
|
-
response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
|
|
324
|
-
if response.status_code == 404:
|
|
325
|
-
return None
|
|
326
|
-
response.raise_for_status()
|
|
327
|
-
result: dict[str, Any] = response.json()
|
|
328
|
-
return result
|
|
329
|
-
except httpx.HTTPStatusError as e:
|
|
330
|
-
console.print(f"[red]API error: {e.response.status_code}[/red]")
|
|
331
|
-
return None
|
|
332
|
-
except httpx.RequestError as e:
|
|
333
|
-
console.print(f"[red]Request error: {e}[/red]")
|
|
334
|
-
return None
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None = None) -> dict[str, Any] | None:
|
|
338
|
-
"""Fetch model information from CivitAI by SHA256 hash."""
|
|
339
|
-
url = f"{CIVITAI_API_BASE}/model-versions/by-hash/{sha256_hash}"
|
|
340
|
-
|
|
341
|
-
with Progress(
|
|
342
|
-
SpinnerColumn(),
|
|
343
|
-
TextColumn("[progress.description]{task.description}"),
|
|
344
|
-
console=console,
|
|
345
|
-
transient=True,
|
|
346
|
-
) as progress:
|
|
347
|
-
progress.add_task("[cyan]Fetching from CivitAI...", total=None)
|
|
348
|
-
|
|
349
|
-
try:
|
|
350
|
-
response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
|
|
351
|
-
if response.status_code == 404:
|
|
352
|
-
return None
|
|
353
|
-
response.raise_for_status()
|
|
354
|
-
result: dict[str, Any] = response.json()
|
|
355
|
-
return result
|
|
356
|
-
except httpx.HTTPStatusError as e:
|
|
357
|
-
console.print(f"[red]API error: {e.response.status_code}[/red]")
|
|
358
|
-
return None
|
|
359
|
-
except httpx.RequestError as e:
|
|
360
|
-
console.print(f"[red]Request error: {e}[/red]")
|
|
361
|
-
return None
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
def search_civitai(
|
|
365
|
-
query: str | None = None,
|
|
366
|
-
model_type: ModelType | None = None,
|
|
367
|
-
base_model: BaseModel | None = None,
|
|
368
|
-
sort: SortOrder = SortOrder.downloads,
|
|
369
|
-
limit: int = 20,
|
|
370
|
-
api_key: str | None = None,
|
|
371
|
-
) -> dict[str, Any] | None:
|
|
372
|
-
"""Search CivitAI models."""
|
|
373
|
-
params: dict[str, Any] = {
|
|
374
|
-
"limit": min(limit, 100),
|
|
375
|
-
"nsfw": "true",
|
|
376
|
-
}
|
|
377
|
-
|
|
378
|
-
# API quirk: query + filters don't work reliably together
|
|
379
|
-
# If we have filters, skip query and filter client-side
|
|
380
|
-
has_filters = model_type is not None or base_model is not None
|
|
381
|
-
|
|
382
|
-
if query and not has_filters:
|
|
383
|
-
params["query"] = query
|
|
384
|
-
|
|
385
|
-
if model_type:
|
|
386
|
-
params["types"] = model_type.to_api()
|
|
387
|
-
|
|
388
|
-
if base_model:
|
|
389
|
-
params["baseModels"] = base_model.to_api()
|
|
390
|
-
|
|
391
|
-
params["sort"] = sort.to_api()
|
|
392
|
-
|
|
393
|
-
# Request more if we need client-side filtering
|
|
394
|
-
if query and has_filters:
|
|
395
|
-
params["limit"] = 100
|
|
396
|
-
|
|
397
|
-
url = f"{CIVITAI_API_BASE}/models"
|
|
398
|
-
|
|
399
|
-
with Progress(
|
|
400
|
-
SpinnerColumn(),
|
|
401
|
-
TextColumn("[progress.description]{task.description}"),
|
|
402
|
-
console=console,
|
|
403
|
-
transient=True,
|
|
404
|
-
) as progress:
|
|
405
|
-
progress.add_task("[cyan]Searching CivitAI...", total=None)
|
|
406
|
-
|
|
407
|
-
try:
|
|
408
|
-
response = httpx.get(url, params=params, headers=_get_headers(api_key), timeout=30.0)
|
|
409
|
-
response.raise_for_status()
|
|
410
|
-
result: dict[str, Any] = response.json()
|
|
411
|
-
|
|
412
|
-
# Client-side filtering when query + filters combined
|
|
413
|
-
if query and has_filters:
|
|
414
|
-
q_lower = query.lower()
|
|
415
|
-
result["items"] = [
|
|
416
|
-
m for m in result.get("items", []) if q_lower in m.get("name", "").lower()
|
|
417
|
-
][:limit]
|
|
418
|
-
|
|
419
|
-
return result
|
|
420
|
-
except httpx.HTTPStatusError as e:
|
|
421
|
-
console.print(f"[red]API error: {e.response.status_code}[/red]")
|
|
422
|
-
return None
|
|
423
|
-
except httpx.RequestError as e:
|
|
424
|
-
console.print(f"[red]Request error: {e}[/red]")
|
|
425
|
-
return None
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
def download_model(
|
|
429
|
-
version_id: int,
|
|
430
|
-
dest_path: Path,
|
|
431
|
-
api_key: str | None = None,
|
|
432
|
-
resume: bool = True,
|
|
433
|
-
) -> bool:
|
|
434
|
-
"""Download a model from CivitAI by version ID with resume support."""
|
|
435
|
-
url = f"{CIVITAI_DOWNLOAD_BASE}/{version_id}"
|
|
436
|
-
params: dict[str, str] = {}
|
|
437
|
-
if api_key:
|
|
438
|
-
params["token"] = api_key
|
|
439
|
-
|
|
440
|
-
headers: dict[str, str] = {}
|
|
441
|
-
mode = "wb"
|
|
442
|
-
initial_size = 0
|
|
443
|
-
|
|
444
|
-
# Check for existing partial download
|
|
445
|
-
if resume and dest_path.exists():
|
|
446
|
-
initial_size = dest_path.stat().st_size
|
|
447
|
-
headers["Range"] = f"bytes={initial_size}-"
|
|
448
|
-
mode = "ab"
|
|
449
|
-
console.print(f"[cyan]Resuming download from {initial_size / (1024**2):.1f} MB[/cyan]")
|
|
450
|
-
|
|
451
|
-
try:
|
|
452
|
-
with httpx.stream(
|
|
453
|
-
"GET",
|
|
454
|
-
url,
|
|
455
|
-
params=params,
|
|
456
|
-
headers=headers,
|
|
457
|
-
follow_redirects=True,
|
|
458
|
-
timeout=httpx.Timeout(30.0, read=None),
|
|
459
|
-
) as response:
|
|
460
|
-
if response.status_code == 416:
|
|
461
|
-
console.print("[green]File already fully downloaded.[/green]")
|
|
462
|
-
return True
|
|
463
|
-
|
|
464
|
-
response.raise_for_status()
|
|
465
|
-
|
|
466
|
-
content_length = response.headers.get("content-length")
|
|
467
|
-
total_size = int(content_length) + initial_size if content_length else 0
|
|
468
|
-
|
|
469
|
-
content_disp = response.headers.get("content-disposition", "")
|
|
470
|
-
if "filename=" in content_disp:
|
|
471
|
-
match = re.search(r'filename="?([^";\n]+)"?', content_disp)
|
|
472
|
-
if match and dest_path.is_dir():
|
|
473
|
-
dest_path = dest_path / match.group(1)
|
|
474
|
-
|
|
475
|
-
with Progress(
|
|
476
|
-
SpinnerColumn(),
|
|
477
|
-
TextColumn("[progress.description]{task.description}"),
|
|
478
|
-
BarColumn(),
|
|
479
|
-
TaskProgressColumn(),
|
|
480
|
-
DownloadColumn(),
|
|
481
|
-
TransferSpeedColumn(),
|
|
482
|
-
TimeRemainingColumn(),
|
|
483
|
-
console=console,
|
|
484
|
-
) as progress:
|
|
485
|
-
task = progress.add_task(
|
|
486
|
-
f"[cyan]Downloading {dest_path.name}...",
|
|
487
|
-
total=total_size if total_size > 0 else None,
|
|
488
|
-
completed=initial_size,
|
|
489
|
-
)
|
|
490
|
-
|
|
491
|
-
with dest_path.open(mode) as f:
|
|
492
|
-
for chunk in response.iter_bytes(1024 * 1024):
|
|
493
|
-
f.write(chunk)
|
|
494
|
-
progress.update(task, advance=len(chunk))
|
|
495
|
-
|
|
496
|
-
console.print(f"[green]Downloaded:[/green] {dest_path}")
|
|
497
|
-
return True
|
|
498
|
-
|
|
499
|
-
except httpx.HTTPStatusError as e:
|
|
500
|
-
console.print(f"[red]Download error: HTTP {e.response.status_code}[/red]")
|
|
501
|
-
if e.response.status_code == 401:
|
|
502
|
-
console.print("[yellow]Hint: This model may require an API key.[/yellow]")
|
|
503
|
-
return False
|
|
504
|
-
except httpx.RequestError as e:
|
|
505
|
-
console.print(f"[red]Download error: {e}[/red]")
|
|
506
|
-
return False
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
# ============================================================================
|
|
510
|
-
# Display Functions
|
|
511
|
-
# ============================================================================
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
def _format_size(size_kb: float) -> str:
|
|
515
|
-
"""Format size in KB to human-readable string."""
|
|
516
|
-
if size_kb < 1024:
|
|
517
|
-
return f"{size_kb:.0f} KB"
|
|
518
|
-
if size_kb < 1024 * 1024:
|
|
519
|
-
return f"{size_kb / 1024:.1f} MB"
|
|
520
|
-
return f"{size_kb / 1024 / 1024:.2f} GB"
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
def _format_count(count: int) -> str:
|
|
524
|
-
"""Format large numbers with K/M suffix."""
|
|
525
|
-
if count < 1000:
|
|
526
|
-
return str(count)
|
|
527
|
-
if count < 1_000_000:
|
|
528
|
-
return f"{count / 1000:.1f}K"
|
|
529
|
-
return f"{count / 1_000_000:.1f}M"
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
def _display_file_info(file_path: Path, local_metadata: dict[str, Any], sha256_hash: str) -> None:
|
|
533
|
-
"""Display file information table."""
|
|
534
|
-
file_table = Table(title="File Information", show_header=True, header_style="bold magenta")
|
|
535
|
-
file_table.add_column("Property", style="cyan")
|
|
536
|
-
file_table.add_column("Value", style="green")
|
|
537
|
-
|
|
538
|
-
file_table.add_row("File", str(file_path.name))
|
|
539
|
-
file_table.add_row("Path", str(file_path.parent))
|
|
540
|
-
file_table.add_row("Size", f"{file_path.stat().st_size / (1024**3):.2f} GB")
|
|
541
|
-
file_table.add_row("SHA256", sha256_hash)
|
|
542
|
-
file_table.add_row("Header Size", f"{local_metadata['header_size']:,} bytes")
|
|
543
|
-
file_table.add_row("Tensor Count", str(local_metadata["tensor_count"]))
|
|
544
|
-
|
|
545
|
-
console.print()
|
|
546
|
-
console.print(file_table)
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
def _display_local_metadata(local_metadata: dict[str, Any]) -> None:
|
|
550
|
-
"""Display local safetensor metadata table."""
|
|
551
|
-
if local_metadata["metadata"]:
|
|
552
|
-
meta_table = Table(
|
|
553
|
-
title="Safetensor Metadata", show_header=True, header_style="bold magenta"
|
|
554
|
-
)
|
|
555
|
-
meta_table.add_column("Key", style="cyan")
|
|
556
|
-
meta_table.add_column("Value", style="green", max_width=80)
|
|
557
|
-
|
|
558
|
-
for key, value in sorted(local_metadata["metadata"].items()):
|
|
559
|
-
display_value = str(value)
|
|
560
|
-
if len(display_value) > 200:
|
|
561
|
-
display_value = display_value[:200] + "..."
|
|
562
|
-
meta_table.add_row(key, display_value)
|
|
563
|
-
|
|
564
|
-
console.print()
|
|
565
|
-
console.print(meta_table)
|
|
566
|
-
else:
|
|
567
|
-
console.print()
|
|
568
|
-
console.print("[yellow]No embedded metadata found in safetensor file.[/yellow]")
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
def _display_civitai_data(civitai_data: dict[str, Any] | None) -> None:
|
|
572
|
-
"""Display CivitAI model information table."""
|
|
573
|
-
if not civitai_data:
|
|
574
|
-
console.print()
|
|
575
|
-
console.print("[yellow]Model not found on CivitAI.[/yellow]")
|
|
576
|
-
return
|
|
577
|
-
|
|
578
|
-
civit_table = Table(
|
|
579
|
-
title="CivitAI Model Information", show_header=True, header_style="bold magenta"
|
|
580
|
-
)
|
|
581
|
-
civit_table.add_column("Property", style="cyan")
|
|
582
|
-
civit_table.add_column("Value", style="green", max_width=80)
|
|
583
|
-
|
|
584
|
-
civit_table.add_row("Model ID", str(civitai_data.get("modelId", "N/A")))
|
|
585
|
-
civit_table.add_row("Version ID", str(civitai_data.get("id", "N/A")))
|
|
586
|
-
civit_table.add_row("Version Name", str(civitai_data.get("name", "N/A")))
|
|
587
|
-
civit_table.add_row("Base Model", str(civitai_data.get("baseModel", "N/A")))
|
|
588
|
-
civit_table.add_row("Created At", str(civitai_data.get("createdAt", "N/A")))
|
|
589
|
-
|
|
590
|
-
trained_words: list[str] = civitai_data.get("trainedWords", [])
|
|
591
|
-
if trained_words:
|
|
592
|
-
civit_table.add_row("Trigger Words", ", ".join(trained_words))
|
|
593
|
-
|
|
594
|
-
download_url = str(civitai_data.get("downloadUrl", "N/A"))
|
|
595
|
-
civit_table.add_row("Download URL", download_url)
|
|
596
|
-
|
|
597
|
-
files: list[dict[str, Any]] = civitai_data.get("files", [])
|
|
598
|
-
for f in files:
|
|
599
|
-
if f.get("primary"):
|
|
600
|
-
civit_table.add_row("Primary File", str(f.get("name", "N/A")))
|
|
601
|
-
civit_table.add_row("File Size (CivitAI)", _format_size(f.get("sizeKB", 0)))
|
|
602
|
-
meta: dict[str, Any] = f.get("metadata", {})
|
|
603
|
-
if meta:
|
|
604
|
-
civit_table.add_row("Format", str(meta.get("format", "N/A")))
|
|
605
|
-
civit_table.add_row("Precision", str(meta.get("fp", "N/A")))
|
|
606
|
-
civit_table.add_row("Size Type", str(meta.get("size", "N/A")))
|
|
607
|
-
|
|
608
|
-
console.print()
|
|
609
|
-
console.print(civit_table)
|
|
610
|
-
|
|
611
|
-
model_id = civitai_data.get("modelId")
|
|
612
|
-
if model_id:
|
|
613
|
-
console.print()
|
|
614
|
-
console.print(
|
|
615
|
-
f"[bold blue]View on CivitAI:[/bold blue] https://civitai.com/models/{model_id}"
|
|
616
|
-
)
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
def _display_model_info(model_data: dict[str, Any]) -> None:
|
|
620
|
-
"""Display full CivitAI model information."""
|
|
621
|
-
model_table = Table(title="Model Information", show_header=True, header_style="bold magenta")
|
|
622
|
-
model_table.add_column("Property", style="cyan")
|
|
623
|
-
model_table.add_column("Value", style="green", max_width=80)
|
|
624
|
-
|
|
625
|
-
model_table.add_row("ID", str(model_data.get("id", "N/A")))
|
|
626
|
-
model_table.add_row("Name", str(model_data.get("name", "N/A")))
|
|
627
|
-
model_table.add_row("Type", str(model_data.get("type", "N/A")))
|
|
628
|
-
model_table.add_row("NSFW", str(model_data.get("nsfw", False)))
|
|
629
|
-
|
|
630
|
-
creator = model_data.get("creator", {})
|
|
631
|
-
if creator:
|
|
632
|
-
model_table.add_row("Creator", str(creator.get("username", "N/A")))
|
|
633
|
-
|
|
634
|
-
tags: list[str] = model_data.get("tags", [])
|
|
635
|
-
if tags:
|
|
636
|
-
model_table.add_row("Tags", ", ".join(tags[:10]) + ("..." if len(tags) > 10 else ""))
|
|
637
|
-
|
|
638
|
-
stats: dict[str, Any] = model_data.get("stats", {})
|
|
639
|
-
if stats:
|
|
640
|
-
model_table.add_row("Downloads", f"{stats.get('downloadCount', 0):,}")
|
|
641
|
-
model_table.add_row("Favorites", f"{stats.get('favoriteCount', 0):,}")
|
|
642
|
-
model_table.add_row(
|
|
643
|
-
"Rating", f"{stats.get('rating', 0):.1f} ({stats.get('ratingCount', 0)} ratings)"
|
|
644
|
-
)
|
|
645
|
-
|
|
646
|
-
mode = model_data.get("mode")
|
|
647
|
-
if mode:
|
|
648
|
-
model_table.add_row("Status", str(mode))
|
|
649
|
-
|
|
650
|
-
console.print()
|
|
651
|
-
console.print(model_table)
|
|
652
|
-
|
|
653
|
-
versions: list[dict[str, Any]] = model_data.get("modelVersions", [])
|
|
654
|
-
if versions:
|
|
655
|
-
ver_table = Table(title="Model Versions", show_header=True, header_style="bold magenta")
|
|
656
|
-
ver_table.add_column("ID", style="cyan")
|
|
657
|
-
ver_table.add_column("Name", style="green")
|
|
658
|
-
ver_table.add_column("Base Model", style="yellow")
|
|
659
|
-
ver_table.add_column("Created", style="blue")
|
|
660
|
-
ver_table.add_column("Primary File", style="white")
|
|
661
|
-
|
|
662
|
-
for ver in versions:
|
|
663
|
-
files: list[dict[str, Any]] = ver.get("files", [])
|
|
664
|
-
primary_file = next((f for f in files if f.get("primary")), files[0] if files else None)
|
|
665
|
-
file_info = ""
|
|
666
|
-
if primary_file:
|
|
667
|
-
file_info = (
|
|
668
|
-
f"{primary_file.get('name', 'N/A')} "
|
|
669
|
-
f"({_format_size(primary_file.get('sizeKB', 0))})"
|
|
670
|
-
)
|
|
671
|
-
|
|
672
|
-
created = str(ver.get("createdAt", "N/A"))[:10]
|
|
673
|
-
ver_table.add_row(
|
|
674
|
-
str(ver.get("id", "N/A")),
|
|
675
|
-
str(ver.get("name", "N/A")),
|
|
676
|
-
str(ver.get("baseModel", "N/A")),
|
|
677
|
-
created,
|
|
678
|
-
file_info,
|
|
679
|
-
)
|
|
680
|
-
|
|
681
|
-
console.print()
|
|
682
|
-
console.print(ver_table)
|
|
683
|
-
|
|
684
|
-
model_id = model_data.get("id")
|
|
685
|
-
if model_id:
|
|
686
|
-
console.print()
|
|
687
|
-
console.print(
|
|
688
|
-
f"[bold blue]View on CivitAI:[/bold blue] https://civitai.com/models/{model_id}"
|
|
689
|
-
)
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
def _display_search_results(results: dict[str, Any]) -> None:
|
|
693
|
-
"""Display search results in a table."""
|
|
694
|
-
items = results.get("items", [])
|
|
695
|
-
if not items:
|
|
696
|
-
console.print("[yellow]No results found.[/yellow]")
|
|
697
|
-
return
|
|
698
|
-
|
|
699
|
-
table = Table(show_header=True, header_style="bold magenta")
|
|
700
|
-
table.add_column("ID", style="cyan", justify="right")
|
|
701
|
-
table.add_column("Name", style="green", max_width=40)
|
|
702
|
-
table.add_column("Type", style="yellow")
|
|
703
|
-
table.add_column("Base", style="blue")
|
|
704
|
-
table.add_column("Size", justify="right")
|
|
705
|
-
table.add_column("DLs", justify="right")
|
|
706
|
-
table.add_column("Rating", justify="right")
|
|
707
|
-
|
|
708
|
-
for model in items:
|
|
709
|
-
model_id = str(model.get("id", ""))
|
|
710
|
-
name = model.get("name", "N/A")
|
|
711
|
-
if len(name) > 40:
|
|
712
|
-
name = name[:37] + "..."
|
|
713
|
-
model_type = model.get("type", "N/A")
|
|
714
|
-
|
|
715
|
-
# Get latest version info
|
|
716
|
-
versions = model.get("modelVersions", [])
|
|
717
|
-
base_model = "N/A"
|
|
718
|
-
size = "N/A"
|
|
719
|
-
if versions:
|
|
720
|
-
latest = versions[0]
|
|
721
|
-
base_model = latest.get("baseModel", "N/A")
|
|
722
|
-
files = latest.get("files", [])
|
|
723
|
-
primary = next((f for f in files if f.get("primary")), files[0] if files else None)
|
|
724
|
-
if primary:
|
|
725
|
-
size = _format_size(primary.get("sizeKB", 0))
|
|
726
|
-
|
|
727
|
-
stats = model.get("stats", {})
|
|
728
|
-
downloads = _format_count(stats.get("downloadCount", 0))
|
|
729
|
-
rating = f"{stats.get('rating', 0):.1f}"
|
|
730
|
-
|
|
731
|
-
table.add_row(model_id, name, model_type, base_model, size, downloads, rating)
|
|
732
|
-
|
|
733
|
-
console.print()
|
|
734
|
-
console.print(table)
|
|
735
|
-
|
|
736
|
-
metadata = results.get("metadata", {})
|
|
737
|
-
total = metadata.get("totalItems", len(items))
|
|
738
|
-
console.print(f"\n[dim]Showing {len(items)} of {total:,} results[/dim]")
|
|
739
|
-
console.print("[dim]Use 'tsr get <id>' to view details or 'tsr dl -m <id>' to download[/dim]")
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
# ============================================================================
|
|
743
|
-
# CLI Commands
|
|
744
|
-
# ============================================================================
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
@app.command()
|
|
748
|
-
def info(
|
|
749
|
-
file: Annotated[Path, typer.Argument(help="Path to the safetensor file")],
|
|
750
|
-
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
|
|
751
|
-
skip_civitai: Annotated[
|
|
752
|
-
bool, typer.Option("--skip-civitai", help="Skip CivitAI API lookup")
|
|
753
|
-
] = False,
|
|
754
|
-
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
|
755
|
-
save_to: Annotated[
|
|
756
|
-
Path | None, typer.Option("--save-to", help="Save metadata to directory")
|
|
757
|
-
] = None,
|
|
758
|
-
) -> None:
|
|
759
|
-
"""Read safetensor metadata and fetch CivitAI info."""
|
|
760
|
-
file_path = file.resolve()
|
|
761
|
-
|
|
762
|
-
if not file_path.exists():
|
|
763
|
-
console.print(f"[red]Error: File not found: {file_path}[/red]")
|
|
764
|
-
raise typer.Exit(1)
|
|
765
|
-
|
|
766
|
-
if file_path.suffix.lower() not in (".safetensors", ".sft"):
|
|
767
|
-
console.print("[yellow]Warning: File does not have .safetensors extension[/yellow]")
|
|
768
|
-
|
|
769
|
-
try:
|
|
770
|
-
console.print(f"[bold]Reading safetensor file:[/bold] {file_path.name}")
|
|
771
|
-
local_metadata = read_safetensor_metadata(file_path)
|
|
772
|
-
sha256_hash = compute_sha256(file_path)
|
|
773
|
-
|
|
774
|
-
civitai_data = None
|
|
775
|
-
if not skip_civitai:
|
|
776
|
-
key = api_key or load_api_key()
|
|
777
|
-
civitai_data = fetch_civitai_by_hash(sha256_hash, key)
|
|
778
|
-
|
|
779
|
-
if json_output:
|
|
780
|
-
output = {
|
|
781
|
-
"file": str(file_path),
|
|
782
|
-
"sha256": sha256_hash,
|
|
783
|
-
"header_size": local_metadata["header_size"],
|
|
784
|
-
"tensor_count": local_metadata["tensor_count"],
|
|
785
|
-
"metadata": local_metadata["metadata"],
|
|
786
|
-
"civitai": civitai_data,
|
|
787
|
-
}
|
|
788
|
-
console.print_json(data=output)
|
|
789
|
-
else:
|
|
790
|
-
_display_file_info(file_path, local_metadata, sha256_hash)
|
|
791
|
-
_display_local_metadata(local_metadata)
|
|
792
|
-
_display_civitai_data(civitai_data)
|
|
793
|
-
|
|
794
|
-
if save_to:
|
|
795
|
-
output_dir = save_to.resolve()
|
|
796
|
-
if not output_dir.exists() or not output_dir.is_dir():
|
|
797
|
-
console.print(f"[red]Error: Invalid directory: {output_dir}[/red]")
|
|
798
|
-
raise typer.Exit(1)
|
|
799
|
-
|
|
800
|
-
base_name = get_base_name(file_path)
|
|
801
|
-
json_path = output_dir / f"{base_name}.json"
|
|
802
|
-
sha_path = output_dir / f"{base_name}.sha256"
|
|
803
|
-
|
|
804
|
-
output = {
|
|
805
|
-
"file": str(file_path),
|
|
806
|
-
"sha256": sha256_hash,
|
|
807
|
-
"header_size": local_metadata["header_size"],
|
|
808
|
-
"tensor_count": local_metadata["tensor_count"],
|
|
809
|
-
"metadata": local_metadata["metadata"],
|
|
810
|
-
"civitai": civitai_data,
|
|
811
|
-
}
|
|
812
|
-
json_path.write_text(json.dumps(output, indent=2))
|
|
813
|
-
sha_path.write_text(f"{sha256_hash} {file_path.name}\n")
|
|
814
|
-
|
|
815
|
-
console.print()
|
|
816
|
-
console.print(f"[green]Saved:[/green] {json_path}")
|
|
817
|
-
console.print(f"[green]Saved:[/green] {sha_path}")
|
|
818
|
-
|
|
819
|
-
except ValueError as e:
|
|
820
|
-
console.print(f"[red]Error reading safetensor: {e}[/red]")
|
|
821
|
-
raise typer.Exit(1) from e
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
@app.command()
|
|
825
|
-
def search(
|
|
826
|
-
query: Annotated[str | None, typer.Argument(help="Search query (optional)")] = None,
|
|
827
|
-
model_type: Annotated[
|
|
828
|
-
ModelType | None, typer.Option("-t", "--type", help="Model type filter")
|
|
829
|
-
] = None,
|
|
830
|
-
base: Annotated[
|
|
831
|
-
BaseModel | None, typer.Option("-b", "--base", help="Base model filter")
|
|
832
|
-
] = None,
|
|
833
|
-
sort: Annotated[
|
|
834
|
-
SortOrder, typer.Option("-s", "--sort", help="Sort order")
|
|
835
|
-
] = SortOrder.downloads,
|
|
836
|
-
limit: Annotated[int, typer.Option("-n", "--limit", help="Max results")] = 20,
|
|
837
|
-
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
|
838
|
-
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
|
|
839
|
-
) -> None:
|
|
840
|
-
"""Search CivitAI models."""
|
|
841
|
-
key = api_key or load_api_key()
|
|
842
|
-
|
|
843
|
-
results = search_civitai(
|
|
844
|
-
query=query,
|
|
845
|
-
model_type=model_type,
|
|
846
|
-
base_model=base,
|
|
847
|
-
sort=sort,
|
|
848
|
-
limit=limit,
|
|
849
|
-
api_key=key,
|
|
850
|
-
)
|
|
851
|
-
|
|
852
|
-
if not results:
|
|
853
|
-
console.print("[red]Search failed.[/red]")
|
|
854
|
-
raise typer.Exit(1)
|
|
855
|
-
|
|
856
|
-
if json_output:
|
|
857
|
-
console.print_json(data=results)
|
|
858
|
-
else:
|
|
859
|
-
_display_search_results(results)
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
@app.command()
|
|
863
|
-
def get(
|
|
864
|
-
id_value: Annotated[int, typer.Argument(help="CivitAI model ID or version ID")],
|
|
865
|
-
version: Annotated[
|
|
866
|
-
bool, typer.Option("-v", "--version", help="Treat ID as version ID instead of model ID")
|
|
867
|
-
] = False,
|
|
868
|
-
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
|
|
869
|
-
json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
|
|
870
|
-
) -> None:
|
|
871
|
-
"""Fetch model information from CivitAI by model ID or version ID."""
|
|
872
|
-
key = api_key or load_api_key()
|
|
873
|
-
|
|
874
|
-
if version:
|
|
875
|
-
# Fetch by version ID
|
|
876
|
-
version_data = fetch_civitai_model_version(id_value, key)
|
|
877
|
-
if not version_data:
|
|
878
|
-
console.print(f"[red]Error: Version {id_value} not found on CivitAI.[/red]")
|
|
879
|
-
raise typer.Exit(1)
|
|
880
|
-
|
|
881
|
-
if json_output:
|
|
882
|
-
console.print_json(data=version_data)
|
|
883
|
-
else:
|
|
884
|
-
_display_civitai_data(version_data)
|
|
885
|
-
else:
|
|
886
|
-
# Fetch by model ID
|
|
887
|
-
model_data = fetch_civitai_model(id_value, key)
|
|
888
|
-
if not model_data:
|
|
889
|
-
console.print(f"[red]Error: Model {id_value} not found on CivitAI.[/red]")
|
|
890
|
-
raise typer.Exit(1)
|
|
891
|
-
|
|
892
|
-
if json_output:
|
|
893
|
-
console.print_json(data=model_data)
|
|
894
|
-
else:
|
|
895
|
-
_display_model_info(model_data)
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
def _resolve_version_id(
|
|
899
|
-
version_id: int | None,
|
|
900
|
-
hash_val: str | None,
|
|
901
|
-
model_id: int | None,
|
|
902
|
-
api_key: str | None,
|
|
903
|
-
) -> int | None:
|
|
904
|
-
"""Resolve version ID from hash or model ID."""
|
|
905
|
-
if version_id:
|
|
906
|
-
return version_id
|
|
907
|
-
|
|
908
|
-
if hash_val:
|
|
909
|
-
console.print(f"[cyan]Looking up model by hash: {hash_val[:16]}...[/cyan]")
|
|
910
|
-
civitai_data = fetch_civitai_by_hash(hash_val.upper(), api_key)
|
|
911
|
-
if not civitai_data:
|
|
912
|
-
console.print("[red]Error: Model not found on CivitAI for this hash.[/red]")
|
|
913
|
-
return None
|
|
914
|
-
vid: int | None = civitai_data.get("id")
|
|
915
|
-
if vid:
|
|
916
|
-
console.print(f"[green]Found:[/green] {civitai_data.get('name', 'N/A')}")
|
|
917
|
-
return vid
|
|
918
|
-
|
|
919
|
-
if model_id:
|
|
920
|
-
console.print(f"[cyan]Looking up model {model_id}...[/cyan]")
|
|
921
|
-
model_data = fetch_civitai_model(model_id, api_key)
|
|
922
|
-
if not model_data:
|
|
923
|
-
console.print(f"[red]Error: Model {model_id} not found.[/red]")
|
|
924
|
-
return None
|
|
925
|
-
versions = model_data.get("modelVersions", [])
|
|
926
|
-
if not versions:
|
|
927
|
-
console.print("[red]Error: Model has no versions.[/red]")
|
|
928
|
-
return None
|
|
929
|
-
latest = versions[0]
|
|
930
|
-
latest_vid: int | None = latest.get("id")
|
|
931
|
-
if latest_vid:
|
|
932
|
-
name = latest.get("name", "N/A")
|
|
933
|
-
console.print(f"[green]Found latest:[/green] {name} (ID: {latest_vid})")
|
|
934
|
-
return latest_vid
|
|
935
|
-
|
|
936
|
-
return None
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
def _prepare_download_dir(output: Path | None, model_type_str: str | None) -> Path | None:
|
|
940
|
-
"""Prepare output directory for download."""
|
|
941
|
-
if output is None:
|
|
942
|
-
output_dir = get_default_output_path(model_type_str)
|
|
943
|
-
if output_dir is None:
|
|
944
|
-
console.print(
|
|
945
|
-
f"[red]Error: No default path for type '{model_type_str}'. "
|
|
946
|
-
"Use --output to specify.[/red]"
|
|
947
|
-
)
|
|
948
|
-
return None
|
|
949
|
-
console.print(f"[dim]Using default path for {model_type_str}: {output_dir}[/dim]")
|
|
950
|
-
else:
|
|
951
|
-
output_dir = output.resolve()
|
|
952
|
-
|
|
953
|
-
if not output_dir.exists():
|
|
954
|
-
console.print(f"[cyan]Creating directory: {output_dir}[/cyan]")
|
|
955
|
-
output_dir.mkdir(parents=True, exist_ok=True)
|
|
956
|
-
elif not output_dir.is_dir():
|
|
957
|
-
console.print(f"[red]Error: Not a directory: {output_dir}[/red]")
|
|
958
|
-
return None
|
|
959
|
-
|
|
960
|
-
return output_dir
|
|
961
|
-
|
|
962
|
-
|
|
963
|
-
@app.command("dl")
|
|
964
|
-
def download(
|
|
965
|
-
version_id: Annotated[
|
|
966
|
-
int | None, typer.Option("-v", "--version-id", help="Model version ID")
|
|
967
|
-
] = None,
|
|
968
|
-
model_id: Annotated[
|
|
969
|
-
int | None, typer.Option("-m", "--model-id", help="Model ID (downloads latest)")
|
|
970
|
-
] = None,
|
|
971
|
-
hash_val: Annotated[
|
|
972
|
-
str | None, typer.Option("-H", "--hash", help="SHA256 hash to look up")
|
|
973
|
-
] = None,
|
|
974
|
-
output: Annotated[Path | None, typer.Option("-o", "--output", help="Output directory")] = None,
|
|
975
|
-
no_resume: Annotated[
|
|
976
|
-
bool, typer.Option("--no-resume", help="Don't resume partial downloads")
|
|
977
|
-
] = False,
|
|
978
|
-
api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
|
|
979
|
-
) -> None:
|
|
980
|
-
"""Download a model from CivitAI."""
|
|
981
|
-
key = api_key or load_api_key()
|
|
982
|
-
|
|
983
|
-
resolved_version_id = _resolve_version_id(version_id, hash_val, model_id, key)
|
|
984
|
-
if not resolved_version_id:
|
|
985
|
-
if not version_id and not hash_val and not model_id:
|
|
986
|
-
console.print("[red]Error: Must specify --version-id, --model-id, or --hash[/red]")
|
|
987
|
-
raise typer.Exit(1)
|
|
988
|
-
|
|
989
|
-
console.print(f"[cyan]Fetching version info for {resolved_version_id}...[/cyan]")
|
|
990
|
-
version_info = fetch_civitai_model_version(resolved_version_id, key)
|
|
991
|
-
if not version_info:
|
|
992
|
-
console.print("[red]Error: Could not fetch model version info.[/red]")
|
|
993
|
-
raise typer.Exit(1)
|
|
994
|
-
|
|
995
|
-
model_type_str: str | None = version_info.get("model", {}).get("type")
|
|
996
|
-
output_dir = _prepare_download_dir(output, model_type_str)
|
|
997
|
-
if not output_dir:
|
|
998
|
-
raise typer.Exit(1)
|
|
999
|
-
|
|
1000
|
-
files: list[dict[str, Any]] = version_info.get("files", [])
|
|
1001
|
-
primary_file = next((f for f in files if f.get("primary")), files[0] if files else None)
|
|
1002
|
-
if not primary_file:
|
|
1003
|
-
console.print("[red]Error: No files found for this version.[/red]")
|
|
1004
|
-
raise typer.Exit(1)
|
|
1005
|
-
|
|
1006
|
-
filename = primary_file.get("name", f"model-{resolved_version_id}.safetensors")
|
|
1007
|
-
dest_path = output_dir / filename
|
|
1008
|
-
|
|
1009
|
-
table = Table(title="Model Download", show_header=True, header_style="bold magenta")
|
|
1010
|
-
table.add_column("Property", style="cyan")
|
|
1011
|
-
table.add_column("Value", style="green")
|
|
1012
|
-
table.add_row("Version", version_info.get("name", "N/A"))
|
|
1013
|
-
table.add_row("Base Model", version_info.get("baseModel", "N/A"))
|
|
1014
|
-
table.add_row("File", filename)
|
|
1015
|
-
table.add_row("Size", _format_size(primary_file.get("sizeKB", 0)))
|
|
1016
|
-
table.add_row("Destination", str(dest_path))
|
|
1017
|
-
console.print()
|
|
1018
|
-
console.print(table)
|
|
1019
|
-
console.print()
|
|
1020
|
-
|
|
1021
|
-
success = download_model(resolved_version_id, dest_path, key, resume=not no_resume)
|
|
1022
|
-
if not success:
|
|
1023
|
-
raise typer.Exit(1)
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
@app.command()
|
|
1027
|
-
def config(
|
|
1028
|
-
show: Annotated[bool, typer.Option("--show", help="Show current config")] = False,
|
|
1029
|
-
set_key: Annotated[str | None, typer.Option("--set-key", help="Set CivitAI API key")] = None,
|
|
1030
|
-
) -> None:
|
|
1031
|
-
"""Manage configuration."""
|
|
1032
|
-
if set_key:
|
|
1033
|
-
cfg = load_config()
|
|
1034
|
-
if "api" not in cfg:
|
|
1035
|
-
cfg["api"] = {}
|
|
1036
|
-
cfg["api"]["civitai_key"] = set_key
|
|
1037
|
-
save_config(cfg)
|
|
1038
|
-
console.print(f"[green]API key saved to {CONFIG_FILE}[/green]")
|
|
1039
|
-
return
|
|
1040
|
-
|
|
1041
|
-
if show or (not set_key):
|
|
1042
|
-
console.print(f"[bold]Config file:[/bold] {CONFIG_FILE}")
|
|
1043
|
-
console.print(f"[bold]Config exists:[/bold] {CONFIG_FILE.exists()}")
|
|
1044
|
-
|
|
1045
|
-
key = load_api_key()
|
|
1046
|
-
if key:
|
|
1047
|
-
masked = key[:4] + "..." + key[-4:] if len(key) > 8 else "***"
|
|
1048
|
-
console.print(f"[bold]API key:[/bold] {masked}")
|
|
1049
|
-
else:
|
|
1050
|
-
console.print("[bold]API key:[/bold] [yellow]Not set[/yellow]")
|
|
1051
|
-
|
|
1052
|
-
console.print()
|
|
1053
|
-
console.print("[dim]Set API key with: tsr config --set-key YOUR_KEY[/dim]")
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
def main() -> int:
|
|
1057
|
-
"""Main entry point."""
|
|
1058
|
-
# Handle legacy invocation: tsr <file.safetensors> -> tsr info <file>
|
|
1059
|
-
if len(sys.argv) > 1 and not sys.argv[1].startswith("-"):
|
|
1060
|
-
arg = sys.argv[1]
|
|
1061
|
-
if arg not in ("info", "search", "get", "dl", "download", "config") and (
|
|
1062
|
-
arg.endswith(".safetensors") or arg.endswith(".sft") or Path(arg).exists()
|
|
1063
|
-
):
|
|
1064
|
-
sys.argv = [sys.argv[0], "info", *sys.argv[1:]]
|
|
1065
|
-
|
|
1066
|
-
app()
|
|
1067
|
-
return 0
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
if __name__ == "__main__":
|
|
1071
|
-
sys.exit(main())
|