spacr 0.2.56__py3-none-any.whl → 0.2.61__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/core.py CHANGED
@@ -16,7 +16,6 @@ import seaborn as sns
16
16
  import cellpose
17
17
  from skimage.measure import regionprops, label
18
18
  from skimage.transform import resize as resizescikit
19
- from torch.utils.data import DataLoader
20
19
 
21
20
  from skimage import measure
22
21
  from sklearn.model_selection import train_test_split
@@ -43,6 +42,16 @@ import warnings
43
42
  warnings.filterwarnings("ignore", message="3D stack used, but stitch_threshold=0 and do_3D=False, so masks are made per plane only")
44
43
 
45
44
 
45
+ from torchvision import transforms
46
+ from torch.utils.data import DataLoader, random_split
47
+ from collections import defaultdict
48
+ import os
49
+ import random
50
+ from PIL import Image
51
+ from torchvision.transforms import ToTensor
52
+
53
+
54
+
46
55
  def analyze_plaques(folder):
47
56
  summary_data = []
48
57
  details_data = []
@@ -976,173 +985,6 @@ def generate_dataset(settings={}):
976
985
 
977
986
  return tar_name
978
987
 
979
- def generate_dataset_v1(src, file_metadata=None, experiment='TSG101_screen', sample=None):
980
-
981
- from .utils import initiate_counter, add_images_to_tar
982
-
983
- db_path = os.path.join(src, 'measurements', 'measurements.db')
984
- dst = os.path.join(src, 'datasets')
985
- all_paths = []
986
-
987
- # Connect to the database and retrieve the image paths
988
- print(f'Reading DataBase: {db_path}')
989
- try:
990
- with sqlite3.connect(db_path) as conn:
991
- cursor = conn.cursor()
992
- if file_metadata:
993
- if isinstance(file_metadata, str):
994
- cursor.execute("SELECT png_path FROM png_list WHERE png_path LIKE ?", (f"%{file_metadata}%",))
995
- else:
996
- cursor.execute("SELECT png_path FROM png_list")
997
-
998
- while True:
999
- rows = cursor.fetchmany(1000)
1000
- if not rows:
1001
- break
1002
- all_paths.extend([row[0] for row in rows])
1003
-
1004
- except sqlite3.Error as e:
1005
- print(f"Database error: {e}")
1006
- return
1007
- except Exception as e:
1008
- print(f"Error: {e}")
1009
- return
1010
-
1011
- if isinstance(sample, int):
1012
- selected_paths = random.sample(all_paths, sample)
1013
- print(f'Random selection of {len(selected_paths)} paths')
1014
- else:
1015
- selected_paths = all_paths
1016
- random.shuffle(selected_paths)
1017
- print(f'All paths: {len(selected_paths)} paths')
1018
-
1019
- total_images = len(selected_paths)
1020
- print(f'Found {total_images} images')
1021
-
1022
- # Create a temp folder in dst
1023
- temp_dir = os.path.join(dst, "temp_tars")
1024
- os.makedirs(temp_dir, exist_ok=True)
1025
-
1026
- # Chunking the data
1027
- num_procs = max(2, cpu_count() - 2)
1028
- chunk_size = len(selected_paths) // num_procs
1029
- remainder = len(selected_paths) % num_procs
1030
-
1031
- paths_chunks = []
1032
- start = 0
1033
- for i in range(num_procs):
1034
- end = start + chunk_size + (1 if i < remainder else 0)
1035
- paths_chunks.append(selected_paths[start:end])
1036
- start = end
1037
-
1038
- temp_tar_files = [os.path.join(temp_dir, f'temp_{i}.tar') for i in range(num_procs)]
1039
-
1040
- print(f'Generating temporary tar files in {dst}')
1041
-
1042
- # Initialize shared counter and lock
1043
- counter = Value('i', 0)
1044
- lock = Lock()
1045
-
1046
- with Pool(processes=num_procs, initializer=initiate_counter, initargs=(counter, lock)) as pool:
1047
- pool.starmap(add_images_to_tar, [(paths_chunks[i], temp_tar_files[i], total_images) for i in range(num_procs)])
1048
-
1049
- # Combine the temporary tar files into a final tar
1050
- date_name = datetime.date.today().strftime('%y%m%d')
1051
- if not file_metadata is None:
1052
- tar_name = f'{date_name}_{experiment}_{file_metadata}.tar'
1053
- else:
1054
- tar_name = f'{date_name}_{experiment}.tar'
1055
- tar_name = os.path.join(dst, tar_name)
1056
- if os.path.exists(tar_name):
1057
- number = random.randint(1, 100)
1058
- tar_name_2 = f'{date_name}_{experiment}_{file_metadata}_{number}.tar'
1059
- print(f'Warning: {os.path.basename(tar_name)} exists, saving as {os.path.basename(tar_name_2)} ')
1060
- tar_name = os.path.join(dst, tar_name_2)
1061
-
1062
- print(f'Merging temporary files')
1063
-
1064
- with tarfile.open(tar_name, 'w') as final_tar:
1065
- for temp_tar_path in temp_tar_files:
1066
- with tarfile.open(temp_tar_path, 'r') as temp_tar:
1067
- for member in temp_tar.getmembers():
1068
- file_obj = temp_tar.extractfile(member)
1069
- final_tar.addfile(member, file_obj)
1070
- os.remove(temp_tar_path)
1071
-
1072
- # Delete the temp folder
1073
- shutil.rmtree(temp_dir)
1074
- print(f"\nSaved {total_images} images to {tar_name}")
1075
-
1076
- def apply_model_to_tar_v1(tar_path, model_path, file_type='cell_png', image_size=224, batch_size=64, normalize=True, preload='images', n_jobs=10, threshold=0.5, verbose=False):
1077
-
1078
- from .io import TarImageDataset
1079
- from .utils import process_vision_results, print_progress
1080
-
1081
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1082
- if normalize:
1083
- transform = transforms.Compose([
1084
- transforms.ToTensor(),
1085
- transforms.CenterCrop(size=(image_size, image_size)),
1086
- transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
1087
- else:
1088
- transform = transforms.Compose([
1089
- transforms.ToTensor(),
1090
- transforms.CenterCrop(size=(image_size, image_size))])
1091
-
1092
- if verbose:
1093
- print(f'Loading model from {model_path}')
1094
- print(f'Loading dataset from {tar_path}')
1095
-
1096
- model = torch.load(model_path)
1097
-
1098
- dataset = TarImageDataset(tar_path, transform=transform)
1099
- data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=n_jobs, pin_memory=True)
1100
-
1101
- model_name = os.path.splitext(os.path.basename(model_path))[0]
1102
- dataset_name = os.path.splitext(os.path.basename(tar_path))[0]
1103
- date_name = datetime.date.today().strftime('%y%m%d')
1104
- dst = os.path.dirname(tar_path)
1105
- result_loc = f'{dst}/{date_name}_{dataset_name}_{model_name}_result.csv'
1106
-
1107
- model.eval()
1108
- model = model.to(device)
1109
-
1110
- if verbose:
1111
- print(model)
1112
- print(f'Generated dataset with {len(dataset)} images')
1113
- print(f'Generating loader from {len(data_loader)} batches')
1114
- print(f'Results wil be saved in: {result_loc}')
1115
- print(f'Model is in eval mode')
1116
- print(f'Model loaded to device')
1117
-
1118
- prediction_pos_probs = []
1119
- filenames_list = []
1120
- time_ls = []
1121
- gc.collect()
1122
- with torch.no_grad():
1123
- for batch_idx, (batch_images, filenames) in enumerate(data_loader, start=1):
1124
- start = time.time()
1125
- images = batch_images.to(torch.float).to(device)
1126
- outputs = model(images)
1127
- batch_prediction_pos_prob = torch.sigmoid(outputs).cpu().numpy()
1128
- prediction_pos_probs.extend(batch_prediction_pos_prob.tolist())
1129
- filenames_list.extend(filenames)
1130
- stop = time.time()
1131
- duration = stop - start
1132
- time_ls.append(duration)
1133
- files_processed = batch_idx*batch_size
1134
- files_to_process = len(data_loader)
1135
- print_progress(files_processed, files_to_process, n_jobs=n_jobs, time_ls=time_ls, batch_size=batch_size, operation_type="Tar dataset")
1136
-
1137
- data = {'path':filenames_list, 'pred':prediction_pos_probs}
1138
- df = pd.DataFrame(data, index=None)
1139
- df = process_vision_results(df, threshold)
1140
-
1141
- df.to_csv(result_loc, index=True, header=True, mode='w')
1142
- torch.cuda.empty_cache()
1143
- torch.cuda.memory.empty_cache()
1144
- return df
1145
-
1146
988
  def apply_model_to_tar(settings={}):
1147
989
 
1148
990
  from .io import TarImageDataset
@@ -1397,107 +1239,6 @@ def generate_dataset_from_lists(dst, class_data, classes, test_split=0.1):
1397
1239
 
1398
1240
  return os.path.join(dst, 'train'), os.path.join(dst, 'test')
1399
1241
 
1400
- def generate_training_dataset_v1(src, mode='annotation', annotation_column='test', annotated_classes=[1,2], classes=['nc','pc'], size=200, test_split=0.1, class_metadata=[['c1'],['c2']], metadata_type_by='col', channel_of_interest=3, custom_measurement=None, tables=None, png_type='cell_png'):
1401
-
1402
- from .io import _read_and_merge_data, _read_db
1403
- from .utils import get_paths_from_db, annotate_conditions
1404
-
1405
- db_path = os.path.join(src, 'measurements','measurements.db')
1406
- dst = os.path.join(src, 'datasets', 'training')
1407
-
1408
- if os.path.exists(dst):
1409
- for i in range(1, 1000):
1410
- dst = os.path.join(src, 'datasets', f'training_{i}')
1411
- if not os.path.exists(dst):
1412
- print(f'Creating new directory for training: {dst}')
1413
- break
1414
-
1415
- if mode == 'annotation':
1416
- class_paths_ls_2 = []
1417
- class_paths_ls = training_dataset_from_annotation(db_path, dst, annotation_column, annotated_classes=annotated_classes)
1418
- for class_paths in class_paths_ls:
1419
- class_paths_temp = random.sample(class_paths, size)
1420
- class_paths_ls_2.append(class_paths_temp)
1421
- class_paths_ls = class_paths_ls_2
1422
-
1423
- elif mode == 'metadata':
1424
- class_paths_ls = []
1425
- class_len_ls = []
1426
- [df] = _read_db(db_loc=db_path, tables=['png_list'])
1427
- df['metadata_based_class'] = pd.NA
1428
- for i, class_ in enumerate(classes):
1429
- ls = class_metadata[i]
1430
- df.loc[df[metadata_type_by].isin(ls), 'metadata_based_class'] = class_
1431
-
1432
- for class_ in classes:
1433
- if size == None:
1434
- c_s = []
1435
- for c in classes:
1436
- c_s_t_df = df[df['metadata_based_class'] == c]
1437
- c_s.append(len(c_s_t_df))
1438
- print(f'Found {len(c_s_t_df)} images for class {c}')
1439
- size = min(c_s)
1440
- print(f'Using the smallest class size: {size}')
1441
-
1442
- class_temp_df = df[df['metadata_based_class'] == class_]
1443
- class_len_ls.append(len(class_temp_df))
1444
- print(f'Found {len(class_temp_df)} images for class {class_}')
1445
- class_paths_temp = random.sample(class_temp_df['png_path'].tolist(), size)
1446
- class_paths_ls.append(class_paths_temp)
1447
-
1448
- elif mode == 'recruitment':
1449
- class_paths_ls = []
1450
- if not isinstance(tables, list):
1451
- tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
1452
-
1453
- df, _ = _read_and_merge_data(locs=[db_path],
1454
- tables=tables,
1455
- verbose=False,
1456
- include_multinucleated=True,
1457
- include_multiinfected=True,
1458
- include_noninfected=True)
1459
-
1460
- print('length df 1', len(df))
1461
-
1462
- df = annotate_conditions(df, cells=['HeLa'], cell_loc=None, pathogens=['pathogen'], pathogen_loc=None, treatments=classes, treatment_loc=class_metadata, types = ['col','col',metadata_type_by])
1463
- print('length df 2', len(df))
1464
- [png_list_df] = _read_db(db_loc=db_path, tables=['png_list'])
1465
-
1466
- if custom_measurement != None:
1467
-
1468
- if not isinstance(custom_measurement, list):
1469
- print(f'custom_measurement should be a list, add [ measurement_1, measurement_2 ] or [ measurement ]')
1470
- return
1471
-
1472
- if isinstance(custom_measurement, list):
1473
- if len(custom_measurement) == 2:
1474
- print(f'Classes will be defined by the Q1 and Q3 quantiles of recruitment ({custom_measurement[0]}/{custom_measurement[1]})')
1475
- df['recruitment'] = df[f'{custom_measurement[0]}']/df[f'{custom_measurement[1]}']
1476
- if len(custom_measurement) == 1:
1477
- print(f'Classes will be defined by the Q1 and Q3 quantiles of recruitment ({custom_measurement[0]})')
1478
- df['recruitment'] = df[f'{custom_measurement[0]}']
1479
- else:
1480
- print(f'Classes will be defined by the Q1 and Q3 quantiles of recruitment (pathogen/cytoplasm for channel {channel_of_interest})')
1481
- df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_mean_intensity']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
1482
-
1483
- q25 = df['recruitment'].quantile(0.25)
1484
- q75 = df['recruitment'].quantile(0.75)
1485
- df_lower = df[df['recruitment'] <= q25]
1486
- df_upper = df[df['recruitment'] >= q75]
1487
-
1488
- class_paths_lower = get_paths_from_db(df=df_lower, png_df=png_list_df, image_type=png_type)
1489
-
1490
- class_paths_lower = random.sample(class_paths_lower['png_path'].tolist(), size)
1491
- class_paths_ls.append(class_paths_lower)
1492
-
1493
- class_paths_upper = get_paths_from_db(df=df_upper, png_df=png_list_df, image_type=png_type)
1494
- class_paths_upper = random.sample(class_paths_upper['png_path'].tolist(), size)
1495
- class_paths_ls.append(class_paths_upper)
1496
-
1497
- generate_dataset_from_lists(dst, class_data=class_paths_ls, classes=classes, test_split=0.1)
1498
-
1499
- return
1500
-
1501
1242
  def generate_training_dataset(settings):
1502
1243
 
1503
1244
  from .io import _read_and_merge_data, _read_db
@@ -1602,21 +1343,19 @@ def generate_training_dataset(settings):
1602
1343
 
1603
1344
  return train_class_dir, test_class_dir
1604
1345
 
1605
- def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], n_jobs=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, channels=[1, 2, 3], augment=False, verbose=False):
1346
+ def generate_loaders(src, mode='train', image_size=224, batch_size=32, classes=['nc','pc'], n_jobs=None, validation_split=0.0, pin_memory=False, normalize=False, channels=[1, 2, 3], augment=False, preload_batches=3, verbose=False):
1606
1347
 
