spacr 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/io.py CHANGED
@@ -22,6 +22,7 @@ from torchvision.transforms import ToTensor
22
22
  import seaborn as sns
23
23
  from nd2reader import ND2Reader
24
24
  from torchvision import transforms
25
+ from sklearn.model_selection import train_test_split
25
26
 
26
27
  def process_non_tif_non_2D_images(folder):
27
28
  """Processes all images in the folder and splits them into grayscale channels, preserving bit depth."""
@@ -984,47 +985,6 @@ def _move_to_chan_folder(src, regex, timelapse=False, metadata_type=''):
984
985
  shutil.move(os.path.join(src, filename), move)
985
986
  return
986
987
 
987
- def _merge_channels_v2(src, plot=False):
988
- from .plot import plot_arrays
989
- """
990
- Merge the channels in the given source directory and save the merged files in a 'stack' directory.
991
-
992
- Args:
993
- src (str): The path to the source directory containing the channel folders.
994
- plot (bool, optional): Whether to plot the merged arrays. Defaults to False.
995
-
996
- Returns:
997
- None
998
- """
999
- src = Path(src)
1000
- stack_dir = src / 'stack'
1001
- chan_dirs = [d for d in src.iterdir() if d.is_dir() and d.name in ['01', '02', '03', '04', '00', '1', '2', '3', '4','0']]
1002
-
1003
- chan_dirs.sort(key=lambda x: x.name)
1004
- print(f'List of folders in src: {[d.name for d in chan_dirs]}. Single channel folders.')
1005
- start_time = time.time()
1006
-
1007
- # First directory and its files
1008
- dir_files = list(chan_dirs[0].iterdir())
1009
-
1010
- # Create the 'stack' directory if it doesn't exist
1011
- stack_dir.mkdir(exist_ok=True)
1012
- print(f'generated folder with merged arrays: {stack_dir}')
1013
-
1014
- if _is_dir_empty(stack_dir):
1015
- with Pool(max(cpu_count() // 2, 1)) as pool:
1016
- #with Pool(cpu_count()) as pool:
1017
- merge_func = partial(_merge_file, chan_dirs, stack_dir)
1018
- pool.map(merge_func, dir_files)
1019
-
1020
- avg_time = (time.time() - start_time) / len(dir_files)
1021
- print(f'Average Time: {avg_time:.3f} sec')
1022
-
1023
- if plot:
1024
- plot_arrays(src+'/stack')
1025
-
1026
- return
1027
-
1028
988
  def _merge_channels(src, plot=False):
1029
989
  """
1030
990
  Merge the channels in the given source directory and save the merged files in a 'stack' directory without using multiprocessing.
@@ -2384,12 +2344,8 @@ def _results_to_csv(src, df, df_well):
2384
2344
  wells.to_csv(wells_loc, index=True, header=True)
2385
2345
  cells.to_csv(cells_loc, index=True, header=True)
2386
2346
  return cells, wells
2387
-
2388
- ###################################################
2389
- # Classify
2390
- ###################################################
2391
2347
 
2392
- def read_plot_model_stats(file_path ,save=False):
2348
+ def read_plot_model_stats(train_file_path, val_file_path ,save=False):
2393
2349
 
2394
2350
  def _plot_and_save(train_df, val_df, column='accuracy', save=False, path=None, dpi=600):
2395
2351
 
@@ -2418,37 +2374,19 @@ def read_plot_model_stats(file_path ,save=False):
2418
2374
  plt.savefig(pdf_path, format='pdf', dpi=dpi)
2419
2375
  else:
2420
2376
  plt.show()
2421
- # Read the CSV into a dataframe
2422
- df = pd.read_csv(file_path, index_col=0)
2423
-
2424
- # Split the dataframe into train and validation based on the index
2425
- train_df = df.filter(like='_train', axis=0).copy()
2426
- val_df = df.filter(like='_val', axis=0).copy()
2427
-
2428
- fldr_1 = os.path.dirname(file_path)
2429
-
2430
- train_csv_path = os.path.join(fldr_1, 'train.csv')
2431
- val_csv_path = os.path.join(fldr_1, 'validation.csv')
2432
2377
 
2433
- fldr_2 = os.path.dirname(fldr_1)
2434
- fldr_3 = os.path.dirname(fldr_2)
2435
- bn_1 = os.path.basename(fldr_1)
2436
- bn_2 = os.path.basename(fldr_2)
2437
- bn_3 = os.path.basename(fldr_3)
2438
- model_name = str(f'{bn_1}_{bn_2}_{bn_3}')
2378
+ # Read the CSVs into DataFrames
2379
+ train_df = pd.read_csv(train_file_path, index_col=0)
2380
+ val_df = pd.read_csv(val_file_path, index_col=0)
2439
2381
 
2440
- # Extract epochs from index
2441
- train_df['epoch'] = [int(idx.split('_')[0]) for idx in train_df.index]
2442
- val_df['epoch'] = [int(idx.split('_')[0]) for idx in val_df.index]
2443
-
2444
- # Save dataframes to a CSV file
2445
- train_df.to_csv(train_csv_path)
2446
- val_df.to_csv(val_csv_path)
2382
+ # Get the folder path for saving plots
2383
+ fldr_1 = os.path.dirname(train_file_path)
2447
2384
 
2448
2385
  if save:
2449
2386
  # Setting the style
2450
2387
  sns.set(style="whitegrid")
2451
2388
 
2389
+ # Plot and save the results
2452
2390
  _plot_and_save(train_df, val_df, column='accuracy', save=save, path=fldr_1)
2453
2391
  _plot_and_save(train_df, val_df, column='neg_accuracy', save=save, path=fldr_1)
2454
2392
  _plot_and_save(train_df, val_df, column='pos_accuracy', save=save, path=fldr_1)
@@ -2496,50 +2434,53 @@ def _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_
2496
2434
 
2497
2435
  return model_path
2498
2436
 
2499
- def _save_progress(dst, results_df, result_type='train'):
2437
+ def _save_progress(dst, train_df, validation_df):
2500
2438
  """
2501
2439
  Save the progress of the classification model.
2502
2440
 
2503
2441
  Parameters:
2504
2442
  dst (str): The destination directory to save the progress.
2505
- results_df (pandas.DataFrame): The DataFrame containing accuracy, loss, and PRAUC.
2506
- train_metrics_df (pandas.DataFrame): The DataFrame containing training metrics.
2443
+ train_df (pandas.DataFrame): The DataFrame containing training stats.
2444
+ validation_df (pandas.DataFrame): The DataFrame containing validation stats (if available).
2507
2445
 
2508
2446
  Returns:
2509
2447
  None
2510
2448
  """
2449
+
2450
+ def _save_df_to_csv(file_path, df):
2451
+ """
2452
+ Save the given DataFrame to the specified CSV file, either creating a new file or appending to an existing one.
2453
+
2454
+ Parameters:
2455
+ file_path (str): The file path where the CSV will be saved.
2456
+ df (pandas.DataFrame): The DataFrame to save.
2457
+ """
2458
+ if not os.path.exists(file_path):
2459
+ with open(file_path, 'w') as f:
2460
+ df.to_csv(f, index=True, header=True)
2461
+ f.flush() # Ensure data is written to the file system
2462
+ else:
2463
+ with open(file_path, 'a') as f:
2464
+ df.to_csv(f, index=True, header=False)
2465
+ f.flush()
2466
+
2511
2467
  # Save accuracy, loss, PRAUC
2512
2468
  os.makedirs(dst, exist_ok=True)
2513
- results_path = os.path.join(dst, f'{result_type}.csv')
2514
- if not os.path.exists(results_path):
2515
- results_df.to_csv(results_path, index=True, header=True, mode='w')
2516
- else:
2517
- results_df.to_csv(results_path, index=True, header=False, mode='a')
2469
+ results_path_train = os.path.join(dst, 'train.csv')
2470
+ results_path_validation = os.path.join(dst, 'validation.csv')
2518
2471
 
2519
- if result_type == 'train':
2520
- read_plot_model_stats(results_path, save=True)
2521
- return
2472
+ # Save training data
2473
+ _save_df_to_csv(results_path_train, train_df)
2522
2474
 
2523
- def _save_settings(settings, src):
2524
- """
2525
- Save the settings dictionary to a CSV file.
2475
+ # Save validation data if available
2476
+ if validation_df is not None:
2477
+ _save_df_to_csv(results_path_validation, validation_df)
2526
2478
 
2527
- Parameters:
2528
- - settings (dict): A dictionary containing the settings.
2529
- - src (str): The source directory where the settings file will be saved.
2479
+ # Call read_plot_model_stats after ensuring the files are saved
2480
+ read_plot_model_stats(results_path_train, results_path_validation, save=True)
2530
2481
 
2531
- Returns:
2532
- None
2533
- """
2534
- dst = os.path.join(src,'model')
2535
- settings_loc = os.path.join(dst,'settings.csv')
2536
- os.makedirs(dst, exist_ok=True)
2537
- settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
2538
- display(settings_df)
2539
- settings_df.to_csv(settings_loc, index=False)
2540
2482
  return
2541
2483
 
2542
-
2543
2484
  def _copy_missclassified(df):
2544
2485
  misclassified = df[df['true_label'] != df['predicted_label']]
2545
2486
  for _, row in misclassified.iterrows():
@@ -2869,7 +2810,8 @@ def generate_dataset(settings={}):
2869
2810
  all_paths = []
2870
2811
  for i, src in enumerate(settings['src']):
2871
2812
  db_path = os.path.join(src, 'measurements', 'measurements.db')
2872
- dst = os.path.join(src, 'datasets')
2813
+ if i == 0:
2814
+ dst = os.path.join(src, 'datasets')
2873
2815
  paths = generate_path_list_from_db(db_path, file_metadata=settings['file_metadata'])
2874
2816
  correct_paths(paths, src)
2875
2817
  all_paths.extend(paths)
@@ -2917,10 +2859,12 @@ def generate_dataset(settings={}):
2917
2859
 
2918
2860
  # Combine the temporary tar files into a final tar
2919
2861
  date_name = datetime.date.today().strftime('%y%m%d')
2920
- if not settings['file_metadata'] is None:
2921
- tar_name = f"{date_name}_{settings['experiment']}_{settings['file_metadata']}.tar"
2922
- else:
2923
- tar_name = f"{date_name}_{settings['experiment']}.tar"
2862
+ if len(settings['src']) > 1:
2863
+ date_name = f"{date_name}_combined"
2864
+ #if not settings['file_metadata'] is None:
2865
+ # tar_name = f"{date_name}_{settings['experiment']}_{settings['file_metadata']}.tar"
2866
+ #else:
2867
+ tar_name = f"{date_name}_{settings['experiment']}.tar"
2924
2868
  tar_name = os.path.join(dst, tar_name)
2925
2869
  if os.path.exists(tar_name):
2926
2870
  number = random.randint(1, 100)
@@ -2967,7 +2911,6 @@ def generate_loaders(src, mode='train', image_size=224, batch_size=32, classes=[
2967
2911
  - val_loaders (list): List of data loaders for validation datasets.
2968
2912
  """
2969
2913
 
2970
- from .io import spacrDataset
2971
2914
  from .utils import SelectChannels, augment_dataset
2972
2915
 
2973
2916
  chans = []
@@ -3066,10 +3009,6 @@ def generate_loaders(src, mode='train', image_size=224, batch_size=32, classes=[
3066
3009
 
3067
3010
  def generate_training_dataset(settings):
3068
3011
 
3069
- from .io import _read_and_merge_data, _read_db
3070
- from .utils import get_paths_from_db, annotate_conditions, save_settings
3071
- from .settings import set_generate_training_dataset_defaults
3072
-
3073
3012
  # Function to filter png_list_df by prcfo present in df without merging
3074
3013
  def filter_png_list(db_path, settings):
3075
3014
  tables = ['cell', 'nucleus', 'pathogen', 'cytoplasm']
@@ -3173,34 +3112,55 @@ def generate_training_dataset(settings):
3173
3112
  class_paths_ls[i] = random.sample(class_paths, size)
3174
3113
 
3175
3114
  return class_paths_ls
3115
+
3116
+ from .io import _read_and_merge_data, _read_db
3117
+ from .utils import get_paths_from_db, annotate_conditions, save_settings
3118
+ from .settings import set_generate_training_dataset_defaults
3176
3119
 
3177
3120
  # Set default settings and save
3178
3121
  settings = set_generate_training_dataset_defaults(settings)
3179
3122
  save_settings(settings, 'cv_dataset', show=True)
3180
3123
 
3181
- db_path = os.path.join(settings['src'], 'measurements', 'measurements.db')
3182
- dst = os.path.join(settings['src'], 'datasets', 'training')
3124
+ class_path_list = None
3183
3125
 
3184
- # Create a new directory for training data if necessary
3185
- if os.path.exists(dst):
3186
- for i in range(1, 100000):
3187
- dst = os.path.join(settings['src'], 'datasets', f'training_{i}')
3188
- if not os.path.exists(dst):
3189
- print(f'Creating new directory for training: {dst}')
3190
- break
3126
+ if isinstance(settings['src'], str):
3127
+ src = [settings['src']]
3191
3128
 
3192
- # Select dataset based on dataset mode
3193
- if settings['dataset_mode'] == 'annotation':
3194
- class_paths_ls = annotation_based_selection(db_path, dst, settings)
3129
+ for i, src in enumerate(settings['src']):
3130
+ db_path = os.path.join(src, 'measurements', 'measurements.db')
3131
+
3132
+ if len(settings['src']) > 1 and i == 0:
3133
+ dst = os.path.join(src, 'datasets', 'training_all')
3134
+ elif len(settings['src']) == 1:
3135
+ dst = os.path.join(src, 'datasets', 'training')
3136
+
3137
+ # Create a new directory for training data if necessary
3138
+ if os.path.exists(dst):
3139
+ for i in range(1, 100000):
3140
+ dst = dst + f'_{i}'
3141
+ if not os.path.exists(dst):
3142
+ print(f'Creating new directory for training: {dst}')
3143
+ break
3195
3144
 
3196
- elif settings['dataset_mode'] == 'metadata':
3197
- class_paths_ls = metadata_based_selection(db_path, settings)
3145
+ # Select dataset based on dataset mode
3146
+ if settings['dataset_mode'] == 'annotation':
3147
+ class_paths_ls = annotation_based_selection(db_path, dst, settings)
3148
+
3149
+ elif settings['dataset_mode'] == 'metadata':
3150
+ class_paths_ls = metadata_based_selection(db_path, settings)
3151
+
3152
+ elif settings['dataset_mode'] == 'measurement':
3153
+ class_paths_ls = measurement_based_selection(settings, db_path)
3154
+
3155
+ if class_path_list is None:
3156
+ class_path_list = [[] for _ in range(len(class_paths_ls))]
3198
3157
 
3199
- elif settings['dataset_mode'] == 'measurement':
3200
- class_paths_ls = measurement_based_selection(settings, db_path)
3158
+ # Extend each list in class_path_list with the corresponding list from class_paths_ls
3159
+ for idx in range(len(class_paths_ls)):
3160
+ class_path_list[idx].extend(class_paths_ls[idx])
3201
3161
 
3202
3162
  # Generate and return training and testing directories
3203
- train_class_dir, test_class_dir = generate_dataset_from_lists(dst, class_data=class_paths_ls, classes=settings['classes'], test_split=settings['test_split'])
3163
+ train_class_dir, test_class_dir = generate_dataset_from_lists(dst, class_data=class_path_list, classes=settings['classes'], test_split=settings['test_split'])
3204
3164
 
3205
3165
  return train_class_dir, test_class_dir
3206
3166
 
@@ -3234,7 +3194,6 @@ def training_dataset_from_annotation(db_path, dst, annotation_column='test', ann
3234
3194
 
3235
3195
  def generate_dataset_from_lists(dst, class_data, classes, test_split=0.1):
3236
3196
  from .utils import print_progress
3237
- from .deep_spacr import train_test_split
3238
3197
  # Make sure that the length of class_data matches the length of classes
3239
3198
  if len(class_data) != len(classes):
3240
3199
  raise ValueError("class_data and classes must have the same length.")
spacr/measure.py CHANGED
@@ -652,43 +652,6 @@ def img_list_to_grid(grid, titles=None):
652
652
  plt.tight_layout(pad=0.1)
653
653
  return fig
654
654
 
655
- def filepaths_to_database(img_paths, settings, source_folder, crop_mode):
656
- from. utils import _map_wells_png
657
- png_df = pd.DataFrame(img_paths, columns=['png_path'])
658
-
659
- png_df['file_name'] = png_df['png_path'].apply(lambda x: os.path.basename(x))
660
-
661
- parts = png_df['file_name'].apply(lambda x: pd.Series(_map_wells_png(x, timelapse=settings['timelapse'])))
662
-
663
- columns = ['plate', 'row', 'col', 'field']
664
-
665
- if settings['timelapse']:
666
- columns = columns + ['time_id']
667
-
668
- columns = columns + ['prcfo']
669
-
670
- if crop_mode == 'cell':
671
- columns = columns + ['cell_id']
672
-
673
- if crop_mode == 'nucleus':
674
- columns = columns + ['nucleus_id']
675
-
676
- if crop_mode == 'pathogen':
677
- columns = columns + ['pathogen_id']
678
-
679
- if crop_mode == 'cytoplasm':
680
- columns = columns + ['cytoplasm_id']
681
-
682
- png_df[columns] = parts
683
-
684
- try:
685
- conn = sqlite3.connect(f'{source_folder}/measurements/measurements.db', timeout=5)
686
- png_df.to_sql('png_list', conn, if_exists='append', index=False)
687
- conn.commit()
688
- except sqlite3.OperationalError as e:
689
- print(f"SQLite error: {e}", flush=True)
690
- traceback.print_exc()
691
-
692
655
  #@log_function_call
693
656
  def _measure_crop_core(index, time_ls, file, settings):
694
657
 
@@ -711,7 +674,7 @@ def _measure_crop_core(index, time_ls, file, settings):
711
674
  """
712
675
 
713
676
  from .plot import _plot_cropped_arrays
714
- from .utils import _merge_overlapping_objects, _filter_object, _relabel_parent_with_child_labels, _exclude_objects, normalize_to_dtype
677
+ from .utils import _merge_overlapping_objects, _filter_object, _relabel_parent_with_child_labels, _exclude_objects, normalize_to_dtype, filepaths_to_database
715
678
  from .utils import _merge_and_save_to_database, _crop_center, _find_bounding_box, _generate_names, _get_percentiles
716
679
 
717
680
  figs = {}