tensors 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tensors/display.py ADDED
@@ -0,0 +1,331 @@
1
+ """Rich table display functions for tsr CLI."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING, Any
6
+
7
+ if TYPE_CHECKING:
8
+ from pathlib import Path
9
+
10
+ from rich.table import Table
11
+
12
+ if TYPE_CHECKING:
13
+ from rich.console import Console
14
+
15
+ # Size formatting constants
16
+ KB = 1024
17
+ MB_IN_KB = KB * KB
18
+ THOUSAND = 1000
19
+ MILLION = 1_000_000
20
+ MAX_TAGS_DISPLAY = 10
21
+
22
+
23
+ def _format_size(size_kb: float) -> str:
24
+ """Format size in KB to human-readable string."""
25
+ if size_kb < KB:
26
+ return f"{size_kb:.0f} KB"
27
+ if size_kb < MB_IN_KB:
28
+ return f"{size_kb / KB:.1f} MB"
29
+ return f"{size_kb / KB / KB:.2f} GB"
30
+
31
+
32
+ def _format_count(count: int) -> str:
33
+ """Format large numbers with K/M suffix."""
34
+ if count < THOUSAND:
35
+ return str(count)
36
+ if count < MILLION:
37
+ return f"{count / THOUSAND:.1f}K"
38
+ return f"{count / MILLION:.1f}M"
39
+
40
+
41
+ def display_file_info(file_path: Path, local_metadata: dict[str, Any], sha256_hash: str, console: Console) -> None:
42
+ """Display file information table."""
43
+ prop_width = 12
44
+
45
+ file_table = Table(title="File Information", show_header=True, header_style="bold magenta", expand=True)
46
+ file_table.add_column("Property", style="cyan", width=prop_width, no_wrap=True)
47
+ file_table.add_column("Value", style="green", no_wrap=True, overflow="ellipsis")
48
+
49
+ file_table.add_row("File", str(file_path.name))
50
+ file_table.add_row("Path", str(file_path.parent))
51
+ file_table.add_row("Size", f"{file_path.stat().st_size / (1024**3):.2f} GB")
52
+ file_table.add_row("SHA256", sha256_hash)
53
+ file_table.add_row("Header Size", f"{local_metadata['header_size']:,} bytes")
54
+ file_table.add_row("Tensor Count", str(local_metadata["tensor_count"]))
55
+
56
+ console.print()
57
+ console.print(file_table)
58
+
59
+
60
+ def display_local_metadata(local_metadata: dict[str, Any], console: Console, keys_filter: list[str] | None = None) -> None:
61
+ """Display local safetensor metadata table."""
62
+ if not local_metadata["metadata"]:
63
+ console.print()
64
+ console.print("[yellow]No embedded metadata found in safetensor file.[/yellow]")
65
+ return
66
+
67
+ metadata = local_metadata["metadata"]
68
+
69
+ # If specific keys requested, show them in full
70
+ if keys_filter:
71
+ for key in keys_filter:
72
+ if key in metadata:
73
+ console.print(f"[cyan]{key}[/cyan]: {metadata[key]}")
74
+ else:
75
+ console.print(f"[yellow]{key}: not found[/yellow]")
76
+ return
77
+
78
+ # Find the longest key to set column width
79
+ all_keys = list(metadata.keys())
80
+ key_width = max(len(k) for k in all_keys) if all_keys else 20
81
+
82
+ # Value width: terminal minus key column and table borders (7 chars)
83
+ terminal_width = console.size.width
84
+ value_width = terminal_width - key_width - 7
85
+
86
+ meta_table = Table(
87
+ title="Safetensor Metadata",
88
+ show_header=True,
89
+ header_style="bold magenta",
90
+ )
91
+ meta_table.add_column("Key", style="cyan", width=key_width, no_wrap=True)
92
+ meta_table.add_column("Value", style="green", width=value_width, no_wrap=True, overflow="ellipsis")
93
+
94
+ for key, value in sorted(metadata.items()):
95
+ meta_table.add_row(key, str(value))
96
+
97
+ console.print()
98
+ console.print(meta_table)
99
+
100
+
101
+ def _build_civitai_table(console: Console) -> tuple[Table, int]:
102
+ """Build CivitAI info table with proper column widths."""
103
+ prop_width = 14
104
+ terminal_width = console.size.width
105
+ overhead = 7
106
+ value_width = max(40, terminal_width - prop_width - overhead)
107
+
108
+ table = Table(title="CivitAI Model Information", show_header=True, header_style="bold magenta")
109
+ table.add_column("Property", style="cyan", width=prop_width, no_wrap=True)
110
+ table.add_column("Value", style="green", width=value_width, no_wrap=True, overflow="ellipsis")
111
+
112
+ return table, value_width
113
+
114
+
115
+ def display_civitai_data(civitai_data: dict[str, Any] | None, console: Console) -> None:
116
+ """Display CivitAI model information table."""
117
+ if not civitai_data:
118
+ console.print()
119
+ console.print("[yellow]Model not found on CivitAI.[/yellow]")
120
+ return
121
+
122
+ civit_table, _ = _build_civitai_table(console)
123
+
124
+ civit_table.add_row("Model ID", str(civitai_data.get("modelId", "N/A")))
125
+ civit_table.add_row("Version ID", str(civitai_data.get("id", "N/A")))
126
+ civit_table.add_row("Version Name", str(civitai_data.get("name", "N/A")))
127
+ civit_table.add_row("Base Model", str(civitai_data.get("baseModel", "N/A")))
128
+ civit_table.add_row("Created At", str(civitai_data.get("createdAt", "N/A")))
129
+
130
+ trained_words: list[str] = civitai_data.get("trainedWords", [])
131
+ if trained_words:
132
+ civit_table.add_row("Trigger Words", ", ".join(trained_words))
133
+
134
+ download_url = str(civitai_data.get("downloadUrl", "N/A"))
135
+ civit_table.add_row("Download URL", download_url)
136
+
137
+ files: list[dict[str, Any]] = civitai_data.get("files", [])
138
+ for f in files:
139
+ if f.get("primary"):
140
+ civit_table.add_row("Primary File", str(f.get("name", "N/A")))
141
+ civit_table.add_row("File Size", _format_size(f.get("sizeKB", 0)))
142
+ meta: dict[str, Any] = f.get("metadata", {})
143
+ if meta:
144
+ civit_table.add_row("Format", str(meta.get("format", "N/A")))
145
+ civit_table.add_row("Precision", str(meta.get("fp", "N/A")))
146
+ civit_table.add_row("Size Type", str(meta.get("size", "N/A")))
147
+
148
+ console.print()
149
+ console.print(civit_table)
150
+
151
+ model_id = civitai_data.get("modelId")
152
+ if model_id:
153
+ console.print()
154
+ console.print(f"[bold blue]View on CivitAI:[/bold blue] https://civitai.com/models/{model_id}")
155
+
156
+
157
+ def _build_model_table(console: Console) -> Table:
158
+ """Build model info table with proper column widths."""
159
+ prop_width = 10
160
+ terminal_width = console.size.width
161
+ overhead = 7
162
+ value_width = max(40, terminal_width - prop_width - overhead)
163
+
164
+ table = Table(title="Model Information", show_header=True, header_style="bold magenta")
165
+ table.add_column("Property", style="cyan", width=prop_width, no_wrap=True)
166
+ table.add_column("Value", style="green", width=value_width, no_wrap=True, overflow="ellipsis")
167
+
168
+ return table
169
+
170
+
171
+ def _add_model_basic_info(table: Table, model_data: dict[str, Any]) -> None:
172
+ """Add basic model info rows to table."""
173
+ table.add_row("ID", str(model_data.get("id", "N/A")))
174
+ table.add_row("Name", str(model_data.get("name", "N/A")))
175
+ table.add_row("Type", str(model_data.get("type", "N/A")))
176
+ table.add_row("NSFW", str(model_data.get("nsfw", False)))
177
+
178
+ creator = model_data.get("creator", {})
179
+ if creator:
180
+ table.add_row("Creator", str(creator.get("username", "N/A")))
181
+
182
+ tags: list[str] = model_data.get("tags", [])
183
+ if tags:
184
+ table.add_row("Tags", ", ".join(tags[:MAX_TAGS_DISPLAY]) + ("..." if len(tags) > MAX_TAGS_DISPLAY else ""))
185
+
186
+ stats: dict[str, Any] = model_data.get("stats", {})
187
+ if stats:
188
+ table.add_row("Downloads", f"{stats.get('downloadCount', 0):,}")
189
+ table.add_row("Likes", f"{stats.get('thumbsUpCount', 0):,}")
190
+
191
+ mode = model_data.get("mode")
192
+ if mode:
193
+ table.add_row("Status", str(mode))
194
+
195
+
196
+ def _build_versions_table(console: Console) -> Table:
197
+ """Build model versions table with proper column widths."""
198
+ id_width = 7
199
+ base_width = 20
200
+ created_width = 10
201
+ size_width = 8
202
+
203
+ terminal_width = console.size.width
204
+ fixed_width = id_width + base_width + created_width + size_width
205
+ overhead = 20
206
+ remaining = max(40, terminal_width - fixed_width - overhead)
207
+ name_width = remaining // 3
208
+ file_width = remaining - name_width
209
+
210
+ table = Table(title="Model Versions", show_header=True, header_style="bold magenta")
211
+ table.add_column("ID", style="cyan", width=id_width, no_wrap=True)
212
+ table.add_column("Name", style="green", width=name_width, no_wrap=True, overflow="ellipsis")
213
+ table.add_column("Base Model", style="yellow", width=base_width, no_wrap=True, overflow="ellipsis")
214
+ table.add_column("Created", style="blue", width=created_width, no_wrap=True)
215
+ table.add_column("Filename", style="white", width=file_width, no_wrap=True, overflow="ellipsis")
216
+ table.add_column("Size", justify="right", width=size_width, no_wrap=True)
217
+
218
+ return table
219
+
220
+
221
+ def _add_version_rows(table: Table, versions: list[dict[str, Any]]) -> None:
222
+ """Add version rows to versions table."""
223
+ for ver in versions:
224
+ files: list[dict[str, Any]] = ver.get("files", [])
225
+ primary_file = next((f for f in files if f.get("primary")), files[0] if files else None)
226
+ filename = "N/A"
227
+ size = "N/A"
228
+ if primary_file:
229
+ filename = primary_file.get("name", "N/A")
230
+ size = _format_size(primary_file.get("sizeKB", 0))
231
+
232
+ created = str(ver.get("createdAt", "N/A"))[:10]
233
+ table.add_row(
234
+ str(ver.get("id", "N/A")),
235
+ str(ver.get("name", "N/A")),
236
+ str(ver.get("baseModel", "N/A")),
237
+ created,
238
+ filename,
239
+ size,
240
+ )
241
+
242
+
243
+ def display_model_info(model_data: dict[str, Any], console: Console) -> None:
244
+ """Display full CivitAI model information."""
245
+ model_table = _build_model_table(console)
246
+ _add_model_basic_info(model_table, model_data)
247
+
248
+ console.print()
249
+ console.print(model_table)
250
+
251
+ versions: list[dict[str, Any]] = model_data.get("modelVersions", [])
252
+ if versions:
253
+ ver_table = _build_versions_table(console)
254
+ _add_version_rows(ver_table, versions)
255
+ console.print()
256
+ console.print(ver_table)
257
+
258
+ model_id = model_data.get("id")
259
+ if model_id:
260
+ console.print()
261
+ console.print(f"[bold blue]View on CivitAI:[/bold blue] https://civitai.com/models/{model_id}")
262
+
263
+
264
+ def _build_search_table(console: Console) -> Table:
265
+ """Build search results table with proper column widths."""
266
+ id_width = 7
267
+ type_width = 16
268
+ base_width = 20
269
+ size_width = 8
270
+ dls_width = 6
271
+ likes_width = 6
272
+
273
+ terminal_width = console.size.width
274
+ fixed_width = id_width + type_width + base_width + size_width + dls_width + likes_width
275
+ overhead = 23
276
+ name_width = max(20, terminal_width - fixed_width - overhead)
277
+
278
+ table = Table(show_header=True, header_style="bold magenta")
279
+ table.add_column("ID", style="cyan", justify="right", width=id_width, no_wrap=True)
280
+ table.add_column("Name", style="green", width=name_width, no_wrap=True, overflow="ellipsis")
281
+ table.add_column("Type", style="yellow", width=type_width, no_wrap=True)
282
+ table.add_column("Base", style="blue", width=base_width, no_wrap=True, overflow="ellipsis")
283
+ table.add_column("Size", justify="right", width=size_width, no_wrap=True)
284
+ table.add_column("DLs", justify="right", width=dls_width, no_wrap=True)
285
+ table.add_column("Likes", justify="right", width=likes_width, no_wrap=True)
286
+
287
+ return table
288
+
289
+
290
+ def _add_search_rows(table: Table, items: list[dict[str, Any]]) -> None:
291
+ """Add search result rows to table."""
292
+ for model in items:
293
+ model_id = str(model.get("id", ""))
294
+ name = model.get("name", "N/A")
295
+ model_type = model.get("type", "N/A")
296
+
297
+ versions = model.get("modelVersions", [])
298
+ base_model = "N/A"
299
+ size = "N/A"
300
+ if versions:
301
+ latest = versions[0]
302
+ base_model = latest.get("baseModel", "N/A")
303
+ files = latest.get("files", [])
304
+ primary = next((f for f in files if f.get("primary")), files[0] if files else None)
305
+ if primary:
306
+ size = _format_size(primary.get("sizeKB", 0))
307
+
308
+ stats = model.get("stats", {})
309
+ downloads = _format_count(stats.get("downloadCount", 0))
310
+ likes = _format_count(stats.get("thumbsUpCount", 0))
311
+
312
+ table.add_row(model_id, name, model_type, base_model, size, downloads, likes)
313
+
314
+
315
+ def display_search_results(results: dict[str, Any], console: Console) -> None:
316
+ """Display search results in a table."""
317
+ items = results.get("items", [])
318
+ if not items:
319
+ console.print("[yellow]No results found.[/yellow]")
320
+ return
321
+
322
+ table = _build_search_table(console)
323
+ _add_search_rows(table, items)
324
+
325
+ console.print()
326
+ console.print(table)
327
+
328
+ metadata = results.get("metadata", {})
329
+ total = metadata.get("totalItems", len(items))
330
+ console.print(f"\n[dim]Showing {len(items)} of {total:,} results[/dim]")
331
+ console.print("[dim]Use 'tsr get <id>' to view details or 'tsr dl -m <id>' to download[/dim]")
tensors/safetensor.py ADDED
@@ -0,0 +1,95 @@
1
+ """Safetensor file reading functions."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import hashlib
6
+ import json
7
+ import struct
8
+ from typing import TYPE_CHECKING, Any
9
+
10
+ if TYPE_CHECKING:
11
+ from pathlib import Path
12
+
13
+ from rich.progress import (
14
+ BarColumn,
15
+ DownloadColumn,
16
+ Progress,
17
+ SpinnerColumn,
18
+ TaskProgressColumn,
19
+ TextColumn,
20
+ TimeRemainingColumn,
21
+ TransferSpeedColumn,
22
+ )
23
+
24
+ if TYPE_CHECKING:
25
+ from rich.console import Console
26
+
27
+ # Safetensor format constants
28
+ HEADER_SIZE_BYTES = 8 # u64 little-endian
29
+ MAX_HEADER_SIZE = 100_000_000 # 100MB sanity check
30
+
31
+
32
+ def read_safetensor_metadata(file_path: Path) -> dict[str, Any]:
33
+ """Read metadata from a safetensor file header."""
34
+ with file_path.open("rb") as f:
35
+ header_size_bytes = f.read(HEADER_SIZE_BYTES)
36
+ if len(header_size_bytes) < HEADER_SIZE_BYTES:
37
+ raise ValueError("Invalid safetensor file: too short")
38
+
39
+ header_size = struct.unpack("<Q", header_size_bytes)[0]
40
+
41
+ if header_size > MAX_HEADER_SIZE:
42
+ raise ValueError(f"Invalid header size: {header_size}")
43
+
44
+ header_bytes = f.read(header_size)
45
+ if len(header_bytes) < header_size:
46
+ raise ValueError("Invalid safetensor file: header truncated")
47
+
48
+ header: dict[str, Any] = json.loads(header_bytes.decode("utf-8"))
49
+
50
+ # Extract __metadata__ if present
51
+ metadata: dict[str, Any] = header.get("__metadata__", {})
52
+
53
+ # Count tensors (keys that aren't __metadata__)
54
+ tensor_count = sum(1 for k in header if k != "__metadata__")
55
+
56
+ return {
57
+ "metadata": metadata,
58
+ "tensor_count": tensor_count,
59
+ "header_size": header_size,
60
+ }
61
+
62
+
63
+ def compute_sha256(file_path: Path, console: Console) -> str:
64
+ """Compute SHA256 hash of a file with progress display."""
65
+ file_size = file_path.stat().st_size
66
+ sha256 = hashlib.sha256()
67
+ chunk_size = 1024 * 1024 * 8 # 8MB chunks
68
+
69
+ with Progress(
70
+ SpinnerColumn(),
71
+ TextColumn("[progress.description]{task.description}"),
72
+ BarColumn(),
73
+ TaskProgressColumn(),
74
+ DownloadColumn(),
75
+ TransferSpeedColumn(),
76
+ TimeRemainingColumn(),
77
+ console=console,
78
+ ) as progress:
79
+ task = progress.add_task(f"[cyan]Hashing {file_path.name}...", total=file_size)
80
+
81
+ with file_path.open("rb") as f:
82
+ while chunk := f.read(chunk_size):
83
+ sha256.update(chunk)
84
+ progress.update(task, advance=len(chunk))
85
+
86
+ return sha256.hexdigest().upper()
87
+
88
+
89
+ def get_base_name(file_path: Path) -> str:
90
+ """Get base filename without .safetensors extension."""
91
+ name = file_path.name
92
+ for ext in (".safetensors", ".sft"):
93
+ if name.lower().endswith(ext):
94
+ return name[: -len(ext)]
95
+ return file_path.stem
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tensors
3
- Version: 0.1.3
3
+ Version: 0.1.5
4
4
  Summary: Read safetensor metadata and fetch CivitAI model information