1607
1348
  """
1608
1349
  Generate data loaders for training and validation/test datasets.
1609
1350
 
1610
1351
  Parameters:
1611
1352
  - src (str): The source directory containing the data.
1612
- - train_mode (str): The training mode. Options are 'erm' (Empirical Risk Minimization) or 'irm' (Invariant Risk Minimization).
1613
1353
  - mode (str): The mode of operation. Options are 'train' or 'test'.
1614
1354
  - image_size (int): The size of the input images.
1615
1355
  - batch_size (int): The batch size for the data loaders.
1616
1356
  - classes (list): The list of classes to consider.
1617
1357
  - n_jobs (int): The number of worker threads for data loading.
1618
- - validation_split (float): The fraction of data to use for validation when train_mode is 'erm'.
1619
- - max_show (int): The maximum number of images to show when verbose is True.
1358
+ - validation_split (float): The fraction of data to use for validation.
1620
1359
  - pin_memory (bool): Whether to pin memory for faster data transfer.
1621
1360
  - normalize (bool): Whether to normalize the input images.
1622
1361
  - verbose (bool): Whether to print additional information and show images.
@@ -1625,18 +1364,10 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
1625
1364
  Returns:
1626
1365
  - train_loaders (list): List of data loaders for training datasets.
1627
1366
  - val_loaders (list): List of data loaders for validation datasets.
1628
- - plate_names (list): List of plate names (only applicable when train_mode is 'irm').
1629
1367
  """
