factorforge-cds 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. factorforge/__init__.py +19 -0
  2. factorforge/__main__.py +8 -0
  3. factorforge/cli/__init__.py +5 -0
  4. factorforge/cli/legacy_cli.py +157 -0
  5. factorforge/cli/main.py +305 -0
  6. factorforge/core/interfaces/__init__.py +7 -0
  7. factorforge/core/interfaces/exporter.py +13 -0
  8. factorforge/core/interfaces/optimizer.py +85 -0
  9. factorforge/core/interfaces/validator.py +9 -0
  10. factorforge/database.py +150 -0
  11. factorforge/engines/__init__.py +60 -0
  12. factorforge/engines/ml/__init__.py +0 -0
  13. factorforge/engines/ml/plant_optimizer.py +325 -0
  14. factorforge/engines/registry.py +141 -0
  15. factorforge/engines/v1_archived/__init__.py +15 -0
  16. factorforge/engines/v2/__init__.py +13 -0
  17. factorforge/engines/v2/codon_table_builder.py +107 -0
  18. factorforge/engines/v2/construct_builder.py +403 -0
  19. factorforge/engines/v2/exporter.py +455 -0
  20. factorforge/engines/v2/optimizer.py +190 -0
  21. factorforge/engines/v2/pipeline.py +275 -0
  22. factorforge/engines/v2/rules/__init__.py +3 -0
  23. factorforge/engines/v2/rules/domesticator.py +403 -0
  24. factorforge/engines/v2/rules/reverse_translator.py +765 -0
  25. factorforge/engines/v2/rules/rule_engine.py +867 -0
  26. factorforge/engines/v2/scoring.py +232 -0
  27. factorforge/engines/v2/utils.py +231 -0
  28. factorforge/engines/v2/validator.py +383 -0
  29. factorforge/engines/v3/__init__.py +12 -0
  30. factorforge/engines/v3/explain.py +119 -0
  31. factorforge/engines/v3/inference/__init__.py +6 -0
  32. factorforge/engines/v3/inference/constrained_decoder.py +80 -0
  33. factorforge/engines/v3/inference/v2_adapter.py +72 -0
  34. factorforge/engines/v3/metrics.py +145 -0
  35. factorforge/engines/v3/modeling_bart_decoder.py +127 -0
  36. factorforge/engines/v3/pipeline.py +192 -0
  37. factorforge/engines/v3/synonym_mask.py +61 -0
  38. factorforge/engines/v3/tokenizer.py +192 -0
  39. factorforge/ml/__init__.py +33 -0
  40. factorforge/ml/feasibility.py +199 -0
  41. factorforge/ml/metrics.py +295 -0
  42. factorforge/utils/__init__.py +31 -0
  43. factorforge/utils/construct_id.py +8 -0
  44. factorforge/utils/exceptions.py +32 -0
  45. factorforge/utils/sequence_validator.py +189 -0
  46. factorforge/utils/validation.py +104 -0
  47. factorforge_cds-3.0.0.dist-info/METADATA +475 -0
  48. factorforge_cds-3.0.0.dist-info/RECORD +52 -0
  49. factorforge_cds-3.0.0.dist-info/WHEEL +5 -0
  50. factorforge_cds-3.0.0.dist-info/entry_points.txt +2 -0
  51. factorforge_cds-3.0.0.dist-info/licenses/LICENSE +201 -0
  52. factorforge_cds-3.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,383 @@
