spacr 0.2.4__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. spacr/__init__.py +1 -11
  2. spacr/core.py +277 -349
  3. spacr/deep_spacr.py +248 -269
  4. spacr/gui.py +58 -54
  5. spacr/gui_core.py +689 -535
  6. spacr/gui_elements.py +1002 -153
  7. spacr/gui_utils.py +452 -107
  8. spacr/io.py +158 -91
  9. spacr/measure.py +199 -151
  10. spacr/plot.py +159 -47
  11. spacr/resources/font/open_sans/OFL.txt +93 -0
  12. spacr/resources/font/open_sans/OpenSans-Italic-VariableFont_wdth,wght.ttf +0 -0
  13. spacr/resources/font/open_sans/OpenSans-VariableFont_wdth,wght.ttf +0 -0
  14. spacr/resources/font/open_sans/README.txt +100 -0
  15. spacr/resources/font/open_sans/static/OpenSans-Bold.ttf +0 -0
  16. spacr/resources/font/open_sans/static/OpenSans-BoldItalic.ttf +0 -0
  17. spacr/resources/font/open_sans/static/OpenSans-ExtraBold.ttf +0 -0
  18. spacr/resources/font/open_sans/static/OpenSans-ExtraBoldItalic.ttf +0 -0
  19. spacr/resources/font/open_sans/static/OpenSans-Italic.ttf +0 -0
  20. spacr/resources/font/open_sans/static/OpenSans-Light.ttf +0 -0
  21. spacr/resources/font/open_sans/static/OpenSans-LightItalic.ttf +0 -0
  22. spacr/resources/font/open_sans/static/OpenSans-Medium.ttf +0 -0
  23. spacr/resources/font/open_sans/static/OpenSans-MediumItalic.ttf +0 -0
  24. spacr/resources/font/open_sans/static/OpenSans-Regular.ttf +0 -0
  25. spacr/resources/font/open_sans/static/OpenSans-SemiBold.ttf +0 -0
  26. spacr/resources/font/open_sans/static/OpenSans-SemiBoldItalic.ttf +0 -0
  27. spacr/resources/font/open_sans/static/OpenSans_Condensed-Bold.ttf +0 -0
  28. spacr/resources/font/open_sans/static/OpenSans_Condensed-BoldItalic.ttf +0 -0
  29. spacr/resources/font/open_sans/static/OpenSans_Condensed-ExtraBold.ttf +0 -0
  30. spacr/resources/font/open_sans/static/OpenSans_Condensed-ExtraBoldItalic.ttf +0 -0
  31. spacr/resources/font/open_sans/static/OpenSans_Condensed-Italic.ttf +0 -0
  32. spacr/resources/font/open_sans/static/OpenSans_Condensed-Light.ttf +0 -0
  33. spacr/resources/font/open_sans/static/OpenSans_Condensed-LightItalic.ttf +0 -0
  34. spacr/resources/font/open_sans/static/OpenSans_Condensed-Medium.ttf +0 -0
  35. spacr/resources/font/open_sans/static/OpenSans_Condensed-MediumItalic.ttf +0 -0
  36. spacr/resources/font/open_sans/static/OpenSans_Condensed-Regular.ttf +0 -0
  37. spacr/resources/font/open_sans/static/OpenSans_Condensed-SemiBold.ttf +0 -0
  38. spacr/resources/font/open_sans/static/OpenSans_Condensed-SemiBoldItalic.ttf +0 -0
  39. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-Bold.ttf +0 -0
  40. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-BoldItalic.ttf +0 -0
  41. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-ExtraBold.ttf +0 -0
  42. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-ExtraBoldItalic.ttf +0 -0
  43. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-Italic.ttf +0 -0
  44. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-Light.ttf +0 -0
  45. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-LightItalic.ttf +0 -0
  46. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-Medium.ttf +0 -0
  47. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-MediumItalic.ttf +0 -0
  48. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-Regular.ttf +0 -0
  49. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-SemiBold.ttf +0 -0
  50. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-SemiBoldItalic.ttf +0 -0
  51. spacr/resources/icons/logo.pdf +2786 -6
  52. spacr/resources/icons/logo_spacr.png +0 -0
  53. spacr/resources/icons/logo_spacr_1.png +0 -0
  54. spacr/sequencing.py +477 -587
  55. spacr/settings.py +217 -144
  56. spacr/utils.py +46 -46
  57. {spacr-0.2.4.dist-info → spacr-0.2.8.dist-info}/METADATA +46 -35
  58. spacr-0.2.8.dist-info/RECORD +100 -0
  59. {spacr-0.2.4.dist-info → spacr-0.2.8.dist-info}/WHEEL +1 -1
  60. spacr-0.2.4.dist-info/RECORD +0 -58
  61. {spacr-0.2.4.dist-info → spacr-0.2.8.dist-info}/LICENSE +0 -0
  62. {spacr-0.2.4.dist-info → spacr-0.2.8.dist-info}/entry_points.txt +0 -0
  63. {spacr-0.2.4.dist-info → spacr-0.2.8.dist-info}/top_level.txt +0 -0
spacr/core.py CHANGED
@@ -16,7 +16,6 @@ import seaborn as sns
16
16
  import cellpose
17
17
  from skimage.measure import regionprops, label
18
18
  from skimage.transform import resize as resizescikit
19
- from torch.utils.data import DataLoader
20
19
 
21
20
  from skimage import measure
22
21
  from sklearn.model_selection import train_test_split
@@ -39,6 +38,20 @@ matplotlib.use('Agg')
39
38
 
40
39
  from .logger import log_function_call
41
40
 
41
+ import warnings
42
+ warnings.filterwarnings("ignore", message="3D stack used, but stitch_threshold=0 and do_3D=False, so masks are made per plane only")
43
+
44
+
45
+ from torchvision import transforms
46
+ from torch.utils.data import DataLoader, random_split
47
+ from collections import defaultdict
48
+ import os
49
+ import random
50
+ from PIL import Image
51
+ from torchvision.transforms import ToTensor
52
+
53
+
54
+
42
55
  def analyze_plaques(folder):
43
56
  summary_data = []
44
57
  details_data = []
@@ -80,7 +93,6 @@ def analyze_plaques(folder):
80
93
 
81
94
  print(f"Analysis completed and saved to database '{db_name}'.")
82
95
 
83
-
84
96
  def train_cellpose(settings):
85
97
 
86
98
  from .io import _load_normalized_images_and_labels, _load_images_and_labels
@@ -874,22 +886,22 @@ def annotate_results(pred_loc):
874
886
  display(df)
875
887
  return df
876
888
 
877
- def generate_dataset(src, file_metadata=None, experiment='TSG101_screen', sample=None):
889
+ def generate_dataset(settings={}):
878
890
 
879
891
  from .utils import initiate_counter, add_images_to_tar
880
892
 
881
- db_path = os.path.join(src, 'measurements', 'measurements.db')
882
- dst = os.path.join(src, 'datasets')
893
+ db_path = os.path.join(settings['src'], 'measurements', 'measurements.db')
894
+ dst = os.path.join(settings['src'], 'datasets')
883
895
  all_paths = []
884
896
 
885
897
  # Connect to the database and retrieve the image paths
886
- print(f'Reading DataBase: {db_path}')
898
+ print(f"Reading DataBase: {db_path}")
887
899
  try:
888
900
  with sqlite3.connect(db_path) as conn:
889
901
  cursor = conn.cursor()
890
- if file_metadata:
891
- if isinstance(file_metadata, str):
892
- cursor.execute("SELECT png_path FROM png_list WHERE png_path LIKE ?", (f"%{file_metadata}%",))
902
+ if settings['file_metadata']:
903
+ if isinstance(settings['file_metadata'], str):
904
+ cursor.execute("SELECT png_path FROM png_list WHERE png_path LIKE ?", (f"%{settings['file_metadata']}%",))
893
905
  else:
894
906
  cursor.execute("SELECT png_path FROM png_list")
895
907
 
