spacr 0.0.17__py3-none-any.whl → 0.0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/graph_learning.py CHANGED
@@ -1,95 +1,276 @@
1
- import pandas as pd
1
+ import os
2
2
  import torch
3
- from torch_geometric.data import Data, Dataset, DataLoader
4
- from torch_geometric.nn import GCNConv, GATConv, global_mean_pool
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from collections import defaultdict
6
+ from torch.utils.data import Dataset, DataLoader
7
+ import pandas as pd
5
8
  import numpy as np
6
- from sklearn.preprocessing import LabelEncoder
9
+ import torch.optim as optim
10
+
11
+ def generate_graphs(sequencing, scores, cell_min, gene_min_read):
12
+ # Load and preprocess sequencing (gene) data
13
+ gene_df = pd.read_csv(sequencing)
14
+ gene_df = gene_df.rename(columns={'prc': 'well_id', 'grna': 'gene_id', 'count': 'read_count'})
15
+ # Filter out genes with read counts less than gene_min_read
16
+ gene_df = gene_df[gene_df['read_count'] >= gene_min_read]
17
+ total_reads_per_well = gene_df.groupby('well_id')['read_count'].sum().reset_index(name='total_reads')
18
+ gene_df = gene_df.merge(total_reads_per_well, on='well_id')
19
+ gene_df['well_read_fraction'] = gene_df['read_count'] / gene_df['total_reads']
20
+
21
+ # Load and preprocess cell score data
22
+ cell_df = pd.read_csv(scores)
23
+ cell_df = cell_df[['prcfo', 'prc', 'pred']].rename(columns={'prcfo': 'cell_id', 'prc': 'well_id', 'pred': 'score'})
24
+
25
+ # Create a global mapping of gene IDs to indices
26
+ unique_genes = gene_df['gene_id'].unique()
27
+ gene_id_to_index = {gene_id: index for index, gene_id in enumerate(unique_genes)}
28
+
29
+ graphs = []
30
+ for well_id in pd.unique(gene_df['well_id']):
31
+ well_genes = gene_df[gene_df['well_id'] == well_id]
32
+ well_cells = cell_df[cell_df['well_id'] == well_id]
33
+
34
+ # Skip wells with no cells or genes or with fewer cells than threshold
35
+ if well_cells.empty or well_genes.empty or len(well_cells) < cell_min:
36
+ continue
37
+
38
+ # Initialize gene features tensor with zeros for all unique genes
39
+ gene_features = torch.zeros((len(gene_id_to_index), 1), dtype=torch.float)
40
+
41
+ # Update gene features tensor with well_read_fraction for genes present in this well
42
+ for _, row in well_genes.iterrows():
43
+ gene_index = gene_id_to_index[row['gene_id']]
44
+ gene_features[gene_index] = torch.tensor([[row['well_read_fraction']]])
45
+
46
+ # Prepare cell features (scores)
47
+ cell_features = torch.tensor(well_cells['score'].values, dtype=torch.float).view(-1, 1)
48
+
49
+ num_genes = len(gene_id_to_index)
50
+ num_cells = cell_features.size(0)
51
+ num_nodes = num_genes + num_cells
52
+
53
+ # Create adjacency matrix connecting each cell to all genes in the well
54
+ adj = torch.zeros((num_nodes, num_nodes), dtype=torch.float)
55
+ for _, row in well_genes.iterrows():
56
+ gene_index = gene_id_to_index[row['gene_id']]
57
+ adj[num_genes:, gene_index] = 1
58
+
59
+ graph = {
60
+ 'adjacency_matrix': adj,
61
+ 'gene_features': gene_features,
62
+ 'cell_features': cell_features,
63
+ 'num_cells': num_cells,
64
+ 'num_genes': num_genes
65
+ }
66
+ graphs.append(graph)
67
+
68
+ print(f'Generated dataset with {len(graphs)} graphs')
69
+ return graphs, gene_id_to_index
7
70
 
