tensors 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tensors/__init__.py +26 -0
- tensors/api.py +288 -0
- tensors/cli.py +413 -0
- tensors/config.py +166 -0
- tensors/display.py +331 -0
- tensors/safetensor.py +95 -0
- {tensors-0.1.3.dist-info → tensors-0.1.5.dist-info}/METADATA +1 -1
- tensors-0.1.5.dist-info/RECORD +10 -0
- tensors-0.1.3.dist-info/RECORD +0 -5
- tensors.py +0 -1071
- {tensors-0.1.3.dist-info → tensors-0.1.5.dist-info}/WHEEL +0 -0
- {tensors-0.1.3.dist-info → tensors-0.1.5.dist-info}/entry_points.txt +0 -0
tensors/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""tsr: Read safetensor metadata, search and download CivitAI models."""
|
|
2
|
+
|
|
3
|
+
from tensors.cli import main
|
|
4
|
+
from tensors.config import (
|
|
5
|
+
CONFIG_DIR,
|
|
6
|
+
CONFIG_FILE,
|
|
7
|
+
LEGACY_RC_FILE,
|
|
8
|
+
get_default_output_path,
|
|
9
|
+
load_api_key,
|
|
10
|
+
load_config,
|
|
11
|
+
save_config,
|
|
12
|
+
)
|
|
13
|
+
from tensors.safetensor import get_base_name, read_safetensor_metadata
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"CONFIG_DIR",
|
|
17
|
+
"CONFIG_FILE",
|
|
18
|
+
"LEGACY_RC_FILE",
|
|
19
|
+
"get_base_name",
|
|
20
|
+
"get_default_output_path",
|
|
21
|
+
"load_api_key",
|
|
22
|
+
"load_config",
|
|
23
|
+
"main",
|
|
24
|
+
"read_safetensor_metadata",
|
|
25
|
+
"save_config",
|
|
26
|
+
]
|
tensors/api.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
"""CivitAI API functions."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import re
|
|
6
|
+
from http import HTTPStatus
|
|
7
|
+
from typing import TYPE_CHECKING, Any
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
|
|
12
|
+
import httpx
|
|
13
|
+
from rich.progress import (
|
|
14
|
+
BarColumn,
|
|
15
|
+
DownloadColumn,
|
|
16
|
+
Progress,
|
|
17
|
+
SpinnerColumn,
|
|
18
|
+
TaskProgressColumn,
|
|
19
|
+
TextColumn,
|
|
20
|
+
TimeRemainingColumn,
|
|
21
|
+
TransferSpeedColumn,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
from tensors.config import CIVITAI_API_BASE, CIVITAI_DOWNLOAD_BASE, BaseModel, ModelType, SortOrder
|
|
25
|
+
|
|
26
|
+
if TYPE_CHECKING:
|
|
27
|
+
from rich.console import Console
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _get_headers(api_key: str | None) -> dict[str, str]:
|
|
31
|
+
"""Get headers for CivitAI API requests."""
|
|
32
|
+
headers: dict[str, str] = {}
|
|
33
|
+
if api_key:
|
|
34
|
+
headers["Authorization"] = f"Bearer {api_key}"
|
|
35
|
+
return headers
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def fetch_civitai_model_version(version_id: int, api_key: str | None, console: Console) -> dict[str, Any] | None:
|
|
39
|
+
"""Fetch model version information from CivitAI by version ID."""
|
|
40
|
+
url = f"{CIVITAI_API_BASE}/model-versions/{version_id}"
|
|
41
|
+
|
|
42
|
+
try:
|
|
43
|
+
response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
|
|
44
|
+
if response.status_code == HTTPStatus.NOT_FOUND:
|
|
45
|
+
return None
|
|
46
|
+
response.raise_for_status()
|
|
47
|
+
result: dict[str, Any] = response.json()
|
|
48
|
+
return result
|
|
49
|
+
except httpx.HTTPStatusError as e:
|
|
50
|
+
console.print(f"[red]API error: {e.response.status_code}[/red]")
|
|
51
|
+
return None
|
|
52
|
+
except httpx.RequestError as e:
|
|
53
|
+
console.print(f"[red]Request error: {e}[/red]")
|
|
54
|
+
return None
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def fetch_civitai_model(model_id: int, api_key: str | None, console: Console) -> dict[str, Any] | None:
|
|
58
|
+
"""Fetch model information from CivitAI by model ID."""
|
|
59
|
+
url = f"{CIVITAI_API_BASE}/models/{model_id}"
|
|
60
|
+
|
|
61
|
+
with Progress(
|
|
62
|
+
SpinnerColumn(),
|
|
63
|
+
TextColumn("[progress.description]{task.description}"),
|
|
64
|
+
console=console,
|
|
65
|
+
transient=True,
|
|
66
|
+
) as progress:
|
|
67
|
+
progress.add_task("[cyan]Fetching model from CivitAI...", total=None)
|
|
68
|
+
|
|
69
|
+
try:
|
|
70
|
+
response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
|
|
71
|
+
if response.status_code == HTTPStatus.NOT_FOUND:
|
|
72
|
+
return None
|
|
73
|
+
response.raise_for_status()
|
|
74
|
+
result: dict[str, Any] = response.json()
|
|
75
|
+
return result
|
|
76
|
+
except httpx.HTTPStatusError as e:
|
|
77
|
+
console.print(f"[red]API error: {e.response.status_code}[/red]")
|
|
78
|
+
return None
|
|
79
|
+
except httpx.RequestError as e:
|
|
80
|
+
console.print(f"[red]Request error: {e}[/red]")
|
|
81
|
+
return None
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None, console: Console) -> dict[str, Any] | None:
|
|
85
|
+
"""Fetch model information from CivitAI by SHA256 hash."""
|
|
86
|
+
url = f"{CIVITAI_API_BASE}/model-versions/by-hash/{sha256_hash}"
|
|
87
|
+
|
|
88
|
+
with Progress(
|
|
89
|
+
SpinnerColumn(),
|
|
90
|
+
TextColumn("[progress.description]{task.description}"),
|
|
91
|
+
console=console,
|
|
92
|
+
transient=True,
|
|
93
|
+
) as progress:
|
|
94
|
+
progress.add_task("[cyan]Fetching from CivitAI...", total=None)
|
|
95
|
+
|
|
96
|
+
try:
|
|
97
|
+
response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
|
|
98
|
+
if response.status_code == HTTPStatus.NOT_FOUND:
|
|
99
|
+
return None
|
|
100
|
+
response.raise_for_status()
|
|
101
|
+
result: dict[str, Any] = response.json()
|
|
102
|
+
return result
|
|
103
|
+
except httpx.HTTPStatusError as e:
|
|
104
|
+
console.print(f"[red]API error: {e.response.status_code}[/red]")
|
|
105
|
+
return None
|
|
106
|
+
except httpx.RequestError as e:
|
|
107
|
+
console.print(f"[red]Request error: {e}[/red]")
|
|
108
|
+
return None
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _build_search_params(
|
|
112
|
+
query: str | None,
|
|
113
|
+
model_type: ModelType | None,
|
|
114
|
+
base_model: BaseModel | None,
|
|
115
|
+
sort: SortOrder,
|
|
116
|
+
limit: int,
|
|
117
|
+
) -> tuple[dict[str, Any], bool]:
|
|
118
|
+
"""Build search parameters and return (params, has_filters)."""
|
|
119
|
+
params: dict[str, Any] = {
|
|
120
|
+
"limit": min(limit, 100),
|
|
121
|
+
"nsfw": "true",
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
# API quirk: query + filters don't work reliably together
|
|
125
|
+
has_filters = model_type is not None or base_model is not None
|
|
126
|
+
|
|
127
|
+
if query and not has_filters:
|
|
128
|
+
params["query"] = query
|
|
129
|
+
|
|
130
|
+
if model_type:
|
|
131
|
+
params["types"] = model_type.to_api()
|
|
132
|
+
|
|
133
|
+
if base_model:
|
|
134
|
+
params["baseModels"] = base_model.to_api()
|
|
135
|
+
|
|
136
|
+
params["sort"] = sort.to_api()
|
|
137
|
+
|
|
138
|
+
# Request more if we need client-side filtering
|
|
139
|
+
if query and has_filters:
|
|
140
|
+
params["limit"] = 100
|
|
141
|
+
|
|
142
|
+
return params, has_filters
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _filter_results(result: dict[str, Any], query: str | None, has_filters: bool, limit: int) -> dict[str, Any]:
|
|
146
|
+
"""Apply client-side filtering when query + filters combined."""
|
|
147
|
+
if query and has_filters:
|
|
148
|
+
q_lower = query.lower()
|
|
149
|
+
result["items"] = [m for m in result.get("items", []) if q_lower in m.get("name", "").lower()][:limit]
|
|
150
|
+
return result
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def search_civitai(
|
|
154
|
+
query: str | None,
|
|
155
|
+
model_type: ModelType | None,
|
|
156
|
+
base_model: BaseModel | None,
|
|
157
|
+
sort: SortOrder,
|
|
158
|
+
limit: int,
|
|
159
|
+
api_key: str | None,
|
|
160
|
+
console: Console,
|
|
161
|
+
) -> dict[str, Any] | None:
|
|
162
|
+
"""Search CivitAI models."""
|
|
163
|
+
params, has_filters = _build_search_params(query, model_type, base_model, sort, limit)
|
|
164
|
+
url = f"{CIVITAI_API_BASE}/models"
|
|
165
|
+
|
|
166
|
+
with Progress(
|
|
167
|
+
SpinnerColumn(),
|
|
168
|
+
TextColumn("[progress.description]{task.description}"),
|
|
169
|
+
console=console,
|
|
170
|
+
transient=True,
|
|
171
|
+
) as progress:
|
|
172
|
+
progress.add_task("[cyan]Searching CivitAI...", total=None)
|
|
173
|
+
|
|
174
|
+
try:
|
|
175
|
+
response = httpx.get(url, params=params, headers=_get_headers(api_key), timeout=30.0)
|
|
176
|
+
response.raise_for_status()
|
|
177
|
+
result: dict[str, Any] = response.json()
|
|
178
|
+
return _filter_results(result, query, has_filters, limit)
|
|
179
|
+
except httpx.HTTPStatusError as e:
|
|
180
|
+
console.print(f"[red]API error: {e.response.status_code}[/red]")
|
|
181
|
+
return None
|
|
182
|
+
except httpx.RequestError as e:
|
|
183
|
+
console.print(f"[red]Request error: {e}[/red]")
|
|
184
|
+
return None
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def _setup_resume(dest_path: Path, resume: bool, console: Console) -> tuple[dict[str, str], str, int]:
|
|
188
|
+
"""Set up resume headers and mode for download."""
|
|
189
|
+
headers: dict[str, str] = {}
|
|
190
|
+
mode = "wb"
|
|
191
|
+
initial_size = 0
|
|
192
|
+
|
|
193
|
+
if resume and dest_path.exists():
|
|
194
|
+
initial_size = dest_path.stat().st_size
|
|
195
|
+
headers["Range"] = f"bytes={initial_size}-"
|
|
196
|
+
mode = "ab"
|
|
197
|
+
console.print(f"[cyan]Resuming download from {initial_size / (1024**2):.1f} MB[/cyan]")
|
|
198
|
+
|
|
199
|
+
return headers, mode, initial_size
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def _get_dest_from_response(response: httpx.Response, dest_path: Path) -> Path:
|
|
203
|
+
"""Extract destination path from response headers if dest is directory."""
|
|
204
|
+
content_disp = response.headers.get("content-disposition", "")
|
|
205
|
+
if "filename=" in content_disp:
|
|
206
|
+
match = re.search(r'filename="?([^";\n]+)"?', content_disp)
|
|
207
|
+
if match and dest_path.is_dir():
|
|
208
|
+
return dest_path / match.group(1)
|
|
209
|
+
return dest_path
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def _stream_download(
|
|
213
|
+
response: httpx.Response,
|
|
214
|
+
dest_path: Path,
|
|
215
|
+
mode: str,
|
|
216
|
+
initial_size: int,
|
|
217
|
+
console: Console,
|
|
218
|
+
) -> bool:
|
|
219
|
+
"""Stream download content to file with progress."""
|
|
220
|
+
content_length = response.headers.get("content-length")
|
|
221
|
+
total_size = int(content_length) + initial_size if content_length else 0
|
|
222
|
+
|
|
223
|
+
with Progress(
|
|
224
|
+
SpinnerColumn(),
|
|
225
|
+
TextColumn("[progress.description]{task.description}"),
|
|
226
|
+
BarColumn(),
|
|
227
|
+
TaskProgressColumn(),
|
|
228
|
+
DownloadColumn(),
|
|
229
|
+
TransferSpeedColumn(),
|
|
230
|
+
TimeRemainingColumn(),
|
|
231
|
+
console=console,
|
|
232
|
+
) as progress:
|
|
233
|
+
task = progress.add_task(
|
|
234
|
+
f"[cyan]Downloading {dest_path.name}...",
|
|
235
|
+
total=total_size if total_size > 0 else None,
|
|
236
|
+
completed=initial_size,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
with dest_path.open(mode) as f:
|
|
240
|
+
for chunk in response.iter_bytes(1024 * 1024):
|
|
241
|
+
f.write(chunk)
|
|
242
|
+
progress.update(task, advance=len(chunk))
|
|
243
|
+
|
|
244
|
+
console.print()
|
|
245
|
+
console.print(f'[magenta]Downloaded:[/magenta] [green]"{dest_path}"[/green]')
|
|
246
|
+
return True
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def download_model(
|
|
250
|
+
version_id: int,
|
|
251
|
+
dest_path: Path,
|
|
252
|
+
api_key: str | None,
|
|
253
|
+
console: Console,
|
|
254
|
+
resume: bool = True,
|
|
255
|
+
) -> bool:
|
|
256
|
+
"""Download a model from CivitAI by version ID with resume support."""
|
|
257
|
+
url = f"{CIVITAI_DOWNLOAD_BASE}/{version_id}"
|
|
258
|
+
params: dict[str, str] = {}
|
|
259
|
+
if api_key:
|
|
260
|
+
params["token"] = api_key
|
|
261
|
+
|
|
262
|
+
headers, mode, initial_size = _setup_resume(dest_path, resume, console)
|
|
263
|
+
|
|
264
|
+
try:
|
|
265
|
+
with httpx.stream(
|
|
266
|
+
"GET",
|
|
267
|
+
url,
|
|
268
|
+
params=params,
|
|
269
|
+
headers=headers,
|
|
270
|
+
follow_redirects=True,
|
|
271
|
+
timeout=httpx.Timeout(30.0, read=None),
|
|
272
|
+
) as response:
|
|
273
|
+
if response.status_code == HTTPStatus.REQUESTED_RANGE_NOT_SATISFIABLE:
|
|
274
|
+
console.print("[green]File already fully downloaded.[/green]")
|
|
275
|
+
return True
|
|
276
|
+
|
|
277
|
+
response.raise_for_status()
|
|
278
|
+
dest_path = _get_dest_from_response(response, dest_path)
|
|
279
|
+
return _stream_download(response, dest_path, mode, initial_size, console)
|
|
280
|
+
|
|
281
|
+
except httpx.HTTPStatusError as e:
|
|
282
|
+
console.print(f"[red]Download error: HTTP {e.response.status_code}[/red]")
|
|
283
|
+
if e.response.status_code == HTTPStatus.UNAUTHORIZED:
|
|
284
|
+
console.print("[yellow]Hint: This model may require an API key.[/yellow]")
|
|
285
|
+
return False
|
|
286
|
+
except httpx.RequestError as e:
|
|
287
|
+
console.print(f"[red]Download error: {e}[/red]")
|
|
288
|
+
return False
|