spacr 0.2.53__py3-none-any.whl → 0.2.61__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/core.py CHANGED
@@ -16,7 +16,6 @@ import seaborn as sns
16
16
  import cellpose
17
17
  from skimage.measure import regionprops, label
18
18
  from skimage.transform import resize as resizescikit
19
- from torch.utils.data import DataLoader
20
19
 
21
20
  from skimage import measure
22
21
  from sklearn.model_selection import train_test_split
@@ -43,6 +42,16 @@ import warnings
43
42
  warnings.filterwarnings("ignore", message="3D stack used, but stitch_threshold=0 and do_3D=False, so masks are made per plane only")
44
43
 
45
44
 
45
+ from torchvision import transforms
46
+ from torch.utils.data import DataLoader, random_split
47
+ from collections import defaultdict
48
+ import os
49
+ import random
50
+ from PIL import Image
51
+ from torchvision.transforms import ToTensor
52
+
53
+
54
+
46
55
  def analyze_plaques(folder):
47
56
  summary_data = []
48
57
  details_data = []
@@ -877,22 +886,22 @@ def annotate_results(pred_loc):
877
886
  display(df)
878
887
  return df
879
888
 
880
- def generate_dataset(src, file_metadata=None, experiment='TSG101_screen', sample=None):
889
+ def generate_dataset(settings={}):
881
890
 
882
891
  from .utils import initiate_counter, add_images_to_tar
883
892
 
884
- db_path = os.path.join(src, 'measurements', 'measurements.db')
885
- dst = os.path.join(src, 'datasets')
893
+ db_path = os.path.join(settings['src'], 'measurements', 'measurements.db')
894
+ dst = os.path.join(settings['src'], 'datasets')
886
895
  all_paths = []
887
896
 
888
897
  # Connect to the database and retrieve the image paths
889
- print(f'Reading DataBase: {db_path}')
898
+ print(f"Reading DataBase: {db_path}")
890
899
  try:
891
900
  with sqlite3.connect(db_path) as conn:
892
901
  cursor = conn.cursor()
893
- if file_metadata:
894
- if isinstance(file_metadata, str):
895
- cursor.execute("SELECT png_path FROM png_list WHERE png_path LIKE ?", (f"%{file_metadata}%",))
902
+ if settings['file_metadata']:
903
+ if isinstance(settings['file_metadata'], str):
904
+ cursor.execute("SELECT png_path FROM png_list WHERE png_path LIKE ?", (f"%{settings['file_metadata']}%",))
896
905
  else:
897
906
  cursor.execute("SELECT png_path FROM png_list")
898
907
 
