spacr 0.0.18__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,72 @@
1
+ import csv
2
+ import os
3
+ import requests
4
+
5
+ def download_alphafold_structures(tsv_location, dst, version="4"):
6
+ # Create the destination directory if it does not exist
7
+ dst_pdb = os.path.join(dst,'pdb')
8
+ dst_cif = os.path.join(dst,'cif')
9
+ dst_pae = os.path.join(dst,'pae')
10
+
11
+ if not os.path.exists(dst):
12
+ os.makedirs(dst)
13
+ if not os.path.exists(dst_pdb):
14
+ os.makedirs(dst_pdb)
15
+ if not os.path.exists(dst_cif):
16
+ os.makedirs(dst_cif)
17
+ if not os.path.exists(dst_pae):
18
+ os.makedirs(dst_pae)
19
+
20
+ failed_downloads = [] # List to keep track of failed downloads
21
+
22
+ # Open the TSV file and read entries
23
+ with open(tsv_location, 'r') as tsv_file:
24
+ reader = csv.DictReader(tsv_file, delimiter='\t')
25
+ for row in reader:
26
+ entry = row['Entry']
27
+ af_link = f"https://alphafold.ebi.ac.uk/files/AF-{entry}-F1-model_v{version}.pdb"
28
+ cif_link = f"https://alphafold.ebi.ac.uk/files/AF-{entry}-F1-model_v{version}.cif"
29
+ pae_link = f"https://alphafold.ebi.ac.uk/files/AF-{entry}-F1-predicted_aligned_error_v{version}.json"
30
+
31
+ try:
32
+ response_pdb = requests.get(af_link, stream=True)
33
+ response_cif = requests.get(cif_link, stream=True)
34
+ response_pae = requests.get(pae_link, stream=True)
35
+ if response_pdb.status_code == 200:
36
+
37
+ # Save the PDB file
38
+ with open(os.path.join(dst_pdb, f"AF-{entry}-F1-model_v{version}.pdb"), 'wb') as pdb_file:
39
+ pdb_file.write(response_pdb.content)
40
+ print(f"Downloaded: AF-{entry}-F1-model_v{version}.pdb")
41
+
42
+ # Save the CIF file
43
+ with open(os.path.join(dst_cif, f"AF-{entry}-F1-model_v{version}.cif"), 'wb') as cif_file:
44
+ cif_file.write(response_cif.content)
45
+ print(f"Downloaded: AF-{entry}-F1-model_v{version}.cif")
46
+
47
+ # Save the PAE file
48
+ with open(os.path.join(dst_pae, f"AF-{entry}-F1-predicted_aligned_error_v{version}.json"), 'wb') as pdb_file:
49
+ pdb_file.write(response_pae.content)
50
+ print(f"Downloaded: AF-{entry}-F1-predicted_aligned_error_v{version}.json")
51
+
52
+ else:
53
+ # If the file could not be downloaded, record the entry
54
+ failed_downloads.append(entry)
55
+ print(f"Failed to download structure for: {entry}")
56
+ except Exception as e:
57
+ print(f"Error downloading structure for {entry}: {e}")
58
+ failed_downloads.append(entry)
59
+
60
+ # Save the list of failed downloads to a CSV file in the destination folder
61
+ if failed_downloads:
62
+ with open(os.path.join(dst, 'failed_downloads.csv'), 'w', newline='') as failed_file:
63
+ writer = csv.writer(failed_file)
64
+ writer.writerow(['Entry'])
65
+ for entry in failed_downloads:
66
+ writer.writerow([entry])
67
+ print(f"Failed download entries saved to: {os.path.join(dst, 'failed_downloads.csv')}")
68
+
69
+ # Example usage:
70
+ tsv_location = '/home/carruthers/Downloads/GT1_proteome/GT1_proteins_uniprot.tsv' # Replace with the path to your TSV file containing a list of UniProt entries
71
+ dst_folder = '/home/carruthers/Downloads/GT1_proteome' # Replace with your destination folder
72
+ download_alphafold_structures(tsv_location, dst_folder)
spacr/graph_learning.py CHANGED
@@ -1,82 +1,276 @@
1
+ import os
1
2
  import torch
