spacr 0.0.82__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/deep_spacr.py CHANGED
@@ -6,7 +6,6 @@ from torch.autograd import grad
6
6
  from torch.optim.lr_scheduler import StepLR
7
7
  import torch.nn.functional as F
8
8
  from IPython.display import display, clear_output
9
-
10
9
  import matplotlib.pyplot as plt
11
10
  from PIL import Image
12
11
 
@@ -200,11 +199,20 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
200
199
  from .io import _save_settings, _copy_missclassified
201
200
  from .utils import pick_best_model
202
201
  from .core import generate_loaders
203
-
202
+ from .settings import set_default_train_test_model
203
+
204
+ torch.cuda.empty_cache()
205
+ torch.cuda.memory.empty_cache()
206
+ gc.collect()
207
+
208
+ settings = set_default_train_test_model(settings)
209
+ channels_str = ''.join(settings['channels'])
210
+ dst = os.path.join(src,'model', settings['model_type'], channels_str, str(f"epochs_{settings['epochs']}"))
211
+ os.makedirs(dst, exist_ok=True)
204
212
  settings['src'] = src
213
+ settings['dst'] = dst
205
214
  settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
206
- settings_csv = os.path.join(src,'settings','train_test_model_settings.csv')
207
- os.makedirs(os.path.join(src,'settings'), exist_ok=True)
215
+ settings_csv = os.path.join(dst,'train_test_model_settings.csv')
208
216
  settings_df.to_csv(settings_csv, index=False)
209
217
 
210
218
  if custom_model:
@@ -212,15 +220,9 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
212
220
 
213
221
  if settings['train']:
214
222
  _save_settings(settings, src)
215
- torch.cuda.empty_cache()
216
- torch.cuda.memory.empty_cache()
217
- gc.collect()
218
- dst = os.path.join(src,'model')
219
- os.makedirs(dst, exist_ok=True)
220
- settings['src'] = src
221
- settings['dst'] = dst
223
+
222
224
  if settings['train']:
223
- train, val, plate_names = generate_loaders(src,
225
+ train, val, plate_names, train_fig = generate_loaders(src,
224
226
  train_mode=settings['train_mode'],
225
227
  mode='train',
226
228
  image_size=settings['image_size'],
@@ -231,11 +233,42 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
231
233
  pin_memory=settings['pin_memory'],
232
234
  normalize=settings['normalize'],
233
235
  channels=settings['channels'],
236
+ augment=settings['augment'],
234
237
  verbose=settings['verbose'])
235
-
236
-
238
+
239
+ train_batch_1_figure = os.path.join(dst, 'batch_1.pdf')
240
+ train_fig.savefig(train_batch_1_figure, format='pdf', dpi=600)
241
+
242
+ if settings['train']:
243
+ model = train_model(dst = settings['dst'],
244
+ model_type=settings['model_type'],
245
+ train_loaders = train,
246
+ train_loader_names = plate_names,
247
+ train_mode = settings['train_mode'],
248
+ epochs = settings['epochs'],
249
+ learning_rate = settings['learning_rate'],
250
+ init_weights = settings['init_weights'],
251
+ weight_decay = settings['weight_decay'],
252
+ amsgrad = settings['amsgrad'],
253
+ optimizer_type = settings['optimizer_type'],
254
+ use_checkpoint = settings['use_checkpoint'],
255
+ dropout_rate = settings['dropout_rate'],
256
+ num_workers = settings['num_workers'],
257
+ val_loaders = val,
258
+ test_loaders = None,
259
+ intermedeate_save = settings['intermedeate_save'],
260
+ schedule = settings['schedule'],
261
+ loss_type=settings['loss_type'],
262
+ gradient_accumulation=settings['gradient_accumulation'],
263
+ gradient_accumulation_steps=settings['gradient_accumulation_steps'],
264
+ channels=settings['channels'])
265
+
266
+ torch.cuda.empty_cache()
267
+ torch.cuda.memory.empty_cache()
268
+ gc.collect()
269
+
237
270
  if settings['test']:
238
- test, _, plate_names_test = generate_loaders(src,
271
+ test, _, plate_names_test, train_fig = generate_loaders(src,
239
272
  train_mode=settings['train_mode'],
240
273
  mode='test',
241
274
  image_size=settings['image_size'],
@@ -246,6 +279,7 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
246
279
  pin_memory=settings['pin_memory'],
247
280
  normalize=settings['normalize'],
248
281
  channels=settings['channels'],
282
+ augment=False,
249
283
  verbose=settings['verbose'])
250
284
  if model == None:
251
285
  model_path = pick_best_model(src+'/model')
@@ -258,10 +292,10 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
258
292
  print(type(model))
259
293
  print(model)
260
294
 
261
- model_fldr = os.path.join(src,'model')
295
+ model_fldr = dst
262
296
  time_now = datetime.date.today().strftime('%y%m%d')
263
- result_loc = f'{model_fldr}/{model_type}_time_{time_now}_result.csv'
264
- acc_loc = f'{model_fldr}/{model_type}_time_{time_now}_acc.csv'
297
+ result_loc = f"{model_fldr}/{settings['model_type']}_time_{time_now}_test_result.csv"
298
+ acc_loc = f"{model_fldr}/{settings['model_type']}_time_{time_now}_test_acc.csv"
265
299
  print(f'Results wil be saved in: {result_loc}')
266
300
 
267
301
  result, accuracy = test_model_performance(loaders=test,
@@ -274,37 +308,12 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
274
308
  result.to_csv(result_loc, index=True, header=True, mode='w')
275
309
  accuracy.to_csv(acc_loc, index=True, header=True, mode='w')
276
310
  _copy_missclassified(accuracy)
277
- else:
278
- test = None
279
-
280
- if settings['train']:
281
- train_model(dst = settings['dst'],
282
- model_type=settings['model_type'],
283
- train_loaders = train,
284
- train_loader_names = plate_names,
285
- train_mode = settings['train_mode'],
286
- epochs = settings['epochs'],
287
- learning_rate = settings['learning_rate'],
288
- init_weights = settings['init_weights'],
289
- weight_decay = settings['weight_decay'],
290
- amsgrad = settings['amsgrad'],
291
- optimizer_type = settings['optimizer_type'],
292
- use_checkpoint = settings['use_checkpoint'],
293
- dropout_rate = settings['dropout_rate'],
294
- num_workers = settings['num_workers'],
295
- val_loaders = val,
296
- test_loaders = test,
297
- intermedeate_save = settings['intermedeate_save'],
298
- schedule = settings['schedule'],
299
- loss_type=settings['loss_type'],
300
- gradient_accumulation=settings['gradient_accumulation'],
301
- gradient_accumulation_steps=settings['gradient_accumulation_steps'])
302
311
 
303
312
  torch.cuda.empty_cache()
304
313
  torch.cuda.memory.empty_cache()
305
314
  gc.collect()
306
315
 
307
- def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='erm', epochs=100, learning_rate=0.0001, weight_decay=0.05, amsgrad=False, optimizer_type='adamw', use_checkpoint=False, dropout_rate=0, num_workers=20, val_loaders=None, test_loaders=None, init_weights='imagenet', intermedeate_save=None, chan_dict=None, schedule = None, loss_type='binary_cross_entropy_with_logits', gradient_accumulation=False, gradient_accumulation_steps=4):
316
+ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='erm', epochs=100, learning_rate=0.0001, weight_decay=0.05, amsgrad=False, optimizer_type='adamw', use_checkpoint=False, dropout_rate=0, num_workers=20, val_loaders=None, test_loaders=None, init_weights='imagenet', intermedeate_save=None, chan_dict=None, schedule = None, loss_type='binary_cross_entropy_with_logits', gradient_accumulation=False, gradient_accumulation_steps=4, channels=['r','g','b']):
308
317
  """
309
318
  Trains a model using the specified parameters.
310
319
 
@@ -349,7 +358,7 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
349
358
  kwargs = {'num_workers': num_workers, 'pin_memory': True} if use_cuda else {}
350
359
 
351
360
  for idx, (images, labels, filenames) in enumerate(train_loaders):
352
- batch, channels, height, width = images.shape
361
+ batch, chans, height, width = images.shape
353
362
  break
354
363
 
355
364
  model = choose_model(model_type, device, init_weights, dropout_rate, use_checkpoint)
@@ -432,10 +441,10 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
432
441
  if schedule == 'step_lr':
433
442
  scheduler.step()
434
443
 
435
- _save_progress(dst, results_df, train_metrics_df)
444
+ _save_progress(dst, results_df, train_metrics_df, epoch, epochs)
436
445
  clear_output(wait=True)
437
446
  display(results_df)
438
- _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
447
+ _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94], channels=channels)
439
448
 
440
449
  if train_mode == 'irm':
441
450
  dummy_w = torch.nn.Parameter(torch.Tensor([1.0])).to(device)
@@ -505,10 +514,10 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
505
514
 
506
515
  clear_output(wait=True)
507
516
  display(results_df)
508
- _save_progress(dst, results_df, train_metrics_df)
517
+ _save_progress(dst, results_df, train_metrics_df, epoch, epochs)
509
518
  _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
510
519
  print(f'Saved model: {dst}')
511
- return
520
+ return model
512
521
 
513
522
  def visualize_saliency_map(src, model_type='maxvit', model_path='', image_size=224, channels=[1,2,3], normalize=True, class_names=None, save_saliency=False, save_dir='saliency_maps'):
514
523
 
@@ -693,4 +702,82 @@ def visualize_integrated_gradients(src, model_path, target_label_idx=0, image_si
693
702
  if save_integrated_grads:
694
703
  os.makedirs(save_dir, exist_ok=True)
695
704
  integrated_grads_image = Image.fromarray((integrated_grads * 255).astype(np.uint8))
696
- integrated_grads_image.save(os.path.join(save_dir, f'integrated_grads_{file}'))
705
+ integrated_grads_image.save(os.path.join(save_dir, f'integrated_grads_{file}'))
706
+
707
+ class SmoothGrad:
708
+ def __init__(self, model, n_samples=50, stdev_spread=0.15):
709
+ self.model = model
710
+ self.n_samples = n_samples
711
+ self.stdev_spread = stdev_spread
712
+
713
+ def compute_smooth_grad(self, input_tensor, target_class):
714
+ self.model.eval()
715
+ stdev = self.stdev_spread * (input_tensor.max() - input_tensor.min())
716
+ total_gradients = torch.zeros_like(input_tensor)
717
+
718
+ for i in range(self.n_samples):
719
+ noise = torch.normal(mean=0, std=stdev, size=input_tensor.shape).to(input_tensor.device)
720
+ noisy_input = input_tensor + noise
721
+ noisy_input.requires_grad_()
722
+ output = self.model(noisy_input)
723
+ self.model.zero_grad()
724
+ output[0, target_class].backward()
725
+ total_gradients += noisy_input.grad
726
+
727
+ avg_gradients = total_gradients / self.n_samples
728
+ return avg_gradients.abs()
729
+
730
+ def visualize_smooth_grad(src, model_path, target_label_idx, image_size=224, channels=[1,2,3], normalize=True, save_smooth_grad=False, save_dir='smooth_grad'):
731
+
732
+ from .utils import preprocess_image
733
+
734
+ use_cuda = torch.cuda.is_available()
735
+ device = torch.device("cuda" if use_cuda else "cpu")
736
+
737
+ model = torch.load(model_path)
738
+ model.to(device)
739
+ smooth_grad = SmoothGrad(model)
740
+
741
+ if save_smooth_grad and not os.path.exists(save_dir):
742
+ os.makedirs(save_dir)
743
+
744
+ images = []
745
+ filenames = []
746
+ for file in os.listdir(src):
747
+ if not file.endswith('.png'):
748
+ continue
749
+ image_path = os.path.join(src, file)
750
+ image, input_tensor = preprocess_image(image_path, normalize=normalize, image_size=image_size, channels=channels)
751
+ images.append(image)
752
+ filenames.append(file)
753
+
754
+ input_tensor = input_tensor.to(device)
755
+ smooth_grad_map = smooth_grad.compute_smooth_grad(input_tensor, target_label_idx)
756
+ smooth_grad_map = np.mean(smooth_grad_map.cpu().data.numpy(), axis=1).squeeze()
757
+
758
+ fig, ax = plt.subplots(1, 3, figsize=(20, 5))
759
+ ax[0].imshow(image)
760
+ ax[0].axis('off')
761
+ ax[0].set_title("Original Image")
762
+ ax[1].imshow(smooth_grad_map, cmap='hot')
763
+ ax[1].axis('off')
764
+ ax[1].set_title("SmoothGrad")
765
+ overlay = np.array(image)
766
+ overlay = overlay / overlay.max()
767
+ smooth_grad_map_rgb = np.stack([smooth_grad_map] * 3, axis=-1) # Convert smooth grad map to RGB
768
+ overlay = (overlay * 0.5 + smooth_grad_map_rgb * 0.5).clip(0, 1)
769
+ ax[2].imshow(overlay)
770
+ ax[2].axis('off')
771
+ ax[2].set_title("Overlay")
772
+ plt.show()
773
+
774
+ if save_smooth_grad:
775
+ os.makedirs(save_dir, exist_ok=True)
776
+ smooth_grad_image = Image.fromarray((smooth_grad_map * 255).astype(np.uint8))
777
+ smooth_grad_image.save(os.path.join(save_dir, f'smooth_grad_{file}'))
778
+
779
+ # Usage
780
+ #src = '/path/to/images'
781
+ #model_path = '/path/to/model.pth'
782
+ #target_label_idx = 0 # Change this to the target class index
783
+ #visualize_smooth_grad(src, model_path, target_label_idx)
spacr/graph_learning.py CHANGED
@@ -9,6 +9,7 @@ from PIL import Image
9
9
  import dgl.nn.pytorch as dglnn
10
10
  from sklearn.datasets import make_classification
11
11
  from .utils import SelectChannels
12
+ from IPython.display import display
12
13
 
13
14
  # approach outline
14
15
  #
@@ -241,6 +242,31 @@ def analyze_associations(probabilities, sequencing_data):
241
242
  sequencing_data['positive_prob'] = probabilities
242
243
  return sequencing_data.groupby('gRNA').positive_prob.mean().sort_values(ascending=False)
243
244
 
245
+ def process_sequencing_df(seq):
246
+
247
+ if isinstance(seq, pd.DataFrame):
248
+ sequencing_df = seq
249
+ elif isinstance(seq, str):
250
+ sequencing_df = pd.read_csv(seq)
251
+
252
+ # Check if 'plate_row' column exists and split into 'plate' and 'row'
253
+ if 'plate_row' in sequencing_df.columns:
254
+ sequencing_df[['plate', 'row']] = sequencing_df['plate_row'].str.split('_', expand=True)
255
+
256
+ # Check if 'plate', 'row' and 'col' or 'plate', 'row' and 'column' exist
257
+ if {'plate', 'row', 'col'}.issubset(sequencing_df.columns) or {'plate', 'row', 'column'}.issubset(sequencing_df.columns):
258
+ if 'col' in sequencing_df.columns:
259
+ sequencing_df['prc'] = sequencing_df[['plate', 'row', 'col']].agg('_'.join, axis=1)
260
+ elif 'column' in sequencing_df.columns:
261
+ sequencing_df['prc'] = sequencing_df[['plate', 'row', 'column']].agg('_'.join, axis=1)
262
+
263
+ # Check if 'count', 'total_reads', 'read_fraction', 'grna' exist and create new dataframe
264
+ if {'count', 'total_reads', 'read_fraction', 'grna'}.issubset(sequencing_df.columns):
265
+ new_df = sequencing_df[['grna', 'prc', 'count', 'total_reads', 'read_fraction']]
266
+ return new_df
267
+
268
+ return sequencing_df
269
+
244
270
  def train_graph_transformer(src, lr=0.01, epochs=100, hidden_feats=128, n_classes=2, row_limit=None, image_size=224, channels=[1,2,3], normalize=True, test_mode=False):
245
271
  if test_mode:
246
272
  # Load MNIST data
@@ -260,7 +286,6 @@ def train_graph_transformer(src, lr=0.01, epochs=100, hidden_feats=128, n_classe
260
286
 
261
287
  # Normalize synthetic sequencing data
262
288
  sequencing_data = normalize_sequencing_data(sequencing_data)
263
-
264
289
  else:
265
290
  from .io import _read_and_join_tables
266
291
  from .utils import get_db_paths, get_sequencing_paths, correct_paths
@@ -274,18 +299,13 @@ def train_graph_transformer(src, lr=0.01, epochs=100, hidden_feats=128, n_classe
274
299
  sequencing_data = pd.DataFrame()
275
300
  for seq in seq_paths:
276
301
  sequencing_df = pd.read_csv(seq)
302
+ sequencing_df = process_sequencing_df(sequencing_df)
277
303
  sequencing_data = pd.concat([sequencing_data, sequencing_df], axis=0)
278
304
 
279
- all_df = pd.DataFrame()
280
- for db_path in db_paths:
281
- df = _read_and_join_tables(db_path, table_names=['png_list'])
282
- all_df = pd.concat([all_df, df], axis=0)
283
-
284
- tables = ['png_list']
285
305
  all_df = pd.DataFrame()
286
306
  image_paths = []
287
307
  for i, db_path in enumerate(db_paths):
288
- df = _read_and_join_tables(db_path, table_names=tables)
308
+ df = _read_and_join_tables(db_path, table_names=['png_list'])
289
309
  df, image_paths_tmp = correct_paths(df, src[i])
290
310
  all_df = pd.concat([all_df, df], axis=0)
291
311
  image_paths.extend(image_paths_tmp)
spacr/gui.py CHANGED
@@ -59,10 +59,10 @@ class MainApp(tk.Tk):
59
59
 
60
60
  # Load the logo image
61
61
  if not self.load_logo(logo_frame):
62
- tk.Label(logo_frame, text="Logo not found", bg="black", fg="white", font=('Arial', 24, tkFont.NORMAL)).pack(padx=10, pady=10)
62
+ tk.Label(logo_frame, text="Logo not found", bg="black", fg="white", font=('Helvetica', 24, tkFont.NORMAL)).pack(padx=10, pady=10)
63
63
 
64
64
  # Add SpaCr text below the logo with padding for sharper text
65
- tk.Label(logo_frame, text="SpaCr", bg="black", fg="#008080", font=('Arial', 24, tkFont.NORMAL)).pack(padx=10, pady=10)
65
+ tk.Label(logo_frame, text="SpaCr", bg="black", fg="#008080", font=('Helvetica', 24, tkFont.NORMAL)).pack(padx=10, pady=10)
66
66
 
67
67
  # Create a frame for the buttons and descriptions
68
68
  buttons_frame = tk.Frame(self.content_frame, bg="black")
@@ -72,10 +72,10 @@ class MainApp(tk.Tk):
72
72
  app_func, app_desc = app_data
73
73
 
74
74
  # Create custom button with text
75
- button = CustomButton(buttons_frame, text=app_name, command=lambda app_name=app_name: self.load_app(app_name))
75
+ button = CustomButton(buttons_frame, text=app_name, command=lambda app_name=app_name: self.load_app(app_name), font=('Helvetica', 12))
76
76
  button.grid(row=i, column=0, pady=10, padx=10, sticky="w")
77
77
 
78
- description_label = tk.Label(buttons_frame, text=app_desc, bg="black", fg="white", wraplength=800, justify="left", font=('Arial', 10, tkFont.NORMAL))
78
+ description_label = tk.Label(buttons_frame, text=app_desc, bg="black", fg="white", wraplength=800, justify="left", font=('Helvetica', 10, tkFont.NORMAL))
79
79
  description_label.grid(row=i, column=1, pady=10, padx=10, sticky="w")
80
80
 
81
81
  # Ensure buttons have a fixed width
@@ -131,7 +131,7 @@ class MainApp(tk.Tk):
131
131
 
132
132
  app_frame = tk.Frame(self.content_frame, bg="black")
133
133
  app_frame.pack(fill=tk.BOTH, expand=True)
134
- selected_app_func(app_frame, self.winfo_width(), self.winfo_height())
134
+ selected_app_func(app_frame)#, self.winfo_width(), self.winfo_height())
135
135
 
136
136
  def clear_frame(self, frame):
137
137
  for widget in frame.winfo_children():
spacr/gui_2.py CHANGED
@@ -1,52 +1,93 @@
1
- import customtkinter as ctk
1
+ import tkinter as tk
2
+ from tkinter import ttk
3
+ from tkinter import font as tkFont
2
4
  from PIL import Image, ImageTk
3
5
  import os
4
6
  import requests
5
7
 
6
- class MainApp(ctk.CTk):
8
+ # Import your GUI apps
9
+ from .gui_mask_app import initiate_mask_root
10
+ from .gui_measure_app import initiate_measure_root
11
+ from .annotate_app import initiate_annotation_app_root
12
+ from .mask_app import initiate_mask_app_root
13
+ from .gui_classify_app import initiate_classify_root
14
+
15
+ from .gui_utils import CustomButton, style_text_boxes
16
+
17
+ class MainApp(tk.Tk):
7
18
  def __init__(self):
8
19
  super().__init__()
9
20
  self.title("SpaCr GUI Collection")
10
- self.geometry("1200x800")
11
- ctk.set_appearance_mode("dark") # Modes: "System" (standard), "Dark", "Light"
12
- ctk.set_default_color_theme("dark-blue") # Themes: "blue" (standard), "green", "dark-blue")
13
-
14
- # Set scaling factor for high DPI displays; use a floating-point value.
15
- self.tk.call('tk', 'scaling', 1.5)
16
-
21
+ self.geometry("1100x1500")
22
+ self.configure(bg="black")
23
+ #self.attributes('-fullscreen', True)
24
+
25
+ style = ttk.Style()
26
+ style_text_boxes(style)
27
+
28
+ self.gui_apps = {
29
+ "Mask": (initiate_mask_root, "Generate cellpose masks for cells, nuclei and pathogen images."),
30
+ "Measure": (initiate_measure_root, "Measure single object intensity and morphological feature. Crop and save single object image"),
31
+ "Annotate": (initiate_annotation_app_root, "Annotation single object images on a grid. Annotations are saved to database."),
32
+ "Make Masks": (initiate_mask_app_root, "Adjust pre-existing Cellpose models to your specific dataset for improved performance"),
33
+ "Classify": (initiate_classify_root, "Train Torch Convolutional Neural Networks (CNNs) or Transformers to classify single object images.")
34
+ }
35
+
36
+ self.selected_app = tk.StringVar()
17
37
  self.create_widgets()
18
38
 
19
39
  def create_widgets(self):
20
- self.content_frame = ctk.CTkFrame(self)
21
- self.content_frame.pack(fill="both", expand=True, padx=20, pady=20)
40
+ # Create the menu bar
41
+ #create_menu_bar(self)
42
+ # Create a canvas to hold the selected app and other elements
43
+ self.canvas = tk.Canvas(self, bg="black", highlightthickness=0, width=4000, height=4000)
44
+ self.canvas.grid(row=0, column=0, sticky="nsew")
45
+ self.grid_rowconfigure(0, weight=1)
46
+ self.grid_columnconfigure(0, weight=1)
47
+ # Create a frame inside the canvas to hold the main content
48
+ self.content_frame = tk.Frame(self.canvas, bg="black")
49
+ self.content_frame.pack(fill=tk.BOTH, expand=True)
50
+ # Create startup screen with buttons for each GUI app
51
+ self.create_startup_screen()
22
52
 
23
- logo_frame = ctk.CTkFrame(self.content_frame)
53
+ def create_startup_screen(self):
54
+ self.clear_frame(self.content_frame)
55
+
56
+ # Create a frame for the logo and description
57
+ logo_frame = tk.Frame(self.content_frame, bg="black")
24
58
  logo_frame.pack(pady=20, expand=True)
25
59
 
60
+ # Load the logo image
26
61
  if not self.load_logo(logo_frame):
27
- ctk.CTkLabel(logo_frame, text="Logo not found", text_color="white", font=('Helvetica', 24)).pack(padx=10, pady=10)
28
-
29
- ctk.CTkLabel(logo_frame, text="SpaCr", text_color="#00BFFF", font=('Helvetica', 36, "bold")).pack(padx=10, pady=10)
30
-
31
- button = ctk.CTkButton(
32
- self.content_frame,
33
- text="Mask",
34
- command=self.load_mask_app,
35
- width=250,
36
- height=60,
37
- corner_radius=20,
38
- fg_color="#1E90FF",
39
- hover_color="#4682B4",
40
- text_color="white",
41
- font=("Helvetica", 18, "bold")
42
- )
43
- button.pack(pady=20)
62
+ tk.Label(logo_frame, text="Logo not found", bg="black", fg="white", font=('Helvetica', 24)).pack(padx=10, pady=10)
63
+
64
+ # Add SpaCr text below the logo with padding for sharper text
65
+ tk.Label(logo_frame, text="SpaCr", bg="black", fg="#008080", font=('Helvetica', 24)).pack(padx=10, pady=10)
66
+
67
+ # Create a frame for the buttons and descriptions
68
+ buttons_frame = tk.Frame(self.content_frame, bg="black")
69
+ buttons_frame.pack(pady=10, expand=True, padx=10)
70
+
71
+ for i, (app_name, app_data) in enumerate(self.gui_apps.items()):
72
+ app_func, app_desc = app_data
73
+
74
+ # Create custom button with text
75
+ button = CustomButton(buttons_frame, text=app_name, command=lambda app_name=app_name: self.load_app(app_name, app_func), font=('Helvetica', 12))
76
+ button.grid(row=i, column=0, pady=10, padx=10, sticky="w")
77
+
78
+ description_label = tk.Label(buttons_frame, text=app_desc, bg="black", fg="white", wraplength=800, justify="left", font=('Helvetica', 12))
79
+ description_label.grid(row=i, column=1, pady=10, padx=10, sticky="w")
80
+
81
+ # Ensure buttons have a fixed width
82
+ buttons_frame.grid_columnconfigure(0, minsize=150)
83
+ # Ensure descriptions expand as needed
84
+ buttons_frame.grid_columnconfigure(1, weight=1)
44
85
 
45
86
  def load_logo(self, frame):
46
87
  def download_image(url, save_path):
47
88
  try:
48
89
  response = requests.get(url, stream=True)
49
- response.raise_for_status()
90
+ response.raise_for_status() # Raise an HTTPError for bad responses
50
91
  with open(save_path, 'wb') as f:
51
92
  for chunk in response.iter_content(chunk_size=8192):
52
93
  f.write(chunk)
@@ -57,34 +98,63 @@ class MainApp(ctk.CTk):
57
98
 
58
99
  try:
59
100
  img_path = os.path.join(os.path.dirname(__file__), 'logo_spacr.png')
101
+ print(f"Trying to load logo from {img_path}")
60
102
  logo_image = Image.open(img_path)
61
103
  except (FileNotFoundError, Image.UnidentifiedImageError):
104
+ print(f"File {img_path} not found or is not a valid image. Attempting to download from GitHub.")
62
105
  if download_image('https://raw.githubusercontent.com/EinarOlafsson/spacr/main/spacr/logo_spacr.png', img_path):
63
106
  try:
107
+ print(f"Downloaded file size: {os.path.getsize(img_path)} bytes")
64
108
  logo_image = Image.open(img_path)
65
109
  except Image.UnidentifiedImageError as e:
110
+ print(f"Downloaded file is not a valid image: {e}")
66
111
  return False
67
112
  else:
68
113
  return False
69
114
  except Exception as e:
115
+ print(f"An error occurred while loading the logo: {e}")
70
116
  return False
71
-
72
117
  try:
73
- logo_image = logo_image.resize((200, 200), Image.Resampling.LANCZOS)
118
+ logo_image = logo_image.resize((800, 800), Image.Resampling.LANCZOS)
74
119
  logo_photo = ImageTk.PhotoImage(logo_image)
75
- logo_label = ctk.CTkLabel(frame, image=logo_photo)
120
+ logo_label = tk.Label(frame, image=logo_photo, bg="black")
76
121
  logo_label.image = logo_photo # Keep a reference to avoid garbage collection
77
122
  logo_label.pack()
78
123
  return True
79
124
  except Exception as e:
125
+ print(f"An error occurred while processing the logo image: {e}")
80
126
  return False
81
127
 
82
- def load_mask_app(self):
83
- print("Mask app loaded.") # Placeholder for mask app loading function
128
+ def load_app_v1(self, app_name):
129
+ selected_app_func, _ = self.gui_apps[app_name]
130
+ self.clear_frame(self.content_frame)
131
+
132
+ app_frame = tk.Frame(self.content_frame, bg="black")
133
+ app_frame.pack(fill=tk.BOTH, expand=True)
134
+ selected_app_func(app_frame)
135
+
136
+ def load_app(root, app_name, app_func):
137
+ if hasattr(root, 'current_app_id'):
138
+ root.after_cancel(root.current_app_id)
139
+ root.current_app_id = None
140
+
141
+ # Clear the current content frame
142
+ for widget in root.content_frame.winfo_children():
143
+ widget.destroy()
144
+
145
+ # Initialize the selected app
146
+ app_frame = tk.Frame(root.content_frame, bg="black")
147
+ app_frame.pack(fill=tk.BOTH, expand=True)
148
+ app_func(app_frame)
149
+
150
+ def clear_frame(self, frame):
151
+ for widget in frame.winfo_children():
152
+ widget.destroy()
153
+
84
154
 
85
155
  def gui_app():
86
156
  app = MainApp()
87
157
  app.mainloop()
88
158
 
89
159
  if __name__ == "__main__":
90
- gui_app()
160
+ gui_app()
spacr/gui_classify_app.py CHANGED
@@ -69,13 +69,13 @@ def import_settings(scrollable_frame):
69
69
  vars_dict = generate_fields(new_settings, scrollable_frame)
70
70
 
71
71
  #@log_function_call
72
- def initiate_classify_root(parent_frame):#, width, height):
72
+ def initiate_classify_root(parent_frame):
73
73
  global vars_dict, q, canvas, fig_queue, canvas_widget, thread_control
74
74
 
75
75
  style = ttk.Style(parent_frame)
76
76
  set_dark_style(style)
77
77
  style_text_boxes(style)
78
- set_default_font(parent_frame, font_name="Arial", size=8)
78
+ set_default_font(parent_frame, font_name="Helvetica", size=8)
79
79
 
80
80
  parent_frame.configure(bg='#333333')
81
81
  parent_frame.grid_rowconfigure(0, weight=1)
@@ -179,7 +179,7 @@ def gui_classify():
179
179
  root = tk.Tk()
180
180
  root.geometry("1000x800")
181
181
  root.title("SpaCer: generate masks")
182
- initiate_classify_root(root, 1000, 800)
182
+ initiate_classify_root(root),
183
183
  create_menu_bar(root)
184
184
  root.mainloop()
185
185