@@ -909,16 +918,16 @@ def generate_dataset(src, file_metadata=None, experiment='TSG101_screen', sample
909
918
  print(f"Error: {e}")
910
919
  return
911
920
 
912
- if isinstance(sample, int):
913
- selected_paths = random.sample(all_paths, sample)
914
- print(f'Random selection of {len(selected_paths)} paths')
921
+ if isinstance(settings['sample'], int):
922
+ selected_paths = random.sample(all_paths, settings['sample'])
923
+ print(f"Random selection of {len(selected_paths)} paths")
915
924
  else:
916
925
  selected_paths = all_paths
917
926
  random.shuffle(selected_paths)
918
- print(f'All paths: {len(selected_paths)} paths')
927
+ print(f"All paths: {len(selected_paths)} paths")
919
928
 
920
929
  total_images = len(selected_paths)
921
- print(f'Found {total_images} images')
930
+ print(f"Found {total_images} images")
922
931
 
923
932
  # Create a temp folder in dst
924
933
  temp_dir = os.path.join(dst, "temp_tars")
@@ -936,9 +945,9 @@ def generate_dataset(src, file_metadata=None, experiment='TSG101_screen', sample
936
945
  paths_chunks.append(selected_paths[start:end])
937
946
  start = end
938
947
 
939
- temp_tar_files = [os.path.join(temp_dir, f'temp_{i}.tar') for i in range(num_procs)]
948
+ temp_tar_files = [os.path.join(temp_dir, f"temp_{i}.tar") for i in range(num_procs)]
940
949
 
941
- print(f'Generating temporary tar files in {dst}')
950
+ print(f"Generating temporary tar files in {dst}")
942
951
 
943
952
  # Initialize shared counter and lock
944
953
  counter = Value('i', 0)
@@ -949,18 +958,18 @@ def generate_dataset(src, file_metadata=None, experiment='TSG101_screen', sample
949
958
 
950
959
  # Combine the temporary tar files into a final tar
951
960
  date_name = datetime.date.today().strftime('%y%m%d')
952
- if not file_metadata is None:
953
- tar_name = f'{date_name}_{experiment}_{file_metadata}.tar'
961
+ if not settings['file_metadata'] is None:
962
+ tar_name = f"{date_name}_{settings['experiment']}_{settings['file_metadata']}.tar"
954
963
  else:
955
- tar_name = f'{date_name}_{experiment}.tar'
964
+ tar_name = f"{date_name}_{settings['experiment']}.tar"
956
965
  tar_name = os.path.join(dst, tar_name)
957
966
  if os.path.exists(tar_name):
958
967
  number = random.randint(1, 100)
959
- tar_name_2 = f'{date_name}_{experiment}_{file_metadata}_{number}.tar'
960
- print(f'Warning: {os.path.basename(tar_name)} exists, saving as {os.path.basename(tar_name_2)} ')
968
+ tar_name_2 = f"{date_name}_{settings['experiment']}_{settings['file_metadata']}_{number}.tar"
969
+ print(f"Warning: {os.path.basename(tar_name)} exists, saving as {os.path.basename(tar_name_2)} ")
961
970
  tar_name = os.path.join(dst, tar_name_2)
962
971
 
963
- print(f'Merging temporary files')
972
+ print(f"Merging temporary files")
964
973
 
965
974
  with tarfile.open(tar_name, 'w') as final_tar:
966
975
  for temp_tar_path in temp_tar_files:
@@ -974,41 +983,43 @@ def generate_dataset(src, file_metadata=None, experiment='TSG101_screen', sample
974
983
  shutil.rmtree(temp_dir)
975
984
  print(f"\nSaved {total_images} images to {tar_name}")
976
985
 
977
- def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=224, batch_size=64, normalize=True, preload='images', n_jobs=10, threshold=0.5, verbose=False):
986
+ return tar_name
987
+
988
+ def apply_model_to_tar(settings={}):
978
989
 
979
990
  from .io import TarImageDataset
980
991
  from .utils import process_vision_results, print_progress
981
992
 
982
993
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
983
- if normalize:
994
+ if settings['normalize']:
984
995
  transform = transforms.Compose([
985
996
  transforms.ToTensor(),
986
- transforms.CenterCrop(size=(image_size, image_size)),
997
+ transforms.CenterCrop(size=(settings['image_size'], settings['image_size'])),
987
998
  transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
988
999
  else:
989
1000
  transform = transforms.Compose([
990
1001
  transforms.ToTensor(),
991
- transforms.CenterCrop(size=(image_size, image_size))])
1002
+ transforms.CenterCrop(size=(settings['image_size'], settings['image_size']))])
992
1003
 
993
- if verbose:
994
- print(f'Loading model from {model_path}')
995
- print(f'Loading dataset from {tar_path}')
1004
+ if settings['verbose']:
1005
+ print(f"Loading model from {settings['model_path']}")
1006
+ print(f"Loading dataset from {settings['tar_path']}")
996
1007
 
997
- model = torch.load(model_path)
1008
+ model = torch.load(settings['model_path'])
998
1009
 
999
- dataset = TarImageDataset(tar_path, transform=transform)
1000
- data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=n_jobs, pin_memory=True)
1010
+ dataset = TarImageDataset(settings['tar_path'], transform=transform)
1011
+ data_loader = DataLoader(dataset, batch_size=settings['batch_size'], shuffle=True, num_workers=settings['n_jobs'], pin_memory=True)
1001
1012
 
1002
- model_name = os.path.splitext(os.path.basename(model_path))[0]
1003
- dataset_name = os.path.splitext(os.path.basename(tar_path))[0]
1013
+ model_name = os.path.splitext(os.path.basename(settings['model_path']))[0]
1014
+ dataset_name = os.path.splitext(os.path.basename(settings['tar_path']))[0]
1004
1015
  date_name = datetime.date.today().strftime('%y%m%d')
1005
- dst = os.path.dirname(tar_path)
1016
+ dst = os.path.dirname(settings['tar_path'])
1006
1017
  result_loc = f'{dst}/{date_name}_{dataset_name}_{model_name}_result.csv'
1007
1018
 
1008
1019
  model.eval()
1009
1020
  model = model.to(device)
1010
1021
 
1011
- if verbose:
1022
+ if settings['verbose']:
1012
1023
  print(model)
1013
1024
  print(f'Generated dataset with {len(dataset)} images')
1014
1025
  print(f'Generating loader from {len(data_loader)} batches')
@@ -1031,13 +1042,13 @@ def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=22
1031
1042
  stop = time.time()
1032
1043
  duration = stop - start
1033
1044
  time_ls.append(duration)
1034
- files_processed = batch_idx*batch_size
1045
+ files_processed = batch_idx*settings['batch_size']
1035
1046
  files_to_process = len(data_loader)
1036
- print_progress(files_processed, files_to_process, n_jobs=n_jobs, time_ls=time_ls, batch_size=batch_size, operation_type="Tar dataset")
1047
+ print_progress(files_processed, files_to_process, n_jobs=settings['n_jobs'], time_ls=time_ls, batch_size=settings['batch_size'], operation_type="Tar dataset")
1037
1048
 
1038
1049
  data = {'path':filenames_list, 'pred':prediction_pos_probs}
1039
1050
  df = pd.DataFrame(data, index=None)
1040
- df = process_vision_results(df, threshold)
1051
+ df = process_vision_results(df, settings['score_threshold'])
1041
1052
 
1042
1053
  df.to_csv(result_loc, index=True, header=True, mode='w')
1043
1054
  torch.cuda.empty_cache()
@@ -1206,19 +1217,19 @@ def generate_dataset_from_lists(dst, class_data, classes, test_split=0.1):
1206
1217
  for path in train_data:
1207
1218
  start = time.time()
1208
1219
  shutil.copy(path, os.path.join(train_class_dir, os.path.basename(path)))
1209
- processed_files += 1
1210
1220
  duration = time.time() - start
1211
1221
  time_ls.append(duration)
1212
1222
  print_progress(processed_files, total_files, n_jobs=1, time_ls=None, batch_size=None, operation_type="Copying files for Train dataset")
1223
+ processed_files += 1
1213
1224
 
1214
1225
  # Copy test files
1215
1226
  for path in test_data:
1216
1227
  start = time.time()
1217
1228
  shutil.copy(path, os.path.join(test_class_dir, os.path.basename(path)))
1218
- processed_files += 1
1219
1229
  duration = time.time() - start
1220
1230
  time_ls.append(duration)
1221
1231
  print_progress(processed_files, total_files, n_jobs=1, time_ls=None, batch_size=None, operation_type="Copying files for Test dataset")
1232
+ processed_files += 1
1222
1233
 
1223
1234
  # Print summary
1224
1235
  for cls in classes:
@@ -1226,44 +1237,47 @@ def generate_dataset_from_lists(dst, class_data, classes, test_split=0.1):
1226
1237
  test_class_dir = os.path.join(dst, f'test/{cls}')
1227
1238
  print(f'Train class {cls}: {len(os.listdir(train_class_dir))}, Test class {cls}: {len(os.listdir(test_class_dir))}')
1228
1239
 
1229
- return
1240
+ return os.path.join(dst, 'train'), os.path.join(dst, 'test')
1230
1241
 
1231
- def generate_training_dataset(src, mode='annotation', annotation_column='test', annotated_classes=[1,2], classes=['nc','pc'], size=200, test_split=0.1, class_metadata=[['c1'],['c2']], metadata_type_by='col', channel_of_interest=3, custom_measurement=None, tables=None, png_type='cell_png'):
1242
+ def generate_training_dataset(settings):
1232
1243
 
1233
1244
  from .io import _read_and_merge_data, _read_db
1234
1245
  from .utils import get_paths_from_db, annotate_conditions
1246
+ from .settings import set_generate_training_dataset_defaults
1247
+
1248
+ settings = set_generate_training_dataset_defaults(settings)
1235
1249
 
1236
- db_path = os.path.join(src, 'measurements','measurements.db')
1237
- dst = os.path.join(src, 'datasets', 'training')
1250
+ db_path = os.path.join(settings['src'], 'measurements','measurements.db')
1251
+ dst = os.path.join(settings['src'], 'datasets', 'training')
1238
1252
 
1239
1253
  if os.path.exists(dst):
1240
1254
  for i in range(1, 1000):
1241
- dst = os.path.join(src, 'datasets', f'training_{i}')
1255
+ dst = os.path.join(settings['src'], 'datasets', f'training_{i}')
1242
1256
  if not os.path.exists(dst):
1243
1257
  print(f'Creating new directory for training: {dst}')
1244
1258
  break
1245
1259
 
1246
- if mode == 'annotation':
1260
+ if settings['dataset_mode'] == 'annotation':
1247
1261
  class_paths_ls_2 = []
1248
- class_paths_ls = training_dataset_from_annotation(db_path, dst, annotation_column, annotated_classes=annotated_classes)
1262
+ class_paths_ls = training_dataset_from_annotation(db_path, dst, settings['annotation_column'], annotated_classes=settings['annotated_classes'])
1249
1263
  for class_paths in class_paths_ls:
1250
- class_paths_temp = random.sample(class_paths, size)
1264
+ class_paths_temp = random.sample(class_paths, settings['size'])
1251
1265
  class_paths_ls_2.append(class_paths_temp)
1252
1266
  class_paths_ls = class_paths_ls_2
1253
1267
 
1254
- elif mode == 'metadata':
1268
+ elif settings['dataset_mode'] == 'metadata':
1255
1269
  class_paths_ls = []
1256
1270
  class_len_ls = []
1257
1271
  [df] = _read_db(db_loc=db_path, tables=['png_list'])
1258
1272
  df['metadata_based_class'] = pd.NA
1259
- for i, class_ in enumerate(classes):
1260
- ls = class_metadata[i]
1261
- df.loc[df[metadata_type_by].isin(ls), 'metadata_based_class'] = class_
1273
+ for i, class_ in enumerate(settings['classes']):
1274
+ ls = settings['class_metadata'][i]
1275
+ df.loc[df[settings['metadata_type_by']].isin(ls), 'metadata_based_class'] = class_
1262
1276
 
1263
- for class_ in classes:
1264
- if size == None:
1277
+ for class_ in settings['classes']:
1278
+ if settings['size'] == None:
1265
1279
  c_s = []
1266
- for c in classes:
1280
+ for c in settings['classes']:
1267
1281
  c_s_t_df = df[df['metadata_based_class'] == c]
1268
1282
  c_s.append(len(c_s_t_df))
1269
1283
  print(f'Found {len(c_s_t_df)} images for class {c}')
@@ -1273,12 +1287,12 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
1273
1287
  class_temp_df = df[df['metadata_based_class'] == class_]
1274
1288
  class_len_ls.append(len(class_temp_df))
1275
1289
  print(f'Found {len(class_temp_df)} images for class {class_}')
1276
- class_paths_temp = random.sample(class_temp_df['png_path'].tolist(), size)
1290
+ class_paths_temp = random.sample(class_temp_df['png_path'].tolist(), settings['size'])
1277
1291
  class_paths_ls.append(class_paths_temp)
1278
1292
 
1279
- elif mode == 'recruitment':
1293
+ elif settings['dataset_mode'] == 'recruitment':
1280
1294
  class_paths_ls = []
1281
- if not isinstance(tables, list):
1295
+ if not isinstance(settings['tables'], list):
1282
1296
  tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
1283
1297
 
1284
1298
  df, _ = _read_and_merge_data(locs=[db_path],
@@ -1290,60 +1304,58 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
1290
1304
 
1291
1305
  print('length df 1', len(df))
1292
1306
 
1293
- df = annotate_conditions(df, cells=['HeLa'], cell_loc=None, pathogens=['pathogen'], pathogen_loc=None, treatments=classes, treatment_loc=class_metadata, types = ['col','col',metadata_type_by])
1307
+ df = annotate_conditions(df, cells=['HeLa'], cell_loc=None, pathogens=['pathogen'], pathogen_loc=None, treatments=settings['classes'], treatment_loc=settings['class_metadata'], types = settings['metadata_type_by'])
1294
1308
  print('length df 2', len(df))
1295
1309
  [png_list_df] = _read_db(db_loc=db_path, tables=['png_list'])
1296
1310
 
1297
- if custom_measurement != None:
1311
+ if settings['custom_measurement'] != None:
1298
1312
 
1299
- if not isinstance(custom_measurement, list):
1313
+ if not isinstance(settings['custom_measurement'], list):
1300
1314
  print(f'custom_measurement should be a list, add [ measurement_1, measurement_2 ] or [ measurement ]')
1301
1315
  return
1302
1316
 
1303
- if isinstance(custom_measurement, list):
1304
- if len(custom_measurement) == 2:
1305
- print(f'Classes will be defined by the Q1 and Q3 quantiles of recruitment ({custom_measurement[0]}/{custom_measurement[1]})')
1306
- df['recruitment'] = df[f'{custom_measurement[0]}']/df[f'{custom_measurement[1]}']
1307
- if len(custom_measurement) == 1:
1308
- print(f'Classes will be defined by the Q1 and Q3 quantiles of recruitment ({custom_measurement[0]})')
1309
- df['recruitment'] = df[f'{custom_measurement[0]}']
1317
+ if isinstance(settings['custom_measurement'], list):
1318
+ if len(settings['custom_measurement']) == 2:
1319
+ print(f"Classes will be defined by the Q1 and Q3 quantiles of recruitment ({settings['custom_measurement'][0]}/{settings['custom_measurement'][1]})")
1320
+ df['recruitment'] = df[f"{settings['custom_measurement'][0]}']/df[f'{settings['custom_measurement'][1]}"]
1321
+ if len(settings['custom_measurement']) == 1:
1322
+ print(f"Classes will be defined by the Q1 and Q3 quantiles of recruitment ({settings['custom_measurement'][0]})")
1323
+ df['recruitment'] = df[f"{settings['custom_measurement'][0]}"]
1310
1324
  else:
1311
- print(f'Classes will be defined by the Q1 and Q3 quantiles of recruitment (pathogen/cytoplasm for channel {channel_of_interest})')
1312
- df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_mean_intensity']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
1325
+ print(f"Classes will be defined by the Q1 and Q3 quantiles of recruitment (pathogen/cytoplasm for channel {settings['channel_of_interest']})")
1326
+ df['recruitment'] = df[f"pathogen_channel_{settings['channel_of_interest']}_mean_intensity']/df[f'cytoplasm_channel_{settings['channel_of_interest']}_mean_intensity"]
1313
1327
 
1314
1328
  q25 = df['recruitment'].quantile(0.25)
1315
1329
  q75 = df['recruitment'].quantile(0.75)
1316
1330
  df_lower = df[df['recruitment'] <= q25]
1317
1331
  df_upper = df[df['recruitment'] >= q75]
1318
1332
 
1319
- class_paths_lower = get_paths_from_db(df=df_lower, png_df=png_list_df, image_type=png_type)
1333
+ class_paths_lower = get_paths_from_db(df=df_lower, png_df=png_list_df, image_type=settings['png_type'])
1320
1334
 
1321
- class_paths_lower = random.sample(class_paths_lower['png_path'].tolist(), size)
1335
+ class_paths_lower = random.sample(class_paths_lower['png_path'].tolist(), settings['size'])
1322
1336
  class_paths_ls.append(class_paths_lower)
1323
1337
 
1324
- class_paths_upper = get_paths_from_db(df=df_upper, png_df=png_list_df, image_type=png_type)
1325
- class_paths_upper = random.sample(class_paths_upper['png_path'].tolist(), size)
1338
+ class_paths_upper = get_paths_from_db(df=df_upper, png_df=png_list_df, image_type=settings['png_type'])
1339
+ class_paths_upper = random.sample(class_paths_upper['png_path'].tolist(), settings['size'])
1326
1340
  class_paths_ls.append(class_paths_upper)
1327
1341
 
1328
- generate_dataset_from_lists(dst, class_data=class_paths_ls, classes=classes, test_split=0.1)
1342
+ train_class_dir, test_class_dir = generate_dataset_from_lists(dst, class_data=class_paths_ls, classes=settings['classes'], test_split=settings['test_split'])
1329
1343
 
1330
- return
1344
+ return train_class_dir, test_class_dir
1331
1345
 
1332
- def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], n_jobs=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, channels=[1, 2, 3], augment=False, verbose=False):
1346
+ def generate_loaders(src, mode='train', image_size=224, batch_size=32, classes=['nc','pc'], n_jobs=None, validation_split=0.0, pin_memory=False, normalize=False, channels=[1, 2, 3], augment=False, preload_batches=3, verbose=False):
1333
1347
 
1334
1348
  """
1335
1349
  Generate data loaders for training and validation/test datasets.
1336
1350
 
1337
1351
  Parameters:
1338
1352
  - src (str): The source directory containing the data.
1339
- - train_mode (str): The training mode. Options are 'erm' (Empirical Risk Minimization) or 'irm' (Invariant Risk Minimization).
1340
1353
  - mode (str): The mode of operation. Options are 'train' or 'test'.
1341
1354
  - image_size (int): The size of the input images.
1342
1355
  - batch_size (int): The batch size for the data loaders.
1343
1356
  - classes (list): The list of classes to consider.
1344
1357
  - n_jobs (int): The number of worker threads for data loading.
1345
- - validation_split (float): The fraction of data to use for validation when train_mode is 'erm'.
1346
- - max_show (int): The maximum number of images to show when verbose is True.
1358
+ - validation_split (float): The fraction of data to use for validation.
1347
1359
  - pin_memory (bool): Whether to pin memory for faster data transfer.
1348
1360
  - normalize (bool): Whether to normalize the input images.
1349
1361
  - verbose (bool): Whether to print additional information and show images.
@@ -1352,18 +1364,10 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
1352
1364
  Returns:
1353
1365
  - train_loaders (list): List of data loaders for training datasets.
1354
1366
  - val_loaders (list): List of data loaders for validation datasets.
1355
- - plate_names (list): List of plate names (only applicable when train_mode is 'irm').
1356
1367
  """
1357
1368
 
1358
- from .io import MyDataset
1359
- from .plot import _imshow
1360
- from torchvision import transforms
1361
- from torch.utils.data import DataLoader, random_split
1362
- from collections import defaultdict
1363
- import os
1364
- import random
1365
- from PIL import Image
1366
- from torchvision.transforms import ToTensor
1369
+ from .io import spacrDataset, spacrDataLoader
1370
+ from .plot import _imshow_gpu
1367
1371
  from .utils import SelectChannels, augment_dataset
1368
1372
 
1369
1373
  chans = []
@@ -1380,12 +1384,9 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
1380
1384
  if verbose:
1381
1385
  print(f'Training a network on channels: {channels}')
1382
1386
  print(f'Channel 1: Red, Channel 2: Green, Channel 3: Blue')
1383
-
1384
- plate_to_filenames = defaultdict(list)
1385
- plate_to_labels = defaultdict(list)
1387
+
1386
1388
  train_loaders = []
1387
1389
  val_loaders = []
1388
- plate_names = []
1389
1390
 
1390
1391
  if normalize:
1391
1392
  transform = transforms.Compose([
@@ -1413,157 +1414,114 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
1413
1414
  print(f'mode:{mode} is not valid, use mode = train or test')
1414
1415
  return
1415
1416
 
1416
- if train_mode == 'erm':
1417
-
1418
- data = MyDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
1419
-
1420
- if validation_split > 0:
1421
- train_size = int((1 - validation_split) * len(data))
1422
- val_size = len(data) - train_size
1423
- if not augment:
1424
- print(f'Train data:{train_size}, Validation data:{val_size}')
1425
- train_dataset, val_dataset = random_split(data, [train_size, val_size])
1426
-
1427
- if augment:
1428
-
1429
- print(f'Data before augmentation: Train: {len(train_dataset)}, Validataion:{len(val_dataset)}')
1430
- train_dataset = augment_dataset(train_dataset, is_grayscale=(len(channels) == 1))
1431
- #val_dataset = augment_dataset(val_dataset, is_grayscale=(len(channels) == 1))
1432
- print(f'Data after augmentation: Train: {len(train_dataset)}')#, Validataion:{len(val_dataset)}')
1433
-
1434
- train_loaders = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=n_jobs if n_jobs is not None else 0, pin_memory=pin_memory)
1435
- val_loaders = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=n_jobs if n_jobs is not None else 0, pin_memory=pin_memory)
1436
- else:
1437
- train_loaders = DataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=n_jobs if n_jobs is not None else 0, pin_memory=pin_memory)
1438
-
1439
- elif train_mode == 'irm':
1440
- data = MyDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
1441
-
1442
- for filename, label in zip(data.filenames, data.labels):
1443
- plate = data.get_plate(filename)
1444
- plate_to_filenames[plate].append(filename)
1445
- plate_to_labels[plate].append(label)
1446
-
1447
- for plate, filenames in plate_to_filenames.items():
1448
- labels = plate_to_labels[plate]
1449
- plate_data = MyDataset(data_dir, classes, specific_files=filenames, specific_labels=labels, transform=transform, shuffle=False, pin_memory=pin_memory)
1450
- plate_names.append(plate)
1451
-
1452
- if validation_split > 0:
1453
- train_size = int((1 - validation_split) * len(plate_data))
1454
- val_size = len(plate_data) - train_size
1455
- if not augment:
1456
- print(f'Train data:{train_size}, Validation data:{val_size}')
1457
- train_dataset, val_dataset = random_split(plate_data, [train_size, val_size])
1458
-
1459
- if augment:
1460
-
1461
- print(f'Data before augmentation: Train: {len(train_dataset)}, Validataion:{val_dataset}')
1462
- train_dataset = augment_dataset(train_dataset, is_grayscale=(len(channels) == 1))
1463
- #val_dataset = augment_dataset(val_dataset, is_grayscale=(len(channels) == 1))
1464
- print(f'Data after augmentation: Train: {len(train_dataset)}')#, Validataion:{len(val_dataset)}')
1465
-
1466
- train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=n_jobs if n_jobs is not None else 0, pin_memory=pin_memory)
1467
- val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=n_jobs if n_jobs is not None else 0, pin_memory=pin_memory)
1468
-
1469
- train_loaders.append(train_loader)
1470
- val_loaders.append(val_loader)
1471
- else:
1472
- train_loader = DataLoader(plate_data, batch_size=batch_size, shuffle=shuffle, num_workers=n_jobs if n_jobs is not None else 0, pin_memory=pin_memory)
1473
- train_loaders.append(train_loader)
1474
- val_loaders.append(None)
1475
-
1476
- else:
1477
- print(f'train_mode:{train_mode} is not valid, use: train_mode = irm or erm')
1478
- return
1479
-
1417
+ data = spacrDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
1418
+ num_workers = n_jobs if n_jobs is not None else 0
1480
1419
 