1630
1368
 
1631
- from .io import MyDataset
1632
- from .plot import _imshow
1633
- from torchvision import transforms
1634
- from torch.utils.data import DataLoader, random_split
1635
- from collections import defaultdict
1636
- import os
1637
- import random
1638
- from PIL import Image
1639
- from torchvision.transforms import ToTensor
1369
+ from .io import spacrDataset, spacrDataLoader
1370
+ from .plot import _imshow_gpu
1640
1371
  from .utils import SelectChannels, augment_dataset
1641
1372
 
1642
1373
  chans = []
@@ -1653,12 +1384,9 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
1653
1384
  if verbose:
1654
1385
  print(f'Training a network on channels: {channels}')
1655
1386
  print(f'Channel 1: Red, Channel 2: Green, Channel 3: Blue')
1656
-
1657
- plate_to_filenames = defaultdict(list)
1658
- plate_to_labels = defaultdict(list)
1387
+
1659
1388
  train_loaders = []
1660
1389
  val_loaders = []
1661
- plate_names = []
1662
1390
 
1663
1391
  if normalize:
1664
1392
  transform = transforms.Compose([
@@ -1686,157 +1414,114 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
1686
1414
  print(f'mode:{mode} is not valid, use mode = train or test')
1687
1415
  return
1688
1416
 
1689
- if train_mode == 'erm':
1690
-
1691
- data = MyDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
1692
-
1693
- if validation_split > 0:
1694
- train_size = int((1 - validation_split) * len(data))
1695
- val_size = len(data) - train_size
1696
- if not augment:
1697
- print(f'Train data:{train_size}, Validation data:{val_size}')
1698
- train_dataset, val_dataset = random_split(data, [train_size, val_size])
1699
-
1700
- if augment:
1701
-
1702
- print(f'Data before augmentation: Train: {len(train_dataset)}, Validataion:{len(val_dataset)}')
1703
- train_dataset = augment_dataset(train_dataset, is_grayscale=(len(channels) == 1))
1704
- #val_dataset = augment_dataset(val_dataset, is_grayscale=(len(channels) == 1))
1705
- print(f'Data after augmentation: Train: {len(train_dataset)}')#, Validataion:{len(val_dataset)}')
1706
-
1707
- train_loaders = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=n_jobs if n_jobs is not None else 0, pin_memory=pin_memory)
1708
- val_loaders = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=n_jobs if n_jobs is not None else 0, pin_memory=pin_memory)
1709
- else:
1710
- train_loaders = DataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=n_jobs if n_jobs is not None else 0, pin_memory=pin_memory)
1711
-
1712
- elif train_mode == 'irm':
1713
- data = MyDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
1714
-
1715
- for filename, label in zip(data.filenames, data.labels):
1716
- plate = data.get_plate(filename)
1717
- plate_to_filenames[plate].append(filename)
1718
- plate_to_labels[plate].append(label)
1719
-
1720
- for plate, filenames in plate_to_filenames.items():
1721
- labels = plate_to_labels[plate]
1722
- plate_data = MyDataset(data_dir, classes, specific_files=filenames, specific_labels=labels, transform=transform, shuffle=False, pin_memory=pin_memory)
1723
- plate_names.append(plate)
1724
-
1725
- if validation_split > 0:
1726
- train_size = int((1 - validation_split) * len(plate_data))
1727
- val_size = len(plate_data) - train_size
1728
- if not augment:
1729
- print(f'Train data:{train_size}, Validation data:{val_size}')
1730
- train_dataset, val_dataset = random_split(plate_data, [train_size, val_size])
1731
-
1732
- if augment:
1733
-
1734
- print(f'Data before augmentation: Train: {len(train_dataset)}, Validataion:{val_dataset}')
1735
- train_dataset = augment_dataset(train_dataset, is_grayscale=(len(channels) == 1))
1736
- #val_dataset = augment_dataset(val_dataset, is_grayscale=(len(channels) == 1))
1737
- print(f'Data after augmentation: Train: {len(train_dataset)}')#, Validataion:{len(val_dataset)}')
1738
-
1739
- train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=n_jobs if n_jobs is not None else 0, pin_memory=pin_memory)
1740
- val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=n_jobs if n_jobs is not None else 0, pin_memory=pin_memory)
1741
-
1742
- train_loaders.append(train_loader)
1743
- val_loaders.append(val_loader)
1744
- else:
1745
- train_loader = DataLoader(plate_data, batch_size=batch_size, shuffle=shuffle, num_workers=n_jobs if n_jobs is not None else 0, pin_memory=pin_memory)
1746
- train_loaders.append(train_loader)
1747
- val_loaders.append(None)
1417
+ data = spacrDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
1418
+ num_workers = n_jobs if n_jobs is not None else 0
1748
1419
 
