spacr 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. spacr/__init__.py +19 -3
  2. spacr/cellpose.py +311 -0
  3. spacr/core.py +245 -2494
  4. spacr/deep_spacr.py +335 -163
  5. spacr/gui.py +2 -0
  6. spacr/gui_core.py +85 -65
  7. spacr/gui_elements.py +110 -5
  8. spacr/gui_utils.py +375 -7
  9. spacr/io.py +680 -141
  10. spacr/logger.py +28 -9
  11. spacr/measure.py +108 -133
  12. spacr/mediar.py +0 -3
  13. spacr/ml.py +1051 -0
  14. spacr/openai.py +37 -0
  15. spacr/plot.py +707 -20
  16. spacr/resources/data/lopit.csv +3833 -0
  17. spacr/resources/data/toxoplasma_metadata.csv +8843 -0
  18. spacr/resources/icons/convert.png +0 -0
  19. spacr/resources/{models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model → icons/dna_matrix.mp4} +0 -0
  20. spacr/sequencing.py +241 -1311
  21. spacr/settings.py +181 -50
  22. spacr/sim.py +0 -2
  23. spacr/submodules.py +349 -0
  24. spacr/timelapse.py +0 -2
  25. spacr/toxo.py +238 -0
  26. spacr/utils.py +776 -182
  27. {spacr-0.3.1.dist-info → spacr-0.3.3.dist-info}/METADATA +31 -22
  28. {spacr-0.3.1.dist-info → spacr-0.3.3.dist-info}/RECORD +32 -33
  29. spacr/chris.py +0 -50
  30. spacr/graph_learning.py +0 -340
  31. spacr/resources/MEDIAR/.git +0 -1
  32. spacr/resources/MEDIAR_weights/.DS_Store +0 -0
  33. spacr/resources/icons/.DS_Store +0 -0
  34. spacr/resources/icons/spacr_logo_rotation.gif +0 -0
  35. spacr/resources/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model_settings.csv +0 -23
  36. spacr/resources/models/cp/toxo_pv_lumen.CP_model +0 -0
  37. spacr/sim_app.py +0 -0
  38. {spacr-0.3.1.dist-info → spacr-0.3.3.dist-info}/LICENSE +0 -0
  39. {spacr-0.3.1.dist-info → spacr-0.3.3.dist-info}/WHEEL +0 -0
  40. {spacr-0.3.1.dist-info → spacr-0.3.3.dist-info}/entry_points.txt +0 -0
  41. {spacr-0.3.1.dist-info → spacr-0.3.3.dist-info}/top_level.txt +0 -0
spacr/utils.py CHANGED
@@ -1,6 +1,7 @@
1
- import sys, os, re, sqlite3, torch, torchvision, random, string, shutil, cv2, tarfile, glob, psutil, platform, gzip, subprocess
1
+ import os, re, sqlite3, torch, torchvision, random, string, shutil, cv2, tarfile, glob, psutil, platform, gzip, subprocess, time, requests, ast, traceback
2
2
 
3
3
  import numpy as np
4
+ import pandas as pd
4
5
  from cellpose import models as cp_models
5
6
  from cellpose import denoise
6
7
 
@@ -11,10 +12,10 @@ from skimage.transform import resize as resizescikit
11
12
  from skimage.morphology import dilation, square
12
13
  from skimage.measure import find_contours
13
14
  from skimage.segmentation import clear_border
15
+ from scipy.stats import pearsonr
14
16
 
15
17
  from collections import defaultdict, OrderedDict
16
18
  from PIL import Image
17
- import pandas as pd
18
19
  from statsmodels.stats.outliers_influence import variance_inflation_factor
19
20
  from statsmodels.stats.stattools import durbin_watson
20
21
  import statsmodels.formula.api as smf
@@ -24,7 +25,7 @@ from itertools import combinations
24
25
  from functools import reduce
25
26
  from IPython.display import display
26
27
 
27
- from multiprocessing import Pool, cpu_count
28
+ from multiprocessing import Pool, cpu_count, set_start_method, get_start_method
28
29
  from concurrent.futures import ThreadPoolExecutor
29
30
 
30
31
  import torch.nn as nn
@@ -33,65 +34,304 @@ from torch.utils.checkpoint import checkpoint
33
34
  from torch.utils.data import Subset
34
35
  from torch.autograd import grad
35
36
 
37
+ from torchvision import models
38
+ from torchvision.models.resnet import ResNet18_Weights, ResNet34_Weights, ResNet50_Weights, ResNet101_Weights, ResNet152_Weights
39
+ import torchvision.transforms as transforms
40
+ from torchvision.models import resnet50
41
+ from torchvision.utils import make_grid
42
+
36
43
  import seaborn as sns
37
44
  import matplotlib.pyplot as plt
38
45
  from matplotlib.offsetbox import OffsetImage, AnnotationBbox
39
46
 
47
+ from scipy import stats
40
48
  import scipy.ndimage as ndi
41
49
  from scipy.spatial import distance
42
- from scipy.stats import fisher_exact
50
+ from scipy.stats import fisher_exact, f_oneway, kruskal
43
51
  from scipy.ndimage.filters import gaussian_filter
44
52
  from scipy.spatial import ConvexHull
45
53
  from scipy.interpolate import splprep, splev
46
54
  from scipy.ndimage import binary_dilation
47
55
 
48
- from sklearn.preprocessing import StandardScaler
49
56
  from skimage.exposure import rescale_intensity
50
57
  from sklearn.metrics import auc, precision_recall_curve
51
58
  from sklearn.model_selection import train_test_split
52
59
  from sklearn.linear_model import Lasso, Ridge
53
- from sklearn.preprocessing import OneHotEncoder
54
- from sklearn.cluster import KMeans
55
- from sklearn.preprocessing import StandardScaler
56
- from sklearn.cluster import DBSCAN
57
- from sklearn.cluster import KMeans
60
+ from sklearn.preprocessing import OneHotEncoder, StandardScaler
61
+ from sklearn.cluster import KMeans, DBSCAN
58
62
  from sklearn.manifold import TSNE
59
- from sklearn.cluster import KMeans
60
63
  from sklearn.decomposition import PCA
64
+ from sklearn.ensemble import RandomForestClassifier
65
+
66
+ from huggingface_hub import list_repo_files
61
67
 
62
68
  import umap.umap_ as umap
69
+ #import umap
63
70
 
64
- from torchvision import models
65
- from torchvision.models.resnet import ResNet18_Weights, ResNet34_Weights, ResNet50_Weights, ResNet101_Weights, ResNet152_Weights
66
- import torchvision.transforms as transforms
71
+ def filepaths_to_database(img_paths, settings, source_folder, crop_mode):
67
72
 
68
- from sklearn.ensemble import RandomForestClassifier
69
- from sklearn.preprocessing import StandardScaler
70
- from scipy.stats import f_oneway, kruskal
71
- from sklearn.cluster import KMeans
72
- from scipy import stats
73
+ png_df = pd.DataFrame(img_paths, columns=['png_path'])
73
74
 
74
- from .logger import log_function_call
75
- from multiprocessing import set_start_method, get_start_method
75
+ png_df['file_name'] = png_df['png_path'].apply(lambda x: os.path.basename(x))
76
76
 
