spacr 0.0.18__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/timelapse.py CHANGED
@@ -1,4 +1,4 @@
1
- import cv2, os, re, glob, random, btrack
1
+ import cv2, os, re, glob, random, btrack, sqlite3
2
2
  import numpy as np
3
3
  import pandas as pd
4
4
  from collections import defaultdict
@@ -9,6 +9,10 @@ from IPython.display import Image as ipyimage
9
9
  import trackpy as tp
10
10
  from btrack import datasets as btrack_datasets
11
11
  from skimage.measure import regionprops
12
+ from scipy.signal import find_peaks
13
+ from scipy.optimize import curve_fit
14
+ from scipy.integrate import trapz
15
+ import matplotlib.pyplot as plt
12
16
 
13
17
  from .logger import log_function_call
14
18
 
@@ -144,56 +148,6 @@ def _sort_key(file_path):
144
148
  # Return a tuple that sorts this file as "earliest" or "lowest"
145
149
  return ('', '', '', 0)
146
150
 
147
- def _save_mask_timelapse_as_gif(masks, path, cmap, norm, filenames):
148
- """
149
- Save a timelapse of masks as a GIF.
150
-
151
- Parameters:
152
- masks (list): List of mask frames.
153
- path (str): Path to save the GIF.
154
- cmap: Colormap for displaying the masks.
155
- norm: Normalization for the masks.
156
- filenames (list): List of filenames corresponding to each mask frame.
157
-
158
- Returns:
159
- None
160
- """
161
- def _update(frame):
162
- """
163
- Update the plot with the given frame.
164
-
165
- Parameters:
166
- frame (int): The frame number to update the plot with.
167
-
168
- Returns:
169
- None
170
- """
171
- nonlocal filename_text_obj
172
- if filename_text_obj is not None:
173
- filename_text_obj.remove()
174
- ax.clear()
175
- ax.axis('off')
176
- current_mask = masks[frame]
177
- ax.imshow(current_mask, cmap=cmap, norm=norm)
178
- ax.set_title(f'Frame: {frame}', fontsize=24, color='white')
179
- filename_text = filenames[frame]
180
- filename_text_obj = fig.text(0.5, 0.01, filename_text, ha='center', va='center', fontsize=20, color='white')
181
- for label_value in np.unique(current_mask):
182
- if label_value == 0: continue # Skip background
183
- y, x = np.mean(np.where(current_mask == label_value), axis=1)
184
- ax.text(x, y, str(label_value), color='white', fontsize=24, ha='center', va='center')
185
-
186
- fig, ax = plt.subplots(figsize=(50, 50), facecolor='black')
187
- ax.set_facecolor('black')
188
- ax.axis('off')
189
- plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)
190
-
191
- filename_text_obj = None
192
- anim = FuncAnimation(fig, _update, frames=len(masks), blit=False)
193
- anim.save(path, writer='pillow', fps=2, dpi=80) # Adjust DPI for size/quality
194
- plt.close(fig)
195
- print(f'Saved timelapse to {path}')
196
-
197
151
  def _masks_to_gif(masks, gif_folder, name, filenames, object_type):
198
152
  """
199
153
  Converts a sequence of masks into a GIF file.
@@ -208,6 +162,9 @@ def _masks_to_gif(masks, gif_folder, name, filenames, object_type):
208
162
  Returns:
209
163
  None
210
164
  """
165
+
166
+ from .io import _save_mask_timelapse_as_gif
167
+
211
168
  def _display_gif(path):
212
169
  with open(path, 'rb') as file:
213
170
  display(ipyimage(file.read()))
@@ -220,7 +177,7 @@ def _masks_to_gif(masks, gif_folder, name, filenames, object_type):
220
177
  norm = plt.cm.colors.Normalize(vmin=0, vmax=highest_label)
221
178
 
222
179
  save_path_gif = os.path.join(gif_folder, f'timelapse_masks_{object_type}_{name}.gif')