1749
- else:
1750
- print(f'train_mode:{train_mode} is not valid, use: train_mode = irm or erm')
1751
- return
1752
-
1753
-
1754
- if train_mode == 'erm':
1755
- for idx, (images, labels, filenames) in enumerate(train_loaders):
1756
- if idx >= max_show:
1757
- break
1758
- images = images.cpu()
1759
- label_strings = [str(label.item()) for label in labels]
1760
- train_fig = _imshow(images, label_strings, nrow=20, fontsize=12)
1761
- if verbose:
1762
- plt.show()
1420
+ if validation_split > 0:
1421
+ train_size = int((1 - validation_split) * len(data))
1422
+ val_size = len(data) - train_size
1423
+ if not augment:
1424
+ print(f'Train data:{train_size}, Validation data:{val_size}')
1425
+ train_dataset, val_dataset = random_split(data, [train_size, val_size])
1763
1426
 
1764
- elif train_mode == 'irm':
1765
- for plate_name, train_loader in zip(plate_names, train_loaders):
1766
- print(f'Plate: {plate_name} with {len(train_loader.dataset)} images')
1767
- for idx, (images, labels, filenames) in enumerate(train_loader):
1768
- if idx >= max_show:
1769
- break
1770
- images = images.cpu()
1771
- label_strings = [str(label.item()) for label in labels]
1772
- train_fig = _imshow(images, label_strings, nrow=20, fontsize=12)
1773
- if verbose:
1774
- plt.show()
1427
+ if augment:
1775
1428
 
