tensors 0.1.1__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tensors/__init__.py ADDED
@@ -0,0 +1,26 @@
1
+ """tsr: Read safetensor metadata, search and download CivitAI models."""
2
+
3
+ from tensors.cli import main
4
+ from tensors.config import (
5
+ CONFIG_DIR,
6
+ CONFIG_FILE,
7
+ LEGACY_RC_FILE,
8
+ get_default_output_path,
9
+ load_api_key,
10
+ load_config,
11
+ save_config,
12
+ )
13
+ from tensors.safetensor import get_base_name, read_safetensor_metadata
14
+
15
+ __all__ = [
16
+ "CONFIG_DIR",
17
+ "CONFIG_FILE",
18
+ "LEGACY_RC_FILE",
19
+ "get_base_name",
20
+ "get_default_output_path",
21
+ "load_api_key",
22
+ "load_config",
23
+ "main",
24
+ "read_safetensor_metadata",
25
+ "save_config",
26
+ ]
tensors/api.py ADDED
@@ -0,0 +1,288 @@
1
+ """CivitAI API functions."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+ from http import HTTPStatus
7
+ from typing import TYPE_CHECKING, Any
8
+
9
+ if TYPE_CHECKING:
10
+ from pathlib import Path
11
+
12
+ import httpx
13
+ from rich.progress import (
14
+ BarColumn,
15
+ DownloadColumn,
16
+ Progress,
17
+ SpinnerColumn,
18
+ TaskProgressColumn,
19
+ TextColumn,
20
+ TimeRemainingColumn,
21
+ TransferSpeedColumn,
22
+ )
23
+
24
+ from tensors.config import CIVITAI_API_BASE, CIVITAI_DOWNLOAD_BASE, BaseModel, ModelType, SortOrder
25
+
26
+ if TYPE_CHECKING:
27
+ from rich.console import Console
28
+
29
+
30
+ def _get_headers(api_key: str | None) -> dict[str, str]:
31
+ """Get headers for CivitAI API requests."""
32
+ headers: dict[str, str] = {}
33
+ if api_key:
34
+ headers["Authorization"] = f"Bearer {api_key}"
35
+ return headers
36
+
37
+
38
+ def fetch_civitai_model_version(version_id: int, api_key: str | None, console: Console) -> dict[str, Any] | None:
39
+ """Fetch model version information from CivitAI by version ID."""
40
+ url = f"{CIVITAI_API_BASE}/model-versions/{version_id}"
41
+
42
+ try:
43
+ response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
44
+ if response.status_code == HTTPStatus.NOT_FOUND:
45
+ return None
46
+ response.raise_for_status()
47
+ result: dict[str, Any] = response.json()
48
+ return result
49
+ except httpx.HTTPStatusError as e:
50
+ console.print(f"[red]API error: {e.response.status_code}[/red]")
51
+ return None
52
+ except httpx.RequestError as e:
53
+ console.print(f"[red]Request error: {e}[/red]")
54
+ return None
55
+
56
+
57
+ def fetch_civitai_model(model_id: int, api_key: str | None, console: Console) -> dict[str, Any] | None:
58
+ """Fetch model information from CivitAI by model ID."""
59
+ url = f"{CIVITAI_API_BASE}/models/{model_id}"
60
+
61
+ with Progress(
62
+ SpinnerColumn(),
63
+ TextColumn("[progress.description]{task.description}"),
64
+ console=console,
65
+ transient=True,
66
+ ) as progress:
67
+ progress.add_task("[cyan]Fetching model from CivitAI...", total=None)
68
+
69
+ try:
70
+ response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
71
+ if response.status_code == HTTPStatus.NOT_FOUND:
72
+ return None
73
+ response.raise_for_status()
74
+ result: dict[str, Any] = response.json()
75
+ return result
76
+ except httpx.HTTPStatusError as e:
77
+ console.print(f"[red]API error: {e.response.status_code}[/red]")
78
+ return None
79
+ except httpx.RequestError as e:
80
+ console.print(f"[red]Request error: {e}[/red]")
81
+ return None
82
+
83
+
84
+ def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None, console: Console) -> dict[str, Any] | None:
85
+ """Fetch model information from CivitAI by SHA256 hash."""
86
+ url = f"{CIVITAI_API_BASE}/model-versions/by-hash/{sha256_hash}"
87
+
88
+ with Progress(
89
+ SpinnerColumn(),
90
+ TextColumn("[progress.description]{task.description}"),
91
+ console=console,
92
+ transient=True,
93
+ ) as progress:
94
+ progress.add_task("[cyan]Fetching from CivitAI...", total=None)
95
+
96
+ try:
97
+ response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
98
+ if response.status_code == HTTPStatus.NOT_FOUND:
99
+ return None
100
+ response.raise_for_status()
101
+ result: dict[str, Any] = response.json()
102
+ return result
103
+ except httpx.HTTPStatusError as e:
104
+ console.print(f"[red]API error: {e.response.status_code}[/red]")
105
+ return None
106
+ except httpx.RequestError as e:
107
+ console.print(f"[red]Request error: {e}[/red]")
108
+ return None
109
+
110
+
111
+ def _build_search_params(
112
+ query: str | None,
113
+ model_type: ModelType | None,
114
+ base_model: BaseModel | None,
115
+ sort: SortOrder,
116
+ limit: int,
117
+ ) -> tuple[dict[str, Any], bool]:
118
+ """Build search parameters and return (params, has_filters)."""
119
+ params: dict[str, Any] = {
120
+ "limit": min(limit, 100),
121
+ "nsfw": "true",
122
+ }
123
+
124
+ # API quirk: query + filters don't work reliably together
125
+ has_filters = model_type is not None or base_model is not None
126
+
127
+ if query and not has_filters:
128
+ params["query"] = query
129
+
130
+ if model_type:
131
+ params["types"] = model_type.to_api()
132
+
133
+ if base_model:
134
+ params["baseModels"] = base_model.to_api()
135
+
136
+ params["sort"] = sort.to_api()
137
+
138
+ # Request more if we need client-side filtering
139
+ if query and has_filters:
140
+ params["limit"] = 100
141
+
142
+ return params, has_filters
143
+
144
+
145
+ def _filter_results(result: dict[str, Any], query: str | None, has_filters: bool, limit: int) -> dict[str, Any]:
146
+ """Apply client-side filtering when query + filters combined."""
147
+ if query and has_filters:
148
+ q_lower = query.lower()
149
+ result["items"] = [m for m in result.get("items", []) if q_lower in m.get("name", "").lower()][:limit]
150
+ return result
151
+
152
+
153
+ def search_civitai(
154
+ query: str | None,
155
+ model_type: ModelType | None,
156
+ base_model: BaseModel | None,
157
+ sort: SortOrder,
158
+ limit: int,
159
+ api_key: str | None,
160
+ console: Console,
161
+ ) -> dict[str, Any] | None:
162
+ """Search CivitAI models."""
163
+ params, has_filters = _build_search_params(query, model_type, base_model, sort, limit)
164
+ url = f"{CIVITAI_API_BASE}/models"
165
+
166
+ with Progress(
167
+ SpinnerColumn(),
168
+ TextColumn("[progress.description]{task.description}"),
169
+ console=console,
170
+ transient=True,
171
+ ) as progress:
172
+ progress.add_task("[cyan]Searching CivitAI...", total=None)
173
+
174
+ try:
175
+ response = httpx.get(url, params=params, headers=_get_headers(api_key), timeout=30.0)
176
+ response.raise_for_status()
177
+ result: dict[str, Any] = response.json()
178
+ return _filter_results(result, query, has_filters, limit)
179
+ except httpx.HTTPStatusError as e:
180
+ console.print(f"[red]API error: {e.response.status_code}[/red]")
181
+ return None
182
+ except httpx.RequestError as e:
183
+ console.print(f"[red]Request error: {e}[/red]")
184
+ return None
185
+
186
+
187
+ def _setup_resume(dest_path: Path, resume: bool, console: Console) -> tuple[dict[str, str], str, int]:
188
+ """Set up resume headers and mode for download."""
189
+ headers: dict[str, str] = {}
190
+ mode = "wb"
191
+ initial_size = 0
192
+
193
+ if resume and dest_path.exists():
194
+ initial_size = dest_path.stat().st_size
195
+ headers["Range"] = f"bytes={initial_size}-"
196
+ mode = "ab"
197
+ console.print(f"[cyan]Resuming download from {initial_size / (1024**2):.1f} MB[/cyan]")
198
+
199
+ return headers, mode, initial_size
200
+
201
+
202
+ def _get_dest_from_response(response: httpx.Response, dest_path: Path) -> Path:
203
+ """Extract destination path from response headers if dest is directory."""
204
+ content_disp = response.headers.get("content-disposition", "")
205
+ if "filename=" in content_disp:
206
+ match = re.search(r'filename="?([^";\n]+)"?', content_disp)
207
+ if match and dest_path.is_dir():
208
+ return dest_path / match.group(1)
209
+ return dest_path
210
+
211
+
212
+ def _stream_download(
213
+ response: httpx.Response,
214
+ dest_path: Path,
215
+ mode: str,
216
+ initial_size: int,
217
+ console: Console,
218
+ ) -> bool:
219
+ """Stream download content to file with progress."""
220
+ content_length = response.headers.get("content-length")
221
+ total_size = int(content_length) + initial_size if content_length else 0
222
+
223
+ with Progress(
224
+ SpinnerColumn(),
225
+ TextColumn("[progress.description]{task.description}"),
226
+ BarColumn(),
227
+ TaskProgressColumn(),
228
+ DownloadColumn(),
229
+ TransferSpeedColumn(),
230
+ TimeRemainingColumn(),
231
+ console=console,
232
+ ) as progress:
233
+ task = progress.add_task(
234
+ f"[cyan]Downloading {dest_path.name}...",
235
+ total=total_size if total_size > 0 else None,
236
+ completed=initial_size,
237
+ )
238
+
239
+ with dest_path.open(mode) as f:
240
+ for chunk in response.iter_bytes(1024 * 1024):
241
+ f.write(chunk)
242
+ progress.update(task, advance=len(chunk))
243
+
244
+ console.print()
245
+ console.print(f'[magenta]Downloaded:[/magenta] [green]"{dest_path}"[/green]')
246
+ return True
247
+
248
+
249
+ def download_model(
250
+ version_id: int,
251
+ dest_path: Path,
252
+ api_key: str | None,
253
+ console: Console,
254
+ resume: bool = True,
255
+ ) -> bool:
256
+ """Download a model from CivitAI by version ID with resume support."""
257
+ url = f"{CIVITAI_DOWNLOAD_BASE}/{version_id}"
258
+ params: dict[str, str] = {}
259
+ if api_key:
260
+ params["token"] = api_key
261
+
262
+ headers, mode, initial_size = _setup_resume(dest_path, resume, console)
263
+
264
+ try:
265
+ with httpx.stream(
266
+ "GET",
267
+ url,
268
+ params=params,
269
+ headers=headers,
270
+ follow_redirects=True,
271
+ timeout=httpx.Timeout(30.0, read=None),
272
+ ) as response:
273
+ if response.status_code == HTTPStatus.REQUESTED_RANGE_NOT_SATISFIABLE:
274
+ console.print("[green]File already fully downloaded.[/green]")
275
+ return True
276
+
277
+ response.raise_for_status()
278
+ dest_path = _get_dest_from_response(response, dest_path)
279
+ return _stream_download(response, dest_path, mode, initial_size, console)
280
+
281
+ except httpx.HTTPStatusError as e:
282
+ console.print(f"[red]Download error: HTTP {e.response.status_code}[/red]")
283
+ if e.response.status_code == HTTPStatus.UNAUTHORIZED:
284
+ console.print("[yellow]Hint: This model may require an API key.[/yellow]")
285
+ return False
286
+ except httpx.RequestError as e:
287
+ console.print(f"[red]Download error: {e}[/red]")
288
+ return False