spacr 0.3.81__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/__init__.py CHANGED
@@ -67,8 +67,4 @@ logging.basicConfig(filename='spacr.log', level=logging.INFO,
67
67
 
68
68
  from .utils import download_models
69
69
 
70
- # Check if models already exist
71
- #models_dir = os.path.join(os.path.dirname(__file__), 'resources', 'models', 'cp')
72
- #if not os.path.exists(models_dir) or not os.listdir(models_dir):
73
- # print("Models not found, downloading...")
74
70
  download_models()
spacr/core.py CHANGED
@@ -7,15 +7,20 @@ from IPython.display import display
7
7
  import warnings
8
8
  warnings.filterwarnings("ignore", message="3D stack used, but stitch_threshold=0 and do_3D=False, so masks are made per plane only")
9
9
 
10
- def preprocess_generate_masks(src, settings={}):
10
+ def preprocess_generate_masks(settings):
11
11
 
12
12
  from .io import preprocess_img_data, _load_and_concatenate_arrays
13
13
  from .plot import plot_image_mask_overlay, plot_arrays
14
- from .utils import _pivot_counts_table, check_mask_folder, adjust_cell_masks, print_progress, save_settings
14
+ from .utils import _pivot_counts_table, check_mask_folder, adjust_cell_masks, print_progress, save_settings, delete_intermedeate_files
15
15
  from .settings import set_default_settings_preprocess_generate_masks
16
-
17
- if not isinstance(settings['src'], (str, list)):
18
- ValueError(f'src must be a string or a list of strings')
16
+
17
+
18
+ if 'src' in settings:
19
+ if not isinstance(settings['src'], (str, list)):
20
+ ValueError(f'src must be a string or a list of strings')
21
+ return
22
+ else:
23
+ ValueError(f'src is a required parameter')
19
24
  return
20
25
 
21
26
  if isinstance(settings['src'], str):
@@ -27,9 +32,8 @@ def preprocess_generate_masks(src, settings={}):
27
32
  print(f'Processing folder: {source_folder}')
28
33
  settings['src'] = source_folder
29
34
  src = source_folder
30
- settings = set_default_settings_preprocess_generate_masks(src, settings)
31
-
32
- save_settings(settings, name='gen_mask')
35
+ settings = set_default_settings_preprocess_generate_masks(settings)
36
+ save_settings(settings, name='gen_mask_settings')
33
37
 
34
38
  if not settings['pathogen_channel'] is None:
35
39
  custom_model_ls = ['toxo_pv_lumen','toxo_cyto']
@@ -158,6 +162,10 @@ def preprocess_generate_masks(src, settings={}):
158
162
 
159
163
  torch.cuda.empty_cache()
160
164
  gc.collect()
165
+
166
+ if settings['delete_intermediate']:
167
+ delete_intermedeate_files(settings)
168
+
161
169
  print("Successfully completed run")
162
170
  return
163
171
 
@@ -172,8 +180,10 @@ def generate_cellpose_masks(src, settings, object_type):
172
180
  gc.collect()
173
181
  if not torch.cuda.is_available():
174
182
  print(f'Torch CUDA is not available, using CPU')
175
-
176
- settings = set_default_settings_preprocess_generate_masks(src, settings)
183
+
184
+ settings['src'] = src
185
+
186
+ settings = set_default_settings_preprocess_generate_masks(settings)
177
187
 
178
188
  if settings['verbose']:
179
189
  settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
@@ -192,11 +202,13 @@ def generate_cellpose_masks(src, settings, object_type):
192
202
  timelapse_objects = settings['timelapse_objects']
193
203
 
194
204
  batch_size = settings['batch_size']
205
+
195
206
  cellprob_threshold = settings[f'{object_type}_CP_prob']
196
207
 
197
208
  flow_threshold = settings[f'{object_type}_FT']
198
209
 
199
210
  object_settings = _get_object_settings(object_type, settings)
211
+
200
212
  model_name = object_settings['model_name']
201
213
 
202
214
  cellpose_channels = _get_cellpose_channels(src, settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel'])
@@ -473,7 +485,7 @@ def generate_image_umap(settings={}):
473
485
  df, image_paths_tmp = correct_paths(df, settings['src'][i])
474
486
  all_df = pd.concat([all_df, df], axis=0)
475
487
  #image_paths.extend(image_paths_tmp)
476
-
488
+
477
489
  all_df['cond'] = all_df['column_name'].apply(map_condition, neg=settings['neg'], pos=settings['pos'], mix=settings['mix'])
478
490
 
479
491
  if settings['exclude_conditions']:
@@ -493,7 +505,7 @@ def generate_image_umap(settings={}):
493
505
 
494
506
  # Extract and reset the index for the column to compare
495
507
  col_to_compare = all_df[settings['col_to_compare']].reset_index(drop=True)
496
-
508
+ print(col_to_compare)
497
509
  #if settings['only_top_features']:
498
510
  # column_list = None
499
511
 
@@ -781,8 +793,10 @@ def generate_mediar_masks(src, settings, object_type):
781
793
  if not torch.cuda.is_available():
782
794
  print(f'Torch CUDA is not available, using CPU')
783
795
 
796
+ settings['src'] = src
797
+
784
798
  # Preprocess settings
785
- settings = set_default_settings_preprocess_generate_masks(src, settings)
799
+ settings = set_default_settings_preprocess_generate_masks(settings)
786
800
 
787
801
  if settings['verbose']:
788
802
  settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
spacr/deep_spacr.py CHANGED
@@ -12,6 +12,7 @@ from PIL import Image
12
12
  from sklearn.metrics import auc, precision_recall_curve
13
13
  from IPython.display import display
14
14
  from multiprocessing import cpu_count
15
+ import torch.optim as optim
15
16
 
16
17
  from torchvision import transforms
17
18
  from torch.utils.data import DataLoader
@@ -76,10 +77,11 @@ def apply_model_to_tar(settings={}):
76
77
  from .io import TarImageDataset
77
78
  from .utils import process_vision_results, print_progress
78
79
 
79
- if os.path.exists(settings['dataset']):
80
- tar_path = settings['dataset']
81
- else:
82
- tar_path = os.path.join(settings['src'], 'datasets', settings['dataset'])
80
+ #if os.path.exists(settings['dataset']):
81
+ # tar_path = settings['dataset']
82
+ #else:
83
+ # tar_path = os.path.join(settings['src'], 'datasets', settings['dataset'])
84
+ tar_path = settings['tar_path']
83
85
  model_path = settings['model_path']
84
86
 
85
87
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -103,7 +105,8 @@ def apply_model_to_tar(settings={}):
103
105
  data_loader = DataLoader(dataset, batch_size=settings['batch_size'], shuffle=True, num_workers=settings['n_jobs'], pin_memory=True)
104
106
 
105
107
  model_name = os.path.splitext(os.path.basename(model_path))[0]
106
- dataset_name = os.path.splitext(os.path.basename(settings['dataset']))[0]
108
+ #dataset_name = os.path.splitext(os.path.basename(settings['dataset']))[0]
109
+ dataset_name = os.path.splitext(os.path.basename(settings['tar_path']))[0]
107
110
  date_name = datetime.date.today().strftime('%y%m%d')
108
111
  dst = os.path.dirname(tar_path)
109
112
  result_loc = f'{dst}/{date_name}_{dataset_name}_{model_name}_result.csv'
@@ -934,3 +937,373 @@ def deep_spacr(settings={}):
934
937
 
935
938
  if os.path.exists(settings['model_path']):
936
939
  apply_model_to_tar(settings)
940
+
941
+ def model_knowledge_transfer(
942
+ teacher_paths,
943
+ student_save_path,
944
+ data_loader, # A DataLoader with (images, labels)
945
+ device='cpu',
946
+ student_model_name='maxvit_t',
947
+ pretrained=True,
948
+ dropout_rate=None,
949
+ use_checkpoint=False,
950
+ alpha=0.5,
951
+ temperature=2.0,
952
+ lr=1e-4,
953
+ epochs=10
954
+ ):
955
+ """
956
+ Performs multi-teacher knowledge distillation on a new labeled dataset,
957
+ producing a single student TorchModel that combines (distills) the
958
+ teachers' knowledge plus the labeled data.
959
+
960
+ Usage:
961
+ student = model_knowledge_transfer(
962
+ teacher_paths=[
963
+ 'teacherA.pth',
964
+ 'teacherB.pth',
965
+ ...
966
+ ],
967
+ student_save_path='distilled_student.pth',
968
+ data_loader=my_data_loader,
969
+ device='cuda',
970
+ student_model_name='maxvit_t',
971
+ alpha=0.5,
972
+ temperature=2.0,
973
+ lr=1e-4,
974
+ epochs=10
975
+ )
976
+
977
+ Then load it via:
978
+ fused_student = torch.load('distilled_student.pth')
979
+ # fused_student is a TorchModel instance, ready for inference.
980
+
981
+ Args:
982
+ teacher_paths (list[str]): List of paths to teacher models (TorchModel
983
+ or dict with 'model' in it). They must have the same architecture
984
+ or at least produce the same dimension of output.
985
+ student_save_path (str): Destination path to save the final student
986
+ TorchModel.
987
+ data_loader (DataLoader): Yields (images, labels) from the new dataset.
988
+ device (str): 'cpu' or 'cuda'.
989
+ student_model_name (str): Architecture name for the student TorchModel.
990
+ pretrained (bool): If the student should be initialized as pretrained.
991
+ dropout_rate (float): If needed by your TorchModel init.
992
+ use_checkpoint (bool): If needed by your TorchModel init.
993
+ alpha (float): Weight balancing real-label CE vs. distillation loss
994
+ (0..1).
995
+ temperature (float): Distillation temperature (>1 typically).
996
+ lr (float): Learning rate for the student.
997
+ epochs (int): Number of training epochs.
998
+
999
+ Returns:
1000
+ TorchModel: The final, trained student model.
1001
+ """
1002
+ from spacr.utils import TorchModel # Adapt if needed
1003
+
1004
+ # Adjust filename to reflect knowledge-distillation if desired
1005
+ if student_save_path.endswith('.pth'):
1006
+ base, ext = os.path.splitext(student_save_path)
1007
+ else:
1008
+ base = student_save_path
1009
+ student_save_path = base + '_KD.pth'
1010
+
1011
+ # -- 1. Load teacher models --
1012
+ teachers = []
1013
+ print("Loading teacher models:")
1014
+ for path in teacher_paths:
1015
+ print(f" Loading teacher: {path}")
1016
+ ckpt = torch.load(path, map_location=device)
1017
+ if isinstance(ckpt, TorchModel):
1018
+ teacher = ckpt.to(device)
1019
+ elif isinstance(ckpt, dict):
1020
+ # If it's a dict with 'model' inside
1021
+ # We might need to check if it has 'model_name', etc.
1022
+ # But let's keep it simple: same architecture as the student
1023
+ teacher = TorchModel(
1024
+ model_name=ckpt.get('model_name', student_model_name),
1025
+ pretrained=ckpt.get('pretrained', pretrained),
1026
+ dropout_rate=ckpt.get('dropout_rate', dropout_rate),
1027
+ use_checkpoint=ckpt.get('use_checkpoint', use_checkpoint)
1028
+ ).to(device)
1029
+ teacher.load_state_dict(ckpt['model'])
1030
+ else:
1031
+ raise ValueError(f"Unsupported checkpoint type at {path} (must be TorchModel or dict).")
1032
+
1033
+ teacher.eval() # For consistent batchnorm, dropout
1034
+ teachers.append(teacher)
1035
+
1036
+ # -- 2. Initialize the student TorchModel --
1037
+ student_model = TorchModel(
1038
+ model_name=student_model_name,
1039
+ pretrained=pretrained,
1040
+ dropout_rate=dropout_rate,
1041
+ use_checkpoint=use_checkpoint
1042
+ ).to(device)
1043
+
1044
+ # You could load a partial checkpoint into the student here if desired.
1045
+
1046
+ # -- 3. Optimizer --
1047
+ optimizer = optim.Adam(student_model.parameters(), lr=lr)
1048
+
1049
+ # Distillation training loop
1050
+ for epoch in range(epochs):
1051
+ student_model.train()
1052
+ running_loss = 0.0
1053
+
1054
+ for images, labels in data_loader:
1055
+ images, labels = images.to(device), labels.to(device)
1056
+
1057
+ # Forward pass student
1058
+ logits_s = student_model(images) # shape: (B, num_classes)
1059
+ logits_s_temp = logits_s / temperature # scale by T
1060
+
1061
+ # Distillation from teachers
1062
+ with torch.no_grad():
1063
+ # We'll average teacher probabilities
1064
+ teacher_probs_list = []
1065
+ for tm in teachers:
1066
+ logits_t = tm(images) / temperature
1067
+ # convert to probabilities
1068
+ teacher_probs_list.append(F.softmax(logits_t, dim=1))
1069
+ # average them
1070
+ teacher_probs_ensemble = torch.mean(torch.stack(teacher_probs_list), dim=0)
1071
+
1072
+ # Student probabilities (log-softmax)
1073
+ student_log_probs = F.log_softmax(logits_s_temp, dim=1)
1074
+
1075
+ # Distillation loss => KLDiv
1076
+ loss_distill = F.kl_div(
1077
+ student_log_probs,
1078
+ teacher_probs_ensemble,
1079
+ reduction='batchmean'
1080
+ ) * (temperature ** 2)
1081
+
1082
+ # Real label loss => cross-entropy
1083
+ # We can compute this on the raw logits or scaled. Typically raw logits is standard:
1084
+ loss_ce = F.cross_entropy(logits_s, labels)
1085
+
1086
+ # Weighted sum
1087
+ loss = alpha * loss_ce + (1 - alpha) * loss_distill
1088
+
1089
+ optimizer.zero_grad()
1090
+ loss.backward()
1091
+ optimizer.step()
1092
+
1093
+ running_loss += loss.item()
1094
+
1095
+ avg_loss = running_loss / len(data_loader)
1096
+ print(f"Epoch [{epoch+1}/{epochs}] - Loss: {avg_loss:.4f}")
1097
+
1098
+ # -- 4. Save final student as a TorchModel --
1099
+ torch.save(student_model, student_save_path)
1100
+ print(f"Knowledge-distilled student saved to: {student_save_path}")
1101
+
1102
+ return student_model
1103
+
1104
+ def model_fusion(model_paths,
1105
+ save_path,
1106
+ device='cpu',
1107
+ model_name='maxvit_t',
1108
+ pretrained=True,
1109
+ dropout_rate=None,
1110
+ use_checkpoint=False,
1111
+ aggregator='mean'):
1112
+ """
1113
+ Fuses an arbitrary number of TorchModel instances by combining their weights
1114
+ (using mean, geomean, median, sum, max, or min) and saves the entire fused
1115
+ model object.
1116
+
1117
+ You can later load the fused model with:
1118
+ model = torch.load('fused_model.pth')
1119
+
1120
+ which returns a ready-to-use TorchModel instance.
1121
+
1122
+ Parameters:
1123
+ model_paths (list of str): Paths to the model checkpoints to fuse.
1124
+ Each checkpoint can be:
1125
+ - A dict with keys ['model', 'model_name', ...]
1126
+ - A TorchModel instance
1127
+ save_path (str): Destination path to save the fused model.
1128
+ device (str): 'cpu' or 'cuda' for loading weights and final model device.
1129
+ model_name (str): Default model name (used if not in checkpoint).
1130
+ pretrained (bool): Default if not in checkpoint.
1131
+ dropout_rate (float): Default if not in checkpoint.
1132
+ use_checkpoint (bool): Default if not in checkpoint.
1133
+ aggregator (str): How to combine weights across models:
1134
+ 'mean', 'geomean', 'median', 'sum', 'max', or 'min'.
1135
+
1136
+ Returns:
1137
+ fused_model (TorchModel): The final fused TorchModel instance
1138
+ with combined weights.
1139
+ """
1140
+ from spacr.utils import TorchModel
1141
+
1142
+ if save_path.endswith('.pth'):
1143
+ save_path_part1, ext = os.path.splitext(save_path)
1144
+ else:
1145
+ save_path_part1 = save_path
1146
+
1147
+ save_path = save_path_part1 + f'_{aggregator}.pth'
1148
+
1149
+ valid_aggregators = {'mean', 'geomean', 'median', 'sum', 'max', 'min'}
1150
+ if aggregator not in valid_aggregators:
1151
+ raise ValueError(f"Invalid aggregator '{aggregator}'. "
1152
+ f"Must be one of {valid_aggregators}.")
1153
+
1154
+ # --- 1. Load the first checkpoint to figure out architecture & hyperparams ---
1155
+ print(f"Loading the first model from: {model_paths[0]} to derive architecture")
1156
+ first_ckpt = torch.load(model_paths[0], map_location=device)
1157
+
1158
+ if isinstance(first_ckpt, dict):
1159
+ # It's a dict with state_dict + possibly metadata
1160
+ # Use any stored metadata if present
1161
+ model_name = first_ckpt.get('model_name', model_name)
1162
+ pretrained = first_ckpt.get('pretrained', pretrained)
1163
+ dropout_rate = first_ckpt.get('dropout_rate', dropout_rate)
1164
+ use_checkpoint = first_ckpt.get('use_checkpoint', use_checkpoint)
1165
+
1166
+ # Initialize the fused model
1167
+ fused_model = TorchModel(
1168
+ model_name=model_name,
1169
+ pretrained=pretrained,
1170
+ dropout_rate=dropout_rate,
1171
+ use_checkpoint=use_checkpoint
1172
+ ).to(device)
1173
+
1174
+ # We'll collect state dicts in a list
1175
+ state_dicts = [first_ckpt['model']] # the actual weights
1176
+ elif isinstance(first_ckpt, TorchModel):
1177
+ # The checkpoint is directly a TorchModel instance
1178
+ fused_model = first_ckpt.to(device)
1179
+ state_dicts = [fused_model.state_dict()]
1180
+ else:
1181
+ raise ValueError("Unsupported checkpoint format. Must be a dict or a TorchModel instance.")
1182
+
1183
+ # --- 2. Load the rest of the checkpoints ---
1184
+ for path in model_paths[1:]:
1185
+ print(f"Loading model from: {path}")
1186
+ ckpt = torch.load(path, map_location=device)
1187
+ if isinstance(ckpt, dict):
1188
+ state_dicts.append(ckpt['model']) # Just the state dict portion
1189
+ elif isinstance(ckpt, TorchModel):
1190
+ state_dicts.append(ckpt.state_dict())
1191
+ else:
1192
+ raise ValueError(f"Unsupported checkpoint format in {path} (must be dict or TorchModel).")
1193
+
1194
+ # --- 3. Verify all state dicts have the same keys ---
1195
+ fused_sd = fused_model.state_dict()
1196
+ for sd in state_dicts:
1197
+ if fused_sd.keys() != sd.keys():
1198
+ raise ValueError("All models must have identical architecture/state_dict keys.")
1199
+
1200
+ # --- 4. Define aggregator logic ---
1201
+ def combine_tensors(tensor_list, mode='mean'):
1202
+ """Given a list of Tensors, combine them using the chosen aggregator."""
1203
+ # stack along new dimension => shape (num_models, *tensor.shape)
1204
+ stacked = torch.stack(tensor_list, dim=0).float()
1205
+
1206
+ if mode == 'mean':
1207
+ return stacked.mean(dim=0)
1208
+ elif mode == 'geomean':
1209
+ # geometric mean = exp(mean(log(tensor)))
1210
+ # caution: requires all > 0
1211
+ return torch.exp(torch.log(stacked).mean(dim=0))
1212
+ elif mode == 'median':
1213
+ return stacked.median(dim=0).values
1214
+ elif mode == 'sum':
1215
+ return stacked.sum(dim=0)
1216
+ elif mode == 'max':
1217
+ return stacked.max(dim=0).values
1218
+ elif mode == 'min':
1219
+ return stacked.min(dim=0).values
1220
+ else:
1221
+ raise ValueError(f"Unsupported aggregator: {mode}")
1222
+
1223
+ # --- 5. Combine the weights ---
1224
+ for key in fused_sd.keys():
1225
+ # gather all versions of this tensor
1226
+ all_tensors = [sd[key] for sd in state_dicts]
1227
+ fused_sd[key] = combine_tensors(all_tensors, mode=aggregator)
1228
+
1229
+ # Load combined weights into the fused model
1230
+ fused_model.load_state_dict(fused_sd)
1231
+
1232
+ # --- 6. Save the entire TorchModel object ---
1233
+ torch.save(fused_model, save_path)
1234
+ print(f"Fused model (aggregator='{aggregator}') saved as a full TorchModel to: {save_path}")
1235
+
1236
+ return fused_model
1237
+
1238
+ def annotate_filter_vision(settings):
1239
+
1240
+ from .utils import annotate_conditions
1241
+
1242
+ def filter_csv_by_png(csv_file):
1243
+ """
1244
+ Filters a DataFrame by removing rows that match PNG filenames in a folder.
1245
+
1246
+ Parameters:
1247
+ csv_file (str): Path to the CSV file.
1248
+
1249
+ Returns:
1250
+ pd.DataFrame: Filtered DataFrame.
1251
+ """
1252
+ # Split the path to identify the datasets folder and build the training folder path
1253
+ before_datasets, after_datasets = csv_file.split(os.sep + "datasets" + os.sep, 1)
1254
+ train_fldr = os.path.join(before_datasets, 'datasets', 'training', 'train')
1255
+
1256
+ # Paths for train/nc and train/pc
1257
+ nc_folder = os.path.join(train_fldr, 'nc')
1258
+ pc_folder = os.path.join(train_fldr, 'pc')
1259
+
1260
+ # Load the CSV file into a DataFrame
1261
+ df = pd.read_csv(csv_file)
1262
+
1263
+ # Collect PNG filenames from train/nc and train/pc
1264
+ png_files = set()
1265
+ for folder in [nc_folder, pc_folder]:
1266
+ if os.path.exists(folder): # Ensure the folder exists
1267
+ png_files.update({file for file in os.listdir(folder) if file.endswith(".png")})
1268
+
1269
+ # Filter the DataFrame by excluding rows where filenames match PNG files
1270
+ filtered_df = df[~df['path'].isin(png_files)]
1271
+
1272
+ return filtered_df
1273
+
1274
+ if isinstance(settings['src'], str):
1275
+ settings['src'] = [settings['src']]
1276
+
1277
+ for src in settings['src']:
1278
+ ann_src, ext = os.path.splitext(src)
1279
+ output_csv = ann_src+'_annotated_filtered.csv'
1280
+ print(output_csv)
1281
+
1282
+ df = pd.read_csv(src)
1283
+
1284
+ if 'column_name' not in df.columns:
1285
+ df['column_name'] = df['column']
1286
+
1287
+ df = annotate_conditions(df,
1288
+ cells=settings['cells'],
1289
+ cell_loc=settings['cell_loc'],
1290
+ pathogens=settings['pathogens'],
1291
+ pathogen_loc=settings['pathogen_loc'],
1292
+ treatments=settings['treatments'],
1293
+ treatment_loc=settings['treatment_loc'])
1294
+
1295
+ if not settings['filter_column'] is None:
1296
+ if settings['filter_column'] in df.columns:
1297
+ filtered_df = df[(df[settings['filter_column']] > settings['upper_threshold']) | (df[settings['filter_column']] < settings['lower_threshold'])]
1298
+ print(f'Filtered DataFrame with {len(df)} rows to {len(filtered_df)} rows.')
1299
+ else:
1300
+ print(f"{settings['filter_column']} not in DataFrame columns.")
1301
+ filtered_df = df
1302
+ else:
1303
+ filtered_df = df
1304
+
1305
+ filtered_df.to_csv(output_csv, index=False)
1306
+
1307
+ if settings['remove_train']:
1308
+ df = filter_csv_by_png(output_csv)
1309
+ df.to_csv(output_csv, index=False)
spacr/gui_core.py CHANGED
@@ -324,40 +324,6 @@ def show_next_figure():
324
324
  index_control.set(figure_index)
325
325
  index_control.set_to(len(figures) - 1)
326
326
  display_figure(fig)
327
-
328
- def process_fig_queue_v1():
329
- global canvas, fig_queue, canvas_widget, parent_frame, uppdate_frequency, figures, figure_index, index_control
330
- from .gui_elements import standardize_figure
331
-
332
- #print("process_fig_queue called", flush=True)
333
- try:
334
- while not fig_queue.empty():
335
- fig = fig_queue.get_nowait()
336
- if fig is None:
337
- print("Warning: Retrieved a None figure from fig_queue.", flush=True)
338
- continue
339
-
340
- # Standardize the figure appearance before adding it
341
- fig = standardize_figure(fig)
342
- figures.append(fig)
343
-
344
- # Update slider maximum
345
- index_control.set_to(len(figures) - 1)
346
-
347
- # If no figure has been displayed yet
348
- if figure_index == -1:
349
- figure_index = 0
350
- display_figure(figures[figure_index])
351
- index_control.set(figure_index)
352
-
353
- except Exception as e:
354
- print("Exception in process_fig_queue:", e, flush=True)
355
- traceback.print_exc()
356
-
357
- finally:
358
- # Schedule process_fig_queue() to run again
359
- after_id = canvas_widget.after(uppdate_frequency, process_fig_queue)
360
- parent_frame.after_tasks.append(after_id)
361
327
 
362
328
  def process_fig_queue():
363
329
  global canvas, fig_queue, canvas_widget, parent_frame, uppdate_frequency, figures, figure_index, index_control
@@ -544,7 +510,7 @@ def import_settings(settings_type='mask'):
544
510
  #vars_dict = hide_all_settings(vars_dict, categories=None)
545
511
  csv_settings = read_settings_from_csv(csv_file_path)
546
512
  if settings_type == 'mask':
547
- settings = set_default_settings_preprocess_generate_masks(src='path', settings={})
513
+ settings = set_default_settings_preprocess_generate_masks(settings={})
548
514
  elif settings_type == 'measure':
549
515
  settings = get_measure_crop_settings(settings={})
550
516
  elif settings_type == 'classify':
@@ -596,7 +562,7 @@ def setup_settings_panel(vertical_container, settings_type='mask'):
596
562
  settings_frame.grid_columnconfigure(0, weight=1)
597
563
 
598
564
  if settings_type == 'mask':
599
- settings = set_default_settings_preprocess_generate_masks(src='path', settings={})
565
+ settings = set_default_settings_preprocess_generate_masks(settings={})
600
566
  elif settings_type == 'measure':
601
567
  settings = get_measure_crop_settings(settings={})
602
568
  elif settings_type == 'classify':
@@ -912,7 +878,7 @@ def start_process(q=None, fig_queue=None, settings_type='mask'):
912
878
  q.put(f"Error: {e}")
913
879
  return
914
880
 
915
- if thread_control.get("run_thread") is not None:
881
+ if isinstance(thread_control, dict) and thread_control.get("run_thread") is not None:
916
882
  initiate_abort()
917
883
 
918
884
  stop_requested = Value('i', 0)
@@ -1018,6 +984,66 @@ def main_thread_update_function(root, q, fig_queue, canvas_widget):
1018
984
  print(f"Error updating GUI canvas: {e}")
1019
985
  finally:
1020
986
  root.after(uppdate_frequency, lambda: main_thread_update_function(root, q, fig_queue, canvas_widget))
987
+
988
+ def cleanup_previous_instance():
989
+ """
990
+ Cleans up resources from the previous application instance.
991
+ """
992
+ global parent_frame, usage_bars, figures, figure_index, thread_control, canvas, q, fig_queue
993
+
994
+ # 1. Destroy all widgets in the parent frame
995
+ if parent_frame is not None:
996
+ for widget in parent_frame.winfo_children():
997
+ try:
998
+ widget.destroy()
999
+ except Exception as e:
1000
+ print(f"Error destroying widget: {e}")
1001
+ parent_frame.update_idletasks()
1002
+ parent_frame = None
1003
+
1004
+ # 2. Cancel all pending `after` tasks
1005
+ if parent_frame is not None:
1006
+ parent_window = parent_frame.winfo_toplevel()
1007
+ if hasattr(parent_window, 'after_tasks'):
1008
+ for after_id in parent_window.after_tasks:
1009
+ parent_window.after_cancel(after_id)
1010
+ parent_window.after_tasks = []
1011
+
1012
+ # 3. Clear global queues
1013
+ if q is not None:
1014
+ while not q.empty():
1015
+ q.get()
1016
+ q = None
1017
+
1018
+ if fig_queue is not None:
1019
+ while not fig_queue.empty():
1020
+ fig_queue.get()
1021
+ fig_queue = None
1022
+
1023
+ # 4. Stop and reset global thread control
1024
+ if thread_control is not None:
1025
+ thread_control['stop'] = True
1026
+ #thread_control = None
1027
+
1028
+ # 5. Reset usage bars, figures, and indices
1029
+ usage_bars = []
1030
+ figures = deque()
1031
+ figure_index = -1
1032
+
1033
+ # 6. Clear canvas or other visualizations
1034
+ if canvas is not None:
1035
+ try:
1036
+ if hasattr(canvas, 'figure'): # Check if it's a FigureCanvasTkAgg
1037
+ canvas.figure.clear() # Clear the Matplotlib figure
1038
+ canvas.get_tk_widget().destroy() # Destroy the Tkinter widget
1039
+ else:
1040
+ # Assume it's a standard Tkinter Canvas
1041
+ canvas.delete("all")
1042
+ except Exception as e:
1043
+ print(f"Error clearing canvas: {e}")
1044
+ canvas = None
1045
+
1046
+ print("Previous instance cleaned up successfully.")
1021
1047
 
1022
1048
  def initiate_root(parent, settings_type='mask'):
1023
1049
  """
@@ -1033,7 +1059,11 @@ def initiate_root(parent, settings_type='mask'):
1033
1059
 
1034
1060
  global q, fig_queue, thread_control, parent_frame, scrollable_frame, button_frame, vars_dict, canvas, canvas_widget, button_scrollable_frame, progress_bar, uppdate_frequency, figures, figure_index, index_control, usage_bars
1035
1061
 
1036
- from .gui_utils import setup_frame, get_screen_dimensions
1062
+ # Clean up any previous instance
1063
+ cleanup_previous_instance()
1064
+
1065
+ from .gui_utils import setup_frame
1066
+ from .gui_elements import create_menu_bar
1037
1067
  from .settings import descriptions
1038
1068
  #from .openai import Chatbot
1039
1069
 
@@ -1096,6 +1126,7 @@ def initiate_root(parent, settings_type='mask'):
1096
1126
 
1097
1127
  process_console_queue()
1098
1128
  process_fig_queue()
1129
+ create_menu_bar(parent)
1099
1130
  after_id = parent_window.after(uppdate_frequency, lambda: main_thread_update_function(parent_window, q, fig_queue, canvas_widget))
1100
1131
  parent_window.after_tasks.append(after_id)
1101
1132