8
- class BasicGNN(torch.nn.Module):
9
- def __init__(self, num_node_features, num_classes):
10
- super(BasicGNN, self).__init__()
11
- self.conv1 = GCNConv(num_node_features, 16)
12
- self.conv2 = GCNConv(16, num_classes)
71
+ def print_graphs_info(graphs, gene_id_to_index):
72
+ # Invert the gene_id_to_index mapping for easy lookup
73
+ index_to_gene_id = {v: k for k, v in gene_id_to_index.items()}
13
74
 
14
- def forward(self, data):
15
- x, edge_index = data.x, data.edge_index
75
+ for i, graph in enumerate(graphs, start=1):
76
+ print(f"Graph {i}:")
77
+ num_genes = graph['num_genes']
78
+ num_cells = graph['num_cells']
79
+ gene_features = graph['gene_features']
80
+ cell_features = graph['cell_features']
81
+
82
+ print(f" Number of Genes: {num_genes}")
83
+ print(f" Number of Cells: {num_cells}")
84
+
85
+ # Identify genes present in the graph based on non-zero feature values
86
+ present_genes = [index_to_gene_id[idx] for idx, feature in enumerate(gene_features) if feature.item() > 0]
87
+ print(" Genes present in this Graph:", present_genes)
88
+
89
+ # Display gene features for genes present in the graph
90
+ print(" Gene Features:")
91
+ for gene_id in present_genes:
92
+ idx = gene_id_to_index[gene_id]
93
+ print(f" {gene_id}: {gene_features[idx].item()}")
94
+
95
+ # Display a sample of cell features, for brevity
96
+ print(" Cell Features (sample):")
97
+ for idx, feature in enumerate(cell_features[:min(5, len(cell_features))]):
98
+ print(f" Cell {idx+1}: {feature.item()}")
99
+
100
+ print("-" * 40)
101
+
102
+ class Attention(nn.Module):
103
+ def __init__(self, feature_dim, attn_dim, dropout_rate=0.1):
104
+ super(Attention, self).__init__()
105
+ self.query = nn.Linear(feature_dim, attn_dim)
106
+ self.key = nn.Linear(feature_dim, attn_dim)
107
+ self.value = nn.Linear(feature_dim, feature_dim)
108
+ self.scale = 1.0 / (attn_dim ** 0.5)
109
+ self.dropout = nn.Dropout(dropout_rate)
110
+
111
+ def forward(self, gene_features, cell_features):
112
+ # Queries come from the cell features
113
+ q = self.query(cell_features)
114
+ # Keys and values come from the gene features
115
+ k = self.key(gene_features)
116
+ v = self.value(gene_features)
16
117
 
17
- x = self.conv1(x, edge_index)
18
- x = torch.relu(x)
19
- x = self.conv2(x, edge_index)
118
+ # Compute attention weights
119
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale
120
+ attn_weights = F.softmax(attn_weights, dim=-1)
121
+ # Apply dropout to attention weights
122
+ attn_weights = self.dropout(attn_weights)
123
+
124
+ # Apply attention weights to the values
125
+ attn_output = torch.matmul(attn_weights, v)
20
126
 
21
- return x
22
-
23
- def create_graph(cells_df, wells_df, well_id='prc', infer_id='gene'):
24
- # Nodes: Combine cell and well nodes
25
- num_cells = cells_df.shape[0]
26
- num_wells = len(wells_df[well_id].unique())
27
- num_genes = len(wells_df[infer_id].unique())
28
-
29
- # Edges: You need to define edges based on your cells' and wells' relationships
30
- edge_index = [...] # Fill with (source, target) index pairs
31
-
32
- # Node features: Depending on your data, this might include measurements for cells, and gene fractions for wells
33
- x = [...] # Feature matrix
34
-
35
- # Labels: If you're predicting something specific, like gene knockouts
36
- y = [...] # Target labels for nodes
37
-
38
- data = Data(x=torch.tensor(x, dtype=torch.float), edge_index=torch.tensor(edge_index, dtype=torch.long), y=torch.tensor(y))
39
-
40
- return data
127
+ return attn_output, attn_weights
41
128
 
