sat-water 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
satwater/utils.py ADDED
@@ -0,0 +1,45 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import random
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+
9
+
10
+ def set_seed(seed=42):
11
+ os.environ["PYTHONHASHSEED"] = str(seed)
12
+ random.seed(seed)
13
+ np.random.seed(seed)
14
+ try:
15
+ import tensorflow as tf
16
+
17
+ tf.random.set_seed(seed)
18
+ except Exception:
19
+ pass
20
+
21
+
22
+ def ensure_dir(path):
23
+ p = Path(path)
24
+ p.mkdir(parents=True, exist_ok=True)
25
+ return p
26
+
27
+
28
+ def parse_shape(shape):
29
+ parts = [int(x.strip()) for x in shape.split(",")]
30
+ if len(parts) != 3:
31
+ raise ValueError("shape must be 'H,W,C' e.g. '128,128,3'")
32
+ return parts[0], parts[1], parts[2]
33
+
34
+
35
+ def parse_models_arg(models_arg):
36
+ out = {}
37
+ for item in models_arg.split(","):
38
+ item = item.strip()
39
+ if not item:
40
+ continue
41
+ if "=" not in item:
42
+ raise ValueError("models must be comma-separated key=path pairs")
43
+ k, v = item.split("=", 1)
44
+ out[k.strip()] = v.strip()
45
+ return out
satwater/weights.py ADDED
@@ -0,0 +1,176 @@
1
+ """
2
+ Created on Fri Jan 16 16:21:44 2026
3
+
4
+ @author: Busayo Alabi
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import hashlib
10
+ import json
11
+ import os
12
+ from dataclasses import dataclass
13
+
14
+ from huggingface_hub import hf_hub_download
15
+
16
+
17
+ class WeightsError(RuntimeError):
18
+ pass
19
+
20
+
21
+ @dataclass(frozen=True)
22
+ class WeightsRef:
23
+ model_key: str
24
+ repo_id: str
25
+ revision: str
26
+ hf_root: str
27
+ weights_file_in_repo: str
28
+ expected_sha256: str
29
+ input_shape: str | None = None
30
+
31
+
32
+ def _sha256_file(path, chunk_size=1024 * 1024):
33
+ h = hashlib.sha256()
34
+ with open(path, "rb") as f:
35
+ for chunk in iter(lambda: f.read(chunk_size), b""):
36
+ h.update(chunk)
37
+ return h.hexdigest()
38
+
39
+
40
+ def _load_manifest(repo_id, revision, hf_root="weights", cache_dir=None):
41
+ """
42
+ Download and parse manifest from HF.
43
+ """
44
+ manifest_path = hf_hub_download(
45
+ repo_id=repo_id,
46
+ repo_type="model",
47
+ filename=f"{hf_root.strip('/')}/manifest.json",
48
+ revision=revision,
49
+ cache_dir=cache_dir,
50
+ )
51
+ with open(manifest_path, encoding="utf-8") as f:
52
+ return json.load(f)
53
+
54
+
55
+ def resolve_weights_ref(
56
+ model_key, repo_id, revision="main", hf_root="weights", cache_dir=None
57
+ ):
58
+ """
59
+ Resolve a model_key into the exact file and expected sha256 from manifest.
60
+ """
61
+ manifest = _load_manifest(
62
+ repo_id=repo_id, revision=revision, hf_root=hf_root, cache_dir=cache_dir
63
+ )
64
+
65
+ if "models" not in manifest or not isinstance(manifest["models"], dict):
66
+ raise WeightsError("Invalid manifest.json: missing 'models' mapping")
67
+
68
+ models = manifest["models"]
69
+ if model_key not in models:
70
+ available = ", ".join(sorted(models.keys()))
71
+ raise WeightsError(f"Unknown model_key='{model_key}'. Available: {available}")
72
+
73
+ entry = models[model_key]
74
+ try:
75
+ weights_file = entry["weights_file"]
76
+ sha = entry["sha256"]
77
+ input_shape = entry.get("input_shape")
78
+ except KeyError as e:
79
+ raise WeightsError(
80
+ f"Invalid manifest entry for '{model_key}': missing {e!s}"
81
+ ) from e
82
+ weights_file_in_repo = f"{hf_root.strip('/')}/{weights_file}".replace("//", "/")
83
+
84
+ return WeightsRef(
85
+ model_key=model_key,
86
+ repo_id=repo_id,
87
+ revision=revision,
88
+ hf_root=hf_root.strip("/"),
89
+ weights_file_in_repo=weights_file_in_repo,
90
+ expected_sha256=sha,
91
+ input_shape=input_shape,
92
+ )
93
+
94
+
95
+ def download_weights(
96
+ model_key,
97
+ repo_id,
98
+ revision="main",
99
+ hf_root="weights",
100
+ cache_dir=None,
101
+ verify=True,
102
+ retry_on_mismatch=True,
103
+ ):
104
+ """
105
+ Download weights for a given model_key, verify SHA256, and return the local path.
106
+ """
107
+ ref = resolve_weights_ref(
108
+ model_key=model_key,
109
+ repo_id=repo_id,
110
+ revision=revision,
111
+ hf_root=hf_root,
112
+ cache_dir=cache_dir,
113
+ )
114
+
115
+ def _download(force):
116
+ rel = ref.weights_file_in_repo
117
+ return hf_hub_download(
118
+ repo_id=ref.repo_id,
119
+ repo_type="model",
120
+ filename=rel,
121
+ revision=ref.revision,
122
+ cache_dir=cache_dir,
123
+ force_download=force,
124
+ )
125
+
126
+ local_path = _download(force=False)
127
+
128
+ if not verify:
129
+ return local_path
130
+
131
+ actual = _sha256_file(local_path)
132
+ if actual == ref.expected_sha256:
133
+ return local_path
134
+
135
+ if retry_on_mismatch:
136
+ # Corrupted cache or partial download. Force another new download once.
137
+ local_path = _download(force=True)
138
+ actual = _sha256_file(local_path)
139
+ if actual == ref.expected_sha256:
140
+ return local_path
141
+
142
+ raise WeightsError(
143
+ "SHA256 mismatch for downloaded weights.\n"
144
+ f"model_key: {ref.model_key}\n"
145
+ f"repo_id: {ref.repo_id}\n"
146
+ f"revision: {ref.revision}\n"
147
+ f"file: {ref.weights_file_in_repo}\n"
148
+ f"expected: {ref.expected_sha256}\n"
149
+ f"actual: {actual}\n"
150
+ "Tip: If you are offline or behind a proxy, downloads may be partial or may not work."
151
+ )
152
+
153
+
154
+ DEFAULT_WEIGHTS_REPO = os.environ.get(
155
+ "SATWATER_WEIGHTS_REPO", "busayojee/sat-water-weights"
156
+ )
157
+ DEFAULT_WEIGHTS_REV = os.environ.get("SATWATER_WEIGHTS_REV", "main")
158
+
159
+
160
+ def get_weights_path(
161
+ model_key,
162
+ repo_id=DEFAULT_WEIGHTS_REPO,
163
+ revision=DEFAULT_WEIGHTS_REV,
164
+ hf_root="weights",
165
+ cache_dir=None,
166
+ ):
167
+ # fpr inference
168
+ return download_weights(
169
+ model_key=model_key,
170
+ repo_id=repo_id,
171
+ revision=revision,
172
+ hf_root=hf_root,
173
+ cache_dir=cache_dir,
174
+ verify=True,
175
+ retry_on_mismatch=True,
176
+ )