cocoatree 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. cocoatree/__init__.py +8 -0
  2. cocoatree/__params.py +80 -0
  3. cocoatree/_pipeline.py +144 -0
  4. cocoatree/_scraper.py +23 -0
  5. cocoatree/_version.py +1 -0
  6. cocoatree/datasets/__init__.py +3 -0
  7. cocoatree/datasets/_base.py +188 -0
  8. cocoatree/datasets/data/DHFR/3QL0.pdb +3507 -0
  9. cocoatree/datasets/data/DHFR/DHFR_sectors.npz +0 -0
  10. cocoatree/datasets/data/DHFR/alignment.faa.gz +0 -0
  11. cocoatree/datasets/data/S1A_serine_proteases/3tgi.pdb +2844 -0
  12. cocoatree/datasets/data/S1A_serine_proteases/halabi_alignment.fasta +20580 -0
  13. cocoatree/datasets/data/S1A_serine_proteases/halabi_metadata.csv +1471 -0
  14. cocoatree/datasets/data/S1A_serine_proteases/halabi_sectors.npz +0 -0
  15. cocoatree/datasets/data/S1A_serine_proteases/rivoire_alignment.fasta +19460 -0
  16. cocoatree/datasets/data/S1A_serine_proteases/rivoire_metadata.csv +1391 -0
  17. cocoatree/datasets/data/S1A_serine_proteases/rivoire_sectors.npz +0 -0
  18. cocoatree/datasets/data/rhomboid_proteases/2NRF.pdb +3300 -0
  19. cocoatree/datasets/data/rhomboid_proteases/Data_S1_Rhomboid_MSA_short_names.fasta +5534 -0
  20. cocoatree/datasets/data/rhomboid_proteases/rhomboid_metadata_clean.csv +2766 -0
  21. cocoatree/datasets/data/rhomboid_proteases/rhomboid_sectors.npz +0 -0
  22. cocoatree/datasets/tests/test_datasets.py +14 -0
  23. cocoatree/decomposition.py +263 -0
  24. cocoatree/io.py +185 -0
  25. cocoatree/msa.py +579 -0
  26. cocoatree/pysca.py +238 -0
  27. cocoatree/randomize.py +30 -0
  28. cocoatree/scripts/cocoatree-sca.py +6 -0
  29. cocoatree/statistics/__init__.py +58 -0
  30. cocoatree/statistics/pairwise.py +318 -0
  31. cocoatree/statistics/position.py +258 -0
  32. cocoatree/tests/test_init.py +24 -0
  33. cocoatree/tests/test_msa.py +14 -0
  34. cocoatree/visualization.py +440 -0
  35. cocoatree-0.1.0.dist-info/METADATA +66 -0
  36. cocoatree-0.1.0.dist-info/RECORD +39 -0
  37. cocoatree-0.1.0.dist-info/WHEEL +5 -0
  38. cocoatree-0.1.0.dist-info/licenses/LICENSE +28 -0
  39. cocoatree-0.1.0.dist-info/top_level.txt +1 -0