77
- import torch
78
- import torchvision.transforms as transforms
79
- from torchvision.models import resnet50
80
- from PIL import Image
81
- import numpy as np
82
- import umap
83
- import pandas as pd
84
- from sklearn.ensemble import RandomForestClassifier
85
- from sklearn.preprocessing import StandardScaler
86
- from scipy.stats import f_oneway, kruskal
87
- from sklearn.cluster import KMeans
88
- from scipy import stats
77
+ parts = png_df['file_name'].apply(lambda x: pd.Series(_map_wells_png(x, timelapse=settings['timelapse'])))
78
+
79
+ columns = ['plate', 'row', 'col', 'field']
80
+
81
+ if settings['timelapse']:
82
+ columns = columns + ['time_id']
83
+
84
+ columns = columns + ['prcfo']
85
+
86
+ if crop_mode == 'cell':
87
+ columns = columns + ['cell_id']
88
+
89
+ if crop_mode == 'nucleus':
90
+ columns = columns + ['nucleus_id']
91
+
92
+ if crop_mode == 'pathogen':
93
+ columns = columns + ['pathogen_id']
94
+
95
+ if crop_mode == 'cytoplasm':
96
+ columns = columns + ['cytoplasm_id']
97
+
98
+ png_df[columns] = parts
99
+
100
+ try:
101
+ conn = sqlite3.connect(f'{source_folder}/measurements/measurements.db', timeout=5)
102
+ png_df.to_sql('png_list', conn, if_exists='append', index=False)
103
+ conn.commit()
104
+ except sqlite3.OperationalError as e:
105
+ print(f"SQLite error: {e}", flush=True)
106
+ traceback.print_exc()
107
+
108
+ def activation_maps_to_database(img_paths, source_folder, settings):
109
+ from .io import _create_database
110
+
111
+ png_df = pd.DataFrame(img_paths, columns=['png_path'])
112
+ png_df['file_name'] = png_df['png_path'].apply(lambda x: os.path.basename(x))
113
+ parts = png_df['file_name'].apply(lambda x: pd.Series(_map_wells_png(x, timelapse=False)))
114
+ columns = ['plate', 'row', 'col', 'field', 'prcfo', 'object']
115
+ png_df[columns] = parts
116
+
117
+ dataset_name = os.path.splitext(os.path.basename(settings['dataset']))[0]
118
+ database_name = f"{source_folder}/measurements/{dataset_name}.db"
119
+
120
+ if not os.path.exists(database_name):
121
+ _create_database(database_name)
122
+
123
+ try:
124
+ conn = sqlite3.connect(database_name, timeout=5)
125
+ png_df.to_sql(f"{settings['cam_type']}_list", conn, if_exists='append', index=False)
126
+ conn.commit()
127
+ except sqlite3.OperationalError as e:
128
+ print(f"SQLite error: {e}", flush=True)
129
+ traceback.print_exc()
130
+
131
+ def activation_correlations_to_database(df, img_paths, source_folder, settings):
132
+ from .io import _create_database
133
+
134
+ png_df = pd.DataFrame(img_paths, columns=['png_path'])
135
+ png_df['file_name'] = png_df['png_path'].apply(lambda x: os.path.basename(x))
136
+ parts = png_df['file_name'].apply(lambda x: pd.Series(_map_wells_png(x, timelapse=False)))
137
+ columns = ['plate', 'row', 'col', 'field', 'prcfo', 'object']
138
+ png_df[columns] = parts
139
+
140
+ # Align both DataFrames by file_name
141
+ png_df.set_index('file_name', inplace=True)
142
+ df.set_index('file_name', inplace=True)
143
+
144
+ merged_df = pd.concat([png_df, df], axis=1)
145
+ merged_df.reset_index(inplace=True)
146
+
147
+ dataset_name = os.path.splitext(os.path.basename(settings['dataset']))[0]
148
+ database_name = f"{source_folder}/measurements/{dataset_name}.db"
149
+
150
+ if not os.path.exists(database_name):
151
+ _create_database(database_name)
152
+
153
+ try:
154
+ conn = sqlite3.connect(database_name, timeout=5)
155
+ merged_df.to_sql(f"{settings['cam_type']}_correlations", conn, if_exists='append', index=False)
156
+ conn.commit()
157
+ except sqlite3.OperationalError as e:
158
+ print(f"SQLite error: {e}", flush=True)
159
+ traceback.print_exc()
160
+
161
+ def calculate_activation_correlations(inputs, activation_maps, file_names, manders_thresholds=[15, 50, 75]):
162
+ """
163
+ Calculates Pearson and Manders correlations between input image channels and activation map channels.
164
+
165
+ Args:
166
+ inputs: A batch of input images, Tensor of shape (batch_size, channels, height, width)
167
+ activation_maps: A batch of activation maps, Tensor of shape (batch_size, channels, height, width)
168
+ file_names: List of file names corresponding to each image in the batch.
169
+ manders_thresholds: List of intensity percentiles to calculate Manders correlation.
170
+
171
+ Returns:
172
+ df_correlations: A DataFrame with columns for pairwise correlations (Pearson and Manders)
173
+ between input channels and activation map channels.
174
+ """
175
+
176
+ # Ensure tensors are detached and moved to CPU before converting to numpy
177
+ inputs = inputs.detach().cpu()
178
+ activation_maps = activation_maps.detach().cpu()
179
+
180
+ batch_size, in_channels, height, width = inputs.shape
181
+
182
+ if activation_maps.dim() == 3:
183
+ # If activation maps have no channels, add a dummy channel dimension
184
+ activation_maps = activation_maps.unsqueeze(1) # Now shape is (batch_size, 1, height, width)
185
+
186
+ _, act_channels, act_height, act_width = activation_maps.shape
187
+
188
+ # Ensure that the inputs and activation maps are the same size
189
+ if (height != act_height) or (width != act_width):
190
+ activation_maps = torch.nn.functional.interpolate(activation_maps, size=(height, width), mode='bilinear')
191
+
192
+ # Dictionary to collect correlation results
193
+ correlations_dict = {'file_name': []}
194
+
195
+ # Initialize correlation columns based on input channels and activation map channels
196
+ for in_c in range(in_channels):
197
+ for act_c in range(act_channels):
198
+ correlations_dict[f'channel_{in_c}_activation_{act_c}_pearsons'] = []
199
+ for threshold in manders_thresholds:
200
+ correlations_dict[f'channel_{in_c}_activation_{act_c}_{threshold}_M1'] = []
201
+ correlations_dict[f'channel_{in_c}_activation_{act_c}_{threshold}_M2'] = []
202
+
203
+ # Loop over the batch
204
+ for b in range(batch_size):
205
+ input_img = inputs[b] # Input image channels (C, H, W)
206
+ activation_map = activation_maps[b] # Activation map channels (C, H, W)
207
+
208
+ # Add the file name to the current row
209
+ correlations_dict['file_name'].append(file_names[b])
210
+
211
+ # Calculate correlations for each channel pair
212
+ for in_c in range(in_channels):
213
+ input_channel = input_img[in_c].flatten().numpy() # Flatten the input image channel
214
+ input_channel = input_channel[np.isfinite(input_channel)] # Remove NaN or inf values
215
+
216
+ for act_c in range(act_channels):
217
+ activation_channel = activation_map[act_c].flatten().numpy() # Flatten the activation map channel
218
+ activation_channel = activation_channel[np.isfinite(activation_channel)] # Remove NaN or inf values
219
+
220
+ # Check if there are valid (non-empty) arrays left to calculate the Pearson correlation
221
+ if input_channel.size > 0 and activation_channel.size > 0:
222
+ pearson_corr, _ = pearsonr(input_channel, activation_channel)
223
+ else:
224
+ pearson_corr = np.nan # Assign NaN if there are no valid data points
225
+ correlations_dict[f'channel_{in_c}_activation_{act_c}_pearsons'].append(pearson_corr)
226
+
227
+ # Compute Manders correlations for each threshold
228
+ for threshold in manders_thresholds:
229
+ # Get the top percentile pixels based on intensity in both channels
230
+ if input_channel.size > 0 and activation_channel.size > 0:
231
+ input_threshold = np.percentile(input_channel, threshold)
232
+ activation_threshold = np.percentile(activation_channel, threshold)
233
+
234
+ # Mask the pixels above the threshold
235
+ mask = (input_channel >= input_threshold) & (activation_channel >= activation_threshold)
236
+
237
+ # If we have enough pixels, calculate Manders correlation
238
+ if np.sum(mask) > 0:
239
+ manders_corr_M1 = np.sum(input_channel[mask] * activation_channel[mask]) / np.sum(input_channel[mask] ** 2)
240
+ manders_corr_M2 = np.sum(activation_channel[mask] * input_channel[mask]) / np.sum(activation_channel[mask] ** 2)
241
+ else:
242
+ manders_corr_M1 = np.nan
243
+ manders_corr_M2 = np.nan
244
+ else:
245
+ manders_corr_M1 = np.nan
246
+ manders_corr_M2 = np.nan
247
+
248
+ # Store the Manders correlation for this threshold
249
+ correlations_dict[f'channel_{in_c}_activation_{act_c}_{threshold}_M1'].append(manders_corr_M1)
250
+ correlations_dict[f'channel_{in_c}_activation_{act_c}_{threshold}_M2'].append(manders_corr_M2)
251
+
252
+ # Convert the dictionary to a DataFrame
253
+ df_correlations = pd.DataFrame(correlations_dict)
254
+
255
+ return df_correlations
256
+
257
+ def load_settings(csv_file_path, show=False, setting_key='setting_key', setting_value='setting_value'):
258
+ """
259
+ Convert a CSV file with 'settings_key' and 'settings_value' columns into a dictionary.
260
+ Handles special cases where values are lists, tuples, booleans, None, integers, floats, and nested dictionaries.
261
+
262
+ Args:
263
+ csv_file_path (str): The path to the CSV file.
264
+ show (bool): Whether to display the dataframe (for debugging).
265
+ setting_key (str): The name of the column that contains the setting keys.
266
+ setting_value (str): The name of the column that contains the setting values.
267
+
268
+ Returns:
269
+ dict: A dictionary where 'settings_key' are the keys and 'settings_value' are the values.
270
+ """
271
+ # Read the CSV file into a DataFrame
272
+ df = pd.read_csv(csv_file_path)
273
+
274
+ if show:
275
+ display(df)
276
+
277
+ # Ensure the columns 'setting_key' and 'setting_value' exist
278
+ if setting_key not in df.columns or setting_value not in df.columns:
279
+ raise ValueError(f"CSV file must contain {setting_key} and {setting_value} columns.")
280
+
281
+ def parse_value(value):
282
+ """Parse the string value into the appropriate Python data type."""
283
+ # Handle empty values
284
+ if pd.isna(value) or value == '':
285
+ return None
286
+
287
+ # Handle boolean values
288
+ if value == 'True':
289
+ return True
290
+ if value == 'False':
291
+ return False
292
+
293
+ # Handle lists, tuples, dictionaries, and other literals
294
+ if value.startswith(('(', '[', '{')): # If it starts with (, [ or {, use ast.literal_eval
295
+ try:
296
+ parsed_value = ast.literal_eval(value)
297
+ # If parsed_value is a dict, recursively parse its values
298
+ if isinstance(parsed_value, dict):
299
+ parsed_value = {k: parse_value(v) for k, v in parsed_value.items()}
300
+ return parsed_value
301
+ except (ValueError, SyntaxError):
302
+ pass # If there's an error, return the value as-is
303
+
304
+ # Handle numeric values (integers and floats)
305
+ try:
306
+ if '.' in value:
307
+ return float(value) # If it contains a dot, convert to float
308
+ return int(value) # Otherwise, convert to integer
309
+ except ValueError:
310
+ pass # If it's not a valid number, return the value as-is
311
+
312
+ # Return the original value if no other type matched
313
+ return value
314
+
315
+ # Convert the DataFrame to a dictionary, with parsing of each value
316
+ result_dict = {key: parse_value(value) for key, value in zip(df[setting_key], df[setting_value])}
89
317
 
