genie-dca 2.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. genie_dca-2.0.0/Genie/__init__.py +26 -0
  2. genie_dca-2.0.0/Genie/core/__init__.py +7 -0
  3. genie_dca-2.0.0/Genie/core/evolution.py +222 -0
  4. genie_dca-2.0.0/Genie/main.py +711 -0
  5. genie_dca-2.0.0/Genie/utils/__init__.py +28 -0
  6. genie_dca-2.0.0/Genie/utils/codon_utils.py +603 -0
  7. genie_dca-2.0.0/Genie/utils/parser.py +121 -0
  8. genie_dca-2.0.0/Genie/utils/pca_utils.py +250 -0
  9. genie_dca-2.0.0/Genie/utils/stats.py +125 -0
  10. genie_dca-2.0.0/Genie_aa/__init__.py +18 -0
  11. genie_dca-2.0.0/Genie_aa/main.py +490 -0
  12. genie_dca-2.0.0/Genie_aa/sampling/__init__.py +11 -0
  13. genie_dca-2.0.0/Genie_aa/sampling/sampling.py +195 -0
  14. genie_dca-2.0.0/Genie_aa/utils/__init__.py +7 -0
  15. genie_dca-2.0.0/Genie_aa/utils/parser.py +133 -0
  16. genie_dca-2.0.0/Genie_aa/utils/stats.py +126 -0
  17. genie_dca-2.0.0/LICENSE +21 -0
  18. genie_dca-2.0.0/MANIFEST.in +9 -0
  19. genie_dca-2.0.0/PKG-INFO +370 -0
  20. genie_dca-2.0.0/README.md +331 -0
  21. genie_dca-2.0.0/genie_dca.egg-info/PKG-INFO +370 -0
  22. genie_dca-2.0.0/genie_dca.egg-info/SOURCES.txt +30 -0
  23. genie_dca-2.0.0/genie_dca.egg-info/dependency_links.txt +1 -0
  24. genie_dca-2.0.0/genie_dca.egg-info/entry_points.txt +5 -0
  25. genie_dca-2.0.0/genie_dca.egg-info/requires.txt +16 -0
  26. genie_dca-2.0.0/genie_dca.egg-info/top_level.txt +3 -0
  27. genie_dca-2.0.0/pyproject.toml +57 -0
  28. genie_dca-2.0.0/scripts/__init__.py +15 -0
  29. genie_dca-2.0.0/scripts/reconstruct_at_timesteps.py +265 -0
  30. genie_dca-2.0.0/scripts/reconstruct_chains.py +261 -0
  31. genie_dca-2.0.0/setup.cfg +4 -0
  32. genie_dca-2.0.0/setup.py +4 -0
