factorforge-cds 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. factorforge/__init__.py +19 -0
  2. factorforge/__main__.py +8 -0
  3. factorforge/cli/__init__.py +5 -0
  4. factorforge/cli/legacy_cli.py +157 -0
  5. factorforge/cli/main.py +305 -0
  6. factorforge/core/interfaces/__init__.py +7 -0
  7. factorforge/core/interfaces/exporter.py +13 -0
  8. factorforge/core/interfaces/optimizer.py +85 -0
  9. factorforge/core/interfaces/validator.py +9 -0
  10. factorforge/database.py +150 -0
  11. factorforge/engines/__init__.py +60 -0
  12. factorforge/engines/ml/__init__.py +0 -0
  13. factorforge/engines/ml/plant_optimizer.py +325 -0
  14. factorforge/engines/registry.py +141 -0
  15. factorforge/engines/v1_archived/__init__.py +15 -0
  16. factorforge/engines/v2/__init__.py +13 -0
  17. factorforge/engines/v2/codon_table_builder.py +107 -0
  18. factorforge/engines/v2/construct_builder.py +403 -0
  19. factorforge/engines/v2/exporter.py +455 -0
  20. factorforge/engines/v2/optimizer.py +190 -0
  21. factorforge/engines/v2/pipeline.py +275 -0
  22. factorforge/engines/v2/rules/__init__.py +3 -0
  23. factorforge/engines/v2/rules/domesticator.py +403 -0
  24. factorforge/engines/v2/rules/reverse_translator.py +765 -0
  25. factorforge/engines/v2/rules/rule_engine.py +867 -0
  26. factorforge/engines/v2/scoring.py +232 -0
  27. factorforge/engines/v2/utils.py +231 -0
  28. factorforge/engines/v2/validator.py +383 -0
  29. factorforge/engines/v3/__init__.py +12 -0
  30. factorforge/engines/v3/explain.py +119 -0
  31. factorforge/engines/v3/inference/__init__.py +6 -0
  32. factorforge/engines/v3/inference/constrained_decoder.py +80 -0
  33. factorforge/engines/v3/inference/v2_adapter.py +72 -0
  34. factorforge/engines/v3/metrics.py +145 -0
  35. factorforge/engines/v3/modeling_bart_decoder.py +127 -0
  36. factorforge/engines/v3/pipeline.py +192 -0
  37. factorforge/engines/v3/synonym_mask.py +61 -0
  38. factorforge/engines/v3/tokenizer.py +192 -0
  39. factorforge/ml/__init__.py +33 -0
  40. factorforge/ml/feasibility.py +199 -0
  41. factorforge/ml/metrics.py +295 -0
  42. factorforge/utils/__init__.py +31 -0
  43. factorforge/utils/construct_id.py +8 -0
  44. factorforge/utils/exceptions.py +32 -0
  45. factorforge/utils/sequence_validator.py +189 -0
  46. factorforge/utils/validation.py +104 -0
  47. factorforge_cds-3.0.0.dist-info/METADATA +475 -0
  48. factorforge_cds-3.0.0.dist-info/RECORD +52 -0
  49. factorforge_cds-3.0.0.dist-info/WHEEL +5 -0
  50. factorforge_cds-3.0.0.dist-info/entry_points.txt +2 -0
  51. factorforge_cds-3.0.0.dist-info/licenses/LICENSE +201 -0
  52. factorforge_cds-3.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,145 @@