2
- from torch_geometric.data import Data
3
- from torch_geometric.nn import GCNConv, global_mean_pool
3
+ import torch.nn as nn
4
4
  import torch.nn.functional as F
5
- from torch.nn import Linear
5
+ from collections import defaultdict
6
+ from torch.utils.data import Dataset, DataLoader
6
7
  import pandas as pd
7
- from sklearn.preprocessing import LabelEncoder
8
+ import numpy as np
9
+ import torch.optim as optim
8
10
 
9
- #1. Cell Nodes: Represent individual cells. Each cell node could have attributes like cell area, nuclear area, and a CNN-based phenotype score. These nodes could be visualized as circles labeled with "C".
10
- #2. Well Nodes: Represent wells in the 384-well plates. Wells are intermediary nodes that link cells to genes based on the experimental setup. Each well could contain multiple cells and be associated with certain gene knockouts. These nodes might not have direct attributes in the schematic but serve to connect cell nodes to gene nodes. These can be visualized as squares labeled with "W".
11
- #3. Gene Nodes: Represent genes that have been knocked out. Gene nodes are connected to well nodes, indicating which genes are knocked out in each well. Attributes might include the fraction of sequencing reads for that gene, indicating its relative abundance or importance in the well. These nodes can be visualized as diamonds labeled with "G".
12
-
13
- # Define a simple GNN model
14
- class GNN(torch.nn.Module):
15
- def __init__(self):
16
- super(GNN, self).__init__()
17
- self.conv1 = GCNConv(1, 16) # Assume node features are 1-dimensional for simplicity
18
- self.conv2 = GCNConv(16, 32)
19
- self.out = Linear(32, 1) # Predicting a single score for each cell/well
20
-
21
- def forward(self, data):
22
- x, edge_index = data.x, data.edge_index
11
+ def generate_graphs(sequencing, scores, cell_min, gene_min_read):
12
+ # Load and preprocess sequencing (gene) data
13
+ gene_df = pd.read_csv(sequencing)
14
+ gene_df = gene_df.rename(columns={'prc': 'well_id', 'grna': 'gene_id', 'count': 'read_count'})
15
+ # Filter out genes with read counts less than gene_min_read
16
+ gene_df = gene_df[gene_df['read_count'] >= gene_min_read]
17
+ total_reads_per_well = gene_df.groupby('well_id')['read_count'].sum().reset_index(name='total_reads')
18
+ gene_df = gene_df.merge(total_reads_per_well, on='well_id')
19
+ gene_df['well_read_fraction'] = gene_df['read_count'] / gene_df['total_reads']
20
+
21
+ # Load and preprocess cell score data
22
+ cell_df = pd.read_csv(scores)
23
+ cell_df = cell_df[['prcfo', 'prc', 'pred']].rename(columns={'prcfo': 'cell_id', 'prc': 'well_id', 'pred': 'score'})
24
+
25
+ # Create a global mapping of gene IDs to indices
26
+ unique_genes = gene_df['gene_id'].unique()
27
+ gene_id_to_index = {gene_id: index for index, gene_id in enumerate(unique_genes)}
28
+
29
+ graphs = []
30
+ for well_id in pd.unique(gene_df['well_id']):
31
+ well_genes = gene_df[gene_df['well_id'] == well_id]
32
+ well_cells = cell_df[cell_df['well_id'] == well_id]
33
+
34
+ # Skip wells with no cells or genes or with fewer cells than threshold
35
+ if well_cells.empty or well_genes.empty or len(well_cells) < cell_min:
36
+ continue
37
+
38
+ # Initialize gene features tensor with zeros for all unique genes
39
+ gene_features = torch.zeros((len(gene_id_to_index), 1), dtype=torch.float)
40
+
41
+ # Update gene features tensor with well_read_fraction for genes present in this well
42
+ for _, row in well_genes.iterrows():
43
+ gene_index = gene_id_to_index[row['gene_id']]
44
+ gene_features[gene_index] = torch.tensor([[row['well_read_fraction']]])
45
+
46
+ # Prepare cell features (scores)
47
+ cell_features = torch.tensor(well_cells['score'].values, dtype=torch.float).view(-1, 1)
48
+
49
+ num_genes = len(gene_id_to_index)
50
+ num_cells = cell_features.size(0)
51
+ num_nodes = num_genes + num_cells
52
+
53
+ # Create adjacency matrix connecting each cell to all genes in the well
54
+ adj = torch.zeros((num_nodes, num_nodes), dtype=torch.float)
55
+ for _, row in well_genes.iterrows():
56
+ gene_index = gene_id_to_index[row['gene_id']]
57
+ adj[num_genes:, gene_index] = 1
58
+
59
+ graph = {
60
+ 'adjacency_matrix': adj,
61
+ 'gene_features': gene_features,
62
+ 'cell_features': cell_features,
63
+ 'num_cells': num_cells,
64
+ 'num_genes': num_genes
65
+ }
66
+ graphs.append(graph)
67
+
68
+ print(f'Generated dataset with {len(graphs)} graphs')
69
+ return graphs, gene_id_to_index
70
+
71
+ def print_graphs_info(graphs, gene_id_to_index):
72
+ # Invert the gene_id_to_index mapping for easy lookup
73
+ index_to_gene_id = {v: k for k, v in gene_id_to_index.items()}
74
+
75
+ for i, graph in enumerate(graphs, start=1):
76
+ print(f"Graph {i}:")
77
+ num_genes = graph['num_genes']
78
+ num_cells = graph['num_cells']
79
+ gene_features = graph['gene_features']
80
+ cell_features = graph['cell_features']
81
+
82
+ print(f" Number of Genes: {num_genes}")
83
+ print(f" Number of Cells: {num_cells}")
84
+
85
+ # Identify genes present in the graph based on non-zero feature values
86
+ present_genes = [index_to_gene_id[idx] for idx, feature in enumerate(gene_features) if feature.item() > 0]
87
+ print(" Genes present in this Graph:", present_genes)
88
+
89
+ # Display gene features for genes present in the graph
90
+ print(" Gene Features:")
91
+ for gene_id in present_genes:
92
+ idx = gene_id_to_index[gene_id]
93
+ print(f" {gene_id}: {gene_features[idx].item()}")
94
+
95
+ # Display a sample of cell features, for brevity
96
+ print(" Cell Features (sample):")
97
+ for idx, feature in enumerate(cell_features[:min(5, len(cell_features))]):
98
+ print(f" Cell {idx+1}: {feature.item()}")
99
+
100
+ print("-" * 40)
101
+
102
+ class Attention(nn.Module):
103
+ def __init__(self, feature_dim, attn_dim, dropout_rate=0.1):
104
+ super(Attention, self).__init__()
105
+ self.query = nn.Linear(feature_dim, attn_dim)
106
+ self.key = nn.Linear(feature_dim, attn_dim)
107
+ self.value = nn.Linear(feature_dim, feature_dim)
108
+ self.scale = 1.0 / (attn_dim ** 0.5)
109
+ self.dropout = nn.Dropout(dropout_rate)
110
+
111
+ def forward(self, gene_features, cell_features):
112
+ # Queries come from the cell features
113
+ q = self.query(cell_features)
114
+ # Keys and values come from the gene features
115
+ k = self.key(gene_features)
116
+ v = self.value(gene_features)
23
117
 
