spacr 0.0.82__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +4 -0
- spacr/annotate_app.py +4 -0
- spacr/annotate_app_v2.py +511 -0
- spacr/core.py +254 -172
- spacr/deep_spacr.py +137 -50
- spacr/graph_learning.py +28 -8
- spacr/io.py +227 -144
- spacr/measure.py +2 -1
- spacr/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model +0 -0
- spacr/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model_settings.csv +23 -0
- spacr/plot.py +102 -6
- spacr/sequencing.py +140 -91
- spacr/settings.py +477 -0
- spacr/timelapse.py +0 -3
- spacr/utils.py +312 -275
- {spacr-0.0.82.dist-info → spacr-0.1.0.dist-info}/METADATA +1 -1
- spacr-0.1.0.dist-info/RECORD +40 -0
- spacr-0.0.82.dist-info/RECORD +0 -36
- {spacr-0.0.82.dist-info → spacr-0.1.0.dist-info}/LICENSE +0 -0
- {spacr-0.0.82.dist-info → spacr-0.1.0.dist-info}/WHEEL +0 -0
- {spacr-0.0.82.dist-info → spacr-0.1.0.dist-info}/entry_points.txt +0 -0
- {spacr-0.0.82.dist-info → spacr-0.1.0.dist-info}/top_level.txt +0 -0
spacr/core.py
CHANGED
@@ -13,7 +13,7 @@ from IPython.display import display
|
|
13
13
|
from multiprocessing import Pool, cpu_count, Value, Lock
|
14
14
|
|
15
15
|
import seaborn as sns
|
16
|
-
|
16
|
+
import cellpose
|
17
17
|
from skimage.measure import regionprops, label
|
18
18
|
from skimage.transform import resize as resizescikit
|
19
19
|
from torch.utils.data import DataLoader
|
@@ -25,6 +25,7 @@ from sklearn.linear_model import LogisticRegression
|
|
25
25
|
from sklearn.inspection import permutation_importance
|
26
26
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
|
27
27
|
from sklearn.preprocessing import StandardScaler
|
28
|
+
from sklearn.metrics import precision_recall_curve, f1_score
|
28
29
|
|
29
30
|
from scipy.spatial.distance import cosine, euclidean, mahalanobis, cityblock, minkowski, chebyshev, hamming, jaccard, braycurtis
|
30
31
|
|
@@ -41,43 +42,51 @@ from .logger import log_function_call
|
|
41
42
|
def analyze_plaques(folder):
|
42
43
|
summary_data = []
|
43
44
|
details_data = []
|
45
|
+
stats_data = []
|
44
46
|
|
45
47
|
for filename in os.listdir(folder):
|
46
48
|
filepath = os.path.join(folder, filename)
|
47
49
|
if os.path.isfile(filepath):
|
48
50
|
# Assuming each file is a NumPy array file (.npy) containing a 16-bit labeled image
|
49
|
-
image = np.load(filepath)
|
50
|
-
|
51
|
+
#image = np.load(filepath)
|
52
|
+
image = cellpose.io.imread(filepath)
|
51
53
|
labeled_image = label(image)
|
52
54
|
regions = regionprops(labeled_image)
|
53
55
|
|
54
56
|
object_count = len(regions)
|
55
57
|
sizes = [region.area for region in regions]
|
56
58
|
average_size = np.mean(sizes) if sizes else 0
|
59
|
+
std_dev_size = np.std(sizes) if sizes else 0
|
57
60
|
|
58
61
|
summary_data.append({'file': filename, 'object_count': object_count, 'average_size': average_size})
|
62
|
+
stats_data.append({'file': filename, 'plaque_count': object_count, 'average_size': average_size, 'std_dev_size': std_dev_size})
|
59
63
|
for size in sizes:
|
60
64
|
details_data.append({'file': filename, 'plaque_size': size})
|
61
65
|
|
62
66
|
# Convert lists to pandas DataFrames
|
63
67
|
summary_df = pd.DataFrame(summary_data)
|
64
68
|
details_df = pd.DataFrame(details_data)
|
69
|
+
stats_df = pd.DataFrame(stats_data)
|
65
70
|
|
66
71
|
# Save DataFrames to a SQLite database
|
67
|
-
db_name = 'plaques_analysis.db'
|
72
|
+
db_name = os.path.join(folder, 'plaques_analysis.db')
|
68
73
|
conn = sqlite3.connect(db_name)
|
69
74
|
|
70
75
|
summary_df.to_sql('summary', conn, if_exists='replace', index=False)
|
71
76
|
details_df.to_sql('details', conn, if_exists='replace', index=False)
|
77
|
+
stats_df.to_sql('stats', conn, if_exists='replace', index=False)
|
72
78
|
|
73
79
|
conn.close()
|
74
80
|
|
75
81
|
print(f"Analysis completed and saved to database '{db_name}'.")
|
76
82
|
|
83
|
+
|
77
84
|
def train_cellpose(settings):
|
78
85
|
|
79
86
|
from .io import _load_normalized_images_and_labels, _load_images_and_labels
|
80
|
-
|
87
|
+
from .settings import get_train_cellpose_default_settings#, resize_images_and_labels
|
88
|
+
|
89
|
+
settings = get_train_cellpose_default_settings()
|
81
90
|
|
82
91
|
img_src = settings['img_src']
|
83
92
|
mask_src = os.path.join(img_src, 'masks')
|
@@ -146,7 +155,7 @@ def train_cellpose(settings):
|
|
146
155
|
|
147
156
|
image_files = [os.path.join(img_src, f) for f in os.listdir(img_src) if f.endswith('.tif')]
|
148
157
|
label_files = [os.path.join(mask_src, f) for f in os.listdir(mask_src) if f.endswith('.tif')]
|
149
|
-
images, masks, image_names, mask_names = _load_normalized_images_and_labels(image_files, label_files, channels, percentiles, circular, invert, verbose, remove_background, background, Signal_to_noise, target_height, target_width)
|
158
|
+
images, masks, image_names, mask_names, orig_dims = _load_normalized_images_and_labels(image_files, label_files, channels, percentiles, circular, invert, verbose, remove_background, background, Signal_to_noise, target_height, target_width)
|
150
159
|
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
151
160
|
|
152
161
|
if test:
|
@@ -962,9 +971,10 @@ def generate_dataset(src, file_metadata=None, experiment='TSG101_screen', sample
|
|
962
971
|
shutil.rmtree(temp_dir)
|
963
972
|
print(f"\nSaved {total_images} images to {tar_name}")
|
964
973
|
|
965
|
-
def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=224, batch_size=64, normalize=True, preload='images', num_workers=10, verbose=False):
|
974
|
+
def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=224, batch_size=64, normalize=True, preload='images', num_workers=10, threshold=0.5, verbose=False):
|
966
975
|
|
967
|
-
from .io import TarImageDataset
|
976
|
+
from .io import TarImageDataset
|
977
|
+
from .utils import process_vision_results
|
968
978
|
|
969
979
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
970
980
|
if normalize:
|
@@ -1017,6 +1027,8 @@ def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=22
|
|
1017
1027
|
|
1018
1028
|
data = {'path':filenames_list, 'pred':prediction_pos_probs}
|
1019
1029
|
df = pd.DataFrame(data, index=None)
|
1030
|
+
df = process_vision_results(df, threshold)
|
1031
|
+
|
1020
1032
|
df.to_csv(result_loc, index=True, header=True, mode='w')
|
1021
1033
|
torch.cuda.empty_cache()
|
1022
1034
|
torch.cuda.memory.empty_cache()
|
@@ -1290,7 +1302,7 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
|
|
1290
1302
|
|
1291
1303
|
return
|
1292
1304
|
|
1293
|
-
def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], num_workers=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, channels=[1, 2, 3], verbose=False):
|
1305
|
+
def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], num_workers=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, channels=[1, 2, 3], augment=False, verbose=False):
|
1294
1306
|
|
1295
1307
|
"""
|
1296
1308
|
Generate data loaders for training and validation/test datasets.
|
@@ -1325,7 +1337,7 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
|
|
1325
1337
|
import random
|
1326
1338
|
from PIL import Image
|
1327
1339
|
from torchvision.transforms import ToTensor
|
1328
|
-
from .utils import SelectChannels
|
1340
|
+
from .utils import SelectChannels, augment_dataset
|
1329
1341
|
|
1330
1342
|
chans = []
|
1331
1343
|
|
@@ -1375,14 +1387,22 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
|
|
1375
1387
|
return
|
1376
1388
|
|
1377
1389
|
if train_mode == 'erm':
|
1390
|
+
|
1378
1391
|
data = MyDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
|
1392
|
+
|
1379
1393
|
if validation_split > 0:
|
1380
1394
|
train_size = int((1 - validation_split) * len(data))
|
1381
1395
|
val_size = len(data) - train_size
|
1396
|
+
if not augment:
|
1397
|
+
print(f'Train data:{train_size}, Validation data:{val_size}')
|
1398
|
+
train_dataset, val_dataset = random_split(data, [train_size, val_size])
|
1382
1399
|
|
1383
|
-
|
1400
|
+
if augment:
|
1384
1401
|
|
1385
|
-
|
1402
|
+
print(f'Data before augmentation: Train: {len(train_dataset)}, Validataion:{len(val_dataset)}')
|
1403
|
+
train_dataset = augment_dataset(train_dataset, is_grayscale=(len(channels) == 1))
|
1404
|
+
#val_dataset = augment_dataset(val_dataset, is_grayscale=(len(channels) == 1))
|
1405
|
+
print(f'Data after augmentation: Train: {len(train_dataset)}')#, Validataion:{len(val_dataset)}')
|
1386
1406
|
|
1387
1407
|
train_loaders = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
|
1388
1408
|
val_loaders = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
|
@@ -1405,10 +1425,16 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
|
|
1405
1425
|
if validation_split > 0:
|
1406
1426
|
train_size = int((1 - validation_split) * len(plate_data))
|
1407
1427
|
val_size = len(plate_data) - train_size
|
1428
|
+
if not augment:
|
1429
|
+
print(f'Train data:{train_size}, Validation data:{val_size}')
|
1430
|
+
train_dataset, val_dataset = random_split(plate_data, [train_size, val_size])
|
1408
1431
|
|
1409
|
-
|
1432
|
+
if augment:
|
1410
1433
|
|
1411
|
-
|
1434
|
+
print(f'Data before augmentation: Train: {len(train_dataset)}, Validataion:{val_dataset}')
|
1435
|
+
train_dataset = augment_dataset(train_dataset, is_grayscale=(len(channels) == 1))
|
1436
|
+
#val_dataset = augment_dataset(val_dataset, is_grayscale=(len(channels) == 1))
|
1437
|
+
print(f'Data after augmentation: Train: {len(train_dataset)}')#, Validataion:{len(val_dataset)}')
|
1412
1438
|
|
1413
1439
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
|
1414
1440
|
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
|
@@ -1423,28 +1449,33 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
|
|
1423
1449
|
else:
|
1424
1450
|
print(f'train_mode:{train_mode} is not valid, use: train_mode = irm or erm')
|
1425
1451
|
return
|
1452
|
+
|
1453
|
+
|
1454
|
+
if train_mode == 'erm':
|
1455
|
+
for idx, (images, labels, filenames) in enumerate(train_loaders):
|
1456
|
+
if idx >= max_show:
|
1457
|
+
break
|
1458
|
+
images = images.cpu()
|
1459
|
+
label_strings = [str(label.item()) for label in labels]
|
1460
|
+
train_fig = _imshow(images, label_strings, nrow=20, fontsize=12)
|
1461
|
+
if verbose:
|
1462
|
+
plt.show()
|
1426
1463
|
|
1427
|
-
|
1428
|
-
|
1429
|
-
|
1464
|
+
elif train_mode == 'irm':
|
1465
|
+
for plate_name, train_loader in zip(plate_names, train_loaders):
|
1466
|
+
print(f'Plate: {plate_name} with {len(train_loader.dataset)} images')
|
1467
|
+
for idx, (images, labels, filenames) in enumerate(train_loader):
|
1430
1468
|
if idx >= max_show:
|
1431
1469
|
break
|
1432
1470
|
images = images.cpu()
|
1433
1471
|
label_strings = [str(label.item()) for label in labels]
|
1434
|
-
_imshow(images, label_strings, nrow=20, fontsize=12)
|
1435
|
-
|
1436
|
-
|
1437
|
-
|
1438
|
-
|
1439
|
-
|
1440
|
-
|
1441
|
-
images = images.cpu()
|
1442
|
-
label_strings = [str(label.item()) for label in labels]
|
1443
|
-
_imshow(images, label_strings, nrow=20, fontsize=12)
|
1444
|
-
|
1445
|
-
return train_loaders, val_loaders, plate_names
|
1446
|
-
|
1447
|
-
def analyze_recruitment(src, metadata_settings, advanced_settings):
|
1472
|
+
train_fig = _imshow(images, label_strings, nrow=20, fontsize=12)
|
1473
|
+
if verbose:
|
1474
|
+
plt.show()
|
1475
|
+
|
1476
|
+
return train_loaders, val_loaders, plate_names, train_fig
|
1477
|
+
|
1478
|
+
def analyze_recruitment(src, metadata_settings={}, advanced_settings={}):
|
1448
1479
|
"""
|
1449
1480
|
Analyze recruitment data by grouping the DataFrame by well coordinates and plotting controls and recruitment data.
|
1450
1481
|
|
@@ -1460,6 +1491,9 @@ def analyze_recruitment(src, metadata_settings, advanced_settings):
|
|
1460
1491
|
from .io import _read_and_merge_data, _results_to_csv
|
1461
1492
|
from .plot import plot_merged, _plot_controls, _plot_recruitment
|
1462
1493
|
from .utils import _object_filter, annotate_conditions, _calculate_recruitment, _group_by_well
|
1494
|
+
from .settings import get_analyze_recruitment_default_settings
|
1495
|
+
|
1496
|
+
settings = get_analyze_recruitment_default_settings(settings)
|
1463
1497
|
|
1464
1498
|
settings_dict = {**metadata_settings, **advanced_settings}
|
1465
1499
|
settings_df = pd.DataFrame(list(settings_dict.items()), columns=['Key', 'Value'])
|
@@ -1634,8 +1668,8 @@ def preprocess_generate_masks(src, settings={}):
|
|
1634
1668
|
|
1635
1669
|
from .io import preprocess_img_data, _load_and_concatenate_arrays
|
1636
1670
|
from .plot import plot_merged, plot_arrays
|
1637
|
-
from .utils import _pivot_counts_table,
|
1638
|
-
from .
|
1671
|
+
from .utils import _pivot_counts_table, check_mask_folder, adjust_cell_masks, _merge_cells_based_on_parasite_overlap, process_masks
|
1672
|
+
from .settings import set_default_settings_preprocess_generate_masks, set_default_plot_merge_settings
|
1639
1673
|
|
1640
1674
|
settings = set_default_settings_preprocess_generate_masks(src, settings)
|
1641
1675
|
settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
|
@@ -1756,36 +1790,14 @@ def identify_masks_finetune(settings):
|
|
1756
1790
|
from .plot import print_mask_and_flows
|
1757
1791
|
from .utils import get_files_from_dir, resize_images_and_labels
|
1758
1792
|
from .io import _load_normalized_images_and_labels, _load_images_and_labels
|
1759
|
-
|
1793
|
+
from .settings import get_identify_masks_finetune_default_settings
|
1794
|
+
|
1795
|
+
settings = get_identify_masks_finetune_default_settings(settings)
|
1796
|
+
|
1760
1797
|
#User defined settings
|
1761
1798
|
src=settings['src']
|
1762
1799
|
dst=settings['dst']
|
1763
1800
|
|
1764
|
-
|
1765
|
-
settings.setdefault('model_name', 'cyto')
|
1766
|
-
settings.setdefault('custom_model', None)
|
1767
|
-
settings.setdefault('channels', [0,0])
|
1768
|
-
settings.setdefault('background', 100)
|
1769
|
-
settings.setdefault('remove_background', False)
|
1770
|
-
settings.setdefault('Signal_to_noise', 10)
|
1771
|
-
settings.setdefault('CP_prob', 0)
|
1772
|
-
settings.setdefault('diameter', 30)
|
1773
|
-
settings.setdefault('batch_size', 50)
|
1774
|
-
settings.setdefault('flow_threshold', 0.4)
|
1775
|
-
settings.setdefault('save', False)
|
1776
|
-
settings.setdefault('verbose', False)
|
1777
|
-
settings.setdefault('normalize', True)
|
1778
|
-
settings.setdefault('percentiles', None)
|
1779
|
-
settings.setdefault('circular', False)
|
1780
|
-
settings.setdefault('invert', False)
|
1781
|
-
settings.setdefault('resize', False)
|
1782
|
-
settings.setdefault('target_height', None)
|
1783
|
-
settings.setdefault('target_width', None)
|
1784
|
-
settings.setdefault('rescale', False)
|
1785
|
-
settings.setdefault('resample', False)
|
1786
|
-
settings.setdefault('grayscale', True)
|
1787
|
-
|
1788
|
-
|
1789
1801
|
model_name=settings['model_name']
|
1790
1802
|
custom_model=settings['custom_model']
|
1791
1803
|
channels = settings['channels']
|
@@ -1844,23 +1856,25 @@ def identify_masks_finetune(settings):
|
|
1844
1856
|
print(f'Cellpose settings: Model: {model_name}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{CP_prob}')
|
1845
1857
|
|
1846
1858
|
all_image_files = [os.path.join(src, f) for f in os.listdir(src) if f.endswith('.tif')]
|
1847
|
-
|
1859
|
+
mask_files = set(os.listdir(os.path.join(src, 'masks')))
|
1860
|
+
all_image_files = [f for f in all_image_files if os.path.basename(f) not in mask_files]
|
1848
1861
|
random.shuffle(all_image_files)
|
1849
1862
|
|
1850
1863
|
time_ls = []
|
1851
1864
|
for i in range(0, len(all_image_files), batch_size):
|
1865
|
+
gc.collect()
|
1852
1866
|
image_files = all_image_files[i:i+batch_size]
|
1853
1867
|
|
1854
1868
|
if normalize:
|
1855
|
-
images, _, image_names, _ = _load_normalized_images_and_labels(image_files=image_files, label_files=None, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose, remove_background=remove_background, background=background, Signal_to_noise=Signal_to_noise)
|
1869
|
+
images, _, image_names, _, orig_dims = _load_normalized_images_and_labels(image_files=image_files, label_files=None, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose, remove_background=remove_background, background=background, Signal_to_noise=Signal_to_noise, target_height=target_height, target_width=target_width)
|
1856
1870
|
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
1857
|
-
orig_dims = [(image.shape[0], image.shape[1]) for image in images]
|
1871
|
+
#orig_dims = [(image.shape[0], image.shape[1]) for image in images]
|
1858
1872
|
else:
|
1859
1873
|
images, _, image_names, _ = _load_images_and_labels(image_files=image_files, label_files=None, circular=circular, invert=invert)
|
1860
1874
|
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
1861
1875
|
orig_dims = [(image.shape[0], image.shape[1]) for image in images]
|
1862
|
-
|
1863
|
-
|
1876
|
+
if resize:
|
1877
|
+
images, _ = resize_images_and_labels(images, None, target_height, target_width, True)
|
1864
1878
|
|
1865
1879
|
for file_index, stack in enumerate(images):
|
1866
1880
|
start = time.time()
|
@@ -1899,6 +1913,8 @@ def identify_masks_finetune(settings):
|
|
1899
1913
|
os.makedirs(dst, exist_ok=True)
|
1900
1914
|
output_filename = os.path.join(dst, image_names[file_index])
|
1901
1915
|
cv2.imwrite(output_filename, mask)
|
1916
|
+
del images, output, mask, flows
|
1917
|
+
gc.collect()
|
1902
1918
|
return
|
1903
1919
|
|
1904
1920
|
def identify_masks(src, object_type, model_name, batch_size, channels, diameter, minimum_size, maximum_size, filter_intensity, flow_threshold=30, cellprob_threshold=1, figuresize=25, cmap='inferno', refine_masks=True, filter_size=True, filter_dimm=True, remove_border_objects=False, verbose=False, plot=False, merge=False, save=True, start_at=0, file_type='.npz', net_avg=True, resample=True, timelapse=False, timelapse_displacement=None, timelapse_frame_limits=None, timelapse_memory=3, timelapse_remove_transient=False, timelapse_mode='btrack', timelapse_objects='cell'):
|
@@ -2125,10 +2141,11 @@ def prepare_batch_for_cellpose(batch):
|
|
2125
2141
|
|
2126
2142
|
def generate_cellpose_masks(src, settings, object_type):
|
2127
2143
|
|
2128
|
-
from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size,
|
2144
|
+
from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_cellpose_channels, _choose_model, mask_object_count
|
2129
2145
|
from .io import _create_database, _save_object_counts_to_database, _check_masks, _get_avg_object_size
|
2130
2146
|
from .timelapse import _npz_to_movie, _btrack_track_cells, _trackpy_track_cells
|
2131
2147
|
from .plot import plot_masks
|
2148
|
+
from .settings import set_default_settings_preprocess_generate_masks, _get_object_settings
|
2132
2149
|
|
2133
2150
|
gc.collect()
|
2134
2151
|
if not torch.cuda.is_available():
|
@@ -2457,32 +2474,15 @@ def generate_masks_from_imgs(src, model, model_name, batch_size, diameter, cellp
|
|
2457
2474
|
|
2458
2475
|
|
2459
2476
|
def check_cellpose_models(settings):
|
2477
|
+
|
2478
|
+
from .settings import get_check_cellpose_models_default_settings
|
2460
2479
|
|
2480
|
+
settings = get_check_cellpose_models_default_settings(settings)
|
2461
2481
|
src = settings['src']
|
2462
|
-
settings.setdefault('batch_size', 10)
|
2463
|
-
settings.setdefault('CP_prob', 0)
|
2464
|
-
settings.setdefault('flow_threshold', 0.4)
|
2465
|
-
settings.setdefault('save', True)
|
2466
|
-
settings.setdefault('normalize', True)
|
2467
|
-
settings.setdefault('channels', [0,0])
|
2468
|
-
settings.setdefault('percentiles', None)
|
2469
|
-
settings.setdefault('circular', False)
|
2470
|
-
settings.setdefault('invert', False)
|
2471
|
-
settings.setdefault('plot', True)
|
2472
|
-
settings.setdefault('diameter', 40)
|
2473
|
-
settings.setdefault('grayscale', True)
|
2474
|
-
settings.setdefault('remove_background', False)
|
2475
|
-
settings.setdefault('background', 100)
|
2476
|
-
settings.setdefault('Signal_to_noise', 5)
|
2477
|
-
settings.setdefault('verbose', False)
|
2478
|
-
settings.setdefault('resize', False)
|
2479
|
-
settings.setdefault('target_height', None)
|
2480
|
-
settings.setdefault('target_width', None)
|
2481
2482
|
|
2482
|
-
|
2483
|
-
|
2484
|
-
|
2485
|
-
display(settings_df)
|
2483
|
+
settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
|
2484
|
+
settings_df['setting_value'] = settings_df['setting_value'].apply(str)
|
2485
|
+
display(settings_df)
|
2486
2486
|
|
2487
2487
|
cellpose_models = ['cyto', 'nuclei', 'cyto2', 'cyto3']
|
2488
2488
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
@@ -2622,8 +2622,24 @@ def _calculate_similarity(df, features, col_to_compare, val1, val2):
|
|
2622
2622
|
|
2623
2623
|
return df
|
2624
2624
|
|
2625
|
-
def
|
2626
|
-
|
2625
|
+
def find_optimal_threshold(y_true, y_pred_proba):
|
2626
|
+
"""
|
2627
|
+
Find the optimal threshold for binary classification based on the F1-score.
|
2628
|
+
|
2629
|
+
Args:
|
2630
|
+
y_true (array-like): True binary labels.
|
2631
|
+
y_pred_proba (array-like): Predicted probabilities for the positive class.
|
2632
|
+
|
2633
|
+
Returns:
|
2634
|
+
float: The optimal threshold.
|
2635
|
+
"""
|
2636
|
+
precision, recall, thresholds = precision_recall_curve(y_true, y_pred_proba)
|
2637
|
+
f1_scores = 2 * (precision * recall) / (precision + recall)
|
2638
|
+
optimal_idx = np.argmax(f1_scores)
|
2639
|
+
optimal_threshold = thresholds[optimal_idx]
|
2640
|
+
return optimal_threshold
|
2641
|
+
|
2642
|
+
def ml_analysis(df, channel_of_interest=3, location_column='col', positive_control='c2', negative_control='c1', exclude=None, n_repeats=10, top_features=30, n_estimators=100, test_size=0.2, model_type='xgboost', n_jobs=-1, remove_low_variance_features=True, remove_highly_correlated_features=True, verbose=False):
|
2627
2643
|
"""
|
2628
2644
|
Calculates permutation importance for numerical features in the dataframe,
|
2629
2645
|
comparing groups based on specified column values and uses the model to predict
|
@@ -2632,12 +2648,11 @@ def _permutation_importance(df, feature_string='channel_3', col_to_compare='col'
|
|
2632
2648
|
Args:
|
2633
2649
|
df (pandas.DataFrame): The DataFrame containing the data.
|
2634
2650
|
feature_string (str): String to filter features that contain this substring.
|
2635
|
-
|
2636
|
-
|
2651
|
+
location_column (str): Column name to use for comparing groups.
|
2652
|
+
positive_control, negative_control (str): Values in location_column to create subsets for comparison.
|
2637
2653
|
exclude (list or str, optional): Columns to exclude from features.
|
2638
2654
|
n_repeats (int): Number of repeats for permutation importance.
|
2639
|
-
|
2640
|
-
nr_to_plot (int): Number of top features to plot based on permutation importance.
|
2655
|
+
top_features (int): Number of top features to plot based on permutation importance.
|
2641
2656
|
n_estimators (int): Number of trees in the random forest, gradient boosting, or XGBoost model.
|
2642
2657
|
test_size (float): Proportion of the dataset to include in the test split.
|
2643
2658
|
random_state (int): Random seed for reproducibility.
|
@@ -2650,38 +2665,48 @@ def _permutation_importance(df, feature_string='channel_3', col_to_compare='col'
|
|
2650
2665
|
"""
|
2651
2666
|
|
2652
2667
|
from .utils import filter_dataframe_features
|
2668
|
+
from .plot import plot_permutation, plot_feature_importance
|
2653
2669
|
|
2670
|
+
random_state = 42
|
2671
|
+
|
2654
2672
|
if 'cells_per_well' in df.columns:
|
2655
2673
|
df = df.drop(columns=['cells_per_well'])
|
2656
2674
|
|
2675
|
+
df_metadata = df[[location_column]].copy()
|
2676
|
+
df, features = filter_dataframe_features(df, channel_of_interest, exclude, remove_low_variance_features, remove_highly_correlated_features, verbose)
|
2677
|
+
|
2678
|
+
|
2679
|
+
if verbose:
|
2680
|
+
print(f'Found {len(features)} numerical features in the dataframe')
|
2681
|
+
print(f'Features used in training: {features}')
|
2682
|
+
df = pd.concat([df, df_metadata[location_column]], axis=1)
|
2683
|
+
|
2657
2684
|
# Subset the dataframe based on specified column values
|
2658
|
-
df1 = df[df[
|
2659
|
-
df2 = df[df[
|
2685
|
+
df1 = df[df[location_column] == negative_control].copy()
|
2686
|
+
df2 = df[df[location_column] == positive_control].copy()
|
2660
2687
|
|
2661
2688
|
# Create target variable
|
2662
|
-
df1['target'] = 0
|
2663
|
-
df2['target'] = 1
|
2689
|
+
df1['target'] = 0 # Negative control
|
2690
|
+
df2['target'] = 1 # Positive control
|
2664
2691
|
|
2665
2692
|
# Combine the subsets for analysis
|
2666
2693
|
combined_df = pd.concat([df1, df2])
|
2667
|
-
|
2668
|
-
if
|
2669
|
-
|
2670
|
-
elif not feature_string is 'morphology':
|
2671
|
-
channel_of_interest = 'morphology'
|
2672
|
-
|
2673
|
-
_, features = filter_dataframe_features(combined_df, channel_of_interest, exclude)
|
2694
|
+
combined_df = combined_df.drop(columns=[location_column])
|
2695
|
+
if verbose:
|
2696
|
+
print(f'Found {len(df1)} samples for {negative_control} and {len(df2)} samples for {positive_control}. Total: {len(combined_df)}')
|
2674
2697
|
|
2675
2698
|
X = combined_df[features]
|
2676
2699
|
y = combined_df['target']
|
2677
2700
|
|
2678
2701
|
# Split the data into training and testing sets
|
2679
2702
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)
|
2680
|
-
|
2681
|
-
#
|
2703
|
+
|
2704
|
+
# Add data usage labels
|
2682
2705
|
combined_df['data_usage'] = 'train'
|
2683
2706
|
combined_df.loc[X_test.index, 'data_usage'] = 'test'
|
2684
|
-
|
2707
|
+
df['data_usage'] = 'not_used'
|
2708
|
+
df.loc[combined_df.index, 'data_usage'] = combined_df['data_usage']
|
2709
|
+
|
2685
2710
|
# Initialize the model based on model_type
|
2686
2711
|
if model_type == 'random_forest':
|
2687
2712
|
model = RandomForestClassifier(n_estimators=n_estimators, random_state=random_state, n_jobs=n_jobs)
|
@@ -2703,29 +2728,24 @@ def _permutation_importance(df, feature_string='channel_3', col_to_compare='col'
|
|
2703
2728
|
'feature': [features[i] for i in perm_importance.importances_mean.argsort()],
|
2704
2729
|
'importance_mean': perm_importance.importances_mean[perm_importance.importances_mean.argsort()],
|
2705
2730
|
'importance_std': perm_importance.importances_std[perm_importance.importances_mean.argsort()]
|
2706
|
-
}).tail(
|
2731
|
+
}).tail(top_features)
|
2732
|
+
|
2733
|
+
permutation_fig = plot_permutation(permutation_df)
|
2734
|
+
if verbose:
|
2735
|
+
permutation_fig.show()
|
2707
2736
|
|
2708
|
-
# Plotting
|
2709
|
-
fig, ax = plt.subplots()
|
2710
|
-
ax.barh(permutation_df['feature'], permutation_df['importance_mean'], xerr=permutation_df['importance_std'], color="teal", align="center", alpha=0.6)
|
2711
|
-
ax.set_xlabel('Permutation Importance')
|
2712
|
-
plt.tight_layout()
|
2713
|
-
plt.show()
|
2714
|
-
|
2715
2737
|
# Feature importance for models that support it
|
2716
2738
|
if model_type in ['random_forest', 'xgboost', 'gradient_boosting']:
|
2717
2739
|
feature_importances = model.feature_importances_
|
2718
2740
|
feature_importance_df = pd.DataFrame({
|
2719
2741
|
'feature': features,
|
2720
2742
|
'importance': feature_importances
|
2721
|
-
}).sort_values(by='importance', ascending=False).head(
|
2743
|
+
}).sort_values(by='importance', ascending=False).head(top_features)
|
2722
2744
|
|
2723
|
-
|
2724
|
-
|
2725
|
-
|
2726
|
-
|
2727
|
-
plt.tight_layout()
|
2728
|
-
plt.show()
|
2745
|
+
feature_importance_fig = plot_feature_importance(feature_importance_df)
|
2746
|
+
if verbose:
|
2747
|
+
feature_importance_fig.show()
|
2748
|
+
|
2729
2749
|
else:
|
2730
2750
|
feature_importance_df = pd.DataFrame()
|
2731
2751
|
|
@@ -2733,38 +2753,38 @@ def _permutation_importance(df, feature_string='channel_3', col_to_compare='col'
|
|
2733
2753
|
predictions_test = model.predict(X_test)
|
2734
2754
|
combined_df.loc[X_test.index, 'predictions'] = predictions_test
|
2735
2755
|
|
2736
|
-
#
|
2737
|
-
|
2738
|
-
|
2756
|
+
# Get prediction probabilities for the test set
|
2757
|
+
prediction_probabilities_test = model.predict_proba(X_test)
|
2758
|
+
|
2759
|
+
# Find the optimal threshold
|
2760
|
+
optimal_threshold = find_optimal_threshold(y_test, prediction_probabilities_test[:, 1])
|
2761
|
+
if verbose:
|
2762
|
+
print(f'Optimal threshold: {optimal_threshold}')
|
2739
2763
|
|
2740
2764
|
# Predicting the target variable for all other rows in the dataframe
|
2741
2765
|
X_all = df[features]
|
2742
2766
|
all_predictions = model.predict(X_all)
|
2743
2767
|
df['predictions'] = all_predictions
|
2744
2768
|
|
2745
|
-
#
|
2746
|
-
|
2747
|
-
|
2748
|
-
|
2749
|
-
|
2750
|
-
|
2751
|
-
|
2752
|
-
|
2753
|
-
|
2754
|
-
print(f"Accuracy: {accuracy}")
|
2755
|
-
print(f"Precision: {precision}")
|
2756
|
-
print(f"Recall: {recall}")
|
2757
|
-
print(f"F1 Score: {f1}")
|
2758
|
-
|
2759
|
-
# Printing class-specific accuracy metrics
|
2760
|
-
print("\nClassification Report:")
|
2761
|
-
print(classification_report(y_test, predictions_test))
|
2769
|
+
# Get prediction probabilities for all rows in the dataframe
|
2770
|
+
prediction_probabilities = model.predict_proba(X_all)
|
2771
|
+
for i in range(prediction_probabilities.shape[1]):
|
2772
|
+
df[f'prediction_probability_class_{i}'] = prediction_probabilities[:, i]
|
2773
|
+
if verbose:
|
2774
|
+
print("\nClassification Report:")
|
2775
|
+
print(classification_report(y_test, predictions_test))
|
2776
|
+
report_dict = classification_report(y_test, predictions_test, output_dict=True)
|
2777
|
+
metrics_df = pd.DataFrame(report_dict).transpose()
|
2762
2778
|
|
2763
|
-
df = _calculate_similarity(df, features,
|
2779
|
+
df = _calculate_similarity(df, features, location_column, positive_control, negative_control)
|
2764
2780
|
|
2765
|
-
|
2781
|
+
df['prcfo'] = df.index.astype(str)
|
2782
|
+
df[['plate', 'row', 'col', 'field', 'object']] = df['prcfo'].str.split('_', expand=True)
|
2783
|
+
df['prc'] = df['plate'] + '_' + df['row'] + '_' + df['col']
|
2784
|
+
|
2785
|
+
return [df, permutation_df, feature_importance_df, model, X_train, X_test, y_train, y_test, metrics_df], [permutation_fig, feature_importance_fig]
|
2766
2786
|
|
2767
|
-
def
|
2787
|
+
def shap_analysis(model, X_train, X_test):
|
2768
2788
|
|
2769
2789
|
"""
|
2770
2790
|
Performs SHAP analysis on the given model and data.
|
@@ -2773,17 +2793,45 @@ def _shap_analysis(model, X_train, X_test):
|
|
2773
2793
|
model: The trained model.
|
2774
2794
|
X_train (pandas.DataFrame): Training feature set.
|
2775
2795
|
X_test (pandas.DataFrame): Testing feature set.
|
2796
|
+
Returns:
|
2797
|
+
fig: Matplotlib figure object containing the SHAP summary plot.
|
2776
2798
|
"""
|
2777
|
-
|
2799
|
+
|
2778
2800
|
explainer = shap.Explainer(model, X_train)
|
2779
2801
|
shap_values = explainer(X_test)
|
2780
|
-
|
2802
|
+
# Create a new figure
|
2803
|
+
fig, ax = plt.subplots()
|
2781
2804
|
# Summary plot
|
2782
|
-
shap.summary_plot(shap_values, X_test)
|
2783
|
-
|
2784
|
-
|
2805
|
+
shap.summary_plot(shap_values, X_test, show=False)
|
2806
|
+
# Save the current figure (the one that SHAP just created)
|
2807
|
+
fig = plt.gcf()
|
2808
|
+
plt.close(fig) # Close the figure to prevent it from displaying immediately
|
2809
|
+
return fig
|
2810
|
+
|
2811
|
+
def check_index(df, elements=5, split_char='_'):
|
2812
|
+
problematic_indices = []
|
2813
|
+
for idx in df.index:
|
2814
|
+
parts = str(idx).split(split_char)
|
2815
|
+
if len(parts) != elements:
|
2816
|
+
problematic_indices.append(idx)
|
2817
|
+
if problematic_indices:
|
2818
|
+
print("Indices that cannot be separated into 5 parts:")
|
2819
|
+
for idx in problematic_indices:
|
2820
|
+
print(idx)
|
2821
|
+
raise ValueError(f"Found {len(problematic_indices)} problematic indices that do not split into {elements} parts.")
|
2822
|
+
|
2823
|
+
#def plate_heatmap(src, model_type='xgboost', variable='predictions', grouping='mean', min_max='allq', cmap='viridis', channel_of_interest=3, min_count=25, n_estimators=100, col_to_compare='col', pos='c2', neg='c1', exclude=None, n_repeats=10, clean=True, nr_to_plot=20, verbose=False, n_jobs=-1):
|
2824
|
+
def generate_ml_scores(src, settings):
|
2825
|
+
|
2785
2826
|
from .io import _read_and_merge_data
|
2786
|
-
from .plot import
|
2827
|
+
from .plot import plot_plates
|
2828
|
+
from .utils import get_ml_results_paths
|
2829
|
+
from .settings import set_default_analyze_screen
|
2830
|
+
|
2831
|
+
settings = set_default_analyze_screen(settings)
|
2832
|
+
|
2833
|
+
settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
|
2834
|
+
display(settings_df)
|
2787
2835
|
|
2788
2836
|
db_loc = [src+'/measurements/measurements.db']
|
2789
2837
|
tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
|
@@ -2791,27 +2839,60 @@ def plate_heatmap(src, model_type='xgboost', variable='predictions', grouping='m
|
|
2791
2839
|
|
2792
2840
|
df, _ = _read_and_merge_data(db_loc,
|
2793
2841
|
tables,
|
2794
|
-
verbose
|
2795
|
-
include_multinucleated
|
2796
|
-
include_multiinfected
|
2797
|
-
include_noninfected
|
2798
|
-
|
2799
|
-
if
|
2800
|
-
|
2801
|
-
|
2802
|
-
|
2803
|
-
|
2804
|
-
|
2805
|
-
|
2806
|
-
|
2807
|
-
|
2842
|
+
settings['verbose'],
|
2843
|
+
include_multinucleated,
|
2844
|
+
include_multiinfected,
|
2845
|
+
include_noninfected)
|
2846
|
+
|
2847
|
+
if settings['channel_of_interest'] in [0,1,2,3]:
|
2848
|
+
|
2849
|
+
df['recruitment'] = df[f"pathogen_channel_{settings['channel_of_interest']}_mean_intensity"]/df[f"cytoplasm_channel_{settings['channel_of_interest']}_mean_intensity"]
|
2850
|
+
|
2851
|
+
output, figs = ml_analysis(df,
|
2852
|
+
settings['channel_of_interest'],
|
2853
|
+
settings['location_column'],
|
2854
|
+
settings['positive_control'],
|
2855
|
+
settings['negative_control'],
|
2856
|
+
settings['exclude'],
|
2857
|
+
settings['n_repeats'],
|
2858
|
+
settings['top_features'],
|
2859
|
+
settings['n_estimators'],
|
2860
|
+
settings['test_size'],
|
2861
|
+
settings['model_type'],
|
2862
|
+
settings['n_jobs'],
|
2863
|
+
settings['remove_low_variance_features'],
|
2864
|
+
settings['remove_highly_correlated_features'],
|
2865
|
+
settings['verbose'])
|
2866
|
+
|
2867
|
+
shap_fig = shap_analysis(output[3], output[4], output[5])
|
2808
2868
|
|
2809
2869
|
features = output[0].select_dtypes(include=[np.number]).columns.tolist()
|
2810
2870
|
|
2811
|
-
if not
|
2812
|
-
raise ValueError(f"Variable {
|
2871
|
+
if not settings['heatmap_feature'] in features:
|
2872
|
+
raise ValueError(f"Variable {settings['heatmap_feature']} not found in the dataframe. Please choose one of the following: {features}")
|
2813
2873
|
|
2814
|
-
plate_heatmap =
|
2874
|
+
plate_heatmap = plot_plates(df=output[0],
|
2875
|
+
variable=settings['heatmap_feature'],
|
2876
|
+
grouping=settings['grouping'],
|
2877
|
+
min_max=settings['min_max'],
|
2878
|
+
cmap=settings['cmap'],
|
2879
|
+
min_count=settings['minimum_cell_count'],
|
2880
|
+
verbose=settings['verbose'])
|
2881
|
+
|
2882
|
+
data_path, permutation_path, feature_importance_path, model_metricks_path, permutation_fig_path, feature_importance_fig_path, shap_fig_path, plate_heatmap_path, settings_csv = get_ml_results_paths(src, settings['model_type'], settings['channel_of_interest'])
|
2883
|
+
df, permutation_df, feature_importance_df, _, _, _, _, _, metrics_df = output
|
2884
|
+
|
2885
|
+
settings_df.to_csv(settings_csv, index=False)
|
2886
|
+
df.to_csv(data_path, mode='w', encoding='utf-8')
|
2887
|
+
permutation_df.to_csv(permutation_path, mode='w', encoding='utf-8')
|
2888
|
+
feature_importance_df.to_csv(feature_importance_path, mode='w', encoding='utf-8')
|
2889
|
+
metrics_df.to_csv(model_metricks_path, mode='w', encoding='utf-8')
|
2890
|
+
|
2891
|
+
plate_heatmap.savefig(plate_heatmap_path, format='pdf')
|
2892
|
+
figs[0].savefig(permutation_fig_path, format='pdf')
|
2893
|
+
figs[1].savefig(feature_importance_fig_path, format='pdf')
|
2894
|
+
shap_fig.savefig(shap_fig_path, format='pdf')
|
2895
|
+
|
2815
2896
|
return [output, plate_heatmap]
|
2816
2897
|
|
2817
2898
|
def join_measurments_and_annotation(src, tables = ['cell', 'nucleus', 'pathogen','cytoplasm']):
|
@@ -2940,8 +3021,8 @@ def generate_image_umap(settings={}):
|
|
2940
3021
|
"""
|
2941
3022
|
|
2942
3023
|
from .io import _read_and_join_tables
|
2943
|
-
from .utils import get_db_paths, preprocess_data, reduction_and_clustering, remove_noise, generate_colors, correct_paths, plot_embedding, plot_clusters_grid,
|
2944
|
-
|
3024
|
+
from .utils import get_db_paths, preprocess_data, reduction_and_clustering, remove_noise, generate_colors, correct_paths, plot_embedding, plot_clusters_grid, cluster_feature_analysis, generate_umap_from_images
|
3025
|
+
from .settings import get_umap_image_settings
|
2945
3026
|
settings = get_umap_image_settings(settings)
|
2946
3027
|
|
2947
3028
|
if isinstance(settings['src'], str):
|
@@ -3123,7 +3204,8 @@ def reducer_hyperparameter_search(settings={}, reduction_params=None, dbscan_par
|
|
3123
3204
|
"""
|
3124
3205
|
|
3125
3206
|
from .io import _read_and_join_tables
|
3126
|
-
from .utils import get_db_paths, preprocess_data, search_reduction_and_clustering, generate_colors
|
3207
|
+
from .utils import get_db_paths, preprocess_data, search_reduction_and_clustering, generate_colors
|
3208
|
+
from .settings import get_umap_image_settings
|
3127
3209
|
|
3128
3210
|
settings = get_umap_image_settings(settings)
|
3129
3211
|
pointsize = settings['dot_size']
|