@@ -906,16 +918,16 @@ def generate_dataset(src, file_metadata=None, experiment='TSG101_screen', sample
906
918
  print(f"Error: {e}")
907
919
  return
908
920
 
909
- if isinstance(sample, int):
910
- selected_paths = random.sample(all_paths, sample)
911
- print(f'Random selection of {len(selected_paths)} paths')
921
+ if isinstance(settings['sample'], int):
922
+ selected_paths = random.sample(all_paths, settings['sample'])
923
+ print(f"Random selection of {len(selected_paths)} paths")
912
924
  else:
913
925
  selected_paths = all_paths
914
926
  random.shuffle(selected_paths)
915
- print(f'All paths: {len(selected_paths)} paths')
927
+ print(f"All paths: {len(selected_paths)} paths")
916
928
 
917
929
  total_images = len(selected_paths)
918
- print(f'Found {total_images} images')
930
+ print(f"Found {total_images} images")
919
931
 
920
932
  # Create a temp folder in dst
921
933
  temp_dir = os.path.join(dst, "temp_tars")
@@ -933,9 +945,9 @@ def generate_dataset(src, file_metadata=None, experiment='TSG101_screen', sample
933
945
  paths_chunks.append(selected_paths[start:end])
934
946
  start = end
935
947
 
936
- temp_tar_files = [os.path.join(temp_dir, f'temp_{i}.tar') for i in range(num_procs)]
948
+ temp_tar_files = [os.path.join(temp_dir, f"temp_{i}.tar") for i in range(num_procs)]
937
949
 
938
- print(f'Generating temporary tar files in {dst}')
950
+ print(f"Generating temporary tar files in {dst}")
939
951
 
940
952
  # Initialize shared counter and lock
941
953
  counter = Value('i', 0)
@@ -946,18 +958,18 @@ def generate_dataset(src, file_metadata=None, experiment='TSG101_screen', sample
946
958
 
947
959
  # Combine the temporary tar files into a final tar
948
960
  date_name = datetime.date.today().strftime('%y%m%d')
949
- if not file_metadata is None:
950
- tar_name = f'{date_name}_{experiment}_{file_metadata}.tar'
961
+ if not settings['file_metadata'] is None:
962
+ tar_name = f"{date_name}_{settings['experiment']}_{settings['file_metadata']}.tar"
951
963
  else:
952
- tar_name = f'{date_name}_{experiment}.tar'
964
+ tar_name = f"{date_name}_{settings['experiment']}.tar"
953
965
  tar_name = os.path.join(dst, tar_name)
954
966
  if os.path.exists(tar_name):
955
967
  number = random.randint(1, 100)
956
- tar_name_2 = f'{date_name}_{experiment}_{file_metadata}_{number}.tar'
957
- print(f'Warning: {os.path.basename(tar_name)} exists, saving as {os.path.basename(tar_name_2)} ')
968
+ tar_name_2 = f"{date_name}_{settings['experiment']}_{settings['file_metadata']}_{number}.tar"
969
+ print(f"Warning: {os.path.basename(tar_name)} exists, saving as {os.path.basename(tar_name_2)} ")
958
970
  tar_name = os.path.join(dst, tar_name_2)
959
971
 
960
- print(f'Merging temporary files')
972
+ print(f"Merging temporary files")
961
973
 
962
974
  with tarfile.open(tar_name, 'w') as final_tar:
963
975
  for temp_tar_path in temp_tar_files:
@@ -971,41 +983,43 @@ def generate_dataset(src, file_metadata=None, experiment='TSG101_screen', sample
971
983
  shutil.rmtree(temp_dir)
972
984
  print(f"\nSaved {total_images} images to {tar_name}")
973
985
 
974
- def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=224, batch_size=64, normalize=True, preload='images', n_jobs=10, threshold=0.5, verbose=False):
986
+ return tar_name
987
+
988
+ def apply_model_to_tar(settings={}):
975
989
 
976
990
  from .io import TarImageDataset
977
991
  from .utils import process_vision_results, print_progress
978
992
 
979
993
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
980
- if normalize:
994
+ if settings['normalize']:
981
995
  transform = transforms.Compose([
982
996
  transforms.ToTensor(),
983
- transforms.CenterCrop(size=(image_size, image_size)),
997
+ transforms.CenterCrop(size=(settings['image_size'], settings['image_size'])),
984
998
  transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
985
999
  else:
986
1000
  transform = transforms.Compose([
987
1001
  transforms.ToTensor(),
988
- transforms.CenterCrop(size=(image_size, image_size))])
1002
+ transforms.CenterCrop(size=(settings['image_size'], settings['image_size']))])
989
1003
 
990
- if verbose:
991
- print(f'Loading model from {model_path}')
992
- print(f'Loading dataset from {tar_path}')
1004
+ if settings['verbose']:
1005
+ print(f"Loading model from {settings['model_path']}")
1006
+ print(f"Loading dataset from {settings['tar_path']}")
993
1007
 
994
- model = torch.load(model_path)
1008
+ model = torch.load(settings['model_path'])
995
1009
 
996
- dataset = TarImageDataset(tar_path, transform=transform)
997
- data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=n_jobs, pin_memory=True)
1010
+ dataset = TarImageDataset(settings['tar_path'], transform=transform)
1011
+ data_loader = DataLoader(dataset, batch_size=settings['batch_size'], shuffle=True, num_workers=settings['n_jobs'], pin_memory=True)
998
1012
 
999
- model_name = os.path.splitext(os.path.basename(model_path))[0]
1000
- dataset_name = os.path.splitext(os.path.basename(tar_path))[0]
1013
+ model_name = os.path.splitext(os.path.basename(settings['model_path']))[0]
1014
+ dataset_name = os.path.splitext(os.path.basename(settings['tar_path']))[0]
1001
1015
  date_name = datetime.date.today().strftime('%y%m%d')
1002
- dst = os.path.dirname(tar_path)
1016
+ dst = os.path.dirname(settings['tar_path'])
1003
1017
  result_loc = f'{dst}/{date_name}_{dataset_name}_{model_name}_result.csv'
1004
1018
 
1005
1019
  model.eval()
1006
1020
  model = model.to(device)
1007
1021
 
1008
- if verbose:
1022
+ if settings['verbose']:
1009
1023
  print(model)
1010
1024
  print(f'Generated dataset with {len(dataset)} images')
1011
1025
  print(f'Generating loader from {len(data_loader)} batches')
@@ -1028,17 +1042,13 @@ def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=22
1028
1042
  stop = time.time()
1029
1043
  duration = stop - start
1030
1044
  time_ls.append(duration)
1031
- files_processed = batch_idx*batch_size
1045
+ files_processed = batch_idx*settings['batch_size']
1032
1046
  files_to_process = len(data_loader)
1033
- print_progress(files_processed, files_to_process, n_jobs=n_jobs, time_ls=time_ls, batch_size=batch_size, operation_type="Tar dataset")
1034
-
1035
-
1036
-
1037
-
1047
+ print_progress(files_processed, files_to_process, n_jobs=settings['n_jobs'], time_ls=time_ls, batch_size=settings['batch_size'], operation_type="Tar dataset")
1038
1048
 
1039
1049
  data = {'path':filenames_list, 'pred':prediction_pos_probs}
1040
1050
  df = pd.DataFrame(data, index=None)
1041
- df = process_vision_results(df, threshold)
1051
+ df = process_vision_results(df, settings['score_threshold'])
1042
1052
 
1043
1053
  df.to_csv(result_loc, index=True, header=True, mode='w')
1044
1054
  torch.cuda.empty_cache()
@@ -1207,19 +1217,19 @@ def generate_dataset_from_lists(dst, class_data, classes, test_split=0.1):
1207
1217
  for path in train_data:
1208
1218
  start = time.time()
1209
1219
  shutil.copy(path, os.path.join(train_class_dir, os.path.basename(path)))
1210
- processed_files += 1
1211
1220
  duration = time.time() - start
1212
1221
  time_ls.append(duration)
1213
1222
  print_progress(processed_files, total_files, n_jobs=1, time_ls=None, batch_size=None, operation_type="Copying files for Train dataset")
1223
+ processed_files += 1
1214
1224
 
1215
1225
  # Copy test files
1216
1226
  for path in test_data:
1217
1227
  start = time.time()
1218
1228
  shutil.copy(path, os.path.join(test_class_dir, os.path.basename(path)))
1219
- processed_files += 1
1220
1229
  duration = time.time() - start
1221
1230
  time_ls.append(duration)
1222
1231
  print_progress(processed_files, total_files, n_jobs=1, time_ls=None, batch_size=None, operation_type="Copying files for Test dataset")
1232
+ processed_files += 1
1223
1233
 
1224
1234
  # Print summary
1225
1235
  for cls in classes:
@@ -1227,44 +1237,47 @@ def generate_dataset_from_lists(dst, class_data, classes, test_split=0.1):
1227
1237
  test_class_dir = os.path.join(dst, f'test/{cls}')
1228
1238
  print(f'Train class {cls}: {len(os.listdir(train_class_dir))}, Test class {cls}: {len(os.listdir(test_class_dir))}')
1229
1239
 
1230
- return
1240
+ return os.path.join(dst, 'train'), os.path.join(dst, 'test')
1231
1241
 
1232
- def generate_training_dataset(src, mode='annotation', annotation_column='test', annotated_classes=[1,2], classes=['nc','pc'], size=200, test_split=0.1, class_metadata=[['c1'],['c2']], metadata_type_by='col', channel_of_interest=3, custom_measurement=None, tables=None, png_type='cell_png'):
1242
+ def generate_training_dataset(settings):
1233
1243
 
1234
1244
  from .io import _read_and_merge_data, _read_db
1235
1245
  from .utils import get_paths_from_db, annotate_conditions
1246
+ from .settings import set_generate_training_dataset_defaults
1247
+
1248
+ settings = set_generate_training_dataset_defaults(settings)
1236
1249
 
1237
- db_path = os.path.join(src, 'measurements','measurements.db')
1238
- dst = os.path.join(src, 'datasets', 'training')
1250
+ db_path = os.path.join(settings['src'], 'measurements','measurements.db')
1251
+ dst = os.path.join(settings['src'], 'datasets', 'training')
1239
1252
 
1240
1253
  if os.path.exists(dst):
1241
1254
  for i in range(1, 1000):
1242
- dst = os.path.join(src, 'datasets', f'training_{i}')
1255
+ dst = os.path.join(settings['src'], 'datasets', f'training_{i}')
1243
1256
  if not os.path.exists(dst):
1244
1257
  print(f'Creating new directory for training: {dst}')
1245
1258
  break
1246
1259
 
1247
- if mode == 'annotation':
1260
+ if settings['dataset_mode'] == 'annotation':
1248
1261
  class_paths_ls_2 = []
1249
- class_paths_ls = training_dataset_from_annotation(db_path, dst, annotation_column, annotated_classes=annotated_classes)
1262
+ class_paths_ls = training_dataset_from_annotation(db_path, dst, settings['annotation_column'], annotated_classes=settings['annotated_classes'])
1250
1263
  for class_paths in class_paths_ls:
1251
- class_paths_temp = random.sample(class_paths, size)
1264
+ class_paths_temp = random.sample(class_paths, settings['size'])
1252
1265
  class_paths_ls_2.append(class_paths_temp)
1253
1266
  class_paths_ls = class_paths_ls_2
1254
1267
 
1255
- elif mode == 'metadata':
1268
+ elif settings['dataset_mode'] == 'metadata':
1256
1269
  class_paths_ls = []
1257
1270
  class_len_ls = []
1258
1271
  [df] = _read_db(db_loc=db_path, tables=['png_list'])
1259
1272
  df['metadata_based_class'] = pd.NA
1260
- for i, class_ in enumerate(classes):
1261
- ls = class_metadata[i]
1262
- df.loc[df[metadata_type_by].isin(ls), 'metadata_based_class'] = class_
1273
+ for i, class_ in enumerate(settings['classes']):
1274
+ ls = settings['class_metadata'][i]
1275
+ df.loc[df[settings['metadata_type_by']].isin(ls), 'metadata_based_class'] = class_
1263
1276
 
1264
- for class_ in classes:
1265
- if size == None:
1277
+ for class_ in settings['classes']:
1278
+ if settings['size'] == None:
1266
1279
  c_s = []
1267
- for c in classes:
1280
+ for c in settings['classes']:
1268
1281
  c_s_t_df = df[df['metadata_based_class'] == c]
1269
1282
  c_s.append(len(c_s_t_df))
1270
1283
  print(f'Found {len(c_s_t_df)} images for class {c}')
@@ -1274,12 +1287,12 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
1274
1287
  class_temp_df = df[df['metadata_based_class'] == class_]
1275
1288
  class_len_ls.append(len(class_temp_df))
1276
1289
  print(f'Found {len(class_temp_df)} images for class {class_}')
1277
- class_paths_temp = random.sample(class_temp_df['png_path'].tolist(), size)
1290
+ class_paths_temp = random.sample(class_temp_df['png_path'].tolist(), settings['size'])
1278
1291
  class_paths_ls.append(class_paths_temp)
1279
1292
 
1280
- elif mode == 'recruitment':
1293
+ elif settings['dataset_mode'] == 'recruitment':
1281
1294
  class_paths_ls = []
1282
- if not isinstance(tables, list):
1295
+ if not isinstance(settings['tables'], list):
1283
1296
  tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
1284
1297
 
1285
1298
  df, _ = _read_and_merge_data(locs=[db_path],
@@ -1291,60 +1304,58 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
1291
1304
 
1292
1305
  print('length df 1', len(df))
1293
1306
 
1294
- df = annotate_conditions(df, cells=['HeLa'], cell_loc=None, pathogens=['pathogen'], pathogen_loc=None, treatments=classes, treatment_loc=class_metadata, types = ['col','col',metadata_type_by])
1307
+ df = annotate_conditions(df, cells=['HeLa'], cell_loc=None, pathogens=['pathogen'], pathogen_loc=None, treatments=settings['classes'], treatment_loc=settings['class_metadata'], types = settings['metadata_type_by'])
1295
1308
  print('length df 2', len(df))
1296
1309
  [png_list_df] = _read_db(db_loc=db_path, tables=['png_list'])
1297
1310
 
1298
- if custom_measurement != None:
1311
+ if settings['custom_measurement'] != None:
1299
1312
 
1300
- if not isinstance(custom_measurement, list):
1313
+ if not isinstance(settings['custom_measurement'], list):
1301
1314
  print(f'custom_measurement should be a list, add [ measurement_1, measurement_2 ] or [ measurement ]')
1302
1315
  return
1303
1316
 
1304
- if isinstance(custom_measurement, list):
1305
- if len(custom_measurement) == 2:
1306
- print(f'Classes will be defined by the Q1 and Q3 quantiles of recruitment ({custom_measurement[0]}/{custom_measurement[1]})')
1307
- df['recruitment'] = df[f'{custom_measurement[0]}']/df[f'{custom_measurement[1]}']
1308
- if len(custom_measurement) == 1:
1309
- print(f'Classes will be defined by the Q1 and Q3 quantiles of recruitment ({custom_measurement[0]})')
1310
- df['recruitment'] = df[f'{custom_measurement[0]}']
1317
+ if isinstance(settings['custom_measurement'], list):
1318
+ if len(settings['custom_measurement']) == 2:
1319
+ print(f"Classes will be defined by the Q1 and Q3 quantiles of recruitment ({settings['custom_measurement'][0]}/{settings['custom_measurement'][1]})")
1320
+ df['recruitment'] = df[f"{settings['custom_measurement'][0]}']/df[f'{settings['custom_measurement'][1]}"]
1321
+ if len(settings['custom_measurement']) == 1:
1322
+ print(f"Classes will be defined by the Q1 and Q3 quantiles of recruitment ({settings['custom_measurement'][0]})")
1323
+ df['recruitment'] = df[f"{settings['custom_measurement'][0]}"]
1311
1324
  else:
1312
- print(f'Classes will be defined by the Q1 and Q3 quantiles of recruitment (pathogen/cytoplasm for channel {channel_of_interest})')
1313
- df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_mean_intensity']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
1325
+ print(f"Classes will be defined by the Q1 and Q3 quantiles of recruitment (pathogen/cytoplasm for channel {settings['channel_of_interest']})")
1326
+ df['recruitment'] = df[f"pathogen_channel_{settings['channel_of_interest']}_mean_intensity']/df[f'cytoplasm_channel_{settings['channel_of_interest']}_mean_intensity"]
1314
1327
 
1315
1328
  q25 = df['recruitment'].quantile(0.25)
1316
1329
  q75 = df['recruitment'].quantile(0.75)
1317
1330
  df_lower = df[df['recruitment'] <= q25]
1318
1331
  df_upper = df[df['recruitment'] >= q75]
1319
1332
 
1320
- class_paths_lower = get_paths_from_db(df=df_lower, png_df=png_list_df, image_type=png_type)
1333
+ class_paths_lower = get_paths_from_db(df=df_lower, png_df=png_list_df, image_type=settings['png_type'])
1321
1334
 
1322
- class_paths_lower = random.sample(class_paths_lower['png_path'].tolist(), size)
1335
+ class_paths_lower = random.sample(class_paths_lower['png_path'].tolist(), settings['size'])
1323
1336
  class_paths_ls.append(class_paths_lower)
1324
1337
 
1325
- class_paths_upper = get_paths_from_db(df=df_upper, png_df=png_list_df, image_type=png_type)
1326
- class_paths_upper = random.sample(class_paths_upper['png_path'].tolist(), size)
1338
+ class_paths_upper = get_paths_from_db(df=df_upper, png_df=png_list_df, image_type=settings['png_type'])
1339
+ class_paths_upper = random.sample(class_paths_upper['png_path'].tolist(), settings['size'])
1327
1340
  class_paths_ls.append(class_paths_upper)
1328
1341
 
1329
- generate_dataset_from_lists(dst, class_data=class_paths_ls, classes=classes, test_split=0.1)
1342
+ train_class_dir, test_class_dir = generate_dataset_from_lists(dst, class_data=class_paths_ls, classes=settings['classes'], test_split=settings['test_split'])
1330
1343
 
1331
- return
1344
+ return train_class_dir, test_class_dir
1332
1345
 
1333
- def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], n_jobs=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, channels=[1, 2, 3], augment=False, verbose=False):
1346
+ def generate_loaders(src, mode='train', image_size=224, batch_size=32, classes=['nc','pc'], n_jobs=None, validation_split=0.0, pin_memory=False, normalize=False, channels=[1, 2, 3], augment=False, preload_batches=3, verbose=False):
1334
1347
 
1335
1348
  """
1336
1349
  Generate data loaders for training and validation/test datasets.
1337
1350
 
1338
1351
  Parameters:
1339
1352
  - src (str): The source directory containing the data.
1340
- - train_mode (str): The training mode. Options are 'erm' (Empirical Risk Minimization) or 'irm' (Invariant Risk Minimization).
1341
1353
  - mode (str): The mode of operation. Options are 'train' or 'test'.
1342
1354
  - image_size (int): The size of the input images.
1343
1355
  - batch_size (int): The batch size for the data loaders.
1344
1356
  - classes (list): The list of classes to consider.
1345
1357
  - n_jobs (int): The number of worker threads for data loading.
1346
- - validation_split (float): The fraction of data to use for validation when train_mode is 'erm'.
1347
- - max_show (int): The maximum number of images to show when verbose is True.
1358
+ - validation_split (float): The fraction of data to use for validation.
1348
1359
  - pin_memory (bool): Whether to pin memory for faster data transfer.
1349
1360
  - normalize (bool): Whether to normalize the input images.
1350
1361
  - verbose (bool): Whether to print additional information and show images.
@@ -1353,18 +1364,10 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
1353
1364
  Returns:
1354
1365
  - train_loaders (list): List of data loaders for training datasets.
1355
1366
  - val_loaders (list): List of data loaders for validation datasets.
1356
- - plate_names (list): List of plate names (only applicable when train_mode is 'irm').
1357
1367
  """
1358
1368
 
1359
- from .io import MyDataset
1360
- from .plot import _imshow
1361
- from torchvision import transforms
1362
- from torch.utils.data import DataLoader, random_split
1363
- from collections import defaultdict
1364
- import os
1365
- import random
1366
- from PIL import Image
1367
- from torchvision.transforms import ToTensor
1369
+ from .io import spacrDataset, spacrDataLoader
1370
+ from .plot import _imshow_gpu
1368
1371
  from .utils import SelectChannels, augment_dataset
1369
1372
 
1370
1373
  chans = []
@@ -1381,12 +1384,9 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
1381
1384
  if verbose:
1382
1385
  print(f'Training a network on channels: {channels}')
1383
1386
  print(f'Channel 1: Red, Channel 2: Green, Channel 3: Blue')
1384
-
1385
- plate_to_filenames = defaultdict(list)
1386
- plate_to_labels = defaultdict(list)
1387
+
1387
1388
  train_loaders = []
1388
1389
  val_loaders = []
1389
- plate_names = []
1390
1390
 
1391
1391
  if normalize:
1392
1392
  transform = transforms.Compose([
@@ -1414,157 +1414,114 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
1414
1414
  print(f'mode:{mode} is not valid, use mode = train or test')
1415
1415
  return
1416
1416
 
1417
- if train_mode == 'erm':
1418
-
1419
- data = MyDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
1420
-
1421
- if validation_split > 0:
1422
- train_size = int((1 - validation_split) * len(data))
1423
- val_size = len(data) - train_size
1424
- if not augment:
1425
- print(f'Train data:{train_size}, Validation data:{val_size}')
1426
- train_dataset, val_dataset = random_split(data, [train_size, val_size])
1427
-
1428
- if augment:
1429
-
1430
- print(f'Data before augmentation: Train: {len(train_dataset)}, Validataion:{len(val_dataset)}')
1431
- train_dataset = augment_dataset(train_dataset, is_grayscale=(len(channels) == 1))
1432
- #val_dataset = augment_dataset(val_dataset, is_grayscale=(len(channels) == 1))
1433
- print(f'Data after augmentation: Train: {len(train_dataset)}')#, Validataion:{len(val_dataset)}')
1434
-
1435
- train_loaders = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=n_jobs if n_jobs is not None else 0, pin_memory=pin_memory)
1436
- val_loaders = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=n_jobs if n_jobs is not None else 0, pin_memory=pin_memory)
1437
- else:
1438
- train_loaders = DataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=n_jobs if n_jobs is not None else 0, pin_memory=pin_memory)
1439
-
1440
- elif train_mode == 'irm':
1441
- data = MyDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
1442
-
1443
- for filename, label in zip(data.filenames, data.labels):
1444
- plate = data.get_plate(filename)
1445
- plate_to_filenames[plate].append(filename)
1446
- plate_to_labels[plate].append(label)
1447
-
1448
- for plate, filenames in plate_to_filenames.items():
1449
- labels = plate_to_labels[plate]
1450
- plate_data = MyDataset(data_dir, classes, specific_files=filenames, specific_labels=labels, transform=transform, shuffle=False, pin_memory=pin_memory)
1451
- plate_names.append(plate)
1452
-
1453
- if validation_split > 0:
1454
- train_size = int((1 - validation_split) * len(plate_data))
1455
- val_size = len(plate_data) - train_size
1456
- if not augment:
1457
- print(f'Train data:{train_size}, Validation data:{val_size}')
1458
- train_dataset, val_dataset = random_split(plate_data, [train_size, val_size])
1459
-
1460
- if augment:
1461
-
1462
- print(f'Data before augmentation: Train: {len(train_dataset)}, Validataion:{val_dataset}')
1463
- train_dataset = augment_dataset(train_dataset, is_grayscale=(len(channels) == 1))
1464
- #val_dataset = augment_dataset(val_dataset, is_grayscale=(len(channels) == 1))
1465
- print(f'Data after augmentation: Train: {len(train_dataset)}')#, Validataion:{len(val_dataset)}')
1466
-
1467
- train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=n_jobs if n_jobs is not None else 0, pin_memory=pin_memory)
1468
- val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=n_jobs if n_jobs is not None else 0, pin_memory=pin_memory)
1469
-
1470
- train_loaders.append(train_loader)
1471
- val_loaders.append(val_loader)
1472
- else:
1473
- train_loader = DataLoader(plate_data, batch_size=batch_size, shuffle=shuffle, num_workers=n_jobs if n_jobs is not None else 0, pin_memory=pin_memory)
1474
- train_loaders.append(train_loader)
1475
- val_loaders.append(None)
1476
-
1477
- else:
1478
- print(f'train_mode:{train_mode} is not valid, use: train_mode = irm or erm')
1479
- return
1480
-
1417
+ data = spacrDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
1418
+ num_workers = n_jobs if n_jobs is not None else 0
1481
1419
 