24
- # Two layers of GCN convolution
25
- x = F.relu(self.conv1(x, edge_index))
26
- x = F.dropout(x, training=self.training)
27
- x = self.conv2(x, edge_index)
118
+ # Compute attention weights
119
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale
120
+ attn_weights = F.softmax(attn_weights, dim=-1)
121
+ # Apply dropout to attention weights
122
+ attn_weights = self.dropout(attn_weights)
123
+
124
+ # Apply attention weights to the values
125
+ attn_output = torch.matmul(attn_weights, v)
28
126
 
29
- # Global mean pooling
30
- x = global_mean_pool(x, batch=torch.tensor([0, 0, 0])) # Assume all nodes belong to the same graph
31
- x = self.out(x)
32
- return x
127
+ return attn_output, attn_weights
33
128
 
34
- def construct_graph(cell_data_loc, well_data_loc, well_id='prc', infer_id='gene', features=[]):
35
-
36
- # Example loading step
37
- cells_df = pd.read_csv(cell_data_loc)
38
- wells_df = pd.read_csv(well_data_loc)
39
-
40
- # Encode categorical data
41
- well_encoder = LabelEncoder()
42
- gene_encoder = LabelEncoder()
43
-
44
- cells_df['well_id'] = well_encoder.fit_transform(cells_df[well_id])
45
- wells_df['gene_id'] = gene_encoder.fit_transform(wells_df[infer_id])
46
-
47
- # Assume cell features are in columns ['feature1', 'feature2', ...]
48
- cell_features = torch.tensor(cells_df[[features]].values, dtype=torch.float)
49
-
50
- # Creating nodes for cells and assigning phenotype scores as labels
51
- y = torch.tensor(cells_df['phenotype_score'].values, dtype=torch.float).unsqueeze(1)
52
-
53
- # Constructing edges (this is simplified; you should define edges based on your data structure)
54
- edge_index = torch.tensor([[0, 1], [1, 2], [2, 0]], dtype=torch.long).t().contiguous()
55
-
56
- graphdata = Data(x=cell_features, edge_index=torch.tensor(edge_index, dtype=torch.long), y=y)
57
- return graphdata, len(features)
129
+ class GraphTransformer(nn.Module):
130
+ def __init__(self, gene_feature_size, cell_feature_size, hidden_dim, output_dim, attn_dim, dropout_rate=0.1):
131
+ super(GraphTransformer, self).__init__()
132
+ self.gene_transform = nn.Linear(gene_feature_size, hidden_dim)
133
+ self.cell_transform = nn.Linear(cell_feature_size, hidden_dim)
134
+ self.dropout = nn.Dropout(dropout_rate)
58
135
 
