sat-water 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,347 @@
1
+ Metadata-Version: 2.4
2
+ Name: sat-water
3
+ Version: 0.1.0
4
+ Summary: Satellite water body segmentation with pretrained weights hosted on Hugging Face.
5
+ Project-URL: Homepage, https://github.com/busayojee/sat-water
6
+ Project-URL: Repository, https://github.com/busayojee/sat-water
7
+ Author: Busayo Alabi
8
+ License: MIT License
9
+
10
+ Copyright (c) 2023 Oluwabusayo Alabi
11
+
12
+ Permission is hereby granted, free of charge, to any person obtaining a copy
13
+ of this software and associated documentation files (the "Software"), to deal
14
+ in the Software without restriction, including without limitation the rights
15
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16
+ copies of the Software, and to permit persons to whom the Software is
17
+ furnished to do so, subject to the following conditions:
18
+
19
+ The above copyright notice and this permission notice shall be included in all
20
+ copies or substantial portions of the Software.
21
+
22
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28
+ SOFTWARE.
29
+ License-File: LICENSE
30
+ Keywords: computer-vision,satellite,segmentation,water
31
+ Classifier: License :: OSI Approved :: MIT License
32
+ Classifier: Operating System :: OS Independent
33
+ Classifier: Programming Language :: Python :: 3
34
+ Requires-Python: >=3.9
35
+ Requires-Dist: huggingface-hub>=0.20
36
+ Requires-Dist: matplotlib>=3.5
37
+ Requires-Dist: numpy>=1.23
38
+ Requires-Dist: pillow>=9
39
+ Provides-Extra: tf
40
+ Requires-Dist: efficientnet==1.0.0; extra == 'tf'
41
+ Requires-Dist: image-classifiers==1.0.0; extra == 'tf'
42
+ Requires-Dist: keras-applications<=1.0.8; extra == 'tf'
43
+ Requires-Dist: segmentation-models==1.0.1; extra == 'tf'
44
+ Requires-Dist: tensorflow-macos>=2.15; (platform_system == 'Darwin' and platform_machine == 'arm64') and extra == 'tf'
45
+ Requires-Dist: tensorflow-metal>=1.0; (platform_system == 'Darwin' and platform_machine == 'arm64') and extra == 'tf'
46
+ Requires-Dist: tensorflow>=2.15; (platform_system != 'Darwin') and extra == 'tf'
47
+ Requires-Dist: tf-keras>=2.16; extra == 'tf'
48
+ Description-Content-Type: text/markdown
49
+
50
+ # SEGMENTATION OF WATER BODIES FROM SATELLITE IMAGES (sat-water)
51
+
52
+ ## Introduction
53
+ Satellite imagery is a rich source of information, and the accurate segmentation of water bodies is crucial for understanding environmental patterns and changes over time. This project aims to provide a reliable and efficient tool for extracting water regions from raw satellite images.
54
+
55
+ This repository supports two workflows:
56
+ 1. **Library usage**: install with `pip` and run inference (pretrained weights downloaded on-demand).
57
+ 2. **Training workflow**: train your own models using the included preprocessing + training pipeline.
58
+
59
+ ---
60
+
61
+ ## Dataset
62
+
63
+ The dataset for this project is gotten here [kaggle.com](https://www.kaggle.com/datasets/franciscoescobar/satellite-images-of-water-bodies). It consists of jpeg images of water bodies taken by satellites and their mask. More details of the dataset is provided on the website.
64
+
65
+ ---
66
+
67
+ ## Installation
68
+
69
+ ### As a library
70
+ ```bash
71
+ pip install sat-water
72
+ ```
73
+
74
+ To run inference/training you must install the TensorFlow extras:
75
+
76
+ ```bash
77
+ pip install "sat-water[tf]"
78
+ ```
79
+
80
+ ### From source (development)
81
+ ```bash
82
+ git clone https://github.com/busayojee/sat-water.git
83
+ cd sat-water
84
+ pip install -e .
85
+ ```
86
+
87
+ > Note: `sat-water` sets `TF_USE_LEGACY_KERAS=1` and `SM_FRAMEWORK=tf.keras` by default at import time to keep `segmentation-models` compatible.
88
+
89
+ ---
90
+
91
+ ## Pretrained models
92
+
93
+ Pretrained weights are hosted on Hugging Face and downloaded at inference time with SHA256 integrity verification.
94
+
95
+ Default weights repo:
96
+ - `busayojee/sat-water-weights`
97
+
98
+ Override weights source:
99
+ ```bash
100
+ export SATWATER_WEIGHTS_REPO="busayojee/sat-water-weights"
101
+ export SATWATER_WEIGHTS_REV="main"
102
+ ```
103
+
104
+ ### Available model keys
105
+
106
+ This project was trained on 2 models. The <b>UNET</b> with no backbone and the UNET with a <b>RESNET34</b> backbone of which 2 different models were trained on different sizes of images and also different hyperparameters.
107
+
108
+ | Model key | Architecture | Input size | Notes |
109
+ |---|---|---:|---|
110
+ | `resnet34_256` | UNet + ResNet34 backbone | 256×256 | Best speed/quality tradeoff |
111
+ | `resnet34_512` | UNet + ResNet34 backbone | 512×512 | Higher-res boundaries; slower |
112
+ | `unet` | UNet (no backbone) | 128×128 | Currently unavailable in weights repo |
113
+
114
+ ---
115
+
116
+ ## Quickstart (library inference)
117
+
118
+ ```python
119
+ from satwater.inference import segment_image
120
+
121
+ res = segment_image(
122
+ "path/to/image.jpg",
123
+ model="resnet34_512", # or "resnet34_256"
124
+ return_overlay=True,
125
+ show=False,
126
+ )
127
+
128
+ mask = res.masks["resnet34_512"] # (H, W, 1)
129
+ overlay = res.overlays["resnet34_512"] # (H, W, 3)
130
+ ```
131
+
132
+ ---
133
+
134
+ ## Inference API
135
+
136
+ `segment_image(...)` is the recommended entrypoint for package users.
137
+
138
+ ### Parameters (commonly used)
139
+
140
+ - `image_path` *(str)*: path to an input image (`.jpg`, `.png`, etc.)
141
+ - `model` *(str)*: one of `resnet34_256`, `resnet34_512` (and `unet` once available)
142
+ - `return_overlay` *(bool)*: whether to return an overlay image (original image + blended water mask)
143
+ - `show` *(bool)*: whether to display the result via matplotlib (useful in notebooks / local runs)
144
+
145
+ ### Weights source / versioning
146
+
147
+ - `repo_id` *(str, optional)*: Hugging Face repo containing weights (defaults to `SATWATER_WEIGHTS_REPO`)
148
+ - `revision` *(str, optional)*: branch / tag / commit (defaults to `SATWATER_WEIGHTS_REV`)
149
+ - `save_dir` *(str | Path | None, optional)*: output directory (if supported in your local version).
150
+ If you want saving, you can always do it manually from the returned arrays (example below).
151
+
152
+ #### Manual saving
153
+
154
+ ```python
155
+ from PIL import Image
156
+ import numpy as np
157
+
158
+ Image.fromarray((mask.squeeze(-1) * 255).astype(np.uint8)).save("mask.png")
159
+ Image.fromarray(overlay).save("overlay.png")
160
+ ```
161
+
162
+ ---
163
+
164
+ ## Training history (reference)
165
+
166
+ The plots below are from historical runs in this repository and are provided to show convergence behavior.
167
+
168
+ | UNet (baseline) | ResNet34-UNet (256×256) | ResNet34-UNet (512×512) |
169
+ |:--:|:--:|:--:|
170
+ | <img width="260" alt="UNet History" src="https://github.com/busayojee/sat-water/blob/main/assets/results/history_unet.png"> | <img width="260" alt="ResNet34 256 History" src="https://github.com/busayojee/sat-water/blob/main/assets/results/history_resnet34.png"> | <img width="260" alt="ResNet34 512 History" src="https://github.com/busayojee/sat-water/blob/main/assets/results/historyresnet34(2).png"> |
171
+
172
+ ---
173
+
174
+ ## Inference examples
175
+
176
+ Qualitative predictions produced by the three models.
177
+
178
+ | UNet | ResNet34-UNet (256×256) | ResNet34-UNet (512×512) |
179
+ |:--:|:--:|:--:|
180
+ | <img width="260" alt="UNet Prediction" src="https://github.com/busayojee/sat-water/blob/main/assets/results/prediciton_unet.png"> | <img width="260" alt="ResNet34 256 Prediction" src="https://github.com/busayojee/sat-water/blob/main/assets/results/prediciton_resnet34.png"> | <img width="260" alt="ResNet34 512 Prediction" src="https://github.com/busayojee/sat-water/blob/main/assets/results/prediciton_resnet34(2).png"> |
181
+
182
+ ---
183
+
184
+ ## Single test instance (end-to-end)
185
+
186
+ Using all models to predict a single test instance.
187
+
188
+ | Test Image | Prediction |
189
+ |:--:|:--:|
190
+ | <img width="300" alt="Test Image" src="https://github.com/busayojee/sat-water/blob/main/assets/results/test2.jpg"> | <img width="300" alt="Prediction" src="https://github.com/busayojee/sat-water/blob/main/assets/results/prediciton_test.png"> |
191
+
192
+ Label overlay of the best prediction (ResNet34-UNet 512×512 in that run):
193
+
194
+ <img width="320" alt="Overlay" src="https://github.com/busayojee/sat-water/blob/main/assets/results/test2.png">
195
+
196
+ ---
197
+
198
+ ## Train your own model
199
+
200
+ ### Preprocessing
201
+
202
+ ```python
203
+ from satwater.preprocess import Preprocess
204
+
205
+ train_ds, val_ds, test_ds = Preprocess.data_load(
206
+ dataset_dir="path/to/dataset",
207
+ masks_dir="/Masks",
208
+ images_dir="/Images",
209
+ split=(0.7, 0.2, 0.1),
210
+ shape=(256, 256),
211
+ batch_size=16,
212
+ channels=3,
213
+ )
214
+ ```
215
+
216
+ ### Training (UNet baseline)
217
+ ```python
218
+ from satwater.models import Unet
219
+
220
+ history = Unet.train(
221
+ train_ds,
222
+ val_ds,
223
+ shape=(128, 128, 3),
224
+ n_classes=2,
225
+ lr=1e-4,
226
+ loss=Unet.loss,
227
+ metrics=Unet.metrics,
228
+ name="unet",
229
+ )
230
+ ```
231
+
232
+ ### Training (ResNet34-UNet)
233
+ ```python
234
+ from satwater.models import BackboneModels
235
+
236
+ bm = BackboneModels("resnet34", train_ds, val_ds, test_ds, name="resnet34_256")
237
+ bm.build_model(n_classes=2, n=1, lr=1e-4)
238
+ history = bm.train()
239
+ ```
240
+
241
+ > For a 512×512 run, load a second dataset with `shape=(512, 512)` and use a different model name (e.g. `resnet34_512`) to keep artifacts separate.
242
+
243
+
244
+ ### Inference
245
+ To run inference for UNET
246
+
247
+ ```
248
+ inference_u = Inference(model="path/to/model",name="unet")
249
+ inference_u.predict_ds(test_ds)
250
+ ```
251
+
252
+ for RESNET 1 and 2
253
+
254
+ ```
255
+ inference_r = Inference(model="path/to/model",name="resnet34")
256
+ inference_r.predict_ds(test_ds)
257
+
258
+ inference_r2 = Inference(model="path/to/model",name="resnet34(2)")
259
+ inference_r2.predict_ds(test_ds1)
260
+ ```
261
+
262
+ For all 3 models together
263
+
264
+ ```
265
+ models={"unet":"path/to/model1", "resnet34":"path/to/model2", "resnet34(2)":"path/to/model3"}
266
+ inference_multiple = Inference(model=models)
267
+ inference_multiple.predict_ds(test_ds)
268
+ ```
269
+
270
+ ## CLI (optional)
271
+
272
+ If you included the `scripts/` folder in your package/repo, you can run the scripts directly.
273
+
274
+ ### Training CLI
275
+
276
+ UNet:
277
+
278
+ ```bash
279
+ python scripts/train.py --dataset path/to/dataset --image-folder /Images --mask-folder /Masks --shape 128,128,3 --batch-size 16 --split 0.2,0.1 --channels 3 --model unet --name unet --epochs 100 --lr 1e-4
280
+ ```
281
+
282
+ ResNet34-UNet (256):
283
+
284
+ ```bash
285
+ python scripts/train.py --dataset path/to/dataset --image-folder /Images --mask-folder /Masks --shape 256,256,3 --batch-size 8 --split 0.2,0.1 --channels 3 --model resnet34 --name resnet34_256 --epochs 100 --lr 1e-4
286
+ ```
287
+
288
+ ResNet34-UNet (512):
289
+
290
+ ```bash
291
+ python scripts/train.py --dataset path/to/dataset --image-folder /Images --mask-folder /Masks --shape 512,512,3 --batch-size 4 --split 0.2,0.1 --channels 3 --model resnet34(2) --name resnet34_512 --epochs 100 --lr 1e-4
292
+ ```
293
+
294
+ ### Inference CLI
295
+
296
+ Single model:
297
+
298
+ ```bash
299
+ python scripts/infer.py --image path/to/image.jpg --model path/to/model.keras --name unet --out prediction
300
+ ```
301
+
302
+ Multiple models:
303
+
304
+ ```bash
305
+ python scripts/infer.py --image path/to/image.jpg --models "unet=path/to/unet.keras,resnet34=path/to/resnet34.keras,resnet34(2)=path/to/resnet34_2.keras" --out prediction
306
+ ```
307
+
308
+ ### Upload weights to Hugging Face (optional)
309
+
310
+ ```bash
311
+ export HF_TOKEN="YOUR_HUGGINGFACE_TOKEN"
312
+
313
+ python scripts/weights.py --repo-id user/repo --hf-root weights --out-dir dist/weights --model unet=path/to/unet.keras@128,128,3 --model resnet34_256=path/to/resnet34_256.keras@256,256,3 --model resnet34_512=path/to/resnet34_512.keras@512,512,3
314
+ ```
315
+
316
+ ---
317
+
318
+ ## Contributing
319
+
320
+ Contributions are welcome — especially around:
321
+ - adding/refreshing pretrained weights (including UNet)
322
+ - improving inference UX (CLI, batch inference, better overlays)
323
+ - expanding tests and CI matrix
324
+ - model evaluation and benchmarking on additional datasets
325
+
326
+ ### How to contribute
327
+ 1. Fork the repo
328
+ 2. Create a feature branch:
329
+ ```bash
330
+ git checkout -b feat/my-change
331
+ ```
332
+ 3. Run checks locally:
333
+ ```bash
334
+ pytest -q
335
+ ruff check .
336
+ ruff format .
337
+ ```
338
+ 4. Open a pull request with a short summary + screenshots (if changing inference output)
339
+
340
+ If you’re reporting a bug, please include:
341
+ - OS + Python version
342
+ - TensorFlow version
343
+ - full traceback + a minimal repro snippet
344
+
345
+ ---
346
+
347
+
@@ -0,0 +1,11 @@
1
+ satwater/__init__.py,sha256=6qCWit47FBUB1-Six68y9SxvH7XP7Jl86KE_pft17VU,111
2
+ satwater/builders.py,sha256=JhuTFhWs-a_w5n6wQBGybJILJXUs_rEH7xTQ1ZWJLBo,5280
3
+ satwater/inference.py,sha256=b_SDDocomT8I_xt8Um6jHHj7E1CclRgyx79Z8ZIWM_k,10054
4
+ satwater/models.py,sha256=DEaQ_Bx8ghPD9NPrKTZ2c9ZGTCuZIWrrTYQgd8SdBfQ,8622
5
+ satwater/preprocess.py,sha256=zvPJOXhj2DeKCo0H6Df0OlKdZB52woGuI4Z7orCtvjg,5120
6
+ satwater/utils.py,sha256=6QL6exM0bjjXK44oguAphkT7dZVv-tQhsQ3csemlxDM,989
7
+ satwater/weights.py,sha256=g4KeTpYaB74TQCvjLcy5eIn4vahM_P50yKYDSf-XGe8,4611
8
+ sat_water-0.1.0.dist-info/METADATA,sha256=Zg4h6fnvfQ5IEHbHpdbK6Y_bGC9zkGD8NG2bRPg-wXQ,11989
9
+ sat_water-0.1.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
10
+ sat_water-0.1.0.dist-info/licenses/LICENSE,sha256=RbsEzAC18a5cM_uuYZCLWyQlYBEJX_M_U_kA58cYqkc,1073
11
+ sat_water-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.28.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Oluwabusayo Alabi
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
satwater/__init__.py ADDED
@@ -0,0 +1,4 @@
1
+ import os
2
+
3
+ os.environ.setdefault("TF_USE_LEGACY_KERAS", "1")
4
+ os.environ.setdefault("SM_FRAMEWORK", "tf.keras")
satwater/builders.py ADDED
@@ -0,0 +1,179 @@
1
+ """
2
+ Created on Fri Jan 16 19:43:12 2026
3
+
4
+ @author: Busayo Alabi
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import os
10
+ from collections.abc import Callable
11
+ from dataclasses import dataclass
12
+
13
+ from satwater.weights import get_weights_path
14
+
15
+ try:
16
+ import tensorflow as tf
17
+ except Exception as e:
18
+ raise ImportError(
19
+ "TensorFlow is required for sat-water inference/training.\n\n"
20
+ "Install TensorFlow first, then reinstall sat-water.\n"
21
+ "Recommended:\n"
22
+ " Linux/Windows: pip install 'tensorflow'\n"
23
+ " Apple Silicon: pip install 'tensorflow-macos' 'tensorflow-metal'\n\n"
24
+ "If you are using segmentation-models with TF legacy Keras:\n"
25
+ " pip install tf-keras segmentation-models\n"
26
+ ) from e
27
+
28
+ try:
29
+ import segmentation_models as sm
30
+ except Exception:
31
+ sm = None
32
+
33
+
34
+ class BuilderError(RuntimeError):
35
+ pass
36
+
37
+
38
+ def build_custom_unet(input_shape=(128, 128, 3), n_classes=2):
39
+ """
40
+ Pretrained U-Net builder matching exported weights.
41
+ """
42
+
43
+ def conv_block(x, filters: int):
44
+ x = tf.keras.layers.Conv2D(filters, (3, 3), padding="same")(x)
45
+ x = tf.keras.layers.BatchNormalization()(x)
46
+ x = tf.keras.layers.Activation("relu")(x)
47
+
48
+ x = tf.keras.layers.Conv2D(filters, (3, 3), padding="same")(x)
49
+ x = tf.keras.layers.BatchNormalization()(x)
50
+ x = tf.keras.layers.Activation("relu")(x)
51
+ return x
52
+
53
+ def encoder_block(x, filters: int):
54
+ c = conv_block(x, filters)
55
+ p = tf.keras.layers.MaxPooling2D((2, 2), strides=(2, 2))(c)
56
+ return p, c
57
+
58
+ def decoder_block(x, skip, filters: int):
59
+ x = tf.keras.layers.Conv2DTranspose(
60
+ filters, (2, 2), strides=(2, 2), padding="same"
61
+ )(x)
62
+ x = tf.keras.layers.Concatenate(axis=-1)([skip, x])
63
+ x = tf.keras.layers.BatchNormalization()(x)
64
+ x = tf.keras.layers.Activation("relu")(x)
65
+
66
+ x = tf.keras.layers.Conv2D(filters, (3, 3), padding="same")(x)
67
+ x = tf.keras.layers.BatchNormalization()(x)
68
+ x = tf.keras.layers.Activation("relu")(x)
69
+
70
+ x = tf.keras.layers.Conv2D(filters, (3, 3), padding="same")(x)
71
+ x = tf.keras.layers.BatchNormalization()(x)
72
+ x = tf.keras.layers.Activation("relu")(x)
73
+ return x
74
+
75
+ inputs = tf.keras.layers.Input(shape=input_shape)
76
+
77
+ p0, c0 = encoder_block(inputs, 32)
78
+ p1, c1 = encoder_block(p0, 64)
79
+ p2, c2 = encoder_block(p1, 128)
80
+ p3, c3 = encoder_block(p2, 256)
81
+ p4, c4 = encoder_block(p3, 512)
82
+
83
+ center = conv_block(p4, 1024)
84
+
85
+ d4 = decoder_block(center, c4, 512)
86
+ d3 = decoder_block(d4, c3, 256)
87
+ d2 = decoder_block(d3, c2, 128)
88
+ d1 = decoder_block(d2, c1, 64)
89
+ d0 = decoder_block(d1, c0, 32)
90
+
91
+ outputs = tf.keras.layers.Conv2D(n_classes, (1, 1), activation="softmax")(d0)
92
+
93
+ return tf.keras.Model(inputs=inputs, outputs=outputs, name="satwater_unet")
94
+
95
+
96
+ @dataclass(frozen=True)
97
+ class PretrainedModel:
98
+ """
99
+ Container returned by `load_pretrained()`.
100
+ args
101
+ model: tf.keras.Model ready for inference
102
+ preprocess_func: optional preprocessing for backbones, for exampkle Resnet34
103
+ input_shape: expected (H, W, C)
104
+ model_key: key saved in manifest : "unet" | "resnet34_256" | "resnet34_512"
105
+ """
106
+
107
+ model: tf.keras.Model
108
+ preprocess_func: Callable | None
109
+ input_shape: tuple[int, int, int]
110
+ model_key: str
111
+
112
+
113
+ MODEL_SPECS = {
114
+ "unet": (128, 128, 3),
115
+ "resnet34_256": (256, 256, 3),
116
+ "resnet34_512": (512, 512, 3),
117
+ }
118
+
119
+
120
+ def _require_segmentation_models():
121
+ if sm is None:
122
+ raise ImportError(
123
+ "segmentation-models is required for ResNet34 pretrained models.\n"
124
+ "Install it with:\n"
125
+ " pip install segmentation-models tf-keras\n"
126
+ )
127
+
128
+
129
+ def build_resnet34_unet(input_shape, n_classes=2):
130
+ _require_segmentation_models()
131
+ model = sm.Unet(
132
+ "resnet34",
133
+ classes=n_classes,
134
+ activation="softmax",
135
+ encoder_weights=None,
136
+ input_shape=input_shape,
137
+ )
138
+ preprocess_func = sm.get_preprocessing("resnet34")
139
+ return model, preprocess_func
140
+
141
+
142
+ def load_pretrained(
143
+ model_key, repo_id=None, revision="main", n_classes=2, hf_root="weights"
144
+ ):
145
+ """
146
+ Download and verify pretrained weights from Hugging Face, build the correct architecture,
147
+ load weights, and then return a ready model.
148
+
149
+ Returns the pretrained model
150
+ """
151
+ if model_key not in MODEL_SPECS:
152
+ raise BuilderError(
153
+ f"Unknown model_key='{model_key}'. Expected the following: {', '.join(MODEL_SPECS.keys())}"
154
+ )
155
+ input_shape = MODEL_SPECS[model_key]
156
+
157
+ weights_path = get_weights_path(
158
+ model_key=model_key,
159
+ repo_id=repo_id
160
+ or os.environ.get("SATWATER_WEIGHTS_REPO", "busayojee/sat-water-weights"),
161
+ revision=revision,
162
+ hf_root=hf_root,
163
+ )
164
+
165
+ preprocess_func = None
166
+ if model_key == "unet":
167
+ raise NotImplementedError("Unet not ready yet")
168
+ else:
169
+ model, preprocess_func = build_resnet34_unet(
170
+ input_shape=input_shape, n_classes=n_classes
171
+ )
172
+
173
+ model.load_weights(weights_path)
174
+ return PretrainedModel(
175
+ model=model,
176
+ preprocess_func=preprocess_func,
177
+ input_shape=input_shape,
178
+ model_key=model_key,
179
+ )