cccpm 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,363 @@
1
+ import os
2
+
3
+ from pathlib import Path
4
+
5
+ import pandas as pd
6
+ import numpy as np
7
+ import arakawa as ar
8
+ import matplotlib.pyplot as plt
9
+
10
+ from cccpm.reporting.plots.plots import boxplot_model_performance
11
+ from cccpm.reporting.plots.plots import (scatter_plot, scatter_plot_covariates_model, scatter_plot_network_strengths,
12
+ histograms_network_strengths)
13
+ from cccpm.reporting.plots.cpm_chord_plot import plot_netplotbrain
14
+ from cccpm.reporting.reporting_utils import format_results_table, extract_log_block, load_results_from_folder, load_data_from_folder
15
+
16
+
17
+ class HTMLReporter:
18
+ def __init__(self, results_directory: str, atlas_labels: str = None):
19
+ self.results_directory = results_directory
20
+ self.plots_dir = os.path.join(results_directory, "plots")
21
+ self.X_names, self.y_name, self.covariates_names = self.load_variable_names()
22
+ os.makedirs(self.plots_dir, exist_ok=True)
23
+
24
+ # copy atlas labels file to plotting directory
25
+ if atlas_labels is not None:
26
+ self.atlas_labels = pd.read_csv(atlas_labels)
27
+ else:
28
+ self.atlas_labels = None
29
+
30
+ # Load results
31
+ self.df = pd.read_csv(os.path.join(results_directory, 'cv_results.csv'))
32
+ self.df_mean = load_results_from_folder(results_directory, 'cv_results_mean_std.csv')
33
+ self.df_mean = self.df_mean.reorder_levels(["model", "network"])
34
+ model_order = ["covariates", "connectome", "full", "residuals", "increment"]
35
+ network_order = ["positive", "negative", "both"]
36
+
37
+ self.df_mean.index = pd.MultiIndex.from_frame(
38
+ self.df_mean.index.to_frame().assign(
39
+ model=pd.Categorical(self.df_mean.index.get_level_values("model"), categories=model_order, ordered=True),
40
+ network=pd.Categorical(self.df_mean.index.get_level_values("network"), categories=network_order, ordered=True)
41
+ )
42
+ )
43
+ # Now sort
44
+ self.df_mean = self.df_mean.sort_index()
45
+ self.df_predictions = load_data_from_folder(results_directory, 'cv_predictions.csv')
46
+ self.df_p_values = load_data_from_folder(results_directory, 'p_values.csv')
47
+ self.df_permutations = load_data_from_folder(results_directory, 'permutation_results.csv')
48
+ self.df_network_strengths = load_data_from_folder(results_directory, 'cv_network_strengths.csv')
49
+
50
+ def generate_html_report(self):
51
+
52
+ info_page = self.generate_info_page()
53
+ main_results_page = self.generate_main_results_page()
54
+ edges_page = self.generate_brain_plot_page()
55
+ edges_table_page = self.generate_edge_page()
56
+ network_strength_page = self.generate_network_strengths_page()
57
+ data_description_page = self.generate_data_description_page()
58
+ hyperparameters_page = self.generate_hyperparameters_page()
59
+ data_insight_page = self.generate_target_cov_features_insights_page() # <-- add this
60
+
61
+ report_blocks = [
62
+ info_page,
63
+ data_insight_page,
64
+ data_description_page,
65
+ main_results_page,
66
+ hyperparameters_page,
67
+ network_strength_page,
68
+ edges_page,
69
+ edges_table_page
70
+ ]
71
+
72
+ main_tabs = ar.Select(blocks=report_blocks)
73
+ script_dir = Path(__file__).parent
74
+ image_path = script_dir / 'assets/CCCPM.png'
75
+ image_path = image_path.resolve() # Optional: resolve to absolute path
76
+
77
+ main_page = ar.Group(ar.Media(file=image_path, name="Logo"),
78
+ main_tabs,
79
+ widths=[1, 10], columns=2)
80
+ report = ar.Report(blocks=[main_page])
81
+ report.save(os.path.join(self.results_directory, 'report.html'),
82
+ open=False, formatting=ar.Formatting(width=ar.Width.FULL, accent_color="orange"))
83
+ plt.close('all')
84
+ return
85
+
86
+ def generate_hyperparameters_page(self):
87
+ try:
88
+ if 'params' not in self.df.columns:
89
+ return ar.Blocks(blocks=[ar.Text("No hyperparameters found in results.")],
90
+ label='Hyperparameters')
91
+
92
+ hyper_df = self.df[['fold', 'model', 'params']].copy()
93
+ hyper_df = (
94
+ hyper_df
95
+ .drop(columns='model') # we don’t need the model any more
96
+ .drop_duplicates(subset=['fold']) # keep first row per fold
97
+ .reset_index(drop=True)
98
+ )
99
+
100
+ if isinstance(hyper_df['params'].iloc[0], dict):
101
+ hyper_df['params'] = hyper_df['params'].apply(
102
+ lambda x: "\n".join(f"{k}: {v}" for k, v in x.items())
103
+ )
104
+
105
+ hyper_table = ar.DataTable(
106
+ df=hyper_df,
107
+ label='Hyperparameters by Fold and Model'
108
+ )
109
+
110
+ try:
111
+ if isinstance(self.df['params'].iloc[0], dict):
112
+ all_params = pd.json_normalize(self.df['params'])
113
+ param_summary = all_params.describe().T
114
+ summary_table = ar.DataTable(
115
+ df=param_summary,
116
+ label='Hyperparameter Summary Statistics'
117
+ )
118
+
119
+ return ar.Blocks(
120
+ blocks=[
121
+ ar.Text("## Model Hyperparameters"),
122
+ ar.Text("### Parameters Used in Each Fold"),
123
+ hyper_table,
124
+ ar.Text("### Parameter Summary Across Folds"),
125
+ summary_table
126
+ ],
127
+ label='Hyperparameters'
128
+ )
129
+ except Exception:
130
+ pass
131
+
132
+ return ar.Blocks(
133
+ blocks=[
134
+ ar.Text("## Model Hyperparameters"),
135
+ ar.Text("### Parameters Used in Each Fold"),
136
+ hyper_table
137
+ ],
138
+ label='Hyperparameters'
139
+ )
140
+
141
+ except Exception as e:
142
+ return ar.Blocks(
143
+ blocks=[ar.Text(f"Could not generate hyperparameters page: {str(e)}")],
144
+ label='Hyperparameters'
145
+ )
146
+
147
+ def generate_data_description_page(self):
148
+ target_column_name = 'y_true'
149
+ target_series = self.df_predictions[target_column_name]
150
+
151
+ target_desc = f"""
152
+ ## Target Variable Description
153
+
154
+ - Number of observations: {len(self.df_predictions)}
155
+ - Target variable name: {target_column_name}
156
+ - Range: {target_series.min():.2f} to {target_series.max():.2f}
157
+ - Mean: {target_series.mean():.2f}
158
+ - Standard deviation: {target_series.std():.2f}
159
+ """
160
+
161
+ feature_desc = """
162
+ ## Dummy Feature Description
163
+
164
+ Connectivity features were derived from:
165
+ """
166
+
167
+ if self.atlas_labels is not None:
168
+ feature_desc += f"\n- Atlas labels provided: {len(self.atlas_labels)} regions"
169
+
170
+ return ar.Blocks(blocks=[
171
+ ar.Text(target_desc),
172
+ ar.Text(feature_desc)
173
+ ], label='Data Description')
174
+
175
+ def generate_info_page(self):
176
+ log_text = extract_log_block(os.path.join(self.results_directory, "cpm_log.txt"))
177
+ # --- Page: Info ---
178
+ info_text = ar.Group(ar.Text("""
179
+ # Confound-Corrected Connectome-Based Predictive Modeling
180
+ ## Python Toolbox
181
+ **Author**: Nils R. Winter
182
+ **GitHub**: https://github.com/wwu-mmll/cpm_python
183
+
184
+ **Confound-Corrected Connectome-Based Predictive Modelling** is a Python package for performing connectome-based
185
+ predictive modeling (**CPM**). This toolbox is designed for researchers in neuroscience and psychiatry, providing
186
+ robust methods for building **predictive models** based on structural or functional **connectome** data. It emphasizes
187
+ replicability, interpretability, and flexibility, making it a valuable tool for analyzing brain connectivity
188
+ and its relationship to behavior or clinical outcomes.
189
+ """),
190
+ ar.Text("**Version: 0.1.0**"),
191
+ widths=[7, 1], columns=2)
192
+
193
+
194
+ header = ar.Text("## Analysis Setup")
195
+ log_block = ar.Text(f"<pre>{log_text}</pre>")
196
+ log_group = ar.Group(ar.Blocks(blocks=[header, log_block]), ar.Text("Current Analysis"), columns=2, widths=[7, 1])
197
+
198
+ blocks = ar.Blocks(blocks=[info_text, log_group], label="Info")
199
+ return blocks
200
+
201
+ def generate_main_results_page(self):
202
+ self.df_p_values.set_index(['model', 'network'], inplace=True)
203
+ self.df_p_values.columns = pd.MultiIndex.from_tuples([(col, 'p') for col in self.df_p_values.columns])
204
+ df_combined = pd.concat([self.df_mean, self.df_p_values], axis=1)
205
+ df_combined = df_combined.sort_index(axis=1, level=0)
206
+ desired_order = ["mean", "std", "p"]
207
+ df_combined = df_combined.loc[:,
208
+ sorted(df_combined.columns, key=lambda x: (x[0], desired_order.index(x[1])))
209
+ ]
210
+
211
+ # Style with smaller font
212
+ styled_df = format_results_table(df_combined)
213
+ table = ar.HTML(styled_df.to_html(escape=False), label='Predictive Performance')
214
+
215
+ bar_plot_blocks = []
216
+ for metric in list(self.df.columns)[3:-1]:
217
+ if metric == 'params':
218
+ continue
219
+ plot_name = boxplot_model_performance(self.df, metric, self.plots_dir, models=["covariates", "connectome", "full", "residuals"])
220
+ plot_name_increment = boxplot_model_performance(self.df, metric, self.plots_dir, models=["increment"], filename_suffix="increment")
221
+ #plot_block = ar.Media(file=plot_name, name=f"Image1_{metric}", caption="Boxplot of main predictive performance",
222
+ # label=f'{metric}')
223
+ plot_block = ar.Blocks(blocks=[ar.Media(file=plot_name, name=f"Image1_{metric}"),
224
+ #ar.HTML(plot_name_increment, name=f"Image1_increment_{metric}"),
225
+ ar.Media(file=plot_name_increment, name=f"Image1_increment_{metric}")
226
+ ],
227
+ label=f'{metric}')
228
+ #plot_block = ar.HTML(plot_name, name=f"Image1_{metric}")
229
+
230
+ bar_plot_blocks.append(plot_block)
231
+
232
+ # predictions scatter plot
233
+ scatter_plot_name = scatter_plot(self.df_predictions, self.plots_dir, self.y_name)
234
+ scatter_covariates_name = scatter_plot_covariates_model(self.df_predictions, self.plots_dir, self.y_name)
235
+
236
+ scatter_block = ar.Media(file=scatter_plot_name, name=f"Predictions", caption="Scatter plot of true versus predicted scores.",
237
+ label='predictions')
238
+ scatter_block_covariates = ar.Media(file=scatter_covariates_name, name=f"PredictionsCovariatesModel",
239
+ caption="Scatter plot of true versus predicted scores for the covariates model.",
240
+ label='predictions_covariates')
241
+
242
+ first_row = ar.Group(name='main_results', blocks=[ar.Select(blocks=bar_plot_blocks), scatter_block, scatter_block_covariates], columns=3,
243
+ widths=[2, 1, 1])
244
+
245
+ second_row = ar.Group(name='perms_and_predictions', blocks=[table], columns=2, widths=[2, 1])
246
+
247
+ return ar.Blocks(blocks=[first_row, second_row], label='Predictive Performance')
248
+
249
+ def generate_network_strengths_page(self):
250
+ scatter_network_strengths = scatter_plot_network_strengths(self.df_network_strengths, self.plots_dir, self.y_name)
251
+ scatter_block_network_strength = ar.Media(file=scatter_network_strengths, name=f"NetworkStrengths",
252
+ caption="Scatter plot of target versus network strength scores.",
253
+ label='Network Strengths')
254
+ hist = histograms_network_strengths(self.df_network_strengths, self.plots_dir, self.y_name)
255
+ hist_block_network_strength = ar.Media(file=hist, name=f"NetworkStrengthsHist",
256
+ caption="Histograms of network strength scores.",
257
+ label='Distribution of Network Strengths')
258
+ row = ar.Group(name='network_strengths', blocks=[scatter_block_network_strength, hist_block_network_strength], columns=4)
259
+ return ar.Blocks(blocks=[row], label='Network Strengths')
260
+
261
+ def generate_brain_plot_page(self):
262
+ if self.atlas_labels is None:
263
+ return ar.Blocks(blocks=[ar.Group(blocks=[ar.Text("Provide atlas labels as csv file.")], columns=1)],
264
+ label='Brain Plots')
265
+ plots = list()
266
+ edges = list()
267
+ for metric in ["sig_stability_positive_edges", "sig_stability_negative_edges"]:
268
+ plot_brainplot, edge_list = plot_netplotbrain(results_folder=self.results_directory,
269
+ selected_metric=metric,
270
+ atlas_labels=self.atlas_labels)
271
+ plots.append(plot_brainplot)
272
+ edges.append(edge_list)
273
+
274
+ third_header = ar.Group(blocks=[ar.Text("Significantly Stable Positive Edges"), ar.Text("Significantly Stable Negative Edges")], columns=2)
275
+ third_row = ar.Group(blocks=[ar.Media(file=plots[0]), ar.Media(file=plots[1])], columns=2)
276
+ blocks = ar.Blocks(blocks=[third_header, third_row], label='Brain Plots')
277
+ return blocks
278
+
279
+ def generate_edge_page(self):
280
+ import numpy as np
281
+
282
+ dfs = dict()
283
+ for network in ['positive', 'negative']:
284
+ edges = {'stability': np.load(os.path.join(self.results_directory, f"stability_{network}_edges.npy")),
285
+ 'stability_significance': np.load(os.path.join(self.results_directory, f"sig_stability_{network}_edges.npy"))}
286
+ dfs[network] = self.create_edge_table(edges, self.atlas_labels)
287
+
288
+
289
+ first_header = ar.Group(blocks=[ar.Text("## Positive Edges"), ar.Text("## Negative Edges")], columns=2)
290
+ first_row = ar.Group(blocks=[ar.DataTable(df=dfs['positive']), ar.DataTable(df=dfs['negative'])], columns=2)
291
+
292
+ blocks = ar.Blocks(blocks=[first_header, first_row], label='Stable Edges')
293
+ return blocks
294
+
295
+ @staticmethod
296
+ def create_edge_table(matrix, atlas):
297
+ n = matrix['stability'].shape[0]
298
+ stability = []
299
+ significance = []
300
+ region_a = []
301
+ region_b = []
302
+
303
+ for i in range(1, n):
304
+ for j in range(i):
305
+ if matrix['stability'][i, j] == 0:
306
+ continue
307
+ if atlas is not None:
308
+ region_a.append(atlas['region'][i])
309
+ region_b.append(atlas['region'][j])
310
+ else:
311
+ region_a.append(f"Region {i}")
312
+ region_b.append(f"Region {j}")
313
+ stability.append(matrix['stability'][i, j])
314
+ significance.append(matrix['stability_significance'][i, j])
315
+
316
+ df = pd.DataFrame({'Region A': region_a, 'Region B': region_b,
317
+ 'Stability': stability, 'Stability Significance': significance})
318
+ df[['Stability', 'Stability Significance']] = df[['Stability', 'Stability Significance']].round(5)
319
+
320
+ df.sort_values(by=['Stability Significance', 'Stability'], inplace=True, ascending=[True, False])
321
+ df.set_index(['Region A', 'Region B'], inplace=True)
322
+ if df.empty:
323
+ first_col = df.columns[0]
324
+ row = {col: ("No significantly stable edges." if col == first_col else np.nan) for col in df.columns}
325
+ df = pd.DataFrame([row])
326
+ return df
327
+
328
+ def generate_target_cov_features_insights_page(self):
329
+ """
330
+ Generate a page summarizing the input data:
331
+ - summary statistics from summary.csv
332
+ - scatter matrix image
333
+ """
334
+ summary_path = os.path.join(self.results_directory, "data_insights", "summary.csv")
335
+ scatter_matrix_path = os.path.join(self.results_directory, "data_insights", "scatter_matrix.png")
336
+
337
+ # Load summary table
338
+ if os.path.exists(summary_path):
339
+ summary_df = pd.read_csv(summary_path, index_col=0)
340
+ summary_block = ar.DataTable(df=summary_df, label="Input Data Summary")
341
+ else:
342
+ summary_block = ar.Text("Summary file not found.")
343
+
344
+ # Load scatter matrix image
345
+ if os.path.exists(scatter_matrix_path):
346
+ scatter_block = ar.Media(file=scatter_matrix_path, name="ScatterMatrix",
347
+ caption="Scatter matrix of covariates and target.")
348
+ else:
349
+ scatter_block = ar.Text("Scatter matrix image not found.")
350
+
351
+ # Combine both into a single report block
352
+ row = ar.Group(name='data_overview', blocks=[summary_block, scatter_block], columns=2, widths=[1, 1])
353
+ return ar.Blocks(blocks=[row], label='Target, Covariates & Features')
354
+
355
+ def load_variable_names(self):
356
+ X_names = pd.read_csv(os.path.join(self.results_directory, 'data_insights', "X_names.csv"), header=None)[0].tolist()
357
+ y_name = pd.read_csv(os.path.join(self.results_directory, 'data_insights', "y_name.csv"), header=None).iloc[0, 0]
358
+ covar_names = pd.read_csv(os.path.join(self.results_directory, 'data_insights', "covariate_names.csv"), header=None)[0].tolist()
359
+ return X_names, y_name, covar_names
360
+
361
+ if __name__=="__main__":
362
+ reporter = HTMLReporter(results_directory='/spm-data/vault-data3/mmll/projects/cpm_python/results_new/hcp_SSAGA_TB_Yrs_Smoked_spearman_partial_p=0.01')
363
+ reporter.generate_html_report()
File without changes