1481
- if train_mode == 'erm':
1482
- for idx, (images, labels, filenames) in enumerate(train_loaders):
1483
- if idx >= max_show:
1484
- break
1485
- images = images.cpu()
1486
- label_strings = [str(label.item()) for label in labels]
1487
- train_fig = _imshow(images, label_strings, nrow=20, fontsize=12)
1488
- if verbose:
1489
- plt.show()
1420
+ if validation_split > 0:
1421
+ train_size = int((1 - validation_split) * len(data))
1422
+ val_size = len(data) - train_size
1423
+ if not augment:
1424
+ print(f'Train data:{train_size}, Validation data:{val_size}')
1425
+ train_dataset, val_dataset = random_split(data, [train_size, val_size])
1490
1426
 
1491
- elif train_mode == 'irm':
1492
- for plate_name, train_loader in zip(plate_names, train_loaders):
1493
- print(f'Plate: {plate_name} with {len(train_loader.dataset)} images')
1494
- for idx, (images, labels, filenames) in enumerate(train_loader):
1495
- if idx >= max_show:
1496
- break
1497
- images = images.cpu()
1498
- label_strings = [str(label.item()) for label in labels]
1499
- train_fig = _imshow(images, label_strings, nrow=20, fontsize=12)
1500
- if verbose:
1501
- plt.show()
1427
+ if augment:
1502
1428
 
