pg-sui 0.2.0__py3-none-any.whl → 1.6.14.dev9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info}/METADATA +101 -79
- pg_sui-1.6.14.dev9.dist-info/RECORD +81 -0
- {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info}/WHEEL +1 -1
- pg_sui-1.6.14.dev9.dist-info/entry_points.txt +4 -0
- {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info/licenses}/LICENSE +0 -0
- pg_sui-1.6.14.dev9.dist-info/top_level.txt +1 -0
- pgsui/__init__.py +35 -54
- pgsui/_version.py +34 -0
- pgsui/cli.py +909 -0
- pgsui/data_processing/__init__.py +0 -0
- pgsui/data_processing/config.py +565 -0
- pgsui/data_processing/containers.py +1424 -0
- pgsui/data_processing/transformers.py +557 -907
- pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
- pgsui/electron/app/__main__.py +5 -0
- pgsui/electron/app/extra-resources/.gitkeep +1 -0
- pgsui/electron/app/icons/icons/1024x1024.png +0 -0
- pgsui/electron/app/icons/icons/128x128.png +0 -0
- pgsui/electron/app/icons/icons/16x16.png +0 -0
- pgsui/electron/app/icons/icons/24x24.png +0 -0
- pgsui/electron/app/icons/icons/256x256.png +0 -0
- pgsui/electron/app/icons/icons/32x32.png +0 -0
- pgsui/electron/app/icons/icons/48x48.png +0 -0
- pgsui/electron/app/icons/icons/512x512.png +0 -0
- pgsui/electron/app/icons/icons/64x64.png +0 -0
- pgsui/electron/app/icons/icons/icon.icns +0 -0
- pgsui/electron/app/icons/icons/icon.ico +0 -0
- pgsui/electron/app/main.js +227 -0
- pgsui/electron/app/package-lock.json +6894 -0
- pgsui/electron/app/package.json +51 -0
- pgsui/electron/app/preload.js +15 -0
- pgsui/electron/app/server.py +157 -0
- pgsui/electron/app/ui/logo.png +0 -0
- pgsui/electron/app/ui/renderer.js +131 -0
- pgsui/electron/app/ui/styles.css +59 -0
- pgsui/electron/app/ui/ui_shim.js +72 -0
- pgsui/electron/bootstrap.py +43 -0
- pgsui/electron/launch.py +57 -0
- pgsui/electron/package.json +14 -0
- pgsui/example_data/__init__.py +0 -0
- pgsui/example_data/phylip_files/__init__.py +0 -0
- pgsui/example_data/phylip_files/test.phy +0 -0
- pgsui/example_data/popmaps/__init__.py +0 -0
- pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
- pgsui/example_data/structure_files/__init__.py +0 -0
- pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
- pgsui/impute/__init__.py +0 -0
- pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
- pgsui/impute/deterministic/imputers/mode.py +844 -0
- pgsui/impute/deterministic/imputers/nmf.py +221 -0
- pgsui/impute/deterministic/imputers/phylo.py +973 -0
- pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
- pgsui/impute/supervised/__init__.py +0 -0
- pgsui/impute/supervised/base.py +343 -0
- pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
- pgsui/impute/supervised/imputers/random_forest.py +291 -0
- pgsui/impute/unsupervised/__init__.py +0 -0
- pgsui/impute/unsupervised/base.py +1118 -0
- pgsui/impute/unsupervised/callbacks.py +92 -262
- {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
- pgsui/impute/unsupervised/imputers/autoencoder.py +1285 -0
- pgsui/impute/unsupervised/imputers/nlpca.py +1554 -0
- pgsui/impute/unsupervised/imputers/ubp.py +1575 -0
- pgsui/impute/unsupervised/imputers/vae.py +1228 -0
- pgsui/impute/unsupervised/loss_functions.py +261 -0
- pgsui/impute/unsupervised/models/__init__.py +0 -0
- pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
- pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
- pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
- pgsui/impute/unsupervised/models/vae_model.py +269 -630
- pgsui/impute/unsupervised/nn_scorers.py +255 -0
- pgsui/utils/__init__.py +0 -0
- pgsui/utils/classification_viz.py +608 -0
- pgsui/utils/logging_utils.py +22 -0
- pgsui/utils/misc.py +35 -480
- pgsui/utils/plotting.py +996 -829
- pgsui/utils/pretty_metrics.py +290 -0
- pgsui/utils/scorers.py +213 -666
- pg_sui-0.2.0.dist-info/RECORD +0 -75
- pg_sui-0.2.0.dist-info/top_level.txt +0 -3
- pgsui/example_data/phylip_files/test_n10.phy +0 -118
- pgsui/example_data/phylip_files/test_n100.phy +0 -118
- pgsui/example_data/phylip_files/test_n2.phy +0 -118
- pgsui/example_data/phylip_files/test_n500.phy +0 -118
- pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
- pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
- pgsui/example_data/trees/test.iqtree +0 -376
- pgsui/example_data/trees/test.qmat +0 -5
- pgsui/example_data/trees/test.rate +0 -2033
- pgsui/example_data/trees/test.tre +0 -1
- pgsui/example_data/trees/test_n10.rate +0 -19
- pgsui/example_data/trees/test_n100.rate +0 -109
- pgsui/example_data/trees/test_n500.rate +0 -509
- pgsui/example_data/trees/test_siterates.txt +0 -2024
- pgsui/example_data/trees/test_siterates_n10.txt +0 -10
- pgsui/example_data/trees/test_siterates_n100.txt +0 -100
- pgsui/example_data/trees/test_siterates_n500.txt +0 -500
- pgsui/example_data/vcf_files/test.vcf +0 -244
- pgsui/example_data/vcf_files/test.vcf.gz +0 -0
- pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
- pgsui/impute/estimators.py +0 -1268
- pgsui/impute/impute.py +0 -1463
- pgsui/impute/simple_imputers.py +0 -1431
- pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
- pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
- pgsui/impute/unsupervised/keras_classifiers.py +0 -697
- pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
- pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
- pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
- pgsui/pg_sui.py +0 -261
- pgsui/utils/sequence_tools.py +0 -407
- simulation/sim_benchmarks.py +0 -333
- simulation/sim_treeparams.py +0 -475
- test/__init__.py +0 -0
- test/pg_sui_simtest.py +0 -215
- test/pg_sui_testing.py +0 -523
- test/test.py +0 -151
- test/test_pgsui.py +0 -374
- test/test_tkc.py +0 -185
|
@@ -1,286 +1,116 @@
|
|
|
1
|
-
import
|
|
2
|
-
import sys
|
|
1
|
+
from snpio.utils.logging import LoggerManager
|
|
3
2
|
|
|
4
|
-
|
|
5
|
-
import tensorflow as tf
|
|
3
|
+
from pgsui.utils.logging_utils import configure_logger
|
|
6
4
|
|
|
7
5
|
|
|
8
|
-
class
|
|
9
|
-
"""
|
|
6
|
+
class EarlyStopping:
|
|
7
|
+
"""Class to stop the training when a monitored metric has stopped improving.
|
|
10
8
|
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
This process is supposed to improve the latent distribution sampling for the variational autoencoder model and eliminate the KL vanishing issue.
|
|
14
|
-
|
|
15
|
-
Three types of cycle curves can be used that determine how the weight increases: 'linear', 'sigmoid', and 'cosine'..
|
|
16
|
-
|
|
17
|
-
Code is adapted from: https://github.com/haofuml/cyclical_annealing
|
|
18
|
-
|
|
19
|
-
The cyclical annealing process was first described in the following paper: https://aclanthology.org/N19-1021.pdf
|
|
20
|
-
|
|
21
|
-
Args:
|
|
22
|
-
n_iter (int): Number of iterations (epochs) being used in training.
|
|
23
|
-
start (float, optional): Where to start cycles. Defaults to 0.0.
|
|
24
|
-
stop (float, optional): Where to stop cycles. Defaults to 1.0.
|
|
25
|
-
n_cycle (int, optional): How many cycles to use across all the epochs. Defaults to 4.
|
|
26
|
-
ratio (float, optional): Ratio to determine proportion used to increase beta. Defaults to 0.5.
|
|
27
|
-
schedule_type (str, optional): Type of curve to use for scheduler. Possible options include: 'linear', 'sigmoid', or 'cosine'. Defaults to 'linear'.
|
|
9
|
+
This class is used to stop the training of a model when a monitored metric has stopped improving (such as validation loss or accuracy). If the metric does not improve for `patience` epochs, and we have already passed the `min_epochs` epoch threshold, training is halted. The best model checkpoint is reloaded when early stopping is triggered.
|
|
28
10
|
|
|
11
|
+
Example:
|
|
12
|
+
>>> early_stopping = EarlyStopping(patience=25, verbose=1, min_epochs=100)
|
|
13
|
+
>>> for epoch in range(1, 1001):
|
|
14
|
+
>>> val_loss = train_epoch(...)
|
|
15
|
+
>>> early_stopping(val_loss, model)
|
|
16
|
+
>>> if early_stopping.early_stop:
|
|
17
|
+
>>> break
|
|
29
18
|
"""
|
|
30
19
|
|
|
31
20
|
def __init__(
|
|
32
21
|
self,
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
22
|
+
patience: int = 25,
|
|
23
|
+
delta: float = 0.0,
|
|
24
|
+
verbose: int = 0,
|
|
25
|
+
mode: str = "min",
|
|
26
|
+
min_epochs: int = 100,
|
|
27
|
+
prefix: str = "pgsui_output",
|
|
28
|
+
debug: bool = False,
|
|
39
29
|
):
|
|
40
|
-
|
|
41
|
-
self.start = start
|
|
42
|
-
self.stop = stop
|
|
43
|
-
self.n_cycle = n_cycle
|
|
44
|
-
self.ratio = ratio
|
|
45
|
-
self.schedule_type = schedule_type
|
|
30
|
+
"""Early stopping callback for PyTorch training.
|
|
46
31
|
|
|
47
|
-
|
|
32
|
+
This class is used to stop the training of a model when a monitored metric has stopped improving (such as validation loss or accuracy). If the metric does not improve for `patience` epochs, and we have already passed the `min_epochs` epoch threshold, training is halted. The best model checkpoint is reloaded when early stopping is triggered. The `mode` parameter can be set to "min" or "max" to indicate whether the metric should be minimized or maximized, respectively.
|
|
48
33
|
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
34
|
+
Args:
|
|
35
|
+
patience (int): Number of epochs to wait after the last time the monitored metric improved.
|
|
36
|
+
delta (float): Minimum change in the monitored metric to qualify as an improvement.
|
|
37
|
+
verbose (int): Verbosity level (0 = silent, 1 = improvement messages, 2+ = more).
|
|
38
|
+
mode (str): "min" or "max" to indicate how improvement is defined.
|
|
39
|
+
prefix (str): Prefix for directory naming.
|
|
40
|
+
output_dir (Path): Directory in which to create subfolders/checkpoints.
|
|
41
|
+
min_epochs (int): Minimum epoch count before early stopping can take effect.
|
|
42
|
+
debug (bool): Debug mode for logging messages
|
|
43
|
+
|
|
44
|
+
Raises:
|
|
45
|
+
ValueError: If an invalid mode is provided. Must be "min" or "max".
|
|
53
46
|
"""
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
47
|
+
self.patience = patience
|
|
48
|
+
self.delta = delta
|
|
49
|
+
self.verbose = verbose >= 2 or debug
|
|
50
|
+
self.debug = debug
|
|
51
|
+
self.mode = mode
|
|
52
|
+
self.counter = 0
|
|
53
|
+
self.epoch_count = 0
|
|
54
|
+
self.best_score = float("inf") if mode == "min" else 0.0
|
|
55
|
+
self.early_stop = False
|
|
56
|
+
self.best_model = None
|
|
57
|
+
self.min_epochs = min_epochs
|
|
58
|
+
|
|
59
|
+
is_verbose = verbose >= 2 or debug
|
|
60
|
+
logman = LoggerManager(name=__name__, prefix=prefix, verbose=is_verbose)
|
|
61
|
+
self.logger = configure_logger(
|
|
62
|
+
logman.get_logger(), verbose=is_verbose, debug=debug
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# Define the comparison function for the monitored metric
|
|
66
|
+
if mode == "min":
|
|
67
|
+
self.monitor = lambda current, best: current < best - self.delta
|
|
68
|
+
elif mode == "max":
|
|
69
|
+
self.monitor = lambda current, best: current > best + self.delta
|
|
60
70
|
else:
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
)
|
|
64
|
-
|
|
65
|
-
self.arr = cycle_func()
|
|
71
|
+
msg = f"Invalid mode provided: '{mode}'. Use 'min' or 'max'."
|
|
72
|
+
self.logger.error(msg)
|
|
73
|
+
raise ValueError(msg)
|
|
66
74
|
|
|
67
|
-
def
|
|
68
|
-
"""
|
|
69
|
-
|
|
70
|
-
Here, the new kl_beta weight is set.
|
|
75
|
+
def __call__(self, score, model):
|
|
76
|
+
"""Checks if early stopping condition is met and checkpoints model accordingly.
|
|
71
77
|
|
|
72
78
|
Args:
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
"""
|
|
76
|
-
idx = epoch - 1
|
|
77
|
-
new_weight = self.arr[idx]
|
|
78
|
-
|
|
79
|
-
tf.keras.backend.set_value(self.model.kl_beta, new_weight)
|
|
80
|
-
|
|
81
|
-
def _linear_cycle_range(self):
|
|
82
|
-
"""Get an array with a linear cycle curve ranging from 0 to 1 for n_iter epochs.
|
|
83
|
-
|
|
84
|
-
The amount of time cycling and spent at 1.0 is determined by the ratio variable.
|
|
85
|
-
|
|
86
|
-
Returns:
|
|
87
|
-
numpy.ndarray: Linear cycle range.
|
|
79
|
+
score (float): The current metric value (e.g., validation loss/accuracy).
|
|
80
|
+
model (torch.nn.Module): The model being trained.
|
|
88
81
|
"""
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
#
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
def _sigmoid_cycle_range(self):
|
|
107
|
-
"""Get sigmoidal curve cycle ranging from 0 to 1 for n_iter epochs.
|
|
108
|
-
|
|
109
|
-
The amount of time cycling and spent at 1.0 is determined by the ratio variable.
|
|
110
|
-
|
|
111
|
-
Returns:
|
|
112
|
-
numpy.ndarray: Sigmoidal cycle range.
|
|
113
|
-
"""
|
|
114
|
-
L = np.ones(self.n_iter)
|
|
115
|
-
period = self.n_iter / self.n_cycle
|
|
116
|
-
step = (self.stop - self.start) / (
|
|
117
|
-
period * self.ratio
|
|
118
|
-
) # step is in [0,1]
|
|
119
|
-
|
|
120
|
-
for c in range(self.n_cycle):
|
|
121
|
-
v, i = self.start, 0
|
|
82
|
+
# Increment the epoch count each time we call this function
|
|
83
|
+
self.epoch_count += 1
|
|
84
|
+
|
|
85
|
+
# If this is the first epoch, initialize best_score and save model
|
|
86
|
+
if self.best_score is None:
|
|
87
|
+
self.best_score = score
|
|
88
|
+
return
|
|
89
|
+
|
|
90
|
+
# Check if there is improvement
|
|
91
|
+
if self.monitor(score, self.best_score):
|
|
92
|
+
# If improved, reset counter and update the best score/model
|
|
93
|
+
self.best_score = score
|
|
94
|
+
self.best_model = model
|
|
95
|
+
self.counter = 0
|
|
96
|
+
else:
|
|
97
|
+
# No improvement: increase counter
|
|
98
|
+
self.counter += 1
|
|
122
99
|
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
100
|
+
if self.verbose:
|
|
101
|
+
self.logger.info(
|
|
102
|
+
f"EarlyStopping counter: {self.counter}/{self.patience}"
|
|
126
103
|
)
|
|
127
|
-
v += step
|
|
128
|
-
i += 1
|
|
129
|
-
return L
|
|
130
|
-
|
|
131
|
-
def _cosine_cycle_range(self):
|
|
132
|
-
"""Get cosine curve cycle ranging from 0 to 1 for n_iter epochs.
|
|
133
|
-
|
|
134
|
-
The amount of time cycling and spent at 1.0 is determined by the ratio variable.
|
|
135
|
-
|
|
136
|
-
Returns:
|
|
137
|
-
numpy.ndarray: Cosine cycle range.
|
|
138
|
-
"""
|
|
139
|
-
L = np.ones(self.n_iter)
|
|
140
|
-
period = self.n_iter / self.n_cycle
|
|
141
|
-
step = (self.stop - self.start) / (
|
|
142
|
-
period * self.ratio
|
|
143
|
-
) # step is in [0,1]
|
|
144
|
-
|
|
145
|
-
for c in range(self.n_cycle):
|
|
146
|
-
v, i = self.start, 0
|
|
147
|
-
|
|
148
|
-
while v <= self.stop:
|
|
149
|
-
L[int(i + c * period)] = 0.5 - 0.5 * math.cos(v * math.pi)
|
|
150
|
-
v += step
|
|
151
|
-
i += 1
|
|
152
|
-
return L
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
class VAECallbacks(tf.keras.callbacks.Callback):
|
|
156
|
-
"""Custom callbacks to use with subclassed VAE Keras model.
|
|
157
|
-
|
|
158
|
-
Requires y, missing_mask, and sample_weight to be input variables to be properties with setters in the subclassed model.
|
|
159
|
-
"""
|
|
160
|
-
|
|
161
|
-
def __init__(self):
|
|
162
|
-
self.indices = None
|
|
163
|
-
|
|
164
|
-
def on_epoch_begin(self, epoch, logs=None):
|
|
165
|
-
"""Shuffle input and target at start of epoch."""
|
|
166
|
-
y = self.model.y.copy()
|
|
167
|
-
missing_mask = self.model.missing_mask
|
|
168
|
-
sample_weight = self.model.sample_weight
|
|
169
|
-
|
|
170
|
-
n_samples = len(y)
|
|
171
|
-
self.indices = np.arange(n_samples)
|
|
172
|
-
np.random.shuffle(self.indices)
|
|
173
|
-
|
|
174
|
-
self.model.y = y[self.indices]
|
|
175
|
-
self.model.missing_mask = missing_mask[self.indices]
|
|
176
|
-
|
|
177
|
-
if sample_weight is not None:
|
|
178
|
-
self.model.sample_weight = sample_weight[self.indices]
|
|
179
|
-
|
|
180
|
-
def on_train_batch_begin(self, batch, logs=None):
|
|
181
|
-
"""Get batch index."""
|
|
182
|
-
self.model.batch_idx = batch
|
|
183
|
-
|
|
184
|
-
def on_epoch_end(self, epoch, logs=None):
|
|
185
|
-
"""Unsort the row indices."""
|
|
186
|
-
unshuffled = np.argsort(self.indices)
|
|
187
|
-
|
|
188
|
-
self.model.y = self.model.y[unshuffled]
|
|
189
|
-
self.model.missing_mask = self.model.missing_mask[unshuffled]
|
|
190
|
-
|
|
191
|
-
if self.model.sample_weight is not None:
|
|
192
|
-
self.model.sample_weight = self.model.sample_weight[unshuffled]
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
class UBPCallbacks(tf.keras.callbacks.Callback):
|
|
196
|
-
"""Custom callbacks to use with subclassed NLPCA/ UBP Keras models.
|
|
197
|
-
|
|
198
|
-
Requires y, missing_mask, V_latent, and sample_weight to be input variables to be properties with setters in the subclassed model.
|
|
199
|
-
"""
|
|
200
|
-
|
|
201
|
-
def __init__(self):
|
|
202
|
-
self.indices = None
|
|
203
|
-
|
|
204
|
-
def on_epoch_begin(self, epoch, logs=None):
|
|
205
|
-
"""Shuffle input and target at start of epoch."""
|
|
206
|
-
y = self.model.y.copy()
|
|
207
|
-
missing_mask = self.model.missing_mask
|
|
208
|
-
sample_weight = self.model.sample_weight
|
|
209
|
-
|
|
210
|
-
n_samples = len(y)
|
|
211
|
-
self.indices = np.arange(n_samples)
|
|
212
|
-
np.random.shuffle(self.indices)
|
|
213
104
|
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
self.model.missing_mask = missing_mask[self.indices]
|
|
105
|
+
# Now check if we surpass patience AND have reached min_epochs
|
|
106
|
+
if self.counter >= self.patience and self.epoch_count >= self.min_epochs:
|
|
217
107
|
|
|
218
|
-
|
|
219
|
-
|
|
108
|
+
if self.best_model is None:
|
|
109
|
+
self.best_model = model
|
|
220
110
|
|
|
221
|
-
|
|
222
|
-
"""Get batch index."""
|
|
223
|
-
self.model.batch_idx = batch
|
|
224
|
-
|
|
225
|
-
def on_epoch_end(self, epoch, logs=None):
|
|
226
|
-
"""Unsort the row indices."""
|
|
227
|
-
unshuffled = np.argsort(self.indices)
|
|
228
|
-
|
|
229
|
-
self.model.y = self.model.y[unshuffled]
|
|
230
|
-
self.model.V_latent = self.model.V_latent[unshuffled]
|
|
231
|
-
self.model.missing_mask = self.model.missing_mask[unshuffled]
|
|
232
|
-
|
|
233
|
-
if self.model.sample_weight is not None:
|
|
234
|
-
self.model.sample_weight = self.model.sample_weight[unshuffled]
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
class UBPEarlyStopping(tf.keras.callbacks.Callback):
|
|
238
|
-
"""Stop training when the loss is at its min, i.e. the loss stops decreasing.
|
|
239
|
-
|
|
240
|
-
Args:
|
|
241
|
-
patience (int, optional): Number of epochs to wait after min has been hit. After this
|
|
242
|
-
number of no improvement, training stops. Defaults to 0.
|
|
243
|
-
|
|
244
|
-
phase (int, optional): Current UBP Phase. Defaults to 3.
|
|
245
|
-
"""
|
|
246
|
-
|
|
247
|
-
def __init__(self, patience=0, phase=3):
|
|
248
|
-
super(UBPEarlyStopping, self).__init__()
|
|
249
|
-
self.patience = patience
|
|
250
|
-
self.phase = phase
|
|
251
|
-
|
|
252
|
-
# best_weights to store the weights at which the minimum loss occurs.
|
|
253
|
-
self.best_weights = None
|
|
254
|
-
|
|
255
|
-
# In UBP, the input gets refined during training.
|
|
256
|
-
# So we have to revert it too.
|
|
257
|
-
self.best_input = None
|
|
258
|
-
|
|
259
|
-
def on_train_begin(self, logs=None):
|
|
260
|
-
# The number of epoch it has waited when loss is no longer minimum.
|
|
261
|
-
self.wait = 0
|
|
262
|
-
# The epoch the training stops at.
|
|
263
|
-
self.stopped_epoch = 0
|
|
264
|
-
# Initialize the best as infinity.
|
|
265
|
-
self.best = np.Inf
|
|
266
|
-
|
|
267
|
-
def on_epoch_end(self, epoch, logs=None):
|
|
268
|
-
current = logs.get("loss")
|
|
269
|
-
if np.less(current, self.best):
|
|
270
|
-
self.best = current
|
|
271
|
-
self.wait = 0
|
|
272
|
-
# Record the best weights if current results is better (less).
|
|
273
|
-
self.best_weights = self.model.get_weights()
|
|
274
|
-
|
|
275
|
-
if self.phase != 2:
|
|
276
|
-
# Only refine input in phase 2.
|
|
277
|
-
self.best_input = self.model.V_latent
|
|
278
|
-
else:
|
|
279
|
-
self.wait += 1
|
|
280
|
-
if self.wait >= self.patience:
|
|
281
|
-
self.stopped_epoch = epoch
|
|
282
|
-
self.model.stop_training = True
|
|
283
|
-
self.model.set_weights(self.best_weights)
|
|
111
|
+
self.early_stop = True
|
|
284
112
|
|
|
285
|
-
if self.
|
|
286
|
-
self.
|
|
113
|
+
if self.verbose:
|
|
114
|
+
self.logger.info(
|
|
115
|
+
f"Early stopping triggered at epoch {self.epoch_count}"
|
|
116
|
+
)
|
|
File without changes
|