136
+ # Attention layer to let each cell attend to all genes
137
+ self.attention = Attention(hidden_dim, attn_dim)
59
138
 
60
- def train_gnn(cell_data_loc, well_data_loc, well_id='prc', infer_id='gene', lr=0.01, epochs=100):
139
+ # This layer is used to transform the combined features after attention
140
+ self.combine_transform = nn.Linear(2 * hidden_dim, hidden_dim)
141
+
142
+ # Output layer for predicting cell scores, ensuring it matches the number of cells
143
+ self.cell_output = nn.Linear(hidden_dim, output_dim)
144
+
145
+ def forward(self, adjacency_matrix, gene_features, cell_features):
146
+ # Apply initial transformation to gene and cell features
147
+ transformed_gene_features = F.relu(self.gene_transform(gene_features))
148
+ transformed_cell_features = F.relu(self.cell_transform(cell_features))
149
+
150
+ # Incorporate attention mechanism
151
+ attn_output, attn_weights = self.attention(transformed_gene_features, transformed_cell_features)
152
+
153
+ # Combine the transformed cell features with the attention output features
154
+ combined_cell_features = torch.cat((transformed_cell_features, attn_output), dim=1)
155
+
156
+ # Apply dropout here as well
157
+ combined_cell_features = self.dropout(combined_cell_features)
158
+
159
+ combined_cell_features = F.relu(self.combine_transform(combined_cell_features))
160
+
161
+ # Combine gene and cell features for message passing
162
+ combined_features = torch.cat((transformed_gene_features, combined_cell_features), dim=0)
163
+
164
+ # Apply message passing via adjacency matrix multiplication
165
+ message_passed_features = torch.matmul(adjacency_matrix, combined_features)
166
+
167
+ # Predict cell scores from the post-message passed cell features
168
+ cell_scores = self.cell_output(message_passed_features[-cell_features.size(0):])
169
+
170
+ return cell_scores, attn_weights
61
171
 