cocoatree/pysca.py ADDED
@@ -0,0 +1,238 @@
1
+ import numpy as np
2
+ import sys
3
+ from scipy.stats import t, scoreatpercentile
4
+
5
+
6
+ def _basicICA(x, r0, Niter, tolerance=1e-15):
7
+ """
8
+ Basic ICA algorithm, based on work by Bell & Sejnowski (infomax). The input
9
+ data should preferentially be sphered, i.e., x.T.dot(x) = 1
10
+ Source: https://github.com/ranganathanlab/pySCA/
11
+
12
+ Parameters
13
+ ----------
14
+ x : LxM input matrix where L = # features and M = # samples
15
+
16
+ r : learning rate / relaxation parameter (e.g. r=.0001)
17
+
18
+ Niter : number of iterations (e.g. 1000)
19
+
20
+ Returns
21
+ -------
22
+ w : unmixing matrix
23
+
24
+ change: record of incremental changes during the iterations.
25
+
26
+ **Note** r and Niter should be adjusted to achieve convergence, which
27
+ should be assessed by visualizing 'change' with plot(range(iter), change)
28
+ **Example**::
29
+ [w, change] = basicICA(x, r, Niter)
30
+ """
31
+
32
+ [L, M] = x.shape
33
+ w = np.eye(L)
34
+ change = list()
35
+ r = r0 / M
36
+ with np.errstate(over="raise"):
37
+ try:
38
+ for _ in range(Niter):
39
+ w_old = np.copy(w)
40
+ u = w.dot(x)
41
+ w += r * (
42
+ M * np.eye(L) + (1.0 - 2.0 / (1.0 + np.exp(-u))).dot(u.T)
43
+ ).dot(w)
44
+ delta = (w - w_old).ravel()
45
+ val = delta.dot(delta.T)
46
+ change.append(val)
47
+ if np.isclose(val, 0, atol=tolerance):
48
+ break
49
+ if _ == Niter - 1:
50
+ print("basicICA failed to converge: " + str(val))
51
+ except FloatingPointError as e:
52
+ sys.exit("Error: basicICA " + str(e))
53
+ return [w, change]
54
+
55
+
56
+ def _compute_ica(V, kmax=6, learnrate=0.1, iterations=10000):
57
+ """
58
+ ICA rotation (using _basicICA) with default parameters and normalization of
59
+ outputs.
60
+ Basic ICA algorithm, based on work by Bell & Sejnowski (infomax). The input
61
+ data should preferentially be sphered, i.e., x.T.dot(x) = 1
62
+
63
+ Source: https://github.com/ranganathanlab/pySCA/
64
+
65
+ Parameters
66
+ ----------
67
+ V : ndarray,
68
+ eigenvectors obtained after matrix decomposition
69
+
70
+ kmax : integer,
71
+ number of independent components to retrieve
72
+
73
+ learnrate : integer,
74
+ learning rate / relaxation parameter
75
+
76
+ iterations : integer,
77
+ number of iterations
78
+
79
+ **Note** r and Niter should be adjusted to achieve convergence, which
80
+ should be assessed by visualizing 'change' with plot(range(iter), change)
81
+
82
+ Returns
83
+ -------
84
+ Vica : ndarray,
85
+ contributions along each independent components
86
+
87
+ W : ndarray of shape (kmax, kmax),
88
+ unmixing matrix
89
+
90
+ **Example**::
91
+ Vica, W = rotICA(V, kmax=6, learnrate=.0001, iterations=10000)
92
+ """
93
+
94
+ V1 = V[:, :kmax].T
95
+ [W, changes] = _basicICA(V1, learnrate, iterations)
96
+ Vica = (W.dot(V1)).T
97
+ for n in range(kmax):
98
+ imax = abs(Vica[:, n]).argmax()
99
+ Vica[:, n] = (
100
+ np.sign(Vica[imax, n]) * Vica[:, n] / np.linalg.norm(Vica[:, n])
101
+ )
102
+ return Vica, W
103
+
104
+
105
+ class Unit:
106
+ """
107
+ A class for units (sectors, sequence families, etc.)
108
+
109
+ Attributes
110
+ ----------
111
+
112
+ name : string describing the unit (ex: 'firmicutes')
113
+ items : set of member items (ex: indices for all firmicutes
114
+ sequence in an alignment)
115
+ col : color code associated to the unit (for plotting)
116
+ vect : an additional vector describing the member items (ex: a list
117
+ of sequence weights)
118
+ """
119
+
120
+ def __init__(self):
121
+ self.name = ""
122
+ self.items = set()
123
+ self.col = 0
124
+ self.vect = 0
125
+
126
+
127
+ def _icList(Vica, n_component, Cij, p_cut=0.95):
128
+ """
129
+ Produces a list of positions contributing to each independent component
130
+ (IC) above a defined statistical cutoff (p_cut, the cutoff on the CDF of
131
+ the t-distribution fit to the histogram of each IC). Any position above the
132
+ cutoff on more than one IC are assigned to one IC based on which group of
133
+ positions to which it shows a higher degree of coevolution. Additionally
134
+ returns the numeric value of the cutoff for each IC, and the pdf fit, which
135
+ can be used for plotting/evaluation.
136
+
137
+ Parameters
138
+ ----------
139
+ Vica : ndarray,
140
+ independent components
141
+
142
+ n_component : int,
143
+ number of independent components chosen
144
+
145
+ Cij : numpy.ndarray,
146
+ coevolution matrix
147
+
148
+ p_cut : int,
149
+ cutoff on the CDF of the t-distribution fit to the histogran of each IC
150
+
151
+ Returns
152
+ -------
153
+ selected_res : list of cocoatree.decomposition.Unit,
154
+ positions of the selected residues for each independent component.
155
+ Beware that if the alignment used for the analysis has been filtered,
156
+ those are the positions on the filtered alignment and not on the
157
+ original alignment, a mapping of the positions may be needed.
158
+
159
+ ic_size : list,
160
+ number of selected residues for each component.
161
+
162
+ sorted_pos : list,
163
+ positions of the residues sorted by decreasing contribution for each
164
+ component.
165
+
166
+ cutoff : list,
167
+ numeric value of the cutoff for each component.
168
+
169
+ scaled_pdf : list of np.ndarrays,
170
+ scaled probability distribution function for each component.
171
+
172
+ all_fits : list,
173
+ t-distribution fits for each component.
174
+
175
+ **Example**::
176
+ selected_res, ic_size, sorted_pos, cutoff, scaled_pdf, all_fits = \
177
+ icList(Vica, n_component, Cij, p_cut=0.95)
178
+ """
179
+
180
+ # do the PDF/CDF fit, and assign cutoffs
181
+ Npos = len(Vica)
182
+ cutoff = list()
183
+ scaled_pdf = list()
184
+ all_fits = list()
185
+ for k in range(n_component):
186
+ pd = t.fit(Vica[:, k])
187
+ all_fits.append(pd)
188
+ iqr = scoreatpercentile(Vica[:, k], 75) - scoreatpercentile(
189
+ Vica[:, k], 25
190
+ )
191
+ binwidth = 2 * iqr * (len(Vica[:, k]) ** (-0.33))
192
+ nbins = round((max(Vica[:, k]) - min(Vica[:, k])) / binwidth)
193
+ h_params = np.histogram(Vica[:, k], int(nbins))
194
+ x_dist = np.linspace(min(h_params[1]), max(h_params[1]), num=100)
195
+ area_hist = Npos * (h_params[1][2] - h_params[1][1])
196
+ scaled_pdf.append(area_hist * (t.pdf(x_dist, pd[0], pd[1], pd[2])))
197
+ cd = t.cdf(x_dist, pd[0], pd[1], pd[2])
198
+ tmp = scaled_pdf[k].argmax()
199
+ if abs(max(Vica[:, k])) > abs(min(Vica[:, k])):
200
+ tail = cd[tmp: len(cd)]
201
+ else:
202
+ cd = 1 - cd
203
+ tail = cd[0:tmp]
204
+ diff = abs(tail - p_cut)
205
+ x_pos = diff.argmin()
206
+ cutoff.append(x_dist[x_pos + tmp])
207
+
208
+ # select the positions with significant contributions to each IC
209
+ ic_init = list()
210
+ for k in range(n_component):
211
+ ic_init.append([i for i in range(Npos) if Vica[i, k] > cutoff[k]])
212
+
213
+ # construct the sorted, non-redundant iclist
214
+ sorted_pos = list()
215
+ ic_size = list()
216
+ selected_res = list()
217
+ icpos_tmp = list()
218
+ Cij_nodiag = Cij.copy()
219
+ for i in range(Npos):
220
+ Cij_nodiag[i, i] = 0
221
+ for k in range(n_component):
222
+ icpos_tmp = list(ic_init[k])
223
+ for kprime in [kp for kp in range(n_component) if kp != k]:
224
+ tmp = [v for v in icpos_tmp if v in ic_init[kprime]]
225
+ for i in tmp:
226
+ remsec = np.linalg.norm(
227
+ Cij_nodiag[i, ic_init[k]]
228
+ ) < np.linalg.norm(Cij_nodiag[i, ic_init[kprime]])
229
+ if remsec:
230
+ icpos_tmp.remove(i)
231
+ sorted_pos += sorted(icpos_tmp, key=lambda i: -Vica[i, k])
232
+ ic_size.append(len(icpos_tmp))
233
+ s = Unit()
234
+ s.items = sorted(icpos_tmp, key=lambda i: -Vica[i, k])
235
+ s.col = k / n_component
236
+ s.vect = -Vica[s.items, k]
237
+ selected_res.append(s)
238
+ return selected_res, ic_size, sorted_pos, cutoff, scaled_pdf, all_fits
cocoatree/randomize.py ADDED
@@ -0,0 +1,30 @@
1
+ """Module to perform randomization of alignments"""
2
+
3
+ import numpy as np
4
+
5
+
6
+ def _randomize_seqs_conserving_col_compo(sequences=[], seed=None):
7
+ """
8
+ Randomize the list of sequenecs (MSA) so that the content of each
9
+ column is overall conserved (conservation of aa frequencies)
10
+
11
+ Parameters
12
+ ----------
13
+ sequences : list of sequences (MSA)
14
+
15
+ seed : int
16
+ to generate exact same list of random numbers
17
+ (mostly for testing )
18
+
19
+ Returns
20
+ -------
21
+ rand_seqs : list of sequences where the columns have been shuffled
22
+ """
23
+
24
+ seq_array = np.array([list(seq) for seq in sequences])
25
+ T = seq_array.T
26
+ rng = np.random.default_rng(seed)
27
+ rand_seq_array = np.array([rng.permutation(T[i]) for i in range(len(T))]).T
28
+ rand_seqs = [''.join(seq) for seq in rand_seq_array]
29
+
30
+ return rand_seqs
@@ -0,0 +1,6 @@
1
+ import cocoatree
2
+ import argparse
3
+
4
+
5
+ def main():
6
+
@@ -0,0 +1,58 @@
1
+ from . import position
2
+ from . import pairwise
3
+ from .. import msa
4
+ from ..__params import __freq_regularization_ref
5
+
6
+
7
+ def compute_all_frequencies(sequences,
8
+ seq_weights=None,
9
+ freq_regul=__freq_regularization_ref):
10
+ """
11
+ Compute frequencies on sequences
12
+
13
+
14
+ Parameters
15
+ ----------
16
+ sequences : list of sequences
17
+
18
+ seq_weights : {None, np.ndarray (n_seq)}
19
+ if None, will re-compute the sequence weights.
20
+
21
+ freq_regul : regularization parameter (default=__freq_regularization_ref)
22
+
23
+ Returns
24
+ -------
25
+ aa_freqs : np.ndarray (nseq, 21)
26
+ A (nseq, 21) ndarray containing the amino acid frequencies at each
27
+ positions.
28
+
29
+ bkgd_freqs : np.ndarray (21, )
30
+ A (21,) np.array containing the background amino acid frequencies
31
+ at each position; it is computed from the mean frequency of amino acid
32
+ a in all proteins in the NCBI non-redundant database
33
+ (see Rivoire et al., https://dx.plos.org/10.1371/journal.pcbi.1004817)
34
+
35
+ aa_joint_freqs : np.ndarray (nseq, nseq, 21, 21)
36
+ An ndarray containing the pairwise joint frequencies of amino acids
37
+ for each pair of positions in the list of provided sequences.
38
+ """
39
+ if seq_weights is None:
40
+ seq_weights, _ = msa.compute_seq_weights(sequences)
41
+
42
+ aa_freqs = position._compute_aa_freqs(
43
+ sequences,
44
+ freq_regul=freq_regul,
45
+ seq_weights=seq_weights)
46
+
47
+ bkgd_freqs = position._compute_background_freqs(
48
+ aa_freqs,
49
+ sequences,
50
+ seq_weights=seq_weights,
51
+ freq_regul=__freq_regularization_ref)
52
+
53
+ aa_joint_freqs = pairwise._compute_aa_joint_freqs(
54
+ sequences,
55
+ seq_weights=seq_weights,
56
+ freq_regul=freq_regul)
57
+
58
+ return aa_freqs, bkgd_freqs, aa_joint_freqs
@@ -0,0 +1,318 @@
1
+ import numpy as np
2
+ from ..__params import lett2num, __freq_regularization_ref, __aa_count
3
+ from ..msa import compute_seq_weights
4
+ from .position import _compute_first_order_freqs
5
+
6
+
7
+ def _compute_aa_joint_freqs(sequences, seq_weights=None,
8
+ freq_regul=__freq_regularization_ref):
9
+ """Computes the joint frequencies of each pair of amino acids in a MSA
10
+
11
+ .. math::
12
+
13
+ f_{ij}^{ab} = (\\sum_s w_s x_{si}^a x_{sj}^b +
14
+ \\lambda/(21)^2)/(M_{eff} + \\lambda)
15
+
16
+ where
17
+
18
+ .. math::
19
+
20
+ M_{eff} = \\sum_s w_s
21
+
22
+ represents the effective number of sequences in the alignment and *lambda*
23
+ is a regularization parameter (pseudocount).
24
+
25
+ Parameters
26
+ ----------
27
+ sequences : list of sequences as imported by load_MSA()
28
+
29
+ seq_weights : numpy 1D array, optional
30
+ Gives more or less importance to certain sequences. If
31
+ seq_weights=None, all sequences are attributed an equal weighti
32
+ of 1.
33
+
34
+ freq_regul : regularization parameter (default=__freq_regularization_ref)
35
+
36
+ Returns
37
+ -------
38
+ aa_joint_freqs : np.ndarray of shape (Npos, Npos, aa_count, aa_count)
39
+ joint frequency of amino acids `a` and `b`
40
+ at respective positions `i` and `j`
41
+ """
42
+
43
+ # Convert sequences to binary format
44
+ tmp = np.array([[char for char in row] for row in sequences])
45
+ binary_array = np.array([tmp == aa for aa in lett2num.keys()]).astype(int)
46
+
47
+ # Adding weights
48
+ if seq_weights is None:
49
+ seq_weights = np.ones(len(sequences))
50
+ weighted_binary_array = binary_array * \
51
+ seq_weights[np.newaxis, :, np.newaxis]
52
+ # number of effective sequences
53
+ m_eff = np.sum(seq_weights)
54
+
55
+ # Joint frequencies
56
+ aa_joint_freqs = np.tensordot(weighted_binary_array, binary_array,
57
+ axes=([1], [1])).transpose(1, 3, 0, 2)
58
+ aa_joint_freqs = (aa_joint_freqs + freq_regul * m_eff / __aa_count ** 2)\
59
+ / ((1 + freq_regul) * m_eff)
60
+ return aa_joint_freqs
61
+
62
+
63
+ def _compute_aa_product_freqs(aa_freqs_1, aa_freqs_2):
64
+ """Computes the product of frequencies
65
+
66
+ (joint frequencies if residues are independent)
67
+
68
+ Parameters
69
+ ----------
70
+ aa_freqs_1 : frequency of amino acid *a* at position *i* (set 1)
71
+
72
+ aa_freqs_2 : frequency of amino acid *a* at position *i* (set 2)
73
+
74
+ Returns
75
+ -------
76
+ aa_prod_freqs : np.ndarray of shape (Npos, Npos, aa_count, aa_count)
77
+ product of frequency of amino acids *a* and $b$
78
+ at respective positions *i* and *j*
79
+ """
80
+
81
+ aa_product_freqs = np.multiply.outer(aa_freqs_1, aa_freqs_2)
82
+ aa_product_freqs = np.moveaxis(aa_product_freqs,
83
+ [0, 1, 2, 3],
84
+ [0, 2, 1, 3])
85
+
86
+ return aa_product_freqs
87
+
88
+
89
+ def _compute_second_order_freqs(sequences, seq_weights=None,
90
+ freq_regul=__freq_regularization_ref):
91
+ """
92
+ Computes joint frequencies and the product of frequencies
93
+
94
+ Parameters
95
+ ----------
96
+ sequences : list of sequences
97
+
98
+ seq_weights : np.ndarray
99
+ weight values for each sequence of the alignment
100
+
101
+ freq_regul : regularization parameter (default=__freq_regularization_ref)
102
+
103
+ Returns
104
+ -------
105
+ aa_joint_freqs : np.ndarray of shape (Npos, Npos, aa_count, aa_count)
106
+ joint frequency of amino acids `a` and `b` at respective positions
107
+ `i` and `j`
108
+
109
+ aa_product_freqs : np.ndarray of shape (Npos, Npos, aa_count, aa_count)
110
+ product of frequency of amino acids `a` and `b` at respective
111
+ positions `i` and `j`
112
+ """
113
+
114
+ # joint frequencies
115
+ aa_joint_freqs = _compute_aa_joint_freqs(sequences,
116
+ seq_weights=seq_weights,
117
+ freq_regul=freq_regul)
118
+
119
+ aa_freqs, _ = _compute_first_order_freqs(
120
+ sequences, seq_weights=seq_weights, freq_regul=freq_regul)
121
+
122
+ # joint frequencies if independence (product of frequencies)
123
+ aa_product_freqs = _compute_aa_product_freqs(aa_freqs, aa_freqs)
124
+
125
+ return aa_joint_freqs, aa_product_freqs
126
+
127
+
128
+ def compute_sca_matrix(sequences, seq_weights=None, raw_correlation=False,
129
+ freq_regul=__freq_regularization_ref):
130
+ """Compute the SCA coevolution matrix
131
+
132
+ .. math::
133
+
134
+ C_{ij}^{ab} = f_{ij}^{ab} - f_i^a f_j^b
135
+
136
+ .. math::
137
+
138
+ \\tilde{C_{ij}} = \\sqrt{sum_{a,b} \\tilde{(C_{ij}^{ab})^2}}
139
+
140
+ Parameters
141
+ ----------
142
+ sequences : list of sequences
143
+
144
+ seq_weights : ndarray (nseq), optional, default: None
145
+ if None, will compute sequence weights
146
+
147
+ raw_correlation : boolean, optional, default: False
148
+ whether to return raw correlations
149
+
150
+ freq_regul : regularization parameter (default=__freq_regularization_ref)
151
+
152
+ Returns
153
+ -------
154
+ SCA_matrix : SCA coevolution matrix
155
+ """
156
+
157
+ # computing frequencies
158
+ if seq_weights is None:
159
+ seq_weights, _ = compute_seq_weights(sequences)
160
+ aa_joint_freqs, aa_product_freqs = _compute_second_order_freqs(
161
+ sequences, seq_weights=seq_weights, freq_regul=freq_regul)
162
+
163
+ # Cijab
164
+ Cijab = aa_joint_freqs - aa_product_freqs
165
+
166
+ if not raw_correlation:
167
+
168
+ # derivative of relative entropy
169
+ aa_freqs, bkgd_freqs = _compute_first_order_freqs(
170
+ sequences, seq_weights=seq_weights, freq_regul=freq_regul)
171
+ aa_freqs = aa_freqs.transpose([1, 0])
172
+ phi = np.log(
173
+ aa_freqs * (1 - bkgd_freqs[:, np.newaxis]) / (
174
+ (1 - aa_freqs) *
175
+ bkgd_freqs[:, np.newaxis])).transpose([1, 0])
176
+ phi = np.multiply.outer(phi, phi).transpose([0, 2, 1, 3])
177
+
178
+ # applying sca positional weights
179
+ Cijab = phi * Cijab
180
+
181
+ # Frobenius norm
182
+ SCA_matrix = np.sqrt(np.sum(Cijab ** 2, axis=(2, 3)))
183
+
184
+ return SCA_matrix
185
+
186
+
187
+ def compute_mutual_information_matrix(sequences, seq_weights=None,
188
+ freq_regul=__freq_regularization_ref,
189
+ normalize=True):
190
+ """Compute the mutual information matrix
191
+
192
+ .. math::
193
+
194
+ I(X, Y) = \\sum_{x,y} p(x, y) \\log \\frac{p(x, y)}{p(x)p(y)}
195
+
196
+ Parameters
197
+ ----------
198
+ sequences : list of sequences
199
+
200
+ seq_weights : ndarray (nseq), optional, default: None
201
+ if None, will compute sequence weights
202
+
203
+ freq_regul : regularization parameter (default=__freq_regularization_ref)
204
+
205
+ normalize : boolean, default : True
206
+ Whether to normalize the mutual information by the entropy.
207
+
208
+ Returns
209
+ -------
210
+ mi_matrix : np.ndarray of shape (nseq, nseq)
211
+ the matrix of mutual information
212
+ """
213
+
214
+ # computing frequencies
215
+ if seq_weights is None:
216
+ seq_weights, _ = compute_seq_weights(sequences)
217
+ aa_joint_freqs, aa_product_freqs = _compute_second_order_freqs(
218
+ sequences, seq_weights=seq_weights,
219
+ freq_regul=freq_regul)
220
+
221
+ # mutual information
222
+ mi_matrix = np.sum(
223
+ aa_joint_freqs * np.log(aa_joint_freqs / aa_product_freqs),
224
+ axis=(2, 3))
225
+
226
+ if normalize:
227
+ joint_entropy = -np.sum(aa_joint_freqs * np.log(aa_joint_freqs),
228
+ axis=(2, 3))
229
+ mi_matrix /= joint_entropy
230
+
231
+ return mi_matrix
232
+
233
+
234
+ def compute_apc(MIij):
235
+ """
236
+ Computes the average product correction (APC) as described in Dunn et
237
+ al. (2008).
238
+
239
+ .. math::
240
+
241
+ APC(a, b) = \\frac{MI(a, \\bar{x}) MI(b, \\bar{x}){\\overline{MI}}
242
+
243
+ where :math:`MI(a, \\bar{x})` is the mean mutual information of column *a*
244
+ and :math:`\\overline{MI}` is the overall mean mutual information
245
+
246
+ The corrected mutual information is then:
247
+
248
+ .. math::
249
+
250
+ MIp(a, b) = MI(a, b) - APC(a, b)
251
+
252
+ Parameters
253
+ ----------
254
+ MIij : np.ndarray,
255
+ the mutual information matrix
256
+
257
+ Returns
258
+ -------
259
+ APC_ij : np.ndarray,
260
+ the average product correction (APC) matrix
261
+
262
+ MIp : np.ndarray,
263
+ the APC corrected mutual information matrix
264
+ """
265
+
266
+ n = MIij.shape[0]
267
+ m = n - 1
268
+ # Replace the matrix diagonal by 0
269
+ np.fill_diagonal(MIij, 0)
270
+
271
+ MI_colmean = (1/m) * np.sum(MIij, axis=0)
272
+ MI_colmean = np.multiply.outer(MI_colmean, MI_colmean)
273
+
274
+ MI_overmean = (2/(m*n)) * np.sum(np.tril(MIij))
275
+
276
+ APC_ij = MI_colmean / MI_overmean
277
+
278
+ MIp = MIij - APC_ij
279
+
280
+ return APC_ij, MIp
281
+
282
+
283
+ def compute_entropy_correction(coevolution_matrix, s):
284
+
285
+ """
286
+ Computes the entropy correction according to Vorberg et al. (2018)
287
+
288
+ .. math::
289
+
290
+ C_{ij}^{EC} = C_{ij} - \\alpha s_{i}^{\\frac{1}{2}} \
291
+ s_{j}^{\\frac{1}{2}}
292
+
293
+ where :math:`\\alpha` is a coefficient determining the strength of the
294
+ correction:
295
+
296
+ .. math::
297
+
298
+ \\alpha = \\frac{\\sum_{i \\neq j}^{L} c_ij \
299
+ s_{i}^{\\frac{1}{2}}}{\\sum_{i \\neq j}^{L} s_i s_j}
300
+
301
+ Parameters
302
+ ----------
303
+ coevolution_matrix : square matrix of shape (Nseq, Nseq)
304
+
305
+ s : entropy computed for every position of the MSA
306
+
307
+ Returns
308
+ -------
309
+ a square matrix of shape (Nseq, Nseq)
310
+ """
311
+
312
+ s_prod = np.multiply.outer(s, s)
313
+ no_diag_eye = (1 - np.eye(s_prod.shape[0]))
314
+ alpha = np.sum(
315
+ (no_diag_eye * np.sqrt(s_prod) * coevolution_matrix) / np.sum(
316
+ (no_diag_eye * s_prod)))
317
+
318
+ return coevolution_matrix - alpha * np.sqrt(s_prod)