sat-water 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sat_water-0.1.0.dist-info/METADATA +347 -0
- sat_water-0.1.0.dist-info/RECORD +11 -0
- sat_water-0.1.0.dist-info/WHEEL +4 -0
- sat_water-0.1.0.dist-info/licenses/LICENSE +21 -0
- satwater/__init__.py +4 -0
- satwater/builders.py +179 -0
- satwater/inference.py +313 -0
- satwater/models.py +229 -0
- satwater/preprocess.py +161 -0
- satwater/utils.py +45 -0
- satwater/weights.py +176 -0
satwater/utils.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import random
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def set_seed(seed=42):
|
|
11
|
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
|
12
|
+
random.seed(seed)
|
|
13
|
+
np.random.seed(seed)
|
|
14
|
+
try:
|
|
15
|
+
import tensorflow as tf
|
|
16
|
+
|
|
17
|
+
tf.random.set_seed(seed)
|
|
18
|
+
except Exception:
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def ensure_dir(path):
|
|
23
|
+
p = Path(path)
|
|
24
|
+
p.mkdir(parents=True, exist_ok=True)
|
|
25
|
+
return p
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def parse_shape(shape):
|
|
29
|
+
parts = [int(x.strip()) for x in shape.split(",")]
|
|
30
|
+
if len(parts) != 3:
|
|
31
|
+
raise ValueError("shape must be 'H,W,C' e.g. '128,128,3'")
|
|
32
|
+
return parts[0], parts[1], parts[2]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def parse_models_arg(models_arg):
|
|
36
|
+
out = {}
|
|
37
|
+
for item in models_arg.split(","):
|
|
38
|
+
item = item.strip()
|
|
39
|
+
if not item:
|
|
40
|
+
continue
|
|
41
|
+
if "=" not in item:
|
|
42
|
+
raise ValueError("models must be comma-separated key=path pairs")
|
|
43
|
+
k, v = item.split("=", 1)
|
|
44
|
+
out[k.strip()] = v.strip()
|
|
45
|
+
return out
|
satwater/weights.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Created on Fri Jan 16 16:21:44 2026
|
|
3
|
+
|
|
4
|
+
@author: Busayo Alabi
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import hashlib
|
|
10
|
+
import json
|
|
11
|
+
import os
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
|
|
14
|
+
from huggingface_hub import hf_hub_download
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class WeightsError(RuntimeError):
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass(frozen=True)
|
|
22
|
+
class WeightsRef:
|
|
23
|
+
model_key: str
|
|
24
|
+
repo_id: str
|
|
25
|
+
revision: str
|
|
26
|
+
hf_root: str
|
|
27
|
+
weights_file_in_repo: str
|
|
28
|
+
expected_sha256: str
|
|
29
|
+
input_shape: str | None = None
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _sha256_file(path, chunk_size=1024 * 1024):
|
|
33
|
+
h = hashlib.sha256()
|
|
34
|
+
with open(path, "rb") as f:
|
|
35
|
+
for chunk in iter(lambda: f.read(chunk_size), b""):
|
|
36
|
+
h.update(chunk)
|
|
37
|
+
return h.hexdigest()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _load_manifest(repo_id, revision, hf_root="weights", cache_dir=None):
|
|
41
|
+
"""
|
|
42
|
+
Download and parse manifest from HF.
|
|
43
|
+
"""
|
|
44
|
+
manifest_path = hf_hub_download(
|
|
45
|
+
repo_id=repo_id,
|
|
46
|
+
repo_type="model",
|
|
47
|
+
filename=f"{hf_root.strip('/')}/manifest.json",
|
|
48
|
+
revision=revision,
|
|
49
|
+
cache_dir=cache_dir,
|
|
50
|
+
)
|
|
51
|
+
with open(manifest_path, encoding="utf-8") as f:
|
|
52
|
+
return json.load(f)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def resolve_weights_ref(
|
|
56
|
+
model_key, repo_id, revision="main", hf_root="weights", cache_dir=None
|
|
57
|
+
):
|
|
58
|
+
"""
|
|
59
|
+
Resolve a model_key into the exact file and expected sha256 from manifest.
|
|
60
|
+
"""
|
|
61
|
+
manifest = _load_manifest(
|
|
62
|
+
repo_id=repo_id, revision=revision, hf_root=hf_root, cache_dir=cache_dir
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
if "models" not in manifest or not isinstance(manifest["models"], dict):
|
|
66
|
+
raise WeightsError("Invalid manifest.json: missing 'models' mapping")
|
|
67
|
+
|
|
68
|
+
models = manifest["models"]
|
|
69
|
+
if model_key not in models:
|
|
70
|
+
available = ", ".join(sorted(models.keys()))
|
|
71
|
+
raise WeightsError(f"Unknown model_key='{model_key}'. Available: {available}")
|
|
72
|
+
|
|
73
|
+
entry = models[model_key]
|
|
74
|
+
try:
|
|
75
|
+
weights_file = entry["weights_file"]
|
|
76
|
+
sha = entry["sha256"]
|
|
77
|
+
input_shape = entry.get("input_shape")
|
|
78
|
+
except KeyError as e:
|
|
79
|
+
raise WeightsError(
|
|
80
|
+
f"Invalid manifest entry for '{model_key}': missing {e!s}"
|
|
81
|
+
) from e
|
|
82
|
+
weights_file_in_repo = f"{hf_root.strip('/')}/{weights_file}".replace("//", "/")
|
|
83
|
+
|
|
84
|
+
return WeightsRef(
|
|
85
|
+
model_key=model_key,
|
|
86
|
+
repo_id=repo_id,
|
|
87
|
+
revision=revision,
|
|
88
|
+
hf_root=hf_root.strip("/"),
|
|
89
|
+
weights_file_in_repo=weights_file_in_repo,
|
|
90
|
+
expected_sha256=sha,
|
|
91
|
+
input_shape=input_shape,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def download_weights(
|
|
96
|
+
model_key,
|
|
97
|
+
repo_id,
|
|
98
|
+
revision="main",
|
|
99
|
+
hf_root="weights",
|
|
100
|
+
cache_dir=None,
|
|
101
|
+
verify=True,
|
|
102
|
+
retry_on_mismatch=True,
|
|
103
|
+
):
|
|
104
|
+
"""
|
|
105
|
+
Download weights for a given model_key, verify SHA256, and return the local path.
|
|
106
|
+
"""
|
|
107
|
+
ref = resolve_weights_ref(
|
|
108
|
+
model_key=model_key,
|
|
109
|
+
repo_id=repo_id,
|
|
110
|
+
revision=revision,
|
|
111
|
+
hf_root=hf_root,
|
|
112
|
+
cache_dir=cache_dir,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
def _download(force):
|
|
116
|
+
rel = ref.weights_file_in_repo
|
|
117
|
+
return hf_hub_download(
|
|
118
|
+
repo_id=ref.repo_id,
|
|
119
|
+
repo_type="model",
|
|
120
|
+
filename=rel,
|
|
121
|
+
revision=ref.revision,
|
|
122
|
+
cache_dir=cache_dir,
|
|
123
|
+
force_download=force,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
local_path = _download(force=False)
|
|
127
|
+
|
|
128
|
+
if not verify:
|
|
129
|
+
return local_path
|
|
130
|
+
|
|
131
|
+
actual = _sha256_file(local_path)
|
|
132
|
+
if actual == ref.expected_sha256:
|
|
133
|
+
return local_path
|
|
134
|
+
|
|
135
|
+
if retry_on_mismatch:
|
|
136
|
+
# Corrupted cache or partial download. Force another new download once.
|
|
137
|
+
local_path = _download(force=True)
|
|
138
|
+
actual = _sha256_file(local_path)
|
|
139
|
+
if actual == ref.expected_sha256:
|
|
140
|
+
return local_path
|
|
141
|
+
|
|
142
|
+
raise WeightsError(
|
|
143
|
+
"SHA256 mismatch for downloaded weights.\n"
|
|
144
|
+
f"model_key: {ref.model_key}\n"
|
|
145
|
+
f"repo_id: {ref.repo_id}\n"
|
|
146
|
+
f"revision: {ref.revision}\n"
|
|
147
|
+
f"file: {ref.weights_file_in_repo}\n"
|
|
148
|
+
f"expected: {ref.expected_sha256}\n"
|
|
149
|
+
f"actual: {actual}\n"
|
|
150
|
+
"Tip: If you are offline or behind a proxy, downloads may be partial or may not work."
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
DEFAULT_WEIGHTS_REPO = os.environ.get(
|
|
155
|
+
"SATWATER_WEIGHTS_REPO", "busayojee/sat-water-weights"
|
|
156
|
+
)
|
|
157
|
+
DEFAULT_WEIGHTS_REV = os.environ.get("SATWATER_WEIGHTS_REV", "main")
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def get_weights_path(
|
|
161
|
+
model_key,
|
|
162
|
+
repo_id=DEFAULT_WEIGHTS_REPO,
|
|
163
|
+
revision=DEFAULT_WEIGHTS_REV,
|
|
164
|
+
hf_root="weights",
|
|
165
|
+
cache_dir=None,
|
|
166
|
+
):
|
|
167
|
+
# fpr inference
|
|
168
|
+
return download_weights(
|
|
169
|
+
model_key=model_key,
|
|
170
|
+
repo_id=repo_id,
|
|
171
|
+
revision=revision,
|
|
172
|
+
hf_root=hf_root,
|
|
173
|
+
cache_dir=cache_dir,
|
|
174
|
+
verify=True,
|
|
175
|
+
retry_on_mismatch=True,
|
|
176
|
+
)
|