42
- def create_graph(cells_df, wells_df, well_id='prc', infer_id='gene'):
43
- # Assume cells_df and wells_df are preprocessed to include 'well_id' and 'gene_id' as encoded fields
44
-
45
- # Node feature creation (this is highly data-dependent; consider cells_df features like cell area, intensity, etc.)
46
- cell_features = [...] # Extract cell features into a matrix
47
- well_features = [...] # Optional: Aggregate or represent well features
48
- gene_features = [...] # Optional: Represent gene features
49
-
50
- x = np.concatenate([cell_features, well_features, gene_features], axis=0)
51
-
52
- # Edge index construction
53
- edge_index = [...] # You'll need to construct this based on your data relationships
54
-
55
- # Labels (assuming you have a column 'label' in cells_df for cell-level labels)
56
- y = cells_df['label'].values
129
+ class GraphTransformer(nn.Module):
130
+ def __init__(self, gene_feature_size, cell_feature_size, hidden_dim, output_dim, attn_dim, dropout_rate=0.1):
131
+ super(GraphTransformer, self).__init__()
132
+ self.gene_transform = nn.Linear(gene_feature_size, hidden_dim)
133
+ self.cell_transform = nn.Linear(cell_feature_size, hidden_dim)
134
+ self.dropout = nn.Dropout(dropout_rate)
135
+
136
+ # Attention layer to let each cell attend to all genes
137
+ self.attention = Attention(hidden_dim, attn_dim)
138
+
139
+ # This layer is used to transform the combined features after attention
140
+ self.combine_transform = nn.Linear(2 * hidden_dim, hidden_dim)
141
+
142
+ # Output layer for predicting cell scores, ensuring it matches the number of cells
143
+ self.cell_output = nn.Linear(hidden_dim, output_dim)
144
+
145
+ def forward(self, adjacency_matrix, gene_features, cell_features):
146
+ # Apply initial transformation to gene and cell features
147
+ transformed_gene_features = F.relu(self.gene_transform(gene_features))
148
+ transformed_cell_features = F.relu(self.cell_transform(cell_features))
149
+
150
+ # Incorporate attention mechanism
151
+ attn_output, attn_weights = self.attention(transformed_gene_features, transformed_cell_features)
152
+
153
+ # Combine the transformed cell features with the attention output features
154
+ combined_cell_features = torch.cat((transformed_cell_features, attn_output), dim=1)
155
+
156
+ # Apply dropout here as well
157
+ combined_cell_features = self.dropout(combined_cell_features)
158
+
159
+ combined_cell_features = F.relu(self.combine_transform(combined_cell_features))
160
+
161
+ # Combine gene and cell features for message passing
162
+ combined_features = torch.cat((transformed_gene_features, combined_cell_features), dim=0)
163
+
164
+ # Apply message passing via adjacency matrix multiplication
165
+ message_passed_features = torch.matmul(adjacency_matrix, combined_features)
166
+
167
+ # Predict cell scores from the post-message passed cell features
168
+ cell_scores = self.cell_output(message_passed_features[-cell_features.size(0):])
169
+
170
+ return cell_scores, attn_weights
57
171
 
58
- # If y needs to include well and gene nodes, you'll have to expand it appropriately, possibly with dummy labels
172
+ def train_graph_transformer(graphs, lr=0.01, dropout_rate=0.1, weight_decay=0.00001, epochs=100, save_fldr='', acc_threshold = 0.1):
173
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
174
+ model = GraphTransformer(gene_feature_size=1, cell_feature_size=1, hidden_dim=256, output_dim=1, attn_dim=128, dropout_rate=dropout_rate).to(device)
175
+
176
+ criterion = nn.MSELoss()
177
+ #optimizer = torch.optim.Adam(model.parameters(), lr=lr)
178
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
179
+
180
+ training_log = []
59
181
 
60
- data = Data(x=torch.tensor(x, dtype=torch.float),
61
- edge_index=torch.tensor(edge_index, dtype=torch.long),
62
- y=torch.tensor(y, dtype=torch.long)) # Adjust dtype as needed
182
+ accumulate_grad_batches=1
183
+ threshold=acc_threshold
63
184
 