172
+ def train_graph_transformer(graphs, lr=0.01, dropout_rate=0.1, weight_decay=0.00001, epochs=100, save_fldr='', acc_threshold = 0.1):
62
173
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
63
- data, nr_of_features = construct_graph(cell_data_loc, well_data_loc).to(device)
64
-
65
- model = GNN(feature_size=nr_of_features).to(device)
174
+ model = GraphTransformer(gene_feature_size=1, cell_feature_size=1, hidden_dim=256, output_dim=1, attn_dim=128, dropout_rate=dropout_rate).to(device)
66
175
 
67
- # Assuming binary classification for simplicity
68
- optimizer = torch.optim.Adam(model.parameters(), lr=lr)
69
- criterion = torch.nn.BCELoss()
176
+ criterion = nn.MSELoss()
177
+ #optimizer = torch.optim.Adam(model.parameters(), lr=lr)
178
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
70
179
 
180
+ training_log = []
181
+
182
+ accumulate_grad_batches=1
183
+ threshold=acc_threshold
184
+
71
185
  for epoch in range(epochs):
72
186
  model.train()
187
+ total_loss = 0
188
+ total_correct = 0
189
+ total_samples = 0
73
190
  optimizer.zero_grad()
74
- out = model(data)
75
- loss = criterion(out[data.train_mask], data.y[data.train_mask])
76
- loss.backward()
77
- optimizer.step()
191
+ batch_count = 0 # Initialize batch_count
78
192
 
79
- if epoch % 10 == 0:
80
- print(f'Epoch {epoch}, Loss: {loss.item()}')
81
-
82
- return model
193
+ for graph in graphs:
194
+ adjacency_matrix = graph['adjacency_matrix'].to(device)
195
+ gene_features = graph['gene_features'].to(device)
196
+ cell_features = graph['cell_features'].to(device)
197
+ num_cells = graph['num_cells']
198
+ predictions, attn_weights = model(adjacency_matrix, gene_features, cell_features)
199
+ predictions = predictions.squeeze()
200
+ true_scores = cell_features[:num_cells, 0]
201
+ loss = criterion(predictions, true_scores) / accumulate_grad_batches
202
+ loss.backward()
203
+
204
+ # Calculate "accuracy"
205
+ with torch.no_grad():
206
+ correct_predictions = (torch.abs(predictions - true_scores) / true_scores <= threshold).sum().item()
207
+ total_correct += correct_predictions
208
+ total_samples += num_cells
209
+
210
+ batch_count += 1 # Increment batch_count
211
+ if batch_count % accumulate_grad_batches == 0 or batch_count == len(graphs):
212
+ optimizer.step()
213
+ optimizer.zero_grad()
214
+
215
+ total_loss += loss.item() * accumulate_grad_batches
216
+
217
+ accuracy = total_correct / total_samples
218
+ training_log.append({"Epoch": epoch+1, "Average Loss": total_loss / len(graphs), "Accuracy": accuracy})
219
+ print(f"Epoch {epoch+1}, Loss: {total_loss / len(graphs)}, Accuracy: {accuracy}", end="\r", flush=True)
220
+
221
+ # Save the training log and model as before
222
+ os.makedirs(save_fldr, exist_ok=True)
223
+ log_path = os.path.join(save_fldr, 'training_log.csv')
224
+ training_log_df = pd.DataFrame(training_log)
225
+ training_log_df.to_csv(log_path, index=False)
226
+ print(f"Training log saved to {log_path}")
227
+
228
+ model_path = os.path.join(save_fldr, 'model.pth')
229
+ torch.save(model.state_dict(), model_path)
230
+ print(f"Model saved to {model_path}")
231
+
232
+ return model
233
+
234
+ def annotate_cells_with_genes(graphs, model, gene_id_to_index):
235
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
236
+ model.to(device)
237
+ model.eval()
238
+ annotated_data = []
239
+
240
+ with torch.no_grad():
241
+ for graph in graphs:
242
+ adjacency_matrix = graph['adjacency_matrix'].to(device)
243
+ gene_features = graph['gene_features'].to(device)
244
+ cell_features = graph['cell_features'].to(device)
245
+
246
+ predictions, attn_weights = model(adjacency_matrix, gene_features, cell_features)
247
+ predictions = np.atleast_1d(predictions.squeeze().cpu().numpy())
248
+ attn_weights = np.atleast_2d(attn_weights.squeeze().cpu().numpy())
249
+
250
+ # This approach assumes all genes in gene_id_to_index are used in the model.
251
+ # Create a list of gene IDs present in this specific graph.
252
+ present_gene_ids = [key for key, value in gene_id_to_index.items() if value < gene_features.size(0)]
253
+
254
+ for cell_idx in range(cell_features.size(0)):
255
+ true_score = cell_features[cell_idx, 0].item()
256
+ predicted_score = predictions[cell_idx]
257
+
258
+ # Find the index of the most probable gene.
259
+ most_probable_gene_idx = attn_weights[cell_idx].argmax()
260
+
261
+ if len(present_gene_ids) > most_probable_gene_idx: # Ensure index is within the range
262
+ most_probable_gene_id = present_gene_ids[most_probable_gene_idx]
263
+ most_probable_gene_score = attn_weights[cell_idx, most_probable_gene_idx] if attn_weights.ndim > 1 else attn_weights[most_probable_gene_idx]
264
+
265
+ annotated_data.append({
266
+ "Cell ID": cell_idx,
267
+ "Most Probable Gene": most_probable_gene_id,
268
+ "Cell Score": true_score,
269
+ "Predicted Cell Score": predicted_score,
270
+ "Probability Score for Highest Gene": most_probable_gene_score
271
+ })
272
+ else:
273
+ # Handle the case where the index is out of bounds - this should not happen but is here for robustness
274
+ print("Error: Gene index out of bounds. This might indicate a mismatch in the model's output.")
275
+
276
+ return pd.DataFrame(annotated_data)
@@ -1,82 +1,84 @@
1
1
  import torch