5
5
  Requires-Python: >=3.12
6
6
  Requires-Dist: httpx>=0.27.0
@@ -0,0 +1,10 @@
1
+ tensors/__init__.py,sha256=Mtf3EBB_VNGnVokXnGdelRR1vZuz30fHGFb2eywgr_M,567
2
+ tensors/api.py,sha256=cSA7x2Dc_yaxwmVIVg19GTsn_J-0ChO_fd7cDNvX0dk,9634
3
+ tensors/cli.py,sha256=s5efFuBBDo6dvFfQSz6v58TM49vuDnvLe-hSzPr1AcQ,14750
4
+ tensors/config.py,sha256=dqpycZfsPCDC6QbpwTJmIASbE4MDAewuqDGQZnT7WtI,4744
5
+ tensors/display.py,sha256=SvOoVMT-tav_4xzGKcjdyhHi-pagzYQv0JzoCbrZjAA,12493
6
+ tensors/safetensor.py,sha256=CGporkoEXWXrPBaYp3mZZ_rVCXbGFNCVOA4P1AqCBOI,2787
7
+ tensors-0.1.5.dist-info/METADATA,sha256=QFeN_oYv37pOmMRC7zDcTyFSwV804WDmblynhV7E4Bs,2855
8
+ tensors-0.1.5.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
9
+ tensors-0.1.5.dist-info/entry_points.txt,sha256=wuNX2VdjEEyFmGaDk-iSxuecbHpixSrzHAWgfCkNUEY,37
10
+ tensors-0.1.5.dist-info/RECORD,,
@@ -1,5 +0,0 @@
1
- tensors.py,sha256=iWMZ9U9hZFFrRW3X7eh8GeMKJ1oo6u0jAmmvyvotT0g,37521
2
- tensors-0.1.3.dist-info/METADATA,sha256=RU2iMZh2h4RLurcR6Fw6WND-rqsnVBrJ9JVZ-dOK4wk,2855
3
- tensors-0.1.3.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
4
- tensors-0.1.3.dist-info/entry_points.txt,sha256=wuNX2VdjEEyFmGaDk-iSxuecbHpixSrzHAWgfCkNUEY,37
5
- tensors-0.1.3.dist-info/RECORD,,