64
- return data
185
+ for epoch in range(epochs):
186
+ model.train()
187
+ total_loss = 0
188
+ total_correct = 0
189
+ total_samples = 0
190
+ optimizer.zero_grad()
191
+ batch_count = 0 # Initialize batch_count
192
+
193
+ for graph in graphs:
194
+ adjacency_matrix = graph['adjacency_matrix'].to(device)
195
+ gene_features = graph['gene_features'].to(device)
196
+ cell_features = graph['cell_features'].to(device)
197
+ num_cells = graph['num_cells']
198
+ predictions, attn_weights = model(adjacency_matrix, gene_features, cell_features)
199
+ predictions = predictions.squeeze()
200
+ true_scores = cell_features[:num_cells, 0]
201
+ loss = criterion(predictions, true_scores) / accumulate_grad_batches
202
+ loss.backward()
65
203
 
204
+ # Calculate "accuracy"
205
+ with torch.no_grad():
206
+ correct_predictions = (torch.abs(predictions - true_scores) / true_scores <= threshold).sum().item()
207
+ total_correct += correct_predictions
208
+ total_samples += num_cells
66
209
 
67
- def train_gnn(cell_data_loc, well_data_loc, well_id='prc', infer_id='gene', lr=0.01):
210
+ batch_count += 1 # Increment batch_count
211
+ if batch_count % accumulate_grad_batches == 0 or batch_count == len(graphs):
212
+ optimizer.step()
213
+ optimizer.zero_grad()
68
214
 
69
- # Example loading step
70
- cells_df = pd.read_csv(cell_data_loc)
71
- wells_df = pd.read_csv(well_data_loc)
215
+ total_loss += loss.item() * accumulate_grad_batches
216
+
217
+ accuracy = total_correct / total_samples
218
+ training_log.append({"Epoch": epoch+1, "Average Loss": total_loss / len(graphs), "Accuracy": accuracy})
219
+ print(f"Epoch {epoch+1}, Loss: {total_loss / len(graphs)}, Accuracy: {accuracy}", end="\r", flush=True)
220
+
221
+ # Save the training log and model as before
222
+ os.makedirs(save_fldr, exist_ok=True)
223
+ log_path = os.path.join(save_fldr, 'training_log.csv')
224
+ training_log_df = pd.DataFrame(training_log)
225
+ training_log_df.to_csv(log_path, index=False)
226
+ print(f"Training log saved to {log_path}")
227
+
228
+ model_path = os.path.join(save_fldr, 'model.pth')
229
+ torch.save(model.state_dict(), model_path)
230
+ print(f"Model saved to {model_path}")
72
231
 
73
- well_encoder = LabelEncoder()
74
- cells_df['well_id'] = well_encoder.fit_transform(cells_df[well_id])
232
+ return model
233
+
234
+ def annotate_cells_with_genes(graphs, model, gene_id_to_index):
235
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
236
+ model.to(device)
237
+ model.eval()
238
+ annotated_data = []
75
239
 
76
- gene_encoder = LabelEncoder()
77
- wells_df['gene_id'] = gene_encoder.fit_transform(wells_df[infer_id])
240
+ with torch.no_grad():
241
+ for graph in graphs:
242
+ adjacency_matrix = graph['adjacency_matrix'].to(device)
243
+ gene_features = graph['gene_features'].to(device)
244
+ cell_features = graph['cell_features'].to(device)
78
245
 
79
- graph_data = create_graph(cells_df, wells_df)
246
+ predictions, attn_weights = model(adjacency_matrix, gene_features, cell_features)
247
+ predictions = np.atleast_1d(predictions.squeeze().cpu().numpy())
248
+ attn_weights = np.atleast_2d(attn_weights.squeeze().cpu().numpy())
80
249
 
250
+ # This approach assumes all genes in gene_id_to_index are used in the model.
251
+ # Create a list of gene IDs present in this specific graph.
252
+ present_gene_ids = [key for key, value in gene_id_to_index.items() if value < gene_features.size(0)]
81
253
 
82
- # Example instantiation and use
83
- model = BasicGNN(num_node_features=..., num_classes=...)
254
+ for cell_idx in range(cell_features.size(0)):
255
+ true_score = cell_features[cell_idx, 0].item()
256
+ predicted_score = predictions[cell_idx]
257
+
258
+ # Find the index of the most probable gene.
259
+ most_probable_gene_idx = attn_weights[cell_idx].argmax()
84
260
 
