spacr 0.3.80__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +0 -4
- spacr/core.py +27 -13
- spacr/deep_spacr.py +378 -5
- spacr/gui_core.py +82 -20
- spacr/gui_elements.py +192 -3
- spacr/gui_utils.py +1 -1
- spacr/io.py +5 -176
- spacr/measure.py +10 -6
- spacr/ml.py +369 -46
- spacr/plot.py +201 -90
- spacr/settings.py +80 -21
- spacr/submodules.py +282 -1
- spacr/toxo.py +98 -75
- spacr/utils.py +144 -49
- {spacr-0.3.80.dist-info → spacr-0.4.0.dist-info}/METADATA +2 -1
- {spacr-0.3.80.dist-info → spacr-0.4.0.dist-info}/RECORD +20 -20
- {spacr-0.3.80.dist-info → spacr-0.4.0.dist-info}/LICENSE +0 -0
- {spacr-0.3.80.dist-info → spacr-0.4.0.dist-info}/WHEEL +0 -0
- {spacr-0.3.80.dist-info → spacr-0.4.0.dist-info}/entry_points.txt +0 -0
- {spacr-0.3.80.dist-info → spacr-0.4.0.dist-info}/top_level.txt +0 -0
spacr/__init__.py
CHANGED
@@ -67,8 +67,4 @@ logging.basicConfig(filename='spacr.log', level=logging.INFO,
|
|
67
67
|
|
68
68
|
from .utils import download_models
|
69
69
|
|
70
|
-
# Check if models already exist
|
71
|
-
#models_dir = os.path.join(os.path.dirname(__file__), 'resources', 'models', 'cp')
|
72
|
-
#if not os.path.exists(models_dir) or not os.listdir(models_dir):
|
73
|
-
# print("Models not found, downloading...")
|
74
70
|
download_models()
|
spacr/core.py
CHANGED
@@ -7,15 +7,20 @@ from IPython.display import display
|
|
7
7
|
import warnings
|
8
8
|
warnings.filterwarnings("ignore", message="3D stack used, but stitch_threshold=0 and do_3D=False, so masks are made per plane only")
|
9
9
|
|
10
|
-
def preprocess_generate_masks(
|
10
|
+
def preprocess_generate_masks(settings):
|
11
11
|
|
12
12
|
from .io import preprocess_img_data, _load_and_concatenate_arrays
|
13
13
|
from .plot import plot_image_mask_overlay, plot_arrays
|
14
|
-
from .utils import _pivot_counts_table, check_mask_folder, adjust_cell_masks, print_progress, save_settings
|
14
|
+
from .utils import _pivot_counts_table, check_mask_folder, adjust_cell_masks, print_progress, save_settings, delete_intermedeate_files
|
15
15
|
from .settings import set_default_settings_preprocess_generate_masks
|
16
|
-
|
17
|
-
|
18
|
-
|
16
|
+
|
17
|
+
|
18
|
+
if 'src' in settings:
|
19
|
+
if not isinstance(settings['src'], (str, list)):
|
20
|
+
ValueError(f'src must be a string or a list of strings')
|
21
|
+
return
|
22
|
+
else:
|
23
|
+
ValueError(f'src is a required parameter')
|
19
24
|
return
|
20
25
|
|
21
26
|
if isinstance(settings['src'], str):
|
@@ -27,9 +32,8 @@ def preprocess_generate_masks(src, settings={}):
|
|
27
32
|
print(f'Processing folder: {source_folder}')
|
28
33
|
settings['src'] = source_folder
|
29
34
|
src = source_folder
|
30
|
-
settings = set_default_settings_preprocess_generate_masks(
|
31
|
-
|
32
|
-
save_settings(settings, name='gen_mask')
|
35
|
+
settings = set_default_settings_preprocess_generate_masks(settings)
|
36
|
+
save_settings(settings, name='gen_mask_settings')
|
33
37
|
|
34
38
|
if not settings['pathogen_channel'] is None:
|
35
39
|
custom_model_ls = ['toxo_pv_lumen','toxo_cyto']
|
@@ -158,6 +162,10 @@ def preprocess_generate_masks(src, settings={}):
|
|
158
162
|
|
159
163
|
torch.cuda.empty_cache()
|
160
164
|
gc.collect()
|
165
|
+
|
166
|
+
if settings['delete_intermediate']:
|
167
|
+
delete_intermedeate_files(settings)
|
168
|
+
|
161
169
|
print("Successfully completed run")
|
162
170
|
return
|
163
171
|
|
@@ -172,8 +180,10 @@ def generate_cellpose_masks(src, settings, object_type):
|
|
172
180
|
gc.collect()
|
173
181
|
if not torch.cuda.is_available():
|
174
182
|
print(f'Torch CUDA is not available, using CPU')
|
175
|
-
|
176
|
-
settings =
|
183
|
+
|
184
|
+
settings['src'] = src
|
185
|
+
|
186
|
+
settings = set_default_settings_preprocess_generate_masks(settings)
|
177
187
|
|
178
188
|
if settings['verbose']:
|
179
189
|
settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
|
@@ -192,11 +202,13 @@ def generate_cellpose_masks(src, settings, object_type):
|
|
192
202
|
timelapse_objects = settings['timelapse_objects']
|
193
203
|
|
194
204
|
batch_size = settings['batch_size']
|
205
|
+
|
195
206
|
cellprob_threshold = settings[f'{object_type}_CP_prob']
|
196
207
|
|
197
208
|
flow_threshold = settings[f'{object_type}_FT']
|
198
209
|
|
199
210
|
object_settings = _get_object_settings(object_type, settings)
|
211
|
+
|
200
212
|
model_name = object_settings['model_name']
|
201
213
|
|
202
214
|
cellpose_channels = _get_cellpose_channels(src, settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel'])
|
@@ -473,7 +485,7 @@ def generate_image_umap(settings={}):
|
|
473
485
|
df, image_paths_tmp = correct_paths(df, settings['src'][i])
|
474
486
|
all_df = pd.concat([all_df, df], axis=0)
|
475
487
|
#image_paths.extend(image_paths_tmp)
|
476
|
-
|
488
|
+
|
477
489
|
all_df['cond'] = all_df['column_name'].apply(map_condition, neg=settings['neg'], pos=settings['pos'], mix=settings['mix'])
|
478
490
|
|
479
491
|
if settings['exclude_conditions']:
|
@@ -493,7 +505,7 @@ def generate_image_umap(settings={}):
|
|
493
505
|
|
494
506
|
# Extract and reset the index for the column to compare
|
495
507
|
col_to_compare = all_df[settings['col_to_compare']].reset_index(drop=True)
|
496
|
-
|
508
|
+
print(col_to_compare)
|
497
509
|
#if settings['only_top_features']:
|
498
510
|
# column_list = None
|
499
511
|
|
@@ -781,8 +793,10 @@ def generate_mediar_masks(src, settings, object_type):
|
|
781
793
|
if not torch.cuda.is_available():
|
782
794
|
print(f'Torch CUDA is not available, using CPU')
|
783
795
|
|
796
|
+
settings['src'] = src
|
797
|
+
|
784
798
|
# Preprocess settings
|
785
|
-
settings = set_default_settings_preprocess_generate_masks(
|
799
|
+
settings = set_default_settings_preprocess_generate_masks(settings)
|
786
800
|
|
787
801
|
if settings['verbose']:
|
788
802
|
settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
|
spacr/deep_spacr.py
CHANGED
@@ -12,6 +12,7 @@ from PIL import Image
|
|
12
12
|
from sklearn.metrics import auc, precision_recall_curve
|
13
13
|
from IPython.display import display
|
14
14
|
from multiprocessing import cpu_count
|
15
|
+
import torch.optim as optim
|
15
16
|
|
16
17
|
from torchvision import transforms
|
17
18
|
from torch.utils.data import DataLoader
|
@@ -76,10 +77,11 @@ def apply_model_to_tar(settings={}):
|
|
76
77
|
from .io import TarImageDataset
|
77
78
|
from .utils import process_vision_results, print_progress
|
78
79
|
|
79
|
-
if os.path.exists(settings['dataset']):
|
80
|
-
|
81
|
-
else:
|
82
|
-
|
80
|
+
#if os.path.exists(settings['dataset']):
|
81
|
+
# tar_path = settings['dataset']
|
82
|
+
#else:
|
83
|
+
# tar_path = os.path.join(settings['src'], 'datasets', settings['dataset'])
|
84
|
+
tar_path = settings['tar_path']
|
83
85
|
model_path = settings['model_path']
|
84
86
|
|
85
87
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
@@ -103,7 +105,8 @@ def apply_model_to_tar(settings={}):
|
|
103
105
|
data_loader = DataLoader(dataset, batch_size=settings['batch_size'], shuffle=True, num_workers=settings['n_jobs'], pin_memory=True)
|
104
106
|
|
105
107
|
model_name = os.path.splitext(os.path.basename(model_path))[0]
|
106
|
-
dataset_name = os.path.splitext(os.path.basename(settings['dataset']))[0]
|
108
|
+
#dataset_name = os.path.splitext(os.path.basename(settings['dataset']))[0]
|
109
|
+
dataset_name = os.path.splitext(os.path.basename(settings['tar_path']))[0]
|
107
110
|
date_name = datetime.date.today().strftime('%y%m%d')
|
108
111
|
dst = os.path.dirname(tar_path)
|
109
112
|
result_loc = f'{dst}/{date_name}_{dataset_name}_{model_name}_result.csv'
|
@@ -934,3 +937,373 @@ def deep_spacr(settings={}):
|
|
934
937
|
|
935
938
|
if os.path.exists(settings['model_path']):
|
936
939
|
apply_model_to_tar(settings)
|
940
|
+
|
941
|
+
def model_knowledge_transfer(
|
942
|
+
teacher_paths,
|
943
|
+
student_save_path,
|
944
|
+
data_loader, # A DataLoader with (images, labels)
|
945
|
+
device='cpu',
|
946
|
+
student_model_name='maxvit_t',
|
947
|
+
pretrained=True,
|
948
|
+
dropout_rate=None,
|
949
|
+
use_checkpoint=False,
|
950
|
+
alpha=0.5,
|
951
|
+
temperature=2.0,
|
952
|
+
lr=1e-4,
|
953
|
+
epochs=10
|
954
|
+
):
|
955
|
+
"""
|
956
|
+
Performs multi-teacher knowledge distillation on a new labeled dataset,
|
957
|
+
producing a single student TorchModel that combines (distills) the
|
958
|
+
teachers' knowledge plus the labeled data.
|
959
|
+
|
960
|
+
Usage:
|
961
|
+
student = model_knowledge_transfer(
|
962
|
+
teacher_paths=[
|
963
|
+
'teacherA.pth',
|
964
|
+
'teacherB.pth',
|
965
|
+
...
|
966
|
+
],
|
967
|
+
student_save_path='distilled_student.pth',
|
968
|
+
data_loader=my_data_loader,
|
969
|
+
device='cuda',
|
970
|
+
student_model_name='maxvit_t',
|
971
|
+
alpha=0.5,
|
972
|
+
temperature=2.0,
|
973
|
+
lr=1e-4,
|
974
|
+
epochs=10
|
975
|
+
)
|
976
|
+
|
977
|
+
Then load it via:
|
978
|
+
fused_student = torch.load('distilled_student.pth')
|
979
|
+
# fused_student is a TorchModel instance, ready for inference.
|
980
|
+
|
981
|
+
Args:
|
982
|
+
teacher_paths (list[str]): List of paths to teacher models (TorchModel
|
983
|
+
or dict with 'model' in it). They must have the same architecture
|
984
|
+
or at least produce the same dimension of output.
|
985
|
+
student_save_path (str): Destination path to save the final student
|
986
|
+
TorchModel.
|
987
|
+
data_loader (DataLoader): Yields (images, labels) from the new dataset.
|
988
|
+
device (str): 'cpu' or 'cuda'.
|
989
|
+
student_model_name (str): Architecture name for the student TorchModel.
|
990
|
+
pretrained (bool): If the student should be initialized as pretrained.
|
991
|
+
dropout_rate (float): If needed by your TorchModel init.
|
992
|
+
use_checkpoint (bool): If needed by your TorchModel init.
|
993
|
+
alpha (float): Weight balancing real-label CE vs. distillation loss
|
994
|
+
(0..1).
|
995
|
+
temperature (float): Distillation temperature (>1 typically).
|
996
|
+
lr (float): Learning rate for the student.
|
997
|
+
epochs (int): Number of training epochs.
|
998
|
+
|
999
|
+
Returns:
|
1000
|
+
TorchModel: The final, trained student model.
|
1001
|
+
"""
|
1002
|
+
from spacr.utils import TorchModel # Adapt if needed
|
1003
|
+
|
1004
|
+
# Adjust filename to reflect knowledge-distillation if desired
|
1005
|
+
if student_save_path.endswith('.pth'):
|
1006
|
+
base, ext = os.path.splitext(student_save_path)
|
1007
|
+
else:
|
1008
|
+
base = student_save_path
|
1009
|
+
student_save_path = base + '_KD.pth'
|
1010
|
+
|
1011
|
+
# -- 1. Load teacher models --
|
1012
|
+
teachers = []
|
1013
|
+
print("Loading teacher models:")
|
1014
|
+
for path in teacher_paths:
|
1015
|
+
print(f" Loading teacher: {path}")
|
1016
|
+
ckpt = torch.load(path, map_location=device)
|
1017
|
+
if isinstance(ckpt, TorchModel):
|
1018
|
+
teacher = ckpt.to(device)
|
1019
|
+
elif isinstance(ckpt, dict):
|
1020
|
+
# If it's a dict with 'model' inside
|
1021
|
+
# We might need to check if it has 'model_name', etc.
|
1022
|
+
# But let's keep it simple: same architecture as the student
|
1023
|
+
teacher = TorchModel(
|
1024
|
+
model_name=ckpt.get('model_name', student_model_name),
|
1025
|
+
pretrained=ckpt.get('pretrained', pretrained),
|
1026
|
+
dropout_rate=ckpt.get('dropout_rate', dropout_rate),
|
1027
|
+
use_checkpoint=ckpt.get('use_checkpoint', use_checkpoint)
|
1028
|
+
).to(device)
|
1029
|
+
teacher.load_state_dict(ckpt['model'])
|
1030
|
+
else:
|
1031
|
+
raise ValueError(f"Unsupported checkpoint type at {path} (must be TorchModel or dict).")
|
1032
|
+
|
1033
|
+
teacher.eval() # For consistent batchnorm, dropout
|
1034
|
+
teachers.append(teacher)
|
1035
|
+
|
1036
|
+
# -- 2. Initialize the student TorchModel --
|
1037
|
+
student_model = TorchModel(
|
1038
|
+
model_name=student_model_name,
|
1039
|
+
pretrained=pretrained,
|
1040
|
+
dropout_rate=dropout_rate,
|
1041
|
+
use_checkpoint=use_checkpoint
|
1042
|
+
).to(device)
|
1043
|
+
|
1044
|
+
# You could load a partial checkpoint into the student here if desired.
|
1045
|
+
|
1046
|
+
# -- 3. Optimizer --
|
1047
|
+
optimizer = optim.Adam(student_model.parameters(), lr=lr)
|
1048
|
+
|
1049
|
+
# Distillation training loop
|
1050
|
+
for epoch in range(epochs):
|
1051
|
+
student_model.train()
|
1052
|
+
running_loss = 0.0
|
1053
|
+
|
1054
|
+
for images, labels in data_loader:
|
1055
|
+
images, labels = images.to(device), labels.to(device)
|
1056
|
+
|
1057
|
+
# Forward pass student
|
1058
|
+
logits_s = student_model(images) # shape: (B, num_classes)
|
1059
|
+
logits_s_temp = logits_s / temperature # scale by T
|
1060
|
+
|
1061
|
+
# Distillation from teachers
|
1062
|
+
with torch.no_grad():
|
1063
|
+
# We'll average teacher probabilities
|
1064
|
+
teacher_probs_list = []
|
1065
|
+
for tm in teachers:
|
1066
|
+
logits_t = tm(images) / temperature
|
1067
|
+
# convert to probabilities
|
1068
|
+
teacher_probs_list.append(F.softmax(logits_t, dim=1))
|
1069
|
+
# average them
|
1070
|
+
teacher_probs_ensemble = torch.mean(torch.stack(teacher_probs_list), dim=0)
|
1071
|
+
|
1072
|
+
# Student probabilities (log-softmax)
|
1073
|
+
student_log_probs = F.log_softmax(logits_s_temp, dim=1)
|
1074
|
+
|
1075
|
+
# Distillation loss => KLDiv
|
1076
|
+
loss_distill = F.kl_div(
|
1077
|
+
student_log_probs,
|
1078
|
+
teacher_probs_ensemble,
|
1079
|
+
reduction='batchmean'
|
1080
|
+
) * (temperature ** 2)
|
1081
|
+
|
1082
|
+
# Real label loss => cross-entropy
|
1083
|
+
# We can compute this on the raw logits or scaled. Typically raw logits is standard:
|
1084
|
+
loss_ce = F.cross_entropy(logits_s, labels)
|
1085
|
+
|
1086
|
+
# Weighted sum
|
1087
|
+
loss = alpha * loss_ce + (1 - alpha) * loss_distill
|
1088
|
+
|
1089
|
+
optimizer.zero_grad()
|
1090
|
+
loss.backward()
|
1091
|
+
optimizer.step()
|
1092
|
+
|
1093
|
+
running_loss += loss.item()
|
1094
|
+
|
1095
|
+
avg_loss = running_loss / len(data_loader)
|
1096
|
+
print(f"Epoch [{epoch+1}/{epochs}] - Loss: {avg_loss:.4f}")
|
1097
|
+
|
1098
|
+
# -- 4. Save final student as a TorchModel --
|
1099
|
+
torch.save(student_model, student_save_path)
|
1100
|
+
print(f"Knowledge-distilled student saved to: {student_save_path}")
|
1101
|
+
|
1102
|
+
return student_model
|
1103
|
+
|
1104
|
+
def model_fusion(model_paths,
|
1105
|
+
save_path,
|
1106
|
+
device='cpu',
|
1107
|
+
model_name='maxvit_t',
|
1108
|
+
pretrained=True,
|
1109
|
+
dropout_rate=None,
|
1110
|
+
use_checkpoint=False,
|
1111
|
+
aggregator='mean'):
|
1112
|
+
"""
|
1113
|
+
Fuses an arbitrary number of TorchModel instances by combining their weights
|
1114
|
+
(using mean, geomean, median, sum, max, or min) and saves the entire fused
|
1115
|
+
model object.
|
1116
|
+
|
1117
|
+
You can later load the fused model with:
|
1118
|
+
model = torch.load('fused_model.pth')
|
1119
|
+
|
1120
|
+
which returns a ready-to-use TorchModel instance.
|
1121
|
+
|
1122
|
+
Parameters:
|
1123
|
+
model_paths (list of str): Paths to the model checkpoints to fuse.
|
1124
|
+
Each checkpoint can be:
|
1125
|
+
- A dict with keys ['model', 'model_name', ...]
|
1126
|
+
- A TorchModel instance
|
1127
|
+
save_path (str): Destination path to save the fused model.
|
1128
|
+
device (str): 'cpu' or 'cuda' for loading weights and final model device.
|
1129
|
+
model_name (str): Default model name (used if not in checkpoint).
|
1130
|
+
pretrained (bool): Default if not in checkpoint.
|
1131
|
+
dropout_rate (float): Default if not in checkpoint.
|
1132
|
+
use_checkpoint (bool): Default if not in checkpoint.
|
1133
|
+
aggregator (str): How to combine weights across models:
|
1134
|
+
'mean', 'geomean', 'median', 'sum', 'max', or 'min'.
|
1135
|
+
|
1136
|
+
Returns:
|
1137
|
+
fused_model (TorchModel): The final fused TorchModel instance
|
1138
|
+
with combined weights.
|
1139
|
+
"""
|
1140
|
+
from spacr.utils import TorchModel
|
1141
|
+
|
1142
|
+
if save_path.endswith('.pth'):
|
1143
|
+
save_path_part1, ext = os.path.splitext(save_path)
|
1144
|
+
else:
|
1145
|
+
save_path_part1 = save_path
|
1146
|
+
|
1147
|
+
save_path = save_path_part1 + f'_{aggregator}.pth'
|
1148
|
+
|
1149
|
+
valid_aggregators = {'mean', 'geomean', 'median', 'sum', 'max', 'min'}
|
1150
|
+
if aggregator not in valid_aggregators:
|
1151
|
+
raise ValueError(f"Invalid aggregator '{aggregator}'. "
|
1152
|
+
f"Must be one of {valid_aggregators}.")
|
1153
|
+
|
1154
|
+
# --- 1. Load the first checkpoint to figure out architecture & hyperparams ---
|
1155
|
+
print(f"Loading the first model from: {model_paths[0]} to derive architecture")
|
1156
|
+
first_ckpt = torch.load(model_paths[0], map_location=device)
|
1157
|
+
|
1158
|
+
if isinstance(first_ckpt, dict):
|
1159
|
+
# It's a dict with state_dict + possibly metadata
|
1160
|
+
# Use any stored metadata if present
|
1161
|
+
model_name = first_ckpt.get('model_name', model_name)
|
1162
|
+
pretrained = first_ckpt.get('pretrained', pretrained)
|
1163
|
+
dropout_rate = first_ckpt.get('dropout_rate', dropout_rate)
|
1164
|
+
use_checkpoint = first_ckpt.get('use_checkpoint', use_checkpoint)
|
1165
|
+
|
1166
|
+
# Initialize the fused model
|
1167
|
+
fused_model = TorchModel(
|
1168
|
+
model_name=model_name,
|
1169
|
+
pretrained=pretrained,
|
1170
|
+
dropout_rate=dropout_rate,
|
1171
|
+
use_checkpoint=use_checkpoint
|
1172
|
+
).to(device)
|
1173
|
+
|
1174
|
+
# We'll collect state dicts in a list
|
1175
|
+
state_dicts = [first_ckpt['model']] # the actual weights
|
1176
|
+
elif isinstance(first_ckpt, TorchModel):
|
1177
|
+
# The checkpoint is directly a TorchModel instance
|
1178
|
+
fused_model = first_ckpt.to(device)
|
1179
|
+
state_dicts = [fused_model.state_dict()]
|
1180
|
+
else:
|
1181
|
+
raise ValueError("Unsupported checkpoint format. Must be a dict or a TorchModel instance.")
|
1182
|
+
|
1183
|
+
# --- 2. Load the rest of the checkpoints ---
|
1184
|
+
for path in model_paths[1:]:
|
1185
|
+
print(f"Loading model from: {path}")
|
1186
|
+
ckpt = torch.load(path, map_location=device)
|
1187
|
+
if isinstance(ckpt, dict):
|
1188
|
+
state_dicts.append(ckpt['model']) # Just the state dict portion
|
1189
|
+
elif isinstance(ckpt, TorchModel):
|
1190
|
+
state_dicts.append(ckpt.state_dict())
|
1191
|
+
else:
|
1192
|
+
raise ValueError(f"Unsupported checkpoint format in {path} (must be dict or TorchModel).")
|
1193
|
+
|
1194
|
+
# --- 3. Verify all state dicts have the same keys ---
|
1195
|
+
fused_sd = fused_model.state_dict()
|
1196
|
+
for sd in state_dicts:
|
1197
|
+
if fused_sd.keys() != sd.keys():
|
1198
|
+
raise ValueError("All models must have identical architecture/state_dict keys.")
|
1199
|
+
|
1200
|
+
# --- 4. Define aggregator logic ---
|
1201
|
+
def combine_tensors(tensor_list, mode='mean'):
|
1202
|
+
"""Given a list of Tensors, combine them using the chosen aggregator."""
|
1203
|
+
# stack along new dimension => shape (num_models, *tensor.shape)
|
1204
|
+
stacked = torch.stack(tensor_list, dim=0).float()
|
1205
|
+
|
1206
|
+
if mode == 'mean':
|
1207
|
+
return stacked.mean(dim=0)
|
1208
|
+
elif mode == 'geomean':
|
1209
|
+
# geometric mean = exp(mean(log(tensor)))
|
1210
|
+
# caution: requires all > 0
|
1211
|
+
return torch.exp(torch.log(stacked).mean(dim=0))
|
1212
|
+
elif mode == 'median':
|
1213
|
+
return stacked.median(dim=0).values
|
1214
|
+
elif mode == 'sum':
|
1215
|
+
return stacked.sum(dim=0)
|
1216
|
+
elif mode == 'max':
|
1217
|
+
return stacked.max(dim=0).values
|
1218
|
+
elif mode == 'min':
|
1219
|
+
return stacked.min(dim=0).values
|
1220
|
+
else:
|
1221
|
+
raise ValueError(f"Unsupported aggregator: {mode}")
|
1222
|
+
|
1223
|
+
# --- 5. Combine the weights ---
|
1224
|
+
for key in fused_sd.keys():
|
1225
|
+
# gather all versions of this tensor
|
1226
|
+
all_tensors = [sd[key] for sd in state_dicts]
|
1227
|
+
fused_sd[key] = combine_tensors(all_tensors, mode=aggregator)
|
1228
|
+
|
1229
|
+
# Load combined weights into the fused model
|
1230
|
+
fused_model.load_state_dict(fused_sd)
|
1231
|
+
|
1232
|
+
# --- 6. Save the entire TorchModel object ---
|
1233
|
+
torch.save(fused_model, save_path)
|
1234
|
+
print(f"Fused model (aggregator='{aggregator}') saved as a full TorchModel to: {save_path}")
|
1235
|
+
|
1236
|
+
return fused_model
|
1237
|
+
|
1238
|
+
def annotate_filter_vision(settings):
|
1239
|
+
|
1240
|
+
from .utils import annotate_conditions
|
1241
|
+
|
1242
|
+
def filter_csv_by_png(csv_file):
|
1243
|
+
"""
|
1244
|
+
Filters a DataFrame by removing rows that match PNG filenames in a folder.
|
1245
|
+
|
1246
|
+
Parameters:
|
1247
|
+
csv_file (str): Path to the CSV file.
|
1248
|
+
|
1249
|
+
Returns:
|
1250
|
+
pd.DataFrame: Filtered DataFrame.
|
1251
|
+
"""
|
1252
|
+
# Split the path to identify the datasets folder and build the training folder path
|
1253
|
+
before_datasets, after_datasets = csv_file.split(os.sep + "datasets" + os.sep, 1)
|
1254
|
+
train_fldr = os.path.join(before_datasets, 'datasets', 'training', 'train')
|
1255
|
+
|
1256
|
+
# Paths for train/nc and train/pc
|
1257
|
+
nc_folder = os.path.join(train_fldr, 'nc')
|
1258
|
+
pc_folder = os.path.join(train_fldr, 'pc')
|
1259
|
+
|
1260
|
+
# Load the CSV file into a DataFrame
|
1261
|
+
df = pd.read_csv(csv_file)
|
1262
|
+
|
1263
|
+
# Collect PNG filenames from train/nc and train/pc
|
1264
|
+
png_files = set()
|
1265
|
+
for folder in [nc_folder, pc_folder]:
|
1266
|
+
if os.path.exists(folder): # Ensure the folder exists
|
1267
|
+
png_files.update({file for file in os.listdir(folder) if file.endswith(".png")})
|
1268
|
+
|
1269
|
+
# Filter the DataFrame by excluding rows where filenames match PNG files
|
1270
|
+
filtered_df = df[~df['path'].isin(png_files)]
|
1271
|
+
|
1272
|
+
return filtered_df
|
1273
|
+
|
1274
|
+
if isinstance(settings['src'], str):
|
1275
|
+
settings['src'] = [settings['src']]
|
1276
|
+
|
1277
|
+
for src in settings['src']:
|
1278
|
+
ann_src, ext = os.path.splitext(src)
|
1279
|
+
output_csv = ann_src+'_annotated_filtered.csv'
|
1280
|
+
print(output_csv)
|
1281
|
+
|
1282
|
+
df = pd.read_csv(src)
|
1283
|
+
|
1284
|
+
if 'column_name' not in df.columns:
|
1285
|
+
df['column_name'] = df['column']
|
1286
|
+
|
1287
|
+
df = annotate_conditions(df,
|
1288
|
+
cells=settings['cells'],
|
1289
|
+
cell_loc=settings['cell_loc'],
|
1290
|
+
pathogens=settings['pathogens'],
|
1291
|
+
pathogen_loc=settings['pathogen_loc'],
|
1292
|
+
treatments=settings['treatments'],
|
1293
|
+
treatment_loc=settings['treatment_loc'])
|
1294
|
+
|
1295
|
+
if not settings['filter_column'] is None:
|
1296
|
+
if settings['filter_column'] in df.columns:
|
1297
|
+
filtered_df = df[(df[settings['filter_column']] > settings['upper_threshold']) | (df[settings['filter_column']] < settings['lower_threshold'])]
|
1298
|
+
print(f'Filtered DataFrame with {len(df)} rows to {len(filtered_df)} rows.')
|
1299
|
+
else:
|
1300
|
+
print(f"{settings['filter_column']} not in DataFrame columns.")
|
1301
|
+
filtered_df = df
|
1302
|
+
else:
|
1303
|
+
filtered_df = df
|
1304
|
+
|
1305
|
+
filtered_df.to_csv(output_csv, index=False)
|
1306
|
+
|
1307
|
+
if settings['remove_train']:
|
1308
|
+
df = filter_csv_by_png(output_csv)
|
1309
|
+
df.to_csv(output_csv, index=False)
|