spacr 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +19 -3
- spacr/cellpose.py +311 -0
- spacr/core.py +245 -2494
- spacr/deep_spacr.py +335 -163
- spacr/gui.py +2 -0
- spacr/gui_core.py +85 -65
- spacr/gui_elements.py +110 -5
- spacr/gui_utils.py +375 -7
- spacr/io.py +680 -141
- spacr/logger.py +28 -9
- spacr/measure.py +108 -133
- spacr/mediar.py +0 -3
- spacr/ml.py +1051 -0
- spacr/openai.py +37 -0
- spacr/plot.py +707 -20
- spacr/resources/data/lopit.csv +3833 -0
- spacr/resources/data/toxoplasma_metadata.csv +8843 -0
- spacr/resources/icons/convert.png +0 -0
- spacr/resources/{models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model → icons/dna_matrix.mp4} +0 -0
- spacr/sequencing.py +241 -1311
- spacr/settings.py +181 -50
- spacr/sim.py +0 -2
- spacr/submodules.py +349 -0
- spacr/timelapse.py +0 -2
- spacr/toxo.py +238 -0
- spacr/utils.py +776 -182
- {spacr-0.3.1.dist-info → spacr-0.3.3.dist-info}/METADATA +31 -22
- {spacr-0.3.1.dist-info → spacr-0.3.3.dist-info}/RECORD +32 -33
- spacr/chris.py +0 -50
- spacr/graph_learning.py +0 -340
- spacr/resources/MEDIAR/.git +0 -1
- spacr/resources/MEDIAR_weights/.DS_Store +0 -0
- spacr/resources/icons/.DS_Store +0 -0
- spacr/resources/icons/spacr_logo_rotation.gif +0 -0
- spacr/resources/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model_settings.csv +0 -23
- spacr/resources/models/cp/toxo_pv_lumen.CP_model +0 -0
- spacr/sim_app.py +0 -0
- {spacr-0.3.1.dist-info → spacr-0.3.3.dist-info}/LICENSE +0 -0
- {spacr-0.3.1.dist-info → spacr-0.3.3.dist-info}/WHEEL +0 -0
- {spacr-0.3.1.dist-info → spacr-0.3.3.dist-info}/entry_points.txt +0 -0
- {spacr-0.3.1.dist-info → spacr-0.3.3.dist-info}/top_level.txt +0 -0
spacr/utils.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
|
-
import
|
1
|
+
import os, re, sqlite3, torch, torchvision, random, string, shutil, cv2, tarfile, glob, psutil, platform, gzip, subprocess, time, requests, ast, traceback
|
2
2
|
|
3
3
|
import numpy as np
|
4
|
+
import pandas as pd
|
4
5
|
from cellpose import models as cp_models
|
5
6
|
from cellpose import denoise
|
6
7
|
|
@@ -11,10 +12,10 @@ from skimage.transform import resize as resizescikit
|
|
11
12
|
from skimage.morphology import dilation, square
|
12
13
|
from skimage.measure import find_contours
|
13
14
|
from skimage.segmentation import clear_border
|
15
|
+
from scipy.stats import pearsonr
|
14
16
|
|
15
17
|
from collections import defaultdict, OrderedDict
|
16
18
|
from PIL import Image
|
17
|
-
import pandas as pd
|
18
19
|
from statsmodels.stats.outliers_influence import variance_inflation_factor
|
19
20
|
from statsmodels.stats.stattools import durbin_watson
|
20
21
|
import statsmodels.formula.api as smf
|
@@ -24,7 +25,7 @@ from itertools import combinations
|
|
24
25
|
from functools import reduce
|
25
26
|
from IPython.display import display
|
26
27
|
|
27
|
-
from multiprocessing import Pool, cpu_count
|
28
|
+
from multiprocessing import Pool, cpu_count, set_start_method, get_start_method
|
28
29
|
from concurrent.futures import ThreadPoolExecutor
|
29
30
|
|
30
31
|
import torch.nn as nn
|
@@ -33,65 +34,304 @@ from torch.utils.checkpoint import checkpoint
|
|
33
34
|
from torch.utils.data import Subset
|
34
35
|
from torch.autograd import grad
|
35
36
|
|
37
|
+
from torchvision import models
|
38
|
+
from torchvision.models.resnet import ResNet18_Weights, ResNet34_Weights, ResNet50_Weights, ResNet101_Weights, ResNet152_Weights
|
39
|
+
import torchvision.transforms as transforms
|
40
|
+
from torchvision.models import resnet50
|
41
|
+
from torchvision.utils import make_grid
|
42
|
+
|
36
43
|
import seaborn as sns
|
37
44
|
import matplotlib.pyplot as plt
|
38
45
|
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
|
39
46
|
|
47
|
+
from scipy import stats
|
40
48
|
import scipy.ndimage as ndi
|
41
49
|
from scipy.spatial import distance
|
42
|
-
from scipy.stats import fisher_exact
|
50
|
+
from scipy.stats import fisher_exact, f_oneway, kruskal
|
43
51
|
from scipy.ndimage.filters import gaussian_filter
|
44
52
|
from scipy.spatial import ConvexHull
|
45
53
|
from scipy.interpolate import splprep, splev
|
46
54
|
from scipy.ndimage import binary_dilation
|
47
55
|
|
48
|
-
from sklearn.preprocessing import StandardScaler
|
49
56
|
from skimage.exposure import rescale_intensity
|
50
57
|
from sklearn.metrics import auc, precision_recall_curve
|
51
58
|
from sklearn.model_selection import train_test_split
|
52
59
|
from sklearn.linear_model import Lasso, Ridge
|
53
|
-
from sklearn.preprocessing import OneHotEncoder
|
54
|
-
from sklearn.cluster import KMeans
|
55
|
-
from sklearn.preprocessing import StandardScaler
|
56
|
-
from sklearn.cluster import DBSCAN
|
57
|
-
from sklearn.cluster import KMeans
|
60
|
+
from sklearn.preprocessing import OneHotEncoder, StandardScaler
|
61
|
+
from sklearn.cluster import KMeans, DBSCAN
|
58
62
|
from sklearn.manifold import TSNE
|
59
|
-
from sklearn.cluster import KMeans
|
60
63
|
from sklearn.decomposition import PCA
|
64
|
+
from sklearn.ensemble import RandomForestClassifier
|
65
|
+
|
66
|
+
from huggingface_hub import list_repo_files
|
61
67
|
|
62
68
|
import umap.umap_ as umap
|
69
|
+
#import umap
|
63
70
|
|
64
|
-
|
65
|
-
from torchvision.models.resnet import ResNet18_Weights, ResNet34_Weights, ResNet50_Weights, ResNet101_Weights, ResNet152_Weights
|
66
|
-
import torchvision.transforms as transforms
|
71
|
+
def filepaths_to_database(img_paths, settings, source_folder, crop_mode):
|
67
72
|
|
68
|
-
|
69
|
-
from sklearn.preprocessing import StandardScaler
|
70
|
-
from scipy.stats import f_oneway, kruskal
|
71
|
-
from sklearn.cluster import KMeans
|
72
|
-
from scipy import stats
|
73
|
+
png_df = pd.DataFrame(img_paths, columns=['png_path'])
|
73
74
|
|
74
|
-
|
75
|
-
from multiprocessing import set_start_method, get_start_method
|
75
|
+
png_df['file_name'] = png_df['png_path'].apply(lambda x: os.path.basename(x))
|
76
76
|
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
77
|
+
parts = png_df['file_name'].apply(lambda x: pd.Series(_map_wells_png(x, timelapse=settings['timelapse'])))
|
78
|
+
|
79
|
+
columns = ['plate', 'row', 'col', 'field']
|
80
|
+
|
81
|
+
if settings['timelapse']:
|
82
|
+
columns = columns + ['time_id']
|
83
|
+
|
84
|
+
columns = columns + ['prcfo']
|
85
|
+
|
86
|
+
if crop_mode == 'cell':
|
87
|
+
columns = columns + ['cell_id']
|
88
|
+
|
89
|
+
if crop_mode == 'nucleus':
|
90
|
+
columns = columns + ['nucleus_id']
|
91
|
+
|
92
|
+
if crop_mode == 'pathogen':
|
93
|
+
columns = columns + ['pathogen_id']
|
94
|
+
|
95
|
+
if crop_mode == 'cytoplasm':
|
96
|
+
columns = columns + ['cytoplasm_id']
|
97
|
+
|
98
|
+
png_df[columns] = parts
|
99
|
+
|
100
|
+
try:
|
101
|
+
conn = sqlite3.connect(f'{source_folder}/measurements/measurements.db', timeout=5)
|
102
|
+
png_df.to_sql('png_list', conn, if_exists='append', index=False)
|
103
|
+
conn.commit()
|
104
|
+
except sqlite3.OperationalError as e:
|
105
|
+
print(f"SQLite error: {e}", flush=True)
|
106
|
+
traceback.print_exc()
|
107
|
+
|
108
|
+
def activation_maps_to_database(img_paths, source_folder, settings):
|
109
|
+
from .io import _create_database
|
110
|
+
|
111
|
+
png_df = pd.DataFrame(img_paths, columns=['png_path'])
|
112
|
+
png_df['file_name'] = png_df['png_path'].apply(lambda x: os.path.basename(x))
|
113
|
+
parts = png_df['file_name'].apply(lambda x: pd.Series(_map_wells_png(x, timelapse=False)))
|
114
|
+
columns = ['plate', 'row', 'col', 'field', 'prcfo', 'object']
|
115
|
+
png_df[columns] = parts
|
116
|
+
|
117
|
+
dataset_name = os.path.splitext(os.path.basename(settings['dataset']))[0]
|
118
|
+
database_name = f"{source_folder}/measurements/{dataset_name}.db"
|
119
|
+
|
120
|
+
if not os.path.exists(database_name):
|
121
|
+
_create_database(database_name)
|
122
|
+
|
123
|
+
try:
|
124
|
+
conn = sqlite3.connect(database_name, timeout=5)
|
125
|
+
png_df.to_sql(f"{settings['cam_type']}_list", conn, if_exists='append', index=False)
|
126
|
+
conn.commit()
|
127
|
+
except sqlite3.OperationalError as e:
|
128
|
+
print(f"SQLite error: {e}", flush=True)
|
129
|
+
traceback.print_exc()
|
130
|
+
|
131
|
+
def activation_correlations_to_database(df, img_paths, source_folder, settings):
|
132
|
+
from .io import _create_database
|
133
|
+
|
134
|
+
png_df = pd.DataFrame(img_paths, columns=['png_path'])
|
135
|
+
png_df['file_name'] = png_df['png_path'].apply(lambda x: os.path.basename(x))
|
136
|
+
parts = png_df['file_name'].apply(lambda x: pd.Series(_map_wells_png(x, timelapse=False)))
|
137
|
+
columns = ['plate', 'row', 'col', 'field', 'prcfo', 'object']
|
138
|
+
png_df[columns] = parts
|
139
|
+
|
140
|
+
# Align both DataFrames by file_name
|
141
|
+
png_df.set_index('file_name', inplace=True)
|
142
|
+
df.set_index('file_name', inplace=True)
|
143
|
+
|
144
|
+
merged_df = pd.concat([png_df, df], axis=1)
|
145
|
+
merged_df.reset_index(inplace=True)
|
146
|
+
|
147
|
+
dataset_name = os.path.splitext(os.path.basename(settings['dataset']))[0]
|
148
|
+
database_name = f"{source_folder}/measurements/{dataset_name}.db"
|
149
|
+
|
150
|
+
if not os.path.exists(database_name):
|
151
|
+
_create_database(database_name)
|
152
|
+
|
153
|
+
try:
|
154
|
+
conn = sqlite3.connect(database_name, timeout=5)
|
155
|
+
merged_df.to_sql(f"{settings['cam_type']}_correlations", conn, if_exists='append', index=False)
|
156
|
+
conn.commit()
|
157
|
+
except sqlite3.OperationalError as e:
|
158
|
+
print(f"SQLite error: {e}", flush=True)
|
159
|
+
traceback.print_exc()
|
160
|
+
|
161
|
+
def calculate_activation_correlations(inputs, activation_maps, file_names, manders_thresholds=[15, 50, 75]):
|
162
|
+
"""
|
163
|
+
Calculates Pearson and Manders correlations between input image channels and activation map channels.
|
164
|
+
|
165
|
+
Args:
|
166
|
+
inputs: A batch of input images, Tensor of shape (batch_size, channels, height, width)
|
167
|
+
activation_maps: A batch of activation maps, Tensor of shape (batch_size, channels, height, width)
|
168
|
+
file_names: List of file names corresponding to each image in the batch.
|
169
|
+
manders_thresholds: List of intensity percentiles to calculate Manders correlation.
|
170
|
+
|
171
|
+
Returns:
|
172
|
+
df_correlations: A DataFrame with columns for pairwise correlations (Pearson and Manders)
|
173
|
+
between input channels and activation map channels.
|
174
|
+
"""
|
175
|
+
|
176
|
+
# Ensure tensors are detached and moved to CPU before converting to numpy
|
177
|
+
inputs = inputs.detach().cpu()
|
178
|
+
activation_maps = activation_maps.detach().cpu()
|
179
|
+
|
180
|
+
batch_size, in_channels, height, width = inputs.shape
|
181
|
+
|
182
|
+
if activation_maps.dim() == 3:
|
183
|
+
# If activation maps have no channels, add a dummy channel dimension
|
184
|
+
activation_maps = activation_maps.unsqueeze(1) # Now shape is (batch_size, 1, height, width)
|
185
|
+
|
186
|
+
_, act_channels, act_height, act_width = activation_maps.shape
|
187
|
+
|
188
|
+
# Ensure that the inputs and activation maps are the same size
|
189
|
+
if (height != act_height) or (width != act_width):
|
190
|
+
activation_maps = torch.nn.functional.interpolate(activation_maps, size=(height, width), mode='bilinear')
|
191
|
+
|
192
|
+
# Dictionary to collect correlation results
|
193
|
+
correlations_dict = {'file_name': []}
|
194
|
+
|
195
|
+
# Initialize correlation columns based on input channels and activation map channels
|
196
|
+
for in_c in range(in_channels):
|
197
|
+
for act_c in range(act_channels):
|
198
|
+
correlations_dict[f'channel_{in_c}_activation_{act_c}_pearsons'] = []
|
199
|
+
for threshold in manders_thresholds:
|
200
|
+
correlations_dict[f'channel_{in_c}_activation_{act_c}_{threshold}_M1'] = []
|
201
|
+
correlations_dict[f'channel_{in_c}_activation_{act_c}_{threshold}_M2'] = []
|
202
|
+
|
203
|
+
# Loop over the batch
|
204
|
+
for b in range(batch_size):
|
205
|
+
input_img = inputs[b] # Input image channels (C, H, W)
|
206
|
+
activation_map = activation_maps[b] # Activation map channels (C, H, W)
|
207
|
+
|
208
|
+
# Add the file name to the current row
|
209
|
+
correlations_dict['file_name'].append(file_names[b])
|
210
|
+
|
211
|
+
# Calculate correlations for each channel pair
|
212
|
+
for in_c in range(in_channels):
|
213
|
+
input_channel = input_img[in_c].flatten().numpy() # Flatten the input image channel
|
214
|
+
input_channel = input_channel[np.isfinite(input_channel)] # Remove NaN or inf values
|
215
|
+
|
216
|
+
for act_c in range(act_channels):
|
217
|
+
activation_channel = activation_map[act_c].flatten().numpy() # Flatten the activation map channel
|
218
|
+
activation_channel = activation_channel[np.isfinite(activation_channel)] # Remove NaN or inf values
|
219
|
+
|
220
|
+
# Check if there are valid (non-empty) arrays left to calculate the Pearson correlation
|
221
|
+
if input_channel.size > 0 and activation_channel.size > 0:
|
222
|
+
pearson_corr, _ = pearsonr(input_channel, activation_channel)
|
223
|
+
else:
|
224
|
+
pearson_corr = np.nan # Assign NaN if there are no valid data points
|
225
|
+
correlations_dict[f'channel_{in_c}_activation_{act_c}_pearsons'].append(pearson_corr)
|
226
|
+
|
227
|
+
# Compute Manders correlations for each threshold
|
228
|
+
for threshold in manders_thresholds:
|
229
|
+
# Get the top percentile pixels based on intensity in both channels
|
230
|
+
if input_channel.size > 0 and activation_channel.size > 0:
|
231
|
+
input_threshold = np.percentile(input_channel, threshold)
|
232
|
+
activation_threshold = np.percentile(activation_channel, threshold)
|
233
|
+
|
234
|
+
# Mask the pixels above the threshold
|
235
|
+
mask = (input_channel >= input_threshold) & (activation_channel >= activation_threshold)
|
236
|
+
|
237
|
+
# If we have enough pixels, calculate Manders correlation
|
238
|
+
if np.sum(mask) > 0:
|
239
|
+
manders_corr_M1 = np.sum(input_channel[mask] * activation_channel[mask]) / np.sum(input_channel[mask] ** 2)
|
240
|
+
manders_corr_M2 = np.sum(activation_channel[mask] * input_channel[mask]) / np.sum(activation_channel[mask] ** 2)
|
241
|
+
else:
|
242
|
+
manders_corr_M1 = np.nan
|
243
|
+
manders_corr_M2 = np.nan
|
244
|
+
else:
|
245
|
+
manders_corr_M1 = np.nan
|
246
|
+
manders_corr_M2 = np.nan
|
247
|
+
|
248
|
+
# Store the Manders correlation for this threshold
|
249
|
+
correlations_dict[f'channel_{in_c}_activation_{act_c}_{threshold}_M1'].append(manders_corr_M1)
|
250
|
+
correlations_dict[f'channel_{in_c}_activation_{act_c}_{threshold}_M2'].append(manders_corr_M2)
|
251
|
+
|
252
|
+
# Convert the dictionary to a DataFrame
|
253
|
+
df_correlations = pd.DataFrame(correlations_dict)
|
254
|
+
|
255
|
+
return df_correlations
|
256
|
+
|
257
|
+
def load_settings(csv_file_path, show=False, setting_key='setting_key', setting_value='setting_value'):
|
258
|
+
"""
|
259
|
+
Convert a CSV file with 'settings_key' and 'settings_value' columns into a dictionary.
|
260
|
+
Handles special cases where values are lists, tuples, booleans, None, integers, floats, and nested dictionaries.
|
261
|
+
|
262
|
+
Args:
|
263
|
+
csv_file_path (str): The path to the CSV file.
|
264
|
+
show (bool): Whether to display the dataframe (for debugging).
|
265
|
+
setting_key (str): The name of the column that contains the setting keys.
|
266
|
+
setting_value (str): The name of the column that contains the setting values.
|
267
|
+
|
268
|
+
Returns:
|
269
|
+
dict: A dictionary where 'settings_key' are the keys and 'settings_value' are the values.
|
270
|
+
"""
|
271
|
+
# Read the CSV file into a DataFrame
|
272
|
+
df = pd.read_csv(csv_file_path)
|
273
|
+
|
274
|
+
if show:
|
275
|
+
display(df)
|
276
|
+
|
277
|
+
# Ensure the columns 'setting_key' and 'setting_value' exist
|
278
|
+
if setting_key not in df.columns or setting_value not in df.columns:
|
279
|
+
raise ValueError(f"CSV file must contain {setting_key} and {setting_value} columns.")
|
280
|
+
|
281
|
+
def parse_value(value):
|
282
|
+
"""Parse the string value into the appropriate Python data type."""
|
283
|
+
# Handle empty values
|
284
|
+
if pd.isna(value) or value == '':
|
285
|
+
return None
|
286
|
+
|
287
|
+
# Handle boolean values
|
288
|
+
if value == 'True':
|
289
|
+
return True
|
290
|
+
if value == 'False':
|
291
|
+
return False
|
292
|
+
|
293
|
+
# Handle lists, tuples, dictionaries, and other literals
|
294
|
+
if value.startswith(('(', '[', '{')): # If it starts with (, [ or {, use ast.literal_eval
|
295
|
+
try:
|
296
|
+
parsed_value = ast.literal_eval(value)
|
297
|
+
# If parsed_value is a dict, recursively parse its values
|
298
|
+
if isinstance(parsed_value, dict):
|
299
|
+
parsed_value = {k: parse_value(v) for k, v in parsed_value.items()}
|
300
|
+
return parsed_value
|
301
|
+
except (ValueError, SyntaxError):
|
302
|
+
pass # If there's an error, return the value as-is
|
303
|
+
|
304
|
+
# Handle numeric values (integers and floats)
|
305
|
+
try:
|
306
|
+
if '.' in value:
|
307
|
+
return float(value) # If it contains a dot, convert to float
|
308
|
+
return int(value) # Otherwise, convert to integer
|
309
|
+
except ValueError:
|
310
|
+
pass # If it's not a valid number, return the value as-is
|
311
|
+
|
312
|
+
# Return the original value if no other type matched
|
313
|
+
return value
|
314
|
+
|
315
|
+
# Convert the DataFrame to a dictionary, with parsing of each value
|
316
|
+
result_dict = {key: parse_value(value) for key, value in zip(df[setting_key], df[setting_value])}
|
89
317
|
|
90
|
-
|
318
|
+
return result_dict
|
319
|
+
|
320
|
+
|
321
|
+
def save_settings(settings, name='settings', show=False):
|
91
322
|
|
92
323
|
settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
|
93
|
-
|
94
|
-
|
324
|
+
if show:
|
325
|
+
display(settings_df)
|
326
|
+
|
327
|
+
if isinstance(settings['src'], list):
|
328
|
+
src = settings['src'][0]
|
329
|
+
name = f"{name}_list"
|
330
|
+
else:
|
331
|
+
src = settings['src']
|
332
|
+
|
333
|
+
settings_csv = os.path.join(src,'settings',f'{name}.csv')
|
334
|
+
os.makedirs(os.path.join(src,'settings'), exist_ok=True)
|
95
335
|
settings_df.to_csv(settings_csv, index=False)
|
96
336
|
|
97
337
|
def print_progress(files_processed, files_to_process, n_jobs, time_ls=None, batch_size=None, operation_type=""):
|
@@ -303,7 +543,7 @@ def _get_cellpose_batch_size():
|
|
303
543
|
except Exception as e:
|
304
544
|
return 8
|
305
545
|
|
306
|
-
def
|
546
|
+
def _extract_filename_metadata_v1(filenames, src, regular_expression, metadata_type='cellvoyager', pick_slice=False, skip_mode='01'):
|
307
547
|
|
308
548
|
images_by_key = defaultdict(list)
|
309
549
|
|
@@ -353,6 +593,57 @@ def _extract_filename_metadata(filenames, src, regular_expression, metadata_type
|
|
353
593
|
|
354
594
|
return images_by_key
|
355
595
|
|
596
|
+
def _extract_filename_metadata(filenames, src, regular_expression, metadata_type='cellvoyager', pick_slice=False, skip_mode='01'):
|
597
|
+
|
598
|
+
images_by_key = defaultdict(list)
|
599
|
+
|
600
|
+
for filename in filenames:
|
601
|
+
match = regular_expression.match(filename)
|
602
|
+
if match:
|
603
|
+
try:
|
604
|
+
try:
|
605
|
+
plate = match.group('plateID')
|
606
|
+
except:
|
607
|
+
plate = os.path.basename(src)
|
608
|
+
|
609
|
+
well = match.group('wellID')
|
610
|
+
field = match.group('fieldID')
|
611
|
+
channel = match.group('chanID')
|
612
|
+
mode = None
|
613
|
+
|
614
|
+
if well[0].isdigit():
|
615
|
+
well = str(_safe_int_convert(well))
|
616
|
+
if field[0].isdigit():
|
617
|
+
field = str(_safe_int_convert(field))
|
618
|
+
if channel[0].isdigit():
|
619
|
+
channel = str(_safe_int_convert(channel))
|
620
|
+
|
621
|
+
if metadata_type =='cq1':
|
622
|
+
orig_wellID = wellID
|
623
|
+
wellID = _convert_cq1_well_id(wellID)
|
624
|
+
print(f'Converted Well ID: {orig_wellID} to {wellID}', end='\r', flush=True)
|
625
|
+
|
626
|
+
if pick_slice:
|
627
|
+
try:
|
628
|
+
mode = match.group('AID')
|
629
|
+
except IndexError:
|
630
|
+
sliceid = '00'
|
631
|
+
|
632
|
+
if mode == skip_mode:
|
633
|
+
continue
|
634
|
+
|
635
|
+
key = (plate, well, field, channel, mode)
|
636
|
+
file_path = os.path.join(src, filename) # Store the full path
|
637
|
+
images_by_key[key].append(file_path)
|
638
|
+
|
639
|
+
except IndexError:
|
640
|
+
print(f"Could not extract information from filename {filename} using provided regex")
|
641
|
+
else:
|
642
|
+
print(f"Filename {filename} did not match provided regex")
|
643
|
+
continue
|
644
|
+
|
645
|
+
return images_by_key
|
646
|
+
|
356
647
|
def mask_object_count(mask):
|
357
648
|
"""
|
358
649
|
Counts the number of objects in a given mask.
|
@@ -443,7 +734,7 @@ def _generate_representative_images(db_path, cells=['HeLa'], cell_loc=None, path
|
|
443
734
|
from .plot import _plot_images_on_grid
|
444
735
|
|
445
736
|
df = _read_and_join_tables(db_path)
|
446
|
-
df =
|
737
|
+
df = annotate_conditions(df, cells, cell_loc, pathogens, pathogen_loc, treatments, treatment_loc)
|
447
738
|
|
448
739
|
if update_db:
|
449
740
|
_update_database_with_merged_info(db_path, df, table='png_list', columns=['pathogen', 'treatment', 'host_cells', 'condition', 'prcfo'])
|
@@ -489,34 +780,6 @@ def _map_values(row, values, locs):
|
|
489
780
|
return value_dict.get(row[type_], None)
|
490
781
|
return values[0] if values else None
|
491
782
|
|
492
|
-
def _annotate_conditions(df, cells=['HeLa'], cell_loc=None, pathogens=['rh'], pathogen_loc=None, treatments=['cm'], treatment_loc=None):
|
493
|
-
"""
|
494
|
-
Annotates conditions in the given DataFrame based on the provided parameters.
|
495
|
-
|
496
|
-
Args:
|
497
|
-
df (pandas.DataFrame): The DataFrame to annotate.
|
498
|
-
cells (list, optional): The list of host cell types. Defaults to ['HeLa'].
|
499
|
-
cell_loc (list, optional): The list of location identifiers for host cells. Defaults to None.
|
500
|
-
pathogens (list, optional): The list of pathogens. Defaults to ['rh'].
|
501
|
-
pathogen_loc (list, optional): The list of location identifiers for pathogens. Defaults to None.
|
502
|
-
treatments (list, optional): The list of treatments. Defaults to ['cm'].
|
503
|
-
treatment_loc (list, optional): The list of location identifiers for treatments. Defaults to None.
|
504
|
-
|
505
|
-
Returns:
|
506
|
-
pandas.DataFrame: The annotated DataFrame with the 'host_cells', 'pathogen', 'treatment', and 'condition' columns.
|
507
|
-
"""
|
508
|
-
|
509
|
-
|
510
|
-
# Apply mappings or defaults
|
511
|
-
df['host_cells'] = [cells[0]] * len(df) if cell_loc is None else df.apply(_map_values, args=(cells, cell_loc), axis=1)
|
512
|
-
df['pathogen'] = [pathogens[0]] * len(df) if pathogen_loc is None else df.apply(_map_values, args=(pathogens, pathogen_loc), axis=1)
|
513
|
-
df['treatment'] = [treatments[0]] * len(df) if treatment_loc is None else df.apply(_map_values, args=(treatments, treatment_loc), axis=1)
|
514
|
-
|
515
|
-
# Construct condition column
|
516
|
-
df['condition'] = df.apply(lambda row: '_'.join(filter(None, [row.get('pathogen'), row.get('treatment')])), axis=1)
|
517
|
-
df['condition'] = df['condition'].apply(lambda x: x if x else 'none')
|
518
|
-
return df
|
519
|
-
|
520
783
|
def is_list_of_lists(var):
|
521
784
|
if isinstance(var, list) and all(isinstance(i, list) for i in var):
|
522
785
|
return True
|
@@ -816,7 +1079,7 @@ def _map_wells_png(file_name, timelapse=False):
|
|
816
1079
|
print(f"Error: {e}")
|
817
1080
|
plate, row, column, field, object_id, prcfo = 'error', 'error', 'error', 'error', 'error', 'error'
|
818
1081
|
if timelapse:
|
819
|
-
return plate, row, column, field, timeid, prcfo, object_id
|
1082
|
+
return plate, row, column, field, timeid, prcfo, object_id
|
820
1083
|
else:
|
821
1084
|
return plate, row, column, field, prcfo, object_id
|
822
1085
|
|
@@ -1085,67 +1348,74 @@ def _get_cellpose_channels(src, nucleus_channel, pathogen_channel, cell_channel)
|
|
1085
1348
|
else:
|
1086
1349
|
cellpose_channels['cell'] = [0,0]
|
1087
1350
|
return cellpose_channels
|
1088
|
-
|
1089
|
-
def annotate_conditions(df, cells=
|
1351
|
+
|
1352
|
+
def annotate_conditions(df, cells=None, cell_loc=None, pathogens=None, pathogen_loc=None, treatments=None, treatment_loc=None):
|
1090
1353
|
"""
|
1091
|
-
Annotates conditions in a DataFrame based on specified criteria.
|
1354
|
+
Annotates conditions in a DataFrame based on specified criteria and combines them into a 'condition' column.
|
1355
|
+
NaN is used for missing values, and they are excluded from the 'condition' column.
|
1092
1356
|
|
1093
1357
|
Args:
|
1094
1358
|
df (pandas.DataFrame): The DataFrame to annotate.
|
1095
|
-
cells (list, optional):
|
1096
|
-
cell_loc (list, optional):
|
1097
|
-
pathogens (list, optional):
|
1098
|
-
pathogen_loc (list, optional):
|
1099
|
-
treatments (list, optional):
|
1100
|
-
treatment_loc (list, optional):
|
1101
|
-
types (list, optional): List of column types for host cells, pathogens, and treatments. Defaults to ['col','col','col'].
|
1359
|
+
cells (list/str, optional): Host cell types. Defaults to None.
|
1360
|
+
cell_loc (list of lists, optional): Values for each host cell type. Defaults to None.
|
1361
|
+
pathogens (list/str, optional): Pathogens. Defaults to None.
|
1362
|
+
pathogen_loc (list of lists, optional): Values for each pathogen. Defaults to None.
|
1363
|
+
treatments (list/str, optional): Treatments. Defaults to None.
|
1364
|
+
treatment_loc (list of lists, optional): Values for each treatment. Defaults to None.
|
1102
1365
|
|
1103
1366
|
Returns:
|
1104
|
-
pandas.DataFrame:
|
1367
|
+
pandas.DataFrame: Annotated DataFrame with a combined 'condition' column.
|
1105
1368
|
"""
|
1369
|
+
|
1370
|
+
def _get_type(val):
|
1371
|
+
"""Determine if a value maps to 'row' or 'col'."""
|
1372
|
+
if isinstance(val, str) and val.startswith('c'):
|
1373
|
+
return 'col'
|
1374
|
+
elif isinstance(val, str) and val.startswith('r'):
|
1375
|
+
return 'row'
|
1376
|
+
return None
|
1106
1377
|
|
1107
|
-
|
1108
|
-
def _map_values(row, dict_, type_='col'):
|
1378
|
+
def _map_or_default(column_name, values, loc, df):
|
1109
1379
|
"""
|
1110
|
-
|
1380
|
+
Consolidates the logic for mapping values or assigning defaults when loc is None.
|
1111
1381
|
|
1112
1382
|
Args:
|
1113
|
-
|
1114
|
-
|
1115
|
-
|
1116
|
-
|
1117
|
-
Returns:
|
1118
|
-
str: The mapped value if found, otherwise None.
|
1383
|
+
column_name (str): The column in the DataFrame to annotate.
|
1384
|
+
values (list/str): The list of values or a single string to annotate.
|
1385
|
+
loc (list of lists): Location mapping for the values, or None if not used.
|
1386
|
+
df (pandas.DataFrame): The DataFrame to modify.
|
1119
1387
|
"""
|
1120
|
-
|
1121
|
-
|
1122
|
-
|
1123
|
-
|
1388
|
+
if isinstance(values, str) or (isinstance(values, list) and loc is None):
|
1389
|
+
# Assign all rows the first value in the list or the single string
|
1390
|
+
df[column_name] = values if isinstance(values, str) else values[0]
|
1391
|
+
elif values is not None and loc is not None:
|
1392
|
+
# Perform the location-based mapping
|
1393
|
+
value_dict = {val: key for key, loc_list in zip(values, loc) for val in loc_list}
|
1394
|
+
df[column_name] = np.nan
|
1395
|
+
for val, key in value_dict.items():
|
1396
|
+
loc_type = _get_type(val)
|
1397
|
+
if loc_type:
|
1398
|
+
df.loc[df[loc_type] == val, column_name] = key
|
1399
|
+
|
1400
|
+
# Handle cells, pathogens, and treatments using the consolidated logic
|
1401
|
+
_map_or_default('host_cells', cells, cell_loc, df)
|
1402
|
+
_map_or_default('pathogen', pathogens, pathogen_loc, df)
|
1403
|
+
_map_or_default('treatment', treatments, treatment_loc, df)
|
1404
|
+
|
1405
|
+
# Conditionally fill NaN for pathogen and treatment columns if applicable
|
1406
|
+
if pathogens is not None:
|
1407
|
+
df['pathogen'].fillna(np.nan, inplace=True)
|
1408
|
+
if treatments is not None:
|
1409
|
+
df['treatment'].fillna(np.nan, inplace=True)
|
1410
|
+
|
1411
|
+
# Create the 'condition' column by excluding any NaN values, safely checking if 'host_cells', 'pathogen', and 'treatment' exist
|
1412
|
+
df['condition'] = df.apply(
|
1413
|
+
lambda x: '_'.join([str(v) for v in [x.get('host_cells'), x.get('pathogen'), x.get('treatment')] if pd.notna(v)]),
|
1414
|
+
axis=1
|
1415
|
+
)
|
1124
1416
|
|
1125
|
-
if cell_loc is None:
|
1126
|
-
df['host_cells'] = cells[0]
|
1127
|
-
else:
|
1128
|
-
cells_dict = dict(zip(cells, cell_loc))
|
1129
|
-
df['host_cells'] = df.apply(lambda row: _map_values(row, cells_dict, type_=types[0]), axis=1)
|
1130
|
-
if pathogen_loc is None:
|
1131
|
-
if pathogens != None:
|
1132
|
-
df['pathogen'] = 'none'
|
1133
|
-
else:
|
1134
|
-
pathogens_dict = dict(zip(pathogens, pathogen_loc))
|
1135
|
-
df['pathogen'] = df.apply(lambda row: _map_values(row, pathogens_dict, type_=types[1]), axis=1)
|
1136
|
-
if treatment_loc is None:
|
1137
|
-
df['treatment'] = 'cm'
|
1138
|
-
else:
|
1139
|
-
treatments_dict = dict(zip(treatments, treatment_loc))
|
1140
|
-
df['treatment'] = df.apply(lambda row: _map_values(row, treatments_dict, type_=types[2]), axis=1)
|
1141
|
-
if pathogens != None:
|
1142
|
-
df['condition'] = df['pathogen']+'_'+df['treatment']
|
1143
|
-
else:
|
1144
|
-
df['condition'] = df['treatment']
|
1145
1417
|
return df
|
1146
|
-
|
1147
1418
|
|
1148
|
-
|
1149
1419
|
def _split_data(df, group_by, object_type):
|
1150
1420
|
"""
|
1151
1421
|
Splits the input dataframe into numeric and non-numeric parts, groups them by the specified column,
|
@@ -1951,9 +2221,10 @@ def add_images_to_tar(paths_chunk, tar_path, total_images):
|
|
1951
2221
|
tar.add(img_path, arcname=arcname)
|
1952
2222
|
with lock:
|
1953
2223
|
counter.value += 1
|
1954
|
-
if counter.value %
|
1955
|
-
progress = (counter.value / total_images) * 100
|
1956
|
-
print(f"Progress: {counter.value}/{total_images} ({progress:.2f}%)", end='\r', file=sys.stdout, flush=True)
|
2224
|
+
if counter.value % 10 == 0: # Print every 100 updates
|
2225
|
+
#progress = (counter.value / total_images) * 100
|
2226
|
+
#print(f"Progress: {counter.value}/{total_images} ({progress:.2f}%)", end='\r', file=sys.stdout, flush=True)
|
2227
|
+
print_progress(counter.value, total_images, n_jobs=1, time_ls=None, batch_size=None, operation_type="generating .tar dataset")
|
1957
2228
|
except FileNotFoundError:
|
1958
2229
|
print(f"File not found: {img_path}")
|
1959
2230
|
|
@@ -2070,52 +2341,6 @@ def check_multicollinearity(x):
|
|
2070
2341
|
vif_data["VIF"] = [variance_inflation_factor(x.values, i) for i in range(x.shape[1])]
|
2071
2342
|
return vif_data
|
2072
2343
|
|
2073
|
-
def generate_dependent_variable(df, dv_loc, pc_min=0.95, nc_max=0.05, agg_type='mean'):
|
2074
|
-
|
2075
|
-
from .plot import _plot_histograms_and_stats, _plot_plates
|
2076
|
-
|
2077
|
-
def qstring_to_float(qstr):
|
2078
|
-
number = int(qstr[1:]) # Remove the "q" and convert the rest to an integer
|
2079
|
-
return number / 100.0
|
2080
|
-
|
2081
|
-
print("Unique values in plate:", df['plate'].unique())
|
2082
|
-
dv_cell_loc = f'{dv_loc}/dv_cell.csv'
|
2083
|
-
dv_well_loc = f'{dv_loc}/dv_well.csv'
|
2084
|
-
|
2085
|
-
df['pred'] = 1-df['pred'] #if you swiched pc and nc
|
2086
|
-
df = df[(df['pred'] <= nc_max) | (df['pred'] >= pc_min)]
|
2087
|
-
|
2088
|
-
if 'prc' not in df.columns:
|
2089
|
-
df['prc'] = df['plate'] + '_' + df['row'] + '_' + df['col']
|
2090
|
-
|
2091
|
-
if agg_type.startswith('q'):
|
2092
|
-
val = qstring_to_float(agg_type)
|
2093
|
-
agg_type = lambda x: x.quantile(val)
|
2094
|
-
|
2095
|
-
# Aggregating for mean prediction and total count
|
2096
|
-
df_grouped = df.groupby('prc').agg(
|
2097
|
-
pred=('pred', agg_type),
|
2098
|
-
recruitment=('recruitment', agg_type),
|
2099
|
-
count_prc=('prc', 'size'),
|
2100
|
-
#count_above_95=('pred', lambda x: (x > 0.95).sum()),
|
2101
|
-
mean_pathogen_area=('pathogen_area', 'mean')
|
2102
|
-
)
|
2103
|
-
|
2104
|
-
df_cell = df[['prc', 'pred', 'pathogen_area', 'recruitment']]
|
2105
|
-
|
2106
|
-
df_cell.to_csv(dv_cell_loc, index=True, header=True, mode='w')
|
2107
|
-
df_grouped.to_csv(dv_well_loc, index=True, header=True, mode='w') # Changed from loc to dv_loc
|
2108
|
-
display(df)
|
2109
|
-
_plot_histograms_and_stats(df)
|
2110
|
-
df_grouped = df_grouped.sort_values(by='count_prc', ascending=True)
|
2111
|
-
display(df_grouped)
|
2112
|
-
print('pred')
|
2113
|
-
_plot_plates(df=df_cell, variable='pred', grouping='mean', min_max='allq', cmap='viridis')
|
2114
|
-
print('recruitment')
|
2115
|
-
_plot_plates(df=df_cell, variable='recruitment', grouping='mean', min_max='allq', cmap='viridis')
|
2116
|
-
|
2117
|
-
return df_grouped
|
2118
|
-
|
2119
2344
|
def lasso_reg(merged_df, alpha_value=0.01, reg_type='lasso'):
|
2120
2345
|
# Separate predictors and response
|
2121
2346
|
X = merged_df[['gene', 'grna', 'plate', 'row', 'column']]
|
@@ -3021,7 +3246,6 @@ def preprocess_image(image_path, image_size=224, channels=[1,2,3], normalize=Tru
|
|
3021
3246
|
input_tensor = transform(image).unsqueeze(0)
|
3022
3247
|
return image, input_tensor
|
3023
3248
|
|
3024
|
-
|
3025
3249
|
class SaliencyMapGenerator:
|
3026
3250
|
def __init__(self, model):
|
3027
3251
|
self.model = model
|
@@ -3042,18 +3266,194 @@ class SaliencyMapGenerator:
|
|
3042
3266
|
saliency = X.grad.abs()
|
3043
3267
|
return saliency
|
3044
3268
|
|
3045
|
-
def
|
3269
|
+
def compute_saliency_and_predictions(self, X):
|
3270
|
+
self.model.eval()
|
3271
|
+
X.requires_grad_()
|
3272
|
+
|
3273
|
+
# Forward pass to get predictions (logits)
|
3274
|
+
scores = self.model(X).squeeze()
|
3275
|
+
|
3276
|
+
# Get predicted class (0 or 1 for binary classification)
|
3277
|
+
predictions = (scores > 0).long()
|
3278
|
+
|
3279
|
+
# Compute saliency maps
|
3280
|
+
self.model.zero_grad()
|
3281
|
+
target_scores = scores * (2 * predictions - 1)
|
3282
|
+
target_scores.backward(torch.ones_like(target_scores))
|
3283
|
+
|
3284
|
+
saliency = X.grad.abs()
|
3285
|
+
|
3286
|
+
return saliency, predictions
|
3287
|
+
|
3288
|
+
def plot_activation_grid(self, X, saliency, predictions, overlay=True, normalize=False):
|
3046
3289
|
N = X.shape[0]
|
3290
|
+
rows = (N + 7) // 8
|
3291
|
+
fig, axs = plt.subplots(rows, 8, figsize=(16, rows * 2))
|
3292
|
+
|
3047
3293
|
for i in range(N):
|
3048
|
-
|
3049
|
-
|
3050
|
-
|
3051
|
-
|
3052
|
-
|
3053
|
-
|
3054
|
-
|
3055
|
-
|
3056
|
-
|
3294
|
+
ax = axs[i // 8, i % 8]
|
3295
|
+
saliency_map = saliency[i].cpu().numpy() # Move to CPU and convert to numpy
|
3296
|
+
|
3297
|
+
if saliency_map.shape[0] == 3: # Channels first, reshape to (H, W, 3)
|
3298
|
+
saliency_map = np.transpose(saliency_map, (1, 2, 0))
|
3299
|
+
|
3300
|
+
# Normalize image channels to 2nd and 98th percentiles
|
3301
|
+
if overlay:
|
3302
|
+
img_np = X[i].permute(1, 2, 0).detach().cpu().numpy()
|
3303
|
+
if normalize:
|
3304
|
+
img_np = self.percentile_normalize(img_np)
|
3305
|
+
ax.imshow(img_np)
|
3306
|
+
ax.imshow(saliency_map, cmap='jet', alpha=0.5)
|
3307
|
+
|
3308
|
+
# Add class label in the top-left corner
|
3309
|
+
ax.text(5, 25, str(predictions[i].item()), fontsize=12, color='white', weight='bold',
|
3310
|
+
bbox=dict(facecolor='black', alpha=0.7, boxstyle='round,pad=0.2'))
|
3311
|
+
ax.axis('off')
|
3312
|
+
|
3313
|
+
plt.tight_layout(pad=0)
|
3314
|
+
return fig
|
3315
|
+
|
3316
|
+
def percentile_normalize(self, img, lower_percentile=2, upper_percentile=98):
|
3317
|
+
"""
|
3318
|
+
Normalize each channel of the image to the given percentiles.
|
3319
|
+
Args:
|
3320
|
+
img: Input image as numpy array with shape (H, W, C)
|
3321
|
+
lower_percentile: Lower percentile for normalization (default 2)
|
3322
|
+
upper_percentile: Upper percentile for normalization (default 98)
|
3323
|
+
Returns:
|
3324
|
+
img: Normalized image
|
3325
|
+
"""
|
3326
|
+
img_normalized = np.zeros_like(img)
|
3327
|
+
|
3328
|
+
for c in range(img.shape[2]): # Iterate over each channel
|
3329
|
+
low = np.percentile(img[:, :, c], lower_percentile)
|
3330
|
+
high = np.percentile(img[:, :, c], upper_percentile)
|
3331
|
+
img_normalized[:, :, c] = np.clip((img[:, :, c] - low) / (high - low), 0, 1)
|
3332
|
+
|
3333
|
+
return img_normalized
|
3334
|
+
|
3335
|
+
|
3336
|
+
class GradCAMGenerator:
|
3337
|
+
def __init__(self, model, target_layer, cam_type='gradcam'):
|
3338
|
+
self.model = model
|
3339
|
+
self.model.eval()
|
3340
|
+
self.target_layer = target_layer
|
3341
|
+
self.cam_type = cam_type
|
3342
|
+
self.gradients = None
|
3343
|
+
self.activations = None
|
3344
|
+
|
3345
|
+
# Hook the target layer
|
3346
|
+
self.target_layer_module = self.get_layer(self.model, self.target_layer)
|
3347
|
+
self.hook_layers()
|
3348
|
+
|
3349
|
+
def hook_layers(self):
|
3350
|
+
# Forward hook to get activations
|
3351
|
+
def forward_hook(module, input, output):
|
3352
|
+
self.activations = output
|
3353
|
+
|
3354
|
+
# Backward hook to get gradients
|
3355
|
+
def backward_hook(module, grad_input, grad_output):
|
3356
|
+
self.gradients = grad_output[0]
|
3357
|
+
|
3358
|
+
self.target_layer_module.register_forward_hook(forward_hook)
|
3359
|
+
self.target_layer_module.register_backward_hook(backward_hook)
|
3360
|
+
|
3361
|
+
def get_layer(self, model, target_layer):
|
3362
|
+
# Recursively find the layer specified in target_layer
|
3363
|
+
modules = target_layer.split('.')
|
3364
|
+
layer = model
|
3365
|
+
for module in modules:
|
3366
|
+
layer = getattr(layer, module)
|
3367
|
+
return layer
|
3368
|
+
|
3369
|
+
def compute_gradcam_maps(self, X, y):
|
3370
|
+
X.requires_grad_()
|
3371
|
+
|
3372
|
+
# Forward pass
|
3373
|
+
scores = self.model(X).squeeze()
|
3374
|
+
|
3375
|
+
# Perform backward pass
|
3376
|
+
target_scores = scores * (2 * y - 1)
|
3377
|
+
self.model.zero_grad()
|
3378
|
+
target_scores.backward(torch.ones_like(target_scores))
|
3379
|
+
|
3380
|
+
# Compute GradCAM
|
3381
|
+
pooled_gradients = torch.mean(self.gradients, dim=[0, 2, 3])
|
3382
|
+
for i in range(self.activations.size(1)):
|
3383
|
+
self.activations[:, i, :, :] *= pooled_gradients[i]
|
3384
|
+
|
3385
|
+
gradcam = torch.mean(self.activations, dim=1).squeeze()
|
3386
|
+
gradcam = F.relu(gradcam)
|
3387
|
+
gradcam = F.interpolate(gradcam.unsqueeze(0).unsqueeze(0), size=X.shape[2:], mode='bilinear')
|
3388
|
+
gradcam = gradcam.squeeze().cpu().detach().numpy()
|
3389
|
+
gradcam = (gradcam - gradcam.min()) / (gradcam.max() - gradcam.min())
|
3390
|
+
|
3391
|
+
return gradcam
|
3392
|
+
|
3393
|
+
def compute_gradcam_and_predictions(self, X):
|
3394
|
+
self.model.eval()
|
3395
|
+
X.requires_grad_()
|
3396
|
+
|
3397
|
+
# Forward pass to get predictions (logits)
|
3398
|
+
scores = self.model(X).squeeze()
|
3399
|
+
|
3400
|
+
# Get predicted class (0 or 1 for binary classification)
|
3401
|
+
predictions = (scores > 0).long()
|
3402
|
+
|
3403
|
+
# Compute gradcam maps
|
3404
|
+
gradcam_maps = []
|
3405
|
+
for i in range(X.size(0)):
|
3406
|
+
gradcam_map = self.compute_gradcam_maps(X[i].unsqueeze(0), predictions[i])
|
3407
|
+
gradcam_maps.append(gradcam_map)
|
3408
|
+
|
3409
|
+
return torch.tensor(gradcam_maps), predictions
|
3410
|
+
|
3411
|
+
def plot_activation_grid(self, X, gradcam, predictions, overlay=True, normalize=False):
|
3412
|
+
N = X.shape[0]
|
3413
|
+
rows = (N + 7) // 8
|
3414
|
+
fig, axs = plt.subplots(rows, 8, figsize=(16, rows * 2))
|
3415
|
+
|
3416
|
+
for i in range(N):
|
3417
|
+
ax = axs[i // 8, i % 8]
|
3418
|
+
gradcam_map = gradcam[i].cpu().numpy()
|
3419
|
+
|
3420
|
+
# Normalize image channels to 2nd and 98th percentiles
|
3421
|
+
if overlay:
|
3422
|
+
img_np = X[i].permute(1, 2, 0).detach().cpu().numpy()
|
3423
|
+
if normalize:
|
3424
|
+
img_np = self.percentile_normalize(img_np)
|
3425
|
+
ax.imshow(img_np)
|
3426
|
+
ax.imshow(gradcam_map, cmap='jet', alpha=0.5)
|
3427
|
+
|
3428
|
+
#ax.imshow(X[i].permute(1, 2, 0).detach().cpu().numpy()) # Original image
|
3429
|
+
#ax.imshow(gradcam_map, cmap='jet', alpha=0.5) # Overlay the gradcam map
|
3430
|
+
|
3431
|
+
# Add class label in the top-left corner
|
3432
|
+
ax.text(5, 25, str(predictions[i].item()), fontsize=12, color='white', weight='bold',
|
3433
|
+
bbox=dict(facecolor='black', alpha=0.7, boxstyle='round,pad=0.2'))
|
3434
|
+
ax.axis('off')
|
3435
|
+
|
3436
|
+
plt.tight_layout(pad=0)
|
3437
|
+
return fig
|
3438
|
+
|
3439
|
+
def percentile_normalize(self, img, lower_percentile=2, upper_percentile=98):
|
3440
|
+
"""
|
3441
|
+
Normalize each channel of the image to the given percentiles.
|
3442
|
+
Args:
|
3443
|
+
img: Input image as numpy array with shape (H, W, C)
|
3444
|
+
lower_percentile: Lower percentile for normalization (default 2)
|
3445
|
+
upper_percentile: Upper percentile for normalization (default 98)
|
3446
|
+
Returns:
|
3447
|
+
img: Normalized image
|
3448
|
+
"""
|
3449
|
+
img_normalized = np.zeros_like(img)
|
3450
|
+
|
3451
|
+
for c in range(img.shape[2]): # Iterate over each channel
|
3452
|
+
low = np.percentile(img[:, :, c], lower_percentile)
|
3453
|
+
high = np.percentile(img[:, :, c], upper_percentile)
|
3454
|
+
img_normalized[:, :, c] = np.clip((img[:, :, c] - low) / (high - low), 0, 1)
|
3455
|
+
|
3456
|
+
return img_normalized
|
3057
3457
|
|
3058
3458
|
def preprocess_image(image_path, normalize=True, image_size=224, channels=[1,2,3]):
|
3059
3459
|
preprocess = transforms.Compose([
|
@@ -3594,13 +3994,86 @@ def plot_grid(cluster_images, colors, figuresize, black_background, verbose):
|
|
3594
3994
|
plt.show()
|
3595
3995
|
return grid_fig
|
3596
3996
|
|
3597
|
-
def
|
3997
|
+
def generate_path_list_from_db_v1(db_path, file_metadata):
|
3998
|
+
|
3999
|
+
all_paths = []
|
4000
|
+
|
4001
|
+
# Connect to the database and retrieve the image paths
|
4002
|
+
print(f"Reading DataBase: {db_path}")
|
4003
|
+
try:
|
4004
|
+
with sqlite3.connect(db_path) as conn:
|
4005
|
+
cursor = conn.cursor()
|
4006
|
+
if file_metadata:
|
4007
|
+
if isinstance(file_metadata, str):
|
4008
|
+
cursor.execute("SELECT png_path FROM png_list WHERE png_path LIKE ?", (f"%{file_metadata}%",))
|
4009
|
+
else:
|
4010
|
+
cursor.execute("SELECT png_path FROM png_list")
|
4011
|
+
|
4012
|
+
while True:
|
4013
|
+
rows = cursor.fetchmany(1000)
|
4014
|
+
if not rows:
|
4015
|
+
break
|
4016
|
+
all_paths.extend([row[0] for row in rows])
|
3598
4017
|
|
3599
|
-
|
3600
|
-
print("
|
3601
|
-
return
|
4018
|
+
except sqlite3.Error as e:
|
4019
|
+
print(f"Database error: {e}")
|
4020
|
+
return
|
4021
|
+
except Exception as e:
|
4022
|
+
print(f"Error: {e}")
|
4023
|
+
return
|
3602
4024
|
|
3603
|
-
|
4025
|
+
return all_paths
|
4026
|
+
|
4027
|
+
def generate_path_list_from_db(db_path, file_metadata):
|
4028
|
+
all_paths = []
|
4029
|
+
|
4030
|
+
# Connect to the database and retrieve the image paths
|
4031
|
+
print(f"Reading DataBase: {db_path}")
|
4032
|
+
try:
|
4033
|
+
with sqlite3.connect(db_path) as conn:
|
4034
|
+
cursor = conn.cursor()
|
4035
|
+
|
4036
|
+
if file_metadata:
|
4037
|
+
if isinstance(file_metadata, str):
|
4038
|
+
# If file_metadata is a single string
|
4039
|
+
cursor.execute("SELECT png_path FROM png_list WHERE png_path LIKE ?", (f"%{file_metadata}%",))
|
4040
|
+
elif isinstance(file_metadata, list):
|
4041
|
+
# If file_metadata is a list of strings
|
4042
|
+
query = "SELECT png_path FROM png_list WHERE " + " OR ".join(
|
4043
|
+
["png_path LIKE ?" for _ in file_metadata])
|
4044
|
+
params = [f"%{meta}%" for meta in file_metadata]
|
4045
|
+
cursor.execute(query, params)
|
4046
|
+
else:
|
4047
|
+
# If file_metadata is None or empty
|
4048
|
+
cursor.execute("SELECT png_path FROM png_list")
|
4049
|
+
|
4050
|
+
while True:
|
4051
|
+
rows = cursor.fetchmany(1000)
|
4052
|
+
if not rows:
|
4053
|
+
break
|
4054
|
+
all_paths.extend([row[0] for row in rows])
|
4055
|
+
|
4056
|
+
except sqlite3.Error as e:
|
4057
|
+
print(f"Database error: {e}")
|
4058
|
+
return
|
4059
|
+
except Exception as e:
|
4060
|
+
print(f"Error: {e}")
|
4061
|
+
return
|
4062
|
+
|
4063
|
+
return all_paths
|
4064
|
+
|
4065
|
+
def correct_paths(df, base_path):
|
4066
|
+
|
4067
|
+
if isinstance(df, pd.DataFrame):
|
4068
|
+
|
4069
|
+
if 'png_path' not in df.columns:
|
4070
|
+
print("No 'png_path' column found in the dataframe.")
|
4071
|
+
return df, None
|
4072
|
+
else:
|
4073
|
+
image_paths = df['png_path'].to_list()
|
4074
|
+
|
4075
|
+
elif isinstance(df, list):
|
4076
|
+
image_paths = df
|
3604
4077
|
|
3605
4078
|
adjusted_image_paths = []
|
3606
4079
|
for path in image_paths:
|
@@ -3614,9 +4087,11 @@ def correct_paths(df, base_path):
|
|
3614
4087
|
else:
|
3615
4088
|
adjusted_image_paths.append(path)
|
3616
4089
|
|
3617
|
-
df
|
3618
|
-
|
3619
|
-
|
4090
|
+
if isinstance(df, pd.DataFrame):
|
4091
|
+
df['png_path'] = adjusted_image_paths
|
4092
|
+
return df, adjusted_image_paths
|
4093
|
+
else:
|
4094
|
+
return adjusted_image_paths
|
3620
4095
|
|
3621
4096
|
def delete_folder(folder_path):
|
3622
4097
|
if os.path.exists(folder_path) and os.path.isdir(folder_path):
|
@@ -4424,7 +4899,7 @@ def convert_and_relabel_masks(folder_path):
|
|
4424
4899
|
|
4425
4900
|
def correct_masks(src):
|
4426
4901
|
|
4427
|
-
from .
|
4902
|
+
from .io import _load_and_concatenate_arrays
|
4428
4903
|
|
4429
4904
|
cell_path = os.path.join(src,'norm_channel_stack', 'cell_mask_stack')
|
4430
4905
|
convert_and_relabel_masks(cell_path)
|
@@ -4447,4 +4922,123 @@ def get_cuda_version():
|
|
4447
4922
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
4448
4923
|
return None
|
4449
4924
|
|
4925
|
+
def all_elements_match(list1, list2):
|
4926
|
+
# Check if all elements in list1 are in list2
|
4927
|
+
return all(element in list2 for element in list1)
|
4928
|
+
|
4929
|
+
def prepare_batch_for_segmentation(batch):
|
4930
|
+
# Ensure the batch is of dtype float32
|
4931
|
+
if batch.dtype != np.float32:
|
4932
|
+
batch = batch.astype(np.float32)
|
4933
|
+
|
4934
|
+
# Normalize each image in the batch
|
4935
|
+
for i in range(batch.shape[0]):
|
4936
|
+
if batch[i].max() > 1:
|
4937
|
+
batch[i] = batch[i] / batch[i].max()
|
4938
|
+
|
4939
|
+
return batch
|
4940
|
+
|
4941
|
+
def check_index(df, elements=5, split_char='_'):
|
4942
|
+
problematic_indices = []
|
4943
|
+
for idx in df.index:
|
4944
|
+
parts = str(idx).split(split_char)
|
4945
|
+
if len(parts) != elements:
|
4946
|
+
problematic_indices.append(idx)
|
4947
|
+
if problematic_indices:
|
4948
|
+
print("Indices that cannot be separated into 5 parts:")
|
4949
|
+
for idx in problematic_indices:
|
4950
|
+
print(idx)
|
4951
|
+
raise ValueError(f"Found {len(problematic_indices)} problematic indices that do not split into {elements} parts.")
|
4952
|
+
|
4953
|
+
# Define the mapping function
|
4954
|
+
def map_condition(col_value, neg='c1', pos='c2', mix='c3'):
|
4955
|
+
if col_value == neg:
|
4956
|
+
return 'neg'
|
4957
|
+
elif col_value == pos:
|
4958
|
+
return 'pos'
|
4959
|
+
elif col_value == mix:
|
4960
|
+
return 'mix'
|
4961
|
+
else:
|
4962
|
+
return 'screen'
|
4963
|
+
|
4964
|
+
def download_models(repo_id="einarolafsson/models", local_dir=None, retries=5, delay=5):
|
4965
|
+
"""
|
4966
|
+
Downloads all model files from Hugging Face and stores them in the specified local directory.
|
4450
4967
|
|
4968
|
+
Args:
|
4969
|
+
repo_id (str): The repository ID on Hugging Face (default is 'einarolafsson/models').
|
4970
|
+
local_dir (str): The local directory where models will be saved. Defaults to '/home/carruthers/Desktop/test'.
|
4971
|
+
retries (int): Number of retry attempts in case of failure.
|
4972
|
+
delay (int): Delay in seconds between retries.
|
4973
|
+
|
4974
|
+
Returns:
|
4975
|
+
str: The local path to the downloaded models.
|
4976
|
+
"""
|
4977
|
+
# Create the local directory if it doesn't exist
|
4978
|
+
if not os.path.exists(local_dir):
|
4979
|
+
os.makedirs(local_dir)
|
4980
|
+
elif len(os.listdir(local_dir)) > 0:
|
4981
|
+
print(f"Models already downloaded to: {local_dir}")
|
4982
|
+
return local_dir
|
4983
|
+
|
4984
|
+
attempt = 0
|
4985
|
+
while attempt < retries:
|
4986
|
+
try:
|
4987
|
+
# List all files in the repo
|
4988
|
+
files = list_repo_files(repo_id, repo_type="dataset")
|
4989
|
+
print(f"Files in repository: {files}") # Debugging print to check file list
|
4990
|
+
|
4991
|
+
# Download each file
|
4992
|
+
for file_name in files:
|
4993
|
+
for download_attempt in range(retries):
|
4994
|
+
try:
|
4995
|
+
url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/{file_name}?download=true"
|
4996
|
+
print(f"Downloading file from: {url}") # Debugging
|
4997
|
+
|
4998
|
+
response = requests.get(url, stream=True)
|
4999
|
+
print(f"HTTP response status: {response.status_code}") # Debugging
|
5000
|
+
response.raise_for_status()
|
5001
|
+
|
5002
|
+
# Save the file locally
|
5003
|
+
local_file_path = os.path.join(local_dir, os.path.basename(file_name))
|
5004
|
+
with open(local_file_path, 'wb') as file:
|
5005
|
+
for chunk in response.iter_content(chunk_size=8192):
|
5006
|
+
file.write(chunk)
|
5007
|
+
print(f"Downloaded model file: {file_name} to {local_file_path}")
|
5008
|
+
break # Exit the retry loop if successful
|
5009
|
+
except (requests.HTTPError, requests.Timeout) as e:
|
5010
|
+
print(f"Error downloading {file_name}: {e}. Retrying in {delay} seconds...")
|
5011
|
+
time.sleep(delay)
|
5012
|
+
else:
|
5013
|
+
raise Exception(f"Failed to download {file_name} after multiple attempts.")
|
5014
|
+
|
5015
|
+
return local_dir # Return the directory where models are saved
|
5016
|
+
|
5017
|
+
except (requests.HTTPError, requests.Timeout) as e:
|
5018
|
+
print(f"Error downloading files: {e}. Retrying in {delay} seconds...")
|
5019
|
+
attempt += 1
|
5020
|
+
time.sleep(delay)
|
5021
|
+
|
5022
|
+
raise Exception("Failed to download model files after multiple attempts.")
|
5023
|
+
|
5024
|
+
def generate_cytoplasm_mask(nucleus_mask, cell_mask):
|
5025
|
+
|
5026
|
+
"""
|
5027
|
+
Generates a cytoplasm mask from nucleus and cell masks.
|
5028
|
+
|
5029
|
+
Parameters:
|
5030
|
+
- nucleus_mask (np.array): Binary or segmented mask of the nucleus (non-zero values represent nucleus).
|
5031
|
+
- cell_mask (np.array): Binary or segmented mask of the whole cell (non-zero values represent cell).
|
5032
|
+
|
5033
|
+
Returns:
|
5034
|
+
- cytoplasm_mask (np.array): Mask for the cytoplasm (1 for cytoplasm, 0 for nucleus and pathogens).
|
5035
|
+
"""
|
5036
|
+
|
5037
|
+
# Make sure the nucleus and cell masks are numpy arrays
|
5038
|
+
nucleus_mask = np.array(nucleus_mask)
|
5039
|
+
cell_mask = np.array(cell_mask)
|
5040
|
+
|
5041
|
+
# Generate cytoplasm mask
|
5042
|
+
cytoplasm_mask = np.where(np.logical_or(nucleus_mask != 0), 0, cell_mask)
|
5043
|
+
|
5044
|
+
return cytoplasm_mask
|