85
- # Assuming binary classification for simplicity
86
- criterion = torch.nn.BCEWithLogitsLoss()
87
- optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
261
+ if len(present_gene_ids) > most_probable_gene_idx: # Ensure index is within the range
262
+ most_probable_gene_id = present_gene_ids[most_probable_gene_idx]
263
+ most_probable_gene_score = attn_weights[cell_idx, most_probable_gene_idx] if attn_weights.ndim > 1 else attn_weights[most_probable_gene_idx]
88
264
 
89
- for epoch in range(200):
90
- optimizer.zero_grad()
91
- out = model(graph_data)
92
- loss = criterion(out, graph_data.y)
93
- loss.backward()
94
- optimizer.step()
95
- print(f'Epoch {epoch}, Loss: {loss.item()}')
265
+ annotated_data.append({
266
+ "Cell ID": cell_idx,
267
+ "Most Probable Gene": most_probable_gene_id,
268
+ "Cell Score": true_score,
269
+ "Predicted Cell Score": predicted_score,
270
+ "Probability Score for Highest Gene": most_probable_gene_score
271
+ })
272
+ else:
273
+ # Handle the case where the index is out of bounds - this should not happen but is here for robustness
274
+ print("Error: Gene index out of bounds. This might indicate a mismatch in the model's output.")
275
+
276
+ return pd.DataFrame(annotated_data)
@@ -0,0 +1,84 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.utils.data import Dataset, DataLoader, TensorDataset
5
+
6
+ # Let's assume that the feature embedding part and the dataset loading part
7
+ # has already been taken care of, and your data is already in the format
8
+ # suitable for PyTorch (i.e., Tensors).
9
+
10
+ class FeatureEmbedder(nn.Module):
11
+ def __init__(self, vocab_sizes, embedding_size):
12
+ super(FeatureEmbedder, self).__init__()
13
+ self.embeddings = nn.ModuleDict({
14
+ key: nn.Embedding(num_embeddings=vocab_size+1,
15
+ embedding_dim=embedding_size,
16
+ padding_idx=vocab_size)
17
+ for key, vocab_size in vocab_sizes.items()
18
+ })
19
+ # Adding the 'visit' embedding
20
+ self.embeddings['visit'] = nn.Parameter(torch.zeros(1, embedding_size))
21
+
22
+ def forward(self, feature_map, max_num_codes):
23
+ # Implementation will depend on how you want to handle sparse data
24
+ # This is just a placeholder
25
+ embeddings = {}
26
+ masks = {}
27
+ for key, tensor in feature_map.items():
28
+ embeddings[key] = self.embeddings[key](tensor.long())
29
+ mask = torch.ones_like(tensor, dtype=torch.float32)
30
+ masks[key] = mask.unsqueeze(-1)
31
+
32
+ # Batch size hardcoded for simplicity in example
33
+ batch_size = 1 # Replace with actual batch size
34
+ embeddings['visit'] = self.embeddings['visit'].expand(batch_size, -1, -1)
35
+ masks['visit'] = torch.ones(batch_size, 1)
36
+
37
+ return embeddings, masks
38
+
39
+ class GraphConvolutionalTransformer(nn.Module):
40
+ def __init__(self, embedding_size=128, num_attention_heads=1, **kwargs):
41
+ super(GraphConvolutionalTransformer, self).__init__()
42
+ # Transformer Blocks
43
+ self.layers = nn.ModuleList([
44
+ nn.TransformerEncoderLayer(
45
+ d_model=embedding_size,
46
+ nhead=num_attention_heads,
47
+ batch_first=True)
48
+ for _ in range(kwargs.get('num_transformer_stack', 3))
49
+ ])
50
+ # Output Layer for Classification
51
+ self.output_layer = nn.Linear(embedding_size, 1)
52
+
53
+ def feedforward(self, features, mask=None, training=None):
54
+ # Implement feedforward logic (placeholder)
55
+ pass
56
+
57
+ def forward(self, embeddings, masks, mask=None, training=False):
58
+ features = embeddings
59
+ attentions = [] # Storing attentions if needed
60
+
61
+ # Pass through each Transformer block
62
+ for layer in self.layers:
63
+ features = layer(features) # Apply transformer encoding here
64
+
65
+ if mask is not None:
66
+ features = features * mask
67
+
68
+ logits = self.output_layer(features[:, 0, :]) # Using the 'visit' embedding for classification
69
+ return logits, attentions
70
+
71
+ # Usage Example
72
+ vocab_sizes = {'dx_ints':3249, 'proc_ints':2210}
73
+ embedding_size = 128
74
+ gct_params = {
75
+ 'embedding_size': embedding_size,
76
+ 'num_transformer_stack': 3,
77
+ 'num_attention_heads': 1
78
+ }
79
+ feature_embedder = FeatureEmbedder(vocab_sizes, embedding_size)
80
+ gct_model = GraphConvolutionalTransformer(**gct_params)
81
+
82
+ # Assume `feature_map` is a dictionary of tensors, and `max_num_codes` is provided
83
+ embeddings, masks = feature_embedder(feature_map, max_num_codes)
84
+ logits, attentions = gct_model(embeddings, masks)
spacr/gui_classify_app.py CHANGED
@@ -17,7 +17,7 @@ except AttributeError:
17
17
 
