spacr 0.0.17__py3-none-any.whl → 0.0.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +2 -0
- spacr/alpha.py +18 -0
- spacr/cli.py +1 -200
- spacr/core.py +267 -56
- spacr/graph_learning.py +256 -75
- spacr/graph_learning_lap.py +84 -0
- spacr/gui_classify_app.py +6 -21
- spacr/gui_mask_app.py +9 -43
- spacr/gui_measure_app.py +10 -24
- spacr/gui_sim_app.py +0 -213
- spacr/gui_utils.py +84 -66
- spacr/io.py +258 -110
- spacr/measure.py +11 -17
- spacr/old_code.py +187 -1
- spacr/plot.py +92 -87
- spacr/timelapse.py +213 -52
- spacr/utils.py +219 -118
- {spacr-0.0.17.dist-info → spacr-0.0.20.dist-info}/METADATA +28 -26
- spacr-0.0.20.dist-info/RECORD +31 -0
- {spacr-0.0.17.dist-info → spacr-0.0.20.dist-info}/WHEEL +1 -1
- spacr/gui_temp.py +0 -212
- spacr/test_annotate_app.py +0 -58
- spacr/test_plot.py +0 -43
- spacr/test_train.py +0 -39
- spacr/test_utils.py +0 -33
- spacr-0.0.17.dist-info/RECORD +0 -34
- {spacr-0.0.17.dist-info → spacr-0.0.20.dist-info}/LICENSE +0 -0
- {spacr-0.0.17.dist-info → spacr-0.0.20.dist-info}/entry_points.txt +0 -0
- {spacr-0.0.17.dist-info → spacr-0.0.20.dist-info}/top_level.txt +0 -0
spacr/graph_learning.py
CHANGED
@@ -1,95 +1,276 @@
|
|
1
|
-
import
|
1
|
+
import os
|
2
2
|
import torch
|
3
|
-
|
4
|
-
|
3
|
+
import torch.nn as nn
|
4
|
+
import torch.nn.functional as F
|
5
|
+
from collections import defaultdict
|
6
|
+
from torch.utils.data import Dataset, DataLoader
|
7
|
+
import pandas as pd
|
5
8
|
import numpy as np
|
6
|
-
|
9
|
+
import torch.optim as optim
|
10
|
+
|
11
|
+
def generate_graphs(sequencing, scores, cell_min, gene_min_read):
|
12
|
+
# Load and preprocess sequencing (gene) data
|
13
|
+
gene_df = pd.read_csv(sequencing)
|
14
|
+
gene_df = gene_df.rename(columns={'prc': 'well_id', 'grna': 'gene_id', 'count': 'read_count'})
|
15
|
+
# Filter out genes with read counts less than gene_min_read
|
16
|
+
gene_df = gene_df[gene_df['read_count'] >= gene_min_read]
|
17
|
+
total_reads_per_well = gene_df.groupby('well_id')['read_count'].sum().reset_index(name='total_reads')
|
18
|
+
gene_df = gene_df.merge(total_reads_per_well, on='well_id')
|
19
|
+
gene_df['well_read_fraction'] = gene_df['read_count'] / gene_df['total_reads']
|
20
|
+
|
21
|
+
# Load and preprocess cell score data
|
22
|
+
cell_df = pd.read_csv(scores)
|
23
|
+
cell_df = cell_df[['prcfo', 'prc', 'pred']].rename(columns={'prcfo': 'cell_id', 'prc': 'well_id', 'pred': 'score'})
|
24
|
+
|
25
|
+
# Create a global mapping of gene IDs to indices
|
26
|
+
unique_genes = gene_df['gene_id'].unique()
|
27
|
+
gene_id_to_index = {gene_id: index for index, gene_id in enumerate(unique_genes)}
|
28
|
+
|
29
|
+
graphs = []
|
30
|
+
for well_id in pd.unique(gene_df['well_id']):
|
31
|
+
well_genes = gene_df[gene_df['well_id'] == well_id]
|
32
|
+
well_cells = cell_df[cell_df['well_id'] == well_id]
|
33
|
+
|
34
|
+
# Skip wells with no cells or genes or with fewer cells than threshold
|
35
|
+
if well_cells.empty or well_genes.empty or len(well_cells) < cell_min:
|
36
|
+
continue
|
37
|
+
|
38
|
+
# Initialize gene features tensor with zeros for all unique genes
|
39
|
+
gene_features = torch.zeros((len(gene_id_to_index), 1), dtype=torch.float)
|
40
|
+
|
41
|
+
# Update gene features tensor with well_read_fraction for genes present in this well
|
42
|
+
for _, row in well_genes.iterrows():
|
43
|
+
gene_index = gene_id_to_index[row['gene_id']]
|
44
|
+
gene_features[gene_index] = torch.tensor([[row['well_read_fraction']]])
|
45
|
+
|
46
|
+
# Prepare cell features (scores)
|
47
|
+
cell_features = torch.tensor(well_cells['score'].values, dtype=torch.float).view(-1, 1)
|
48
|
+
|
49
|
+
num_genes = len(gene_id_to_index)
|
50
|
+
num_cells = cell_features.size(0)
|
51
|
+
num_nodes = num_genes + num_cells
|
52
|
+
|
53
|
+
# Create adjacency matrix connecting each cell to all genes in the well
|
54
|
+
adj = torch.zeros((num_nodes, num_nodes), dtype=torch.float)
|
55
|
+
for _, row in well_genes.iterrows():
|
56
|
+
gene_index = gene_id_to_index[row['gene_id']]
|
57
|
+
adj[num_genes:, gene_index] = 1
|
58
|
+
|
59
|
+
graph = {
|
60
|
+
'adjacency_matrix': adj,
|
61
|
+
'gene_features': gene_features,
|
62
|
+
'cell_features': cell_features,
|
63
|
+
'num_cells': num_cells,
|
64
|
+
'num_genes': num_genes
|
65
|
+
}
|
66
|
+
graphs.append(graph)
|
67
|
+
|
68
|
+
print(f'Generated dataset with {len(graphs)} graphs')
|
69
|
+
return graphs, gene_id_to_index
|
7
70
|
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
self.conv1 = GCNConv(num_node_features, 16)
|
12
|
-
self.conv2 = GCNConv(16, num_classes)
|
71
|
+
def print_graphs_info(graphs, gene_id_to_index):
|
72
|
+
# Invert the gene_id_to_index mapping for easy lookup
|
73
|
+
index_to_gene_id = {v: k for k, v in gene_id_to_index.items()}
|
13
74
|
|
14
|
-
|
15
|
-
|
75
|
+
for i, graph in enumerate(graphs, start=1):
|
76
|
+
print(f"Graph {i}:")
|
77
|
+
num_genes = graph['num_genes']
|
78
|
+
num_cells = graph['num_cells']
|
79
|
+
gene_features = graph['gene_features']
|
80
|
+
cell_features = graph['cell_features']
|
81
|
+
|
82
|
+
print(f" Number of Genes: {num_genes}")
|
83
|
+
print(f" Number of Cells: {num_cells}")
|
84
|
+
|
85
|
+
# Identify genes present in the graph based on non-zero feature values
|
86
|
+
present_genes = [index_to_gene_id[idx] for idx, feature in enumerate(gene_features) if feature.item() > 0]
|
87
|
+
print(" Genes present in this Graph:", present_genes)
|
88
|
+
|
89
|
+
# Display gene features for genes present in the graph
|
90
|
+
print(" Gene Features:")
|
91
|
+
for gene_id in present_genes:
|
92
|
+
idx = gene_id_to_index[gene_id]
|
93
|
+
print(f" {gene_id}: {gene_features[idx].item()}")
|
94
|
+
|
95
|
+
# Display a sample of cell features, for brevity
|
96
|
+
print(" Cell Features (sample):")
|
97
|
+
for idx, feature in enumerate(cell_features[:min(5, len(cell_features))]):
|
98
|
+
print(f" Cell {idx+1}: {feature.item()}")
|
99
|
+
|
100
|
+
print("-" * 40)
|
101
|
+
|
102
|
+
class Attention(nn.Module):
|
103
|
+
def __init__(self, feature_dim, attn_dim, dropout_rate=0.1):
|
104
|
+
super(Attention, self).__init__()
|
105
|
+
self.query = nn.Linear(feature_dim, attn_dim)
|
106
|
+
self.key = nn.Linear(feature_dim, attn_dim)
|
107
|
+
self.value = nn.Linear(feature_dim, feature_dim)
|
108
|
+
self.scale = 1.0 / (attn_dim ** 0.5)
|
109
|
+
self.dropout = nn.Dropout(dropout_rate)
|
110
|
+
|
111
|
+
def forward(self, gene_features, cell_features):
|
112
|
+
# Queries come from the cell features
|
113
|
+
q = self.query(cell_features)
|
114
|
+
# Keys and values come from the gene features
|
115
|
+
k = self.key(gene_features)
|
116
|
+
v = self.value(gene_features)
|
16
117
|
|
17
|
-
|
18
|
-
|
19
|
-
|
118
|
+
# Compute attention weights
|
119
|
+
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
120
|
+
attn_weights = F.softmax(attn_weights, dim=-1)
|
121
|
+
# Apply dropout to attention weights
|
122
|
+
attn_weights = self.dropout(attn_weights)
|
123
|
+
|
124
|
+
# Apply attention weights to the values
|
125
|
+
attn_output = torch.matmul(attn_weights, v)
|
20
126
|
|
21
|
-
return
|
22
|
-
|
23
|
-
def create_graph(cells_df, wells_df, well_id='prc', infer_id='gene'):
|
24
|
-
# Nodes: Combine cell and well nodes
|
25
|
-
num_cells = cells_df.shape[0]
|
26
|
-
num_wells = len(wells_df[well_id].unique())
|
27
|
-
num_genes = len(wells_df[infer_id].unique())
|
28
|
-
|
29
|
-
# Edges: You need to define edges based on your cells' and wells' relationships
|
30
|
-
edge_index = [...] # Fill with (source, target) index pairs
|
31
|
-
|
32
|
-
# Node features: Depending on your data, this might include measurements for cells, and gene fractions for wells
|
33
|
-
x = [...] # Feature matrix
|
34
|
-
|
35
|
-
# Labels: If you're predicting something specific, like gene knockouts
|
36
|
-
y = [...] # Target labels for nodes
|
37
|
-
|
38
|
-
data = Data(x=torch.tensor(x, dtype=torch.float), edge_index=torch.tensor(edge_index, dtype=torch.long), y=torch.tensor(y))
|
39
|
-
|
40
|
-
return data
|
127
|
+
return attn_output, attn_weights
|
41
128
|
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
129
|
+
class GraphTransformer(nn.Module):
|
130
|
+
def __init__(self, gene_feature_size, cell_feature_size, hidden_dim, output_dim, attn_dim, dropout_rate=0.1):
|
131
|
+
super(GraphTransformer, self).__init__()
|
132
|
+
self.gene_transform = nn.Linear(gene_feature_size, hidden_dim)
|
133
|
+
self.cell_transform = nn.Linear(cell_feature_size, hidden_dim)
|
134
|
+
self.dropout = nn.Dropout(dropout_rate)
|
135
|
+
|
136
|
+
# Attention layer to let each cell attend to all genes
|
137
|
+
self.attention = Attention(hidden_dim, attn_dim)
|
138
|
+
|
139
|
+
# This layer is used to transform the combined features after attention
|
140
|
+
self.combine_transform = nn.Linear(2 * hidden_dim, hidden_dim)
|
141
|
+
|
142
|
+
# Output layer for predicting cell scores, ensuring it matches the number of cells
|
143
|
+
self.cell_output = nn.Linear(hidden_dim, output_dim)
|
144
|
+
|
145
|
+
def forward(self, adjacency_matrix, gene_features, cell_features):
|
146
|
+
# Apply initial transformation to gene and cell features
|
147
|
+
transformed_gene_features = F.relu(self.gene_transform(gene_features))
|
148
|
+
transformed_cell_features = F.relu(self.cell_transform(cell_features))
|
149
|
+
|
150
|
+
# Incorporate attention mechanism
|
151
|
+
attn_output, attn_weights = self.attention(transformed_gene_features, transformed_cell_features)
|
152
|
+
|
153
|
+
# Combine the transformed cell features with the attention output features
|
154
|
+
combined_cell_features = torch.cat((transformed_cell_features, attn_output), dim=1)
|
155
|
+
|
156
|
+
# Apply dropout here as well
|
157
|
+
combined_cell_features = self.dropout(combined_cell_features)
|
158
|
+
|
159
|
+
combined_cell_features = F.relu(self.combine_transform(combined_cell_features))
|
160
|
+
|
161
|
+
# Combine gene and cell features for message passing
|
162
|
+
combined_features = torch.cat((transformed_gene_features, combined_cell_features), dim=0)
|
163
|
+
|
164
|
+
# Apply message passing via adjacency matrix multiplication
|
165
|
+
message_passed_features = torch.matmul(adjacency_matrix, combined_features)
|
166
|
+
|
167
|
+
# Predict cell scores from the post-message passed cell features
|
168
|
+
cell_scores = self.cell_output(message_passed_features[-cell_features.size(0):])
|
169
|
+
|
170
|
+
return cell_scores, attn_weights
|
57
171
|
|
58
|
-
|
172
|
+
def train_graph_transformer(graphs, lr=0.01, dropout_rate=0.1, weight_decay=0.00001, epochs=100, save_fldr='', acc_threshold = 0.1):
|
173
|
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
174
|
+
model = GraphTransformer(gene_feature_size=1, cell_feature_size=1, hidden_dim=256, output_dim=1, attn_dim=128, dropout_rate=dropout_rate).to(device)
|
175
|
+
|
176
|
+
criterion = nn.MSELoss()
|
177
|
+
#optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
178
|
+
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
179
|
+
|
180
|
+
training_log = []
|
59
181
|
|
60
|
-
|
61
|
-
|
62
|
-
y=torch.tensor(y, dtype=torch.long)) # Adjust dtype as needed
|
182
|
+
accumulate_grad_batches=1
|
183
|
+
threshold=acc_threshold
|
63
184
|
|
64
|
-
|
185
|
+
for epoch in range(epochs):
|
186
|
+
model.train()
|
187
|
+
total_loss = 0
|
188
|
+
total_correct = 0
|
189
|
+
total_samples = 0
|
190
|
+
optimizer.zero_grad()
|
191
|
+
batch_count = 0 # Initialize batch_count
|
192
|
+
|
193
|
+
for graph in graphs:
|
194
|
+
adjacency_matrix = graph['adjacency_matrix'].to(device)
|
195
|
+
gene_features = graph['gene_features'].to(device)
|
196
|
+
cell_features = graph['cell_features'].to(device)
|
197
|
+
num_cells = graph['num_cells']
|
198
|
+
predictions, attn_weights = model(adjacency_matrix, gene_features, cell_features)
|
199
|
+
predictions = predictions.squeeze()
|
200
|
+
true_scores = cell_features[:num_cells, 0]
|
201
|
+
loss = criterion(predictions, true_scores) / accumulate_grad_batches
|
202
|
+
loss.backward()
|
65
203
|
|
204
|
+
# Calculate "accuracy"
|
205
|
+
with torch.no_grad():
|
206
|
+
correct_predictions = (torch.abs(predictions - true_scores) / true_scores <= threshold).sum().item()
|
207
|
+
total_correct += correct_predictions
|
208
|
+
total_samples += num_cells
|
66
209
|
|
67
|
-
|
210
|
+
batch_count += 1 # Increment batch_count
|
211
|
+
if batch_count % accumulate_grad_batches == 0 or batch_count == len(graphs):
|
212
|
+
optimizer.step()
|
213
|
+
optimizer.zero_grad()
|
68
214
|
|
69
|
-
|
70
|
-
|
71
|
-
|
215
|
+
total_loss += loss.item() * accumulate_grad_batches
|
216
|
+
|
217
|
+
accuracy = total_correct / total_samples
|
218
|
+
training_log.append({"Epoch": epoch+1, "Average Loss": total_loss / len(graphs), "Accuracy": accuracy})
|
219
|
+
print(f"Epoch {epoch+1}, Loss: {total_loss / len(graphs)}, Accuracy: {accuracy}", end="\r", flush=True)
|
220
|
+
|
221
|
+
# Save the training log and model as before
|
222
|
+
os.makedirs(save_fldr, exist_ok=True)
|
223
|
+
log_path = os.path.join(save_fldr, 'training_log.csv')
|
224
|
+
training_log_df = pd.DataFrame(training_log)
|
225
|
+
training_log_df.to_csv(log_path, index=False)
|
226
|
+
print(f"Training log saved to {log_path}")
|
227
|
+
|
228
|
+
model_path = os.path.join(save_fldr, 'model.pth')
|
229
|
+
torch.save(model.state_dict(), model_path)
|
230
|
+
print(f"Model saved to {model_path}")
|
72
231
|
|
73
|
-
|
74
|
-
|
232
|
+
return model
|
233
|
+
|
234
|
+
def annotate_cells_with_genes(graphs, model, gene_id_to_index):
|
235
|
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
236
|
+
model.to(device)
|
237
|
+
model.eval()
|
238
|
+
annotated_data = []
|
75
239
|
|
76
|
-
|
77
|
-
|
240
|
+
with torch.no_grad():
|
241
|
+
for graph in graphs:
|
242
|
+
adjacency_matrix = graph['adjacency_matrix'].to(device)
|
243
|
+
gene_features = graph['gene_features'].to(device)
|
244
|
+
cell_features = graph['cell_features'].to(device)
|
78
245
|
|
79
|
-
|
246
|
+
predictions, attn_weights = model(adjacency_matrix, gene_features, cell_features)
|
247
|
+
predictions = np.atleast_1d(predictions.squeeze().cpu().numpy())
|
248
|
+
attn_weights = np.atleast_2d(attn_weights.squeeze().cpu().numpy())
|
80
249
|
|
250
|
+
# This approach assumes all genes in gene_id_to_index are used in the model.
|
251
|
+
# Create a list of gene IDs present in this specific graph.
|
252
|
+
present_gene_ids = [key for key, value in gene_id_to_index.items() if value < gene_features.size(0)]
|
81
253
|
|
82
|
-
|
83
|
-
|
254
|
+
for cell_idx in range(cell_features.size(0)):
|
255
|
+
true_score = cell_features[cell_idx, 0].item()
|
256
|
+
predicted_score = predictions[cell_idx]
|
257
|
+
|
258
|
+
# Find the index of the most probable gene.
|
259
|
+
most_probable_gene_idx = attn_weights[cell_idx].argmax()
|
84
260
|
|
85
|
-
|
86
|
-
|
87
|
-
|
261
|
+
if len(present_gene_ids) > most_probable_gene_idx: # Ensure index is within the range
|
262
|
+
most_probable_gene_id = present_gene_ids[most_probable_gene_idx]
|
263
|
+
most_probable_gene_score = attn_weights[cell_idx, most_probable_gene_idx] if attn_weights.ndim > 1 else attn_weights[most_probable_gene_idx]
|
88
264
|
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
265
|
+
annotated_data.append({
|
266
|
+
"Cell ID": cell_idx,
|
267
|
+
"Most Probable Gene": most_probable_gene_id,
|
268
|
+
"Cell Score": true_score,
|
269
|
+
"Predicted Cell Score": predicted_score,
|
270
|
+
"Probability Score for Highest Gene": most_probable_gene_score
|
271
|
+
})
|
272
|
+
else:
|
273
|
+
# Handle the case where the index is out of bounds - this should not happen but is here for robustness
|
274
|
+
print("Error: Gene index out of bounds. This might indicate a mismatch in the model's output.")
|
275
|
+
|
276
|
+
return pd.DataFrame(annotated_data)
|
@@ -0,0 +1,84 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
from torch.utils.data import Dataset, DataLoader, TensorDataset
|
5
|
+
|
6
|
+
# Let's assume that the feature embedding part and the dataset loading part
|
7
|
+
# has already been taken care of, and your data is already in the format
|
8
|
+
# suitable for PyTorch (i.e., Tensors).
|
9
|
+
|
10
|
+
class FeatureEmbedder(nn.Module):
|
11
|
+
def __init__(self, vocab_sizes, embedding_size):
|
12
|
+
super(FeatureEmbedder, self).__init__()
|
13
|
+
self.embeddings = nn.ModuleDict({
|
14
|
+
key: nn.Embedding(num_embeddings=vocab_size+1,
|
15
|
+
embedding_dim=embedding_size,
|
16
|
+
padding_idx=vocab_size)
|
17
|
+
for key, vocab_size in vocab_sizes.items()
|
18
|
+
})
|
19
|
+
# Adding the 'visit' embedding
|
20
|
+
self.embeddings['visit'] = nn.Parameter(torch.zeros(1, embedding_size))
|
21
|
+
|
22
|
+
def forward(self, feature_map, max_num_codes):
|
23
|
+
# Implementation will depend on how you want to handle sparse data
|
24
|
+
# This is just a placeholder
|
25
|
+
embeddings = {}
|
26
|
+
masks = {}
|
27
|
+
for key, tensor in feature_map.items():
|
28
|
+
embeddings[key] = self.embeddings[key](tensor.long())
|
29
|
+
mask = torch.ones_like(tensor, dtype=torch.float32)
|
30
|
+
masks[key] = mask.unsqueeze(-1)
|
31
|
+
|
32
|
+
# Batch size hardcoded for simplicity in example
|
33
|
+
batch_size = 1 # Replace with actual batch size
|
34
|
+
embeddings['visit'] = self.embeddings['visit'].expand(batch_size, -1, -1)
|
35
|
+
masks['visit'] = torch.ones(batch_size, 1)
|
36
|
+
|
37
|
+
return embeddings, masks
|
38
|
+
|
39
|
+
class GraphConvolutionalTransformer(nn.Module):
|
40
|
+
def __init__(self, embedding_size=128, num_attention_heads=1, **kwargs):
|
41
|
+
super(GraphConvolutionalTransformer, self).__init__()
|
42
|
+
# Transformer Blocks
|
43
|
+
self.layers = nn.ModuleList([
|
44
|
+
nn.TransformerEncoderLayer(
|
45
|
+
d_model=embedding_size,
|
46
|
+
nhead=num_attention_heads,
|
47
|
+
batch_first=True)
|
48
|
+
for _ in range(kwargs.get('num_transformer_stack', 3))
|
49
|
+
])
|
50
|
+
# Output Layer for Classification
|
51
|
+
self.output_layer = nn.Linear(embedding_size, 1)
|
52
|
+
|
53
|
+
def feedforward(self, features, mask=None, training=None):
|
54
|
+
# Implement feedforward logic (placeholder)
|
55
|
+
pass
|
56
|
+
|
57
|
+
def forward(self, embeddings, masks, mask=None, training=False):
|
58
|
+
features = embeddings
|
59
|
+
attentions = [] # Storing attentions if needed
|
60
|
+
|
61
|
+
# Pass through each Transformer block
|
62
|
+
for layer in self.layers:
|
63
|
+
features = layer(features) # Apply transformer encoding here
|
64
|
+
|
65
|
+
if mask is not None:
|
66
|
+
features = features * mask
|
67
|
+
|
68
|
+
logits = self.output_layer(features[:, 0, :]) # Using the 'visit' embedding for classification
|
69
|
+
return logits, attentions
|
70
|
+
|
71
|
+
# Usage Example
|
72
|
+
vocab_sizes = {'dx_ints':3249, 'proc_ints':2210}
|
73
|
+
embedding_size = 128
|
74
|
+
gct_params = {
|
75
|
+
'embedding_size': embedding_size,
|
76
|
+
'num_transformer_stack': 3,
|
77
|
+
'num_attention_heads': 1
|
78
|
+
}
|
79
|
+
feature_embedder = FeatureEmbedder(vocab_sizes, embedding_size)
|
80
|
+
gct_model = GraphConvolutionalTransformer(**gct_params)
|
81
|
+
|
82
|
+
# Assume `feature_map` is a dictionary of tensors, and `max_num_codes` is provided
|
83
|
+
embeddings, masks = feature_embedder(feature_map, max_num_codes)
|
84
|
+
logits, attentions = gct_model(embeddings, masks)
|
spacr/gui_classify_app.py
CHANGED
@@ -17,7 +17,7 @@ except AttributeError:
|
|
17
17
|
|
18
18
|
from .logger import log_function_call
|
19
19
|
from .gui_utils import ScrollableFrame, StdoutRedirector, create_dark_mode, set_dark_style, set_default_font, generate_fields, process_stdout_stderr, safe_literal_eval, clear_canvas, main_thread_update_function
|
20
|
-
from .gui_utils import classify_variables, check_classify_gui_settings, train_test_model_wrapper
|
20
|
+
from .gui_utils import classify_variables, check_classify_gui_settings, train_test_model_wrapper, read_settings_from_csv, update_settings_from_csv
|
21
21
|
|
22
22
|
thread_control = {"run_thread": None, "stop_requested": False}
|
23
23
|
|
@@ -39,7 +39,6 @@ def run_classify_gui(q, fig_queue, stop_requested):
|
|
39
39
|
process_stdout_stderr(q)
|
40
40
|
try:
|
41
41
|
settings = check_classify_gui_settings(vars_dict)
|
42
|
-
#settings = add_mask_gui_defaults(settings)
|
43
42
|
for key in settings:
|
44
43
|
value = settings[key]
|
45
44
|
print(key, value, type(value))
|
@@ -65,25 +64,10 @@ def import_settings(scrollable_frame):
|
|
65
64
|
global vars_dict
|
66
65
|
|
67
66
|
csv_file_path = filedialog.askopenfilename(filetypes=[("CSV files", "*.csv")])
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
imported_variables = {}
|
73
|
-
|
74
|
-
with open(csv_file_path, newline='') as csvfile:
|
75
|
-
reader = csv.DictReader(csvfile)
|
76
|
-
for row in reader:
|
77
|
-
key = row['Key']
|
78
|
-
value = row['Value']
|
79
|
-
# Evaluate the value safely using safe_literal_eval
|
80
|
-
imported_variables[key] = safe_literal_eval(value)
|
81
|
-
|
82
|
-
# Track changed variables and apply the imported ones, printing changes as we go
|
83
|
-
for key, var in vars_dict.items():
|
84
|
-
if key in imported_variables and var.get() != imported_variables[key]:
|
85
|
-
print(f"Updating '{key}' from '{var.get()}' to '{imported_variables[key]}'")
|
86
|
-
var.set(imported_variables[key])
|
67
|
+
csv_settings = read_settings_from_csv(csv_file_path)
|
68
|
+
variables = classify_variables()
|
69
|
+
new_settings = update_settings_from_csv(variables, csv_settings)
|
70
|
+
vars_dict = generate_fields(new_settings, scrollable_frame)
|
87
71
|
|
88
72
|
@log_function_call
|
89
73
|
def initiate_classify_root(width, height):
|
@@ -171,6 +155,7 @@ def initiate_classify_root(width, height):
|
|
171
155
|
canvas_widget = canvas.get_tk_widget()
|
172
156
|
horizontal_container.add(canvas_widget, stretch="always")
|
173
157
|
canvas.draw()
|
158
|
+
canvas.figure = figure
|
174
159
|
|
175
160
|
# Console output setup below the settings
|
176
161
|
console_output = scrolledtext.ScrolledText(vertical_container, height=10)
|
spacr/gui_mask_app.py
CHANGED
@@ -15,10 +15,9 @@ try:
|
|
15
15
|
except AttributeError:
|
16
16
|
pass
|
17
17
|
|
18
|
-
|
19
18
|
from .logger import log_function_call
|
20
19
|
from .gui_utils import ScrollableFrame, StdoutRedirector, safe_literal_eval, clear_canvas, main_thread_update_function, create_dark_mode, set_dark_style, set_default_font, generate_fields, process_stdout_stderr
|
21
|
-
from .gui_utils import mask_variables, check_mask_gui_settings,
|
20
|
+
from .gui_utils import mask_variables, check_mask_gui_settings, preprocess_generate_masks_wrapper, read_settings_from_csv, update_settings_from_csv #, add_mask_gui_defaults
|
22
21
|
|
23
22
|
thread_control = {"run_thread": None, "stop_requested": False}
|
24
23
|
|
@@ -40,7 +39,7 @@ def run_mask_gui(q, fig_queue, stop_requested):
|
|
40
39
|
process_stdout_stderr(q)
|
41
40
|
try:
|
42
41
|
settings = check_mask_gui_settings(vars_dict)
|
43
|
-
settings = add_mask_gui_defaults(settings)
|
42
|
+
#settings = add_mask_gui_defaults(settings)
|
44
43
|
#for key in settings:
|
45
44
|
# value = settings[key]
|
46
45
|
# print(key, value, type(value))
|
@@ -66,25 +65,11 @@ def import_settings(scrollable_frame):
|
|
66
65
|
global vars_dict
|
67
66
|
|
68
67
|
csv_file_path = filedialog.askopenfilename(filetypes=[("CSV files", "*.csv")])
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
with open(csv_file_path, newline='') as csvfile:
|
76
|
-
reader = csv.DictReader(csvfile)
|
77
|
-
for row in reader:
|
78
|
-
key = row['Key']
|
79
|
-
value = row['Value']
|
80
|
-
# Evaluate the value safely using safe_literal_eval
|
81
|
-
imported_variables[key] = safe_literal_eval(value)
|
82
|
-
|
83
|
-
# Track changed variables and apply the imported ones, printing changes as we go
|
84
|
-
for key, var in vars_dict.items():
|
85
|
-
if key in imported_variables and var.get() != imported_variables[key]:
|
86
|
-
print(f"Updating '{key}' from '{var.get()}' to '{imported_variables[key]}'")
|
87
|
-
var.set(imported_variables[key])
|
68
|
+
csv_settings = read_settings_from_csv(csv_file_path)
|
69
|
+
variables = mask_variables()
|
70
|
+
#variables = add_mask_gui_defaults(variables)
|
71
|
+
new_settings = update_settings_from_csv(variables, csv_settings)
|
72
|
+
vars_dict = generate_fields(new_settings, scrollable_frame)
|
88
73
|
|
89
74
|
@log_function_call
|
90
75
|
def initiate_mask_root(width, height):
|
@@ -185,10 +170,10 @@ def initiate_mask_root(width, height):
|
|
185
170
|
|
186
171
|
# This is your GUI setup where you create the Run button
|
187
172
|
run_button = ttk.Button(scrollable_frame.scrollable_frame, text="Run",command=lambda: start_process(q, fig_queue))
|
188
|
-
run_button.grid(row=
|
173
|
+
run_button.grid(row=45, column=0, pady=10)
|
189
174
|
|
190
175
|
abort_button = ttk.Button(scrollable_frame.scrollable_frame, text="Abort", command=initiate_abort)
|
191
|
-
abort_button.grid(row=
|
176
|
+
abort_button.grid(row=45, column=1, pady=10)
|
192
177
|
|
193
178
|
progress_label = ttk.Label(scrollable_frame.scrollable_frame, text="Processing: 0%", background="#333333", foreground="white")
|
194
179
|
progress_label.grid(row=41, column=0, columnspan=2, sticky="ew", pady=(5, 0))
|
@@ -210,24 +195,5 @@ def gui_mask():
|
|
210
195
|
root, vars_dict = initiate_mask_root(1000, 1500)
|
211
196
|
root.mainloop()
|
212
197
|
|
213
|
-
#def gui_mask():
|
214
|
-
# from .cli import get_arg_parser
|
215
|
-
# from .version import version_str
|
216
|
-
#
|
217
|
-
# args = get_arg_parser().parse_args()
|
218
|
-
#
|
219
|
-
# if args.version:
|
220
|
-
# print(version_str)
|
221
|
-
# return
|
222
|
-
#
|
223
|
-
# if args.headless:
|
224
|
-
# settings = {}
|
225
|
-
# spacr.core.preprocess_generate_masks(settings['src'], settings=settings, advanced_settings={})
|
226
|
-
# return
|
227
|
-
#
|
228
|
-
# global vars_dict, root
|
229
|
-
# root, vars_dict = initiate_mask_root(1000, 1500)
|
230
|
-
# root.mainloop()
|
231
|
-
|
232
198
|
if __name__ == "__main__":
|
233
199
|
gui_mask()
|
spacr/gui_measure_app.py
CHANGED
@@ -7,7 +7,7 @@ import matplotlib.pyplot as plt
|
|
7
7
|
matplotlib.use('Agg') # Use the non-GUI Agg backend
|
8
8
|
from multiprocessing import Process, Queue, Value
|
9
9
|
from ttkthemes import ThemedTk
|
10
|
-
from tkinter import filedialog
|
10
|
+
from tkinter import filedialog, StringVar, BooleanVar, IntVar, DoubleVar, Tk
|
11
11
|
|
12
12
|
try:
|
13
13
|
ctypes.windll.shcore.SetProcessDpiAwareness(True)
|
@@ -16,7 +16,7 @@ except AttributeError:
|
|
16
16
|
|
17
17
|
from .logger import log_function_call
|
18
18
|
from .gui_utils import ScrollableFrame, StdoutRedirector, process_stdout_stderr, set_dark_style, set_default_font, generate_fields, create_dark_mode, main_thread_update_function
|
19
|
-
from .gui_utils import measure_variables, measure_crop_wrapper, clear_canvas, safe_literal_eval, check_measure_gui_settings,
|
19
|
+
from .gui_utils import measure_variables, measure_crop_wrapper, clear_canvas, safe_literal_eval, check_measure_gui_settings, read_settings_from_csv, update_settings_from_csv
|
20
20
|
|
21
21
|
thread_control = {"run_thread": None, "stop_requested": False}
|
22
22
|
|
@@ -25,8 +25,8 @@ def run_measure_gui(q, fig_queue, stop_requested):
|
|
25
25
|
global vars_dict
|
26
26
|
process_stdout_stderr(q)
|
27
27
|
try:
|
28
|
+
print('hello')
|
28
29
|
settings = check_measure_gui_settings(vars_dict)
|
29
|
-
settings = add_measure_gui_defaults(settings)
|
30
30
|
#for key in settings:
|
31
31
|
# value = settings[key]
|
32
32
|
# print(key, value, type(value))
|
@@ -60,29 +60,15 @@ def initiate_abort():
|
|
60
60
|
thread_control["run_thread"].terminate()
|
61
61
|
thread_control["run_thread"] = None
|
62
62
|
|
63
|
+
@log_function_call
|
63
64
|
def import_settings(scrollable_frame):
|
64
|
-
global vars_dict
|
65
|
+
global vars_dict
|
65
66
|
|
66
67
|
csv_file_path = filedialog.askopenfilename(filetypes=[("CSV files", "*.csv")])
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
imported_variables = {}
|
72
|
-
|
73
|
-
with open(csv_file_path, newline='') as csvfile:
|
74
|
-
reader = csv.DictReader(csvfile)
|
75
|
-
for row in reader:
|
76
|
-
key = row['Key']
|
77
|
-
value = row['Value']
|
78
|
-
# Evaluate the value safely using safe_literal_eval
|
79
|
-
imported_variables[key] = safe_literal_eval(value)
|
80
|
-
|
81
|
-
# Track changed variables and apply the imported ones, printing changes as we go
|
82
|
-
for key, var in vars_dict.items():
|
83
|
-
if key in imported_variables and var.get() != imported_variables[key]:
|
84
|
-
print(f"Updating '{key}' from '{var.get()}' to '{imported_variables[key]}'")
|
85
|
-
var.set(imported_variables[key])
|
68
|
+
csv_settings = read_settings_from_csv(csv_file_path)
|
69
|
+
variables = measure_variables()
|
70
|
+
new_settings = update_settings_from_csv(variables, csv_settings)
|
71
|
+
vars_dict = generate_fields(new_settings, scrollable_frame)
|
86
72
|
|
87
73
|
@log_function_call
|
88
74
|
def initiate_measure_root(width, height):
|
@@ -201,7 +187,7 @@ def initiate_measure_root(width, height):
|
|
201
187
|
_process_fig_queue()
|
202
188
|
create_dark_mode(root, style, console_output)
|
203
189
|
|
204
|
-
root.after(100, lambda: main_thread_update_function(root, q, fig_queue, canvas_widget, progress_label))
|
190
|
+
#root.after(100, lambda: main_thread_update_function(root, q, fig_queue, canvas_widget, progress_label))
|
205
191
|
|
206
192
|
return root, vars_dict
|
207
193
|
|