1503
- return train_loaders, val_loaders, plate_names, train_fig
1429
+ print(f'Data before augmentation: Train: {len(train_dataset)}, Validataion:{len(val_dataset)}')
1430
+ train_dataset = augment_dataset(train_dataset, is_grayscale=(len(channels) == 1))
1431
+ print(f'Data after augmentation: Train: {len(train_dataset)}')
1432
+
1433
+ print(f'Generating Dataloader with {n_jobs} workers')
1434
+ #train_loaders = spacrDataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=True, preload_batches=preload_batches)
1435
+ #train_loaders = spacrDataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=True, preload_batches=preload_batches)
1504
1436
 
1505
- def analyze_recruitment(src, metadata_settings={}, advanced_settings={}):
1437
+ train_loaders = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=1, pin_memory=pin_memory, persistent_workers=True)
1438
+ val_loaders = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=1, pin_memory=pin_memory, persistent_workers=True)
1439
+ else:
1440
+ train_loaders = DataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=1, pin_memory=pin_memory, persistent_workers=True)
1441
+ #train_loaders = spacrDataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=True, preload_batches=preload_batches)
1442
+
1443
+ #dataset (Dataset) – dataset from which to load the data.
1444
+ #batch_size (int, optional) – how many samples per batch to load (default: 1).
1445
+ #shuffle (bool, optional) – set to True to have the data reshuffled at every epoch (default: False).
1446
+ #sampler (Sampler or Iterable, optional) – defines the strategy to draw samples from the dataset. Can be any Iterable with __len__ implemented. If specified, shuffle must not be specified.
1447
+ #batch_sampler (Sampler or Iterable, optional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last.
1448
+ #num_workers (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)
1449
+ #collate_fn (Callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
1450
+ #pin_memory (bool, optional) – If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them. If your data elements are a custom type, or your collate_fn returns a batch that is a custom type, see the example below.
1451
+ #drop_last (bool, optional) – set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)
1452
+ #timeout (numeric, optional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0)
1453
+ #worker_init_fn (Callable, optional) – If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)
1454
+ #multiprocessing_context (str or multiprocessing.context.BaseContext, optional) – If None, the default multiprocessing context of your operating system will be used. (default: None)
1455
+ #generator (torch.Generator, optional) – If not None, this RNG will be used by RandomSampler to generate random indexes and multiprocessing to generate base_seed for workers. (default: None)
1456
+ #prefetch_factor (int, optional, keyword-only arg) – Number of batches loaded in advance by each worker. 2 means there will be a total of 2 * num_workers batches prefetched across all workers. (default value depends on the set value for num_workers. If value of num_workers=0 default is None. Otherwise, if value of num_workers > 0 default is 2).
1457
+ #persistent_workers (bool, optional) – If True, the data loader will not shut down the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive. (default: False)
1458
+ #pin_memory_device (str, optional) – the device to pin_memory to if pin_memory is True.
1459
+
1460
+ #images, labels, filenames = next(iter(train_loaders))
1461
+ #images = images.cpu()
1462
+ #label_strings = [str(label.item()) for label in labels]
1463
+ #train_fig = _imshow_gpu(images, label_strings, nrow=20, fontsize=12)
1464
+ #if verbose:
1465
+ # plt.show()
1466
+
1467
+ train_fig = None
1468
+
1469
+ return train_loaders, val_loaders, train_fig
1470
+
1471
+ def analyze_recruitment(settings={}):
1506
1472
  """
1507
1473
  Analyze recruitment data by grouping the DataFrame by well coordinates and plotting controls and recruitment data.
1508
1474
 
1509
1475
  Parameters:
1510
- src (str): The source of the recruitment data.
1511
- metadata_settings (dict): The settings for metadata.
1512
- advanced_settings (dict): The advanced settings for recruitment analysis.
1476
+ settings (dict): settings.
1513
1477
 
1514
1478
  Returns:
1515
1479
  None
1516
1480
  """