18
18
  from .logger import log_function_call
19
19
  from .gui_utils import ScrollableFrame, StdoutRedirector, create_dark_mode, set_dark_style, set_default_font, generate_fields, process_stdout_stderr, safe_literal_eval, clear_canvas, main_thread_update_function
20
- from .gui_utils import classify_variables, check_classify_gui_settings, train_test_model_wrapper
20
+ from .gui_utils import classify_variables, check_classify_gui_settings, train_test_model_wrapper, read_settings_from_csv, update_settings_from_csv
21
21
 
22
22
  thread_control = {"run_thread": None, "stop_requested": False}
23
23
 
@@ -39,7 +39,6 @@ def run_classify_gui(q, fig_queue, stop_requested):
39
39
  process_stdout_stderr(q)
40
40
  try:
41
41
  settings = check_classify_gui_settings(vars_dict)
42
- #settings = add_mask_gui_defaults(settings)
43
42
  for key in settings:
44
43
  value = settings[key]
45
44
  print(key, value, type(value))
@@ -65,25 +64,10 @@ def import_settings(scrollable_frame):
65
64
  global vars_dict
66
65
 
67
66
  csv_file_path = filedialog.askopenfilename(filetypes=[("CSV files", "*.csv")])
68
-
69
- if not csv_file_path:
70
- return
71
-
72
- imported_variables = {}
73
-
74
- with open(csv_file_path, newline='') as csvfile:
75
- reader = csv.DictReader(csvfile)
76
- for row in reader:
77
- key = row['Key']
78
- value = row['Value']
79
- # Evaluate the value safely using safe_literal_eval
80
- imported_variables[key] = safe_literal_eval(value)
81
-
82
- # Track changed variables and apply the imported ones, printing changes as we go
83
- for key, var in vars_dict.items():
84
- if key in imported_variables and var.get() != imported_variables[key]:
85
- print(f"Updating '{key}' from '{var.get()}' to '{imported_variables[key]}'")
86
- var.set(imported_variables[key])
67
+ csv_settings = read_settings_from_csv(csv_file_path)
68
+ variables = classify_variables()
69
+ new_settings = update_settings_from_csv(variables, csv_settings)
70
+ vars_dict = generate_fields(new_settings, scrollable_frame)
87
71
 
88
72
  @log_function_call
89
73
  def initiate_classify_root(width, height):
@@ -171,6 +155,7 @@ def initiate_classify_root(width, height):
171
155
  canvas_widget = canvas.get_tk_widget()
172
156
  horizontal_container.add(canvas_widget, stretch="always")
173
157
  canvas.draw()
158
+ canvas.figure = figure
174
159
 
175
160
  # Console output setup below the settings
176
161
  console_output = scrolledtext.ScrolledText(vertical_container, height=10)
spacr/gui_mask_app.py CHANGED
@@ -15,10 +15,9 @@ try:
15
15
  except AttributeError:
16
16
  pass
17
17
 
18
-
19
18
  from .logger import log_function_call
20
19
  from .gui_utils import ScrollableFrame, StdoutRedirector, safe_literal_eval, clear_canvas, main_thread_update_function, create_dark_mode, set_dark_style, set_default_font, generate_fields, process_stdout_stderr
21
- from .gui_utils import mask_variables, check_mask_gui_settings, add_mask_gui_defaults, preprocess_generate_masks_wrapper
20
+ from .gui_utils import mask_variables, check_mask_gui_settings, preprocess_generate_masks_wrapper, read_settings_from_csv, update_settings_from_csv #, add_mask_gui_defaults
22
21
 
23
22
  thread_control = {"run_thread": None, "stop_requested": False}
24
23
 
@@ -40,7 +39,7 @@ def run_mask_gui(q, fig_queue, stop_requested):
40
39
  process_stdout_stderr(q)
41
40
  try:
42
41
  settings = check_mask_gui_settings(vars_dict)
43
- settings = add_mask_gui_defaults(settings)
42
+ #settings = add_mask_gui_defaults(settings)
44
43
  #for key in settings:
45
44
  # value = settings[key]
46
45
  # print(key, value, type(value))
@@ -66,25 +65,11 @@ def import_settings(scrollable_frame):
66
65
  global vars_dict
67
66
 
68
67
  csv_file_path = filedialog.askopenfilename(filetypes=[("CSV files", "*.csv")])
69
-
70
- if not csv_file_path:
71
- return
72
-
73
- imported_variables = {}
74
-
75
- with open(csv_file_path, newline='') as csvfile:
76
- reader = csv.DictReader(csvfile)
77
- for row in reader:
78
- key = row['Key']
79
- value = row['Value']
80
- # Evaluate the value safely using safe_literal_eval
81
- imported_variables[key] = safe_literal_eval(value)
82
-
83
- # Track changed variables and apply the imported ones, printing changes as we go
84
- for key, var in vars_dict.items():
85
- if key in imported_variables and var.get() != imported_variables[key]:
86
- print(f"Updating '{key}' from '{var.get()}' to '{imported_variables[key]}'")
87
- var.set(imported_variables[key])
68
+ csv_settings = read_settings_from_csv(csv_file_path)
69
+ variables = mask_variables()
70
+ #variables = add_mask_gui_defaults(variables)
71
+ new_settings = update_settings_from_csv(variables, csv_settings)
72
+ vars_dict = generate_fields(new_settings, scrollable_frame)
88
73
 
89
74
  @log_function_call
90
75
  def initiate_mask_root(width, height):
@@ -185,10 +170,10 @@ def initiate_mask_root(width, height):
185
170
 
186
171
  # This is your GUI setup where you create the Run button
187
172
  run_button = ttk.Button(scrollable_frame.scrollable_frame, text="Run",command=lambda: start_process(q, fig_queue))
188
- run_button.grid(row=40, column=0, pady=10)
173
+ run_button.grid(row=45, column=0, pady=10)
189
174
 
190
175
  abort_button = ttk.Button(scrollable_frame.scrollable_frame, text="Abort", command=initiate_abort)
191
- abort_button.grid(row=40, column=1, pady=10)
176
+ abort_button.grid(row=45, column=1, pady=10)
192
177
 
193
178
  progress_label = ttk.Label(scrollable_frame.scrollable_frame, text="Processing: 0%", background="#333333", foreground="white")
194
179
  progress_label.grid(row=41, column=0, columnspan=2, sticky="ew", pady=(5, 0))
@@ -210,24 +195,5 @@ def gui_mask():
210
195
  root, vars_dict = initiate_mask_root(1000, 1500)
211
196
  root.mainloop()
212
197
 
213
- #def gui_mask():
214
- # from .cli import get_arg_parser
215
- # from .version import version_str
216
- #
217
- # args = get_arg_parser().parse_args()
218
- #
219
- # if args.version:
220
- # print(version_str)
221
- # return
222
- #
223
- # if args.headless:
224
- # settings = {}
225
- # spacr.core.preprocess_generate_masks(settings['src'], settings=settings, advanced_settings={})
226
- # return
227
- #
228
- # global vars_dict, root
229
- # root, vars_dict = initiate_mask_root(1000, 1500)
230
- # root.mainloop()
231
-
232
198
  if __name__ == "__main__":
