spacr 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/submodules.py CHANGED
@@ -36,7 +36,7 @@ def analyze_recruitment(settings={}):
36
36
  sns.color_palette("mako", as_cmap=True)
37
37
  print(f"channel:{settings['channel_of_interest']} = {settings['target']}")
38
38
 
39
- df, _ = _read_and_merge_data(db_loc=[settings['src']+'/measurements/measurements.db'],
39
+ df, _ = _read_and_merge_data(locs=[settings['src']+'/measurements/measurements.db'],
40
40
  tables=['cell', 'nucleus', 'pathogen','cytoplasm'],
41
41
  verbose=True,
42
42
  nuclei_limit=settings['nuclei_limit'],
@@ -89,15 +89,16 @@ def analyze_recruitment(settings={}):
89
89
 
90
90
  if not settings['cell_chann_dim'] is None:
91
91
  df = _object_filter(df, 'cell', settings['cell_size_range'], settings['cell_intensity_range'], mask_chans, 0)
92
- if not settings['target_intensity_min'] is None:
93
- df = df[df[f"cell_channel_{settings['channel_of_interest']}_percentile_95'] > settings['target_intensity_min"]]
92
+ if not settings['target_intensity_min'] is None or not settings['target_intensity_min'] is 0:
93
+ df = df[df[f"cell_channel_{settings['channel_of_interest']}_percentile_95"] > settings['target_intensity_min']]
94
94
  print(f"After channel {settings['channel_of_interest']} filtration", len(df))
95
95
  if not settings['nucleus_chann_dim'] is None:
96
96
  df = _object_filter(df, 'nucleus', settings['nucleus_size_range'], settings['nucleus_intensity_range'], mask_chans, 1)
97
97
  if not settings['pathogen_chann_dim'] is None:
98
98
  df = _object_filter(df, 'pathogen', settings['pathogen_size_range'], settings['pathogen_intensity_range'], mask_chans, 2)
99
99
 
100
- df['recruitment'] = df[f"pathogen_channel_{settings['channel_of_interest']}_mean_intensity']/df[f'cytoplasm_channel_{settings['channel_of_interest']}_mean_intensity"]
100
+ df['recruitment'] = df[f"pathogen_channel_{settings['channel_of_interest']}_mean_intensity"]/df[f"cytoplasm_channel_{settings['channel_of_interest']}_mean_intensity"]
101
+
101
102
  for chan in settings['channel_dims']:
102
103
  df = _calculate_recruitment(df, channel=chan)
103
104
  print(f'calculated recruitment for: {len(df)} rows')
@@ -114,9 +115,9 @@ def analyze_recruitment(settings={}):
114
115
  _plot_controls(df, mask_chans, settings['channel_of_interest'], figuresize=5)
115
116
 
116
117
  print(f'PV level: {len(df)} rows')
117
- _plot_recruitment(df, 'by PV', settings['channel_of_interest'], settings['target'], settings['figuresize'])
118
+ _plot_recruitment(df, 'by PV', settings['channel_of_interest'], columns=[], figuresize=settings['figuresize'])
118
119
  print(f'well level: {len(df_well)} rows')
119
- _plot_recruitment(df_well, 'by well', settings['channel_of_interest'], settings['target'], settings['figuresize'])
120
+ _plot_recruitment(df_well, 'by well', settings['channel_of_interest'], columns=[], figuresize=settings['figuresize'])
120
121
  cells,wells = _results_to_csv(settings['src'], df, df_well)
121
122
 
122
123
  return [cells,wells]
spacr/toxo.py CHANGED
@@ -112,10 +112,15 @@ def go_term_enrichment_by_column(significant_df, metadata_path, go_term_columns=
112
112
  - Plot the enrichment score vs -log10(p-value).
113
113
  """
114
114
 
115
- significant_df['variable'].fillna(significant_df['feature'], inplace=True)
116
- split_columns = significant_df['variable'].str.split('_', expand=True)
117
- significant_df['gene_nr'] = split_columns[0]
118
- gene_list = significant_df['gene_nr'].to_list()
115
+ #significant_df['variable'].fillna(significant_df['feature'], inplace=True)
116
+ #split_columns = significant_df['variable'].str.split('_', expand=True)
117
+ #significant_df['gene_nr'] = split_columns[0]
118
+ #gene_list = significant_df['gene_nr'].to_list()
119
+
120
+ significant_df = significant_df.dropna(subset=['n_gene'])
121
+ significant_df = significant_df[significant_df['n_gene'] != None]
122
+
123
+ gene_list = significant_df['n_gene'].to_list()
119
124
 
120
125
  # Load metadata
121
126
  metadata = pd.read_csv(metadata_path)
spacr/utils.py CHANGED
@@ -1,4 +1,4 @@
1
- import os, re, sqlite3, torch, torchvision, random, string, shutil, cv2, tarfile, glob, psutil, platform, gzip, subprocess, time, requests
1
+ import os, re, sqlite3, torch, torchvision, random, string, shutil, cv2, tarfile, glob, psutil, platform, gzip, subprocess, time, requests, ast, traceback
2
2
 
3
3
  import numpy as np
4
4
  import pandas as pd
@@ -12,6 +12,7 @@ from skimage.transform import resize as resizescikit
12
12
  from skimage.morphology import dilation, square
13
13
  from skimage.measure import find_contours
14
14
  from skimage.segmentation import clear_border
15
+ from scipy.stats import pearsonr
15
16
 
16
17
  from collections import defaultdict, OrderedDict
17
18
  from PIL import Image
@@ -37,6 +38,7 @@ from torchvision import models
37
38
  from torchvision.models.resnet import ResNet18_Weights, ResNet34_Weights, ResNet50_Weights, ResNet101_Weights, ResNet152_Weights
38
39
  import torchvision.transforms as transforms
39
40
  from torchvision.models import resnet50
41
+ from torchvision.utils import make_grid
40
42
 
41
43
  import seaborn as sns
42
44
  import matplotlib.pyplot as plt
@@ -66,13 +68,270 @@ from huggingface_hub import list_repo_files
66
68
  import umap.umap_ as umap
67
69
  #import umap
68
70
 
71
+ def filepaths_to_database(img_paths, settings, source_folder, crop_mode):
72
+
73
+ png_df = pd.DataFrame(img_paths, columns=['png_path'])
74
+
75
+ png_df['file_name'] = png_df['png_path'].apply(lambda x: os.path.basename(x))
76
+
77
+ parts = png_df['file_name'].apply(lambda x: pd.Series(_map_wells_png(x, timelapse=settings['timelapse'])))
78
+
79
+ columns = ['plate', 'row', 'col', 'field']
80
+
81
+ if settings['timelapse']:
82
+ columns = columns + ['time_id']
83
+
84
+ columns = columns + ['prcfo']
85
+
86
+ if crop_mode == 'cell':
87
+ columns = columns + ['cell_id']
88
+
89
+ if crop_mode == 'nucleus':
90
+ columns = columns + ['nucleus_id']
91
+
92
+ if crop_mode == 'pathogen':
93
+ columns = columns + ['pathogen_id']
94
+
95
+ if crop_mode == 'cytoplasm':
96
+ columns = columns + ['cytoplasm_id']
97
+
98
+ png_df[columns] = parts
99
+
100
+ try:
101
+ conn = sqlite3.connect(f'{source_folder}/measurements/measurements.db', timeout=5)
102
+ png_df.to_sql('png_list', conn, if_exists='append', index=False)
103
+ conn.commit()
104
+ except sqlite3.OperationalError as e:
105
+ print(f"SQLite error: {e}", flush=True)
106
+ traceback.print_exc()
107
+
108
+ def activation_maps_to_database(img_paths, source_folder, settings):
109
+ from .io import _create_database
110
+
111
+ png_df = pd.DataFrame(img_paths, columns=['png_path'])
112
+ png_df['file_name'] = png_df['png_path'].apply(lambda x: os.path.basename(x))
113
+ parts = png_df['file_name'].apply(lambda x: pd.Series(_map_wells_png(x, timelapse=False)))
114
+ columns = ['plate', 'row', 'col', 'field', 'prcfo', 'object']
115
+ png_df[columns] = parts
116
+
117
+ dataset_name = os.path.splitext(os.path.basename(settings['dataset']))[0]
118
+ database_name = f"{source_folder}/measurements/{dataset_name}.db"
119
+
120
+ if not os.path.exists(database_name):
121
+ _create_database(database_name)
122
+
123
+ try:
124
+ conn = sqlite3.connect(database_name, timeout=5)
125
+ png_df.to_sql(f"{settings['cam_type']}_list", conn, if_exists='append', index=False)
126
+ conn.commit()
127
+ except sqlite3.OperationalError as e:
128
+ print(f"SQLite error: {e}", flush=True)
129
+ traceback.print_exc()
130
+
131
+ def activation_correlations_to_database(df, img_paths, source_folder, settings):
132
+ from .io import _create_database
133
+
134
+ png_df = pd.DataFrame(img_paths, columns=['png_path'])
135
+ png_df['file_name'] = png_df['png_path'].apply(lambda x: os.path.basename(x))
136
+ parts = png_df['file_name'].apply(lambda x: pd.Series(_map_wells_png(x, timelapse=False)))
137
+ columns = ['plate', 'row', 'col', 'field', 'prcfo', 'object']
138
+ png_df[columns] = parts
139
+
140
+ # Align both DataFrames by file_name
141
+ png_df.set_index('file_name', inplace=True)
142
+ df.set_index('file_name', inplace=True)
143
+
144
+ merged_df = pd.concat([png_df, df], axis=1)
145
+ merged_df.reset_index(inplace=True)
146
+
147
+ dataset_name = os.path.splitext(os.path.basename(settings['dataset']))[0]
148
+ database_name = f"{source_folder}/measurements/{dataset_name}.db"
149
+
150
+ if not os.path.exists(database_name):
151
+ _create_database(database_name)
152
+
153
+ try:
154
+ conn = sqlite3.connect(database_name, timeout=5)
155
+ merged_df.to_sql(f"{settings['cam_type']}_correlations", conn, if_exists='append', index=False)
156
+ conn.commit()
157
+ except sqlite3.OperationalError as e:
158
+ print(f"SQLite error: {e}", flush=True)
159
+ traceback.print_exc()
160
+
161
+ def calculate_activation_correlations(inputs, activation_maps, file_names, manders_thresholds=[15, 50, 75]):
162
+ """
163
+ Calculates Pearson and Manders correlations between input image channels and activation map channels.
164
+
165
+ Args:
166
+ inputs: A batch of input images, Tensor of shape (batch_size, channels, height, width)
167
+ activation_maps: A batch of activation maps, Tensor of shape (batch_size, channels, height, width)
168
+ file_names: List of file names corresponding to each image in the batch.
169
+ manders_thresholds: List of intensity percentiles to calculate Manders correlation.
170
+
171
+ Returns:
172
+ df_correlations: A DataFrame with columns for pairwise correlations (Pearson and Manders)
173
+ between input channels and activation map channels.
174
+ """
175
+
176
+ # Ensure tensors are detached and moved to CPU before converting to numpy
177
+ inputs = inputs.detach().cpu()
178
+ activation_maps = activation_maps.detach().cpu()
179
+
180
+ batch_size, in_channels, height, width = inputs.shape
181
+
182
+ if activation_maps.dim() == 3:
183
+ # If activation maps have no channels, add a dummy channel dimension
184
+ activation_maps = activation_maps.unsqueeze(1) # Now shape is (batch_size, 1, height, width)
185
+
186
+ _, act_channels, act_height, act_width = activation_maps.shape
187
+
188
+ # Ensure that the inputs and activation maps are the same size
189
+ if (height != act_height) or (width != act_width):
190
+ activation_maps = torch.nn.functional.interpolate(activation_maps, size=(height, width), mode='bilinear')
191
+
192
+ # Dictionary to collect correlation results
193
+ correlations_dict = {'file_name': []}
194
+
195
+ # Initialize correlation columns based on input channels and activation map channels
196
+ for in_c in range(in_channels):
197
+ for act_c in range(act_channels):
198
+ correlations_dict[f'channel_{in_c}_activation_{act_c}_pearsons'] = []
199
+ for threshold in manders_thresholds:
200
+ correlations_dict[f'channel_{in_c}_activation_{act_c}_{threshold}_M1'] = []
201
+ correlations_dict[f'channel_{in_c}_activation_{act_c}_{threshold}_M2'] = []
202
+
203
+ # Loop over the batch
204
+ for b in range(batch_size):
205
+ input_img = inputs[b] # Input image channels (C, H, W)
206
+ activation_map = activation_maps[b] # Activation map channels (C, H, W)
207
+
208
+ # Add the file name to the current row
209
+ correlations_dict['file_name'].append(file_names[b])
210
+
211
+ # Calculate correlations for each channel pair
212
+ for in_c in range(in_channels):
213
+ input_channel = input_img[in_c].flatten().numpy() # Flatten the input image channel
214
+ input_channel = input_channel[np.isfinite(input_channel)] # Remove NaN or inf values
215
+
216
+ for act_c in range(act_channels):
217
+ activation_channel = activation_map[act_c].flatten().numpy() # Flatten the activation map channel
218
+ activation_channel = activation_channel[np.isfinite(activation_channel)] # Remove NaN or inf values
219
+
220
+ # Check if there are valid (non-empty) arrays left to calculate the Pearson correlation
221
+ if input_channel.size > 0 and activation_channel.size > 0:
222
+ pearson_corr, _ = pearsonr(input_channel, activation_channel)
223
+ else:
224
+ pearson_corr = np.nan # Assign NaN if there are no valid data points
225
+ correlations_dict[f'channel_{in_c}_activation_{act_c}_pearsons'].append(pearson_corr)
226
+
227
+ # Compute Manders correlations for each threshold
228
+ for threshold in manders_thresholds:
229
+ # Get the top percentile pixels based on intensity in both channels
230
+ if input_channel.size > 0 and activation_channel.size > 0:
231
+ input_threshold = np.percentile(input_channel, threshold)
232
+ activation_threshold = np.percentile(activation_channel, threshold)
233
+
234
+ # Mask the pixels above the threshold
235
+ mask = (input_channel >= input_threshold) & (activation_channel >= activation_threshold)
236
+
237
+ # If we have enough pixels, calculate Manders correlation
238
+ if np.sum(mask) > 0:
239
+ manders_corr_M1 = np.sum(input_channel[mask] * activation_channel[mask]) / np.sum(input_channel[mask] ** 2)
240
+ manders_corr_M2 = np.sum(activation_channel[mask] * input_channel[mask]) / np.sum(activation_channel[mask] ** 2)
241
+ else:
242
+ manders_corr_M1 = np.nan
243
+ manders_corr_M2 = np.nan
244
+ else:
245
+ manders_corr_M1 = np.nan
246
+ manders_corr_M2 = np.nan
247
+
248
+ # Store the Manders correlation for this threshold
249
+ correlations_dict[f'channel_{in_c}_activation_{act_c}_{threshold}_M1'].append(manders_corr_M1)
250
+ correlations_dict[f'channel_{in_c}_activation_{act_c}_{threshold}_M2'].append(manders_corr_M2)
251
+
252
+ # Convert the dictionary to a DataFrame
253
+ df_correlations = pd.DataFrame(correlations_dict)
254
+
255
+ return df_correlations
256
+
257
+ def load_settings(csv_file_path, show=False, setting_key='setting_key', setting_value='setting_value'):
258
+ """
259
+ Convert a CSV file with 'settings_key' and 'settings_value' columns into a dictionary.
260
+ Handles special cases where values are lists, tuples, booleans, None, integers, floats, and nested dictionaries.
261
+
262
+ Args:
263
+ csv_file_path (str): The path to the CSV file.
264
+ show (bool): Whether to display the dataframe (for debugging).
265
+ setting_key (str): The name of the column that contains the setting keys.
266
+ setting_value (str): The name of the column that contains the setting values.
267
+
268
+ Returns:
269
+ dict: A dictionary where 'settings_key' are the keys and 'settings_value' are the values.
270
+ """
271
+ # Read the CSV file into a DataFrame
272
+ df = pd.read_csv(csv_file_path)
273
+
274
+ if show:
275
+ display(df)
276
+
277
+ # Ensure the columns 'setting_key' and 'setting_value' exist
278
+ if setting_key not in df.columns or setting_value not in df.columns:
279
+ raise ValueError(f"CSV file must contain {setting_key} and {setting_value} columns.")
280
+
281
+ def parse_value(value):
282
+ """Parse the string value into the appropriate Python data type."""
283
+ # Handle empty values
284
+ if pd.isna(value) or value == '':
285
+ return None
286
+
287
+ # Handle boolean values
288
+ if value == 'True':
289
+ return True
290
+ if value == 'False':
291
+ return False
292
+
293
+ # Handle lists, tuples, dictionaries, and other literals
294
+ if value.startswith(('(', '[', '{')): # If it starts with (, [ or {, use ast.literal_eval
295
+ try:
296
+ parsed_value = ast.literal_eval(value)
297
+ # If parsed_value is a dict, recursively parse its values
298
+ if isinstance(parsed_value, dict):
299
+ parsed_value = {k: parse_value(v) for k, v in parsed_value.items()}
300
+ return parsed_value
301
+ except (ValueError, SyntaxError):
302
+ pass # If there's an error, return the value as-is
303
+
304
+ # Handle numeric values (integers and floats)
305
+ try:
306
+ if '.' in value:
307
+ return float(value) # If it contains a dot, convert to float
308
+ return int(value) # Otherwise, convert to integer
309
+ except ValueError:
310
+ pass # If it's not a valid number, return the value as-is
311
+
312
+ # Return the original value if no other type matched
313
+ return value
314
+
315
+ # Convert the DataFrame to a dictionary, with parsing of each value
316
+ result_dict = {key: parse_value(value) for key, value in zip(df[setting_key], df[setting_value])}
317
+
318
+ return result_dict
319
+
320
+
69
321
  def save_settings(settings, name='settings', show=False):
70
322
 
71
323
  settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
72
324
  if show:
73
325
  display(settings_df)
74
- settings_csv = os.path.join(settings['src'],'settings',f'{name}.csv')
75
- os.makedirs(os.path.join(settings['src'],'settings'), exist_ok=True)
326
+
327
+ if isinstance(settings['src'], list):
328
+ src = settings['src'][0]
329
+ name = f"{name}_list"
330
+ else:
331
+ src = settings['src']
332
+
333
+ settings_csv = os.path.join(src,'settings',f'{name}.csv')
334
+ os.makedirs(os.path.join(src,'settings'), exist_ok=True)
76
335
  settings_df.to_csv(settings_csv, index=False)
77
336
 
78
337
  def print_progress(files_processed, files_to_process, n_jobs, time_ls=None, batch_size=None, operation_type=""):
@@ -820,7 +1079,7 @@ def _map_wells_png(file_name, timelapse=False):
820
1079
  print(f"Error: {e}")
821
1080
  plate, row, column, field, object_id, prcfo = 'error', 'error', 'error', 'error', 'error', 'error'
822
1081
  if timelapse:
823
- return plate, row, column, field, timeid, prcfo, object_id,
1082
+ return plate, row, column, field, timeid, prcfo, object_id
824
1083
  else:
825
1084
  return plate, row, column, field, prcfo, object_id
826
1085
 
@@ -2987,7 +3246,6 @@ def preprocess_image(image_path, image_size=224, channels=[1,2,3], normalize=Tru
2987
3246
  input_tensor = transform(image).unsqueeze(0)
2988
3247
  return image, input_tensor
2989
3248
 
2990
-
2991
3249
  class SaliencyMapGenerator:
2992
3250
  def __init__(self, model):
2993
3251
  self.model = model
@@ -3008,18 +3266,194 @@ class SaliencyMapGenerator:
3008
3266
  saliency = X.grad.abs()
3009
3267
  return saliency
3010
3268
 
3011
- def plot_saliency_maps(self, X, y, saliency, class_names):
3269
+ def compute_saliency_and_predictions(self, X):
3270
+ self.model.eval()
3271
+ X.requires_grad_()
3272
+
3273
+ # Forward pass to get predictions (logits)
3274
+ scores = self.model(X).squeeze()
3275
+
3276
+ # Get predicted class (0 or 1 for binary classification)
3277
+ predictions = (scores > 0).long()
3278
+
3279
+ # Compute saliency maps
3280
+ self.model.zero_grad()
3281
+ target_scores = scores * (2 * predictions - 1)
3282
+ target_scores.backward(torch.ones_like(target_scores))
3283
+
3284
+ saliency = X.grad.abs()
3285
+
3286
+ return saliency, predictions
3287
+
3288
+ def plot_activation_grid(self, X, saliency, predictions, overlay=True, normalize=False):
3012
3289
  N = X.shape[0]
3290
+ rows = (N + 7) // 8
3291
+ fig, axs = plt.subplots(rows, 8, figsize=(16, rows * 2))
3292
+
3013
3293
  for i in range(N):
3014
- plt.subplot(2, N, i + 1)
3015
- plt.imshow(X[i].permute(1, 2, 0).cpu().numpy())
3016
- plt.axis('off')
3017
- plt.title(class_names[y[i]])
3018
- plt.subplot(2, N, N + i + 1)
3019
- plt.imshow(saliency[i].cpu().numpy(), cmap=plt.cm.hot)
3020
- plt.axis('off')
3021
- plt.gcf().set_size_inches(12, 5)
3022
- plt.show()
3294
+ ax = axs[i // 8, i % 8]
3295
+ saliency_map = saliency[i].cpu().numpy() # Move to CPU and convert to numpy
3296
+
3297
+ if saliency_map.shape[0] == 3: # Channels first, reshape to (H, W, 3)
3298
+ saliency_map = np.transpose(saliency_map, (1, 2, 0))
3299
+
3300
+ # Normalize image channels to 2nd and 98th percentiles
3301
+ if overlay:
3302
+ img_np = X[i].permute(1, 2, 0).detach().cpu().numpy()
3303
+ if normalize:
3304
+ img_np = self.percentile_normalize(img_np)
3305
+ ax.imshow(img_np)
3306
+ ax.imshow(saliency_map, cmap='jet', alpha=0.5)
3307
+
3308
+ # Add class label in the top-left corner
3309
+ ax.text(5, 25, str(predictions[i].item()), fontsize=12, color='white', weight='bold',
3310
+ bbox=dict(facecolor='black', alpha=0.7, boxstyle='round,pad=0.2'))
3311
+ ax.axis('off')
3312
+
3313
+ plt.tight_layout(pad=0)
3314
+ return fig
3315
+
3316
+ def percentile_normalize(self, img, lower_percentile=2, upper_percentile=98):
3317
+ """
3318
+ Normalize each channel of the image to the given percentiles.
3319
+ Args:
3320
+ img: Input image as numpy array with shape (H, W, C)
3321
+ lower_percentile: Lower percentile for normalization (default 2)
3322
+ upper_percentile: Upper percentile for normalization (default 98)
3323
+ Returns:
3324
+ img: Normalized image
3325
+ """
3326
+ img_normalized = np.zeros_like(img)
3327
+
3328
+ for c in range(img.shape[2]): # Iterate over each channel
3329
+ low = np.percentile(img[:, :, c], lower_percentile)
3330
+ high = np.percentile(img[:, :, c], upper_percentile)
3331
+ img_normalized[:, :, c] = np.clip((img[:, :, c] - low) / (high - low), 0, 1)
3332
+
3333
+ return img_normalized
3334
+
3335
+
3336
+ class GradCAMGenerator:
3337
+ def __init__(self, model, target_layer, cam_type='gradcam'):
3338
+ self.model = model
3339
+ self.model.eval()
3340
+ self.target_layer = target_layer
3341
+ self.cam_type = cam_type
3342
+ self.gradients = None
3343
+ self.activations = None
3344
+
3345
+ # Hook the target layer
3346
+ self.target_layer_module = self.get_layer(self.model, self.target_layer)
3347
+ self.hook_layers()
3348
+
3349
+ def hook_layers(self):
3350
+ # Forward hook to get activations
3351
+ def forward_hook(module, input, output):
3352
+ self.activations = output
3353
+
3354
+ # Backward hook to get gradients
3355
+ def backward_hook(module, grad_input, grad_output):
3356
+ self.gradients = grad_output[0]
3357
+
3358
+ self.target_layer_module.register_forward_hook(forward_hook)
3359
+ self.target_layer_module.register_backward_hook(backward_hook)
3360
+
3361
+ def get_layer(self, model, target_layer):
3362
+ # Recursively find the layer specified in target_layer
3363
+ modules = target_layer.split('.')
3364
+ layer = model
3365
+ for module in modules:
3366
+ layer = getattr(layer, module)
3367
+ return layer
3368
+
3369
+ def compute_gradcam_maps(self, X, y):
3370
+ X.requires_grad_()
3371
+
3372
+ # Forward pass
3373
+ scores = self.model(X).squeeze()
3374
+
3375
+ # Perform backward pass
3376
+ target_scores = scores * (2 * y - 1)
3377
+ self.model.zero_grad()
3378
+ target_scores.backward(torch.ones_like(target_scores))
3379
+
3380
+ # Compute GradCAM
3381
+ pooled_gradients = torch.mean(self.gradients, dim=[0, 2, 3])
3382
+ for i in range(self.activations.size(1)):
3383
+ self.activations[:, i, :, :] *= pooled_gradients[i]
3384
+
3385
+ gradcam = torch.mean(self.activations, dim=1).squeeze()
3386
+ gradcam = F.relu(gradcam)
3387
+ gradcam = F.interpolate(gradcam.unsqueeze(0).unsqueeze(0), size=X.shape[2:], mode='bilinear')
3388
+ gradcam = gradcam.squeeze().cpu().detach().numpy()
3389
+ gradcam = (gradcam - gradcam.min()) / (gradcam.max() - gradcam.min())
3390
+
3391
+ return gradcam
3392
+
3393
+ def compute_gradcam_and_predictions(self, X):
3394
+ self.model.eval()
3395
+ X.requires_grad_()
3396
+
3397
+ # Forward pass to get predictions (logits)
3398
+ scores = self.model(X).squeeze()
3399
+
3400
+ # Get predicted class (0 or 1 for binary classification)
3401
+ predictions = (scores > 0).long()
3402
+
3403
+ # Compute gradcam maps
3404
+ gradcam_maps = []
3405
+ for i in range(X.size(0)):
3406
+ gradcam_map = self.compute_gradcam_maps(X[i].unsqueeze(0), predictions[i])
3407
+ gradcam_maps.append(gradcam_map)
3408
+
3409
+ return torch.tensor(gradcam_maps), predictions
3410
+
3411
+ def plot_activation_grid(self, X, gradcam, predictions, overlay=True, normalize=False):
3412
+ N = X.shape[0]
3413
+ rows = (N + 7) // 8
3414
+ fig, axs = plt.subplots(rows, 8, figsize=(16, rows * 2))
3415
+
3416
+ for i in range(N):
3417
+ ax = axs[i // 8, i % 8]
3418
+ gradcam_map = gradcam[i].cpu().numpy()
3419
+
3420
+ # Normalize image channels to 2nd and 98th percentiles
3421
+ if overlay:
3422
+ img_np = X[i].permute(1, 2, 0).detach().cpu().numpy()
3423
+ if normalize:
3424
+ img_np = self.percentile_normalize(img_np)
3425
+ ax.imshow(img_np)
3426
+ ax.imshow(gradcam_map, cmap='jet', alpha=0.5)
3427
+
3428
+ #ax.imshow(X[i].permute(1, 2, 0).detach().cpu().numpy()) # Original image
3429
+ #ax.imshow(gradcam_map, cmap='jet', alpha=0.5) # Overlay the gradcam map
3430
+
3431
+ # Add class label in the top-left corner
3432
+ ax.text(5, 25, str(predictions[i].item()), fontsize=12, color='white', weight='bold',
3433
+ bbox=dict(facecolor='black', alpha=0.7, boxstyle='round,pad=0.2'))
3434
+ ax.axis('off')
3435
+
3436
+ plt.tight_layout(pad=0)
3437
+ return fig
3438
+
3439
+ def percentile_normalize(self, img, lower_percentile=2, upper_percentile=98):
3440
+ """
3441
+ Normalize each channel of the image to the given percentiles.
3442
+ Args:
3443
+ img: Input image as numpy array with shape (H, W, C)
3444
+ lower_percentile: Lower percentile for normalization (default 2)
3445
+ upper_percentile: Upper percentile for normalization (default 98)
3446
+ Returns:
3447
+ img: Normalized image
3448
+ """
3449
+ img_normalized = np.zeros_like(img)
3450
+
3451
+ for c in range(img.shape[2]): # Iterate over each channel
3452
+ low = np.percentile(img[:, :, c], lower_percentile)
3453
+ high = np.percentile(img[:, :, c], upper_percentile)
3454
+ img_normalized[:, :, c] = np.clip((img[:, :, c] - low) / (high - low), 0, 1)
3455
+
3456
+ return img_normalized
3023
3457
 
3024
3458
  def preprocess_image(image_path, normalize=True, image_size=224, channels=[1,2,3]):
3025
3459
  preprocess = transforms.Compose([
@@ -3560,7 +3994,7 @@ def plot_grid(cluster_images, colors, figuresize, black_background, verbose):
3560
3994
  plt.show()
3561
3995
  return grid_fig
3562
3996
 
3563
- def generate_path_list_from_db(db_path, file_metadata):
3997
+ def generate_path_list_from_db_v1(db_path, file_metadata):
3564
3998
 
3565
3999
  all_paths = []
3566
4000
 
@@ -3590,6 +4024,44 @@ def generate_path_list_from_db(db_path, file_metadata):
3590
4024
 
3591
4025
  return all_paths
3592
4026
 
4027
+ def generate_path_list_from_db(db_path, file_metadata):
4028
+ all_paths = []
4029
+
4030
+ # Connect to the database and retrieve the image paths
4031
+ print(f"Reading DataBase: {db_path}")
4032
+ try:
4033
+ with sqlite3.connect(db_path) as conn:
4034
+ cursor = conn.cursor()
4035
+
4036
+ if file_metadata:
4037
+ if isinstance(file_metadata, str):
4038
+ # If file_metadata is a single string
4039
+ cursor.execute("SELECT png_path FROM png_list WHERE png_path LIKE ?", (f"%{file_metadata}%",))
4040
+ elif isinstance(file_metadata, list):
4041
+ # If file_metadata is a list of strings
4042
+ query = "SELECT png_path FROM png_list WHERE " + " OR ".join(
4043
+ ["png_path LIKE ?" for _ in file_metadata])
4044
+ params = [f"%{meta}%" for meta in file_metadata]
4045
+ cursor.execute(query, params)
4046
+ else:
4047
+ # If file_metadata is None or empty
4048
+ cursor.execute("SELECT png_path FROM png_list")
4049
+
4050
+ while True:
4051
+ rows = cursor.fetchmany(1000)
4052
+ if not rows:
4053
+ break
4054
+ all_paths.extend([row[0] for row in rows])
4055
+
4056
+ except sqlite3.Error as e:
4057
+ print(f"Database error: {e}")
4058
+ return
4059
+ except Exception as e:
4060
+ print(f"Error: {e}")
4061
+ return
4062
+
4063
+ return all_paths
4064
+
3593
4065
  def correct_paths(df, base_path):
3594
4066
 
3595
4067
  if isinstance(df, pd.DataFrame):
@@ -4548,3 +5020,25 @@ def download_models(repo_id="einarolafsson/models", local_dir=None, retries=5, d
4548
5020
  time.sleep(delay)
4549
5021
 
4550
5022
  raise Exception("Failed to download model files after multiple attempts.")
5023
+
5024
+ def generate_cytoplasm_mask(nucleus_mask, cell_mask):
5025
+
5026
+ """
5027
+ Generates a cytoplasm mask from nucleus and cell masks.
5028
+
5029
+ Parameters:
5030
+ - nucleus_mask (np.array): Binary or segmented mask of the nucleus (non-zero values represent nucleus).
5031
+ - cell_mask (np.array): Binary or segmented mask of the whole cell (non-zero values represent cell).
5032
+
5033
+ Returns:
5034
+ - cytoplasm_mask (np.array): Mask for the cytoplasm (1 for cytoplasm, 0 for nucleus and pathogens).
5035
+ """
5036
+
5037
+ # Make sure the nucleus and cell masks are numpy arrays
5038
+ nucleus_mask = np.array(nucleus_mask)
5039
+ cell_mask = np.array(cell_mask)
5040
+
5041
+ # Generate cytoplasm mask
5042
+ cytoplasm_mask = np.where(np.logical_or(nucleus_mask != 0), 0, cell_mask)
5043
+
5044
+ return cytoplasm_mask