rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,128 @@
1
+ import torch
2
+ from rfd3.model.layers.layer_utils import (
3
+ MultiDimLinear,
4
+ RMSNorm,
5
+ Transition,
6
+ linearNoBias,
7
+ )
8
+ from torch import nn
9
+
10
+ from foundry.training.checkpoint import activation_checkpointing
11
+ from foundry.utils.torch import device_of
12
+
13
+
14
+ class AttentionPairBiasPairformerDeepspeed(nn.Module):
15
+ def __init__(self, c_a, c_s, c_pair, n_head, kq_norm=False):
16
+ super().__init__()
17
+ self.n_head = n_head
18
+ self.c_a = c_a
19
+ self.c_pair = c_pair
20
+ self.c = c_a // n_head
21
+
22
+ self.to_q = MultiDimLinear(c_a, (n_head, self.c))
23
+ self.to_k = MultiDimLinear(c_a, (n_head, self.c), bias=False, norm=kq_norm)
24
+ self.to_v = MultiDimLinear(c_a, (n_head, self.c), bias=False, norm=kq_norm)
25
+ self.to_b = linearNoBias(c_pair, n_head)
26
+ self.to_g = nn.Sequential(
27
+ MultiDimLinear(c_a, (n_head, self.c), bias=False),
28
+ nn.Sigmoid(),
29
+ )
30
+ self.to_a = linearNoBias(c_a, c_a)
31
+ # self.linear_output_project = nn.Sequential(
32
+ # LinearBiasInit(c_s, c_a, biasinit=-2.),
33
+ # nn.Sigmoid(),
34
+ # )
35
+ self.ln_0 = RMSNorm((c_pair,))
36
+ # self.ada_ln_1 = AdaLN(c_a=c_a, c_s=c_s)
37
+ self.ln_1 = RMSNorm((c_a,))
38
+ self.use_deepspeed_evo = False
39
+ self.force_bfloat16 = True
40
+
41
+ def forward(
42
+ self,
43
+ A_I, # [I, C_a]
44
+ S_I, # [I, C_a] | None
45
+ Z_II, # [I, I, C_z]
46
+ Beta_II=None, # [I, I]
47
+ ):
48
+ # Input projections
49
+ assert S_I is None
50
+ A_I = self.ln_1(A_I)
51
+
52
+ if self.use_deepspeed_evo or self.force_bfloat16:
53
+ A_I = A_I.to(torch.bfloat16)
54
+
55
+ Q_IH = self.to_q(A_I) # / np.sqrt(self.c)
56
+ K_IH = self.to_k(A_I)
57
+ V_IH = self.to_v(A_I)
58
+ B_IIH = self.to_b(self.ln_0(Z_II)) + Beta_II[..., None]
59
+ G_IH = self.to_g(A_I)
60
+
61
+ B, L = B_IIH.shape[:2]
62
+
63
+ if not self.use_deepspeed_evo or L <= 24:
64
+ Q_IH = Q_IH / torch.sqrt(
65
+ torch.tensor(self.c).to(Q_IH.device, torch.bfloat16)
66
+ )
67
+ # Attention
68
+ A_IIH = torch.softmax(
69
+ torch.einsum("...ihd,...jhd->...ijh", Q_IH, K_IH) + B_IIH, dim=-2
70
+ ) # softmax over j
71
+ ## G_IH: [I, H, C]
72
+ ## A_IIH: [I, I, H]
73
+ ## V_IH: [I, H, C]
74
+ A_I = torch.einsum("...ijh,...jhc->...ihc", A_IIH, V_IH)
75
+ A_I = G_IH * A_I # [B, I, H, C]
76
+ A_I = A_I.flatten(start_dim=-2) # [B, I, Ca]
77
+ else:
78
+ raise NotImplementedError
79
+
80
+ A_I = self.to_a(A_I)
81
+
82
+ return A_I
83
+
84
+
85
+ class PairformerBlock(nn.Module):
86
+ """
87
+ Attempt to replicate AF3 architecture from scratch.
88
+ """
89
+
90
+ def __init__(
91
+ self,
92
+ c_s,
93
+ c_z,
94
+ attention_pair_bias,
95
+ p_drop=0.1,
96
+ triangle_multiplication=None,
97
+ triangle_attention=None,
98
+ n_transition=4,
99
+ use_deepspeed_evo=True,
100
+ use_triangle_mult=False,
101
+ use_triangle_attn=False,
102
+ ):
103
+ super().__init__()
104
+
105
+ # self.drop_row = Dropout(broadcast_dim=-2, p_drop=p_drop)
106
+ # self.drop_col = Dropout(broadcast_dim=-3, p_drop=p_drop)
107
+
108
+ self.z_transition = Transition(c=c_z, n=n_transition)
109
+
110
+ if c_s > 0:
111
+ self.s_transition = Transition(c=c_s, n=n_transition)
112
+
113
+ self.attention_pair_bias = AttentionPairBiasPairformerDeepspeed(
114
+ c_a=c_s, c_s=0, c_pair=c_z, **attention_pair_bias
115
+ )
116
+
117
+ @activation_checkpointing
118
+ def forward(self, S_I, Z_II):
119
+ with torch.amp.autocast(
120
+ device_type=device_of(self).type, enabled=True, dtype=torch.bfloat16
121
+ ):
122
+ Z_II = Z_II + self.z_transition(Z_II)
123
+ if S_I is not None:
124
+ S_I = S_I + self.attention_pair_bias(
125
+ S_I, None, Z_II, Beta_II=torch.tensor([0.0], device=Z_II.device)
126
+ )
127
+ S_I = S_I + self.s_transition(S_I)
128
+ return S_I, Z_II
rfd3/run_inference.py ADDED
@@ -0,0 +1,45 @@
1
+ #!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../../../.ipd/shebang/rf3_exec.sh" "$0" "$@"'
2
+
3
+ import os
4
+
5
+ import hydra
6
+ import rootutils
7
+ from dotenv import load_dotenv
8
+ from omegaconf import DictConfig, OmegaConf
9
+
10
+ from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine
11
+
12
+ # Setup root dir and environment variables (more info: https://github.com/ashleve/rootutils)
13
+ # NOTE: Sets the `PROJECT_ROOT` environment variable to the root directory of the project (where `.project-root` is located)
14
+ rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
15
+
16
+ load_dotenv(override=True)
17
+
18
+ # If the user has set `PROJECT_PATH`, use it to build the config path; otherwise, fall back to `PROJECT_ROOT`
19
+ _config_path = os.path.join(os.environ["PROJECT_ROOT"], "models/rfd3/configs")
20
+
21
+
22
+ @hydra.main(
23
+ config_path=_config_path,
24
+ config_name="inference",
25
+ version_base="1.3",
26
+ )
27
+ def run_inference(cfg: DictConfig) -> None:
28
+ """Execute the specified inference pipeline"""
29
+
30
+ run_params_set = {"inputs", "n_batches", "out_dir"}
31
+ run_params = {k: v for k, v in cfg.items() if k in run_params_set}
32
+
33
+ # Create __init__ args by filtering for all configs not in run_params
34
+ cfg_dict = OmegaConf.to_container(cfg, resolve=True)
35
+ init_cfg_dict = {
36
+ k: v for k, v in cfg_dict.items() if k not in run_params_set | {"_target_"}
37
+ }
38
+
39
+ # Run
40
+ engine = RFD3InferenceEngine(**RFD3InferenceConfig(**init_cfg_dict))
41
+ engine.run(**run_params)
42
+
43
+
44
+ if __name__ == "__main__":
45
+ run_inference()
rfd3/testing/debug.py ADDED
@@ -0,0 +1,139 @@
1
+ #!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../../scripts/shebang/modelhub_exec.sh" "$0" "$@"'
2
+ # JBs debugging file, please create your own and go crazy!
3
+ import logging
4
+ import os
5
+ import sys
6
+ import time
7
+
8
+ import hydra
9
+ import ipdb # noqa: F401
10
+ import numpy as np
11
+ import rootutils
12
+ import torch
13
+ import tree
14
+ from atomworks.ml.utils.token import (
15
+ get_token_starts,
16
+ )
17
+ from rfd3.testing.testing_utils import (
18
+ TEST_CFG_TRAIN,
19
+ TEST_JSON_DATA,
20
+ build_pipelines,
21
+ instantiate_example,
22
+ load_train_or_val_cfg,
23
+ )
24
+
25
+ from foundry.utils.ddp import set_accelerator_based_on_availability
26
+
27
+ logging.basicConfig(level=logging.INFO)
28
+ logger = logging.getLogger(__name__)
29
+
30
+ # Same as train.py
31
+ rootutils.setup_root(__file__ + "/../..", indicator=".project-root", pythonpath=True)
32
+ _config_path = os.path.join(
33
+ os.environ.get("PROJECT_PATH", os.environ.get("PROJECT_ROOT", "../..")), "configs"
34
+ )
35
+ print(f"Config path: {_config_path}")
36
+ print(f"Project root: {os.environ.get('PROJECT_ROOT', '../..')}")
37
+
38
+
39
+ is_inference = True
40
+ args = TEST_JSON_DATA["1qys-1-refactored"]
41
+ input = instantiate_example(args, is_inference=is_inference)
42
+
43
+
44
+ TEST_CFG_TRAIN = (
45
+ load_train_or_val_cfg(name=sys.argv[1].split("=")[-1])
46
+ if len(sys.argv) > 1
47
+ else TEST_CFG_TRAIN
48
+ )
49
+
50
+
51
+ def forward(example, trainer, model, is_inference=is_inference):
52
+ network_input = trainer._assemble_network_inputs(example)
53
+
54
+ # Forward pass
55
+ device = "cuda:0"
56
+
57
+ def _inmap(path, x):
58
+ if hasattr(x, "cpu") and path != ("f", "msa_stack"):
59
+ return x.to(device)
60
+ else:
61
+ return x
62
+
63
+ network_input = tree.map_structure_with_path(_inmap, network_input)
64
+ model.eval() if is_inference else model.train()
65
+ if not is_inference:
66
+ network_output = model.forward(
67
+ input=network_input,
68
+ n_cycle=1,
69
+ coord_atom_lvl_to_be_noised=example["coord_atom_lvl_to_be_noised"].to(
70
+ device
71
+ ),
72
+ )
73
+ else:
74
+ with torch.no_grad():
75
+ network_output = model.forward(
76
+ input=network_input,
77
+ n_cycle=1,
78
+ coord_atom_lvl_to_be_noised=example["coord_atom_lvl_to_be_noised"].to(
79
+ device
80
+ ),
81
+ )
82
+ return network_output
83
+
84
+
85
+ def prep_forward(cfg):
86
+ trainer = hydra.utils.instantiate(
87
+ cfg.trainer,
88
+ loggers=None,
89
+ callbacks=None,
90
+ _convert_="partial",
91
+ _recursive_=False,
92
+ )
93
+ set_accelerator_based_on_availability(cfg)
94
+ trainer.initialize_or_update_trainer_state({"train_cfg": cfg})
95
+ cfg.trainer.devices_per_node = 1
96
+ cfg.trainer.num_nodes = 1
97
+ try:
98
+ trainer.fabric.launch()
99
+ except Exception as e:
100
+ print(f"Error: {e}")
101
+ print("Switching port")
102
+ os.environ["MASTER_PORT"] = str(1024 + np.random.randint(64512))
103
+ trainer.fabric.launch()
104
+ trainer.construct_model()
105
+ model = trainer.state["model"]
106
+ return model, trainer
107
+
108
+
109
+ def test_conditional_forward():
110
+ unindexed_cfg = load_train_or_val_cfg("test-unindexed")
111
+ unindexed_cfg.datasets.global_transform_args.train_conditions.island.frequency = (
112
+ 1e10
113
+ )
114
+ unindexed_cfg.datasets.global_transform_args.train_conditions.island.p_unindex_motif_tokens = 1.0
115
+ unindexed_pipes = build_pipelines(composed_config=unindexed_cfg)
116
+
117
+ t0 = time.time()
118
+ example = unindexed_pipes[is_inference](input)
119
+ example["example_id"] = "debug_example"
120
+ print(f"Time taken to process example: {time.time() - t0}")
121
+
122
+ aa = example["atom_array"]
123
+ t_aa = aa[get_token_starts(aa)] # noqa: F841
124
+
125
+ from rfd3.testing.debug_utils import pipe_out_to_file
126
+
127
+ pipe_out_to_file(example, save=True)
128
+
129
+ print("Preparing model")
130
+ model, trainer = prep_forward(TEST_CFG_TRAIN)
131
+ if is_inference:
132
+ model.eval()
133
+ trainer.state["model"].eval()
134
+ network_output = forward(example, trainer, model, is_inference=is_inference) # noqa: F841
135
+
136
+
137
+ if __name__ == "__main__":
138
+ test_conditional_forward()
139
+ print("Finished main")
@@ -0,0 +1,73 @@
1
+ import numpy as np
2
+ from atomworks.common import sum_string_arrays
3
+ from atomworks.io.utils.io_utils import to_cif_file
4
+ from atomworks.ml.transforms.center_random_augmentation import CenterRandomAugmentation
5
+ from biotite.structure import AtomArrayStack
6
+ from rfd3.trainer.rfd3 import _reassign_unindexed_token_chains
7
+ from rfd3.transforms.design_transforms import (
8
+ MotifCenterRandomAugmentation,
9
+ )
10
+
11
+
12
+ def pipe_out_to_file(output, save=True):
13
+ atom_array = output["atom_array"]
14
+
15
+ xyz = output["coord_atom_lvl_to_be_noised"]
16
+ idxs = np.argsort(output["t"].numpy())
17
+ eps = output["noise"].numpy()[idxs]
18
+ eps[0] = eps[0] * 0
19
+ x = AtomArrayStack(xyz.shape[0], xyz.shape[1])
20
+ x.coord = xyz[idxs].numpy() + eps
21
+
22
+ x.set_annotation("chain_id", ["A"] * xyz.shape[1])
23
+ x.set_annotation("atom_name", [f"C{i}" for i in range(x.shape[-1])])
24
+ x.set_annotation("res_id", output["feats"]["atom_to_token_map"])
25
+ x.set_annotation("element", ["C"] * x.shape[-1])
26
+ x.set_annotation("res_name", [atom_array.res_name[i] for i in range(x.shape[-1])])
27
+
28
+ if save:
29
+ f = f"{output.get('example_id', 'example')}_debug_out.cif"
30
+ to_cif_file(
31
+ x,
32
+ f,
33
+ id="x",
34
+ )
35
+ print("Saved cif file to:", f)
36
+ else:
37
+ return x
38
+
39
+
40
+ def save_pipe_out(atom_array):
41
+ atom_array = _reassign_unindexed_token_chains(atom_array)
42
+
43
+ f = "debug_out.cif"
44
+ to_cif_file(
45
+ atom_array,
46
+ f,
47
+ id="x",
48
+ )
49
+ print("Saved cif file to:", f)
50
+
51
+
52
+ def to_debug_pipe(pipe):
53
+ pipe.transforms = [
54
+ t
55
+ for t in pipe.transforms
56
+ if not isinstance(t, (CenterRandomAugmentation, MotifCenterRandomAugmentation))
57
+ ]
58
+ return pipe
59
+
60
+
61
+ # Allows to use atom-array whenever debugging by removing friction in atoms having the same identifiers
62
+ def save_debug_cif(atom_array, filepath, name="debug_out.cif"):
63
+ dummy_array = atom_array.copy()
64
+ dummy_array.chain_id = sum_string_arrays(
65
+ dummy_array.chain_id, "-", dummy_array.transformation_id.astype(str)
66
+ )
67
+
68
+ f = filepath + name
69
+ to_cif_file(
70
+ dummy_array,
71
+ f,
72
+ )
73
+ print("Saved cif file to:", f)