1776
- return train_loaders, val_loaders, plate_names, train_fig
1429
+ print(f'Data before augmentation: Train: {len(train_dataset)}, Validataion:{len(val_dataset)}')
1430
+ train_dataset = augment_dataset(train_dataset, is_grayscale=(len(channels) == 1))
1431
+ print(f'Data after augmentation: Train: {len(train_dataset)}')
1432
+
1433
+ print(f'Generating Dataloader with {n_jobs} workers')
1434
+ #train_loaders = spacrDataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=True, preload_batches=preload_batches)
1435
+ #train_loaders = spacrDataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=True, preload_batches=preload_batches)
1777
1436
 
1778
- def analyze_recruitment(src, metadata_settings={}, advanced_settings={}):
1437
+ train_loaders = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=1, pin_memory=pin_memory, persistent_workers=True)
1438
+ val_loaders = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=1, pin_memory=pin_memory, persistent_workers=True)
1439
+ else:
1440
+ train_loaders = DataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=1, pin_memory=pin_memory, persistent_workers=True)
1441
+ #train_loaders = spacrDataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=True, preload_batches=preload_batches)
1442
+
1443
+ #dataset (Dataset) – dataset from which to load the data.
1444
+ #batch_size (int, optional) – how many samples per batch to load (default: 1).
1445
+ #shuffle (bool, optional) – set to True to have the data reshuffled at every epoch (default: False).
1446
+ #sampler (Sampler or Iterable, optional) – defines the strategy to draw samples from the dataset. Can be any Iterable with __len__ implemented. If specified, shuffle must not be specified.
1447
+ #batch_sampler (Sampler or Iterable, optional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last.
1448
+ #num_workers (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)
1449
+ #collate_fn (Callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
1450
+ #pin_memory (bool, optional) – If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them. If your data elements are a custom type, or your collate_fn returns a batch that is a custom type, see the example below.
1451
+ #drop_last (bool, optional) – set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)
1452
+ #timeout (numeric, optional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0)
1453
+ #worker_init_fn (Callable, optional) – If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)
1454
+ #multiprocessing_context (str or multiprocessing.context.BaseContext, optional) – If None, the default multiprocessing context of your operating system will be used. (default: None)
1455
+ #generator (torch.Generator, optional) – If not None, this RNG will be used by RandomSampler to generate random indexes and multiprocessing to generate base_seed for workers. (default: None)
1456
+ #prefetch_factor (int, optional, keyword-only arg) – Number of batches loaded in advance by each worker. 2 means there will be a total of 2 * num_workers batches prefetched across all workers. (default value depends on the set value for num_workers. If value of num_workers=0 default is None. Otherwise, if value of num_workers > 0 default is 2).
1457
+ #persistent_workers (bool, optional) – If True, the data loader will not shut down the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive. (default: False)
1458
+ #pin_memory_device (str, optional) – the device to pin_memory to if pin_memory is True.
1459
+
1460
+ #images, labels, filenames = next(iter(train_loaders))
1461
+ #images = images.cpu()
1462
+ #label_strings = [str(label.item()) for label in labels]
1463
+ #train_fig = _imshow_gpu(images, label_strings, nrow=20, fontsize=12)
1464
+ #if verbose:
1465
+ # plt.show()
1466
+
1467
+ train_fig = None
1468
+
1469
+ return train_loaders, val_loaders, train_fig
1470
+
1471
+ def analyze_recruitment(settings={}):
1779
1472
  """
1780
1473
  Analyze recruitment data by grouping the DataFrame by well coordinates and plotting controls and recruitment data.
1781
1474
 
1782
1475
  Parameters:
1783
- src (str): The source of the recruitment data.
1784
- metadata_settings (dict): The settings for metadata.
1785
- advanced_settings (dict): The advanced settings for recruitment analysis.
1476
+ settings (dict): settings.
1786
1477
 
1787
1478
  Returns:
1788
1479
  None
1789
1480
  """
