spacr 0.0.35__py3-none-any.whl → 0.0.61__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +2 -2
- spacr/__main__.py +0 -2
- spacr/alpha.py +514 -2
- spacr/annotate_app.py +113 -117
- spacr/core.py +864 -728
- spacr/deep_spacr.py +696 -0
- spacr/foldseek.py +2 -16
- spacr/graph_learning.py +297 -253
- spacr/gui.py +9 -8
- spacr/gui_2.py +90 -0
- spacr/gui_classify_app.py +7 -8
- spacr/gui_mask_app.py +13 -13
- spacr/gui_measure_app.py +8 -10
- spacr/gui_utils.py +134 -35
- spacr/io.py +311 -467
- spacr/mask_app.py +110 -6
- spacr/measure.py +19 -5
- spacr/models/cp/toxo_pv_lumen.CP_model +0 -0
- spacr/old_code.py +70 -2
- spacr/plot.py +23 -6
- spacr/sequencing.py +1130 -0
- spacr/sim.py +0 -42
- spacr/timelapse.py +0 -1
- spacr/train.py +172 -13
- spacr/umap.py +0 -689
- spacr/utils.py +1322 -75
- {spacr-0.0.35.dist-info → spacr-0.0.61.dist-info}/METADATA +14 -29
- spacr-0.0.61.dist-info/RECORD +39 -0
- {spacr-0.0.35.dist-info → spacr-0.0.61.dist-info}/entry_points.txt +1 -0
- spacr-0.0.35.dist-info/RECORD +0 -35
- {spacr-0.0.35.dist-info → spacr-0.0.61.dist-info}/LICENSE +0 -0
- {spacr-0.0.35.dist-info → spacr-0.0.61.dist-info}/WHEEL +0 -0
- {spacr-0.0.35.dist-info → spacr-0.0.61.dist-info}/top_level.txt +0 -0
spacr/foldseek.py
CHANGED
@@ -1,26 +1,12 @@
|
|
1
|
-
import os, shutil, subprocess, tarfile,
|
2
|
-
import pandas as pd
|
3
|
-
from scipy.stats import fisher_exact
|
4
|
-
from statsmodels.stats.multitest import multipletests
|
5
|
-
from concurrent.futures import ProcessPoolExecutor, as_completed
|
6
|
-
import seaborn as sns
|
7
|
-
import matplotlib.pyplot as plt
|
1
|
+
import os, shutil, subprocess, tarfile, requests
|
8
2
|
import numpy as np
|
9
|
-
|
10
|
-
import requests, time, random
|
11
|
-
from concurrent.futures import ProcessPoolExecutor, as_completed
|
12
|
-
|
13
3
|
import pandas as pd
|
14
4
|
from scipy.stats import fisher_exact
|
15
5
|
from statsmodels.stats.multitest import multipletests
|
16
6
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
17
|
-
import pandas as pd
|
18
|
-
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
|
19
|
-
|
20
7
|
import seaborn as sns
|
21
8
|
import matplotlib.pyplot as plt
|
22
|
-
import
|
23
|
-
from matplotlib.ticker import FixedLocator
|
9
|
+
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
|
24
10
|
|
25
11
|
def run_command(command):
|
26
12
|
print(f"Executing: {command}")
|
spacr/graph_learning.py
CHANGED
@@ -1,276 +1,320 @@
|
|
1
1
|
import os
|
2
|
-
|
3
|
-
import torch
|
4
|
-
import torch.nn.functional as F
|
5
|
-
from collections import defaultdict
|
6
|
-
from torch.utils.data import Dataset, DataLoader
|
2
|
+
os.environ['DGLBACKEND'] = 'pytorch'
|
3
|
+
import torch, dgl
|
7
4
|
import pandas as pd
|
8
|
-
import
|
9
|
-
import
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
self.gene_transform = nn.Linear(gene_feature_size, hidden_dim)
|
133
|
-
self.cell_transform = nn.Linear(cell_feature_size, hidden_dim)
|
134
|
-
self.dropout = nn.Dropout(dropout_rate)
|
135
|
-
|
136
|
-
# Attention layer to let each cell attend to all genes
|
137
|
-
self.attention = Attention(hidden_dim, attn_dim)
|
138
|
-
|
139
|
-
# This layer is used to transform the combined features after attention
|
140
|
-
self.combine_transform = nn.Linear(2 * hidden_dim, hidden_dim)
|
141
|
-
|
142
|
-
# Output layer for predicting cell scores, ensuring it matches the number of cells
|
143
|
-
self.cell_output = nn.Linear(hidden_dim, output_dim)
|
5
|
+
import torch.nn as nn
|
6
|
+
from torchvision import datasets, transforms
|
7
|
+
from sklearn.preprocessing import StandardScaler
|
8
|
+
from PIL import Image
|
9
|
+
import dgl.nn.pytorch as dglnn
|
10
|
+
from sklearn.datasets import make_classification
|
11
|
+
from .utils import SelectChannels
|
12
|
+
|
13
|
+
# approach outline
|
14
|
+
#
|
15
|
+
# 1. Data Preparation:
|
16
|
+
# Test Mode: Load MNIST data and generate synthetic gRNA data.
|
17
|
+
# Real Data: Load image paths and sequencing data as fractions.
|
18
|
+
#
|
19
|
+
# 2. Graph Construction:
|
20
|
+
# Each well is represented as a graph.
|
21
|
+
# Each graph has cell nodes (with image features) and gRNA nodes (with gRNA fraction features).
|
22
|
+
# Each cell node is connected to each gRNA node within the same well.
|
23
|
+
#
|
24
|
+
# 3. Model Training:
|
25
|
+
# Use an encoder-decoder architecture with the Graph Transformer model.
|
26
|
+
# The encoder processes the cell and gRNA nodes.
|
27
|
+
# The decoder outputs the phenotype score for each cell node.
|
28
|
+
# The model is trained on all wells (including positive and negative controls).
|
29
|
+
# The model learns to score the gRNA in column 1 (negative control) as 0 and the gRNA in column 2 (positive control) as 1 based on the cell features.
|
30
|
+
#
|
31
|
+
# 4. Model Application:
|
32
|
+
# Apply the trained model to all wells to get classification probabilities.
|
33
|
+
#
|
34
|
+
# 5. Evaluation:
|
35
|
+
# Evaluate the model's performance using the control wells.
|
36
|
+
#
|
37
|
+
# 6. Association Analysis:
|
38
|
+
# Analyze the association between gRNAs and the classification scores.
|
39
|
+
#
|
40
|
+
# The model learns the associations between cell features and phenotype scores based on the controls and then generalizes this learning to the screening wells.
|
41
|
+
|
42
|
+
# Load MNIST data for testing
|
43
|
+
def load_mnist_data():
|
44
|
+
transform = transforms.Compose([
|
45
|
+
transforms.Resize((28, 28)),
|
46
|
+
transforms.ToTensor(),
|
47
|
+
transforms.Normalize((0.1307,), (0.3081,))
|
48
|
+
])
|
49
|
+
mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
|
50
|
+
mnist_test = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
|
51
|
+
return mnist_train, mnist_test
|
52
|
+
|
53
|
+
# Generate synthetic gRNA data
|
54
|
+
def generate_synthetic_grna_data(n_samples, n_features):
|
55
|
+
X, y = make_classification(n_samples=n_samples, n_features=n_features, n_informative=5, n_redundant=0, n_classes=2, random_state=42)
|
56
|
+
synthetic_data = pd.DataFrame(X, columns=[f'feature_{i}' for i in range(n_features)])
|
57
|
+
synthetic_data['label'] = y
|
58
|
+
return synthetic_data
|
59
|
+
|
60
|
+
# Preprocess image
|
61
|
+
def preprocess_image(image_path, image_size=224, channels=[1,2,3], normalize=True):
|
62
|
+
|
63
|
+
if normalize:
|
64
|
+
preprocess = transforms.Compose([
|
65
|
+
transforms.ToTensor(),
|
66
|
+
transforms.CenterCrop(size=(image_size, image_size)),
|
67
|
+
SelectChannels(channels),
|
68
|
+
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
69
|
+
else:
|
70
|
+
preprocess = transforms.Compose([
|
71
|
+
transforms.ToTensor(),
|
72
|
+
transforms.CenterCrop(size=(image_size, image_size)),
|
73
|
+
SelectChannels(channels)])
|
74
|
+
|
75
|
+
image = Image.open(image_path).convert('RGB')
|
76
|
+
return preprocess(image)
|
77
|
+
|
78
|
+
def extract_metadata_from_path(path):
|
79
|
+
"""
|
80
|
+
Extract metadata from the image path.
|
81
|
+
The path format is expected to be plate_well_field_objectnumber.png
|
82
|
+
|
83
|
+
Parameters:
|
84
|
+
path (str): The path to the image file.
|
85
|
+
|
86
|
+
Returns:
|
87
|
+
dict: A dictionary with the extracted metadata.
|
88
|
+
"""
|
89
|
+
filename = os.path.basename(path)
|
90
|
+
name, ext = os.path.splitext(filename)
|
91
|
+
|
92
|
+
# Ensure the file has the correct extension
|
93
|
+
if ext.lower() != '.png':
|
94
|
+
raise ValueError("Expected a .png file")
|
95
|
+
|
96
|
+
# Split the name by underscores
|
97
|
+
parts = name.split('_')
|
98
|
+
if len(parts) != 4:
|
99
|
+
raise ValueError("Expected filename format: plate_well_field_objectnumber.png")
|
100
|
+
|
101
|
+
plate, well, field, object_number = parts
|
102
|
+
|
103
|
+
return {'plate': plate, 'well': well,'field': field, 'object_number': object_number}
|
104
|
+
|
105
|
+
# Load images
|
106
|
+
def load_images(image_paths, image_size=224, channels=[1,2,3], normalize=True):
|
107
|
+
images = []
|
108
|
+
metadata_list = []
|
109
|
+
for path in image_paths:
|
110
|
+
image = preprocess_image(path, image_size, channels, normalize)
|
111
|
+
images.append(image)
|
112
|
+
metadata = extract_metadata_from_path(path) # Extract metadata from image path or database
|
113
|
+
metadata_list.append(metadata)
|
114
|
+
return torch.stack(images), metadata_list
|
115
|
+
|
116
|
+
# Normalize sequencing data
|
117
|
+
def normalize_sequencing_data(sequencing_data):
|
118
|
+
scaler = StandardScaler()
|
119
|
+
sequencing_data.iloc[:, 2:] = scaler.fit_transform(sequencing_data.iloc[:, 2:])
|
120
|
+
return sequencing_data
|
121
|
+
|
122
|
+
# Construct graph for each well
|
123
|
+
def construct_well_graph(images, image_metadata, grna_data):
|
124
|
+
cell_nodes = len(images)
|
125
|
+
grna_nodes = grna_data.shape[0]
|
126
|
+
|
127
|
+
graph = dgl.DGLGraph()
|
128
|
+
graph.add_nodes(cell_nodes + grna_nodes)
|
144
129
|
|
145
|
-
|
146
|
-
|
147
|
-
transformed_gene_features = F.relu(self.gene_transform(gene_features))
|
148
|
-
transformed_cell_features = F.relu(self.cell_transform(cell_features))
|
130
|
+
cell_features = torch.stack(images)
|
131
|
+
grna_features = torch.tensor(grna_data).float()
|
149
132
|
|
150
|
-
|
151
|
-
|
133
|
+
features = torch.cat([cell_features, grna_features], dim=0)
|
134
|
+
graph.ndata['features'] = features
|
152
135
|
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
136
|
+
for i in range(cell_nodes):
|
137
|
+
for j in range(cell_nodes, cell_nodes + grna_nodes):
|
138
|
+
graph.add_edge(i, j)
|
139
|
+
graph.add_edge(j, i)
|
140
|
+
|
141
|
+
return graph
|
158
142
|
|
159
|
-
|
143
|
+
def create_graphs_for_wells(images, metadata_list, sequencing_data):
|
144
|
+
graphs = []
|
145
|
+
labels = []
|
160
146
|
|
161
|
-
|
162
|
-
|
147
|
+
for well in sequencing_data['well'].unique():
|
148
|
+
well_images = [img for img, meta in zip(images, metadata_list) if meta['well'] == well]
|
149
|
+
well_metadata = [meta for meta in metadata_list if meta['well'] == well]
|
150
|
+
well_grna_data = sequencing_data[sequencing_data['well'] == well].iloc[:, 2:].values
|
163
151
|
|
164
|
-
|
165
|
-
|
152
|
+
graph = construct_well_graph(well_images, well_metadata, well_grna_data)
|
153
|
+
graphs.append(graph)
|
166
154
|
|
167
|
-
|
168
|
-
|
155
|
+
if well_metadata[0]['column'] == 1: # Negative control
|
156
|
+
labels.append(0)
|
157
|
+
elif well_metadata[0]['column'] == 2: # Positive control
|
158
|
+
labels.append(1)
|
159
|
+
else:
|
160
|
+
labels.append(-1) # Screen wells, will be used for evaluation
|
161
|
+
|
162
|
+
return graphs, labels
|
163
|
+
|
164
|
+
# Define Encoder-Decoder Transformer Model
|
165
|
+
class Encoder(nn.Module):
|
166
|
+
def __init__(self, in_feats, hidden_feats):
|
167
|
+
super(Encoder, self).__init__()
|
168
|
+
self.conv1 = dglnn.GraphConv(in_feats, hidden_feats)
|
169
|
+
self.conv2 = dglnn.GraphConv(hidden_feats, hidden_feats)
|
170
|
+
|
171
|
+
def forward(self, g, features):
|
172
|
+
x = self.conv1(g, features)
|
173
|
+
x = torch.relu(x)
|
174
|
+
x = self.conv2(g, x)
|
175
|
+
x = torch.relu(x)
|
176
|
+
return x
|
177
|
+
|
178
|
+
class Decoder(nn.Module):
|
179
|
+
def __init__(self, hidden_feats, out_feats):
|
180
|
+
super(Decoder, self).__init__()
|
181
|
+
self.linear = nn.Linear(hidden_feats, out_feats)
|
182
|
+
|
183
|
+
def forward(self, x):
|
184
|
+
return self.linear(x)
|
169
185
|
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
186
|
+
class GraphTransformer(nn.Module):
|
187
|
+
def __init__(self, in_feats, hidden_feats, out_feats):
|
188
|
+
super(GraphTransformer, self).__init__()
|
189
|
+
self.encoder = Encoder(in_feats, hidden_feats)
|
190
|
+
self.decoder = Decoder(hidden_feats, out_feats)
|
175
191
|
|
176
|
-
|
177
|
-
|
178
|
-
|
192
|
+
def forward(self, g, features):
|
193
|
+
x = self.encoder(g, features)
|
194
|
+
with g.local_scope():
|
195
|
+
g.ndata['h'] = x
|
196
|
+
hg = dgl.mean_nodes(g, 'h')
|
197
|
+
return self.decoder(hg)
|
179
198
|
|
180
|
-
|
181
|
-
|
182
|
-
accumulate_grad_batches=1
|
183
|
-
threshold=acc_threshold
|
184
|
-
|
199
|
+
def train(graphs, labels, model, loss_fn, optimizer, epochs=100):
|
185
200
|
for epoch in range(epochs):
|
186
201
|
model.train()
|
187
202
|
total_loss = 0
|
188
|
-
|
189
|
-
|
190
|
-
optimizer.zero_grad()
|
191
|
-
batch_count = 0 # Initialize batch_count
|
203
|
+
correct = 0
|
204
|
+
total = 0
|
192
205
|
|
193
|
-
for graph in graphs:
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
206
|
+
for graph, label in zip(graphs, labels):
|
207
|
+
if label == -1:
|
208
|
+
continue # Skip screen wells for training
|
209
|
+
|
210
|
+
features = graph.ndata['features']
|
211
|
+
logits = model(graph, features)
|
212
|
+
loss = loss_fn(logits, torch.tensor([label]))
|
213
|
+
|
214
|
+
optimizer.zero_grad()
|
202
215
|
loss.backward()
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
batch_count += 1 # Increment batch_count
|
211
|
-
if batch_count % accumulate_grad_batches == 0 or batch_count == len(graphs):
|
212
|
-
optimizer.step()
|
213
|
-
optimizer.zero_grad()
|
214
|
-
|
215
|
-
total_loss += loss.item() * accumulate_grad_batches
|
216
|
+
optimizer.step()
|
217
|
+
|
218
|
+
total_loss += loss.item()
|
219
|
+
_, predicted = torch.max(logits, 1)
|
220
|
+
correct += (predicted == label).sum().item()
|
221
|
+
total += 1
|
216
222
|
|
217
|
-
accuracy =
|
218
|
-
|
219
|
-
print(f"Epoch {epoch+1}, Loss: {total_loss / len(graphs)}, Accuracy: {accuracy}", end="\r", flush=True)
|
220
|
-
|
221
|
-
# Save the training log and model as before
|
222
|
-
os.makedirs(save_fldr, exist_ok=True)
|
223
|
-
log_path = os.path.join(save_fldr, 'training_log.csv')
|
224
|
-
training_log_df = pd.DataFrame(training_log)
|
225
|
-
training_log_df.to_csv(log_path, index=False)
|
226
|
-
print(f"Training log saved to {log_path}")
|
227
|
-
|
228
|
-
model_path = os.path.join(save_fldr, 'model.pth')
|
229
|
-
torch.save(model.state_dict(), model_path)
|
230
|
-
print(f"Model saved to {model_path}")
|
223
|
+
accuracy = correct / total if total > 0 else 0
|
224
|
+
print(f'Epoch {epoch}, Loss: {total_loss / total:.4f}, Accuracy: {accuracy * 100:.2f}%')
|
231
225
|
|
232
|
-
|
233
|
-
|
234
|
-
def annotate_cells_with_genes(graphs, model, gene_id_to_index):
|
235
|
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
236
|
-
model.to(device)
|
226
|
+
def apply_model(graphs, model):
|
237
227
|
model.eval()
|
238
|
-
|
228
|
+
results = []
|
239
229
|
|
240
230
|
with torch.no_grad():
|
241
231
|
for graph in graphs:
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
232
|
+
features = graph.ndata['features']
|
233
|
+
logits = model(graph, features)
|
234
|
+
probabilities = torch.softmax(logits, dim=1)
|
235
|
+
results.append(probabilities[:, 1].item())
|
236
|
+
|
237
|
+
return results
|
238
|
+
|
239
|
+
def analyze_associations(probabilities, sequencing_data):
|
240
|
+
# Analyze associations between gRNAs and classification scores
|
241
|
+
sequencing_data['positive_prob'] = probabilities
|
242
|
+
return sequencing_data.groupby('gRNA').positive_prob.mean().sort_values(ascending=False)
|
243
|
+
|
244
|
+
def train_graph_transformer(src, lr=0.01, epochs=100, hidden_feats=128, n_classes=2, row_limit=None, image_size=224, channels=[1,2,3], normalize=True, test_mode=False):
|
245
|
+
if test_mode:
|
246
|
+
# Load MNIST data
|
247
|
+
mnist_train, mnist_test = load_mnist_data()
|
248
|
+
|
249
|
+
# Generate synthetic gRNA data
|
250
|
+
synthetic_grna_data = generate_synthetic_grna_data(len(mnist_train), 10) # 10 synthetic features
|
251
|
+
sequencing_data = synthetic_grna_data
|
252
|
+
|
253
|
+
# Load MNIST images and metadata
|
254
|
+
images = []
|
255
|
+
metadata_list = []
|
256
|
+
for idx, (img, label) in enumerate(mnist_train):
|
257
|
+
images.append(img)
|
258
|
+
metadata_list.append({'index': idx, 'plate': 'plate1', 'well': idx, 'column': label})
|
259
|
+
images = torch.stack(images)
|
260
|
+
|
261
|
+
# Normalize synthetic sequencing data
|
262
|
+
sequencing_data = normalize_sequencing_data(sequencing_data)
|
263
|
+
|
264
|
+
else:
|
265
|
+
from .io import _read_and_join_tables
|
266
|
+
from .utils import get_db_paths, get_sequencing_paths, correct_paths
|
267
|
+
|
268
|
+
db_paths = get_db_paths(src)
|
269
|
+
seq_paths = get_sequencing_paths(src)
|
270
|
+
|
271
|
+
if isinstance(src, str):
|
272
|
+
src = [src]
|
273
|
+
|
274
|
+
sequencing_data = pd.DataFrame()
|
275
|
+
for seq in seq_paths:
|
276
|
+
sequencing_df = pd.read_csv(seq)
|
277
|
+
sequencing_data = pd.concat([sequencing_data, sequencing_df], axis=0)
|
278
|
+
|
279
|
+
all_df = pd.DataFrame()
|
280
|
+
for db_path in db_paths:
|
281
|
+
df = _read_and_join_tables(db_path, table_names=['png_list'])
|
282
|
+
all_df = pd.concat([all_df, df], axis=0)
|
283
|
+
|
284
|
+
tables = ['png_list']
|
285
|
+
all_df = pd.DataFrame()
|
286
|
+
image_paths = []
|
287
|
+
for i, db_path in enumerate(db_paths):
|
288
|
+
df = _read_and_join_tables(db_path, table_names=tables)
|
289
|
+
df, image_paths_tmp = correct_paths(df, src[i])
|
290
|
+
all_df = pd.concat([all_df, df], axis=0)
|
291
|
+
image_paths.extend(image_paths_tmp)
|
292
|
+
|
293
|
+
if row_limit is not None:
|
294
|
+
all_df = all_df.sample(n=row_limit, random_state=42)
|
295
|
+
|
296
|
+
images, metadata_list = load_images(image_paths, image_size, channels, normalize)
|
297
|
+
sequencing_data = normalize_sequencing_data(sequencing_data)
|
298
|
+
|
299
|
+
# Step 1: Create graphs for each well
|
300
|
+
graphs, labels = create_graphs_for_wells(images, metadata_list, sequencing_data)
|
301
|
+
|
302
|
+
# Step 2: Train Graph Transformer Model
|
303
|
+
in_feats = graphs[0].ndata['features'].shape[1]
|
304
|
+
model = GraphTransformer(in_feats, hidden_feats, n_classes)
|
305
|
+
loss_fn = nn.CrossEntropyLoss()
|
306
|
+
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
307
|
+
|
308
|
+
# Train the model
|
309
|
+
train(graphs, labels, model, loss_fn, optimizer, epochs)
|
310
|
+
|
311
|
+
# Step 3: Apply the model to all wells (including screen wells)
|
312
|
+
screen_graphs = [graph for graph, label in zip(graphs, labels) if label == -1]
|
313
|
+
probabilities = apply_model(screen_graphs, model)
|
314
|
+
|
315
|
+
# Step 4: Analyze associations between gRNAs and classification scores
|
316
|
+
associations = analyze_associations(probabilities, sequencing_data)
|
317
|
+
print("Top associated gRNAs with positive control phenotype:")
|
318
|
+
print(associations.head())
|
319
|
+
|
320
|
+
return model, associations
|