yunta 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. yunta/__init__.py +0 -0
  2. yunta/cli.py +252 -0
  3. yunta/dca.py +172 -0
  4. yunta/dca_torch.py +99 -0
  5. yunta/io.py +69 -0
  6. yunta/modelling.py +43 -0
  7. yunta/plots.py +55 -0
  8. yunta/scoring.py +73 -0
  9. yunta/screening.py +369 -0
  10. yunta/src_speedppi/alphafold/__init__.py +0 -0
  11. yunta/src_speedppi/alphafold/common/__init__.py +14 -0
  12. yunta/src_speedppi/alphafold/confidence.py +155 -0
  13. yunta/src_speedppi/alphafold/data/__init__.py +14 -0
  14. yunta/src_speedppi/alphafold/data/foldonly.py +146 -0
  15. yunta/src_speedppi/alphafold/data/mmcif_parsing.py +385 -0
  16. yunta/src_speedppi/alphafold/data/msaonly.py +176 -0
  17. yunta/src_speedppi/alphafold/data/pair_msas.py +117 -0
  18. yunta/src_speedppi/alphafold/data/parsers.py +368 -0
  19. yunta/src_speedppi/alphafold/data/pipeline.py +208 -0
  20. yunta/src_speedppi/alphafold/data/templates.py +910 -0
  21. yunta/src_speedppi/alphafold/data/tools/__init__.py +14 -0
  22. yunta/src_speedppi/alphafold/data/tools/hhblits.py +156 -0
  23. yunta/src_speedppi/alphafold/data/tools/hhsearch.py +92 -0
  24. yunta/src_speedppi/alphafold/data/tools/hmmbuild.py +138 -0
  25. yunta/src_speedppi/alphafold/data/tools/hmmsearch.py +90 -0
  26. yunta/src_speedppi/alphafold/data/tools/jackhmmer.py +199 -0
  27. yunta/src_speedppi/alphafold/data/tools/kalign.py +105 -0
  28. yunta/src_speedppi/alphafold/data/tools/utils.py +40 -0
  29. yunta/src_speedppi/alphafold/model/__init__.py +14 -0
  30. yunta/src_speedppi/alphafold/model/all_atom.py +1141 -0
  31. yunta/src_speedppi/alphafold/model/all_atom_test.py +135 -0
  32. yunta/src_speedppi/alphafold/model/common_modules.py +84 -0
  33. yunta/src_speedppi/alphafold/model/config.py +403 -0
  34. yunta/src_speedppi/alphafold/model/data.py +41 -0
  35. yunta/src_speedppi/alphafold/model/features.py +103 -0
  36. yunta/src_speedppi/alphafold/model/folding.py +1006 -0
  37. yunta/src_speedppi/alphafold/model/layer_stack.py +274 -0
  38. yunta/src_speedppi/alphafold/model/layer_stack_test.py +335 -0
  39. yunta/src_speedppi/alphafold/model/lddt.py +88 -0
  40. yunta/src_speedppi/alphafold/model/lddt_test.py +79 -0
  41. yunta/src_speedppi/alphafold/model/mapping.py +218 -0
  42. yunta/src_speedppi/alphafold/model/model.py +141 -0
  43. yunta/src_speedppi/alphafold/model/modules.py +2095 -0
  44. yunta/src_speedppi/alphafold/model/prng.py +69 -0
  45. yunta/src_speedppi/alphafold/model/prng_test.py +46 -0
  46. yunta/src_speedppi/alphafold/model/quat_affine.py +459 -0
  47. yunta/src_speedppi/alphafold/model/quat_affine_test.py +150 -0
  48. yunta/src_speedppi/alphafold/model/r3.py +322 -0
  49. yunta/src_speedppi/alphafold/model/tf/__init__.py +14 -0
  50. yunta/src_speedppi/alphafold/model/tf/data_transforms.py +624 -0
  51. yunta/src_speedppi/alphafold/model/tf/input_pipeline.py +166 -0
  52. yunta/src_speedppi/alphafold/model/tf/protein_features.py +130 -0
  53. yunta/src_speedppi/alphafold/model/tf/protein_features_test.py +51 -0
  54. yunta/src_speedppi/alphafold/model/tf/proteins_dataset.py +168 -0
  55. yunta/src_speedppi/alphafold/model/tf/shape_helpers.py +47 -0
  56. yunta/src_speedppi/alphafold/model/tf/shape_helpers_test.py +40 -0
  57. yunta/src_speedppi/alphafold/model/tf/shape_placeholders.py +20 -0
  58. yunta/src_speedppi/alphafold/model/tf/utils.py +47 -0
  59. yunta/src_speedppi/alphafold/model/utils.py +81 -0
  60. yunta/src_speedppi/alphafold/protein.py +234 -0
  61. yunta/src_speedppi/alphafold/residue_constants.py +896 -0
  62. yunta/structs/metrics.py +73 -0
  63. yunta/structs/msa.py +329 -0
  64. yunta/structs/pdb_structs.py +42 -0
  65. yunta/weights.py +66 -0
  66. yunta-0.0.1.dist-info/LICENSE +438 -0
  67. yunta-0.0.1.dist-info/METADATA +735 -0
  68. yunta-0.0.1.dist-info/RECORD +71 -0
  69. yunta-0.0.1.dist-info/WHEEL +5 -0
  70. yunta-0.0.1.dist-info/entry_points.txt +2 -0
  71. yunta-0.0.1.dist-info/top_level.txt +1 -0
