spacr 0.2.56__py3-none-any.whl → 0.2.61__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/core.py +135 -472
- spacr/deep_spacr.py +189 -270
- spacr/gui_core.py +296 -87
- spacr/gui_elements.py +34 -81
- spacr/gui_utils.py +61 -47
- spacr/io.py +104 -41
- spacr/plot.py +47 -1
- spacr/settings.py +27 -31
- spacr/utils.py +14 -13
- {spacr-0.2.56.dist-info → spacr-0.2.61.dist-info}/METADATA +1 -1
- {spacr-0.2.56.dist-info → spacr-0.2.61.dist-info}/RECORD +15 -15
- {spacr-0.2.56.dist-info → spacr-0.2.61.dist-info}/LICENSE +0 -0
- {spacr-0.2.56.dist-info → spacr-0.2.61.dist-info}/WHEEL +0 -0
- {spacr-0.2.56.dist-info → spacr-0.2.61.dist-info}/entry_points.txt +0 -0
- {spacr-0.2.56.dist-info → spacr-0.2.61.dist-info}/top_level.txt +0 -0
spacr/core.py
CHANGED
@@ -16,7 +16,6 @@ import seaborn as sns
|
|
16
16
|
import cellpose
|
17
17
|
from skimage.measure import regionprops, label
|
18
18
|
from skimage.transform import resize as resizescikit
|
19
|
-
from torch.utils.data import DataLoader
|
20
19
|
|
21
20
|
from skimage import measure
|
22
21
|
from sklearn.model_selection import train_test_split
|
@@ -43,6 +42,16 @@ import warnings
|
|
43
42
|
warnings.filterwarnings("ignore", message="3D stack used, but stitch_threshold=0 and do_3D=False, so masks are made per plane only")
|
44
43
|
|
45
44
|
|
45
|
+
from torchvision import transforms
|
46
|
+
from torch.utils.data import DataLoader, random_split
|
47
|
+
from collections import defaultdict
|
48
|
+
import os
|
49
|
+
import random
|
50
|
+
from PIL import Image
|
51
|
+
from torchvision.transforms import ToTensor
|
52
|
+
|
53
|
+
|
54
|
+
|
46
55
|
def analyze_plaques(folder):
|
47
56
|
summary_data = []
|
48
57
|
details_data = []
|
@@ -976,173 +985,6 @@ def generate_dataset(settings={}):
|
|
976
985
|
|
977
986
|
return tar_name
|
978
987
|
|
979
|
-
def generate_dataset_v1(src, file_metadata=None, experiment='TSG101_screen', sample=None):
|
980
|
-
|
981
|
-
from .utils import initiate_counter, add_images_to_tar
|
982
|
-
|
983
|
-
db_path = os.path.join(src, 'measurements', 'measurements.db')
|
984
|
-
dst = os.path.join(src, 'datasets')
|
985
|
-
all_paths = []
|
986
|
-
|
987
|
-
# Connect to the database and retrieve the image paths
|
988
|
-
print(f'Reading DataBase: {db_path}')
|
989
|
-
try:
|
990
|
-
with sqlite3.connect(db_path) as conn:
|
991
|
-
cursor = conn.cursor()
|
992
|
-
if file_metadata:
|
993
|
-
if isinstance(file_metadata, str):
|
994
|
-
cursor.execute("SELECT png_path FROM png_list WHERE png_path LIKE ?", (f"%{file_metadata}%",))
|
995
|
-
else:
|
996
|
-
cursor.execute("SELECT png_path FROM png_list")
|
997
|
-
|
998
|
-
while True:
|
999
|
-
rows = cursor.fetchmany(1000)
|
1000
|
-
if not rows:
|
1001
|
-
break
|
1002
|
-
all_paths.extend([row[0] for row in rows])
|
1003
|
-
|
1004
|
-
except sqlite3.Error as e:
|
1005
|
-
print(f"Database error: {e}")
|
1006
|
-
return
|
1007
|
-
except Exception as e:
|
1008
|
-
print(f"Error: {e}")
|
1009
|
-
return
|
1010
|
-
|
1011
|
-
if isinstance(sample, int):
|
1012
|
-
selected_paths = random.sample(all_paths, sample)
|
1013
|
-
print(f'Random selection of {len(selected_paths)} paths')
|
1014
|
-
else:
|
1015
|
-
selected_paths = all_paths
|
1016
|
-
random.shuffle(selected_paths)
|
1017
|
-
print(f'All paths: {len(selected_paths)} paths')
|
1018
|
-
|
1019
|
-
total_images = len(selected_paths)
|
1020
|
-
print(f'Found {total_images} images')
|
1021
|
-
|
1022
|
-
# Create a temp folder in dst
|
1023
|
-
temp_dir = os.path.join(dst, "temp_tars")
|
1024
|
-
os.makedirs(temp_dir, exist_ok=True)
|
1025
|
-
|
1026
|
-
# Chunking the data
|
1027
|
-
num_procs = max(2, cpu_count() - 2)
|
1028
|
-
chunk_size = len(selected_paths) // num_procs
|
1029
|
-
remainder = len(selected_paths) % num_procs
|
1030
|
-
|
1031
|
-
paths_chunks = []
|
1032
|
-
start = 0
|
1033
|
-
for i in range(num_procs):
|
1034
|
-
end = start + chunk_size + (1 if i < remainder else 0)
|
1035
|
-
paths_chunks.append(selected_paths[start:end])
|
1036
|
-
start = end
|
1037
|
-
|
1038
|
-
temp_tar_files = [os.path.join(temp_dir, f'temp_{i}.tar') for i in range(num_procs)]
|
1039
|
-
|
1040
|
-
print(f'Generating temporary tar files in {dst}')
|
1041
|
-
|
1042
|
-
# Initialize shared counter and lock
|
1043
|
-
counter = Value('i', 0)
|
1044
|
-
lock = Lock()
|
1045
|
-
|
1046
|
-
with Pool(processes=num_procs, initializer=initiate_counter, initargs=(counter, lock)) as pool:
|
1047
|
-
pool.starmap(add_images_to_tar, [(paths_chunks[i], temp_tar_files[i], total_images) for i in range(num_procs)])
|
1048
|
-
|
1049
|
-
# Combine the temporary tar files into a final tar
|
1050
|
-
date_name = datetime.date.today().strftime('%y%m%d')
|
1051
|
-
if not file_metadata is None:
|
1052
|
-
tar_name = f'{date_name}_{experiment}_{file_metadata}.tar'
|
1053
|
-
else:
|
1054
|
-
tar_name = f'{date_name}_{experiment}.tar'
|
1055
|
-
tar_name = os.path.join(dst, tar_name)
|
1056
|
-
if os.path.exists(tar_name):
|
1057
|
-
number = random.randint(1, 100)
|
1058
|
-
tar_name_2 = f'{date_name}_{experiment}_{file_metadata}_{number}.tar'
|
1059
|
-
print(f'Warning: {os.path.basename(tar_name)} exists, saving as {os.path.basename(tar_name_2)} ')
|
1060
|
-
tar_name = os.path.join(dst, tar_name_2)
|
1061
|
-
|
1062
|
-
print(f'Merging temporary files')
|
1063
|
-
|
1064
|
-
with tarfile.open(tar_name, 'w') as final_tar:
|
1065
|
-
for temp_tar_path in temp_tar_files:
|
1066
|
-
with tarfile.open(temp_tar_path, 'r') as temp_tar:
|
1067
|
-
for member in temp_tar.getmembers():
|
1068
|
-
file_obj = temp_tar.extractfile(member)
|
1069
|
-
final_tar.addfile(member, file_obj)
|
1070
|
-
os.remove(temp_tar_path)
|
1071
|
-
|
1072
|
-
# Delete the temp folder
|
1073
|
-
shutil.rmtree(temp_dir)
|
1074
|
-
print(f"\nSaved {total_images} images to {tar_name}")
|
1075
|
-
|
1076
|
-
def apply_model_to_tar_v1(tar_path, model_path, file_type='cell_png', image_size=224, batch_size=64, normalize=True, preload='images', n_jobs=10, threshold=0.5, verbose=False):
|
1077
|
-
|
1078
|
-
from .io import TarImageDataset
|
1079
|
-
from .utils import process_vision_results, print_progress
|
1080
|
-
|
1081
|
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
1082
|
-
if normalize:
|
1083
|
-
transform = transforms.Compose([
|
1084
|
-
transforms.ToTensor(),
|
1085
|
-
transforms.CenterCrop(size=(image_size, image_size)),
|
1086
|
-
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
1087
|
-
else:
|
1088
|
-
transform = transforms.Compose([
|
1089
|
-
transforms.ToTensor(),
|
1090
|
-
transforms.CenterCrop(size=(image_size, image_size))])
|
1091
|
-
|
1092
|
-
if verbose:
|
1093
|
-
print(f'Loading model from {model_path}')
|
1094
|
-
print(f'Loading dataset from {tar_path}')
|
1095
|
-
|
1096
|
-
model = torch.load(model_path)
|
1097
|
-
|
1098
|
-
dataset = TarImageDataset(tar_path, transform=transform)
|
1099
|
-
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=n_jobs, pin_memory=True)
|
1100
|
-
|
1101
|
-
model_name = os.path.splitext(os.path.basename(model_path))[0]
|
1102
|
-
dataset_name = os.path.splitext(os.path.basename(tar_path))[0]
|
1103
|
-
date_name = datetime.date.today().strftime('%y%m%d')
|
1104
|
-
dst = os.path.dirname(tar_path)
|
1105
|
-
result_loc = f'{dst}/{date_name}_{dataset_name}_{model_name}_result.csv'
|
1106
|
-
|
1107
|
-
model.eval()
|
1108
|
-
model = model.to(device)
|
1109
|
-
|
1110
|
-
if verbose:
|
1111
|
-
print(model)
|
1112
|
-
print(f'Generated dataset with {len(dataset)} images')
|
1113
|
-
print(f'Generating loader from {len(data_loader)} batches')
|
1114
|
-
print(f'Results wil be saved in: {result_loc}')
|
1115
|
-
print(f'Model is in eval mode')
|
1116
|
-
print(f'Model loaded to device')
|
1117
|
-
|
1118
|
-
prediction_pos_probs = []
|
1119
|
-
filenames_list = []
|
1120
|
-
time_ls = []
|
1121
|
-
gc.collect()
|
1122
|
-
with torch.no_grad():
|
1123
|
-
for batch_idx, (batch_images, filenames) in enumerate(data_loader, start=1):
|
1124
|
-
start = time.time()
|
1125
|
-
images = batch_images.to(torch.float).to(device)
|
1126
|
-
outputs = model(images)
|
1127
|
-
batch_prediction_pos_prob = torch.sigmoid(outputs).cpu().numpy()
|
1128
|
-
prediction_pos_probs.extend(batch_prediction_pos_prob.tolist())
|
1129
|
-
filenames_list.extend(filenames)
|
1130
|
-
stop = time.time()
|
1131
|
-
duration = stop - start
|
1132
|
-
time_ls.append(duration)
|
1133
|
-
files_processed = batch_idx*batch_size
|
1134
|
-
files_to_process = len(data_loader)
|
1135
|
-
print_progress(files_processed, files_to_process, n_jobs=n_jobs, time_ls=time_ls, batch_size=batch_size, operation_type="Tar dataset")
|
1136
|
-
|
1137
|
-
data = {'path':filenames_list, 'pred':prediction_pos_probs}
|
1138
|
-
df = pd.DataFrame(data, index=None)
|
1139
|
-
df = process_vision_results(df, threshold)
|
1140
|
-
|
1141
|
-
df.to_csv(result_loc, index=True, header=True, mode='w')
|
1142
|
-
torch.cuda.empty_cache()
|
1143
|
-
torch.cuda.memory.empty_cache()
|
1144
|
-
return df
|
1145
|
-
|
1146
988
|
def apply_model_to_tar(settings={}):
|
1147
989
|
|
1148
990
|
from .io import TarImageDataset
|
@@ -1397,107 +1239,6 @@ def generate_dataset_from_lists(dst, class_data, classes, test_split=0.1):
|
|
1397
1239
|
|
1398
1240
|
return os.path.join(dst, 'train'), os.path.join(dst, 'test')
|
1399
1241
|
|
1400
|
-
def generate_training_dataset_v1(src, mode='annotation', annotation_column='test', annotated_classes=[1,2], classes=['nc','pc'], size=200, test_split=0.1, class_metadata=[['c1'],['c2']], metadata_type_by='col', channel_of_interest=3, custom_measurement=None, tables=None, png_type='cell_png'):
|
1401
|
-
|
1402
|
-
from .io import _read_and_merge_data, _read_db
|
1403
|
-
from .utils import get_paths_from_db, annotate_conditions
|
1404
|
-
|
1405
|
-
db_path = os.path.join(src, 'measurements','measurements.db')
|
1406
|
-
dst = os.path.join(src, 'datasets', 'training')
|
1407
|
-
|
1408
|
-
if os.path.exists(dst):
|
1409
|
-
for i in range(1, 1000):
|
1410
|
-
dst = os.path.join(src, 'datasets', f'training_{i}')
|
1411
|
-
if not os.path.exists(dst):
|
1412
|
-
print(f'Creating new directory for training: {dst}')
|
1413
|
-
break
|
1414
|
-
|
1415
|
-
if mode == 'annotation':
|
1416
|
-
class_paths_ls_2 = []
|
1417
|
-
class_paths_ls = training_dataset_from_annotation(db_path, dst, annotation_column, annotated_classes=annotated_classes)
|
1418
|
-
for class_paths in class_paths_ls:
|
1419
|
-
class_paths_temp = random.sample(class_paths, size)
|
1420
|
-
class_paths_ls_2.append(class_paths_temp)
|
1421
|
-
class_paths_ls = class_paths_ls_2
|
1422
|
-
|
1423
|
-
elif mode == 'metadata':
|
1424
|
-
class_paths_ls = []
|
1425
|
-
class_len_ls = []
|
1426
|
-
[df] = _read_db(db_loc=db_path, tables=['png_list'])
|
1427
|
-
df['metadata_based_class'] = pd.NA
|
1428
|
-
for i, class_ in enumerate(classes):
|
1429
|
-
ls = class_metadata[i]
|
1430
|
-
df.loc[df[metadata_type_by].isin(ls), 'metadata_based_class'] = class_
|
1431
|
-
|
1432
|
-
for class_ in classes:
|
1433
|
-
if size == None:
|
1434
|
-
c_s = []
|
1435
|
-
for c in classes:
|
1436
|
-
c_s_t_df = df[df['metadata_based_class'] == c]
|
1437
|
-
c_s.append(len(c_s_t_df))
|
1438
|
-
print(f'Found {len(c_s_t_df)} images for class {c}')
|
1439
|
-
size = min(c_s)
|
1440
|
-
print(f'Using the smallest class size: {size}')
|
1441
|
-
|
1442
|
-
class_temp_df = df[df['metadata_based_class'] == class_]
|
1443
|
-
class_len_ls.append(len(class_temp_df))
|
1444
|
-
print(f'Found {len(class_temp_df)} images for class {class_}')
|
1445
|
-
class_paths_temp = random.sample(class_temp_df['png_path'].tolist(), size)
|
1446
|
-
class_paths_ls.append(class_paths_temp)
|
1447
|
-
|
1448
|
-
elif mode == 'recruitment':
|
1449
|
-
class_paths_ls = []
|
1450
|
-
if not isinstance(tables, list):
|
1451
|
-
tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
|
1452
|
-
|
1453
|
-
df, _ = _read_and_merge_data(locs=[db_path],
|
1454
|
-
tables=tables,
|
1455
|
-
verbose=False,
|
1456
|
-
include_multinucleated=True,
|
1457
|
-
include_multiinfected=True,
|
1458
|
-
include_noninfected=True)
|
1459
|
-
|
1460
|
-
print('length df 1', len(df))
|
1461
|
-
|
1462
|
-
df = annotate_conditions(df, cells=['HeLa'], cell_loc=None, pathogens=['pathogen'], pathogen_loc=None, treatments=classes, treatment_loc=class_metadata, types = ['col','col',metadata_type_by])
|
1463
|
-
print('length df 2', len(df))
|
1464
|
-
[png_list_df] = _read_db(db_loc=db_path, tables=['png_list'])
|
1465
|
-
|
1466
|
-
if custom_measurement != None:
|
1467
|
-
|
1468
|
-
if not isinstance(custom_measurement, list):
|
1469
|
-
print(f'custom_measurement should be a list, add [ measurement_1, measurement_2 ] or [ measurement ]')
|
1470
|
-
return
|
1471
|
-
|
1472
|
-
if isinstance(custom_measurement, list):
|
1473
|
-
if len(custom_measurement) == 2:
|
1474
|
-
print(f'Classes will be defined by the Q1 and Q3 quantiles of recruitment ({custom_measurement[0]}/{custom_measurement[1]})')
|
1475
|
-
df['recruitment'] = df[f'{custom_measurement[0]}']/df[f'{custom_measurement[1]}']
|
1476
|
-
if len(custom_measurement) == 1:
|
1477
|
-
print(f'Classes will be defined by the Q1 and Q3 quantiles of recruitment ({custom_measurement[0]})')
|
1478
|
-
df['recruitment'] = df[f'{custom_measurement[0]}']
|
1479
|
-
else:
|
1480
|
-
print(f'Classes will be defined by the Q1 and Q3 quantiles of recruitment (pathogen/cytoplasm for channel {channel_of_interest})')
|
1481
|
-
df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_mean_intensity']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
|
1482
|
-
|
1483
|
-
q25 = df['recruitment'].quantile(0.25)
|
1484
|
-
q75 = df['recruitment'].quantile(0.75)
|
1485
|
-
df_lower = df[df['recruitment'] <= q25]
|
1486
|
-
df_upper = df[df['recruitment'] >= q75]
|
1487
|
-
|
1488
|
-
class_paths_lower = get_paths_from_db(df=df_lower, png_df=png_list_df, image_type=png_type)
|
1489
|
-
|
1490
|
-
class_paths_lower = random.sample(class_paths_lower['png_path'].tolist(), size)
|
1491
|
-
class_paths_ls.append(class_paths_lower)
|
1492
|
-
|
1493
|
-
class_paths_upper = get_paths_from_db(df=df_upper, png_df=png_list_df, image_type=png_type)
|
1494
|
-
class_paths_upper = random.sample(class_paths_upper['png_path'].tolist(), size)
|
1495
|
-
class_paths_ls.append(class_paths_upper)
|
1496
|
-
|
1497
|
-
generate_dataset_from_lists(dst, class_data=class_paths_ls, classes=classes, test_split=0.1)
|
1498
|
-
|
1499
|
-
return
|
1500
|
-
|
1501
1242
|
def generate_training_dataset(settings):
|
1502
1243
|
|
1503
1244
|
from .io import _read_and_merge_data, _read_db
|
@@ -1602,21 +1343,19 @@ def generate_training_dataset(settings):
|
|
1602
1343
|
|
1603
1344
|
return train_class_dir, test_class_dir
|
1604
1345
|
|
1605
|
-
def generate_loaders(src,
|
1346
|
+
def generate_loaders(src, mode='train', image_size=224, batch_size=32, classes=['nc','pc'], n_jobs=None, validation_split=0.0, pin_memory=False, normalize=False, channels=[1, 2, 3], augment=False, preload_batches=3, verbose=False):
|
1606
1347
|
|
1607
1348
|
"""
|
1608
1349
|
Generate data loaders for training and validation/test datasets.
|
1609
1350
|
|
1610
1351
|
Parameters:
|
1611
1352
|
- src (str): The source directory containing the data.
|
1612
|
-
- train_mode (str): The training mode. Options are 'erm' (Empirical Risk Minimization) or 'irm' (Invariant Risk Minimization).
|
1613
1353
|
- mode (str): The mode of operation. Options are 'train' or 'test'.
|
1614
1354
|
- image_size (int): The size of the input images.
|
1615
1355
|
- batch_size (int): The batch size for the data loaders.
|
1616
1356
|
- classes (list): The list of classes to consider.
|
1617
1357
|
- n_jobs (int): The number of worker threads for data loading.
|
1618
|
-
- validation_split (float): The fraction of data to use for validation
|
1619
|
-
- max_show (int): The maximum number of images to show when verbose is True.
|
1358
|
+
- validation_split (float): The fraction of data to use for validation.
|
1620
1359
|
- pin_memory (bool): Whether to pin memory for faster data transfer.
|
1621
1360
|
- normalize (bool): Whether to normalize the input images.
|
1622
1361
|
- verbose (bool): Whether to print additional information and show images.
|
@@ -1625,18 +1364,10 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
|
|
1625
1364
|
Returns:
|
1626
1365
|
- train_loaders (list): List of data loaders for training datasets.
|
1627
1366
|
- val_loaders (list): List of data loaders for validation datasets.
|
1628
|
-
- plate_names (list): List of plate names (only applicable when train_mode is 'irm').
|
1629
1367
|
"""
|
1630
1368
|
|
1631
|
-
from .io import
|
1632
|
-
from .plot import
|
1633
|
-
from torchvision import transforms
|
1634
|
-
from torch.utils.data import DataLoader, random_split
|
1635
|
-
from collections import defaultdict
|
1636
|
-
import os
|
1637
|
-
import random
|
1638
|
-
from PIL import Image
|
1639
|
-
from torchvision.transforms import ToTensor
|
1369
|
+
from .io import spacrDataset, spacrDataLoader
|
1370
|
+
from .plot import _imshow_gpu
|
1640
1371
|
from .utils import SelectChannels, augment_dataset
|
1641
1372
|
|
1642
1373
|
chans = []
|
@@ -1653,12 +1384,9 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
|
|
1653
1384
|
if verbose:
|
1654
1385
|
print(f'Training a network on channels: {channels}')
|
1655
1386
|
print(f'Channel 1: Red, Channel 2: Green, Channel 3: Blue')
|
1656
|
-
|
1657
|
-
plate_to_filenames = defaultdict(list)
|
1658
|
-
plate_to_labels = defaultdict(list)
|
1387
|
+
|
1659
1388
|
train_loaders = []
|
1660
1389
|
val_loaders = []
|
1661
|
-
plate_names = []
|
1662
1390
|
|
1663
1391
|
if normalize:
|
1664
1392
|
transform = transforms.Compose([
|
@@ -1686,157 +1414,114 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
|
|
1686
1414
|
print(f'mode:{mode} is not valid, use mode = train or test')
|
1687
1415
|
return
|
1688
1416
|
|
1689
|
-
|
1690
|
-
|
1691
|
-
data = MyDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
|
1692
|
-
|
1693
|
-
if validation_split > 0:
|
1694
|
-
train_size = int((1 - validation_split) * len(data))
|
1695
|
-
val_size = len(data) - train_size
|
1696
|
-
if not augment:
|
1697
|
-
print(f'Train data:{train_size}, Validation data:{val_size}')
|
1698
|
-
train_dataset, val_dataset = random_split(data, [train_size, val_size])
|
1699
|
-
|
1700
|
-
if augment:
|
1701
|
-
|
1702
|
-
print(f'Data before augmentation: Train: {len(train_dataset)}, Validataion:{len(val_dataset)}')
|
1703
|
-
train_dataset = augment_dataset(train_dataset, is_grayscale=(len(channels) == 1))
|
1704
|
-
#val_dataset = augment_dataset(val_dataset, is_grayscale=(len(channels) == 1))
|
1705
|
-
print(f'Data after augmentation: Train: {len(train_dataset)}')#, Validataion:{len(val_dataset)}')
|
1706
|
-
|
1707
|
-
train_loaders = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=n_jobs if n_jobs is not None else 0, pin_memory=pin_memory)
|
1708
|
-
val_loaders = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=n_jobs if n_jobs is not None else 0, pin_memory=pin_memory)
|
1709
|
-
else:
|
1710
|
-
train_loaders = DataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=n_jobs if n_jobs is not None else 0, pin_memory=pin_memory)
|
1711
|
-
|
1712
|
-
elif train_mode == 'irm':
|
1713
|
-
data = MyDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
|
1714
|
-
|
1715
|
-
for filename, label in zip(data.filenames, data.labels):
|
1716
|
-
plate = data.get_plate(filename)
|
1717
|
-
plate_to_filenames[plate].append(filename)
|
1718
|
-
plate_to_labels[plate].append(label)
|
1719
|
-
|
1720
|
-
for plate, filenames in plate_to_filenames.items():
|
1721
|
-
labels = plate_to_labels[plate]
|
1722
|
-
plate_data = MyDataset(data_dir, classes, specific_files=filenames, specific_labels=labels, transform=transform, shuffle=False, pin_memory=pin_memory)
|
1723
|
-
plate_names.append(plate)
|
1724
|
-
|
1725
|
-
if validation_split > 0:
|
1726
|
-
train_size = int((1 - validation_split) * len(plate_data))
|
1727
|
-
val_size = len(plate_data) - train_size
|
1728
|
-
if not augment:
|
1729
|
-
print(f'Train data:{train_size}, Validation data:{val_size}')
|
1730
|
-
train_dataset, val_dataset = random_split(plate_data, [train_size, val_size])
|
1731
|
-
|
1732
|
-
if augment:
|
1733
|
-
|
1734
|
-
print(f'Data before augmentation: Train: {len(train_dataset)}, Validataion:{val_dataset}')
|
1735
|
-
train_dataset = augment_dataset(train_dataset, is_grayscale=(len(channels) == 1))
|
1736
|
-
#val_dataset = augment_dataset(val_dataset, is_grayscale=(len(channels) == 1))
|
1737
|
-
print(f'Data after augmentation: Train: {len(train_dataset)}')#, Validataion:{len(val_dataset)}')
|
1738
|
-
|
1739
|
-
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=n_jobs if n_jobs is not None else 0, pin_memory=pin_memory)
|
1740
|
-
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=n_jobs if n_jobs is not None else 0, pin_memory=pin_memory)
|
1741
|
-
|
1742
|
-
train_loaders.append(train_loader)
|
1743
|
-
val_loaders.append(val_loader)
|
1744
|
-
else:
|
1745
|
-
train_loader = DataLoader(plate_data, batch_size=batch_size, shuffle=shuffle, num_workers=n_jobs if n_jobs is not None else 0, pin_memory=pin_memory)
|
1746
|
-
train_loaders.append(train_loader)
|
1747
|
-
val_loaders.append(None)
|
1417
|
+
data = spacrDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
|
1418
|
+
num_workers = n_jobs if n_jobs is not None else 0
|
1748
1419
|
|
1749
|
-
|
1750
|
-
|
1751
|
-
|
1752
|
-
|
1753
|
-
|
1754
|
-
|
1755
|
-
for idx, (images, labels, filenames) in enumerate(train_loaders):
|
1756
|
-
if idx >= max_show:
|
1757
|
-
break
|
1758
|
-
images = images.cpu()
|
1759
|
-
label_strings = [str(label.item()) for label in labels]
|
1760
|
-
train_fig = _imshow(images, label_strings, nrow=20, fontsize=12)
|
1761
|
-
if verbose:
|
1762
|
-
plt.show()
|
1420
|
+
if validation_split > 0:
|
1421
|
+
train_size = int((1 - validation_split) * len(data))
|
1422
|
+
val_size = len(data) - train_size
|
1423
|
+
if not augment:
|
1424
|
+
print(f'Train data:{train_size}, Validation data:{val_size}')
|
1425
|
+
train_dataset, val_dataset = random_split(data, [train_size, val_size])
|
1763
1426
|
|
1764
|
-
|
1765
|
-
for plate_name, train_loader in zip(plate_names, train_loaders):
|
1766
|
-
print(f'Plate: {plate_name} with {len(train_loader.dataset)} images')
|
1767
|
-
for idx, (images, labels, filenames) in enumerate(train_loader):
|
1768
|
-
if idx >= max_show:
|
1769
|
-
break
|
1770
|
-
images = images.cpu()
|
1771
|
-
label_strings = [str(label.item()) for label in labels]
|
1772
|
-
train_fig = _imshow(images, label_strings, nrow=20, fontsize=12)
|
1773
|
-
if verbose:
|
1774
|
-
plt.show()
|
1427
|
+
if augment:
|
1775
1428
|
|
1776
|
-
|
1429
|
+
print(f'Data before augmentation: Train: {len(train_dataset)}, Validataion:{len(val_dataset)}')
|
1430
|
+
train_dataset = augment_dataset(train_dataset, is_grayscale=(len(channels) == 1))
|
1431
|
+
print(f'Data after augmentation: Train: {len(train_dataset)}')
|
1432
|
+
|
1433
|
+
print(f'Generating Dataloader with {n_jobs} workers')
|
1434
|
+
#train_loaders = spacrDataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=True, preload_batches=preload_batches)
|
1435
|
+
#train_loaders = spacrDataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=True, preload_batches=preload_batches)
|
1777
1436
|
|
1778
|
-
|
1437
|
+
train_loaders = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=1, pin_memory=pin_memory, persistent_workers=True)
|
1438
|
+
val_loaders = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=1, pin_memory=pin_memory, persistent_workers=True)
|
1439
|
+
else:
|
1440
|
+
train_loaders = DataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=1, pin_memory=pin_memory, persistent_workers=True)
|
1441
|
+
#train_loaders = spacrDataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=True, preload_batches=preload_batches)
|
1442
|
+
|
1443
|
+
#dataset (Dataset) – dataset from which to load the data.
|
1444
|
+
#batch_size (int, optional) – how many samples per batch to load (default: 1).
|
1445
|
+
#shuffle (bool, optional) – set to True to have the data reshuffled at every epoch (default: False).
|
1446
|
+
#sampler (Sampler or Iterable, optional) – defines the strategy to draw samples from the dataset. Can be any Iterable with __len__ implemented. If specified, shuffle must not be specified.
|
1447
|
+
#batch_sampler (Sampler or Iterable, optional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last.
|
1448
|
+
#num_workers (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)
|
1449
|
+
#collate_fn (Callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
|
1450
|
+
#pin_memory (bool, optional) – If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them. If your data elements are a custom type, or your collate_fn returns a batch that is a custom type, see the example below.
|
1451
|
+
#drop_last (bool, optional) – set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)
|
1452
|
+
#timeout (numeric, optional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0)
|
1453
|
+
#worker_init_fn (Callable, optional) – If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)
|
1454
|
+
#multiprocessing_context (str or multiprocessing.context.BaseContext, optional) – If None, the default multiprocessing context of your operating system will be used. (default: None)
|
1455
|
+
#generator (torch.Generator, optional) – If not None, this RNG will be used by RandomSampler to generate random indexes and multiprocessing to generate base_seed for workers. (default: None)
|
1456
|
+
#prefetch_factor (int, optional, keyword-only arg) – Number of batches loaded in advance by each worker. 2 means there will be a total of 2 * num_workers batches prefetched across all workers. (default value depends on the set value for num_workers. If value of num_workers=0 default is None. Otherwise, if value of num_workers > 0 default is 2).
|
1457
|
+
#persistent_workers (bool, optional) – If True, the data loader will not shut down the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive. (default: False)
|
1458
|
+
#pin_memory_device (str, optional) – the device to pin_memory to if pin_memory is True.
|
1459
|
+
|
1460
|
+
#images, labels, filenames = next(iter(train_loaders))
|
1461
|
+
#images = images.cpu()
|
1462
|
+
#label_strings = [str(label.item()) for label in labels]
|
1463
|
+
#train_fig = _imshow_gpu(images, label_strings, nrow=20, fontsize=12)
|
1464
|
+
#if verbose:
|
1465
|
+
# plt.show()
|
1466
|
+
|
1467
|
+
train_fig = None
|
1468
|
+
|
1469
|
+
return train_loaders, val_loaders, train_fig
|
1470
|
+
|
1471
|
+
def analyze_recruitment(settings={}):
|
1779
1472
|
"""
|
1780
1473
|
Analyze recruitment data by grouping the DataFrame by well coordinates and plotting controls and recruitment data.
|
1781
1474
|
|
1782
1475
|
Parameters:
|
1783
|
-
|
1784
|
-
metadata_settings (dict): The settings for metadata.
|
1785
|
-
advanced_settings (dict): The advanced settings for recruitment analysis.
|
1476
|
+
settings (dict): settings.
|
1786
1477
|
|
1787
1478
|
Returns:
|
1788
1479
|
None
|
1789
1480
|
"""
|
1790
1481
|
|
1791
1482
|
from .io import _read_and_merge_data, _results_to_csv
|
1792
|
-
from .plot import
|
1793
|
-
from .utils import _object_filter, annotate_conditions, _calculate_recruitment, _group_by_well
|
1483
|
+
from .plot import plot_image_mask_overlay, _plot_controls, _plot_recruitment
|
1484
|
+
from .utils import _object_filter, annotate_conditions, _calculate_recruitment, _group_by_well, save_settings
|
1794
1485
|
from .settings import get_analyze_recruitment_default_settings
|
1795
1486
|
|
1796
|
-
settings = get_analyze_recruitment_default_settings(settings)
|
1797
|
-
|
1798
|
-
settings_dict = {**metadata_settings, **advanced_settings}
|
1799
|
-
settings_df = pd.DataFrame(list(settings_dict.items()), columns=['Key', 'Value'])
|
1800
|
-
settings_csv = os.path.join(src,'settings','analyze_settings.csv')
|
1801
|
-
os.makedirs(os.path.join(src,'settings'), exist_ok=True)
|
1802
|
-
settings_df.to_csv(settings_csv, index=False)
|
1487
|
+
settings = get_analyze_recruitment_default_settings(settings=settings)
|
1488
|
+
save_settings(settings, name='recruitment')
|
1803
1489
|
|
1804
1490
|
# metadata settings
|
1805
|
-
|
1806
|
-
|
1807
|
-
|
1808
|
-
|
1809
|
-
|
1810
|
-
|
1811
|
-
|
1812
|
-
|
1813
|
-
|
1814
|
-
|
1815
|
-
|
1816
|
-
|
1817
|
-
|
1818
|
-
|
1819
|
-
|
1820
|
-
|
1491
|
+
src = settings['src']
|
1492
|
+
target = settings['target']
|
1493
|
+
cell_types = settings['cell_types']
|
1494
|
+
cell_plate_metadata = settings['cell_plate_metadata']
|
1495
|
+
pathogen_types = settings['pathogen_types']
|
1496
|
+
pathogen_plate_metadata = settings['pathogen_plate_metadata']
|
1497
|
+
treatments = settings['treatments']
|
1498
|
+
treatment_plate_metadata = settings['treatment_plate_metadata']
|
1499
|
+
metadata_types = settings['metadata_types']
|
1500
|
+
channel_dims = settings['channel_dims']
|
1501
|
+
cell_chann_dim = settings['cell_chann_dim']
|
1502
|
+
cell_mask_dim = settings['cell_mask_dim']
|
1503
|
+
nucleus_chann_dim = settings['nucleus_chann_dim']
|
1504
|
+
nucleus_mask_dim = settings['nucleus_mask_dim']
|
1505
|
+
pathogen_chann_dim = settings['pathogen_chann_dim']
|
1506
|
+
pathogen_mask_dim = settings['pathogen_mask_dim']
|
1507
|
+
channel_of_interest = settings['channel_of_interest']
|
1821
1508
|
|
1822
1509
|
# Advanced settings
|
1823
|
-
plot =
|
1824
|
-
plot_nr =
|
1825
|
-
plot_control =
|
1826
|
-
figuresize =
|
1827
|
-
|
1828
|
-
|
1829
|
-
|
1830
|
-
|
1831
|
-
|
1832
|
-
|
1833
|
-
|
1834
|
-
|
1835
|
-
|
1836
|
-
|
1837
|
-
|
1838
|
-
cell_intensity_range = advanced_settings['cell_intensity_range']
|
1839
|
-
target_intensity_min = advanced_settings['target_intensity_min']
|
1510
|
+
plot = settings['plot']
|
1511
|
+
plot_nr = settings['plot_nr']
|
1512
|
+
plot_control = settings['plot_control']
|
1513
|
+
figuresize = settings['figuresize']
|
1514
|
+
include_noninfected = settings['include_noninfected']
|
1515
|
+
include_multiinfected = settings['include_multiinfected']
|
1516
|
+
include_multinucleated = settings['include_multinucleated']
|
1517
|
+
cells_per_well = settings['cells_per_well']
|
1518
|
+
pathogen_size_range = settings['pathogen_size_range']
|
1519
|
+
nucleus_size_range = settings['nucleus_size_range']
|
1520
|
+
cell_size_range = settings['cell_size_range']
|
1521
|
+
pathogen_intensity_range = settings['pathogen_intensity_range']
|
1522
|
+
nucleus_intensity_range = settings['nucleus_intensity_range']
|
1523
|
+
cell_intensity_range = settings['cell_intensity_range']
|
1524
|
+
target_intensity_min = settings['target_intensity_min']
|
1840
1525
|
|
1841
1526
|
print(f'Cell(s): {cell_types}, in {cell_plate_metadata}')
|
1842
1527
|
print(f'Pathogen(s): {pathogen_types}, in {pathogen_plate_metadata}')
|
@@ -1854,9 +1539,6 @@ def analyze_recruitment(src, metadata_settings={}, advanced_settings={}):
|
|
1854
1539
|
else:
|
1855
1540
|
metadata_types = metadata_types
|
1856
1541
|
|
1857
|
-
if isinstance(backgrounds, (int,float)):
|
1858
|
-
backgrounds = [backgrounds, backgrounds, backgrounds, backgrounds]
|
1859
|
-
|
1860
1542
|
sns.color_palette("mako", as_cmap=True)
|
1861
1543
|
print(f'channel:{channel_of_interest} = {target}')
|
1862
1544
|
overlay_channels = channel_dims
|
@@ -1866,11 +1548,11 @@ def analyze_recruitment(src, metadata_settings={}, advanced_settings={}):
|
|
1866
1548
|
db_loc = [src+'/measurements/measurements.db']
|
1867
1549
|
tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
|
1868
1550
|
df, _ = _read_and_merge_data(db_loc,
|
1869
|
-
|
1870
|
-
|
1871
|
-
|
1872
|
-
|
1873
|
-
|
1551
|
+
tables,
|
1552
|
+
verbose=True,
|
1553
|
+
include_multinucleated=include_multinucleated,
|
1554
|
+
include_multiinfected=include_multiinfected,
|
1555
|
+
include_noninfected=include_noninfected)
|
1874
1556
|
|
1875
1557
|
df = annotate_conditions(df,
|
1876
1558
|
cells=cell_types,
|
@@ -1889,48 +1571,31 @@ def analyze_recruitment(src, metadata_settings={}, advanced_settings={}):
|
|
1889
1571
|
random.shuffle(files)
|
1890
1572
|
|
1891
1573
|
_max = 10**100
|
1892
|
-
|
1893
|
-
|
1894
|
-
|
1895
|
-
|
1896
|
-
|
1897
|
-
|
1898
|
-
if nucleus_size_range is None:
|
1899
|
-
nucleus_size_range = [0,_max]
|
1900
|
-
if pathogen_size_range is None:
|
1901
|
-
pathogen_size_range = [0,_max]
|
1902
|
-
|
1903
|
-
filter_min_max = [[cell_size_range[0],cell_size_range[1]],[nucleus_size_range[0],nucleus_size_range[1]],[pathogen_size_range[0],pathogen_size_range[1]]]
|
1574
|
+
if cell_size_range is None:
|
1575
|
+
cell_size_range = [0,_max]
|
1576
|
+
if nucleus_size_range is None:
|
1577
|
+
nucleus_size_range = [0,_max]
|
1578
|
+
if pathogen_size_range is None:
|
1579
|
+
pathogen_size_range = [0,_max]
|
1904
1580
|
|
1905
1581
|
if plot:
|
1906
|
-
|
1907
|
-
|
1908
|
-
|
1909
|
-
|
1910
|
-
|
1911
|
-
|
1912
|
-
|
1913
|
-
|
1914
|
-
|
1915
|
-
|
1916
|
-
|
1917
|
-
|
1918
|
-
|
1919
|
-
|
1920
|
-
|
1921
|
-
|
1922
|
-
|
1923
|
-
'print_object_number':True,
|
1924
|
-
'nr':plot_nr,
|
1925
|
-
'figuresize':20,
|
1926
|
-
'cmap':'inferno',
|
1927
|
-
'verbose':False}
|
1928
|
-
|
1929
|
-
if os.path.exists(os.path.join(src,'merged')):
|
1930
|
-
try:
|
1931
|
-
plot_merged(src=os.path.join(src,'merged'), settings=plot_settings)
|
1932
|
-
except Exception as e:
|
1933
|
-
print(f'Failed to plot images with outlines, Error: {e}')
|
1582
|
+
merged_path = os.path.join(src,'merged')
|
1583
|
+
if os.path.exists(merged_path):
|
1584
|
+
try:
|
1585
|
+
for idx, file in enumerate(os.listdir(merged_path)):
|
1586
|
+
file_path = os.path.join(merged_path,file)
|
1587
|
+
if idx <= plot_nr:
|
1588
|
+
plot_image_mask_overlay(file_path,
|
1589
|
+
channel_dims,
|
1590
|
+
cell_chann_dim,
|
1591
|
+
nucleus_chann_dim,
|
1592
|
+
pathogen_chann_dim,
|
1593
|
+
figuresize=10,
|
1594
|
+
normalize=True,
|
1595
|
+
thickness=3,
|
1596
|
+
save_pdf=True)
|
1597
|
+
except Exception as e:
|
1598
|
+
print(f'Failed to plot images with outlines, Error: {e}')
|
1934
1599
|
|
1935
1600
|
if not cell_chann_dim is None:
|
1936
1601
|
df = _object_filter(df, object_type='cell', size_range=cell_size_range, intensity_range=cell_intensity_range, mask_chans=mask_chans, mask_chan=0)
|
@@ -1968,14 +1633,12 @@ def preprocess_generate_masks(src, settings={}):
|
|
1968
1633
|
|
1969
1634
|
from .io import preprocess_img_data, _load_and_concatenate_arrays
|
1970
1635
|
from .plot import plot_image_mask_overlay, plot_arrays
|
1971
|
-
from .utils import _pivot_counts_table, check_mask_folder, adjust_cell_masks, print_progress
|
1636
|
+
from .utils import _pivot_counts_table, check_mask_folder, adjust_cell_masks, print_progress, save_settings
|
1972
1637
|
from .settings import set_default_settings_preprocess_generate_masks
|
1973
1638
|
|
1974
1639
|
settings = set_default_settings_preprocess_generate_masks(src, settings)
|
1975
|
-
|
1976
|
-
|
1977
|
-
os.makedirs(os.path.join(src,'settings'), exist_ok=True)
|
1978
|
-
settings_df.to_csv(settings_csv, index=False)
|
1640
|
+
settings['src'] = src
|
1641
|
+
save_settings(settings)
|
1979
1642
|
|
1980
1643
|
if not settings['pathogen_channel'] is None:
|
1981
1644
|
custom_model_ls = ['toxo_pv_lumen','toxo_cyto']
|
@@ -2266,7 +1929,7 @@ def generate_cellpose_masks(src, settings, object_type):
|
|
2266
1929
|
settings_df['setting_value'] = settings_df['setting_value'].apply(str)
|
2267
1930
|
display(settings_df)
|
2268
1931
|
|
2269
|
-
figuresize=
|
1932
|
+
figuresize=10
|
2270
1933
|
timelapse = settings['timelapse']
|
2271
1934
|
|
2272
1935
|
if timelapse:
|