sat-water 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
satwater/inference.py ADDED
@@ -0,0 +1,313 @@
1
+ """
2
+ Created on Wed Nov 18 13:22:57 2023
3
+
4
+ @author: Busayo Alabi
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import os
10
+ from dataclasses import dataclass
11
+ from pathlib import Path
12
+ from typing import Union
13
+
14
+ import matplotlib.pyplot as plt
15
+ import numpy as np
16
+
17
+ try:
18
+ import tensorflow as tf
19
+ except Exception as e:
20
+ raise ImportError(
21
+ "TensorFlow is required for sat-water inference/training.\n\n"
22
+ "Install TensorFlow first, then reinstall sat-water.\n"
23
+ "Recommended:\n"
24
+ " Linux/Windows: pip install 'tensorflow'\n"
25
+ " Apple Silicon: pip install 'tensorflow-macos' 'tensorflow-metal'\n\n"
26
+ "If you are using segmentation-models with TF legacy Keras:\n"
27
+ " pip install tf-keras segmentation-models\n"
28
+ ) from e
29
+
30
+ from satwater.builders import load_pretrained
31
+ from satwater.preprocess import Preprocess
32
+
33
+ try:
34
+ from PIL import Image
35
+ except Exception:
36
+ Image = None
37
+
38
+ ArrayLikeImage = Union[str, np.ndarray, tf.Tensor, "Image.Image"]
39
+
40
+
41
+ @dataclass(frozen=True)
42
+ class SegmentationResult:
43
+ """
44
+ Result from a segmentation call.
45
+ """
46
+
47
+ masks: dict[str, np.ndarray]
48
+ overlays: dict[str, np.ndarray]
49
+ base_image: np.ndarray
50
+
51
+
52
+ class Inference:
53
+ """
54
+ Supports both pretrained model keys and local saved model paths:
55
+ """
56
+
57
+ pretrained_keys = {"unet", "resnet34_256", "resnet34_512"}
58
+
59
+ def __init__(
60
+ self, model="unet", *, name=None, repo_id=None, revision="main", save_dir=None
61
+ ):
62
+ self.repo_id = repo_id or os.environ.get(
63
+ "SATWATER_WEIGHTS_REPO", "busayojee/sat-water-weights"
64
+ )
65
+ self.revision = revision
66
+ self.save_dir = Path(save_dir) if save_dir is not None else None
67
+ self.name = name
68
+ self.preprocess_func = None
69
+ self.preprocess_funcs = {}
70
+
71
+ if isinstance(model, dict):
72
+ self.model = {}
73
+ for key, spec in model.items():
74
+ m, pp = self._load_one(spec, key_hint=key)
75
+ self.model[key] = m
76
+ self.preprocess_funcs[key] = pp
77
+ return
78
+
79
+ if isinstance(model, str):
80
+ if model in self.pretrained_keys:
81
+ m, pp = self._load_pretrained(model)
82
+ self.model = m
83
+ self.name = model
84
+ self.preprocess_func = pp
85
+ return
86
+
87
+ if self._looks_like_file(model):
88
+ self.model = tf.keras.models.load_model(model, compile=False)
89
+ self.name = name or "local"
90
+ return
91
+
92
+ raise ValueError(
93
+ "model must be a pretrained key (unet/resnet34_256/resnet34_512), "
94
+ "a local path to a saved model, or a dict of them."
95
+ )
96
+
97
+ raise TypeError("model must be a str or dict[str, str]")
98
+
99
+ @staticmethod
100
+ def _looks_like_file(s):
101
+ return ("/" in s) or s.endswith(".keras") or s.endswith(".h5")
102
+
103
+ def _load_pretrained(self, key):
104
+ pm = load_pretrained(
105
+ model_key=key, repo_id=self.repo_id, revision=self.revision
106
+ )
107
+ return pm.model, pm.preprocess_func
108
+
109
+ def _load_one(self, spec, key_hint=None):
110
+ if spec in self.pretrained_keys:
111
+ return self._load_pretrained(spec)
112
+ if self._looks_like_file(spec):
113
+ return tf.keras.models.load_model(spec, compile=False), None
114
+ if key_hint in self.pretrained_keys:
115
+ return self._load_pretrained(key_hint)
116
+ raise ValueError(
117
+ f"Unknown model spec '{spec}'. Use pretrained key or local model path."
118
+ )
119
+
120
+ @staticmethod
121
+ def _to_tf_image(x):
122
+ if isinstance(x, str):
123
+ raw = tf.io.read_file(x)
124
+ img = tf.image.decode_image(raw, channels=3, expand_animations=False)
125
+ img = tf.cast(img, tf.float32) / 255.0
126
+ return img
127
+
128
+ if Image is not None and isinstance(x, Image.Image):
129
+ arr = np.asarray(x.convert("RGB"), dtype=np.float32) / 255.0
130
+ return tf.convert_to_tensor(arr, dtype=tf.float32)
131
+
132
+ if isinstance(x, np.ndarray):
133
+ arr = x
134
+ if arr.ndim == 2:
135
+ arr = np.stack([arr, arr, arr], axis=-1)
136
+ if arr.shape[-1] == 4:
137
+ arr = arr[..., :3]
138
+ arr = arr.astype(np.float32)
139
+ if arr.max() > 1.5:
140
+ arr = arr / 255.0
141
+ return tf.convert_to_tensor(arr, dtype=tf.float32)
142
+
143
+ if isinstance(x, tf.Tensor):
144
+ t = tf.cast(x, tf.float32)
145
+ if t.shape.rank == 2:
146
+ t = tf.stack([t, t, t], axis=-1)
147
+ if t.shape.rank == 3 and t.shape[-1] == 4:
148
+ t = t[..., :3]
149
+ tmax = tf.reduce_max(t)
150
+ t = tf.cond(tmax > 1.5, lambda: t / 255.0, lambda: t)
151
+ return t
152
+
153
+ raise TypeError(
154
+ "Unsupported image type. Use path, PIL.Image, numpy array, or tf.Tensor."
155
+ )
156
+
157
+ @staticmethod
158
+ def _resize(img, size):
159
+ return tf.image.resize(img, size=size, method="bilinear")
160
+
161
+ @staticmethod
162
+ def _mask_to_uint8(mask):
163
+ m = np.squeeze(mask)
164
+ if m.dtype != np.uint8:
165
+ m = m.astype(np.uint8)
166
+ return m
167
+
168
+ @staticmethod
169
+ def overlay_mask(image, mask, alpha=0.55, color=(0.0, 0.4, 1.0)):
170
+ img = image
171
+ if img.dtype != np.float32:
172
+ img = img.astype(np.float32) / 255.0
173
+ img = np.clip(img, 0.0, 1.0)
174
+
175
+ m = np.squeeze(mask)
176
+ water = (m > 0).astype(np.float32)
177
+ water3 = np.repeat(water[:, :, None], 3, axis=2)
178
+
179
+ col = np.array(color, dtype=np.float32)[None, None, :]
180
+ overlay = img * (1.0 - alpha * water3) + col * (alpha * water3)
181
+ return (np.clip(overlay, 0.0, 1.0) * 255.0).astype(np.uint8)
182
+
183
+ def _maybe_save(self, arr, fname):
184
+ if self.save_dir is None:
185
+ return
186
+ self.save_dir.mkdir(parents=True, exist_ok=True)
187
+ out = self.save_dir / fname
188
+ plt.imsave(out.as_posix(), arr)
189
+
190
+ def predict(
191
+ self, image, *, return_overlay=True, save=False, show=False, fname="prediction"
192
+ ):
193
+ raw_img = self._to_tf_image(image)
194
+ model_img = Preprocess.normalization_layer(raw_img)
195
+ base_img = np.clip(raw_img.numpy(), 0.0, 1.0).astype(np.float32)
196
+
197
+ masks = {}
198
+ overlays = {}
199
+
200
+ if isinstance(self.model, dict):
201
+ for key, model in self.model.items():
202
+ mask_np = self._predict_for_key(key, model, model_img)
203
+ masks[key] = mask_np
204
+ if return_overlay:
205
+ ov = self.overlay_mask(
206
+ self._image_for_key(key, raw_img).numpy(), mask_np
207
+ )
208
+ overlays[key] = ov
209
+ if save and self.save_dir is not None:
210
+ self._maybe_save(ov, f"{fname}_{key}_overlay.png")
211
+ self._maybe_save(np.squeeze(mask_np), f"{fname}_{key}_mask.png")
212
+ if show:
213
+ self._plot_results(base_img, masks, overlays, title=fname)
214
+ return SegmentationResult(
215
+ masks=masks, overlays=overlays, base_image=base_img
216
+ )
217
+
218
+ key = self.name or "model"
219
+ mask_np = self._predict_for_key(key, self.model, model_img)
220
+ masks[key] = mask_np
221
+
222
+ if return_overlay:
223
+ ov = self.overlay_mask(self._image_for_key(key, raw_img).numpy(), mask_np)
224
+ overlays[key] = ov
225
+ if save and self.save_dir is not None:
226
+ self._maybe_save(ov, f"{fname}_{key}_overlay.png")
227
+ self._maybe_save(np.squeeze(mask_np), f"{fname}_{key}_mask.png")
228
+
229
+ if show:
230
+ self._plot_results(base_img, masks, overlays, title=fname)
231
+
232
+ return SegmentationResult(masks=masks, overlays=overlays, base_image=base_img)
233
+
234
+ def _image_for_key(self, key, img):
235
+ if key == "unet":
236
+ return self._resize(img, (128, 128))
237
+ if key == "resnet34_256":
238
+ return self._resize(img, (256, 256))
239
+ if key == "resnet34_512":
240
+ return self._resize(img, (512, 512))
241
+ return self._resize(img, (128, 128))
242
+
243
+ def _predict_for_key(self, key, model, img):
244
+ x = self._image_for_key(key, img)
245
+
246
+ if isinstance(self.model, dict):
247
+ pp = self.preprocess_funcs.get(key)
248
+ else:
249
+ pp = self.preprocess_func
250
+
251
+ if pp is not None:
252
+ x = pp(x)
253
+
254
+ x = (x - tf.reduce_min(x)) / (tf.reduce_max(x) - tf.reduce_min(x) + 1e-8)
255
+
256
+ pred = model.predict(x[tf.newaxis, ...], verbose=0)
257
+ pred = tf.argmax(pred, axis=-1)
258
+ pred = tf.expand_dims(pred, axis=-1)[0, :, :, :]
259
+ pred_np = pred.numpy().astype(np.int32)
260
+
261
+ if key == "resnet34_512":
262
+ pred_np = 1 - pred_np
263
+ return pred_np
264
+
265
+ @staticmethod
266
+ def _plot_results(base_img, masks, overlays, title="result"):
267
+ keys = list(masks.keys())
268
+ n = len(keys)
269
+ cols = 3 if overlays else 2
270
+ plt.figure(figsize=(5 * cols, 4 * n))
271
+
272
+ for i, k in enumerate(keys):
273
+ r = i * cols
274
+
275
+ plt.subplot(n, cols, r + 1)
276
+ plt.title("Image")
277
+ plt.imshow(base_img)
278
+ plt.axis("off")
279
+
280
+ plt.subplot(n, cols, r + 2)
281
+ plt.title(f"Mask: {k}")
282
+ plt.imshow(np.squeeze(masks[k]))
283
+ plt.axis("off")
284
+
285
+ if overlays:
286
+ plt.subplot(n, cols, r + 3)
287
+ plt.title(f"Overlay: {k}")
288
+ plt.imshow(overlays[k])
289
+ plt.axis("off")
290
+
291
+ plt.suptitle(title)
292
+ plt.tight_layout()
293
+ plt.show()
294
+
295
+
296
+ def segment_image(
297
+ image,
298
+ *,
299
+ model="unet",
300
+ repo_id=None,
301
+ revision="main",
302
+ return_overlay=True,
303
+ save=False,
304
+ save_dir=None,
305
+ show=False,
306
+ fname="prediction",
307
+ ):
308
+ infer = Inference(
309
+ model=model, repo_id=repo_id, revision=revision, save_dir=save_dir
310
+ )
311
+ return infer.predict(
312
+ image, return_overlay=return_overlay, save=save, show=show, fname=fname
313
+ )
satwater/models.py ADDED
@@ -0,0 +1,229 @@
1
+ """
2
+ Created on Wed Nov 17 11:05:25 2023
3
+
4
+ @author: Busayo Alabi
5
+ """
6
+
7
+ import matplotlib.pyplot as plt
8
+
9
+ try:
10
+ import tensorflow as tf
11
+ except Exception as e:
12
+ raise ImportError(
13
+ "TensorFlow is required for sat-water inference/training.\n\n"
14
+ "Install TensorFlow first, then reinstall sat-water.\n"
15
+ "Recommended:\n"
16
+ " Linux/Windows: pip install 'tensorflow'\n"
17
+ " Apple Silicon: pip install 'tensorflow-macos' 'tensorflow-metal'\n\n"
18
+ "If you are using segmentation-models with TF legacy Keras:\n"
19
+ " pip install tf-keras segmentation-models\n"
20
+ ) from e
21
+
22
+ try:
23
+ import segmentation_models as sm
24
+ except Exception as e:
25
+ raise ImportError(
26
+ "segmentation-models is required for training backbone models.\n"
27
+ "Install it with:\n"
28
+ " pip install segmentation-models tf-keras\n"
29
+ ) from e
30
+
31
+ from satwater.preprocess import Preprocess
32
+
33
+
34
+ class Unet:
35
+ loss = sm.losses.categorical_focal_dice_loss
36
+ metrics = [sm.metrics.iou_score, sm.metrics.f1_score]
37
+
38
+ @staticmethod
39
+ def dice_coef_water(y_true, y_pred, smooth=1e-6):
40
+ y_true = tf.cast(y_true, tf.float32)
41
+ y_true = tf.squeeze(y_true, axis=-1)
42
+ y_pred = tf.cast(y_pred[..., 1], tf.float32)
43
+
44
+ y_true_f = tf.reshape(y_true, [tf.shape(y_true)[0], -1])
45
+ y_pred_f = tf.reshape(y_pred, [tf.shape(y_pred)[0], -1])
46
+
47
+ intersection = tf.reduce_sum(y_true_f * y_pred_f, axis=1)
48
+ denom = tf.reduce_sum(y_true_f + y_pred_f, axis=1)
49
+
50
+ dice = (2.0 * intersection + smooth) / (denom + smooth)
51
+ return tf.reduce_mean(dice)
52
+
53
+ @staticmethod
54
+ def conv_block(input, filters):
55
+ encoder = tf.keras.layers.Conv2D(filters, (3, 3), padding="same")(input)
56
+ encoder = tf.keras.layers.BatchNormalization()(encoder)
57
+ encoder = tf.keras.layers.Activation("relu")(encoder)
58
+ encoder = tf.keras.layers.Conv2D(filters, (3, 3), padding="same")(encoder)
59
+ encoder = tf.keras.layers.BatchNormalization()(encoder)
60
+ encoder = tf.keras.layers.Activation("relu")(encoder)
61
+ return encoder
62
+
63
+ @staticmethod
64
+ def encoder_block(input, filters):
65
+ encoder = Unet.conv_block(input, filters)
66
+ encoder_pool = tf.keras.layers.MaxPooling2D((2, 2), strides=(2, 2))(encoder)
67
+ return encoder_pool, encoder
68
+
69
+ @staticmethod
70
+ def decoder_block(input, concat, filters):
71
+ decoder = tf.keras.layers.Conv2DTranspose(
72
+ filters, (2, 2), strides=(2, 2), padding="same"
73
+ )(input)
74
+ decoder = tf.keras.layers.concatenate([concat, decoder], axis=-1)
75
+ decoder = tf.keras.layers.BatchNormalization()(decoder)
76
+ decoder = tf.keras.layers.Activation("relu")(decoder)
77
+ decoder = tf.keras.layers.Conv2D(filters, (3, 3), padding="same")(decoder)
78
+ decoder = tf.keras.layers.BatchNormalization()(decoder)
79
+ decoder = tf.keras.layers.Activation("relu")(decoder)
80
+ decoder = tf.keras.layers.Conv2D(filters, (3, 3), padding="same")(decoder)
81
+ decoder = tf.keras.layers.BatchNormalization()(decoder)
82
+ decoder = tf.keras.layers.Activation("relu")(decoder)
83
+ return decoder
84
+
85
+ @staticmethod
86
+ def models(input_shape, n_classes):
87
+ inputs = tf.keras.layers.Input(shape=input_shape)
88
+ encoder_pool0, encoder0 = Unet.encoder_block(inputs, 32)
89
+ encoder_pool1, encoder1 = Unet.encoder_block(encoder_pool0, 64)
90
+ encoder_pool2, encoder2 = Unet.encoder_block(encoder_pool1, 128)
91
+ encoder_pool3, encoder3 = Unet.encoder_block(encoder_pool2, 256)
92
+ encoder_pool4, encoder4 = Unet.encoder_block(encoder_pool3, 512)
93
+ center = Unet.conv_block(encoder_pool4, 1024)
94
+ decoder4 = Unet.decoder_block(center, encoder4, 512)
95
+ decoder3 = Unet.decoder_block(decoder4, encoder3, 256)
96
+ decoder2 = Unet.decoder_block(decoder3, encoder2, 128)
97
+ decoder1 = Unet.decoder_block(decoder2, encoder1, 64)
98
+ decoder0 = Unet.decoder_block(decoder1, encoder0, 32)
99
+ outputs = tf.keras.layers.Conv2D(n_classes, (1, 1), activation="softmax")(
100
+ decoder0
101
+ )
102
+ models = tf.keras.models.Model(inputs=[inputs], outputs=[outputs])
103
+ return models
104
+
105
+ @staticmethod
106
+ def checkpoint():
107
+ early_stopping = tf.keras.callbacks.EarlyStopping(
108
+ monitor="val_iou_score", patience=6
109
+ )
110
+ reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
111
+ monitor="val_loss", factor=0.1, patience=4, min_lr=0.00005
112
+ )
113
+ return [early_stopping, reduce_lr]
114
+
115
+ @staticmethod
116
+ def train(
117
+ train_ds,
118
+ val_ds,
119
+ epochs=100,
120
+ shape=(128, 128, 3),
121
+ n_classes=2,
122
+ lr=0.0001,
123
+ loss="sparse_categorical_crossentropy",
124
+ metrics=None,
125
+ name="model",
126
+ ):
127
+ if metrics is None:
128
+ metrics = ["accuracy"]
129
+ model = Unet.models(shape, n_classes)
130
+ callbacks = Unet.checkpoint()
131
+ adam = tf.keras.optimizers.legacy.Adam(learning_rate=lr)
132
+ model.compile(optimizer=adam, loss=loss, metrics=metrics)
133
+ tf.keras.utils.plot_model(
134
+ model, to_file=f"segmentation/segment-water/model_{name}.png"
135
+ )
136
+ print(model.summary())
137
+ history = model.fit(
138
+ train_ds,
139
+ epochs=100,
140
+ shuffle=True,
141
+ verbose=1,
142
+ callbacks=callbacks,
143
+ validation_data=val_ds,
144
+ )
145
+ model.save(f"{name}.h5")
146
+ return history
147
+
148
+ @staticmethod
149
+ def plot_history(history, epochs, model=""):
150
+ IOU = history.history["iou_score"]
151
+ val_IOU = history.history["val_iou_score"]
152
+ loss = history.history["loss"]
153
+ val_loss = history.history["val_loss"]
154
+ epochs_range = range(epochs)
155
+ plt.figure(figsize=(16, 8))
156
+ plt.subplot(1, 2, 1)
157
+ plt.plot(epochs_range, IOU, label="Training IOU coefficient")
158
+ plt.plot(epochs_range, val_IOU, label="Validation IOU coefficient")
159
+ plt.legend(loc="upper right")
160
+ plt.title("Training and Validation IOU coefficient")
161
+
162
+ plt.subplot(1, 2, 2)
163
+ plt.plot(epochs_range, loss, label="Training Loss")
164
+ plt.plot(epochs_range, val_loss, label="Validation Loss")
165
+ plt.legend(loc="upper right")
166
+ plt.title("Training and Validation Loss")
167
+ plt.savefig(f"segmentation/segment-water/results/history{model}.png")
168
+ plt.show()
169
+
170
+
171
+ class BackboneModels:
172
+ parallel_calls = tf.data.AUTOTUNE
173
+
174
+ def __init__(self, Backbone, train_ds, val_ds, test_ds, name="model"):
175
+ assert isinstance(Backbone, str), "Backbone must be string"
176
+ self.backbone = Backbone
177
+ self.train_ds = train_ds
178
+ self.val_ds = val_ds
179
+ self.test_ds = test_ds
180
+ self.name = name
181
+ self.model = None
182
+
183
+ @staticmethod
184
+ def preprocess(images, label, Backbone):
185
+ preprocessor = sm.get_preprocessing(Backbone)
186
+ return preprocessor(images), label
187
+
188
+ def build_model(self, n_classes, n=1, lr=0.0001):
189
+ self.train_ds = self.train_ds.map(
190
+ lambda x, y: BackboneModels.preprocess(x, y, self.backbone),
191
+ num_parallel_calls=BackboneModels.parallel_calls,
192
+ )
193
+ self.val_ds = self.val_ds.map(
194
+ lambda x, y: BackboneModels.preprocess(x, y, self.backbone),
195
+ num_parallel_calls=BackboneModels.parallel_calls,
196
+ )
197
+ self.test_ds = self.test_ds.map(
198
+ lambda x, y: BackboneModels.preprocess(x, y, self.backbone),
199
+ num_parallel_calls=BackboneModels.parallel_calls,
200
+ )
201
+ model = sm.Unet(
202
+ self.backbone,
203
+ classes=n_classes,
204
+ activation="softmax",
205
+ encoder_weights="imagenet",
206
+ )
207
+ adam = tf.keras.optimizers.legacy.Adam(learning_rate=lr)
208
+ model.compile(optimizer=adam, loss=Unet.loss, metrics=Unet.metrics)
209
+ tf.keras.utils.plot_model(
210
+ model, to_file=f"segmentation/segment-water/model_{self.name}.png"
211
+ )
212
+ self.model = model
213
+ Preprocess.plot_image(self.train_ds, n)
214
+ return self.model.summary()
215
+
216
+ def train(self):
217
+ if self.model:
218
+ history1 = self.model.fit(
219
+ self.train_ds,
220
+ epochs=100,
221
+ verbose=1,
222
+ callbacks=Unet.checkpoint(),
223
+ shuffle=True,
224
+ validation_data=self.val_ds,
225
+ )
226
+ self.model.save(f"sat_water_{self.name}.h5")
227
+ return history1
228
+ else:
229
+ return "A model has to be built with the build_model method first"
satwater/preprocess.py ADDED
@@ -0,0 +1,161 @@
1
+ """
2
+ Created on Wed Nov 17 10:48:32 2023
3
+
4
+ @author: Busayo Alabi
5
+ """
6
+
7
+ import random
8
+ from pathlib import Path
9
+
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+
13
+ try:
14
+ import tensorflow as tf
15
+ except Exception as e:
16
+ raise ImportError(
17
+ "TensorFlow is required for sat-water inference/training.\n\n"
18
+ "Install TensorFlow first, then reinstall sat-water.\n"
19
+ "Recommended:\n"
20
+ " Linux/Windows: pip install 'tensorflow'\n"
21
+ " Apple Silicon: pip install 'tensorflow-macos' 'tensorflow-metal'\n\n"
22
+ "If you are using segmentation-models with TF legacy Keras:\n"
23
+ " pip install tf-keras segmentation-models\n"
24
+ ) from e
25
+
26
+ from skimage import color
27
+
28
+
29
+ class Preprocess:
30
+ normalization_layer = tf.keras.layers.Rescaling(1.0 / 255)
31
+
32
+ @staticmethod
33
+ def load(mask_path, mask_folder, image_folder, channels):
34
+ mask = tf.io.read_file(mask_path)
35
+ mask = tf.dtypes.cast(tf.image.decode_jpeg(mask, channels=1), tf.float32)
36
+ img_path = tf.strings.regex_replace(mask_path, mask_folder, image_folder)
37
+ image = tf.io.read_file(img_path)
38
+ image = tf.image.decode_jpeg(image, channels=channels)
39
+ return {"image": Preprocess.normalization_layer(image), "mask": mask}
40
+
41
+ @staticmethod
42
+ def make_single_class(sample):
43
+ mask = sample["mask"]
44
+ sample["mask"] = tf.where(mask >= 0.5, 1, 0)
45
+ return sample
46
+
47
+ @staticmethod
48
+ def resize(data, size):
49
+ image = tf.image.resize(data["image"], size[:2])
50
+ mask = tf.image.resize(data["mask"], size[:2])
51
+ return image, mask
52
+
53
+ def resize_test(data, size):
54
+ image = tf.image.resize(data, size)
55
+ return image
56
+
57
+ @staticmethod
58
+ def data_load(
59
+ dataset, mask_folder, image_folder, split, shape, batch_size, channels=0
60
+ ):
61
+ dataset = [dataset]
62
+ ds = []
63
+ for f in dataset:
64
+ mask_path = list(Path(f + mask_folder).glob("**/*"))
65
+ for i in mask_path:
66
+ ds.append(str(i))
67
+ size = len(ds)
68
+ random.shuffle(ds)
69
+ ds = tf.data.Dataset.from_tensor_slices(ds)
70
+ val_size, test_size = split
71
+ train_size = int((1 - (val_size + test_size)) * size)
72
+ val_size = int(val_size * size)
73
+ test_size = int(test_size * size)
74
+ train_ds = ds.take(train_size)
75
+ rest_ds = ds.skip(train_size)
76
+ val_ds = rest_ds.take(val_size)
77
+ test_ds = rest_ds.skip(val_size)
78
+
79
+ AUTOTUNE = tf.data.experimental.AUTOTUNE
80
+
81
+ print(f"The Train Dataset contains {len(train_ds)} images.")
82
+ print(f"The Validation Dataset contains {len(val_ds)} images.")
83
+ print(f"The Test Dataset contains {len(test_ds)} images.")
84
+
85
+ train_ds = train_ds.map(
86
+ lambda x: Preprocess.load(x, mask_folder, image_folder, channels)
87
+ )
88
+ train_ds = train_ds.map(Preprocess.make_single_class)
89
+ train_ds = train_ds.map(
90
+ lambda x: Preprocess.resize(x, shape), num_parallel_calls=AUTOTUNE
91
+ )
92
+ train_ds = train_ds.batch(batch_size)
93
+ train_ds = train_ds.prefetch(buffer_size=AUTOTUNE)
94
+
95
+ val_ds = val_ds.map(
96
+ lambda x: Preprocess.load(x, mask_folder, image_folder, channels)
97
+ )
98
+ val_ds = val_ds.map(Preprocess.make_single_class)
99
+ val_ds = val_ds.map(
100
+ lambda x: Preprocess.resize(x, shape), num_parallel_calls=AUTOTUNE
101
+ )
102
+ val_ds = val_ds.batch(batch_size)
103
+ val_ds = val_ds.prefetch(buffer_size=AUTOTUNE)
104
+
105
+ test_ds = test_ds.map(
106
+ lambda x: Preprocess.load(x, mask_folder, image_folder, channels)
107
+ )
108
+ test_ds = test_ds.map(Preprocess.make_single_class)
109
+ test_ds = test_ds.map(
110
+ lambda x: Preprocess.resize(x, shape), num_parallel_calls=AUTOTUNE
111
+ )
112
+ test_ds = test_ds.batch(batch_size)
113
+ test_ds = test_ds.prefetch(buffer_size=AUTOTUNE)
114
+
115
+ return (train_ds, val_ds, test_ds)
116
+
117
+ @staticmethod
118
+ def label_overlay(img, mask, fname=None):
119
+ plt.figure()
120
+
121
+ mask = np.squeeze(mask, axis=-1)
122
+
123
+ overlay = color.label2rgb(
124
+ mask,
125
+ img,
126
+ colors=[(0, 255, 0)],
127
+ alpha=0.05,
128
+ bg_label=0,
129
+ bg_color=None,
130
+ image_alpha=1,
131
+ saturation=1,
132
+ kind="overlay",
133
+ )
134
+ plt.figure(figsize=(6, 6))
135
+
136
+ plt.imshow(overlay)
137
+ plt.axis("off")
138
+
139
+ if fname:
140
+ plt.savefig(fname)
141
+ plt.clf()
142
+ plt.close()
143
+
144
+ plt.show()
145
+
146
+ @staticmethod
147
+ def plot_image(train_ds, n=1, fname=None):
148
+ for image, mask in (
149
+ train_ds.unbatch().shuffle(buffer_size=len(train_ds)).take(1)
150
+ ):
151
+ image = (image - tf.reduce_min(image)) / (
152
+ tf.reduce_max(image) - tf.reduce_min(image)
153
+ )
154
+
155
+ plt.figure(figsize=(12, 6))
156
+ plt.subplot(121)
157
+ plt.imshow(image)
158
+ plt.subplot(122)
159
+ plt.imshow(mask)
160
+ plt.show()
161
+ Preprocess.label_overlay(image, mask, fname)