tensors 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,154 @@
1
+ Metadata-Version: 2.4
2
+ Name: tensors
3
+ Version: 0.1.1
4
+ Summary: Read safetensor metadata and fetch CivitAI model information
5
+ Requires-Python: >=3.12
6
+ Requires-Dist: httpx>=0.27.0
7
+ Requires-Dist: rich>=13.0.0
8
+ Requires-Dist: safetensors>=0.4.0
9
+ Requires-Dist: typer>=0.15.0
10
+ Description-Content-Type: text/markdown
11
+
12
+ # tensors
13
+
14
+ A CLI tool for working with safetensor files and CivitAI models.
15
+
16
+ ## Features
17
+
18
+ - **Read safetensor metadata** - Parse headers, count tensors, extract embedded metadata
19
+ - **CivitAI integration** - Search models, fetch info, identify files by hash
20
+ - **Download models** - Resume support, type-based default paths
21
+ - **Hash verification** - SHA256 computation with progress display
22
+
23
+ ## Installation
24
+
25
+ ```bash
26
+ # Clone and install
27
+ git clone https://github.com/aladac/tensors.git
28
+ cd tensors
29
+ uv sync
30
+
31
+ # Or install directly
32
+ uv pip install git+https://github.com/aladac/tensors.git
33
+ ```
34
+
35
+ ## Usage
36
+
37
+ ### Search CivitAI
38
+
39
+ ```bash
40
+ # Search by query
41
+ tsr search "illustrious"
42
+
43
+ # Filter by type and base model
44
+ tsr search -t lora -b sdxl
45
+
46
+ # Sort by newest, limit results
47
+ tsr search -t checkpoint -s newest -n 10
48
+ ```
49
+
50
+ ### Get Model Info
51
+
52
+ ```bash
53
+ # Get model info by ID (shows all versions)
54
+ tsr get 12345
55
+
56
+ # Get specific version info
57
+ tsr get -v 67890
58
+ ```
59
+
60
+ ### Download Models
61
+
62
+ ```bash
63
+ # Download latest version of a model
64
+ tsr dl -m 12345
65
+
66
+ # Download specific version
67
+ tsr dl -v 67890
68
+
69
+ # Download by hash lookup
70
+ tsr dl -H ABC123...
71
+
72
+ # Custom output directory
73
+ tsr dl -m 12345 -o ./models
74
+ ```
75
+
76
+ ### Inspect Local Files
77
+
78
+ ```bash
79
+ # Read safetensor file and lookup on CivitAI
80
+ tsr info model.safetensors
81
+
82
+ # Skip CivitAI lookup
83
+ tsr info model.safetensors --skip-civitai
84
+
85
+ # Output as JSON
86
+ tsr info model.safetensors -j
87
+
88
+ # Save metadata files
89
+ tsr info model.safetensors --save-to ./metadata
90
+ ```
91
+
92
+ ### Configuration
93
+
94
+ ```bash
95
+ # Show current config
96
+ tsr config
97
+
98
+ # Set CivitAI API key
99
+ tsr config --set-key YOUR_API_KEY
100
+ ```
101
+
102
+ ## Configuration
103
+
104
+ Config file: `~/.config/tensors/config.toml`
105
+
106
+ ```toml
107
+ [api]
108
+ civitai_key = "your-api-key"
109
+ ```
110
+
111
+ Or set via environment variable:
112
+ ```bash
113
+ export CIVITAI_API_KEY="your-api-key"
114
+ ```
115
+
116
+ ## Default Paths
117
+
118
+ Models are downloaded to XDG-compliant paths:
119
+
120
+ | Type | Path |
121
+ |------|------|
122
+ | Checkpoint | `~/.local/share/tensors/models/checkpoints/` |
123
+ | LoRA | `~/.local/share/tensors/models/loras/` |
124
+ | Metadata | `~/.local/share/tensors/metadata/` |
125
+
126
+ ## Search Options
127
+
128
+ | Option | Values |
129
+ |--------|--------|
130
+ | `-t, --type` | checkpoint, lora, embedding, vae, controlnet, locon |
131
+ | `-b, --base` | sd15, sdxl, pony, flux, illustrious |
132
+ | `-s, --sort` | downloads, rating, newest |
133
+ | `-n, --limit` | Number of results (default: 20) |
134
+
135
+ ## Development
136
+
137
+ ```bash
138
+ # Install dev dependencies
139
+ uv sync --group dev
140
+
141
+ # Run tests
142
+ uv run pytest
143
+
144
+ # Lint and format
145
+ uv run ruff check .
146
+ uv run ruff format .
147
+
148
+ # Type check
149
+ uv run mypy tensors.py
150
+ ```
151
+
152
+ ## License
153
+
154
+ MIT
@@ -0,0 +1,5 @@
1
+ tensors.py,sha256=iWMZ9U9hZFFrRW3X7eh8GeMKJ1oo6u0jAmmvyvotT0g,37521
2
+ tensors-0.1.1.dist-info/METADATA,sha256=B2b3anQ8-OZruHJLnHCgiJ74zlALCCDrPbNtlSYuWYs,2855
3
+ tensors-0.1.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
4
+ tensors-0.1.1.dist-info/entry_points.txt,sha256=wuNX2VdjEEyFmGaDk-iSxuecbHpixSrzHAWgfCkNUEY,37
5
+ tensors-0.1.1.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.28.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ tsr = tensors:main
tensors.py ADDED
@@ -0,0 +1,1071 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ tsr: Read safetensor metadata, search and download CivitAI models.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import hashlib
9
+ import json
10
+ import os
11
+ import re
12
+ import struct
13
+ import sys
14
+ import tomllib
15
+ from enum import Enum
16
+ from pathlib import Path
17
+ from typing import Annotated, Any
18
+
19
+ import httpx
20
+ import typer
21
+ from rich.console import Console
22
+ from rich.progress import (
23
+ BarColumn,
24
+ DownloadColumn,
25
+ Progress,
26
+ SpinnerColumn,
27
+ TaskProgressColumn,
28
+ TextColumn,
29
+ TimeRemainingColumn,
30
+ TransferSpeedColumn,
31
+ )
32
+ from rich.table import Table
33
+
34
+ # ============================================================================
35
+ # App and Console Setup
36
+ # ============================================================================
37
+
38
+ app = typer.Typer(
39
+ name="tsr",
40
+ help="Read safetensor metadata, search and download CivitAI models.",
41
+ no_args_is_help=True,
42
+ )
43
+ console = Console()
44
+
45
+ # ============================================================================
46
+ # Configuration
47
+ # ============================================================================
48
+
49
+ # XDG Base Directory spec
50
+ # Config: ~/.config/tensors/config.toml
51
+ # Data: ~/.local/share/tensors/models/, ~/.local/share/tensors/metadata/
52
+ CONFIG_DIR = Path(os.environ.get("XDG_CONFIG_HOME", Path.home() / ".config")) / "tensors"
53
+ CONFIG_FILE = CONFIG_DIR / "config.toml"
54
+
55
+ DATA_DIR = Path(os.environ.get("XDG_DATA_HOME", Path.home() / ".local" / "share")) / "tensors"
56
+ MODELS_DIR = DATA_DIR / "models"
57
+ METADATA_DIR = DATA_DIR / "metadata"
58
+
59
+ # Legacy config for migration
60
+ LEGACY_RC_FILE = Path.home() / ".sftrc"
61
+
62
+ # Default download paths by model type
63
+ DEFAULT_PATHS: dict[str, Path] = {
64
+ "Checkpoint": MODELS_DIR / "checkpoints",
65
+ "LORA": MODELS_DIR / "loras",
66
+ "LoCon": MODELS_DIR / "loras",
67
+ }
68
+
69
+ CIVITAI_API_BASE = "https://civitai.com/api/v1"
70
+ CIVITAI_DOWNLOAD_BASE = "https://civitai.com/api/download/models"
71
+
72
+
73
+ # ============================================================================
74
+ # Enums for CLI
75
+ # ============================================================================
76
+
77
+
78
+ class ModelType(str, Enum):
79
+ """CivitAI model types."""
80
+
81
+ checkpoint = "checkpoint"
82
+ lora = "lora"
83
+ embedding = "embedding"
84
+ vae = "vae"
85
+ controlnet = "controlnet"
86
+ locon = "locon"
87
+
88
+ def to_api(self) -> str:
89
+ """Convert to CivitAI API value."""
90
+ mapping = {
91
+ "checkpoint": "Checkpoint",
92
+ "lora": "LORA",
93
+ "embedding": "TextualInversion",
94
+ "vae": "VAE",
95
+ "controlnet": "Controlnet",
96
+ "locon": "LoCon",
97
+ }
98
+ return mapping[self.value]
99
+
100
+
101
+ class BaseModel(str, Enum):
102
+ """Common base models."""
103
+
104
+ sd15 = "sd15"
105
+ sdxl = "sdxl"
106
+ pony = "pony"
107
+ flux = "flux"
108
+ illustrious = "illustrious"
109
+
110
+ def to_api(self) -> str:
111
+ """Convert to CivitAI API value."""
112
+ mapping = {
113
+ "sd15": "SD 1.5",
114
+ "sdxl": "SDXL 1.0",
115
+ "pony": "Pony",
116
+ "flux": "Flux.1 D",
117
+ "illustrious": "Illustrious",
118
+ }
119
+ return mapping[self.value]
120
+
121
+
122
+ class SortOrder(str, Enum):
123
+ """Sort options for search."""
124
+
125
+ downloads = "downloads"
126
+ rating = "rating"
127
+ newest = "newest"
128
+
129
+ def to_api(self) -> str:
130
+ """Convert to CivitAI API value."""
131
+ mapping = {
132
+ "downloads": "Most Downloaded",
133
+ "rating": "Highest Rated",
134
+ "newest": "Newest",
135
+ }
136
+ return mapping[self.value]
137
+
138
+
139
+ # ============================================================================
140
+ # Config Functions
141
+ # ============================================================================
142
+
143
+
144
+ def load_config() -> dict[str, Any]:
145
+ """Load configuration from TOML config file."""
146
+ if CONFIG_FILE.exists():
147
+ with CONFIG_FILE.open("rb") as f:
148
+ return tomllib.load(f)
149
+ return {}
150
+
151
+
152
+ def save_config(config: dict[str, Any]) -> None:
153
+ """Save configuration to TOML config file."""
154
+ CONFIG_DIR.mkdir(parents=True, exist_ok=True)
155
+
156
+ lines: list[str] = []
157
+ for key, value in config.items():
158
+ if isinstance(value, dict):
159
+ lines.append(f"[{key}]")
160
+ for k, v in value.items():
161
+ if isinstance(v, str):
162
+ lines.append(f'{k} = "{v}"')
163
+ else:
164
+ lines.append(f"{k} = {v}")
165
+ lines.append("")
166
+ elif isinstance(value, str):
167
+ lines.append(f'{key} = "{value}"')
168
+ else:
169
+ lines.append(f"{key} = {value}")
170
+
171
+ CONFIG_FILE.write_text("\n".join(lines) + "\n")
172
+
173
+
174
+ def load_api_key() -> str | None:
175
+ """Load API key from config file or CIVITAI_API_KEY env var."""
176
+ # Check environment variable first
177
+ env_key = os.environ.get("CIVITAI_API_KEY")
178
+ if env_key:
179
+ return env_key
180
+
181
+ # Check TOML config file
182
+ config = load_config()
183
+ api_section = config.get("api", {})
184
+ if isinstance(api_section, dict):
185
+ key = api_section.get("civitai_key")
186
+ if key:
187
+ return str(key)
188
+
189
+ # Fall back to legacy RC file for migration
190
+ if LEGACY_RC_FILE.exists():
191
+ content = LEGACY_RC_FILE.read_text().strip()
192
+ if content:
193
+ return content
194
+ return None
195
+
196
+
197
+ def get_default_output_path(model_type: str | None) -> Path | None:
198
+ """Get default output path based on model type."""
199
+ if model_type and model_type in DEFAULT_PATHS:
200
+ return DEFAULT_PATHS[model_type]
201
+ return None
202
+
203
+
204
+ # ============================================================================
205
+ # Safetensor Functions
206
+ # ============================================================================
207
+
208
+
209
+ def read_safetensor_metadata(file_path: Path) -> dict[str, Any]:
210
+ """Read metadata from a safetensor file header."""
211
+ with file_path.open("rb") as f:
212
+ # First 8 bytes are the header size (little-endian u64)
213
+ header_size_bytes = f.read(8)
214
+ if len(header_size_bytes) < 8:
215
+ raise ValueError("Invalid safetensor file: too short")
216
+
217
+ header_size = struct.unpack("<Q", header_size_bytes)[0]
218
+
219
+ if header_size > 100_000_000: # 100MB sanity check
220
+ raise ValueError(f"Invalid header size: {header_size}")
221
+
222
+ header_bytes = f.read(header_size)
223
+ if len(header_bytes) < header_size:
224
+ raise ValueError("Invalid safetensor file: header truncated")
225
+
226
+ header: dict[str, Any] = json.loads(header_bytes.decode("utf-8"))
227
+
228
+ # Extract __metadata__ if present
229
+ metadata: dict[str, Any] = header.get("__metadata__", {})
230
+
231
+ # Count tensors (keys that aren't __metadata__)
232
+ tensor_count = sum(1 for k in header if k != "__metadata__")
233
+
234
+ return {
235
+ "metadata": metadata,
236
+ "tensor_count": tensor_count,
237
+ "header_size": header_size,
238
+ }
239
+
240
+
241
+ def compute_sha256(file_path: Path) -> str:
242
+ """Compute SHA256 hash of a file with progress display."""
243
+ file_size = file_path.stat().st_size
244
+ sha256 = hashlib.sha256()
245
+ chunk_size = 1024 * 1024 * 8 # 8MB chunks
246
+
247
+ with Progress(
248
+ SpinnerColumn(),
249
+ TextColumn("[progress.description]{task.description}"),
250
+ BarColumn(),
251
+ TaskProgressColumn(),
252
+ DownloadColumn(),
253
+ TransferSpeedColumn(),
254
+ TimeRemainingColumn(),
255
+ console=console,
256
+ ) as progress:
257
+ task = progress.add_task(f"[cyan]Hashing {file_path.name}...", total=file_size)
258
+
259
+ with file_path.open("rb") as f:
260
+ while chunk := f.read(chunk_size):
261
+ sha256.update(chunk)
262
+ progress.update(task, advance=len(chunk))
263
+
264
+ return sha256.hexdigest().upper()
265
+
266
+
267
+ def get_base_name(file_path: Path) -> str:
268
+ """Get base filename without .safetensors extension."""
269
+ name = file_path.name
270
+ for ext in (".safetensors", ".sft"):
271
+ if name.lower().endswith(ext):
272
+ return name[: -len(ext)]
273
+ return file_path.stem
274
+
275
+
276
+ # ============================================================================
277
+ # CivitAI API Functions
278
+ # ============================================================================
279
+
280
+
281
+ def _get_headers(api_key: str | None) -> dict[str, str]:
282
+ """Get headers for CivitAI API requests."""
283
+ headers: dict[str, str] = {}
284
+ if api_key:
285
+ headers["Authorization"] = f"Bearer {api_key}"
286
+ return headers
287
+
288
+
289
+ def fetch_civitai_model_version(
290
+ version_id: int, api_key: str | None = None
291
+ ) -> dict[str, Any] | None:
292
+ """Fetch model version information from CivitAI by version ID."""
293
+ url = f"{CIVITAI_API_BASE}/model-versions/{version_id}"
294
+
295
+ try:
296
+ response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
297
+ if response.status_code == 404:
298
+ return None
299
+ response.raise_for_status()
300
+ result: dict[str, Any] = response.json()
301
+ return result
302
+ except httpx.HTTPStatusError as e:
303
+ console.print(f"[red]API error: {e.response.status_code}[/red]")
304
+ return None
305
+ except httpx.RequestError as e:
306
+ console.print(f"[red]Request error: {e}[/red]")
307
+ return None
308
+
309
+
310
+ def fetch_civitai_model(model_id: int, api_key: str | None = None) -> dict[str, Any] | None:
311
+ """Fetch model information from CivitAI by model ID."""
312
+ url = f"{CIVITAI_API_BASE}/models/{model_id}"
313
+
314
+ with Progress(
315
+ SpinnerColumn(),
316
+ TextColumn("[progress.description]{task.description}"),
317
+ console=console,
318
+ transient=True,
319
+ ) as progress:
320
+ progress.add_task("[cyan]Fetching model from CivitAI...", total=None)
321
+
322
+ try:
323
+ response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
324
+ if response.status_code == 404:
325
+ return None
326
+ response.raise_for_status()
327
+ result: dict[str, Any] = response.json()
328
+ return result
329
+ except httpx.HTTPStatusError as e:
330
+ console.print(f"[red]API error: {e.response.status_code}[/red]")
331
+ return None
332
+ except httpx.RequestError as e:
333
+ console.print(f"[red]Request error: {e}[/red]")
334
+ return None
335
+
336
+
337
+ def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None = None) -> dict[str, Any] | None:
338
+ """Fetch model information from CivitAI by SHA256 hash."""
339
+ url = f"{CIVITAI_API_BASE}/model-versions/by-hash/{sha256_hash}"
340
+
341
+ with Progress(
342
+ SpinnerColumn(),
343
+ TextColumn("[progress.description]{task.description}"),
344
+ console=console,
345
+ transient=True,
346
+ ) as progress:
347
+ progress.add_task("[cyan]Fetching from CivitAI...", total=None)
348
+
349
+ try:
350
+ response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
351
+ if response.status_code == 404:
352
+ return None
353
+ response.raise_for_status()
354
+ result: dict[str, Any] = response.json()
355
+ return result
356
+ except httpx.HTTPStatusError as e:
357
+ console.print(f"[red]API error: {e.response.status_code}[/red]")
358
+ return None
359
+ except httpx.RequestError as e:
360
+ console.print(f"[red]Request error: {e}[/red]")
361
+ return None
362
+
363
+
364
+ def search_civitai(
365
+ query: str | None = None,
366
+ model_type: ModelType | None = None,
367
+ base_model: BaseModel | None = None,
368
+ sort: SortOrder = SortOrder.downloads,
369
+ limit: int = 20,
370
+ api_key: str | None = None,
371
+ ) -> dict[str, Any] | None:
372
+ """Search CivitAI models."""
373
+ params: dict[str, Any] = {
374
+ "limit": min(limit, 100),
375
+ "nsfw": "true",
376
+ }
377
+
378
+ # API quirk: query + filters don't work reliably together
379
+ # If we have filters, skip query and filter client-side
380
+ has_filters = model_type is not None or base_model is not None
381
+
382
+ if query and not has_filters:
383
+ params["query"] = query
384
+
385
+ if model_type:
386
+ params["types"] = model_type.to_api()
387
+
388
+ if base_model:
389
+ params["baseModels"] = base_model.to_api()
390
+
391
+ params["sort"] = sort.to_api()
392
+
393
+ # Request more if we need client-side filtering
394
+ if query and has_filters:
395
+ params["limit"] = 100
396
+
397
+ url = f"{CIVITAI_API_BASE}/models"
398
+
399
+ with Progress(
400
+ SpinnerColumn(),
401
+ TextColumn("[progress.description]{task.description}"),
402
+ console=console,
403
+ transient=True,
404
+ ) as progress:
405
+ progress.add_task("[cyan]Searching CivitAI...", total=None)
406
+
407
+ try:
408
+ response = httpx.get(url, params=params, headers=_get_headers(api_key), timeout=30.0)
409
+ response.raise_for_status()
410
+ result: dict[str, Any] = response.json()
411
+
412
+ # Client-side filtering when query + filters combined
413
+ if query and has_filters:
414
+ q_lower = query.lower()
415
+ result["items"] = [
416
+ m for m in result.get("items", []) if q_lower in m.get("name", "").lower()
417
+ ][:limit]
418
+
419
+ return result
420
+ except httpx.HTTPStatusError as e:
421
+ console.print(f"[red]API error: {e.response.status_code}[/red]")
422
+ return None
423
+ except httpx.RequestError as e:
424
+ console.print(f"[red]Request error: {e}[/red]")
425
+ return None
426
+
427
+
428
+ def download_model(
429
+ version_id: int,
430
+ dest_path: Path,
431
+ api_key: str | None = None,
432
+ resume: bool = True,
433
+ ) -> bool:
434
+ """Download a model from CivitAI by version ID with resume support."""
435
+ url = f"{CIVITAI_DOWNLOAD_BASE}/{version_id}"
436
+ params: dict[str, str] = {}
437
+ if api_key:
438
+ params["token"] = api_key
439
+
440
+ headers: dict[str, str] = {}
441
+ mode = "wb"
442
+ initial_size = 0
443
+
444
+ # Check for existing partial download
445
+ if resume and dest_path.exists():
446
+ initial_size = dest_path.stat().st_size
447
+ headers["Range"] = f"bytes={initial_size}-"
448
+ mode = "ab"
449
+ console.print(f"[cyan]Resuming download from {initial_size / (1024**2):.1f} MB[/cyan]")
450
+
451
+ try:
452
+ with httpx.stream(
453
+ "GET",
454
+ url,
455
+ params=params,
456
+ headers=headers,
457
+ follow_redirects=True,
458
+ timeout=httpx.Timeout(30.0, read=None),
459
+ ) as response:
460
+ if response.status_code == 416:
461
+ console.print("[green]File already fully downloaded.[/green]")
462
+ return True
463
+
464
+ response.raise_for_status()
465
+
466
+ content_length = response.headers.get("content-length")
467
+ total_size = int(content_length) + initial_size if content_length else 0
468
+
469
+ content_disp = response.headers.get("content-disposition", "")
470
+ if "filename=" in content_disp:
471
+ match = re.search(r'filename="?([^";\n]+)"?', content_disp)
472
+ if match and dest_path.is_dir():
473
+ dest_path = dest_path / match.group(1)
474
+
475
+ with Progress(
476
+ SpinnerColumn(),
477
+ TextColumn("[progress.description]{task.description}"),
478
+ BarColumn(),
479
+ TaskProgressColumn(),
480
+ DownloadColumn(),
481
+ TransferSpeedColumn(),
482
+ TimeRemainingColumn(),
483
+ console=console,
484
+ ) as progress:
485
+ task = progress.add_task(
486
+ f"[cyan]Downloading {dest_path.name}...",
487
+ total=total_size if total_size > 0 else None,
488
+ completed=initial_size,
489
+ )
490
+
491
+ with dest_path.open(mode) as f:
492
+ for chunk in response.iter_bytes(1024 * 1024):
493
+ f.write(chunk)
494
+ progress.update(task, advance=len(chunk))
495
+
496
+ console.print(f"[green]Downloaded:[/green] {dest_path}")
497
+ return True
498
+
499
+ except httpx.HTTPStatusError as e:
500
+ console.print(f"[red]Download error: HTTP {e.response.status_code}[/red]")
501
+ if e.response.status_code == 401:
502
+ console.print("[yellow]Hint: This model may require an API key.[/yellow]")
503
+ return False
504
+ except httpx.RequestError as e:
505
+ console.print(f"[red]Download error: {e}[/red]")
506
+ return False
507
+
508
+
509
+ # ============================================================================
510
+ # Display Functions
511
+ # ============================================================================
512
+
513
+
514
+ def _format_size(size_kb: float) -> str:
515
+ """Format size in KB to human-readable string."""
516
+ if size_kb < 1024:
517
+ return f"{size_kb:.0f} KB"
518
+ if size_kb < 1024 * 1024:
519
+ return f"{size_kb / 1024:.1f} MB"
520
+ return f"{size_kb / 1024 / 1024:.2f} GB"
521
+
522
+
523
+ def _format_count(count: int) -> str:
524
+ """Format large numbers with K/M suffix."""
525
+ if count < 1000:
526
+ return str(count)
527
+ if count < 1_000_000:
528
+ return f"{count / 1000:.1f}K"
529
+ return f"{count / 1_000_000:.1f}M"
530
+
531
+
532
+ def _display_file_info(file_path: Path, local_metadata: dict[str, Any], sha256_hash: str) -> None:
533
+ """Display file information table."""
534
+ file_table = Table(title="File Information", show_header=True, header_style="bold magenta")
535
+ file_table.add_column("Property", style="cyan")
536
+ file_table.add_column("Value", style="green")
537
+
538
+ file_table.add_row("File", str(file_path.name))
539
+ file_table.add_row("Path", str(file_path.parent))
540
+ file_table.add_row("Size", f"{file_path.stat().st_size / (1024**3):.2f} GB")
541
+ file_table.add_row("SHA256", sha256_hash)
542
+ file_table.add_row("Header Size", f"{local_metadata['header_size']:,} bytes")
543
+ file_table.add_row("Tensor Count", str(local_metadata["tensor_count"]))
544
+
545
+ console.print()
546
+ console.print(file_table)
547
+
548
+
549
+ def _display_local_metadata(local_metadata: dict[str, Any]) -> None:
550
+ """Display local safetensor metadata table."""
551
+ if local_metadata["metadata"]:
552
+ meta_table = Table(
553
+ title="Safetensor Metadata", show_header=True, header_style="bold magenta"
554
+ )
555
+ meta_table.add_column("Key", style="cyan")
556
+ meta_table.add_column("Value", style="green", max_width=80)
557
+
558
+ for key, value in sorted(local_metadata["metadata"].items()):
559
+ display_value = str(value)
560
+ if len(display_value) > 200:
561
+ display_value = display_value[:200] + "..."
562
+ meta_table.add_row(key, display_value)
563
+
564
+ console.print()
565
+ console.print(meta_table)
566
+ else:
567
+ console.print()
568
+ console.print("[yellow]No embedded metadata found in safetensor file.[/yellow]")
569
+
570
+
571
+ def _display_civitai_data(civitai_data: dict[str, Any] | None) -> None:
572
+ """Display CivitAI model information table."""
573
+ if not civitai_data:
574
+ console.print()
575
+ console.print("[yellow]Model not found on CivitAI.[/yellow]")
576
+ return
577
+
578
+ civit_table = Table(
579
+ title="CivitAI Model Information", show_header=True, header_style="bold magenta"
580
+ )
581
+ civit_table.add_column("Property", style="cyan")
582
+ civit_table.add_column("Value", style="green", max_width=80)
583
+
584
+ civit_table.add_row("Model ID", str(civitai_data.get("modelId", "N/A")))
585
+ civit_table.add_row("Version ID", str(civitai_data.get("id", "N/A")))
586
+ civit_table.add_row("Version Name", str(civitai_data.get("name", "N/A")))
587
+ civit_table.add_row("Base Model", str(civitai_data.get("baseModel", "N/A")))
588
+ civit_table.add_row("Created At", str(civitai_data.get("createdAt", "N/A")))
589
+
590
+ trained_words: list[str] = civitai_data.get("trainedWords", [])
591
+ if trained_words:
592
+ civit_table.add_row("Trigger Words", ", ".join(trained_words))
593
+
594
+ download_url = str(civitai_data.get("downloadUrl", "N/A"))
595
+ civit_table.add_row("Download URL", download_url)
596
+
597
+ files: list[dict[str, Any]] = civitai_data.get("files", [])
598
+ for f in files:
599
+ if f.get("primary"):
600
+ civit_table.add_row("Primary File", str(f.get("name", "N/A")))
601
+ civit_table.add_row("File Size (CivitAI)", _format_size(f.get("sizeKB", 0)))
602
+ meta: dict[str, Any] = f.get("metadata", {})
603
+ if meta:
604
+ civit_table.add_row("Format", str(meta.get("format", "N/A")))
605
+ civit_table.add_row("Precision", str(meta.get("fp", "N/A")))
606
+ civit_table.add_row("Size Type", str(meta.get("size", "N/A")))
607
+
608
+ console.print()
609
+ console.print(civit_table)
610
+
611
+ model_id = civitai_data.get("modelId")
612
+ if model_id:
613
+ console.print()
614
+ console.print(
615
+ f"[bold blue]View on CivitAI:[/bold blue] https://civitai.com/models/{model_id}"
616
+ )
617
+
618
+
619
+ def _display_model_info(model_data: dict[str, Any]) -> None:
620
+ """Display full CivitAI model information."""
621
+ model_table = Table(title="Model Information", show_header=True, header_style="bold magenta")
622
+ model_table.add_column("Property", style="cyan")
623
+ model_table.add_column("Value", style="green", max_width=80)
624
+
625
+ model_table.add_row("ID", str(model_data.get("id", "N/A")))
626
+ model_table.add_row("Name", str(model_data.get("name", "N/A")))
627
+ model_table.add_row("Type", str(model_data.get("type", "N/A")))
628
+ model_table.add_row("NSFW", str(model_data.get("nsfw", False)))
629
+
630
+ creator = model_data.get("creator", {})
631
+ if creator:
632
+ model_table.add_row("Creator", str(creator.get("username", "N/A")))
633
+
634
+ tags: list[str] = model_data.get("tags", [])
635
+ if tags:
636
+ model_table.add_row("Tags", ", ".join(tags[:10]) + ("..." if len(tags) > 10 else ""))
637
+
638
+ stats: dict[str, Any] = model_data.get("stats", {})
639
+ if stats:
640
+ model_table.add_row("Downloads", f"{stats.get('downloadCount', 0):,}")
641
+ model_table.add_row("Favorites", f"{stats.get('favoriteCount', 0):,}")
642
+ model_table.add_row(
643
+ "Rating", f"{stats.get('rating', 0):.1f} ({stats.get('ratingCount', 0)} ratings)"
644
+ )
645
+
646
+ mode = model_data.get("mode")
647
+ if mode:
648
+ model_table.add_row("Status", str(mode))
649
+
650
+ console.print()
651
+ console.print(model_table)
652
+
653
+ versions: list[dict[str, Any]] = model_data.get("modelVersions", [])
654
+ if versions:
655
+ ver_table = Table(title="Model Versions", show_header=True, header_style="bold magenta")
656
+ ver_table.add_column("ID", style="cyan")
657
+ ver_table.add_column("Name", style="green")
658
+ ver_table.add_column("Base Model", style="yellow")
659
+ ver_table.add_column("Created", style="blue")
660
+ ver_table.add_column("Primary File", style="white")
661
+
662
+ for ver in versions:
663
+ files: list[dict[str, Any]] = ver.get("files", [])
664
+ primary_file = next((f for f in files if f.get("primary")), files[0] if files else None)
665
+ file_info = ""
666
+ if primary_file:
667
+ file_info = (
668
+ f"{primary_file.get('name', 'N/A')} "
669
+ f"({_format_size(primary_file.get('sizeKB', 0))})"
670
+ )
671
+
672
+ created = str(ver.get("createdAt", "N/A"))[:10]
673
+ ver_table.add_row(
674
+ str(ver.get("id", "N/A")),
675
+ str(ver.get("name", "N/A")),
676
+ str(ver.get("baseModel", "N/A")),
677
+ created,
678
+ file_info,
679
+ )
680
+
681
+ console.print()
682
+ console.print(ver_table)
683
+
684
+ model_id = model_data.get("id")
685
+ if model_id:
686
+ console.print()
687
+ console.print(
688
+ f"[bold blue]View on CivitAI:[/bold blue] https://civitai.com/models/{model_id}"
689
+ )
690
+
691
+
692
+ def _display_search_results(results: dict[str, Any]) -> None:
693
+ """Display search results in a table."""
694
+ items = results.get("items", [])
695
+ if not items:
696
+ console.print("[yellow]No results found.[/yellow]")
697
+ return
698
+
699
+ table = Table(show_header=True, header_style="bold magenta")
700
+ table.add_column("ID", style="cyan", justify="right")
701
+ table.add_column("Name", style="green", max_width=40)
702
+ table.add_column("Type", style="yellow")
703
+ table.add_column("Base", style="blue")
704
+ table.add_column("Size", justify="right")
705
+ table.add_column("DLs", justify="right")
706
+ table.add_column("Rating", justify="right")
707
+
708
+ for model in items:
709
+ model_id = str(model.get("id", ""))
710
+ name = model.get("name", "N/A")
711
+ if len(name) > 40:
712
+ name = name[:37] + "..."
713
+ model_type = model.get("type", "N/A")
714
+
715
+ # Get latest version info
716
+ versions = model.get("modelVersions", [])
717
+ base_model = "N/A"
718
+ size = "N/A"
719
+ if versions:
720
+ latest = versions[0]
721
+ base_model = latest.get("baseModel", "N/A")
722
+ files = latest.get("files", [])
723
+ primary = next((f for f in files if f.get("primary")), files[0] if files else None)
724
+ if primary:
725
+ size = _format_size(primary.get("sizeKB", 0))
726
+
727
+ stats = model.get("stats", {})
728
+ downloads = _format_count(stats.get("downloadCount", 0))
729
+ rating = f"{stats.get('rating', 0):.1f}"
730
+
731
+ table.add_row(model_id, name, model_type, base_model, size, downloads, rating)
732
+
733
+ console.print()
734
+ console.print(table)
735
+
736
+ metadata = results.get("metadata", {})
737
+ total = metadata.get("totalItems", len(items))
738
+ console.print(f"\n[dim]Showing {len(items)} of {total:,} results[/dim]")
739
+ console.print("[dim]Use 'tsr get <id>' to view details or 'tsr dl -m <id>' to download[/dim]")
740
+
741
+
742
+ # ============================================================================
743
+ # CLI Commands
744
+ # ============================================================================
745
+
746
+
747
+ @app.command()
748
+ def info(
749
+ file: Annotated[Path, typer.Argument(help="Path to the safetensor file")],
750
+ api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
751
+ skip_civitai: Annotated[
752
+ bool, typer.Option("--skip-civitai", help="Skip CivitAI API lookup")
753
+ ] = False,
754
+ json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
755
+ save_to: Annotated[
756
+ Path | None, typer.Option("--save-to", help="Save metadata to directory")
757
+ ] = None,
758
+ ) -> None:
759
+ """Read safetensor metadata and fetch CivitAI info."""
760
+ file_path = file.resolve()
761
+
762
+ if not file_path.exists():
763
+ console.print(f"[red]Error: File not found: {file_path}[/red]")
764
+ raise typer.Exit(1)
765
+
766
+ if file_path.suffix.lower() not in (".safetensors", ".sft"):
767
+ console.print("[yellow]Warning: File does not have .safetensors extension[/yellow]")
768
+
769
+ try:
770
+ console.print(f"[bold]Reading safetensor file:[/bold] {file_path.name}")
771
+ local_metadata = read_safetensor_metadata(file_path)
772
+ sha256_hash = compute_sha256(file_path)
773
+
774
+ civitai_data = None
775
+ if not skip_civitai:
776
+ key = api_key or load_api_key()
777
+ civitai_data = fetch_civitai_by_hash(sha256_hash, key)
778
+
779
+ if json_output:
780
+ output = {
781
+ "file": str(file_path),
782
+ "sha256": sha256_hash,
783
+ "header_size": local_metadata["header_size"],
784
+ "tensor_count": local_metadata["tensor_count"],
785
+ "metadata": local_metadata["metadata"],
786
+ "civitai": civitai_data,
787
+ }
788
+ console.print_json(data=output)
789
+ else:
790
+ _display_file_info(file_path, local_metadata, sha256_hash)
791
+ _display_local_metadata(local_metadata)
792
+ _display_civitai_data(civitai_data)
793
+
794
+ if save_to:
795
+ output_dir = save_to.resolve()
796
+ if not output_dir.exists() or not output_dir.is_dir():
797
+ console.print(f"[red]Error: Invalid directory: {output_dir}[/red]")
798
+ raise typer.Exit(1)
799
+
800
+ base_name = get_base_name(file_path)
801
+ json_path = output_dir / f"{base_name}.json"
802
+ sha_path = output_dir / f"{base_name}.sha256"
803
+
804
+ output = {
805
+ "file": str(file_path),
806
+ "sha256": sha256_hash,
807
+ "header_size": local_metadata["header_size"],
808
+ "tensor_count": local_metadata["tensor_count"],
809
+ "metadata": local_metadata["metadata"],
810
+ "civitai": civitai_data,
811
+ }
812
+ json_path.write_text(json.dumps(output, indent=2))
813
+ sha_path.write_text(f"{sha256_hash} {file_path.name}\n")
814
+
815
+ console.print()
816
+ console.print(f"[green]Saved:[/green] {json_path}")
817
+ console.print(f"[green]Saved:[/green] {sha_path}")
818
+
819
+ except ValueError as e:
820
+ console.print(f"[red]Error reading safetensor: {e}[/red]")
821
+ raise typer.Exit(1) from e
822
+
823
+
824
+ @app.command()
825
+ def search(
826
+ query: Annotated[str | None, typer.Argument(help="Search query (optional)")] = None,
827
+ model_type: Annotated[
828
+ ModelType | None, typer.Option("-t", "--type", help="Model type filter")
829
+ ] = None,
830
+ base: Annotated[
831
+ BaseModel | None, typer.Option("-b", "--base", help="Base model filter")
832
+ ] = None,
833
+ sort: Annotated[
834
+ SortOrder, typer.Option("-s", "--sort", help="Sort order")
835
+ ] = SortOrder.downloads,
836
+ limit: Annotated[int, typer.Option("-n", "--limit", help="Max results")] = 20,
837
+ json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
838
+ api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
839
+ ) -> None:
840
+ """Search CivitAI models."""
841
+ key = api_key or load_api_key()
842
+
843
+ results = search_civitai(
844
+ query=query,
845
+ model_type=model_type,
846
+ base_model=base,
847
+ sort=sort,
848
+ limit=limit,
849
+ api_key=key,
850
+ )
851
+
852
+ if not results:
853
+ console.print("[red]Search failed.[/red]")
854
+ raise typer.Exit(1)
855
+
856
+ if json_output:
857
+ console.print_json(data=results)
858
+ else:
859
+ _display_search_results(results)
860
+
861
+
862
+ @app.command()
863
+ def get(
864
+ id_value: Annotated[int, typer.Argument(help="CivitAI model ID or version ID")],
865
+ version: Annotated[
866
+ bool, typer.Option("-v", "--version", help="Treat ID as version ID instead of model ID")
867
+ ] = False,
868
+ api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
869
+ json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
870
+ ) -> None:
871
+ """Fetch model information from CivitAI by model ID or version ID."""
872
+ key = api_key or load_api_key()
873
+
874
+ if version:
875
+ # Fetch by version ID
876
+ version_data = fetch_civitai_model_version(id_value, key)
877
+ if not version_data:
878
+ console.print(f"[red]Error: Version {id_value} not found on CivitAI.[/red]")
879
+ raise typer.Exit(1)
880
+
881
+ if json_output:
882
+ console.print_json(data=version_data)
883
+ else:
884
+ _display_civitai_data(version_data)
885
+ else:
886
+ # Fetch by model ID
887
+ model_data = fetch_civitai_model(id_value, key)
888
+ if not model_data:
889
+ console.print(f"[red]Error: Model {id_value} not found on CivitAI.[/red]")
890
+ raise typer.Exit(1)
891
+
892
+ if json_output:
893
+ console.print_json(data=model_data)
894
+ else:
895
+ _display_model_info(model_data)
896
+
897
+
898
+ def _resolve_version_id(
899
+ version_id: int | None,
900
+ hash_val: str | None,
901
+ model_id: int | None,
902
+ api_key: str | None,
903
+ ) -> int | None:
904
+ """Resolve version ID from hash or model ID."""
905
+ if version_id:
906
+ return version_id
907
+
908
+ if hash_val:
909
+ console.print(f"[cyan]Looking up model by hash: {hash_val[:16]}...[/cyan]")
910
+ civitai_data = fetch_civitai_by_hash(hash_val.upper(), api_key)
911
+ if not civitai_data:
912
+ console.print("[red]Error: Model not found on CivitAI for this hash.[/red]")
913
+ return None
914
+ vid: int | None = civitai_data.get("id")
915
+ if vid:
916
+ console.print(f"[green]Found:[/green] {civitai_data.get('name', 'N/A')}")
917
+ return vid
918
+
919
+ if model_id:
920
+ console.print(f"[cyan]Looking up model {model_id}...[/cyan]")
921
+ model_data = fetch_civitai_model(model_id, api_key)
922
+ if not model_data:
923
+ console.print(f"[red]Error: Model {model_id} not found.[/red]")
924
+ return None
925
+ versions = model_data.get("modelVersions", [])
926
+ if not versions:
927
+ console.print("[red]Error: Model has no versions.[/red]")
928
+ return None
929
+ latest = versions[0]
930
+ latest_vid: int | None = latest.get("id")
931
+ if latest_vid:
932
+ name = latest.get("name", "N/A")
933
+ console.print(f"[green]Found latest:[/green] {name} (ID: {latest_vid})")
934
+ return latest_vid
935
+
936
+ return None
937
+
938
+
939
+ def _prepare_download_dir(output: Path | None, model_type_str: str | None) -> Path | None:
940
+ """Prepare output directory for download."""
941
+ if output is None:
942
+ output_dir = get_default_output_path(model_type_str)
943
+ if output_dir is None:
944
+ console.print(
945
+ f"[red]Error: No default path for type '{model_type_str}'. "
946
+ "Use --output to specify.[/red]"
947
+ )
948
+ return None
949
+ console.print(f"[dim]Using default path for {model_type_str}: {output_dir}[/dim]")
950
+ else:
951
+ output_dir = output.resolve()
952
+
953
+ if not output_dir.exists():
954
+ console.print(f"[cyan]Creating directory: {output_dir}[/cyan]")
955
+ output_dir.mkdir(parents=True, exist_ok=True)
956
+ elif not output_dir.is_dir():
957
+ console.print(f"[red]Error: Not a directory: {output_dir}[/red]")
958
+ return None
959
+
960
+ return output_dir
961
+
962
+
963
+ @app.command("dl")
964
+ def download(
965
+ version_id: Annotated[
966
+ int | None, typer.Option("-v", "--version-id", help="Model version ID")
967
+ ] = None,
968
+ model_id: Annotated[
969
+ int | None, typer.Option("-m", "--model-id", help="Model ID (downloads latest)")
970
+ ] = None,
971
+ hash_val: Annotated[
972
+ str | None, typer.Option("-H", "--hash", help="SHA256 hash to look up")
973
+ ] = None,
974
+ output: Annotated[Path | None, typer.Option("-o", "--output", help="Output directory")] = None,
975
+ no_resume: Annotated[
976
+ bool, typer.Option("--no-resume", help="Don't resume partial downloads")
977
+ ] = False,
978
+ api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
979
+ ) -> None:
980
+ """Download a model from CivitAI."""
981
+ key = api_key or load_api_key()
982
+
983
+ resolved_version_id = _resolve_version_id(version_id, hash_val, model_id, key)
984
+ if not resolved_version_id:
985
+ if not version_id and not hash_val and not model_id:
986
+ console.print("[red]Error: Must specify --version-id, --model-id, or --hash[/red]")
987
+ raise typer.Exit(1)
988
+
989
+ console.print(f"[cyan]Fetching version info for {resolved_version_id}...[/cyan]")
990
+ version_info = fetch_civitai_model_version(resolved_version_id, key)
991
+ if not version_info:
992
+ console.print("[red]Error: Could not fetch model version info.[/red]")
993
+ raise typer.Exit(1)
994
+
995
+ model_type_str: str | None = version_info.get("model", {}).get("type")
996
+ output_dir = _prepare_download_dir(output, model_type_str)
997
+ if not output_dir:
998
+ raise typer.Exit(1)
999
+
1000
+ files: list[dict[str, Any]] = version_info.get("files", [])
1001
+ primary_file = next((f for f in files if f.get("primary")), files[0] if files else None)
1002
+ if not primary_file:
1003
+ console.print("[red]Error: No files found for this version.[/red]")
1004
+ raise typer.Exit(1)
1005
+
1006
+ filename = primary_file.get("name", f"model-{resolved_version_id}.safetensors")
1007
+ dest_path = output_dir / filename
1008
+
1009
+ table = Table(title="Model Download", show_header=True, header_style="bold magenta")
1010
+ table.add_column("Property", style="cyan")
1011
+ table.add_column("Value", style="green")
1012
+ table.add_row("Version", version_info.get("name", "N/A"))
1013
+ table.add_row("Base Model", version_info.get("baseModel", "N/A"))
1014
+ table.add_row("File", filename)
1015
+ table.add_row("Size", _format_size(primary_file.get("sizeKB", 0)))
1016
+ table.add_row("Destination", str(dest_path))
1017
+ console.print()
1018
+ console.print(table)
1019
+ console.print()
1020
+
1021
+ success = download_model(resolved_version_id, dest_path, key, resume=not no_resume)
1022
+ if not success:
1023
+ raise typer.Exit(1)
1024
+
1025
+
1026
+ @app.command()
1027
+ def config(
1028
+ show: Annotated[bool, typer.Option("--show", help="Show current config")] = False,
1029
+ set_key: Annotated[str | None, typer.Option("--set-key", help="Set CivitAI API key")] = None,
1030
+ ) -> None:
1031
+ """Manage configuration."""
1032
+ if set_key:
1033
+ cfg = load_config()
1034
+ if "api" not in cfg:
1035
+ cfg["api"] = {}
1036
+ cfg["api"]["civitai_key"] = set_key
1037
+ save_config(cfg)
1038
+ console.print(f"[green]API key saved to {CONFIG_FILE}[/green]")
1039
+ return
1040
+
1041
+ if show or (not set_key):
1042
+ console.print(f"[bold]Config file:[/bold] {CONFIG_FILE}")
1043
+ console.print(f"[bold]Config exists:[/bold] {CONFIG_FILE.exists()}")
1044
+
1045
+ key = load_api_key()
1046
+ if key:
1047
+ masked = key[:4] + "..." + key[-4:] if len(key) > 8 else "***"
1048
+ console.print(f"[bold]API key:[/bold] {masked}")
1049
+ else:
1050
+ console.print("[bold]API key:[/bold] [yellow]Not set[/yellow]")
1051
+
1052
+ console.print()
1053
+ console.print("[dim]Set API key with: tsr config --set-key YOUR_KEY[/dim]")
1054
+
1055
+
1056
+ def main() -> int:
1057
+ """Main entry point."""
1058
+ # Handle legacy invocation: tsr <file.safetensors> -> tsr info <file>
1059
+ if len(sys.argv) > 1 and not sys.argv[1].startswith("-"):
1060
+ arg = sys.argv[1]
1061
+ if arg not in ("info", "search", "get", "dl", "download", "config") and (
1062
+ arg.endswith(".safetensors") or arg.endswith(".sft") or Path(arg).exists()
1063
+ ):
1064
+ sys.argv = [sys.argv[0], "info", *sys.argv[1:]]
1065
+
1066
+ app()
1067
+ return 0
1068
+
1069
+
1070
+ if __name__ == "__main__":
1071
+ sys.exit(main())