2
- from torch_geometric.data import Data
3
- from torch_geometric.nn import GCNConv, global_mean_pool
2
+ import torch.nn as nn
4
3
  import torch.nn.functional as F
5
- from torch.nn import Linear
6
- import pandas as pd
7
- from sklearn.preprocessing import LabelEncoder
4
+ from torch.utils.data import Dataset, DataLoader, TensorDataset
8
5
 
9
- #1. Cell Nodes: Represent individual cells. Each cell node could have attributes like cell area, nuclear area, and a CNN-based phenotype score. These nodes could be visualized as circles labeled with "C".
10
- #2. Well Nodes: Represent wells in the 384-well plates. Wells are intermediary nodes that link cells to genes based on the experimental setup. Each well could contain multiple cells and be associated with certain gene knockouts. These nodes might not have direct attributes in the schematic but serve to connect cell nodes to gene nodes. These can be visualized as squares labeled with "W".
11
- #3. Gene Nodes: Represent genes that have been knocked out. Gene nodes are connected to well nodes, indicating which genes are knocked out in each well. Attributes might include the fraction of sequencing reads for that gene, indicating its relative abundance or importance in the well. These nodes can be visualized as diamonds labeled with "G".
12
-
13
- # Define a simple GNN model
14
- class GNN(torch.nn.Module):
15
- def __init__(self):
16
- super(GNN, self).__init__()
17
- self.conv1 = GCNConv(1, 16) # Assume node features are 1-dimensional for simplicity
18
- self.conv2 = GCNConv(16, 32)
19
- self.out = Linear(32, 1) # Predicting a single score for each cell/well
6
+ # Let's assume that the feature embedding part and the dataset loading part
7
+ # has already been taken care of, and your data is already in the format
8
+ # suitable for PyTorch (i.e., Tensors).
20
9
 