1517
1481
 
1518
1482
  from .io import _read_and_merge_data, _results_to_csv
1519
- from .plot import plot_merged, _plot_controls, _plot_recruitment
1520
- from .utils import _object_filter, annotate_conditions, _calculate_recruitment, _group_by_well
1483
+ from .plot import plot_image_mask_overlay, _plot_controls, _plot_recruitment
1484
+ from .utils import _object_filter, annotate_conditions, _calculate_recruitment, _group_by_well, save_settings
1521
1485
  from .settings import get_analyze_recruitment_default_settings
1522
1486
 
1523
- settings = get_analyze_recruitment_default_settings(settings)
1524
-
1525
- settings_dict = {**metadata_settings, **advanced_settings}
1526
- settings_df = pd.DataFrame(list(settings_dict.items()), columns=['Key', 'Value'])
1527
- settings_csv = os.path.join(src,'settings','analyze_settings.csv')
1528
- os.makedirs(os.path.join(src,'settings'), exist_ok=True)
1529
- settings_df.to_csv(settings_csv, index=False)
1487
+ settings = get_analyze_recruitment_default_settings(settings=settings)
1488
+ save_settings(settings, name='recruitment')
1530
1489
 
1531
1490
  # metadata settings
1532
- target = metadata_settings['target']
1533
- cell_types = metadata_settings['cell_types']
1534
- cell_plate_metadata = metadata_settings['cell_plate_metadata']
1535
- pathogen_types = metadata_settings['pathogen_types']
1536
- pathogen_plate_metadata = metadata_settings['pathogen_plate_metadata']
1537
- treatments = metadata_settings['treatments']
1538
- treatment_plate_metadata = metadata_settings['treatment_plate_metadata']
1539
- metadata_types = metadata_settings['metadata_types']
1540
- channel_dims = metadata_settings['channel_dims']
1541
- cell_chann_dim = metadata_settings['cell_chann_dim']
1542
- cell_mask_dim = metadata_settings['cell_mask_dim']
1543
- nucleus_chann_dim = metadata_settings['nucleus_chann_dim']
1544
- nucleus_mask_dim = metadata_settings['nucleus_mask_dim']
1545
- pathogen_chann_dim = metadata_settings['pathogen_chann_dim']
1546
- pathogen_mask_dim = metadata_settings['pathogen_mask_dim']
1547
- channel_of_interest = metadata_settings['channel_of_interest']
1491
+ src = settings['src']
1492
+ target = settings['target']
1493
+ cell_types = settings['cell_types']
1494
+ cell_plate_metadata = settings['cell_plate_metadata']
1495
+ pathogen_types = settings['pathogen_types']
1496
+ pathogen_plate_metadata = settings['pathogen_plate_metadata']
1497
+ treatments = settings['treatments']
1498
+ treatment_plate_metadata = settings['treatment_plate_metadata']
1499
+ metadata_types = settings['metadata_types']
1500
+ channel_dims = settings['channel_dims']
1501
+ cell_chann_dim = settings['cell_chann_dim']
1502
+ cell_mask_dim = settings['cell_mask_dim']
1503
+ nucleus_chann_dim = settings['nucleus_chann_dim']
1504
+ nucleus_mask_dim = settings['nucleus_mask_dim']
1505
+ pathogen_chann_dim = settings['pathogen_chann_dim']
1506
+ pathogen_mask_dim = settings['pathogen_mask_dim']
1507
+ channel_of_interest = settings['channel_of_interest']
1548
1508
 
