rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
import networkx as nx
|
|
2
|
+
import numpy as np
|
|
3
|
+
from atomworks.io.utils.bonds import _atom_array_to_networkx_graph
|
|
4
|
+
|
|
5
|
+
from foundry.utils.ddp import RankedLogger
|
|
6
|
+
|
|
7
|
+
global_logger = RankedLogger(__name__, rank_zero_only=False)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
#################################################################################
|
|
11
|
+
# Training sample conditioning utilities
|
|
12
|
+
#################################################################################
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def sample_island_tokens(
|
|
16
|
+
array_length,
|
|
17
|
+
island_len_min=5,
|
|
18
|
+
island_len_max=30,
|
|
19
|
+
n_islands_min=1,
|
|
20
|
+
n_islands_max=30,
|
|
21
|
+
max_length=None,
|
|
22
|
+
):
|
|
23
|
+
"""
|
|
24
|
+
Generate a boolean mask of length `array_length` with random contiguous islands (True segments)
|
|
25
|
+
while optionally constraining the total number of True values.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
array_length (int): Total length of the boolean array.
|
|
29
|
+
island_len_min (int): Minimum island length (inclusive).
|
|
30
|
+
island_len_max (int): Maximum island length (inclusive).
|
|
31
|
+
n_islands (int): Number of islands to attempt to generate.
|
|
32
|
+
max_length (int, optional): Maximum allowed total number of True values in the output.
|
|
33
|
+
If None, no constraint is applied.
|
|
34
|
+
seed (int, optional): Random seed for reproducibility.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
np.ndarray: Boolean array of length `array_length` with island positions set to True.
|
|
38
|
+
"""
|
|
39
|
+
n_islands = np.random.randint(n_islands_min, n_islands_max + 1)
|
|
40
|
+
|
|
41
|
+
mask = np.zeros(array_length, dtype=bool)
|
|
42
|
+
for _ in range(n_islands):
|
|
43
|
+
current_total = mask.sum()
|
|
44
|
+
if max_length is not None:
|
|
45
|
+
if current_total >= max_length:
|
|
46
|
+
break
|
|
47
|
+
remaining = max_length - current_total
|
|
48
|
+
else:
|
|
49
|
+
remaining = None # not used
|
|
50
|
+
|
|
51
|
+
# Randomly select a candidate island length.
|
|
52
|
+
candidate_length = np.random.randint(island_len_min, island_len_max + 1)
|
|
53
|
+
candidate_length = min(candidate_length, array_length) # Fit into array
|
|
54
|
+
|
|
55
|
+
# Choose a random starting index ensuring the island fits.
|
|
56
|
+
high_start = array_length - candidate_length
|
|
57
|
+
start = np.random.randint(0, high_start + 1)
|
|
58
|
+
|
|
59
|
+
# Evaluate the segment that would be activated.
|
|
60
|
+
segment = mask[start : start + candidate_length]
|
|
61
|
+
new_trues = np.sum(~segment)
|
|
62
|
+
|
|
63
|
+
# If we have a maximum True budget and adding all new positions would exceed it, adjust the island.
|
|
64
|
+
if max_length is not None and new_trues > remaining:
|
|
65
|
+
# We try to trim the island so that it adds at most `remaining` new True values.
|
|
66
|
+
count_new = 0
|
|
67
|
+
adjusted_length = 0
|
|
68
|
+
for i in range(candidate_length):
|
|
69
|
+
if not mask[start + i]:
|
|
70
|
+
count_new += 1
|
|
71
|
+
adjusted_length += 1
|
|
72
|
+
# Once we've added as many new trues as allowed, break.
|
|
73
|
+
if count_new >= remaining:
|
|
74
|
+
break
|
|
75
|
+
# Only add the island if its adjusted length meets the minimum requirement.
|
|
76
|
+
if adjusted_length < island_len_min:
|
|
77
|
+
continue # Skip this island and try the next one.
|
|
78
|
+
mask[start : start + adjusted_length] = True
|
|
79
|
+
else:
|
|
80
|
+
# No max constraint or this candidate island fits within the remaining budget.
|
|
81
|
+
mask[start : start + candidate_length] = True
|
|
82
|
+
|
|
83
|
+
assert mask.sum() <= array_length, "Generated mask exceeds array length."
|
|
84
|
+
return mask
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def sample_subgraph_atoms(
|
|
88
|
+
subarray, p_seed_furthest_from_o=0.8, n_bond_expectation=3, p_fix_all=0.0
|
|
89
|
+
):
|
|
90
|
+
"""
|
|
91
|
+
subarray: atom array for a single token (e.g. ligand or residue)
|
|
92
|
+
n_bond_expectation: expected number of bonds to sample from geometric distribution
|
|
93
|
+
p_seed_furthest_from_o: probability of choosing the furthest atom from the backbone oxygen atom as seed
|
|
94
|
+
p_fix_all: probability of fixing all atoms in the subarray (skips this function this function)
|
|
95
|
+
|
|
96
|
+
returns:
|
|
97
|
+
np.ndarray: boolean mask of atoms to be shown as motif (length of subarray)
|
|
98
|
+
"""
|
|
99
|
+
if random_condition(p_fix_all):
|
|
100
|
+
return np.ones(subarray.array_length(), dtype=bool)
|
|
101
|
+
|
|
102
|
+
# ... Create graph from subarray
|
|
103
|
+
G = _atom_array_to_networkx_graph(
|
|
104
|
+
subarray,
|
|
105
|
+
annotations=["atom_name"],
|
|
106
|
+
bond_order=False,
|
|
107
|
+
cast_aromatic_bonds_to_same_type=True,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
# ... Determine if subarray is a residue
|
|
111
|
+
is_protein = subarray.is_protein.all()
|
|
112
|
+
|
|
113
|
+
# ... Choose a seed atom
|
|
114
|
+
if random_condition(p_seed_furthest_from_o) and is_protein:
|
|
115
|
+
seed_atom = choose_furthest_from_oxygen(G)
|
|
116
|
+
else:
|
|
117
|
+
seed_atom = choose_uniformly_random_atom_name(subarray)
|
|
118
|
+
|
|
119
|
+
# ... Sample atoms within n bonds
|
|
120
|
+
# sample bonded fragment to show as motif from geom. distribution
|
|
121
|
+
p = 1 / (1 + n_bond_expectation)
|
|
122
|
+
n_bonds = np.random.geometric(p=p) - 1
|
|
123
|
+
atom_names = get_atom_names_within_n_bonds(
|
|
124
|
+
G, src_atom_name=seed_atom, n_bonds=n_bonds
|
|
125
|
+
)
|
|
126
|
+
is_motif_atom = np.isin(subarray.atom_name, atom_names)
|
|
127
|
+
|
|
128
|
+
return is_motif_atom
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
#################################################################################
|
|
132
|
+
# Graph traversal utilities | assume each node has "atom_name" attribute
|
|
133
|
+
#################################################################################
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def get_node_idx_from_atom_name(G, atom_name):
|
|
137
|
+
matches = [
|
|
138
|
+
node for node, data in G.nodes(data=True) if data.get("node_data") == atom_name
|
|
139
|
+
]
|
|
140
|
+
|
|
141
|
+
if len(matches) == 0:
|
|
142
|
+
raise ValueError(
|
|
143
|
+
f"No node with atom_name = '{atom_name}' found. Got {G.nodes(data=True)}"
|
|
144
|
+
)
|
|
145
|
+
elif len(matches) > 1:
|
|
146
|
+
raise ValueError(
|
|
147
|
+
f"Multiple nodes with atom_name = '{atom_name}' found: {matches}. Got {G.nodes(data=True)}"
|
|
148
|
+
)
|
|
149
|
+
else:
|
|
150
|
+
src_node = matches[0]
|
|
151
|
+
|
|
152
|
+
return src_node
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def get_atom_names_within_n_bonds(G, src_atom_name, n_bonds):
|
|
156
|
+
src_node = get_node_idx_from_atom_name(G, src_atom_name)
|
|
157
|
+
|
|
158
|
+
paths = nx.single_source_shortest_path_length(G, source=src_node, cutoff=n_bonds)
|
|
159
|
+
atom_indices = list(paths.keys())
|
|
160
|
+
atom_names = [G.nodes[i]["node_data"] for i in atom_indices]
|
|
161
|
+
return atom_names
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def choose_furthest_from_oxygen(G):
|
|
165
|
+
"""Chooses furthest node in graph from backbone oxygen atom"""
|
|
166
|
+
src_node = get_node_idx_from_atom_name(G, "O")
|
|
167
|
+
shortest_paths = nx.single_source_shortest_path_length(G, source=src_node)
|
|
168
|
+
|
|
169
|
+
max_dist = max(shortest_paths.values())
|
|
170
|
+
furthest_nodes = [node for node, dist in shortest_paths.items() if dist == max_dist]
|
|
171
|
+
|
|
172
|
+
sampled_node = np.random.choice(furthest_nodes)
|
|
173
|
+
return G.nodes[sampled_node]["node_data"]
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def choose_uniformly_random_atom_name(subarray):
|
|
177
|
+
valid_indices = np.where(subarray.occupancy > 0)[0]
|
|
178
|
+
if len(valid_indices) == 0:
|
|
179
|
+
# raise ValueError("No atoms with occupancy > 0")
|
|
180
|
+
# global_logger.warning("No atoms with occupancy > 0")
|
|
181
|
+
valid_indices = np.arange(subarray.array_length())
|
|
182
|
+
sampled_idx = np.random.choice(valid_indices)
|
|
183
|
+
return subarray.atom_name[sampled_idx]
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
#################################################################################
|
|
187
|
+
# Utility functions
|
|
188
|
+
#################################################################################
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def random_condition(p_cond):
|
|
192
|
+
"""
|
|
193
|
+
Made this function because I always get confused by which order the
|
|
194
|
+
inequality should be
|
|
195
|
+
"""
|
|
196
|
+
assert 0 <= p_cond <= 1, "p_cond must be between 0 and 1"
|
|
197
|
+
if p_cond == 0:
|
|
198
|
+
return False
|
|
199
|
+
else:
|
|
200
|
+
return np.random.rand() < p_cond
|