EntDetect 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. EntDetect/Jwalk/GridTools.py +567 -0
  2. EntDetect/Jwalk/PDBTools.py +532 -0
  3. EntDetect/Jwalk/SASDTools.py +543 -0
  4. EntDetect/Jwalk/SurfaceTools.py +150 -0
  5. EntDetect/Jwalk/__init__.py +19 -0
  6. EntDetect/Jwalk/naccess.config.txt +255 -0
  7. EntDetect/__init__.py +10 -0
  8. EntDetect/_logging.py +71 -0
  9. EntDetect/change_resolution.py +2361 -0
  10. EntDetect/clustering.py +2626 -0
  11. EntDetect/compare_sim2exp.py +1927 -0
  12. EntDetect/entanglement_features.py +478 -0
  13. EntDetect/gaussian_entanglement.py +2067 -0
  14. EntDetect/order_params.py +1048 -0
  15. EntDetect/resources/__init__.py +11 -0
  16. EntDetect/resources/__pycache__/__init__.cpython-311.pyc +0 -0
  17. EntDetect/resources/calc_K.pl +712 -0
  18. EntDetect/resources/calc_Q.pl +962 -0
  19. EntDetect/resources/pulchra +0 -0
  20. EntDetect/resources/shared_files/__init__.py +2 -0
  21. EntDetect/resources/shared_files/bt_contact_potential.dat +22 -0
  22. EntDetect/resources/shared_files/karanicolas_dihe_parm.dat +1600 -0
  23. EntDetect/resources/shared_files/kgs_contact_potential.dat +22 -0
  24. EntDetect/resources/shared_files/mj_contact_potential.dat +22 -0
  25. EntDetect/resources/stride +0 -0
  26. EntDetect/statistics.py +1344 -0
  27. EntDetect/utilities.py +201 -0
  28. entdetect-1.2.0.dist-info/METADATA +26 -0
  29. entdetect-1.2.0.dist-info/RECORD +45 -0
  30. entdetect-1.2.0.dist-info/WHEEL +5 -0
  31. entdetect-1.2.0.dist-info/entry_points.txt +11 -0
  32. entdetect-1.2.0.dist-info/licenses/LICENSE +674 -0
  33. entdetect-1.2.0.dist-info/top_level.txt +2 -0
  34. scripts/__init__.py +5 -0
  35. scripts/convert_cor_psf_to_pdb.py +103 -0
  36. scripts/run_Foldingpathway.py +162 -0
  37. scripts/run_MSM.py +152 -0
  38. scripts/run_OP_on_simulation_traj.py +194 -0
  39. scripts/run_change_resolution.py +63 -0
  40. scripts/run_compare_sim2exp.py +215 -0
  41. scripts/run_montecarlo.py +158 -0
  42. scripts/run_nativeNCLE.py +179 -0
  43. scripts/run_nonnative_entanglement_clustering.py +110 -0
  44. scripts/run_population_modeling.py +117 -0
  45. scripts/run_workflow4_nativeNCLE_batch.py +412 -0