1790
1481
 
1791
1482
  from .io import _read_and_merge_data, _results_to_csv
1792
- from .plot import plot_merged, _plot_controls, _plot_recruitment
1793
- from .utils import _object_filter, annotate_conditions, _calculate_recruitment, _group_by_well
1483
+ from .plot import plot_image_mask_overlay, _plot_controls, _plot_recruitment
1484
+ from .utils import _object_filter, annotate_conditions, _calculate_recruitment, _group_by_well, save_settings
1794
1485
  from .settings import get_analyze_recruitment_default_settings
1795
1486
 
1796
- settings = get_analyze_recruitment_default_settings(settings)
1797
-
1798
- settings_dict = {**metadata_settings, **advanced_settings}
1799
- settings_df = pd.DataFrame(list(settings_dict.items()), columns=['Key', 'Value'])
1800
- settings_csv = os.path.join(src,'settings','analyze_settings.csv')
1801
- os.makedirs(os.path.join(src,'settings'), exist_ok=True)
1802
- settings_df.to_csv(settings_csv, index=False)
1487
+ settings = get_analyze_recruitment_default_settings(settings=settings)
1488
+ save_settings(settings, name='recruitment')
1803
1489
 
1804
1490
  # metadata settings
1805
- target = metadata_settings['target']
1806
- cell_types = metadata_settings['cell_types']
1807
- cell_plate_metadata = metadata_settings['cell_plate_metadata']
1808
- pathogen_types = metadata_settings['pathogen_types']
1809
- pathogen_plate_metadata = metadata_settings['pathogen_plate_metadata']
1810
- treatments = metadata_settings['treatments']
1811
- treatment_plate_metadata = metadata_settings['treatment_plate_metadata']
1812
- metadata_types = metadata_settings['metadata_types']
1813
- channel_dims = metadata_settings['channel_dims']
1814
- cell_chann_dim = metadata_settings['cell_chann_dim']
1815
- cell_mask_dim = metadata_settings['cell_mask_dim']
1816
- nucleus_chann_dim = metadata_settings['nucleus_chann_dim']
1817
- nucleus_mask_dim = metadata_settings['nucleus_mask_dim']
1818
- pathogen_chann_dim = metadata_settings['pathogen_chann_dim']
1819
- pathogen_mask_dim = metadata_settings['pathogen_mask_dim']
1820
- channel_of_interest = metadata_settings['channel_of_interest']
1491
+ src = settings['src']
1492
+ target = settings['target']
1493
+ cell_types = settings['cell_types']
1494
+ cell_plate_metadata = settings['cell_plate_metadata']
1495
+ pathogen_types = settings['pathogen_types']
1496
+ pathogen_plate_metadata = settings['pathogen_plate_metadata']
1497
+ treatments = settings['treatments']
1498
+ treatment_plate_metadata = settings['treatment_plate_metadata']
1499
+ metadata_types = settings['metadata_types']
1500
+ channel_dims = settings['channel_dims']
1501
+ cell_chann_dim = settings['cell_chann_dim']
1502
+ cell_mask_dim = settings['cell_mask_dim']
1503
+ nucleus_chann_dim = settings['nucleus_chann_dim']
1504
+ nucleus_mask_dim = settings['nucleus_mask_dim']
1505
+ pathogen_chann_dim = settings['pathogen_chann_dim']
1506
+ pathogen_mask_dim = settings['pathogen_mask_dim']
1507
+ channel_of_interest = settings['channel_of_interest']
1821
1508
 
