spacr 0.3.1__py3-none-any.whl → 0.3.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. spacr/__init__.py +19 -3
  2. spacr/cellpose.py +311 -0
  3. spacr/core.py +245 -2494
  4. spacr/deep_spacr.py +316 -48
  5. spacr/gui.py +1 -0
  6. spacr/gui_core.py +74 -63
  7. spacr/gui_elements.py +110 -5
  8. spacr/gui_utils.py +346 -6
  9. spacr/io.py +680 -141
  10. spacr/logger.py +28 -9
  11. spacr/measure.py +107 -95
  12. spacr/mediar.py +0 -3
  13. spacr/ml.py +1051 -0
  14. spacr/openai.py +37 -0
  15. spacr/plot.py +707 -20
  16. spacr/resources/data/lopit.csv +3833 -0
  17. spacr/resources/data/toxoplasma_metadata.csv +8843 -0
  18. spacr/resources/icons/convert.png +0 -0
  19. spacr/resources/{models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model → icons/dna_matrix.mp4} +0 -0
  20. spacr/sequencing.py +241 -1311
  21. spacr/settings.py +134 -47
  22. spacr/sim.py +0 -2
  23. spacr/submodules.py +349 -0
  24. spacr/timelapse.py +0 -2
  25. spacr/toxo.py +238 -0
  26. spacr/utils.py +419 -180
  27. {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/METADATA +31 -22
  28. {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/RECORD +32 -33
  29. spacr/chris.py +0 -50
  30. spacr/graph_learning.py +0 -340
  31. spacr/resources/MEDIAR/.git +0 -1
  32. spacr/resources/MEDIAR_weights/.DS_Store +0 -0
  33. spacr/resources/icons/.DS_Store +0 -0
  34. spacr/resources/icons/spacr_logo_rotation.gif +0 -0
  35. spacr/resources/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model_settings.csv +0 -23
  36. spacr/resources/models/cp/toxo_pv_lumen.CP_model +0 -0
  37. spacr/sim_app.py +0 -0
  38. {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/LICENSE +0 -0
  39. {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/WHEEL +0 -0
  40. {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/entry_points.txt +0 -0
  41. {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/top_level.txt +0 -0
spacr/utils.py CHANGED
@@ -1,6 +1,7 @@
1
- import sys, os, re, sqlite3, torch, torchvision, random, string, shutil, cv2, tarfile, glob, psutil, platform, gzip, subprocess
1
+ import os, re, sqlite3, torch, torchvision, random, string, shutil, cv2, tarfile, glob, psutil, platform, gzip, subprocess, time, requests, ast
2
2
 
3
3
  import numpy as np
4
+ import pandas as pd
4
5
  from cellpose import models as cp_models
5
6
  from cellpose import denoise
6
7
 
@@ -14,7 +15,6 @@ from skimage.segmentation import clear_border
14
15
 
15
16
  from collections import defaultdict, OrderedDict
16
17
  from PIL import Image
17
- import pandas as pd
18
18
  from statsmodels.stats.outliers_influence import variance_inflation_factor
19
19
  from statsmodels.stats.stattools import durbin_watson
20
20
  import statsmodels.formula.api as smf
@@ -24,7 +24,7 @@ from itertools import combinations
24
24
  from functools import reduce
25
25
  from IPython.display import display
26
26
 
27
- from multiprocessing import Pool, cpu_count
27
+ from multiprocessing import Pool, cpu_count, set_start_method, get_start_method
28
28
  from concurrent.futures import ThreadPoolExecutor
29
29
 
30
30
  import torch.nn as nn
@@ -33,65 +33,118 @@ from torch.utils.checkpoint import checkpoint
33
33
  from torch.utils.data import Subset
34
34
  from torch.autograd import grad
35
35
 
36
+ from torchvision import models
37
+ from torchvision.models.resnet import ResNet18_Weights, ResNet34_Weights, ResNet50_Weights, ResNet101_Weights, ResNet152_Weights
38
+ import torchvision.transforms as transforms
39
+ from torchvision.models import resnet50
40
+ from torchvision.utils import make_grid
41
+
36
42
  import seaborn as sns
37
43
  import matplotlib.pyplot as plt
38
44
  from matplotlib.offsetbox import OffsetImage, AnnotationBbox
39
45
 
46
+ from scipy import stats
40
47
  import scipy.ndimage as ndi
41
48
  from scipy.spatial import distance
42
- from scipy.stats import fisher_exact
49
+ from scipy.stats import fisher_exact, f_oneway, kruskal
43
50
  from scipy.ndimage.filters import gaussian_filter
44
51
  from scipy.spatial import ConvexHull
45
52
  from scipy.interpolate import splprep, splev
46
53
  from scipy.ndimage import binary_dilation
47
54
 
48
- from sklearn.preprocessing import StandardScaler
49
55
  from skimage.exposure import rescale_intensity
50
56
  from sklearn.metrics import auc, precision_recall_curve
51
57
  from sklearn.model_selection import train_test_split
52
58
  from sklearn.linear_model import Lasso, Ridge
53
- from sklearn.preprocessing import OneHotEncoder
54
- from sklearn.cluster import KMeans
55
- from sklearn.preprocessing import StandardScaler
56
- from sklearn.cluster import DBSCAN
57
- from sklearn.cluster import KMeans
59
+ from sklearn.preprocessing import OneHotEncoder, StandardScaler
60
+ from sklearn.cluster import KMeans, DBSCAN
58
61
  from sklearn.manifold import TSNE
59
- from sklearn.cluster import KMeans
60
62
  from sklearn.decomposition import PCA
63
+ from sklearn.ensemble import RandomForestClassifier
64
+
65
+ from huggingface_hub import list_repo_files
61
66
 
62
67
  import umap.umap_ as umap
68
+ #import umap
63
69
 
64
- from torchvision import models
65
- from torchvision.models.resnet import ResNet18_Weights, ResNet34_Weights, ResNet50_Weights, ResNet101_Weights, ResNet152_Weights
66
- import torchvision.transforms as transforms
70
+ def load_settings(csv_file_path, show=False, setting_key='setting_key', setting_value='setting_value'):
71
+ """
72
+ Convert a CSV file with 'settings_key' and 'settings_value' columns into a dictionary.
73
+ Handles special cases where values are lists, tuples, booleans, None, integers, floats, and nested dictionaries.
67
74
 
68
- from sklearn.ensemble import RandomForestClassifier
69
- from sklearn.preprocessing import StandardScaler
70
- from scipy.stats import f_oneway, kruskal
71
- from sklearn.cluster import KMeans
72
- from scipy import stats
75
+ Args:
76
+ csv_file_path (str): The path to the CSV file.
77
+ show (bool): Whether to display the dataframe (for debugging).
78
+ setting_key (str): The name of the column that contains the setting keys.
79
+ setting_value (str): The name of the column that contains the setting values.
73
80
 
74
- from .logger import log_function_call
75
- from multiprocessing import set_start_method, get_start_method
81
+ Returns:
82
+ dict: A dictionary where 'settings_key' are the keys and 'settings_value' are the values.
83
+ """
84
+ # Read the CSV file into a DataFrame
85
+ df = pd.read_csv(csv_file_path)
76
86
 
77
- import torch
78
- import torchvision.transforms as transforms
79
- from torchvision.models import resnet50
80
- from PIL import Image
81
- import numpy as np
82
- import umap
83
- import pandas as pd
84
- from sklearn.ensemble import RandomForestClassifier
85
- from sklearn.preprocessing import StandardScaler
86
- from scipy.stats import f_oneway, kruskal
87
- from sklearn.cluster import KMeans
88
- from scipy import stats
87
+ if show:
88
+ display(df)
89
+
90
+ # Ensure the columns 'setting_key' and 'setting_value' exist
91
+ if setting_key not in df.columns or setting_value not in df.columns:
92
+ raise ValueError(f"CSV file must contain {setting_key} and {setting_value} columns.")
93
+
94
+ def parse_value(value):
95
+ """Parse the string value into the appropriate Python data type."""
96
+ # Handle empty values
97
+ if pd.isna(value) or value == '':
98
+ return None
99
+
100
+ # Handle boolean values
101
+ if value == 'True':
102
+ return True
103
+ if value == 'False':
104
+ return False
105
+
106
+ # Handle lists, tuples, dictionaries, and other literals
107
+ if value.startswith(('(', '[', '{')): # If it starts with (, [ or {, use ast.literal_eval
108
+ try:
109
+ parsed_value = ast.literal_eval(value)
110
+ # If parsed_value is a dict, recursively parse its values
111
+ if isinstance(parsed_value, dict):
112
+ parsed_value = {k: parse_value(v) for k, v in parsed_value.items()}
113
+ return parsed_value
114
+ except (ValueError, SyntaxError):
115
+ pass # If there's an error, return the value as-is
116
+
117
+ # Handle numeric values (integers and floats)
118
+ try:
119
+ if '.' in value:
120
+ return float(value) # If it contains a dot, convert to float
121
+ return int(value) # Otherwise, convert to integer
122
+ except ValueError:
123
+ pass # If it's not a valid number, return the value as-is
124
+
125
+ # Return the original value if no other type matched
126
+ return value
127
+
128
+ # Convert the DataFrame to a dictionary, with parsing of each value
129
+ result_dict = {key: parse_value(value) for key, value in zip(df[setting_key], df[setting_value])}
89
130
 
90
- def save_settings(settings, name='settings'):
131
+ return result_dict
132
+
133
+
134
+ def save_settings(settings, name='settings', show=False):
91
135
 
92
136
  settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
93
- settings_csv = os.path.join(settings['src'],'settings',f'{name}.csv')
94
- os.makedirs(os.path.join(settings['src'],'settings'), exist_ok=True)
137
+ if show:
138
+ display(settings_df)
139
+
140
+ if isinstance(settings['src'], list):
141
+ src = settings['src'][0]
142
+ name = f"{name}_list"
143
+ else:
144
+ src = settings['src']
145
+
146
+ settings_csv = os.path.join(src,'settings',f'{name}.csv')
147
+ os.makedirs(os.path.join(src,'settings'), exist_ok=True)
95
148
  settings_df.to_csv(settings_csv, index=False)
96
149
 
97
150
  def print_progress(files_processed, files_to_process, n_jobs, time_ls=None, batch_size=None, operation_type=""):
@@ -303,7 +356,7 @@ def _get_cellpose_batch_size():
303
356
  except Exception as e:
304
357
  return 8
305
358
 
306
- def _extract_filename_metadata(filenames, src, regular_expression, metadata_type='cellvoyager', pick_slice=False, skip_mode='01'):
359
+ def _extract_filename_metadata_v1(filenames, src, regular_expression, metadata_type='cellvoyager', pick_slice=False, skip_mode='01'):
307
360
 
308
361
  images_by_key = defaultdict(list)
309
362
 
@@ -353,6 +406,57 @@ def _extract_filename_metadata(filenames, src, regular_expression, metadata_type
353
406
 
354
407
  return images_by_key
355
408
 
409
+ def _extract_filename_metadata(filenames, src, regular_expression, metadata_type='cellvoyager', pick_slice=False, skip_mode='01'):
410
+
411
+ images_by_key = defaultdict(list)
412
+
413
+ for filename in filenames:
414
+ match = regular_expression.match(filename)
415
+ if match:
416
+ try:
417
+ try:
418
+ plate = match.group('plateID')
419
+ except:
420
+ plate = os.path.basename(src)
421
+
422
+ well = match.group('wellID')
423
+ field = match.group('fieldID')
424
+ channel = match.group('chanID')
425
+ mode = None
426
+
427
+ if well[0].isdigit():
428
+ well = str(_safe_int_convert(well))
429
+ if field[0].isdigit():
430
+ field = str(_safe_int_convert(field))
431
+ if channel[0].isdigit():
432
+ channel = str(_safe_int_convert(channel))
433
+
434
+ if metadata_type =='cq1':
435
+ orig_wellID = wellID
436
+ wellID = _convert_cq1_well_id(wellID)
437
+ print(f'Converted Well ID: {orig_wellID} to {wellID}', end='\r', flush=True)
438
+
439
+ if pick_slice:
440
+ try:
441
+ mode = match.group('AID')
442
+ except IndexError:
443
+ sliceid = '00'
444
+
445
+ if mode == skip_mode:
446
+ continue
447
+
448
+ key = (plate, well, field, channel, mode)
449
+ file_path = os.path.join(src, filename) # Store the full path
450
+ images_by_key[key].append(file_path)
451
+
452
+ except IndexError:
453
+ print(f"Could not extract information from filename {filename} using provided regex")
454
+ else:
455
+ print(f"Filename {filename} did not match provided regex")
456
+ continue
457
+
458
+ return images_by_key
459
+
356
460
  def mask_object_count(mask):
357
461
  """
358
462
  Counts the number of objects in a given mask.
@@ -443,7 +547,7 @@ def _generate_representative_images(db_path, cells=['HeLa'], cell_loc=None, path
443
547
  from .plot import _plot_images_on_grid
444
548
 
445
549
  df = _read_and_join_tables(db_path)
446
- df = _annotate_conditions(df, cells, cell_loc, pathogens, pathogen_loc, treatments,treatment_loc)
550
+ df = annotate_conditions(df, cells, cell_loc, pathogens, pathogen_loc, treatments, treatment_loc)
447
551
 
448
552
  if update_db:
449
553
  _update_database_with_merged_info(db_path, df, table='png_list', columns=['pathogen', 'treatment', 'host_cells', 'condition', 'prcfo'])
@@ -489,34 +593,6 @@ def _map_values(row, values, locs):
489
593
  return value_dict.get(row[type_], None)
490
594
  return values[0] if values else None
491
595
 
492
- def _annotate_conditions(df, cells=['HeLa'], cell_loc=None, pathogens=['rh'], pathogen_loc=None, treatments=['cm'], treatment_loc=None):
493
- """
494
- Annotates conditions in the given DataFrame based on the provided parameters.
495
-
496
- Args:
497
- df (pandas.DataFrame): The DataFrame to annotate.
498
- cells (list, optional): The list of host cell types. Defaults to ['HeLa'].
499
- cell_loc (list, optional): The list of location identifiers for host cells. Defaults to None.
500
- pathogens (list, optional): The list of pathogens. Defaults to ['rh'].
501
- pathogen_loc (list, optional): The list of location identifiers for pathogens. Defaults to None.
502
- treatments (list, optional): The list of treatments. Defaults to ['cm'].
503
- treatment_loc (list, optional): The list of location identifiers for treatments. Defaults to None.
504
-
505
- Returns:
506
- pandas.DataFrame: The annotated DataFrame with the 'host_cells', 'pathogen', 'treatment', and 'condition' columns.
507
- """
508
-
509
-
510
- # Apply mappings or defaults
511
- df['host_cells'] = [cells[0]] * len(df) if cell_loc is None else df.apply(_map_values, args=(cells, cell_loc), axis=1)
512
- df['pathogen'] = [pathogens[0]] * len(df) if pathogen_loc is None else df.apply(_map_values, args=(pathogens, pathogen_loc), axis=1)
513
- df['treatment'] = [treatments[0]] * len(df) if treatment_loc is None else df.apply(_map_values, args=(treatments, treatment_loc), axis=1)
514
-
515
- # Construct condition column
516
- df['condition'] = df.apply(lambda row: '_'.join(filter(None, [row.get('pathogen'), row.get('treatment')])), axis=1)
517
- df['condition'] = df['condition'].apply(lambda x: x if x else 'none')
518
- return df
519
-
520
596
  def is_list_of_lists(var):
521
597
  if isinstance(var, list) and all(isinstance(i, list) for i in var):
522
598
  return True
@@ -1085,67 +1161,74 @@ def _get_cellpose_channels(src, nucleus_channel, pathogen_channel, cell_channel)
1085
1161
  else:
1086
1162
  cellpose_channels['cell'] = [0,0]
1087
1163
  return cellpose_channels
1088
-
1089
- def annotate_conditions(df, cells=['HeLa'], cell_loc=None, pathogens=['rh'], pathogen_loc=None, treatments=['cm'], treatment_loc=None, types = ['col','col','col']):
1164
+
1165
+ def annotate_conditions(df, cells=None, cell_loc=None, pathogens=None, pathogen_loc=None, treatments=None, treatment_loc=None):
1090
1166
  """
1091
- Annotates conditions in a DataFrame based on specified criteria.
1167
+ Annotates conditions in a DataFrame based on specified criteria and combines them into a 'condition' column.
1168
+ NaN is used for missing values, and they are excluded from the 'condition' column.
1092
1169
 
1093
1170
  Args:
1094
1171
  df (pandas.DataFrame): The DataFrame to annotate.
1095
- cells (list, optional): List of host cell types. Defaults to ['HeLa'].
1096
- cell_loc (list, optional): List of corresponding values for each host cell type. Defaults to None.
1097
- pathogens (list, optional): List of pathogens. Defaults to ['rh'].
1098
- pathogen_loc (list, optional): List of corresponding values for each pathogen. Defaults to None.
1099
- treatments (list, optional): List of treatments. Defaults to ['cm'].
1100
- treatment_loc (list, optional): List of corresponding values for each treatment. Defaults to None.
1101
- types (list, optional): List of column types for host cells, pathogens, and treatments. Defaults to ['col','col','col'].
1172
+ cells (list/str, optional): Host cell types. Defaults to None.
1173
+ cell_loc (list of lists, optional): Values for each host cell type. Defaults to None.
1174
+ pathogens (list/str, optional): Pathogens. Defaults to None.
1175
+ pathogen_loc (list of lists, optional): Values for each pathogen. Defaults to None.
1176
+ treatments (list/str, optional): Treatments. Defaults to None.
1177
+ treatment_loc (list of lists, optional): Values for each treatment. Defaults to None.
1102
1178
 
1103
1179
  Returns:
1104
- pandas.DataFrame: The annotated DataFrame.
1180
+ pandas.DataFrame: Annotated DataFrame with a combined 'condition' column.
1105
1181
  """
1182
+
1183
+ def _get_type(val):
1184
+ """Determine if a value maps to 'row' or 'col'."""
1185
+ if isinstance(val, str) and val.startswith('c'):
1186
+ return 'col'
1187
+ elif isinstance(val, str) and val.startswith('r'):
1188
+ return 'row'
1189
+ return None
1106
1190
 
1107
- # Function to apply to each row
1108
- def _map_values(row, dict_, type_='col'):
1191
+ def _map_or_default(column_name, values, loc, df):
1109
1192
  """
1110
- Maps the values in a row to corresponding keys in a dictionary.
1193
+ Consolidates the logic for mapping values or assigning defaults when loc is None.
1111
1194
 
1112
1195
  Args:
1113
- row (dict): The row containing the values to be mapped.
1114
- dict_ (dict): The dictionary containing the mapping values.
1115
- type_ (str, optional): The type of mapping to perform. Defaults to 'col'.
1116
-
1117
- Returns:
1118
- str: The mapped value if found, otherwise None.
1196
+ column_name (str): The column in the DataFrame to annotate.
1197
+ values (list/str): The list of values or a single string to annotate.
1198
+ loc (list of lists): Location mapping for the values, or None if not used.
1199
+ df (pandas.DataFrame): The DataFrame to modify.
1119
1200
  """
1120
- for values, cols in dict_.items():
1121
- if row[type_] in cols:
1122
- return values
1123
- return None
1201
+ if isinstance(values, str) or (isinstance(values, list) and loc is None):
1202
+ # Assign all rows the first value in the list or the single string
1203
+ df[column_name] = values if isinstance(values, str) else values[0]
1204
+ elif values is not None and loc is not None:
1205
+ # Perform the location-based mapping
1206
+ value_dict = {val: key for key, loc_list in zip(values, loc) for val in loc_list}
1207
+ df[column_name] = np.nan
1208
+ for val, key in value_dict.items():
1209
+ loc_type = _get_type(val)
1210
+ if loc_type:
1211
+ df.loc[df[loc_type] == val, column_name] = key
1212
+
1213
+ # Handle cells, pathogens, and treatments using the consolidated logic
1214
+ _map_or_default('host_cells', cells, cell_loc, df)
1215
+ _map_or_default('pathogen', pathogens, pathogen_loc, df)
1216
+ _map_or_default('treatment', treatments, treatment_loc, df)
1217
+
1218
+ # Conditionally fill NaN for pathogen and treatment columns if applicable
1219
+ if pathogens is not None:
1220
+ df['pathogen'].fillna(np.nan, inplace=True)
1221
+ if treatments is not None:
1222
+ df['treatment'].fillna(np.nan, inplace=True)
1223
+
1224
+ # Create the 'condition' column by excluding any NaN values, safely checking if 'host_cells', 'pathogen', and 'treatment' exist
1225
+ df['condition'] = df.apply(
1226
+ lambda x: '_'.join([str(v) for v in [x.get('host_cells'), x.get('pathogen'), x.get('treatment')] if pd.notna(v)]),
1227
+ axis=1
1228
+ )
1124
1229
 
1125
- if cell_loc is None:
1126
- df['host_cells'] = cells[0]
1127
- else:
1128
- cells_dict = dict(zip(cells, cell_loc))
1129
- df['host_cells'] = df.apply(lambda row: _map_values(row, cells_dict, type_=types[0]), axis=1)
1130
- if pathogen_loc is None:
1131
- if pathogens != None:
1132
- df['pathogen'] = 'none'
1133
- else:
1134
- pathogens_dict = dict(zip(pathogens, pathogen_loc))
1135
- df['pathogen'] = df.apply(lambda row: _map_values(row, pathogens_dict, type_=types[1]), axis=1)
1136
- if treatment_loc is None:
1137
- df['treatment'] = 'cm'
1138
- else:
1139
- treatments_dict = dict(zip(treatments, treatment_loc))
1140
- df['treatment'] = df.apply(lambda row: _map_values(row, treatments_dict, type_=types[2]), axis=1)
1141
- if pathogens != None:
1142
- df['condition'] = df['pathogen']+'_'+df['treatment']
1143
- else:
1144
- df['condition'] = df['treatment']
1145
1230
  return df
1146
-
1147
1231
 
1148
-
1149
1232
  def _split_data(df, group_by, object_type):
1150
1233
  """
1151
1234
  Splits the input dataframe into numeric and non-numeric parts, groups them by the specified column,
@@ -1951,9 +2034,10 @@ def add_images_to_tar(paths_chunk, tar_path, total_images):
1951
2034
  tar.add(img_path, arcname=arcname)
1952
2035
  with lock:
1953
2036
  counter.value += 1
1954
- if counter.value % 100 == 0: # Print every 100 updates
1955
- progress = (counter.value / total_images) * 100
1956
- print(f"Progress: {counter.value}/{total_images} ({progress:.2f}%)", end='\r', file=sys.stdout, flush=True)
2037
+ if counter.value % 10 == 0: # Print every 100 updates
2038
+ #progress = (counter.value / total_images) * 100
2039
+ #print(f"Progress: {counter.value}/{total_images} ({progress:.2f}%)", end='\r', file=sys.stdout, flush=True)
2040
+ print_progress(counter.value, total_images, n_jobs=1, time_ls=None, batch_size=None, operation_type="generating .tar dataset")
1957
2041
  except FileNotFoundError:
1958
2042
  print(f"File not found: {img_path}")
1959
2043
 
@@ -2070,52 +2154,6 @@ def check_multicollinearity(x):
2070
2154
  vif_data["VIF"] = [variance_inflation_factor(x.values, i) for i in range(x.shape[1])]
2071
2155
  return vif_data
2072
2156
 
2073
- def generate_dependent_variable(df, dv_loc, pc_min=0.95, nc_max=0.05, agg_type='mean'):
2074
-
2075
- from .plot import _plot_histograms_and_stats, _plot_plates
2076
-
2077
- def qstring_to_float(qstr):
2078
- number = int(qstr[1:]) # Remove the "q" and convert the rest to an integer
2079
- return number / 100.0
2080
-
2081
- print("Unique values in plate:", df['plate'].unique())
2082
- dv_cell_loc = f'{dv_loc}/dv_cell.csv'
2083
- dv_well_loc = f'{dv_loc}/dv_well.csv'
2084
-
2085
- df['pred'] = 1-df['pred'] #if you swiched pc and nc
2086
- df = df[(df['pred'] <= nc_max) | (df['pred'] >= pc_min)]
2087
-
2088
- if 'prc' not in df.columns:
2089
- df['prc'] = df['plate'] + '_' + df['row'] + '_' + df['col']
2090
-
2091
- if agg_type.startswith('q'):
2092
- val = qstring_to_float(agg_type)
2093
- agg_type = lambda x: x.quantile(val)
2094
-
2095
- # Aggregating for mean prediction and total count
2096
- df_grouped = df.groupby('prc').agg(
2097
- pred=('pred', agg_type),
2098
- recruitment=('recruitment', agg_type),
2099
- count_prc=('prc', 'size'),
2100
- #count_above_95=('pred', lambda x: (x > 0.95).sum()),
2101
- mean_pathogen_area=('pathogen_area', 'mean')
2102
- )
2103
-
2104
- df_cell = df[['prc', 'pred', 'pathogen_area', 'recruitment']]
2105
-
2106
- df_cell.to_csv(dv_cell_loc, index=True, header=True, mode='w')
2107
- df_grouped.to_csv(dv_well_loc, index=True, header=True, mode='w') # Changed from loc to dv_loc
2108
- display(df)
2109
- _plot_histograms_and_stats(df)
2110
- df_grouped = df_grouped.sort_values(by='count_prc', ascending=True)
2111
- display(df_grouped)
2112
- print('pred')
2113
- _plot_plates(df=df_cell, variable='pred', grouping='mean', min_max='allq', cmap='viridis')
2114
- print('recruitment')
2115
- _plot_plates(df=df_cell, variable='recruitment', grouping='mean', min_max='allq', cmap='viridis')
2116
-
2117
- return df_grouped
2118
-
2119
2157
  def lasso_reg(merged_df, alpha_value=0.01, reg_type='lasso'):
2120
2158
  # Separate predictors and response
2121
2159
  X = merged_df[['gene', 'grna', 'plate', 'row', 'column']]
@@ -3021,7 +3059,6 @@ def preprocess_image(image_path, image_size=224, channels=[1,2,3], normalize=Tru
3021
3059
  input_tensor = transform(image).unsqueeze(0)
3022
3060
  return image, input_tensor
3023
3061
 
3024
-
3025
3062
  class SaliencyMapGenerator:
3026
3063
  def __init__(self, model):
3027
3064
  self.model = model
@@ -3042,17 +3079,63 @@ class SaliencyMapGenerator:
3042
3079
  saliency = X.grad.abs()
3043
3080
  return saliency
3044
3081
 
3045
- def plot_saliency_maps(self, X, y, saliency, class_names):
3082
+ def compute_saliency_and_predictions(self, X):
3083
+ self.model.eval()
3084
+ X.requires_grad_()
3085
+
3086
+ # Forward pass to get predictions (logits)
3087
+ scores = self.model(X).squeeze()
3088
+
3089
+ # Get predicted class (0 or 1 for binary classification)
3090
+ predictions = (scores > 0).long()
3091
+
3092
+ # Compute saliency maps
3093
+ self.model.zero_grad()
3094
+ target_scores = scores * (2 * predictions - 1)
3095
+ target_scores.backward(torch.ones_like(target_scores))
3096
+
3097
+ saliency = X.grad.abs()
3098
+
3099
+ return saliency, predictions
3100
+
3101
+ def plot_saliency_grid(self, X, saliency, predictions, mode='mean'):
3046
3102
  N = X.shape[0]
3103
+ rows = (N + 7) // 8 # Ensure we can handle batches of different sizes
3104
+ fig, axs = plt.subplots(rows, 8, figsize=(16, rows * 2))
3105
+
3047
3106
  for i in range(N):
3048
- plt.subplot(2, N, i + 1)
3049
- plt.imshow(X[i].permute(1, 2, 0).cpu().numpy())
3050
- plt.axis('off')
3051
- plt.title(class_names[y[i]])
3052
- plt.subplot(2, N, N + i + 1)
3053
- plt.imshow(saliency[i].cpu().numpy(), cmap=plt.cm.hot)
3054
- plt.axis('off')
3055
- plt.gcf().set_size_inches(12, 5)
3107
+ ax = axs[i // 8, i % 8]
3108
+
3109
+ if mode == 'mean':
3110
+ saliency_map = saliency[i].mean(dim=0).cpu().numpy() # Mean saliency over channels
3111
+ ax.imshow(X[i].permute(1, 2, 0).detach().cpu().numpy()) # Added .detach() here
3112
+ ax.imshow(saliency_map, cmap='jet', alpha=0.5)
3113
+
3114
+ elif mode == 'channel':
3115
+ # Plot individual channels in a loop if the image has multiple channels
3116
+ for j in range(X.shape[1]):
3117
+ saliency_map = saliency[i, j].cpu().numpy()
3118
+ ax.imshow(saliency_map, cmap='jet')
3119
+ ax.axis('off')
3120
+
3121
+ elif mode == '3-channel' and X.shape[1] == 3:
3122
+ saliency_map = saliency[i].cpu().numpy().transpose(1, 2, 0)
3123
+ ax.imshow(saliency_map)
3124
+
3125
+ elif mode == '2-channel' and X.shape[1] == 2:
3126
+ saliency_map = saliency[i].cpu().numpy().transpose(1, 2, 0)
3127
+ ax.imshow(saliency_map)
3128
+
3129
+ # Add class label in top-left corner
3130
+ ax.text(5, 25, str(predictions[i].item()), fontsize=12, color='white', weight='bold',
3131
+ bbox=dict(facecolor='black', alpha=0.7, boxstyle='round,pad=0.2'))
3132
+ ax.axis('off')
3133
+
3134
+ # Turn off unused axes
3135
+ for j in range(N, rows * 8):
3136
+ fig.delaxes(axs[j // 8, j % 8])
3137
+
3138
+ plt.tight_layout(pad=0)
3056
3139
  plt.show()
3057
3140
 
3058
3141
  def preprocess_image(image_path, normalize=True, image_size=224, channels=[1,2,3]):
@@ -3594,13 +3677,48 @@ def plot_grid(cluster_images, colors, figuresize, black_background, verbose):
3594
3677
  plt.show()
3595
3678
  return grid_fig
3596
3679
 
3597
- def correct_paths(df, base_path):
3680
+ def generate_path_list_from_db(db_path, file_metadata):
3598
3681
 
3599
- if 'png_path' not in df.columns:
3600
- print("No 'png_path' column found in the dataframe.")
3601
- return df, None
3682
+ all_paths = []
3683
+
3684
+ # Connect to the database and retrieve the image paths
3685
+ print(f"Reading DataBase: {db_path}")
3686
+ try:
3687
+ with sqlite3.connect(db_path) as conn:
3688
+ cursor = conn.cursor()
3689
+ if file_metadata:
3690
+ if isinstance(file_metadata, str):
3691
+ cursor.execute("SELECT png_path FROM png_list WHERE png_path LIKE ?", (f"%{file_metadata}%",))
3692
+ else:
3693
+ cursor.execute("SELECT png_path FROM png_list")
3694
+
3695
+ while True:
3696
+ rows = cursor.fetchmany(1000)
3697
+ if not rows:
3698
+ break
3699
+ all_paths.extend([row[0] for row in rows])
3700
+
3701
+ except sqlite3.Error as e:
3702
+ print(f"Database error: {e}")
3703
+ return
3704
+ except Exception as e:
3705
+ print(f"Error: {e}")
3706
+ return
3602
3707
 
3603
- image_paths = df['png_path'].to_list()
3708
+ return all_paths
3709
+
3710
+ def correct_paths(df, base_path):
3711
+
3712
+ if isinstance(df, pd.DataFrame):
3713
+
3714
+ if 'png_path' not in df.columns:
3715
+ print("No 'png_path' column found in the dataframe.")
3716
+ return df, None
3717
+ else:
3718
+ image_paths = df['png_path'].to_list()
3719
+
3720
+ elif isinstance(df, list):
3721
+ image_paths = df
3604
3722
 
3605
3723
  adjusted_image_paths = []
3606
3724
  for path in image_paths:
@@ -3614,9 +3732,11 @@ def correct_paths(df, base_path):
3614
3732
  else:
3615
3733
  adjusted_image_paths.append(path)
3616
3734
 
3617
- df['png_path'] = adjusted_image_paths
3618
- image_paths = df['png_path'].to_list()
3619
- return df, image_paths
3735
+ if isinstance(df, pd.DataFrame):
3736
+ df['png_path'] = adjusted_image_paths
3737
+ return df, adjusted_image_paths
3738
+ else:
3739
+ return adjusted_image_paths
3620
3740
 
3621
3741
  def delete_folder(folder_path):
3622
3742
  if os.path.exists(folder_path) and os.path.isdir(folder_path):
@@ -4424,7 +4544,7 @@ def convert_and_relabel_masks(folder_path):
4424
4544
 
4425
4545
  def correct_masks(src):
4426
4546
 
4427
- from .utils import _load_and_concatenate_arrays
4547
+ from .io import _load_and_concatenate_arrays
4428
4548
 
4429
4549
  cell_path = os.path.join(src,'norm_channel_stack', 'cell_mask_stack')
4430
4550
  convert_and_relabel_masks(cell_path)
@@ -4447,4 +4567,123 @@ def get_cuda_version():
4447
4567
  except (subprocess.CalledProcessError, FileNotFoundError):
4448
4568
  return None
4449
4569
 
4570
+ def all_elements_match(list1, list2):
4571
+ # Check if all elements in list1 are in list2
4572
+ return all(element in list2 for element in list1)
4573
+
4574
+ def prepare_batch_for_segmentation(batch):
4575
+ # Ensure the batch is of dtype float32
4576
+ if batch.dtype != np.float32:
4577
+ batch = batch.astype(np.float32)
4578
+
4579
+ # Normalize each image in the batch
4580
+ for i in range(batch.shape[0]):
4581
+ if batch[i].max() > 1:
4582
+ batch[i] = batch[i] / batch[i].max()
4583
+
4584
+ return batch
4585
+
4586
+ def check_index(df, elements=5, split_char='_'):
4587
+ problematic_indices = []
4588
+ for idx in df.index:
4589
+ parts = str(idx).split(split_char)
4590
+ if len(parts) != elements:
4591
+ problematic_indices.append(idx)
4592
+ if problematic_indices:
4593
+ print("Indices that cannot be separated into 5 parts:")
4594
+ for idx in problematic_indices:
4595
+ print(idx)
4596
+ raise ValueError(f"Found {len(problematic_indices)} problematic indices that do not split into {elements} parts.")
4597
+
4598
+ # Define the mapping function
4599
+ def map_condition(col_value, neg='c1', pos='c2', mix='c3'):
4600
+ if col_value == neg:
4601
+ return 'neg'
4602
+ elif col_value == pos:
4603
+ return 'pos'
4604
+ elif col_value == mix:
4605
+ return 'mix'
4606
+ else:
4607
+ return 'screen'
4608
+
4609
+ def download_models(repo_id="einarolafsson/models", local_dir=None, retries=5, delay=5):
4610
+ """
4611
+ Downloads all model files from Hugging Face and stores them in the specified local directory.
4612
+
4613
+ Args:
4614
+ repo_id (str): The repository ID on Hugging Face (default is 'einarolafsson/models').
4615
+ local_dir (str): The local directory where models will be saved. Defaults to '/home/carruthers/Desktop/test'.
4616
+ retries (int): Number of retry attempts in case of failure.
4617
+ delay (int): Delay in seconds between retries.
4618
+
4619
+ Returns:
4620
+ str: The local path to the downloaded models.
4621
+ """
4622
+ # Create the local directory if it doesn't exist
4623
+ if not os.path.exists(local_dir):
4624
+ os.makedirs(local_dir)
4625
+ elif len(os.listdir(local_dir)) > 0:
4626
+ print(f"Models already downloaded to: {local_dir}")
4627
+ return local_dir
4628
+
4629
+ attempt = 0
4630
+ while attempt < retries:
4631
+ try:
4632
+ # List all files in the repo
4633
+ files = list_repo_files(repo_id, repo_type="dataset")
4634
+ print(f"Files in repository: {files}") # Debugging print to check file list
4450
4635
 
4636
+ # Download each file
4637
+ for file_name in files:
4638
+ for download_attempt in range(retries):
4639
+ try:
4640
+ url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/{file_name}?download=true"
4641
+ print(f"Downloading file from: {url}") # Debugging
4642
+
4643
+ response = requests.get(url, stream=True)
4644
+ print(f"HTTP response status: {response.status_code}") # Debugging
4645
+ response.raise_for_status()
4646
+
4647
+ # Save the file locally
4648
+ local_file_path = os.path.join(local_dir, os.path.basename(file_name))
4649
+ with open(local_file_path, 'wb') as file:
4650
+ for chunk in response.iter_content(chunk_size=8192):
4651
+ file.write(chunk)
4652
+ print(f"Downloaded model file: {file_name} to {local_file_path}")
4653
+ break # Exit the retry loop if successful
4654
+ except (requests.HTTPError, requests.Timeout) as e:
4655
+ print(f"Error downloading {file_name}: {e}. Retrying in {delay} seconds...")
4656
+ time.sleep(delay)
4657
+ else:
4658
+ raise Exception(f"Failed to download {file_name} after multiple attempts.")
4659
+
4660
+ return local_dir # Return the directory where models are saved
4661
+
4662
+ except (requests.HTTPError, requests.Timeout) as e:
4663
+ print(f"Error downloading files: {e}. Retrying in {delay} seconds...")
4664
+ attempt += 1
4665
+ time.sleep(delay)
4666
+
4667
+ raise Exception("Failed to download model files after multiple attempts.")
4668
+
4669
+ def generate_cytoplasm_mask(nucleus_mask, cell_mask):
4670
+
4671
+ """
4672
+ Generates a cytoplasm mask from nucleus and cell masks.
4673
+
4674
+ Parameters:
4675
+ - nucleus_mask (np.array): Binary or segmented mask of the nucleus (non-zero values represent nucleus).
4676
+ - cell_mask (np.array): Binary or segmented mask of the whole cell (non-zero values represent cell).
4677
+
4678
+ Returns:
4679
+ - cytoplasm_mask (np.array): Mask for the cytoplasm (1 for cytoplasm, 0 for nucleus and pathogens).
4680
+ """
4681
+
4682
+ # Make sure the nucleus and cell masks are numpy arrays
4683
+ nucleus_mask = np.array(nucleus_mask)
4684
+ cell_mask = np.array(cell_mask)
4685
+
4686
+ # Generate cytoplasm mask
4687
+ cytoplasm_mask = np.where(np.logical_or(nucleus_mask != 0), 0, cell_mask)
4688
+
4689
+ return cytoplasm_mask