genie-dca 2.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genie_dca-2.0.0/Genie/__init__.py +26 -0
- genie_dca-2.0.0/Genie/core/__init__.py +7 -0
- genie_dca-2.0.0/Genie/core/evolution.py +222 -0
- genie_dca-2.0.0/Genie/main.py +711 -0
- genie_dca-2.0.0/Genie/utils/__init__.py +28 -0
- genie_dca-2.0.0/Genie/utils/codon_utils.py +603 -0
- genie_dca-2.0.0/Genie/utils/parser.py +121 -0
- genie_dca-2.0.0/Genie/utils/pca_utils.py +250 -0
- genie_dca-2.0.0/Genie/utils/stats.py +125 -0
- genie_dca-2.0.0/Genie_aa/__init__.py +18 -0
- genie_dca-2.0.0/Genie_aa/main.py +490 -0
- genie_dca-2.0.0/Genie_aa/sampling/__init__.py +11 -0
- genie_dca-2.0.0/Genie_aa/sampling/sampling.py +195 -0
- genie_dca-2.0.0/Genie_aa/utils/__init__.py +7 -0
- genie_dca-2.0.0/Genie_aa/utils/parser.py +133 -0
- genie_dca-2.0.0/Genie_aa/utils/stats.py +126 -0
- genie_dca-2.0.0/LICENSE +21 -0
- genie_dca-2.0.0/MANIFEST.in +9 -0
- genie_dca-2.0.0/PKG-INFO +370 -0
- genie_dca-2.0.0/README.md +331 -0
- genie_dca-2.0.0/genie_dca.egg-info/PKG-INFO +370 -0
- genie_dca-2.0.0/genie_dca.egg-info/SOURCES.txt +30 -0
- genie_dca-2.0.0/genie_dca.egg-info/dependency_links.txt +1 -0
- genie_dca-2.0.0/genie_dca.egg-info/entry_points.txt +5 -0
- genie_dca-2.0.0/genie_dca.egg-info/requires.txt +16 -0
- genie_dca-2.0.0/genie_dca.egg-info/top_level.txt +3 -0
- genie_dca-2.0.0/pyproject.toml +57 -0
- genie_dca-2.0.0/scripts/__init__.py +15 -0
- genie_dca-2.0.0/scripts/reconstruct_at_timesteps.py +265 -0
- genie_dca-2.0.0/scripts/reconstruct_chains.py +261 -0
- genie_dca-2.0.0/setup.cfg +4 -0
- genie_dca-2.0.0/setup.py +4 -0
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# Genie package
|
|
2
|
+
from .main import main
|
|
3
|
+
from .core import evolve_sequences
|
|
4
|
+
from .utils import (
|
|
5
|
+
build_codon_neighbors,
|
|
6
|
+
build_codon_to_index_map,
|
|
7
|
+
build_amino_to_codons_map,
|
|
8
|
+
parse_arguments
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
# Import reconstruction functions from scripts
|
|
12
|
+
from scripts import reconstruct_at_timesteps, reconstruct_chains_from_log
|
|
13
|
+
|
|
14
|
+
__version__ = "2.0.0"
|
|
15
|
+
__all__ = [
|
|
16
|
+
"main",
|
|
17
|
+
"evolve_sequences",
|
|
18
|
+
"build_codon_neighbors",
|
|
19
|
+
"build_codon_to_index_map",
|
|
20
|
+
"build_amino_to_codons_map",
|
|
21
|
+
"parse_arguments",
|
|
22
|
+
"reconstruct_at_timesteps",
|
|
23
|
+
"reconstruct_chains_from_log"
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
|
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Evolution module for genetic sequence evolution using DCA models.
|
|
3
|
+
"""
|
|
4
|
+
import torch
|
|
5
|
+
import time
|
|
6
|
+
from typing import Dict, List, Tuple, Optional
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def evolve_sequences(
|
|
10
|
+
chains: torch.Tensor,
|
|
11
|
+
dna_chains: torch.Tensor,
|
|
12
|
+
params: Dict[str, torch.Tensor],
|
|
13
|
+
codon_neighbor_tensor: torch.Tensor,
|
|
14
|
+
codon_neighbor_codon_tensor: torch.Tensor,
|
|
15
|
+
mutation_lookup: torch.Tensor,
|
|
16
|
+
num_options: torch.Tensor,
|
|
17
|
+
codon_usage: torch.Tensor,
|
|
18
|
+
p: float = 0.5,
|
|
19
|
+
p_values: Optional[torch.Tensor] = None,
|
|
20
|
+
device: Optional[torch.device] = None,
|
|
21
|
+
dtype: torch.dtype = torch.float32,
|
|
22
|
+
beta: float = 1.0
|
|
23
|
+
) -> Tuple[torch.Tensor, torch.Tensor, float, float]:
|
|
24
|
+
"""
|
|
25
|
+
Evolve sequences using unified GPU kernel (no split/merge overhead).
|
|
26
|
+
|
|
27
|
+
MASSIVELY OPTIMIZED: All chains processed in parallel with mask-based logic.
|
|
28
|
+
Eliminates split/merge overhead and enables full GPU parallelization.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
chains: One-hot encoded amino acid sequences (n_chains, seq_length, q)
|
|
32
|
+
dna_chains: DNA sequences as codon indices (n_chains, seq_length)
|
|
33
|
+
params: DCA model parameters with bias and coupling_matrix
|
|
34
|
+
codon_neighbor_tensor: Pre-computed neighbor accessibility (num_codons, 3, q)
|
|
35
|
+
codon_neighbor_codon_tensor: Pre-computed codon neighbor accessibility (num_codons, 3, num_codons)
|
|
36
|
+
mutation_lookup: Pre-computed codon mutations (num_codons, 3, q, max_neighbors)
|
|
37
|
+
num_options: Count of valid options (num_codons, 3, q)
|
|
38
|
+
codon_usage: Tensor (num_codons,) with codon usage frequencies
|
|
39
|
+
p: Float probability threshold for Metropolis vs Gibbs selection
|
|
40
|
+
p_values: Pre-generated random values (n_chains,) for Metropolis/Gibbs split (optional)
|
|
41
|
+
device: Torch device (CPU/GPU)
|
|
42
|
+
dtype: Torch data type
|
|
43
|
+
beta: Inverse temperature for Gibbs sampling
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
Tuple of (evolved amino acid chains, evolved DNA chains, metro_time, gibbs_time)
|
|
47
|
+
"""
|
|
48
|
+
N, L, q = chains.shape
|
|
49
|
+
|
|
50
|
+
# Generate or use pre-generated random numbers for Metropolis/Gibbs split
|
|
51
|
+
if p_values is None:
|
|
52
|
+
random_values = torch.rand(N, device=device)
|
|
53
|
+
else:
|
|
54
|
+
random_values = p_values
|
|
55
|
+
|
|
56
|
+
# Create boolean mask for Gibbs (True) vs Metropolis (False)
|
|
57
|
+
use_gibbs = random_values > p # Shape: (N,)
|
|
58
|
+
|
|
59
|
+
# ========== UNIFIED KERNEL - ALL CHAINS IN PARALLEL ==========
|
|
60
|
+
|
|
61
|
+
# Step 1: Randomly select one position per chain
|
|
62
|
+
selected_sites = torch.randint(0, L, (N,), device=device) # (N,) e.g. [3, 0, 7, ..., L-1]
|
|
63
|
+
batch_arange = torch.arange(N, device=device) # (N,) e.g. [0, 1, 2, ..., N-1]
|
|
64
|
+
|
|
65
|
+
# Step 2: Extract biases and couplings for selected sites
|
|
66
|
+
biases = params["bias"][selected_sites] # (N, q)
|
|
67
|
+
couplings_batch = params["coupling_matrix"][selected_sites] # (N, q, L, q)
|
|
68
|
+
|
|
69
|
+
# Step 3: Compute coupling term using optimized bmm (faster than einsum)
|
|
70
|
+
# Reshape: (N, q, L, q) @ (N, L, q) -> (N, q, q) -> sum over last dim -> (N, q)
|
|
71
|
+
# More efficient: (N, q, L*q) @ (N, L*q, 1) -> (N, q, 1) -> squeeze
|
|
72
|
+
chains_flat = chains.reshape(N, L * q, 1) # (N, L*q, 1)
|
|
73
|
+
couplings_flat = couplings_batch.reshape(N, q, L * q) # (N, q, L*q)
|
|
74
|
+
coupling_term = torch.bmm(couplings_flat, chains_flat).squeeze(-1) # (N, q)
|
|
75
|
+
local_field = biases + coupling_term # (N, q)
|
|
76
|
+
|
|
77
|
+
# Step 4: Get current state
|
|
78
|
+
current_codon_indices = dna_chains[batch_arange, selected_sites] # (N,)
|
|
79
|
+
current_aa_onehot = chains[batch_arange, selected_sites] # (N, q)
|
|
80
|
+
|
|
81
|
+
# ========== EXTRACT PARAMS ONCE (avoid repeated dict lookups) ==========
|
|
82
|
+
gap_idx = params["gap_codon_idx"]
|
|
83
|
+
non_gap_codon_tensor = params["non_gap_codon_tensor"]
|
|
84
|
+
stop_codon_mask = params["stop_codon_mask"]
|
|
85
|
+
codon_to_aa_onehot = params["codon_to_aa_onehot"]
|
|
86
|
+
log_codon_usage = params["log_codon_usage"]
|
|
87
|
+
codon_to_aa_idx = params["codon_to_aa_idx"]
|
|
88
|
+
|
|
89
|
+
# ========== METROPOLIS LOGIC (gap insertion/deletion) ==========
|
|
90
|
+
|
|
91
|
+
is_gap = (current_codon_indices == gap_idx) # (N,)
|
|
92
|
+
|
|
93
|
+
# METROPOLIS PROPOSAL RULES (symmetric):
|
|
94
|
+
# - From gap (index 0): propose any of the 64 non-gap codons uniformly (p=1/64 each)
|
|
95
|
+
# - From non-gap: propose gap (p=1/64) OR stay (p=63/64)
|
|
96
|
+
# - Stop codons (indices 62, 63, 64) are proposed but REJECTED in acceptance
|
|
97
|
+
|
|
98
|
+
# Number of non-gap codons (should be 64: 61 coding + 3 stop)
|
|
99
|
+
num_non_gap = non_gap_codon_tensor.shape[0] # Should be 64
|
|
100
|
+
|
|
101
|
+
# For gap positions: propose random codon from ALL 64 non-gap codons (including stops)
|
|
102
|
+
random_non_gap_idx = torch.randint(0, num_non_gap, (N,), device=device) # (N,) between 0 and 63
|
|
103
|
+
random_non_gap_codon = non_gap_codon_tensor[random_non_gap_idx] # (N,) codon indices between 0 and 63
|
|
104
|
+
|
|
105
|
+
# For non-gap positions: 1/64 chance gap, 63/64 chance stay
|
|
106
|
+
rand_vals = torch.rand(N, device=device, dtype=dtype)
|
|
107
|
+
propose_gap_from_nongap = rand_vals < (1.0 / 64.0)
|
|
108
|
+
|
|
109
|
+
# Combine proposals efficiently (no clone, direct where chain)
|
|
110
|
+
# Priority: is_gap → random, propose_gap → gap, else → current
|
|
111
|
+
# Create gap_proposal dynamically (torch.full is extremely fast on GPU)
|
|
112
|
+
gap_proposal = torch.full((N,), gap_idx, dtype=torch.long, device=device)
|
|
113
|
+
metro_proposed_codon = torch.where(
|
|
114
|
+
is_gap,
|
|
115
|
+
random_non_gap_codon,
|
|
116
|
+
torch.where(propose_gap_from_nongap, gap_proposal, current_codon_indices)
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Reject stop codons BEFORE computing energy (use pre-computed mask - faster than isin)
|
|
120
|
+
is_stop_codon = stop_codon_mask[metro_proposed_codon] # Boolean indexing (molto più veloce)
|
|
121
|
+
|
|
122
|
+
# If stop codon proposed, replace with current (will be rejected anyway)
|
|
123
|
+
metro_proposed_codon = torch.where(is_stop_codon, current_codon_indices, metro_proposed_codon)
|
|
124
|
+
|
|
125
|
+
# Convert to amino acids using pre-computed one-hot lookup (no one_hot() call)
|
|
126
|
+
metro_proposed_aa_onehot = codon_to_aa_onehot[metro_proposed_codon] # (N, q) direct lookup
|
|
127
|
+
|
|
128
|
+
# Metropolis acceptance with codon usage bias
|
|
129
|
+
# Acceptance = (codon_usage[new] / codon_usage[old]) * exp(-beta * delta_E)
|
|
130
|
+
# - Symmetric proposal (1/64 all directions)
|
|
131
|
+
# - Codon usage bias included in acceptance ratio
|
|
132
|
+
# - Stop codons have usage=0.0, so they are automatically rejected
|
|
133
|
+
|
|
134
|
+
delta_E = torch.sum((current_aa_onehot - metro_proposed_aa_onehot) * local_field, dim=-1)
|
|
135
|
+
|
|
136
|
+
# Get codon usage for current and proposed codons
|
|
137
|
+
current_codon_usage = codon_usage[current_codon_indices] # (N,)
|
|
138
|
+
proposed_codon_usage = codon_usage[metro_proposed_codon] # (N,)
|
|
139
|
+
|
|
140
|
+
# Codon usage ratio (stop codons already replaced with current, so no extra check needed)
|
|
141
|
+
codon_usage_ratio = proposed_codon_usage / (current_codon_usage + 1e-10)
|
|
142
|
+
|
|
143
|
+
# Metropolis acceptance: min(1, (usage_new/usage_old) * exp(-beta * delta_E))
|
|
144
|
+
# Clamp fused with computation (no intermediate tensor)
|
|
145
|
+
metro_acceptance_prob = torch.clamp(codon_usage_ratio * torch.exp(-beta * delta_E), 0.0, 1.0)
|
|
146
|
+
metro_accept = torch.rand(N, device=device, dtype=dtype) < metro_acceptance_prob
|
|
147
|
+
|
|
148
|
+
# ========== GIBBS LOGIC (codon-aware sampling) ==========
|
|
149
|
+
|
|
150
|
+
# Select random nucleotide position (0, 1, or 2)
|
|
151
|
+
nucleotide_positions = torch.randint(0, 3, (N,), device=device)
|
|
152
|
+
|
|
153
|
+
# Get valid codon mask (N, num_codons) - single gather operation
|
|
154
|
+
num_codons = codon_neighbor_codon_tensor.shape[0] # e.g. 65
|
|
155
|
+
valid_codon_mask = codon_neighbor_codon_tensor[current_codon_indices, nucleotide_positions] # (N, num_codons)
|
|
156
|
+
|
|
157
|
+
# Build codon->AA mapping with mask (avoid clone by using where)
|
|
158
|
+
codon_aa_indices = torch.where(
|
|
159
|
+
valid_codon_mask,
|
|
160
|
+
codon_to_aa_idx.unsqueeze(0), # Lazy broadcast, no clone
|
|
161
|
+
-1
|
|
162
|
+
) # (N, num_codons)
|
|
163
|
+
|
|
164
|
+
# Compute logits efficiently with fused operations
|
|
165
|
+
# Direct indexing for valid codons (faster than gather + where)
|
|
166
|
+
aa_indices_safe = torch.where(valid_codon_mask, codon_aa_indices, 0)
|
|
167
|
+
codon_logits = (beta * local_field).gather(1, aa_indices_safe) + log_codon_usage # Fused add
|
|
168
|
+
|
|
169
|
+
# Set invalid codons to -inf (single where, no intermediate tensor)
|
|
170
|
+
codon_logits = torch.where(valid_codon_mask, codon_logits, torch.tensor(float('-inf'), dtype=dtype, device=device))
|
|
171
|
+
|
|
172
|
+
# Gumbel-Max sampling (faster than multinomial)
|
|
173
|
+
gumbel_noise = -torch.log(-torch.log(torch.rand(N, num_codons, device=device, dtype=dtype) + 1e-10) + 1e-10)
|
|
174
|
+
gibbs_proposed_codon = (codon_logits + gumbel_noise).argmax(dim=-1)
|
|
175
|
+
|
|
176
|
+
# Convert to amino acids using pre-computed one-hot lookup (no one_hot() call)
|
|
177
|
+
gibbs_proposed_aa_onehot = codon_to_aa_onehot[gibbs_proposed_codon] # (N, q) direct lookup
|
|
178
|
+
|
|
179
|
+
# ========== COMBINE RESULTS WITH MASKS ==========
|
|
180
|
+
|
|
181
|
+
# Combine acceptance: Gibbs always accepts, Metropolis uses metro_accept
|
|
182
|
+
accept_mutation = use_gibbs | metro_accept # Single boolean mask (N,)
|
|
183
|
+
|
|
184
|
+
# Cache unsqueeze operations for reuse (avoid recomputing)
|
|
185
|
+
use_gibbs_2d = use_gibbs.unsqueeze(-1) # (N, 1)
|
|
186
|
+
accept_mutation_2d = accept_mutation.unsqueeze(-1) # (N, 1)
|
|
187
|
+
|
|
188
|
+
# Select proposed codon based on method (no intermediate tensor)
|
|
189
|
+
proposed_codon = torch.where(use_gibbs, gibbs_proposed_codon, metro_proposed_codon)
|
|
190
|
+
|
|
191
|
+
# Single where for codon (no nested where)
|
|
192
|
+
final_codon = torch.where(accept_mutation, proposed_codon, current_codon_indices)
|
|
193
|
+
|
|
194
|
+
# Single where for AA with lazy broadcasting (no expand)
|
|
195
|
+
proposed_aa_onehot = torch.where(
|
|
196
|
+
use_gibbs_2d, # Reuse cached unsqueeze
|
|
197
|
+
gibbs_proposed_aa_onehot,
|
|
198
|
+
metro_proposed_aa_onehot
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
final_aa_onehot = torch.where(
|
|
202
|
+
accept_mutation_2d, # Reuse cached unsqueeze
|
|
203
|
+
proposed_aa_onehot,
|
|
204
|
+
current_aa_onehot
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
# Update chains in-place
|
|
208
|
+
chains[batch_arange, selected_sites] = final_aa_onehot
|
|
209
|
+
dna_chains[batch_arange, selected_sites] = final_codon
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
# also output a tensor containing where you have selcted sites and the new aa
|
|
213
|
+
selected_sites_tensor = selected_sites.unsqueeze(-1) # (N, 1)
|
|
214
|
+
# not in one hot
|
|
215
|
+
new_aa_tensor = final_aa_onehot.argmax(dim=-1) # (N,)
|
|
216
|
+
# put the tensor together in shape (N, 2)
|
|
217
|
+
mutation_info_tensor = torch.cat((selected_sites_tensor, new_aa_tensor.unsqueeze(-1)), dim=-1) # (N, 2)
|
|
218
|
+
# You can return these tensors if needed for analysis
|
|
219
|
+
|
|
220
|
+
# Return timing (0 since unified kernel)
|
|
221
|
+
return chains, dna_chains, mutation_info_tensor
|
|
222
|
+
|