spacr 0.3.80__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/__init__.py CHANGED
@@ -67,8 +67,4 @@ logging.basicConfig(filename='spacr.log', level=logging.INFO,
67
67
 
68
68
  from .utils import download_models
69
69
 
70
- # Check if models already exist
71
- #models_dir = os.path.join(os.path.dirname(__file__), 'resources', 'models', 'cp')
72
- #if not os.path.exists(models_dir) or not os.listdir(models_dir):
73
- # print("Models not found, downloading...")
74
70
  download_models()
spacr/core.py CHANGED
@@ -7,15 +7,20 @@ from IPython.display import display
7
7
  import warnings
8
8
  warnings.filterwarnings("ignore", message="3D stack used, but stitch_threshold=0 and do_3D=False, so masks are made per plane only")
9
9
 
10
- def preprocess_generate_masks(src, settings={}):
10
+ def preprocess_generate_masks(settings):
11
11
 
12
12
  from .io import preprocess_img_data, _load_and_concatenate_arrays
13
13
  from .plot import plot_image_mask_overlay, plot_arrays
14
- from .utils import _pivot_counts_table, check_mask_folder, adjust_cell_masks, print_progress, save_settings
14
+ from .utils import _pivot_counts_table, check_mask_folder, adjust_cell_masks, print_progress, save_settings, delete_intermedeate_files
15
15
  from .settings import set_default_settings_preprocess_generate_masks
16
-
17
- if not isinstance(settings['src'], (str, list)):
18
- ValueError(f'src must be a string or a list of strings')
16
+
17
+
18
+ if 'src' in settings:
19
+ if not isinstance(settings['src'], (str, list)):
20
+ ValueError(f'src must be a string or a list of strings')
21
+ return
22
+ else:
23
+ ValueError(f'src is a required parameter')
19
24
  return
20
25
 
21
26
  if isinstance(settings['src'], str):
@@ -27,9 +32,8 @@ def preprocess_generate_masks(src, settings={}):
27
32
  print(f'Processing folder: {source_folder}')
28
33
  settings['src'] = source_folder
29
34
  src = source_folder
30
- settings = set_default_settings_preprocess_generate_masks(src, settings)
31
-
32
- save_settings(settings, name='gen_mask')
35
+ settings = set_default_settings_preprocess_generate_masks(settings)
36
+ save_settings(settings, name='gen_mask_settings')
33
37
 
34
38
  if not settings['pathogen_channel'] is None:
35
39
  custom_model_ls = ['toxo_pv_lumen','toxo_cyto']
@@ -158,6 +162,10 @@ def preprocess_generate_masks(src, settings={}):
158
162
 
159
163
  torch.cuda.empty_cache()
160
164
  gc.collect()
165
+
166
+ if settings['delete_intermediate']:
167
+ delete_intermedeate_files(settings)
168
+
161
169
  print("Successfully completed run")
162
170
  return
163
171
 
@@ -172,8 +180,10 @@ def generate_cellpose_masks(src, settings, object_type):
172
180
  gc.collect()
173
181
  if not torch.cuda.is_available():
174
182
  print(f'Torch CUDA is not available, using CPU')
175
-
176
- settings = set_default_settings_preprocess_generate_masks(src, settings)
183
+
184
+ settings['src'] = src
185
+
186
+ settings = set_default_settings_preprocess_generate_masks(settings)
177
187
 
178
188
  if settings['verbose']:
179
189
  settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
@@ -192,11 +202,13 @@ def generate_cellpose_masks(src, settings, object_type):
192
202
  timelapse_objects = settings['timelapse_objects']
193
203
 
194
204
  batch_size = settings['batch_size']
205
+
195
206
  cellprob_threshold = settings[f'{object_type}_CP_prob']
196
207
 
197
208
  flow_threshold = settings[f'{object_type}_FT']
198
209
 
199
210
  object_settings = _get_object_settings(object_type, settings)
211
+
200
212
  model_name = object_settings['model_name']
201
213
 
202
214
  cellpose_channels = _get_cellpose_channels(src, settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel'])
@@ -473,7 +485,7 @@ def generate_image_umap(settings={}):
473
485
  df, image_paths_tmp = correct_paths(df, settings['src'][i])
474
486
  all_df = pd.concat([all_df, df], axis=0)
475
487
  #image_paths.extend(image_paths_tmp)
476
-
488
+
477
489
  all_df['cond'] = all_df['column_name'].apply(map_condition, neg=settings['neg'], pos=settings['pos'], mix=settings['mix'])
478
490
 
479
491
  if settings['exclude_conditions']:
@@ -493,7 +505,7 @@ def generate_image_umap(settings={}):
493
505
 
494
506
  # Extract and reset the index for the column to compare
495
507
  col_to_compare = all_df[settings['col_to_compare']].reset_index(drop=True)
496
-
508
+ print(col_to_compare)
497
509
  #if settings['only_top_features']:
498
510
  # column_list = None
499
511
 
@@ -781,8 +793,10 @@ def generate_mediar_masks(src, settings, object_type):
781
793
  if not torch.cuda.is_available():
782
794
  print(f'Torch CUDA is not available, using CPU')
783
795
 
796
+ settings['src'] = src
797
+
784
798
  # Preprocess settings
785
- settings = set_default_settings_preprocess_generate_masks(src, settings)
799
+ settings = set_default_settings_preprocess_generate_masks(settings)
786
800
 
787
801
  if settings['verbose']:
788
802
  settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
spacr/deep_spacr.py CHANGED
@@ -12,6 +12,7 @@ from PIL import Image
12
12
  from sklearn.metrics import auc, precision_recall_curve
13
13
  from IPython.display import display
14
14
  from multiprocessing import cpu_count
15
+ import torch.optim as optim
15
16
 
16
17
  from torchvision import transforms
17
18
  from torch.utils.data import DataLoader
@@ -76,10 +77,11 @@ def apply_model_to_tar(settings={}):
76
77
  from .io import TarImageDataset
77
78
  from .utils import process_vision_results, print_progress
78
79
 
79
- if os.path.exists(settings['dataset']):
80
- tar_path = settings['dataset']
81
- else:
82
- tar_path = os.path.join(settings['src'], 'datasets', settings['dataset'])
80
+ #if os.path.exists(settings['dataset']):
81
+ # tar_path = settings['dataset']
82
+ #else:
83
+ # tar_path = os.path.join(settings['src'], 'datasets', settings['dataset'])
84
+ tar_path = settings['tar_path']
83
85
  model_path = settings['model_path']
84
86
 
85
87
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -103,7 +105,8 @@ def apply_model_to_tar(settings={}):
103
105
  data_loader = DataLoader(dataset, batch_size=settings['batch_size'], shuffle=True, num_workers=settings['n_jobs'], pin_memory=True)
104
106
 
105
107
  model_name = os.path.splitext(os.path.basename(model_path))[0]
106
- dataset_name = os.path.splitext(os.path.basename(settings['dataset']))[0]
108
+ #dataset_name = os.path.splitext(os.path.basename(settings['dataset']))[0]
109
+ dataset_name = os.path.splitext(os.path.basename(settings['tar_path']))[0]
107
110
  date_name = datetime.date.today().strftime('%y%m%d')
108
111
  dst = os.path.dirname(tar_path)
109
112
  result_loc = f'{dst}/{date_name}_{dataset_name}_{model_name}_result.csv'
@@ -934,3 +937,373 @@ def deep_spacr(settings={}):
934
937
 
935
938
  if os.path.exists(settings['model_path']):
936
939
  apply_model_to_tar(settings)
940
+
941
+ def model_knowledge_transfer(
942
+ teacher_paths,
943
+ student_save_path,
944
+ data_loader, # A DataLoader with (images, labels)
945
+ device='cpu',
946
+ student_model_name='maxvit_t',
947
+ pretrained=True,
948
+ dropout_rate=None,
949
+ use_checkpoint=False,
950
+ alpha=0.5,
951
+ temperature=2.0,
952
+ lr=1e-4,
953
+ epochs=10
954
+ ):
955
+ """
956
+ Performs multi-teacher knowledge distillation on a new labeled dataset,
957
+ producing a single student TorchModel that combines (distills) the
958
+ teachers' knowledge plus the labeled data.
959
+
960
+ Usage:
961
+ student = model_knowledge_transfer(
962
+ teacher_paths=[
963
+ 'teacherA.pth',
964
+ 'teacherB.pth',
965
+ ...
966
+ ],
967
+ student_save_path='distilled_student.pth',
968
+ data_loader=my_data_loader,
969
+ device='cuda',
970
+ student_model_name='maxvit_t',
971
+ alpha=0.5,
972
+ temperature=2.0,
973
+ lr=1e-4,
974
+ epochs=10
975
+ )
976
+
977
+ Then load it via:
978
+ fused_student = torch.load('distilled_student.pth')
979
+ # fused_student is a TorchModel instance, ready for inference.
980
+
981
+ Args:
982
+ teacher_paths (list[str]): List of paths to teacher models (TorchModel
983
+ or dict with 'model' in it). They must have the same architecture
984
+ or at least produce the same dimension of output.
985
+ student_save_path (str): Destination path to save the final student
986
+ TorchModel.
987
+ data_loader (DataLoader): Yields (images, labels) from the new dataset.
988
+ device (str): 'cpu' or 'cuda'.
989
+ student_model_name (str): Architecture name for the student TorchModel.
990
+ pretrained (bool): If the student should be initialized as pretrained.
991
+ dropout_rate (float): If needed by your TorchModel init.
992
+ use_checkpoint (bool): If needed by your TorchModel init.
993
+ alpha (float): Weight balancing real-label CE vs. distillation loss
994
+ (0..1).
995
+ temperature (float): Distillation temperature (>1 typically).
996
+ lr (float): Learning rate for the student.
997
+ epochs (int): Number of training epochs.
998
+
999
+ Returns:
1000
+ TorchModel: The final, trained student model.
1001
+ """
1002
+ from spacr.utils import TorchModel # Adapt if needed
1003
+
1004
+ # Adjust filename to reflect knowledge-distillation if desired
1005
+ if student_save_path.endswith('.pth'):
1006
+ base, ext = os.path.splitext(student_save_path)
1007
+ else:
1008
+ base = student_save_path
1009
+ student_save_path = base + '_KD.pth'
1010
+
1011
+ # -- 1. Load teacher models --
1012
+ teachers = []
1013
+ print("Loading teacher models:")
1014
+ for path in teacher_paths:
1015
+ print(f" Loading teacher: {path}")
1016
+ ckpt = torch.load(path, map_location=device)
1017
+ if isinstance(ckpt, TorchModel):
1018
+ teacher = ckpt.to(device)
1019
+ elif isinstance(ckpt, dict):
1020
+ # If it's a dict with 'model' inside
1021
+ # We might need to check if it has 'model_name', etc.
1022
+ # But let's keep it simple: same architecture as the student
1023
+ teacher = TorchModel(
1024
+ model_name=ckpt.get('model_name', student_model_name),
1025
+ pretrained=ckpt.get('pretrained', pretrained),
1026
+ dropout_rate=ckpt.get('dropout_rate', dropout_rate),
1027
+ use_checkpoint=ckpt.get('use_checkpoint', use_checkpoint)
1028
+ ).to(device)
1029
+ teacher.load_state_dict(ckpt['model'])
1030
+ else:
1031
+ raise ValueError(f"Unsupported checkpoint type at {path} (must be TorchModel or dict).")
1032
+
1033
+ teacher.eval() # For consistent batchnorm, dropout
1034
+ teachers.append(teacher)
1035
+
1036
+ # -- 2. Initialize the student TorchModel --
1037
+ student_model = TorchModel(
1038
+ model_name=student_model_name,
1039
+ pretrained=pretrained,
1040
+ dropout_rate=dropout_rate,
1041
+ use_checkpoint=use_checkpoint
1042
+ ).to(device)
1043
+
1044
+ # You could load a partial checkpoint into the student here if desired.
1045
+
1046
+ # -- 3. Optimizer --
1047
+ optimizer = optim.Adam(student_model.parameters(), lr=lr)
1048
+
1049
+ # Distillation training loop
1050
+ for epoch in range(epochs):
1051
+ student_model.train()
1052
+ running_loss = 0.0
1053
+
1054
+ for images, labels in data_loader:
1055
+ images, labels = images.to(device), labels.to(device)
1056
+
1057
+ # Forward pass student
1058
+ logits_s = student_model(images) # shape: (B, num_classes)
1059
+ logits_s_temp = logits_s / temperature # scale by T
1060
+
1061
+ # Distillation from teachers
1062
+ with torch.no_grad():
1063
+ # We'll average teacher probabilities
1064
+ teacher_probs_list = []
1065
+ for tm in teachers:
1066
+ logits_t = tm(images) / temperature
1067
+ # convert to probabilities
1068
+ teacher_probs_list.append(F.softmax(logits_t, dim=1))
1069
+ # average them
1070
+ teacher_probs_ensemble = torch.mean(torch.stack(teacher_probs_list), dim=0)
1071
+
1072
+ # Student probabilities (log-softmax)
1073
+ student_log_probs = F.log_softmax(logits_s_temp, dim=1)
1074
+
1075
+ # Distillation loss => KLDiv
1076
+ loss_distill = F.kl_div(
1077
+ student_log_probs,
1078
+ teacher_probs_ensemble,
1079
+ reduction='batchmean'
1080
+ ) * (temperature ** 2)
1081
+
1082
+ # Real label loss => cross-entropy
1083
+ # We can compute this on the raw logits or scaled. Typically raw logits is standard:
1084
+ loss_ce = F.cross_entropy(logits_s, labels)
1085
+
1086
+ # Weighted sum
1087
+ loss = alpha * loss_ce + (1 - alpha) * loss_distill
1088
+
1089
+ optimizer.zero_grad()
1090
+ loss.backward()
1091
+ optimizer.step()
1092
+
1093
+ running_loss += loss.item()
1094
+
1095
+ avg_loss = running_loss / len(data_loader)
1096
+ print(f"Epoch [{epoch+1}/{epochs}] - Loss: {avg_loss:.4f}")
1097
+
1098
+ # -- 4. Save final student as a TorchModel --
1099
+ torch.save(student_model, student_save_path)
1100
+ print(f"Knowledge-distilled student saved to: {student_save_path}")
1101
+
1102
+ return student_model
1103
+
1104
+ def model_fusion(model_paths,
1105
+ save_path,
1106
+ device='cpu',
1107
+ model_name='maxvit_t',
1108
+ pretrained=True,
1109
+ dropout_rate=None,
1110
+ use_checkpoint=False,
1111
+ aggregator='mean'):
1112
+ """
1113
+ Fuses an arbitrary number of TorchModel instances by combining their weights
1114
+ (using mean, geomean, median, sum, max, or min) and saves the entire fused
1115
+ model object.
1116
+
1117
+ You can later load the fused model with:
1118
+ model = torch.load('fused_model.pth')
1119
+
1120
+ which returns a ready-to-use TorchModel instance.
1121
+
1122
+ Parameters:
1123
+ model_paths (list of str): Paths to the model checkpoints to fuse.
1124
+ Each checkpoint can be:
1125
+ - A dict with keys ['model', 'model_name', ...]
1126
+ - A TorchModel instance
1127
+ save_path (str): Destination path to save the fused model.
1128
+ device (str): 'cpu' or 'cuda' for loading weights and final model device.
1129
+ model_name (str): Default model name (used if not in checkpoint).
1130
+ pretrained (bool): Default if not in checkpoint.
1131
+ dropout_rate (float): Default if not in checkpoint.
1132
+ use_checkpoint (bool): Default if not in checkpoint.
1133
+ aggregator (str): How to combine weights across models:
1134
+ 'mean', 'geomean', 'median', 'sum', 'max', or 'min'.
1135
+
1136
+ Returns:
1137
+ fused_model (TorchModel): The final fused TorchModel instance
1138
+ with combined weights.
1139
+ """
1140
+ from spacr.utils import TorchModel
1141
+
1142
+ if save_path.endswith('.pth'):
1143
+ save_path_part1, ext = os.path.splitext(save_path)
1144
+ else:
1145
+ save_path_part1 = save_path
1146
+
1147
+ save_path = save_path_part1 + f'_{aggregator}.pth'
1148
+
1149
+ valid_aggregators = {'mean', 'geomean', 'median', 'sum', 'max', 'min'}
1150
+ if aggregator not in valid_aggregators:
1151
+ raise ValueError(f"Invalid aggregator '{aggregator}'. "
1152
+ f"Must be one of {valid_aggregators}.")
1153
+
1154
+ # --- 1. Load the first checkpoint to figure out architecture & hyperparams ---
1155
+ print(f"Loading the first model from: {model_paths[0]} to derive architecture")
1156
+ first_ckpt = torch.load(model_paths[0], map_location=device)
1157
+
1158
+ if isinstance(first_ckpt, dict):
1159
+ # It's a dict with state_dict + possibly metadata
1160
+ # Use any stored metadata if present
1161
+ model_name = first_ckpt.get('model_name', model_name)
1162
+ pretrained = first_ckpt.get('pretrained', pretrained)
1163
+ dropout_rate = first_ckpt.get('dropout_rate', dropout_rate)
1164
+ use_checkpoint = first_ckpt.get('use_checkpoint', use_checkpoint)
1165
+
1166
+ # Initialize the fused model
1167
+ fused_model = TorchModel(
1168
+ model_name=model_name,
1169
+ pretrained=pretrained,
1170
+ dropout_rate=dropout_rate,
1171
+ use_checkpoint=use_checkpoint
1172
+ ).to(device)
1173
+
1174
+ # We'll collect state dicts in a list
1175
+ state_dicts = [first_ckpt['model']] # the actual weights
1176
+ elif isinstance(first_ckpt, TorchModel):
1177
+ # The checkpoint is directly a TorchModel instance
1178
+ fused_model = first_ckpt.to(device)
1179
+ state_dicts = [fused_model.state_dict()]
1180
+ else:
1181
+ raise ValueError("Unsupported checkpoint format. Must be a dict or a TorchModel instance.")
1182
+
1183
+ # --- 2. Load the rest of the checkpoints ---
1184
+ for path in model_paths[1:]:
1185
+ print(f"Loading model from: {path}")
1186
+ ckpt = torch.load(path, map_location=device)
1187
+ if isinstance(ckpt, dict):
1188
+ state_dicts.append(ckpt['model']) # Just the state dict portion
1189
+ elif isinstance(ckpt, TorchModel):
1190
+ state_dicts.append(ckpt.state_dict())
1191
+ else:
1192
+ raise ValueError(f"Unsupported checkpoint format in {path} (must be dict or TorchModel).")
1193
+
1194
+ # --- 3. Verify all state dicts have the same keys ---
1195
+ fused_sd = fused_model.state_dict()
1196
+ for sd in state_dicts:
1197
+ if fused_sd.keys() != sd.keys():
1198
+ raise ValueError("All models must have identical architecture/state_dict keys.")
1199
+
1200
+ # --- 4. Define aggregator logic ---
1201
+ def combine_tensors(tensor_list, mode='mean'):
1202
+ """Given a list of Tensors, combine them using the chosen aggregator."""
1203
+ # stack along new dimension => shape (num_models, *tensor.shape)
1204
+ stacked = torch.stack(tensor_list, dim=0).float()
1205
+
1206
+ if mode == 'mean':
1207
+ return stacked.mean(dim=0)
1208
+ elif mode == 'geomean':
1209
+ # geometric mean = exp(mean(log(tensor)))
1210
+ # caution: requires all > 0
1211
+ return torch.exp(torch.log(stacked).mean(dim=0))
1212
+ elif mode == 'median':
1213
+ return stacked.median(dim=0).values
1214
+ elif mode == 'sum':
1215
+ return stacked.sum(dim=0)
1216
+ elif mode == 'max':
1217
+ return stacked.max(dim=0).values
1218
+ elif mode == 'min':
1219
+ return stacked.min(dim=0).values
1220
+ else:
1221
+ raise ValueError(f"Unsupported aggregator: {mode}")
1222
+
1223
+ # --- 5. Combine the weights ---
1224
+ for key in fused_sd.keys():
1225
+ # gather all versions of this tensor
1226
+ all_tensors = [sd[key] for sd in state_dicts]
1227
+ fused_sd[key] = combine_tensors(all_tensors, mode=aggregator)
1228
+
1229
+ # Load combined weights into the fused model
1230
+ fused_model.load_state_dict(fused_sd)
1231
+
1232
+ # --- 6. Save the entire TorchModel object ---
1233
+ torch.save(fused_model, save_path)
1234
+ print(f"Fused model (aggregator='{aggregator}') saved as a full TorchModel to: {save_path}")
1235
+
1236
+ return fused_model
1237
+
1238
+ def annotate_filter_vision(settings):
1239
+
1240
+ from .utils import annotate_conditions
1241
+
1242
+ def filter_csv_by_png(csv_file):
1243
+ """
1244
+ Filters a DataFrame by removing rows that match PNG filenames in a folder.
1245
+
1246
+ Parameters:
1247
+ csv_file (str): Path to the CSV file.
1248
+
1249
+ Returns:
1250
+ pd.DataFrame: Filtered DataFrame.
1251
+ """
1252
+ # Split the path to identify the datasets folder and build the training folder path
1253
+ before_datasets, after_datasets = csv_file.split(os.sep + "datasets" + os.sep, 1)
1254
+ train_fldr = os.path.join(before_datasets, 'datasets', 'training', 'train')
1255
+
1256
+ # Paths for train/nc and train/pc
1257
+ nc_folder = os.path.join(train_fldr, 'nc')
1258
+ pc_folder = os.path.join(train_fldr, 'pc')
1259
+
1260
+ # Load the CSV file into a DataFrame
1261
+ df = pd.read_csv(csv_file)
1262
+
1263
+ # Collect PNG filenames from train/nc and train/pc
1264
+ png_files = set()
1265
+ for folder in [nc_folder, pc_folder]:
1266
+ if os.path.exists(folder): # Ensure the folder exists
1267
+ png_files.update({file for file in os.listdir(folder) if file.endswith(".png")})
1268
+
1269
+ # Filter the DataFrame by excluding rows where filenames match PNG files
1270
+ filtered_df = df[~df['path'].isin(png_files)]
1271
+
1272
+ return filtered_df
1273
+
1274
+ if isinstance(settings['src'], str):
1275
+ settings['src'] = [settings['src']]
1276
+
1277
+ for src in settings['src']:
1278
+ ann_src, ext = os.path.splitext(src)
1279
+ output_csv = ann_src+'_annotated_filtered.csv'
1280
+ print(output_csv)
1281
+
1282
+ df = pd.read_csv(src)
1283
+
1284
+ if 'column_name' not in df.columns:
1285
+ df['column_name'] = df['column']
1286
+
1287
+ df = annotate_conditions(df,
1288
+ cells=settings['cells'],
1289
+ cell_loc=settings['cell_loc'],
1290
+ pathogens=settings['pathogens'],
1291
+ pathogen_loc=settings['pathogen_loc'],
1292
+ treatments=settings['treatments'],
1293
+ treatment_loc=settings['treatment_loc'])
1294
+
1295
+ if not settings['filter_column'] is None:
1296
+ if settings['filter_column'] in df.columns:
1297
+ filtered_df = df[(df[settings['filter_column']] > settings['upper_threshold']) | (df[settings['filter_column']] < settings['lower_threshold'])]
1298
+ print(f'Filtered DataFrame with {len(df)} rows to {len(filtered_df)} rows.')
1299
+ else:
1300
+ print(f"{settings['filter_column']} not in DataFrame columns.")
1301
+ filtered_df = df
1302
+ else:
1303
+ filtered_df = df
1304
+
1305
+ filtered_df.to_csv(output_csv, index=False)
1306
+
1307
+ if settings['remove_train']:
1308
+ df = filter_csv_by_png(output_csv)
1309
+ df.to_csv(output_csv, index=False)