1
+ """
2
+ Input Validator for FactorForge v2
3
+ Input validation and preprocessing module (P0-1)
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import re
9
+ from enum import Enum
10
+ from typing import Any
11
+
12
+
13
+ class SequenceType(Enum):
14
+ """Sequence type enum"""
15
+
16
+ PROTEIN = "protein"
17
+ DNA = "dna"
18
+ FASTA = "fasta"
19
+ UNKNOWN = "unknown"
20
+
21
+
22
+ class ValidationLevel(Enum):
23
+ """Validation level"""
24
+
25
+ VALID = "valid"
26
+ WARNING = "warning"
27
+ ERROR = "error"
28
+
29
+
30
+ class InputValidator:
31
+ """
32
+ Input sequence validation and preprocessing
33
+
34
+ Features:
35
+ - Auto-detect AA/DNA/FASTA
36
+ - Real-time input validation
37
+ - Handle frame errors, stop codons, non-standard AAs
38
+ """
39
+
40
+ # Standard amino acids (20 + STOP)
41
+ STANDARD_AA = set("ACDEFGHIKLMNPQRSTVWY*")
42
+
43
+ # Ambiguous amino acids
44
+ AMBIGUOUS_AA = {
45
+ "B": "Asx (Asn or Asp)",
46
+ "Z": "Glx (Gln or Glu)",
47
+ "X": "Xaa (Unknown)",
48
+ "J": "Xle (Leu or Ile)",
49
+ "U": "Sec (Selenocysteine)",
50
+ "O": "Pyl (Pyrrolysine)",
51
+ }
52
+
53
+ # DNA bases
54
+ DNA_BASES = set("ATGC")
55
+ AMBIGUOUS_DNA = set("NRYSWKMBDHV") # IUPAC ambiguity codes
56
+
57
+ def __init__(self) -> None:
58
+ """Initialize"""
59
+ self.warnings: list[dict[str, Any]] = []
60
+ self.errors: list[dict[str, Any]] = []
61
+
62
+ def detect_sequence_type(self, sequence: str) -> SequenceType:
63
+ """
64
+ Auto-detect sequence type
65
+
66
+ Args:
67
+ sequence: Input sequence (may include whitespace)
68
+
69
+ Returns:
70
+ SequenceType enum
71
+
72
+ Raises:
73
+ None.
74
+
75
+ Examples:
76
+ >>> validator = InputValidator()
77
+ >>> validator.detect_sequence_type("ATGC").value
78
+ 'dna'
79
+ """
80
+ # Remove whitespace and newlines
81
+ clean_seq = re.sub(r"\s+", "", sequence).upper()
82
+
83
+ if not clean_seq:
84
+ return SequenceType.UNKNOWN
85
+
86
+ # Check FASTA format (starts with '>')
87
+ if sequence.strip().startswith(">"):
88
+ return SequenceType.FASTA
89
+
90
+ # Analyze character set
91
+ unique_chars = set(clean_seq)
92
+
93
+ # DNA: only ATGC or ATGC + IUPAC codes
94
+ if unique_chars <= (self.DNA_BASES | self.AMBIGUOUS_DNA):
95
+ return SequenceType.DNA
96
+
97
+ # Protein: amino acid characters
98
+ if unique_chars <= (self.STANDARD_AA | set(self.AMBIGUOUS_AA.keys())):
99
+ return SequenceType.PROTEIN
100
+
101
+ return SequenceType.UNKNOWN
102
+
103
+ def validate(self, sequence: str, auto_fix: bool = False) -> dict[str, Any]:
104
+ """
105
+ Validate input sequence
106
+
107
+ Args:
108
+ sequence: Input sequence
109
+ auto_fix: Whether to auto-fix
110
+
111
+ Returns:
112
+ {
113
+ "type": "protein|dna|fasta|unknown",
114
+ "valid": True/False,
115
+ "level": "valid|warning|error",
116
+ "warnings": [...],
117
+ "errors": [...],
118
+ "metadata": {...},
119
+ "processed_sequence": "..." # Preprocessed sequence
120
+ }
121
+
122
+ Raises:
123
+ None.
124
+
125
+ Examples:
126
+ >>> validator = InputValidator()
127
+ >>> result = validator.validate("ATG GCC TAA")
128
+ >>> result["type"]
129
+ 'dna'
130
+ """
131
+ self.warnings = []
132
+ self.errors = []
133
+
134
+ # 1. Detect sequence type
135
+ seq_type = self.detect_sequence_type(sequence)
136
+
137
+ # 2. Type-specific validation
138
+ if seq_type == SequenceType.FASTA:
139
+ return self._validate_fasta(sequence, auto_fix)
140
+ elif seq_type == SequenceType.DNA:
141
+ return self._validate_dna(sequence, auto_fix)
142
+ elif seq_type == SequenceType.PROTEIN:
143
+ return self._validate_protein(sequence, auto_fix)
144
+ else:
145
+ self.errors.append(
146
+ {
147
+ "code": "UNKNOWN_TYPE",
148
+ "message": "Unable to detect sequence type.",
149
+ "suggestion": "Enter a DNA (ATGC) or Protein (20 AA) sequence.",
150
+ }
151
+ )
152
+ return self._build_result(seq_type, sequence)
153
+
154
+ def _validate_fasta(self, sequence: str, auto_fix: bool) -> dict[str, Any]:
155
+ """Validate FASTA format"""
156
+ lines = sequence.strip().split("\n")
157
+
158
+ if not lines[0].startswith(">"):
159
+ self.errors.append(
160
+ {"code": "INVALID_FASTA", "message": "Missing FASTA header (must start with '>')."}
161
+ )
162
+ return self._build_result(SequenceType.FASTA, sequence)
163
+
164
+ # Extract header
165
+ header = lines[0][1:].strip()
166
+
167
+ # Extract sequence lines
168
+ seq_lines = [
169
+ line.strip() for line in lines[1:] if line.strip() and not line.startswith(">")
170
+ ]
171
+ seq_content = "".join(seq_lines)
172
+
173
+ # Recursively validate sequence content
174
+ seq_result = self.validate(seq_content, auto_fix)
175
+
176
+ # Add FASTA metadata
177
+ seq_result["fasta_header"] = header
178
+ seq_result["type"] = SequenceType.FASTA.value
179
+
180
+ return seq_result
181
+
182
+ def _validate_dna(self, sequence: str, auto_fix: bool) -> dict[str, Any]:
183
+ """Validate DNA sequence"""
184
+ clean_seq = re.sub(r"\s+", "", sequence).upper()
185
+
186
+ # 1. Check invalid bases
187
+ invalid_chars = set(clean_seq) - (self.DNA_BASES | self.AMBIGUOUS_DNA)
188
+ if invalid_chars:
189
+ self.errors.append(
190
+ {
191
+ "code": "INVALID_DNA_CHARS",
192
+ "message": f"Invalid DNA bases: {', '.join(invalid_chars)}",
193
+ "suggestion": "Use only ATGC or IUPAC codes.",
194
+ }
195
+ )
196
+
197
+ # 2. Warn on ambiguous bases
198
+ ambiguous_chars = set(clean_seq) & self.AMBIGUOUS_DNA
199
+ if ambiguous_chars:
200
+ self.warnings.append(
201
+ {
202
+ "code": "AMBIGUOUS_DNA",
203
+ "message": f"Ambiguous bases found: {', '.join(ambiguous_chars)}",
204
+ "suggestion": "Consider replacing with exact bases.",
205
+ }
206
+ )
207
+
208
+ # 3. Frame check (multiple of 3)
209
+ if len(clean_seq) % 3 != 0:
210
+ remainder = len(clean_seq) % 3
211
+ self.warnings.append(
212
+ {
213
+ "code": "FRAME_ERROR",
214
+ "message": f"Sequence length is not a multiple of 3 ({len(clean_seq)} bp).",
215
+ "suggestion": f"Remove or add the last {remainder} bp.",
216
+ "auto_fix_option": {
217
+ "trim_end": len(clean_seq) - remainder,
218
+ "trim_start": len(clean_seq) - 3 + remainder,
219
+ },
220
+ }
221
+ )
222
+
223
+ if auto_fix:
224
+ clean_seq = clean_seq[: len(clean_seq) - remainder]
225
+ self.warnings[-1]["auto_fixed"] = True
226
+
227
+ # 4. Stop codon check
228
+ stop_codons = {"TAA", "TAG", "TGA"}
229
+ stop_positions = []
230
+
231
+ for i in range(0, len(clean_seq) - 2, 3):
232
+ codon = clean_seq[i : i + 3]
233
+ if codon in stop_codons:
234
+ codon_pos = i // 3 + 1 # 1-indexed
235
+ stop_positions.append(codon_pos)
236
+
237
+ # Warn on internal stop codons
238
+ if len(stop_positions) > 1 or (
239
+ len(stop_positions) == 1 and stop_positions[0] != len(clean_seq) // 3
240
+ ):
241
+ self.warnings.append(
242
+ {
243
+ "code": "INTERNAL_STOP",
244
+ "message": f"Internal stop codon found: positions {stop_positions}",
245
+ "suggestion": "Verify whether this is intended.",
246
+ }
247
+ )
248
+
249
+ # 5. GC content
250
+ gc_count = clean_seq.count("G") + clean_seq.count("C")
251
+ gc_content = (gc_count / len(clean_seq) * 100) if clean_seq else 0
252
+
253
+ # Warn on extreme GC content
254
+ if gc_content < 30 or gc_content > 70:
255
+ self.warnings.append(
256
+ {
257
+ "code": "EXTREME_GC",
258
+ "message": f"GC content is extreme: {gc_content:.1f}%",
259
+ "suggestion": "Recommended range is 30-70%.",
260
+ }
261
+ )
262
+
263
+ metadata = {
264
+ "length": len(clean_seq),
265
+ "gc_content": round(gc_content, 2),
266
+ "has_stop": len(stop_positions) > 0,
267
+ "stop_positions": stop_positions,
268
+ "ambiguous_bases": list(ambiguous_chars),
269
+ }
270
+
271
+ return self._build_result(SequenceType.DNA, clean_seq, metadata)
272
+
273
+ def _validate_protein(self, sequence: str, auto_fix: bool) -> dict[str, Any]:
274
+ """Validate protein sequence"""
275
+ clean_seq = re.sub(r"\s+", "", sequence).upper()
276
+
277
+ # 1. Check invalid amino acids
278
+ invalid_chars = set(clean_seq) - (self.STANDARD_AA | set(self.AMBIGUOUS_AA.keys()))
279
+ if invalid_chars:
280
+ self.errors.append(
281
+ {
282
+ "code": "INVALID_AA_CHARS",
283
+ "message": f"Invalid amino acids: {', '.join(invalid_chars)}",
284
+ "suggestion": "Use only standard 20 AAs or * (STOP).",
285
+ }
286
+ )
287
+
288
+ # 2. Warn on ambiguous AAs
289
+ ambiguous_found = {}
290
+ for aa in self.AMBIGUOUS_AA:
291
+ if aa in clean_seq:
292
+ ambiguous_found[aa] = self.AMBIGUOUS_AA[aa]
293
+
294
+ if ambiguous_found:
295
+ self.warnings.append(
296
+ {
297
+ "code": "AMBIGUOUS_AA",
298
+ "message": f"Ambiguous amino acids found: {ambiguous_found}",
299
+ "suggestion": "Consider replacing with specific amino acids.",
300
+ }
301
+ )
302
+
303
+ # 3. Internal STOP check
304
+ stop_positions = [i + 1 for i, aa in enumerate(clean_seq) if aa == "*"]
305
+
306
+ if len(stop_positions) > 1 or (
307
+ len(stop_positions) == 1 and stop_positions[0] != len(clean_seq)
308
+ ):
309
+ self.warnings.append(
310
+ {
311
+ "code": "INTERNAL_STOP",
312
+ "message": f"Internal stop codon found: positions {stop_positions}",
313
+ "suggestion": "Verify whether this is intended.",
314
+ }
315
+ )
316
+
317
+ metadata = {
318
+ "length": len(clean_seq),
319
+ "has_stop": len(stop_positions) > 0,
320
+ "stop_positions": stop_positions,
321
+ "ambiguous_aa": ambiguous_found,
322
+ }
323
+
324
+ return self._build_result(SequenceType.PROTEIN, clean_seq, metadata)
325
+
326
+ def _build_result(
327
+ self,
328
+ seq_type: SequenceType,
329
+ processed_seq: str,
330
+ metadata: dict[str, Any] | None = None,
331
+ ) -> dict[str, Any]:
332
+ """Build validation result"""
333
+ # Determine validation level
334
+ if self.errors:
335
+ level = ValidationLevel.ERROR
336
+ valid = False
337
+ elif self.warnings:
338
+ level = ValidationLevel.WARNING
339
+ valid = True
340
+ else:
341
+ level = ValidationLevel.VALID
342
+ valid = True
343
+
344
+ return {
345
+ "type": seq_type.value,
346
+ "valid": valid,
347
+ "level": level.value,
348
+ "warnings": self.warnings,
349
+ "errors": self.errors,
350
+ "metadata": metadata or {},
351
+ "processed_sequence": processed_seq,
352
+ }
353
+
354
+
355
+ # --- Usage example ---
356
+ if __name__ == "__main__":
357
+ import json
358
+
359
+ validator = InputValidator()
360
+
361
+ # Test case 1: DNA sequence (valid)
362
+ print("=== Test 1: Valid DNA ===")
363
+ result1 = validator.validate("ATG GCC AAA TAA")
364
+ print(json.dumps(result1, indent=2, ensure_ascii=False))
365
+
366
+ # Test case 2: DNA sequence (frame error)
367
+ print("\n=== Test 2: DNA with frame error ===")
368
+ result2 = validator.validate("ATGGCCAA", auto_fix=True)
369
+ print(json.dumps(result2, indent=2, ensure_ascii=False))
370
+
371
+ # Test case 3: Protein sequence (ambiguous AA)
372
+ print("\n=== Test 3: Protein with ambiguous AA ===")
373
+ result3 = validator.validate("MAKXLF*")
374
+ print(json.dumps(result3, indent=2, ensure_ascii=False))
375
+
376
+ # Test case 4: FASTA format
377
+ print("\n=== Test 4: FASTA format ===")
378
+ fasta_input = """>GFP_test
379
+ ATGGTGAGCAAGGGCGAGGAGCTGTTCACCGGG
380
+ GTGGTGCCCATCCTGGTCGAGCTGGACGGCGAC
381
+ TAA"""
382
+ result4 = validator.validate(fasta_input)
383
+ print(json.dumps(result4, indent=2, ensure_ascii=False))
@@ -0,0 +1,12 @@
1
+ """
2
+ FactorForge v3 - BART decoder scaffolding.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ __version__ = "3.0.0"
8
+
9
+ from .pipeline import V3Optimizer, V3Pipeline
10
+ from .tokenizer import AATokenizer, CodonTokenizer
11
+
12
+ __all__ = ["AATokenizer", "CodonTokenizer", "V3Optimizer", "V3Pipeline"]
@@ -0,0 +1,119 @@
1
+ """FDA-style explainability report helpers for FactorForge v3."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import hashlib
6
+ import json
7
+ import re
8
+ from dataclasses import dataclass
9
+ from typing import Any
10
+
11
+ DEFAULT_RESTRICTION_SITES: dict[str, str] = {
12
+ "EcoRI": "GAATTC",
13
+ "BamHI": "GGATCC",
14
+ "BsaI": "GGTCTC",
15
+ "BsmBI": "CGTCTC",
16
+ "NotI": "GCGGCCGC",
17
+ }
18
+
19
+
20
+ @dataclass(frozen=True)
21
+ class ExplainabilityInputs:
22
+ aa_sequence: str
23
+ dna_sequence: str
24
+ metrics: dict[str, float]
25
+ model_id: str
26
+ tokenizer_hash: str
27
+ seed: int
28
+ config: dict[str, Any]
29
+ post_guard: dict[str, Any] | None = None
30
+
31
+
32
+ def build_fda_report(inputs: ExplainabilityInputs) -> dict[str, Any]:
33
+ aa_seq = inputs.aa_sequence
34
+ dna_seq = inputs.dna_sequence
35
+
36
+ return {
37
+ "model": {
38
+ "id": inputs.model_id,
39
+ "hash": _hash_string(inputs.model_id),
40
+ },
41
+ "tokenizer": {
42
+ "hash": inputs.tokenizer_hash,
43
+ },
44
+ "inputs": {
45
+ "aa_length": len(aa_seq),
46
+ },
47
+ "outputs": {
48
+ "dna_length": len(dna_seq),
49
+ "cai": float(inputs.metrics.get("cai", 0.0)),
50
+ "gc_percent": float(inputs.metrics.get("gc_content", 0.0)),
51
+ },
52
+ "constraint_checks": _constraint_checks(dna_seq),
53
+ "post_guard": inputs.post_guard or {},
54
+ "determinism": {
55
+ "seed": inputs.seed,
56
+ "config": inputs.config,
57
+ },
58
+ }
59
+
60
+
61
+ def write_fda_report(path: str, report: dict[str, Any]) -> None:
62
+ with open(path, "w", encoding="utf-8") as handle:
63
+ json.dump(report, handle, indent=2, sort_keys=True)
64
+
65
+
66
+ def _constraint_checks(sequence: str) -> dict[str, Any]:
67
+ return {
68
+ "poly_a_runs": _scan_poly_a(sequence),
69
+ "restriction_sites": _scan_restriction_sites(sequence, DEFAULT_RESTRICTION_SITES),
70
+ "splice_like_motifs": _scan_splice_like(sequence),
71
+ "homopolymers": _scan_homopolymers(sequence),
72
+ "repeats": _scan_repeats(sequence),
73
+ }
74
+
75
+
76
+ def _scan_poly_a(sequence: str, min_len: int = 6) -> list[dict[str, int]]:
77
+ matches = re.finditer(rf"A{{{min_len},}}", sequence)
78
+ return [{"start": m.start(), "end": m.end()} for m in matches]
79
+
80
+
81
+ def _scan_restriction_sites(
82
+ sequence: str, sites: dict[str, str]
83
+ ) -> list[dict[str, Any]]:
84
+ hits: list[dict[str, Any]] = []
85
+ for enzyme, motif in sites.items():
86
+ positions = [m.start() for m in re.finditer(motif, sequence)]
87
+ if positions:
88
+ hits.append({"enzyme": enzyme, "motif": motif, "positions": positions})
89
+ return hits
90
+
91
+
92
+ def _scan_splice_like(sequence: str) -> list[dict[str, int]]:
93
+ matches = re.finditer(r"GT[ACGT]{2,20}AG", sequence)
94
+ return [{"start": m.start(), "end": m.end()} for m in matches]
95
+
96
+
97
+ def _scan_homopolymers(sequence: str, min_len: int = 6) -> list[dict[str, Any]]:
98
+ runs: list[dict[str, Any]] = []
99
+ for base in "ACGT":
100
+ for match in re.finditer(rf"{base}{{{min_len},}}", sequence):
101
+ runs.append({"base": base, "start": match.start(), "end": match.end()})
102
+ return runs
103
+
104
+
105
+ def _scan_repeats(sequence: str, kmer: int = 6) -> list[dict[str, Any]]:
106
+ seen: dict[str, list[int]] = {}
107
+ for idx in range(0, max(0, len(sequence) - kmer + 1)):
108
+ token = sequence[idx : idx + kmer]
109
+ seen.setdefault(token, []).append(idx)
110
+
111
+ repeats: list[dict[str, Any]] = []
112
+ for token, positions in seen.items():
113
+ if len(positions) > 1:
114
+ repeats.append({"kmer": token, "positions": positions})
115
+ return repeats
116
+
117
+
118
+ def _hash_string(value: str) -> str:
119
+ return hashlib.sha256(value.encode("utf-8")).hexdigest()
@@ -0,0 +1,6 @@
1
+ """Inference adapters for FactorForge v3-alpha."""
2
+
3
+ from .v2_adapter import optimize_with_v2
4
+
5
+ __all__ = ["optimize_with_v2"]
6
+
@@ -0,0 +1,80 @@
1
+ """Constrained v3 decoding and fallback helpers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ import torch
8
+
9
+ from factorforge.engines.v3.inference.v2_adapter import optimize_with_v2
10
+ from factorforge.engines.v3.synonym_mask import build_synonym_token_mask, normalize_protein_sequence
11
+ from factorforge.engines.v3.tokenizer import CodonTokenizer
12
+ from factorforge.utils.validation import validate_candidate_sequence
13
+
14
+
15
+ def constrained_greedy_decode(
16
+ decoder: Any,
17
+ encoder_hidden_states: torch.Tensor,
18
+ protein_sequence: str,
19
+ codon_tokenizer: CodonTokenizer,
20
+ ) -> torch.Tensor:
21
+ """Greedy decode one codon per amino acid under a synonym mask."""
22
+ protein = normalize_protein_sequence(protein_sequence)
23
+ mask = build_synonym_token_mask(
24
+ protein,
25
+ codon_tokenizer.token_to_id,
26
+ device=encoder_hidden_states.device,
27
+ )
28
+ batch_size = int(encoder_hidden_states.shape[0])
29
+ if batch_size != 1:
30
+ raise ValueError("constrained_greedy_decode currently supports batch_size=1")
31
+
32
+ decoder_input_ids = torch.full(
33
+ (1, 1),
34
+ codon_tokenizer.bos_token_id,
35
+ dtype=torch.long,
36
+ device=encoder_hidden_states.device,
37
+ )
38
+ generated: list[torch.Tensor] = []
39
+ for position in range(len(protein)):
40
+ logits = decoder(
41
+ encoder_hidden_states=encoder_hidden_states,
42
+ decoder_input_ids=decoder_input_ids,
43
+ )
44
+ next_logits = logits[:, -1, :].masked_fill(~mask[position].unsqueeze(0), -1.0e9)
45
+ next_token = torch.argmax(next_logits, dim=-1)
46
+ generated.append(next_token)
47
+ decoder_input_ids = torch.cat([decoder_input_ids, next_token.unsqueeze(1)], dim=1)
48
+
49
+ eos = torch.tensor(
50
+ [[codon_tokenizer.eos_token_id]],
51
+ dtype=torch.long,
52
+ device=encoder_hidden_states.device,
53
+ )
54
+ return torch.cat([decoder_input_ids, eos], dim=1)
55
+
56
+
57
+ def validate_candidate_or_fallback(
58
+ protein_sequence: str,
59
+ dna_sequence: str,
60
+ fallback_options: dict[str, Any] | None = None,
61
+ ) -> dict[str, Any]:
62
+ """Return a valid v3 candidate or deterministic v2 fallback metadata."""
63
+ protein = normalize_protein_sequence(protein_sequence)
64
+ validator = validate_candidate_sequence(protein, dna_sequence)
65
+ if validator["passed"]:
66
+ return {
67
+ "engine": "v3",
68
+ "protein_sequence": protein,
69
+ "dna_sequence": dna_sequence,
70
+ "validator": validator,
71
+ "fallback_used": False,
72
+ "warnings": list(validator["warnings"]),
73
+ "errors": [],
74
+ }
75
+
76
+ fallback = optimize_with_v2(protein, options=fallback_options)
77
+ fallback["fallback_used"] = True
78
+ fallback.setdefault("metadata", {})["fallback_reason"] = validator["errors"]
79
+ return fallback
80
+
@@ -0,0 +1,72 @@
1
+ """Formal v2 adapter for v3-alpha baseline, teacher, and fallback use."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ from factorforge.engines.v2.rules.reverse_translator import OptimizationProfile, ReverseTranslator
8
+ from factorforge.engines.v2.rules.rule_engine import RuleEngine
9
+ from factorforge.engines.v2.scoring import calculate_composite_score
10
+ from factorforge.engines.v3.metrics import load_codon_usage_table
11
+ from factorforge.ml.metrics import calculate_cai, calculate_gc
12
+ from factorforge.utils.validation import validate_candidate_sequence
13
+
14
+
15
+ def optimize_with_v2(
16
+ protein_sequence: str,
17
+ options: dict[str, Any] | None = None,
18
+ ) -> dict[str, Any]:
19
+ """Optimize a protein sequence with v2 semantics and return v3-alpha metadata."""
20
+ opts = options or {}
21
+ profile_name = str(opts.get("profile", "high_cai")).lower()
22
+ scan_mode = str(opts.get("scan_mode", "fast"))
23
+ try:
24
+ profile = OptimizationProfile(profile_name)
25
+ except ValueError as exc:
26
+ supported = ", ".join(item.value for item in OptimizationProfile)
27
+ raise ValueError(f"Unknown v2 profile: {profile_name}. Supported profiles: {supported}") from exc
28
+
29
+ protein = "".join(protein_sequence.upper().split()).rstrip("*")
30
+ if not protein:
31
+ raise ValueError("protein_sequence must not be empty")
32
+
33
+ translator = ReverseTranslator()
34
+ candidates = translator.generate_candidates(protein, profile=profile, n=1)
35
+ dna_sequence = candidates[0]["sequence"]
36
+ table = load_codon_usage_table()
37
+ metrics = {
38
+ "cai": calculate_cai(dna_sequence, table.codon_weights),
39
+ "gc": calculate_gc(dna_sequence),
40
+ "gc_content": calculate_gc(dna_sequence),
41
+ "score": calculate_composite_score(
42
+ cai=candidates[0]["cai"],
43
+ gc=candidates[0]["gc"],
44
+ sequence=dna_sequence,
45
+ profile=profile.value,
46
+ ),
47
+ }
48
+
49
+ rule_engine = RuleEngine()
50
+ scan_results = rule_engine.scan_all(dna_sequence, mode=scan_mode)
51
+ validator = validate_candidate_sequence(protein, dna_sequence)
52
+ warnings = list(validator["warnings"])
53
+ errors = list(validator["errors"])
54
+ violation_count = sum(len(value) for value in scan_results.values())
55
+ if violation_count:
56
+ warnings.append(f"v2 rule scan reported {violation_count} violation(s)")
57
+
58
+ return {
59
+ "engine": "v2",
60
+ "protein_sequence": protein,
61
+ "dna_sequence": dna_sequence,
62
+ "metrics": metrics,
63
+ "validator": validator,
64
+ "warnings": warnings,
65
+ "errors": errors,
66
+ "metadata": {
67
+ "profile": profile.value,
68
+ "scan_mode": scan_mode,
69
+ "scan_results": scan_results,
70
+ },
71
+ }
72
+