@@ -0,0 +1,1344 @@
1
+ import time, sys
2
+ from tqdm import tqdm
3
+ import multiprocessing as mp
4
+ from scipy.stats import bootstrap, kstest
5
+ import math, random
6
+ import logging
7
+ from EntDetect._logging import setup_logger
8
+ import argparse
9
+ import glob
10
+ import numpy as np
11
+ import pandas as pd
12
+ from sklearn.preprocessing import OneHotEncoder, LabelEncoder
13
+ from sklearn.compose import ColumnTransformer
14
+ from sklearn.pipeline import Pipeline
15
+ from sklearn.model_selection import train_test_split, StratifiedKFold, KFold, cross_validate, GridSearchCV
16
+ from sklearn.linear_model import LogisticRegression
17
+ from sklearn.metrics import accuracy_score, balanced_accuracy_score, average_precision_score, f1_score, recall_score, precision_score, roc_auc_score
18
+ from sklearn.preprocessing import StandardScaler
19
+ from sklearn.neighbors import NearestNeighbors
20
+ from scipy.spatial.distance import euclidean
21
+ import matplotlib.pyplot as plt
22
+ import os
23
+
24
+ try:
25
+ import rpy2.robjects as robjects
26
+ from rpy2.robjects import pandas2ri
27
+ from rpy2.robjects.packages import importr
28
+ from rpy2.robjects.conversion import localconverter
29
+ except ImportError:
30
+ robjects = None
31
+ pandas2ri = None
32
+ importr = None
33
+ localconverter = None
34
+ import statsmodels.api as sm
35
+ import statsmodels.formula.api as smf
36
+ from scipy.stats import poisson, binom, fisher_exact, chi2, norm
37
+ import scipy.stats as st
38
+ from matplotlib.ticker import MultipleLocator
39
+ from scipy.special import expit
40
+ from scipy.stats import entropy
41
+
42
+ #pd.set_option('display.max_rows', 4000)
43
+
44
+ class ProteomeLogisticRegression:
45
+ """
46
+ A class to handle the data analysis process including encoding, regression, and statistical tests.
47
+ """
48
+
49
+ ################################################################################################
50
+ def __init__(self, dataframe_files:str, outdir:str, gene_list:str, ID:str, reg_formula:str, log_level:int=logging.INFO, logdir:str=None):
51
+ """
52
+ Initializes the DataAnalysis class with necessary paths and parameters.
53
+
54
+ Parameters:
55
+ - dataframe_files (str): Path to residue feature files.
56
+ The file contains columns for each regression variable and 1 column for the response variable.
57
+ The rows should equal the number of samples in the model.
58
+ The uniprot ID should be in the file name and match what is in the gene list file
59
+ - outdir (str): Path to the output directory.
60
+ - gene_lists (str): Path to gene lists to use.
61
+ The file is a single column uniprot ID with no header
62
+ - ID (str): ID for output filenames.
63
+ - reg_formula (str): Regression formula.
64
+ """
65
+ self.dataframe_files = dataframe_files
66
+ self.outdir = outdir
67
+ self.gene_list = np.loadtxt(gene_list, dtype=str)
68
+ self.ID = ID
69
+ self.logger = setup_logger('ProteomeLogisticRegression', outdir=logdir if logdir is not None else outdir, ID=ID, log_level=log_level)
70
+ self.logger.info('Initializing ProteomeLogisticRegression')
71
+ self.logger.debug(f'gene_list: {self.gene_list} {len(self.gene_list)}')
72
+ self.reg_formula = reg_formula
73
+ #self.logger = self.setup_logging()
74
+ #self.gene_list_files = glob.glob(self.gene_list)
75
+
76
+ if not os.path.exists(f'{self.outdir}'):
77
+ os.makedirs(f'{self.outdir}')
78
+ self.logger.info(f'Made output directories {self.outdir}')
79
+ ################################################################################################
80
+
81
+ ################################################################################################
82
+ def encode_boolean_columns(self, df: pd.DataFrame, boolean_columns: list) -> pd.DataFrame:
83
+ """
84
+ Encodes boolean-like columns in a DataFrame to binary 0 and 1.
85
+
86
+ Parameters:
87
+ - df (pd.DataFrame): The input DataFrame.
88
+ - boolean_columns (list): A list of column names to be encoded.
89
+
90
+ Returns:
91
+ - pd.DataFrame: The DataFrame with encoded columns.
92
+ """
93
+ label_encoder = LabelEncoder()
94
+
95
+ for column in boolean_columns:
96
+ if column in df.columns:
97
+ df[column] = label_encoder.fit_transform(df[column])
98
+ else:
99
+ self.logger.info(f"Column '{column}' does not exist in the DataFrame.")
100
+
101
+ return df
102
+ ################################################################################################
103
+
104
+ ################################################################################################
105
+ def regression(self, df, formula):
106
+ """
107
+ Performs quasi-binomial regression analysis on the provided DataFrame.
108
+
109
+ Parameters:
110
+ - df (pd.DataFrame): DataFrame containing the data for regression.
111
+ - formula (str): The formula specifying the regression model.
112
+
113
+ Returns:
114
+ - table_1_df (pd.DataFrame): DataFrame containing the regression results with p-values.
115
+ """
116
+ model = sm.GLM.from_formula(formula, family=sm.families.Binomial(), data=df)
117
+ #model = smf.logit(formula=formula, data=df)
118
+ result = model.fit()
119
+
120
+ # Get the cov_params
121
+ cov_matrix = result.cov_params()
122
+ self.logger.info(f'cov_matrix:\n{cov_matrix}')
123
+
124
+ # Get the coefficients
125
+ self.logger.info("Coefficients:")
126
+ coefficients = {'A': 0}
127
+ for k,v in result.params.items():
128
+ if 'AA' in k:
129
+ k = k.replace('AA[T.', '').replace(']', '')
130
+ coefficients[k] = v
131
+ for k,v in coefficients.items():
132
+ self.logger.info(f'{k}: {v}')
133
+ self.coefficients = coefficients
134
+
135
+ ## recalculate the pvalue to add more digits as statsmodels truncates it to 0 if it is below 0.0001 for some reason.
136
+ self.logger.debug(result.summary())
137
+ table = result.summary().tables[1]
138
+ table_df = pd.DataFrame(table.data[1:], columns=table.data[0])
139
+ pvalues = []
140
+ for z in table_df['z']:
141
+ z = float(z)
142
+ if z < 0:
143
+ p = st.norm.cdf(z)
144
+ else:
145
+ p = 1 - st.norm.cdf(z)
146
+ pvalues += [p*2]
147
+
148
+ table_df['P>|z|'] = pvalues
149
+ table_df = table_df.rename(columns={"": "var"})
150
+ return table_df, cov_matrix
151
+ ################################################################################################
152
+
153
+ ################################################################################################
154
+ def load_data(self, sep:str, reg_var:list, response_var:str, var2binarize:list, mask_column:str):
155
+ """
156
+ Loads the residue feature files and filters the data for analysis.
157
+
158
+ Parameters:
159
+ - sep (str): The separator used in the CSV files.
160
+ - reg_var (list): List of regression variables to include in the analysis.
161
+ - response_var (str): The response variable for the regression.
162
+ - var2binarize (list): List of variables to binarize. Best not to use booleans and convert to 0/1
163
+ - mask_column (list): Column header to use for masking the data. Should be a column containing null values for samples to exclude.
164
+ """
165
+
166
+ self.data = pd.DataFrame()
167
+ self.n = 0
168
+
169
+ if os.path.isfile(self.dataframe_files):
170
+ self.logger.info(f'Loading single design matrix file: {self.dataframe_files}')
171
+ self.data = pd.read_csv(self.dataframe_files, sep=sep, usecols=reg_var+[response_var]+[mask_column, 'gene'])
172
+ self.data = self.data[self.data['gene'].isin(self.gene_list)]
173
+ self.n = len(self.data['gene'].dropna().unique())
174
+ else:
175
+ files = glob.glob(os.path.join(self.dataframe_files, '*'))
176
+ #files = [f for f in files if any(s in f for s in self.gene_list)] # get only those files in the gene list
177
+ self.logger.info(f'Number of files: {len(files)}')
178
+
179
+ for i, gene in enumerate(self.gene_list):
180
+ gene_resFeat = [f for f in files if gene in f]
181
+ if len(gene_resFeat) == 0:
182
+ self.logger.warning(f"WARNING: No residue feature file found for gene {gene}")
183
+ continue
184
+ elif len(gene_resFeat) > 1:
185
+ self.logger.warning(f"WARNING: More than 1 residue feature file found for gene {gene}")
186
+ continue
187
+ gene_resFeat_file = gene_resFeat[0]
188
+ #print(f'gene_resFeat_file: {gene_resFeat_file} {i}')
189
+ if len(self.data) == 0:
190
+ self.data = pd.read_csv(gene_resFeat_file, sep=sep, usecols=reg_var+[response_var]+[mask_column, 'gene'])
191
+ self.n += 1
192
+ else:
193
+ self.data = pd.concat((self.data, pd.read_csv(gene_resFeat_file, sep=sep, usecols=reg_var+[response_var]+[mask_column, 'gene'])))
194
+ self.n += 1
195
+
196
+ #self.data = self.data[self.data['gene'].isin(self.reg_genes)]
197
+ self.data = self.data[self.data['AA'] != 'NC']
198
+ self.data = self.data[self.data[mask_column].notna()]
199
+ self.data = self.data[self.data['AA'].notna()]
200
+ self.data = self.data.reset_index()
201
+ self.logger.info(f'Loaded Regression DataFrame:\n{self.data}')
202
+ self.data = self.encode_boolean_columns(self.data, boolean_columns=var2binarize)
203
+ self.data = self.data[[response_var]+reg_var]
204
+ self.logger.info(f'Loaded Regression DataFrame:\n{self.data}')
205
+ #print(f"Data loaded and filtered. Number of unique genes: {len(self.data['gene'].unique())}")
206
+ ################################################################################################
207
+
208
+ ################################################################################################
209
+ def run(self, ):
210
+ """
211
+ Orchestrates the workflow by loading data, performing regression, and saving results.
212
+ """
213
+
214
+ # Perform regression
215
+ reg, cov_matrix = self.regression(self.data, self.reg_formula)
216
+ reg['coef'] = reg['coef'].astype(float)
217
+ reg['OR'] = np.exp(reg['coef'].astype(float))
218
+ reg['std err'] = reg['std err'].astype(float)
219
+ reg['z'] = reg['z'].astype(float)
220
+ reg['P>|z|'] = reg['P>|z|'].astype(float)
221
+ reg['[0.025'] = np.exp(reg['[0.025'].astype(float))
222
+ reg['0.975]'] = np.exp(reg['0.975]'].astype(float))
223
+ reg['ID'] = self.ID
224
+ reg['n'] = self.n
225
+ self.logger.info(f'Regression Results:\n{reg.to_string()}')
226
+
227
+ return reg
228
+ ################################################################################################
229
+
230
+ # ################################################################################################
231
+ # def plot_regression_results(self, data:pd.DataFrame, reg_df: pd.DataFrame, outfile: str, title: str, reg_var:str, response_var:str):
232
+ # """
233
+ # Plot the regression results as a single single plot of the probability of the response variable vs the regression variables.
234
+ # P(x) = 1 / (1 + exp(-(B0 + B1*X1 + B2*X2 + ... + Bn*Xn)))
235
+ # """
236
+ # import matplotlib.pyplot as plt
237
+ # print(data)
238
+ # print(reg_df)
239
+
240
+ # ## get the coefficients into an easy to read dictionary
241
+ # coefficients = {}
242
+ # for rowi, row in reg_df.iterrows():
243
+ # if 'AA' in row['var']:
244
+ # row['var'] = row['var'].replace('AA[T.', '').replace(']', '')
245
+ # coefficients[row['var']] = row['coef']
246
+ # print(f'coefficients: {coefficients}')
247
+
248
+ # ## plot cut_C_Rall vs region and then fit a curve to it where cut_C_Rall ~ 1 / (1 + exp(-(B0 + B1*X1 + B2*X2 + ... + Bn*Xn)))
249
+ # ## where X1 is the region and X2..Xn is the AA terms
250
+ # df = data.copy()
251
+ # # One-hot encode the AA column (exclude AA not in coefficients)
252
+ # aa_keys = set(coefficients.keys()) - {'Intercept', 'region'}
253
+ # df = df[df['AA'].isin(aa_keys)] # filter only known AA
254
+ # aa_dummies = pd.get_dummies(df['AA'])
255
+ # df = pd.concat([df, aa_dummies], axis=1)
256
+
257
+ # # Add coefficient-weighted terms
258
+ # df['linear_predictor'] = coefficients['Intercept'] + coefficients['region'] * df['region']
259
+ # for aa in aa_keys:
260
+ # if aa in df.columns:
261
+ # df['linear_predictor'] += coefficients[aa] * df[aa]
262
+
263
+ # # Compute predicted probability
264
+ # df['predicted'] = expit(df['linear_predictor'])
265
+ # print(df)
266
+
267
+ # # Aggregate by region for plotting
268
+ # agg = df.groupby('region')[['cut_C_Rall', 'predicted']].mean().reset_index()
269
+
270
+ # # Plot
271
+ # plt.figure(figsize=(8, 6))
272
+ # plt.plot(agg['region'], agg['cut_C_Rall'], label='Observed', marker='o')
273
+ # plt.plot(agg['region'], agg['predicted'], label='Predicted (logistic fit)', marker='x')
274
+ # plt.xlabel('Region')
275
+ # plt.ylabel('Probability of cut_C_Rall = 1')
276
+ # plt.title('Observed vs Predicted cut_C_Rall by Region')
277
+ # plt.legend()
278
+ # plt.grid(True)
279
+ # plt.tight_layout()
280
+
281
+
282
+ # plt.savefig(outfile)
283
+ # print(f'SAVED: {outfile}')
284
+ # ################################################################################################
285
+
286
+ #########################################################################################################################
287
+ #########################################################################################################################
288
+ class MonteCarlo:
289
+
290
+ """
291
+ A class to handle the data analysis process including encoding, regression, and statistical tests.
292
+ """
293
+
294
+ def __init__(self, dataframe_files:str, outdir:str, gene_list:str, ID:str, reg_formula:str, response_var:str, test_var:str, random:bool, n_groups:int, steps:int, C1:float, C2:float, beta:float, linearT:bool, log_level:int=logging.INFO, logdir:str=None):
295
+ """
296
+ Initializes the DataAnalysis class with necessary paths and parameters.
297
+
298
+ Parameters:
299
+ - dataframe_files (str): Path to residue feature files.
300
+ - outdir (str): Path to the output directory.
301
+ - gene_lists (str): Path to gene lists to use.
302
+ - ID (str): ID for output filenames.
303
+ - reg_formula (str): Regression formula.
304
+ - response_var (str): The response variable for the regression.
305
+ - test_var (str): The test variable for the regression.
306
+ - random (bool): Whether to randomize the data.
307
+ - n_groups (int): Number of groups for the Monte Carlo simulation.
308
+ - steps (int): Number of steps for the Monte Carlo simulation.
309
+ - C1 (float): Coefficient for the energy function.
310
+ - C2 (float): Coefficient for the energy function.
311
+ - beta (float): Initial temperature for the simulated annealing.
312
+ - linearT (bool): Whether to use linear temperature scaling.
313
+ """
314
+ self.dataframe_files = dataframe_files
315
+ self.outdir = outdir
316
+ self.gene_list = np.loadtxt(gene_list, dtype=str)
317
+ self.num_genes = len(self.gene_list)
318
+ self.logger = setup_logger('MonteCarlo', outdir=logdir if logdir is not None else outdir, ID=ID, log_level=log_level)
319
+ self.logger.debug(f'gene_list: {self.gene_list} {self.num_genes}')
320
+
321
+ self.ID = ID
322
+ self.reg_formula = reg_formula
323
+ self.response_var = response_var
324
+ self.test_var = test_var
325
+ self.steps = steps
326
+ self.C1 = C1
327
+ self.C2 = C2
328
+ self.beta = beta
329
+ self.linearT = linearT
330
+ self.data = {}
331
+ self.n_groups = n_groups
332
+ self.random = random
333
+
334
+ if not os.path.exists(f'{self.outdir}'):
335
+ os.makedirs(f'{self.outdir}')
336
+ self.logger.info(f'Made output directories {self.outdir}')
337
+
338
+ self.logger.info(f'{"#"*100}\nNEW RUN')
339
+
340
+ # store the parameters in the log file
341
+ self.logger.info(f'dataframe_files: {self.dataframe_files}')
342
+ self.logger.info(f'outdir: {self.outdir}')
343
+ self.logger.info(f'gene_list: {self.gene_list} {self.num_genes}')
344
+ self.logger.info(f'ID: {self.ID}')
345
+ self.logger.info(f'response_var: {self.response_var}')
346
+ self.logger.info(f'test_var: {self.test_var}')
347
+ self.logger.info(f'steps: {self.steps}')
348
+ self.logger.info(f'C1: {self.C1}')
349
+ self.logger.info(f'C2: {self.C2}')
350
+ self.logger.info(f'linearT: {self.linearT}')
351
+ self.logger.info(f'beta: {self.beta}')
352
+ self.logger.info(f'n_groups: {self.n_groups}')
353
+
354
+ ################################################################################################
355
+
356
+ ################################################################################################
357
+ def setup_logging(self):
358
+ """
359
+ Sets up the logging configuration.
360
+
361
+ Returns:
362
+ - logger (logging.Logger): Configured logger.
363
+ """
364
+ logging.basicConfig(level=logging.INFO)
365
+ logger = logging.getLogger(__name__)
366
+ return logger
367
+ ################################################################################################
368
+
369
+ ################################################################################################
370
+ def encode_boolean_columns(self, df: pd.DataFrame, boolean_columns: list) -> pd.DataFrame:
371
+ """
372
+ Encodes boolean-like columns in a DataFrame to binary 0 and 1.
373
+
374
+ Parameters:
375
+ - df (pd.DataFrame): The input DataFrame.
376
+ - boolean_columns (list): A list of column names to be encoded.
377
+
378
+ Returns:
379
+ - pd.DataFrame: The DataFrame with encoded columns.
380
+ """
381
+ label_encoder = LabelEncoder()
382
+
383
+ for column in boolean_columns:
384
+ if column in df.columns:
385
+ df[column] = label_encoder.fit_transform(df[column])
386
+ else:
387
+ self.logger.info(f"Column '{column}' does not exist in the DataFrame.")
388
+
389
+ return df
390
+ ################################################################################################
391
+
392
+ ################################################################################################
393
+ def regression(self, df, formula, genes):
394
+ """
395
+ Performs quasi-binomial regression analysis on the provided DataFrame.
396
+
397
+ Parameters:
398
+ - df (pd.DataFrame): DataFrame containing the data for regression.
399
+ - formula (str): The formula specifying the regression model.
400
+
401
+ Returns:
402
+ - table_1_df (pd.DataFrame): DataFrame containing the regression results with p-values.
403
+ """
404
+
405
+ model = sm.GLM.from_formula(formula, family=sm.families.Binomial(), data=df)
406
+ result = model.fit()
407
+
408
+ ## recalculate the pvalue to add more digits as statsmodels truncates it to 0 if it is below 0.0001 for some reason.
409
+ table = result.summary().tables[1]
410
+ table_df = pd.DataFrame(table.data[1:], columns=table.data[0])
411
+ pvalues = []
412
+ for z in table_df['z']:
413
+ z = float(z)
414
+ if z < 0:
415
+ p = st.norm.cdf(z)
416
+ else:
417
+ p = 1 - st.norm.cdf(z)
418
+ pvalues += [p*2]
419
+ table_df['P>|z|'] = pvalues
420
+ table_df = table_df.rename(columns={"": "var"})
421
+
422
+ coef = table_df[table_df['var'] == 'region']['coef'].values[0]
423
+ std = table_df[table_df['var'] == 'region']['std err'].values[0]
424
+
425
+ # get size dist
426
+ size_dist = self.prot_size[self.prot_size['gene'].isin(genes)]['prot_size'].values
427
+ return table_df, float(coef), float(std), table_df[table_df['var'] == 'region'], size_dist
428
+ ################################################################################################
429
+
430
+ ################################################################################################
431
+ def metrics(self, df, genes:list):
432
+
433
+ ctable = pd.crosstab(df[self.response_var], df[self.test_var])
434
+ res = fisher_exact(ctable)
435
+ OR, pvalue = res.statistic, res.pvalue
436
+
437
+ # get size dist
438
+ size_dist = self.prot_size[self.prot_size['gene'].isin(genes)]['prot_size'].values
439
+
440
+ return OR, pvalue, size_dist
441
+ ################################################################################################
442
+
443
+ ################################################################################################
444
+ def load_data(self, sep:str, reg_var:list, response_var:str, var2binarize:list, mask_column:str, ID_column:str, Length_column:str):
445
+ """
446
+ Loads the residue feature files and filters the data for analysis.
447
+
448
+ Parameters:
449
+ - sep (str): The separator used in the CSV files.
450
+ - reg_var (list): List of regression variables to include in the analysis.
451
+ - response_var (str): The response variable for the regression.
452
+ - var2binarize (list): List of variables to binarize. Best not to use booleans and convert to 0/1
453
+ - mask_column (list): Column header to use for masking the data. Should be a column containing null values for samples to exclude.
454
+ - ID_column (str): Column header for the gene ID.
455
+ """
456
+
457
+ self.data = pd.DataFrame()
458
+ self.n = 0
459
+ self.prot_size = {'gene':[], 'prot_size':[]}
460
+ if os.path.isfile(self.dataframe_files):
461
+ self.logger.info(f'Loading single design matrix file: {self.dataframe_files}')
462
+ self.data = pd.read_csv(self.dataframe_files, sep=sep, usecols=reg_var+[response_var]+[mask_column, ID_column, Length_column])
463
+ self.data = self.data[self.data[ID_column].isin(self.gene_list)]
464
+ self.n = len(self.data[ID_column].dropna().unique())
465
+
466
+ size_df = self.data[[ID_column, Length_column]].dropna().drop_duplicates(subset=[ID_column])
467
+ self.prot_size = {
468
+ 'gene': size_df[ID_column].tolist(),
469
+ 'prot_size': size_df[Length_column].tolist(),
470
+ }
471
+ else:
472
+ files = glob.glob(os.path.join(self.dataframe_files, '*'))
473
+ #files = [f for f in files if any(s in f for s in self.gene_list)] # get only those files in the gene list
474
+ self.logger.info(f'Number of files: {len(files)}')
475
+
476
+ for i, gene in enumerate(self.gene_list):
477
+ gene_resFeat = [f for f in files if gene in f]
478
+ if len(gene_resFeat) == 0:
479
+ self.logger.warning(f"WARNING: No residue feature file found for gene {gene}")
480
+ continue
481
+ elif len(gene_resFeat) > 1:
482
+ self.logger.warning(f"WARNING: More than 1 residue feature file found for gene {gene}")
483
+ continue
484
+ gene_resFeat_file = gene_resFeat[0]
485
+ #print(f'gene_resFeat_file: {gene_resFeat_file} {i}')
486
+ if len(self.data) == 0:
487
+ self.data = pd.read_csv(gene_resFeat_file, sep=sep, usecols=reg_var+[response_var]+[mask_column, ID_column, Length_column])
488
+ self.n += 1
489
+ self.prot_size['gene'] += [gene]
490
+ self.prot_size['prot_size'] += [self.data[Length_column].values[0]]
491
+ else:
492
+ df = pd.read_csv(gene_resFeat_file, sep=sep, usecols=reg_var+[response_var]+[mask_column, ID_column, Length_column])
493
+ self.data = pd.concat((self.data, df))
494
+ self.n += 1
495
+ self.prot_size['gene'] += [gene]
496
+ self.prot_size['prot_size'] += [df[Length_column].values[0]]
497
+
498
+ #self.data = self.data[self.data['gene'].isin(self.reg_genes)]
499
+ self.data = self.data[self.data['AA'] != 'NC']
500
+ self.data = self.data[self.data[mask_column].notna()]
501
+ self.data = self.data[self.data['AA'].notna()]
502
+ self.data = self.data.reset_index()
503
+ self.logger.info(f'Loaded Regression DataFrame:\n{self.data}')
504
+ self.logger.info(f'number of genes: {len(self.data[ID_column].unique())}')
505
+
506
+ self.data = self.encode_boolean_columns(self.data, boolean_columns=var2binarize)
507
+ self.data = self.data[[ID_column]+[response_var]+reg_var]
508
+ self.logger.info(f'Loaded Regression DataFrame:\n{self.data}')
509
+
510
+ self.prot_size = pd.DataFrame(self.prot_size)
511
+ self.logger.info(f'self.prot_size: {self.prot_size}')
512
+ #print(f"Data loaded and filtered. Number of unique genes: {len(self.data['gene'].unique())}")
513
+
514
+ if self.random:
515
+ self.logger.info(f'Randomizing {response_var} column')
516
+ self.data[response_var] = self.data[response_var].sample(frac=1).values
517
+ self.logger.info(f'Randomized Regression DataFrame:\n{self.data}')
518
+
519
+ ################################################################################################
520
+
521
+ ################################################################################################
522
+ def run(self, encoded_df, ID_column:str):
523
+ """
524
+ Orchestrates the workflow by loading data, performing regression, and saving results.
525
+ """
526
+
527
+ ## load reference size dist
528
+ self.ref_sizes = self.prot_size['prot_size'].values
529
+ self.logger.info(f'ref_sizes:\n{self.ref_sizes}')
530
+
531
+ #####################################################################################################
532
+ ## do it randomly untill the groups ceof are in an assending order
533
+ groups = {}
534
+ subgroups = self.create_unique_subgroups(np.arange(len(self.gene_list)), self.n_groups)
535
+ coefs = []
536
+ last_step = -1
537
+ for n, subgroup in enumerate(subgroups):
538
+ #logging.info(n, len(subgroup), subgroup, type(subgroup))
539
+
540
+ subgroup = np.array(subgroup, dtype=int)
541
+ groups[n] = {}
542
+ groups[n]['genes'] = [self.gene_list[s] for s in subgroup]
543
+
544
+ # get the inital OR, number of cuts, and regression row
545
+ #_, coef, std, cuts, reg_row, size_dist = self.regression(encoded_df[encoded_df['gene'].isin(groups_temp[n]['genes'])][reg_vars], self.reg_formula, groups_temp[n]['genes'])
546
+ #response_var:str, test_var:str, genes:list
547
+ OR, pvalue, size_dist = self.metrics(encoded_df[encoded_df[ID_column].isin(groups[n]['genes'])], groups[n]['genes'])
548
+ self.logger.debug(f'OR: {OR} pvalue: {pvalue} size_dist: {size_dist}')
549
+
550
+ state_size_dist_bootres = bootstrap((size_dist,) , np.mean)
551
+ state_size_mean, state_size_lb, state_size_ub = np.mean(size_dist), state_size_dist_bootres.confidence_interval.low, state_size_dist_bootres.confidence_interval.high
552
+ ks_stat_size = kstest(self.ref_sizes, size_dist).statistic
553
+
554
+ E = -1*self.C1*np.log(OR) + self.C2*(ks_stat_size)
555
+
556
+ groups[n]['OR'] = [OR]
557
+ groups[n]['pvalue'] = [pvalue]
558
+ groups[n]['size_dist'] = [size_dist]
559
+ groups[n]['psize_mean'] = [state_size_mean]
560
+ groups[n]['psize_lb'] = [state_size_lb]
561
+ groups[n]['psize_ub'] = [state_size_ub]
562
+ groups[n]['step'] = [last_step]
563
+ groups[n]['E'] = [E]
564
+ groups[n]['ks_stat_size'] = [ks_stat_size]
565
+ groups[n]['beta'] = [self.beta]
566
+
567
+ ## log the starting information
568
+ for state in range(len(groups)):
569
+ state_data = groups[state]
570
+ state_OR = state_data['OR'][-1]
571
+ state_pvalue = state_data['pvalue'][-1]
572
+ num_genes = len(state_data['genes'])
573
+ state_step = state_data['step'][-1]
574
+ state_E = state_data['E'][-1]
575
+ ks_stat_size = state_data['ks_stat_size'][-1]
576
+ state_beta = state_data['beta'][-1]
577
+
578
+ self.logger.info(f'STEP: {state_step} | State: {state} | OR: {state_OR} | pvalue: {state_pvalue} | num_genes: {num_genes} | E: {state_E} | ks_stat_size: {ks_stat_size} | beta: {state_beta}')
579
+ self.logger.debug(f'STEP: {state_step} | State: {state} | OR: {state_OR} | pvalue: {state_pvalue} | num_genes: {num_genes} | E: {state_E} | ks_stat_size: {ks_stat_size} | beta: {state_beta}')
580
+
581
+
582
+ #####################################################################################################
583
+ #####################################################################################################
584
+ ## Start optimizer and run for X steps using simulated annealing where
585
+ # the energy function for each state is defined as E(i) = -C1*OR(i) + C2*ks-test(ref_sizes, sizes))|
586
+ # for each step we swap non-overlapping states. For exampline in a 5 state system
587
+ # step 1 we swap 1-2, 3-4
588
+ # step 2 we swap 2-3, 4-5
589
+ # this constitutes one full montecarlo step
590
+ #
591
+ # after every 1000 MC steps move to the next beta where beta ranges from (1/15 to 1000)
592
+ #
593
+ # The objective function is then M = exp(-beta*deltaE)
594
+ # if deltaE is <=0 Accept the step
595
+ # else
596
+ # get random float [0,1]
597
+ # M > random_float and M < 1 accept, else reject
598
+ #
599
+ # beta should start at 50 times the starting energy scale and decrease by half every 1000 steps till it reaches 1 (?)
600
+
601
+ def swap_n(list1, list2, n):
602
+ # Ensure n is not larger than the smaller list size
603
+ n = min(n, len(list1), len(list2))
604
+ list1 = np.array(list1)
605
+ list2 = np.array(list2)
606
+
607
+ # Randomly select n indices from both lists
608
+ indices1 = random.sample(range(len(list1)), n)
609
+ indices2 = random.sample(range(len(list2)), n)
610
+
611
+ # Swap the elements at the selected indices
612
+ for i in range(n):
613
+ idx1 = indices1[i]
614
+ idx2 = indices2[i]
615
+ gene1 = list1[indices1[i]]
616
+ gene2 = list2[indices2[i]]
617
+ #print(i, idx1, idx2, gene1, gene2)
618
+ list1[idx1] = gene2
619
+ list2[idx2] = gene1
620
+
621
+ return list1, list2
622
+
623
+ ## create swapping scheme for each monte carlo step
624
+ self.logger.info(f'Making paris for MC swaps of {self.n_groups} groups in {self.steps} steps')
625
+ pairs = [(i, i+1) for i in range(0, self.n_groups - 1 )]
626
+ self.logger.debug(f'pairs: {pairs}')
627
+
628
+ reps = self.steps
629
+ beta = self.beta
630
+ beta_i = 0
631
+ if self.linearT == False:
632
+ betas = np.linspace(beta, 1000, 75)
633
+ elif self.linearT == True:
634
+ T = 1/beta
635
+ Ts = np.linspace(T, 0.001, 75)
636
+ self.logger.debug(f'Ts: {Ts}')
637
+ betas = 1/Ts
638
+ self.logger.debug(f'betas: {betas} wiht start beta: {betas[beta_i]}')
639
+ logging.info(f'Starting simulations with {reps} steps and beta = {self.beta}')
640
+
641
+ for step in tqdm(range(last_step + 1, reps + last_step + 1)):
642
+ #for step in range(last_step + 1, reps + last_step + 1):
643
+ #logging.info(f'{"#"*100}\n{"#"*100}\nSTEP: {step}')
644
+ #print(f'{"#"*100}\n{"#"*100}\nSTEP: {step}')
645
+
646
+ ## For each pair in the pairs to test
647
+ for pair in pairs:
648
+ #print(f'{"#"*100}\npair: {pair}')
649
+
650
+ # Get previous energy
651
+ Eold = groups[pair[0]]['E'][-1] + groups[pair[1]]['E'][-1]
652
+ #print(f'Eold: {Eold}')
653
+
654
+ # Get old state genes
655
+ p0_genes = groups[pair[0]]['genes']
656
+ p1_genes = groups[pair[1]]['genes']
657
+
658
+ # Swap n genes
659
+ p0_genes_prime, p1_genes_prime = swap_n(p0_genes, p1_genes, 5)
660
+
661
+ # get new regression info
662
+ #_, p0_coef, p0_std, p0_cuts, p0_reg_row, p0_size_dist = self.regression(encoded_df[encoded_df['gene'].isin(p0_genes_prime)][reg_vars], self.reg_formula, p0_genes_prime)
663
+ OR, pvalue, size_dist = self.metrics(encoded_df[encoded_df[ID_column].isin(p0_genes_prime)], p0_genes_prime)
664
+ p0_OR, p0_pvalue, p0_size_dist = self.metrics(encoded_df[encoded_df[ID_column].isin(p0_genes_prime)], p0_genes_prime)
665
+ p0_ks_stat_size = kstest(self.ref_sizes, p0_size_dist).statistic
666
+ Ep0 = -1*self.C1*np.log(p0_OR) + self.C2*(p0_ks_stat_size)
667
+
668
+ #_, p1_coef, p1_std, p1_cuts, p1_reg_row, p1_size_dist = self.regression(encoded_df[encoded_df['gene'].isin(p1_genes_prime)][reg_vars], self.reg_formula, p1_genes_prime)
669
+ p1_OR, p1_pvalue, p1_size_dist = self.metrics(encoded_df[encoded_df[ID_column].isin(p1_genes_prime)], p1_genes_prime)
670
+ p1_ks_stat_size = kstest(self.ref_sizes, p1_size_dist).statistic
671
+ Ep1 = -1*self.C1*np.log(p1_OR) + self.C2*(p1_ks_stat_size)
672
+
673
+ # Calculate new E and deltaE
674
+ Enew = Ep0 + Ep1
675
+ #print(f'Enew: {Enew}')
676
+ deltaE = Enew - Eold
677
+ #print(f'deltaE: {deltaE}')
678
+
679
+ # Apply metropolis critera
680
+ rand_float = random.uniform(0, 1)
681
+ M = np.exp(-1*beta*deltaE)
682
+ if deltaE <= 0:
683
+ accept_M = True
684
+ else:
685
+ ## Apply metropolis critera
686
+ if M < 1 and M > rand_float:
687
+ accept_M = True
688
+ else:
689
+ accept_M = False
690
+ #print(f'M: {M} with accept_M: {accept_M}')
691
+
692
+ if accept_M:
693
+ groups[pair[0]]['genes'] = p0_genes_prime
694
+ groups[pair[1]]['genes'] = p1_genes_prime
695
+
696
+ ## Get step summary after all pairs have been
697
+ #print(f'{"#"*100}\nStep summary for {step}')
698
+ logstr = [f'{"#"*50}']
699
+ for state, state_data in groups.items():
700
+ state_genes = state_data['genes']
701
+ #_, step_coef, step_std, step_cuts, step_reg_row, step_size_dist = self.regression(encoded_df[encoded_df['gene'].isin(state_genes)][reg_vars], self.reg_formula, state_genes)
702
+ state_OR, state_pvalue, state_size_dist = self.metrics(encoded_df[encoded_df[ID_column].isin(state_genes)], state_genes)
703
+ state_size_dist_bootres = bootstrap((state_size_dist,) , np.mean)
704
+ state_size_mean, state_size_lb, state_size_ub = np.mean(state_size_dist), state_size_dist_bootres.confidence_interval.low, state_size_dist_bootres.confidence_interval.high
705
+
706
+ ks_stat_size = kstest(self.ref_sizes, state_size_dist).statistic
707
+
708
+ E = -1*self.C1*np.log(state_OR) + self.C2*(ks_stat_size)
709
+ #print(f'STEP: {step} | state {state} OR: {state_OR} pvalue: {state_pvalue} cuts: {state_cuts} size_mean: {state_size_mean} | ks_stat: {ks_stat} | E: {E}')
710
+ logstr += [f'STEP: {step} | state {state} OR: {state_OR} pvalue: {state_pvalue} size_mean: {state_size_mean} | ks_stat_size: {ks_stat_size} | E: {E} | beta: {beta}']
711
+
712
+ groups[state]['OR'] += [state_OR]
713
+ groups[state]['pvalue'] += [state_pvalue]
714
+ groups[state]['size_dist'] += [state_size_dist]
715
+ groups[state]['psize_mean'] += [state_size_mean]
716
+ groups[state]['psize_lb'] += [state_size_lb]
717
+ groups[state]['psize_ub'] += [state_size_ub]
718
+ groups[state]['step'] += [step]
719
+ groups[state]['ks_stat_size'] += [ks_stat_size]
720
+ groups[state]['E'] += [E]
721
+ groups[state]['beta'] += [beta]
722
+
723
+
724
+ ## check ranks
725
+ #old_ranks = sorted(range(len(old_coefs)), key=lambda i: old_coefs[i], reverse=True)
726
+ #new_ranks = sorted(range(len(new_coefs)), key=lambda i: new_coefs[i], reverse=True)
727
+ #rank_cond = new_ranks == old_ranks
728
+
729
+ # logging.info status of step to log file
730
+ if step % 100 == 0:
731
+ logstr += [f'{"#"*50}']
732
+ logstr = '\n'.join(logstr)
733
+ logging.info(logstr)
734
+
735
+
736
+ # update beta
737
+ if step % 750 == 0 and step > 10:
738
+ beta_i += 1
739
+ if beta_i < len(betas):
740
+ beta = betas[beta_i]
741
+ else:
742
+ beta = 1000
743
+ logging.info(f'Beta update: {beta}')
744
+
745
+
746
+ logging.info(f'{"#*50"}Simulation complete')
747
+ logging.info(f'{"#*50"}Final state stats')
748
+ for state in range(len(groups)):
749
+ state_data = groups[state]
750
+ state_OR = state_data['OR'][-1]
751
+ state_pvalue = state_data['pvalue'][-1]
752
+ state_size_dist = state_data['size_dist'][-1]
753
+ state_size_mean = state_data['psize_mean'][-1]
754
+ state_size_lb = state_data['psize_lb'][-1]
755
+ state_size_ub = state_data['psize_ub'][-1]
756
+ state_ks_stat_size = state_data['ks_stat_size'][-1]
757
+ state_E = state_data['E'][-1]
758
+ state_beta = state_data['beta'][-1]
759
+ logging.info(f'State: {state} | OR: {state_OR} | pvalue: {state_pvalue} | state_size_mean: {state_size_mean:.0f} ({state_size_lb:.0f}, {state_size_ub:.0f}) | state_ks_stat_size: {state_ks_stat_size} | E: {state_E} | beta: {state_beta}')
760
+ self.logger.debug(f'State: {state} | OR: {state_OR} | pvalue: {state_pvalue} | state_size_mean: {state_size_mean:.0f} ({state_size_lb:.0f}, {state_size_ub:.0f}) | state_ks_stat_size: {state_ks_stat_size} | E: {state_E} | beta: {state_beta}')
761
+ #####################################################################################################
762
+
763
+ #####################################################################################################
764
+ ## Save final results
765
+ dfs = []
766
+ for state, state_data in groups.items():
767
+ #logging.info(n, len(subgroup), subgroup, type(subgroup))
768
+
769
+ # get the inital OR, number of cuts, and regression row
770
+ _, coef, std, reg_row, size_dist = self.regression(encoded_df[encoded_df[ID_column].isin(groups[state]['genes'])], self.reg_formula, groups[state]['genes'])
771
+ state_size_dist_bootres = bootstrap((size_dist,) , np.mean)
772
+ state_size_mean, state_size_lb, state_size_ub = np.mean(size_dist), state_size_dist_bootres.confidence_interval.low, state_size_dist_bootres.confidence_interval.high
773
+
774
+ ks_stat_size = kstest(self.ref_sizes, size_dist).statistic
775
+
776
+ E = -1*self.C1*coef + self.C2*(ks_stat_size)
777
+
778
+ reg_row['state'] = state
779
+ reg_row['beta'] = beta
780
+ reg_row['ks_stat_size'] = ks_stat_size
781
+ reg_row['E'] = E
782
+ reg_row['psize_mean'] = state_size_mean
783
+ reg_row['psize_lb'] = state_size_lb
784
+ reg_row['psize_ub'] = state_size_ub
785
+ reg_row['step'] = step
786
+ reg_row['OR'] = np.exp(coef)
787
+ reg_row['OR_lb'] = np.exp(reg_row['[0.025'].astype(float))
788
+ reg_row['OR_ub'] = np.exp(reg_row['0.975]'].astype(float))
789
+ reg_row['ID'] = self.ID
790
+ reg_row['n'] = len(state_data['genes'])
791
+ dfs += [reg_row]
792
+
793
+ ## save the final gene list for the state
794
+ state_final_genelist_outfile = os.path.join(self.outdir, f'State{state}_final_genelist_{self.ID}.txt')
795
+ logging.info(state_data['genes'])
796
+ np.savetxt(state_final_genelist_outfile, list(state_data['genes']), fmt='%s')
797
+ logging.info(f'SAVED: {state_final_genelist_outfile}')
798
+
799
+ ## save the state data for this state
800
+ state_df = {'state':[], 'step':[], 'OR':[], 'pvalue':[], 'psize_mean':[], 'psize_lb':[], 'psize_ub':[], 'ks_stat_size':[], 'E':[], 'beta':[]}
801
+ for i in range(len(state_data['step'])):
802
+ state_df['state'] += [state]
803
+ state_df['step'] += [state_data['step'][i]]
804
+ state_df['OR'] += [state_data['OR'][i]]
805
+ state_df['pvalue'] += [state_data['pvalue'][i]]
806
+ state_df['psize_mean'] += [state_data['psize_mean'][i]]
807
+ state_df['psize_lb'] += [state_data['psize_lb'][i]]
808
+ state_df['psize_ub'] += [state_data['psize_ub'][i]]
809
+ state_df['ks_stat_size'] += [state_data['ks_stat_size'][i]]
810
+ state_df['E'] += [state_data['E'][i]]
811
+ state_df['beta'] += [state_data['beta'][i]]
812
+ state_df = pd.DataFrame(state_df)
813
+ #print(state_df)
814
+ state_final_traj_outfile = os.path.join(self.outdir, f'State{state}_final_traj_{self.ID}.csv')
815
+ state_df.to_csv(state_final_traj_outfile, index=False)
816
+ logging.info(f'SAVED: {state_final_traj_outfile}')
817
+
818
+ outdf = pd.concat(dfs)
819
+ self.logger.debug(outdf)
820
+
821
+ ## Save the final step regression data
822
+ final_step_outfile = os.path.join(self.outdir, f'Final_step_reg_{self.ID}.csv')
823
+ outdf.to_csv(final_step_outfile, index=False)
824
+ self.logger.debug(f'SAVED: {final_step_outfile}')
825
+ logging.info(f'SAVED: {final_step_outfile}')
826
+
827
+ #####################################################################################################
828
+ ################################################################################################
829
+
830
+ ################################################################################################
831
+ def create_unique_subgroups(self, array, m):
832
+ # Shuffle the array
833
+ np.random.shuffle(array)
834
+
835
+ # Calculate the size of each subgroup
836
+ # If you want each subgroup to have equal size:
837
+ subgroup_size = len(array) // m
838
+
839
+ # Initialize the list of subgroups
840
+ subgroups = []
841
+
842
+ # Split the array into subgroups
843
+ for i in range(m):
844
+ start_index = i * subgroup_size
845
+ end_index = start_index + subgroup_size
846
+
847
+ # Handle the case where the last subgroup might be larger
848
+ if i == m - 1:
849
+ subgroups.append(array[start_index:])
850
+ else:
851
+ subgroups.append(array[start_index:end_index])
852
+
853
+ return subgroups
854
+ ################################################################################################
855
+
856
+ ################################################################################################
857
+ def is_decending(self, array):
858
+ # Iterate through the array and compare adjacent elements
859
+ for i in range(len(array) - 1):
860
+ if array[i] < array[i + 1]:
861
+ return False
862
+ return True
863
+ ################################################################################################
864
+
865
+ ################################################################################################
866
+ def get_per_res_cuts(self, df, cutkey):
867
+ per_res_cuts_df = {'gene':[], 'per_res_cuts':[]}
868
+ for gene, gene_df in df.groupby('gene'):
869
+ per_res_cuts_df['gene'] += [gene]
870
+ per_res_cuts_df['per_res_cuts'] += [np.sum(gene_df[cutkey])/len(gene_df)]
871
+ per_res_cuts_df = pd.DataFrame(per_res_cuts_df)
872
+ self.logger.debug(per_res_cuts_df)
873
+ return per_res_cuts_df
874
+ ################################################################################################
875
+
876
+ #########################################################################################################################
877
+ #########################################################################################################################
878
+ class FoldingPathwayStats:
879
+
880
+ """
881
+ A class to analyze the folding pathway statistics resulting from the Markov State Model (MSM) of folding generated by EntDetect.clustering.MSMNonNativeEntanglementClustering
882
+ """
883
+
884
+ ####################################################################
885
+ def __init__(self, outdir:str='./FoldingPathwayStats',
886
+ n_window:int=200,
887
+ n_traj:int=1000,
888
+ tarj_type_col:str='traj_type',
889
+ traj_type_list:list=['A', 'B'],
890
+ msm_data:pd.DataFrame()=None,
891
+ meta_set_file:str='',
892
+ state_type:str='metastablestate',
893
+ rm_traj_list:list=[], log_level:int=logging.INFO, logdir:str=None):
894
+ """
895
+ Initializes the DataAnalysis class with necessary paths and parameters.
896
+
897
+ Parameters:
898
+ - outdir (str): Path to the output directory.
899
+ """
900
+ self.outdir = outdir
901
+ self.n_window = n_window
902
+ self.n_traj = n_traj
903
+ self.tarj_type_col = tarj_type_col
904
+ self.traj_type_list = traj_type_list
905
+ self.msm_data = msm_data
906
+ self.meta_set_file = meta_set_file
907
+ self.state_type = state_type
908
+ self.rm_traj_list = [str(t) for t in rm_traj_list] # Ensure rm_traj_list is a list of strings
909
+ if self.state_type not in ['microstate', 'metastablestate']:
910
+ raise ValueError(f'state_type must be [microstate | metastablestate]')
911
+
912
+ self.logger = setup_logger('FoldingPathwayStats', outdir=logdir if logdir is not None else self.outdir, log_level=log_level)
913
+ if not os.path.exists(f'{self.outdir}'):
914
+ os.makedirs(f'{self.outdir}')
915
+ self.logger.debug(f'Made output directories {self.outdir}')
916
+
917
+ ## set delta time between frames in experimental seconds
918
+ dt = 0.015/1000 # time step
919
+ nsave = 5000 # number of steps between frames
920
+ alpha = 4331293.0
921
+ self.dt = dt*nsave*alpha/1e9 # in seconds
922
+ ####################################################################
923
+
924
+ ####################################################################
925
+ def post_trans(self, ):
926
+ """
927
+ The folding pathways are identified as following:
928
+
929
+ (1) For each discrete trajectory, put the starting state at the first frame into the pathway;
930
+
931
+ (2) Move forward along the trajectory and find the state that is different with the last state recoded in the pathway.
932
+ If the state has not yet been recorded in the pathway, then put it into the pathway.
933
+ Otherwise, cut the pathway at the first place where this state is recorded and then move forward;
934
+
935
+ (3) Repeat step (2) until reach the end of the trajectory.
936
+
937
+ This will yield a pathway that has no loop on the route and only records the on-pathway states for each discrete trajectory.
938
+ """
939
+ # print(f'Loading MSM data from {self.msm_data}')
940
+ # msm_data = pd.read_csv(self.msm_data)
941
+ # print(f'msm_data:\n{msm_data}')
942
+ msm_data = self.msm_data[~self.msm_data['traj'].isin(self.rm_traj_list)]
943
+ self.logger.info(f'msm_data:\n{msm_data}')
944
+
945
+ folding_pathways = {}
946
+ folding_pathways_df = {self.tarj_type_col:[], 'pathway':[], 'probability':[]}
947
+ for traj_type in self.traj_type_list:
948
+ self.logger.info(f'Processing {traj_type} trajectories')
949
+ traj_type_msm_data = msm_data[msm_data[self.tarj_type_col] == traj_type]
950
+ self.logger.info(f'traj_type_msm_data:\n{traj_type_msm_data}')
951
+
952
+ # Quality check that there is data
953
+ if traj_type_msm_data.empty:
954
+ raise ValueError(f"No data found for trajectory type: {traj_type}")
955
+
956
+ # start_states = [0]
957
+ end_states = traj_type_msm_data[self.state_type].unique()
958
+ # print(f'Start states: {start_states}, End states: {end_states}')
959
+
960
+ pathways = {}
961
+ states_on_pathway = []
962
+ # start_states = [str(s+1) for s in start_states]
963
+ end_states = [str(s+1) for s in end_states]
964
+
965
+ for traj, traj_df in traj_type_msm_data.groupby('traj'):
966
+ # print(f'Analyzing folding pathway of traj: {traj}')
967
+ # print(traj_df)
968
+
969
+ path = []
970
+ md = traj_df[self.state_type].values ## get the state labels
971
+
972
+ # (1) For each discrete trajectory, put the starting state at the first frame into the pathway;
973
+
974
+ # (2) Move forward along the trajectory and find the state that is different with the last state recoded in the pathway.
975
+ # If the state has not yet been recorded in the pathway, then put it into the pathway.
976
+ # Otherwise, cut the pathway at the first place where this state is recorded and then move forward;
977
+ start_state = str(md[0]+1) # +1 because the states are 0-indexed in the data
978
+
979
+ path.append(str(md[0]+1))
980
+ for mdi in md[1:]:
981
+ tag_find = False
982
+ for pi in range(len(path)):
983
+ if path[pi] == str(mdi+1):
984
+ path = path[0:pi+1]
985
+ tag_find = True
986
+ break
987
+ if not tag_find:
988
+ path.append(str(mdi+1))
989
+
990
+
991
+ if path[-1] not in end_states:
992
+ continue
993
+
994
+ for p in path:
995
+ if not int(p) in states_on_pathway:
996
+ states_on_pathway.append(int(p))
997
+
998
+ path = ' -> '.join(path)
999
+ # print(f'path: {path}')
1000
+ if not path in pathways.keys():
1001
+ pathways[path] = 1
1002
+ else:
1003
+ pathways[path] += 1
1004
+
1005
+ for k, v in pathways.items():
1006
+ self.logger.info(f'{k}: {v}')
1007
+
1008
+ tot_num = 0
1009
+ for path in pathways.keys():
1010
+ tot_num += pathways[path]
1011
+ for path in pathways.keys():
1012
+ pathways[path] /= tot_num
1013
+
1014
+ sort_pathways = sorted(pathways.items(), key=lambda x: x[1], reverse=True)
1015
+
1016
+ states_on_pathway.sort()
1017
+ folding_pathways[traj_type] = {'pathways': sort_pathways, 'states_on_pathway': states_on_pathway}
1018
+
1019
+ for path, prob in sort_pathways:
1020
+ folding_pathways_df[self.tarj_type_col].append(traj_type)
1021
+ folding_pathways_df['pathway'].append(path)
1022
+ folding_pathways_df['probability'].append(prob)
1023
+
1024
+ # print(folding_pathways)
1025
+ folding_pathways_df = pd.DataFrame(folding_pathways_df)
1026
+ self.logger.debug(folding_pathways_df)
1027
+ outfile = os.path.join(self.outdir, f'FoldingPathways_{self.state_type}_{"-".join(self.traj_type_list)}.csv')
1028
+ folding_pathways_df.to_csv(outfile, index=False)
1029
+ self.logger.info(f'Saved folding pathways to {outfile}')
1030
+
1031
+ return folding_pathways_df
1032
+ ####################################################################
1033
+
1034
+ ####################################################################
1035
+ def JS_divergence(self,):
1036
+
1037
+ # Load MSM data
1038
+ # print(f'Loading MSM data from {self.msm_data}')
1039
+ # msm_data = pd.read_csv(self.msm_data)
1040
+ # print(f'msm_data:\n{msm_data}')
1041
+ # msm_data = msm_data[~msm_data['traj'].isin(self.rm_traj_list)]
1042
+ msm_data = self.msm_data[~self.msm_data['traj'].isin(self.rm_traj_list)]
1043
+ self.logger.info(f'msm_data:\n{msm_data}')
1044
+
1045
+ dtrajs = np.asarray([msm_data[msm_data['traj'] == t][self.state_type].values for t in msm_data['traj'].unique()])
1046
+ self.logger.info(f'dtrajs shape: {dtrajs.shape}\n{dtrajs}')
1047
+
1048
+ # Load the meta set
1049
+ meta_set_df = pd.read_csv(self.meta_set_file)
1050
+ self.logger.info(f'meta_set_df:\n{meta_set_df}')
1051
+ meta_set = [s['microstates'].values for i, s in meta_set_df.groupby('metastable_state')]
1052
+ self.logger.debug(f'meta_set: {meta_set}')
1053
+
1054
+ # Get the number of states present for the state type
1055
+ n_states = len(np.unique(dtrajs))
1056
+ self.logger.info('Number of %sstates: %d'%(self.state_type, n_states))
1057
+
1058
+ # Get the max number of frames
1059
+ max_T_len = dtrajs.shape[1]
1060
+ self.logger.debug(f'max_T_len: {max_T_len}')
1061
+
1062
+ # make list of traj_idx for each mutant type
1063
+ mtype2trajid = {traj_type: [] for traj_type in self.traj_type_list}
1064
+ for i_ax, (traj, traj_df) in enumerate(msm_data.groupby('traj')):
1065
+ traj_type = traj_df[self.tarj_type_col].values[0]
1066
+ if traj_type in mtype2trajid:
1067
+ mtype2trajid[traj_type].append(i_ax)
1068
+ self.logger.debug(f'mtype2trajid: {mtype2trajid}')
1069
+
1070
+ # analysis MSM for each mutant
1071
+ P_list = []
1072
+ for i_ax, traj_type in enumerate(self.traj_type_list):
1073
+ dtrajs_0 = dtrajs[mtype2trajid[traj_type]]
1074
+ self.logger.info(f'Processing {traj_type} trajectories with {len(dtrajs_0)} trajectories')
1075
+ P_list_0 = np.zeros((max_T_len, n_states))
1076
+ for i in range(len(dtrajs_0)):
1077
+ # print(f'Processing trajectory {dtrajs_0[i][-self.n_window:]}')
1078
+ (N, be) = np.histogram(dtrajs_0[i][-self.n_window:], bins=np.arange(-0.5, n_states, 1))
1079
+ dtraj_last = np.argwhere(N == np.max(N))[0][0]
1080
+ # print(f'Traj {i} last state histogram: {N} {np.max(N)},\n be: {be},\n dtraj_last: {dtraj_last}')
1081
+
1082
+ for j in range(max_T_len):
1083
+ if j >= len(dtrajs_0[i]):
1084
+ state_0 = dtraj_last
1085
+ else:
1086
+ state_0 = dtrajs_0[i][j]
1087
+ P_list_0[j,state_0] += 1
1088
+ P_list.append(P_list_0)
1089
+ self.logger.debug(P_list_0)
1090
+
1091
+ # Jensen-Shannon divergence
1092
+ JS_list = []
1093
+ if self.state_type == 'microstate':
1094
+ for ms in meta_set:
1095
+ P_list_0 = []
1096
+ for i_ax, traj_type in enumerate(self.traj_type_list):
1097
+ P = np.copy(P_list[i_ax][:,ms])
1098
+ for i in range(len(P)):
1099
+ if np.sum(P[i,:]) != 0:
1100
+ P[i,:] = P[i,:] / np.sum(P[i,:])
1101
+ else:
1102
+ P[i,0] = 1
1103
+ P_list_0.append(P)
1104
+ M_0 = 0.5 * (P_list_0[0] + P_list_0[1])
1105
+ JS_list.append(0.5 * (entropy(P_list_0[0], M_0, axis=1) + entropy(P_list_0[1], M_0, axis=1)))
1106
+
1107
+ P = P_list[0]
1108
+ for i in range(len(P)):
1109
+ if np.sum(P[i,:]) != 0:
1110
+ P[i,:] = P[i,:] / np.sum(P[i,:])
1111
+ self.logger.debug(f'P: {P} {P.shape}')
1112
+
1113
+ Q = P_list[1]
1114
+ for i in range(len(Q)):
1115
+ if np.sum(Q[i,:]) != 0:
1116
+ Q[i,:] = Q[i,:] / np.sum(Q[i,:])
1117
+ self.logger.debug(f'Q: {Q} {Q.shape}')
1118
+
1119
+ M = 0.5 * (P + Q)
1120
+ self.logger.debug(M)
1121
+ entropy_arr = 0.5 * (entropy(P, M, axis=1) + entropy(Q, M, axis=1))
1122
+ self.logger.info(f'entropy_arr (n={len(entropy_arr)}): {entropy_arr}')
1123
+ # entropy_arr = 0.5 * (entropy(P_list[0], M, axis=1) + entropy(P_list[1], M, axis=1))
1124
+
1125
+ JS_list.append(entropy_arr)
1126
+ JS_list = np.array(JS_list).T
1127
+
1128
+ ## write the output file
1129
+ outfile = os.path.join(self.outdir, f'JS_div_{self.state_type}_{"-".join(self.traj_type_list)}.dat')
1130
+ fo = open(outfile, 'w')
1131
+ fo.write('%10s '%('Time(s)'))
1132
+ for j in range(JS_list.shape[1]-1):
1133
+ fo.write('%10s '%('P%d'%(j+1)))
1134
+ fo.write('%10s\n'%('JSD'))
1135
+ for i in range(max_T_len):
1136
+ fo.write('%10.4f '%((i+1)*self.dt))
1137
+ for j in range(JS_list.shape[1]):
1138
+ fo.write('%10.4f '%(JS_list[i,j]))
1139
+ fo.write('\n')
1140
+ fo.close()
1141
+ self.logger.debug(f'SAVED: {outfile}')
1142
+ ####################################################################
1143
+
1144
+ #########################################################################################################################
1145
+ #########################################################################################################################
1146
+ class MSMStats:
1147
+
1148
+ """
1149
+ A class to analyze various statistical properties of the Markov State Model (MSM) of folding generated by EntDetect.clustering.MSMNonNativeEntanglementClustering
1150
+ """
1151
+
1152
+ ####################################################################
1153
+ def __init__(self, outdir:str='./MSMStats',
1154
+ n_window:int=200,
1155
+ n_traj:int=1000,
1156
+ traj_type_list:list=['A', 'B'],
1157
+ tarj_type_col:str='traj_type',
1158
+ msm_data_file:str='',
1159
+ meta_set_file:str='',
1160
+ state_type:str='metastablestate',
1161
+ rm_traj_list:list=[], log_level:int=logging.INFO, logdir:str=None):
1162
+ """
1163
+ Initializes the DataAnalysis class with necessary paths and parameters.
1164
+
1165
+ Parameters:
1166
+ - outdir (str): Path to the output directory.
1167
+ """
1168
+ self.outdir = outdir
1169
+ self.n_window = n_window
1170
+ self.n_traj = n_traj
1171
+ self.traj_type_list = traj_type_list
1172
+ self.tarj_type_col = tarj_type_col
1173
+ self.msm_data_file = msm_data_file
1174
+ self.meta_set_file = meta_set_file
1175
+ self.state_type = state_type
1176
+ self.rm_traj_list = [str(t) for t in rm_traj_list] # Ensure rm_traj_list is a list of strings
1177
+ if self.state_type not in ['microstate', 'metastablestate']:
1178
+ raise ValueError(f'state_type must be [microstate | metastablestate]')
1179
+
1180
+ self.logger = setup_logger('MSMStats', outdir=logdir if logdir is not None else self.outdir, log_level=log_level)
1181
+ if not os.path.exists(f'{self.outdir}'):
1182
+ os.makedirs(f'{self.outdir}')
1183
+ self.logger.debug(f'Made output directories {self.outdir}')
1184
+
1185
+ ## set delta time between frames in experimental seconds
1186
+ self.dt = 0.015/1000 # time step
1187
+ nsave = 5000 # number of steps between frames
1188
+ alpha = 4331293.0
1189
+ self.dt = self.dt*nsave*alpha/1e9 # in seconds
1190
+ self.end_t = 60 # in seconds
1191
+ self.n_boot = 100
1192
+ self.num_proc = 20
1193
+ self.num_points_plot = 1000
1194
+ self.if_boot = True
1195
+ ####################################################################
1196
+
1197
+ ####################################################################
1198
+ def StateProbabilityStats(self, ):
1199
+ """
1200
+ This function calculates the state probabilities from the MSM data and saves them to a file.
1201
+ It also calculates the bootstrap statistics for the state probabilities.
1202
+ """
1203
+ outfile = os.path.join(self.outdir, f'MSTS.csv')
1204
+ if os.path.exists(outfile):
1205
+ self.logger.info(f'File {outfile} already exists. Skipping calculation.')
1206
+ df = pd.read_csv(outfile)
1207
+ self.logger.info(f'Loaded existing data from {outfile}')
1208
+ self.logger.info(f'df:\n{df}')
1209
+ return df
1210
+
1211
+ else:
1212
+ # Load MSM data
1213
+ self.logger.info(f'Loading MSM data from {self.msm_data_file}')
1214
+ msm_data = pd.read_csv(self.msm_data_file)
1215
+ self.logger.info(f'msm_data:\n{msm_data}')
1216
+ msm_data = msm_data[~msm_data['traj'].isin(self.rm_traj_list)]
1217
+ self.logger.info(f'msm_data:\n{msm_data}')
1218
+ dtrajs = np.asarray([msm_data[msm_data['traj'] == t][self.state_type].values for t in msm_data['traj'].unique()])
1219
+ self.logger.info(f'dtrajs shape: {dtrajs.shape}\n{dtrajs}')
1220
+
1221
+ # npzfile = np.load(msm_data_file, allow_pickle=True)
1222
+
1223
+ max_T_len = int(np.ceil(self.end_t/self.dt))
1224
+ self.logger.debug(f'max_T_len: {max_T_len}')
1225
+ interval = int(max_T_len / self.num_points_plot)
1226
+ self.logger.debug(f'interval: {interval}')
1227
+ sample_idx = [max_T_len-1-i*interval for i in range(int(max_T_len/interval), -1, -1)]
1228
+ if sample_idx[0] != 0:
1229
+ sample_idx = [0] + sample_idx
1230
+ self.logger.debug(f'sample_idx: {sample_idx}')
1231
+
1232
+ # dtrajs = npzfile['meta_dtrajs']
1233
+ n_states = 0
1234
+ tag_error = False
1235
+ for i, md in enumerate(dtrajs):
1236
+ if n_states < np.max(md):
1237
+ n_states = np.max(md)
1238
+ if len(md) < max_T_len:
1239
+ #print(f"WARNING: Traj #{i+1} stopped early")
1240
+ tag_error = True
1241
+ n_states += 1
1242
+ self.logger.debug(f'n_states: {n_states}')
1243
+
1244
+ MSTS_list = []
1245
+ boot_stat_list = []
1246
+ df = {self.tarj_type_col:[], 'Time(s)':[], 'State':[], 'Probability':[], 'Lower CI':[], 'Upper CI':[]}
1247
+ for i_ax, traj_type in enumerate(self.traj_type_list):
1248
+ self.logger.info(f'Processing {traj_type} trajectories')
1249
+
1250
+ ## Get the dtrajs for this trajectory type
1251
+ meta_dtrajs = msm_data[msm_data[self.tarj_type_col] == traj_type]
1252
+ self.logger.info(f'meta_dtrajs:\n{meta_dtrajs}')
1253
+ meta_dtrajs = np.asarray([meta_dtrajs[meta_dtrajs['traj'] == t][self.state_type].values for t in meta_dtrajs['traj'].unique()])
1254
+ self.logger.info(f'meta_dtrajs:\n{meta_dtrajs} shape: {meta_dtrajs.shape}')
1255
+
1256
+
1257
+ # MSTS
1258
+ PPT = np.zeros((meta_dtrajs.shape[1], n_states))
1259
+ self.logger.info(f'PPT shape: {PPT.shape}')
1260
+ t_span = (np.arange(meta_dtrajs.shape[1])+1)*self.dt
1261
+ self.logger.debug(f't_span: {t_span} shape: {t_span.shape}')
1262
+
1263
+ for md in meta_dtrajs:
1264
+ for i in range(len(t_span)):
1265
+ PPT[i,md[i]]+=1
1266
+ PPT /= len(meta_dtrajs)
1267
+ self.logger.info(f'PPT:\n{PPT} {PPT.shape}')
1268
+
1269
+
1270
+ ## Bootstrapping
1271
+ self.logger.info(f'Bootstrapping 95% ci...')
1272
+ boot_arrs = []
1273
+ if self.if_boot:
1274
+ for booti in range(self.n_boot):
1275
+ sample_idx = np.random.choice(np.arange(len(meta_dtrajs)), len(meta_dtrajs), replace=True)
1276
+ boot_meta_dtrajs = meta_dtrajs[sample_idx,:]
1277
+
1278
+ boot_PPT = np.zeros((boot_meta_dtrajs.shape[1], n_states))
1279
+ for md in boot_meta_dtrajs:
1280
+ for i in range(len(t_span)):
1281
+ boot_PPT[i,md[i]]+=1
1282
+ boot_PPT /= len(boot_meta_dtrajs)
1283
+ # print(f'{booti} boot_PPT:\n{boot_PPT} {boot_PPT.shape}')
1284
+
1285
+ boot_arrs.append(boot_PPT)
1286
+
1287
+ boot_arrs = np.array(boot_arrs)
1288
+ self.logger.info(f'boot_arrs shape: {boot_arrs.shape}')
1289
+ lower = np.percentile(boot_arrs, 2.5, axis=0)
1290
+ upper = np.percentile(boot_arrs, 97.5, axis=0)
1291
+ self.logger.info(f'lower:\n{lower} shape: {lower.shape}')
1292
+ self.logger.info(f'upper:\n{upper} shape: {upper.shape}')
1293
+
1294
+ ## make the output dataframe
1295
+ for i in range(PPT.shape[0]):
1296
+ for j in range(n_states):
1297
+ df[self.tarj_type_col].append(traj_type)
1298
+ df['Time(s)'].append(t_span[i])
1299
+ df['State'].append(j+1) # +1 because states are 0-indexed in the data
1300
+ df['Probability'].append(PPT[i,j])
1301
+ df['Lower CI'].append(lower[i,j])
1302
+ df['Upper CI'].append(upper[i,j])
1303
+ df = pd.DataFrame(df)
1304
+ df = df.sort_values(by=[self.tarj_type_col, 'State', 'Time(s)'])
1305
+ self.logger.debug(df)
1306
+
1307
+ ## save the dataframe to a csv file
1308
+ df.to_csv(outfile, index=False)
1309
+ self.logger.debug(f'SAVED: {outfile}')
1310
+ return df
1311
+ ####################################################################
1312
+
1313
+ ####################################################################
1314
+ def Plot_StateProbabilityStats(self, df=None):
1315
+ """
1316
+ This function plots the state probabilities from the MSM data.
1317
+ Makes one plot for each traj_type.
1318
+ It also adds the confidence intervals.
1319
+ """
1320
+ if df is None:
1321
+ df = self.StateProbabilityStats()
1322
+
1323
+ for traj_type in self.traj_type_list:
1324
+ outfile = os.path.join(self.outdir, f'{traj_type}_MSTS_plot.png')
1325
+ self.logger.info(f'Plotting state probabilities to {outfile}')
1326
+ plt.figure(figsize=(10, 6))
1327
+
1328
+ traj_df = df[df[self.tarj_type_col] == traj_type]
1329
+ for state in traj_df['State'].unique():
1330
+ state_df = traj_df[traj_df['State'] == state]
1331
+ plt.plot(state_df['Time(s)'], state_df['Probability'], label=f'{traj_type} State {state}')
1332
+ plt.fill_between(state_df['Time(s)'], state_df['Lower CI'], state_df['Upper CI'], alpha=0.2)
1333
+
1334
+ plt.xlabel('Time (s)')
1335
+ plt.ylabel('Probability')
1336
+ plt.title('State Probabilities Over Time')
1337
+ plt.legend()
1338
+ plt.grid()
1339
+ plt.savefig(outfile)
1340
+ plt.close()
1341
+ self.logger.debug(f'SAVED: {outfile}')
1342
+ ####################################################################
1343
+ #########################################################################################################################
1344
+ #########################################################################################################################