cccpm 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cccpm/__init__.py ADDED
@@ -0,0 +1 @@
1
+ from cccpm.cpm_analysis import CPMRegression
cccpm/cpm_analysis.py ADDED
@@ -0,0 +1,272 @@
1
+ import os
2
+ import logging
3
+ import shutil
4
+
5
+ from typing import Union, Type
6
+ from tqdm import tqdm
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ from sklearn.model_selection import BaseCrossValidator, BaseShuffleSplit, KFold, RepeatedKFold, StratifiedKFold
11
+ from sklearn.linear_model import LinearRegression
12
+
13
+ from cccpm.fold import run_inner_folds
14
+ from cccpm.logging import setup_logging
15
+ from cccpm.more_models import BaseCPMModel, LinearCPMModel
16
+ from cccpm.edge_selection import UnivariateEdgeSelection, PThreshold
17
+ from cccpm.results_manager import ResultsManager, PermutationManager
18
+ from cccpm.utils import train_test_split, check_data, impute_missing_values, select_stable_edges, generate_data_insights
19
+ from cccpm.scoring import score_regression_models
20
+ from cccpm.reporting import HTMLReporter
21
+
22
+
23
+ class CPMRegression:
24
+ """
25
+ This class handles the process of performing CPM Regression with cross-validation and permutation testing.
26
+ """
27
+ def __init__(self,
28
+ results_directory: str,
29
+ cpm_model: Type[BaseCPMModel] = LinearCPMModel,
30
+ cv: Union[BaseCrossValidator, BaseShuffleSplit, RepeatedKFold, StratifiedKFold] = KFold(n_splits=10, shuffle=True, random_state=42),
31
+ inner_cv: Union[BaseCrossValidator, BaseShuffleSplit, RepeatedKFold, StratifiedKFold] = None,
32
+ edge_selection: UnivariateEdgeSelection = UnivariateEdgeSelection(
33
+ edge_statistic='pearson',
34
+ edge_selection=[PThreshold(threshold=[0.05], correction=[None])]
35
+ ),
36
+ select_stable_edges: bool = False,
37
+ stability_threshold: float = 0.8,
38
+ impute_missing_values: bool = True,
39
+ calculate_residuals: bool = False,
40
+ n_permutations: int = 0,
41
+ atlas_labels: str = None):
42
+ """
43
+ Initialize the CPMRegression object.
44
+
45
+ Parameters
46
+ ----------
47
+ results_directory: str
48
+ Directory to save results.
49
+ cv: Union[BaseCrossValidator, BaseShuffleSplit]
50
+ Outer cross-validation strategy.
51
+ inner_cv: Union[BaseCrossValidator, BaseShuffleSplit]
52
+ Inner cross-validation strategy for edge selection.
53
+ edge_selection: UnivariateEdgeSelection
54
+ Method for edge selection.
55
+ impute_missing_values: bool
56
+ Whether to impute missing values.
57
+ n_permutations: int
58
+ Number of permutations to run for permutation testing.
59
+ atlas_labels: str
60
+ CSV file containing atlas and regions labels.
61
+ """
62
+ self.results_directory = results_directory
63
+ self.cpm_model = cpm_model
64
+ self.cv = cv
65
+ self.inner_cv = inner_cv
66
+ self.edge_selection = edge_selection
67
+ self.select_stable_edges = select_stable_edges
68
+ self.stability_threshold = stability_threshold
69
+ self.impute_missing_values = impute_missing_values
70
+ self.calculate_residuals = calculate_residuals
71
+ self.n_permutations = n_permutations
72
+
73
+ np.random.seed(42)
74
+ os.makedirs(self.results_directory, exist_ok=True)
75
+ os.makedirs(os.path.join(self.results_directory, "edges"), exist_ok=True)
76
+ setup_logging(os.path.join(self.results_directory, "cpm_log.txt"))
77
+ self.logger = logging.getLogger(__name__)
78
+
79
+ # Log important configuration details
80
+ self._log_analysis_details()
81
+
82
+ # check inner cv and param grid
83
+ if self.inner_cv is None:
84
+ if len(self.edge_selection.param_grid) > 1:
85
+ raise RuntimeError("Multiple hyperparameter configurations but no inner cv defined. "
86
+ "Please provide only one hyperparameter configuration or an inner cv.")
87
+ if self.select_stable_edges:
88
+ raise RuntimeError("Stable edges can only be selected when using an inner cv.")
89
+
90
+ # check and copy atlas labels file
91
+ self.atlas_labels = self._validate_and_copy_atlas_file(atlas_labels)
92
+
93
+ # results are saved to the results manager instance
94
+ self.results_manager = None
95
+
96
+ def _log_analysis_details(self):
97
+ """
98
+ Log important information about the analysis in a structured format.
99
+ """
100
+ self.logger.info("Starting CPM Regression Analysis")
101
+ self.logger.info("="*50)
102
+ self.logger.info(f"Results Directory: {self.results_directory}")
103
+ self.logger.info(f"CPM Model: {self.cpm_model.name}")
104
+ self.logger.info(f"Outer CV strategy: {self.cv}")
105
+ self.logger.info(f"Inner CV strategy: {self.inner_cv}")
106
+ self.logger.info(f"Edge selection method: {self.edge_selection}")
107
+ self.logger.info(f"Select stable edges: {'Yes' if self.select_stable_edges else 'No'}")
108
+ if self.select_stable_edges:
109
+ self.logger.info(f"Stability threshold: {self.stability_threshold}")
110
+ self.logger.info(f"Impute Missing Values: {'Yes' if self.impute_missing_values else 'No'}")
111
+ self.logger.info(f"Calculate residuals: {'Yes' if self.calculate_residuals else 'No'}")
112
+ self.logger.info(f"Number of Permutations: {self.n_permutations}")
113
+ self.logger.info("="*50)
114
+
115
+ def _validate_and_copy_atlas_file(self, csv_path):
116
+ """
117
+ Validates that a CSV file exists and contains the required columns ('x', 'y', 'z', 'region').
118
+ If valid, copies it to <self.results_directory>/edges.
119
+ """
120
+ if csv_path is None:
121
+ return None
122
+
123
+ required_columns = {"x", "y", "z", "region"}
124
+ csv_path = os.path.abspath(csv_path)
125
+
126
+ # Check if file exists
127
+ if not os.path.isfile(csv_path):
128
+ raise RuntimeError(f"CSV file does not exist: {csv_path}")
129
+
130
+ # Try to read and validate columns
131
+ try:
132
+ df = pd.read_csv(csv_path)
133
+ missing = required_columns - set(df.columns)
134
+
135
+ if missing:
136
+ raise RuntimeError(f"CSV file is missing required columns: {', '.join(missing)}")
137
+ except Exception as e:
138
+ raise RuntimeError(f"Error reading CSV file {csv_path}: {e}")
139
+
140
+ # File and columns valid, proceed to copy
141
+ dest_path = os.path.join(self.results_directory, "edges", os.path.basename(csv_path))
142
+
143
+ try:
144
+ shutil.copy(csv_path, dest_path)
145
+ self.logger.info(f"Copied CSV file to {dest_path}")
146
+ return dest_path
147
+ except Exception as e:
148
+ self.logger.error(f"Error copying file to {dest_path}: {e}")
149
+ return None
150
+
151
+ def run(self,
152
+ X: Union[pd.DataFrame, np.ndarray],
153
+ y: Union[pd.Series, pd.DataFrame, np.ndarray],
154
+ covariates: Union[pd.Series, pd.DataFrame, np.ndarray]):
155
+ """
156
+ Estimates a model using the provided data and conducts permutation testing. This method first fits the model to the actual data and subsequently performs estimation on permuted data for a specified number of permutations. Finally, it calculates permutation results.
157
+
158
+ Parameters
159
+ ----------
160
+ X: Feature data used for the model. Can be a pandas DataFrame or a NumPy array.
161
+ y: Target variable used in the estimation process. Can be a pandas Series, DataFrame, or a NumPy array.
162
+ covariates: Additional covariate data to include in the model. Can be a pandas Series, DataFrame, or a NumPy array.
163
+
164
+ """
165
+ self.logger.info(f"Starting estimation with {self.n_permutations} permutations.")
166
+
167
+ # check data and convert to numpy
168
+ generate_data_insights(X=X, y=y, covariates=covariates, results_directory=self.results_directory)
169
+ X, y, covariates = check_data(X, y, covariates, impute_missings=self.impute_missing_values)
170
+
171
+ # Estimate models on actual data
172
+ self._single_run(X=X, y=y, covariates=covariates, perm_run=0)
173
+ self.logger.info("=" * 50)
174
+
175
+ # Estimate models on permuted data
176
+ for perm_id in tqdm(range(1, self.n_permutations + 1), desc="Permutation runs", unit="run",
177
+ total=self.n_permutations):
178
+ y = np.random.permutation(y)
179
+ self._single_run(X=X, y=y, covariates=covariates, perm_run=perm_id)
180
+
181
+ if self.n_permutations > 0:
182
+ PermutationManager.calculate_permutation_results(self.results_directory, self.logger)
183
+ self.logger.info("Estimation completed.")
184
+ self.logger.info("Generating results file.")
185
+ reporter = HTMLReporter(results_directory=self.results_directory, atlas_labels=self.atlas_labels)
186
+ reporter.generate_html_report()
187
+
188
+ def generate_html_report(self):
189
+ self.logger.info("Generating HTML report.")
190
+ reporter = HTMLReporter(results_directory=self.results_directory, atlas_labels=self.atlas_labels)
191
+ reporter.generate_html_report()
192
+
193
+ def _single_run(self,
194
+ X: Union[pd.DataFrame, np.ndarray],
195
+ y: Union[pd.Series, pd.DataFrame, np.ndarray],
196
+ covariates: Union[pd.Series, pd.DataFrame, np.ndarray],
197
+ perm_run: int = 0):
198
+ """
199
+ Perform an estimation run (either real or permuted data). Includes outer cross-validation loop. For permutation
200
+ runs, the same strategy is used, but printing is less verbose and the results folder changes.
201
+
202
+ :param X: Features (predictors).
203
+ :param y: Labels (target variable).
204
+ :param covariates: Covariates to control for.
205
+ :param perm_run: Permutation run identifier.
206
+ """
207
+ results_manager = ResultsManager(output_dir=self.results_directory, perm_run=perm_run,
208
+ n_folds=self.cv.get_n_splits(), n_features=X.shape[1])
209
+
210
+ iterator = (
211
+ tqdm(
212
+ enumerate(self.cv.split(X, y)),
213
+ total=self.cv.get_n_splits(),
214
+ desc="Running outer folds",
215
+ unit="fold"
216
+ )
217
+ if not perm_run else
218
+ enumerate(self.cv.split(X, y))
219
+ )
220
+ for outer_fold, (train, test) in iterator:
221
+ # split according to single outer fold
222
+ X_train, X_test, y_train, y_test, cov_train, cov_test = train_test_split(train, test, X, y, covariates)
223
+
224
+ # impute missing values
225
+ if self.impute_missing_values:
226
+ X_train, X_test, cov_train, cov_test = impute_missing_values(X_train, X_test, cov_train, cov_test)
227
+
228
+ # residualize X to remove effect of covariates
229
+ if self.calculate_residuals:
230
+ residual_model = LinearRegression().fit(cov_train, X_train)
231
+ X_train = X_train - residual_model.predict(cov_train)
232
+ X_test = X_test - residual_model.predict(cov_test)
233
+
234
+ # if the user specified an inner cross-validation, estimate models witin inner loop
235
+ if self.inner_cv:
236
+ best_params, stability_edges = run_inner_folds(cpm_model=self.cpm_model,
237
+ X=X_train, y=y_train, covariates=cov_train,
238
+ inner_cv=self.inner_cv,
239
+ edge_selection=self.edge_selection,
240
+ results_directory=os.path.join(results_manager.results_directory, 'folds', str(outer_fold)),
241
+ perm_run=perm_run)
242
+ else:
243
+ best_params = self.edge_selection.param_grid[0]
244
+
245
+ # Use best parameters to estimate performance on outer fold test set
246
+ if self.select_stable_edges:
247
+ edges = select_stable_edges(stability_edges, self.stability_threshold)
248
+ else:
249
+ self.edge_selection.set_params(**best_params)
250
+ edges = self.edge_selection.fit_transform(X=X_train, y=y_train, covariates=cov_train).return_selected_edges()
251
+
252
+ results_manager.store_edges(edges=edges, fold=outer_fold)
253
+
254
+ # Build model and make predictions
255
+ model = self.cpm_model(edges=edges).fit(X_train, y_train, cov_train)
256
+ y_pred = model.predict(X_test, cov_test)
257
+ network_strengths = model.get_network_strengths(X_test, cov_test)
258
+ metrics = score_regression_models(y_true=y_test, y_pred=y_pred)
259
+ results_manager.store_predictions(y_pred=y_pred, y_true=y_test, params=best_params, fold=outer_fold,
260
+ param_id=0, test_indices=test)
261
+ results_manager.store_metrics(metrics=metrics, params=best_params, fold=outer_fold, param_id=0)
262
+ results_manager.store_network_strengths(network_strengths=network_strengths, y_true=y_test, fold=outer_fold)
263
+
264
+ # once all outer folds are done, calculate final results and edge stability
265
+ results_manager.calculate_final_cv_results()
266
+ results_manager.calculate_edge_stability()
267
+
268
+ if not perm_run:
269
+ self.logger.info(results_manager.agg_results.round(4).to_string())
270
+ results_manager.save_predictions()
271
+ results_manager.save_network_strengths()
272
+ self.results_manager = results_manager
@@ -0,0 +1,271 @@
1
+ import pandas as pd
2
+ import numpy as np
3
+ from scipy.stats import ttest_1samp, t, rankdata
4
+ from typing import Union
5
+
6
+ from sklearn.base import BaseEstimator
7
+ from sklearn.model_selection import ParameterGrid
8
+ import statsmodels.stats.multitest as multitest
9
+ from warnings import filterwarnings
10
+
11
+
12
+ def one_sample_t_test(matrix, population_mean):
13
+ # Calculate the mean and standard deviation along the rows
14
+ sample_means = np.mean(matrix, axis=0)
15
+ sample_stds = np.std(matrix, axis=0, ddof=1)
16
+ n = matrix.shape[1] # Number of samples in each row
17
+
18
+ # Calculate the t-statistics
19
+ filterwarnings('ignore', category=RuntimeWarning)
20
+ t_stats = (sample_means - population_mean) / (sample_stds / np.sqrt(n))
21
+
22
+ # Calculate the p-values using the t-distribution survival function
23
+ p_values = 2 * t.sf(np.abs(t_stats), df=n - 1)
24
+
25
+ return t_stats, p_values
26
+
27
+
28
+ def compute_t_and_p_values(correlations, df):
29
+ # Calculate t-statistics
30
+ t_stats = correlations * np.sqrt(df / (1 - correlations ** 2))
31
+ # Calculate p-values
32
+ p_values = 2 * t.sf(np.abs(t_stats), df=df)
33
+ return t_stats, p_values
34
+
35
+
36
+ def compute_correlation_and_pvalues(x, Y, rank=False):
37
+ n = len(x)
38
+
39
+ if rank:
40
+ x = rankdata(x)
41
+ Y = rankdata(Y, axis=0)
42
+
43
+ # Mean-centering
44
+ x_centered = x - np.mean(x)
45
+ Y_centered = Y - np.mean(Y, axis=0)
46
+
47
+ # Correlation calculation
48
+ corr_numerator = np.dot(Y_centered.T, x_centered)
49
+ corr_denominator = (np.sqrt(np.sum(Y_centered ** 2, axis=0)) * np.sqrt(np.sum(x_centered ** 2)))
50
+
51
+ correlations = corr_numerator / corr_denominator
52
+
53
+ # Calculate t-statistics and p-values
54
+ _, p_values = compute_t_and_p_values(correlations, n - 2)
55
+
56
+ return correlations, p_values
57
+
58
+
59
+ def get_residuals(X, Z):
60
+ # Add a column of ones to Z for the intercept
61
+ Z = np.hstack([Z, np.ones((Z.shape[0], 1))])
62
+
63
+ # Compute the coefficients using the normal equation
64
+ B = np.linalg.lstsq(Z, X, rcond=None)[0]
65
+
66
+ # Predict X from Z
67
+ X_hat = Z.dot(B)
68
+
69
+ # Compute residuals
70
+ residuals = X - X_hat
71
+
72
+ return residuals
73
+
74
+
75
+ def semi_partial_correlation(x, Y, Z, rank=False):
76
+ # ToDo: THIS IS A PARTIAL CORRELATION, NOT SEMI-PARTIAL
77
+ if rank:
78
+ x = rankdata(x)
79
+ Y = rankdata(Y, axis=0)
80
+ Z = rankdata(Z, axis=0)
81
+
82
+ #Y = np.apply_along_axis(rankdata, 0, Y)
83
+ #Z = np.apply_along_axis(rankdata, 0, Z)
84
+
85
+ # Calculate residuals for x and each column in Y
86
+ x_residuals = get_residuals(x.reshape(-1, 1), Z).ravel()
87
+ Y_residuals = get_residuals(Y, Z)
88
+
89
+ # Mean-centering the residuals
90
+ x_centered = x_residuals - np.mean(x_residuals)
91
+ Y_centered = Y_residuals - np.mean(Y_residuals, axis=0)
92
+
93
+ # Correlation calculation
94
+ corr_numerator = np.dot(Y_centered.T, x_centered)
95
+ corr_denominator = (np.sqrt(np.sum(Y_centered ** 2, axis=0)) * np.sqrt(np.sum(x_centered ** 2)))
96
+ partial_corr = corr_numerator / corr_denominator
97
+
98
+ # Calculate t-statistics and p-values
99
+ n = len(x)
100
+ k = Z.shape[1]
101
+ _, p_values = compute_t_and_p_values(partial_corr, n - k - 2)
102
+
103
+ return partial_corr, p_values
104
+
105
+
106
+ def pearson_correlation_with_pvalues(x, Y):
107
+ return compute_correlation_and_pvalues(x, Y, rank=False)
108
+
109
+
110
+ def spearman_correlation_with_pvalues(x, Y):
111
+ return compute_correlation_and_pvalues(x, Y, rank=True)
112
+
113
+
114
+ def semi_partial_correlation_pearson(x, Y, Z):
115
+ return semi_partial_correlation(x, Y, Z, rank=False)
116
+
117
+
118
+ def semi_partial_correlation_spearman(x, Y, Z):
119
+ return semi_partial_correlation(x, Y, Z, rank=True)
120
+
121
+
122
+ class BaseEdgeSelector(BaseEstimator):
123
+ def select(self, r, p):
124
+ pass
125
+
126
+
127
+ class PThreshold(BaseEdgeSelector):
128
+ def __init__(self, threshold: Union[float, list] = 0.05, correction: Union[str, list] = None):
129
+ """
130
+
131
+ :param threshold:
132
+ :param correction: can be one of statsmodels methods
133
+ bonferroni : one-step correction
134
+ sidak : one-step correction
135
+ holm-sidak : step down method using Sidak adjustments
136
+ holm : step-down method using Bonferroni adjustments
137
+ simes-hochberg : step-up method (independent)
138
+ hommel : closed method based on Simes tests (non-negative)
139
+ fdr_bh : Benjamini/Hochberg (non-negative)
140
+ fdr_by : Benjamini/Yekutieli (negative)
141
+ fdr_tsbh : two stage fdr correction (non-negative)
142
+ fdr_tsbky : two stage fdr correction (non-negative)
143
+ """
144
+ self._threshold = None
145
+ self._correction = None
146
+ self.threshold = threshold
147
+ self.correction = correction
148
+
149
+ @property
150
+ def threshold(self):
151
+ if isinstance(self._threshold, (int, float)):
152
+ return [float(self._threshold)]
153
+ return self._threshold or [0.05]
154
+
155
+ @threshold.setter
156
+ def threshold(self, value):
157
+ if isinstance(value, (int, float)):
158
+ self._threshold = float(value)
159
+ elif isinstance(value, list):
160
+ self._threshold = value
161
+ else:
162
+ raise ValueError("threshold must be float or list")
163
+
164
+ @property
165
+ def correction(self):
166
+ if self._correction is None:
167
+ return [None]
168
+ if isinstance(self._correction, str):
169
+ return [self._correction]
170
+ return self._correction
171
+
172
+ @correction.setter
173
+ def correction(self, value):
174
+ if value is None:
175
+ self._correction = None
176
+ elif isinstance(value, str):
177
+ self._correction = value
178
+ elif isinstance(value, list):
179
+ self._correction = value
180
+ else:
181
+ raise ValueError("correction must be None, str, or list")
182
+
183
+ def select(self, r, p):
184
+ if self._correction is not None:
185
+ _, p, _, _ = multitest.multipletests(p, alpha=0.05, method=self._correction)
186
+ pos_edges = np.where((p < self.threshold) & (r > 0))[0]
187
+ neg_edges = np.where((p < self.threshold) & (r < 0))[0]
188
+ return {'positive': pos_edges, 'negative': neg_edges}
189
+
190
+
191
+ class SelectPercentile(BaseEdgeSelector):
192
+ def __init__(self, percentile: Union[float, list] = 0.05):
193
+ self.percentile = percentile
194
+
195
+
196
+ class SelectKBest(BaseEdgeSelector):
197
+ def __init__(self, k: Union[int, list] = None):
198
+ self.k = k
199
+
200
+
201
+ class EdgeStatistic(BaseEstimator):
202
+ def __init__(self, edge_statistic: str = 'spearman', t_test_filter: bool = False):
203
+ self.edge_statistic = edge_statistic
204
+ self.t_test_filter = t_test_filter
205
+
206
+ def fit_transform(self,
207
+ X: Union[pd.DataFrame, np.ndarray],
208
+ y: Union[pd.Series, pd.DataFrame, np.ndarray],
209
+ covariates: Union[pd.Series, pd.DataFrame, np.ndarray]):
210
+ r_edges, p_edges = np.zeros(X.shape[1]), np.ones(X.shape[1])
211
+ #if self.t_test_filter:
212
+ # _, p_values = one_sample_t_test(X, 0)
213
+ # valid_edges = p_values < 0.05
214
+ #else:
215
+ # valid_edges = np.bool(np.ones(X.shape[1]))
216
+
217
+ from sklearn.feature_selection import VarianceThreshold
218
+ selector = VarianceThreshold(threshold=0.01)
219
+ selector.fit(X)
220
+ valid_edges = selector.get_support()
221
+
222
+
223
+ if self.edge_statistic == 'pearson':
224
+ r_edges_masked, p_edges_masked = pearson_correlation_with_pvalues(y, X[:, valid_edges])
225
+ elif self.edge_statistic == 'spearman':
226
+ r_edges_masked, p_edges_masked = spearman_correlation_with_pvalues(y, X[:, valid_edges])
227
+ elif self.edge_statistic == 'pearson_partial':
228
+ r_edges_masked, p_edges_masked = semi_partial_correlation_pearson(y, X[:, valid_edges], covariates)
229
+ elif self.edge_statistic == 'spearman_partial':
230
+ r_edges_masked, p_edges_masked = semi_partial_correlation_spearman(y, X[:, valid_edges], covariates)
231
+ else:
232
+ raise NotImplemented("Unsupported edge selection method")
233
+ r_edges[valid_edges] = r_edges_masked
234
+ p_edges[valid_edges] = p_edges_masked
235
+ return r_edges, p_edges
236
+
237
+
238
+ class UnivariateEdgeSelection(BaseEstimator):
239
+ def __init__(self,
240
+ edge_statistic: str = 'spearman',
241
+ t_test_filter: bool = False,
242
+ edge_selection: Union[list, None, PThreshold] = None,
243
+ ):
244
+ self.r_edges = None
245
+ self.p_edges = None
246
+ self.t_test_filter = t_test_filter
247
+ self.edge_statistic = EdgeStatistic(edge_statistic=edge_statistic, t_test_filter=t_test_filter)
248
+ self.edge_selection = edge_selection
249
+ if isinstance(edge_selection, (list, tuple)):
250
+ self.edge_selection = edge_selection
251
+ else:
252
+ self.edge_selection = [edge_selection]
253
+ self.param_grid = self._generate_config_grid()
254
+
255
+ def _generate_config_grid(self):
256
+ grid_elements = []
257
+ for selector in self.edge_selection:
258
+ params = {}
259
+ params['edge_selection'] = [selector]
260
+ for key, value in selector.get_params().items():
261
+ params['edge_selection__' + key] = value
262
+ grid_elements.append(params)
263
+ return ParameterGrid(grid_elements)
264
+
265
+ def fit_transform(self, X, y=None, covariates=None):
266
+ self.r_edges, self.p_edges = self.edge_statistic.fit_transform(X=X, y=y, covariates=covariates)
267
+ return self
268
+
269
+ def return_selected_edges(self):
270
+ selected_edges = self.edge_selection.select(r=self.r_edges, p=self.p_edges)
271
+ return selected_edges
cccpm/fold.py ADDED
@@ -0,0 +1,46 @@
1
+ from cccpm.utils import train_test_split
2
+ from cccpm.scoring import score_regression_models
3
+ from cccpm.results_manager import ResultsManager
4
+ from cccpm.edge_selection import BaseEdgeSelector
5
+
6
+
7
+ def run_inner_folds(cpm_model, X, y, covariates, inner_cv, edge_selection: BaseEdgeSelector, results_directory,
8
+ perm_run):
9
+ """
10
+ Run inner cross-validation over all folds and hyperparameter configurations.
11
+
12
+ Returns
13
+ -------
14
+ cv_results : DataFrame
15
+ Aggregated results from all inner folds.
16
+ stability_edges : dict
17
+ Dictionary with 'positive' and 'negative' keys mapping to arrays of edge stability scores.
18
+ """
19
+ param_grid = edge_selection.param_grid
20
+ n_features = X.shape[1]
21
+ n_params = len(param_grid)
22
+ n_folds = inner_cv.get_n_splits()
23
+
24
+ results_manager = ResultsManager(output_dir=results_directory, perm_run=perm_run,
25
+ n_folds=n_folds, n_features=n_features, n_params=n_params)
26
+
27
+ for fold_id, (train, test) in enumerate(inner_cv.split(X, y)):
28
+ # split according to single fold
29
+ X_train, X_test, y_train, y_test, cov_train, cov_test = train_test_split(train, test, X, y, covariates)
30
+
31
+ for param_id, config in enumerate(param_grid):
32
+ edge_selection.set_params(**config)
33
+ selected_edges = edge_selection.fit_transform(X_train, y_train, cov_train).return_selected_edges()
34
+ y_pred = cpm_model(edges=selected_edges).fit(X_train, y_train, cov_train).predict(X_test, cov_test)
35
+ metrics = score_regression_models(y_true=y_test, y_pred=y_pred)
36
+
37
+ results_manager.store_edges(selected_edges, fold_id, param_id)
38
+ results_manager.store_metrics(metrics=metrics, params=config, fold=fold_id, param_id=param_id)
39
+
40
+ # once all outer folds are done, calculate final results and edge stability
41
+ results_manager.aggregate_inner_folds()
42
+
43
+ best_params, best_param_id = results_manager.find_best_params()
44
+ stability_edges = results_manager.calculate_edge_stability(write=False, best_param_id=best_param_id)
45
+
46
+ return best_params, stability_edges
cccpm/logging.py ADDED
@@ -0,0 +1,37 @@
1
+ import sys
2
+ import logging
3
+
4
+
5
+ def setup_logging(log_file: str = "analysis_log.txt"):
6
+ # Get the root logger
7
+ logger = logging.getLogger()
8
+
9
+ # Check if handlers already exist and remove them to avoid duplication
10
+ if logger.hasHandlers():
11
+ logger.handlers.clear()
12
+
13
+ # Console handler: logs all levels (DEBUG and above) to the console
14
+ console_handler = logging.StreamHandler(sys.stdout)
15
+ console_handler.setLevel(logging.DEBUG)
16
+ console_handler.setFormatter(SimpleFormatter())
17
+
18
+ # File handler: logs only INFO level logs to the file
19
+ file_handler = logging.FileHandler(log_file, mode='w')
20
+ file_handler.setLevel(logging.INFO)
21
+ file_handler.addFilter(lambda record: record.levelno == logging.INFO)
22
+ file_handler.setFormatter(SimpleFormatter())
23
+
24
+ # Create a logger and set the base level to DEBUG so both handlers can operate independently
25
+ logger.setLevel(logging.DEBUG) # This ensures all messages are passed to handlers
26
+ logger.addHandler(console_handler)
27
+ logger.addHandler(file_handler)
28
+
29
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
30
+ logging.getLogger('matplotlib.font_manager').setLevel(logging.WARNING)
31
+
32
+
33
+ class SimpleFormatter(logging.Formatter):
34
+ def format(self, record):
35
+ log_fmt = "%(message)s"
36
+ formatter = logging.Formatter(log_fmt)
37
+ return formatter.format(record)