spacr 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/core.py CHANGED
@@ -6,7 +6,6 @@ import pandas as pd
6
6
 
7
7
  import cellpose
8
8
  from cellpose import models as cp_models
9
- from cellpose import denoise
10
9
 
11
10
  import statsmodels.formula.api as smf
12
11
  import statsmodels.api as sm
@@ -29,17 +28,9 @@ matplotlib.use('Agg')
29
28
  import torchvision.transforms as transforms
30
29
  from sklearn.model_selection import train_test_split
31
30
  from sklearn.ensemble import IsolationForest
32
-
33
31
  from .logger import log_function_call
34
32
 
35
- #from .io import TarImageDataset, NoClassDataset, MyDataset, read_db, _copy_missclassified, read_mask, load_normalized_images_and_labels, load_images_and_labels
36
- #from .plot import plot_merged, plot_arrays, _plot_controls, _plot_recruitment, _imshow, _plot_histograms_and_stats, _reg_v_plot, visualize_masks, plot_comparison_results
37
- #from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index, dice_coefficient, _object_filter
38
- #from .utils import resize_images_and_labels, generate_fraction_map, MLR, fishers_odds, lasso_reg, model_metrics, _map_wells_png, check_multicollinearity, init_globals, add_images_to_tar
39
- #from .utils import get_paths_from_db, pick_best_model, test_model_performance, evaluate_model_performance, compute_irm_penalty
40
- #from .utils import _pivot_counts_table, _generate_masks, _get_cellpose_channels, annotate_conditions, _calculate_recruitment, calculate_loss, _group_by_well, choose_model
41
33
 
42
- @log_function_call
43
34
  def analyze_plaques(folder):
44
35
  summary_data = []
45
36
  details_data = []
@@ -76,7 +67,6 @@ def analyze_plaques(folder):
76
67
 
77
68
  print(f"Analysis completed and saved to database '{db_name}'.")
78
69
 
79
- @log_function_call
80
70
  def compare_masks(dir1, dir2, dir3, verbose=False):
81
71
 
82
72
  from .io import _read_mask
@@ -178,10 +168,9 @@ def generate_cp_masks(settings):
178
168
 
179
169
  dst = os.path.join(src,'masks')
180
170
  os.makedirs(dst, exist_ok=True)
181
-
171
+
182
172
  identify_masks(src, dst, model_name, channels, diameter, batch_size, flow_threshold, cellprob_threshold, figuresize, cmap, verbose, plot, save, custom_model, signal_thresholds, normalize, resize, target_height, target_width, rescale, resample, net_avg, invert, circular, percentiles, overlay, grayscale)
183
173
 
184
- @log_function_call
185
174
  def train_cellpose(settings):
186
175
 
187
176
  from .io import _load_normalized_images_and_labels, _load_images_and_labels
@@ -281,7 +270,6 @@ def train_cellpose(settings):
281
270
 
282
271
  return print(f"Model saved at: {model_save_path}/{model_name}")
283
272
 
284
- @log_function_call
285
273
  def analyze_data_reg(sequencing_loc, dv_loc, agg_type = 'mean', dv_col='pred', transform=None, min_cell_count=50, min_reads=100, min_wells=2, max_wells=1000, min_frequency=0.0,remove_outlier_genes=False, refine_model=False,by_plate=False, regression_type='mlr', alpha_value=0.01, fishers=False, fisher_threshold=0.9):
286
274
 
287
275
  from .plot import _reg_v_plot