223
- _save_mask_timelapse_as_gif(masks, save_path_gif, cmap, norm, filenames)
180
+ _save_mask_timelapse_as_gif(masks, None, save_path_gif, cmap, norm, filenames)
224
181
  #_display_gif(save_path_gif)
225
182
 
226
183
  def _timelapse_masks_to_gif(folder_path, mask_channels, object_types):
@@ -450,6 +407,8 @@ def _trackpy_track_cells(src, name, batch_filenames, object_type, masks, timelap
450
407
  from .plot import _visualize_and_save_timelapse_stack_with_tracks
451
408
  from .utils import _masks_to_masks_stack
452
409
 
410
+ print(f'Tracking objects with trackpy')
411
+
453
412
  if timelapse_displacement is None:
454
413
  features = _prepare_for_tracking(masks)
455
414
  timelapse_displacement = _find_optimal_search_range(features, initial_search_range=500, increment=10, max_attempts=49, memory=3)
@@ -574,3 +533,335 @@ def _btrack_track_cells(src, name, batch_filenames, object_type, plot, save, mas
574
533
 
575
534
  mask_stack = _masks_to_masks_stack(masks)
576
535
  return mask_stack
536
+
537
+ def exponential_decay(x, a, b, c):
538
+ return a * np.exp(-b * x) + c
539
+
540
+ def preprocess_pathogen_data(pathogen_df):
541
+ # Group by identifiers and count the number of parasites
542
+ parasite_counts = pathogen_df.groupby(['plate', 'row', 'col', 'field', 'timeid', 'pathogen_cell_id']).size().reset_index(name='parasite_count')
543
+
544
+ # Aggregate numerical columns and take the first of object columns
545
+ agg_funcs = {col: 'mean' if np.issubdtype(pathogen_df[col].dtype, np.number) else 'first' for col in pathogen_df.columns if col not in ['plate', 'row', 'col', 'field', 'timeid', 'pathogen_cell_id', 'parasite_count']}
546
+ pathogen_agg = pathogen_df.groupby(['plate', 'row', 'col', 'field', 'timeid', 'pathogen_cell_id']).agg(agg_funcs).reset_index()
547
+
548
+ # Merge the counts back into the aggregated data
549
+ pathogen_agg = pathogen_agg.merge(parasite_counts, on=['plate', 'row', 'col', 'field', 'timeid', 'pathogen_cell_id'])
550
+
551
+ # Remove the object_label column as it corresponds to the pathogen ID not the cell ID
552
+ if 'object_label' in pathogen_agg.columns:
553
+ pathogen_agg.drop(columns=['object_label'], inplace=True)
554
+
555
+ # Change the name of pathogen_cell_id to object_label
556
+ pathogen_agg.rename(columns={'pathogen_cell_id': 'object_label'}, inplace=True)
557
+
558
+ return pathogen_agg
559
+
560
+ def plot_data(measurement, group, ax, label, marker='o', linestyle='-'):
561
+ ax.plot(group['time'], group['delta_' + measurement], marker=marker, linestyle=linestyle, label=label)
562
+
563
+ def infected_vs_noninfected(result_df, measurement):
564
+ # Separate the merged dataframe into two groups based on pathogen_count
565
+ infected_cells_df = result_df[result_df.groupby('plate_row_column_field_object')['parasite_count'].transform('max') > 0]
566
+ uninfected_cells_df = result_df[result_df.groupby('plate_row_column_field_object')['parasite_count'].transform('max') == 0]
567
+
568
+ # Plotting
569
+ fig, axs = plt.subplots(2, 1, figsize=(12, 10), sharex=True)
570
+
571
+ # Plot for cells that were infected at some time
572
+ for group_id in infected_cells_df['plate_row_column_field_object'].unique():
573
+ group = infected_cells_df[infected_cells_df['plate_row_column_field_object'] == group_id]
574
+ plot_data(measurement, group, axs[0], 'Infected', marker='x')
575
+
576
+ # Plot for cells that were never infected
577
+ for group_id in uninfected_cells_df['plate_row_column_field_object'].unique():
578
+ group = uninfected_cells_df[uninfected_cells_df['plate_row_column_field_object'] == group_id]
579
+ plot_data(measurement, group, axs[1], 'Uninfected')
580
+
581
+ # Set the titles and labels
582
+ axs[0].set_title('Cells Infected at Some Time')
583
+ axs[1].set_title('Cells Never Infected')
584
+ for ax in axs:
585
+ ax.set_xlabel('Time')
586
+ ax.set_ylabel('Normalized Delta ' + measurement)
587
+ all_timepoints = sorted(result_df['time'].unique())
588
+ ax.set_xticks(all_timepoints)
589
+ ax.set_xticklabels(all_timepoints, rotation=45, ha="right")
590
+
591
+ plt.tight_layout()
592
+ plt.show()
593
+
594
+ def save_figure(fig, src, figure_number):
595
+ source = os.path.dirname(src)
596
+ results_fldr = os.path.join(source,'results')
597
+ os.makedirs(results_fldr, exist_ok=True)
598
+ fig_loc = os.path.join(results_fldr, f'figure_{figure_number}.pdf')
599
+ fig.savefig(fig_loc)
600
+ print(f'Saved figure:{fig_loc}')
601
+
602
+ def save_results_dataframe(df, src, results_name):
603
+ source = os.path.dirname(src)
604
+ results_fldr = os.path.join(source,'results')
605
+ os.makedirs(results_fldr, exist_ok=True)
606
+ csv_loc = os.path.join(results_fldr, f'{results_name}.csv')
607
+ df.to_csv(csv_loc, index=True)
608
+ print(f'Saved results:{csv_loc}')
609
+
610
+ def summarize_per_well(peak_details_df):
611
+ # Step 1: Split the 'ID' column
612
+ split_columns = peak_details_df['ID'].str.split('_', expand=True)
613
+ peak_details_df[['plate', 'row', 'column', 'field', 'object_number']] = split_columns
614
+
615
+ # Step 2: Create 'well_ID' by combining 'row' and 'column'
616
+ peak_details_df['well_ID'] = peak_details_df['row'] + '_' + peak_details_df['column']
617
+
618
+ # Filter entries where 'amplitude' is not null
619
+ filtered_df = peak_details_df[peak_details_df['amplitude'].notna()]
620
+
621
+ # Preparation for Step 3: Identify numeric columns for averaging from the filtered dataframe
622
+ numeric_cols = filtered_df.select_dtypes(include=['number']).columns
623
+
624
+ # Step 3: Calculate summary statistics
625
+ summary_df = filtered_df.groupby('well_ID').agg(
626
+ peaks_per_well=('ID', 'size'),
627
+ unique_IDs_with_amplitude=('ID', 'nunique'), # Count unique IDs per well with non-null amplitude
628
+ **{col: (col, 'mean') for col in numeric_cols} # exclude 'amplitude' from averaging if it's numeric
629
+ ).reset_index()
630
+
631
+ # Step 3: Calculate summary statistics
632
+ summary_df_2 = peak_details_df.groupby('well_ID').agg(
633
+ cells_per_well=('object_number', 'nunique'),
634
+ ).reset_index()
635
+
636
+ summary_df['cells_per_well'] = summary_df_2['cells_per_well']
637
+ summary_df['peaks_per_cell'] = summary_df['peaks_per_well'] / summary_df['cells_per_well']
638
+
639
+ return summary_df
640
+
641
+ def summarize_per_well_inf_non_inf(peak_details_df):
642
+ # Step 1: Split the 'ID' column
643
+ split_columns = peak_details_df['ID'].str.split('_', expand=True)
644
+ peak_details_df[['plate', 'row', 'column', 'field', 'object_number']] = split_columns
645
+
646
+ # Step 2: Create 'well_ID' by combining 'row' and 'column'
647
+ peak_details_df['well_ID'] = peak_details_df['row'] + '_' + peak_details_df['column']
648
+
649
+ # Assume 'pathogen_count' indicates infection if > 0
650
+ # Add an 'infected_status' column to classify cells
651
+ peak_details_df['infected_status'] = peak_details_df['infected'].apply(lambda x: 'infected' if x > 0 else 'non_infected')
652
+
653
+ # Preparation for Step 3: Identify numeric columns for averaging
654
+ numeric_cols = peak_details_df.select_dtypes(include=['number']).columns
655
+
656
+ # Step 3: Calculate summary statistics
657
+ summary_df = peak_details_df.groupby(['well_ID', 'infected_status']).agg(
658
+ cells_per_well=('object_number', 'nunique'),
659
+ peaks_per_well=('ID', 'size'),
660
+ **{col: (col, 'mean') for col in numeric_cols}
661
+ ).reset_index()
662
+
663
+ # Calculate peaks per cell
664
+ summary_df['peaks_per_cell'] = summary_df['peaks_per_well'] / summary_df['cells_per_well']
665
+
666
+ return summary_df
667
+
668
+ def analyze_calcium_oscillations(db_loc, measurement='cell_channel_1_mean_intensity', size_filter='cell_area', fluctuation_threshold=0.25, num_lines=None, peak_height=0.01, pathogen=None, cytoplasm=None, remove_transient=True, verbose=False, transience_threshold=0.9):
669
+ # Load data
670
+ conn = sqlite3.connect(db_loc)
671
+ # Load cell table
672
+ cell_df = pd.read_sql(f"SELECT * FROM {'cell'}", conn)
673
+
674
+ if pathogen:
675
+ pathogen_df = pd.read_sql("SELECT * FROM pathogen", conn)
676
+ pathogen_df['pathogen_cell_id'] = pathogen_df['pathogen_cell_id'].astype(float).astype('Int64')
677
+ pathogen_df = preprocess_pathogen_data(pathogen_df)
678
+ cell_df = cell_df.merge(pathogen_df, on=['plate', 'row', 'col', 'field', 'timeid', 'object_label'], how='left', suffixes=('', '_pathogen'))
679
+ cell_df['parasite_count'] = cell_df['parasite_count'].fillna(0)
680
+ print(f'After pathogen merge: {len(cell_df)} objects')
681
+
682
+ # Optionally load cytoplasm table and merge
683
+ if cytoplasm:
684
+ cytoplasm_df = pd.read_sql(f"SELECT * FROM {'cytoplasm'}", conn)
685
+ # Merge on specified columns
686
+ cell_df = cell_df.merge(cytoplasm_df, on=['plate', 'row', 'col', 'field', 'timeid', 'object_label'], how='left', suffixes=('', '_cytoplasm'))
687
+
688
+ print(f'After cytoplasm merge: {len(cell_df)} objects')
689
+
690
+ conn.close()
691
+
692
+ # Continue with your existing processing on cell_df now containing merged data...
693
+ # Prepare DataFrame (use cell_df instead of df)
694
+ prcf_components = cell_df['prcf'].str.split('_', expand=True)
695
+ cell_df['plate'] = prcf_components[0]
696
+ cell_df['row'] = prcf_components[1]
697
+ cell_df['column'] = prcf_components[2]
698
+ cell_df['field'] = prcf_components[3]
699
+ cell_df['time'] = prcf_components[4].str.extract('t(\d+)').astype(int)
700
+ cell_df['object_number'] = cell_df['object_label']
701
+ cell_df['plate_row_column_field_object'] = cell_df['plate'].astype(str) + '_' + cell_df['row'].astype(str) + '_' + cell_df['column'].astype(str) + '_' + cell_df['field'].astype(str) + '_' + cell_df['object_label'].astype(str)
702
+
703
+ df = cell_df.copy()
704
+
705
+ # Fit exponential decay model to all scaled fluorescence data
706
+ try:
707
+ params, _ = curve_fit(exponential_decay, df['time'], df[measurement], p0=[max(df[measurement]), 0.01, min(df[measurement])], maxfev=10000)
708
+ df['corrected_' + measurement] = df[measurement] / exponential_decay(df['time'], *params)
709
+ except RuntimeError as e:
710
+ print(f"Curve fitting failed for the entire dataset with error: {e}")
711
+ return
712
+ if verbose:
713
+ print(f'Analyzing: {len(df)} objects')
714
+
715
+ # Normalizing corrected fluorescence for each cell
716
+ corrected_dfs = []
717
+ peak_details_list = []
718
+ total_timepoints = df['time'].nunique()
719
+ size_filter_removed = 0
720
+ transience_removed = 0
721
+
722
+ for unique_id, group in df.groupby('plate_row_column_field_object'):
723
+ group = group.sort_values('time')
724
+ if remove_transient:
725
+
726
+ threshold = int(transience_threshold * total_timepoints)
727
+
728
+ if verbose:
729
+ print(f'Group length: {len(group)} Timelapse length: {total_timepoints}, threshold:{threshold}')
730
+
731
+ if len(group) <= threshold:
732
+ transience_removed += 1
733
+ if verbose:
734
+ print(f'removed group {unique_id} due to transience')
735
+ continue
736
+
737
+ size_diff = group[size_filter].std() / group[size_filter].mean()
738
+
739
+ if size_diff <= fluctuation_threshold:
740
+ group['delta_' + measurement] = group['corrected_' + measurement].diff().fillna(0)
741
+ corrected_dfs.append(group)
742
+
743
+ # Detect peaks
744
+ peaks, properties = find_peaks(group['delta_' + measurement], height=peak_height)
745
+
746
+ # Set values < 0 to 0
747
+ group_filtered = group.copy()
748
+ group_filtered['delta_' + measurement] = group['delta_' + measurement].clip(lower=0)
749
+ above_zero_auc = trapz(y=group_filtered['delta_' + measurement], x=group_filtered['time'])
750
+ auc = trapz(y=group['delta_' + measurement], x=group_filtered['time'])
751
+ is_infected = (group['parasite_count'] > 0).any()
752
+
753
+ if is_infected:
754
+ is_infected = 1
755
+ else:
756
+ is_infected = 0
757
+
758
+ if len(peaks) == 0:
759
+ peak_details_list.append({
760
+ 'ID': unique_id,
761
+ 'plate': group['plate'].iloc[0],
762
+ 'row': group['row'].iloc[0],
763
+ 'column': group['column'].iloc[0],
764
+ 'field': group['field'].iloc[0],
765
+ 'object_number': group['object_number'].iloc[0],
766
+ 'time': np.nan, # The time of the peak
767
+ 'amplitude': np.nan,
768
+ 'delta': np.nan,
769
+ 'AUC': auc,
770
+ 'AUC_positive': above_zero_auc,
771
+ 'AUC_peak': np.nan,
772
+ 'infected': is_infected
773
+ })
774
+
775
+ # Inside the for loop where peaks are detected
776
+ for i, peak in enumerate(peaks):
777
+
778
+ amplitude = properties['peak_heights'][i]
779
+ peak_time = group['time'].iloc[peak]
780
+ pathogen_count_at_peak = group['parasite_count'].iloc[peak]
781
+
782
+ start_idx = max(peak - 1, 0)
783
+ end_idx = min(peak + 1, len(group) - 1)
784
+
785
+ # Using indices to slice for AUC calculation
786
+ peak_segment_y = group['delta_' + measurement].iloc[start_idx:end_idx + 1]
787
+ peak_segment_x = group['time'].iloc[start_idx:end_idx + 1]
788
+ peak_auc = trapz(y=peak_segment_y, x=peak_segment_x)
789
+
790
+ peak_details_list.append({
791
+ 'ID': unique_id,
792
+ 'plate': group['plate'].iloc[0],
793
+ 'row': group['row'].iloc[0],
794
+ 'column': group['column'].iloc[0],
795
+ 'field': group['field'].iloc[0],
796
+ 'object_number': group['object_number'].iloc[0],
797
+ 'time': peak_time, # The time of the peak
798
+ 'amplitude': amplitude,
799
+ 'delta': group['delta_' + measurement].iloc[peak],
800
+ 'AUC': auc,
801
+ 'AUC_positive': above_zero_auc,
802
+ 'AUC_peak': peak_auc,
803
+ 'infected': pathogen_count_at_peak
804
+ })
805
+ else:
806
+ size_filter_removed += 1
807
+
808
+ if verbose:
809
+ print(f'Removed {size_filter_removed} objects due to size filter fluctuation')
810
+ print(f'Removed {transience_removed} objects due to transience')
811
+
812
+ if len(corrected_dfs) > 0:
813
+ result_df = pd.concat(corrected_dfs)
814
+ else:
815
+ print("No suitable cells found for analysis")
816
+ return
817
+
818
+ peak_details_df = pd.DataFrame(peak_details_list)
819
+ summary_df = summarize_per_well(peak_details_df)
820
+ summary_df_inf_non_inf = summarize_per_well_inf_non_inf(peak_details_df)
821
+
822
+ save_results_dataframe(df=peak_details_df, src=db_loc, results_name='peak_details')
823
+ save_results_dataframe(df=result_df, src=db_loc, results_name='results')
824
+ save_results_dataframe(df=summary_df, src=db_loc, results_name='well_results')
825
+ save_results_dataframe(df=summary_df_inf_non_inf, src=db_loc, results_name='well_results_inf_non_inf')
826
+
827
+ # Plotting
828
+ fig, ax = plt.subplots(figsize=(10, 8))
829
+ sampled_groups = result_df['plate_row_column_field_object'].unique()
830
+ if num_lines is not None and 0 < num_lines < len(sampled_groups):
831
+ sampled_groups = np.random.choice(sampled_groups, size=num_lines, replace=False)
832
+
833
+ for group_id in sampled_groups:
834
+ group = result_df[result_df['plate_row_column_field_object'] == group_id]
835
+ ax.plot(group['time'], group['delta_' + measurement], marker='o', linestyle='-')
836
+
837
+ ax.set_xticks(sorted(df['time'].unique()))
838
+ ax.set_xticklabels(sorted(df['time'].unique()), rotation=45, ha="right")
839
+ ax.set_title(f'Normalized Delta of {measurement} Over Time (Corrected for Photobleaching)')
840
+ ax.set_xlabel('Time')
841
+ ax.set_ylabel('Normalized Delta ' + measurement)
842
+ plt.tight_layout()
843
+
844
+ plt.show()
845
+
846
+ save_figure(fig, src=db_loc, figure_number=1)
847
+
848
+ if pathogen:
849
+ infected_vs_noninfected(result_df, measurement)
850
+ save_figure(fig, src=db_loc, figure_number=2)
851
+
852
+ # Identify cells with and without pathogens
853
+ infected_cells = result_df[result_df.groupby('plate_row_column_field_object')['parasite_count'].transform('max') > 0]['plate_row_column_field_object'].unique()
854
+ noninfected_cells = result_df[result_df.groupby('plate_row_column_field_object')['parasite_count'].transform('max') == 0]['plate_row_column_field_object'].unique()
855
+
856
+ # Peaks in infected and noninfected cells
857
+ infected_peaks = peak_details_df[peak_details_df['ID'].isin(infected_cells)]
858
+ noninfected_peaks = peak_details_df[peak_details_df['ID'].isin(noninfected_cells)]
859
+
860
+ # Calculate the average number of peaks per cell
861
+ avg_inf_peaks_per_cell = len(infected_peaks) / len(infected_cells) if len(infected_cells) > 0 else 0
862
+ avg_non_inf_peaks_per_cell = len(noninfected_peaks) / len(noninfected_cells) if len(noninfected_cells) > 0 else 0
863
+
864
+ print(f'Average number of peaks per infected cell: {avg_inf_peaks_per_cell:.2f}')
865
+ print(f'Average number of peaks per non-infected cell: {avg_non_inf_peaks_per_cell:.2f}')
866
+ print(f'done')
867
+ return result_df, peak_details_df, fig
spacr/train.py CHANGED
@@ -194,8 +194,8 @@ def test_model_performance(loaders, model, loader_name_list, epoch, train_mode,
194
194
 
195
195
  def train_test_model(src, settings, custom_model=False, custom_model_path=None):
196
196
 
197
- from .io import save_settings, _copy_missclassified
198
- from .utils import pick_best_model, test_model_performance
197
+ from .io import _save_settings, _copy_missclassified
198
+ from .utils import pick_best_model
199
199
  from .core import generate_loaders
200
200
 
201
201
  settings['src'] = src
@@ -208,7 +208,7 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
208
208
  model = torch.load(custom_model_path)
209
209
 
210
210
  if settings['train']:
211
- save_settings(settings, src)
211
+ _save_settings(settings, src)
212
212
  torch.cuda.empty_cache()
213
213
  torch.cuda.memory.empty_cache()
214
214
  gc.collect()
@@ -227,7 +227,9 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
227
227
  validation_split=settings['val_split'],
228
228
  pin_memory=settings['pin_memory'],
229
229
  normalize=settings['normalize'],
230
- verbose=settings['verbose'])
230
+ channels=settings['channels'],
231
+ verbose=settings['verbose'])
232
+
231
233
 
232
234
  if settings['test']:
233
235
  test, _, plate_names_test = generate_loaders(src,
@@ -240,6 +242,7 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
240
242
  validation_split=0.0,
241
243
  pin_memory=settings['pin_memory'],
242
244
  normalize=settings['normalize'],
245
+ channels=settings['channels'],
243
246
  verbose=settings['verbose'])
244
247
  if model == None:
245
248
  model_path = pick_best_model(src+'/model')
@@ -330,8 +333,8 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
330
333
  None
331
334
  """
332
335
 
333
- from .io import save_model, save_progress
334
- from .utils import evaluate_model_performance, compute_irm_penalty, calculate_loss, choose_model
336
+ from .io import _save_model, _save_progress
337
+ from .utils import compute_irm_penalty, calculate_loss, choose_model #evaluate_model_performance,
335
338
 
336
339
  print(f'Train batches:{len(train_loaders)}, Validation batches:{len(val_loaders)}')
337
340
 
@@ -347,6 +350,11 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
347
350
  break
348
351
 
349
352
  model = choose_model(model_type, device, init_weights, dropout_rate, use_checkpoint)
353
+
354
+ if model is None:
355
+ print(f'Model {model_type} not found')
356
+ return
357
+
350
358
  model.to(device)
351
359
 
352
360
  if optimizer_type == 'adamw':
@@ -421,10 +429,10 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
421
429
  if schedule == 'step_lr':
422
430
  scheduler.step()
423
431
 
424
- save_progress(dst, results_df, train_metrics_df)
432
+ _save_progress(dst, results_df, train_metrics_df)
425
433
  clear_output(wait=True)
426
434
  display(results_df)
427
- save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
435
+ _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
428
436
 
429
437
  if train_mode == 'irm':
430
438
  dummy_w = torch.nn.Parameter(torch.Tensor([1.0])).to(device)
@@ -494,7 +502,7 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
494
502
 
495
503
  clear_output(wait=True)
496
504
  display(results_df)
497
- save_progress(dst, results_df, train_metrics_df)
498
- save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
505
+ _save_progress(dst, results_df, train_metrics_df)
506
+ _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
499
507
  print(f'Saved model: {dst}')
500
508
  return