90
- def save_settings(settings, name='settings'):
318
+ return result_dict
319
+
320
+
321
+ def save_settings(settings, name='settings', show=False):
91
322
 
92
323
  settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
93
- settings_csv = os.path.join(settings['src'],'settings',f'{name}.csv')
94
- os.makedirs(os.path.join(settings['src'],'settings'), exist_ok=True)
324
+ if show:
325
+ display(settings_df)
326
+
327
+ if isinstance(settings['src'], list):
328
+ src = settings['src'][0]
329
+ name = f"{name}_list"
330
+ else:
331
+ src = settings['src']
332
+
333
+ settings_csv = os.path.join(src,'settings',f'{name}.csv')
334
+ os.makedirs(os.path.join(src,'settings'), exist_ok=True)
95
335
  settings_df.to_csv(settings_csv, index=False)
96
336
 
97
337
  def print_progress(files_processed, files_to_process, n_jobs, time_ls=None, batch_size=None, operation_type=""):
@@ -303,7 +543,7 @@ def _get_cellpose_batch_size():
303
543
  except Exception as e:
304
544
  return 8
305
545
 
306
- def _extract_filename_metadata(filenames, src, regular_expression, metadata_type='cellvoyager', pick_slice=False, skip_mode='01'):
546
+ def _extract_filename_metadata_v1(filenames, src, regular_expression, metadata_type='cellvoyager', pick_slice=False, skip_mode='01'):
307
547
 
308
548
  images_by_key = defaultdict(list)
309
549
 