1482
- if train_mode == 'erm':
1483
- for idx, (images, labels, filenames) in enumerate(train_loaders):
1484
- if idx >= max_show:
1485
- break
1486
- images = images.cpu()
1487
- label_strings = [str(label.item()) for label in labels]
1488
- train_fig = _imshow(images, label_strings, nrow=20, fontsize=12)
1489
- if verbose:
1490
- plt.show()
1420
+ if validation_split > 0:
1421
+ train_size = int((1 - validation_split) * len(data))
1422
+ val_size = len(data) - train_size
1423
+ if not augment:
1424
+ print(f'Train data:{train_size}, Validation data:{val_size}')
1425
+ train_dataset, val_dataset = random_split(data, [train_size, val_size])
1491
1426
 
1492
- elif train_mode == 'irm':
1493
- for plate_name, train_loader in zip(plate_names, train_loaders):
1494
- print(f'Plate: {plate_name} with {len(train_loader.dataset)} images')
1495
- for idx, (images, labels, filenames) in enumerate(train_loader):
1496
- if idx >= max_show:
1497
- break
1498
- images = images.cpu()
1499
- label_strings = [str(label.item()) for label in labels]
1500
- train_fig = _imshow(images, label_strings, nrow=20, fontsize=12)
1501
- if verbose:
1502
- plt.show()
1427
+ if augment:
1503
1428
 
