sat-water 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sat_water-0.1.0.dist-info/METADATA +347 -0
- sat_water-0.1.0.dist-info/RECORD +11 -0
- sat_water-0.1.0.dist-info/WHEEL +4 -0
- sat_water-0.1.0.dist-info/licenses/LICENSE +21 -0
- satwater/__init__.py +4 -0
- satwater/builders.py +179 -0
- satwater/inference.py +313 -0
- satwater/models.py +229 -0
- satwater/preprocess.py +161 -0
- satwater/utils.py +45 -0
- satwater/weights.py +176 -0
satwater/inference.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Created on Wed Nov 18 13:22:57 2023
|
|
3
|
+
|
|
4
|
+
@author: Busayo Alabi
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Union
|
|
13
|
+
|
|
14
|
+
import matplotlib.pyplot as plt
|
|
15
|
+
import numpy as np
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
import tensorflow as tf
|
|
19
|
+
except Exception as e:
|
|
20
|
+
raise ImportError(
|
|
21
|
+
"TensorFlow is required for sat-water inference/training.\n\n"
|
|
22
|
+
"Install TensorFlow first, then reinstall sat-water.\n"
|
|
23
|
+
"Recommended:\n"
|
|
24
|
+
" Linux/Windows: pip install 'tensorflow'\n"
|
|
25
|
+
" Apple Silicon: pip install 'tensorflow-macos' 'tensorflow-metal'\n\n"
|
|
26
|
+
"If you are using segmentation-models with TF legacy Keras:\n"
|
|
27
|
+
" pip install tf-keras segmentation-models\n"
|
|
28
|
+
) from e
|
|
29
|
+
|
|
30
|
+
from satwater.builders import load_pretrained
|
|
31
|
+
from satwater.preprocess import Preprocess
|
|
32
|
+
|
|
33
|
+
try:
|
|
34
|
+
from PIL import Image
|
|
35
|
+
except Exception:
|
|
36
|
+
Image = None
|
|
37
|
+
|
|
38
|
+
ArrayLikeImage = Union[str, np.ndarray, tf.Tensor, "Image.Image"]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass(frozen=True)
|
|
42
|
+
class SegmentationResult:
|
|
43
|
+
"""
|
|
44
|
+
Result from a segmentation call.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
masks: dict[str, np.ndarray]
|
|
48
|
+
overlays: dict[str, np.ndarray]
|
|
49
|
+
base_image: np.ndarray
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class Inference:
|
|
53
|
+
"""
|
|
54
|
+
Supports both pretrained model keys and local saved model paths:
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
pretrained_keys = {"unet", "resnet34_256", "resnet34_512"}
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self, model="unet", *, name=None, repo_id=None, revision="main", save_dir=None
|
|
61
|
+
):
|
|
62
|
+
self.repo_id = repo_id or os.environ.get(
|
|
63
|
+
"SATWATER_WEIGHTS_REPO", "busayojee/sat-water-weights"
|
|
64
|
+
)
|
|
65
|
+
self.revision = revision
|
|
66
|
+
self.save_dir = Path(save_dir) if save_dir is not None else None
|
|
67
|
+
self.name = name
|
|
68
|
+
self.preprocess_func = None
|
|
69
|
+
self.preprocess_funcs = {}
|
|
70
|
+
|
|
71
|
+
if isinstance(model, dict):
|
|
72
|
+
self.model = {}
|
|
73
|
+
for key, spec in model.items():
|
|
74
|
+
m, pp = self._load_one(spec, key_hint=key)
|
|
75
|
+
self.model[key] = m
|
|
76
|
+
self.preprocess_funcs[key] = pp
|
|
77
|
+
return
|
|
78
|
+
|
|
79
|
+
if isinstance(model, str):
|
|
80
|
+
if model in self.pretrained_keys:
|
|
81
|
+
m, pp = self._load_pretrained(model)
|
|
82
|
+
self.model = m
|
|
83
|
+
self.name = model
|
|
84
|
+
self.preprocess_func = pp
|
|
85
|
+
return
|
|
86
|
+
|
|
87
|
+
if self._looks_like_file(model):
|
|
88
|
+
self.model = tf.keras.models.load_model(model, compile=False)
|
|
89
|
+
self.name = name or "local"
|
|
90
|
+
return
|
|
91
|
+
|
|
92
|
+
raise ValueError(
|
|
93
|
+
"model must be a pretrained key (unet/resnet34_256/resnet34_512), "
|
|
94
|
+
"a local path to a saved model, or a dict of them."
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
raise TypeError("model must be a str or dict[str, str]")
|
|
98
|
+
|
|
99
|
+
@staticmethod
|
|
100
|
+
def _looks_like_file(s):
|
|
101
|
+
return ("/" in s) or s.endswith(".keras") or s.endswith(".h5")
|
|
102
|
+
|
|
103
|
+
def _load_pretrained(self, key):
|
|
104
|
+
pm = load_pretrained(
|
|
105
|
+
model_key=key, repo_id=self.repo_id, revision=self.revision
|
|
106
|
+
)
|
|
107
|
+
return pm.model, pm.preprocess_func
|
|
108
|
+
|
|
109
|
+
def _load_one(self, spec, key_hint=None):
|
|
110
|
+
if spec in self.pretrained_keys:
|
|
111
|
+
return self._load_pretrained(spec)
|
|
112
|
+
if self._looks_like_file(spec):
|
|
113
|
+
return tf.keras.models.load_model(spec, compile=False), None
|
|
114
|
+
if key_hint in self.pretrained_keys:
|
|
115
|
+
return self._load_pretrained(key_hint)
|
|
116
|
+
raise ValueError(
|
|
117
|
+
f"Unknown model spec '{spec}'. Use pretrained key or local model path."
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
@staticmethod
|
|
121
|
+
def _to_tf_image(x):
|
|
122
|
+
if isinstance(x, str):
|
|
123
|
+
raw = tf.io.read_file(x)
|
|
124
|
+
img = tf.image.decode_image(raw, channels=3, expand_animations=False)
|
|
125
|
+
img = tf.cast(img, tf.float32) / 255.0
|
|
126
|
+
return img
|
|
127
|
+
|
|
128
|
+
if Image is not None and isinstance(x, Image.Image):
|
|
129
|
+
arr = np.asarray(x.convert("RGB"), dtype=np.float32) / 255.0
|
|
130
|
+
return tf.convert_to_tensor(arr, dtype=tf.float32)
|
|
131
|
+
|
|
132
|
+
if isinstance(x, np.ndarray):
|
|
133
|
+
arr = x
|
|
134
|
+
if arr.ndim == 2:
|
|
135
|
+
arr = np.stack([arr, arr, arr], axis=-1)
|
|
136
|
+
if arr.shape[-1] == 4:
|
|
137
|
+
arr = arr[..., :3]
|
|
138
|
+
arr = arr.astype(np.float32)
|
|
139
|
+
if arr.max() > 1.5:
|
|
140
|
+
arr = arr / 255.0
|
|
141
|
+
return tf.convert_to_tensor(arr, dtype=tf.float32)
|
|
142
|
+
|
|
143
|
+
if isinstance(x, tf.Tensor):
|
|
144
|
+
t = tf.cast(x, tf.float32)
|
|
145
|
+
if t.shape.rank == 2:
|
|
146
|
+
t = tf.stack([t, t, t], axis=-1)
|
|
147
|
+
if t.shape.rank == 3 and t.shape[-1] == 4:
|
|
148
|
+
t = t[..., :3]
|
|
149
|
+
tmax = tf.reduce_max(t)
|
|
150
|
+
t = tf.cond(tmax > 1.5, lambda: t / 255.0, lambda: t)
|
|
151
|
+
return t
|
|
152
|
+
|
|
153
|
+
raise TypeError(
|
|
154
|
+
"Unsupported image type. Use path, PIL.Image, numpy array, or tf.Tensor."
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
@staticmethod
|
|
158
|
+
def _resize(img, size):
|
|
159
|
+
return tf.image.resize(img, size=size, method="bilinear")
|
|
160
|
+
|
|
161
|
+
@staticmethod
|
|
162
|
+
def _mask_to_uint8(mask):
|
|
163
|
+
m = np.squeeze(mask)
|
|
164
|
+
if m.dtype != np.uint8:
|
|
165
|
+
m = m.astype(np.uint8)
|
|
166
|
+
return m
|
|
167
|
+
|
|
168
|
+
@staticmethod
|
|
169
|
+
def overlay_mask(image, mask, alpha=0.55, color=(0.0, 0.4, 1.0)):
|
|
170
|
+
img = image
|
|
171
|
+
if img.dtype != np.float32:
|
|
172
|
+
img = img.astype(np.float32) / 255.0
|
|
173
|
+
img = np.clip(img, 0.0, 1.0)
|
|
174
|
+
|
|
175
|
+
m = np.squeeze(mask)
|
|
176
|
+
water = (m > 0).astype(np.float32)
|
|
177
|
+
water3 = np.repeat(water[:, :, None], 3, axis=2)
|
|
178
|
+
|
|
179
|
+
col = np.array(color, dtype=np.float32)[None, None, :]
|
|
180
|
+
overlay = img * (1.0 - alpha * water3) + col * (alpha * water3)
|
|
181
|
+
return (np.clip(overlay, 0.0, 1.0) * 255.0).astype(np.uint8)
|
|
182
|
+
|
|
183
|
+
def _maybe_save(self, arr, fname):
|
|
184
|
+
if self.save_dir is None:
|
|
185
|
+
return
|
|
186
|
+
self.save_dir.mkdir(parents=True, exist_ok=True)
|
|
187
|
+
out = self.save_dir / fname
|
|
188
|
+
plt.imsave(out.as_posix(), arr)
|
|
189
|
+
|
|
190
|
+
def predict(
|
|
191
|
+
self, image, *, return_overlay=True, save=False, show=False, fname="prediction"
|
|
192
|
+
):
|
|
193
|
+
raw_img = self._to_tf_image(image)
|
|
194
|
+
model_img = Preprocess.normalization_layer(raw_img)
|
|
195
|
+
base_img = np.clip(raw_img.numpy(), 0.0, 1.0).astype(np.float32)
|
|
196
|
+
|
|
197
|
+
masks = {}
|
|
198
|
+
overlays = {}
|
|
199
|
+
|
|
200
|
+
if isinstance(self.model, dict):
|
|
201
|
+
for key, model in self.model.items():
|
|
202
|
+
mask_np = self._predict_for_key(key, model, model_img)
|
|
203
|
+
masks[key] = mask_np
|
|
204
|
+
if return_overlay:
|
|
205
|
+
ov = self.overlay_mask(
|
|
206
|
+
self._image_for_key(key, raw_img).numpy(), mask_np
|
|
207
|
+
)
|
|
208
|
+
overlays[key] = ov
|
|
209
|
+
if save and self.save_dir is not None:
|
|
210
|
+
self._maybe_save(ov, f"{fname}_{key}_overlay.png")
|
|
211
|
+
self._maybe_save(np.squeeze(mask_np), f"{fname}_{key}_mask.png")
|
|
212
|
+
if show:
|
|
213
|
+
self._plot_results(base_img, masks, overlays, title=fname)
|
|
214
|
+
return SegmentationResult(
|
|
215
|
+
masks=masks, overlays=overlays, base_image=base_img
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
key = self.name or "model"
|
|
219
|
+
mask_np = self._predict_for_key(key, self.model, model_img)
|
|
220
|
+
masks[key] = mask_np
|
|
221
|
+
|
|
222
|
+
if return_overlay:
|
|
223
|
+
ov = self.overlay_mask(self._image_for_key(key, raw_img).numpy(), mask_np)
|
|
224
|
+
overlays[key] = ov
|
|
225
|
+
if save and self.save_dir is not None:
|
|
226
|
+
self._maybe_save(ov, f"{fname}_{key}_overlay.png")
|
|
227
|
+
self._maybe_save(np.squeeze(mask_np), f"{fname}_{key}_mask.png")
|
|
228
|
+
|
|
229
|
+
if show:
|
|
230
|
+
self._plot_results(base_img, masks, overlays, title=fname)
|
|
231
|
+
|
|
232
|
+
return SegmentationResult(masks=masks, overlays=overlays, base_image=base_img)
|
|
233
|
+
|
|
234
|
+
def _image_for_key(self, key, img):
|
|
235
|
+
if key == "unet":
|
|
236
|
+
return self._resize(img, (128, 128))
|
|
237
|
+
if key == "resnet34_256":
|
|
238
|
+
return self._resize(img, (256, 256))
|
|
239
|
+
if key == "resnet34_512":
|
|
240
|
+
return self._resize(img, (512, 512))
|
|
241
|
+
return self._resize(img, (128, 128))
|
|
242
|
+
|
|
243
|
+
def _predict_for_key(self, key, model, img):
|
|
244
|
+
x = self._image_for_key(key, img)
|
|
245
|
+
|
|
246
|
+
if isinstance(self.model, dict):
|
|
247
|
+
pp = self.preprocess_funcs.get(key)
|
|
248
|
+
else:
|
|
249
|
+
pp = self.preprocess_func
|
|
250
|
+
|
|
251
|
+
if pp is not None:
|
|
252
|
+
x = pp(x)
|
|
253
|
+
|
|
254
|
+
x = (x - tf.reduce_min(x)) / (tf.reduce_max(x) - tf.reduce_min(x) + 1e-8)
|
|
255
|
+
|
|
256
|
+
pred = model.predict(x[tf.newaxis, ...], verbose=0)
|
|
257
|
+
pred = tf.argmax(pred, axis=-1)
|
|
258
|
+
pred = tf.expand_dims(pred, axis=-1)[0, :, :, :]
|
|
259
|
+
pred_np = pred.numpy().astype(np.int32)
|
|
260
|
+
|
|
261
|
+
if key == "resnet34_512":
|
|
262
|
+
pred_np = 1 - pred_np
|
|
263
|
+
return pred_np
|
|
264
|
+
|
|
265
|
+
@staticmethod
|
|
266
|
+
def _plot_results(base_img, masks, overlays, title="result"):
|
|
267
|
+
keys = list(masks.keys())
|
|
268
|
+
n = len(keys)
|
|
269
|
+
cols = 3 if overlays else 2
|
|
270
|
+
plt.figure(figsize=(5 * cols, 4 * n))
|
|
271
|
+
|
|
272
|
+
for i, k in enumerate(keys):
|
|
273
|
+
r = i * cols
|
|
274
|
+
|
|
275
|
+
plt.subplot(n, cols, r + 1)
|
|
276
|
+
plt.title("Image")
|
|
277
|
+
plt.imshow(base_img)
|
|
278
|
+
plt.axis("off")
|
|
279
|
+
|
|
280
|
+
plt.subplot(n, cols, r + 2)
|
|
281
|
+
plt.title(f"Mask: {k}")
|
|
282
|
+
plt.imshow(np.squeeze(masks[k]))
|
|
283
|
+
plt.axis("off")
|
|
284
|
+
|
|
285
|
+
if overlays:
|
|
286
|
+
plt.subplot(n, cols, r + 3)
|
|
287
|
+
plt.title(f"Overlay: {k}")
|
|
288
|
+
plt.imshow(overlays[k])
|
|
289
|
+
plt.axis("off")
|
|
290
|
+
|
|
291
|
+
plt.suptitle(title)
|
|
292
|
+
plt.tight_layout()
|
|
293
|
+
plt.show()
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def segment_image(
|
|
297
|
+
image,
|
|
298
|
+
*,
|
|
299
|
+
model="unet",
|
|
300
|
+
repo_id=None,
|
|
301
|
+
revision="main",
|
|
302
|
+
return_overlay=True,
|
|
303
|
+
save=False,
|
|
304
|
+
save_dir=None,
|
|
305
|
+
show=False,
|
|
306
|
+
fname="prediction",
|
|
307
|
+
):
|
|
308
|
+
infer = Inference(
|
|
309
|
+
model=model, repo_id=repo_id, revision=revision, save_dir=save_dir
|
|
310
|
+
)
|
|
311
|
+
return infer.predict(
|
|
312
|
+
image, return_overlay=return_overlay, save=save, show=show, fname=fname
|
|
313
|
+
)
|
satwater/models.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Created on Wed Nov 17 11:05:25 2023
|
|
3
|
+
|
|
4
|
+
@author: Busayo Alabi
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import matplotlib.pyplot as plt
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
import tensorflow as tf
|
|
11
|
+
except Exception as e:
|
|
12
|
+
raise ImportError(
|
|
13
|
+
"TensorFlow is required for sat-water inference/training.\n\n"
|
|
14
|
+
"Install TensorFlow first, then reinstall sat-water.\n"
|
|
15
|
+
"Recommended:\n"
|
|
16
|
+
" Linux/Windows: pip install 'tensorflow'\n"
|
|
17
|
+
" Apple Silicon: pip install 'tensorflow-macos' 'tensorflow-metal'\n\n"
|
|
18
|
+
"If you are using segmentation-models with TF legacy Keras:\n"
|
|
19
|
+
" pip install tf-keras segmentation-models\n"
|
|
20
|
+
) from e
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
import segmentation_models as sm
|
|
24
|
+
except Exception as e:
|
|
25
|
+
raise ImportError(
|
|
26
|
+
"segmentation-models is required for training backbone models.\n"
|
|
27
|
+
"Install it with:\n"
|
|
28
|
+
" pip install segmentation-models tf-keras\n"
|
|
29
|
+
) from e
|
|
30
|
+
|
|
31
|
+
from satwater.preprocess import Preprocess
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class Unet:
|
|
35
|
+
loss = sm.losses.categorical_focal_dice_loss
|
|
36
|
+
metrics = [sm.metrics.iou_score, sm.metrics.f1_score]
|
|
37
|
+
|
|
38
|
+
@staticmethod
|
|
39
|
+
def dice_coef_water(y_true, y_pred, smooth=1e-6):
|
|
40
|
+
y_true = tf.cast(y_true, tf.float32)
|
|
41
|
+
y_true = tf.squeeze(y_true, axis=-1)
|
|
42
|
+
y_pred = tf.cast(y_pred[..., 1], tf.float32)
|
|
43
|
+
|
|
44
|
+
y_true_f = tf.reshape(y_true, [tf.shape(y_true)[0], -1])
|
|
45
|
+
y_pred_f = tf.reshape(y_pred, [tf.shape(y_pred)[0], -1])
|
|
46
|
+
|
|
47
|
+
intersection = tf.reduce_sum(y_true_f * y_pred_f, axis=1)
|
|
48
|
+
denom = tf.reduce_sum(y_true_f + y_pred_f, axis=1)
|
|
49
|
+
|
|
50
|
+
dice = (2.0 * intersection + smooth) / (denom + smooth)
|
|
51
|
+
return tf.reduce_mean(dice)
|
|
52
|
+
|
|
53
|
+
@staticmethod
|
|
54
|
+
def conv_block(input, filters):
|
|
55
|
+
encoder = tf.keras.layers.Conv2D(filters, (3, 3), padding="same")(input)
|
|
56
|
+
encoder = tf.keras.layers.BatchNormalization()(encoder)
|
|
57
|
+
encoder = tf.keras.layers.Activation("relu")(encoder)
|
|
58
|
+
encoder = tf.keras.layers.Conv2D(filters, (3, 3), padding="same")(encoder)
|
|
59
|
+
encoder = tf.keras.layers.BatchNormalization()(encoder)
|
|
60
|
+
encoder = tf.keras.layers.Activation("relu")(encoder)
|
|
61
|
+
return encoder
|
|
62
|
+
|
|
63
|
+
@staticmethod
|
|
64
|
+
def encoder_block(input, filters):
|
|
65
|
+
encoder = Unet.conv_block(input, filters)
|
|
66
|
+
encoder_pool = tf.keras.layers.MaxPooling2D((2, 2), strides=(2, 2))(encoder)
|
|
67
|
+
return encoder_pool, encoder
|
|
68
|
+
|
|
69
|
+
@staticmethod
|
|
70
|
+
def decoder_block(input, concat, filters):
|
|
71
|
+
decoder = tf.keras.layers.Conv2DTranspose(
|
|
72
|
+
filters, (2, 2), strides=(2, 2), padding="same"
|
|
73
|
+
)(input)
|
|
74
|
+
decoder = tf.keras.layers.concatenate([concat, decoder], axis=-1)
|
|
75
|
+
decoder = tf.keras.layers.BatchNormalization()(decoder)
|
|
76
|
+
decoder = tf.keras.layers.Activation("relu")(decoder)
|
|
77
|
+
decoder = tf.keras.layers.Conv2D(filters, (3, 3), padding="same")(decoder)
|
|
78
|
+
decoder = tf.keras.layers.BatchNormalization()(decoder)
|
|
79
|
+
decoder = tf.keras.layers.Activation("relu")(decoder)
|
|
80
|
+
decoder = tf.keras.layers.Conv2D(filters, (3, 3), padding="same")(decoder)
|
|
81
|
+
decoder = tf.keras.layers.BatchNormalization()(decoder)
|
|
82
|
+
decoder = tf.keras.layers.Activation("relu")(decoder)
|
|
83
|
+
return decoder
|
|
84
|
+
|
|
85
|
+
@staticmethod
|
|
86
|
+
def models(input_shape, n_classes):
|
|
87
|
+
inputs = tf.keras.layers.Input(shape=input_shape)
|
|
88
|
+
encoder_pool0, encoder0 = Unet.encoder_block(inputs, 32)
|
|
89
|
+
encoder_pool1, encoder1 = Unet.encoder_block(encoder_pool0, 64)
|
|
90
|
+
encoder_pool2, encoder2 = Unet.encoder_block(encoder_pool1, 128)
|
|
91
|
+
encoder_pool3, encoder3 = Unet.encoder_block(encoder_pool2, 256)
|
|
92
|
+
encoder_pool4, encoder4 = Unet.encoder_block(encoder_pool3, 512)
|
|
93
|
+
center = Unet.conv_block(encoder_pool4, 1024)
|
|
94
|
+
decoder4 = Unet.decoder_block(center, encoder4, 512)
|
|
95
|
+
decoder3 = Unet.decoder_block(decoder4, encoder3, 256)
|
|
96
|
+
decoder2 = Unet.decoder_block(decoder3, encoder2, 128)
|
|
97
|
+
decoder1 = Unet.decoder_block(decoder2, encoder1, 64)
|
|
98
|
+
decoder0 = Unet.decoder_block(decoder1, encoder0, 32)
|
|
99
|
+
outputs = tf.keras.layers.Conv2D(n_classes, (1, 1), activation="softmax")(
|
|
100
|
+
decoder0
|
|
101
|
+
)
|
|
102
|
+
models = tf.keras.models.Model(inputs=[inputs], outputs=[outputs])
|
|
103
|
+
return models
|
|
104
|
+
|
|
105
|
+
@staticmethod
|
|
106
|
+
def checkpoint():
|
|
107
|
+
early_stopping = tf.keras.callbacks.EarlyStopping(
|
|
108
|
+
monitor="val_iou_score", patience=6
|
|
109
|
+
)
|
|
110
|
+
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
|
|
111
|
+
monitor="val_loss", factor=0.1, patience=4, min_lr=0.00005
|
|
112
|
+
)
|
|
113
|
+
return [early_stopping, reduce_lr]
|
|
114
|
+
|
|
115
|
+
@staticmethod
|
|
116
|
+
def train(
|
|
117
|
+
train_ds,
|
|
118
|
+
val_ds,
|
|
119
|
+
epochs=100,
|
|
120
|
+
shape=(128, 128, 3),
|
|
121
|
+
n_classes=2,
|
|
122
|
+
lr=0.0001,
|
|
123
|
+
loss="sparse_categorical_crossentropy",
|
|
124
|
+
metrics=None,
|
|
125
|
+
name="model",
|
|
126
|
+
):
|
|
127
|
+
if metrics is None:
|
|
128
|
+
metrics = ["accuracy"]
|
|
129
|
+
model = Unet.models(shape, n_classes)
|
|
130
|
+
callbacks = Unet.checkpoint()
|
|
131
|
+
adam = tf.keras.optimizers.legacy.Adam(learning_rate=lr)
|
|
132
|
+
model.compile(optimizer=adam, loss=loss, metrics=metrics)
|
|
133
|
+
tf.keras.utils.plot_model(
|
|
134
|
+
model, to_file=f"segmentation/segment-water/model_{name}.png"
|
|
135
|
+
)
|
|
136
|
+
print(model.summary())
|
|
137
|
+
history = model.fit(
|
|
138
|
+
train_ds,
|
|
139
|
+
epochs=100,
|
|
140
|
+
shuffle=True,
|
|
141
|
+
verbose=1,
|
|
142
|
+
callbacks=callbacks,
|
|
143
|
+
validation_data=val_ds,
|
|
144
|
+
)
|
|
145
|
+
model.save(f"{name}.h5")
|
|
146
|
+
return history
|
|
147
|
+
|
|
148
|
+
@staticmethod
|
|
149
|
+
def plot_history(history, epochs, model=""):
|
|
150
|
+
IOU = history.history["iou_score"]
|
|
151
|
+
val_IOU = history.history["val_iou_score"]
|
|
152
|
+
loss = history.history["loss"]
|
|
153
|
+
val_loss = history.history["val_loss"]
|
|
154
|
+
epochs_range = range(epochs)
|
|
155
|
+
plt.figure(figsize=(16, 8))
|
|
156
|
+
plt.subplot(1, 2, 1)
|
|
157
|
+
plt.plot(epochs_range, IOU, label="Training IOU coefficient")
|
|
158
|
+
plt.plot(epochs_range, val_IOU, label="Validation IOU coefficient")
|
|
159
|
+
plt.legend(loc="upper right")
|
|
160
|
+
plt.title("Training and Validation IOU coefficient")
|
|
161
|
+
|
|
162
|
+
plt.subplot(1, 2, 2)
|
|
163
|
+
plt.plot(epochs_range, loss, label="Training Loss")
|
|
164
|
+
plt.plot(epochs_range, val_loss, label="Validation Loss")
|
|
165
|
+
plt.legend(loc="upper right")
|
|
166
|
+
plt.title("Training and Validation Loss")
|
|
167
|
+
plt.savefig(f"segmentation/segment-water/results/history{model}.png")
|
|
168
|
+
plt.show()
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
class BackboneModels:
|
|
172
|
+
parallel_calls = tf.data.AUTOTUNE
|
|
173
|
+
|
|
174
|
+
def __init__(self, Backbone, train_ds, val_ds, test_ds, name="model"):
|
|
175
|
+
assert isinstance(Backbone, str), "Backbone must be string"
|
|
176
|
+
self.backbone = Backbone
|
|
177
|
+
self.train_ds = train_ds
|
|
178
|
+
self.val_ds = val_ds
|
|
179
|
+
self.test_ds = test_ds
|
|
180
|
+
self.name = name
|
|
181
|
+
self.model = None
|
|
182
|
+
|
|
183
|
+
@staticmethod
|
|
184
|
+
def preprocess(images, label, Backbone):
|
|
185
|
+
preprocessor = sm.get_preprocessing(Backbone)
|
|
186
|
+
return preprocessor(images), label
|
|
187
|
+
|
|
188
|
+
def build_model(self, n_classes, n=1, lr=0.0001):
|
|
189
|
+
self.train_ds = self.train_ds.map(
|
|
190
|
+
lambda x, y: BackboneModels.preprocess(x, y, self.backbone),
|
|
191
|
+
num_parallel_calls=BackboneModels.parallel_calls,
|
|
192
|
+
)
|
|
193
|
+
self.val_ds = self.val_ds.map(
|
|
194
|
+
lambda x, y: BackboneModels.preprocess(x, y, self.backbone),
|
|
195
|
+
num_parallel_calls=BackboneModels.parallel_calls,
|
|
196
|
+
)
|
|
197
|
+
self.test_ds = self.test_ds.map(
|
|
198
|
+
lambda x, y: BackboneModels.preprocess(x, y, self.backbone),
|
|
199
|
+
num_parallel_calls=BackboneModels.parallel_calls,
|
|
200
|
+
)
|
|
201
|
+
model = sm.Unet(
|
|
202
|
+
self.backbone,
|
|
203
|
+
classes=n_classes,
|
|
204
|
+
activation="softmax",
|
|
205
|
+
encoder_weights="imagenet",
|
|
206
|
+
)
|
|
207
|
+
adam = tf.keras.optimizers.legacy.Adam(learning_rate=lr)
|
|
208
|
+
model.compile(optimizer=adam, loss=Unet.loss, metrics=Unet.metrics)
|
|
209
|
+
tf.keras.utils.plot_model(
|
|
210
|
+
model, to_file=f"segmentation/segment-water/model_{self.name}.png"
|
|
211
|
+
)
|
|
212
|
+
self.model = model
|
|
213
|
+
Preprocess.plot_image(self.train_ds, n)
|
|
214
|
+
return self.model.summary()
|
|
215
|
+
|
|
216
|
+
def train(self):
|
|
217
|
+
if self.model:
|
|
218
|
+
history1 = self.model.fit(
|
|
219
|
+
self.train_ds,
|
|
220
|
+
epochs=100,
|
|
221
|
+
verbose=1,
|
|
222
|
+
callbacks=Unet.checkpoint(),
|
|
223
|
+
shuffle=True,
|
|
224
|
+
validation_data=self.val_ds,
|
|
225
|
+
)
|
|
226
|
+
self.model.save(f"sat_water_{self.name}.h5")
|
|
227
|
+
return history1
|
|
228
|
+
else:
|
|
229
|
+
return "A model has to be built with the build_model method first"
|
satwater/preprocess.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Created on Wed Nov 17 10:48:32 2023
|
|
3
|
+
|
|
4
|
+
@author: Busayo Alabi
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import random
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
import matplotlib.pyplot as plt
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
import tensorflow as tf
|
|
15
|
+
except Exception as e:
|
|
16
|
+
raise ImportError(
|
|
17
|
+
"TensorFlow is required for sat-water inference/training.\n\n"
|
|
18
|
+
"Install TensorFlow first, then reinstall sat-water.\n"
|
|
19
|
+
"Recommended:\n"
|
|
20
|
+
" Linux/Windows: pip install 'tensorflow'\n"
|
|
21
|
+
" Apple Silicon: pip install 'tensorflow-macos' 'tensorflow-metal'\n\n"
|
|
22
|
+
"If you are using segmentation-models with TF legacy Keras:\n"
|
|
23
|
+
" pip install tf-keras segmentation-models\n"
|
|
24
|
+
) from e
|
|
25
|
+
|
|
26
|
+
from skimage import color
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class Preprocess:
|
|
30
|
+
normalization_layer = tf.keras.layers.Rescaling(1.0 / 255)
|
|
31
|
+
|
|
32
|
+
@staticmethod
|
|
33
|
+
def load(mask_path, mask_folder, image_folder, channels):
|
|
34
|
+
mask = tf.io.read_file(mask_path)
|
|
35
|
+
mask = tf.dtypes.cast(tf.image.decode_jpeg(mask, channels=1), tf.float32)
|
|
36
|
+
img_path = tf.strings.regex_replace(mask_path, mask_folder, image_folder)
|
|
37
|
+
image = tf.io.read_file(img_path)
|
|
38
|
+
image = tf.image.decode_jpeg(image, channels=channels)
|
|
39
|
+
return {"image": Preprocess.normalization_layer(image), "mask": mask}
|
|
40
|
+
|
|
41
|
+
@staticmethod
|
|
42
|
+
def make_single_class(sample):
|
|
43
|
+
mask = sample["mask"]
|
|
44
|
+
sample["mask"] = tf.where(mask >= 0.5, 1, 0)
|
|
45
|
+
return sample
|
|
46
|
+
|
|
47
|
+
@staticmethod
|
|
48
|
+
def resize(data, size):
|
|
49
|
+
image = tf.image.resize(data["image"], size[:2])
|
|
50
|
+
mask = tf.image.resize(data["mask"], size[:2])
|
|
51
|
+
return image, mask
|
|
52
|
+
|
|
53
|
+
def resize_test(data, size):
|
|
54
|
+
image = tf.image.resize(data, size)
|
|
55
|
+
return image
|
|
56
|
+
|
|
57
|
+
@staticmethod
|
|
58
|
+
def data_load(
|
|
59
|
+
dataset, mask_folder, image_folder, split, shape, batch_size, channels=0
|
|
60
|
+
):
|
|
61
|
+
dataset = [dataset]
|
|
62
|
+
ds = []
|
|
63
|
+
for f in dataset:
|
|
64
|
+
mask_path = list(Path(f + mask_folder).glob("**/*"))
|
|
65
|
+
for i in mask_path:
|
|
66
|
+
ds.append(str(i))
|
|
67
|
+
size = len(ds)
|
|
68
|
+
random.shuffle(ds)
|
|
69
|
+
ds = tf.data.Dataset.from_tensor_slices(ds)
|
|
70
|
+
val_size, test_size = split
|
|
71
|
+
train_size = int((1 - (val_size + test_size)) * size)
|
|
72
|
+
val_size = int(val_size * size)
|
|
73
|
+
test_size = int(test_size * size)
|
|
74
|
+
train_ds = ds.take(train_size)
|
|
75
|
+
rest_ds = ds.skip(train_size)
|
|
76
|
+
val_ds = rest_ds.take(val_size)
|
|
77
|
+
test_ds = rest_ds.skip(val_size)
|
|
78
|
+
|
|
79
|
+
AUTOTUNE = tf.data.experimental.AUTOTUNE
|
|
80
|
+
|
|
81
|
+
print(f"The Train Dataset contains {len(train_ds)} images.")
|
|
82
|
+
print(f"The Validation Dataset contains {len(val_ds)} images.")
|
|
83
|
+
print(f"The Test Dataset contains {len(test_ds)} images.")
|
|
84
|
+
|
|
85
|
+
train_ds = train_ds.map(
|
|
86
|
+
lambda x: Preprocess.load(x, mask_folder, image_folder, channels)
|
|
87
|
+
)
|
|
88
|
+
train_ds = train_ds.map(Preprocess.make_single_class)
|
|
89
|
+
train_ds = train_ds.map(
|
|
90
|
+
lambda x: Preprocess.resize(x, shape), num_parallel_calls=AUTOTUNE
|
|
91
|
+
)
|
|
92
|
+
train_ds = train_ds.batch(batch_size)
|
|
93
|
+
train_ds = train_ds.prefetch(buffer_size=AUTOTUNE)
|
|
94
|
+
|
|
95
|
+
val_ds = val_ds.map(
|
|
96
|
+
lambda x: Preprocess.load(x, mask_folder, image_folder, channels)
|
|
97
|
+
)
|
|
98
|
+
val_ds = val_ds.map(Preprocess.make_single_class)
|
|
99
|
+
val_ds = val_ds.map(
|
|
100
|
+
lambda x: Preprocess.resize(x, shape), num_parallel_calls=AUTOTUNE
|
|
101
|
+
)
|
|
102
|
+
val_ds = val_ds.batch(batch_size)
|
|
103
|
+
val_ds = val_ds.prefetch(buffer_size=AUTOTUNE)
|
|
104
|
+
|
|
105
|
+
test_ds = test_ds.map(
|
|
106
|
+
lambda x: Preprocess.load(x, mask_folder, image_folder, channels)
|
|
107
|
+
)
|
|
108
|
+
test_ds = test_ds.map(Preprocess.make_single_class)
|
|
109
|
+
test_ds = test_ds.map(
|
|
110
|
+
lambda x: Preprocess.resize(x, shape), num_parallel_calls=AUTOTUNE
|
|
111
|
+
)
|
|
112
|
+
test_ds = test_ds.batch(batch_size)
|
|
113
|
+
test_ds = test_ds.prefetch(buffer_size=AUTOTUNE)
|
|
114
|
+
|
|
115
|
+
return (train_ds, val_ds, test_ds)
|
|
116
|
+
|
|
117
|
+
@staticmethod
|
|
118
|
+
def label_overlay(img, mask, fname=None):
|
|
119
|
+
plt.figure()
|
|
120
|
+
|
|
121
|
+
mask = np.squeeze(mask, axis=-1)
|
|
122
|
+
|
|
123
|
+
overlay = color.label2rgb(
|
|
124
|
+
mask,
|
|
125
|
+
img,
|
|
126
|
+
colors=[(0, 255, 0)],
|
|
127
|
+
alpha=0.05,
|
|
128
|
+
bg_label=0,
|
|
129
|
+
bg_color=None,
|
|
130
|
+
image_alpha=1,
|
|
131
|
+
saturation=1,
|
|
132
|
+
kind="overlay",
|
|
133
|
+
)
|
|
134
|
+
plt.figure(figsize=(6, 6))
|
|
135
|
+
|
|
136
|
+
plt.imshow(overlay)
|
|
137
|
+
plt.axis("off")
|
|
138
|
+
|
|
139
|
+
if fname:
|
|
140
|
+
plt.savefig(fname)
|
|
141
|
+
plt.clf()
|
|
142
|
+
plt.close()
|
|
143
|
+
|
|
144
|
+
plt.show()
|
|
145
|
+
|
|
146
|
+
@staticmethod
|
|
147
|
+
def plot_image(train_ds, n=1, fname=None):
|
|
148
|
+
for image, mask in (
|
|
149
|
+
train_ds.unbatch().shuffle(buffer_size=len(train_ds)).take(1)
|
|
150
|
+
):
|
|
151
|
+
image = (image - tf.reduce_min(image)) / (
|
|
152
|
+
tf.reduce_max(image) - tf.reduce_min(image)
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
plt.figure(figsize=(12, 6))
|
|
156
|
+
plt.subplot(121)
|
|
157
|
+
plt.imshow(image)
|
|
158
|
+
plt.subplot(122)
|
|
159
|
+
plt.imshow(mask)
|
|
160
|
+
plt.show()
|
|
161
|
+
Preprocess.label_overlay(image, mask, fname)
|