yunta/__init__.py ADDED
File without changes
yunta/cli.py ADDED
@@ -0,0 +1,252 @@
1
+ """Command-line interface for sppid."""
2
+
3
+ __version__ = '0.0.1'
4
+
5
+ from typing import Any, Mapping, Optional, Tuple, Union
6
+
7
+ from argparse import ArgumentParser, FileType, Namespace
8
+ from io import TextIOWrapper
9
+ import os
10
+ import sys
11
+
12
+ from carabiner import print_err
13
+ from carabiner.cast import cast, flatten
14
+ from carabiner.cliutils import clicommand, CLIOption, CLICommand, CLIApp
15
+
16
+ from .io import write_metrics
17
+ from .plots import plot_matrix
18
+ from .screening import (
19
+ dca_one_vs_many,
20
+ dca_many_vs_many,
21
+ model_one_vs_many,
22
+ model_many_vs_many,
23
+ rf2track_one_vs_many,
24
+ )
25
+
26
+ def _load_msa_list(*args):
27
+ args = [a[0] if isinstance(a, list) else a for a in args]
28
+ return [flatten([line.strip() for line in msa]) for msa in args]
29
+
30
+ def _plot_results(results, result_interaction, metric,
31
+ output_dir: str = '.', *args, **kwargs) -> None:
32
+ if not os.path.exists(output_dir):
33
+ os.makedirs(output_dir)
34
+ filename_prefix = os.path.join(output_dir, metric.ID)
35
+ if hasattr(metric, 'apc'):
36
+ apc = metric.apc
37
+ filename_prefix += f".{apc=}"
38
+ plot_matrix(results,
39
+ filename_prefix=filename_prefix,
40
+ hline=metric.chain_a_len,
41
+ vline=metric.chain_a_len)
42
+ plot_matrix(result_interaction,
43
+ filename_prefix=f"{filename_prefix}.interaction",
44
+ ylabel=metric.uniprot_id_1,
45
+ xlabel=metric.uniprot_id_2)
46
+ return None
47
+
48
+
49
+ @clicommand(message="Making RosettaFold-2track prediction with the following parameters")
50
+ def _rf2t_single(args: Namespace) -> None:
51
+
52
+ if args.list_file:
53
+ msa1, msa2 = _load_msa_list(args.msa1, args.msa2)
54
+ else:
55
+ msa1, msa2 = args.msa1, args.msa2
56
+
57
+ print_err(f"Running RF-2t using {msa1} as reference.")
58
+ outputs = rf2track_one_vs_many(
59
+ msa_file1=msa1,
60
+ msa_file2=msa2,
61
+ cpu=args.cpu,
62
+ )
63
+ metrics = [_output[-1] for _output in outputs]
64
+ write_metrics(metrics,
65
+ filename=args.output)
66
+ if args.plot is not None:
67
+ for _output in outputs:
68
+ _plot_results(*_output, output_dir=args.plot)
69
+
70
+ return None
71
+
72
+
73
+
74
+ @clicommand(message="Calculating DCA for a pair of MSAs with the following parameters")
75
+ def _dca_single(args: Namespace) -> None:
76
+
77
+ if args.list_file:
78
+ msa1, msa2 = _load_msa_list(args.msa1, args.msa2)
79
+ else:
80
+ msa1, msa2 = args.msa1, args.msa2
81
+
82
+ outputs = dca_one_vs_many(
83
+ msa_file1=msa1,
84
+ msa_file2=msa2,
85
+ apc=args.apc,
86
+ )
87
+ metrics = [_output[-1] for _output in outputs]
88
+ write_metrics(metrics,
89
+ filename=args.output)
90
+ if args.plot is not None:
91
+ for _output in outputs:
92
+ _plot_results(*_output, output_dir=args.plot)
93
+
94
+ return None
95
+
96
+
97
+ @clicommand(message="Calculating DCA between pairs of MSAs with the following parameters")
98
+ def _dca_many_vs_many(args: Namespace) -> None:
99
+
100
+ if args.list_file:
101
+ msa1, msa2 = _load_msa_list(args.msa1, args.msa2)
102
+ else:
103
+ msa1, msa2 = args.msa1, args.msa2
104
+
105
+ outputs = dca_many_vs_many(
106
+ msa_files1=msa1,
107
+ msa_files2=msa2,
108
+ apc=args.apc,
109
+ )
110
+
111
+ metrics = [_output[-1] for _output in outputs]
112
+ write_metrics(metrics,
113
+ filename=args.output)
114
+ if args.plot is not None:
115
+ for _output in outputs:
116
+ _plot_results(*_output, output_dir=args.plot)
117
+
118
+ return None
119
+
120
+
121
+ @clicommand(message="Modelling one PPI with the following parameters")
122
+ def _af2_single(args: Namespace) -> None:
123
+
124
+ if args.list_file:
125
+ msa1, msa2 = _load_msa_list(args.msa1, args.msa2)
126
+ else:
127
+ msa1, msa2 = args.msa1, args.msa2
128
+
129
+ metric = model_one_vs_many(
130
+ msa_file1=msa1,
131
+ msa_file2=msa2,
132
+ max_recycles=args.recycles,
133
+ output_dir=args.output,
134
+ param_dir=args.params,
135
+ )
136
+
137
+ output_filename = os.path.join(args.output_dir, f"{metric.ID}_metrics.csv")
138
+ print_err(f"Saving metrics as {output_filename}")
139
+ write_metrics(metric,
140
+ filename=output_filename)
141
+
142
+ return None
143
+
144
+
145
+ @clicommand(message="Modelling sets of PPIs with the following parameters")
146
+ def _af2_many_vs_many(args: Namespace) -> None:
147
+
148
+ if args.list_file:
149
+ msa1, msa2 = _load_msa_list(args.msa1, args.msa2)
150
+ else:
151
+ msa1, msa2 = args.msa1, args.msa2
152
+
153
+ metrics = model_many_vs_many(
154
+ msa_files1=msa1,
155
+ msa_files2=msa2,
156
+ output_dir=args.output,
157
+ max_recycles=args.recycles,
158
+ param_dir=args.params,
159
+ )
160
+
161
+ output_filename = os.path.join(args.output, "_all_metrics.csv")
162
+ print_err(f"Saving metrics as {output_filename}")
163
+ write_metrics(metrics,
164
+ filename=output_filename)
165
+
166
+ return None
167
+
168
+
169
+ def main() -> None:
170
+ inputs = CLIOption('msa1',
171
+ default=sys.stdin,
172
+ type=FileType('r'),
173
+ nargs='?',
174
+ help='MSA file. Default: STDIN.')
175
+ input2 = CLIOption('--msa2', '-2',
176
+ type=FileType('r'),
177
+ default=None,
178
+ help='Second MSA file.')
179
+ inputs_list = CLIOption('msa1',
180
+ default=sys.stdin,
181
+ type=FileType('r'),
182
+ nargs='*',
183
+ help='MSA file(s).')
184
+ inputs_list2 = CLIOption('--msa2', '-2',
185
+ type=FileType('r'),
186
+ default=None,
187
+ nargs='*',
188
+ help='Second MSA file(s). Default: if not provided, all pairwise from msa1.')
189
+ list_file = CLIOption('--list-file', '-l',
190
+ action='store_true',
191
+ help='Treat inputs as plain-text list of MSA files, rather than MSA filenames. '
192
+ 'Default: treat as MSA filenames.')
193
+ output = CLIOption('--output', '-o',
194
+ type=str,
195
+ required=True,
196
+ help='Output directory.')
197
+ plot = CLIOption('--plot', '-p',
198
+ type=str,
199
+ default=None,
200
+ help='Directory for saving plots. Default: don\'t plot.')
201
+ cpu = CLIOption('--cpu', '-c',
202
+ action='store_true',
203
+ help='Whether to use CPU only. Default: use GPU.')
204
+ output_file = CLIOption('--output', '-o',
205
+ default=sys.stdout,
206
+ type=FileType('w'),
207
+ nargs='?',
208
+ help='Output filename. Default: STDOUT.')
209
+ apc = CLIOption('--apc', '-a',
210
+ action='store_true',
211
+ help='Whether to use APC correction in DCA. Default: don\'t apply correction.')
212
+ params = CLIOption('--params', '-w',
213
+ type=str,
214
+ default=None,
215
+ help='Path to AlphaFold2 params file (.npz).')
216
+ recycles = CLIOption('--recycles', '-x',
217
+ type=int,
218
+ default=10,
219
+ help='Maximum number of recyles through the model.')
220
+
221
+ rf2t_single = CLICommand('rf2t-single',
222
+ description='Calculate RF-2track contacts for between one protein and a series of others.',
223
+ main=_rf2t_single,
224
+ options=[inputs, inputs_list2, list_file, output_file, plot, cpu])
225
+ dca_single = CLICommand('dca-single',
226
+ description='Calculate DCA for one protein-protein interaction.',
227
+ main=_dca_single,
228
+ options=[inputs, inputs_list2, list_file, output_file, plot, apc])
229
+ dca_many = CLICommand('dca-many',
230
+ description='Calculate DCA between two sets of proteins, or all pairs in one set of proteins.',
231
+ main=_dca_many_vs_many,
232
+ options=[inputs_list, inputs_list2, list_file, apc, output_file, plot])
233
+ af2_single = CLICommand('af2-single',
234
+ description='Model one protein-protein interaction.',
235
+ main=_af2_single,
236
+ options=[inputs, inputs_list2, list_file, output, params, recycles, plot])
237
+ af2_many = CLICommand('af2-many',
238
+ description='Model all interactions between two sets of proteins, or all pairs in one set of proteins.',
239
+ main=_af2_many_vs_many,
240
+ options=[inputs_list, inputs_list2, list_file, output, params, recycles, plot])
241
+
242
+ app = CLIApp("sppid",
243
+ version=__version__,
244
+ description="Screening protein-protein interactions using DCA, RosettaFold-2track, and AlphaFold2.",
245
+ commands=[dca_single, dca_many, rf2t_single, af2_single, af2_many])
246
+
247
+ app.run()
248
+ return None
249
+
250
+
251
+ if __name__ == "__main__":
252
+ main()
yunta/dca.py ADDED
@@ -0,0 +1,172 @@
1
+ """Code from Humpreys, Science, 2021 (DOI: 10.1126/science.abm4805), reimplemented for Tensorflow 2.
2
+
3
+ According to the paper's supplementary material:
4
+ "We reimplemented DCA so that it can be computed using GPUs to speed up the calculation.
5
+ The code is in the box below. In addition, we applied average product correction (APC).[...]"
6
+
7
+ ```python
8
+ import sys
9
+ import numpy as np
10
+ import string
11
+ import tensorflow as tf
12
+
13
+ def parse_a3m(a3mlines):
14
+ seqs = []
15
+ labels = []
16
+ for line in a3mlines:
17
+ if line[0] == '>':
18
+ labels.append(line.rstrip())
19
+ else:
20
+ seqs.append(line[:-1])
21
+ alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-"), dtype='|S1').view(np.uint8)
22
+ seq_num = np.array([list(s) for s in seqs], dtype='|S1').view(np.uint8)
23
+ for i in range(alphabet.shape[0]):
24
+ seq_num[seq_num == alphabet[i]] = i
25
+ seq_num[seq_num > 20] = 20
26
+ return {'seqs' : seq_num, 'labels' : labels }
27
+
28
+ def tf_cov(x,w=None):
29
+ if w is None:
30
+ num_points = tf.cast(tf.shape(x)[0],tf.float32) - 1
31
+ x_mean = tf.reduce_mean(x, axis=0, keep_dims=True)
32
+ x = (x - x_mean)
33
+ else:
34
+ num_points = tf.reduce_sum(w) - tf.sqrt(tf.reduce_mean(w))
35
+ x_mean = tf.reduce_sum(x * w[:,None], axis=0, keepdims=True) / num_points
36
+ x = (x - x_mean) * tf.sqrt(w[:,None])
37
+ return tf.matmul(tf.transpose(x),x)/num_points
38
+
39
+ fp = open(sys.argv[1] + ".alignments", "r")
40
+ pair2a3m = {}
41
+ pair = ""
42
+ pairs = []
43
+ a3m = []
44
+ for line in fp:
45
+ if line[:2] == ">>":
46
+ if pair and a3m:
47
+ pair2a3m[pair] = a3m
48
+ pairs.append(pair)
49
+ pair = line[2:-1]
50
+ a3m = []
51
+ else:
52
+ a3m.append(line)
53
+ fp.close()
54
+ if pair and a3m:
55
+ pair2a3m[pair] = a3m
56
+ pairs.append(pair)
57
+
58
+ config = tf.ConfigProto(gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9))
59
+
60
+ with tf.Graph().as_default():
61
+ x = tf.placeholder(tf.uint8,shape=(None,None),name="x")
62
+ x_shape = tf.shape(x)
63
+ x_nr = x_shape[0]
64
+ x_nc = x_shape[1]
65
+ x_ns = 21
66
+ x_msa = tf.one_hot(x,x_ns)
67
+ x_cutoff = tf.cast(x_nc,tf.float32) * 0.8
68
+ x_pw = tf.tensordot(x_msa, x_msa, [[1,2], [1,2]])
69
+ x_cut = x_pw > x_cutoff
70
+ x_weights = 1.0/tf.reduce_sum(tf.cast(x_cut, dtype=tf.float32),-1)
71
+ x_feat = tf.reshape(x_msa,(x_nr,x_nc*x_ns))
72
+ x_c = tf_cov(x_feat,x_weights) + tf.eye(x_nc*x_ns) *
73
+ 4.5/tf.sqrt(tf.reduce_sum(x_weights))
74
+ x_c_inv = tf.linalg.inv(x_c)
75
+ x_w = tf.reshape(x_c_inv,(x_nc,x_ns,x_nc,x_ns))
76
+ x_wi = tf.sqrt(tf.reduce_sum(tf.square(x_w[:,:-1,:,:-1]),(1,3))) * (1-tf.eye(x_nc))
77
+ # APC
78
+ #x_ap = tf.reduce_sum(x_wi,0,keepdims=True) * tf.reduce_sum(x_wi,1,keepdims=True) /
79
+ tf.reduce_sum(x_wi)
80
+ #x_wip = (x_wi - x_ap) * (1-tf.eye(x_nc))
81
+
82
+ with tf.Session(config=config) as sess:
83
+ results = []
84
+ get_pairs = []
85
+ for pair in pairs:
86
+ print (pair)
87
+ msa = parse_a3m(pair2a3m[pair])
88
+ try:
89
+ wip = sess.run(x_wi,{x:msa['seqs']})
90
+ results.append(wip.astype(np.float16))
91
+ get_pairs.append(pair)
92
+ except tf.errors.ResourceExhaustedError as e:
93
+ pass
94
+ np.savez_compressed(sys.argv[1], *results, names=get_pairs)
95
+ rp = open(sys.argv[1] + ".log","w")
96
+ for pair in get_pairs:
97
+ rp.write(pair + "\n")
98
+ rp.close()
99
+ ```
100
+ """
101
+
102
+ from typing import Optional, Union
103
+ from io import TextIOWrapper
104
+ import sys
105
+
106
+ from carabiner import print_err
107
+ import numpy as np
108
+ import tensorflow as tf
109
+
110
+ from .structs.msa import MSA, _A3M_ALPHABET
111
+
112
+ @tf.function
113
+ def _tf_cov(x: tf.Tensor, w: Optional[tf.Tensor] = None) -> tf.float32:
114
+ if w is None:
115
+ num_points = tf.cast(tf.shape(x)[0], tf.float32) - 1.
116
+ x_mean = tf.math.reduce_mean(x, axis=0, keep_dims=True)
117
+ x = (x - x_mean)
118
+ else:
119
+ num_points = tf.math.reduce_sum(w) - tf.math.sqrt(tf.math.reduce_mean(w))
120
+ x_mean = tf.math.reduce_sum(x * w[:,tf.newaxis], axis=0, keepdims=True) / num_points
121
+ x = (x - x_mean) * tf.math.sqrt(w[:,tf.newaxis])
122
+ return tf.linalg.matmul(x, x, transpose_a=True) / num_points
123
+
124
+
125
+ @tf.function
126
+ def _get_wip(x: tf.Tensor, apc: bool = False) -> tf.float32:
127
+ alphabet_size = len(_A3M_ALPHABET)
128
+ n_row, n_col = tf.shape(x)[-2], tf.shape(x)[-1]
129
+ x_cutoff = tf.cast(n_col, dtype=tf.float32) * .8
130
+ msa_one_hot = tf.one_hot(x, alphabet_size)
131
+ x_pw = tf.tensordot(msa_one_hot, msa_one_hot, [[1,2], [1,2]])
132
+ x_cut = tf.cast(x_pw > x_cutoff,
133
+ dtype=tf.float32)
134
+
135
+ weights = 1. / tf.math.reduce_sum(x_cut, axis=-1)
136
+ x_feat = tf.reshape(msa_one_hot, (n_row, n_col * alphabet_size))
137
+
138
+ x_c = (_tf_cov(x_feat, weights) + tf.eye(n_col * alphabet_size) * 4.5 /
139
+ tf.sqrt(tf.math.reduce_sum(weights)))
140
+ x_c_inv = tf.linalg.inv(x_c)
141
+
142
+ x_w = tf.reshape(x_c_inv, (n_col, alphabet_size, n_col, alphabet_size))
143
+ x_wi = tf.sqrt(tf.reduce_sum(tf.square(x_w[:,:-1,:,:-1]), (1, 3))) * (1. - tf.eye(n_col))
144
+
145
+ if apc:
146
+ x_ap = (tf.math.reduce_sum(x_wi, axis=0, keepdims=True) * tf.math.reduce_sum(x_wi, axis=1, keepdims=True)
147
+ / tf.math.reduce_sum(x_wi))
148
+ x_wi = (x_wi - x_ap) * (1. - tf.eye(n_col))
149
+
150
+ return x_wi
151
+
152
+
153
+ def calculate_dca(msa: MSA,
154
+ apc: bool = False,
155
+ gpu: bool = True) -> np.ndarray:
156
+
157
+ """
158
+
159
+ """
160
+ print_err(f"Devices available:\n{tf.config.list_physical_devices()}")
161
+ if gpu:
162
+ gpus = tf.config.list_physical_devices('GPU')
163
+ if len(gpus) == 0:
164
+ print_err("WARNING! GPU requested but none available. Falling back to CPU.")
165
+ msa_token_ids = np.asarray(msa.sequence_token_ids)
166
+ try:
167
+ wip = _get_wip(msa_token_ids, apc=apc)
168
+ except tf.errors.ResourceExhaustedError as e:
169
+ with tf.device('/cpu:0'):
170
+ wip = _get_wip(msa_token_ids, apc=apc)
171
+
172
+ return wip.numpy()
yunta/dca_torch.py ADDED
@@ -0,0 +1,99 @@
1
+ """DCA for GPU implemented in pytorch."""
2
+
3
+ from typing import Optional, Union
4
+
5
+ from io import TextIOWrapper
6
+ import sys
7
+
8
+ from carabiner import print_err
9
+ import numpy as np
10
+ import torch
11
+ from torch import FloatTensor, Tensor
12
+ import torch.nn.functional as F
13
+
14
+ from .structs.msa import MSA, _A3M_ALPHABET
15
+
16
+
17
+ def _torch_cov(x: Tensor, w: Optional[Tensor] = None) -> FloatTensor:
18
+ if w is None:
19
+ return torch.cov(x)
20
+ else:
21
+ num_points = torch.sum(w) - torch.sqrt(torch.mean(w))
22
+ x_mean = torch.sum(x * w.unsqueeze(-1),
23
+ dim=0, keepdim=True) / num_points
24
+ x = (x - x_mean) * torch.sqrt(w.unsqueeze(-1))
25
+ return torch.matmul(x.transpose(-2, -1), x) / num_points
26
+
27
+
28
+ def _get_wip(x: Tensor, apc: bool = False, gpu: bool = True, dtype=torch.float16) -> FloatTensor:
29
+ device = torch.device("cuda" if( torch.cuda.is_available() and gpu) else "cpu")
30
+ # x = x.to(device)
31
+ alphabet_size = torch.tensor(len(_A3M_ALPHABET), device=device)
32
+ n_row, n_col = x.shape
33
+ msa_one_hot = (F.one_hot(x.to(torch.int64),
34
+ num_classes=alphabet_size)
35
+ .to(dtype)) # nrow, ncol, alphabet_size
36
+ identity_counts = torch.tensordot(msa_one_hot, msa_one_hot,
37
+ dims=[[1,2], [1,2]]) # nrow, nrow
38
+ # # Should be ~alphabet_size-fold fewer calculations and lower memory
39
+ # # but actually appears to be slower?
40
+ # identity_counts = (torch.eq(x, x)
41
+ # .to(dtype)
42
+ # .sum(dim=-1)) # nrow, nrow
43
+ identity_cutoff = torch.tensor(n_col * .8, device=device)
44
+ x_cut = (identity_counts > identity_cutoff).to(dtype)
45
+
46
+ weights = 1. / torch.sum(x_cut, dim=-1) # nrow
47
+ msa_concat_one_hot = msa_one_hot.view(n_row, n_col * alphabet_size)
48
+
49
+ shrinkage_coeff = torch.eye(msa_concat_one_hot.shape[-1],
50
+ device=device) * (4.5 / torch.sqrt(torch.sum(weights)))
51
+ covariance_matrix = _torch_cov(msa_concat_one_hot, weights) # nrow, nrow
52
+ shrunk_cov_matrix = covariance_matrix + shrinkage_coeff
53
+ shrunk_cov_matrix_inv = torch.linalg.inv(shrunk_cov_matrix).view(n_col, alphabet_size, n_col, alphabet_size)
54
+
55
+ shrunk_cov_matrix_inv_no_gaps = shrunk_cov_matrix_inv[:,:-1,:,:-1]
56
+ I_ncol = torch.eye(n_col, device=device)
57
+ interchain_scores = torch.sqrt(torch.sum(torch.square(shrunk_cov_matrix_inv_no_gaps),
58
+ dim=(1, 3))) * (1. - I_ncol)
59
+
60
+ if apc:
61
+ apc_factor = (torch.sum(interchain_scores,
62
+ dim=0, keepdim=True)
63
+ * torch.sum(interchain_scores,
64
+ dim=1, keepdim=True)
65
+ / torch.sum(interchain_scores))
66
+ interchain_scores = (interchain_scores - apc_factor) * (1. - I_ncol)
67
+
68
+ return interchain_scores
69
+
70
+
71
+ def calculate_dca(msa: MSA,
72
+ apc: bool = False,
73
+ gpu: bool = True) -> np.ndarray:
74
+
75
+ """
76
+
77
+ """
78
+ n_gpu = torch.cuda.device_count()
79
+ print_err(f"GPUs available: {n_gpu}")
80
+ if gpu:
81
+ if n_gpu == 0:
82
+ print_err("WARNING! GPU requested but none available. Falling back to CPU.")
83
+ device = 'cpu'
84
+ else:
85
+ device = 'cuda'
86
+ # torch.cuda.empty_cache()
87
+ else:
88
+ device = 'cpu'
89
+ msa_token_ids = torch.tensor(msa.sequence_token_ids,
90
+ dtype=torch.int64,
91
+ device=torch.device(device))
92
+
93
+ try:
94
+ wip = _get_wip(msa_token_ids, apc=apc, gpu=device == 'cuda')
95
+ except torch.cuda.OutOfMemoryError as e:
96
+ print_err("GPU memory exhausted; falling back to CPU.")
97
+ wip = _get_wip(msa_token_ids.to('cpu'), apc=apc, gpu=False)
98
+
99
+ return wip.detach().cpu().numpy()
yunta/io.py ADDED
@@ -0,0 +1,69 @@
1
+ """Tools for input and output."""
2
+
3
+ from typing import Any, Dict, Iterable, Optional, Tuple, Union
4
+
5
+ from csv import DictWriter
6
+ from dataclasses import is_dataclass, asdict
7
+ from io import TextIOWrapper
8
+ import os
9
+ import shutil
10
+ import tempfile
11
+
12
+ from bioino import FastaCollection
13
+ from carabiner import print_err
14
+ from carabiner.cast import cast
15
+ import numpy as np
16
+ from pandas import DataFrame
17
+
18
+ from .structs.pdb_structs import ATMRecord
19
+ from .structs.metrics import ModelMetrics
20
+
21
+ def save_design(pdb_info,
22
+ output_name: str,
23
+ chainA_length: int) -> None:
24
+
25
+ """Save the resulting protein-peptide design to a pdb file.
26
+
27
+ """
28
+
29
+ chain_name = 'A'
30
+ with open(output_name, 'w') as f:
31
+ pdb_contents = pdb_info.split('\n')
32
+ for line in pdb_contents:
33
+ try:
34
+ record = ATMRecord(line)
35
+ if record.res_no > chainA_length:
36
+ chain_name = 'B'
37
+ outline = f"{line[:21]}{chain_name}{line[22:]}"
38
+ print(outline, file=f)
39
+ except Exception:
40
+ print(line, file=f)
41
+ return None
42
+
43
+
44
+ def write_metrics(metrics: Union[Iterable, Any],
45
+ filename: Union[str, TextIOWrapper],
46
+ mode: str = 'w'):
47
+
48
+ """
49
+
50
+ """
51
+
52
+ if is_dataclass(metrics):
53
+ metrics = [metrics]
54
+ if not all(is_dataclass(m) for m in metrics):
55
+ raise TypeError("All metrics must be dataclass objects.")
56
+ f = cast(filename, to=TextIOWrapper, mode=mode)
57
+ if f.name != '<stdout>':
58
+ dirname = os.path.dirname(f.name)
59
+ if len(dirname) > 0 and not os.path.exists(dirname):
60
+ print_err(f"Creating output directory {dirname}")
61
+ os.makedirs(dirname)
62
+
63
+ w = DictWriter(f, fieldnames=list(asdict(metrics[0])), delimiter='\t')
64
+ if mode == 'w':
65
+ w.writeheader()
66
+ for metric in metrics:
67
+ w.writerow(asdict(metric))
68
+
69
+ return None
yunta/modelling.py ADDED
@@ -0,0 +1,43 @@
1
+ """Tools for setting up and using models."""
2
+
3
+ from typing import Optional
4
+
5
+ import os
6
+ from time import time
7
+
8
+ from carabiner import print_err
9
+
10
+ from .weights import get_model_weights
11
+
12
+ _params_path = os.path.join(os.path.dirname(__file__), "data")
13
+
14
+ def make_model_runner(num_ensemble: int = 1,
15
+ max_recycles: int = 10,
16
+ param_dir: Optional[str] = None,
17
+ model_name: Optional[str] = None):
18
+
19
+ """Generate an AlphaFold2 model runner.
20
+
21
+ """
22
+
23
+ import tensorflow as tf
24
+ tf.config.experimental.set_visible_devices([], "GPU")
25
+ from .src_speedppi.alphafold.model import config, data, model
26
+ from jax.lib import xla_bridge
27
+
28
+ print_err(f"Setting up AlphaFold2 model. XLA platform available: {xla_bridge.get_backend().platform}")
29
+
30
+ if model_name is None:
31
+ model_name = 'model_1'
32
+ if param_dir is None:
33
+ param_dir = get_model_weights(model_name)
34
+
35
+ model_config = config.model_config(model_name)
36
+ model_config.data.eval.num_ensemble = num_ensemble
37
+ model_config.data.common.num_recycle = max_recycles
38
+ model_config.model.num_recycle = max_recycles
39
+
40
+ model_params = data.get_model_haiku_params(model_name=model_name,
41
+ data_dir=param_dir)
42
+
43
+ return model.RunModel(model_config, model_params)
yunta/plots.py ADDED
@@ -0,0 +1,55 @@
1
+ """Plotting functions."""
2
+
3
+ from typing import Optional
4
+
5
+ import os
6
+
7
+ from carabiner import colorblind_palette, print_err
8
+ from carabiner.mpl import grid
9
+ from pandas import DataFrame
10
+ import numpy as np
11
+
12
+ def plot_matrix(m: np.ndarray,
13
+ filename_prefix: Optional[str] = None,
14
+ format: str = 'png',
15
+ dpi: int = 300,
16
+ vline: Optional[float] = None,
17
+ hline: Optional[float] = None,
18
+ vmax: Optional[float] = None,
19
+ *args, **kwargs):
20
+
21
+ fig, axes = grid()
22
+ im = axes.imshow(m, cmap='magma', vmin=0., vmax=vmax)
23
+ fig.colorbar(im, shrink=.7)
24
+ if hline is not None:
25
+ axes.axhline(hline, color='lightgrey', zorder=10)
26
+ if vline is not None:
27
+ axes.axvline(vline, color='lightgrey', zorder=10)
28
+ axes.set(*args, **kwargs)
29
+
30
+ if filename_prefix is not None:
31
+ data_file = f"{filename_prefix}.tsv"
32
+ plot_dir = os.path.dirname(data_file)
33
+ if not os.path.exists(plot_dir):
34
+ print_err(f"Creating output directory {plot_dir}")
35
+ os.makedirs(plot_dir)
36
+ print_err(f"Saving matrix data as {data_file}")
37
+ (DataFrame(m,
38
+ index=np.arange(m.shape[0]),
39
+ columns=np.arange(m.shape[1]))
40
+ .to_csv(data_file, sep='\t', index=False))
41
+ plot_file = f"{filename_prefix}.{format}"
42
+ print_err(f"Saving matrix plot as {plot_file}")
43
+ fig.savefig(plot_file, bbox_inches="tight", dpi=dpi)
44
+
45
+ return fig, axes
46
+
47
+ def plot_dca(dca: np.ndarray,
48
+ filename_prefix: Optional[str] = None,
49
+ format: str = 'png',
50
+ dpi: int = 300):
51
+
52
+ if filename_prefix is not None:
53
+ filename_prefix += "_dca"
54
+
55
+ return plot_matrix(dca, filename_prefix=filename_prefix, format=format, dpi=dpi)