1504
- return train_loaders, val_loaders, plate_names, train_fig
1429
+ print(f'Data before augmentation: Train: {len(train_dataset)}, Validataion:{len(val_dataset)}')
1430
+ train_dataset = augment_dataset(train_dataset, is_grayscale=(len(channels) == 1))
1431
+ print(f'Data after augmentation: Train: {len(train_dataset)}')
1432
+
1433
+ print(f'Generating Dataloader with {n_jobs} workers')
1434
+ #train_loaders = spacrDataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=True, preload_batches=preload_batches)
1435
+ #train_loaders = spacrDataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=True, preload_batches=preload_batches)
1505
1436
 
1506
- def analyze_recruitment(src, metadata_settings={}, advanced_settings={}):
1437
+ train_loaders = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=1, pin_memory=pin_memory, persistent_workers=True)
1438
+ val_loaders = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=1, pin_memory=pin_memory, persistent_workers=True)
1439
+ else:
1440
+ train_loaders = DataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=1, pin_memory=pin_memory, persistent_workers=True)
1441
+ #train_loaders = spacrDataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=True, preload_batches=preload_batches)
1442
+
1443
+ #dataset (Dataset) – dataset from which to load the data.
1444
+ #batch_size (int, optional) – how many samples per batch to load (default: 1).
1445
+ #shuffle (bool, optional) – set to True to have the data reshuffled at every epoch (default: False).
1446
+ #sampler (Sampler or Iterable, optional) – defines the strategy to draw samples from the dataset. Can be any Iterable with __len__ implemented. If specified, shuffle must not be specified.
1447
+ #batch_sampler (Sampler or Iterable, optional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last.
1448
+ #num_workers (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)
1449
+ #collate_fn (Callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
1450
+ #pin_memory (bool, optional) – If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them. If your data elements are a custom type, or your collate_fn returns a batch that is a custom type, see the example below.
1451
+ #drop_last (bool, optional) – set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)
1452
+ #timeout (numeric, optional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0)
1453
+ #worker_init_fn (Callable, optional) – If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)
1454
+ #multiprocessing_context (str or multiprocessing.context.BaseContext, optional) – If None, the default multiprocessing context of your operating system will be used. (default: None)
1455
+ #generator (torch.Generator, optional) – If not None, this RNG will be used by RandomSampler to generate random indexes and multiprocessing to generate base_seed for workers. (default: None)
1456
+ #prefetch_factor (int, optional, keyword-only arg) – Number of batches loaded in advance by each worker. 2 means there will be a total of 2 * num_workers batches prefetched across all workers. (default value depends on the set value for num_workers. If value of num_workers=0 default is None. Otherwise, if value of num_workers > 0 default is 2).
1457
+ #persistent_workers (bool, optional) – If True, the data loader will not shut down the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive. (default: False)
1458
+ #pin_memory_device (str, optional) – the device to pin_memory to if pin_memory is True.
1459
+
1460
+ #images, labels, filenames = next(iter(train_loaders))
1461
+ #images = images.cpu()
1462
+ #label_strings = [str(label.item()) for label in labels]
1463
+ #train_fig = _imshow_gpu(images, label_strings, nrow=20, fontsize=12)
1464
+ #if verbose:
1465
+ # plt.show()
1466
+
1467
+ train_fig = None
1468
+
1469
+ return train_loaders, val_loaders, train_fig
1470
+
1471
+ def analyze_recruitment(settings={}):
1507
1472
  """
1508
1473
  Analyze recruitment data by grouping the DataFrame by well coordinates and plotting controls and recruitment data.
1509
1474
 
1510
1475
  Parameters:
1511
- src (str): The source of the recruitment data.
1512
- metadata_settings (dict): The settings for metadata.
1513
- advanced_settings (dict): The advanced settings for recruitment analysis.
1476
+ settings (dict): settings.
1514
1477
 
1515
1478
  Returns:
1516
1479
  None
1517
1480
  """