@@ -353,6 +593,57 @@ def _extract_filename_metadata(filenames, src, regular_expression, metadata_type
353
593
 
354
594
  return images_by_key
355
595
 
596
+ def _extract_filename_metadata(filenames, src, regular_expression, metadata_type='cellvoyager', pick_slice=False, skip_mode='01'):
597
+
598
+ images_by_key = defaultdict(list)
599
+
600
+ for filename in filenames:
601
+ match = regular_expression.match(filename)
602
+ if match:
603
+ try:
604
+ try:
605
+ plate = match.group('plateID')
606
+ except:
607
+ plate = os.path.basename(src)
608
+
609
+ well = match.group('wellID')
610
+ field = match.group('fieldID')
611
+ channel = match.group('chanID')
612
+ mode = None
613
+
614
+ if well[0].isdigit():
615
+ well = str(_safe_int_convert(well))
616
+ if field[0].isdigit():
617
+ field = str(_safe_int_convert(field))
618
+ if channel[0].isdigit():
619
+ channel = str(_safe_int_convert(channel))
620
+
621
+ if metadata_type =='cq1':
622
+ orig_wellID = wellID
623
+ wellID = _convert_cq1_well_id(wellID)
624
+ print(f'Converted Well ID: {orig_wellID} to {wellID}', end='\r', flush=True)
625
+
626
+ if pick_slice:
627
+ try:
628
+ mode = match.group('AID')
629
+ except IndexError:
630
+ sliceid = '00'
631
+
632
+ if mode == skip_mode:
633
+ continue
634
+
635
+ key = (plate, well, field, channel, mode)
636
+ file_path = os.path.join(src, filename) # Store the full path
637
+ images_by_key[key].append(file_path)
638
+
639
+ except IndexError:
640
+ print(f"Could not extract information from filename {filename} using provided regex")
641
+ else:
642
+ print(f"Filename {filename} did not match provided regex")
643
+ continue
644
+
645
+ return images_by_key
646
+
356
647
  def mask_object_count(mask):
357
648
  """
358
649
  Counts the number of objects in a given mask.
@@ -443,7 +734,7 @@ def _generate_representative_images(db_path, cells=['HeLa'], cell_loc=None, path
443
734
  from .plot import _plot_images_on_grid
444
735
 
445
736
  df = _read_and_join_tables(db_path)
446
- df = _annotate_conditions(df, cells, cell_loc, pathogens, pathogen_loc, treatments,treatment_loc)
737
+ df = annotate_conditions(df, cells, cell_loc, pathogens, pathogen_loc, treatments, treatment_loc)
447
738
 
448
739
  if update_db:
449
740
  _update_database_with_merged_info(db_path, df, table='png_list', columns=['pathogen', 'treatment', 'host_cells', 'condition', 'prcfo'])
@@ -489,34 +780,6 @@ def _map_values(row, values, locs):
489
780
  return value_dict.get(row[type_], None)
490
781
  return values[0] if values else None
491
782
 
492
- def _annotate_conditions(df, cells=['HeLa'], cell_loc=None, pathogens=['rh'], pathogen_loc=None, treatments=['cm'], treatment_loc=None):
493
- """
494
- Annotates conditions in the given DataFrame based on the provided parameters.
495
-
496
- Args:
497
- df (pandas.DataFrame): The DataFrame to annotate.
498
- cells (list, optional): The list of host cell types. Defaults to ['HeLa'].
499
- cell_loc (list, optional): The list of location identifiers for host cells. Defaults to None.
500
- pathogens (list, optional): The list of pathogens. Defaults to ['rh'].
501
- pathogen_loc (list, optional): The list of location identifiers for pathogens. Defaults to None.
502
- treatments (list, optional): The list of treatments. Defaults to ['cm'].
503
- treatment_loc (list, optional): The list of location identifiers for treatments. Defaults to None.
504
-
505
- Returns:
506
- pandas.DataFrame: The annotated DataFrame with the 'host_cells', 'pathogen', 'treatment', and 'condition' columns.
507
- """
508
-
509
-
510
- # Apply mappings or defaults
511
- df['host_cells'] = [cells[0]] * len(df) if cell_loc is None else df.apply(_map_values, args=(cells, cell_loc), axis=1)
512
- df['pathogen'] = [pathogens[0]] * len(df) if pathogen_loc is None else df.apply(_map_values, args=(pathogens, pathogen_loc), axis=1)
513
- df['treatment'] = [treatments[0]] * len(df) if treatment_loc is None else df.apply(_map_values, args=(treatments, treatment_loc), axis=1)
514
-
515
- # Construct condition column
516
- df['condition'] = df.apply(lambda row: '_'.join(filter(None, [row.get('pathogen'), row.get('treatment')])), axis=1)
517
- df['condition'] = df['condition'].apply(lambda x: x if x else 'none')
518
- return df
519
-
520
783
  def is_list_of_lists(var):
521
784
  if isinstance(var, list) and all(isinstance(i, list) for i in var):
522
785
  return True
@@ -816,7 +1079,7 @@ def _map_wells_png(file_name, timelapse=False):
816
1079
  print(f"Error: {e}")
817
1080
  plate, row, column, field, object_id, prcfo = 'error', 'error', 'error', 'error', 'error', 'error'
818
1081
  if timelapse:
819
- return plate, row, column, field, timeid, prcfo, object_id,
1082
+ return plate, row, column, field, timeid, prcfo, object_id
820
1083
  else:
821
1084
  return plate, row, column, field, prcfo, object_id
822
1085
 
@@ -1085,67 +1348,74 @@ def _get_cellpose_channels(src, nucleus_channel, pathogen_channel, cell_channel)
1085
1348
  else:
1086
1349
  cellpose_channels['cell'] = [0,0]
1087
1350
  return cellpose_channels
1088
-
1089
- def annotate_conditions(df, cells=['HeLa'], cell_loc=None, pathogens=['rh'], pathogen_loc=None, treatments=['cm'], treatment_loc=None, types = ['col','col','col']):
1351
+
1352
+ def annotate_conditions(df, cells=None, cell_loc=None, pathogens=None, pathogen_loc=None, treatments=None, treatment_loc=None):
1090
1353
  """
1091
- Annotates conditions in a DataFrame based on specified criteria.
1354
+ Annotates conditions in a DataFrame based on specified criteria and combines them into a 'condition' column.
1355
+ NaN is used for missing values, and they are excluded from the 'condition' column.
1092
1356
 
1093
1357
  Args:
1094
1358
  df (pandas.DataFrame): The DataFrame to annotate.
1095
- cells (list, optional): List of host cell types. Defaults to ['HeLa'].
1096
- cell_loc (list, optional): List of corresponding values for each host cell type. Defaults to None.
1097
- pathogens (list, optional): List of pathogens. Defaults to ['rh'].
1098
- pathogen_loc (list, optional): List of corresponding values for each pathogen. Defaults to None.
1099
- treatments (list, optional): List of treatments. Defaults to ['cm'].
1100
- treatment_loc (list, optional): List of corresponding values for each treatment. Defaults to None.
1101
- types (list, optional): List of column types for host cells, pathogens, and treatments. Defaults to ['col','col','col'].
1359
+ cells (list/str, optional): Host cell types. Defaults to None.
1360
+ cell_loc (list of lists, optional): Values for each host cell type. Defaults to None.
1361
+ pathogens (list/str, optional): Pathogens. Defaults to None.
1362
+ pathogen_loc (list of lists, optional): Values for each pathogen. Defaults to None.
1363
+ treatments (list/str, optional): Treatments. Defaults to None.
1364
+ treatment_loc (list of lists, optional): Values for each treatment. Defaults to None.
1102
1365
 
1103
1366
  Returns:
1104
- pandas.DataFrame: The annotated DataFrame.
1367
+ pandas.DataFrame: Annotated DataFrame with a combined 'condition' column.
1105
1368
  """
1369
+
1370
+ def _get_type(val):
1371
+ """Determine if a value maps to 'row' or 'col'."""
1372
+ if isinstance(val, str) and val.startswith('c'):
1373
+ return 'col'
1374
+ elif isinstance(val, str) and val.startswith('r'):
1375
+ return 'row'
1376
+ return None
1106
1377
 
1107
- # Function to apply to each row
1108
- def _map_values(row, dict_, type_='col'):
1378
+ def _map_or_default(column_name, values, loc, df):
1109
1379
  """
1110
- Maps the values in a row to corresponding keys in a dictionary.
1380
+ Consolidates the logic for mapping values or assigning defaults when loc is None.
1111
1381
 
1112
1382
  Args:
1113
- row (dict): The row containing the values to be mapped.
1114
- dict_ (dict): The dictionary containing the mapping values.
1115
- type_ (str, optional): The type of mapping to perform. Defaults to 'col'.
1116
-
1117
- Returns:
1118
- str: The mapped value if found, otherwise None.
1383
+ column_name (str): The column in the DataFrame to annotate.
1384
+ values (list/str): The list of values or a single string to annotate.
1385
+ loc (list of lists): Location mapping for the values, or None if not used.
1386
+ df (pandas.DataFrame): The DataFrame to modify.
1119
1387
  """
1120
- for values, cols in dict_.items():
1121
- if row[type_] in cols:
1122
- return values
1123
- return None
1388
+ if isinstance(values, str) or (isinstance(values, list) and loc is None):
1389
+ # Assign all rows the first value in the list or the single string
1390
+ df[column_name] = values if isinstance(values, str) else values[0]
1391
+ elif values is not None and loc is not None:
1392
+ # Perform the location-based mapping
1393
+ value_dict = {val: key for key, loc_list in zip(values, loc) for val in loc_list}
1394
+ df[column_name] = np.nan
1395
+ for val, key in value_dict.items():
1396
+ loc_type = _get_type(val)
1397
+ if loc_type:
1398
+ df.loc[df[loc_type] == val, column_name] = key
1399
+
1400
+ # Handle cells, pathogens, and treatments using the consolidated logic
1401
+ _map_or_default('host_cells', cells, cell_loc, df)
1402
+ _map_or_default('pathogen', pathogens, pathogen_loc, df)
1403
+ _map_or_default('treatment', treatments, treatment_loc, df)
1404
+
1405
+ # Conditionally fill NaN for pathogen and treatment columns if applicable
1406
+ if pathogens is not None:
1407
+ df['pathogen'].fillna(np.nan, inplace=True)
1408
+ if treatments is not None:
1409
+ df['treatment'].fillna(np.nan, inplace=True)
1410
+
1411
+ # Create the 'condition' column by excluding any NaN values, safely checking if 'host_cells', 'pathogen', and 'treatment' exist
1412
+ df['condition'] = df.apply(
1413
+ lambda x: '_'.join([str(v) for v in [x.get('host_cells'), x.get('pathogen'), x.get('treatment')] if pd.notna(v)]),
1414
+ axis=1
1415
+ )
1124
1416
 
1125
- if cell_loc is None:
1126
- df['host_cells'] = cells[0]
1127
- else:
1128
- cells_dict = dict(zip(cells, cell_loc))
1129
- df['host_cells'] = df.apply(lambda row: _map_values(row, cells_dict, type_=types[0]), axis=1)
1130
- if pathogen_loc is None:
1131
- if pathogens != None:
1132
- df['pathogen'] = 'none'
1133
- else:
1134
- pathogens_dict = dict(zip(pathogens, pathogen_loc))
1135
- df['pathogen'] = df.apply(lambda row: _map_values(row, pathogens_dict, type_=types[1]), axis=1)
1136
- if treatment_loc is None:
1137
- df['treatment'] = 'cm'
1138
- else:
1139
- treatments_dict = dict(zip(treatments, treatment_loc))
1140
- df['treatment'] = df.apply(lambda row: _map_values(row, treatments_dict, type_=types[2]), axis=1)
1141
- if pathogens != None:
1142
- df['condition'] = df['pathogen']+'_'+df['treatment']
1143
- else:
1144
- df['condition'] = df['treatment']
1145
1417
  return df
1146
-
1147
1418
 
1148
-
1149
1419
  def _split_data(df, group_by, object_type):
1150
1420
  """
1151
1421
  Splits the input dataframe into numeric and non-numeric parts, groups them by the specified column,
@@ -1951,9 +2221,10 @@ def add_images_to_tar(paths_chunk, tar_path, total_images):
1951
2221
  tar.add(img_path, arcname=arcname)
1952
2222
  with lock:
1953
2223
  counter.value += 1
1954
- if counter.value % 100 == 0: # Print every 100 updates
1955
- progress = (counter.value / total_images) * 100
1956
- print(f"Progress: {counter.value}/{total_images} ({progress:.2f}%)", end='\r', file=sys.stdout, flush=True)
2224
+ if counter.value % 10 == 0: # Print every 100 updates
2225
+ #progress = (counter.value / total_images) * 100
2226
+ #print(f"Progress: {counter.value}/{total_images} ({progress:.2f}%)", end='\r', file=sys.stdout, flush=True)
2227
+ print_progress(counter.value, total_images, n_jobs=1, time_ls=None, batch_size=None, operation_type="generating .tar dataset")
1957
2228
  except FileNotFoundError:
1958
2229
  print(f"File not found: {img_path}")
1959
2230
 
@@ -2070,52 +2341,6 @@ def check_multicollinearity(x):
2070
2341
  vif_data["VIF"] = [variance_inflation_factor(x.values, i) for i in range(x.shape[1])]
2071
2342
  return vif_data
2072
2343
 
2073
- def generate_dependent_variable(df, dv_loc, pc_min=0.95, nc_max=0.05, agg_type='mean'):
2074
-
2075
- from .plot import _plot_histograms_and_stats, _plot_plates
2076
-
2077
- def qstring_to_float(qstr):
2078
- number = int(qstr[1:]) # Remove the "q" and convert the rest to an integer
2079
- return number / 100.0
2080
-
2081
- print("Unique values in plate:", df['plate'].unique())
2082
- dv_cell_loc = f'{dv_loc}/dv_cell.csv'
2083
- dv_well_loc = f'{dv_loc}/dv_well.csv'
2084
-
2085
- df['pred'] = 1-df['pred'] #if you swiched pc and nc
2086
- df = df[(df['pred'] <= nc_max) | (df['pred'] >= pc_min)]
2087
-
2088
- if 'prc' not in df.columns:
2089
- df['prc'] = df['plate'] + '_' + df['row'] + '_' + df['col']
2090
-
2091
- if agg_type.startswith('q'):
2092
- val = qstring_to_float(agg_type)
2093
- agg_type = lambda x: x.quantile(val)
2094
-
2095
- # Aggregating for mean prediction and total count
2096
- df_grouped = df.groupby('prc').agg(
2097
- pred=('pred', agg_type),
2098
- recruitment=('recruitment', agg_type),
2099
- count_prc=('prc', 'size'),
2100
- #count_above_95=('pred', lambda x: (x > 0.95).sum()),
2101
- mean_pathogen_area=('pathogen_area', 'mean')
2102
- )
2103
-
2104
- df_cell = df[['prc', 'pred', 'pathogen_area', 'recruitment']]
2105
-
2106
- df_cell.to_csv(dv_cell_loc, index=True, header=True, mode='w')
2107
- df_grouped.to_csv(dv_well_loc, index=True, header=True, mode='w') # Changed from loc to dv_loc
2108
- display(df)
2109
- _plot_histograms_and_stats(df)
2110
- df_grouped = df_grouped.sort_values(by='count_prc', ascending=True)
2111
- display(df_grouped)
2112
- print('pred')
2113
- _plot_plates(df=df_cell, variable='pred', grouping='mean', min_max='allq', cmap='viridis')
2114
- print('recruitment')
2115
- _plot_plates(df=df_cell, variable='recruitment', grouping='mean', min_max='allq', cmap='viridis')
2116
-
2117
- return df_grouped
2118
-
2119
2344
  def lasso_reg(merged_df, alpha_value=0.01, reg_type='lasso'):
2120
2345
  # Separate predictors and response
2121
2346
  X = merged_df[['gene', 'grna', 'plate', 'row', 'column']]
@@ -3021,7 +3246,6 @@ def preprocess_image(image_path, image_size=224, channels=[1,2,3], normalize=Tru
3021
3246
  input_tensor = transform(image).unsqueeze(0)
3022
3247
  return image, input_tensor
3023
3248
 
3024
-
3025
3249
  class SaliencyMapGenerator:
3026
3250
  def __init__(self, model):
3027
3251
  self.model = model
@@ -3042,18 +3266,194 @@ class SaliencyMapGenerator:
3042
3266
  saliency = X.grad.abs()
3043
3267
  return saliency
3044
3268
 
3045
- def plot_saliency_maps(self, X, y, saliency, class_names):
3269
+ def compute_saliency_and_predictions(self, X):
3270
+ self.model.eval()
3271
+ X.requires_grad_()
3272
+
3273
+ # Forward pass to get predictions (logits)
3274
+ scores = self.model(X).squeeze()
3275
+
3276
+ # Get predicted class (0 or 1 for binary classification)
3277
+ predictions = (scores > 0).long()
3278
+
3279
+ # Compute saliency maps
3280
+ self.model.zero_grad()
3281
+ target_scores = scores * (2 * predictions - 1)
3282
+ target_scores.backward(torch.ones_like(target_scores))
3283
+
3284
+ saliency = X.grad.abs()
3285
+
3286
+ return saliency, predictions
3287
+
3288
+ def plot_activation_grid(self, X, saliency, predictions, overlay=True, normalize=False):
3046
3289
  N = X.shape[0]
3290
+ rows = (N + 7) // 8
3291
+ fig, axs = plt.subplots(rows, 8, figsize=(16, rows * 2))
3292
+
3047
3293
  for i in range(N):
3048
- plt.subplot(2, N, i + 1)
3049
- plt.imshow(X[i].permute(1, 2, 0).cpu().numpy())
3050
- plt.axis('off')
3051
- plt.title(class_names[y[i]])
3052
- plt.subplot(2, N, N + i + 1)
3053
- plt.imshow(saliency[i].cpu().numpy(), cmap=plt.cm.hot)
3054
- plt.axis('off')
3055
- plt.gcf().set_size_inches(12, 5)
3056
- plt.show()
3294
+ ax = axs[i // 8, i % 8]
3295
+ saliency_map = saliency[i].cpu().numpy() # Move to CPU and convert to numpy
3296
+
3297
+ if saliency_map.shape[0] == 3: # Channels first, reshape to (H, W, 3)
3298
+ saliency_map = np.transpose(saliency_map, (1, 2, 0))
3299
+
3300
+ # Normalize image channels to 2nd and 98th percentiles
3301
+ if overlay:
3302
+ img_np = X[i].permute(1, 2, 0).detach().cpu().numpy()
3303
+ if normalize:
3304
+ img_np = self.percentile_normalize(img_np)
3305
+ ax.imshow(img_np)
3306
+ ax.imshow(saliency_map, cmap='jet', alpha=0.5)
3307
+
3308
+ # Add class label in the top-left corner
3309
+ ax.text(5, 25, str(predictions[i].item()), fontsize=12, color='white', weight='bold',
3310
+ bbox=dict(facecolor='black', alpha=0.7, boxstyle='round,pad=0.2'))
3311
+ ax.axis('off')
3312
+
3313
+ plt.tight_layout(pad=0)
3314
+ return fig
3315
+
3316
+ def percentile_normalize(self, img, lower_percentile=2, upper_percentile=98):
3317
+ """
3318
+ Normalize each channel of the image to the given percentiles.
3319
+ Args:
3320
+ img: Input image as numpy array with shape (H, W, C)
3321
+ lower_percentile: Lower percentile for normalization (default 2)
3322
+ upper_percentile: Upper percentile for normalization (default 98)
3323
+ Returns:
3324
+ img: Normalized image
3325
+ """
3326
+ img_normalized = np.zeros_like(img)
3327
+
3328
+ for c in range(img.shape[2]): # Iterate over each channel
3329
+ low = np.percentile(img[:, :, c], lower_percentile)
3330
+ high = np.percentile(img[:, :, c], upper_percentile)
3331
+ img_normalized[:, :, c] = np.clip((img[:, :, c] - low) / (high - low), 0, 1)
3332
+
3333
+ return img_normalized
3334
+
3335
+
3336
+ class GradCAMGenerator:
3337
+ def __init__(self, model, target_layer, cam_type='gradcam'):
3338
+ self.model = model
3339
+ self.model.eval()
3340
+ self.target_layer = target_layer
3341
+ self.cam_type = cam_type
3342
+ self.gradients = None
3343
+ self.activations = None
3344
+
3345
+ # Hook the target layer
3346
+ self.target_layer_module = self.get_layer(self.model, self.target_layer)
3347
+ self.hook_layers()
3348
+
3349
+ def hook_layers(self):
3350
+ # Forward hook to get activations
3351
+ def forward_hook(module, input, output):
3352
+ self.activations = output
3353
+
3354
+ # Backward hook to get gradients
3355
+ def backward_hook(module, grad_input, grad_output):
3356
+ self.gradients = grad_output[0]
3357
+
3358
+ self.target_layer_module.register_forward_hook(forward_hook)
3359
+ self.target_layer_module.register_backward_hook(backward_hook)
3360
+
3361
+ def get_layer(self, model, target_layer):
3362
+ # Recursively find the layer specified in target_layer
3363
+ modules = target_layer.split('.')
3364
+ layer = model
3365
+ for module in modules:
3366
+ layer = getattr(layer, module)
3367
+ return layer
3368
+
3369
+ def compute_gradcam_maps(self, X, y):
3370
+ X.requires_grad_()
3371
+
3372
+ # Forward pass
3373
+ scores = self.model(X).squeeze()
3374
+
3375
+ # Perform backward pass
3376
+ target_scores = scores * (2 * y - 1)
3377
+ self.model.zero_grad()
3378
+ target_scores.backward(torch.ones_like(target_scores))
3379
+
3380
+ # Compute GradCAM
3381
+ pooled_gradients = torch.mean(self.gradients, dim=[0, 2, 3])
3382
+ for i in range(self.activations.size(1)):
3383
+ self.activations[:, i, :, :] *= pooled_gradients[i]
3384
+
3385
+ gradcam = torch.mean(self.activations, dim=1).squeeze()
3386
+ gradcam = F.relu(gradcam)
3387
+ gradcam = F.interpolate(gradcam.unsqueeze(0).unsqueeze(0), size=X.shape[2:], mode='bilinear')
3388
+ gradcam = gradcam.squeeze().cpu().detach().numpy()
3389
+ gradcam = (gradcam - gradcam.min()) / (gradcam.max() - gradcam.min())
3390
+
3391
+ return gradcam
3392
+
3393
+ def compute_gradcam_and_predictions(self, X):
3394
+ self.model.eval()
3395
+ X.requires_grad_()
3396
+
3397
+ # Forward pass to get predictions (logits)
3398
+ scores = self.model(X).squeeze()
3399
+
3400
+ # Get predicted class (0 or 1 for binary classification)
3401
+ predictions = (scores > 0).long()
3402
+
3403
+ # Compute gradcam maps
3404
+ gradcam_maps = []
3405
+ for i in range(X.size(0)):
3406
+ gradcam_map = self.compute_gradcam_maps(X[i].unsqueeze(0), predictions[i])
3407
+ gradcam_maps.append(gradcam_map)
3408
+
3409
+ return torch.tensor(gradcam_maps), predictions
3410
+
3411
+ def plot_activation_grid(self, X, gradcam, predictions, overlay=True, normalize=False):
3412
+ N = X.shape[0]
3413
+ rows = (N + 7) // 8
3414
+ fig, axs = plt.subplots(rows, 8, figsize=(16, rows * 2))
3415
+
3416
+ for i in range(N):
3417
+ ax = axs[i // 8, i % 8]
3418
+ gradcam_map = gradcam[i].cpu().numpy()
3419
+
3420
+ # Normalize image channels to 2nd and 98th percentiles
3421
+ if overlay:
3422
+ img_np = X[i].permute(1, 2, 0).detach().cpu().numpy()
3423
+ if normalize:
3424
+ img_np = self.percentile_normalize(img_np)
3425
+ ax.imshow(img_np)
3426
+ ax.imshow(gradcam_map, cmap='jet', alpha=0.5)
3427
+
3428
+ #ax.imshow(X[i].permute(1, 2, 0).detach().cpu().numpy()) # Original image
3429
+ #ax.imshow(gradcam_map, cmap='jet', alpha=0.5) # Overlay the gradcam map
3430
+
3431
+ # Add class label in the top-left corner
3432
+ ax.text(5, 25, str(predictions[i].item()), fontsize=12, color='white', weight='bold',
3433
+ bbox=dict(facecolor='black', alpha=0.7, boxstyle='round,pad=0.2'))
3434
+ ax.axis('off')
3435
+
3436
+ plt.tight_layout(pad=0)
3437
+ return fig
3438
+
3439
+ def percentile_normalize(self, img, lower_percentile=2, upper_percentile=98):
3440
+ """
3441
+ Normalize each channel of the image to the given percentiles.
3442
+ Args:
3443
+ img: Input image as numpy array with shape (H, W, C)
3444
+ lower_percentile: Lower percentile for normalization (default 2)
3445
+ upper_percentile: Upper percentile for normalization (default 98)
3446
+ Returns:
3447
+ img: Normalized image
3448
+ """
3449
+ img_normalized = np.zeros_like(img)
3450
+
3451
+ for c in range(img.shape[2]): # Iterate over each channel
3452
+ low = np.percentile(img[:, :, c], lower_percentile)
3453
+ high = np.percentile(img[:, :, c], upper_percentile)
3454
+ img_normalized[:, :, c] = np.clip((img[:, :, c] - low) / (high - low), 0, 1)
3455
+
3456
+ return img_normalized
3057
3457
 
3058
3458
  def preprocess_image(image_path, normalize=True, image_size=224, channels=[1,2,3]):
3059
3459
  preprocess = transforms.Compose([
@@ -3594,13 +3994,86 @@ def plot_grid(cluster_images, colors, figuresize, black_background, verbose):
3594
3994
  plt.show()
3595
3995
  return grid_fig
3596
3996
 
3597
- def correct_paths(df, base_path):
3997
+ def generate_path_list_from_db_v1(db_path, file_metadata):
3998
+
3999
+ all_paths = []
4000
+
4001
+ # Connect to the database and retrieve the image paths
4002
+ print(f"Reading DataBase: {db_path}")
4003
+ try:
4004
+ with sqlite3.connect(db_path) as conn:
4005
+ cursor = conn.cursor()
4006
+ if file_metadata:
4007
+ if isinstance(file_metadata, str):
4008
+ cursor.execute("SELECT png_path FROM png_list WHERE png_path LIKE ?", (f"%{file_metadata}%",))
4009
+ else:
4010
+ cursor.execute("SELECT png_path FROM png_list")
4011
+
4012
+ while True:
4013
+ rows = cursor.fetchmany(1000)
4014
+ if not rows:
4015
+ break
4016
+ all_paths.extend([row[0] for row in rows])
3598
4017
 
3599
- if 'png_path' not in df.columns:
3600
- print("No 'png_path' column found in the dataframe.")
3601
- return df, None
4018
+ except sqlite3.Error as e:
4019
+ print(f"Database error: {e}")
4020
+ return
4021
+ except Exception as e:
4022
+ print(f"Error: {e}")
4023
+ return
3602
4024
 
3603
- image_paths = df['png_path'].to_list()
4025
+ return all_paths
4026
+
4027
+ def generate_path_list_from_db(db_path, file_metadata):
4028
+ all_paths = []
4029
+
4030
+ # Connect to the database and retrieve the image paths
4031
+ print(f"Reading DataBase: {db_path}")
4032
+ try:
4033
+ with sqlite3.connect(db_path) as conn:
4034
+ cursor = conn.cursor()
4035
+
4036
+ if file_metadata:
4037
+ if isinstance(file_metadata, str):
4038
+ # If file_metadata is a single string
4039
+ cursor.execute("SELECT png_path FROM png_list WHERE png_path LIKE ?", (f"%{file_metadata}%",))
4040
+ elif isinstance(file_metadata, list):
4041
+ # If file_metadata is a list of strings
4042
+ query = "SELECT png_path FROM png_list WHERE " + " OR ".join(
4043
+ ["png_path LIKE ?" for _ in file_metadata])
4044
+ params = [f"%{meta}%" for meta in file_metadata]
4045
+ cursor.execute(query, params)
4046
+ else:
4047
+ # If file_metadata is None or empty
4048
+ cursor.execute("SELECT png_path FROM png_list")
4049
+
4050
+ while True:
4051
+ rows = cursor.fetchmany(1000)
4052
+ if not rows:
4053
+ break
4054
+ all_paths.extend([row[0] for row in rows])
4055
+
4056
+ except sqlite3.Error as e:
4057
+ print(f"Database error: {e}")
4058
+ return
4059
+ except Exception as e:
4060
+ print(f"Error: {e}")
4061
+ return
4062
+
4063
+ return all_paths
4064
+
4065
+ def correct_paths(df, base_path):
4066
+
4067
+ if isinstance(df, pd.DataFrame):
4068
+
4069
+ if 'png_path' not in df.columns:
4070
+ print("No 'png_path' column found in the dataframe.")
4071
+ return df, None
4072
+ else:
4073
+ image_paths = df['png_path'].to_list()
4074
+
4075
+ elif isinstance(df, list):
4076
+ image_paths = df
3604
4077
 
3605
4078
  adjusted_image_paths = []
3606
4079
  for path in image_paths:
@@ -3614,9 +4087,11 @@ def correct_paths(df, base_path):
3614
4087
  else:
3615
4088
  adjusted_image_paths.append(path)
3616
4089
 
3617
- df['png_path'] = adjusted_image_paths
3618
- image_paths = df['png_path'].to_list()
3619
- return df, image_paths
4090
+ if isinstance(df, pd.DataFrame):
4091
+ df['png_path'] = adjusted_image_paths
4092
+ return df, adjusted_image_paths
4093
+ else:
4094
+ return adjusted_image_paths
3620
4095
 
3621
4096
  def delete_folder(folder_path):
3622
4097
  if os.path.exists(folder_path) and os.path.isdir(folder_path):
@@ -4424,7 +4899,7 @@ def convert_and_relabel_masks(folder_path):
4424
4899
 
4425
4900
  def correct_masks(src):
4426
4901
 
4427
- from .utils import _load_and_concatenate_arrays
4902
+ from .io import _load_and_concatenate_arrays
4428
4903
 
4429
4904
  cell_path = os.path.join(src,'norm_channel_stack', 'cell_mask_stack')
4430
4905
  convert_and_relabel_masks(cell_path)
@@ -4447,4 +4922,123 @@ def get_cuda_version():
4447
4922
  except (subprocess.CalledProcessError, FileNotFoundError):
4448
4923
  return None
4449
4924
 
4925
+ def all_elements_match(list1, list2):
4926
+ # Check if all elements in list1 are in list2
4927
+ return all(element in list2 for element in list1)
4928
+
4929
+ def prepare_batch_for_segmentation(batch):
4930
+ # Ensure the batch is of dtype float32
4931
+ if batch.dtype != np.float32:
4932
+ batch = batch.astype(np.float32)
4933
+
4934
+ # Normalize each image in the batch
4935
+ for i in range(batch.shape[0]):
4936
+ if batch[i].max() > 1:
4937
+ batch[i] = batch[i] / batch[i].max()
4938
+
4939
+ return batch
4940
+
4941
+ def check_index(df, elements=5, split_char='_'):
4942
+ problematic_indices = []
4943
+ for idx in df.index:
4944
+ parts = str(idx).split(split_char)
4945
+ if len(parts) != elements:
4946
+ problematic_indices.append(idx)
4947
+ if problematic_indices:
4948
+ print("Indices that cannot be separated into 5 parts:")
4949
+ for idx in problematic_indices:
4950
+ print(idx)
4951
+ raise ValueError(f"Found {len(problematic_indices)} problematic indices that do not split into {elements} parts.")
4952
+
4953
+ # Define the mapping function
4954
+ def map_condition(col_value, neg='c1', pos='c2', mix='c3'):
4955
+ if col_value == neg:
4956
+ return 'neg'
4957
+ elif col_value == pos:
4958
+ return 'pos'
4959
+ elif col_value == mix:
4960
+ return 'mix'
4961
+ else:
4962
+ return 'screen'
4963
+
4964
+ def download_models(repo_id="einarolafsson/models", local_dir=None, retries=5, delay=5):
4965
+ """
4966
+ Downloads all model files from Hugging Face and stores them in the specified local directory.
4450
4967
 
4968
+ Args:
4969
+ repo_id (str): The repository ID on Hugging Face (default is 'einarolafsson/models').
4970
+ local_dir (str): The local directory where models will be saved. Defaults to '/home/carruthers/Desktop/test'.
4971
+ retries (int): Number of retry attempts in case of failure.
4972
+ delay (int): Delay in seconds between retries.
4973
+
4974
+ Returns:
4975
+ str: The local path to the downloaded models.
4976
+ """
4977
+ # Create the local directory if it doesn't exist
4978
+ if not os.path.exists(local_dir):
4979
+ os.makedirs(local_dir)
4980
+ elif len(os.listdir(local_dir)) > 0:
4981
+ print(f"Models already downloaded to: {local_dir}")
4982
+ return local_dir
4983
+
4984
+ attempt = 0
4985
+ while attempt < retries:
4986
+ try:
4987
+ # List all files in the repo
4988
+ files = list_repo_files(repo_id, repo_type="dataset")
4989
+ print(f"Files in repository: {files}") # Debugging print to check file list
4990
+
4991
+ # Download each file
4992
+ for file_name in files:
4993
+ for download_attempt in range(retries):
4994
+ try:
4995
+ url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/{file_name}?download=true"
4996
+ print(f"Downloading file from: {url}") # Debugging
4997
+
4998
+ response = requests.get(url, stream=True)
4999
+ print(f"HTTP response status: {response.status_code}") # Debugging
5000
+ response.raise_for_status()
5001
+
5002
+ # Save the file locally
5003
+ local_file_path = os.path.join(local_dir, os.path.basename(file_name))
5004
+ with open(local_file_path, 'wb') as file:
5005
+ for chunk in response.iter_content(chunk_size=8192):
5006
+ file.write(chunk)
5007
+ print(f"Downloaded model file: {file_name} to {local_file_path}")
5008
+ break # Exit the retry loop if successful
5009
+ except (requests.HTTPError, requests.Timeout) as e:
5010
+ print(f"Error downloading {file_name}: {e}. Retrying in {delay} seconds...")
5011
+ time.sleep(delay)
5012
+ else:
5013
+ raise Exception(f"Failed to download {file_name} after multiple attempts.")
5014
+
5015
+ return local_dir # Return the directory where models are saved
5016
+
5017
+ except (requests.HTTPError, requests.Timeout) as e:
5018
+ print(f"Error downloading files: {e}. Retrying in {delay} seconds...")
5019
+ attempt += 1
5020
+ time.sleep(delay)
5021
+
5022
+ raise Exception("Failed to download model files after multiple attempts.")
5023
+
5024
+ def generate_cytoplasm_mask(nucleus_mask, cell_mask):
5025
+
5026
+ """
5027
+ Generates a cytoplasm mask from nucleus and cell masks.
5028
+
5029
+ Parameters:
5030
+ - nucleus_mask (np.array): Binary or segmented mask of the nucleus (non-zero values represent nucleus).
5031
+ - cell_mask (np.array): Binary or segmented mask of the whole cell (non-zero values represent cell).
5032
+
5033
+ Returns:
5034
+ - cytoplasm_mask (np.array): Mask for the cytoplasm (1 for cytoplasm, 0 for nucleus and pathogens).
5035
+ """
5036
+
5037
+ # Make sure the nucleus and cell masks are numpy arrays
5038
+ nucleus_mask = np.array(nucleus_mask)
5039
+ cell_mask = np.array(cell_mask)
5040
+
5041
+ # Generate cytoplasm mask
5042
+ cytoplasm_mask = np.where(np.logical_or(nucleus_mask != 0), 0, cell_mask)
5043
+
5044
+ return cytoplasm_mask