likelihood 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- likelihood/graph/nn.py +8 -2
- likelihood/models/deep/autoencoders.py +312 -109
- likelihood/tools/figures.py +348 -0
- likelihood/tools/models_tools.py +161 -9
- likelihood/tools/tools.py +26 -84
- {likelihood-1.4.1.dist-info → likelihood-1.5.0.dist-info}/METADATA +1 -1
- {likelihood-1.4.1.dist-info → likelihood-1.5.0.dist-info}/RECORD +10 -9
- {likelihood-1.4.1.dist-info → likelihood-1.5.0.dist-info}/WHEEL +1 -1
- {likelihood-1.4.1.dist-info → likelihood-1.5.0.dist-info}/LICENSE +0 -0
- {likelihood-1.4.1.dist-info → likelihood-1.5.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,348 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import warnings
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import matplotlib.pyplot as plt
|
|
6
|
+
import numpy as np
|
|
7
|
+
from matplotlib.ticker import AutoMinorLocator
|
|
8
|
+
from scipy import stats
|
|
9
|
+
|
|
10
|
+
plt.rcParams.update({"font.size": 14})
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def act_pred(
|
|
14
|
+
y_act: np.ndarray,
|
|
15
|
+
y_pred: np.ndarray,
|
|
16
|
+
name: str = "example",
|
|
17
|
+
x_hist: bool = True,
|
|
18
|
+
y_hist: bool = True,
|
|
19
|
+
reg_line: bool = True,
|
|
20
|
+
save_dir: Optional[str] = None,
|
|
21
|
+
) -> None:
|
|
22
|
+
"""
|
|
23
|
+
Creates a scatter plot of actual vs predicted values along with histograms and a regression line.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
y_act : `np.ndarray`
|
|
28
|
+
The actual values (ground truth) as a 1D numpy array.
|
|
29
|
+
y_pred : `np.ndarray`
|
|
30
|
+
The predicted values as a 1D numpy array.
|
|
31
|
+
name : `str`, optional
|
|
32
|
+
The name for saving the plot. Default is "example".
|
|
33
|
+
x_hist : `bool`, optional
|
|
34
|
+
Whether to display the histogram for the actual values (y_act). Default is True.
|
|
35
|
+
y_hist : `bool`, optional
|
|
36
|
+
Whether to display the histogram for the predicted values (y_pred). Default is True.
|
|
37
|
+
reg_line : `bool`, optional
|
|
38
|
+
Whether to plot a regression line (best-fit line) in the scatter plot. Default is True.
|
|
39
|
+
save_dir : `Optional[str]`, optional
|
|
40
|
+
The directory to save the figure. If None, the figure will not be saved. Default is None.
|
|
41
|
+
|
|
42
|
+
Returns
|
|
43
|
+
-------
|
|
44
|
+
`None` : The function doesn't return anything. It generates and optionally saves a plot.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
y_pred, y_act = y_pred.flatten(), y_act.flatten()
|
|
48
|
+
|
|
49
|
+
if not isinstance(y_act, np.ndarray) or not isinstance(y_pred, np.ndarray):
|
|
50
|
+
raise ValueError("y_act and y_pred must be numpy arrays.")
|
|
51
|
+
if y_act.shape != y_pred.shape:
|
|
52
|
+
raise ValueError("y_act and y_pred must have the same shape.")
|
|
53
|
+
|
|
54
|
+
mec = "#2F4F4F"
|
|
55
|
+
mfc = "#C0C0C0"
|
|
56
|
+
|
|
57
|
+
fig = plt.figure(figsize=(6, 6))
|
|
58
|
+
|
|
59
|
+
left, width = 0.1, 0.65
|
|
60
|
+
bottom, height = 0.1, 0.65
|
|
61
|
+
bottom_h = left + width
|
|
62
|
+
left_h = left + width + 0.05
|
|
63
|
+
|
|
64
|
+
ax2 = fig.add_axes([left, bottom, width, height])
|
|
65
|
+
ax2.tick_params(direction="in", length=7, top=True, right=True)
|
|
66
|
+
ax2.xaxis.set_minor_locator(AutoMinorLocator(2))
|
|
67
|
+
ax2.yaxis.set_minor_locator(AutoMinorLocator(2))
|
|
68
|
+
|
|
69
|
+
ax2.scatter(y_act, y_pred, color=mfc, edgecolor=mec, alpha=0.5, s=35, lw=1.2)
|
|
70
|
+
ax2.plot(
|
|
71
|
+
[y_act.min(), y_act.max()], [y_act.min(), y_act.max()], "k--", alpha=0.8, label="Ideal"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
ax2.set_xlabel("Actual value")
|
|
75
|
+
ax2.set_ylabel("Predicted value")
|
|
76
|
+
ax2.set_xlim([y_act.min() * 1.05, y_act.max() * 1.05])
|
|
77
|
+
ax2.set_ylim([y_act.min() * 1.05, y_act.max() * 1.05])
|
|
78
|
+
|
|
79
|
+
ax1 = fig.add_axes([left, bottom_h, width, 0.15])
|
|
80
|
+
ax1.hist(y_act, bins=31, density=True, color=mfc, edgecolor=mec, alpha=0.6)
|
|
81
|
+
ax1.set_xticks([])
|
|
82
|
+
ax1.set_yticks([])
|
|
83
|
+
ax1.set_xlim(ax2.get_xlim())
|
|
84
|
+
|
|
85
|
+
if x_hist:
|
|
86
|
+
ax1.set_alpha(1.0)
|
|
87
|
+
|
|
88
|
+
ax3 = fig.add_axes([left_h, bottom, 0.15, height])
|
|
89
|
+
ax3.hist(
|
|
90
|
+
y_pred, bins=31, density=True, color=mfc, edgecolor=mec, orientation="horizontal", alpha=0.6
|
|
91
|
+
)
|
|
92
|
+
ax3.set_xticks([])
|
|
93
|
+
ax3.set_yticks([])
|
|
94
|
+
ax3.set_ylim(ax2.get_ylim())
|
|
95
|
+
|
|
96
|
+
if y_hist:
|
|
97
|
+
ax3.set_alpha(1.0)
|
|
98
|
+
|
|
99
|
+
if reg_line:
|
|
100
|
+
polyfit = np.polyfit(y_act, y_pred, deg=1)
|
|
101
|
+
reg_line_vals = np.poly1d(polyfit)(np.unique(y_act))
|
|
102
|
+
ax2.plot(np.unique(y_act), reg_line_vals, "r-", label="Regression Line", alpha=0.8)
|
|
103
|
+
|
|
104
|
+
ax2.legend(loc="upper left", framealpha=0.35, handlelength=1.5)
|
|
105
|
+
|
|
106
|
+
plt.tight_layout()
|
|
107
|
+
|
|
108
|
+
if save_dir is not None:
|
|
109
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
110
|
+
fig_name = os.path.join(save_dir, f"{name}_act_pred.png")
|
|
111
|
+
plt.savefig(fig_name, bbox_inches="tight", dpi=300)
|
|
112
|
+
|
|
113
|
+
plt.show()
|
|
114
|
+
plt.close(fig)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def residual(
|
|
118
|
+
y_act: np.ndarray, y_pred: np.ndarray, name: str = "example", save_dir: str = None
|
|
119
|
+
) -> None:
|
|
120
|
+
"""
|
|
121
|
+
Plots the residual errors between the actual and predicted values.
|
|
122
|
+
|
|
123
|
+
This function generates a residual plot by calculating the difference between the
|
|
124
|
+
actual values (y_act) and predicted values (y_pred). The plot shows the residuals
|
|
125
|
+
(y_pred - y_act) against the actual values. Optionally, the plot can be saved to a file.
|
|
126
|
+
|
|
127
|
+
Parameters
|
|
128
|
+
----------
|
|
129
|
+
y_act : `np.ndarray`
|
|
130
|
+
The actual values, typically the ground truth values.
|
|
131
|
+
y_pred : `np.ndarray`
|
|
132
|
+
The predicted values that are compared against the actual values.
|
|
133
|
+
name : `str`, optional
|
|
134
|
+
The name of the plot file (without extension) used when saving the plot. Default is "example".
|
|
135
|
+
save_dir : `str`, optional
|
|
136
|
+
The directory where the plot will be saved. If None, the plot is not saved. Default is None.
|
|
137
|
+
|
|
138
|
+
Returns
|
|
139
|
+
-------
|
|
140
|
+
`None` : This function does not return any value. It generates and optionally saves a plot.
|
|
141
|
+
|
|
142
|
+
Notes
|
|
143
|
+
-----
|
|
144
|
+
- The plot is shown with the residuals (y_pred - y_act) on the y-axis and the actual values (y_act)
|
|
145
|
+
on the x-axis. The plot includes a horizontal line representing the ideal case where the residual
|
|
146
|
+
is zero (i.e., perfect predictions).
|
|
147
|
+
- The plot will be saved as a PNG image if a valid `save_dir` is provided.
|
|
148
|
+
"""
|
|
149
|
+
|
|
150
|
+
mec = "#2F4F4F"
|
|
151
|
+
mfc = "#C0C0C0"
|
|
152
|
+
|
|
153
|
+
y_act = np.array(y_act)
|
|
154
|
+
y_pred = np.array(y_pred)
|
|
155
|
+
|
|
156
|
+
xmin = np.min([y_act]) * 0.9
|
|
157
|
+
xmax = np.max([y_act]) / 0.9
|
|
158
|
+
y_err = y_pred - y_act
|
|
159
|
+
ymin = np.min([y_err]) * 0.9
|
|
160
|
+
ymax = np.max([y_err]) / 0.9
|
|
161
|
+
|
|
162
|
+
fig, ax = plt.subplots(figsize=(4, 4))
|
|
163
|
+
|
|
164
|
+
ax.plot(y_act, y_err, "o", mec=mec, mfc=mfc, alpha=0.5, label=None, mew=1.2, ms=5.2)
|
|
165
|
+
ax.plot([xmin, xmax], [0, 0], "k--", alpha=0.8, label="ideal")
|
|
166
|
+
|
|
167
|
+
ax.set_ylabel("Residual error")
|
|
168
|
+
ax.set_xlabel("Actual value")
|
|
169
|
+
ax.legend(loc="lower right")
|
|
170
|
+
|
|
171
|
+
minor_locator_x = AutoMinorLocator(2)
|
|
172
|
+
minor_locator_y = AutoMinorLocator(2)
|
|
173
|
+
ax.get_xaxis().set_minor_locator(minor_locator_x)
|
|
174
|
+
ax.get_yaxis().set_minor_locator(minor_locator_y)
|
|
175
|
+
|
|
176
|
+
ax.tick_params(right=True, top=True, direction="in", length=7)
|
|
177
|
+
ax.tick_params(which="minor", right=True, top=True, direction="in", length=4)
|
|
178
|
+
|
|
179
|
+
ax.set_xlim(xmin, xmax)
|
|
180
|
+
ax.set_ylim(ymin, ymax)
|
|
181
|
+
|
|
182
|
+
if save_dir is not None:
|
|
183
|
+
fig_name = f"{save_dir}/{name}_residual.png"
|
|
184
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
185
|
+
plt.savefig(fig_name, bbox_inches="tight", dpi=300)
|
|
186
|
+
|
|
187
|
+
plt.draw()
|
|
188
|
+
plt.pause(0.001)
|
|
189
|
+
|
|
190
|
+
plt.close()
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def residual_hist(
|
|
194
|
+
y_act: np.ndarray, y_pred: np.ndarray, name: str = "example", save_dir: Optional[str] = None
|
|
195
|
+
) -> None:
|
|
196
|
+
"""
|
|
197
|
+
Generates a residual error histogram with kernel density estimate (KDE) for the given true and predicted values.
|
|
198
|
+
Optionally saves the plot to a specified directory.
|
|
199
|
+
|
|
200
|
+
Parameters
|
|
201
|
+
----------
|
|
202
|
+
y_act : `np.ndarray`
|
|
203
|
+
Array of true (actual) values.
|
|
204
|
+
|
|
205
|
+
y_pred : `np.ndarray`
|
|
206
|
+
Array of predicted values.
|
|
207
|
+
|
|
208
|
+
name : `str`, optional, default="example"
|
|
209
|
+
The name used for the saved plot filename.
|
|
210
|
+
|
|
211
|
+
save_dir : `str`, optional, default=None
|
|
212
|
+
Directory path to save the generated plot. If None, the plot is not saved.
|
|
213
|
+
|
|
214
|
+
Returns
|
|
215
|
+
--------
|
|
216
|
+
`None` : This function generates and optionally saves a plot but does not return any value.
|
|
217
|
+
|
|
218
|
+
Raises
|
|
219
|
+
-------
|
|
220
|
+
`UserWarning` : If the data has high correlation among variables, suggesting dimensionality reduction.
|
|
221
|
+
"""
|
|
222
|
+
mec = "#2F4F4F"
|
|
223
|
+
mfc = "#C0C0C0"
|
|
224
|
+
y_pred, y_act = y_pred.flatten(), y_act.flatten()
|
|
225
|
+
|
|
226
|
+
fig, ax = plt.subplots(figsize=(4, 4))
|
|
227
|
+
y_err = y_pred - y_act
|
|
228
|
+
x_range = np.linspace(min(y_err), max(y_err), 1000)
|
|
229
|
+
|
|
230
|
+
try:
|
|
231
|
+
kde_act = stats.gaussian_kde(y_err)
|
|
232
|
+
ax.plot(x_range, kde_act(x_range), "-", lw=1.2, color="k", label="kde")
|
|
233
|
+
except np.linalg.LinAlgError as e:
|
|
234
|
+
warnings.warn(
|
|
235
|
+
"The data has very high correlation among variables. Consider dimensionality reduction.",
|
|
236
|
+
UserWarning,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
ax.hist(y_err, color=mfc, bins=35, alpha=1, edgecolor=mec, density=True)
|
|
240
|
+
|
|
241
|
+
ax.set_xlabel("Residual error")
|
|
242
|
+
ax.set_ylabel("Relative frequency")
|
|
243
|
+
plt.legend(loc=2, framealpha=0.35, handlelength=1.5)
|
|
244
|
+
|
|
245
|
+
ax.tick_params(direction="in", length=7, top=True, right=True)
|
|
246
|
+
|
|
247
|
+
minor_locator_x = AutoMinorLocator(2)
|
|
248
|
+
minor_locator_y = AutoMinorLocator(2)
|
|
249
|
+
ax.get_xaxis().set_minor_locator(minor_locator_x)
|
|
250
|
+
ax.get_yaxis().set_minor_locator(minor_locator_y)
|
|
251
|
+
plt.tick_params(which="minor", direction="in", length=4, right=True, top=True)
|
|
252
|
+
|
|
253
|
+
if save_dir is not None:
|
|
254
|
+
fig_name = f"{save_dir}/{name}_residual_hist.png"
|
|
255
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
256
|
+
plt.savefig(fig_name, bbox_inches="tight", dpi=300)
|
|
257
|
+
|
|
258
|
+
plt.draw()
|
|
259
|
+
plt.pause(0.001)
|
|
260
|
+
plt.close()
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def loss_curve(
|
|
264
|
+
x_data: np.ndarray,
|
|
265
|
+
train_err: np.ndarray,
|
|
266
|
+
val_err: np.ndarray,
|
|
267
|
+
name: str = "example",
|
|
268
|
+
save_dir: Optional[str] = None,
|
|
269
|
+
) -> None:
|
|
270
|
+
"""
|
|
271
|
+
Plots the loss curve for both training and validation errors over epochs,
|
|
272
|
+
and optionally saves the plot as an image.
|
|
273
|
+
|
|
274
|
+
Parameters
|
|
275
|
+
----------
|
|
276
|
+
x_data : `np.ndarray`
|
|
277
|
+
Array of x-values (usually epochs) for the plot.
|
|
278
|
+
train_err : `np.ndarray`
|
|
279
|
+
Array of training error values.
|
|
280
|
+
val_err : `np.ndarray`
|
|
281
|
+
Array of validation error values.
|
|
282
|
+
name : `str`, optional
|
|
283
|
+
The name to use when saving the plot. Default is "example".
|
|
284
|
+
save_dir : `Optional[str]`, optional
|
|
285
|
+
Directory where the plot should be saved. If None, the plot is not saved. Default is None.
|
|
286
|
+
|
|
287
|
+
Returns
|
|
288
|
+
-------
|
|
289
|
+
`None` : This function does not return any value. It generates and optionally saves a plot.
|
|
290
|
+
"""
|
|
291
|
+
mec1 = "#2F4F4F"
|
|
292
|
+
mfc1 = "#C0C0C0"
|
|
293
|
+
mec2 = "maroon"
|
|
294
|
+
mfc2 = "pink"
|
|
295
|
+
|
|
296
|
+
fig, ax = plt.subplots(figsize=(4, 4))
|
|
297
|
+
|
|
298
|
+
ax.plot(
|
|
299
|
+
x_data,
|
|
300
|
+
train_err,
|
|
301
|
+
"-",
|
|
302
|
+
color=mec1,
|
|
303
|
+
marker="o",
|
|
304
|
+
mec=mec1,
|
|
305
|
+
mfc=mfc1,
|
|
306
|
+
ms=4,
|
|
307
|
+
alpha=0.5,
|
|
308
|
+
label="train",
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
ax.plot(
|
|
312
|
+
x_data,
|
|
313
|
+
val_err,
|
|
314
|
+
"--",
|
|
315
|
+
color=mec2,
|
|
316
|
+
marker="s",
|
|
317
|
+
mec=mec2,
|
|
318
|
+
mfc=mfc2,
|
|
319
|
+
ms=4,
|
|
320
|
+
alpha=0.5,
|
|
321
|
+
label="validation",
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
max_val_err = max(val_err)
|
|
325
|
+
ax.axhline(max_val_err, color="b", linestyle="--", alpha=0.3)
|
|
326
|
+
|
|
327
|
+
ax.set_xlabel("Number of training epochs")
|
|
328
|
+
ax.set_ylabel("Loss (Units)")
|
|
329
|
+
ax.set_ylim(0, 2 * np.mean(val_err))
|
|
330
|
+
|
|
331
|
+
ax.legend(loc=1, framealpha=0.35, handlelength=1.5)
|
|
332
|
+
|
|
333
|
+
minor_locator_x = AutoMinorLocator(2)
|
|
334
|
+
minor_locator_y = AutoMinorLocator(2)
|
|
335
|
+
ax.get_xaxis().set_minor_locator(minor_locator_x)
|
|
336
|
+
ax.get_yaxis().set_minor_locator(minor_locator_y)
|
|
337
|
+
|
|
338
|
+
ax.tick_params(right=True, top=True, direction="in", length=7)
|
|
339
|
+
ax.tick_params(which="minor", right=True, top=True, direction="in", length=4)
|
|
340
|
+
|
|
341
|
+
if save_dir is not None:
|
|
342
|
+
fig_name = f"{save_dir}/{name}_loss_curve.png"
|
|
343
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
344
|
+
plt.savefig(fig_name, bbox_inches="tight", dpi=300)
|
|
345
|
+
|
|
346
|
+
plt.draw()
|
|
347
|
+
plt.pause(0.001)
|
|
348
|
+
plt.close()
|
likelihood/tools/models_tools.py
CHANGED
|
@@ -3,11 +3,148 @@ import os
|
|
|
3
3
|
|
|
4
4
|
import networkx as nx
|
|
5
5
|
import pandas as pd
|
|
6
|
+
from pandas.core.frame import DataFrame
|
|
6
7
|
|
|
7
8
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
8
9
|
logging.getLogger("tensorflow").setLevel(logging.ERROR)
|
|
9
10
|
|
|
11
|
+
import sys
|
|
12
|
+
import warnings
|
|
13
|
+
from functools import wraps
|
|
14
|
+
from typing import Dict
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
10
17
|
import tensorflow as tf
|
|
18
|
+
from pandas.core.frame import DataFrame
|
|
19
|
+
|
|
20
|
+
from .figures import *
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class suppress_prints:
|
|
24
|
+
def __enter__(self):
|
|
25
|
+
self.original_stdout = sys.stdout
|
|
26
|
+
sys.stdout = open(os.devnull, "w")
|
|
27
|
+
|
|
28
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
29
|
+
sys.stdout.close()
|
|
30
|
+
sys.stdout = self.original_stdout
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def suppress_warnings(func):
|
|
34
|
+
@wraps(func)
|
|
35
|
+
def wrapper(*args, **kwargs):
|
|
36
|
+
with warnings.catch_warnings():
|
|
37
|
+
warnings.simplefilter("ignore")
|
|
38
|
+
return func(*args, **kwargs)
|
|
39
|
+
|
|
40
|
+
return wrapper
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def remove_collinearity(df: DataFrame, threshold: float = 0.9):
|
|
44
|
+
"""
|
|
45
|
+
Removes highly collinear features from the DataFrame based on a correlation threshold.
|
|
46
|
+
|
|
47
|
+
This function calculates the correlation matrix of the DataFrame and removes columns
|
|
48
|
+
that are highly correlated with any other column in the DataFrame. It uses an absolute
|
|
49
|
+
correlation value greater than the specified threshold to identify which columns to drop.
|
|
50
|
+
|
|
51
|
+
Parameters
|
|
52
|
+
----------
|
|
53
|
+
df : `DataFrame`
|
|
54
|
+
The input DataFrame containing numerical data.
|
|
55
|
+
threshold : `float`
|
|
56
|
+
The correlation threshold above which features will be removed. Default is `0.9`.
|
|
57
|
+
|
|
58
|
+
Returns
|
|
59
|
+
----------
|
|
60
|
+
DataFrame: A DataFrame with highly collinear features removed.
|
|
61
|
+
"""
|
|
62
|
+
corr_matrix = df.corr().abs()
|
|
63
|
+
upper_triangle = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))
|
|
64
|
+
to_drop = [
|
|
65
|
+
column for column in upper_triangle.columns if any(upper_triangle[column] > threshold)
|
|
66
|
+
]
|
|
67
|
+
df_reduced = df.drop(columns=to_drop)
|
|
68
|
+
|
|
69
|
+
return df_reduced
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def train_and_insights(
|
|
73
|
+
x_data: np.ndarray,
|
|
74
|
+
y_act: np.ndarray,
|
|
75
|
+
model: tf.keras.Model,
|
|
76
|
+
patience: int = 3,
|
|
77
|
+
reg: bool = False,
|
|
78
|
+
frac: float = 1.0,
|
|
79
|
+
**kwargs: Optional[Dict],
|
|
80
|
+
) -> tf.keras.Model:
|
|
81
|
+
"""
|
|
82
|
+
Train a Keras model and provide insights on the training and validation metrics.
|
|
83
|
+
|
|
84
|
+
Parameters
|
|
85
|
+
----------
|
|
86
|
+
x_data : `np.ndarray`
|
|
87
|
+
Input data for training the model.
|
|
88
|
+
y_act : `np.ndarray`
|
|
89
|
+
Actual labels corresponding to x_data.
|
|
90
|
+
model : `tf.keras.Model`
|
|
91
|
+
The Keras model to train.
|
|
92
|
+
patience : `int`
|
|
93
|
+
The patience parameter for early stopping callback (default is 3).
|
|
94
|
+
reg : `bool`
|
|
95
|
+
Flag to determine if residual analysis should be performed (default is `False`).
|
|
96
|
+
frac : `float`
|
|
97
|
+
Fraction of data to use (default is 1.0).
|
|
98
|
+
|
|
99
|
+
Keyword Arguments:
|
|
100
|
+
----------
|
|
101
|
+
Additional keyword arguments passed to the `model.fit` function, such as validation split and callbacks.
|
|
102
|
+
|
|
103
|
+
Returns
|
|
104
|
+
----------
|
|
105
|
+
`tf.keras.Model`
|
|
106
|
+
The trained model after fitting.
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
validation_split = kwargs.get("validation_split", 0.2)
|
|
110
|
+
callback = kwargs.get(
|
|
111
|
+
"callback", [tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=patience)]
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
for key in ["validation_split", "callback"]:
|
|
115
|
+
if key in kwargs:
|
|
116
|
+
del kwargs[key]
|
|
117
|
+
|
|
118
|
+
history = model.fit(
|
|
119
|
+
x_data,
|
|
120
|
+
y_act,
|
|
121
|
+
validation_split=validation_split,
|
|
122
|
+
verbose=False,
|
|
123
|
+
callbacks=callback,
|
|
124
|
+
**kwargs,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
hist = pd.DataFrame(history.history)
|
|
128
|
+
hist["epoch"] = history.epoch
|
|
129
|
+
|
|
130
|
+
columns = hist.columns
|
|
131
|
+
train_err, train_metric = columns[0], columns[1]
|
|
132
|
+
val_err, val_metric = columns[2], columns[3]
|
|
133
|
+
train_err, val_err = hist[train_err].values, hist[val_err].values
|
|
134
|
+
|
|
135
|
+
with suppress_prints():
|
|
136
|
+
n = int(len(x_data) * frac)
|
|
137
|
+
y_pred = model.predict(x_data[:n])
|
|
138
|
+
y_act = y_act[:n]
|
|
139
|
+
|
|
140
|
+
if reg:
|
|
141
|
+
residual(y_act, y_pred)
|
|
142
|
+
residual_hist(y_act, y_pred)
|
|
143
|
+
act_pred(y_act, y_pred)
|
|
144
|
+
|
|
145
|
+
loss_curve(hist["epoch"].values, train_err, val_err)
|
|
146
|
+
|
|
147
|
+
return model
|
|
11
148
|
|
|
12
149
|
|
|
13
150
|
@tf.keras.utils.register_keras_serializable(package="Custom", name="LoRALayer")
|
|
@@ -58,16 +195,31 @@ def apply_lora(model, rank=4):
|
|
|
58
195
|
return new_model
|
|
59
196
|
|
|
60
197
|
|
|
61
|
-
def graph_metrics(adj_matrix, eigenvector_threshold=1e-6):
|
|
198
|
+
def graph_metrics(adj_matrix: np.ndarray, eigenvector_threshold: float = 1e-6) -> DataFrame:
|
|
62
199
|
"""
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
200
|
+
Calculate various graph metrics based on the given adjacency matrix and return them in a single DataFrame.
|
|
201
|
+
|
|
202
|
+
Parameters
|
|
203
|
+
----------
|
|
204
|
+
adj_matrix : `np.ndarray`
|
|
205
|
+
The adjacency matrix representing the graph, where each element denotes the presence/weight of an edge between nodes.
|
|
206
|
+
eigenvector_threshold : `float`
|
|
207
|
+
A threshold for the eigenvector centrality calculation, used to determine the cutoff for small eigenvalues. Default is `1e-6`.
|
|
208
|
+
|
|
209
|
+
Returns
|
|
210
|
+
----------
|
|
211
|
+
DataFrame : A DataFrame containing the following graph metrics as columns.
|
|
212
|
+
- `Degree Centrality`: Degree centrality values for each node, indicating the number of direct connections each node has.
|
|
213
|
+
- `Clustering Coefficient`: Clustering coefficient values for each node, representing the degree to which nodes cluster together.
|
|
214
|
+
- `Eigenvector Centrality`: Eigenvector centrality values, indicating the influence of a node in the graph based on the eigenvectors of the adjacency matrix.
|
|
215
|
+
- `Degree`: The degree of each node, representing the number of edges connected to each node.
|
|
216
|
+
- `Betweenness Centrality`: Betweenness centrality values, representing the extent to which a node lies on the shortest paths between other nodes.
|
|
217
|
+
- `Closeness Centrality`: Closeness centrality values, indicating the inverse of the average shortest path distance from a node to all other nodes in the graph.
|
|
218
|
+
- `Assortativity`: The assortativity coefficient of the graph, measuring the tendency of nodes to connect to similar nodes.
|
|
219
|
+
|
|
220
|
+
Notes
|
|
221
|
+
----------
|
|
222
|
+
The returned DataFrame will have one row for each node and one column for each of the computed metrics.
|
|
71
223
|
"""
|
|
72
224
|
adj_matrix = adj_matrix.astype(int)
|
|
73
225
|
G = nx.from_numpy_array(adj_matrix)
|