1518
1481
 
1519
1482
  from .io import _read_and_merge_data, _results_to_csv
1520
- from .plot import plot_merged, _plot_controls, _plot_recruitment
1521
- from .utils import _object_filter, annotate_conditions, _calculate_recruitment, _group_by_well
1483
+ from .plot import plot_image_mask_overlay, _plot_controls, _plot_recruitment
1484
+ from .utils import _object_filter, annotate_conditions, _calculate_recruitment, _group_by_well, save_settings
1522
1485
  from .settings import get_analyze_recruitment_default_settings
1523
1486
 
1524
- settings = get_analyze_recruitment_default_settings(settings)
1525
-
1526
- settings_dict = {**metadata_settings, **advanced_settings}
1527
- settings_df = pd.DataFrame(list(settings_dict.items()), columns=['Key', 'Value'])
1528
- settings_csv = os.path.join(src,'settings','analyze_settings.csv')
1529
- os.makedirs(os.path.join(src,'settings'), exist_ok=True)
1530
- settings_df.to_csv(settings_csv, index=False)
1487
+ settings = get_analyze_recruitment_default_settings(settings=settings)
1488
+ save_settings(settings, name='recruitment')
1531
1489
 
1532
1490
  # metadata settings
1533
- target = metadata_settings['target']
1534
- cell_types = metadata_settings['cell_types']
1535
- cell_plate_metadata = metadata_settings['cell_plate_metadata']
1536
- pathogen_types = metadata_settings['pathogen_types']
1537
- pathogen_plate_metadata = metadata_settings['pathogen_plate_metadata']
1538
- treatments = metadata_settings['treatments']
1539
- treatment_plate_metadata = metadata_settings['treatment_plate_metadata']
1540
- metadata_types = metadata_settings['metadata_types']
1541
- channel_dims = metadata_settings['channel_dims']
1542
- cell_chann_dim = metadata_settings['cell_chann_dim']
1543
- cell_mask_dim = metadata_settings['cell_mask_dim']
1544
- nucleus_chann_dim = metadata_settings['nucleus_chann_dim']
1545
- nucleus_mask_dim = metadata_settings['nucleus_mask_dim']
1546
- pathogen_chann_dim = metadata_settings['pathogen_chann_dim']
1547
- pathogen_mask_dim = metadata_settings['pathogen_mask_dim']
1548
- channel_of_interest = metadata_settings['channel_of_interest']
1491
+ src = settings['src']
1492
+ target = settings['target']
1493
+ cell_types = settings['cell_types']
1494
+ cell_plate_metadata = settings['cell_plate_metadata']
1495
+ pathogen_types = settings['pathogen_types']
1496
+ pathogen_plate_metadata = settings['pathogen_plate_metadata']
1497
+ treatments = settings['treatments']
1498
+ treatment_plate_metadata = settings['treatment_plate_metadata']
1499
+ metadata_types = settings['metadata_types']
1500
+ channel_dims = settings['channel_dims']
1501
+ cell_chann_dim = settings['cell_chann_dim']
1502
+ cell_mask_dim = settings['cell_mask_dim']
1503
+ nucleus_chann_dim = settings['nucleus_chann_dim']
1504
+ nucleus_mask_dim = settings['nucleus_mask_dim']
1505
+ pathogen_chann_dim = settings['pathogen_chann_dim']
1506
+ pathogen_mask_dim = settings['pathogen_mask_dim']
1507
+ channel_of_interest = settings['channel_of_interest']
1549
1508
 
