tensors 0.1.3__tar.gz → 0.1.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tensors-0.1.5/.claude/settings.local.json +7 -0
- tensors-0.1.5/.github/workflows/publish.yml +49 -0
- {tensors-0.1.3 → tensors-0.1.5}/PKG-INFO +1 -1
- tensors-0.1.5/RELEASE.md +89 -0
- tensors-0.1.5/justfile +28 -0
- {tensors-0.1.3 → tensors-0.1.5}/pyproject.toml +5 -6
- tensors-0.1.5/tensors/__init__.py +26 -0
- tensors-0.1.5/tensors/api.py +288 -0
- tensors-0.1.5/tensors/cli.py +413 -0
- tensors-0.1.5/tensors/config.py +166 -0
- tensors-0.1.5/tensors/display.py +331 -0
- tensors-0.1.5/tensors/safetensor.py +95 -0
- tensors-0.1.5/tests/test_tensors.py +590 -0
- {tensors-0.1.3 → tensors-0.1.5}/uv.lock +15 -1
- tensors-0.1.3/.github/workflows/publish.yml +0 -128
- tensors-0.1.3/tensors.py +0 -1071
- tensors-0.1.3/tests/test_tensors.py +0 -138
- {tensors-0.1.3 → tensors-0.1.5}/.github/workflows/ci.yml +0 -0
- {tensors-0.1.3 → tensors-0.1.5}/.gitignore +0 -0
- {tensors-0.1.3 → tensors-0.1.5}/.pre-commit-config.yaml +0 -0
- {tensors-0.1.3 → tensors-0.1.5}/.tool-versions +0 -0
- {tensors-0.1.3 → tensors-0.1.5}/README.md +0 -0
- {tensors-0.1.3 → tensors-0.1.5}/TODO.md +0 -0
- {tensors-0.1.3 → tensors-0.1.5}/check.py +0 -0
- {tensors-0.1.3 → tensors-0.1.5}/civit.md +0 -0
- {tensors-0.1.3 → tensors-0.1.5}/nuitka_build.py +0 -0
- {tensors-0.1.3 → tensors-0.1.5}/tests/__init__.py +0 -0
- {tensors-0.1.3 → tensors-0.1.5}/tests/conftest.py +0 -0
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
name: Publish Package
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
tags:
|
|
6
|
+
- 'v*'
|
|
7
|
+
|
|
8
|
+
permissions:
|
|
9
|
+
contents: write
|
|
10
|
+
id-token: write
|
|
11
|
+
|
|
12
|
+
jobs:
|
|
13
|
+
publish:
|
|
14
|
+
runs-on: ubuntu-latest
|
|
15
|
+
|
|
16
|
+
steps:
|
|
17
|
+
- uses: actions/checkout@v4
|
|
18
|
+
|
|
19
|
+
- name: Set up Python
|
|
20
|
+
uses: actions/setup-python@v5
|
|
21
|
+
with:
|
|
22
|
+
python-version: '3.12'
|
|
23
|
+
|
|
24
|
+
- name: Install build tools
|
|
25
|
+
run: pip install build
|
|
26
|
+
|
|
27
|
+
- name: Extract version from tag
|
|
28
|
+
id: version
|
|
29
|
+
run: |
|
|
30
|
+
TAG=${GITHUB_REF#refs/tags/v}
|
|
31
|
+
echo "version=$TAG" >> $GITHUB_OUTPUT
|
|
32
|
+
if [[ "$TAG" =~ -pre[0-9]*$ ]] || [[ "$TAG" =~ -alpha[0-9]*$ ]] || [[ "$TAG" =~ -beta[0-9]*$ ]] || [[ "$TAG" =~ -rc[0-9]*$ ]] || [[ "$TAG" =~ -a[0-9]*$ ]]; then
|
|
33
|
+
echo "prerelease=true" >> $GITHUB_OUTPUT
|
|
34
|
+
else
|
|
35
|
+
echo "prerelease=false" >> $GITHUB_OUTPUT
|
|
36
|
+
fi
|
|
37
|
+
|
|
38
|
+
- name: Build package
|
|
39
|
+
run: python -m build
|
|
40
|
+
|
|
41
|
+
- name: Publish to PyPI
|
|
42
|
+
uses: pypa/gh-action-pypi-publish@release/v1
|
|
43
|
+
|
|
44
|
+
- name: Create GitHub Release
|
|
45
|
+
uses: softprops/action-gh-release@v2
|
|
46
|
+
with:
|
|
47
|
+
files: dist/*
|
|
48
|
+
prerelease: ${{ steps.version.outputs.prerelease }}
|
|
49
|
+
generate_release_notes: true
|
tensors-0.1.5/RELEASE.md
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
# Release Process
|
|
2
|
+
|
|
3
|
+
## Publishing a Release
|
|
4
|
+
|
|
5
|
+
Push a version tag to trigger the publish workflow:
|
|
6
|
+
|
|
7
|
+
```bash
|
|
8
|
+
git tag v0.1.2
|
|
9
|
+
git push origin v0.1.2
|
|
10
|
+
```
|
|
11
|
+
|
|
12
|
+
The workflow will:
|
|
13
|
+
1. Build the Python package
|
|
14
|
+
2. Publish to PyPI
|
|
15
|
+
3. Create a GitHub Release
|
|
16
|
+
|
|
17
|
+
## Building Platform Binaries
|
|
18
|
+
|
|
19
|
+
Platform-specific binaries can be built locally using [Nuitka](https://nuitka.net/).
|
|
20
|
+
|
|
21
|
+
### Prerequisites
|
|
22
|
+
|
|
23
|
+
```bash
|
|
24
|
+
pip install nuitka
|
|
25
|
+
pip install -e .
|
|
26
|
+
```
|
|
27
|
+
|
|
28
|
+
### Build Commands
|
|
29
|
+
|
|
30
|
+
**Linux / macOS:**
|
|
31
|
+
```bash
|
|
32
|
+
python -m nuitka \
|
|
33
|
+
--standalone \
|
|
34
|
+
--onefile \
|
|
35
|
+
--output-dir=dist \
|
|
36
|
+
--output-filename=tsr \
|
|
37
|
+
--assume-yes-for-downloads \
|
|
38
|
+
--remove-output \
|
|
39
|
+
tensors.py
|
|
40
|
+
```
|
|
41
|
+
|
|
42
|
+
**Windows:**
|
|
43
|
+
```powershell
|
|
44
|
+
python -m nuitka `
|
|
45
|
+
--standalone `
|
|
46
|
+
--onefile `
|
|
47
|
+
--output-dir=dist `
|
|
48
|
+
--output-filename=tsr.exe `
|
|
49
|
+
--assume-yes-for-downloads `
|
|
50
|
+
--remove-output `
|
|
51
|
+
tensors.py
|
|
52
|
+
```
|
|
53
|
+
|
|
54
|
+
### Output Artifacts
|
|
55
|
+
|
|
56
|
+
| Platform | Arch | Filename |
|
|
57
|
+
|---------------|-------|--------------------|
|
|
58
|
+
| Linux | x64 | `tsr-linux-x64` |
|
|
59
|
+
| Linux | arm64 | `tsr-linux-arm64` |
|
|
60
|
+
| macOS | arm64 | `tsr-macos-arm64` |
|
|
61
|
+
| macOS | x64 | `tsr-macos-x64` |
|
|
62
|
+
| Windows | x64 | `tsr-windows-x64.exe` |
|
|
63
|
+
|
|
64
|
+
### macOS Code Signing (Optional)
|
|
65
|
+
|
|
66
|
+
To sign and notarize macOS binaries:
|
|
67
|
+
|
|
68
|
+
```bash
|
|
69
|
+
# Sign the binary
|
|
70
|
+
codesign --force --options runtime --sign "Developer ID Application" dist/tsr
|
|
71
|
+
|
|
72
|
+
# Create zip for notarization
|
|
73
|
+
ditto -c -k --keepParent dist/tsr dist/tsr.zip
|
|
74
|
+
|
|
75
|
+
# Submit for notarization
|
|
76
|
+
xcrun notarytool submit dist/tsr.zip \
|
|
77
|
+
--apple-id "$APPLE_ID" \
|
|
78
|
+
--password "$APPLE_ID_PASSWORD" \
|
|
79
|
+
--team-id "$APPLE_TEAM_ID" \
|
|
80
|
+
--wait
|
|
81
|
+
|
|
82
|
+
# Staple the notarization ticket
|
|
83
|
+
xcrun stapler staple dist/tsr
|
|
84
|
+
```
|
|
85
|
+
|
|
86
|
+
Required environment variables:
|
|
87
|
+
- `APPLE_ID` - Apple Developer account email
|
|
88
|
+
- `APPLE_ID_PASSWORD` - App-specific password
|
|
89
|
+
- `APPLE_TEAM_ID` - Developer Team ID
|
tensors-0.1.5/justfile
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
# Default: run all checks, fixes, and tests
|
|
2
|
+
default: fix check test
|
|
3
|
+
|
|
4
|
+
# Run all linting and type checks
|
|
5
|
+
check:
|
|
6
|
+
uv run ruff check .
|
|
7
|
+
uv run mypy tensors/
|
|
8
|
+
|
|
9
|
+
# Run tests
|
|
10
|
+
test:
|
|
11
|
+
uv run pytest
|
|
12
|
+
|
|
13
|
+
# Auto-fix linting issues and format code
|
|
14
|
+
fix:
|
|
15
|
+
uv run ruff check --fix .
|
|
16
|
+
uv run ruff format .
|
|
17
|
+
|
|
18
|
+
# Format code only
|
|
19
|
+
format:
|
|
20
|
+
uv run ruff format .
|
|
21
|
+
|
|
22
|
+
# Lint only (no fixes)
|
|
23
|
+
lint:
|
|
24
|
+
uv run ruff check .
|
|
25
|
+
|
|
26
|
+
# Type check only
|
|
27
|
+
types:
|
|
28
|
+
uv run mypy tensors/
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "tensors"
|
|
3
|
-
version = "0.1.
|
|
3
|
+
version = "0.1.5"
|
|
4
4
|
description = "Read safetensor metadata and fetch CivitAI model information"
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
requires-python = ">=3.12"
|
|
@@ -19,7 +19,7 @@ requires = ["hatchling"]
|
|
|
19
19
|
build-backend = "hatchling.build"
|
|
20
20
|
|
|
21
21
|
[tool.hatch.build.targets.wheel]
|
|
22
|
-
packages = ["tensors
|
|
22
|
+
packages = ["tensors"]
|
|
23
23
|
|
|
24
24
|
[dependency-groups]
|
|
25
25
|
dev = [
|
|
@@ -29,11 +29,12 @@ dev = [
|
|
|
29
29
|
"pytest>=8.0",
|
|
30
30
|
"pytest-cov>=4.1",
|
|
31
31
|
"pre-commit>=3.6",
|
|
32
|
+
"respx>=0.22.0",
|
|
32
33
|
]
|
|
33
34
|
|
|
34
35
|
[tool.ruff]
|
|
35
36
|
target-version = "py312"
|
|
36
|
-
line-length =
|
|
37
|
+
line-length = 130
|
|
37
38
|
|
|
38
39
|
[tool.ruff.lint]
|
|
39
40
|
select = [
|
|
@@ -52,9 +53,7 @@ select = [
|
|
|
52
53
|
"RUF", # ruff-specific
|
|
53
54
|
]
|
|
54
55
|
ignore = [
|
|
55
|
-
"
|
|
56
|
-
"PLR0913", # too many arguments
|
|
57
|
-
"PLR2004", # magic value comparison
|
|
56
|
+
"PLR0913", # Too many arguments - CLI commands need many options
|
|
58
57
|
]
|
|
59
58
|
|
|
60
59
|
[tool.ruff.lint.isort]
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""tsr: Read safetensor metadata, search and download CivitAI models."""
|
|
2
|
+
|
|
3
|
+
from tensors.cli import main
|
|
4
|
+
from tensors.config import (
|
|
5
|
+
CONFIG_DIR,
|
|
6
|
+
CONFIG_FILE,
|
|
7
|
+
LEGACY_RC_FILE,
|
|
8
|
+
get_default_output_path,
|
|
9
|
+
load_api_key,
|
|
10
|
+
load_config,
|
|
11
|
+
save_config,
|
|
12
|
+
)
|
|
13
|
+
from tensors.safetensor import get_base_name, read_safetensor_metadata
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"CONFIG_DIR",
|
|
17
|
+
"CONFIG_FILE",
|
|
18
|
+
"LEGACY_RC_FILE",
|
|
19
|
+
"get_base_name",
|
|
20
|
+
"get_default_output_path",
|
|
21
|
+
"load_api_key",
|
|
22
|
+
"load_config",
|
|
23
|
+
"main",
|
|
24
|
+
"read_safetensor_metadata",
|
|
25
|
+
"save_config",
|
|
26
|
+
]
|
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
"""CivitAI API functions."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import re
|
|
6
|
+
from http import HTTPStatus
|
|
7
|
+
from typing import TYPE_CHECKING, Any
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
|
|
12
|
+
import httpx
|
|
13
|
+
from rich.progress import (
|
|
14
|
+
BarColumn,
|
|
15
|
+
DownloadColumn,
|
|
16
|
+
Progress,
|
|
17
|
+
SpinnerColumn,
|
|
18
|
+
TaskProgressColumn,
|
|
19
|
+
TextColumn,
|
|
20
|
+
TimeRemainingColumn,
|
|
21
|
+
TransferSpeedColumn,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
from tensors.config import CIVITAI_API_BASE, CIVITAI_DOWNLOAD_BASE, BaseModel, ModelType, SortOrder
|
|
25
|
+
|
|
26
|
+
if TYPE_CHECKING:
|
|
27
|
+
from rich.console import Console
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _get_headers(api_key: str | None) -> dict[str, str]:
|
|
31
|
+
"""Get headers for CivitAI API requests."""
|
|
32
|
+
headers: dict[str, str] = {}
|
|
33
|
+
if api_key:
|
|
34
|
+
headers["Authorization"] = f"Bearer {api_key}"
|
|
35
|
+
return headers
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def fetch_civitai_model_version(version_id: int, api_key: str | None, console: Console) -> dict[str, Any] | None:
|
|
39
|
+
"""Fetch model version information from CivitAI by version ID."""
|
|
40
|
+
url = f"{CIVITAI_API_BASE}/model-versions/{version_id}"
|
|
41
|
+
|
|
42
|
+
try:
|
|
43
|
+
response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
|
|
44
|
+
if response.status_code == HTTPStatus.NOT_FOUND:
|
|
45
|
+
return None
|
|
46
|
+
response.raise_for_status()
|
|
47
|
+
result: dict[str, Any] = response.json()
|
|
48
|
+
return result
|
|
49
|
+
except httpx.HTTPStatusError as e:
|
|
50
|
+
console.print(f"[red]API error: {e.response.status_code}[/red]")
|
|
51
|
+
return None
|
|
52
|
+
except httpx.RequestError as e:
|
|
53
|
+
console.print(f"[red]Request error: {e}[/red]")
|
|
54
|
+
return None
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def fetch_civitai_model(model_id: int, api_key: str | None, console: Console) -> dict[str, Any] | None:
|
|
58
|
+
"""Fetch model information from CivitAI by model ID."""
|
|
59
|
+
url = f"{CIVITAI_API_BASE}/models/{model_id}"
|
|
60
|
+
|
|
61
|
+
with Progress(
|
|
62
|
+
SpinnerColumn(),
|
|
63
|
+
TextColumn("[progress.description]{task.description}"),
|
|
64
|
+
console=console,
|
|
65
|
+
transient=True,
|
|
66
|
+
) as progress:
|
|
67
|
+
progress.add_task("[cyan]Fetching model from CivitAI...", total=None)
|
|
68
|
+
|
|
69
|
+
try:
|
|
70
|
+
response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
|
|
71
|
+
if response.status_code == HTTPStatus.NOT_FOUND:
|
|
72
|
+
return None
|
|
73
|
+
response.raise_for_status()
|
|
74
|
+
result: dict[str, Any] = response.json()
|
|
75
|
+
return result
|
|
76
|
+
except httpx.HTTPStatusError as e:
|
|
77
|
+
console.print(f"[red]API error: {e.response.status_code}[/red]")
|
|
78
|
+
return None
|
|
79
|
+
except httpx.RequestError as e:
|
|
80
|
+
console.print(f"[red]Request error: {e}[/red]")
|
|
81
|
+
return None
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def fetch_civitai_by_hash(sha256_hash: str, api_key: str | None, console: Console) -> dict[str, Any] | None:
|
|
85
|
+
"""Fetch model information from CivitAI by SHA256 hash."""
|
|
86
|
+
url = f"{CIVITAI_API_BASE}/model-versions/by-hash/{sha256_hash}"
|
|
87
|
+
|
|
88
|
+
with Progress(
|
|
89
|
+
SpinnerColumn(),
|
|
90
|
+
TextColumn("[progress.description]{task.description}"),
|
|
91
|
+
console=console,
|
|
92
|
+
transient=True,
|
|
93
|
+
) as progress:
|
|
94
|
+
progress.add_task("[cyan]Fetching from CivitAI...", total=None)
|
|
95
|
+
|
|
96
|
+
try:
|
|
97
|
+
response = httpx.get(url, headers=_get_headers(api_key), timeout=30.0)
|
|
98
|
+
if response.status_code == HTTPStatus.NOT_FOUND:
|
|
99
|
+
return None
|
|
100
|
+
response.raise_for_status()
|
|
101
|
+
result: dict[str, Any] = response.json()
|
|
102
|
+
return result
|
|
103
|
+
except httpx.HTTPStatusError as e:
|
|
104
|
+
console.print(f"[red]API error: {e.response.status_code}[/red]")
|
|
105
|
+
return None
|
|
106
|
+
except httpx.RequestError as e:
|
|
107
|
+
console.print(f"[red]Request error: {e}[/red]")
|
|
108
|
+
return None
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _build_search_params(
|
|
112
|
+
query: str | None,
|
|
113
|
+
model_type: ModelType | None,
|
|
114
|
+
base_model: BaseModel | None,
|
|
115
|
+
sort: SortOrder,
|
|
116
|
+
limit: int,
|
|
117
|
+
) -> tuple[dict[str, Any], bool]:
|
|
118
|
+
"""Build search parameters and return (params, has_filters)."""
|
|
119
|
+
params: dict[str, Any] = {
|
|
120
|
+
"limit": min(limit, 100),
|
|
121
|
+
"nsfw": "true",
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
# API quirk: query + filters don't work reliably together
|
|
125
|
+
has_filters = model_type is not None or base_model is not None
|
|
126
|
+
|
|
127
|
+
if query and not has_filters:
|
|
128
|
+
params["query"] = query
|
|
129
|
+
|
|
130
|
+
if model_type:
|
|
131
|
+
params["types"] = model_type.to_api()
|
|
132
|
+
|
|
133
|
+
if base_model:
|
|
134
|
+
params["baseModels"] = base_model.to_api()
|
|
135
|
+
|
|
136
|
+
params["sort"] = sort.to_api()
|
|
137
|
+
|
|
138
|
+
# Request more if we need client-side filtering
|
|
139
|
+
if query and has_filters:
|
|
140
|
+
params["limit"] = 100
|
|
141
|
+
|
|
142
|
+
return params, has_filters
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _filter_results(result: dict[str, Any], query: str | None, has_filters: bool, limit: int) -> dict[str, Any]:
|
|
146
|
+
"""Apply client-side filtering when query + filters combined."""
|
|
147
|
+
if query and has_filters:
|
|
148
|
+
q_lower = query.lower()
|
|
149
|
+
result["items"] = [m for m in result.get("items", []) if q_lower in m.get("name", "").lower()][:limit]
|
|
150
|
+
return result
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def search_civitai(
|
|
154
|
+
query: str | None,
|
|
155
|
+
model_type: ModelType | None,
|
|
156
|
+
base_model: BaseModel | None,
|
|
157
|
+
sort: SortOrder,
|
|
158
|
+
limit: int,
|
|
159
|
+
api_key: str | None,
|
|
160
|
+
console: Console,
|
|
161
|
+
) -> dict[str, Any] | None:
|
|
162
|
+
"""Search CivitAI models."""
|
|
163
|
+
params, has_filters = _build_search_params(query, model_type, base_model, sort, limit)
|
|
164
|
+
url = f"{CIVITAI_API_BASE}/models"
|
|
165
|
+
|
|
166
|
+
with Progress(
|
|
167
|
+
SpinnerColumn(),
|
|
168
|
+
TextColumn("[progress.description]{task.description}"),
|
|
169
|
+
console=console,
|
|
170
|
+
transient=True,
|
|
171
|
+
) as progress:
|
|
172
|
+
progress.add_task("[cyan]Searching CivitAI...", total=None)
|
|
173
|
+
|
|
174
|
+
try:
|
|
175
|
+
response = httpx.get(url, params=params, headers=_get_headers(api_key), timeout=30.0)
|
|
176
|
+
response.raise_for_status()
|
|
177
|
+
result: dict[str, Any] = response.json()
|
|
178
|
+
return _filter_results(result, query, has_filters, limit)
|
|
179
|
+
except httpx.HTTPStatusError as e:
|
|
180
|
+
console.print(f"[red]API error: {e.response.status_code}[/red]")
|
|
181
|
+
return None
|
|
182
|
+
except httpx.RequestError as e:
|
|
183
|
+
console.print(f"[red]Request error: {e}[/red]")
|
|
184
|
+
return None
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def _setup_resume(dest_path: Path, resume: bool, console: Console) -> tuple[dict[str, str], str, int]:
|
|
188
|
+
"""Set up resume headers and mode for download."""
|
|
189
|
+
headers: dict[str, str] = {}
|
|
190
|
+
mode = "wb"
|
|
191
|
+
initial_size = 0
|
|
192
|
+
|
|
193
|
+
if resume and dest_path.exists():
|
|
194
|
+
initial_size = dest_path.stat().st_size
|
|
195
|
+
headers["Range"] = f"bytes={initial_size}-"
|
|
196
|
+
mode = "ab"
|
|
197
|
+
console.print(f"[cyan]Resuming download from {initial_size / (1024**2):.1f} MB[/cyan]")
|
|
198
|
+
|
|
199
|
+
return headers, mode, initial_size
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def _get_dest_from_response(response: httpx.Response, dest_path: Path) -> Path:
|
|
203
|
+
"""Extract destination path from response headers if dest is directory."""
|
|
204
|
+
content_disp = response.headers.get("content-disposition", "")
|
|
205
|
+
if "filename=" in content_disp:
|
|
206
|
+
match = re.search(r'filename="?([^";\n]+)"?', content_disp)
|
|
207
|
+
if match and dest_path.is_dir():
|
|
208
|
+
return dest_path / match.group(1)
|
|
209
|
+
return dest_path
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def _stream_download(
|
|
213
|
+
response: httpx.Response,
|
|
214
|
+
dest_path: Path,
|
|
215
|
+
mode: str,
|
|
216
|
+
initial_size: int,
|
|
217
|
+
console: Console,
|
|
218
|
+
) -> bool:
|
|
219
|
+
"""Stream download content to file with progress."""
|
|
220
|
+
content_length = response.headers.get("content-length")
|
|
221
|
+
total_size = int(content_length) + initial_size if content_length else 0
|
|
222
|
+
|
|
223
|
+
with Progress(
|
|
224
|
+
SpinnerColumn(),
|
|
225
|
+
TextColumn("[progress.description]{task.description}"),
|
|
226
|
+
BarColumn(),
|
|
227
|
+
TaskProgressColumn(),
|
|
228
|
+
DownloadColumn(),
|
|
229
|
+
TransferSpeedColumn(),
|
|
230
|
+
TimeRemainingColumn(),
|
|
231
|
+
console=console,
|
|
232
|
+
) as progress:
|
|
233
|
+
task = progress.add_task(
|
|
234
|
+
f"[cyan]Downloading {dest_path.name}...",
|
|
235
|
+
total=total_size if total_size > 0 else None,
|
|
236
|
+
completed=initial_size,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
with dest_path.open(mode) as f:
|
|
240
|
+
for chunk in response.iter_bytes(1024 * 1024):
|
|
241
|
+
f.write(chunk)
|
|
242
|
+
progress.update(task, advance=len(chunk))
|
|
243
|
+
|
|
244
|
+
console.print()
|
|
245
|
+
console.print(f'[magenta]Downloaded:[/magenta] [green]"{dest_path}"[/green]')
|
|
246
|
+
return True
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def download_model(
|
|
250
|
+
version_id: int,
|
|
251
|
+
dest_path: Path,
|
|
252
|
+
api_key: str | None,
|
|
253
|
+
console: Console,
|
|
254
|
+
resume: bool = True,
|
|
255
|
+
) -> bool:
|
|
256
|
+
"""Download a model from CivitAI by version ID with resume support."""
|
|
257
|
+
url = f"{CIVITAI_DOWNLOAD_BASE}/{version_id}"
|
|
258
|
+
params: dict[str, str] = {}
|
|
259
|
+
if api_key:
|
|
260
|
+
params["token"] = api_key
|
|
261
|
+
|
|
262
|
+
headers, mode, initial_size = _setup_resume(dest_path, resume, console)
|
|
263
|
+
|
|
264
|
+
try:
|
|
265
|
+
with httpx.stream(
|
|
266
|
+
"GET",
|
|
267
|
+
url,
|
|
268
|
+
params=params,
|
|
269
|
+
headers=headers,
|
|
270
|
+
follow_redirects=True,
|
|
271
|
+
timeout=httpx.Timeout(30.0, read=None),
|
|
272
|
+
) as response:
|
|
273
|
+
if response.status_code == HTTPStatus.REQUESTED_RANGE_NOT_SATISFIABLE:
|
|
274
|
+
console.print("[green]File already fully downloaded.[/green]")
|
|
275
|
+
return True
|
|
276
|
+
|
|
277
|
+
response.raise_for_status()
|
|
278
|
+
dest_path = _get_dest_from_response(response, dest_path)
|
|
279
|
+
return _stream_download(response, dest_path, mode, initial_size, console)
|
|
280
|
+
|
|
281
|
+
except httpx.HTTPStatusError as e:
|
|
282
|
+
console.print(f"[red]Download error: HTTP {e.response.status_code}[/red]")
|
|
283
|
+
if e.response.status_code == HTTPStatus.UNAUTHORIZED:
|
|
284
|
+
console.print("[yellow]Hint: This model may require an API key.[/yellow]")
|
|
285
|
+
return False
|
|
286
|
+
except httpx.RequestError as e:
|
|
287
|
+
console.print(f"[red]Download error: {e}[/red]")
|
|
288
|
+
return False
|