tensors 0.1.3__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tensors/cli.py ADDED
@@ -0,0 +1,395 @@
1
+ """CLI application and commands for tsr."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import sys
7
+ from pathlib import Path
8
+ from typing import Annotated, Any
9
+
10
+ import typer
11
+ from rich.console import Console
12
+ from rich.table import Table
13
+
14
+ from tensors.api import (
15
+ download_model,
16
+ fetch_civitai_by_hash,
17
+ fetch_civitai_model,
18
+ fetch_civitai_model_version,
19
+ search_civitai,
20
+ )
21
+ from tensors.config import (
22
+ CONFIG_FILE,
23
+ BaseModel,
24
+ ModelType,
25
+ SortOrder,
26
+ get_default_output_path,
27
+ load_api_key,
28
+ load_config,
29
+ save_config,
30
+ )
31
+ from tensors.display import (
32
+ _format_size,
33
+ display_civitai_data,
34
+ display_file_info,
35
+ display_local_metadata,
36
+ display_model_info,
37
+ display_search_results,
38
+ )
39
+ from tensors.safetensor import compute_sha256, get_base_name, read_safetensor_metadata
40
+
41
+ # Key masking threshold
42
+ MIN_KEY_LENGTH_FOR_MASKING = 8
43
+
44
+ app = typer.Typer(
45
+ name="tsr",
46
+ help="Read safetensor metadata, search and download CivitAI models.",
47
+ no_args_is_help=True,
48
+ )
49
+ console = Console()
50
+
51
+
52
+ @app.command()
53
+ def info(
54
+ file: Annotated[Path, typer.Argument(help="Path to the safetensor file")],
55
+ meta: Annotated[list[str] | None, typer.Option("--meta", "-m", help="Show specific metadata key(s) in full")] = None,
56
+ api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
57
+ skip_civitai: Annotated[bool, typer.Option("--skip-civitai", help="Skip CivitAI API lookup")] = False,
58
+ json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
59
+ save_to: Annotated[Path | None, typer.Option("--save-to", help="Save metadata to directory")] = None,
60
+ ) -> None:
61
+ """Read safetensor metadata and fetch CivitAI info."""
62
+ file_path = file.resolve()
63
+
64
+ if not file_path.exists():
65
+ console.print(f"[red]Error: File not found: {file_path}[/red]")
66
+ raise typer.Exit(1)
67
+
68
+ if file_path.suffix.lower() not in (".safetensors", ".sft"):
69
+ console.print("[yellow]Warning: File does not have .safetensors extension[/yellow]")
70
+
71
+ try:
72
+ local_metadata = read_safetensor_metadata(file_path)
73
+
74
+ if meta:
75
+ display_local_metadata(local_metadata, console, keys_filter=meta)
76
+ return
77
+
78
+ console.print(f"[bold]Reading safetensor file:[/bold] {file_path.name}")
79
+ sha256_hash = compute_sha256(file_path, console)
80
+
81
+ civitai_data = None
82
+ if not skip_civitai:
83
+ key = api_key or load_api_key()
84
+ civitai_data = fetch_civitai_by_hash(sha256_hash, key, console)
85
+
86
+ if json_output:
87
+ _output_info_json(file_path, sha256_hash, local_metadata, civitai_data)
88
+ else:
89
+ display_file_info(file_path, local_metadata, sha256_hash, console)
90
+ display_local_metadata(local_metadata, console)
91
+ display_civitai_data(civitai_data, console)
92
+
93
+ if save_to:
94
+ _save_metadata(save_to, file_path, sha256_hash, local_metadata, civitai_data)
95
+
96
+ except ValueError as e:
97
+ console.print(f"[red]Error reading safetensor: {e}[/red]")
98
+ raise typer.Exit(1) from e
99
+
100
+
101
+ def _output_info_json(
102
+ file_path: Path,
103
+ sha256_hash: str,
104
+ local_metadata: dict[str, Any],
105
+ civitai_data: dict[str, Any] | None,
106
+ ) -> None:
107
+ """Output info command result as JSON."""
108
+ output = {
109
+ "file": str(file_path),
110
+ "sha256": sha256_hash,
111
+ "header_size": local_metadata["header_size"],
112
+ "tensor_count": local_metadata["tensor_count"],
113
+ "metadata": local_metadata["metadata"],
114
+ "civitai": civitai_data,
115
+ }
116
+ console.print_json(data=output)
117
+
118
+
119
+ def _save_metadata(
120
+ save_to: Path,
121
+ file_path: Path,
122
+ sha256_hash: str,
123
+ local_metadata: dict[str, Any],
124
+ civitai_data: dict[str, Any] | None,
125
+ ) -> None:
126
+ """Save metadata to directory."""
127
+ output_dir = save_to.resolve()
128
+ if not output_dir.exists() or not output_dir.is_dir():
129
+ console.print(f"[red]Error: Invalid directory: {output_dir}[/red]")
130
+ raise typer.Exit(1)
131
+
132
+ base_name = get_base_name(file_path)
133
+ json_path = output_dir / f"{base_name}.json"
134
+ sha_path = output_dir / f"{base_name}.sha256"
135
+
136
+ output = {
137
+ "file": str(file_path),
138
+ "sha256": sha256_hash,
139
+ "header_size": local_metadata["header_size"],
140
+ "tensor_count": local_metadata["tensor_count"],
141
+ "metadata": local_metadata["metadata"],
142
+ "civitai": civitai_data,
143
+ }
144
+ json_path.write_text(json.dumps(output, indent=2))
145
+ sha_path.write_text(f"{sha256_hash} {file_path.name}\n")
146
+
147
+ console.print()
148
+ console.print(f"[green]Saved:[/green] {json_path}")
149
+ console.print(f"[green]Saved:[/green] {sha_path}")
150
+
151
+
152
+ @app.command()
153
+ def search(
154
+ query: Annotated[str | None, typer.Argument(help="Search query (optional)")] = None,
155
+ model_type: Annotated[ModelType | None, typer.Option("-t", "--type", help="Model type filter")] = None,
156
+ base: Annotated[BaseModel | None, typer.Option("-b", "--base", help="Base model filter")] = None,
157
+ sort: Annotated[SortOrder, typer.Option("-s", "--sort", help="Sort order")] = SortOrder.downloads,
158
+ limit: Annotated[int, typer.Option("-n", "--limit", help="Max results")] = 20,
159
+ json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
160
+ api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
161
+ ) -> None:
162
+ """Search CivitAI models."""
163
+ key = api_key or load_api_key()
164
+
165
+ results = search_civitai(
166
+ query=query,
167
+ model_type=model_type,
168
+ base_model=base,
169
+ sort=sort,
170
+ limit=limit,
171
+ api_key=key,
172
+ console=console,
173
+ )
174
+
175
+ if not results:
176
+ console.print("[red]Search failed.[/red]")
177
+ raise typer.Exit(1)
178
+
179
+ if json_output:
180
+ console.print_json(data=results)
181
+ else:
182
+ display_search_results(results, console)
183
+
184
+
185
+ @app.command()
186
+ def get(
187
+ id_value: Annotated[int, typer.Argument(help="CivitAI model ID or version ID")],
188
+ version: Annotated[bool, typer.Option("-v", "--version", help="Treat ID as version ID instead of model ID")] = False,
189
+ api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
190
+ json_output: Annotated[bool, typer.Option("--json", "-j", help="Output as JSON")] = False,
191
+ ) -> None:
192
+ """Fetch model information from CivitAI by model ID or version ID."""
193
+ key = api_key or load_api_key()
194
+
195
+ if version:
196
+ version_data = fetch_civitai_model_version(id_value, key, console)
197
+ if not version_data:
198
+ console.print(f"[red]Error: Version {id_value} not found on CivitAI.[/red]")
199
+ raise typer.Exit(1)
200
+
201
+ if json_output:
202
+ console.print_json(data=version_data)
203
+ else:
204
+ display_civitai_data(version_data, console)
205
+ else:
206
+ model_data = fetch_civitai_model(id_value, key, console)
207
+ if not model_data:
208
+ console.print(f"[red]Error: Model {id_value} not found on CivitAI.[/red]")
209
+ raise typer.Exit(1)
210
+
211
+ if json_output:
212
+ console.print_json(data=model_data)
213
+ else:
214
+ display_model_info(model_data, console)
215
+
216
+
217
+ def _resolve_by_hash(hash_val: str, api_key: str | None) -> int | None:
218
+ """Resolve version ID from SHA256 hash."""
219
+ console.print(f"[cyan]Looking up model by hash: {hash_val[:16]}...[/cyan]")
220
+ civitai_data = fetch_civitai_by_hash(hash_val.upper(), api_key, console)
221
+ if not civitai_data:
222
+ console.print("[red]Error: Model not found on CivitAI for this hash.[/red]")
223
+ return None
224
+ vid: int | None = civitai_data.get("id")
225
+ if vid:
226
+ console.print(f"[green]Found:[/green] {civitai_data.get('name', 'N/A')}")
227
+ return vid
228
+
229
+
230
+ def _resolve_by_model_id(model_id: int, api_key: str | None) -> int | None:
231
+ """Resolve latest version ID from model ID."""
232
+ console.print(f"[cyan]Looking up model {model_id}...[/cyan]")
233
+ model_data = fetch_civitai_model(model_id, api_key, console)
234
+ if not model_data:
235
+ console.print(f"[red]Error: Model {model_id} not found.[/red]")
236
+ return None
237
+ versions = model_data.get("modelVersions", [])
238
+ if not versions:
239
+ console.print("[red]Error: Model has no versions.[/red]")
240
+ return None
241
+ latest = versions[0]
242
+ latest_vid: int | None = latest.get("id")
243
+ if latest_vid:
244
+ console.print(f"[green]Found latest:[/green] {latest.get('name', 'N/A')} (ID: {latest_vid})")
245
+ return latest_vid
246
+
247
+
248
+ def _resolve_version_id(
249
+ version_id: int | None,
250
+ hash_val: str | None,
251
+ model_id: int | None,
252
+ api_key: str | None,
253
+ ) -> int | None:
254
+ """Resolve version ID from direct ID, hash, or model ID."""
255
+ if version_id:
256
+ return version_id
257
+ if hash_val:
258
+ return _resolve_by_hash(hash_val, api_key)
259
+ if model_id:
260
+ return _resolve_by_model_id(model_id, api_key)
261
+ return None
262
+
263
+
264
+ def _prepare_download_dir(output: Path | None, model_type_str: str | None) -> Path | None:
265
+ """Prepare output directory for download."""
266
+ if output is None:
267
+ output_dir = get_default_output_path(model_type_str)
268
+ if output_dir is None:
269
+ console.print(f"[red]Error: No default path for type '{model_type_str}'. Use --output to specify.[/red]")
270
+ return None
271
+ console.print(f"[dim]Using default path for {model_type_str}: {output_dir}[/dim]")
272
+ else:
273
+ output_dir = output.resolve()
274
+
275
+ if not output_dir.exists():
276
+ console.print(f"[cyan]Creating directory: {output_dir}[/cyan]")
277
+ output_dir.mkdir(parents=True, exist_ok=True)
278
+ elif not output_dir.is_dir():
279
+ console.print(f"[red]Error: Not a directory: {output_dir}[/red]")
280
+ return None
281
+
282
+ return output_dir
283
+
284
+
285
+ @app.command("dl")
286
+ def download(
287
+ version_id: Annotated[int | None, typer.Option("-v", "--version-id", help="Model version ID")] = None,
288
+ model_id: Annotated[int | None, typer.Option("-m", "--model-id", help="Model ID (downloads latest)")] = None,
289
+ hash_val: Annotated[str | None, typer.Option("-H", "--hash", help="SHA256 hash to look up")] = None,
290
+ output: Annotated[Path | None, typer.Option("-o", "--output", help="Output directory")] = None,
291
+ no_resume: Annotated[bool, typer.Option("--no-resume", help="Don't resume partial downloads")] = False,
292
+ api_key: Annotated[str | None, typer.Option("--api-key", help="CivitAI API key")] = None,
293
+ ) -> None:
294
+ """Download a model from CivitAI."""
295
+ key = api_key or load_api_key()
296
+
297
+ resolved_version_id = _resolve_version_id(version_id, hash_val, model_id, key)
298
+ if not resolved_version_id:
299
+ if not version_id and not hash_val and not model_id:
300
+ console.print("[red]Error: Must specify --version-id, --model-id, or --hash[/red]")
301
+ raise typer.Exit(1)
302
+
303
+ console.print(f"[cyan]Fetching version info for {resolved_version_id}...[/cyan]")
304
+ version_info = fetch_civitai_model_version(resolved_version_id, key, console)
305
+ if not version_info:
306
+ console.print("[red]Error: Could not fetch model version info.[/red]")
307
+ raise typer.Exit(1)
308
+
309
+ model_type_str: str | None = version_info.get("model", {}).get("type")
310
+ output_dir = _prepare_download_dir(output, model_type_str)
311
+ if not output_dir:
312
+ raise typer.Exit(1)
313
+
314
+ files: list[dict[str, Any]] = version_info.get("files", [])
315
+ primary_file = next((f for f in files if f.get("primary")), files[0] if files else None)
316
+ if not primary_file:
317
+ console.print("[red]Error: No files found for this version.[/red]")
318
+ raise typer.Exit(1)
319
+
320
+ filename = primary_file.get("name", f"model-{resolved_version_id}.safetensors")
321
+ dest_path = output_dir / filename
322
+
323
+ _display_download_info(version_info, filename, primary_file, dest_path)
324
+
325
+ success = download_model(resolved_version_id, dest_path, key, console, resume=not no_resume)
326
+ if not success:
327
+ raise typer.Exit(1)
328
+
329
+
330
+ def _display_download_info(
331
+ version_info: dict[str, Any],
332
+ filename: str,
333
+ primary_file: dict[str, Any],
334
+ dest_path: Path,
335
+ ) -> None:
336
+ """Display download info table."""
337
+ table = Table(title="Model Download", show_header=True, header_style="bold magenta")
338
+ table.add_column("Property", style="cyan")
339
+ table.add_column("Value", style="green")
340
+ table.add_row("Version", version_info.get("name", "N/A"))
341
+ table.add_row("Base Model", version_info.get("baseModel", "N/A"))
342
+ table.add_row("File", filename)
343
+ table.add_row("Size", _format_size(primary_file.get("sizeKB", 0)))
344
+ table.add_row("Destination", str(dest_path))
345
+ console.print()
346
+ console.print(table)
347
+ console.print()
348
+
349
+
350
+ @app.command()
351
+ def config(
352
+ show: Annotated[bool, typer.Option("--show", help="Show current config")] = False,
353
+ set_key: Annotated[str | None, typer.Option("--set-key", help="Set CivitAI API key")] = None,
354
+ ) -> None:
355
+ """Manage configuration."""
356
+ if set_key:
357
+ cfg = load_config()
358
+ if "api" not in cfg:
359
+ cfg["api"] = {}
360
+ cfg["api"]["civitai_key"] = set_key
361
+ save_config(cfg)
362
+ console.print(f"[green]API key saved to {CONFIG_FILE}[/green]")
363
+ return
364
+
365
+ if show or (not set_key):
366
+ console.print(f"[bold]Config file:[/bold] {CONFIG_FILE}")
367
+ console.print(f"[bold]Config exists:[/bold] {CONFIG_FILE.exists()}")
368
+
369
+ key = load_api_key()
370
+ if key:
371
+ masked = key[:4] + "..." + key[-4:] if len(key) > MIN_KEY_LENGTH_FOR_MASKING else "***"
372
+ console.print(f"[bold]API key:[/bold] {masked}")
373
+ else:
374
+ console.print("[bold]API key:[/bold] [yellow]Not set[/yellow]")
375
+
376
+ console.print()
377
+ console.print("[dim]Set API key with: tsr config --set-key YOUR_KEY[/dim]")
378
+
379
+
380
+ def main() -> int:
381
+ """Main entry point."""
382
+ # Handle legacy invocation: tsr <file.safetensors> -> tsr info <file>
383
+ if len(sys.argv) > 1 and not sys.argv[1].startswith("-"):
384
+ arg = sys.argv[1]
385
+ if arg not in ("info", "search", "get", "dl", "download", "config") and (
386
+ arg.endswith(".safetensors") or arg.endswith(".sft") or Path(arg).exists()
387
+ ):
388
+ sys.argv = [sys.argv[0], "info", *sys.argv[1:]]
389
+
390
+ app()
391
+ return 0
392
+
393
+
394
+ if __name__ == "__main__":
395
+ sys.exit(main())
tensors/config.py ADDED
@@ -0,0 +1,166 @@
1
+ """Configuration, constants, and enums for tsr CLI."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import tomllib
7
+ from enum import Enum
8
+ from pathlib import Path
9
+ from typing import Any
10
+
11
+ # ============================================================================
12
+ # XDG Base Directory Configuration
13
+ # ============================================================================
14
+
15
+ # Config: ~/.config/tensors/config.toml
16
+ # Data: ~/.local/share/tensors/models/, ~/.local/share/tensors/metadata/
17
+ CONFIG_DIR = Path(os.environ.get("XDG_CONFIG_HOME", Path.home() / ".config")) / "tensors"
18
+ CONFIG_FILE = CONFIG_DIR / "config.toml"
19
+
20
+ DATA_DIR = Path(os.environ.get("XDG_DATA_HOME", Path.home() / ".local" / "share")) / "tensors"
21
+ MODELS_DIR = DATA_DIR / "models"
22
+ METADATA_DIR = DATA_DIR / "metadata"
23
+
24
+ # Legacy config for migration
25
+ LEGACY_RC_FILE = Path.home() / ".sftrc"
26
+
27
+ # Default download paths by model type
28
+ DEFAULT_PATHS: dict[str, Path] = {
29
+ "Checkpoint": MODELS_DIR / "checkpoints",
30
+ "LORA": MODELS_DIR / "loras",
31
+ "LoCon": MODELS_DIR / "loras",
32
+ }
33
+
34
+ CIVITAI_API_BASE = "https://civitai.com/api/v1"
35
+ CIVITAI_DOWNLOAD_BASE = "https://civitai.com/api/download/models"
36
+
37
+
38
+ # ============================================================================
39
+ # Enums for CLI
40
+ # ============================================================================
41
+
42
+
43
+ class ModelType(str, Enum):
44
+ """CivitAI model types."""
45
+
46
+ checkpoint = "checkpoint"
47
+ lora = "lora"
48
+ embedding = "embedding"
49
+ vae = "vae"
50
+ controlnet = "controlnet"
51
+ locon = "locon"
52
+
53
+ def to_api(self) -> str:
54
+ """Convert to CivitAI API value."""
55
+ mapping = {
56
+ "checkpoint": "Checkpoint",
57
+ "lora": "LORA",
58
+ "embedding": "TextualInversion",
59
+ "vae": "VAE",
60
+ "controlnet": "Controlnet",
61
+ "locon": "LoCon",
62
+ }
63
+ return mapping[self.value]
64
+
65
+
66
+ class BaseModel(str, Enum):
67
+ """Common base models."""
68
+
69
+ sd15 = "sd15"
70
+ sdxl = "sdxl"
71
+ pony = "pony"
72
+ flux = "flux"
73
+ illustrious = "illustrious"
74
+
75
+ def to_api(self) -> str:
76
+ """Convert to CivitAI API value."""
77
+ mapping = {
78
+ "sd15": "SD 1.5",
79
+ "sdxl": "SDXL 1.0",
80
+ "pony": "Pony",
81
+ "flux": "Flux.1 D",
82
+ "illustrious": "Illustrious",
83
+ }
84
+ return mapping[self.value]
85
+
86
+
87
+ class SortOrder(str, Enum):
88
+ """Sort options for search."""
89
+
90
+ downloads = "downloads"
91
+ rating = "rating"
92
+ newest = "newest"
93
+
94
+ def to_api(self) -> str:
95
+ """Convert to CivitAI API value."""
96
+ mapping = {
97
+ "downloads": "Most Downloaded",
98
+ "rating": "Highest Rated",
99
+ "newest": "Newest",
100
+ }
101
+ return mapping[self.value]
102
+
103
+
104
+ # ============================================================================
105
+ # Config Functions
106
+ # ============================================================================
107
+
108
+
109
+ def load_config() -> dict[str, Any]:
110
+ """Load configuration from TOML config file."""
111
+ if CONFIG_FILE.exists():
112
+ with CONFIG_FILE.open("rb") as f:
113
+ return tomllib.load(f)
114
+ return {}
115
+
116
+
117
+ def save_config(config: dict[str, Any]) -> None:
118
+ """Save configuration to TOML config file."""
119
+ CONFIG_DIR.mkdir(parents=True, exist_ok=True)
120
+
121
+ lines: list[str] = []
122
+ for key, value in config.items():
123
+ if isinstance(value, dict):
124
+ lines.append(f"[{key}]")
125
+ for k, v in value.items():
126
+ if isinstance(v, str):
127
+ lines.append(f'{k} = "{v}"')
128
+ else:
129
+ lines.append(f"{k} = {v}")
130
+ lines.append("")
131
+ elif isinstance(value, str):
132
+ lines.append(f'{key} = "{value}"')
133
+ else:
134
+ lines.append(f"{key} = {value}")
135
+
136
+ CONFIG_FILE.write_text("\n".join(lines) + "\n")
137
+
138
+
139
+ def load_api_key() -> str | None:
140
+ """Load API key from config file or CIVITAI_API_KEY env var."""
141
+ # Check environment variable first
142
+ env_key = os.environ.get("CIVITAI_API_KEY")
143
+ if env_key:
144
+ return env_key
145
+
146
+ # Check TOML config file
147
+ config = load_config()
148
+ api_section = config.get("api", {})
149
+ if isinstance(api_section, dict):
150
+ key = api_section.get("civitai_key")
151
+ if key:
152
+ return str(key)
153
+
154
+ # Fall back to legacy RC file for migration
155
+ if LEGACY_RC_FILE.exists():
156
+ content = LEGACY_RC_FILE.read_text().strip()
157
+ if content:
158
+ return content
159
+ return None
160
+
161
+
162
+ def get_default_output_path(model_type: str | None) -> Path | None:
163
+ """Get default output path based on model type."""
164
+ if model_type and model_type in DEFAULT_PATHS:
165
+ return DEFAULT_PATHS[model_type]
166
+ return None