1550
1509
  # Advanced settings
1551
- plot = advanced_settings['plot']
1552
- plot_nr = advanced_settings['plot_nr']
1553
- plot_control = advanced_settings['plot_control']
1554
- figuresize = advanced_settings['figuresize']
1555
- remove_background = advanced_settings['remove_background']
1556
- backgrounds = advanced_settings['backgrounds']
1557
- include_noninfected = advanced_settings['include_noninfected']
1558
- include_multiinfected = advanced_settings['include_multiinfected']
1559
- include_multinucleated = advanced_settings['include_multinucleated']
1560
- cells_per_well = advanced_settings['cells_per_well']
1561
- pathogen_size_range = advanced_settings['pathogen_size_range']
1562
- nucleus_size_range = advanced_settings['nucleus_size_range']
1563
- cell_size_range = advanced_settings['cell_size_range']
1564
- pathogen_intensity_range = advanced_settings['pathogen_intensity_range']
1565
- nucleus_intensity_range = advanced_settings['nucleus_intensity_range']
1566
- cell_intensity_range = advanced_settings['cell_intensity_range']
1567
- target_intensity_min = advanced_settings['target_intensity_min']
1510
+ plot = settings['plot']
1511
+ plot_nr = settings['plot_nr']
1512
+ plot_control = settings['plot_control']
1513
+ figuresize = settings['figuresize']
1514
+ include_noninfected = settings['include_noninfected']
1515
+ include_multiinfected = settings['include_multiinfected']
1516
+ include_multinucleated = settings['include_multinucleated']
1517
+ cells_per_well = settings['cells_per_well']
1518
+ pathogen_size_range = settings['pathogen_size_range']
1519
+ nucleus_size_range = settings['nucleus_size_range']
1520
+ cell_size_range = settings['cell_size_range']
1521
+ pathogen_intensity_range = settings['pathogen_intensity_range']
1522
+ nucleus_intensity_range = settings['nucleus_intensity_range']
1523
+ cell_intensity_range = settings['cell_intensity_range']
1524
+ target_intensity_min = settings['target_intensity_min']
1568
1525
 
1569
1526
  print(f'Cell(s): {cell_types}, in {cell_plate_metadata}')
1570
1527
  print(f'Pathogen(s): {pathogen_types}, in {pathogen_plate_metadata}')
@@ -1582,9 +1539,6 @@ def analyze_recruitment(src, metadata_settings={}, advanced_settings={}):
1582
1539
  else:
1583
1540
  metadata_types = metadata_types
1584
1541
 
1585
- if isinstance(backgrounds, (int,float)):
1586
- backgrounds = [backgrounds, backgrounds, backgrounds, backgrounds]
1587
-
1588
1542
  sns.color_palette("mako", as_cmap=True)
1589
1543
  print(f'channel:{channel_of_interest} = {target}')
1590
1544
  overlay_channels = channel_dims
@@ -1594,11 +1548,11 @@ def analyze_recruitment(src, metadata_settings={}, advanced_settings={}):
1594
1548
  db_loc = [src+'/measurements/measurements.db']
1595
1549
  tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
1596
1550
  df, _ = _read_and_merge_data(db_loc,
1597
- tables,
1598
- verbose=True,
1599
- include_multinucleated=include_multinucleated,
1600
- include_multiinfected=include_multiinfected,
1601
- include_noninfected=include_noninfected)
1551
+ tables,
1552
+ verbose=True,
1553
+ include_multinucleated=include_multinucleated,
1554
+ include_multiinfected=include_multiinfected,
1555
+ include_noninfected=include_noninfected)
1602
1556
 
1603
1557
  df = annotate_conditions(df,
1604
1558
  cells=cell_types,
@@ -1617,48 +1571,31 @@ def analyze_recruitment(src, metadata_settings={}, advanced_settings={}):
1617
1571
  random.shuffle(files)
1618
1572
 
1619
1573
  _max = 10**100
1620
-
1621
- if cell_size_range is None and nucleus_size_range is None and pathogen_size_range is None:
1622
- filter_min_max = None
1623
- else:
1624
- if cell_size_range is None:
1625
- cell_size_range = [0,_max]
1626
- if nucleus_size_range is None:
1627
- nucleus_size_range = [0,_max]
1628
- if pathogen_size_range is None:
1629
- pathogen_size_range = [0,_max]
1630
-
1631
- filter_min_max = [[cell_size_range[0],cell_size_range[1]],[nucleus_size_range[0],nucleus_size_range[1]],[pathogen_size_range[0],pathogen_size_range[1]]]
1574
+ if cell_size_range is None:
1575
+ cell_size_range = [0,_max]
1576
+ if nucleus_size_range is None:
1577
+ nucleus_size_range = [0,_max]
1578
+ if pathogen_size_range is None:
1579
+ pathogen_size_range = [0,_max]
1632
1580
 
1633
1581
  if plot:
1634
- plot_settings = {'include_noninfected':include_noninfected,
1635
- 'include_multiinfected':include_multiinfected,
1636
- 'include_multinucleated':include_multinucleated,
1637
- 'remove_background':remove_background,
1638
- 'filter_min_max':filter_min_max,
1639
- 'channel_dims':channel_dims,
1640
- 'backgrounds':backgrounds,
1641
- 'cell_mask_dim':mask_dims[0],
1642
- 'nucleus_mask_dim':mask_dims[1],
1643
- 'pathogen_mask_dim':mask_dims[2],
1644
- 'overlay_chans':overlay_channels,
1645
- 'outline_thickness':3,
1646
- 'outline_color':'gbr',
1647
- 'overlay_chans':overlay_channels,
1648
- 'overlay':True,
1649
- 'normalization_percentiles':[1,99],
1650
- 'normalize':True,
1651
- 'print_object_number':True,
1652
- 'nr':plot_nr,
1653
- 'figuresize':20,
1654
- 'cmap':'inferno',
1655
- 'verbose':False}
1656
-
1657
- if os.path.exists(os.path.join(src,'merged')):
1658
- try:
1659
- plot_merged(src=os.path.join(src,'merged'), settings=plot_settings)
1660
- except Exception as e:
1661
- print(f'Failed to plot images with outlines, Error: {e}')
1582
+ merged_path = os.path.join(src,'merged')
1583
+ if os.path.exists(merged_path):
1584
+ try:
1585
+ for idx, file in enumerate(os.listdir(merged_path)):
1586
+ file_path = os.path.join(merged_path,file)
1587
+ if idx <= plot_nr:
1588
+ plot_image_mask_overlay(file_path,
1589
+ channel_dims,
1590
+ cell_chann_dim,
1591
+ nucleus_chann_dim,
1592
+ pathogen_chann_dim,
1593
+ figuresize=10,
1594
+ normalize=True,
1595
+ thickness=3,
1596
+ save_pdf=True)
1597
+ except Exception as e:
1598
+ print(f'Failed to plot images with outlines, Error: {e}')
1662
1599
 
1663
1600
  if not cell_chann_dim is None:
1664
1601
  df = _object_filter(df, object_type='cell', size_range=cell_size_range, intensity_range=cell_intensity_range, mask_chans=mask_chans, mask_chan=0)
@@ -1695,15 +1632,13 @@ def analyze_recruitment(src, metadata_settings={}, advanced_settings={}):
1695
1632
  def preprocess_generate_masks(src, settings={}):
1696
1633
 
1697
1634
  from .io import preprocess_img_data, _load_and_concatenate_arrays
1698
- from .plot import plot_merged, plot_arrays
1699
- from .utils import _pivot_counts_table, check_mask_folder, adjust_cell_masks
1700
- from .settings import set_default_settings_preprocess_generate_masks, set_default_plot_merge_settings
1635
+ from .plot import plot_image_mask_overlay, plot_arrays
1636
+ from .utils import _pivot_counts_table, check_mask_folder, adjust_cell_masks, print_progress, save_settings
1637
+ from .settings import set_default_settings_preprocess_generate_masks
1701
1638
 
1702
1639
  settings = set_default_settings_preprocess_generate_masks(src, settings)
1703
- settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
1704
- settings_csv = os.path.join(src,'settings','preprocess_generate_masks_settings.csv')
1705
- os.makedirs(os.path.join(src,'settings'), exist_ok=True)
1706
- settings_df.to_csv(settings_csv, index=False)
1640
+ settings['src'] = src
1641
+ save_settings(settings, name='gen_mask')
1707
1642
 
1708
1643
  if not settings['pathogen_channel'] is None:
1709
1644
  custom_model_ls = ['toxo_pv_lumen','toxo_cyto']
@@ -1730,20 +1665,47 @@ def preprocess_generate_masks(src, settings={}):
1730
1665
 
1731
1666
  if settings['preprocess']:
1732
1667
  settings, src = preprocess_img_data(settings)
1733
-
1668
+
1669
+ files_to_process = 3
1670
+ files_processed = 0
1734
1671
  if settings['masks']:
1735
1672
  mask_src = os.path.join(src, 'norm_channel_stack')
1736
1673
  if settings['cell_channel'] != None:
1674
+ time_ls=[]
1737
1675
  if check_mask_folder(src, 'cell_mask_stack'):
1676
+ start = time.time()
1738
1677
  generate_cellpose_masks(mask_src, settings, 'cell')
1678
+ stop = time.time()
1679
+ duration = (stop - start)
1680
+ time_ls.append(duration)
1681
+ files_processed += 1
1682
+ print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=None, operation_type=f'cell_mask_gen')
1739
1683
 