21
- def forward(self, data):
22
- x, edge_index = data.x, data.edge_index
10
+ class FeatureEmbedder(nn.Module):
11
+ def __init__(self, vocab_sizes, embedding_size):
12
+ super(FeatureEmbedder, self).__init__()
13
+ self.embeddings = nn.ModuleDict({
14
+ key: nn.Embedding(num_embeddings=vocab_size+1,
15
+ embedding_dim=embedding_size,
16
+ padding_idx=vocab_size)
17
+ for key, vocab_size in vocab_sizes.items()
18
+ })
19
+ # Adding the 'visit' embedding
20
+ self.embeddings['visit'] = nn.Parameter(torch.zeros(1, embedding_size))
21
+
22
+ def forward(self, feature_map, max_num_codes):
23
+ # Implementation will depend on how you want to handle sparse data
24
+ # This is just a placeholder
25
+ embeddings = {}
26
+ masks = {}
27
+ for key, tensor in feature_map.items():
28
+ embeddings[key] = self.embeddings[key](tensor.long())
29
+ mask = torch.ones_like(tensor, dtype=torch.float32)
30
+ masks[key] = mask.unsqueeze(-1)
23
31
 
24
- # Two layers of GCN convolution
25
- x = F.relu(self.conv1(x, edge_index))
26
- x = F.dropout(x, training=self.training)
27
- x = self.conv2(x, edge_index)
32
+ # Batch size hardcoded for simplicity in example
33
+ batch_size = 1 # Replace with actual batch size
34
+ embeddings['visit'] = self.embeddings['visit'].expand(batch_size, -1, -1)
35
+ masks['visit'] = torch.ones(batch_size, 1)
28
36
 
29
- # Global mean pooling
30
- x = global_mean_pool(x, batch=torch.tensor([0, 0, 0])) # Assume all nodes belong to the same graph
31
- x = self.out(x)
32
- return x
33
-
34
- def construct_graph(cell_data_loc, well_data_loc, well_id='prc', infer_id='gene', features=[]):
35
-
36
- # Example loading step
37
- cells_df = pd.read_csv(cell_data_loc)
38
- wells_df = pd.read_csv(well_data_loc)
39
-
40
- # Encode categorical data
41
- well_encoder = LabelEncoder()
42
- gene_encoder = LabelEncoder()
43
-
44
- cells_df['well_id'] = well_encoder.fit_transform(cells_df[well_id])
45
- wells_df['gene_id'] = gene_encoder.fit_transform(wells_df[infer_id])
46
-
47
- # Assume cell features are in columns ['feature1', 'feature2', ...]
48
- cell_features = torch.tensor(cells_df[[features]].values, dtype=torch.float)
49
-
50
- # Creating nodes for cells and assigning phenotype scores as labels
51
- y = torch.tensor(cells_df['phenotype_score'].values, dtype=torch.float).unsqueeze(1)
52
-
53
- # Constructing edges (this is simplified; you should define edges based on your data structure)
54
- edge_index = torch.tensor([[0, 1], [1, 2], [2, 0]], dtype=torch.long).t().contiguous()
55
-
56
- graphdata = Data(x=cell_features, edge_index=torch.tensor(edge_index, dtype=torch.long), y=y)
57
- return graphdata, len(features)
58
-
37
+ return embeddings, masks
59
38
 
60
- def train_gnn(cell_data_loc, well_data_loc, well_id='prc', infer_id='gene', lr=0.01, epochs=100):
61
-
62
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
63
- data, nr_of_features = construct_graph(cell_data_loc, well_data_loc).to(device)
64
-
65
- model = GNN(feature_size=nr_of_features).to(device)
39
+ class GraphConvolutionalTransformer(nn.Module):
40
+ def __init__(self, embedding_size=128, num_attention_heads=1, **kwargs):
41
+ super(GraphConvolutionalTransformer, self).__init__()
42
+ # Transformer Blocks
43
+ self.layers = nn.ModuleList([
44
+ nn.TransformerEncoderLayer(
45
+ d_model=embedding_size,
46
+ nhead=num_attention_heads,
47
+ batch_first=True)
48
+ for _ in range(kwargs.get('num_transformer_stack', 3))
49
+ ])
50
+ # Output Layer for Classification
51
+ self.output_layer = nn.Linear(embedding_size, 1)
66
52
 
67
- # Assuming binary classification for simplicity
68
- optimizer = torch.optim.Adam(model.parameters(), lr=lr)
69
- criterion = torch.nn.BCELoss()
53
+ def feedforward(self, features, mask=None, training=None):
54
+ # Implement feedforward logic (placeholder)
55
+ pass
70
56
 