@@ -0,0 +1,26 @@
1
+ # Genie package
2
+ from .main import main
3
+ from .core import evolve_sequences
4
+ from .utils import (
5
+ build_codon_neighbors,
6
+ build_codon_to_index_map,
7
+ build_amino_to_codons_map,
8
+ parse_arguments
9
+ )
10
+
11
+ # Import reconstruction functions from scripts
12
+ from scripts import reconstruct_at_timesteps, reconstruct_chains_from_log
13
+
14
+ __version__ = "2.0.0"
15
+ __all__ = [
16
+ "main",
17
+ "evolve_sequences",
18
+ "build_codon_neighbors",
19
+ "build_codon_to_index_map",
20
+ "build_amino_to_codons_map",
21
+ "parse_arguments",
22
+ "reconstruct_at_timesteps",
23
+ "reconstruct_chains_from_log"
24
+ ]
25
+
26
+
@@ -0,0 +1,7 @@
1
+ # Core evolution algorithms
2
+ from .evolution import evolve_sequences
3
+
4
+ __all__ = [
5
+ 'evolve_sequences'
6
+ ]
7
+
@@ -0,0 +1,222 @@
1
+ """
2
+ Evolution module for genetic sequence evolution using DCA models.
3
+ """
4
+ import torch
5
+ import time
6
+ from typing import Dict, List, Tuple, Optional
7
+
8
+
9
+ def evolve_sequences(
10
+ chains: torch.Tensor,
11
+ dna_chains: torch.Tensor,
12
+ params: Dict[str, torch.Tensor],
13
+ codon_neighbor_tensor: torch.Tensor,
14
+ codon_neighbor_codon_tensor: torch.Tensor,
15
+ mutation_lookup: torch.Tensor,
16
+ num_options: torch.Tensor,
17
+ codon_usage: torch.Tensor,
18
+ p: float = 0.5,
19
+ p_values: Optional[torch.Tensor] = None,
20
+ device: Optional[torch.device] = None,
21
+ dtype: torch.dtype = torch.float32,
22
+ beta: float = 1.0
23
+ ) -> Tuple[torch.Tensor, torch.Tensor, float, float]:
24
+ """
25
+ Evolve sequences using unified GPU kernel (no split/merge overhead).
26
+
27
+ MASSIVELY OPTIMIZED: All chains processed in parallel with mask-based logic.
28
+ Eliminates split/merge overhead and enables full GPU parallelization.
29
+
30
+ Args:
31
+ chains: One-hot encoded amino acid sequences (n_chains, seq_length, q)
32
+ dna_chains: DNA sequences as codon indices (n_chains, seq_length)
33
+ params: DCA model parameters with bias and coupling_matrix
34
+ codon_neighbor_tensor: Pre-computed neighbor accessibility (num_codons, 3, q)
35
+ codon_neighbor_codon_tensor: Pre-computed codon neighbor accessibility (num_codons, 3, num_codons)
36
+ mutation_lookup: Pre-computed codon mutations (num_codons, 3, q, max_neighbors)
37
+ num_options: Count of valid options (num_codons, 3, q)
38
+ codon_usage: Tensor (num_codons,) with codon usage frequencies
39
+ p: Float probability threshold for Metropolis vs Gibbs selection
40
+ p_values: Pre-generated random values (n_chains,) for Metropolis/Gibbs split (optional)
41
+ device: Torch device (CPU/GPU)
42
+ dtype: Torch data type
43
+ beta: Inverse temperature for Gibbs sampling
44
+
45
+ Returns:
46
+ Tuple of (evolved amino acid chains, evolved DNA chains, metro_time, gibbs_time)
47
+ """
48
+ N, L, q = chains.shape
49
+
50
+ # Generate or use pre-generated random numbers for Metropolis/Gibbs split
51
+ if p_values is None:
52
+ random_values = torch.rand(N, device=device)
53
+ else:
54
+ random_values = p_values
55
+
56
+ # Create boolean mask for Gibbs (True) vs Metropolis (False)
57
+ use_gibbs = random_values > p # Shape: (N,)
58
+
59
+ # ========== UNIFIED KERNEL - ALL CHAINS IN PARALLEL ==========
60
+
61
+ # Step 1: Randomly select one position per chain
62
+ selected_sites = torch.randint(0, L, (N,), device=device) # (N,) e.g. [3, 0, 7, ..., L-1]
63
+ batch_arange = torch.arange(N, device=device) # (N,) e.g. [0, 1, 2, ..., N-1]
64
+
65
+ # Step 2: Extract biases and couplings for selected sites
66
+ biases = params["bias"][selected_sites] # (N, q)
67
+ couplings_batch = params["coupling_matrix"][selected_sites] # (N, q, L, q)
68
+
69
+ # Step 3: Compute coupling term using optimized bmm (faster than einsum)
70
+ # Reshape: (N, q, L, q) @ (N, L, q) -> (N, q, q) -> sum over last dim -> (N, q)
71
+ # More efficient: (N, q, L*q) @ (N, L*q, 1) -> (N, q, 1) -> squeeze
72
+ chains_flat = chains.reshape(N, L * q, 1) # (N, L*q, 1)
73
+ couplings_flat = couplings_batch.reshape(N, q, L * q) # (N, q, L*q)
74
+ coupling_term = torch.bmm(couplings_flat, chains_flat).squeeze(-1) # (N, q)
75
+ local_field = biases + coupling_term # (N, q)
76
+
77
+ # Step 4: Get current state
78
+ current_codon_indices = dna_chains[batch_arange, selected_sites] # (N,)
79
+ current_aa_onehot = chains[batch_arange, selected_sites] # (N, q)
80
+
81
+ # ========== EXTRACT PARAMS ONCE (avoid repeated dict lookups) ==========
82
+ gap_idx = params["gap_codon_idx"]
83
+ non_gap_codon_tensor = params["non_gap_codon_tensor"]
84
+ stop_codon_mask = params["stop_codon_mask"]
85
+ codon_to_aa_onehot = params["codon_to_aa_onehot"]
86
+ log_codon_usage = params["log_codon_usage"]
87
+ codon_to_aa_idx = params["codon_to_aa_idx"]
88
+
89
+ # ========== METROPOLIS LOGIC (gap insertion/deletion) ==========
90
+
91
+ is_gap = (current_codon_indices == gap_idx) # (N,)
92
+
93
+ # METROPOLIS PROPOSAL RULES (symmetric):
94
+ # - From gap (index 0): propose any of the 64 non-gap codons uniformly (p=1/64 each)
95
+ # - From non-gap: propose gap (p=1/64) OR stay (p=63/64)
96
+ # - Stop codons (indices 62, 63, 64) are proposed but REJECTED in acceptance
97
+
98
+ # Number of non-gap codons (should be 64: 61 coding + 3 stop)
99
+ num_non_gap = non_gap_codon_tensor.shape[0] # Should be 64
100
+
101
+ # For gap positions: propose random codon from ALL 64 non-gap codons (including stops)
102
+ random_non_gap_idx = torch.randint(0, num_non_gap, (N,), device=device) # (N,) between 0 and 63
103
+ random_non_gap_codon = non_gap_codon_tensor[random_non_gap_idx] # (N,) codon indices between 0 and 63
104
+
105
+ # For non-gap positions: 1/64 chance gap, 63/64 chance stay
106
+ rand_vals = torch.rand(N, device=device, dtype=dtype)
107
+ propose_gap_from_nongap = rand_vals < (1.0 / 64.0)
108
+
109
+ # Combine proposals efficiently (no clone, direct where chain)
110
+ # Priority: is_gap → random, propose_gap → gap, else → current
111
+ # Create gap_proposal dynamically (torch.full is extremely fast on GPU)
112
+ gap_proposal = torch.full((N,), gap_idx, dtype=torch.long, device=device)
113
+ metro_proposed_codon = torch.where(
114
+ is_gap,
115
+ random_non_gap_codon,
116
+ torch.where(propose_gap_from_nongap, gap_proposal, current_codon_indices)
117
+ )
118
+
119
+ # Reject stop codons BEFORE computing energy (use pre-computed mask - faster than isin)
120
+ is_stop_codon = stop_codon_mask[metro_proposed_codon] # Boolean indexing (molto più veloce)
121
+
122
+ # If stop codon proposed, replace with current (will be rejected anyway)
123
+ metro_proposed_codon = torch.where(is_stop_codon, current_codon_indices, metro_proposed_codon)
124
+
125
+ # Convert to amino acids using pre-computed one-hot lookup (no one_hot() call)
126
+ metro_proposed_aa_onehot = codon_to_aa_onehot[metro_proposed_codon] # (N, q) direct lookup
127
+
128
+ # Metropolis acceptance with codon usage bias
129
+ # Acceptance = (codon_usage[new] / codon_usage[old]) * exp(-beta * delta_E)
130
+ # - Symmetric proposal (1/64 all directions)
131
+ # - Codon usage bias included in acceptance ratio
132
+ # - Stop codons have usage=0.0, so they are automatically rejected
133
+
134
+ delta_E = torch.sum((current_aa_onehot - metro_proposed_aa_onehot) * local_field, dim=-1)
135
+
136
+ # Get codon usage for current and proposed codons
137
+ current_codon_usage = codon_usage[current_codon_indices] # (N,)
138
+ proposed_codon_usage = codon_usage[metro_proposed_codon] # (N,)
139
+
140
+ # Codon usage ratio (stop codons already replaced with current, so no extra check needed)
141
+ codon_usage_ratio = proposed_codon_usage / (current_codon_usage + 1e-10)
142
+
143
+ # Metropolis acceptance: min(1, (usage_new/usage_old) * exp(-beta * delta_E))
144
+ # Clamp fused with computation (no intermediate tensor)
145
+ metro_acceptance_prob = torch.clamp(codon_usage_ratio * torch.exp(-beta * delta_E), 0.0, 1.0)
146
+ metro_accept = torch.rand(N, device=device, dtype=dtype) < metro_acceptance_prob
147
+
148
+ # ========== GIBBS LOGIC (codon-aware sampling) ==========
149
+
150
+ # Select random nucleotide position (0, 1, or 2)
151
+ nucleotide_positions = torch.randint(0, 3, (N,), device=device)
152
+
153
+ # Get valid codon mask (N, num_codons) - single gather operation
154
+ num_codons = codon_neighbor_codon_tensor.shape[0] # e.g. 65
155
+ valid_codon_mask = codon_neighbor_codon_tensor[current_codon_indices, nucleotide_positions] # (N, num_codons)
156
+
157
+ # Build codon->AA mapping with mask (avoid clone by using where)
158
+ codon_aa_indices = torch.where(
159
+ valid_codon_mask,
160
+ codon_to_aa_idx.unsqueeze(0), # Lazy broadcast, no clone
161
+ -1
162
+ ) # (N, num_codons)
163
+
164
+ # Compute logits efficiently with fused operations
165
+ # Direct indexing for valid codons (faster than gather + where)
166
+ aa_indices_safe = torch.where(valid_codon_mask, codon_aa_indices, 0)
167
+ codon_logits = (beta * local_field).gather(1, aa_indices_safe) + log_codon_usage # Fused add
168
+
169
+ # Set invalid codons to -inf (single where, no intermediate tensor)
170
+ codon_logits = torch.where(valid_codon_mask, codon_logits, torch.tensor(float('-inf'), dtype=dtype, device=device))
171
+
172
+ # Gumbel-Max sampling (faster than multinomial)
173
+ gumbel_noise = -torch.log(-torch.log(torch.rand(N, num_codons, device=device, dtype=dtype) + 1e-10) + 1e-10)
174
+ gibbs_proposed_codon = (codon_logits + gumbel_noise).argmax(dim=-1)
175
+
176
+ # Convert to amino acids using pre-computed one-hot lookup (no one_hot() call)
177
+ gibbs_proposed_aa_onehot = codon_to_aa_onehot[gibbs_proposed_codon] # (N, q) direct lookup
178
+
179
+ # ========== COMBINE RESULTS WITH MASKS ==========
180
+
181
+ # Combine acceptance: Gibbs always accepts, Metropolis uses metro_accept
182
+ accept_mutation = use_gibbs | metro_accept # Single boolean mask (N,)
183
+
184
+ # Cache unsqueeze operations for reuse (avoid recomputing)
185
+ use_gibbs_2d = use_gibbs.unsqueeze(-1) # (N, 1)
186
+ accept_mutation_2d = accept_mutation.unsqueeze(-1) # (N, 1)
187
+
188
+ # Select proposed codon based on method (no intermediate tensor)
189
+ proposed_codon = torch.where(use_gibbs, gibbs_proposed_codon, metro_proposed_codon)
190
+
191
+ # Single where for codon (no nested where)
192
+ final_codon = torch.where(accept_mutation, proposed_codon, current_codon_indices)
193
+
194
+ # Single where for AA with lazy broadcasting (no expand)
195
+ proposed_aa_onehot = torch.where(
196
+ use_gibbs_2d, # Reuse cached unsqueeze
197
+ gibbs_proposed_aa_onehot,
198
+ metro_proposed_aa_onehot
199
+ )
200
+
201
+ final_aa_onehot = torch.where(
202
+ accept_mutation_2d, # Reuse cached unsqueeze
203
+ proposed_aa_onehot,
204
+ current_aa_onehot
205
+ )
206
+
207
+ # Update chains in-place
208
+ chains[batch_arange, selected_sites] = final_aa_onehot
209
+ dna_chains[batch_arange, selected_sites] = final_codon
210
+
211
+
212
+ # also output a tensor containing where you have selcted sites and the new aa
213
+ selected_sites_tensor = selected_sites.unsqueeze(-1) # (N, 1)
214
+ # not in one hot
215
+ new_aa_tensor = final_aa_onehot.argmax(dim=-1) # (N,)
216
+ # put the tensor together in shape (N, 2)
217
+ mutation_info_tensor = torch.cat((selected_sites_tensor, new_aa_tensor.unsqueeze(-1)), dim=-1) # (N, 2)
218
+ # You can return these tensors if needed for analysis
219
+
220
+ # Return timing (0 since unified kernel)
221
+ return chains, dna_chains, mutation_info_tensor
222
+