1549
1509
  # Advanced settings
1550
- plot = advanced_settings['plot']
1551
- plot_nr = advanced_settings['plot_nr']
1552
- plot_control = advanced_settings['plot_control']
1553
- figuresize = advanced_settings['figuresize']
1554
- remove_background = advanced_settings['remove_background']
1555
- backgrounds = advanced_settings['backgrounds']
1556
- include_noninfected = advanced_settings['include_noninfected']
1557
- include_multiinfected = advanced_settings['include_multiinfected']
1558
- include_multinucleated = advanced_settings['include_multinucleated']
1559
- cells_per_well = advanced_settings['cells_per_well']
1560
- pathogen_size_range = advanced_settings['pathogen_size_range']
1561
- nucleus_size_range = advanced_settings['nucleus_size_range']
1562
- cell_size_range = advanced_settings['cell_size_range']
1563
- pathogen_intensity_range = advanced_settings['pathogen_intensity_range']
1564
- nucleus_intensity_range = advanced_settings['nucleus_intensity_range']
1565
- cell_intensity_range = advanced_settings['cell_intensity_range']
1566
- target_intensity_min = advanced_settings['target_intensity_min']
1510
+ plot = settings['plot']
1511
+ plot_nr = settings['plot_nr']
1512
+ plot_control = settings['plot_control']
1513
+ figuresize = settings['figuresize']
1514
+ include_noninfected = settings['include_noninfected']
1515
+ include_multiinfected = settings['include_multiinfected']
1516
+ include_multinucleated = settings['include_multinucleated']
1517
+ cells_per_well = settings['cells_per_well']
1518
+ pathogen_size_range = settings['pathogen_size_range']
1519
+ nucleus_size_range = settings['nucleus_size_range']
1520
+ cell_size_range = settings['cell_size_range']
1521
+ pathogen_intensity_range = settings['pathogen_intensity_range']
1522
+ nucleus_intensity_range = settings['nucleus_intensity_range']
1523
+ cell_intensity_range = settings['cell_intensity_range']
1524
+ target_intensity_min = settings['target_intensity_min']
1567
1525
 