1822
1509
  # Advanced settings
1823
- plot = advanced_settings['plot']
1824
- plot_nr = advanced_settings['plot_nr']
1825
- plot_control = advanced_settings['plot_control']
1826
- figuresize = advanced_settings['figuresize']
1827
- remove_background = advanced_settings['remove_background']
1828
- backgrounds = advanced_settings['backgrounds']
1829
- include_noninfected = advanced_settings['include_noninfected']
1830
- include_multiinfected = advanced_settings['include_multiinfected']
1831
- include_multinucleated = advanced_settings['include_multinucleated']
1832
- cells_per_well = advanced_settings['cells_per_well']
1833
- pathogen_size_range = advanced_settings['pathogen_size_range']
1834
- nucleus_size_range = advanced_settings['nucleus_size_range']
1835
- cell_size_range = advanced_settings['cell_size_range']
1836
- pathogen_intensity_range = advanced_settings['pathogen_intensity_range']
1837
- nucleus_intensity_range = advanced_settings['nucleus_intensity_range']
1838
- cell_intensity_range = advanced_settings['cell_intensity_range']
1839
- target_intensity_min = advanced_settings['target_intensity_min']
1510
+ plot = settings['plot']
1511
+ plot_nr = settings['plot_nr']
1512
+ plot_control = settings['plot_control']
1513
+ figuresize = settings['figuresize']
1514
+ include_noninfected = settings['include_noninfected']
1515
+ include_multiinfected = settings['include_multiinfected']
1516
+ include_multinucleated = settings['include_multinucleated']
1517
+ cells_per_well = settings['cells_per_well']
1518
+ pathogen_size_range = settings['pathogen_size_range']
1519
+ nucleus_size_range = settings['nucleus_size_range']
1520
+ cell_size_range = settings['cell_size_range']
1521
+ pathogen_intensity_range = settings['pathogen_intensity_range']
1522
+ nucleus_intensity_range = settings['nucleus_intensity_range']
1523
+ cell_intensity_range = settings['cell_intensity_range']
1524
+ target_intensity_min = settings['target_intensity_min']
1840
1525
 
1841
1526
  print(f'Cell(s): {cell_types}, in {cell_plate_metadata}')
1842
1527
  print(f'Pathogen(s): {pathogen_types}, in {pathogen_plate_metadata}')
@@ -1854,9 +1539,6 @@ def analyze_recruitment(src, metadata_settings={}, advanced_settings={}):
1854
1539
  else:
1855
1540
  metadata_types = metadata_types
1856
1541
 
1857
- if isinstance(backgrounds, (int,float)):
1858
- backgrounds = [backgrounds, backgrounds, backgrounds, backgrounds]
1859
-
1860
1542
  sns.color_palette("mako", as_cmap=True)
1861
1543
  print(f'channel:{channel_of_interest} = {target}')
1862
1544
  overlay_channels = channel_dims
@@ -1866,11 +1548,11 @@ def analyze_recruitment(src, metadata_settings={}, advanced_settings={}):
1866
1548
  db_loc = [src+'/measurements/measurements.db']
1867
1549
  tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
1868
1550
  df, _ = _read_and_merge_data(db_loc,
1869
- tables,
1870
- verbose=True,
1871
- include_multinucleated=include_multinucleated,
1872
- include_multiinfected=include_multiinfected,
1873
- include_noninfected=include_noninfected)
1551
+ tables,
1552
+ verbose=True,
1553
+ include_multinucleated=include_multinucleated,
1554
+ include_multiinfected=include_multiinfected,
1555
+ include_noninfected=include_noninfected)
1874
1556
 
