spacr 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/core.py +105 -1
- spacr/deep_spacr.py +191 -141
- spacr/gui.py +1 -0
- spacr/gui_core.py +13 -4
- spacr/gui_utils.py +29 -1
- spacr/io.py +84 -125
- spacr/measure.py +1 -38
- spacr/ml.py +153 -66
- spacr/plot.py +429 -7
- spacr/settings.py +55 -10
- spacr/submodules.py +7 -6
- spacr/toxo.py +9 -4
- spacr/utils.py +510 -16
- {spacr-0.3.2.dist-info → spacr-0.3.3.dist-info}/METADATA +28 -25
- {spacr-0.3.2.dist-info → spacr-0.3.3.dist-info}/RECORD +19 -19
- {spacr-0.3.2.dist-info → spacr-0.3.3.dist-info}/LICENSE +0 -0
- {spacr-0.3.2.dist-info → spacr-0.3.3.dist-info}/WHEEL +0 -0
- {spacr-0.3.2.dist-info → spacr-0.3.3.dist-info}/entry_points.txt +0 -0
- {spacr-0.3.2.dist-info → spacr-0.3.3.dist-info}/top_level.txt +0 -0
spacr/submodules.py
CHANGED
@@ -36,7 +36,7 @@ def analyze_recruitment(settings={}):
|
|
36
36
|
sns.color_palette("mako", as_cmap=True)
|
37
37
|
print(f"channel:{settings['channel_of_interest']} = {settings['target']}")
|
38
38
|
|
39
|
-
df, _ = _read_and_merge_data(
|
39
|
+
df, _ = _read_and_merge_data(locs=[settings['src']+'/measurements/measurements.db'],
|
40
40
|
tables=['cell', 'nucleus', 'pathogen','cytoplasm'],
|
41
41
|
verbose=True,
|
42
42
|
nuclei_limit=settings['nuclei_limit'],
|
@@ -89,15 +89,16 @@ def analyze_recruitment(settings={}):
|
|
89
89
|
|
90
90
|
if not settings['cell_chann_dim'] is None:
|
91
91
|
df = _object_filter(df, 'cell', settings['cell_size_range'], settings['cell_intensity_range'], mask_chans, 0)
|
92
|
-
if not settings['target_intensity_min'] is None:
|
93
|
-
df = df[df[f"cell_channel_{settings['channel_of_interest']}_percentile_95
|
92
|
+
if not settings['target_intensity_min'] is None or not settings['target_intensity_min'] is 0:
|
93
|
+
df = df[df[f"cell_channel_{settings['channel_of_interest']}_percentile_95"] > settings['target_intensity_min']]
|
94
94
|
print(f"After channel {settings['channel_of_interest']} filtration", len(df))
|
95
95
|
if not settings['nucleus_chann_dim'] is None:
|
96
96
|
df = _object_filter(df, 'nucleus', settings['nucleus_size_range'], settings['nucleus_intensity_range'], mask_chans, 1)
|
97
97
|
if not settings['pathogen_chann_dim'] is None:
|
98
98
|
df = _object_filter(df, 'pathogen', settings['pathogen_size_range'], settings['pathogen_intensity_range'], mask_chans, 2)
|
99
99
|
|
100
|
-
df['recruitment'] = df[f"pathogen_channel_{settings['channel_of_interest']}_mean_intensity
|
100
|
+
df['recruitment'] = df[f"pathogen_channel_{settings['channel_of_interest']}_mean_intensity"]/df[f"cytoplasm_channel_{settings['channel_of_interest']}_mean_intensity"]
|
101
|
+
|
101
102
|
for chan in settings['channel_dims']:
|
102
103
|
df = _calculate_recruitment(df, channel=chan)
|
103
104
|
print(f'calculated recruitment for: {len(df)} rows')
|
@@ -114,9 +115,9 @@ def analyze_recruitment(settings={}):
|
|
114
115
|
_plot_controls(df, mask_chans, settings['channel_of_interest'], figuresize=5)
|
115
116
|
|
116
117
|
print(f'PV level: {len(df)} rows')
|
117
|
-
_plot_recruitment(df, 'by PV', settings['channel_of_interest'],
|
118
|
+
_plot_recruitment(df, 'by PV', settings['channel_of_interest'], columns=[], figuresize=settings['figuresize'])
|
118
119
|
print(f'well level: {len(df_well)} rows')
|
119
|
-
_plot_recruitment(df_well, 'by well', settings['channel_of_interest'],
|
120
|
+
_plot_recruitment(df_well, 'by well', settings['channel_of_interest'], columns=[], figuresize=settings['figuresize'])
|
120
121
|
cells,wells = _results_to_csv(settings['src'], df, df_well)
|
121
122
|
|
122
123
|
return [cells,wells]
|
spacr/toxo.py
CHANGED
@@ -112,10 +112,15 @@ def go_term_enrichment_by_column(significant_df, metadata_path, go_term_columns=
|
|
112
112
|
- Plot the enrichment score vs -log10(p-value).
|
113
113
|
"""
|
114
114
|
|
115
|
-
significant_df['variable'].fillna(significant_df['feature'], inplace=True)
|
116
|
-
split_columns = significant_df['variable'].str.split('_', expand=True)
|
117
|
-
significant_df['gene_nr'] = split_columns[0]
|
118
|
-
gene_list = significant_df['gene_nr'].to_list()
|
115
|
+
#significant_df['variable'].fillna(significant_df['feature'], inplace=True)
|
116
|
+
#split_columns = significant_df['variable'].str.split('_', expand=True)
|
117
|
+
#significant_df['gene_nr'] = split_columns[0]
|
118
|
+
#gene_list = significant_df['gene_nr'].to_list()
|
119
|
+
|
120
|
+
significant_df = significant_df.dropna(subset=['n_gene'])
|
121
|
+
significant_df = significant_df[significant_df['n_gene'] != None]
|
122
|
+
|
123
|
+
gene_list = significant_df['n_gene'].to_list()
|
119
124
|
|
120
125
|
# Load metadata
|
121
126
|
metadata = pd.read_csv(metadata_path)
|
spacr/utils.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
import os, re, sqlite3, torch, torchvision, random, string, shutil, cv2, tarfile, glob, psutil, platform, gzip, subprocess, time, requests
|
1
|
+
import os, re, sqlite3, torch, torchvision, random, string, shutil, cv2, tarfile, glob, psutil, platform, gzip, subprocess, time, requests, ast, traceback
|
2
2
|
|
3
3
|
import numpy as np
|
4
4
|
import pandas as pd
|
@@ -12,6 +12,7 @@ from skimage.transform import resize as resizescikit
|
|
12
12
|
from skimage.morphology import dilation, square
|
13
13
|
from skimage.measure import find_contours
|
14
14
|
from skimage.segmentation import clear_border
|
15
|
+
from scipy.stats import pearsonr
|
15
16
|
|
16
17
|
from collections import defaultdict, OrderedDict
|
17
18
|
from PIL import Image
|
@@ -37,6 +38,7 @@ from torchvision import models
|
|
37
38
|
from torchvision.models.resnet import ResNet18_Weights, ResNet34_Weights, ResNet50_Weights, ResNet101_Weights, ResNet152_Weights
|
38
39
|
import torchvision.transforms as transforms
|
39
40
|
from torchvision.models import resnet50
|
41
|
+
from torchvision.utils import make_grid
|
40
42
|
|
41
43
|
import seaborn as sns
|
42
44
|
import matplotlib.pyplot as plt
|
@@ -66,13 +68,270 @@ from huggingface_hub import list_repo_files
|
|
66
68
|
import umap.umap_ as umap
|
67
69
|
#import umap
|
68
70
|
|
71
|
+
def filepaths_to_database(img_paths, settings, source_folder, crop_mode):
|
72
|
+
|
73
|
+
png_df = pd.DataFrame(img_paths, columns=['png_path'])
|
74
|
+
|
75
|
+
png_df['file_name'] = png_df['png_path'].apply(lambda x: os.path.basename(x))
|
76
|
+
|
77
|
+
parts = png_df['file_name'].apply(lambda x: pd.Series(_map_wells_png(x, timelapse=settings['timelapse'])))
|
78
|
+
|
79
|
+
columns = ['plate', 'row', 'col', 'field']
|
80
|
+
|
81
|
+
if settings['timelapse']:
|
82
|
+
columns = columns + ['time_id']
|
83
|
+
|
84
|
+
columns = columns + ['prcfo']
|
85
|
+
|
86
|
+
if crop_mode == 'cell':
|
87
|
+
columns = columns + ['cell_id']
|
88
|
+
|
89
|
+
if crop_mode == 'nucleus':
|
90
|
+
columns = columns + ['nucleus_id']
|
91
|
+
|
92
|
+
if crop_mode == 'pathogen':
|
93
|
+
columns = columns + ['pathogen_id']
|
94
|
+
|
95
|
+
if crop_mode == 'cytoplasm':
|
96
|
+
columns = columns + ['cytoplasm_id']
|
97
|
+
|
98
|
+
png_df[columns] = parts
|
99
|
+
|
100
|
+
try:
|
101
|
+
conn = sqlite3.connect(f'{source_folder}/measurements/measurements.db', timeout=5)
|
102
|
+
png_df.to_sql('png_list', conn, if_exists='append', index=False)
|
103
|
+
conn.commit()
|
104
|
+
except sqlite3.OperationalError as e:
|
105
|
+
print(f"SQLite error: {e}", flush=True)
|
106
|
+
traceback.print_exc()
|
107
|
+
|
108
|
+
def activation_maps_to_database(img_paths, source_folder, settings):
|
109
|
+
from .io import _create_database
|
110
|
+
|
111
|
+
png_df = pd.DataFrame(img_paths, columns=['png_path'])
|
112
|
+
png_df['file_name'] = png_df['png_path'].apply(lambda x: os.path.basename(x))
|
113
|
+
parts = png_df['file_name'].apply(lambda x: pd.Series(_map_wells_png(x, timelapse=False)))
|
114
|
+
columns = ['plate', 'row', 'col', 'field', 'prcfo', 'object']
|
115
|
+
png_df[columns] = parts
|
116
|
+
|
117
|
+
dataset_name = os.path.splitext(os.path.basename(settings['dataset']))[0]
|
118
|
+
database_name = f"{source_folder}/measurements/{dataset_name}.db"
|
119
|
+
|
120
|
+
if not os.path.exists(database_name):
|
121
|
+
_create_database(database_name)
|
122
|
+
|
123
|
+
try:
|
124
|
+
conn = sqlite3.connect(database_name, timeout=5)
|
125
|
+
png_df.to_sql(f"{settings['cam_type']}_list", conn, if_exists='append', index=False)
|
126
|
+
conn.commit()
|
127
|
+
except sqlite3.OperationalError as e:
|
128
|
+
print(f"SQLite error: {e}", flush=True)
|
129
|
+
traceback.print_exc()
|
130
|
+
|
131
|
+
def activation_correlations_to_database(df, img_paths, source_folder, settings):
|
132
|
+
from .io import _create_database
|
133
|
+
|
134
|
+
png_df = pd.DataFrame(img_paths, columns=['png_path'])
|
135
|
+
png_df['file_name'] = png_df['png_path'].apply(lambda x: os.path.basename(x))
|
136
|
+
parts = png_df['file_name'].apply(lambda x: pd.Series(_map_wells_png(x, timelapse=False)))
|
137
|
+
columns = ['plate', 'row', 'col', 'field', 'prcfo', 'object']
|
138
|
+
png_df[columns] = parts
|
139
|
+
|
140
|
+
# Align both DataFrames by file_name
|
141
|
+
png_df.set_index('file_name', inplace=True)
|
142
|
+
df.set_index('file_name', inplace=True)
|
143
|
+
|
144
|
+
merged_df = pd.concat([png_df, df], axis=1)
|
145
|
+
merged_df.reset_index(inplace=True)
|
146
|
+
|
147
|
+
dataset_name = os.path.splitext(os.path.basename(settings['dataset']))[0]
|
148
|
+
database_name = f"{source_folder}/measurements/{dataset_name}.db"
|
149
|
+
|
150
|
+
if not os.path.exists(database_name):
|
151
|
+
_create_database(database_name)
|
152
|
+
|
153
|
+
try:
|
154
|
+
conn = sqlite3.connect(database_name, timeout=5)
|
155
|
+
merged_df.to_sql(f"{settings['cam_type']}_correlations", conn, if_exists='append', index=False)
|
156
|
+
conn.commit()
|
157
|
+
except sqlite3.OperationalError as e:
|
158
|
+
print(f"SQLite error: {e}", flush=True)
|
159
|
+
traceback.print_exc()
|
160
|
+
|
161
|
+
def calculate_activation_correlations(inputs, activation_maps, file_names, manders_thresholds=[15, 50, 75]):
|
162
|
+
"""
|
163
|
+
Calculates Pearson and Manders correlations between input image channels and activation map channels.
|
164
|
+
|
165
|
+
Args:
|
166
|
+
inputs: A batch of input images, Tensor of shape (batch_size, channels, height, width)
|
167
|
+
activation_maps: A batch of activation maps, Tensor of shape (batch_size, channels, height, width)
|
168
|
+
file_names: List of file names corresponding to each image in the batch.
|
169
|
+
manders_thresholds: List of intensity percentiles to calculate Manders correlation.
|
170
|
+
|
171
|
+
Returns:
|
172
|
+
df_correlations: A DataFrame with columns for pairwise correlations (Pearson and Manders)
|
173
|
+
between input channels and activation map channels.
|
174
|
+
"""
|
175
|
+
|
176
|
+
# Ensure tensors are detached and moved to CPU before converting to numpy
|
177
|
+
inputs = inputs.detach().cpu()
|
178
|
+
activation_maps = activation_maps.detach().cpu()
|
179
|
+
|
180
|
+
batch_size, in_channels, height, width = inputs.shape
|
181
|
+
|
182
|
+
if activation_maps.dim() == 3:
|
183
|
+
# If activation maps have no channels, add a dummy channel dimension
|
184
|
+
activation_maps = activation_maps.unsqueeze(1) # Now shape is (batch_size, 1, height, width)
|
185
|
+
|
186
|
+
_, act_channels, act_height, act_width = activation_maps.shape
|
187
|
+
|
188
|
+
# Ensure that the inputs and activation maps are the same size
|
189
|
+
if (height != act_height) or (width != act_width):
|
190
|
+
activation_maps = torch.nn.functional.interpolate(activation_maps, size=(height, width), mode='bilinear')
|
191
|
+
|
192
|
+
# Dictionary to collect correlation results
|
193
|
+
correlations_dict = {'file_name': []}
|
194
|
+
|
195
|
+
# Initialize correlation columns based on input channels and activation map channels
|
196
|
+
for in_c in range(in_channels):
|
197
|
+
for act_c in range(act_channels):
|
198
|
+
correlations_dict[f'channel_{in_c}_activation_{act_c}_pearsons'] = []
|
199
|
+
for threshold in manders_thresholds:
|
200
|
+
correlations_dict[f'channel_{in_c}_activation_{act_c}_{threshold}_M1'] = []
|
201
|
+
correlations_dict[f'channel_{in_c}_activation_{act_c}_{threshold}_M2'] = []
|
202
|
+
|
203
|
+
# Loop over the batch
|
204
|
+
for b in range(batch_size):
|
205
|
+
input_img = inputs[b] # Input image channels (C, H, W)
|
206
|
+
activation_map = activation_maps[b] # Activation map channels (C, H, W)
|
207
|
+
|
208
|
+
# Add the file name to the current row
|
209
|
+
correlations_dict['file_name'].append(file_names[b])
|
210
|
+
|
211
|
+
# Calculate correlations for each channel pair
|
212
|
+
for in_c in range(in_channels):
|
213
|
+
input_channel = input_img[in_c].flatten().numpy() # Flatten the input image channel
|
214
|
+
input_channel = input_channel[np.isfinite(input_channel)] # Remove NaN or inf values
|
215
|
+
|
216
|
+
for act_c in range(act_channels):
|
217
|
+
activation_channel = activation_map[act_c].flatten().numpy() # Flatten the activation map channel
|
218
|
+
activation_channel = activation_channel[np.isfinite(activation_channel)] # Remove NaN or inf values
|
219
|
+
|
220
|
+
# Check if there are valid (non-empty) arrays left to calculate the Pearson correlation
|
221
|
+
if input_channel.size > 0 and activation_channel.size > 0:
|
222
|
+
pearson_corr, _ = pearsonr(input_channel, activation_channel)
|
223
|
+
else:
|
224
|
+
pearson_corr = np.nan # Assign NaN if there are no valid data points
|
225
|
+
correlations_dict[f'channel_{in_c}_activation_{act_c}_pearsons'].append(pearson_corr)
|
226
|
+
|
227
|
+
# Compute Manders correlations for each threshold
|
228
|
+
for threshold in manders_thresholds:
|
229
|
+
# Get the top percentile pixels based on intensity in both channels
|
230
|
+
if input_channel.size > 0 and activation_channel.size > 0:
|
231
|
+
input_threshold = np.percentile(input_channel, threshold)
|
232
|
+
activation_threshold = np.percentile(activation_channel, threshold)
|
233
|
+
|
234
|
+
# Mask the pixels above the threshold
|
235
|
+
mask = (input_channel >= input_threshold) & (activation_channel >= activation_threshold)
|
236
|
+
|
237
|
+
# If we have enough pixels, calculate Manders correlation
|
238
|
+
if np.sum(mask) > 0:
|
239
|
+
manders_corr_M1 = np.sum(input_channel[mask] * activation_channel[mask]) / np.sum(input_channel[mask] ** 2)
|
240
|
+
manders_corr_M2 = np.sum(activation_channel[mask] * input_channel[mask]) / np.sum(activation_channel[mask] ** 2)
|
241
|
+
else:
|
242
|
+
manders_corr_M1 = np.nan
|
243
|
+
manders_corr_M2 = np.nan
|
244
|
+
else:
|
245
|
+
manders_corr_M1 = np.nan
|
246
|
+
manders_corr_M2 = np.nan
|
247
|
+
|
248
|
+
# Store the Manders correlation for this threshold
|
249
|
+
correlations_dict[f'channel_{in_c}_activation_{act_c}_{threshold}_M1'].append(manders_corr_M1)
|
250
|
+
correlations_dict[f'channel_{in_c}_activation_{act_c}_{threshold}_M2'].append(manders_corr_M2)
|
251
|
+
|
252
|
+
# Convert the dictionary to a DataFrame
|
253
|
+
df_correlations = pd.DataFrame(correlations_dict)
|
254
|
+
|
255
|
+
return df_correlations
|
256
|
+
|
257
|
+
def load_settings(csv_file_path, show=False, setting_key='setting_key', setting_value='setting_value'):
|
258
|
+
"""
|
259
|
+
Convert a CSV file with 'settings_key' and 'settings_value' columns into a dictionary.
|
260
|
+
Handles special cases where values are lists, tuples, booleans, None, integers, floats, and nested dictionaries.
|
261
|
+
|
262
|
+
Args:
|
263
|
+
csv_file_path (str): The path to the CSV file.
|
264
|
+
show (bool): Whether to display the dataframe (for debugging).
|
265
|
+
setting_key (str): The name of the column that contains the setting keys.
|
266
|
+
setting_value (str): The name of the column that contains the setting values.
|
267
|
+
|
268
|
+
Returns:
|
269
|
+
dict: A dictionary where 'settings_key' are the keys and 'settings_value' are the values.
|
270
|
+
"""
|
271
|
+
# Read the CSV file into a DataFrame
|
272
|
+
df = pd.read_csv(csv_file_path)
|
273
|
+
|
274
|
+
if show:
|
275
|
+
display(df)
|
276
|
+
|
277
|
+
# Ensure the columns 'setting_key' and 'setting_value' exist
|
278
|
+
if setting_key not in df.columns or setting_value not in df.columns:
|
279
|
+
raise ValueError(f"CSV file must contain {setting_key} and {setting_value} columns.")
|
280
|
+
|
281
|
+
def parse_value(value):
|
282
|
+
"""Parse the string value into the appropriate Python data type."""
|
283
|
+
# Handle empty values
|
284
|
+
if pd.isna(value) or value == '':
|
285
|
+
return None
|
286
|
+
|
287
|
+
# Handle boolean values
|
288
|
+
if value == 'True':
|
289
|
+
return True
|
290
|
+
if value == 'False':
|
291
|
+
return False
|
292
|
+
|
293
|
+
# Handle lists, tuples, dictionaries, and other literals
|
294
|
+
if value.startswith(('(', '[', '{')): # If it starts with (, [ or {, use ast.literal_eval
|
295
|
+
try:
|
296
|
+
parsed_value = ast.literal_eval(value)
|
297
|
+
# If parsed_value is a dict, recursively parse its values
|
298
|
+
if isinstance(parsed_value, dict):
|
299
|
+
parsed_value = {k: parse_value(v) for k, v in parsed_value.items()}
|
300
|
+
return parsed_value
|
301
|
+
except (ValueError, SyntaxError):
|
302
|
+
pass # If there's an error, return the value as-is
|
303
|
+
|
304
|
+
# Handle numeric values (integers and floats)
|
305
|
+
try:
|
306
|
+
if '.' in value:
|
307
|
+
return float(value) # If it contains a dot, convert to float
|
308
|
+
return int(value) # Otherwise, convert to integer
|
309
|
+
except ValueError:
|
310
|
+
pass # If it's not a valid number, return the value as-is
|
311
|
+
|
312
|
+
# Return the original value if no other type matched
|
313
|
+
return value
|
314
|
+
|
315
|
+
# Convert the DataFrame to a dictionary, with parsing of each value
|
316
|
+
result_dict = {key: parse_value(value) for key, value in zip(df[setting_key], df[setting_value])}
|
317
|
+
|
318
|
+
return result_dict
|
319
|
+
|
320
|
+
|
69
321
|
def save_settings(settings, name='settings', show=False):
|
70
322
|
|
71
323
|
settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
|
72
324
|
if show:
|
73
325
|
display(settings_df)
|
74
|
-
|
75
|
-
|
326
|
+
|
327
|
+
if isinstance(settings['src'], list):
|
328
|
+
src = settings['src'][0]
|
329
|
+
name = f"{name}_list"
|
330
|
+
else:
|
331
|
+
src = settings['src']
|
332
|
+
|
333
|
+
settings_csv = os.path.join(src,'settings',f'{name}.csv')
|
334
|
+
os.makedirs(os.path.join(src,'settings'), exist_ok=True)
|
76
335
|
settings_df.to_csv(settings_csv, index=False)
|
77
336
|
|
78
337
|
def print_progress(files_processed, files_to_process, n_jobs, time_ls=None, batch_size=None, operation_type=""):
|
@@ -820,7 +1079,7 @@ def _map_wells_png(file_name, timelapse=False):
|
|
820
1079
|
print(f"Error: {e}")
|
821
1080
|
plate, row, column, field, object_id, prcfo = 'error', 'error', 'error', 'error', 'error', 'error'
|
822
1081
|
if timelapse:
|
823
|
-
return plate, row, column, field, timeid, prcfo, object_id
|
1082
|
+
return plate, row, column, field, timeid, prcfo, object_id
|
824
1083
|
else:
|
825
1084
|
return plate, row, column, field, prcfo, object_id
|
826
1085
|
|
@@ -2987,7 +3246,6 @@ def preprocess_image(image_path, image_size=224, channels=[1,2,3], normalize=Tru
|
|
2987
3246
|
input_tensor = transform(image).unsqueeze(0)
|
2988
3247
|
return image, input_tensor
|
2989
3248
|
|
2990
|
-
|
2991
3249
|
class SaliencyMapGenerator:
|
2992
3250
|
def __init__(self, model):
|
2993
3251
|
self.model = model
|
@@ -3008,18 +3266,194 @@ class SaliencyMapGenerator:
|
|
3008
3266
|
saliency = X.grad.abs()
|
3009
3267
|
return saliency
|
3010
3268
|
|
3011
|
-
def
|
3269
|
+
def compute_saliency_and_predictions(self, X):
|
3270
|
+
self.model.eval()
|
3271
|
+
X.requires_grad_()
|
3272
|
+
|
3273
|
+
# Forward pass to get predictions (logits)
|
3274
|
+
scores = self.model(X).squeeze()
|
3275
|
+
|
3276
|
+
# Get predicted class (0 or 1 for binary classification)
|
3277
|
+
predictions = (scores > 0).long()
|
3278
|
+
|
3279
|
+
# Compute saliency maps
|
3280
|
+
self.model.zero_grad()
|
3281
|
+
target_scores = scores * (2 * predictions - 1)
|
3282
|
+
target_scores.backward(torch.ones_like(target_scores))
|
3283
|
+
|
3284
|
+
saliency = X.grad.abs()
|
3285
|
+
|
3286
|
+
return saliency, predictions
|
3287
|
+
|
3288
|
+
def plot_activation_grid(self, X, saliency, predictions, overlay=True, normalize=False):
|
3012
3289
|
N = X.shape[0]
|
3290
|
+
rows = (N + 7) // 8
|
3291
|
+
fig, axs = plt.subplots(rows, 8, figsize=(16, rows * 2))
|
3292
|
+
|
3013
3293
|
for i in range(N):
|
3014
|
-
|
3015
|
-
|
3016
|
-
|
3017
|
-
|
3018
|
-
|
3019
|
-
|
3020
|
-
|
3021
|
-
|
3022
|
-
|
3294
|
+
ax = axs[i // 8, i % 8]
|
3295
|
+
saliency_map = saliency[i].cpu().numpy() # Move to CPU and convert to numpy
|
3296
|
+
|
3297
|
+
if saliency_map.shape[0] == 3: # Channels first, reshape to (H, W, 3)
|
3298
|
+
saliency_map = np.transpose(saliency_map, (1, 2, 0))
|
3299
|
+
|
3300
|
+
# Normalize image channels to 2nd and 98th percentiles
|
3301
|
+
if overlay:
|
3302
|
+
img_np = X[i].permute(1, 2, 0).detach().cpu().numpy()
|
3303
|
+
if normalize:
|
3304
|
+
img_np = self.percentile_normalize(img_np)
|
3305
|
+
ax.imshow(img_np)
|
3306
|
+
ax.imshow(saliency_map, cmap='jet', alpha=0.5)
|
3307
|
+
|
3308
|
+
# Add class label in the top-left corner
|
3309
|
+
ax.text(5, 25, str(predictions[i].item()), fontsize=12, color='white', weight='bold',
|
3310
|
+
bbox=dict(facecolor='black', alpha=0.7, boxstyle='round,pad=0.2'))
|
3311
|
+
ax.axis('off')
|
3312
|
+
|
3313
|
+
plt.tight_layout(pad=0)
|
3314
|
+
return fig
|
3315
|
+
|
3316
|
+
def percentile_normalize(self, img, lower_percentile=2, upper_percentile=98):
|
3317
|
+
"""
|
3318
|
+
Normalize each channel of the image to the given percentiles.
|
3319
|
+
Args:
|
3320
|
+
img: Input image as numpy array with shape (H, W, C)
|
3321
|
+
lower_percentile: Lower percentile for normalization (default 2)
|
3322
|
+
upper_percentile: Upper percentile for normalization (default 98)
|
3323
|
+
Returns:
|
3324
|
+
img: Normalized image
|
3325
|
+
"""
|
3326
|
+
img_normalized = np.zeros_like(img)
|
3327
|
+
|
3328
|
+
for c in range(img.shape[2]): # Iterate over each channel
|
3329
|
+
low = np.percentile(img[:, :, c], lower_percentile)
|
3330
|
+
high = np.percentile(img[:, :, c], upper_percentile)
|
3331
|
+
img_normalized[:, :, c] = np.clip((img[:, :, c] - low) / (high - low), 0, 1)
|
3332
|
+
|
3333
|
+
return img_normalized
|
3334
|
+
|
3335
|
+
|
3336
|
+
class GradCAMGenerator:
|
3337
|
+
def __init__(self, model, target_layer, cam_type='gradcam'):
|
3338
|
+
self.model = model
|
3339
|
+
self.model.eval()
|
3340
|
+
self.target_layer = target_layer
|
3341
|
+
self.cam_type = cam_type
|
3342
|
+
self.gradients = None
|
3343
|
+
self.activations = None
|
3344
|
+
|
3345
|
+
# Hook the target layer
|
3346
|
+
self.target_layer_module = self.get_layer(self.model, self.target_layer)
|
3347
|
+
self.hook_layers()
|
3348
|
+
|
3349
|
+
def hook_layers(self):
|
3350
|
+
# Forward hook to get activations
|
3351
|
+
def forward_hook(module, input, output):
|
3352
|
+
self.activations = output
|
3353
|
+
|
3354
|
+
# Backward hook to get gradients
|
3355
|
+
def backward_hook(module, grad_input, grad_output):
|
3356
|
+
self.gradients = grad_output[0]
|
3357
|
+
|
3358
|
+
self.target_layer_module.register_forward_hook(forward_hook)
|
3359
|
+
self.target_layer_module.register_backward_hook(backward_hook)
|
3360
|
+
|
3361
|
+
def get_layer(self, model, target_layer):
|
3362
|
+
# Recursively find the layer specified in target_layer
|
3363
|
+
modules = target_layer.split('.')
|
3364
|
+
layer = model
|
3365
|
+
for module in modules:
|
3366
|
+
layer = getattr(layer, module)
|
3367
|
+
return layer
|
3368
|
+
|
3369
|
+
def compute_gradcam_maps(self, X, y):
|
3370
|
+
X.requires_grad_()
|
3371
|
+
|
3372
|
+
# Forward pass
|
3373
|
+
scores = self.model(X).squeeze()
|
3374
|
+
|
3375
|
+
# Perform backward pass
|
3376
|
+
target_scores = scores * (2 * y - 1)
|
3377
|
+
self.model.zero_grad()
|
3378
|
+
target_scores.backward(torch.ones_like(target_scores))
|
3379
|
+
|
3380
|
+
# Compute GradCAM
|
3381
|
+
pooled_gradients = torch.mean(self.gradients, dim=[0, 2, 3])
|
3382
|
+
for i in range(self.activations.size(1)):
|
3383
|
+
self.activations[:, i, :, :] *= pooled_gradients[i]
|
3384
|
+
|
3385
|
+
gradcam = torch.mean(self.activations, dim=1).squeeze()
|
3386
|
+
gradcam = F.relu(gradcam)
|
3387
|
+
gradcam = F.interpolate(gradcam.unsqueeze(0).unsqueeze(0), size=X.shape[2:], mode='bilinear')
|
3388
|
+
gradcam = gradcam.squeeze().cpu().detach().numpy()
|
3389
|
+
gradcam = (gradcam - gradcam.min()) / (gradcam.max() - gradcam.min())
|
3390
|
+
|
3391
|
+
return gradcam
|
3392
|
+
|
3393
|
+
def compute_gradcam_and_predictions(self, X):
|
3394
|
+
self.model.eval()
|
3395
|
+
X.requires_grad_()
|
3396
|
+
|
3397
|
+
# Forward pass to get predictions (logits)
|
3398
|
+
scores = self.model(X).squeeze()
|
3399
|
+
|
3400
|
+
# Get predicted class (0 or 1 for binary classification)
|
3401
|
+
predictions = (scores > 0).long()
|
3402
|
+
|
3403
|
+
# Compute gradcam maps
|
3404
|
+
gradcam_maps = []
|
3405
|
+
for i in range(X.size(0)):
|
3406
|
+
gradcam_map = self.compute_gradcam_maps(X[i].unsqueeze(0), predictions[i])
|
3407
|
+
gradcam_maps.append(gradcam_map)
|
3408
|
+
|
3409
|
+
return torch.tensor(gradcam_maps), predictions
|
3410
|
+
|
3411
|
+
def plot_activation_grid(self, X, gradcam, predictions, overlay=True, normalize=False):
|
3412
|
+
N = X.shape[0]
|
3413
|
+
rows = (N + 7) // 8
|
3414
|
+
fig, axs = plt.subplots(rows, 8, figsize=(16, rows * 2))
|
3415
|
+
|
3416
|
+
for i in range(N):
|
3417
|
+
ax = axs[i // 8, i % 8]
|
3418
|
+
gradcam_map = gradcam[i].cpu().numpy()
|
3419
|
+
|
3420
|
+
# Normalize image channels to 2nd and 98th percentiles
|
3421
|
+
if overlay:
|
3422
|
+
img_np = X[i].permute(1, 2, 0).detach().cpu().numpy()
|
3423
|
+
if normalize:
|
3424
|
+
img_np = self.percentile_normalize(img_np)
|
3425
|
+
ax.imshow(img_np)
|
3426
|
+
ax.imshow(gradcam_map, cmap='jet', alpha=0.5)
|
3427
|
+
|
3428
|
+
#ax.imshow(X[i].permute(1, 2, 0).detach().cpu().numpy()) # Original image
|
3429
|
+
#ax.imshow(gradcam_map, cmap='jet', alpha=0.5) # Overlay the gradcam map
|
3430
|
+
|
3431
|
+
# Add class label in the top-left corner
|
3432
|
+
ax.text(5, 25, str(predictions[i].item()), fontsize=12, color='white', weight='bold',
|
3433
|
+
bbox=dict(facecolor='black', alpha=0.7, boxstyle='round,pad=0.2'))
|
3434
|
+
ax.axis('off')
|
3435
|
+
|
3436
|
+
plt.tight_layout(pad=0)
|
3437
|
+
return fig
|
3438
|
+
|
3439
|
+
def percentile_normalize(self, img, lower_percentile=2, upper_percentile=98):
|
3440
|
+
"""
|
3441
|
+
Normalize each channel of the image to the given percentiles.
|
3442
|
+
Args:
|
3443
|
+
img: Input image as numpy array with shape (H, W, C)
|
3444
|
+
lower_percentile: Lower percentile for normalization (default 2)
|
3445
|
+
upper_percentile: Upper percentile for normalization (default 98)
|
3446
|
+
Returns:
|
3447
|
+
img: Normalized image
|
3448
|
+
"""
|
3449
|
+
img_normalized = np.zeros_like(img)
|
3450
|
+
|
3451
|
+
for c in range(img.shape[2]): # Iterate over each channel
|
3452
|
+
low = np.percentile(img[:, :, c], lower_percentile)
|
3453
|
+
high = np.percentile(img[:, :, c], upper_percentile)
|
3454
|
+
img_normalized[:, :, c] = np.clip((img[:, :, c] - low) / (high - low), 0, 1)
|
3455
|
+
|
3456
|
+
return img_normalized
|
3023
3457
|
|
3024
3458
|
def preprocess_image(image_path, normalize=True, image_size=224, channels=[1,2,3]):
|
3025
3459
|
preprocess = transforms.Compose([
|
@@ -3560,7 +3994,7 @@ def plot_grid(cluster_images, colors, figuresize, black_background, verbose):
|
|
3560
3994
|
plt.show()
|
3561
3995
|
return grid_fig
|
3562
3996
|
|
3563
|
-
def
|
3997
|
+
def generate_path_list_from_db_v1(db_path, file_metadata):
|
3564
3998
|
|
3565
3999
|
all_paths = []
|
3566
4000
|
|
@@ -3590,6 +4024,44 @@ def generate_path_list_from_db(db_path, file_metadata):
|
|
3590
4024
|
|
3591
4025
|
return all_paths
|
3592
4026
|
|
4027
|
+
def generate_path_list_from_db(db_path, file_metadata):
|
4028
|
+
all_paths = []
|
4029
|
+
|
4030
|
+
# Connect to the database and retrieve the image paths
|
4031
|
+
print(f"Reading DataBase: {db_path}")
|
4032
|
+
try:
|
4033
|
+
with sqlite3.connect(db_path) as conn:
|
4034
|
+
cursor = conn.cursor()
|
4035
|
+
|
4036
|
+
if file_metadata:
|
4037
|
+
if isinstance(file_metadata, str):
|
4038
|
+
# If file_metadata is a single string
|
4039
|
+
cursor.execute("SELECT png_path FROM png_list WHERE png_path LIKE ?", (f"%{file_metadata}%",))
|
4040
|
+
elif isinstance(file_metadata, list):
|
4041
|
+
# If file_metadata is a list of strings
|
4042
|
+
query = "SELECT png_path FROM png_list WHERE " + " OR ".join(
|
4043
|
+
["png_path LIKE ?" for _ in file_metadata])
|
4044
|
+
params = [f"%{meta}%" for meta in file_metadata]
|
4045
|
+
cursor.execute(query, params)
|
4046
|
+
else:
|
4047
|
+
# If file_metadata is None or empty
|
4048
|
+
cursor.execute("SELECT png_path FROM png_list")
|
4049
|
+
|
4050
|
+
while True:
|
4051
|
+
rows = cursor.fetchmany(1000)
|
4052
|
+
if not rows:
|
4053
|
+
break
|
4054
|
+
all_paths.extend([row[0] for row in rows])
|
4055
|
+
|
4056
|
+
except sqlite3.Error as e:
|
4057
|
+
print(f"Database error: {e}")
|
4058
|
+
return
|
4059
|
+
except Exception as e:
|
4060
|
+
print(f"Error: {e}")
|
4061
|
+
return
|
4062
|
+
|
4063
|
+
return all_paths
|
4064
|
+
|
3593
4065
|
def correct_paths(df, base_path):
|
3594
4066
|
|
3595
4067
|
if isinstance(df, pd.DataFrame):
|
@@ -4548,3 +5020,25 @@ def download_models(repo_id="einarolafsson/models", local_dir=None, retries=5, d
|
|
4548
5020
|
time.sleep(delay)
|
4549
5021
|
|
4550
5022
|
raise Exception("Failed to download model files after multiple attempts.")
|
5023
|
+
|
5024
|
+
def generate_cytoplasm_mask(nucleus_mask, cell_mask):
|
5025
|
+
|
5026
|
+
"""
|
5027
|
+
Generates a cytoplasm mask from nucleus and cell masks.
|
5028
|
+
|
5029
|
+
Parameters:
|
5030
|
+
- nucleus_mask (np.array): Binary or segmented mask of the nucleus (non-zero values represent nucleus).
|
5031
|
+
- cell_mask (np.array): Binary or segmented mask of the whole cell (non-zero values represent cell).
|
5032
|
+
|
5033
|
+
Returns:
|
5034
|
+
- cytoplasm_mask (np.array): Mask for the cytoplasm (1 for cytoplasm, 0 for nucleus and pathogens).
|
5035
|
+
"""
|
5036
|
+
|
5037
|
+
# Make sure the nucleus and cell masks are numpy arrays
|
5038
|
+
nucleus_mask = np.array(nucleus_mask)
|
5039
|
+
cell_mask = np.array(cell_mask)
|
5040
|
+
|
5041
|
+
# Generate cytoplasm mask
|
5042
|
+
cytoplasm_mask = np.where(np.logical_or(nucleus_mask != 0), 0, cell_mask)
|
5043
|
+
|
5044
|
+
return cytoplasm_mask
|