factorforge-cds 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- factorforge/__init__.py +19 -0
- factorforge/__main__.py +8 -0
- factorforge/cli/__init__.py +5 -0
- factorforge/cli/legacy_cli.py +157 -0
- factorforge/cli/main.py +305 -0
- factorforge/core/interfaces/__init__.py +7 -0
- factorforge/core/interfaces/exporter.py +13 -0
- factorforge/core/interfaces/optimizer.py +85 -0
- factorforge/core/interfaces/validator.py +9 -0
- factorforge/database.py +150 -0
- factorforge/engines/__init__.py +60 -0
- factorforge/engines/ml/__init__.py +0 -0
- factorforge/engines/ml/plant_optimizer.py +325 -0
- factorforge/engines/registry.py +141 -0
- factorforge/engines/v1_archived/__init__.py +15 -0
- factorforge/engines/v2/__init__.py +13 -0
- factorforge/engines/v2/codon_table_builder.py +107 -0
- factorforge/engines/v2/construct_builder.py +403 -0
- factorforge/engines/v2/exporter.py +455 -0
- factorforge/engines/v2/optimizer.py +190 -0
- factorforge/engines/v2/pipeline.py +275 -0
- factorforge/engines/v2/rules/__init__.py +3 -0
- factorforge/engines/v2/rules/domesticator.py +403 -0
- factorforge/engines/v2/rules/reverse_translator.py +765 -0
- factorforge/engines/v2/rules/rule_engine.py +867 -0
- factorforge/engines/v2/scoring.py +232 -0
- factorforge/engines/v2/utils.py +231 -0
- factorforge/engines/v2/validator.py +383 -0
- factorforge/engines/v3/__init__.py +12 -0
- factorforge/engines/v3/explain.py +119 -0
- factorforge/engines/v3/inference/__init__.py +6 -0
- factorforge/engines/v3/inference/constrained_decoder.py +80 -0
- factorforge/engines/v3/inference/v2_adapter.py +72 -0
- factorforge/engines/v3/metrics.py +145 -0
- factorforge/engines/v3/modeling_bart_decoder.py +127 -0
- factorforge/engines/v3/pipeline.py +192 -0
- factorforge/engines/v3/synonym_mask.py +61 -0
- factorforge/engines/v3/tokenizer.py +192 -0
- factorforge/ml/__init__.py +33 -0
- factorforge/ml/feasibility.py +199 -0
- factorforge/ml/metrics.py +295 -0
- factorforge/utils/__init__.py +31 -0
- factorforge/utils/construct_id.py +8 -0
- factorforge/utils/exceptions.py +32 -0
- factorforge/utils/sequence_validator.py +189 -0
- factorforge/utils/validation.py +104 -0
- factorforge_cds-3.0.0.dist-info/METADATA +475 -0
- factorforge_cds-3.0.0.dist-info/RECORD +52 -0
- factorforge_cds-3.0.0.dist-info/WHEEL +5 -0
- factorforge_cds-3.0.0.dist-info/entry_points.txt +2 -0
- factorforge_cds-3.0.0.dist-info/licenses/LICENSE +201 -0
- factorforge_cds-3.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
"""Metrics for FactorForge v3 (CAI, perplexity, GC%)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import math
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import TYPE_CHECKING
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING: # pragma: no cover - for type checkers only
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _default_codon_table_path() -> Path:
|
|
16
|
+
return Path(__file__).resolve().parents[4] / "data" / "nbenthamiana_codons.json"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True)
|
|
20
|
+
class CodonUsageTable:
|
|
21
|
+
codon_to_aa: dict[str, str]
|
|
22
|
+
codon_weights: dict[str, float]
|
|
23
|
+
best_codon_for_aa: dict[str, str]
|
|
24
|
+
source: str | None = None
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def load_codon_usage_table(path: Path | None = None) -> CodonUsageTable:
|
|
28
|
+
table_path = path or _default_codon_table_path()
|
|
29
|
+
raw = json.loads(table_path.read_text(encoding="utf-8"))
|
|
30
|
+
if not isinstance(raw, dict):
|
|
31
|
+
raise ValueError("Codon table JSON must be an object")
|
|
32
|
+
|
|
33
|
+
codons = raw.get("codons")
|
|
34
|
+
if not isinstance(codons, dict):
|
|
35
|
+
raise ValueError("Codon table JSON missing 'codons'")
|
|
36
|
+
|
|
37
|
+
codon_to_aa: dict[str, str] = {}
|
|
38
|
+
codon_freq: dict[str, float] = {}
|
|
39
|
+
for codon, entry in codons.items():
|
|
40
|
+
if not isinstance(entry, dict) or not isinstance(codon, str):
|
|
41
|
+
continue
|
|
42
|
+
aa = entry.get("aa")
|
|
43
|
+
freq = entry.get("frequency")
|
|
44
|
+
if not isinstance(aa, str) or not isinstance(freq, (int, float)):
|
|
45
|
+
continue
|
|
46
|
+
codon_to_aa[codon] = aa
|
|
47
|
+
codon_freq[codon] = float(freq)
|
|
48
|
+
|
|
49
|
+
codon_weights = _build_codon_weights(codon_to_aa, codon_freq)
|
|
50
|
+
best_codon_for_aa = _best_codon_map(codon_to_aa, codon_freq)
|
|
51
|
+
source = raw.get("source") if isinstance(raw.get("source"), str) else None
|
|
52
|
+
|
|
53
|
+
return CodonUsageTable(
|
|
54
|
+
codon_to_aa=codon_to_aa,
|
|
55
|
+
codon_weights=codon_weights,
|
|
56
|
+
best_codon_for_aa=best_codon_for_aa,
|
|
57
|
+
source=source,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def compute_cai(dna_sequence: str, table: CodonUsageTable) -> float:
|
|
62
|
+
seq = dna_sequence.upper().replace("U", "T")
|
|
63
|
+
codon_count = len(seq) // 3
|
|
64
|
+
if codon_count == 0:
|
|
65
|
+
return 0.0
|
|
66
|
+
|
|
67
|
+
weights: list[float] = []
|
|
68
|
+
for i in range(codon_count):
|
|
69
|
+
codon = seq[i * 3 : i * 3 + 3]
|
|
70
|
+
weight = table.codon_weights.get(codon)
|
|
71
|
+
if weight is None or weight <= 0:
|
|
72
|
+
return 0.0
|
|
73
|
+
weights.append(weight)
|
|
74
|
+
|
|
75
|
+
log_sum = sum(math.log(w) for w in weights)
|
|
76
|
+
return math.exp(log_sum / len(weights))
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def compute_gc(dna_sequence: str) -> float:
|
|
80
|
+
seq = dna_sequence.upper()
|
|
81
|
+
if not seq:
|
|
82
|
+
return 0.0
|
|
83
|
+
gc = seq.count("G") + seq.count("C")
|
|
84
|
+
return (gc / len(seq)) * 100.0
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def compute_perplexity(
|
|
88
|
+
logits: "torch.Tensor",
|
|
89
|
+
labels: "torch.Tensor",
|
|
90
|
+
ignore_index: int = -100,
|
|
91
|
+
) -> float:
|
|
92
|
+
torch, functional = _require_torch()
|
|
93
|
+
vocab_size = int(logits.shape[-1])
|
|
94
|
+
loss = functional.cross_entropy(
|
|
95
|
+
logits.view(-1, vocab_size),
|
|
96
|
+
labels.view(-1),
|
|
97
|
+
ignore_index=ignore_index,
|
|
98
|
+
)
|
|
99
|
+
return float(torch.exp(loss).item())
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _build_codon_weights(
|
|
103
|
+
codon_to_aa: dict[str, str],
|
|
104
|
+
codon_freq: dict[str, float],
|
|
105
|
+
) -> dict[str, float]:
|
|
106
|
+
by_aa: dict[str, list[float]] = {}
|
|
107
|
+
for codon, aa in codon_to_aa.items():
|
|
108
|
+
if aa == "*":
|
|
109
|
+
continue
|
|
110
|
+
by_aa.setdefault(aa, []).append(codon_freq.get(codon, 0.0))
|
|
111
|
+
|
|
112
|
+
weights: dict[str, float] = {}
|
|
113
|
+
for codon, aa in codon_to_aa.items():
|
|
114
|
+
if aa == "*":
|
|
115
|
+
continue
|
|
116
|
+
max_freq = max(by_aa.get(aa, [0.0]))
|
|
117
|
+
freq = codon_freq.get(codon, 0.0)
|
|
118
|
+
weights[codon] = freq / max_freq if max_freq > 0 else 0.0
|
|
119
|
+
return weights
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def _best_codon_map(
|
|
123
|
+
codon_to_aa: dict[str, str],
|
|
124
|
+
codon_freq: dict[str, float],
|
|
125
|
+
) -> dict[str, str]:
|
|
126
|
+
best: dict[str, tuple[str, float]] = {}
|
|
127
|
+
for codon, aa in codon_to_aa.items():
|
|
128
|
+
if aa == "*":
|
|
129
|
+
continue
|
|
130
|
+
current = best.get(aa)
|
|
131
|
+
freq = codon_freq.get(codon, 0.0)
|
|
132
|
+
if current is None or freq > current[1]:
|
|
133
|
+
best[aa] = (codon, freq)
|
|
134
|
+
return {aa: codon for aa, (codon, _) in best.items()}
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _require_torch():
|
|
138
|
+
try:
|
|
139
|
+
import torch
|
|
140
|
+
from torch.nn import functional
|
|
141
|
+
except ImportError as exc: # pragma: no cover
|
|
142
|
+
raise ImportError(
|
|
143
|
+
"ML dependencies not installed. Install with: pip install -e \".[ml]\""
|
|
144
|
+
) from exc
|
|
145
|
+
return torch, functional
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
"""BART decoder skeleton for FactorForge v3."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
import torch
|
|
9
|
+
from torch import nn
|
|
10
|
+
from transformers import BartConfig
|
|
11
|
+
from transformers.models.bart.modeling_bart import BartDecoder
|
|
12
|
+
except ImportError as exc: # pragma: no cover - exercised in ML installs
|
|
13
|
+
raise ImportError(
|
|
14
|
+
"ML dependencies not installed. Install with: pip install -e \".[ml]\""
|
|
15
|
+
) from exc
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from torch import Tensor
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class BartDecoderSkeleton(nn.Module):
|
|
22
|
+
"""Lightweight BART decoder that consumes unified encoder embeddings."""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
vocab_size: int,
|
|
27
|
+
d_model: int = 256,
|
|
28
|
+
encoder_dim: int = 320, # ESM2 t6_8M per-token embedding dim
|
|
29
|
+
decoder_layers: int = 4,
|
|
30
|
+
decoder_attention_heads: int = 4,
|
|
31
|
+
ffn_dim: int = 1024,
|
|
32
|
+
max_position_embeddings: int = 512,
|
|
33
|
+
dropout: float = 0.1,
|
|
34
|
+
pad_token_id: int = 0,
|
|
35
|
+
bos_token_id: int = 1,
|
|
36
|
+
eos_token_id: int = 2,
|
|
37
|
+
) -> None:
|
|
38
|
+
super().__init__()
|
|
39
|
+
self.config = BartConfig(
|
|
40
|
+
vocab_size=vocab_size,
|
|
41
|
+
d_model=d_model,
|
|
42
|
+
decoder_layers=decoder_layers,
|
|
43
|
+
decoder_attention_heads=decoder_attention_heads,
|
|
44
|
+
decoder_ffn_dim=ffn_dim,
|
|
45
|
+
max_position_embeddings=max_position_embeddings,
|
|
46
|
+
encoder_layers=1,
|
|
47
|
+
encoder_attention_heads=decoder_attention_heads,
|
|
48
|
+
dropout=dropout,
|
|
49
|
+
activation_function="gelu",
|
|
50
|
+
pad_token_id=pad_token_id,
|
|
51
|
+
bos_token_id=bos_token_id,
|
|
52
|
+
eos_token_id=eos_token_id,
|
|
53
|
+
decoder_start_token_id=bos_token_id,
|
|
54
|
+
is_encoder_decoder=True,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
self.encoder_projection = nn.Linear(encoder_dim, d_model)
|
|
58
|
+
self.decoder = BartDecoder(self.config)
|
|
59
|
+
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
|
|
60
|
+
self.register_buffer("final_logits_bias", torch.zeros((1, vocab_size)))
|
|
61
|
+
|
|
62
|
+
def forward(
|
|
63
|
+
self,
|
|
64
|
+
encoder_hidden_states: "Tensor",
|
|
65
|
+
decoder_input_ids: "Tensor",
|
|
66
|
+
attention_mask: "Tensor | None" = None,
|
|
67
|
+
encoder_attention_mask: "Tensor | None" = None,
|
|
68
|
+
) -> "Tensor":
|
|
69
|
+
"""Teacher-forcing forward pass returning logits [B, T, vocab_size]."""
|
|
70
|
+
encoder_states = self.encoder_projection(encoder_hidden_states)
|
|
71
|
+
outputs = self.decoder(
|
|
72
|
+
input_ids=decoder_input_ids,
|
|
73
|
+
attention_mask=attention_mask,
|
|
74
|
+
encoder_hidden_states=encoder_states,
|
|
75
|
+
encoder_attention_mask=encoder_attention_mask,
|
|
76
|
+
use_cache=False,
|
|
77
|
+
return_dict=True,
|
|
78
|
+
)
|
|
79
|
+
logits = self.lm_head(outputs.last_hidden_state) + self.final_logits_bias.to(
|
|
80
|
+
outputs.last_hidden_state.device
|
|
81
|
+
)
|
|
82
|
+
return logits
|
|
83
|
+
|
|
84
|
+
def generate(
|
|
85
|
+
self,
|
|
86
|
+
encoder_hidden_states: "Tensor",
|
|
87
|
+
max_new_tokens: int,
|
|
88
|
+
bos_token_id: int,
|
|
89
|
+
eos_token_id: int,
|
|
90
|
+
pad_token_id: int | None = None,
|
|
91
|
+
) -> "Tensor":
|
|
92
|
+
"""Greedy decoding for a few steps (no beam search)."""
|
|
93
|
+
batch_size = encoder_hidden_states.shape[0]
|
|
94
|
+
device = encoder_hidden_states.device
|
|
95
|
+
|
|
96
|
+
decoder_input_ids = torch.full(
|
|
97
|
+
(batch_size, 1),
|
|
98
|
+
bos_token_id,
|
|
99
|
+
dtype=torch.long,
|
|
100
|
+
device=device,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
for _ in range(max_new_tokens):
|
|
104
|
+
logits = self.forward(
|
|
105
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
106
|
+
decoder_input_ids=decoder_input_ids,
|
|
107
|
+
attention_mask=None,
|
|
108
|
+
encoder_attention_mask=None,
|
|
109
|
+
)
|
|
110
|
+
next_token = torch.argmax(logits[:, -1, :], dim=-1)
|
|
111
|
+
decoder_input_ids = torch.cat(
|
|
112
|
+
[decoder_input_ids, next_token.unsqueeze(1)],
|
|
113
|
+
dim=1,
|
|
114
|
+
)
|
|
115
|
+
if torch.all(next_token == eos_token_id):
|
|
116
|
+
break
|
|
117
|
+
|
|
118
|
+
if pad_token_id is not None:
|
|
119
|
+
decoder_input_ids = decoder_input_ids.clone()
|
|
120
|
+
for idx in range(decoder_input_ids.shape[0]):
|
|
121
|
+
row = decoder_input_ids[idx]
|
|
122
|
+
eos_positions = (row == eos_token_id).nonzero(as_tuple=False)
|
|
123
|
+
if eos_positions.numel() > 0:
|
|
124
|
+
eos_position = int(eos_positions[0])
|
|
125
|
+
if eos_position + 1 < row.shape[0]:
|
|
126
|
+
row[eos_position + 1 :] = pad_token_id
|
|
127
|
+
return decoder_input_ids
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
"""FactorForge v3 pipeline scaffolding."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
7
|
+
|
|
8
|
+
from factorforge.core.interfaces import OptimizationResult, OptimizerEngine
|
|
9
|
+
from factorforge.engines.registry import EngineRegistry
|
|
10
|
+
from factorforge.engines.v2.rules.domesticator import Domesticator
|
|
11
|
+
from factorforge.engines.v2.rules.rule_engine import RuleEngine
|
|
12
|
+
|
|
13
|
+
from .explain import ExplainabilityInputs, build_fda_report
|
|
14
|
+
from .metrics import CodonUsageTable, compute_cai, compute_gc, load_codon_usage_table
|
|
15
|
+
from .tokenizer import AA_TOKENS, AATokenizer, CodonTokenizer
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from .modeling_bart_decoder import BartDecoderSkeleton
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
VALID_AA = set(AA_TOKENS)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class V3Result:
|
|
26
|
+
sequence: str
|
|
27
|
+
metrics: dict[str, float]
|
|
28
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
29
|
+
report: dict[str, Any] | None = None
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class V3Pipeline:
|
|
33
|
+
"""Minimal v3 pipeline scaffold (decoder + post-guard hook)."""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
aa_tokenizer: AATokenizer | None = None,
|
|
38
|
+
codon_tokenizer: CodonTokenizer | None = None,
|
|
39
|
+
codon_usage: CodonUsageTable | None = None,
|
|
40
|
+
decoder: "BartDecoderSkeleton | None" = None,
|
|
41
|
+
model_id: str = "v3-bart-decoder-skeleton",
|
|
42
|
+
) -> None:
|
|
43
|
+
self.aa_tokenizer = aa_tokenizer or AATokenizer.default()
|
|
44
|
+
self.codon_tokenizer = codon_tokenizer or CodonTokenizer.default()
|
|
45
|
+
self.codon_usage = codon_usage or load_codon_usage_table()
|
|
46
|
+
self.decoder = decoder
|
|
47
|
+
self.model_id = model_id
|
|
48
|
+
|
|
49
|
+
def run(
|
|
50
|
+
self,
|
|
51
|
+
sequence: str,
|
|
52
|
+
encoder_embeddings: Any | None = None,
|
|
53
|
+
max_new_tokens: int = 24,
|
|
54
|
+
apply_post_guard: bool = True,
|
|
55
|
+
seed: int = 0,
|
|
56
|
+
config: dict[str, Any] | None = None,
|
|
57
|
+
) -> V3Result:
|
|
58
|
+
aa_sequence = _normalize_aa(sequence)
|
|
59
|
+
|
|
60
|
+
if encoder_embeddings is not None:
|
|
61
|
+
if self.decoder is None:
|
|
62
|
+
raise ValueError("Decoder not configured. Provide a BartDecoderSkeleton instance.")
|
|
63
|
+
from factorforge.engines.v3.inference.constrained_decoder import (
|
|
64
|
+
constrained_greedy_decode,
|
|
65
|
+
validate_candidate_or_fallback,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
token_ids = constrained_greedy_decode(
|
|
69
|
+
self.decoder,
|
|
70
|
+
encoder_embeddings,
|
|
71
|
+
aa_sequence,
|
|
72
|
+
self.codon_tokenizer,
|
|
73
|
+
)
|
|
74
|
+
dna_sequence = self.codon_tokenizer.decode(token_ids[0].tolist())
|
|
75
|
+
fallback = validate_candidate_or_fallback(aa_sequence, dna_sequence)
|
|
76
|
+
dna_sequence = fallback["dna_sequence"]
|
|
77
|
+
else:
|
|
78
|
+
dna_sequence = self._fallback_gc_target(aa_sequence)
|
|
79
|
+
|
|
80
|
+
post_guard: dict[str, Any] | None = None
|
|
81
|
+
if apply_post_guard:
|
|
82
|
+
dna_sequence, post_guard = self._apply_post_guard(dna_sequence)
|
|
83
|
+
|
|
84
|
+
metrics = {
|
|
85
|
+
"cai": compute_cai(dna_sequence, self.codon_usage),
|
|
86
|
+
"gc_content": compute_gc(dna_sequence),
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
report_inputs = ExplainabilityInputs(
|
|
90
|
+
aa_sequence=aa_sequence,
|
|
91
|
+
dna_sequence=dna_sequence,
|
|
92
|
+
metrics=metrics,
|
|
93
|
+
model_id=self.model_id,
|
|
94
|
+
tokenizer_hash=self.codon_tokenizer.mapping_hash(),
|
|
95
|
+
seed=seed,
|
|
96
|
+
config=config or {},
|
|
97
|
+
post_guard=post_guard,
|
|
98
|
+
)
|
|
99
|
+
report = build_fda_report(report_inputs)
|
|
100
|
+
|
|
101
|
+
metadata = {
|
|
102
|
+
"model_id": self.model_id,
|
|
103
|
+
"post_guard": post_guard or {},
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
return V3Result(sequence=dna_sequence, metrics=metrics, metadata=metadata, report=report)
|
|
107
|
+
|
|
108
|
+
def _reverse_translate_best(self, aa_sequence: str) -> str:
|
|
109
|
+
codons: list[str] = []
|
|
110
|
+
for aa in aa_sequence:
|
|
111
|
+
codon = self.codon_usage.best_codon_for_aa.get(aa)
|
|
112
|
+
if codon is None:
|
|
113
|
+
raise ValueError(f"No codon mapping for amino acid: {aa}")
|
|
114
|
+
codons.append(codon)
|
|
115
|
+
return "".join(codons)
|
|
116
|
+
|
|
117
|
+
def _fallback_gc_target(self, aa_sequence: str) -> str:
|
|
118
|
+
from factorforge.engines.v3.inference.v2_adapter import optimize_with_v2
|
|
119
|
+
|
|
120
|
+
return optimize_with_v2(aa_sequence, options={"profile": "gc_target"})["dna_sequence"]
|
|
121
|
+
|
|
122
|
+
def _apply_post_guard(self, dna_sequence: str) -> tuple[str, dict[str, Any]]:
|
|
123
|
+
rule_engine = RuleEngine()
|
|
124
|
+
domesticator = Domesticator()
|
|
125
|
+
|
|
126
|
+
# Step 1: PolyA iterative fix (v2 pipeline과 동일한 방식)
|
|
127
|
+
polya_fix: dict[str, Any] | None = None
|
|
128
|
+
current_seq = dna_sequence
|
|
129
|
+
has_polya = any(p in dna_sequence for p in rule_engine.POLYA_PATTERNS)
|
|
130
|
+
if has_polya:
|
|
131
|
+
polya_result = rule_engine.fix_polya_iterative(current_seq)
|
|
132
|
+
if polya_result["success"]:
|
|
133
|
+
current_seq = polya_result["modified_seq"]
|
|
134
|
+
polya_fix = polya_result
|
|
135
|
+
|
|
136
|
+
# Step 2: Full scan on fixed sequence
|
|
137
|
+
scan_results = rule_engine.scan_all(current_seq)
|
|
138
|
+
|
|
139
|
+
# Step 3: Domestication
|
|
140
|
+
domestication = domesticator.domesticate(current_seq, standard="golden_gate")
|
|
141
|
+
domesticated = domestication.get("domesticated_seq", current_seq)
|
|
142
|
+
|
|
143
|
+
post_guard = {
|
|
144
|
+
"scan_results": scan_results,
|
|
145
|
+
"domestication": domestication,
|
|
146
|
+
"polya_fix": polya_fix,
|
|
147
|
+
"edited": domesticated != dna_sequence,
|
|
148
|
+
}
|
|
149
|
+
return domesticated, post_guard
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class V3Optimizer(OptimizerEngine):
|
|
153
|
+
"""Optimizer wrapper for v3 pipeline."""
|
|
154
|
+
|
|
155
|
+
name = "v3 BART Decoder"
|
|
156
|
+
version = "3.0.0"
|
|
157
|
+
|
|
158
|
+
def __init__(self, pipeline: V3Pipeline | None = None) -> None:
|
|
159
|
+
self.pipeline = pipeline or V3Pipeline()
|
|
160
|
+
|
|
161
|
+
def optimize(
|
|
162
|
+
self,
|
|
163
|
+
sequence: str,
|
|
164
|
+
profile: str | None = None,
|
|
165
|
+
**kwargs: Any,
|
|
166
|
+
) -> OptimizationResult:
|
|
167
|
+
result = self.pipeline.run(sequence, **kwargs)
|
|
168
|
+
return OptimizationResult(
|
|
169
|
+
sequence=result.sequence,
|
|
170
|
+
metrics=result.metrics,
|
|
171
|
+
metadata=result.metadata,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
def validate(self, sequence: str) -> bool:
|
|
175
|
+
try:
|
|
176
|
+
_normalize_aa(sequence)
|
|
177
|
+
return True
|
|
178
|
+
except ValueError:
|
|
179
|
+
return False
|
|
180
|
+
|
|
181
|
+
def get_supported_profiles(self) -> list[str]:
|
|
182
|
+
return ["default"]
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def _normalize_aa(sequence: str) -> str:
|
|
186
|
+
seq = sequence.strip().replace(" ", "").replace("\n", "").upper()
|
|
187
|
+
if not seq:
|
|
188
|
+
raise ValueError("Protein sequence is empty.")
|
|
189
|
+
invalid = {aa for aa in seq if aa not in VALID_AA}
|
|
190
|
+
if invalid:
|
|
191
|
+
raise ValueError(f"Invalid amino acids found: {''.join(sorted(invalid))}")
|
|
192
|
+
return seq
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""Synonymous codon masks for v3 training and constrained decoding."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Mapping
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from factorforge.ml.metrics import STANDARD_GENETIC_CODE
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
AA_TO_CODONS: dict[str, tuple[str, ...]] = {}
|
|
13
|
+
for _codon, _aa in STANDARD_GENETIC_CODE.items():
|
|
14
|
+
AA_TO_CODONS.setdefault(_aa, tuple())
|
|
15
|
+
for _aa in list(AA_TO_CODONS):
|
|
16
|
+
AA_TO_CODONS[_aa] = tuple(
|
|
17
|
+
codon for codon, codon_aa in STANDARD_GENETIC_CODE.items() if codon_aa == _aa
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def normalize_protein_sequence(protein_sequence: str) -> str:
|
|
22
|
+
"""Normalize a protein sequence for codon-mask construction."""
|
|
23
|
+
protein = "".join(protein_sequence.upper().split())
|
|
24
|
+
if protein.endswith("*"):
|
|
25
|
+
protein = protein[:-1]
|
|
26
|
+
if not protein:
|
|
27
|
+
raise ValueError("protein_sequence must not be empty")
|
|
28
|
+
invalid = [aa for aa in protein if aa not in AA_TO_CODONS or aa == "*"]
|
|
29
|
+
if invalid:
|
|
30
|
+
raise ValueError(f"No synonymous codons for amino acid(s): {''.join(sorted(set(invalid)))}")
|
|
31
|
+
return protein
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def synonymous_codons_for_aa(amino_acid: str) -> tuple[str, ...]:
|
|
35
|
+
"""Return synonymous non-stop codons for one standard amino acid."""
|
|
36
|
+
aa = amino_acid.upper()
|
|
37
|
+
codons = AA_TO_CODONS.get(aa, tuple())
|
|
38
|
+
if not codons or aa == "*":
|
|
39
|
+
raise ValueError(f"No synonymous codons for amino acid: {amino_acid}")
|
|
40
|
+
return codons
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def build_synonym_token_mask(
|
|
44
|
+
protein_sequence: str,
|
|
45
|
+
token_to_id: Mapping[str, int],
|
|
46
|
+
*,
|
|
47
|
+
device: torch.device | None = None,
|
|
48
|
+
) -> torch.Tensor:
|
|
49
|
+
"""Build a boolean mask [protein_length, vocab_size] for synonymous codons only."""
|
|
50
|
+
protein = normalize_protein_sequence(protein_sequence)
|
|
51
|
+
vocab_size = max(token_to_id.values()) + 1
|
|
52
|
+
mask = torch.zeros((len(protein), vocab_size), dtype=torch.bool, device=device)
|
|
53
|
+
for index, aa in enumerate(protein):
|
|
54
|
+
for codon in synonymous_codons_for_aa(aa):
|
|
55
|
+
token_id = token_to_id.get(codon)
|
|
56
|
+
if token_id is not None:
|
|
57
|
+
mask[index, token_id] = True
|
|
58
|
+
if not bool(mask[index].any()):
|
|
59
|
+
raise ValueError(f"No tokenizer codons available for amino acid: {aa}")
|
|
60
|
+
return mask
|
|
61
|
+
|