1568
1526
  print(f'Cell(s): {cell_types}, in {cell_plate_metadata}')
1569
1527
  print(f'Pathogen(s): {pathogen_types}, in {pathogen_plate_metadata}')
@@ -1581,9 +1539,6 @@ def analyze_recruitment(src, metadata_settings={}, advanced_settings={}):
1581
1539
  else:
1582
1540
  metadata_types = metadata_types
1583
1541
 
1584
- if isinstance(backgrounds, (int,float)):
1585
- backgrounds = [backgrounds, backgrounds, backgrounds, backgrounds]
1586
-
1587
1542
  sns.color_palette("mako", as_cmap=True)
1588
1543
  print(f'channel:{channel_of_interest} = {target}')
1589
1544
  overlay_channels = channel_dims
@@ -1593,11 +1548,11 @@ def analyze_recruitment(src, metadata_settings={}, advanced_settings={}):
1593
1548
  db_loc = [src+'/measurements/measurements.db']
1594
1549
  tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
1595
1550
  df, _ = _read_and_merge_data(db_loc,
1596
- tables,
1597
- verbose=True,
1598
- include_multinucleated=include_multinucleated,
1599
- include_multiinfected=include_multiinfected,
1600
- include_noninfected=include_noninfected)
1551
+ tables,
1552
+ verbose=True,
1553
+ include_multinucleated=include_multinucleated,
1554
+ include_multiinfected=include_multiinfected,
1555
+ include_noninfected=include_noninfected)
1601
1556
 