1
+ """Metrics for FactorForge v3 (CAI, perplexity, GC%)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import math
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+ from typing import TYPE_CHECKING
10
+
11
+ if TYPE_CHECKING: # pragma: no cover - for type checkers only
12
+ import torch
13
+
14
+
15
+ def _default_codon_table_path() -> Path:
16
+ return Path(__file__).resolve().parents[4] / "data" / "nbenthamiana_codons.json"
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class CodonUsageTable:
21
+ codon_to_aa: dict[str, str]
22
+ codon_weights: dict[str, float]
23
+ best_codon_for_aa: dict[str, str]
24
+ source: str | None = None
25
+
26
+
27
+ def load_codon_usage_table(path: Path | None = None) -> CodonUsageTable:
28
+ table_path = path or _default_codon_table_path()
29
+ raw = json.loads(table_path.read_text(encoding="utf-8"))
30
+ if not isinstance(raw, dict):
31
+ raise ValueError("Codon table JSON must be an object")
32
+
33
+ codons = raw.get("codons")
34
+ if not isinstance(codons, dict):
35
+ raise ValueError("Codon table JSON missing 'codons'")
36
+
37
+ codon_to_aa: dict[str, str] = {}
38
+ codon_freq: dict[str, float] = {}
39
+ for codon, entry in codons.items():
40
+ if not isinstance(entry, dict) or not isinstance(codon, str):
41
+ continue
42
+ aa = entry.get("aa")
43
+ freq = entry.get("frequency")
44
+ if not isinstance(aa, str) or not isinstance(freq, (int, float)):
45
+ continue
46
+ codon_to_aa[codon] = aa
47
+ codon_freq[codon] = float(freq)
48
+
49
+ codon_weights = _build_codon_weights(codon_to_aa, codon_freq)
50
+ best_codon_for_aa = _best_codon_map(codon_to_aa, codon_freq)
51
+ source = raw.get("source") if isinstance(raw.get("source"), str) else None
52
+
53
+ return CodonUsageTable(
54
+ codon_to_aa=codon_to_aa,
55
+ codon_weights=codon_weights,
56
+ best_codon_for_aa=best_codon_for_aa,
57
+ source=source,
58
+ )
59
+
60
+
61
+ def compute_cai(dna_sequence: str, table: CodonUsageTable) -> float:
62
+ seq = dna_sequence.upper().replace("U", "T")
63
+ codon_count = len(seq) // 3
64
+ if codon_count == 0:
65
+ return 0.0
66
+
67
+ weights: list[float] = []
68
+ for i in range(codon_count):
69
+ codon = seq[i * 3 : i * 3 + 3]
70
+ weight = table.codon_weights.get(codon)
71
+ if weight is None or weight <= 0:
72
+ return 0.0
73
+ weights.append(weight)
74
+
75
+ log_sum = sum(math.log(w) for w in weights)
76
+ return math.exp(log_sum / len(weights))
77
+
78
+
79
+ def compute_gc(dna_sequence: str) -> float:
80
+ seq = dna_sequence.upper()
81
+ if not seq:
82
+ return 0.0
83
+ gc = seq.count("G") + seq.count("C")
84
+ return (gc / len(seq)) * 100.0
85
+
86
+
87
+ def compute_perplexity(
88
+ logits: "torch.Tensor",
89
+ labels: "torch.Tensor",
90
+ ignore_index: int = -100,
91
+ ) -> float:
92
+ torch, functional = _require_torch()
93
+ vocab_size = int(logits.shape[-1])
94
+ loss = functional.cross_entropy(
95
+ logits.view(-1, vocab_size),
96
+ labels.view(-1),
97
+ ignore_index=ignore_index,
98
+ )
99
+ return float(torch.exp(loss).item())
100
+
101
+
102
+ def _build_codon_weights(
103
+ codon_to_aa: dict[str, str],
104
+ codon_freq: dict[str, float],
105
+ ) -> dict[str, float]:
106
+ by_aa: dict[str, list[float]] = {}
107
+ for codon, aa in codon_to_aa.items():
108
+ if aa == "*":
109
+ continue
110
+ by_aa.setdefault(aa, []).append(codon_freq.get(codon, 0.0))
111
+
112
+ weights: dict[str, float] = {}
113
+ for codon, aa in codon_to_aa.items():
114
+ if aa == "*":
115
+ continue
116
+ max_freq = max(by_aa.get(aa, [0.0]))
117
+ freq = codon_freq.get(codon, 0.0)
118
+ weights[codon] = freq / max_freq if max_freq > 0 else 0.0
119
+ return weights
120
+
121
+
122
+ def _best_codon_map(
123
+ codon_to_aa: dict[str, str],
124
+ codon_freq: dict[str, float],
125
+ ) -> dict[str, str]:
126
+ best: dict[str, tuple[str, float]] = {}
127
+ for codon, aa in codon_to_aa.items():
128
+ if aa == "*":
129
+ continue
130
+ current = best.get(aa)
131
+ freq = codon_freq.get(codon, 0.0)
132
+ if current is None or freq > current[1]:
133
+ best[aa] = (codon, freq)
134
+ return {aa: codon for aa, (codon, _) in best.items()}
135
+
136
+
137
+ def _require_torch():
138
+ try:
139
+ import torch
140
+ from torch.nn import functional
141
+ except ImportError as exc: # pragma: no cover
142
+ raise ImportError(
143
+ "ML dependencies not installed. Install with: pip install -e \".[ml]\""
144
+ ) from exc
145
+ return torch, functional
@@ -0,0 +1,127 @@
1
+ """BART decoder skeleton for FactorForge v3."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ try:
8
+ import torch
9
+ from torch import nn
10
+ from transformers import BartConfig
11
+ from transformers.models.bart.modeling_bart import BartDecoder
12
+ except ImportError as exc: # pragma: no cover - exercised in ML installs
13
+ raise ImportError(
14
+ "ML dependencies not installed. Install with: pip install -e \".[ml]\""
15
+ ) from exc
16
+
17
+ if TYPE_CHECKING:
18
+ from torch import Tensor
19
+
20
+
21
+ class BartDecoderSkeleton(nn.Module):
22
+ """Lightweight BART decoder that consumes unified encoder embeddings."""
23
+
24
+ def __init__(
25
+ self,
26
+ vocab_size: int,
27
+ d_model: int = 256,
28
+ encoder_dim: int = 320, # ESM2 t6_8M per-token embedding dim
29
+ decoder_layers: int = 4,
30
+ decoder_attention_heads: int = 4,
31
+ ffn_dim: int = 1024,
32
+ max_position_embeddings: int = 512,
33
+ dropout: float = 0.1,
34
+ pad_token_id: int = 0,
35
+ bos_token_id: int = 1,
36
+ eos_token_id: int = 2,
37
+ ) -> None:
38
+ super().__init__()
39
+ self.config = BartConfig(
40
+ vocab_size=vocab_size,
41
+ d_model=d_model,
42
+ decoder_layers=decoder_layers,
43
+ decoder_attention_heads=decoder_attention_heads,
44
+ decoder_ffn_dim=ffn_dim,
45
+ max_position_embeddings=max_position_embeddings,
46
+ encoder_layers=1,
47
+ encoder_attention_heads=decoder_attention_heads,
48
+ dropout=dropout,
49
+ activation_function="gelu",
50
+ pad_token_id=pad_token_id,
51
+ bos_token_id=bos_token_id,
52
+ eos_token_id=eos_token_id,
53
+ decoder_start_token_id=bos_token_id,
54
+ is_encoder_decoder=True,
55
+ )
56
+
57
+ self.encoder_projection = nn.Linear(encoder_dim, d_model)
58
+ self.decoder = BartDecoder(self.config)
59
+ self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
60
+ self.register_buffer("final_logits_bias", torch.zeros((1, vocab_size)))
61
+
62
+ def forward(
63
+ self,
64
+ encoder_hidden_states: "Tensor",
65
+ decoder_input_ids: "Tensor",
66
+ attention_mask: "Tensor | None" = None,
67
+ encoder_attention_mask: "Tensor | None" = None,
68
+ ) -> "Tensor":
69
+ """Teacher-forcing forward pass returning logits [B, T, vocab_size]."""
70
+ encoder_states = self.encoder_projection(encoder_hidden_states)
71
+ outputs = self.decoder(
72
+ input_ids=decoder_input_ids,
73
+ attention_mask=attention_mask,
74
+ encoder_hidden_states=encoder_states,
75
+ encoder_attention_mask=encoder_attention_mask,
76
+ use_cache=False,
77
+ return_dict=True,
78
+ )
79
+ logits = self.lm_head(outputs.last_hidden_state) + self.final_logits_bias.to(
80
+ outputs.last_hidden_state.device
81
+ )
82
+ return logits
83
+
84
+ def generate(
85
+ self,
86
+ encoder_hidden_states: "Tensor",
87
+ max_new_tokens: int,
88
+ bos_token_id: int,
89
+ eos_token_id: int,
90
+ pad_token_id: int | None = None,
91
+ ) -> "Tensor":
92
+ """Greedy decoding for a few steps (no beam search)."""
93
+ batch_size = encoder_hidden_states.shape[0]
94
+ device = encoder_hidden_states.device
95
+
96
+ decoder_input_ids = torch.full(
97
+ (batch_size, 1),
98
+ bos_token_id,
99
+ dtype=torch.long,
100
+ device=device,
101
+ )
102
+
103
+ for _ in range(max_new_tokens):
104
+ logits = self.forward(
105
+ encoder_hidden_states=encoder_hidden_states,
106
+ decoder_input_ids=decoder_input_ids,
107
+ attention_mask=None,
108
+ encoder_attention_mask=None,
109
+ )
110
+ next_token = torch.argmax(logits[:, -1, :], dim=-1)
111
+ decoder_input_ids = torch.cat(
112
+ [decoder_input_ids, next_token.unsqueeze(1)],
113
+ dim=1,
114
+ )
115
+ if torch.all(next_token == eos_token_id):
116
+ break
117
+
118
+ if pad_token_id is not None:
119
+ decoder_input_ids = decoder_input_ids.clone()
120
+ for idx in range(decoder_input_ids.shape[0]):
121
+ row = decoder_input_ids[idx]
122
+ eos_positions = (row == eos_token_id).nonzero(as_tuple=False)
123
+ if eos_positions.numel() > 0:
124
+ eos_position = int(eos_positions[0])
125
+ if eos_position + 1 < row.shape[0]:
126
+ row[eos_position + 1 :] = pad_token_id
127
+ return decoder_input_ids
@@ -0,0 +1,192 @@
1
+ """FactorForge v3 pipeline scaffolding."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import TYPE_CHECKING, Any
7
+
8
+ from factorforge.core.interfaces import OptimizationResult, OptimizerEngine
9
+ from factorforge.engines.registry import EngineRegistry
10
+ from factorforge.engines.v2.rules.domesticator import Domesticator
11
+ from factorforge.engines.v2.rules.rule_engine import RuleEngine
12
+
13
+ from .explain import ExplainabilityInputs, build_fda_report
14
+ from .metrics import CodonUsageTable, compute_cai, compute_gc, load_codon_usage_table
15
+ from .tokenizer import AA_TOKENS, AATokenizer, CodonTokenizer
16
+
17
+ if TYPE_CHECKING:
18
+ from .modeling_bart_decoder import BartDecoderSkeleton
19
+
20
+
21
+ VALID_AA = set(AA_TOKENS)
22
+
23
+
24
+ @dataclass
25
+ class V3Result:
26
+ sequence: str
27
+ metrics: dict[str, float]
28
+ metadata: dict[str, Any] = field(default_factory=dict)
29
+ report: dict[str, Any] | None = None
30
+
31
+
32
+ class V3Pipeline:
33
+ """Minimal v3 pipeline scaffold (decoder + post-guard hook)."""
34
+
35
+ def __init__(
36
+ self,
37
+ aa_tokenizer: AATokenizer | None = None,
38
+ codon_tokenizer: CodonTokenizer | None = None,
39
+ codon_usage: CodonUsageTable | None = None,
40
+ decoder: "BartDecoderSkeleton | None" = None,
41
+ model_id: str = "v3-bart-decoder-skeleton",
42
+ ) -> None:
43
+ self.aa_tokenizer = aa_tokenizer or AATokenizer.default()
44
+ self.codon_tokenizer = codon_tokenizer or CodonTokenizer.default()
45
+ self.codon_usage = codon_usage or load_codon_usage_table()
46
+ self.decoder = decoder
47
+ self.model_id = model_id
48
+
49
+ def run(
50
+ self,
51
+ sequence: str,
52
+ encoder_embeddings: Any | None = None,
53
+ max_new_tokens: int = 24,
54
+ apply_post_guard: bool = True,
55
+ seed: int = 0,
56
+ config: dict[str, Any] | None = None,
57
+ ) -> V3Result:
58
+ aa_sequence = _normalize_aa(sequence)
59
+
60
+ if encoder_embeddings is not None:
61
+ if self.decoder is None:
62
+ raise ValueError("Decoder not configured. Provide a BartDecoderSkeleton instance.")
63
+ from factorforge.engines.v3.inference.constrained_decoder import (
64
+ constrained_greedy_decode,
65
+ validate_candidate_or_fallback,
66
+ )
67
+
68
+ token_ids = constrained_greedy_decode(
69
+ self.decoder,
70
+ encoder_embeddings,
71
+ aa_sequence,
72
+ self.codon_tokenizer,
73
+ )
74
+ dna_sequence = self.codon_tokenizer.decode(token_ids[0].tolist())
75
+ fallback = validate_candidate_or_fallback(aa_sequence, dna_sequence)
76
+ dna_sequence = fallback["dna_sequence"]
77
+ else:
78
+ dna_sequence = self._fallback_gc_target(aa_sequence)
79
+
80
+ post_guard: dict[str, Any] | None = None
81
+ if apply_post_guard:
82
+ dna_sequence, post_guard = self._apply_post_guard(dna_sequence)
83
+
84
+ metrics = {
85
+ "cai": compute_cai(dna_sequence, self.codon_usage),
86
+ "gc_content": compute_gc(dna_sequence),
87
+ }
88
+
89
+ report_inputs = ExplainabilityInputs(
90
+ aa_sequence=aa_sequence,
91
+ dna_sequence=dna_sequence,
92
+ metrics=metrics,
93
+ model_id=self.model_id,
94
+ tokenizer_hash=self.codon_tokenizer.mapping_hash(),
95
+ seed=seed,
96
+ config=config or {},
97
+ post_guard=post_guard,
98
+ )
99
+ report = build_fda_report(report_inputs)
100
+
101
+ metadata = {
102
+ "model_id": self.model_id,
103
+ "post_guard": post_guard or {},
104
+ }
105
+
106
+ return V3Result(sequence=dna_sequence, metrics=metrics, metadata=metadata, report=report)
107
+
108
+ def _reverse_translate_best(self, aa_sequence: str) -> str:
109
+ codons: list[str] = []
110
+ for aa in aa_sequence:
111
+ codon = self.codon_usage.best_codon_for_aa.get(aa)
112
+ if codon is None:
113
+ raise ValueError(f"No codon mapping for amino acid: {aa}")
114
+ codons.append(codon)
115
+ return "".join(codons)
116
+
117
+ def _fallback_gc_target(self, aa_sequence: str) -> str:
118
+ from factorforge.engines.v3.inference.v2_adapter import optimize_with_v2
119
+
120
+ return optimize_with_v2(aa_sequence, options={"profile": "gc_target"})["dna_sequence"]
121
+
122
+ def _apply_post_guard(self, dna_sequence: str) -> tuple[str, dict[str, Any]]:
123
+ rule_engine = RuleEngine()
124
+ domesticator = Domesticator()
125
+
126
+ # Step 1: PolyA iterative fix (v2 pipeline과 동일한 방식)
127
+ polya_fix: dict[str, Any] | None = None
128
+ current_seq = dna_sequence
129
+ has_polya = any(p in dna_sequence for p in rule_engine.POLYA_PATTERNS)
130
+ if has_polya:
131
+ polya_result = rule_engine.fix_polya_iterative(current_seq)
132
+ if polya_result["success"]:
133
+ current_seq = polya_result["modified_seq"]
134
+ polya_fix = polya_result
135
+
136
+ # Step 2: Full scan on fixed sequence
137
+ scan_results = rule_engine.scan_all(current_seq)
138
+
139
+ # Step 3: Domestication
140
+ domestication = domesticator.domesticate(current_seq, standard="golden_gate")
141
+ domesticated = domestication.get("domesticated_seq", current_seq)
142
+
143
+ post_guard = {
144
+ "scan_results": scan_results,
145
+ "domestication": domestication,
146
+ "polya_fix": polya_fix,
147
+ "edited": domesticated != dna_sequence,
148
+ }
149
+ return domesticated, post_guard
150
+
151
+
152
+ class V3Optimizer(OptimizerEngine):
153
+ """Optimizer wrapper for v3 pipeline."""
154
+
155
+ name = "v3 BART Decoder"
156
+ version = "3.0.0"
157
+
158
+ def __init__(self, pipeline: V3Pipeline | None = None) -> None:
159
+ self.pipeline = pipeline or V3Pipeline()
160
+
161
+ def optimize(
162
+ self,
163
+ sequence: str,
164
+ profile: str | None = None,
165
+ **kwargs: Any,
166
+ ) -> OptimizationResult:
167
+ result = self.pipeline.run(sequence, **kwargs)
168
+ return OptimizationResult(
169
+ sequence=result.sequence,
170
+ metrics=result.metrics,
171
+ metadata=result.metadata,
172
+ )
173
+
174
+ def validate(self, sequence: str) -> bool:
175
+ try:
176
+ _normalize_aa(sequence)
177
+ return True
178
+ except ValueError:
179
+ return False
180
+
181
+ def get_supported_profiles(self) -> list[str]:
182
+ return ["default"]
183
+
184
+
185
+ def _normalize_aa(sequence: str) -> str:
186
+ seq = sequence.strip().replace(" ", "").replace("\n", "").upper()
187
+ if not seq:
188
+ raise ValueError("Protein sequence is empty.")
189
+ invalid = {aa for aa in seq if aa not in VALID_AA}
190
+ if invalid:
191
+ raise ValueError(f"Invalid amino acids found: {''.join(sorted(invalid))}")
192
+ return seq
@@ -0,0 +1,61 @@
1
+ """Synonymous codon masks for v3 training and constrained decoding."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Mapping
6
+
7
+ import torch
8
+
9
+ from factorforge.ml.metrics import STANDARD_GENETIC_CODE
10
+
11
+
12
+ AA_TO_CODONS: dict[str, tuple[str, ...]] = {}
13
+ for _codon, _aa in STANDARD_GENETIC_CODE.items():
14
+ AA_TO_CODONS.setdefault(_aa, tuple())
15
+ for _aa in list(AA_TO_CODONS):
16
+ AA_TO_CODONS[_aa] = tuple(
17
+ codon for codon, codon_aa in STANDARD_GENETIC_CODE.items() if codon_aa == _aa
18
+ )
19
+
20
+
21
+ def normalize_protein_sequence(protein_sequence: str) -> str:
22
+ """Normalize a protein sequence for codon-mask construction."""
23
+ protein = "".join(protein_sequence.upper().split())
24
+ if protein.endswith("*"):
25
+ protein = protein[:-1]
26
+ if not protein:
27
+ raise ValueError("protein_sequence must not be empty")
28
+ invalid = [aa for aa in protein if aa not in AA_TO_CODONS or aa == "*"]
29
+ if invalid:
30
+ raise ValueError(f"No synonymous codons for amino acid(s): {''.join(sorted(set(invalid)))}")
31
+ return protein
32
+
33
+
34
+ def synonymous_codons_for_aa(amino_acid: str) -> tuple[str, ...]:
35
+ """Return synonymous non-stop codons for one standard amino acid."""
36
+ aa = amino_acid.upper()
37
+ codons = AA_TO_CODONS.get(aa, tuple())
38
+ if not codons or aa == "*":
39
+ raise ValueError(f"No synonymous codons for amino acid: {amino_acid}")
40
+ return codons
41
+
42
+
43
+ def build_synonym_token_mask(
44
+ protein_sequence: str,
45
+ token_to_id: Mapping[str, int],
46
+ *,
47
+ device: torch.device | None = None,
48
+ ) -> torch.Tensor:
49
+ """Build a boolean mask [protein_length, vocab_size] for synonymous codons only."""
50
+ protein = normalize_protein_sequence(protein_sequence)
51
+ vocab_size = max(token_to_id.values()) + 1
52
+ mask = torch.zeros((len(protein), vocab_size), dtype=torch.bool, device=device)
53
+ for index, aa in enumerate(protein):
54
+ for codon in synonymous_codons_for_aa(aa):
55
+ token_id = token_to_id.get(codon)
56
+ if token_id is not None:
57
+ mask[index, token_id] = True
58
+ if not bool(mask[index].any()):
59
+ raise ValueError(f"No tokenizer codons available for amino acid: {aa}")
60
+ return mask
61
+