1875
1557
  df = annotate_conditions(df,
1876
1558
  cells=cell_types,
@@ -1889,48 +1571,31 @@ def analyze_recruitment(src, metadata_settings={}, advanced_settings={}):
1889
1571
  random.shuffle(files)
1890
1572
 
1891
1573
  _max = 10**100
1892
-
1893
- if cell_size_range is None and nucleus_size_range is None and pathogen_size_range is None:
1894
- filter_min_max = None
1895
- else:
1896
- if cell_size_range is None:
1897
- cell_size_range = [0,_max]
1898
- if nucleus_size_range is None:
1899
- nucleus_size_range = [0,_max]
1900
- if pathogen_size_range is None:
1901
- pathogen_size_range = [0,_max]
1902
-
1903
- filter_min_max = [[cell_size_range[0],cell_size_range[1]],[nucleus_size_range[0],nucleus_size_range[1]],[pathogen_size_range[0],pathogen_size_range[1]]]
1574
+ if cell_size_range is None:
1575
+ cell_size_range = [0,_max]
1576
+ if nucleus_size_range is None:
1577
+ nucleus_size_range = [0,_max]
1578
+ if pathogen_size_range is None:
1579
+ pathogen_size_range = [0,_max]
1904
1580
 
1905
1581
  if plot:
1906
- plot_settings = {'include_noninfected':include_noninfected,
1907
- 'include_multiinfected':include_multiinfected,
1908
- 'include_multinucleated':include_multinucleated,
1909
- 'remove_background':remove_background,
1910
- 'filter_min_max':filter_min_max,
1911
- 'channel_dims':channel_dims,
1912
- 'backgrounds':backgrounds,
1913
- 'cell_mask_dim':mask_dims[0],
1914
- 'nucleus_mask_dim':mask_dims[1],
1915
- 'pathogen_mask_dim':mask_dims[2],
1916
- 'overlay_chans':overlay_channels,
1917
- 'outline_thickness':3,
1918
- 'outline_color':'gbr',
1919
- 'overlay_chans':overlay_channels,
1920
- 'overlay':True,
1921
- 'normalization_percentiles':[1,99],
1922
- 'normalize':True,
1923
- 'print_object_number':True,
1924
- 'nr':plot_nr,
1925
- 'figuresize':20,
1926
- 'cmap':'inferno',
1927
- 'verbose':False}
1928
-
1929
- if os.path.exists(os.path.join(src,'merged')):
1930
- try:
1931
- plot_merged(src=os.path.join(src,'merged'), settings=plot_settings)
1932
- except Exception as e:
1933
- print(f'Failed to plot images with outlines, Error: {e}')
1582
+ merged_path = os.path.join(src,'merged')
1583
+ if os.path.exists(merged_path):
1584
+ try:
1585
+ for idx, file in enumerate(os.listdir(merged_path)):
1586
+ file_path = os.path.join(merged_path,file)
1587
+ if idx <= plot_nr:
1588
+ plot_image_mask_overlay(file_path,
1589
+ channel_dims,
1590
+ cell_chann_dim,
1591
+ nucleus_chann_dim,
1592
+ pathogen_chann_dim,
1593
+ figuresize=10,
1594
+ normalize=True,
1595
+ thickness=3,
1596
+ save_pdf=True)
1597
+ except Exception as e:
1598
+ print(f'Failed to plot images with outlines, Error: {e}')
1934
1599
 
1935
1600
  if not cell_chann_dim is None:
1936
1601
  df = _object_filter(df, object_type='cell', size_range=cell_size_range, intensity_range=cell_intensity_range, mask_chans=mask_chans, mask_chan=0)
@@ -1968,14 +1633,12 @@ def preprocess_generate_masks(src, settings={}):
1968
1633
 
1969
1634
  from .io import preprocess_img_data, _load_and_concatenate_arrays
1970
1635
  from .plot import plot_image_mask_overlay, plot_arrays
1971
- from .utils import _pivot_counts_table, check_mask_folder, adjust_cell_masks, print_progress
1636
+ from .utils import _pivot_counts_table, check_mask_folder, adjust_cell_masks, print_progress, save_settings
1972
1637
  from .settings import set_default_settings_preprocess_generate_masks
1973
1638
 
1974
1639
  settings = set_default_settings_preprocess_generate_masks(src, settings)
1975
- settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
1976
- settings_csv = os.path.join(src,'settings','preprocess_generate_masks_settings.csv')
1977
- os.makedirs(os.path.join(src,'settings'), exist_ok=True)
1978
- settings_df.to_csv(settings_csv, index=False)
1640
+ settings['src'] = src
1641
+ save_settings(settings)
1979
1642
 
1980
1643
  if not settings['pathogen_channel'] is None:
1981
1644
  custom_model_ls = ['toxo_pv_lumen','toxo_cyto']
@@ -2266,7 +1929,7 @@ def generate_cellpose_masks(src, settings, object_type):
2266
1929
  settings_df['setting_value'] = settings_df['setting_value'].apply(str)
2267
1930
  display(settings_df)
2268
1931
 
2269
- figuresize=25
1932
+ figuresize=10
2270
1933
  timelapse = settings['timelapse']
2271
1934
 
2272
1935
  if timelapse: