stouputils 1.14.3__py3-none-any.whl → 1.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stouputils/data_science/config/get.py +51 -51
- stouputils/data_science/data_processing/image/__init__.py +66 -66
- stouputils/data_science/data_processing/image/auto_contrast.py +79 -79
- stouputils/data_science/data_processing/image/axis_flip.py +58 -58
- stouputils/data_science/data_processing/image/bias_field_correction.py +74 -74
- stouputils/data_science/data_processing/image/binary_threshold.py +73 -73
- stouputils/data_science/data_processing/image/blur.py +59 -59
- stouputils/data_science/data_processing/image/brightness.py +54 -54
- stouputils/data_science/data_processing/image/canny.py +110 -110
- stouputils/data_science/data_processing/image/clahe.py +92 -92
- stouputils/data_science/data_processing/image/common.py +30 -30
- stouputils/data_science/data_processing/image/contrast.py +53 -53
- stouputils/data_science/data_processing/image/curvature_flow_filter.py +74 -74
- stouputils/data_science/data_processing/image/denoise.py +378 -378
- stouputils/data_science/data_processing/image/histogram_equalization.py +123 -123
- stouputils/data_science/data_processing/image/invert.py +64 -64
- stouputils/data_science/data_processing/image/laplacian.py +60 -60
- stouputils/data_science/data_processing/image/median_blur.py +52 -52
- stouputils/data_science/data_processing/image/noise.py +59 -59
- stouputils/data_science/data_processing/image/normalize.py +65 -65
- stouputils/data_science/data_processing/image/random_erase.py +66 -66
- stouputils/data_science/data_processing/image/resize.py +69 -69
- stouputils/data_science/data_processing/image/rotation.py +80 -80
- stouputils/data_science/data_processing/image/salt_pepper.py +68 -68
- stouputils/data_science/data_processing/image/sharpening.py +55 -55
- stouputils/data_science/data_processing/image/shearing.py +64 -64
- stouputils/data_science/data_processing/image/threshold.py +64 -64
- stouputils/data_science/data_processing/image/translation.py +71 -71
- stouputils/data_science/data_processing/image/zoom.py +83 -83
- stouputils/data_science/data_processing/image_augmentation.py +118 -118
- stouputils/data_science/data_processing/image_preprocess.py +183 -183
- stouputils/data_science/data_processing/prosthesis_detection.py +359 -359
- stouputils/data_science/data_processing/technique.py +481 -481
- stouputils/data_science/dataset/__init__.py +45 -45
- stouputils/data_science/dataset/dataset.py +292 -292
- stouputils/data_science/dataset/dataset_loader.py +135 -135
- stouputils/data_science/dataset/grouping_strategy.py +296 -296
- stouputils/data_science/dataset/image_loader.py +100 -100
- stouputils/data_science/dataset/xy_tuple.py +696 -696
- stouputils/data_science/metric_dictionnary.py +106 -106
- stouputils/data_science/mlflow_utils.py +206 -206
- stouputils/data_science/models/abstract_model.py +149 -149
- stouputils/data_science/models/all.py +85 -85
- stouputils/data_science/models/keras/all.py +38 -38
- stouputils/data_science/models/keras/convnext.py +62 -62
- stouputils/data_science/models/keras/densenet.py +50 -50
- stouputils/data_science/models/keras/efficientnet.py +60 -60
- stouputils/data_science/models/keras/mobilenet.py +56 -56
- stouputils/data_science/models/keras/resnet.py +52 -52
- stouputils/data_science/models/keras/squeezenet.py +233 -233
- stouputils/data_science/models/keras/vgg.py +42 -42
- stouputils/data_science/models/keras/xception.py +38 -38
- stouputils/data_science/models/keras_utils/callbacks/__init__.py +20 -20
- stouputils/data_science/models/keras_utils/callbacks/colored_progress_bar.py +219 -219
- stouputils/data_science/models/keras_utils/callbacks/learning_rate_finder.py +148 -148
- stouputils/data_science/models/keras_utils/callbacks/model_checkpoint_v2.py +31 -31
- stouputils/data_science/models/keras_utils/callbacks/progressive_unfreezing.py +249 -249
- stouputils/data_science/models/keras_utils/callbacks/warmup_scheduler.py +66 -66
- stouputils/data_science/models/keras_utils/losses/__init__.py +12 -12
- stouputils/data_science/models/keras_utils/losses/next_generation_loss.py +56 -56
- stouputils/data_science/models/keras_utils/visualizations.py +416 -416
- stouputils/data_science/models/sandbox.py +116 -116
- stouputils/data_science/range_tuple.py +234 -234
- stouputils/data_science/utils.py +285 -285
- stouputils/decorators.py +53 -39
- stouputils/decorators.pyi +2 -2
- stouputils/installer/__init__.py +18 -18
- stouputils/installer/linux.py +144 -144
- stouputils/installer/main.py +223 -223
- stouputils/installer/windows.py +136 -136
- stouputils/io.py +16 -9
- stouputils/print.py +229 -2
- stouputils/print.pyi +90 -1
- stouputils/py.typed +1 -1
- {stouputils-1.14.3.dist-info → stouputils-1.15.0.dist-info}/METADATA +1 -1
- {stouputils-1.14.3.dist-info → stouputils-1.15.0.dist-info}/RECORD +78 -78
- {stouputils-1.14.3.dist-info → stouputils-1.15.0.dist-info}/WHEEL +1 -1
- {stouputils-1.14.3.dist-info → stouputils-1.15.0.dist-info}/entry_points.txt +0 -0
|
@@ -1,148 +1,148 @@
|
|
|
1
|
-
|
|
2
|
-
# pyright: reportMissingTypeStubs=false
|
|
3
|
-
|
|
4
|
-
# Imports
|
|
5
|
-
from typing import Any
|
|
6
|
-
|
|
7
|
-
import tensorflow as tf
|
|
8
|
-
from keras.callbacks import Callback
|
|
9
|
-
from keras.models import Model
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class LearningRateFinder(Callback):
|
|
13
|
-
""" Callback to find optimal learning rate by increasing LR during training.
|
|
14
|
-
|
|
15
|
-
Sources:
|
|
16
|
-
- Inspired by: https://github.com/WittmannF/LRFinder
|
|
17
|
-
- Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186 (first description of the method)
|
|
18
|
-
|
|
19
|
-
This callback gradually increases the learning rate from a minimum to a maximum value
|
|
20
|
-
during training, allowing you to identify the optimal learning rate range for your model.
|
|
21
|
-
|
|
22
|
-
It works by:
|
|
23
|
-
|
|
24
|
-
1. Starting with a very small learning rate
|
|
25
|
-
2. Exponentially increasing it after each batch or epoch
|
|
26
|
-
3. Recording the loss at each learning rate
|
|
27
|
-
4. Restoring the model's initial weights after training
|
|
28
|
-
|
|
29
|
-
The optimal learning rate is typically found where the loss is decreasing most rapidly
|
|
30
|
-
before it starts to diverge.
|
|
31
|
-
|
|
32
|
-
.. image:: https://blog.dataiku.com/hubfs/training%20loss.png
|
|
33
|
-
:alt: Learning rate finder curve example
|
|
34
|
-
"""
|
|
35
|
-
|
|
36
|
-
def __init__(
|
|
37
|
-
self,
|
|
38
|
-
min_lr: float,
|
|
39
|
-
max_lr: float,
|
|
40
|
-
steps_per_epoch: int,
|
|
41
|
-
epochs: int,
|
|
42
|
-
update_per_epoch: bool = False,
|
|
43
|
-
update_interval: int = 5
|
|
44
|
-
) -> None:
|
|
45
|
-
""" Initialize the learning rate finder.
|
|
46
|
-
|
|
47
|
-
Args:
|
|
48
|
-
min_lr (float): Minimum learning rate
|
|
49
|
-
max_lr (float): Maximum learning rate
|
|
50
|
-
steps_per_epoch (int): Steps per epoch
|
|
51
|
-
epochs (int): Number of epochs
|
|
52
|
-
update_per_epoch (bool): If True, update LR once per epoch instead of every batch.
|
|
53
|
-
update_interval (int): Number of steps between each lr increase, bigger value means more stable loss.
|
|
54
|
-
"""
|
|
55
|
-
super().__init__()
|
|
56
|
-
self.min_lr: float = min_lr
|
|
57
|
-
""" Minimum learning rate. """
|
|
58
|
-
self.max_lr: float = max_lr
|
|
59
|
-
""" Maximum learning rate. """
|
|
60
|
-
self.total_updates: int = (epochs if update_per_epoch else steps_per_epoch * epochs) // update_interval
|
|
61
|
-
""" Total number of update steps (considering update_interval). """
|
|
62
|
-
self.update_per_epoch: bool = update_per_epoch
|
|
63
|
-
""" Whether to update learning rate per epoch instead of per batch. """
|
|
64
|
-
self.update_interval: int = max(1, int(update_interval))
|
|
65
|
-
""" Number of steps between each lr increase, bigger value means more stable loss. """
|
|
66
|
-
self.lr_mult: float = (max_lr / min_lr) ** (1 / self.total_updates)
|
|
67
|
-
""" Learning rate multiplier. """
|
|
68
|
-
self.learning_rates: list[float] = []
|
|
69
|
-
""" List of learning rates. """
|
|
70
|
-
self.losses: list[float] = []
|
|
71
|
-
""" List of losses. """
|
|
72
|
-
self.best_lr: float = min_lr
|
|
73
|
-
""" Best learning rate. """
|
|
74
|
-
self.best_loss: float = float("inf")
|
|
75
|
-
""" Best loss. """
|
|
76
|
-
self.model: Model
|
|
77
|
-
""" Model to apply the learning rate finder to. """
|
|
78
|
-
self.initial_weights: list[Any] | None = None
|
|
79
|
-
""" Stores the initial weights of the model. """
|
|
80
|
-
|
|
81
|
-
def on_train_begin(self, logs: dict[str, Any] | None = None) -> None:
|
|
82
|
-
""" Set initial learning rate and save initial model weights at the start of training.
|
|
83
|
-
|
|
84
|
-
Args:
|
|
85
|
-
logs (dict | None): Training logs.
|
|
86
|
-
"""
|
|
87
|
-
self.initial_weights = self.model.get_weights()
|
|
88
|
-
tf.keras.backend.set_value(self.model.optimizer.learning_rate, self.min_lr) # type: ignore
|
|
89
|
-
|
|
90
|
-
def _update_lr_and_track_metrics(self, logs: dict[str, Any] | None = None) -> None:
|
|
91
|
-
""" Update learning rate and track metrics.
|
|
92
|
-
|
|
93
|
-
Args:
|
|
94
|
-
logs (dict | None): Logs from training
|
|
95
|
-
"""
|
|
96
|
-
if logs is None:
|
|
97
|
-
return
|
|
98
|
-
|
|
99
|
-
# Get current learning rate and loss
|
|
100
|
-
current_lr: float = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate)) # type: ignore
|
|
101
|
-
current_loss: float = logs["loss"]
|
|
102
|
-
|
|
103
|
-
# Record values
|
|
104
|
-
self.learning_rates.append(current_lr)
|
|
105
|
-
self.losses.append(current_loss)
|
|
106
|
-
|
|
107
|
-
# Track best values
|
|
108
|
-
if current_loss < self.best_loss:
|
|
109
|
-
self.best_loss = current_loss
|
|
110
|
-
self.best_lr = current_lr
|
|
111
|
-
|
|
112
|
-
# Update learning rate
|
|
113
|
-
new_lr: float = current_lr * self.lr_mult
|
|
114
|
-
tf.keras.backend.set_value(self.model.optimizer.learning_rate, new_lr) # type: ignore
|
|
115
|
-
|
|
116
|
-
def on_batch_end(self, batch: int, logs: dict[str, Any] | None = None) -> None:
|
|
117
|
-
""" Record loss and increase learning rate after each batch if not updating per epoch.
|
|
118
|
-
|
|
119
|
-
Args:
|
|
120
|
-
batch (int): Current batch index.
|
|
121
|
-
logs (dict | None): Training logs.
|
|
122
|
-
"""
|
|
123
|
-
if self.update_per_epoch:
|
|
124
|
-
return
|
|
125
|
-
if batch % self.update_interval == 0:
|
|
126
|
-
self._update_lr_and_track_metrics(logs)
|
|
127
|
-
|
|
128
|
-
def on_epoch_end(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
|
|
129
|
-
""" Record loss and increase learning rate after each epoch if updating per epoch.
|
|
130
|
-
|
|
131
|
-
Args:
|
|
132
|
-
epoch (int): Current epoch index.
|
|
133
|
-
logs (dict | None): Training logs.
|
|
134
|
-
"""
|
|
135
|
-
if not self.update_per_epoch:
|
|
136
|
-
return
|
|
137
|
-
if epoch % self.update_interval == 0:
|
|
138
|
-
self._update_lr_and_track_metrics(logs)
|
|
139
|
-
|
|
140
|
-
def on_train_end(self, logs: dict[str, Any] | None = None) -> None:
|
|
141
|
-
""" Restore initial model weights at the end of training.
|
|
142
|
-
|
|
143
|
-
Args:
|
|
144
|
-
logs (dict | None): Training logs.
|
|
145
|
-
"""
|
|
146
|
-
if self.initial_weights is not None:
|
|
147
|
-
self.model.set_weights(self.initial_weights) # pyright: ignore [reportUnknownMemberType]
|
|
148
|
-
|
|
1
|
+
|
|
2
|
+
# pyright: reportMissingTypeStubs=false
|
|
3
|
+
|
|
4
|
+
# Imports
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import tensorflow as tf
|
|
8
|
+
from keras.callbacks import Callback
|
|
9
|
+
from keras.models import Model
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class LearningRateFinder(Callback):
|
|
13
|
+
""" Callback to find optimal learning rate by increasing LR during training.
|
|
14
|
+
|
|
15
|
+
Sources:
|
|
16
|
+
- Inspired by: https://github.com/WittmannF/LRFinder
|
|
17
|
+
- Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186 (first description of the method)
|
|
18
|
+
|
|
19
|
+
This callback gradually increases the learning rate from a minimum to a maximum value
|
|
20
|
+
during training, allowing you to identify the optimal learning rate range for your model.
|
|
21
|
+
|
|
22
|
+
It works by:
|
|
23
|
+
|
|
24
|
+
1. Starting with a very small learning rate
|
|
25
|
+
2. Exponentially increasing it after each batch or epoch
|
|
26
|
+
3. Recording the loss at each learning rate
|
|
27
|
+
4. Restoring the model's initial weights after training
|
|
28
|
+
|
|
29
|
+
The optimal learning rate is typically found where the loss is decreasing most rapidly
|
|
30
|
+
before it starts to diverge.
|
|
31
|
+
|
|
32
|
+
.. image:: https://blog.dataiku.com/hubfs/training%20loss.png
|
|
33
|
+
:alt: Learning rate finder curve example
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
min_lr: float,
|
|
39
|
+
max_lr: float,
|
|
40
|
+
steps_per_epoch: int,
|
|
41
|
+
epochs: int,
|
|
42
|
+
update_per_epoch: bool = False,
|
|
43
|
+
update_interval: int = 5
|
|
44
|
+
) -> None:
|
|
45
|
+
""" Initialize the learning rate finder.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
min_lr (float): Minimum learning rate
|
|
49
|
+
max_lr (float): Maximum learning rate
|
|
50
|
+
steps_per_epoch (int): Steps per epoch
|
|
51
|
+
epochs (int): Number of epochs
|
|
52
|
+
update_per_epoch (bool): If True, update LR once per epoch instead of every batch.
|
|
53
|
+
update_interval (int): Number of steps between each lr increase, bigger value means more stable loss.
|
|
54
|
+
"""
|
|
55
|
+
super().__init__()
|
|
56
|
+
self.min_lr: float = min_lr
|
|
57
|
+
""" Minimum learning rate. """
|
|
58
|
+
self.max_lr: float = max_lr
|
|
59
|
+
""" Maximum learning rate. """
|
|
60
|
+
self.total_updates: int = (epochs if update_per_epoch else steps_per_epoch * epochs) // update_interval
|
|
61
|
+
""" Total number of update steps (considering update_interval). """
|
|
62
|
+
self.update_per_epoch: bool = update_per_epoch
|
|
63
|
+
""" Whether to update learning rate per epoch instead of per batch. """
|
|
64
|
+
self.update_interval: int = max(1, int(update_interval))
|
|
65
|
+
""" Number of steps between each lr increase, bigger value means more stable loss. """
|
|
66
|
+
self.lr_mult: float = (max_lr / min_lr) ** (1 / self.total_updates)
|
|
67
|
+
""" Learning rate multiplier. """
|
|
68
|
+
self.learning_rates: list[float] = []
|
|
69
|
+
""" List of learning rates. """
|
|
70
|
+
self.losses: list[float] = []
|
|
71
|
+
""" List of losses. """
|
|
72
|
+
self.best_lr: float = min_lr
|
|
73
|
+
""" Best learning rate. """
|
|
74
|
+
self.best_loss: float = float("inf")
|
|
75
|
+
""" Best loss. """
|
|
76
|
+
self.model: Model
|
|
77
|
+
""" Model to apply the learning rate finder to. """
|
|
78
|
+
self.initial_weights: list[Any] | None = None
|
|
79
|
+
""" Stores the initial weights of the model. """
|
|
80
|
+
|
|
81
|
+
def on_train_begin(self, logs: dict[str, Any] | None = None) -> None:
|
|
82
|
+
""" Set initial learning rate and save initial model weights at the start of training.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
logs (dict | None): Training logs.
|
|
86
|
+
"""
|
|
87
|
+
self.initial_weights = self.model.get_weights()
|
|
88
|
+
tf.keras.backend.set_value(self.model.optimizer.learning_rate, self.min_lr) # type: ignore
|
|
89
|
+
|
|
90
|
+
def _update_lr_and_track_metrics(self, logs: dict[str, Any] | None = None) -> None:
|
|
91
|
+
""" Update learning rate and track metrics.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
logs (dict | None): Logs from training
|
|
95
|
+
"""
|
|
96
|
+
if logs is None:
|
|
97
|
+
return
|
|
98
|
+
|
|
99
|
+
# Get current learning rate and loss
|
|
100
|
+
current_lr: float = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate)) # type: ignore
|
|
101
|
+
current_loss: float = logs["loss"]
|
|
102
|
+
|
|
103
|
+
# Record values
|
|
104
|
+
self.learning_rates.append(current_lr)
|
|
105
|
+
self.losses.append(current_loss)
|
|
106
|
+
|
|
107
|
+
# Track best values
|
|
108
|
+
if current_loss < self.best_loss:
|
|
109
|
+
self.best_loss = current_loss
|
|
110
|
+
self.best_lr = current_lr
|
|
111
|
+
|
|
112
|
+
# Update learning rate
|
|
113
|
+
new_lr: float = current_lr * self.lr_mult
|
|
114
|
+
tf.keras.backend.set_value(self.model.optimizer.learning_rate, new_lr) # type: ignore
|
|
115
|
+
|
|
116
|
+
def on_batch_end(self, batch: int, logs: dict[str, Any] | None = None) -> None:
|
|
117
|
+
""" Record loss and increase learning rate after each batch if not updating per epoch.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
batch (int): Current batch index.
|
|
121
|
+
logs (dict | None): Training logs.
|
|
122
|
+
"""
|
|
123
|
+
if self.update_per_epoch:
|
|
124
|
+
return
|
|
125
|
+
if batch % self.update_interval == 0:
|
|
126
|
+
self._update_lr_and_track_metrics(logs)
|
|
127
|
+
|
|
128
|
+
def on_epoch_end(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
|
|
129
|
+
""" Record loss and increase learning rate after each epoch if updating per epoch.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
epoch (int): Current epoch index.
|
|
133
|
+
logs (dict | None): Training logs.
|
|
134
|
+
"""
|
|
135
|
+
if not self.update_per_epoch:
|
|
136
|
+
return
|
|
137
|
+
if epoch % self.update_interval == 0:
|
|
138
|
+
self._update_lr_and_track_metrics(logs)
|
|
139
|
+
|
|
140
|
+
def on_train_end(self, logs: dict[str, Any] | None = None) -> None:
|
|
141
|
+
""" Restore initial model weights at the end of training.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
logs (dict | None): Training logs.
|
|
145
|
+
"""
|
|
146
|
+
if self.initial_weights is not None:
|
|
147
|
+
self.model.set_weights(self.initial_weights) # pyright: ignore [reportUnknownMemberType]
|
|
148
|
+
|
|
@@ -1,31 +1,31 @@
|
|
|
1
|
-
|
|
2
|
-
# pyright: reportMissingTypeStubs=false
|
|
3
|
-
# pyright: reportUnknownMemberType=false
|
|
4
|
-
|
|
5
|
-
# Imports
|
|
6
|
-
from typing import Any
|
|
7
|
-
|
|
8
|
-
from keras.callbacks import ModelCheckpoint
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
class ModelCheckpointV2(ModelCheckpoint):
|
|
12
|
-
""" Model checkpoint callback but only starts after a given number of epochs.
|
|
13
|
-
|
|
14
|
-
Args:
|
|
15
|
-
epochs_before_start (int): Number of epochs before starting the checkpointing
|
|
16
|
-
"""
|
|
17
|
-
|
|
18
|
-
def __init__(self, epochs_before_start: int = 3, *args: Any, **kwargs: Any) -> None:
|
|
19
|
-
super().__init__(*args, **kwargs)
|
|
20
|
-
self.epochs_before_start = epochs_before_start
|
|
21
|
-
self.current_epoch = 0
|
|
22
|
-
|
|
23
|
-
def on_batch_end(self, batch: int, logs: dict[str, Any] | None = None) -> None:
|
|
24
|
-
if self.current_epoch >= self.epochs_before_start:
|
|
25
|
-
super().on_batch_end(batch, logs)
|
|
26
|
-
|
|
27
|
-
def on_epoch_end(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
|
|
28
|
-
self.current_epoch = epoch
|
|
29
|
-
if epoch >= self.epochs_before_start:
|
|
30
|
-
super().on_epoch_end(epoch, logs)
|
|
31
|
-
|
|
1
|
+
|
|
2
|
+
# pyright: reportMissingTypeStubs=false
|
|
3
|
+
# pyright: reportUnknownMemberType=false
|
|
4
|
+
|
|
5
|
+
# Imports
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from keras.callbacks import ModelCheckpoint
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ModelCheckpointV2(ModelCheckpoint):
|
|
12
|
+
""" Model checkpoint callback but only starts after a given number of epochs.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
epochs_before_start (int): Number of epochs before starting the checkpointing
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, epochs_before_start: int = 3, *args: Any, **kwargs: Any) -> None:
|
|
19
|
+
super().__init__(*args, **kwargs)
|
|
20
|
+
self.epochs_before_start = epochs_before_start
|
|
21
|
+
self.current_epoch = 0
|
|
22
|
+
|
|
23
|
+
def on_batch_end(self, batch: int, logs: dict[str, Any] | None = None) -> None:
|
|
24
|
+
if self.current_epoch >= self.epochs_before_start:
|
|
25
|
+
super().on_batch_end(batch, logs)
|
|
26
|
+
|
|
27
|
+
def on_epoch_end(self, epoch: int, logs: dict[str, Any] | None = None) -> None:
|
|
28
|
+
self.current_epoch = epoch
|
|
29
|
+
if epoch >= self.epochs_before_start:
|
|
30
|
+
super().on_epoch_end(epoch, logs)
|
|
31
|
+
|