rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,200 @@
1
+ import networkx as nx
2
+ import numpy as np
3
+ from atomworks.io.utils.bonds import _atom_array_to_networkx_graph
4
+
5
+ from foundry.utils.ddp import RankedLogger
6
+
7
+ global_logger = RankedLogger(__name__, rank_zero_only=False)
8
+
9
+
10
+ #################################################################################
11
+ # Training sample conditioning utilities
12
+ #################################################################################
13
+
14
+
15
+ def sample_island_tokens(
16
+ array_length,
17
+ island_len_min=5,
18
+ island_len_max=30,
19
+ n_islands_min=1,
20
+ n_islands_max=30,
21
+ max_length=None,
22
+ ):
23
+ """
24
+ Generate a boolean mask of length `array_length` with random contiguous islands (True segments)
25
+ while optionally constraining the total number of True values.
26
+
27
+ Args:
28
+ array_length (int): Total length of the boolean array.
29
+ island_len_min (int): Minimum island length (inclusive).
30
+ island_len_max (int): Maximum island length (inclusive).
31
+ n_islands (int): Number of islands to attempt to generate.
32
+ max_length (int, optional): Maximum allowed total number of True values in the output.
33
+ If None, no constraint is applied.
34
+ seed (int, optional): Random seed for reproducibility.
35
+
36
+ Returns:
37
+ np.ndarray: Boolean array of length `array_length` with island positions set to True.
38
+ """
39
+ n_islands = np.random.randint(n_islands_min, n_islands_max + 1)
40
+
41
+ mask = np.zeros(array_length, dtype=bool)
42
+ for _ in range(n_islands):
43
+ current_total = mask.sum()
44
+ if max_length is not None:
45
+ if current_total >= max_length:
46
+ break
47
+ remaining = max_length - current_total
48
+ else:
49
+ remaining = None # not used
50
+
51
+ # Randomly select a candidate island length.
52
+ candidate_length = np.random.randint(island_len_min, island_len_max + 1)
53
+ candidate_length = min(candidate_length, array_length) # Fit into array
54
+
55
+ # Choose a random starting index ensuring the island fits.
56
+ high_start = array_length - candidate_length
57
+ start = np.random.randint(0, high_start + 1)
58
+
59
+ # Evaluate the segment that would be activated.
60
+ segment = mask[start : start + candidate_length]
61
+ new_trues = np.sum(~segment)
62
+
63
+ # If we have a maximum True budget and adding all new positions would exceed it, adjust the island.
64
+ if max_length is not None and new_trues > remaining:
65
+ # We try to trim the island so that it adds at most `remaining` new True values.
66
+ count_new = 0
67
+ adjusted_length = 0
68
+ for i in range(candidate_length):
69
+ if not mask[start + i]:
70
+ count_new += 1
71
+ adjusted_length += 1
72
+ # Once we've added as many new trues as allowed, break.
73
+ if count_new >= remaining:
74
+ break
75
+ # Only add the island if its adjusted length meets the minimum requirement.
76
+ if adjusted_length < island_len_min:
77
+ continue # Skip this island and try the next one.
78
+ mask[start : start + adjusted_length] = True
79
+ else:
80
+ # No max constraint or this candidate island fits within the remaining budget.
81
+ mask[start : start + candidate_length] = True
82
+
83
+ assert mask.sum() <= array_length, "Generated mask exceeds array length."
84
+ return mask
85
+
86
+
87
+ def sample_subgraph_atoms(
88
+ subarray, p_seed_furthest_from_o=0.8, n_bond_expectation=3, p_fix_all=0.0
89
+ ):
90
+ """
91
+ subarray: atom array for a single token (e.g. ligand or residue)
92
+ n_bond_expectation: expected number of bonds to sample from geometric distribution
93
+ p_seed_furthest_from_o: probability of choosing the furthest atom from the backbone oxygen atom as seed
94
+ p_fix_all: probability of fixing all atoms in the subarray (skips this function this function)
95
+
96
+ returns:
97
+ np.ndarray: boolean mask of atoms to be shown as motif (length of subarray)
98
+ """
99
+ if random_condition(p_fix_all):
100
+ return np.ones(subarray.array_length(), dtype=bool)
101
+
102
+ # ... Create graph from subarray
103
+ G = _atom_array_to_networkx_graph(
104
+ subarray,
105
+ annotations=["atom_name"],
106
+ bond_order=False,
107
+ cast_aromatic_bonds_to_same_type=True,
108
+ )
109
+
110
+ # ... Determine if subarray is a residue
111
+ is_protein = subarray.is_protein.all()
112
+
113
+ # ... Choose a seed atom
114
+ if random_condition(p_seed_furthest_from_o) and is_protein:
115
+ seed_atom = choose_furthest_from_oxygen(G)
116
+ else:
117
+ seed_atom = choose_uniformly_random_atom_name(subarray)
118
+
119
+ # ... Sample atoms within n bonds
120
+ # sample bonded fragment to show as motif from geom. distribution
121
+ p = 1 / (1 + n_bond_expectation)
122
+ n_bonds = np.random.geometric(p=p) - 1
123
+ atom_names = get_atom_names_within_n_bonds(
124
+ G, src_atom_name=seed_atom, n_bonds=n_bonds
125
+ )
126
+ is_motif_atom = np.isin(subarray.atom_name, atom_names)
127
+
128
+ return is_motif_atom
129
+
130
+
131
+ #################################################################################
132
+ # Graph traversal utilities | assume each node has "atom_name" attribute
133
+ #################################################################################
134
+
135
+
136
+ def get_node_idx_from_atom_name(G, atom_name):
137
+ matches = [
138
+ node for node, data in G.nodes(data=True) if data.get("node_data") == atom_name
139
+ ]
140
+
141
+ if len(matches) == 0:
142
+ raise ValueError(
143
+ f"No node with atom_name = '{atom_name}' found. Got {G.nodes(data=True)}"
144
+ )
145
+ elif len(matches) > 1:
146
+ raise ValueError(
147
+ f"Multiple nodes with atom_name = '{atom_name}' found: {matches}. Got {G.nodes(data=True)}"
148
+ )
149
+ else:
150
+ src_node = matches[0]
151
+
152
+ return src_node
153
+
154
+
155
+ def get_atom_names_within_n_bonds(G, src_atom_name, n_bonds):
156
+ src_node = get_node_idx_from_atom_name(G, src_atom_name)
157
+
158
+ paths = nx.single_source_shortest_path_length(G, source=src_node, cutoff=n_bonds)
159
+ atom_indices = list(paths.keys())
160
+ atom_names = [G.nodes[i]["node_data"] for i in atom_indices]
161
+ return atom_names
162
+
163
+
164
+ def choose_furthest_from_oxygen(G):
165
+ """Chooses furthest node in graph from backbone oxygen atom"""
166
+ src_node = get_node_idx_from_atom_name(G, "O")
167
+ shortest_paths = nx.single_source_shortest_path_length(G, source=src_node)
168
+
169
+ max_dist = max(shortest_paths.values())
170
+ furthest_nodes = [node for node, dist in shortest_paths.items() if dist == max_dist]
171
+
172
+ sampled_node = np.random.choice(furthest_nodes)
173
+ return G.nodes[sampled_node]["node_data"]
174
+
175
+
176
+ def choose_uniformly_random_atom_name(subarray):
177
+ valid_indices = np.where(subarray.occupancy > 0)[0]
178
+ if len(valid_indices) == 0:
179
+ # raise ValueError("No atoms with occupancy > 0")
180
+ # global_logger.warning("No atoms with occupancy > 0")
181
+ valid_indices = np.arange(subarray.array_length())
182
+ sampled_idx = np.random.choice(valid_indices)
183
+ return subarray.atom_name[sampled_idx]
184
+
185
+
186
+ #################################################################################
187
+ # Utility functions
188
+ #################################################################################
189
+
190
+
191
+ def random_condition(p_cond):
192
+ """
193
+ Made this function because I always get confused by which order the
194
+ inequality should be
195
+ """
196
+ assert 0 <= p_cond <= 1, "p_cond must be between 0 and 1"
197
+ if p_cond == 0:
198
+ return False
199
+ else:
200
+ return np.random.rand() < p_cond