1740
1684
  if settings['nucleus_channel'] != None:
1685
+ time_ls=[]
1741
1686
  if check_mask_folder(src, 'nucleus_mask_stack'):
1687
+ start = time.time()
1742
1688
  generate_cellpose_masks(mask_src, settings, 'nucleus')
1689
+ stop = time.time()
1690
+ duration = (stop - start)
1691
+ time_ls.append(duration)
1692
+ files_processed += 1
1693
+ print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=None, operation_type=f'nucleus_mask_gen')
1743
1694
 
1744
1695
  if settings['pathogen_channel'] != None:
1696
+ time_ls=[]
1745
1697
  if check_mask_folder(src, 'pathogen_mask_stack'):
1698
+ start = time.time()
1746
1699
  generate_cellpose_masks(mask_src, settings, 'pathogen')
1700
+ stop = time.time()
1701
+ duration = (stop - start)
1702
+ time_ls.append(duration)
1703
+ files_processed += 1
1704
+ print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=None, operation_type=f'pathogen_mask_gen')
1705
+
1706
+ #if settings['organelle'] != None:
1707
+ # if check_mask_folder(src, 'organelle_mask_stack'):
1708
+ # generate_cellpose_masks(mask_src, settings, 'organelle')
1747
1709
 
1748
1710
  if settings['adjust_cells']:
1749
1711
  if settings['pathogen_channel'] != None and settings['cell_channel'] != None and settings['nucleus_channel'] != None:
@@ -1752,12 +1714,8 @@ def preprocess_generate_masks(src, settings={}):
1752
1714
  cell_folder = os.path.join(mask_src, 'cell_mask_stack')
1753
1715
  nuclei_folder = os.path.join(mask_src, 'nucleus_mask_stack')
1754
1716
  parasite_folder = os.path.join(mask_src, 'pathogen_mask_stack')
1755
- #image_folder = os.path.join(src, 'stack')
1717
+ #organelle_folder = os.path.join(mask_src, 'organelle_mask_stack')
1756
1718
 
1757
- #process_masks(cell_folder, image_folder, settings['cell_channel'], settings['batch_size'], n_clusters=2, plot=settings['plot'])
1758
- #process_masks(nuclei_folder, image_folder, settings['nucleus_channel'], settings['batch_size'], n_clusters=2, plot=settings['plot'])
1759
- #process_masks(parasite_folder, image_folder, settings['pathogen_channel'], settings['batch_size'], n_clusters=2, plot=settings['plot'])
1760
-
1761
1719
  adjust_cell_masks(parasite_folder, cell_folder, nuclei_folder, overlap_threshold=5, perimeter_threshold=30)
1762
1720
  stop = time.time()
1763
1721
  adjust_time = (stop-start)/60
@@ -1771,38 +1729,28 @@ def preprocess_generate_masks(src, settings={}):
1771
1729
 
1772
1730
  if settings['plot']:
1773
1731
  if not settings['timelapse']:
1774
- plot_dims = len(settings['channels'])
1775
- overlay_channels = [2,1,0]
1776
- cell_mask_dim = nucleus_mask_dim = pathogen_mask_dim = None
1777
- plot_counter = plot_dims
1778
-
1779
- if settings['cell_channel'] is not None:
1780
- cell_mask_dim = plot_counter
1781
- plot_counter += 1
1782
-
1783
- if settings['nucleus_channel'] is not None:
1784
- nucleus_mask_dim = plot_counter
1785
- plot_counter += 1
1786
-
1787
- if settings['pathogen_channel'] is not None:
1788
- pathogen_mask_dim = plot_counter
1789
-
1790
- overlay_channels = [settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel']]
1791
- overlay_channels = [element for element in overlay_channels if element is not None]
1792
-
1793
- plot_settings = set_default_plot_merge_settings()
1794
- plot_settings['channel_dims'] = settings['channels']
1795
- plot_settings['cell_mask_dim'] = cell_mask_dim
1796
- plot_settings['nucleus_mask_dim'] = nucleus_mask_dim
1797
- plot_settings['pathogen_mask_dim'] = pathogen_mask_dim
1798
- plot_settings['overlay_chans'] = overlay_channels
1799
- plot_settings['nr'] = settings['examples_to_plot']
1800
1732
 
1801
1733
  if settings['test_mode'] == True:
1802
- plot_settings['nr'] = len(os.path.join(src,'merged'))
1734
+ settings['examples_to_plot'] = len(os.path.join(src,'merged'))
1803
1735
 
1804
1736
  try:
1805
- fig = plot_merged(src=os.path.join(src,'merged'), settings=plot_settings)
1737
+ merged_src = os.path.join(src,'merged')
1738
+ files = os.listdir(merged_src)
1739
+ random.shuffle(files)
1740
+ time_ls = []
1741
+
1742
+ for i, file in enumerate(files):
1743
+ start = time.time()
1744
+ if i+1 <= settings['examples_to_plot']:
1745
+ file_path = os.path.join(merged_src, file)
1746
+ plot_image_mask_overlay(file_path, settings['channels'], settings['cell_channel'], settings['nucleus_channel'], settings['pathogen_channel'], figuresize=10, normalize=True, thickness=3, save_pdf=True)
1747
+ stop = time.time()
1748
+ duration = stop-start
1749
+ time_ls.append(duration)
1750
+ files_processed = i+1
1751
+ files_to_process = settings['examples_to_plot']
1752
+ print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=None, operation_type="Plot mask outlines")
1753
+ print("Successfully completed run")
1806
1754
  except Exception as e:
1807
1755
  print(f'Failed to plot image mask overly. Error: {e}')
1808
1756
  else:
@@ -1981,7 +1929,7 @@ def generate_cellpose_masks(src, settings, object_type):
1981
1929
  settings_df['setting_value'] = settings_df['setting_value'].apply(str)
1982
1930
  display(settings_df)
1983
1931
 
1984
- figuresize=25
1932
+ figuresize=10
1985
1933
  timelapse = settings['timelapse']
1986
1934
 
1987
1935
  if timelapse:
@@ -2010,7 +1958,7 @@ def generate_cellpose_masks(src, settings, object_type):
2010
1958
 
2011
1959
  if object_type == 'pathogen' and not settings['pathogen_model'] is None:
2012
1960
  model_name = settings['pathogen_model']
2013
-
1961
+
2014
1962
  model = _choose_model(model_name, device, object_type=object_type, restore_type=None, object_settings=object_settings)
2015
1963
 
2016
1964
  chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [2,0] if model_name == 'cyto' else [2, 0] if model_name == 'cyto3' else [2, 0]
@@ -2022,16 +1970,18 @@ def generate_cellpose_masks(src, settings, object_type):
2022
1970
 
2023
1971
  average_sizes = []
2024
1972
  time_ls = []
1973
+
2025
1974
  for file_index, path in enumerate(paths):
2026
1975
  name = os.path.basename(path)
2027
1976
  name, ext = os.path.splitext(name)
2028
1977
  output_folder = os.path.join(os.path.dirname(path), object_type+'_mask_stack')
2029
1978
  os.makedirs(output_folder, exist_ok=True)
2030
1979
  overall_average_size = 0
1980
+
2031
1981
  with np.load(path) as data:
2032
1982
  stack = data['data']
2033
1983
  filenames = data['filenames']
2034
-
1984
+
2035
1985
  for i, filename in enumerate(filenames):
2036
1986
  output_path = os.path.join(output_folder, filename)
2037
1987
 
@@ -2057,11 +2007,8 @@ def generate_cellpose_masks(src, settings, object_type):
2057
2007
  batch_size = len(stack)
2058
2008
  print(f'Cut batch at indecies: {timelapse_frame_limits}, New batch_size: {batch_size} ')
2059
2009
 
2060
- files_processed = 0
2061
2010
  for i in range(0, stack.shape[0], batch_size):
2062
2011
  mask_stack = []
2063
- start = time.time()
2064
-
2065
2012
  if stack.shape[3] == 1:
2066
2013
  batch = stack[i: i+batch_size, :, :, [0,0]].astype(stack.dtype)
2067
2014
  else:
@@ -2072,7 +2019,6 @@ def generate_cellpose_masks(src, settings, object_type):
2072
2019
  if not settings['plot']:
2073
2020
  batch, batch_filenames = _check_masks(batch, batch_filenames, output_folder)
2074
2021
  if batch.size == 0:
2075
- print(f'Processing {file_index}/{len(paths)}: Images/npz {batch.shape[0]}')
2076
2022
  continue
2077
2023
 
2078
2024
  batch = prepare_batch_for_cellpose(batch)
@@ -2083,16 +2029,6 @@ def generate_cellpose_masks(src, settings, object_type):
2083
2029
  save_path = os.path.join(movie_path, f'timelapse_{object_type}_{name}.mp4')
2084
2030
  _npz_to_movie(batch, batch_filenames, save_path, fps=2)
2085
2031
 
2086
- if settings['verbose']:
2087
- print(f'Processing {file_index}/{len(paths)}: Images/npz {batch.shape[0]}')
2088
-
2089
- #cellpose_normalize_dict = {'lowhigh':[0.0,1.0], #pass in normalization values for 0.0 and 1.0 as list [low, high] if None all other keys ignored
2090
- # 'sharpen':object_settings['diameter']/4, #recommended to be 1/4-1/8 diameter of cells in pixels
2091
- # 'normalize':True, #(if False, all following parameters ignored)
2092
- # 'percentile':[2,98], #[perc_low, perc_high]
2093
- # 'tile_norm':224, #normalize by tile set to e.g. 100 for normailize window to be 100 px
2094
- # 'norm3D':True} #compute normalization across entire z-stack rather than plane-by-plane in stitching mode.
2095
-
2096
2032
  output = model.eval(x=batch,
2097
2033
  batch_size=cellpose_batch_size,
2098
2034
  normalize=False,
@@ -2202,16 +2138,8 @@ def generate_cellpose_masks(src, settings, object_type):
2202
2138
 
2203
2139
  average_sizes.append(average_obj_size)
2204
2140
  overall_average_size = np.mean(average_sizes) if len(average_sizes) > 0 else 0
2141
+ print(f'object_size:{object_type}: {overall_average_size:.3f} px2')
2205
2142
 
2206
- stop = time.time()
2207
- duration = (stop - start)
2208
- time_ls.append(duration)
2209
- files_processed += len(batch_filenames)
2210
- #files_processed = (file_index+1)*(batch_size+1)
2211
- files_to_process = (len(paths))*(batch_size)
2212
- print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=batch_size, operation_type=f'{object_type}_mask_gen')
2213
- print(f'object_size:{object_type}): {overall_average_size:.3f} px2')
2214
-
2215
2143
  if not timelapse:
2216
2144
  if settings['plot']:
2217
2145
  plot_masks(batch, mask_stack, flows, figuresize=figuresize, cmap='inferno', nr=batch_size)
@@ -2222,6 +2150,7 @@ def generate_cellpose_masks(src, settings, object_type):
2222
2150
  np.save(output_filename, mask)
2223
2151
  mask_stack = []
2224
2152
  batch_filenames = []
2153
+
2225
2154
  gc.collect()
2226
2155
  torch.cuda.empty_cache()
2227
2156
  return
@@ -2504,7 +2433,6 @@ def ml_analysis(df, channel_of_interest=3, location_column='col', positive_contr
2504
2433
  df_metadata = df[[location_column]].copy()
2505
2434
  df, features = filter_dataframe_features(df, channel_of_interest, exclude, remove_low_variance_features, remove_highly_correlated_features, verbose)
2506
2435
 
2507
-
2508
2436
  if verbose:
2509
2437
  print(f'Found {len(features)} numerical features in the dataframe')
2510
2438
  print(f'Features used in training: {features}')
@@ -2649,7 +2577,6 @@ def check_index(df, elements=5, split_char='_'):
2649
2577
  print(idx)
2650
2578
  raise ValueError(f"Found {len(problematic_indices)} problematic indices that do not split into {elements} parts.")
2651
2579
 
2652
- #def plate_heatmap(src, model_type='xgboost', variable='predictions', grouping='mean', min_max='allq', cmap='viridis', channel_of_interest=3, min_count=25, n_estimators=100, col_to_compare='col', pos='c2', neg='c1', exclude=None, n_repeats=10, clean=True, nr_to_plot=20, verbose=False, n_jobs=-1):
2653
2580
  def generate_ml_scores(src, settings):
2654
2581
 
2655
2582
  from .io import _read_and_merge_data
@@ -2687,7 +2614,7 @@ def generate_ml_scores(src, settings):
2687
2614
  settings['top_features'],
2688
2615
  settings['n_estimators'],
2689
2616
  settings['test_size'],
2690
- settings['model_type'],
2617
+ settings['model_type_ml'],
2691
2618
  settings['n_jobs'],
2692
2619
  settings['remove_low_variance_features'],
2693
2620
  settings['remove_highly_correlated_features'],
@@ -2708,7 +2635,7 @@ def generate_ml_scores(src, settings):
2708
2635
  min_count=settings['minimum_cell_count'],
2709
2636
  verbose=settings['verbose'])
2710
2637
 
2711
- data_path, permutation_path, feature_importance_path, model_metricks_path, permutation_fig_path, feature_importance_fig_path, shap_fig_path, plate_heatmap_path, settings_csv = get_ml_results_paths(src, settings['model_type'], settings['channel_of_interest'])
2638
+ data_path, permutation_path, feature_importance_path, model_metricks_path, permutation_fig_path, feature_importance_fig_path, shap_fig_path, plate_heatmap_path, settings_csv = get_ml_results_paths(src, settings['model_type_ml'], settings['channel_of_interest'])
2712
2639
  df, permutation_df, feature_importance_df, _, _, _, _, _, metrics_df = output
2713
2640
 
2714
2641
  settings_df.to_csv(settings_csv, index=False)
@@ -2865,6 +2792,7 @@ def generate_image_umap(settings={}):
2865
2792
  settings['plot_outlines'] = False
2866
2793
  settings['smooth_lines'] = False
2867
2794
 
2795
+ print(f'Generating Image UMAP ...')
2868
2796
  settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
2869
2797
  settings_dir = os.path.join(settings['src'][0],'settings')
2870
2798
  settings_csv = os.path.join(settings_dir,'embedding_settings.csv')