spacr 0.0.36__py3-none-any.whl → 0.0.61__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/foldseek.py CHANGED
@@ -1,26 +1,12 @@
1
- import os, shutil, subprocess, tarfile, glob, requests, time, random
2
- import pandas as pd
3
- from scipy.stats import fisher_exact
4
- from statsmodels.stats.multitest import multipletests
5
- from concurrent.futures import ProcessPoolExecutor, as_completed
6
- import seaborn as sns
7
- import matplotlib.pyplot as plt
1
+ import os, shutil, subprocess, tarfile, requests
8
2
  import numpy as np
9
-
10
- import requests, time, random
11
- from concurrent.futures import ProcessPoolExecutor, as_completed
12
-
13
3
  import pandas as pd
14
4
  from scipy.stats import fisher_exact
15
5
  from statsmodels.stats.multitest import multipletests
16
6
  from concurrent.futures import ProcessPoolExecutor, as_completed
17
- import pandas as pd
18
- from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
19
-
20
7
  import seaborn as sns
21
8
  import matplotlib.pyplot as plt
22
- import numpy as np
23
- from matplotlib.ticker import FixedLocator
9
+ from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
24
10
 
25
11
  def run_command(command):
26
12
  print(f"Executing: {command}")
spacr/graph_learning.py CHANGED
@@ -1,276 +1,320 @@
1
1
  import os
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- from collections import defaultdict
6
- from torch.utils.data import Dataset, DataLoader
2
+ os.environ['DGLBACKEND'] = 'pytorch'
3
+ import torch, dgl
7
4
  import pandas as pd
