yunta 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- yunta/__init__.py +0 -0
- yunta/cli.py +252 -0
- yunta/dca.py +172 -0
- yunta/dca_torch.py +99 -0
- yunta/io.py +69 -0
- yunta/modelling.py +43 -0
- yunta/plots.py +55 -0
- yunta/scoring.py +73 -0
- yunta/screening.py +369 -0
- yunta/src_speedppi/alphafold/__init__.py +0 -0
- yunta/src_speedppi/alphafold/common/__init__.py +14 -0
- yunta/src_speedppi/alphafold/confidence.py +155 -0
- yunta/src_speedppi/alphafold/data/__init__.py +14 -0
- yunta/src_speedppi/alphafold/data/foldonly.py +146 -0
- yunta/src_speedppi/alphafold/data/mmcif_parsing.py +385 -0
- yunta/src_speedppi/alphafold/data/msaonly.py +176 -0
- yunta/src_speedppi/alphafold/data/pair_msas.py +117 -0
- yunta/src_speedppi/alphafold/data/parsers.py +368 -0
- yunta/src_speedppi/alphafold/data/pipeline.py +208 -0
- yunta/src_speedppi/alphafold/data/templates.py +910 -0
- yunta/src_speedppi/alphafold/data/tools/__init__.py +14 -0
- yunta/src_speedppi/alphafold/data/tools/hhblits.py +156 -0
- yunta/src_speedppi/alphafold/data/tools/hhsearch.py +92 -0
- yunta/src_speedppi/alphafold/data/tools/hmmbuild.py +138 -0
- yunta/src_speedppi/alphafold/data/tools/hmmsearch.py +90 -0
- yunta/src_speedppi/alphafold/data/tools/jackhmmer.py +199 -0
- yunta/src_speedppi/alphafold/data/tools/kalign.py +105 -0
- yunta/src_speedppi/alphafold/data/tools/utils.py +40 -0
- yunta/src_speedppi/alphafold/model/__init__.py +14 -0
- yunta/src_speedppi/alphafold/model/all_atom.py +1141 -0
- yunta/src_speedppi/alphafold/model/all_atom_test.py +135 -0
- yunta/src_speedppi/alphafold/model/common_modules.py +84 -0
- yunta/src_speedppi/alphafold/model/config.py +403 -0
- yunta/src_speedppi/alphafold/model/data.py +41 -0
- yunta/src_speedppi/alphafold/model/features.py +103 -0
- yunta/src_speedppi/alphafold/model/folding.py +1006 -0
- yunta/src_speedppi/alphafold/model/layer_stack.py +274 -0
- yunta/src_speedppi/alphafold/model/layer_stack_test.py +335 -0
- yunta/src_speedppi/alphafold/model/lddt.py +88 -0
- yunta/src_speedppi/alphafold/model/lddt_test.py +79 -0
- yunta/src_speedppi/alphafold/model/mapping.py +218 -0
- yunta/src_speedppi/alphafold/model/model.py +141 -0
- yunta/src_speedppi/alphafold/model/modules.py +2095 -0
- yunta/src_speedppi/alphafold/model/prng.py +69 -0
- yunta/src_speedppi/alphafold/model/prng_test.py +46 -0
- yunta/src_speedppi/alphafold/model/quat_affine.py +459 -0
- yunta/src_speedppi/alphafold/model/quat_affine_test.py +150 -0
- yunta/src_speedppi/alphafold/model/r3.py +322 -0
- yunta/src_speedppi/alphafold/model/tf/__init__.py +14 -0
- yunta/src_speedppi/alphafold/model/tf/data_transforms.py +624 -0
- yunta/src_speedppi/alphafold/model/tf/input_pipeline.py +166 -0
- yunta/src_speedppi/alphafold/model/tf/protein_features.py +130 -0
- yunta/src_speedppi/alphafold/model/tf/protein_features_test.py +51 -0
- yunta/src_speedppi/alphafold/model/tf/proteins_dataset.py +168 -0
- yunta/src_speedppi/alphafold/model/tf/shape_helpers.py +47 -0
- yunta/src_speedppi/alphafold/model/tf/shape_helpers_test.py +40 -0
- yunta/src_speedppi/alphafold/model/tf/shape_placeholders.py +20 -0
- yunta/src_speedppi/alphafold/model/tf/utils.py +47 -0
- yunta/src_speedppi/alphafold/model/utils.py +81 -0
- yunta/src_speedppi/alphafold/protein.py +234 -0
- yunta/src_speedppi/alphafold/residue_constants.py +896 -0
- yunta/structs/metrics.py +73 -0
- yunta/structs/msa.py +329 -0
- yunta/structs/pdb_structs.py +42 -0
- yunta/weights.py +66 -0
- yunta-0.0.1.dist-info/LICENSE +438 -0
- yunta-0.0.1.dist-info/METADATA +735 -0
- yunta-0.0.1.dist-info/RECORD +71 -0
- yunta-0.0.1.dist-info/WHEEL +5 -0
- yunta-0.0.1.dist-info/entry_points.txt +2 -0
- yunta-0.0.1.dist-info/top_level.txt +1 -0
yunta/__init__.py
ADDED
|
File without changes
|
yunta/cli.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
"""Command-line interface for sppid."""
|
|
2
|
+
|
|
3
|
+
__version__ = '0.0.1'
|
|
4
|
+
|
|
5
|
+
from typing import Any, Mapping, Optional, Tuple, Union
|
|
6
|
+
|
|
7
|
+
from argparse import ArgumentParser, FileType, Namespace
|
|
8
|
+
from io import TextIOWrapper
|
|
9
|
+
import os
|
|
10
|
+
import sys
|
|
11
|
+
|
|
12
|
+
from carabiner import print_err
|
|
13
|
+
from carabiner.cast import cast, flatten
|
|
14
|
+
from carabiner.cliutils import clicommand, CLIOption, CLICommand, CLIApp
|
|
15
|
+
|
|
16
|
+
from .io import write_metrics
|
|
17
|
+
from .plots import plot_matrix
|
|
18
|
+
from .screening import (
|
|
19
|
+
dca_one_vs_many,
|
|
20
|
+
dca_many_vs_many,
|
|
21
|
+
model_one_vs_many,
|
|
22
|
+
model_many_vs_many,
|
|
23
|
+
rf2track_one_vs_many,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
def _load_msa_list(*args):
|
|
27
|
+
args = [a[0] if isinstance(a, list) else a for a in args]
|
|
28
|
+
return [flatten([line.strip() for line in msa]) for msa in args]
|
|
29
|
+
|
|
30
|
+
def _plot_results(results, result_interaction, metric,
|
|
31
|
+
output_dir: str = '.', *args, **kwargs) -> None:
|
|
32
|
+
if not os.path.exists(output_dir):
|
|
33
|
+
os.makedirs(output_dir)
|
|
34
|
+
filename_prefix = os.path.join(output_dir, metric.ID)
|
|
35
|
+
if hasattr(metric, 'apc'):
|
|
36
|
+
apc = metric.apc
|
|
37
|
+
filename_prefix += f".{apc=}"
|
|
38
|
+
plot_matrix(results,
|
|
39
|
+
filename_prefix=filename_prefix,
|
|
40
|
+
hline=metric.chain_a_len,
|
|
41
|
+
vline=metric.chain_a_len)
|
|
42
|
+
plot_matrix(result_interaction,
|
|
43
|
+
filename_prefix=f"{filename_prefix}.interaction",
|
|
44
|
+
ylabel=metric.uniprot_id_1,
|
|
45
|
+
xlabel=metric.uniprot_id_2)
|
|
46
|
+
return None
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@clicommand(message="Making RosettaFold-2track prediction with the following parameters")
|
|
50
|
+
def _rf2t_single(args: Namespace) -> None:
|
|
51
|
+
|
|
52
|
+
if args.list_file:
|
|
53
|
+
msa1, msa2 = _load_msa_list(args.msa1, args.msa2)
|
|
54
|
+
else:
|
|
55
|
+
msa1, msa2 = args.msa1, args.msa2
|
|
56
|
+
|
|
57
|
+
print_err(f"Running RF-2t using {msa1} as reference.")
|
|
58
|
+
outputs = rf2track_one_vs_many(
|
|
59
|
+
msa_file1=msa1,
|
|
60
|
+
msa_file2=msa2,
|
|
61
|
+
cpu=args.cpu,
|
|
62
|
+
)
|
|
63
|
+
metrics = [_output[-1] for _output in outputs]
|
|
64
|
+
write_metrics(metrics,
|
|
65
|
+
filename=args.output)
|
|
66
|
+
if args.plot is not None:
|
|
67
|
+
for _output in outputs:
|
|
68
|
+
_plot_results(*_output, output_dir=args.plot)
|
|
69
|
+
|
|
70
|
+
return None
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@clicommand(message="Calculating DCA for a pair of MSAs with the following parameters")
|
|
75
|
+
def _dca_single(args: Namespace) -> None:
|
|
76
|
+
|
|
77
|
+
if args.list_file:
|
|
78
|
+
msa1, msa2 = _load_msa_list(args.msa1, args.msa2)
|
|
79
|
+
else:
|
|
80
|
+
msa1, msa2 = args.msa1, args.msa2
|
|
81
|
+
|
|
82
|
+
outputs = dca_one_vs_many(
|
|
83
|
+
msa_file1=msa1,
|
|
84
|
+
msa_file2=msa2,
|
|
85
|
+
apc=args.apc,
|
|
86
|
+
)
|
|
87
|
+
metrics = [_output[-1] for _output in outputs]
|
|
88
|
+
write_metrics(metrics,
|
|
89
|
+
filename=args.output)
|
|
90
|
+
if args.plot is not None:
|
|
91
|
+
for _output in outputs:
|
|
92
|
+
_plot_results(*_output, output_dir=args.plot)
|
|
93
|
+
|
|
94
|
+
return None
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@clicommand(message="Calculating DCA between pairs of MSAs with the following parameters")
|
|
98
|
+
def _dca_many_vs_many(args: Namespace) -> None:
|
|
99
|
+
|
|
100
|
+
if args.list_file:
|
|
101
|
+
msa1, msa2 = _load_msa_list(args.msa1, args.msa2)
|
|
102
|
+
else:
|
|
103
|
+
msa1, msa2 = args.msa1, args.msa2
|
|
104
|
+
|
|
105
|
+
outputs = dca_many_vs_many(
|
|
106
|
+
msa_files1=msa1,
|
|
107
|
+
msa_files2=msa2,
|
|
108
|
+
apc=args.apc,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
metrics = [_output[-1] for _output in outputs]
|
|
112
|
+
write_metrics(metrics,
|
|
113
|
+
filename=args.output)
|
|
114
|
+
if args.plot is not None:
|
|
115
|
+
for _output in outputs:
|
|
116
|
+
_plot_results(*_output, output_dir=args.plot)
|
|
117
|
+
|
|
118
|
+
return None
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
@clicommand(message="Modelling one PPI with the following parameters")
|
|
122
|
+
def _af2_single(args: Namespace) -> None:
|
|
123
|
+
|
|
124
|
+
if args.list_file:
|
|
125
|
+
msa1, msa2 = _load_msa_list(args.msa1, args.msa2)
|
|
126
|
+
else:
|
|
127
|
+
msa1, msa2 = args.msa1, args.msa2
|
|
128
|
+
|
|
129
|
+
metric = model_one_vs_many(
|
|
130
|
+
msa_file1=msa1,
|
|
131
|
+
msa_file2=msa2,
|
|
132
|
+
max_recycles=args.recycles,
|
|
133
|
+
output_dir=args.output,
|
|
134
|
+
param_dir=args.params,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
output_filename = os.path.join(args.output_dir, f"{metric.ID}_metrics.csv")
|
|
138
|
+
print_err(f"Saving metrics as {output_filename}")
|
|
139
|
+
write_metrics(metric,
|
|
140
|
+
filename=output_filename)
|
|
141
|
+
|
|
142
|
+
return None
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
@clicommand(message="Modelling sets of PPIs with the following parameters")
|
|
146
|
+
def _af2_many_vs_many(args: Namespace) -> None:
|
|
147
|
+
|
|
148
|
+
if args.list_file:
|
|
149
|
+
msa1, msa2 = _load_msa_list(args.msa1, args.msa2)
|
|
150
|
+
else:
|
|
151
|
+
msa1, msa2 = args.msa1, args.msa2
|
|
152
|
+
|
|
153
|
+
metrics = model_many_vs_many(
|
|
154
|
+
msa_files1=msa1,
|
|
155
|
+
msa_files2=msa2,
|
|
156
|
+
output_dir=args.output,
|
|
157
|
+
max_recycles=args.recycles,
|
|
158
|
+
param_dir=args.params,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
output_filename = os.path.join(args.output, "_all_metrics.csv")
|
|
162
|
+
print_err(f"Saving metrics as {output_filename}")
|
|
163
|
+
write_metrics(metrics,
|
|
164
|
+
filename=output_filename)
|
|
165
|
+
|
|
166
|
+
return None
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def main() -> None:
|
|
170
|
+
inputs = CLIOption('msa1',
|
|
171
|
+
default=sys.stdin,
|
|
172
|
+
type=FileType('r'),
|
|
173
|
+
nargs='?',
|
|
174
|
+
help='MSA file. Default: STDIN.')
|
|
175
|
+
input2 = CLIOption('--msa2', '-2',
|
|
176
|
+
type=FileType('r'),
|
|
177
|
+
default=None,
|
|
178
|
+
help='Second MSA file.')
|
|
179
|
+
inputs_list = CLIOption('msa1',
|
|
180
|
+
default=sys.stdin,
|
|
181
|
+
type=FileType('r'),
|
|
182
|
+
nargs='*',
|
|
183
|
+
help='MSA file(s).')
|
|
184
|
+
inputs_list2 = CLIOption('--msa2', '-2',
|
|
185
|
+
type=FileType('r'),
|
|
186
|
+
default=None,
|
|
187
|
+
nargs='*',
|
|
188
|
+
help='Second MSA file(s). Default: if not provided, all pairwise from msa1.')
|
|
189
|
+
list_file = CLIOption('--list-file', '-l',
|
|
190
|
+
action='store_true',
|
|
191
|
+
help='Treat inputs as plain-text list of MSA files, rather than MSA filenames. '
|
|
192
|
+
'Default: treat as MSA filenames.')
|
|
193
|
+
output = CLIOption('--output', '-o',
|
|
194
|
+
type=str,
|
|
195
|
+
required=True,
|
|
196
|
+
help='Output directory.')
|
|
197
|
+
plot = CLIOption('--plot', '-p',
|
|
198
|
+
type=str,
|
|
199
|
+
default=None,
|
|
200
|
+
help='Directory for saving plots. Default: don\'t plot.')
|
|
201
|
+
cpu = CLIOption('--cpu', '-c',
|
|
202
|
+
action='store_true',
|
|
203
|
+
help='Whether to use CPU only. Default: use GPU.')
|
|
204
|
+
output_file = CLIOption('--output', '-o',
|
|
205
|
+
default=sys.stdout,
|
|
206
|
+
type=FileType('w'),
|
|
207
|
+
nargs='?',
|
|
208
|
+
help='Output filename. Default: STDOUT.')
|
|
209
|
+
apc = CLIOption('--apc', '-a',
|
|
210
|
+
action='store_true',
|
|
211
|
+
help='Whether to use APC correction in DCA. Default: don\'t apply correction.')
|
|
212
|
+
params = CLIOption('--params', '-w',
|
|
213
|
+
type=str,
|
|
214
|
+
default=None,
|
|
215
|
+
help='Path to AlphaFold2 params file (.npz).')
|
|
216
|
+
recycles = CLIOption('--recycles', '-x',
|
|
217
|
+
type=int,
|
|
218
|
+
default=10,
|
|
219
|
+
help='Maximum number of recyles through the model.')
|
|
220
|
+
|
|
221
|
+
rf2t_single = CLICommand('rf2t-single',
|
|
222
|
+
description='Calculate RF-2track contacts for between one protein and a series of others.',
|
|
223
|
+
main=_rf2t_single,
|
|
224
|
+
options=[inputs, inputs_list2, list_file, output_file, plot, cpu])
|
|
225
|
+
dca_single = CLICommand('dca-single',
|
|
226
|
+
description='Calculate DCA for one protein-protein interaction.',
|
|
227
|
+
main=_dca_single,
|
|
228
|
+
options=[inputs, inputs_list2, list_file, output_file, plot, apc])
|
|
229
|
+
dca_many = CLICommand('dca-many',
|
|
230
|
+
description='Calculate DCA between two sets of proteins, or all pairs in one set of proteins.',
|
|
231
|
+
main=_dca_many_vs_many,
|
|
232
|
+
options=[inputs_list, inputs_list2, list_file, apc, output_file, plot])
|
|
233
|
+
af2_single = CLICommand('af2-single',
|
|
234
|
+
description='Model one protein-protein interaction.',
|
|
235
|
+
main=_af2_single,
|
|
236
|
+
options=[inputs, inputs_list2, list_file, output, params, recycles, plot])
|
|
237
|
+
af2_many = CLICommand('af2-many',
|
|
238
|
+
description='Model all interactions between two sets of proteins, or all pairs in one set of proteins.',
|
|
239
|
+
main=_af2_many_vs_many,
|
|
240
|
+
options=[inputs_list, inputs_list2, list_file, output, params, recycles, plot])
|
|
241
|
+
|
|
242
|
+
app = CLIApp("sppid",
|
|
243
|
+
version=__version__,
|
|
244
|
+
description="Screening protein-protein interactions using DCA, RosettaFold-2track, and AlphaFold2.",
|
|
245
|
+
commands=[dca_single, dca_many, rf2t_single, af2_single, af2_many])
|
|
246
|
+
|
|
247
|
+
app.run()
|
|
248
|
+
return None
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
if __name__ == "__main__":
|
|
252
|
+
main()
|
yunta/dca.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
"""Code from Humpreys, Science, 2021 (DOI: 10.1126/science.abm4805), reimplemented for Tensorflow 2.
|
|
2
|
+
|
|
3
|
+
According to the paper's supplementary material:
|
|
4
|
+
"We reimplemented DCA so that it can be computed using GPUs to speed up the calculation.
|
|
5
|
+
The code is in the box below. In addition, we applied average product correction (APC).[...]"
|
|
6
|
+
|
|
7
|
+
```python
|
|
8
|
+
import sys
|
|
9
|
+
import numpy as np
|
|
10
|
+
import string
|
|
11
|
+
import tensorflow as tf
|
|
12
|
+
|
|
13
|
+
def parse_a3m(a3mlines):
|
|
14
|
+
seqs = []
|
|
15
|
+
labels = []
|
|
16
|
+
for line in a3mlines:
|
|
17
|
+
if line[0] == '>':
|
|
18
|
+
labels.append(line.rstrip())
|
|
19
|
+
else:
|
|
20
|
+
seqs.append(line[:-1])
|
|
21
|
+
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-"), dtype='|S1').view(np.uint8)
|
|
22
|
+
seq_num = np.array([list(s) for s in seqs], dtype='|S1').view(np.uint8)
|
|
23
|
+
for i in range(alphabet.shape[0]):
|
|
24
|
+
seq_num[seq_num == alphabet[i]] = i
|
|
25
|
+
seq_num[seq_num > 20] = 20
|
|
26
|
+
return {'seqs' : seq_num, 'labels' : labels }
|
|
27
|
+
|
|
28
|
+
def tf_cov(x,w=None):
|
|
29
|
+
if w is None:
|
|
30
|
+
num_points = tf.cast(tf.shape(x)[0],tf.float32) - 1
|
|
31
|
+
x_mean = tf.reduce_mean(x, axis=0, keep_dims=True)
|
|
32
|
+
x = (x - x_mean)
|
|
33
|
+
else:
|
|
34
|
+
num_points = tf.reduce_sum(w) - tf.sqrt(tf.reduce_mean(w))
|
|
35
|
+
x_mean = tf.reduce_sum(x * w[:,None], axis=0, keepdims=True) / num_points
|
|
36
|
+
x = (x - x_mean) * tf.sqrt(w[:,None])
|
|
37
|
+
return tf.matmul(tf.transpose(x),x)/num_points
|
|
38
|
+
|
|
39
|
+
fp = open(sys.argv[1] + ".alignments", "r")
|
|
40
|
+
pair2a3m = {}
|
|
41
|
+
pair = ""
|
|
42
|
+
pairs = []
|
|
43
|
+
a3m = []
|
|
44
|
+
for line in fp:
|
|
45
|
+
if line[:2] == ">>":
|
|
46
|
+
if pair and a3m:
|
|
47
|
+
pair2a3m[pair] = a3m
|
|
48
|
+
pairs.append(pair)
|
|
49
|
+
pair = line[2:-1]
|
|
50
|
+
a3m = []
|
|
51
|
+
else:
|
|
52
|
+
a3m.append(line)
|
|
53
|
+
fp.close()
|
|
54
|
+
if pair and a3m:
|
|
55
|
+
pair2a3m[pair] = a3m
|
|
56
|
+
pairs.append(pair)
|
|
57
|
+
|
|
58
|
+
config = tf.ConfigProto(gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9))
|
|
59
|
+
|
|
60
|
+
with tf.Graph().as_default():
|
|
61
|
+
x = tf.placeholder(tf.uint8,shape=(None,None),name="x")
|
|
62
|
+
x_shape = tf.shape(x)
|
|
63
|
+
x_nr = x_shape[0]
|
|
64
|
+
x_nc = x_shape[1]
|
|
65
|
+
x_ns = 21
|
|
66
|
+
x_msa = tf.one_hot(x,x_ns)
|
|
67
|
+
x_cutoff = tf.cast(x_nc,tf.float32) * 0.8
|
|
68
|
+
x_pw = tf.tensordot(x_msa, x_msa, [[1,2], [1,2]])
|
|
69
|
+
x_cut = x_pw > x_cutoff
|
|
70
|
+
x_weights = 1.0/tf.reduce_sum(tf.cast(x_cut, dtype=tf.float32),-1)
|
|
71
|
+
x_feat = tf.reshape(x_msa,(x_nr,x_nc*x_ns))
|
|
72
|
+
x_c = tf_cov(x_feat,x_weights) + tf.eye(x_nc*x_ns) *
|
|
73
|
+
4.5/tf.sqrt(tf.reduce_sum(x_weights))
|
|
74
|
+
x_c_inv = tf.linalg.inv(x_c)
|
|
75
|
+
x_w = tf.reshape(x_c_inv,(x_nc,x_ns,x_nc,x_ns))
|
|
76
|
+
x_wi = tf.sqrt(tf.reduce_sum(tf.square(x_w[:,:-1,:,:-1]),(1,3))) * (1-tf.eye(x_nc))
|
|
77
|
+
# APC
|
|
78
|
+
#x_ap = tf.reduce_sum(x_wi,0,keepdims=True) * tf.reduce_sum(x_wi,1,keepdims=True) /
|
|
79
|
+
tf.reduce_sum(x_wi)
|
|
80
|
+
#x_wip = (x_wi - x_ap) * (1-tf.eye(x_nc))
|
|
81
|
+
|
|
82
|
+
with tf.Session(config=config) as sess:
|
|
83
|
+
results = []
|
|
84
|
+
get_pairs = []
|
|
85
|
+
for pair in pairs:
|
|
86
|
+
print (pair)
|
|
87
|
+
msa = parse_a3m(pair2a3m[pair])
|
|
88
|
+
try:
|
|
89
|
+
wip = sess.run(x_wi,{x:msa['seqs']})
|
|
90
|
+
results.append(wip.astype(np.float16))
|
|
91
|
+
get_pairs.append(pair)
|
|
92
|
+
except tf.errors.ResourceExhaustedError as e:
|
|
93
|
+
pass
|
|
94
|
+
np.savez_compressed(sys.argv[1], *results, names=get_pairs)
|
|
95
|
+
rp = open(sys.argv[1] + ".log","w")
|
|
96
|
+
for pair in get_pairs:
|
|
97
|
+
rp.write(pair + "\n")
|
|
98
|
+
rp.close()
|
|
99
|
+
```
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
from typing import Optional, Union
|
|
103
|
+
from io import TextIOWrapper
|
|
104
|
+
import sys
|
|
105
|
+
|
|
106
|
+
from carabiner import print_err
|
|
107
|
+
import numpy as np
|
|
108
|
+
import tensorflow as tf
|
|
109
|
+
|
|
110
|
+
from .structs.msa import MSA, _A3M_ALPHABET
|
|
111
|
+
|
|
112
|
+
@tf.function
|
|
113
|
+
def _tf_cov(x: tf.Tensor, w: Optional[tf.Tensor] = None) -> tf.float32:
|
|
114
|
+
if w is None:
|
|
115
|
+
num_points = tf.cast(tf.shape(x)[0], tf.float32) - 1.
|
|
116
|
+
x_mean = tf.math.reduce_mean(x, axis=0, keep_dims=True)
|
|
117
|
+
x = (x - x_mean)
|
|
118
|
+
else:
|
|
119
|
+
num_points = tf.math.reduce_sum(w) - tf.math.sqrt(tf.math.reduce_mean(w))
|
|
120
|
+
x_mean = tf.math.reduce_sum(x * w[:,tf.newaxis], axis=0, keepdims=True) / num_points
|
|
121
|
+
x = (x - x_mean) * tf.math.sqrt(w[:,tf.newaxis])
|
|
122
|
+
return tf.linalg.matmul(x, x, transpose_a=True) / num_points
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
@tf.function
|
|
126
|
+
def _get_wip(x: tf.Tensor, apc: bool = False) -> tf.float32:
|
|
127
|
+
alphabet_size = len(_A3M_ALPHABET)
|
|
128
|
+
n_row, n_col = tf.shape(x)[-2], tf.shape(x)[-1]
|
|
129
|
+
x_cutoff = tf.cast(n_col, dtype=tf.float32) * .8
|
|
130
|
+
msa_one_hot = tf.one_hot(x, alphabet_size)
|
|
131
|
+
x_pw = tf.tensordot(msa_one_hot, msa_one_hot, [[1,2], [1,2]])
|
|
132
|
+
x_cut = tf.cast(x_pw > x_cutoff,
|
|
133
|
+
dtype=tf.float32)
|
|
134
|
+
|
|
135
|
+
weights = 1. / tf.math.reduce_sum(x_cut, axis=-1)
|
|
136
|
+
x_feat = tf.reshape(msa_one_hot, (n_row, n_col * alphabet_size))
|
|
137
|
+
|
|
138
|
+
x_c = (_tf_cov(x_feat, weights) + tf.eye(n_col * alphabet_size) * 4.5 /
|
|
139
|
+
tf.sqrt(tf.math.reduce_sum(weights)))
|
|
140
|
+
x_c_inv = tf.linalg.inv(x_c)
|
|
141
|
+
|
|
142
|
+
x_w = tf.reshape(x_c_inv, (n_col, alphabet_size, n_col, alphabet_size))
|
|
143
|
+
x_wi = tf.sqrt(tf.reduce_sum(tf.square(x_w[:,:-1,:,:-1]), (1, 3))) * (1. - tf.eye(n_col))
|
|
144
|
+
|
|
145
|
+
if apc:
|
|
146
|
+
x_ap = (tf.math.reduce_sum(x_wi, axis=0, keepdims=True) * tf.math.reduce_sum(x_wi, axis=1, keepdims=True)
|
|
147
|
+
/ tf.math.reduce_sum(x_wi))
|
|
148
|
+
x_wi = (x_wi - x_ap) * (1. - tf.eye(n_col))
|
|
149
|
+
|
|
150
|
+
return x_wi
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def calculate_dca(msa: MSA,
|
|
154
|
+
apc: bool = False,
|
|
155
|
+
gpu: bool = True) -> np.ndarray:
|
|
156
|
+
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
"""
|
|
160
|
+
print_err(f"Devices available:\n{tf.config.list_physical_devices()}")
|
|
161
|
+
if gpu:
|
|
162
|
+
gpus = tf.config.list_physical_devices('GPU')
|
|
163
|
+
if len(gpus) == 0:
|
|
164
|
+
print_err("WARNING! GPU requested but none available. Falling back to CPU.")
|
|
165
|
+
msa_token_ids = np.asarray(msa.sequence_token_ids)
|
|
166
|
+
try:
|
|
167
|
+
wip = _get_wip(msa_token_ids, apc=apc)
|
|
168
|
+
except tf.errors.ResourceExhaustedError as e:
|
|
169
|
+
with tf.device('/cpu:0'):
|
|
170
|
+
wip = _get_wip(msa_token_ids, apc=apc)
|
|
171
|
+
|
|
172
|
+
return wip.numpy()
|
yunta/dca_torch.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
"""DCA for GPU implemented in pytorch."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional, Union
|
|
4
|
+
|
|
5
|
+
from io import TextIOWrapper
|
|
6
|
+
import sys
|
|
7
|
+
|
|
8
|
+
from carabiner import print_err
|
|
9
|
+
import numpy as np
|
|
10
|
+
import torch
|
|
11
|
+
from torch import FloatTensor, Tensor
|
|
12
|
+
import torch.nn.functional as F
|
|
13
|
+
|
|
14
|
+
from .structs.msa import MSA, _A3M_ALPHABET
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _torch_cov(x: Tensor, w: Optional[Tensor] = None) -> FloatTensor:
|
|
18
|
+
if w is None:
|
|
19
|
+
return torch.cov(x)
|
|
20
|
+
else:
|
|
21
|
+
num_points = torch.sum(w) - torch.sqrt(torch.mean(w))
|
|
22
|
+
x_mean = torch.sum(x * w.unsqueeze(-1),
|
|
23
|
+
dim=0, keepdim=True) / num_points
|
|
24
|
+
x = (x - x_mean) * torch.sqrt(w.unsqueeze(-1))
|
|
25
|
+
return torch.matmul(x.transpose(-2, -1), x) / num_points
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _get_wip(x: Tensor, apc: bool = False, gpu: bool = True, dtype=torch.float16) -> FloatTensor:
|
|
29
|
+
device = torch.device("cuda" if( torch.cuda.is_available() and gpu) else "cpu")
|
|
30
|
+
# x = x.to(device)
|
|
31
|
+
alphabet_size = torch.tensor(len(_A3M_ALPHABET), device=device)
|
|
32
|
+
n_row, n_col = x.shape
|
|
33
|
+
msa_one_hot = (F.one_hot(x.to(torch.int64),
|
|
34
|
+
num_classes=alphabet_size)
|
|
35
|
+
.to(dtype)) # nrow, ncol, alphabet_size
|
|
36
|
+
identity_counts = torch.tensordot(msa_one_hot, msa_one_hot,
|
|
37
|
+
dims=[[1,2], [1,2]]) # nrow, nrow
|
|
38
|
+
# # Should be ~alphabet_size-fold fewer calculations and lower memory
|
|
39
|
+
# # but actually appears to be slower?
|
|
40
|
+
# identity_counts = (torch.eq(x, x)
|
|
41
|
+
# .to(dtype)
|
|
42
|
+
# .sum(dim=-1)) # nrow, nrow
|
|
43
|
+
identity_cutoff = torch.tensor(n_col * .8, device=device)
|
|
44
|
+
x_cut = (identity_counts > identity_cutoff).to(dtype)
|
|
45
|
+
|
|
46
|
+
weights = 1. / torch.sum(x_cut, dim=-1) # nrow
|
|
47
|
+
msa_concat_one_hot = msa_one_hot.view(n_row, n_col * alphabet_size)
|
|
48
|
+
|
|
49
|
+
shrinkage_coeff = torch.eye(msa_concat_one_hot.shape[-1],
|
|
50
|
+
device=device) * (4.5 / torch.sqrt(torch.sum(weights)))
|
|
51
|
+
covariance_matrix = _torch_cov(msa_concat_one_hot, weights) # nrow, nrow
|
|
52
|
+
shrunk_cov_matrix = covariance_matrix + shrinkage_coeff
|
|
53
|
+
shrunk_cov_matrix_inv = torch.linalg.inv(shrunk_cov_matrix).view(n_col, alphabet_size, n_col, alphabet_size)
|
|
54
|
+
|
|
55
|
+
shrunk_cov_matrix_inv_no_gaps = shrunk_cov_matrix_inv[:,:-1,:,:-1]
|
|
56
|
+
I_ncol = torch.eye(n_col, device=device)
|
|
57
|
+
interchain_scores = torch.sqrt(torch.sum(torch.square(shrunk_cov_matrix_inv_no_gaps),
|
|
58
|
+
dim=(1, 3))) * (1. - I_ncol)
|
|
59
|
+
|
|
60
|
+
if apc:
|
|
61
|
+
apc_factor = (torch.sum(interchain_scores,
|
|
62
|
+
dim=0, keepdim=True)
|
|
63
|
+
* torch.sum(interchain_scores,
|
|
64
|
+
dim=1, keepdim=True)
|
|
65
|
+
/ torch.sum(interchain_scores))
|
|
66
|
+
interchain_scores = (interchain_scores - apc_factor) * (1. - I_ncol)
|
|
67
|
+
|
|
68
|
+
return interchain_scores
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def calculate_dca(msa: MSA,
|
|
72
|
+
apc: bool = False,
|
|
73
|
+
gpu: bool = True) -> np.ndarray:
|
|
74
|
+
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
"""
|
|
78
|
+
n_gpu = torch.cuda.device_count()
|
|
79
|
+
print_err(f"GPUs available: {n_gpu}")
|
|
80
|
+
if gpu:
|
|
81
|
+
if n_gpu == 0:
|
|
82
|
+
print_err("WARNING! GPU requested but none available. Falling back to CPU.")
|
|
83
|
+
device = 'cpu'
|
|
84
|
+
else:
|
|
85
|
+
device = 'cuda'
|
|
86
|
+
# torch.cuda.empty_cache()
|
|
87
|
+
else:
|
|
88
|
+
device = 'cpu'
|
|
89
|
+
msa_token_ids = torch.tensor(msa.sequence_token_ids,
|
|
90
|
+
dtype=torch.int64,
|
|
91
|
+
device=torch.device(device))
|
|
92
|
+
|
|
93
|
+
try:
|
|
94
|
+
wip = _get_wip(msa_token_ids, apc=apc, gpu=device == 'cuda')
|
|
95
|
+
except torch.cuda.OutOfMemoryError as e:
|
|
96
|
+
print_err("GPU memory exhausted; falling back to CPU.")
|
|
97
|
+
wip = _get_wip(msa_token_ids.to('cpu'), apc=apc, gpu=False)
|
|
98
|
+
|
|
99
|
+
return wip.detach().cpu().numpy()
|
yunta/io.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""Tools for input and output."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
|
4
|
+
|
|
5
|
+
from csv import DictWriter
|
|
6
|
+
from dataclasses import is_dataclass, asdict
|
|
7
|
+
from io import TextIOWrapper
|
|
8
|
+
import os
|
|
9
|
+
import shutil
|
|
10
|
+
import tempfile
|
|
11
|
+
|
|
12
|
+
from bioino import FastaCollection
|
|
13
|
+
from carabiner import print_err
|
|
14
|
+
from carabiner.cast import cast
|
|
15
|
+
import numpy as np
|
|
16
|
+
from pandas import DataFrame
|
|
17
|
+
|
|
18
|
+
from .structs.pdb_structs import ATMRecord
|
|
19
|
+
from .structs.metrics import ModelMetrics
|
|
20
|
+
|
|
21
|
+
def save_design(pdb_info,
|
|
22
|
+
output_name: str,
|
|
23
|
+
chainA_length: int) -> None:
|
|
24
|
+
|
|
25
|
+
"""Save the resulting protein-peptide design to a pdb file.
|
|
26
|
+
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
chain_name = 'A'
|
|
30
|
+
with open(output_name, 'w') as f:
|
|
31
|
+
pdb_contents = pdb_info.split('\n')
|
|
32
|
+
for line in pdb_contents:
|
|
33
|
+
try:
|
|
34
|
+
record = ATMRecord(line)
|
|
35
|
+
if record.res_no > chainA_length:
|
|
36
|
+
chain_name = 'B'
|
|
37
|
+
outline = f"{line[:21]}{chain_name}{line[22:]}"
|
|
38
|
+
print(outline, file=f)
|
|
39
|
+
except Exception:
|
|
40
|
+
print(line, file=f)
|
|
41
|
+
return None
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def write_metrics(metrics: Union[Iterable, Any],
|
|
45
|
+
filename: Union[str, TextIOWrapper],
|
|
46
|
+
mode: str = 'w'):
|
|
47
|
+
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
if is_dataclass(metrics):
|
|
53
|
+
metrics = [metrics]
|
|
54
|
+
if not all(is_dataclass(m) for m in metrics):
|
|
55
|
+
raise TypeError("All metrics must be dataclass objects.")
|
|
56
|
+
f = cast(filename, to=TextIOWrapper, mode=mode)
|
|
57
|
+
if f.name != '<stdout>':
|
|
58
|
+
dirname = os.path.dirname(f.name)
|
|
59
|
+
if len(dirname) > 0 and not os.path.exists(dirname):
|
|
60
|
+
print_err(f"Creating output directory {dirname}")
|
|
61
|
+
os.makedirs(dirname)
|
|
62
|
+
|
|
63
|
+
w = DictWriter(f, fieldnames=list(asdict(metrics[0])), delimiter='\t')
|
|
64
|
+
if mode == 'w':
|
|
65
|
+
w.writeheader()
|
|
66
|
+
for metric in metrics:
|
|
67
|
+
w.writerow(asdict(metric))
|
|
68
|
+
|
|
69
|
+
return None
|
yunta/modelling.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""Tools for setting up and using models."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from time import time
|
|
7
|
+
|
|
8
|
+
from carabiner import print_err
|
|
9
|
+
|
|
10
|
+
from .weights import get_model_weights
|
|
11
|
+
|
|
12
|
+
_params_path = os.path.join(os.path.dirname(__file__), "data")
|
|
13
|
+
|
|
14
|
+
def make_model_runner(num_ensemble: int = 1,
|
|
15
|
+
max_recycles: int = 10,
|
|
16
|
+
param_dir: Optional[str] = None,
|
|
17
|
+
model_name: Optional[str] = None):
|
|
18
|
+
|
|
19
|
+
"""Generate an AlphaFold2 model runner.
|
|
20
|
+
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import tensorflow as tf
|
|
24
|
+
tf.config.experimental.set_visible_devices([], "GPU")
|
|
25
|
+
from .src_speedppi.alphafold.model import config, data, model
|
|
26
|
+
from jax.lib import xla_bridge
|
|
27
|
+
|
|
28
|
+
print_err(f"Setting up AlphaFold2 model. XLA platform available: {xla_bridge.get_backend().platform}")
|
|
29
|
+
|
|
30
|
+
if model_name is None:
|
|
31
|
+
model_name = 'model_1'
|
|
32
|
+
if param_dir is None:
|
|
33
|
+
param_dir = get_model_weights(model_name)
|
|
34
|
+
|
|
35
|
+
model_config = config.model_config(model_name)
|
|
36
|
+
model_config.data.eval.num_ensemble = num_ensemble
|
|
37
|
+
model_config.data.common.num_recycle = max_recycles
|
|
38
|
+
model_config.model.num_recycle = max_recycles
|
|
39
|
+
|
|
40
|
+
model_params = data.get_model_haiku_params(model_name=model_name,
|
|
41
|
+
data_dir=param_dir)
|
|
42
|
+
|
|
43
|
+
return model.RunModel(model_config, model_params)
|
yunta/plots.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
"""Plotting functions."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
|
|
7
|
+
from carabiner import colorblind_palette, print_err
|
|
8
|
+
from carabiner.mpl import grid
|
|
9
|
+
from pandas import DataFrame
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
def plot_matrix(m: np.ndarray,
|
|
13
|
+
filename_prefix: Optional[str] = None,
|
|
14
|
+
format: str = 'png',
|
|
15
|
+
dpi: int = 300,
|
|
16
|
+
vline: Optional[float] = None,
|
|
17
|
+
hline: Optional[float] = None,
|
|
18
|
+
vmax: Optional[float] = None,
|
|
19
|
+
*args, **kwargs):
|
|
20
|
+
|
|
21
|
+
fig, axes = grid()
|
|
22
|
+
im = axes.imshow(m, cmap='magma', vmin=0., vmax=vmax)
|
|
23
|
+
fig.colorbar(im, shrink=.7)
|
|
24
|
+
if hline is not None:
|
|
25
|
+
axes.axhline(hline, color='lightgrey', zorder=10)
|
|
26
|
+
if vline is not None:
|
|
27
|
+
axes.axvline(vline, color='lightgrey', zorder=10)
|
|
28
|
+
axes.set(*args, **kwargs)
|
|
29
|
+
|
|
30
|
+
if filename_prefix is not None:
|
|
31
|
+
data_file = f"{filename_prefix}.tsv"
|
|
32
|
+
plot_dir = os.path.dirname(data_file)
|
|
33
|
+
if not os.path.exists(plot_dir):
|
|
34
|
+
print_err(f"Creating output directory {plot_dir}")
|
|
35
|
+
os.makedirs(plot_dir)
|
|
36
|
+
print_err(f"Saving matrix data as {data_file}")
|
|
37
|
+
(DataFrame(m,
|
|
38
|
+
index=np.arange(m.shape[0]),
|
|
39
|
+
columns=np.arange(m.shape[1]))
|
|
40
|
+
.to_csv(data_file, sep='\t', index=False))
|
|
41
|
+
plot_file = f"{filename_prefix}.{format}"
|
|
42
|
+
print_err(f"Saving matrix plot as {plot_file}")
|
|
43
|
+
fig.savefig(plot_file, bbox_inches="tight", dpi=dpi)
|
|
44
|
+
|
|
45
|
+
return fig, axes
|
|
46
|
+
|
|
47
|
+
def plot_dca(dca: np.ndarray,
|
|
48
|
+
filename_prefix: Optional[str] = None,
|
|
49
|
+
format: str = 'png',
|
|
50
|
+
dpi: int = 300):
|
|
51
|
+
|
|
52
|
+
if filename_prefix is not None:
|
|
53
|
+
filename_prefix += "_dca"
|
|
54
|
+
|
|
55
|
+
return plot_matrix(dca, filename_prefix=filename_prefix, format=format, dpi=dpi)
|