cccpm 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cccpm/__init__.py +1 -0
- cccpm/cpm_analysis.py +272 -0
- cccpm/edge_selection.py +271 -0
- cccpm/fold.py +46 -0
- cccpm/logging.py +37 -0
- cccpm/models.py +148 -0
- cccpm/more_models.py +205 -0
- cccpm/reporting/__init__.py +1 -0
- cccpm/reporting/assets/CCCPM.png +0 -0
- cccpm/reporting/html_report.py +363 -0
- cccpm/reporting/plots/__init__.py +0 -0
- cccpm/reporting/plots/chord_v2.py +821 -0
- cccpm/reporting/plots/cpm_chord_plot.py +149 -0
- cccpm/reporting/plots/plots.py +337 -0
- cccpm/reporting/plots/utils.py +19 -0
- cccpm/reporting/reporting_utils.py +124 -0
- cccpm/results_manager.py +463 -0
- cccpm/scoring.py +40 -0
- cccpm/simulation/__init__.py +0 -0
- cccpm/simulation/simulate_multivariate.py +252 -0
- cccpm/simulation/simulate_sem.py +319 -0
- cccpm/simulation/simulate_simple.py +37 -0
- cccpm/utils.py +386 -0
- cccpm-0.2.1.dist-info/METADATA +105 -0
- cccpm-0.2.1.dist-info/RECORD +26 -0
- cccpm-0.2.1.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import numpy as np
|
|
3
|
+
from typing import Union, Tuple
|
|
4
|
+
import matplotlib.pyplot as plt
|
|
5
|
+
import pandas as pd
|
|
6
|
+
import netplotbrain
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def vector_to_upper_triangular_matrix(vector):
|
|
10
|
+
"""
|
|
11
|
+
Convert a vector containing strictly upper triangular elements back
|
|
12
|
+
to a 2D square matrix.
|
|
13
|
+
|
|
14
|
+
Parameters:
|
|
15
|
+
vector (np.ndarray): A vector containing the strictly upper triangular elements.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
np.ndarray: The reconstructed 2D square matrix.
|
|
19
|
+
"""
|
|
20
|
+
# Calculate the size of the matrix from the vector length
|
|
21
|
+
size = int((np.sqrt(8 * vector.size + 1) - 1) / 2) + 1
|
|
22
|
+
if size * (size - 1) // 2 != vector.size:
|
|
23
|
+
raise ValueError("Vector size does not match the number of elements for a valid square matrix.")
|
|
24
|
+
|
|
25
|
+
matrix = np.zeros((size, size))
|
|
26
|
+
# Get the indices of the strictly upper triangular part
|
|
27
|
+
row_indices, col_indices = np.triu_indices(size, k=1)
|
|
28
|
+
# Place the elements into the matrix
|
|
29
|
+
matrix[row_indices, col_indices] = vector
|
|
30
|
+
matrix[col_indices, row_indices] = vector
|
|
31
|
+
return matrix
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def get_colors_from_colormap(n_colors, colormap_name='tab10'):
|
|
35
|
+
"""
|
|
36
|
+
Get a set of distinct colors from a specified colormap.
|
|
37
|
+
|
|
38
|
+
Parameters:
|
|
39
|
+
n_colors (int): Number of distinct colors needed.
|
|
40
|
+
colormap_name (str): Name of the colormap to use (e.g., 'tab10').
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
list: A list of color strings.
|
|
44
|
+
"""
|
|
45
|
+
cmap = plt.get_cmap(colormap_name)
|
|
46
|
+
colors = [cmap(i / (n_colors - 1)) for i in range(n_colors)]
|
|
47
|
+
return colors
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def convert_matrix(adj: Union[list, np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
|
|
51
|
+
"""
|
|
52
|
+
Converts your adjacency (connectivity) matrix into a list of edges (i, j)
|
|
53
|
+
and their weights
|
|
54
|
+
:param adj: the matrix
|
|
55
|
+
"""
|
|
56
|
+
if isinstance(adj, list):
|
|
57
|
+
adj = np.array(adj)
|
|
58
|
+
idxs = np.triu_indices(adj.shape[0], k=1)
|
|
59
|
+
weights = adj[idxs]
|
|
60
|
+
idxs = np.array(idxs).T
|
|
61
|
+
smol = 1e-6
|
|
62
|
+
idxs = idxs[(weights > smol) | (weights < -smol)]
|
|
63
|
+
weights = weights[(weights > smol) | (weights < -smol)]
|
|
64
|
+
return idxs, weights
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def extract_edges(matrix, keep_only_non_zero_edges: bool = False):
|
|
68
|
+
"""
|
|
69
|
+
Given a square matrix (graph), this function returns:
|
|
70
|
+
1. A NumPy array with two columns containing the ids of the two nodes connected by an edge.
|
|
71
|
+
2. A NumPy array containing the weights of the edges.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
matrix (2D numpy array): A square matrix representing a graph.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
edges (2D numpy array): Array of edges.
|
|
78
|
+
weights (1D numpy array): Array of weights corresponding to the edges.
|
|
79
|
+
"""
|
|
80
|
+
if isinstance(matrix, np.ndarray) and matrix.shape[0] == matrix.shape[1]:
|
|
81
|
+
n = matrix.shape[0]
|
|
82
|
+
edges = []
|
|
83
|
+
weights = []
|
|
84
|
+
|
|
85
|
+
for i in range(1, n):
|
|
86
|
+
for j in range(i):
|
|
87
|
+
if keep_only_non_zero_edges:
|
|
88
|
+
if matrix[i, j] != 0: # Only include non-zero edges
|
|
89
|
+
edges.append([i, j])
|
|
90
|
+
weights.append(matrix[i, j])
|
|
91
|
+
else:
|
|
92
|
+
edges.append([i, j])
|
|
93
|
+
weights.append(matrix[i, j])
|
|
94
|
+
edges = np.array(edges, dtype=int)
|
|
95
|
+
weights = np.array(weights)
|
|
96
|
+
return edges, weights
|
|
97
|
+
else:
|
|
98
|
+
raise ValueError("Input must be a square matrix (2D NumPy array).")
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def plot_netplotbrain(results_folder, selected_metric, atlas_labels):
|
|
102
|
+
edges = np.load(os.path.join(results_folder, f"{selected_metric}.npy"))
|
|
103
|
+
if (selected_metric == "sig_stability_positive_edges") or (selected_metric == "sig_stability_negative_edges"):
|
|
104
|
+
threshold = 0.01
|
|
105
|
+
corr_transformed = np.where(np.abs(edges) > threshold, 0, edges)
|
|
106
|
+
corr_transformed = np.where(np.abs(edges) <= threshold, 1, corr_transformed)
|
|
107
|
+
edges = corr_transformed
|
|
108
|
+
elif (selected_metric == "stability_positive_edges") or (selected_metric == "stability_negative_edges"):
|
|
109
|
+
threshold = 1
|
|
110
|
+
corr_transformed = np.where(np.abs(edges) < threshold, 0, edges)
|
|
111
|
+
corr_transformed = np.where(np.abs(edges) >= threshold, 1, corr_transformed)
|
|
112
|
+
edges = corr_transformed
|
|
113
|
+
|
|
114
|
+
if 'positive' in selected_metric:
|
|
115
|
+
edge_color = "#b22222"
|
|
116
|
+
else:
|
|
117
|
+
edge_color = "#317199"
|
|
118
|
+
|
|
119
|
+
edges_plot, edge_weights = extract_edges(edges, keep_only_non_zero_edges=True)
|
|
120
|
+
|
|
121
|
+
if atlas_labels is not None and edges_plot.any():
|
|
122
|
+
aparc = atlas_labels
|
|
123
|
+
|
|
124
|
+
edges_netplot = pd.DataFrame({'i': edges_plot[:, 0], 'j': edges_plot[:, 1],
|
|
125
|
+
'weights': edge_weights})
|
|
126
|
+
|
|
127
|
+
fig, ax = netplotbrain.plot(template='MNI152NLin2009cAsym',
|
|
128
|
+
template_style='glass',
|
|
129
|
+
nodes=aparc,
|
|
130
|
+
edges=edges_netplot,
|
|
131
|
+
view=['LSR'],
|
|
132
|
+
highlight_edges=True,
|
|
133
|
+
highlight_nodes=None,
|
|
134
|
+
node_type='circles',
|
|
135
|
+
edge_color=edge_color,
|
|
136
|
+
node_color='#332f2c'
|
|
137
|
+
)
|
|
138
|
+
else:
|
|
139
|
+
fig = plt.figure()
|
|
140
|
+
edges_netplot = None
|
|
141
|
+
fig.savefig(os.path.join(results_folder, "plots", f"netplotbrain_{selected_metric}.png"))
|
|
142
|
+
return os.path.join(results_folder, "plots", f"netplotbrain_{selected_metric}.png"), edges_netplot
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
if __name__ == "__main__":
|
|
146
|
+
results_directory = '/spm-data/vault-data3/mmll/projects/cpm_python/results/hcp_SSAGA_TB_Yrs_Smoked_spearman_partial_p=0.001/'
|
|
147
|
+
selected_metric = "sig_stability_negative_edges"
|
|
148
|
+
#plot_cpm_chord_plot(results_directory, selected_metric)
|
|
149
|
+
plot_netplotbrain(results_directory, selected_metric)
|
|
@@ -0,0 +1,337 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import pandas as pd
|
|
4
|
+
import seaborn as sns
|
|
5
|
+
|
|
6
|
+
import matplotlib.pyplot as plt
|
|
7
|
+
import matplotlib as mpl
|
|
8
|
+
import matplotlib.gridspec as gridspec
|
|
9
|
+
|
|
10
|
+
from pandas.api.types import is_numeric_dtype
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# Shared plotting settings
|
|
14
|
+
COLOR_MAP = {
|
|
15
|
+
"positive": "#FF5768",
|
|
16
|
+
"negative": "#6C88C4",
|
|
17
|
+
"both": "#d6d6d6"
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
MODEL_ORDER = ["covariates", "connectome", "full", "residuals", "increment"]
|
|
21
|
+
|
|
22
|
+
def apply_nature_style():
|
|
23
|
+
sns.set_theme(style="white")
|
|
24
|
+
mpl.rcParams.update({
|
|
25
|
+
"font.size": 7,
|
|
26
|
+
"axes.labelsize": 7,
|
|
27
|
+
"axes.titlesize": 7,
|
|
28
|
+
"xtick.labelsize": 6,
|
|
29
|
+
"ytick.labelsize": 6,
|
|
30
|
+
"lines.linewidth": 0.75,
|
|
31
|
+
"axes.linewidth": 0.5,
|
|
32
|
+
"legend.fontsize": 6
|
|
33
|
+
})
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def scatter_plot(df: pd.DataFrame, results_folder: str, y_name) -> str:
|
|
37
|
+
apply_nature_style()
|
|
38
|
+
|
|
39
|
+
df = df[df['model'].isin(['connectome', 'residuals', 'full'])]
|
|
40
|
+
|
|
41
|
+
def regplot_colored(data, **kwargs):
|
|
42
|
+
color = COLOR_MAP.get(data['network'].iloc[0], "#000000")
|
|
43
|
+
sns.regplot(
|
|
44
|
+
data=data,
|
|
45
|
+
x="y_true", y="y_pred",
|
|
46
|
+
scatter_kws={"alpha": 0.7, "s": 14, "edgecolor": "white", "color": color},
|
|
47
|
+
line_kws={"color": color, "linewidth": 0.75},
|
|
48
|
+
**kwargs
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
g = sns.FacetGrid(df, row="model", col="network", margin_titles=True, height=1.5, aspect=1)
|
|
52
|
+
g.map_dataframe(regplot_colored)
|
|
53
|
+
g.set_titles(col_template="{col_name}", row_template="{row_name}", size=7)
|
|
54
|
+
g.set_xlabels(y_name)
|
|
55
|
+
g.set_ylabels(f"predicted {y_name}")
|
|
56
|
+
sns.despine(trim=True)
|
|
57
|
+
g.fig.tight_layout(pad=0.5)
|
|
58
|
+
|
|
59
|
+
png_path = os.path.join(results_folder, "predictions.png")
|
|
60
|
+
pdf_path = os.path.join(results_folder, "predictions.pdf")
|
|
61
|
+
g.fig.savefig(png_path, dpi=600, bbox_inches="tight")
|
|
62
|
+
g.fig.savefig(pdf_path, bbox_inches="tight")
|
|
63
|
+
|
|
64
|
+
return png_path
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def scatter_plot_covariates_model(df: pd.DataFrame, results_folder: str, y_name) -> str:
|
|
68
|
+
"""
|
|
69
|
+
Generate a single scatter plot with regression line for the 'covariates' model.
|
|
70
|
+
"""
|
|
71
|
+
apply_nature_style()
|
|
72
|
+
|
|
73
|
+
df = df[df["model"] == "covariates"]
|
|
74
|
+
|
|
75
|
+
# Create a figure with GridSpec
|
|
76
|
+
fig = plt.figure(figsize=(6, 2))
|
|
77
|
+
gs = gridspec.GridSpec(1, 3, figure=fig)
|
|
78
|
+
|
|
79
|
+
# Create one subplot in the center cell
|
|
80
|
+
ax = fig.add_subplot(gs[0, 1])
|
|
81
|
+
|
|
82
|
+
sns.regplot(
|
|
83
|
+
data=df,
|
|
84
|
+
x="y_true",
|
|
85
|
+
y="y_pred",
|
|
86
|
+
scatter_kws={"alpha": 0.7, "s": 14, "edgecolor": "white", "color": "black"},
|
|
87
|
+
line_kws={"color": "black", "linewidth": 0.75},
|
|
88
|
+
ax=ax
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
sns.despine(trim=True)
|
|
92
|
+
ax.set_xlabel(y_name)
|
|
93
|
+
ax.set_ylabel(f"predicted {y_name}")
|
|
94
|
+
ax.set_title('covariates')
|
|
95
|
+
png_path = os.path.join(results_folder, "scatter_covariates.png")
|
|
96
|
+
pdf_path = os.path.join(results_folder, "scatter_covariates.pdf")
|
|
97
|
+
plt.tight_layout(pad=0.5)
|
|
98
|
+
# This makes the figure 10x10 inches
|
|
99
|
+
fig.savefig(png_path, dpi=600)
|
|
100
|
+
fig.savefig(pdf_path)
|
|
101
|
+
|
|
102
|
+
return png_path
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def histograms_network_strengths(df: pd.DataFrame, results_folder: str, y_name) -> str:
|
|
106
|
+
"""
|
|
107
|
+
Create a 2x2 grid of histograms showing the distribution of network_strength
|
|
108
|
+
for two models ('connectome', 'residuals') and two networks ('positive', 'negative').
|
|
109
|
+
"""
|
|
110
|
+
apply_nature_style()
|
|
111
|
+
|
|
112
|
+
# Filter relevant data
|
|
113
|
+
df = df[df["model"].isin(["connectome", "residuals"])]
|
|
114
|
+
df = df[df["network"].isin(["positive", "negative"])]
|
|
115
|
+
|
|
116
|
+
# Color mapping
|
|
117
|
+
color_map = {
|
|
118
|
+
"positive": "#FF5768", # red
|
|
119
|
+
"negative": "#6C88C4" # blue
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
def histplot_colored(data, color=None, **kwargs):
|
|
123
|
+
# Override color based on 'network' value
|
|
124
|
+
network = data["network"].iloc[0]
|
|
125
|
+
color = {"positive": "#FF5768", "negative": "#6C88C4"}[network]
|
|
126
|
+
sns.histplot(
|
|
127
|
+
data=data,
|
|
128
|
+
x="network_strength",
|
|
129
|
+
bins=30,
|
|
130
|
+
edgecolor="white",
|
|
131
|
+
linewidth=0.3,
|
|
132
|
+
color=color, # This now safely overrides the one passed by FacetGrid
|
|
133
|
+
**kwargs
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# Create 2x2 facet grid
|
|
137
|
+
g = sns.FacetGrid(
|
|
138
|
+
df,
|
|
139
|
+
row="model",
|
|
140
|
+
col="network",
|
|
141
|
+
margin_titles=True,
|
|
142
|
+
height=1.5,
|
|
143
|
+
aspect=1
|
|
144
|
+
)
|
|
145
|
+
g.map_dataframe(histplot_colored)
|
|
146
|
+
g.set_titles(col_template="{col_name}", row_template="{row_name}", size=7)
|
|
147
|
+
g.set_axis_labels("network strength", y_name)
|
|
148
|
+
sns.despine(trim=True)
|
|
149
|
+
g.fig.tight_layout(pad=0.5)
|
|
150
|
+
|
|
151
|
+
# Save
|
|
152
|
+
png_path = os.path.join(results_folder, "histograms_network_strengths.png")
|
|
153
|
+
pdf_path = os.path.join(results_folder, "histograms_network_strengths.pdf")
|
|
154
|
+
g.fig.savefig(png_path, dpi=600, bbox_inches="tight")
|
|
155
|
+
g.fig.savefig(pdf_path, bbox_inches="tight")
|
|
156
|
+
|
|
157
|
+
return png_path
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def scatter_plot_network_strengths(df: pd.DataFrame, results_folder: str, y_name) -> str:
|
|
161
|
+
"""
|
|
162
|
+
Create a 2x2 scatter plot of y_true vs network_strength
|
|
163
|
+
for two models ('connectome', 'residuals') and two networks ('positive', 'negative').
|
|
164
|
+
"""
|
|
165
|
+
apply_nature_style()
|
|
166
|
+
|
|
167
|
+
# Define color mapping
|
|
168
|
+
color_map = {
|
|
169
|
+
"positive": "#FF5768", # red
|
|
170
|
+
"negative": "#6C88C4" # blue
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
# Plotting function with custom color per network
|
|
174
|
+
def regplot_colored(data, **kwargs):
|
|
175
|
+
network = data["network"].iloc[0]
|
|
176
|
+
color = color_map.get(network, "black")
|
|
177
|
+
sns.regplot(
|
|
178
|
+
data=data,
|
|
179
|
+
x="network_strength",
|
|
180
|
+
y="y_true",
|
|
181
|
+
scatter_kws={"alpha": 0.7, "s": 14, "edgecolor": "white", "color": color},
|
|
182
|
+
line_kws={"color": color, "linewidth": 0.75},
|
|
183
|
+
**kwargs
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
# Create 2x2 facet grid: rows = model, cols = network
|
|
187
|
+
g = sns.FacetGrid(
|
|
188
|
+
df,
|
|
189
|
+
row="model",
|
|
190
|
+
col="network",
|
|
191
|
+
margin_titles=True,
|
|
192
|
+
height=1.5,
|
|
193
|
+
aspect=1
|
|
194
|
+
)
|
|
195
|
+
g.map_dataframe(regplot_colored)
|
|
196
|
+
g.set_titles(col_template="{col_name}", row_template="{row_name}", size=7)
|
|
197
|
+
g.set_axis_labels("network strength", y_name)
|
|
198
|
+
sns.despine(trim=True)
|
|
199
|
+
g.fig.tight_layout(pad=0.5)
|
|
200
|
+
|
|
201
|
+
# Save
|
|
202
|
+
png_path = os.path.join(results_folder, "scatter_network_strengths.png")
|
|
203
|
+
pdf_path = os.path.join(results_folder, "scatter_network_strengths.pdf")
|
|
204
|
+
g.fig.savefig(png_path, dpi=600, bbox_inches="tight")
|
|
205
|
+
g.fig.savefig(pdf_path, bbox_inches="tight")
|
|
206
|
+
|
|
207
|
+
return png_path
|
|
208
|
+
|
|
209
|
+
def boxplot_model_performance(
|
|
210
|
+
df: pd.DataFrame,
|
|
211
|
+
metric: str,
|
|
212
|
+
results_folder: str,
|
|
213
|
+
models: list[str],
|
|
214
|
+
filename_suffix: str = ""
|
|
215
|
+
) -> str:
|
|
216
|
+
"""
|
|
217
|
+
Creates a horizontal boxplot comparing models across network types.
|
|
218
|
+
|
|
219
|
+
Parameters:
|
|
220
|
+
df: Input dataframe.
|
|
221
|
+
metric: Name of the column to be plotted on the x-axis.
|
|
222
|
+
results_folder: Output folder path.
|
|
223
|
+
models: List of model names to include (e.g. ['increment'] or others).
|
|
224
|
+
filename_suffix: Optional string to append to the output filename.
|
|
225
|
+
"""
|
|
226
|
+
apply_nature_style()
|
|
227
|
+
|
|
228
|
+
df = df[df["model"].isin(models)]
|
|
229
|
+
|
|
230
|
+
# Adjust figure size based on model count
|
|
231
|
+
height = 0.75 if len(models) == 1 else 2
|
|
232
|
+
fig, ax = plt.subplots(figsize=(7, height))
|
|
233
|
+
|
|
234
|
+
sns.boxplot(
|
|
235
|
+
data=df,
|
|
236
|
+
x=metric,
|
|
237
|
+
y="model",
|
|
238
|
+
hue="network",
|
|
239
|
+
order=models,
|
|
240
|
+
hue_order=["both", "negative", "positive"],
|
|
241
|
+
palette=COLOR_MAP,
|
|
242
|
+
orient="h",
|
|
243
|
+
fliersize=2,
|
|
244
|
+
linewidth=0.5,
|
|
245
|
+
width=0.5,
|
|
246
|
+
ax=ax
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
if metric in ["pearson_score", "spearman_score", "explained_variance_score"]:
|
|
250
|
+
ax.axvline(x=0, color="black", linewidth=0.5)
|
|
251
|
+
ax.set_xlim(-0.5, 1)
|
|
252
|
+
|
|
253
|
+
sns.despine(trim=True)
|
|
254
|
+
ax.set_xlabel(metric.replace("_", " "))
|
|
255
|
+
ax.set_ylabel("")
|
|
256
|
+
# Move legend outside the plot
|
|
257
|
+
ax.legend(
|
|
258
|
+
title="",
|
|
259
|
+
loc="center left",
|
|
260
|
+
bbox_to_anchor=(1.01, 0.5),
|
|
261
|
+
frameon=False,
|
|
262
|
+
handletextpad=0.5
|
|
263
|
+
)
|
|
264
|
+
# Save plot
|
|
265
|
+
suffix = f"_{filename_suffix}" if filename_suffix else ""
|
|
266
|
+
png_path = os.path.join(results_folder, f"boxplot_{metric}{suffix}.png")
|
|
267
|
+
pdf_path = os.path.join(results_folder, f"boxplot_{metric}{suffix}.pdf")
|
|
268
|
+
svg_path = os.path.join(results_folder, f"boxplot_{metric}{suffix}.svg")
|
|
269
|
+
fig.tight_layout(pad=0.2)
|
|
270
|
+
fig.savefig(png_path, dpi=600, bbox_inches="tight")
|
|
271
|
+
fig.savefig(pdf_path, bbox_inches="tight")
|
|
272
|
+
fig.savefig(svg_path, bbox_inches="tight")
|
|
273
|
+
|
|
274
|
+
return png_path
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def pairplot_flexible(df: pd.DataFrame, output_path: str) -> str:
|
|
278
|
+
sns.set_theme(style="white")
|
|
279
|
+
variables = df.columns
|
|
280
|
+
n = len(variables)
|
|
281
|
+
fig, axes = plt.subplots(n, n, figsize=(2.5 * n, 2.5 * n))
|
|
282
|
+
|
|
283
|
+
for i, row_var in enumerate(variables):
|
|
284
|
+
for j, col_var in enumerate(variables):
|
|
285
|
+
ax = axes[i, j]
|
|
286
|
+
ax.set_xlabel("")
|
|
287
|
+
ax.set_ylabel("")
|
|
288
|
+
|
|
289
|
+
x = df[col_var]
|
|
290
|
+
y = df[row_var]
|
|
291
|
+
|
|
292
|
+
is_x_cont = is_numeric_dtype(x)
|
|
293
|
+
is_y_cont = is_numeric_dtype(y)
|
|
294
|
+
|
|
295
|
+
if i == j:
|
|
296
|
+
if is_x_cont:
|
|
297
|
+
sns.histplot(x, bins=20, ax=ax, color="gray", edgecolor="white")
|
|
298
|
+
else:
|
|
299
|
+
counts = x.value_counts().sort_index()
|
|
300
|
+
|
|
301
|
+
sns.barplot(
|
|
302
|
+
x=counts.index.astype(str),
|
|
303
|
+
y=counts.values,
|
|
304
|
+
hue=counts.index.astype(str), # ← now we have a hue
|
|
305
|
+
palette="pastel",
|
|
306
|
+
legend=False,
|
|
307
|
+
ax=ax
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
# rotate labels (you can also use tick_params as shown earlier)
|
|
311
|
+
ax.set_xticks(range(len(counts)))
|
|
312
|
+
ax.set_xticklabels(counts.index.astype(str), rotation=45, ha="right")
|
|
313
|
+
ax.set_title(row_var, fontsize=9)
|
|
314
|
+
sns.despine(ax=ax)
|
|
315
|
+
continue
|
|
316
|
+
|
|
317
|
+
if is_x_cont and is_y_cont:
|
|
318
|
+
sns.scatterplot(x=x, y=y, ax=ax, s=15, alpha=0.6, edgecolor="white", linewidth=0.3)
|
|
319
|
+
elif is_x_cont and not is_y_cont:
|
|
320
|
+
sns.histplot(data=df, x=col_var, hue=row_var, ax=ax, element="step", stat="count",
|
|
321
|
+
common_norm=False, bins=20, palette="Set2")
|
|
322
|
+
elif not is_x_cont and is_y_cont:
|
|
323
|
+
sns.histplot(data=df, x=row_var, hue=col_var, ax=ax, element="step", stat="count",
|
|
324
|
+
common_norm=False, bins=20, palette="Set2")
|
|
325
|
+
else:
|
|
326
|
+
ctab = pd.crosstab(y, x)
|
|
327
|
+
sns.heatmap(ctab, annot=True, fmt='d', cmap="Blues", cbar=False, ax=ax)
|
|
328
|
+
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
|
|
329
|
+
ax.set_yticklabels(ax.get_yticklabels(), rotation=0, va='center')
|
|
330
|
+
|
|
331
|
+
ax.tick_params(axis='both', labelsize=6)
|
|
332
|
+
sns.despine(ax=ax)
|
|
333
|
+
|
|
334
|
+
#plt.tight_layout()
|
|
335
|
+
fig.savefig(output_path, dpi=600)
|
|
336
|
+
plt.close(fig)
|
|
337
|
+
return output_path
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
import matplotlib.pyplot as plt
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def get_colors_from_colormap(n_colors, colormap_name='tab10'):
|
|
6
|
+
"""
|
|
7
|
+
Get a set of distinct colors from a specified colormap.
|
|
8
|
+
|
|
9
|
+
Parameters:
|
|
10
|
+
n_colors (int): Number of distinct colors needed.
|
|
11
|
+
colormap_name (str): Name of the colormap to use (e.g., 'tab10').
|
|
12
|
+
|
|
13
|
+
Returns:
|
|
14
|
+
list: A list of color strings.
|
|
15
|
+
"""
|
|
16
|
+
cmap = plt.get_cmap(colormap_name)
|
|
17
|
+
colors = [cmap(i / (n_colors - 1)) for i in range(n_colors)]
|
|
18
|
+
return colors
|
|
19
|
+
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import pandas as pd
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def format_results_table(df, precision=2):
|
|
7
|
+
"""
|
|
8
|
+
Format a MultiIndex DataFrame:
|
|
9
|
+
- Merge 'mean' and 'std' into 'summary'
|
|
10
|
+
- Format p-values: add * / ** and highlight in bold using CSS
|
|
11
|
+
- Return a styled Pandas DataFrame with APA style
|
|
12
|
+
"""
|
|
13
|
+
formatted = {}
|
|
14
|
+
metrics = df.columns.get_level_values(0).unique()
|
|
15
|
+
|
|
16
|
+
for metric in metrics:
|
|
17
|
+
mean = df[(metric, "mean")]
|
|
18
|
+
std = df[(metric, "std")]
|
|
19
|
+
p = df[(metric, "p")]
|
|
20
|
+
|
|
21
|
+
# Format mean [std]
|
|
22
|
+
summary_col = mean.round(precision).astype(str) + " [" + std.round(precision).astype(str) + "]"
|
|
23
|
+
|
|
24
|
+
# Annotate p-values with asterisks (we'll apply bold via styling)
|
|
25
|
+
def p_string(val):
|
|
26
|
+
if pd.isna(val):
|
|
27
|
+
return ""
|
|
28
|
+
elif val < 0.001:
|
|
29
|
+
return "<0.001**"
|
|
30
|
+
elif val < 0.01:
|
|
31
|
+
return f"{val:.3f}**"
|
|
32
|
+
elif val < 0.05:
|
|
33
|
+
return f"{val:.3f}*"
|
|
34
|
+
else:
|
|
35
|
+
return f"{val:.3f}"
|
|
36
|
+
|
|
37
|
+
formatted[(metric, "mean [sd]")] = summary_col
|
|
38
|
+
formatted[(metric, "p")] = p.apply(p_string)
|
|
39
|
+
|
|
40
|
+
combined = pd.DataFrame(formatted, index=df.index)
|
|
41
|
+
combined.columns = pd.MultiIndex.from_tuples(combined.columns)
|
|
42
|
+
|
|
43
|
+
# Column sort: summary → p
|
|
44
|
+
combined = combined.loc[:, sorted(combined.columns, key=lambda x: (x[0], ["mean [sd]", "p"].index(x[1])))]
|
|
45
|
+
|
|
46
|
+
# Build Styler
|
|
47
|
+
styler = combined.style.set_properties(
|
|
48
|
+
**{
|
|
49
|
+
'font-size': '10px',
|
|
50
|
+
'padding': '2px 4px',
|
|
51
|
+
'text-align': 'center'
|
|
52
|
+
}
|
|
53
|
+
).set_table_styles([
|
|
54
|
+
{'selector': 'th',
|
|
55
|
+
'props': [('font-size', '11px'),
|
|
56
|
+
('padding', '2px 4px'),
|
|
57
|
+
('text-align', 'center'),
|
|
58
|
+
('background-color', '#f9f9f9')]},
|
|
59
|
+
{'selector': '.row_heading',
|
|
60
|
+
'props': [('font-size', '10px'),
|
|
61
|
+
('padding', '2px 4px')]},
|
|
62
|
+
{'selector': '.index_name',
|
|
63
|
+
'props': [('font-size', '10px'),
|
|
64
|
+
('padding', '2px 4px')]}
|
|
65
|
+
])
|
|
66
|
+
|
|
67
|
+
# Apply bold to significant p-values via CSS
|
|
68
|
+
def bold_sig(val):
|
|
69
|
+
if isinstance(val, str) and val.endswith("**") or val.endswith("*"):
|
|
70
|
+
return 'font-weight: bold'
|
|
71
|
+
return ''
|
|
72
|
+
|
|
73
|
+
# Apply only to p-value columns
|
|
74
|
+
for col in combined.columns:
|
|
75
|
+
if col[1] == "p":
|
|
76
|
+
styler = styler.map(bold_sig, subset=[col])
|
|
77
|
+
# Add thick horizontal lines between top-level index groups
|
|
78
|
+
def thick_divider_rows(df):
|
|
79
|
+
styles = pd.DataFrame("", index=df.index, columns=df.columns)
|
|
80
|
+
previous_group = None
|
|
81
|
+
for i, idx in enumerate(df.index):
|
|
82
|
+
current_group = idx[0] # assumes 'model' is the first index level
|
|
83
|
+
if previous_group is not None and current_group != previous_group:
|
|
84
|
+
styles.iloc[i] = 'border-top: 1px solid black'
|
|
85
|
+
previous_group = current_group
|
|
86
|
+
return styles
|
|
87
|
+
|
|
88
|
+
styler = styler.apply(thick_divider_rows, axis=None)
|
|
89
|
+
return styler
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def extract_log_block(filepath):
|
|
93
|
+
with open(filepath, "r") as f:
|
|
94
|
+
lines = f.readlines()
|
|
95
|
+
|
|
96
|
+
# Find all indices of separator lines (e.g. "=======")
|
|
97
|
+
sep_indices = [i for i, line in enumerate(lines) if line.strip().startswith("=")]
|
|
98
|
+
|
|
99
|
+
if len(sep_indices) >= 2:
|
|
100
|
+
# Take everything between the first two separator lines
|
|
101
|
+
start = sep_indices[0] + 1
|
|
102
|
+
end = sep_indices[1]
|
|
103
|
+
content = lines[start:end]
|
|
104
|
+
else:
|
|
105
|
+
content = [] # or raise an error, depending on your expectations
|
|
106
|
+
|
|
107
|
+
return "".join(content).strip()
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
# Function to read CSV file from the given folder path
|
|
111
|
+
def load_data_from_folder(folder_path, filename):
|
|
112
|
+
csv_path = os.path.join(folder_path, filename)
|
|
113
|
+
if os.path.exists(csv_path):
|
|
114
|
+
return pd.read_csv(csv_path)
|
|
115
|
+
else:
|
|
116
|
+
raise RuntimeError(f"No CSV file found at path: {csv_path}")
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def load_results_from_folder(folder_path, filename):
|
|
120
|
+
csv_path = os.path.join(folder_path, filename)
|
|
121
|
+
if os.path.exists(csv_path):
|
|
122
|
+
return pd.read_csv(csv_path, header=[0, 1], index_col=[0, 1])
|
|
123
|
+
else:
|
|
124
|
+
raise RuntimeError(f"No CSV file found at path: {csv_path}")
|