8
- import numpy as np
9
- import torch.optim as optim
10
-
11
- def generate_graphs(sequencing, scores, cell_min, gene_min_read):
12
- # Load and preprocess sequencing (gene) data
13
- gene_df = pd.read_csv(sequencing)
14
- gene_df = gene_df.rename(columns={'prc': 'well_id', 'grna': 'gene_id', 'count': 'read_count'})
15
- # Filter out genes with read counts less than gene_min_read
16
- gene_df = gene_df[gene_df['read_count'] >= gene_min_read]
17
- total_reads_per_well = gene_df.groupby('well_id')['read_count'].sum().reset_index(name='total_reads')
18
- gene_df = gene_df.merge(total_reads_per_well, on='well_id')
19
- gene_df['well_read_fraction'] = gene_df['read_count'] / gene_df['total_reads']
20
-
21
- # Load and preprocess cell score data
22
- cell_df = pd.read_csv(scores)
23
- cell_df = cell_df[['prcfo', 'prc', 'pred']].rename(columns={'prcfo': 'cell_id', 'prc': 'well_id', 'pred': 'score'})
24
-
25
- # Create a global mapping of gene IDs to indices
26
- unique_genes = gene_df['gene_id'].unique()
27
- gene_id_to_index = {gene_id: index for index, gene_id in enumerate(unique_genes)}
28
-
29
- graphs = []
30
- for well_id in pd.unique(gene_df['well_id']):
31
- well_genes = gene_df[gene_df['well_id'] == well_id]
32
- well_cells = cell_df[cell_df['well_id'] == well_id]
33
-
34
- # Skip wells with no cells or genes or with fewer cells than threshold
35
- if well_cells.empty or well_genes.empty or len(well_cells) < cell_min:
36
- continue
37
-
38
- # Initialize gene features tensor with zeros for all unique genes
39
- gene_features = torch.zeros((len(gene_id_to_index), 1), dtype=torch.float)
40
-
41
- # Update gene features tensor with well_read_fraction for genes present in this well
42
- for _, row in well_genes.iterrows():
43
- gene_index = gene_id_to_index[row['gene_id']]
44
- gene_features[gene_index] = torch.tensor([[row['well_read_fraction']]])
45
-
46
- # Prepare cell features (scores)
47
- cell_features = torch.tensor(well_cells['score'].values, dtype=torch.float).view(-1, 1)
48
-
49
- num_genes = len(gene_id_to_index)
50
- num_cells = cell_features.size(0)
51
- num_nodes = num_genes + num_cells
52
-
53
- # Create adjacency matrix connecting each cell to all genes in the well
54
- adj = torch.zeros((num_nodes, num_nodes), dtype=torch.float)
55
- for _, row in well_genes.iterrows():
56
- gene_index = gene_id_to_index[row['gene_id']]
57
- adj[num_genes:, gene_index] = 1
58
-
59
- graph = {
60
- 'adjacency_matrix': adj,
61
- 'gene_features': gene_features,
62
- 'cell_features': cell_features,
63
- 'num_cells': num_cells,
64
- 'num_genes': num_genes
65
- }
66
- graphs.append(graph)
67
-
68
- print(f'Generated dataset with {len(graphs)} graphs')
69
- return graphs, gene_id_to_index
70
-
71
- def print_graphs_info(graphs, gene_id_to_index):
72
- # Invert the gene_id_to_index mapping for easy lookup
73
- index_to_gene_id = {v: k for k, v in gene_id_to_index.items()}
74
-
75
- for i, graph in enumerate(graphs, start=1):
76
- print(f"Graph {i}:")
77
- num_genes = graph['num_genes']
78
- num_cells = graph['num_cells']
79
- gene_features = graph['gene_features']
80
- cell_features = graph['cell_features']
81
-
82
- print(f" Number of Genes: {num_genes}")
83
- print(f" Number of Cells: {num_cells}")
84
-
85
- # Identify genes present in the graph based on non-zero feature values
86
- present_genes = [index_to_gene_id[idx] for idx, feature in enumerate(gene_features) if feature.item() > 0]
87
- print(" Genes present in this Graph:", present_genes)
88
-
89
- # Display gene features for genes present in the graph
90
- print(" Gene Features:")
91
- for gene_id in present_genes:
92
- idx = gene_id_to_index[gene_id]
93
- print(f" {gene_id}: {gene_features[idx].item()}")
94
-
95
- # Display a sample of cell features, for brevity
96
- print(" Cell Features (sample):")
97
- for idx, feature in enumerate(cell_features[:min(5, len(cell_features))]):
98
- print(f" Cell {idx+1}: {feature.item()}")
99
-
100
- print("-" * 40)
101
-
102
- class Attention(nn.Module):
103
- def __init__(self, feature_dim, attn_dim, dropout_rate=0.1):
104
- super(Attention, self).__init__()
105
- self.query = nn.Linear(feature_dim, attn_dim)
106
- self.key = nn.Linear(feature_dim, attn_dim)
107
- self.value = nn.Linear(feature_dim, feature_dim)
108
- self.scale = 1.0 / (attn_dim ** 0.5)
109
- self.dropout = nn.Dropout(dropout_rate)
110
-
111
- def forward(self, gene_features, cell_features):
112
- # Queries come from the cell features
113
- q = self.query(cell_features)
114
- # Keys and values come from the gene features
115
- k = self.key(gene_features)
116
- v = self.value(gene_features)
117
-
118
- # Compute attention weights
119
- attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale
120
- attn_weights = F.softmax(attn_weights, dim=-1)
121
- # Apply dropout to attention weights
122
- attn_weights = self.dropout(attn_weights)
123
-
124
- # Apply attention weights to the values
125
- attn_output = torch.matmul(attn_weights, v)
126
-
127
- return attn_output, attn_weights
128
-
129
- class GraphTransformer(nn.Module):
130
- def __init__(self, gene_feature_size, cell_feature_size, hidden_dim, output_dim, attn_dim, dropout_rate=0.1):
131
- super(GraphTransformer, self).__init__()
132
- self.gene_transform = nn.Linear(gene_feature_size, hidden_dim)
133
- self.cell_transform = nn.Linear(cell_feature_size, hidden_dim)
134
- self.dropout = nn.Dropout(dropout_rate)
135
-
136
- # Attention layer to let each cell attend to all genes
137
- self.attention = Attention(hidden_dim, attn_dim)
138
-
139
- # This layer is used to transform the combined features after attention
140
- self.combine_transform = nn.Linear(2 * hidden_dim, hidden_dim)
141
-
142
- # Output layer for predicting cell scores, ensuring it matches the number of cells
143
- self.cell_output = nn.Linear(hidden_dim, output_dim)
5
+ import torch.nn as nn
6
+ from torchvision import datasets, transforms
7
+ from sklearn.preprocessing import StandardScaler
8
+ from PIL import Image
9
+ import dgl.nn.pytorch as dglnn
10
+ from sklearn.datasets import make_classification
11
+ from .utils import SelectChannels
12
+
13
+ # approach outline
14
+ #
15
+ # 1. Data Preparation:
16
+ # Test Mode: Load MNIST data and generate synthetic gRNA data.
17
+ # Real Data: Load image paths and sequencing data as fractions.
18
+ #
19
+ # 2. Graph Construction:
20
+ # Each well is represented as a graph.
21
+ # Each graph has cell nodes (with image features) and gRNA nodes (with gRNA fraction features).
22
+ # Each cell node is connected to each gRNA node within the same well.
23
+ #
24
+ # 3. Model Training:
25
+ # Use an encoder-decoder architecture with the Graph Transformer model.
26
+ # The encoder processes the cell and gRNA nodes.
27
+ # The decoder outputs the phenotype score for each cell node.
28
+ # The model is trained on all wells (including positive and negative controls).
29
+ # The model learns to score the gRNA in column 1 (negative control) as 0 and the gRNA in column 2 (positive control) as 1 based on the cell features.
30
+ #
31
+ # 4. Model Application:
32
+ # Apply the trained model to all wells to get classification probabilities.
33
+ #
34
+ # 5. Evaluation:
35
+ # Evaluate the model's performance using the control wells.
36
+ #
37
+ # 6. Association Analysis:
38
+ # Analyze the association between gRNAs and the classification scores.
39
+ #
40
+ # The model learns the associations between cell features and phenotype scores based on the controls and then generalizes this learning to the screening wells.
41
+
42
+ # Load MNIST data for testing
43
+ def load_mnist_data():
44
+ transform = transforms.Compose([
45
+ transforms.Resize((28, 28)),
46
+ transforms.ToTensor(),
47
+ transforms.Normalize((0.1307,), (0.3081,))
48
+ ])
49
+ mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
50
+ mnist_test = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
51
+ return mnist_train, mnist_test
52
+
53
+ # Generate synthetic gRNA data
54
+ def generate_synthetic_grna_data(n_samples, n_features):
55
+ X, y = make_classification(n_samples=n_samples, n_features=n_features, n_informative=5, n_redundant=0, n_classes=2, random_state=42)
56
+ synthetic_data = pd.DataFrame(X, columns=[f'feature_{i}' for i in range(n_features)])
57
+ synthetic_data['label'] = y
58
+ return synthetic_data
59
+
60
+ # Preprocess image
61
+ def preprocess_image(image_path, image_size=224, channels=[1,2,3], normalize=True):
62
+
63
+ if normalize:
64
+ preprocess = transforms.Compose([
65
+ transforms.ToTensor(),
66
+ transforms.CenterCrop(size=(image_size, image_size)),
67
+ SelectChannels(channels),
68
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
69
+ else:
70
+ preprocess = transforms.Compose([
71
+ transforms.ToTensor(),
72
+ transforms.CenterCrop(size=(image_size, image_size)),
73
+ SelectChannels(channels)])
74
+
75
+ image = Image.open(image_path).convert('RGB')
76
+ return preprocess(image)
77
+
78
+ def extract_metadata_from_path(path):
79
+ """
80
+ Extract metadata from the image path.
81
+ The path format is expected to be plate_well_field_objectnumber.png
82
+
83
+ Parameters:
84
+ path (str): The path to the image file.
85
+
86
+ Returns:
87
+ dict: A dictionary with the extracted metadata.
88
+ """
89
+ filename = os.path.basename(path)
90
+ name, ext = os.path.splitext(filename)
91
+
92
+ # Ensure the file has the correct extension
93
+ if ext.lower() != '.png':
94
+ raise ValueError("Expected a .png file")
95
+
96
+ # Split the name by underscores
97
+ parts = name.split('_')
98
+ if len(parts) != 4:
99
+ raise ValueError("Expected filename format: plate_well_field_objectnumber.png")
100
+
101
+ plate, well, field, object_number = parts
102
+
103
+ return {'plate': plate, 'well': well,'field': field, 'object_number': object_number}
104
+
105
+ # Load images
106
+ def load_images(image_paths, image_size=224, channels=[1,2,3], normalize=True):
107
+ images = []
108
+ metadata_list = []
109
+ for path in image_paths:
110
+ image = preprocess_image(path, image_size, channels, normalize)
111
+ images.append(image)
112
+ metadata = extract_metadata_from_path(path) # Extract metadata from image path or database
113
+ metadata_list.append(metadata)
114
+ return torch.stack(images), metadata_list
115
+
116
+ # Normalize sequencing data
117
+ def normalize_sequencing_data(sequencing_data):
118
+ scaler = StandardScaler()
119
+ sequencing_data.iloc[:, 2:] = scaler.fit_transform(sequencing_data.iloc[:, 2:])
120
+ return sequencing_data
121
+
122
+ # Construct graph for each well
123
+ def construct_well_graph(images, image_metadata, grna_data):
124
+ cell_nodes = len(images)
125
+ grna_nodes = grna_data.shape[0]
126
+
127
+ graph = dgl.DGLGraph()
128
+ graph.add_nodes(cell_nodes + grna_nodes)
144
129
 
145
- def forward(self, adjacency_matrix, gene_features, cell_features):
146
- # Apply initial transformation to gene and cell features
147
- transformed_gene_features = F.relu(self.gene_transform(gene_features))
148
- transformed_cell_features = F.relu(self.cell_transform(cell_features))
130
+ cell_features = torch.stack(images)
131
+ grna_features = torch.tensor(grna_data).float()
149
132
 
150
- # Incorporate attention mechanism
151
- attn_output, attn_weights = self.attention(transformed_gene_features, transformed_cell_features)
133
+ features = torch.cat([cell_features, grna_features], dim=0)
134
+ graph.ndata['features'] = features
152
135
 
153
- # Combine the transformed cell features with the attention output features
154
- combined_cell_features = torch.cat((transformed_cell_features, attn_output), dim=1)
155
-
156
- # Apply dropout here as well
157
- combined_cell_features = self.dropout(combined_cell_features)
136
+ for i in range(cell_nodes):
137
+ for j in range(cell_nodes, cell_nodes + grna_nodes):
138
+ graph.add_edge(i, j)
139
+ graph.add_edge(j, i)
140
+
141
+ return graph
158
142
 
159
- combined_cell_features = F.relu(self.combine_transform(combined_cell_features))
143
+ def create_graphs_for_wells(images, metadata_list, sequencing_data):
144
+ graphs = []
145
+ labels = []
160
146
 
161
- # Combine gene and cell features for message passing
162
- combined_features = torch.cat((transformed_gene_features, combined_cell_features), dim=0)
147
+ for well in sequencing_data['well'].unique():
148
+ well_images = [img for img, meta in zip(images, metadata_list) if meta['well'] == well]
149
+ well_metadata = [meta for meta in metadata_list if meta['well'] == well]
150
+ well_grna_data = sequencing_data[sequencing_data['well'] == well].iloc[:, 2:].values
163
151
 
164
- # Apply message passing via adjacency matrix multiplication
165
- message_passed_features = torch.matmul(adjacency_matrix, combined_features)
152
+ graph = construct_well_graph(well_images, well_metadata, well_grna_data)
153
+ graphs.append(graph)
166
154
 
167
- # Predict cell scores from the post-message passed cell features
168
- cell_scores = self.cell_output(message_passed_features[-cell_features.size(0):])
155
+ if well_metadata[0]['column'] == 1: # Negative control
156
+ labels.append(0)
157
+ elif well_metadata[0]['column'] == 2: # Positive control
158
+ labels.append(1)
159
+ else:
160
+ labels.append(-1) # Screen wells, will be used for evaluation
161
+
162
+ return graphs, labels
163
+
164
+ # Define Encoder-Decoder Transformer Model
165
+ class Encoder(nn.Module):
166
+ def __init__(self, in_feats, hidden_feats):
167
+ super(Encoder, self).__init__()
168
+ self.conv1 = dglnn.GraphConv(in_feats, hidden_feats)
169
+ self.conv2 = dglnn.GraphConv(hidden_feats, hidden_feats)
170
+
171
+ def forward(self, g, features):
172
+ x = self.conv1(g, features)
173
+ x = torch.relu(x)
174
+ x = self.conv2(g, x)
175
+ x = torch.relu(x)
176
+ return x
177
+
178
+ class Decoder(nn.Module):
179
+ def __init__(self, hidden_feats, out_feats):
180
+ super(Decoder, self).__init__()
181
+ self.linear = nn.Linear(hidden_feats, out_feats)
182
+
183
+ def forward(self, x):
184
+ return self.linear(x)
169
185
 
170
- return cell_scores, attn_weights
171
-
172
- def train_graph_transformer(graphs, lr=0.01, dropout_rate=0.1, weight_decay=0.00001, epochs=100, save_fldr='', acc_threshold = 0.1):
173
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
174
- model = GraphTransformer(gene_feature_size=1, cell_feature_size=1, hidden_dim=256, output_dim=1, attn_dim=128, dropout_rate=dropout_rate).to(device)
186
+ class GraphTransformer(nn.Module):
187
+ def __init__(self, in_feats, hidden_feats, out_feats):
188
+ super(GraphTransformer, self).__init__()
189
+ self.encoder = Encoder(in_feats, hidden_feats)
190
+ self.decoder = Decoder(hidden_feats, out_feats)
175
191
 
176
- criterion = nn.MSELoss()
177
- #optimizer = torch.optim.Adam(model.parameters(), lr=lr)
178
- optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
192
+ def forward(self, g, features):
193
+ x = self.encoder(g, features)
194
+ with g.local_scope():
195
+ g.ndata['h'] = x
196
+ hg = dgl.mean_nodes(g, 'h')
197
+ return self.decoder(hg)
179
198
 
180
- training_log = []
181
-
182
- accumulate_grad_batches=1
183
- threshold=acc_threshold
184
-
199
+ def train(graphs, labels, model, loss_fn, optimizer, epochs=100):
185
200
  for epoch in range(epochs):
186
201
  model.train()
187
202
  total_loss = 0
188
- total_correct = 0
189
- total_samples = 0
190
- optimizer.zero_grad()
191
- batch_count = 0 # Initialize batch_count
203
+ correct = 0
204
+ total = 0
192
205
 
193
- for graph in graphs:
194
- adjacency_matrix = graph['adjacency_matrix'].to(device)
195
- gene_features = graph['gene_features'].to(device)
196
- cell_features = graph['cell_features'].to(device)
197
- num_cells = graph['num_cells']
198
- predictions, attn_weights = model(adjacency_matrix, gene_features, cell_features)
199
- predictions = predictions.squeeze()
200
- true_scores = cell_features[:num_cells, 0]
201
- loss = criterion(predictions, true_scores) / accumulate_grad_batches
206
+ for graph, label in zip(graphs, labels):
207
+ if label == -1:
208
+ continue # Skip screen wells for training
209
+
210
+ features = graph.ndata['features']
211
+ logits = model(graph, features)
212
+ loss = loss_fn(logits, torch.tensor([label]))
213
+
214
+ optimizer.zero_grad()
202
215
  loss.backward()
203
-
204
- # Calculate "accuracy"
205
- with torch.no_grad():
206
- correct_predictions = (torch.abs(predictions - true_scores) / true_scores <= threshold).sum().item()
207
- total_correct += correct_predictions
208
- total_samples += num_cells
209
-
210
- batch_count += 1 # Increment batch_count
211
- if batch_count % accumulate_grad_batches == 0 or batch_count == len(graphs):
212
- optimizer.step()
213
- optimizer.zero_grad()
214
-
215
- total_loss += loss.item() * accumulate_grad_batches
216
+ optimizer.step()
217
+
218
+ total_loss += loss.item()
219
+ _, predicted = torch.max(logits, 1)
220
+ correct += (predicted == label).sum().item()
221
+ total += 1
216
222
 
217
- accuracy = total_correct / total_samples
218
- training_log.append({"Epoch": epoch+1, "Average Loss": total_loss / len(graphs), "Accuracy": accuracy})
219
- print(f"Epoch {epoch+1}, Loss: {total_loss / len(graphs)}, Accuracy: {accuracy}", end="\r", flush=True)
220
-
221
- # Save the training log and model as before
222
- os.makedirs(save_fldr, exist_ok=True)
223
- log_path = os.path.join(save_fldr, 'training_log.csv')
224
- training_log_df = pd.DataFrame(training_log)
225
- training_log_df.to_csv(log_path, index=False)
226
- print(f"Training log saved to {log_path}")
227
-
228
- model_path = os.path.join(save_fldr, 'model.pth')
229
- torch.save(model.state_dict(), model_path)
230
- print(f"Model saved to {model_path}")
223
+ accuracy = correct / total if total > 0 else 0
224
+ print(f'Epoch {epoch}, Loss: {total_loss / total:.4f}, Accuracy: {accuracy * 100:.2f}%')
231
225
 
232
- return model
233
-
234
- def annotate_cells_with_genes(graphs, model, gene_id_to_index):
235
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
236
- model.to(device)
226
+ def apply_model(graphs, model):
237
227
  model.eval()
238
- annotated_data = []
228
+ results = []
239
229
 
240
230
  with torch.no_grad():
241
231
  for graph in graphs:
242
- adjacency_matrix = graph['adjacency_matrix'].to(device)
243
- gene_features = graph['gene_features'].to(device)
244
- cell_features = graph['cell_features'].to(device)
245
-
246
- predictions, attn_weights = model(adjacency_matrix, gene_features, cell_features)
247
- predictions = np.atleast_1d(predictions.squeeze().cpu().numpy())
248
- attn_weights = np.atleast_2d(attn_weights.squeeze().cpu().numpy())
249
-
250
- # This approach assumes all genes in gene_id_to_index are used in the model.
251
- # Create a list of gene IDs present in this specific graph.
252
- present_gene_ids = [key for key, value in gene_id_to_index.items() if value < gene_features.size(0)]
253
-
254
- for cell_idx in range(cell_features.size(0)):
255
- true_score = cell_features[cell_idx, 0].item()
256
- predicted_score = predictions[cell_idx]
257
-
258
- # Find the index of the most probable gene.
259
- most_probable_gene_idx = attn_weights[cell_idx].argmax()
260
-
261
- if len(present_gene_ids) > most_probable_gene_idx: # Ensure index is within the range
262
- most_probable_gene_id = present_gene_ids[most_probable_gene_idx]
263
- most_probable_gene_score = attn_weights[cell_idx, most_probable_gene_idx] if attn_weights.ndim > 1 else attn_weights[most_probable_gene_idx]
264
-
265
- annotated_data.append({
266
- "Cell ID": cell_idx,
267
- "Most Probable Gene": most_probable_gene_id,
268
- "Cell Score": true_score,
269
- "Predicted Cell Score": predicted_score,
270
- "Probability Score for Highest Gene": most_probable_gene_score
271
- })
272
- else:
273
- # Handle the case where the index is out of bounds - this should not happen but is here for robustness
274
- print("Error: Gene index out of bounds. This might indicate a mismatch in the model's output.")
275
-
276
- return pd.DataFrame(annotated_data)
232
+ features = graph.ndata['features']
233
+ logits = model(graph, features)
234
+ probabilities = torch.softmax(logits, dim=1)
235
+ results.append(probabilities[:, 1].item())
236
+
237
+ return results
238
+
239
+ def analyze_associations(probabilities, sequencing_data):
240
+ # Analyze associations between gRNAs and classification scores
241
+ sequencing_data['positive_prob'] = probabilities
242
+ return sequencing_data.groupby('gRNA').positive_prob.mean().sort_values(ascending=False)
243
+
244
+ def train_graph_transformer(src, lr=0.01, epochs=100, hidden_feats=128, n_classes=2, row_limit=None, image_size=224, channels=[1,2,3], normalize=True, test_mode=False):
245
+ if test_mode:
246
+ # Load MNIST data
247
+ mnist_train, mnist_test = load_mnist_data()
248
+
249
+ # Generate synthetic gRNA data
250
+ synthetic_grna_data = generate_synthetic_grna_data(len(mnist_train), 10) # 10 synthetic features
251
+ sequencing_data = synthetic_grna_data
252
+
253
+ # Load MNIST images and metadata
254
+ images = []
255
+ metadata_list = []
256
+ for idx, (img, label) in enumerate(mnist_train):
257
+ images.append(img)
258
+ metadata_list.append({'index': idx, 'plate': 'plate1', 'well': idx, 'column': label})
259
+ images = torch.stack(images)
260
+
261
+ # Normalize synthetic sequencing data
262
+ sequencing_data = normalize_sequencing_data(sequencing_data)
263
+
264
+ else:
265
+ from .io import _read_and_join_tables
266
+ from .utils import get_db_paths, get_sequencing_paths, correct_paths
267
+
268
+ db_paths = get_db_paths(src)
269
+ seq_paths = get_sequencing_paths(src)
270
+
271
+ if isinstance(src, str):
272
+ src = [src]
273
+
274
+ sequencing_data = pd.DataFrame()
275
+ for seq in seq_paths:
276
+ sequencing_df = pd.read_csv(seq)
277
+ sequencing_data = pd.concat([sequencing_data, sequencing_df], axis=0)
278
+
279
+ all_df = pd.DataFrame()
280
+ for db_path in db_paths:
281
+ df = _read_and_join_tables(db_path, table_names=['png_list'])
282
+ all_df = pd.concat([all_df, df], axis=0)
283
+
284
+ tables = ['png_list']
285
+ all_df = pd.DataFrame()
286
+ image_paths = []
287
+ for i, db_path in enumerate(db_paths):
288
+ df = _read_and_join_tables(db_path, table_names=tables)
289
+ df, image_paths_tmp = correct_paths(df, src[i])
290
+ all_df = pd.concat([all_df, df], axis=0)
291
+ image_paths.extend(image_paths_tmp)
292
+
293
+ if row_limit is not None:
294
+ all_df = all_df.sample(n=row_limit, random_state=42)
295
+
296
+ images, metadata_list = load_images(image_paths, image_size, channels, normalize)
297
+ sequencing_data = normalize_sequencing_data(sequencing_data)
298
+
299
+ # Step 1: Create graphs for each well
300
+ graphs, labels = create_graphs_for_wells(images, metadata_list, sequencing_data)
301
+
302
+ # Step 2: Train Graph Transformer Model
303
+ in_feats = graphs[0].ndata['features'].shape[1]
304
+ model = GraphTransformer(in_feats, hidden_feats, n_classes)
305
+ loss_fn = nn.CrossEntropyLoss()
306
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
307
+
308
+ # Train the model
309
+ train(graphs, labels, model, loss_fn, optimizer, epochs)
310
+
311
+ # Step 3: Apply the model to all wells (including screen wells)
312
+ screen_graphs = [graph for graph, label in zip(graphs, labels) if label == -1]
313
+ probabilities = apply_model(screen_graphs, model)
314
+
315
+ # Step 4: Analyze associations between gRNAs and classification scores
316
+ associations = analyze_associations(probabilities, sequencing_data)
317
+ print("Top associated gRNAs with positive control phenotype:")
318
+ print(associations.head())
319
+
320
+ return model, associations