71
- for epoch in range(epochs):
72
- model.train()
73
- optimizer.zero_grad()
74
- out = model(data)
75
- loss = criterion(out[data.train_mask], data.y[data.train_mask])
76
- loss.backward()
77
- optimizer.step()
57
+ def forward(self, embeddings, masks, mask=None, training=False):
58
+ features = embeddings
59
+ attentions = [] # Storing attentions if needed
78
60
 
79
- if epoch % 10 == 0:
80
- print(f'Epoch {epoch}, Loss: {loss.item()}')
61
+ # Pass through each Transformer block
62
+ for layer in self.layers:
63
+ features = layer(features) # Apply transformer encoding here
81
64
 
82
- return model
65
+ if mask is not None:
66
+ features = features * mask
67
+
68
+ logits = self.output_layer(features[:, 0, :]) # Using the 'visit' embedding for classification
69
+ return logits, attentions
70
+
71
+ # Usage Example
72
+ vocab_sizes = {'dx_ints':3249, 'proc_ints':2210}
73
+ embedding_size = 128
74
+ gct_params = {
75
+ 'embedding_size': embedding_size,
76
+ 'num_transformer_stack': 3,
77
+ 'num_attention_heads': 1
78
+ }
79
+ feature_embedder = FeatureEmbedder(vocab_sizes, embedding_size)
80
+ gct_model = GraphConvolutionalTransformer(**gct_params)
81
+
82
+ # Assume `feature_map` is a dictionary of tensors, and `max_num_codes` is provided
83
+ embeddings, masks = feature_embedder(feature_map, max_num_codes)
84
+ logits, attentions = gct_model(embeddings, masks)
spacr/gui_classify_app.py CHANGED
@@ -17,7 +17,7 @@ except AttributeError:
17
17
 
18
18
  from .logger import log_function_call
19
19
  from .gui_utils import ScrollableFrame, StdoutRedirector, create_dark_mode, set_dark_style, set_default_font, generate_fields, process_stdout_stderr, safe_literal_eval, clear_canvas, main_thread_update_function
20
- from .gui_utils import classify_variables, check_classify_gui_settings, train_test_model_wrapper
20
+ from .gui_utils import classify_variables, check_classify_gui_settings, train_test_model_wrapper, read_settings_from_csv, update_settings_from_csv
21
21
 
22
22
  thread_control = {"run_thread": None, "stop_requested": False}
23
23
 
@@ -39,7 +39,6 @@ def run_classify_gui(q, fig_queue, stop_requested):
39
39
  process_stdout_stderr(q)
40
40
  try:
41
41
  settings = check_classify_gui_settings(vars_dict)
42
- #settings = add_mask_gui_defaults(settings)
43
42
  for key in settings:
44
43
  value = settings[key]
45
44
  print(key, value, type(value))
@@ -65,25 +64,10 @@ def import_settings(scrollable_frame):
65
64
  global vars_dict
66
65
 
67
66
  csv_file_path = filedialog.askopenfilename(filetypes=[("CSV files", "*.csv")])
68
-
69
- if not csv_file_path:
70
- return
71
-
72
- imported_variables = {}
73
-
74
- with open(csv_file_path, newline='') as csvfile:
75
- reader = csv.DictReader(csvfile)
76
- for row in reader:
77
- key = row['Key']
78
- value = row['Value']
79
- # Evaluate the value safely using safe_literal_eval
80
- imported_variables[key] = safe_literal_eval(value)
81
-
82
- # Track changed variables and apply the imported ones, printing changes as we go
83
- for key, var in vars_dict.items():
84
- if key in imported_variables and var.get() != imported_variables[key]:
85
- print(f"Updating '{key}' from '{var.get()}' to '{imported_variables[key]}'")
86
- var.set(imported_variables[key])
67
+ csv_settings = read_settings_from_csv(csv_file_path)
68
+ variables = classify_variables()
69
+ new_settings = update_settings_from_csv(variables, csv_settings)
70
+ vars_dict = generate_fields(new_settings, scrollable_frame)
87
71
 
88
72
  @log_function_call
89
73
  def initiate_classify_root(width, height):