spacr 0.3.81__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +0 -4
- spacr/core.py +27 -13
- spacr/deep_spacr.py +378 -5
- spacr/gui_core.py +69 -38
- spacr/gui_elements.py +193 -3
- spacr/gui_utils.py +1 -1
- spacr/io.py +5 -176
- spacr/measure.py +10 -6
- spacr/ml.py +369 -46
- spacr/plot.py +201 -90
- spacr/settings.py +52 -16
- spacr/submodules.py +282 -1
- spacr/toxo.py +98 -75
- spacr/utils.py +128 -36
- {spacr-0.3.81.dist-info → spacr-0.4.0.dist-info}/METADATA +2 -1
- {spacr-0.3.81.dist-info → spacr-0.4.0.dist-info}/RECORD +20 -20
- {spacr-0.3.81.dist-info → spacr-0.4.0.dist-info}/LICENSE +0 -0
- {spacr-0.3.81.dist-info → spacr-0.4.0.dist-info}/WHEEL +0 -0
- {spacr-0.3.81.dist-info → spacr-0.4.0.dist-info}/entry_points.txt +0 -0
- {spacr-0.3.81.dist-info → spacr-0.4.0.dist-info}/top_level.txt +0 -0
spacr/__init__.py
CHANGED
@@ -67,8 +67,4 @@ logging.basicConfig(filename='spacr.log', level=logging.INFO,
|
|
67
67
|
|
68
68
|
from .utils import download_models
|
69
69
|
|
70
|
-
# Check if models already exist
|
71
|
-
#models_dir = os.path.join(os.path.dirname(__file__), 'resources', 'models', 'cp')
|
72
|
-
#if not os.path.exists(models_dir) or not os.listdir(models_dir):
|
73
|
-
# print("Models not found, downloading...")
|
74
70
|
download_models()
|
spacr/core.py
CHANGED
@@ -7,15 +7,20 @@ from IPython.display import display
|
|
7
7
|
import warnings
|
8
8
|
warnings.filterwarnings("ignore", message="3D stack used, but stitch_threshold=0 and do_3D=False, so masks are made per plane only")
|
9
9
|
|
10
|
-
def preprocess_generate_masks(
|
10
|
+
def preprocess_generate_masks(settings):
|
11
11
|
|
12
12
|
from .io import preprocess_img_data, _load_and_concatenate_arrays
|
13
13
|
from .plot import plot_image_mask_overlay, plot_arrays
|
14
|
-
from .utils import _pivot_counts_table, check_mask_folder, adjust_cell_masks, print_progress, save_settings
|
14
|
+
from .utils import _pivot_counts_table, check_mask_folder, adjust_cell_masks, print_progress, save_settings, delete_intermedeate_files
|
15
15
|
from .settings import set_default_settings_preprocess_generate_masks
|
16
|
-
|
17
|
-
|
18
|
-
|
16
|
+
|
17
|
+
|
18
|
+
if 'src' in settings:
|
19
|
+
if not isinstance(settings['src'], (str, list)):
|
20
|
+
ValueError(f'src must be a string or a list of strings')
|
21
|
+
return
|
22
|
+
else:
|
23
|
+
ValueError(f'src is a required parameter')
|
19
24
|
return
|
20
25
|
|
21
26
|
if isinstance(settings['src'], str):
|
@@ -27,9 +32,8 @@ def preprocess_generate_masks(src, settings={}):
|
|
27
32
|
print(f'Processing folder: {source_folder}')
|
28
33
|
settings['src'] = source_folder
|
29
34
|
src = source_folder
|
30
|
-
settings = set_default_settings_preprocess_generate_masks(
|
31
|
-
|
32
|
-
save_settings(settings, name='gen_mask')
|
35
|
+
settings = set_default_settings_preprocess_generate_masks(settings)
|
36
|
+
save_settings(settings, name='gen_mask_settings')
|
33
37
|
|
34
38
|
if not settings['pathogen_channel'] is None:
|
35
39
|
custom_model_ls = ['toxo_pv_lumen','toxo_cyto']
|
@@ -158,6 +162,10 @@ def preprocess_generate_masks(src, settings={}):
|
|
158
162
|
|
159
163
|
torch.cuda.empty_cache()
|
160
164
|
gc.collect()
|
165
|
+
|
166
|
+
if settings['delete_intermediate']:
|
167
|
+
delete_intermedeate_files(settings)
|
168
|
+
|
161
169
|
print("Successfully completed run")
|
162
170
|
return
|
163
171
|
|
@@ -172,8 +180,10 @@ def generate_cellpose_masks(src, settings, object_type):
|
|
172
180
|
gc.collect()
|
173
181
|
if not torch.cuda.is_available():
|
174
182
|
print(f'Torch CUDA is not available, using CPU')
|
175
|
-
|
176
|
-
settings =
|
183
|
+
|
184
|
+
settings['src'] = src
|
185
|
+
|
186
|
+
settings = set_default_settings_preprocess_generate_masks(settings)
|
177
187
|
|
178
188
|
if settings['verbose']:
|
179
189
|
settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
|
@@ -192,11 +202,13 @@ def generate_cellpose_masks(src, settings, object_type):
|
|
192
202
|
timelapse_objects = settings['timelapse_objects']
|
193
203
|
|
194
204
|
batch_size = settings['batch_size']
|
205
|
+
|
195
206
|
cellprob_threshold = settings[f'{object_type}_CP_prob']
|
196
207
|
|
197
208
|
flow_threshold = settings[f'{object_type}_FT']
|
198
209
|
|
199
210
|
object_settings = _get_object_settings(object_type, settings)
|
211
|
+
|
200
212
|
model_name = object_settings['model_name']
|
201
213
|
|
202
214
|
cellpose_channels = _get_cellpose_channels(src, settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel'])
|
@@ -473,7 +485,7 @@ def generate_image_umap(settings={}):
|
|
473
485
|
df, image_paths_tmp = correct_paths(df, settings['src'][i])
|
474
486
|
all_df = pd.concat([all_df, df], axis=0)
|
475
487
|
#image_paths.extend(image_paths_tmp)
|
476
|
-
|
488
|
+
|
477
489
|
all_df['cond'] = all_df['column_name'].apply(map_condition, neg=settings['neg'], pos=settings['pos'], mix=settings['mix'])
|
478
490
|
|
479
491
|
if settings['exclude_conditions']:
|
@@ -493,7 +505,7 @@ def generate_image_umap(settings={}):
|
|
493
505
|
|
494
506
|
# Extract and reset the index for the column to compare
|
495
507
|
col_to_compare = all_df[settings['col_to_compare']].reset_index(drop=True)
|
496
|
-
|
508
|
+
print(col_to_compare)
|
497
509
|
#if settings['only_top_features']:
|
498
510
|
# column_list = None
|
499
511
|
|
@@ -781,8 +793,10 @@ def generate_mediar_masks(src, settings, object_type):
|
|
781
793
|
if not torch.cuda.is_available():
|
782
794
|
print(f'Torch CUDA is not available, using CPU')
|
783
795
|
|
796
|
+
settings['src'] = src
|
797
|
+
|
784
798
|
# Preprocess settings
|
785
|
-
settings = set_default_settings_preprocess_generate_masks(
|
799
|
+
settings = set_default_settings_preprocess_generate_masks(settings)
|
786
800
|
|
787
801
|
if settings['verbose']:
|
788
802
|
settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
|
spacr/deep_spacr.py
CHANGED
@@ -12,6 +12,7 @@ from PIL import Image
|
|
12
12
|
from sklearn.metrics import auc, precision_recall_curve
|
13
13
|
from IPython.display import display
|
14
14
|
from multiprocessing import cpu_count
|
15
|
+
import torch.optim as optim
|
15
16
|
|
16
17
|
from torchvision import transforms
|
17
18
|
from torch.utils.data import DataLoader
|
@@ -76,10 +77,11 @@ def apply_model_to_tar(settings={}):
|
|
76
77
|
from .io import TarImageDataset
|
77
78
|
from .utils import process_vision_results, print_progress
|
78
79
|
|
79
|
-
if os.path.exists(settings['dataset']):
|
80
|
-
|
81
|
-
else:
|
82
|
-
|
80
|
+
#if os.path.exists(settings['dataset']):
|
81
|
+
# tar_path = settings['dataset']
|
82
|
+
#else:
|
83
|
+
# tar_path = os.path.join(settings['src'], 'datasets', settings['dataset'])
|
84
|
+
tar_path = settings['tar_path']
|
83
85
|
model_path = settings['model_path']
|
84
86
|
|
85
87
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
@@ -103,7 +105,8 @@ def apply_model_to_tar(settings={}):
|
|
103
105
|
data_loader = DataLoader(dataset, batch_size=settings['batch_size'], shuffle=True, num_workers=settings['n_jobs'], pin_memory=True)
|
104
106
|
|
105
107
|
model_name = os.path.splitext(os.path.basename(model_path))[0]
|
106
|
-
dataset_name = os.path.splitext(os.path.basename(settings['dataset']))[0]
|
108
|
+
#dataset_name = os.path.splitext(os.path.basename(settings['dataset']))[0]
|
109
|
+
dataset_name = os.path.splitext(os.path.basename(settings['tar_path']))[0]
|
107
110
|
date_name = datetime.date.today().strftime('%y%m%d')
|
108
111
|
dst = os.path.dirname(tar_path)
|
109
112
|
result_loc = f'{dst}/{date_name}_{dataset_name}_{model_name}_result.csv'
|
@@ -934,3 +937,373 @@ def deep_spacr(settings={}):
|
|
934
937
|
|
935
938
|
if os.path.exists(settings['model_path']):
|
936
939
|
apply_model_to_tar(settings)
|
940
|
+
|
941
|
+
def model_knowledge_transfer(
|
942
|
+
teacher_paths,
|
943
|
+
student_save_path,
|
944
|
+
data_loader, # A DataLoader with (images, labels)
|
945
|
+
device='cpu',
|
946
|
+
student_model_name='maxvit_t',
|
947
|
+
pretrained=True,
|
948
|
+
dropout_rate=None,
|
949
|
+
use_checkpoint=False,
|
950
|
+
alpha=0.5,
|
951
|
+
temperature=2.0,
|
952
|
+
lr=1e-4,
|
953
|
+
epochs=10
|
954
|
+
):
|
955
|
+
"""
|
956
|
+
Performs multi-teacher knowledge distillation on a new labeled dataset,
|
957
|
+
producing a single student TorchModel that combines (distills) the
|
958
|
+
teachers' knowledge plus the labeled data.
|
959
|
+
|
960
|
+
Usage:
|
961
|
+
student = model_knowledge_transfer(
|
962
|
+
teacher_paths=[
|
963
|
+
'teacherA.pth',
|
964
|
+
'teacherB.pth',
|
965
|
+
...
|
966
|
+
],
|
967
|
+
student_save_path='distilled_student.pth',
|
968
|
+
data_loader=my_data_loader,
|
969
|
+
device='cuda',
|
970
|
+
student_model_name='maxvit_t',
|
971
|
+
alpha=0.5,
|
972
|
+
temperature=2.0,
|
973
|
+
lr=1e-4,
|
974
|
+
epochs=10
|
975
|
+
)
|
976
|
+
|
977
|
+
Then load it via:
|
978
|
+
fused_student = torch.load('distilled_student.pth')
|
979
|
+
# fused_student is a TorchModel instance, ready for inference.
|
980
|
+
|
981
|
+
Args:
|
982
|
+
teacher_paths (list[str]): List of paths to teacher models (TorchModel
|
983
|
+
or dict with 'model' in it). They must have the same architecture
|
984
|
+
or at least produce the same dimension of output.
|
985
|
+
student_save_path (str): Destination path to save the final student
|
986
|
+
TorchModel.
|
987
|
+
data_loader (DataLoader): Yields (images, labels) from the new dataset.
|
988
|
+
device (str): 'cpu' or 'cuda'.
|
989
|
+
student_model_name (str): Architecture name for the student TorchModel.
|
990
|
+
pretrained (bool): If the student should be initialized as pretrained.
|
991
|
+
dropout_rate (float): If needed by your TorchModel init.
|
992
|
+
use_checkpoint (bool): If needed by your TorchModel init.
|
993
|
+
alpha (float): Weight balancing real-label CE vs. distillation loss
|
994
|
+
(0..1).
|
995
|
+
temperature (float): Distillation temperature (>1 typically).
|
996
|
+
lr (float): Learning rate for the student.
|
997
|
+
epochs (int): Number of training epochs.
|
998
|
+
|
999
|
+
Returns:
|
1000
|
+
TorchModel: The final, trained student model.
|
1001
|
+
"""
|
1002
|
+
from spacr.utils import TorchModel # Adapt if needed
|
1003
|
+
|
1004
|
+
# Adjust filename to reflect knowledge-distillation if desired
|
1005
|
+
if student_save_path.endswith('.pth'):
|
1006
|
+
base, ext = os.path.splitext(student_save_path)
|
1007
|
+
else:
|
1008
|
+
base = student_save_path
|
1009
|
+
student_save_path = base + '_KD.pth'
|
1010
|
+
|
1011
|
+
# -- 1. Load teacher models --
|
1012
|
+
teachers = []
|
1013
|
+
print("Loading teacher models:")
|
1014
|
+
for path in teacher_paths:
|
1015
|
+
print(f" Loading teacher: {path}")
|
1016
|
+
ckpt = torch.load(path, map_location=device)
|
1017
|
+
if isinstance(ckpt, TorchModel):
|
1018
|
+
teacher = ckpt.to(device)
|
1019
|
+
elif isinstance(ckpt, dict):
|
1020
|
+
# If it's a dict with 'model' inside
|
1021
|
+
# We might need to check if it has 'model_name', etc.
|
1022
|
+
# But let's keep it simple: same architecture as the student
|
1023
|
+
teacher = TorchModel(
|
1024
|
+
model_name=ckpt.get('model_name', student_model_name),
|
1025
|
+
pretrained=ckpt.get('pretrained', pretrained),
|
1026
|
+
dropout_rate=ckpt.get('dropout_rate', dropout_rate),
|
1027
|
+
use_checkpoint=ckpt.get('use_checkpoint', use_checkpoint)
|
1028
|
+
).to(device)
|
1029
|
+
teacher.load_state_dict(ckpt['model'])
|
1030
|
+
else:
|
1031
|
+
raise ValueError(f"Unsupported checkpoint type at {path} (must be TorchModel or dict).")
|
1032
|
+
|
1033
|
+
teacher.eval() # For consistent batchnorm, dropout
|
1034
|
+
teachers.append(teacher)
|
1035
|
+
|
1036
|
+
# -- 2. Initialize the student TorchModel --
|
1037
|
+
student_model = TorchModel(
|
1038
|
+
model_name=student_model_name,
|
1039
|
+
pretrained=pretrained,
|
1040
|
+
dropout_rate=dropout_rate,
|
1041
|
+
use_checkpoint=use_checkpoint
|
1042
|
+
).to(device)
|
1043
|
+
|
1044
|
+
# You could load a partial checkpoint into the student here if desired.
|
1045
|
+
|
1046
|
+
# -- 3. Optimizer --
|
1047
|
+
optimizer = optim.Adam(student_model.parameters(), lr=lr)
|
1048
|
+
|
1049
|
+
# Distillation training loop
|
1050
|
+
for epoch in range(epochs):
|
1051
|
+
student_model.train()
|
1052
|
+
running_loss = 0.0
|
1053
|
+
|
1054
|
+
for images, labels in data_loader:
|
1055
|
+
images, labels = images.to(device), labels.to(device)
|
1056
|
+
|
1057
|
+
# Forward pass student
|
1058
|
+
logits_s = student_model(images) # shape: (B, num_classes)
|
1059
|
+
logits_s_temp = logits_s / temperature # scale by T
|
1060
|
+
|
1061
|
+
# Distillation from teachers
|
1062
|
+
with torch.no_grad():
|
1063
|
+
# We'll average teacher probabilities
|
1064
|
+
teacher_probs_list = []
|
1065
|
+
for tm in teachers:
|
1066
|
+
logits_t = tm(images) / temperature
|
1067
|
+
# convert to probabilities
|
1068
|
+
teacher_probs_list.append(F.softmax(logits_t, dim=1))
|
1069
|
+
# average them
|
1070
|
+
teacher_probs_ensemble = torch.mean(torch.stack(teacher_probs_list), dim=0)
|
1071
|
+
|
1072
|
+
# Student probabilities (log-softmax)
|
1073
|
+
student_log_probs = F.log_softmax(logits_s_temp, dim=1)
|
1074
|
+
|
1075
|
+
# Distillation loss => KLDiv
|
1076
|
+
loss_distill = F.kl_div(
|
1077
|
+
student_log_probs,
|
1078
|
+
teacher_probs_ensemble,
|
1079
|
+
reduction='batchmean'
|
1080
|
+
) * (temperature ** 2)
|
1081
|
+
|
1082
|
+
# Real label loss => cross-entropy
|
1083
|
+
# We can compute this on the raw logits or scaled. Typically raw logits is standard:
|
1084
|
+
loss_ce = F.cross_entropy(logits_s, labels)
|
1085
|
+
|
1086
|
+
# Weighted sum
|
1087
|
+
loss = alpha * loss_ce + (1 - alpha) * loss_distill
|
1088
|
+
|
1089
|
+
optimizer.zero_grad()
|
1090
|
+
loss.backward()
|
1091
|
+
optimizer.step()
|
1092
|
+
|
1093
|
+
running_loss += loss.item()
|
1094
|
+
|
1095
|
+
avg_loss = running_loss / len(data_loader)
|
1096
|
+
print(f"Epoch [{epoch+1}/{epochs}] - Loss: {avg_loss:.4f}")
|
1097
|
+
|
1098
|
+
# -- 4. Save final student as a TorchModel --
|
1099
|
+
torch.save(student_model, student_save_path)
|
1100
|
+
print(f"Knowledge-distilled student saved to: {student_save_path}")
|
1101
|
+
|
1102
|
+
return student_model
|
1103
|
+
|
1104
|
+
def model_fusion(model_paths,
|
1105
|
+
save_path,
|
1106
|
+
device='cpu',
|
1107
|
+
model_name='maxvit_t',
|
1108
|
+
pretrained=True,
|
1109
|
+
dropout_rate=None,
|
1110
|
+
use_checkpoint=False,
|
1111
|
+
aggregator='mean'):
|
1112
|
+
"""
|
1113
|
+
Fuses an arbitrary number of TorchModel instances by combining their weights
|
1114
|
+
(using mean, geomean, median, sum, max, or min) and saves the entire fused
|
1115
|
+
model object.
|
1116
|
+
|
1117
|
+
You can later load the fused model with:
|
1118
|
+
model = torch.load('fused_model.pth')
|
1119
|
+
|
1120
|
+
which returns a ready-to-use TorchModel instance.
|
1121
|
+
|
1122
|
+
Parameters:
|
1123
|
+
model_paths (list of str): Paths to the model checkpoints to fuse.
|
1124
|
+
Each checkpoint can be:
|
1125
|
+
- A dict with keys ['model', 'model_name', ...]
|
1126
|
+
- A TorchModel instance
|
1127
|
+
save_path (str): Destination path to save the fused model.
|
1128
|
+
device (str): 'cpu' or 'cuda' for loading weights and final model device.
|
1129
|
+
model_name (str): Default model name (used if not in checkpoint).
|
1130
|
+
pretrained (bool): Default if not in checkpoint.
|
1131
|
+
dropout_rate (float): Default if not in checkpoint.
|
1132
|
+
use_checkpoint (bool): Default if not in checkpoint.
|
1133
|
+
aggregator (str): How to combine weights across models:
|
1134
|
+
'mean', 'geomean', 'median', 'sum', 'max', or 'min'.
|
1135
|
+
|
1136
|
+
Returns:
|
1137
|
+
fused_model (TorchModel): The final fused TorchModel instance
|
1138
|
+
with combined weights.
|
1139
|
+
"""
|
1140
|
+
from spacr.utils import TorchModel
|
1141
|
+
|
1142
|
+
if save_path.endswith('.pth'):
|
1143
|
+
save_path_part1, ext = os.path.splitext(save_path)
|
1144
|
+
else:
|
1145
|
+
save_path_part1 = save_path
|
1146
|
+
|
1147
|
+
save_path = save_path_part1 + f'_{aggregator}.pth'
|
1148
|
+
|
1149
|
+
valid_aggregators = {'mean', 'geomean', 'median', 'sum', 'max', 'min'}
|
1150
|
+
if aggregator not in valid_aggregators:
|
1151
|
+
raise ValueError(f"Invalid aggregator '{aggregator}'. "
|
1152
|
+
f"Must be one of {valid_aggregators}.")
|
1153
|
+
|
1154
|
+
# --- 1. Load the first checkpoint to figure out architecture & hyperparams ---
|
1155
|
+
print(f"Loading the first model from: {model_paths[0]} to derive architecture")
|
1156
|
+
first_ckpt = torch.load(model_paths[0], map_location=device)
|
1157
|
+
|
1158
|
+
if isinstance(first_ckpt, dict):
|
1159
|
+
# It's a dict with state_dict + possibly metadata
|
1160
|
+
# Use any stored metadata if present
|
1161
|
+
model_name = first_ckpt.get('model_name', model_name)
|
1162
|
+
pretrained = first_ckpt.get('pretrained', pretrained)
|
1163
|
+
dropout_rate = first_ckpt.get('dropout_rate', dropout_rate)
|
1164
|
+
use_checkpoint = first_ckpt.get('use_checkpoint', use_checkpoint)
|
1165
|
+
|
1166
|
+
# Initialize the fused model
|
1167
|
+
fused_model = TorchModel(
|
1168
|
+
model_name=model_name,
|
1169
|
+
pretrained=pretrained,
|
1170
|
+
dropout_rate=dropout_rate,
|
1171
|
+
use_checkpoint=use_checkpoint
|
1172
|
+
).to(device)
|
1173
|
+
|
1174
|
+
# We'll collect state dicts in a list
|
1175
|
+
state_dicts = [first_ckpt['model']] # the actual weights
|
1176
|
+
elif isinstance(first_ckpt, TorchModel):
|
1177
|
+
# The checkpoint is directly a TorchModel instance
|
1178
|
+
fused_model = first_ckpt.to(device)
|
1179
|
+
state_dicts = [fused_model.state_dict()]
|
1180
|
+
else:
|
1181
|
+
raise ValueError("Unsupported checkpoint format. Must be a dict or a TorchModel instance.")
|
1182
|
+
|
1183
|
+
# --- 2. Load the rest of the checkpoints ---
|
1184
|
+
for path in model_paths[1:]:
|
1185
|
+
print(f"Loading model from: {path}")
|
1186
|
+
ckpt = torch.load(path, map_location=device)
|
1187
|
+
if isinstance(ckpt, dict):
|
1188
|
+
state_dicts.append(ckpt['model']) # Just the state dict portion
|
1189
|
+
elif isinstance(ckpt, TorchModel):
|
1190
|
+
state_dicts.append(ckpt.state_dict())
|
1191
|
+
else:
|
1192
|
+
raise ValueError(f"Unsupported checkpoint format in {path} (must be dict or TorchModel).")
|
1193
|
+
|
1194
|
+
# --- 3. Verify all state dicts have the same keys ---
|
1195
|
+
fused_sd = fused_model.state_dict()
|
1196
|
+
for sd in state_dicts:
|
1197
|
+
if fused_sd.keys() != sd.keys():
|
1198
|
+
raise ValueError("All models must have identical architecture/state_dict keys.")
|
1199
|
+
|
1200
|
+
# --- 4. Define aggregator logic ---
|
1201
|
+
def combine_tensors(tensor_list, mode='mean'):
|
1202
|
+
"""Given a list of Tensors, combine them using the chosen aggregator."""
|
1203
|
+
# stack along new dimension => shape (num_models, *tensor.shape)
|
1204
|
+
stacked = torch.stack(tensor_list, dim=0).float()
|
1205
|
+
|
1206
|
+
if mode == 'mean':
|
1207
|
+
return stacked.mean(dim=0)
|
1208
|
+
elif mode == 'geomean':
|
1209
|
+
# geometric mean = exp(mean(log(tensor)))
|
1210
|
+
# caution: requires all > 0
|
1211
|
+
return torch.exp(torch.log(stacked).mean(dim=0))
|
1212
|
+
elif mode == 'median':
|
1213
|
+
return stacked.median(dim=0).values
|
1214
|
+
elif mode == 'sum':
|
1215
|
+
return stacked.sum(dim=0)
|
1216
|
+
elif mode == 'max':
|
1217
|
+
return stacked.max(dim=0).values
|
1218
|
+
elif mode == 'min':
|
1219
|
+
return stacked.min(dim=0).values
|
1220
|
+
else:
|
1221
|
+
raise ValueError(f"Unsupported aggregator: {mode}")
|
1222
|
+
|
1223
|
+
# --- 5. Combine the weights ---
|
1224
|
+
for key in fused_sd.keys():
|
1225
|
+
# gather all versions of this tensor
|
1226
|
+
all_tensors = [sd[key] for sd in state_dicts]
|
1227
|
+
fused_sd[key] = combine_tensors(all_tensors, mode=aggregator)
|
1228
|
+
|
1229
|
+
# Load combined weights into the fused model
|
1230
|
+
fused_model.load_state_dict(fused_sd)
|
1231
|
+
|
1232
|
+
# --- 6. Save the entire TorchModel object ---
|
1233
|
+
torch.save(fused_model, save_path)
|
1234
|
+
print(f"Fused model (aggregator='{aggregator}') saved as a full TorchModel to: {save_path}")
|
1235
|
+
|
1236
|
+
return fused_model
|
1237
|
+
|
1238
|
+
def annotate_filter_vision(settings):
|
1239
|
+
|
1240
|
+
from .utils import annotate_conditions
|
1241
|
+
|
1242
|
+
def filter_csv_by_png(csv_file):
|
1243
|
+
"""
|
1244
|
+
Filters a DataFrame by removing rows that match PNG filenames in a folder.
|
1245
|
+
|
1246
|
+
Parameters:
|
1247
|
+
csv_file (str): Path to the CSV file.
|
1248
|
+
|
1249
|
+
Returns:
|
1250
|
+
pd.DataFrame: Filtered DataFrame.
|
1251
|
+
"""
|
1252
|
+
# Split the path to identify the datasets folder and build the training folder path
|
1253
|
+
before_datasets, after_datasets = csv_file.split(os.sep + "datasets" + os.sep, 1)
|
1254
|
+
train_fldr = os.path.join(before_datasets, 'datasets', 'training', 'train')
|
1255
|
+
|
1256
|
+
# Paths for train/nc and train/pc
|
1257
|
+
nc_folder = os.path.join(train_fldr, 'nc')
|
1258
|
+
pc_folder = os.path.join(train_fldr, 'pc')
|
1259
|
+
|
1260
|
+
# Load the CSV file into a DataFrame
|
1261
|
+
df = pd.read_csv(csv_file)
|
1262
|
+
|
1263
|
+
# Collect PNG filenames from train/nc and train/pc
|
1264
|
+
png_files = set()
|
1265
|
+
for folder in [nc_folder, pc_folder]:
|
1266
|
+
if os.path.exists(folder): # Ensure the folder exists
|
1267
|
+
png_files.update({file for file in os.listdir(folder) if file.endswith(".png")})
|
1268
|
+
|
1269
|
+
# Filter the DataFrame by excluding rows where filenames match PNG files
|
1270
|
+
filtered_df = df[~df['path'].isin(png_files)]
|
1271
|
+
|
1272
|
+
return filtered_df
|
1273
|
+
|
1274
|
+
if isinstance(settings['src'], str):
|
1275
|
+
settings['src'] = [settings['src']]
|
1276
|
+
|
1277
|
+
for src in settings['src']:
|
1278
|
+
ann_src, ext = os.path.splitext(src)
|
1279
|
+
output_csv = ann_src+'_annotated_filtered.csv'
|
1280
|
+
print(output_csv)
|
1281
|
+
|
1282
|
+
df = pd.read_csv(src)
|
1283
|
+
|
1284
|
+
if 'column_name' not in df.columns:
|
1285
|
+
df['column_name'] = df['column']
|
1286
|
+
|
1287
|
+
df = annotate_conditions(df,
|
1288
|
+
cells=settings['cells'],
|
1289
|
+
cell_loc=settings['cell_loc'],
|
1290
|
+
pathogens=settings['pathogens'],
|
1291
|
+
pathogen_loc=settings['pathogen_loc'],
|
1292
|
+
treatments=settings['treatments'],
|
1293
|
+
treatment_loc=settings['treatment_loc'])
|
1294
|
+
|
1295
|
+
if not settings['filter_column'] is None:
|
1296
|
+
if settings['filter_column'] in df.columns:
|
1297
|
+
filtered_df = df[(df[settings['filter_column']] > settings['upper_threshold']) | (df[settings['filter_column']] < settings['lower_threshold'])]
|
1298
|
+
print(f'Filtered DataFrame with {len(df)} rows to {len(filtered_df)} rows.')
|
1299
|
+
else:
|
1300
|
+
print(f"{settings['filter_column']} not in DataFrame columns.")
|
1301
|
+
filtered_df = df
|
1302
|
+
else:
|
1303
|
+
filtered_df = df
|
1304
|
+
|
1305
|
+
filtered_df.to_csv(output_csv, index=False)
|
1306
|
+
|
1307
|
+
if settings['remove_train']:
|
1308
|
+
df = filter_csv_by_png(output_csv)
|
1309
|
+
df.to_csv(output_csv, index=False)
|
spacr/gui_core.py
CHANGED
@@ -324,40 +324,6 @@ def show_next_figure():
|
|
324
324
|
index_control.set(figure_index)
|
325
325
|
index_control.set_to(len(figures) - 1)
|
326
326
|
display_figure(fig)
|
327
|
-
|
328
|
-
def process_fig_queue_v1():
|
329
|
-
global canvas, fig_queue, canvas_widget, parent_frame, uppdate_frequency, figures, figure_index, index_control
|
330
|
-
from .gui_elements import standardize_figure
|
331
|
-
|
332
|
-
#print("process_fig_queue called", flush=True)
|
333
|
-
try:
|
334
|
-
while not fig_queue.empty():
|
335
|
-
fig = fig_queue.get_nowait()
|
336
|
-
if fig is None:
|
337
|
-
print("Warning: Retrieved a None figure from fig_queue.", flush=True)
|
338
|
-
continue
|
339
|
-
|
340
|
-
# Standardize the figure appearance before adding it
|
341
|
-
fig = standardize_figure(fig)
|
342
|
-
figures.append(fig)
|
343
|
-
|
344
|
-
# Update slider maximum
|
345
|
-
index_control.set_to(len(figures) - 1)
|
346
|
-
|
347
|
-
# If no figure has been displayed yet
|
348
|
-
if figure_index == -1:
|
349
|
-
figure_index = 0
|
350
|
-
display_figure(figures[figure_index])
|
351
|
-
index_control.set(figure_index)
|
352
|
-
|
353
|
-
except Exception as e:
|
354
|
-
print("Exception in process_fig_queue:", e, flush=True)
|
355
|
-
traceback.print_exc()
|
356
|
-
|
357
|
-
finally:
|
358
|
-
# Schedule process_fig_queue() to run again
|
359
|
-
after_id = canvas_widget.after(uppdate_frequency, process_fig_queue)
|
360
|
-
parent_frame.after_tasks.append(after_id)
|
361
327
|
|
362
328
|
def process_fig_queue():
|
363
329
|
global canvas, fig_queue, canvas_widget, parent_frame, uppdate_frequency, figures, figure_index, index_control
|
@@ -544,7 +510,7 @@ def import_settings(settings_type='mask'):
|
|
544
510
|
#vars_dict = hide_all_settings(vars_dict, categories=None)
|
545
511
|
csv_settings = read_settings_from_csv(csv_file_path)
|
546
512
|
if settings_type == 'mask':
|
547
|
-
settings = set_default_settings_preprocess_generate_masks(
|
513
|
+
settings = set_default_settings_preprocess_generate_masks(settings={})
|
548
514
|
elif settings_type == 'measure':
|
549
515
|
settings = get_measure_crop_settings(settings={})
|
550
516
|
elif settings_type == 'classify':
|
@@ -596,7 +562,7 @@ def setup_settings_panel(vertical_container, settings_type='mask'):
|
|
596
562
|
settings_frame.grid_columnconfigure(0, weight=1)
|
597
563
|
|
598
564
|
if settings_type == 'mask':
|
599
|
-
settings = set_default_settings_preprocess_generate_masks(
|
565
|
+
settings = set_default_settings_preprocess_generate_masks(settings={})
|
600
566
|
elif settings_type == 'measure':
|
601
567
|
settings = get_measure_crop_settings(settings={})
|
602
568
|
elif settings_type == 'classify':
|
@@ -912,7 +878,7 @@ def start_process(q=None, fig_queue=None, settings_type='mask'):
|
|
912
878
|
q.put(f"Error: {e}")
|
913
879
|
return
|
914
880
|
|
915
|
-
if thread_control.get("run_thread") is not None:
|
881
|
+
if isinstance(thread_control, dict) and thread_control.get("run_thread") is not None:
|
916
882
|
initiate_abort()
|
917
883
|
|
918
884
|
stop_requested = Value('i', 0)
|
@@ -1018,6 +984,66 @@ def main_thread_update_function(root, q, fig_queue, canvas_widget):
|
|
1018
984
|
print(f"Error updating GUI canvas: {e}")
|
1019
985
|
finally:
|
1020
986
|
root.after(uppdate_frequency, lambda: main_thread_update_function(root, q, fig_queue, canvas_widget))
|
987
|
+
|
988
|
+
def cleanup_previous_instance():
|
989
|
+
"""
|
990
|
+
Cleans up resources from the previous application instance.
|
991
|
+
"""
|
992
|
+
global parent_frame, usage_bars, figures, figure_index, thread_control, canvas, q, fig_queue
|
993
|
+
|
994
|
+
# 1. Destroy all widgets in the parent frame
|
995
|
+
if parent_frame is not None:
|
996
|
+
for widget in parent_frame.winfo_children():
|
997
|
+
try:
|
998
|
+
widget.destroy()
|
999
|
+
except Exception as e:
|
1000
|
+
print(f"Error destroying widget: {e}")
|
1001
|
+
parent_frame.update_idletasks()
|
1002
|
+
parent_frame = None
|
1003
|
+
|
1004
|
+
# 2. Cancel all pending `after` tasks
|
1005
|
+
if parent_frame is not None:
|
1006
|
+
parent_window = parent_frame.winfo_toplevel()
|
1007
|
+
if hasattr(parent_window, 'after_tasks'):
|
1008
|
+
for after_id in parent_window.after_tasks:
|
1009
|
+
parent_window.after_cancel(after_id)
|
1010
|
+
parent_window.after_tasks = []
|
1011
|
+
|
1012
|
+
# 3. Clear global queues
|
1013
|
+
if q is not None:
|
1014
|
+
while not q.empty():
|
1015
|
+
q.get()
|
1016
|
+
q = None
|
1017
|
+
|
1018
|
+
if fig_queue is not None:
|
1019
|
+
while not fig_queue.empty():
|
1020
|
+
fig_queue.get()
|
1021
|
+
fig_queue = None
|
1022
|
+
|
1023
|
+
# 4. Stop and reset global thread control
|
1024
|
+
if thread_control is not None:
|
1025
|
+
thread_control['stop'] = True
|
1026
|
+
#thread_control = None
|
1027
|
+
|
1028
|
+
# 5. Reset usage bars, figures, and indices
|
1029
|
+
usage_bars = []
|
1030
|
+
figures = deque()
|
1031
|
+
figure_index = -1
|
1032
|
+
|
1033
|
+
# 6. Clear canvas or other visualizations
|
1034
|
+
if canvas is not None:
|
1035
|
+
try:
|
1036
|
+
if hasattr(canvas, 'figure'): # Check if it's a FigureCanvasTkAgg
|
1037
|
+
canvas.figure.clear() # Clear the Matplotlib figure
|
1038
|
+
canvas.get_tk_widget().destroy() # Destroy the Tkinter widget
|
1039
|
+
else:
|
1040
|
+
# Assume it's a standard Tkinter Canvas
|
1041
|
+
canvas.delete("all")
|
1042
|
+
except Exception as e:
|
1043
|
+
print(f"Error clearing canvas: {e}")
|
1044
|
+
canvas = None
|
1045
|
+
|
1046
|
+
print("Previous instance cleaned up successfully.")
|
1021
1047
|
|
1022
1048
|
def initiate_root(parent, settings_type='mask'):
|
1023
1049
|
"""
|
@@ -1033,7 +1059,11 @@ def initiate_root(parent, settings_type='mask'):
|
|
1033
1059
|
|
1034
1060
|
global q, fig_queue, thread_control, parent_frame, scrollable_frame, button_frame, vars_dict, canvas, canvas_widget, button_scrollable_frame, progress_bar, uppdate_frequency, figures, figure_index, index_control, usage_bars
|
1035
1061
|
|
1036
|
-
|
1062
|
+
# Clean up any previous instance
|
1063
|
+
cleanup_previous_instance()
|
1064
|
+
|
1065
|
+
from .gui_utils import setup_frame
|
1066
|
+
from .gui_elements import create_menu_bar
|
1037
1067
|
from .settings import descriptions
|
1038
1068
|
#from .openai import Chatbot
|
1039
1069
|
|
@@ -1096,6 +1126,7 @@ def initiate_root(parent, settings_type='mask'):
|
|
1096
1126
|
|
1097
1127
|
process_console_queue()
|
1098
1128
|
process_fig_queue()
|
1129
|
+
create_menu_bar(parent)
|
1099
1130
|
after_id = parent_window.after(uppdate_frequency, lambda: main_thread_update_function(parent_window, q, fig_queue, canvas_widget))
|
1100
1131
|
parent_window.after_tasks.append(after_id)
|
1101
1132
|
|