233
199
  gui_mask()
spacr/gui_measure_app.py CHANGED
@@ -7,7 +7,7 @@ import matplotlib.pyplot as plt
7
7
  matplotlib.use('Agg') # Use the non-GUI Agg backend
8
8
  from multiprocessing import Process, Queue, Value
9
9
  from ttkthemes import ThemedTk
10
- from tkinter import filedialog
10
+ from tkinter import filedialog, StringVar, BooleanVar, IntVar, DoubleVar, Tk
11
11
 
12
12
  try:
13
13
  ctypes.windll.shcore.SetProcessDpiAwareness(True)
@@ -16,7 +16,7 @@ except AttributeError:
16
16
 
17
17
  from .logger import log_function_call
18
18
  from .gui_utils import ScrollableFrame, StdoutRedirector, process_stdout_stderr, set_dark_style, set_default_font, generate_fields, create_dark_mode, main_thread_update_function
19
- from .gui_utils import measure_variables, measure_crop_wrapper, clear_canvas, safe_literal_eval, check_measure_gui_settings, add_measure_gui_defaults
19
+ from .gui_utils import measure_variables, measure_crop_wrapper, clear_canvas, safe_literal_eval, check_measure_gui_settings, read_settings_from_csv, update_settings_from_csv
20
20
 
21
21
  thread_control = {"run_thread": None, "stop_requested": False}
22
22
 
@@ -25,8 +25,8 @@ def run_measure_gui(q, fig_queue, stop_requested):
25
25
  global vars_dict
26
26
  process_stdout_stderr(q)
27
27
  try:
28
+ print('hello')
28
29
  settings = check_measure_gui_settings(vars_dict)
29
- settings = add_measure_gui_defaults(settings)
30
30
  #for key in settings:
31
31
  # value = settings[key]
32
32
  # print(key, value, type(value))
@@ -60,29 +60,15 @@ def initiate_abort():
60
60
  thread_control["run_thread"].terminate()
61
61
  thread_control["run_thread"] = None
62
62
 
63
+ @log_function_call
63
64
  def import_settings(scrollable_frame):
64
- global vars_dict, original_variables_structure
65
+ global vars_dict
65
66
 
66
67
  csv_file_path = filedialog.askopenfilename(filetypes=[("CSV files", "*.csv")])
67
-
68
- if not csv_file_path:
69
- return
70
-
71
- imported_variables = {}
72
-
73
- with open(csv_file_path, newline='') as csvfile:
74
- reader = csv.DictReader(csvfile)
75
- for row in reader:
76
- key = row['Key']
77
- value = row['Value']
78
- # Evaluate the value safely using safe_literal_eval
79
- imported_variables[key] = safe_literal_eval(value)
80
-
81
- # Track changed variables and apply the imported ones, printing changes as we go
82
- for key, var in vars_dict.items():
83
- if key in imported_variables and var.get() != imported_variables[key]:
84
- print(f"Updating '{key}' from '{var.get()}' to '{imported_variables[key]}'")
85
- var.set(imported_variables[key])
68
+ csv_settings = read_settings_from_csv(csv_file_path)
69
+ variables = measure_variables()
70
+ new_settings = update_settings_from_csv(variables, csv_settings)
71
+ vars_dict = generate_fields(new_settings, scrollable_frame)
86
72
 
87
73
  @log_function_call
88
74
  def initiate_measure_root(width, height):
@@ -201,7 +187,7 @@ def initiate_measure_root(width, height):
201
187
  _process_fig_queue()
202
188
  create_dark_mode(root, style, console_output)
203
189
 
204
- root.after(100, lambda: main_thread_update_function(root, q, fig_queue, canvas_widget, progress_label))
190
+ #root.after(100, lambda: main_thread_update_function(root, q, fig_queue, canvas_widget, progress_label))
205
191
 
206
192
  return root, vars_dict
207
193