@@ -430,7 +418,6 @@ def analyze_data_reg(sequencing_loc, dv_loc, agg_type = 'mean', dv_col='pred', t
430
418
 
431
419
  return result
432
420
 
433
- @log_function_call
434
421
  def analyze_data_reg(sequencing_loc, dv_loc, agg_type = 'mean', min_cell_count=50, min_reads=100, min_wells=2, max_wells=1000, remove_outlier_genes=False, refine_model=False, by_plate=False, threshold=0.5, fishers=False):
435
422
 
436
423
  from .plot import _reg_v_plot
@@ -609,7 +596,6 @@ def analyze_data_reg(sequencing_loc, dv_loc, agg_type = 'mean', min_cell_count=5
609
596
 
610
597
  return max_effects, max_effects_pvalues, model, df
611
598
 
612
- @log_function_call
613
599
  def regression_analasys(dv_df,sequencing_loc, min_reads=75, min_wells=2, max_wells=0, model_type = 'mlr', min_cells=100, transform='logit', min_frequency=0.05, gene_column='gene', effect_size_threshold=0.25, fishers=True, clean_regression=False, VIF_threshold=10):
614
600
 
615
601
  from .utils import generate_fraction_map, fishers_odds, model_metrics, check_multicollinearity
@@ -777,7 +763,6 @@ def regression_analasys(dv_df,sequencing_loc, min_reads=75, min_wells=2, max_wel
777
763
 
778
764
  return
779
765
 
780
- @log_function_call
781
766
  def merge_pred_mes(src,
782
767
  pred_loc,
783
768
  target='protein of interest',
@@ -1088,7 +1073,7 @@ def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=22
1088
1073
  batch_prediction_pos_prob = torch.sigmoid(outputs).cpu().numpy()
1089
1074
  prediction_pos_probs.extend(batch_prediction_pos_prob.tolist())
1090
1075
  filenames_list.extend(filenames)
1091
- print(f'\rbatch: {batch_idx}/{len(data_loader)}', end='\r', flush=True)
1076
+ print(f'batch: {batch_idx}/{len(data_loader)}', end='\r', flush=True)
1092
1077
 
1093
1078
  data = {'path':filenames_list, 'pred':prediction_pos_probs}
1094
1079
  df = pd.DataFrame(data, index=None)
@@ -1649,15 +1634,27 @@ def analyze_recruitment(src, metadata_settings, advanced_settings):
1649
1634
  cells,wells = _results_to_csv(src, df, df_well)
1650
1635
  return [cells,wells]
1651
1636
 
1652
- @log_function_call
1653
- def preprocess_generate_masks(src, settings={},advanced_settings={}):
1637
+ def preprocess_generate_masks(src, settings={}):
1654
1638
 
1655
1639
  from .io import preprocess_img_data, _load_and_concatenate_arrays
1656
1640
  from .plot import plot_merged, plot_arrays
1657
1641
  from .utils import _pivot_counts_table
1658
-
1659
- settings = {**settings, **advanced_settings}
1642
+
1643
+ settings['fps'] = 2
1644
+ settings['remove_background'] = True
1645
+ settings['lower_quantile'] = 0.02
1646
+ settings['merge'] = False
1647
+ settings['normalize_plots'] = True
1648
+ settings['all_to_mip'] = False
1649
+ settings['pick_slice'] = False
1650
+ settings['skip_mode'] = src
1651
+ settings['workers'] = os.cpu_count()-4
1652
+ settings['verbose'] = True
1653
+ settings['examples_to_plot'] = 1
1660
1654
  settings['src'] = src
1655
+ settings['upscale'] = False
1656
+ settings['upscale_factor'] = 2.0
1657
+
1661
1658
  settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
1662
1659
  settings_csv = os.path.join(src,'settings','preprocess_generate_masks_settings.csv')
1663
1660
  os.makedirs(os.path.join(src,'settings'), exist_ok=True)
@@ -1676,7 +1673,7 @@ def preprocess_generate_masks(src, settings={},advanced_settings={}):
1676
1673
  settings['save'] = [settings['save']]*3
1677
1674
 
1678
1675
  if settings['preprocess']:
1679
- preprocess_img_data(settings)
1676
+ settings, src = preprocess_img_data(settings)
1680
1677
 
1681
1678
  if settings['masks']:
1682
1679
  mask_src = os.path.join(src, 'norm_channel_stack')
@@ -1747,6 +1744,7 @@ def preprocess_generate_masks(src, settings={},advanced_settings={}):
1747
1744
 
1748
1745
  torch.cuda.empty_cache()
1749
1746
  gc.collect()
1747
+ print("Successfully completed run")
1750
1748
  return
1751
1749
 
1752
1750
  def identify_masks_finetune(src, dst, model_name, channels, diameter, batch_size, flow_threshold=30, cellprob_threshold=1, figuresize=25, cmap='inferno', verbose=False, plot=False, save=False, custom_model=None, signal_thresholds=1000, normalize=True, resize=False, target_height=None, target_width=None, rescale=True, resample=True, net_avg=False, invert=False, circular=False, percentiles=None, overlay=True, grayscale=False):
@@ -1836,8 +1834,7 @@ def identify_masks_finetune(src, dst, model_name, channels, diameter, batch_size
1836
1834
  cv2.imwrite(output_filename, mask)
1837
1835
  return
1838
1836
 
1839
- @log_function_call
1840
- def identify_masks(src, object_type, model_name, batch_size, channels, diameter, minimum_size, maximum_size, flow_threshold=30, cellprob_threshold=1, figuresize=25, cmap='inferno', refine_masks=True, filter_size=True, filter_dimm=True, remove_border_objects=False, verbose=False, plot=False, merge=False, save=True, start_at=0, file_type='.npz', net_avg=True, resample=True, timelapse=False, timelapse_displacement=None, timelapse_frame_limits=None, timelapse_memory=3, timelapse_remove_transient=False, timelapse_mode='btrack', timelapse_objects='cell'):
1837
+ def identify_masks(src, object_type, model_name, batch_size, channels, diameter, minimum_size, maximum_size, filter_intensity, flow_threshold=30, cellprob_threshold=1, figuresize=25, cmap='inferno', refine_masks=True, filter_size=True, filter_dimm=True, remove_border_objects=False, verbose=False, plot=False, merge=False, save=True, start_at=0, file_type='.npz', net_avg=True, resample=True, timelapse=False, timelapse_displacement=None, timelapse_frame_limits=None, timelapse_memory=3, timelapse_remove_transient=False, timelapse_mode='btrack', timelapse_objects='cell'):
1841
1838
  """
1842
1839
  Identify masks from the source images.
1843
1840
 
@@ -1891,7 +1888,8 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
1891
1888
  print(f'Torch CUDA is not available, using CPU')
1892
1889
 
1893
1890
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1894
- model = cp_models.Cellpose(gpu=True, model_type=model_name, device=device) #net_avg=net_avg
1891
+ model = cp_models.Cellpose(gpu=True, model_type=model_name, device=device)
1892
+
1895
1893
  if file_type == '.npz':
1896
1894
  paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.npz')]
1897
1895
  else:
@@ -1918,9 +1916,6 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
1918
1916
 
1919
1917
  average_sizes = []
1920
1918
  time_ls = []
1921
- moving_avg_q1 = 0
1922
- moving_avg_q3 = 0
1923
- moving_count = 0
1924
1919
  for file_index, path in enumerate(paths):
1925
1920
 
1926
1921
  name = os.path.basename(path)
@@ -1961,7 +1956,8 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
1961
1956
  if not plot:
1962
1957
  batch, batch_filenames = _check_masks(batch, batch_filenames, output_folder)
1963
1958
  if batch.size == 0:
1964
- print(f'Processing {file_index}/{len(paths)}: Images/N100pz {batch.shape[0]}', end='\r', flush=True)
1959
+ print(f'Processing: {file_index}/{len(paths)}: Images/N100pz {batch.shape[0]}')
1960
+ #print(f'Processing {file_index}/{len(paths)}: Images/N100pz {batch.shape[0]}', end='\r', flush=True)
1965
1961
  continue
1966
1962
  if batch.max() > 1:
1967
1963
  batch = batch / batch.max()
@@ -1977,7 +1973,7 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
1977
1973
 
1978
1974
  cellpose_batch_size = _get_cellpose_batch_size()
1979
1975
 
1980
- model = cellpose.denoise.DenoiseModel(model_type=f"denoise_{model_name}", gpu=True)
1976
+ #model = cellpose.denoise.DenoiseModel(model_type=f"denoise_{model_name}", gpu=True)
1981
1977
 
1982
1978
  masks, flows, _, _ = model.eval(x=batch,
1983
1979
  batch_size=cellpose_batch_size,
@@ -1989,9 +1985,9 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
1989
1985
  cellprob_threshold=cellprob_threshold,
1990
1986
  rescale=None,
1991
1987
  resample=resample,
1992
- #net_avg=net_avg,
1993
1988
  stitch_threshold=stitch_threshold,
1994
1989
  progress=None)
1990
+
1995
1991
  print('Masks shape',masks.shape)
1996
1992
  if timelapse:
1997
1993
  _save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_timelapse')
@@ -2015,7 +2011,7 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
2015
2011
 
2016
2012
  else:
2017
2013
  _save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_before_filtration')
2018
- mask_stack = _filter_cp_masks(masks, flows, refine_masks, filter_size, minimum_size, maximum_size, remove_border_objects, merge, filter_dimm, batch, moving_avg_q1, moving_avg_q3, moving_count, plot, figuresize)
2014
+ mask_stack = _filter_cp_masks(masks, flows, filter_size, filter_intensity, minimum_size, maximum_size, remove_border_objects, merge, batch, plot, figuresize)
2019
2015
  _save_object_counts_to_database(mask_stack, object_type, batch_filenames, count_loc, added_string='_after_filtration')
2020
2016
 
2021
2017
  if not np.any(mask_stack):
@@ -2032,7 +2028,8 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
2032
2028
  average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
2033
2029
  time_in_min = average_time/60
2034
2030
  time_per_mask = average_time/batch_size
2035
- print(f'Processing {len(paths)} files with {batch_size} imgs: {(file_index+1)*(batch_size+1)}/{(len(paths))*(batch_size+1)}: Time/batch {time_in_min:.3f} min: Time/mask {time_per_mask:.3f}sec: {object_type} size: {overall_average_size:.3f} px2', end='\r', flush=True)
2031
+ print(f'Processing: {len(paths)} files with {batch_size} imgs: {(file_index+1)*(batch_size+1)}/{(len(paths))*(batch_size+1)}: Time/batch {time_in_min:.3f} min: Time/mask {time_per_mask:.3f}sec: {object_type} size: {overall_average_size:.3f} px2')
2032
+ #print(f'Processing {len(paths)} files with {batch_size} imgs: {(file_index+1)*(batch_size+1)}/{(len(paths))*(batch_size+1)}: Time/batch {time_in_min:.3f} min: Time/mask {time_per_mask:.3f}sec: {object_type} size: {overall_average_size:.3f} px2', end='\r', flush=True)
2036
2033
  if not timelapse:
2037
2034
  if plot:
2038
2035
  plot_masks(batch, mask_stack, flows, figuresize=figuresize, cmap=cmap, nr=batch_size, file_type='.npz')
@@ -2046,10 +2043,13 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
2046
2043
  gc.collect()
2047
2044
  return
2048
2045
 
2049
- @log_function_call
2050
- def generate_cellpose_masks(src, settings, object_type):
2046
+ def all_elements_match(list1, list2):
2047
+ # Check if all elements in list1 are in list2
2048
+ return all(element in list2 for element in list1)
2049
+
2050
+ def generate_cellpose_masks_v1(src, settings, object_type):
2051
2051
 
2052
- from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_object_settings, _get_cellpose_channels
2052
+ from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_object_settings, _get_cellpose_channels, mask_object_count
2053
2053
  from .io import _create_database, _save_object_counts_to_database, _check_masks, _get_avg_object_size
2054
2054
  from .timelapse import _npz_to_movie, _btrack_track_cells, _trackpy_track_cells
2055
2055
  from .plot import plot_masks
@@ -2076,7 +2076,9 @@ def generate_cellpose_masks(src, settings, object_type):
2076
2076
  object_settings = _get_object_settings(object_type, settings)
2077
2077
  model_name = object_settings['model_name']
2078
2078
 
2079
- cellpose_channels = _get_cellpose_channels(settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel'])
2079
+ cellpose_channels = _get_cellpose_channels(src, settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel'])
2080
+ if settings['verbose']:
2081
+ print(cellpose_channels)
2080
2082
  channels = cellpose_channels[object_type]
2081
2083
  cellpose_batch_size = _get_cellpose_batch_size()
2082
2084
 
@@ -2094,10 +2096,7 @@ def generate_cellpose_masks(src, settings, object_type):
2094
2096
 
2095
2097
  average_sizes = []
2096
2098
  time_ls = []
2097
- moving_avg_q1 = 0
2098
- moving_avg_q3 = 0
2099
- moving_count = 0
2100
-
2099
+
2101
2100
  for file_index, path in enumerate(paths):
2102
2101
  name = os.path.basename(path)
2103
2102
  name, ext = os.path.splitext(name)
@@ -2108,16 +2107,22 @@ def generate_cellpose_masks(src, settings, object_type):
2108
2107
  stack = data['data']
2109
2108
  filenames = data['filenames']
2110
2109
  if settings['timelapse']:
2110
+
2111
+ trackable_objects = ['cell','nucleus','pathogen']
2112
+ if not all_elements_match(settings['timelapse_objects'], trackable_objects):
2113
+ print(f'timelapse_objects {settings["timelapse_objects"]} must be a subset of {trackable_objects}')
2114
+ return
2115
+
2111
2116
  if len(stack) != batch_size:
2112
2117
  print(f'Changed batch_size:{batch_size} to {len(stack)}, data length:{len(stack)}')
2113
- settings['batch_size'] = len(stack)
2118
+ settings['timelapse_batch_size'] = len(stack)
2114
2119
  batch_size = len(stack)
2115
2120
  if isinstance(timelapse_frame_limits, list):
2116
2121
  if len(timelapse_frame_limits) >= 2:
2117
2122
  stack = stack[timelapse_frame_limits[0]: timelapse_frame_limits[1], :, :, :].astype(stack.dtype)
2118
2123
  filenames = filenames[timelapse_frame_limits[0]: timelapse_frame_limits[1]]
2119
2124
  batch_size = len(stack)
2120
- print(f'Cut batch an indecies: {timelapse_frame_limits}, New batch_size: {batch_size} ')
2125
+ print(f'Cut batch at indecies: {timelapse_frame_limits}, New batch_size: {batch_size} ')
2121
2126
 
2122
2127
  for i in range(0, stack.shape[0], batch_size):
2123
2128
  mask_stack = []
@@ -2133,7 +2138,7 @@ def generate_cellpose_masks(src, settings, object_type):
2133
2138
  if not settings['plot']:
2134
2139
  batch, batch_filenames = _check_masks(batch, batch_filenames, output_folder)
2135
2140
  if batch.size == 0:
2136
- print(f'Processing {file_index}/{len(paths)}: Images/N100pz {batch.shape[0]}', end='\r', flush=True)
2141
+ print(f'Processing {file_index}/{len(paths)}: Images/npz {batch.shape[0]}')
2137
2142
  continue
2138
2143
  if batch.max() > 1:
2139
2144
  batch = batch / batch.max()
@@ -2146,10 +2151,8 @@ def generate_cellpose_masks(src, settings, object_type):
2146
2151
  _npz_to_movie(batch, batch_filenames, save_path, fps=2)
2147
2152
  else:
2148
2153
  stitch_threshold=0.0
2149
- #print(batch.shape)
2150
- #batch, _, _, _ = dn.eval(x=batch, channels=chans, diameter=object_settings['diameter'])
2151
- #batch = np.stack((batch, batch), axis=-1)
2152
- #print(f'object: {object_type} chans : {chans} channels : {channels} model: {model_name}')
2154
+
2155
+ print('batch.shape',batch.shape)
2153
2156
  masks, flows, _, _ = model.eval(x=batch,
2154
2157
  batch_size=cellpose_batch_size,
2155
2158
  normalize=False,
@@ -2161,9 +2164,220 @@ def generate_cellpose_masks(src, settings, object_type):
2161
2164
  rescale=None,
2162
2165
  resample=object_settings['resample'],
2163
2166
  stitch_threshold=stitch_threshold)
2164
- #progress=None)
2167
+
2168
+ if timelapse:
2169
+ if settings['plot']:
2170
+ for idx, (mask, flow, image) in enumerate(zip(masks, flows[0], batch)):
2171
+ if idx == 0:
2172
+ num_objects = mask_object_count(mask)
2173
+ print(f'Number of objects: {num_objects}')
2174
+ plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2175
+
2176
+ _save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_timelapse')
2177
+ if object_type in timelapse_objects:
2178
+ if timelapse_mode == 'btrack':
2179
+ if not timelapse_displacement is None:
2180
+ radius = timelapse_displacement
2181
+ else:
2182
+ radius = 100
2183
+
2184
+ workers = os.cpu_count()-2
2185
+ if workers < 1:
2186
+ workers = 1
2187
+
2188
+ mask_stack = _btrack_track_cells(src=src,
2189
+ name=name,
2190
+ batch_filenames=batch_filenames,
2191
+ object_type=object_type,
2192
+ plot=settings['plot'],
2193
+ save=settings['save'],
2194
+ masks_3D=masks,
2195
+ mode=timelapse_mode,
2196
+ timelapse_remove_transient=timelapse_remove_transient,
2197
+ radius=radius,
2198
+ workers=workers)
2199
+ if timelapse_mode == 'trackpy':
2200
+ mask_stack = _trackpy_track_cells(src=src,
2201
+ name=name,
2202
+ batch_filenames=batch_filenames,
2203
+ object_type=object_type,
2204
+ masks=masks,
2205
+ timelapse_displacement=timelapse_displacement,
2206
+ timelapse_memory=timelapse_memory,
2207
+ timelapse_remove_transient=timelapse_remove_transient,
2208
+ plot=settings['plot'],
2209
+ save=settings['save'],
2210
+ mode=timelapse_mode)
2211
+ else:
2212
+ mask_stack = _masks_to_masks_stack(masks)
2213
+
2214
+ else:
2215
+ _save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_before_filtration')
2216
+ mask_stack = _filter_cp_masks(masks=masks,
2217
+ flows=flows,
2218
+ filter_size=object_settings['filter_size'],
2219
+ filter_intensity=object_settings['filter_intensity'],
2220
+ minimum_size=object_settings['minimum_size'],
2221
+ maximum_size=object_settings['maximum_size'],
2222
+ remove_border_objects=object_settings['remove_border_objects'],
2223
+ merge=False,
2224
+ batch=batch,
2225
+ plot=settings['plot'],
2226
+ figuresize=figuresize)
2227
+
2228
+ _save_object_counts_to_database(mask_stack, object_type, batch_filenames, count_loc, added_string='_after_filtration')
2229
+
2230
+ if not np.any(mask_stack):
2231
+ average_obj_size = 0
2232
+ else:
2233
+ average_obj_size = _get_avg_object_size(mask_stack)
2234
+
2235
+ average_sizes.append(average_obj_size)
2236
+ overall_average_size = np.mean(average_sizes) if len(average_sizes) > 0 else 0
2237
+
2238
+ stop = time.time()
2239
+ duration = (stop - start)
2240
+ time_ls.append(duration)
2241
+ average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
2242
+ time_in_min = average_time/60
2243
+ time_per_mask = average_time/batch_size
2244
+ print(f'Processing {len(paths)} files with {batch_size} imgs: {(file_index+1)*(batch_size+1)}/{(len(paths))*(batch_size+1)}: Time/batch {time_in_min:.3f} min: Time/mask {time_per_mask:.3f}sec: {object_type} size: {overall_average_size:.3f} px2')
2245
+ if not timelapse:
2246
+ if settings['plot']:
2247
+ plot_masks(batch, mask_stack, flows, figuresize=figuresize, cmap='inferno', nr=batch_size)
2248
+ if settings['save']:
2249
+ for mask_index, mask in enumerate(mask_stack):
2250
+ output_filename = os.path.join(output_folder, batch_filenames[mask_index])
2251
+ np.save(output_filename, mask)
2252
+ mask_stack = []
2253
+ batch_filenames = []
2254
+ gc.collect()
2255
+ torch.cuda.empty_cache()
2256
+ return
2257
+
2258
+ def generate_cellpose_masks(src, settings, object_type):
2259
+
2260
+ from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_object_settings, _get_cellpose_channels, _choose_model, mask_object_count
2261
+ from .io import _create_database, _save_object_counts_to_database, _check_masks, _get_avg_object_size
2262
+ from .timelapse import _npz_to_movie, _btrack_track_cells, _trackpy_track_cells
2263
+ from .plot import plot_masks
2264
+
2265
+ gc.collect()
2266
+ if not torch.cuda.is_available():
2267
+ print(f'Torch CUDA is not available, using CPU')
2268
+
2269
+ figuresize=25
2270
+ timelapse = settings['timelapse']
2271
+
2272
+ if timelapse:
2273
+ timelapse_displacement = settings['timelapse_displacement']
2274
+ timelapse_frame_limits = settings['timelapse_frame_limits']
2275
+ timelapse_memory = settings['timelapse_memory']
2276
+ timelapse_remove_transient = settings['timelapse_remove_transient']
2277
+ timelapse_mode = settings['timelapse_mode']
2278
+ timelapse_objects = settings['timelapse_objects']
2279
+
2280
+ batch_size = settings['batch_size']
2281
+ cellprob_threshold = settings[f'{object_type}_CP_prob']
2282
+ flow_threshold = 30
2283
+
2284
+ object_settings = _get_object_settings(object_type, settings)
2285
+ model_name = object_settings['model_name']
2286
+
2287
+ cellpose_channels = _get_cellpose_channels(src, settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel'])
2288
+ if settings['verbose']:
2289
+ print(cellpose_channels)
2165
2290
 
2291
+ channels = cellpose_channels[object_type]
2292
+ cellpose_batch_size = _get_cellpose_batch_size()
2293
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
2294
+ model = _choose_model(model_name, device, object_type='cell', restore_type=None)
2295
+ chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [2,0] if model_name == 'cyto' else [2, 0] if model_name == 'cyto3' else [2, 0]
2296
+ paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.npz')]
2297
+
2298
+ count_loc = os.path.dirname(src)+'/measurements/measurements.db'
2299
+ os.makedirs(os.path.dirname(src)+'/measurements', exist_ok=True)
2300
+ _create_database(count_loc)
2301
+
2302
+ average_sizes = []
2303
+ time_ls = []
2304
+ for file_index, path in enumerate(paths):
2305
+ name = os.path.basename(path)
2306
+ name, ext = os.path.splitext(name)
2307
+ output_folder = os.path.join(os.path.dirname(path), object_type+'_mask_stack')
2308
+ os.makedirs(output_folder, exist_ok=True)
2309
+ overall_average_size = 0
2310
+ with np.load(path) as data:
2311
+ stack = data['data']
2312
+ filenames = data['filenames']
2313
+ if settings['timelapse']:
2314
+
2315
+ trackable_objects = ['cell','nucleus','pathogen']
2316
+ if not all_elements_match(settings['timelapse_objects'], trackable_objects):
2317
+ print(f'timelapse_objects {settings["timelapse_objects"]} must be a subset of {trackable_objects}')
2318
+ return
2319
+
2320
+ if len(stack) != batch_size:
2321
+ print(f'Changed batch_size:{batch_size} to {len(stack)}, data length:{len(stack)}')
2322
+ settings['timelapse_batch_size'] = len(stack)
2323
+ batch_size = len(stack)
2324
+ if isinstance(timelapse_frame_limits, list):
2325
+ if len(timelapse_frame_limits) >= 2:
2326
+ stack = stack[timelapse_frame_limits[0]: timelapse_frame_limits[1], :, :, :].astype(stack.dtype)
2327
+ filenames = filenames[timelapse_frame_limits[0]: timelapse_frame_limits[1]]
2328
+ batch_size = len(stack)
2329
+ print(f'Cut batch at indecies: {timelapse_frame_limits}, New batch_size: {batch_size} ')
2330
+
2331
+ for i in range(0, stack.shape[0], batch_size):
2332
+ mask_stack = []
2333
+ start = time.time()
2334
+
2335
+ if stack.shape[3] == 1:
2336
+ batch = stack[i: i+batch_size, :, :, [0,0]].astype(stack.dtype)
2337
+ else:
2338
+ batch = stack[i: i+batch_size, :, :, channels].astype(stack.dtype)
2339
+
2340
+ batch_filenames = filenames[i: i+batch_size].tolist()
2341
+
2342
+ if not settings['plot']:
2343
+ batch, batch_filenames = _check_masks(batch, batch_filenames, output_folder)
2344
+ if batch.size == 0:
2345
+ print(f'Processing {file_index}/{len(paths)}: Images/npz {batch.shape[0]}')
2346
+ continue
2347
+ if batch.max() > 1:
2348
+ batch = batch / batch.max()
2349
+
2350
+ if timelapse:
2351
+ stitch_threshold=100.0
2352
+ movie_path = os.path.join(os.path.dirname(src), 'movies')
2353
+ os.makedirs(movie_path, exist_ok=True)
2354
+ save_path = os.path.join(movie_path, f'timelapse_{object_type}_{name}.mp4')
2355
+ _npz_to_movie(batch, batch_filenames, save_path, fps=2)
2356
+ else:
2357
+ stitch_threshold=0.0
2358
+
2359
+ print('batch.shape',batch.shape)
2360
+ masks, flows, _, _ = model.eval(x=batch,
2361
+ batch_size=cellpose_batch_size,
2362
+ normalize=False,
2363
+ channels=chans,
2364
+ channel_axis=3,
2365
+ diameter=object_settings['diameter'],
2366
+ flow_threshold=flow_threshold,
2367
+ cellprob_threshold=cellprob_threshold,
2368
+ rescale=None,
2369
+ resample=object_settings['resample'],
2370
+ stitch_threshold=stitch_threshold)
2371
+
2166
2372
  if timelapse:
2373
+
2374
+ if settings['plot']:
2375
+ for idx, (mask, flow, image) in enumerate(zip(masks, flows[0], batch)):
2376
+ if idx == 0:
2377
+ num_objects = mask_object_count(mask)
2378
+ print(f'Number of objects: {num_objects}')
2379
+ plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2380
+
2167
2381
  _save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_timelapse')
2168
2382
  if object_type in timelapse_objects:
2169
2383
  if timelapse_mode == 'btrack':
@@ -2192,13 +2406,13 @@ def generate_cellpose_masks(src, settings, object_type):
2192
2406
  name=name,
2193
2407
  batch_filenames=batch_filenames,
2194
2408
  object_type=object_type,
2195
- masks_3D=masks,
2409
+ masks=masks,
2196
2410
  timelapse_displacement=timelapse_displacement,
2197
2411
  timelapse_memory=timelapse_memory,
2198
2412
  timelapse_remove_transient=timelapse_remove_transient,
2199
2413
  plot=settings['plot'],
2200
2414
  save=settings['save'],
2201
- timelapse_mode=timelapse_mode)
2415
+ mode=timelapse_mode)
2202
2416
  else:
2203
2417
  mask_stack = _masks_to_masks_stack(masks)
2204
2418
 
@@ -2207,15 +2421,12 @@ def generate_cellpose_masks(src, settings, object_type):
2207
2421
  mask_stack = _filter_cp_masks(masks=masks,
2208
2422
  flows=flows,
2209
2423
  filter_size=object_settings['filter_size'],
2424
+ filter_intensity=object_settings['filter_intensity'],
2210
2425
  minimum_size=object_settings['minimum_size'],
2211
2426
  maximum_size=object_settings['maximum_size'],
2212
2427
  remove_border_objects=object_settings['remove_border_objects'],
2213
2428
  merge=False,
2214
- filter_dimm=object_settings['filter_dimm'],
2215
2429
  batch=batch,
2216
- moving_avg_q1=moving_avg_q1,
2217
- moving_avg_q3=moving_avg_q3,
2218
- moving_count=moving_count,
2219
2430
  plot=settings['plot'],
2220
2431
  figuresize=figuresize)
2221
2432
 
@@ -2235,7 +2446,7 @@ def generate_cellpose_masks(src, settings, object_type):
2235
2446
  average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
2236
2447
  time_in_min = average_time/60
2237
2448
  time_per_mask = average_time/batch_size
2238
- print(f'Processing {len(paths)} files with {batch_size} imgs: {(file_index+1)*(batch_size+1)}/{(len(paths))*(batch_size+1)}: Time/batch {time_in_min:.3f} min: Time/mask {time_per_mask:.3f}sec: {object_type} size: {overall_average_size:.3f} px2', end='\r', flush=True)
2449
+ print(f'Processing {len(paths)} files with {batch_size} imgs: {(file_index+1)*(batch_size+1)}/{(len(paths))*(batch_size+1)}: Time/batch {time_in_min:.3f} min: Time/mask {time_per_mask:.3f}sec: {object_type} size: {overall_average_size:.3f} px2')
2239
2450
  if not timelapse:
2240
2451
  if settings['plot']:
2241
2452
  plot_masks(batch, mask_stack, flows, figuresize=figuresize, cmap='inferno', nr=batch_size)