1602
1557
  df = annotate_conditions(df,
1603
1558
  cells=cell_types,
@@ -1616,48 +1571,31 @@ def analyze_recruitment(src, metadata_settings={}, advanced_settings={}):
1616
1571
  random.shuffle(files)
1617
1572
 
1618
1573
  _max = 10**100
1619
-
1620
- if cell_size_range is None and nucleus_size_range is None and pathogen_size_range is None:
1621
- filter_min_max = None
1622
- else:
1623
- if cell_size_range is None:
1624
- cell_size_range = [0,_max]
1625
- if nucleus_size_range is None:
1626
- nucleus_size_range = [0,_max]
1627
- if pathogen_size_range is None:
1628
- pathogen_size_range = [0,_max]
1629
-
1630
- filter_min_max = [[cell_size_range[0],cell_size_range[1]],[nucleus_size_range[0],nucleus_size_range[1]],[pathogen_size_range[0],pathogen_size_range[1]]]
1574
+ if cell_size_range is None:
1575
+ cell_size_range = [0,_max]
1576
+ if nucleus_size_range is None:
1577
+ nucleus_size_range = [0,_max]
1578
+ if pathogen_size_range is None:
1579
+ pathogen_size_range = [0,_max]
1631
1580
 
1632
1581
  if plot:
1633
- plot_settings = {'include_noninfected':include_noninfected,
1634
- 'include_multiinfected':include_multiinfected,
1635
- 'include_multinucleated':include_multinucleated,
1636
- 'remove_background':remove_background,
1637
- 'filter_min_max':filter_min_max,
1638
- 'channel_dims':channel_dims,
1639
- 'backgrounds':backgrounds,
1640
- 'cell_mask_dim':mask_dims[0],
1641
- 'nucleus_mask_dim':mask_dims[1],
1642
- 'pathogen_mask_dim':mask_dims[2],
1643
- 'overlay_chans':overlay_channels,
1644
- 'outline_thickness':3,
1645
- 'outline_color':'gbr',
1646
- 'overlay_chans':overlay_channels,
1647
- 'overlay':True,
1648
- 'normalization_percentiles':[1,99],
1649
- 'normalize':True,
1650
- 'print_object_number':True,
1651
- 'nr':plot_nr,
1652
- 'figuresize':20,
1653
- 'cmap':'inferno',
1654
- 'verbose':False}
1655
-
1656
- if os.path.exists(os.path.join(src,'merged')):
1657
- try:
1658
- plot_merged(src=os.path.join(src,'merged'), settings=plot_settings)
1659
- except Exception as e:
1660
- print(f'Failed to plot images with outlines, Error: {e}')
1582
+ merged_path = os.path.join(src,'merged')
1583
+ if os.path.exists(merged_path):
1584
+ try:
1585
+ for idx, file in enumerate(os.listdir(merged_path)):
1586
+ file_path = os.path.join(merged_path,file)
1587
+ if idx <= plot_nr:
1588
+ plot_image_mask_overlay(file_path,
1589
+ channel_dims,
1590
+ cell_chann_dim,
1591
+ nucleus_chann_dim,
1592
+ pathogen_chann_dim,
1593
+ figuresize=10,
1594
+ normalize=True,
1595
+ thickness=3,
1596
+ save_pdf=True)
1597
+ except Exception as e:
1598
+ print(f'Failed to plot images with outlines, Error: {e}')
1661
1599
 
1662
1600
  if not cell_chann_dim is None:
1663
1601
  df = _object_filter(df, object_type='cell', size_range=cell_size_range, intensity_range=cell_intensity_range, mask_chans=mask_chans, mask_chan=0)
@@ -1695,14 +1633,12 @@ def preprocess_generate_masks(src, settings={}):
1695
1633
 
1696
1634
  from .io import preprocess_img_data, _load_and_concatenate_arrays
1697
1635
  from .plot import plot_image_mask_overlay, plot_arrays
1698
- from .utils import _pivot_counts_table, check_mask_folder, adjust_cell_masks, print_progress
1636
+ from .utils import _pivot_counts_table, check_mask_folder, adjust_cell_masks, print_progress, save_settings
1699
1637
  from .settings import set_default_settings_preprocess_generate_masks
1700
1638
 
1701
1639
  settings = set_default_settings_preprocess_generate_masks(src, settings)
1702
- settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
1703
- settings_csv = os.path.join(src,'settings','preprocess_generate_masks_settings.csv')
1704
- os.makedirs(os.path.join(src,'settings'), exist_ok=True)
1705
- settings_df.to_csv(settings_csv, index=False)
1640
+ settings['src'] = src
1641
+ save_settings(settings)
1706
1642
 
1707
1643
  if not settings['pathogen_channel'] is None:
1708
1644
  custom_model_ls = ['toxo_pv_lumen','toxo_cyto']
@@ -1993,7 +1929,7 @@ def generate_cellpose_masks(src, settings, object_type):
1993
1929
  settings_df['setting_value'] = settings_df['setting_value'].apply(str)
1994
1930
  display(settings_df)
1995
1931
 
1996
- figuresize=25
1932
+ figuresize=10
1997
1933
  timelapse = settings['timelapse']
1998
1934
 
1999
1935
  if timelapse:
@@ -2497,7 +2433,6 @@ def ml_analysis(df, channel_of_interest=3, location_column='col', positive_contr
2497
2433
  df_metadata = df[[location_column]].copy()
2498
2434
  df, features = filter_dataframe_features(df, channel_of_interest, exclude, remove_low_variance_features, remove_highly_correlated_features, verbose)
2499
2435
 
2500
-
2501
2436
  if verbose:
2502
2437
  print(f'Found {len(features)} numerical features in the dataframe')
2503
2438
  print(f'Features used in training: {features}')
@@ -2642,7 +2577,6 @@ def check_index(df, elements=5, split_char='_'):
2642
2577
  print(idx)
2643
2578
  raise ValueError(f"Found {len(problematic_indices)} problematic indices that do not split into {elements} parts.")
2644
2579
 
2645
- #def plate_heatmap(src, model_type='xgboost', variable='predictions', grouping='mean', min_max='allq', cmap='viridis', channel_of_interest=3, min_count=25, n_estimators=100, col_to_compare='col', pos='c2', neg='c1', exclude=None, n_repeats=10, clean=True, nr_to_plot=20, verbose=False, n_jobs=-1):
2646
2580
  def generate_ml_scores(src, settings):
2647
2581
 
2648
2582
  from .io import _read_and_merge_data
@@ -2680,7 +2614,7 @@ def generate_ml_scores(src, settings):
2680
2614
  settings['top_features'],
2681
2615
  settings['n_estimators'],
2682
2616
  settings['test_size'],
2683
- settings['model_type'],
2617
+ settings['model_type_ml'],
2684
2618
  settings['n_jobs'],
2685
2619
  settings['remove_low_variance_features'],
2686
2620
  settings['remove_highly_correlated_features'],
@@ -2701,7 +2635,7 @@ def generate_ml_scores(src, settings):
2701
2635
  min_count=settings['minimum_cell_count'],
2702
2636
  verbose=settings['verbose'])
2703
2637
 
2704
- data_path, permutation_path, feature_importance_path, model_metricks_path, permutation_fig_path, feature_importance_fig_path, shap_fig_path, plate_heatmap_path, settings_csv = get_ml_results_paths(src, settings['model_type'], settings['channel_of_interest'])
2638
+ data_path, permutation_path, feature_importance_path, model_metricks_path, permutation_fig_path, feature_importance_fig_path, shap_fig_path, plate_heatmap_path, settings_csv = get_ml_results_paths(src, settings['model_type_ml'], settings['channel_of_interest'])
2705
2639
  df, permutation_df, feature_importance_df, _, _, _, _, _, metrics_df = output
2706
2640
 
2707
2641
  settings_df.to_csv(settings_csv, index=False)
@@ -2858,6 +2792,7 @@ def generate_image_umap(settings={}):
2858
2792
  settings['plot_outlines'] = False
2859
2793
  settings['smooth_lines'] = False
2860
2794
 
2795
+ print(f'Generating Image UMAP ...')
2861
2796
  settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
2862
2797
  settings_dir = os.path.join(settings['src'][0],'settings')
2863
2798
  settings_csv = os.path.join(settings_dir,'embedding_settings.csv')