celldetective 1.0.2.post1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. celldetective/__main__.py +2 -2
  2. celldetective/events.py +2 -44
  3. celldetective/filters.py +4 -5
  4. celldetective/gui/__init__.py +1 -1
  5. celldetective/gui/analyze_block.py +37 -10
  6. celldetective/gui/btrack_options.py +24 -23
  7. celldetective/gui/classifier_widget.py +62 -19
  8. celldetective/gui/configure_new_exp.py +32 -35
  9. celldetective/gui/control_panel.py +115 -81
  10. celldetective/gui/gui_utils.py +674 -396
  11. celldetective/gui/json_readers.py +7 -6
  12. celldetective/gui/layouts.py +755 -0
  13. celldetective/gui/measurement_options.py +168 -487
  14. celldetective/gui/neighborhood_options.py +322 -270
  15. celldetective/gui/plot_measurements.py +1114 -0
  16. celldetective/gui/plot_signals_ui.py +20 -20
  17. celldetective/gui/process_block.py +449 -169
  18. celldetective/gui/retrain_segmentation_model_options.py +27 -26
  19. celldetective/gui/retrain_signal_model_options.py +25 -24
  20. celldetective/gui/seg_model_loader.py +31 -27
  21. celldetective/gui/signal_annotator.py +2326 -2295
  22. celldetective/gui/signal_annotator_options.py +18 -16
  23. celldetective/gui/styles.py +16 -1
  24. celldetective/gui/survival_ui.py +61 -39
  25. celldetective/gui/tableUI.py +60 -23
  26. celldetective/gui/thresholds_gui.py +68 -66
  27. celldetective/gui/viewers.py +596 -0
  28. celldetective/io.py +234 -23
  29. celldetective/measure.py +37 -32
  30. celldetective/neighborhood.py +495 -27
  31. celldetective/preprocessing.py +683 -0
  32. celldetective/scripts/analyze_signals.py +7 -0
  33. celldetective/scripts/measure_cells.py +12 -0
  34. celldetective/scripts/segment_cells.py +5 -0
  35. celldetective/scripts/track_cells.py +11 -0
  36. celldetective/signals.py +221 -98
  37. celldetective/tracking.py +0 -1
  38. celldetective/utils.py +178 -36
  39. celldetective-1.1.0.dist-info/METADATA +305 -0
  40. celldetective-1.1.0.dist-info/RECORD +80 -0
  41. {celldetective-1.0.2.post1.dist-info → celldetective-1.1.0.dist-info}/top_level.txt +1 -0
  42. tests/__init__.py +0 -0
  43. tests/test_events.py +28 -0
  44. tests/test_filters.py +24 -0
  45. tests/test_io.py +70 -0
  46. tests/test_measure.py +141 -0
  47. tests/test_neighborhood.py +70 -0
  48. tests/test_segmentation.py +93 -0
  49. tests/test_signals.py +135 -0
  50. tests/test_tracking.py +164 -0
  51. tests/test_utils.py +71 -0
  52. celldetective-1.0.2.post1.dist-info/METADATA +0 -221
  53. celldetective-1.0.2.post1.dist-info/RECORD +0 -66
  54. {celldetective-1.0.2.post1.dist-info → celldetective-1.1.0.dist-info}/LICENSE +0 -0
  55. {celldetective-1.0.2.post1.dist-info → celldetective-1.1.0.dist-info}/WHEEL +0 -0
  56. {celldetective-1.0.2.post1.dist-info → celldetective-1.1.0.dist-info}/entry_points.txt +0 -0
@@ -20,6 +20,7 @@ from natsort import natsorted
20
20
  from art import tprint
21
21
  from tifffile import imread
22
22
  import threading
23
+ import datetime
23
24
 
24
25
  tprint("Measure")
25
26
 
@@ -195,6 +196,17 @@ if trajectories is None:
195
196
  print('Use features as a substitute for the trajectory table.')
196
197
  if 'label' not in features:
197
198
  features.append('label')
199
+ features_log=f'features: {features}'
200
+ border_distances_log=f'border_distances: {border_distances}'
201
+ haralick_options_log=f'haralick_options: {haralick_options}'
202
+ background_correction_log=f'background_correction: {background_correction}'
203
+ spot_detection_log=f'spot_detection: {spot_detection}'
204
+ intensity_measurement_radii_log=f'intensity_measurement_radii: {intensity_measurement_radii}'
205
+ isotropic_options_log=f'isotropic_operations: {isotropic_operations} \n'
206
+ log='\n'.join([features_log,border_distances_log,haralick_options_log,background_correction_log,spot_detection_log,intensity_measurement_radii_log,isotropic_options_log])
207
+ with open(pos + f'log_{mode}.json', 'a') as f:
208
+ f.write(f'{datetime.datetime.now()} MEASURE \n')
209
+ f.write(log+'\n')
198
210
 
199
211
 
200
212
  def measure_index(indices):
@@ -3,6 +3,7 @@ Copright © 2022 Laboratoire Adhesion et Inflammation, Authored by Remy Torro.
3
3
  """
4
4
 
5
5
  import argparse
6
+ import datetime
6
7
  import os
7
8
  import json
8
9
  from stardist.models import StarDist2D
@@ -128,6 +129,10 @@ if os.path.exists(os.sep.join([pos,label_folder])):
128
129
  rmtree(os.sep.join([pos,label_folder]))
129
130
  os.mkdir(os.sep.join([pos,label_folder]))
130
131
  print(f'Folder {os.sep.join([pos,label_folder])} successfully generated.')
132
+ log=f'segmentation model: {modelname}\n'
133
+ with open(pos+f'log_{mode}.json', 'a') as f:
134
+ f.write(f'{datetime.datetime.now()} SEGMENT \n')
135
+ f.write(log)
131
136
 
132
137
 
133
138
  # Loop over all frames and segment
@@ -3,6 +3,7 @@ Copright © 2022 Laboratoire Adhesion et Inflammation, Authored by Remy Torro.
3
3
  """
4
4
 
5
5
  import argparse
6
+ import datetime
6
7
  import os
7
8
  import json
8
9
  from celldetective.io import auto_load_number_of_frames, load_frames, interpret_tracking_configuration
@@ -139,6 +140,16 @@ img_num_channels = _get_img_num_per_channel(channel_indices, len_movie, nbr_chan
139
140
  #######################################
140
141
 
141
142
  timestep_dataframes = []
143
+ features_log=f'features: {features}'
144
+ mask_channels_log=f'mask_channels: {mask_channels}'
145
+ haralick_option_log=f'haralick_options: {haralick_options}'
146
+ post_processing_option_log=f'post_processing_options: {post_processing_options}'
147
+ log_list=[features_log, mask_channels_log, haralick_option_log, post_processing_option_log]
148
+ log='\n'.join(log_list)
149
+
150
+ with open(pos+f'log_{mode}.json', 'a') as f:
151
+ f.write(f'{datetime.datetime.now()} TRACK \n')
152
+ f.write(log+"\n")
142
153
 
143
154
  def measure_index(indices):
144
155
  for t in tqdm(indices,desc="frame"):
celldetective/signals.py CHANGED
@@ -6,7 +6,7 @@ import json
6
6
  from tensorflow.keras.optimizers import Adam
7
7
  from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, ReduceLROnPlateau, CSVLogger
8
8
  from tensorflow.keras.losses import CategoricalCrossentropy, MeanSquaredError, MeanAbsoluteError
9
- from tensorflow.keras.metrics import Precision, Recall
9
+ from tensorflow.keras.metrics import Precision, Recall, MeanIoU
10
10
  from tensorflow.keras.models import load_model,clone_model
11
11
  from tensorflow.config.experimental import list_physical_devices, set_memory_growth
12
12
  from tensorflow.keras.utils import to_categorical, plot_model
@@ -92,6 +92,7 @@ class TimeHistory(Callback):
92
92
 
93
93
  def analyze_signals(trajectories, model, interpolate_na=True,
94
94
  selected_signals=None,
95
+ model_path=None,
95
96
  column_labels = {'track': "TRACK_ID", 'time': 'FRAME', 'x': 'POSITION_X', 'y': 'POSITION_Y'},
96
97
  plot_outcome=False, output_dir=None):
97
98
 
@@ -139,7 +140,7 @@ def analyze_signals(trajectories, model, interpolate_na=True,
139
140
 
140
141
  """
141
142
 
142
- model_path = locate_signal_model(model)
143
+ model_path = locate_signal_model(model, path=model_path)
143
144
  complete_path = model_path #+model
144
145
  complete_path = rf"{complete_path}"
145
146
  model_config_path = os.sep.join([complete_path,'config_input.json'])
@@ -196,6 +197,7 @@ def analyze_signals(trajectories, model, interpolate_na=True,
196
197
  for j,col in enumerate(selected_signals):
197
198
  signal = group[col].to_numpy()
198
199
  signals[i,frames,j] = signal
200
+ signals[i,max(frames):,j] = signal[-1]
199
201
 
200
202
  # for i in range(5):
201
203
  # print('pre model')
@@ -422,6 +424,7 @@ class SignalDetectionModel(object):
422
424
  self.dense_collection = dense_collection
423
425
  self.dropout_rate = dropout_rate
424
426
  self.label = label
427
+ self.show_plots = True
425
428
 
426
429
 
427
430
  if self.pretrained is not None:
@@ -430,6 +433,7 @@ class SignalDetectionModel(object):
430
433
  else:
431
434
  print("Create models from scratch...")
432
435
  self.create_models_from_scratch()
436
+ print("Models successfully created.")
433
437
 
434
438
 
435
439
  def load_pretrained_model(self):
@@ -545,10 +549,10 @@ class SignalDetectionModel(object):
545
549
  except:
546
550
  pass
547
551
 
548
- def fit_from_directory(self, ds_folders, normalize=True, normalization_percentile=None, normalization_values = None,
552
+ def fit_from_directory(self, datasets, normalize=True, normalization_percentile=None, normalization_values = None,
549
553
  normalization_clip = None, channel_option=["live_nuclei_channel"], model_name=None, target_directory=None,
550
554
  augment=True, augmentation_factor=2, validation_split=0.20, test_split=0.0, batch_size = 64, epochs=300,
551
- recompile_pretrained=False, learning_rate=0.01, loss_reg="mse", loss_class = CategoricalCrossentropy(from_logits=False)):
555
+ recompile_pretrained=False, learning_rate=0.01, loss_reg="mse", loss_class = CategoricalCrossentropy(from_logits=False), show_plots=True):
552
556
 
553
557
  """
554
558
  Trains the model using data from specified directories.
@@ -600,19 +604,17 @@ class SignalDetectionModel(object):
600
604
  - The method automatically splits the dataset into training, validation, and test sets according to the specified splits.
601
605
  """
602
606
 
603
-
604
607
  if not hasattr(self, 'normalization_percentile'):
605
608
  self.normalization_percentile = normalization_percentile
606
609
  if not hasattr(self, 'normalization_values'):
607
610
  self.normalization_values = normalization_values
608
611
  if not hasattr(self, 'normalization_clip'):
609
612
  self.normalization_clip = normalization_clip
610
- print('Actual clip option:', self.normalization_clip)
611
613
 
612
614
  self.normalize = normalize
613
615
  self.normalization_percentile, self. normalization_values, self.normalization_clip = _interpret_normalization_parameters(self.n_channels, self.normalization_percentile, self.normalization_values, self.normalization_clip)
614
616
 
615
- self.ds_folders = [rf'{d}' for d in ds_folders]
617
+ self.datasets = [rf'{d}' if isinstance(d,str) else d for d in datasets]
616
618
  self.batch_size = batch_size
617
619
  self.epochs = epochs
618
620
  self.validation_split = validation_split
@@ -626,29 +628,24 @@ class SignalDetectionModel(object):
626
628
  self.learning_rate = learning_rate
627
629
  self.loss_reg = loss_reg
628
630
  self.loss_class = loss_class
629
-
630
-
631
- if not os.path.exists(self.model_folder):
632
- #shutil.rmtree(self.model_folder)
633
- os.mkdir(self.model_folder)
634
-
631
+ self.show_plots = show_plots
635
632
  self.channel_option = channel_option
636
633
  assert self.n_channels==len(self.channel_option), f'Mismatch between the channel option and the number of channels of the model...'
637
634
 
638
- self.list_of_sets = []
639
- print(self.ds_folders)
640
- for f in self.ds_folders:
641
- self.list_of_sets.extend(glob(os.sep.join([f,"*.npy"])))
642
- print(f"Found {len(self.list_of_sets)} annotation files...")
643
- self.generate_sets()
635
+ if isinstance(self.datasets[0], dict):
636
+ self.datasets = [self.datasets]
644
637
 
645
- self.train_classifier()
646
- self.train_regressor()
638
+ self.list_of_sets = []
639
+ for ds in self.datasets:
640
+ if isinstance(ds,str):
641
+ self.list_of_sets.extend(glob(os.sep.join([ds,"*.npy"])))
642
+ else:
643
+ self.list_of_sets.append(ds)
644
+
645
+ print(f"Found {len(self.list_of_sets)} datasets...")
647
646
 
648
- config_input = {"channels": self.channel_option, "model_signal_length": self.model_signal_length, 'label': self.label, 'normalize': self.normalize, 'normalization_percentile': self.normalization_percentile, 'normalization_values': self.normalization_values, 'normalization_clip': self.normalization_clip}
649
- json_string = json.dumps(config_input)
650
- with open(os.sep.join([self.model_folder,"config_input.json"]), 'w') as outfile:
651
- outfile.write(json_string)
647
+ self.prepare_sets()
648
+ self.train_generic()
652
649
 
653
650
  def fit(self, x_train, y_time_train, y_class_train, normalize=True, normalization_percentile=None, normalization_values = None, normalization_clip = None, pad=True, validation_data=None, test_data=None, channel_option=["live_nuclei_channel","dead_nuclei_channel"], model_name=None,
654
651
  target_directory=None, augment=True, augmentation_factor=3, validation_split=0.25, batch_size = 64, epochs=300,
@@ -691,7 +688,7 @@ class SignalDetectionModel(object):
691
688
 
692
689
  # If y-class is not one-hot encoded, encode it
693
690
  if self.y_class_train.shape[-1] != self.n_classes:
694
- self.class_weights = compute_weights(self.y_class_train)
691
+ self.class_weights = compute_weights(y=self.y_class_train,class_weight="balanced", classes=np.unique(self.y_class_train))
695
692
  self.y_class_train = to_categorical(self.y_class_train)
696
693
 
697
694
  if self.normalize:
@@ -753,13 +750,21 @@ class SignalDetectionModel(object):
753
750
  self.loss_reg = loss_reg
754
751
  self.loss_class = loss_class
755
752
 
756
- if os.path.exists(self.model_folder):
757
- shutil.rmtree(self.model_folder)
758
- os.mkdir(self.model_folder)
753
+ self.train_generic()
754
+
755
+ def train_generic(self):
756
+
757
+ if not os.path.exists(self.model_folder):
758
+ os.mkdir(self.model_folder)
759
759
 
760
760
  self.train_classifier()
761
761
  self.train_regressor()
762
762
 
763
+ config_input = {"channels": self.channel_option, "model_signal_length": self.model_signal_length, 'label': self.label, 'normalize': self.normalize, 'normalization_percentile': self.normalization_percentile, 'normalization_values': self.normalization_values, 'normalization_clip': self.normalization_clip}
764
+ json_string = json.dumps(config_input)
765
+ with open(os.sep.join([self.model_folder,"config_input.json"]), 'w') as outfile:
766
+ outfile.write(json_string)
767
+
763
768
  def predict_class(self, x, normalize=True, pad=True, return_one_hot=False, interpolate=True):
764
769
 
765
770
  """
@@ -936,21 +941,21 @@ class SignalDetectionModel(object):
936
941
  self.model_class.set_weights(clone_model(self.model_class).get_weights())
937
942
  self.model_class.compile(optimizer=Adam(learning_rate=self.learning_rate),
938
943
  loss=self.loss_class,
939
- metrics=['accuracy', Precision(), Recall()])
944
+ metrics=['accuracy', Precision(), Recall(), MeanIoU(num_classes=self.n_classes, name='iou', dtype=float, sparse_y_true=False, sparse_y_pred=False)])
940
945
  else:
941
946
  self.initial_model = clone_model(self.model_class)
942
947
  self.model_class.set_weights(self.initial_model.get_weights())
943
948
  # Recompile to avoid crash
944
949
  self.model_class.compile(optimizer=Adam(learning_rate=self.learning_rate),
945
950
  loss=self.loss_class,
946
- metrics=['accuracy', Precision(), Recall()])
951
+ metrics=['accuracy', Precision(), Recall(),MeanIoU(num_classes=self.n_classes, name='iou', dtype=float, sparse_y_true=False, sparse_y_pred=False)])
947
952
  # Reset weights
948
953
  self.model_class.set_weights(self.initial_model.get_weights())
949
954
  else:
950
955
  print("Compiling the classifier...")
951
956
  self.model_class.compile(optimizer=Adam(learning_rate=self.learning_rate),
952
957
  loss=self.loss_class,
953
- metrics=['accuracy', Precision(), Recall()])
958
+ metrics=['accuracy', Precision(), Recall(),MeanIoU(num_classes=self.n_classes, name='iou', dtype=float, sparse_y_true=False, sparse_y_pred=False)])
954
959
 
955
960
  self.gather_callbacks("classifier")
956
961
 
@@ -979,7 +984,8 @@ class SignalDetectionModel(object):
979
984
  validation_split = self.validation_split,
980
985
  verbose=1)
981
986
 
982
- self.plot_model_history(mode="classifier")
987
+ if self.show_plots:
988
+ self.plot_model_history(mode="classifier")
983
989
 
984
990
  # Set current classification model as the best model
985
991
  self.model_class = load_model(os.sep.join([self.model_folder,"classifier.h5"]))
@@ -1008,10 +1014,12 @@ class SignalDetectionModel(object):
1008
1014
  results = confusion_matrix(ground_truth,predictions)
1009
1015
  self.dico.update({"test_IoU": IoU_score, "test_balanced_accuracy": balanced_accuracy, "test_confusion": results, 'test_precision': precision, 'test_recall': recall})
1010
1016
 
1011
- try:
1012
- plot_confusion_matrix(results, ["dead","alive","miscellaneous"], output_dir=self.model_folder+os.sep, title=title)
1013
- except:
1014
- pass
1017
+ if self.show_plots:
1018
+ try:
1019
+ plot_confusion_matrix(results, ["dead","alive","miscellaneous"], output_dir=self.model_folder+os.sep, title=title)
1020
+ except Exception as e:
1021
+ print(e)
1022
+ pass
1015
1023
  print("Test set: ",classification_report(ground_truth,predictions))
1016
1024
 
1017
1025
  if hasattr(self, 'x_val'):
@@ -1035,10 +1043,11 @@ class SignalDetectionModel(object):
1035
1043
  results = confusion_matrix(ground_truth,predictions)
1036
1044
  self.dico.update({"val_IoU": IoU_score, "val_balanced_accuracy": balanced_accuracy, "val_confusion": results, 'val_precision': precision, 'val_recall': recall})
1037
1045
 
1038
- try:
1039
- plot_confusion_matrix(results, ["dead","alive","miscellaneous"], output_dir=self.model_folder+os.sep, title=title)
1040
- except:
1041
- pass
1046
+ if self.show_plots:
1047
+ try:
1048
+ plot_confusion_matrix(results, ["dead","alive","miscellaneous"], output_dir=self.model_folder+os.sep, title=title)
1049
+ except:
1050
+ pass
1042
1051
  print("Validation set: ",classification_report(ground_truth,predictions))
1043
1052
 
1044
1053
 
@@ -1110,7 +1119,8 @@ class SignalDetectionModel(object):
1110
1119
  validation_split = self.validation_split,
1111
1120
  verbose=1)
1112
1121
 
1113
- self.plot_model_history(mode="regressor")
1122
+ if self.show_plots:
1123
+ self.plot_model_history(mode="regressor")
1114
1124
  self.dico.update({"history_regressor": self.history_regressor, "execution_time_regressor": self.cb[-1].times})
1115
1125
 
1116
1126
 
@@ -1200,7 +1210,8 @@ class SignalDetectionModel(object):
1200
1210
  test_mae = mae(ground_truth, predictions).numpy()
1201
1211
  print(f"MSE on test set: {test_mse}...")
1202
1212
  print(f"MAE on test set: {test_mae}...")
1203
- regression_plot(predictions, ground_truth, savepath=os.sep.join([self.model_folder,"test_regression.png"]))
1213
+ if self.show_plots:
1214
+ regression_plot(predictions, ground_truth, savepath=os.sep.join([self.model_folder,"test_regression.png"]))
1204
1215
  self.dico.update({"test_mse": test_mse, "test_mae": test_mae})
1205
1216
 
1206
1217
  if hasattr(self, 'x_val'):
@@ -1212,7 +1223,8 @@ class SignalDetectionModel(object):
1212
1223
  val_mse = mse(ground_truth, predictions).numpy()
1213
1224
  val_mae = mae(ground_truth, predictions).numpy()
1214
1225
 
1215
- regression_plot(predictions, ground_truth, savepath=os.sep.join([self.model_folder,"validation_regression.png"]))
1226
+ if self.show_plots:
1227
+ regression_plot(predictions, ground_truth, savepath=os.sep.join([self.model_folder,"validation_regression.png"]))
1216
1228
  print(f"MSE on validation set: {val_mse}...")
1217
1229
  print(f"MAE on validation set: {val_mae}...")
1218
1230
 
@@ -1239,17 +1251,17 @@ class SignalDetectionModel(object):
1239
1251
 
1240
1252
  if mode=="classifier":
1241
1253
 
1242
- reduce_lr = ReduceLROnPlateau(monitor='val_precision', factor=0.5, patience=30,
1254
+ reduce_lr = ReduceLROnPlateau(monitor='val_iou', factor=0.5, patience=30,
1243
1255
  cooldown=10, min_lr=5e-10, min_delta=1.0E-10,
1244
1256
  verbose=1,mode="max")
1245
1257
  self.cb.append(reduce_lr)
1246
1258
  csv_logger = CSVLogger(os.sep.join([self.model_folder,'log_classifier.csv']), append=True, separator=';')
1247
1259
  self.cb.append(csv_logger)
1248
1260
  checkpoint_path = os.sep.join([self.model_folder,"classifier.h5"])
1249
- cp_callback = ModelCheckpoint(checkpoint_path,monitor="val_precision",mode="max",verbose=1,save_best_only=True,save_weights_only=False,save_freq="epoch")
1261
+ cp_callback = ModelCheckpoint(checkpoint_path,monitor="val_iou",mode="max",verbose=1,save_best_only=True,save_weights_only=False,save_freq="epoch")
1250
1262
  self.cb.append(cp_callback)
1251
1263
 
1252
- callback_stop = EarlyStopping(monitor='val_precision', patience=100)
1264
+ callback_stop = EarlyStopping(monitor='val_iou', patience=100)
1253
1265
  self.cb.append(callback_stop)
1254
1266
 
1255
1267
  elif mode=="regressor":
@@ -1278,7 +1290,7 @@ class SignalDetectionModel(object):
1278
1290
 
1279
1291
 
1280
1292
 
1281
- def generate_sets(self):
1293
+ def prepare_sets(self):
1282
1294
 
1283
1295
  """
1284
1296
  Generates and preprocesses training, validation, and test sets from loaded annotations.
@@ -1297,8 +1309,30 @@ class SignalDetectionModel(object):
1297
1309
  self.y_time_set = []
1298
1310
  self.y_class_set = []
1299
1311
 
1300
- for s in self.list_of_sets:
1301
- self.load_and_normalize(s)
1312
+ if isinstance(self.list_of_sets[0],str):
1313
+ # Case 1: a list of npy files to be loaded
1314
+ for s in self.list_of_sets:
1315
+
1316
+ signal_dataset = self.load_set(s)
1317
+ selected_signals, max_length = self.find_best_signal_match(signal_dataset)
1318
+ signals_recast, classes, times_of_interest = self.cast_signals_into_training_data(signal_dataset, selected_signals, max_length)
1319
+ signals_recast, times_of_interest = self.normalize_signals(signals_recast, times_of_interest)
1320
+
1321
+ self.x_set.extend(signals_recast)
1322
+ self.y_time_set.extend(times_of_interest)
1323
+ self.y_class_set.extend(classes)
1324
+
1325
+ elif isinstance(self.list_of_sets[0],list):
1326
+ # Case 2: a list of sets (already loaded)
1327
+ for signal_dataset in self.list_of_sets:
1328
+
1329
+ selected_signals, max_length = self.find_best_signal_match(signal_dataset)
1330
+ signals_recast, classes, times_of_interest = self.cast_signals_into_training_data(signal_dataset, selected_signals, max_length)
1331
+ signals_recast, times_of_interest = self.normalize_signals(signals_recast, times_of_interest)
1332
+
1333
+ self.x_set.extend(signals_recast)
1334
+ self.y_time_set.extend(times_of_interest)
1335
+ self.y_class_set.extend(classes)
1302
1336
 
1303
1337
  self.x_set = np.array(self.x_set).astype(np.float32)
1304
1338
  self.x_set = self.interpolate_signals(self.x_set)
@@ -1325,7 +1359,6 @@ class SignalDetectionModel(object):
1325
1359
  self.x_train = ds["x_train"]
1326
1360
  self.x_val = ds["x_val"]
1327
1361
  self.y_time_train = ds["y1_train"].astype(np.float32)
1328
- print(np.amax(self.y_time_train),np.amin(self.y_time_train))
1329
1362
  self.y_time_val = ds["y1_val"].astype(np.float32)
1330
1363
  self.y_class_train = ds["y2_train"]
1331
1364
  self.y_class_val = ds["y2_val"]
@@ -1357,13 +1390,24 @@ class SignalDetectionModel(object):
1357
1390
 
1358
1391
  nbr_augment = self.augmentation_factor*len(self.x_train)
1359
1392
  randomize = np.arange(len(self.x_train))
1360
- indices = random.choices(randomize,k=nbr_augment)
1393
+
1394
+ unique, counts = np.unique(self.y_class_train.argmax(axis=1),return_counts=True)
1395
+ frac = counts/sum(counts)
1396
+ weights = [frac[0]/f for f in frac]
1397
+ weights[0] = weights[0]*3
1398
+
1399
+ self.pre_augment_weights = weights/sum(weights)
1400
+ weights_array = [self.pre_augment_weights[a.argmax()] for a in self.y_class_train]
1401
+
1402
+ indices = random.choices(randomize,k=nbr_augment, weights=weights_array)
1361
1403
 
1362
1404
  x_train_aug = []
1363
1405
  y_time_train_aug = []
1364
1406
  y_class_train_aug = []
1365
1407
 
1408
+ counts = [0.,0.,0.]
1366
1409
  for k in indices:
1410
+ counts[self.y_class_train[k].argmax()] += 1
1367
1411
  aug = augmenter(self.x_train[k],
1368
1412
  self.y_time_train[k],
1369
1413
  self.y_class_train[k],
@@ -1372,36 +1416,23 @@ class SignalDetectionModel(object):
1372
1416
  x_train_aug.append(aug[0])
1373
1417
  y_time_train_aug.append(aug[1])
1374
1418
  y_class_train_aug.append(aug[2])
1419
+ print('per class counts ',counts)
1375
1420
 
1376
1421
  # Save augmented training set
1377
1422
  self.x_train = np.array(x_train_aug)
1378
1423
  self.y_time_train = np.array(y_time_train_aug)
1379
1424
  self.y_class_train = np.array(y_class_train_aug)
1380
-
1381
-
1382
1425
 
1383
- def load_and_normalize(self, subset):
1426
+ self.class_weights = compute_weights(self.y_class_train.argmax(axis=1))
1427
+ print(f"New class weights: {self.class_weights}...")
1384
1428
 
1385
- """
1386
- Loads a subset of signal data from an annotation file and applies normalization.
1387
-
1388
- Parameters
1389
- ----------
1390
- subset : str
1391
- The file path to the .npy annotation file containing signal data for a subset of observations.
1392
-
1393
- Notes
1394
- -----
1395
- - The method extracts required signal channels from the annotation file and applies specified normalization
1396
- and interpolation steps.
1397
- - Preprocessed signals are added to the global dataset for model training.
1398
- """
1399
-
1400
- set_k = np.load(subset,allow_pickle=True)
1401
- ### here do a mapping between channel option and existing signals
1429
+ def load_set(self, signal_dataset):
1430
+ return np.load(signal_dataset,allow_pickle=True)
1402
1431
 
1432
+ def find_best_signal_match(self, signal_dataset):
1433
+
1403
1434
  required_signals = self.channel_option
1404
- available_signals = list(set_k[0].keys())
1435
+ available_signals = list(signal_dataset[0].keys())
1405
1436
 
1406
1437
  selected_signals = []
1407
1438
  for s in required_signals:
@@ -1421,47 +1452,134 @@ class SignalDetectionModel(object):
1421
1452
  else:
1422
1453
  return None
1423
1454
 
1424
-
1425
1455
  key_to_check = selected_signals[0] #self.channel_option[0]
1426
- signal_lengths = [len(l[key_to_check]) for l in set_k]
1427
- max_length = np.amax(signal_lengths)
1456
+ signal_lengths = [len(l[key_to_check]) for l in signal_dataset]
1457
+ max_length = np.amax(signal_lengths)
1458
+
1459
+ return selected_signals, max_length
1428
1460
 
1429
- fluo = np.zeros((len(set_k),max_length,self.n_channels))
1430
- classes = np.zeros(len(set_k))
1431
- times_of_interest = np.zeros(len(set_k))
1461
+ def cast_signals_into_training_data(self, signal_dataset, selected_signals, max_length):
1432
1462
 
1433
- for k in range(len(set_k)):
1463
+ signals_recast = np.zeros((len(signal_dataset),max_length,self.n_channels))
1464
+ classes = np.zeros(len(signal_dataset))
1465
+ times_of_interest = np.zeros(len(signal_dataset))
1466
+
1467
+ for k in range(len(signal_dataset)):
1434
1468
 
1435
1469
  for i in range(self.n_channels):
1436
1470
  try:
1437
1471
  # take into account timeline for accurate time regression
1438
- timeline = set_k[k]['FRAME'].astype(int)
1439
- fluo[k,timeline,i] = set_k[k][selected_signals[i]]
1472
+ timeline = signal_dataset[k]['FRAME'].astype(int)
1473
+ signals_recast[k,timeline,i] = signal_dataset[k][selected_signals[i]]
1440
1474
  except:
1441
1475
  print(f"Attribute {selected_signals[i]} matched to {self.channel_option[i]} not found in annotation...")
1442
1476
  pass
1443
1477
 
1444
- classes[k] = set_k[k]["class"]
1445
- times_of_interest[k] = set_k[k]["time_of_interest"]
1478
+ classes[k] = signal_dataset[k]["class"]
1479
+ times_of_interest[k] = signal_dataset[k]["time_of_interest"]
1446
1480
 
1447
1481
  # Correct absurd times of interest
1448
1482
  times_of_interest[np.nonzero(classes)] = -1
1449
1483
  times_of_interest[(times_of_interest<=0.0)] = -1
1450
1484
 
1451
- # Attempt per-set normalization
1452
- fluo = pad_to_model_length(fluo, self.model_signal_length)
1485
+ return signals_recast, classes, times_of_interest
1486
+
1487
+ def normalize_signals(self, signals_recast, times_of_interest):
1488
+
1489
+ signals_recast = pad_to_model_length(signals_recast, self.model_signal_length)
1453
1490
  if self.normalize:
1454
- fluo = normalize_signal_set(fluo, self.channel_option, normalization_percentile=self.normalization_percentile,
1491
+ signals_recast = normalize_signal_set(signals_recast, self.channel_option, normalization_percentile=self.normalization_percentile,
1455
1492
  normalization_values=self.normalization_values, normalization_clip=self.normalization_clip,
1456
1493
  )
1457
1494
 
1458
1495
  # Trivial normalization for time of interest
1459
1496
  times_of_interest /= self.model_signal_length
1497
+
1498
+ return signals_recast, times_of_interest
1499
+
1500
+
1501
+ # def load_and_normalize(self, subset):
1502
+
1503
+ # """
1504
+ # Loads a subset of signal data from an annotation file and applies normalization.
1505
+
1506
+ # Parameters
1507
+ # ----------
1508
+ # subset : str
1509
+ # The file path to the .npy annotation file containing signal data for a subset of observations.
1510
+
1511
+ # Notes
1512
+ # -----
1513
+ # - The method extracts required signal channels from the annotation file and applies specified normalization
1514
+ # and interpolation steps.
1515
+ # - Preprocessed signals are added to the global dataset for model training.
1516
+ # """
1517
+
1518
+ # set_k = np.load(subset,allow_pickle=True)
1519
+ # ### here do a mapping between channel option and existing signals
1520
+
1521
+ # required_signals = self.channel_option
1522
+ # available_signals = list(set_k[0].keys())
1523
+
1524
+ # selected_signals = []
1525
+ # for s in required_signals:
1526
+ # pattern_test = [s in a for a in available_signals]
1527
+ # if np.any(pattern_test):
1528
+ # valid_columns = np.array(available_signals)[np.array(pattern_test)]
1529
+ # if len(valid_columns)==1:
1530
+ # selected_signals.append(valid_columns[0])
1531
+ # else:
1532
+ # print(f'Found several candidate signals: {valid_columns}')
1533
+ # for vc in natsorted(valid_columns):
1534
+ # if 'circle' in vc:
1535
+ # selected_signals.append(vc)
1536
+ # break
1537
+ # else:
1538
+ # selected_signals.append(valid_columns[0])
1539
+ # else:
1540
+ # return None
1460
1541
 
1461
- # Add to global dataset
1462
- self.x_set.extend(fluo)
1463
- self.y_time_set.extend(times_of_interest)
1464
- self.y_class_set.extend(classes)
1542
+
1543
+ # key_to_check = selected_signals[0] #self.channel_option[0]
1544
+ # signal_lengths = [len(l[key_to_check]) for l in set_k]
1545
+ # max_length = np.amax(signal_lengths)
1546
+
1547
+ # fluo = np.zeros((len(set_k),max_length,self.n_channels))
1548
+ # classes = np.zeros(len(set_k))
1549
+ # times_of_interest = np.zeros(len(set_k))
1550
+
1551
+ # for k in range(len(set_k)):
1552
+
1553
+ # for i in range(self.n_channels):
1554
+ # try:
1555
+ # # take into account timeline for accurate time regression
1556
+ # timeline = set_k[k]['FRAME'].astype(int)
1557
+ # fluo[k,timeline,i] = set_k[k][selected_signals[i]]
1558
+ # except:
1559
+ # print(f"Attribute {selected_signals[i]} matched to {self.channel_option[i]} not found in annotation...")
1560
+ # pass
1561
+
1562
+ # classes[k] = set_k[k]["class"]
1563
+ # times_of_interest[k] = set_k[k]["time_of_interest"]
1564
+
1565
+ # # Correct absurd times of interest
1566
+ # times_of_interest[np.nonzero(classes)] = -1
1567
+ # times_of_interest[(times_of_interest<=0.0)] = -1
1568
+
1569
+ # # Attempt per-set normalization
1570
+ # fluo = pad_to_model_length(fluo, self.model_signal_length)
1571
+ # if self.normalize:
1572
+ # fluo = normalize_signal_set(fluo, self.channel_option, normalization_percentile=self.normalization_percentile,
1573
+ # normalization_values=self.normalization_values, normalization_clip=self.normalization_clip,
1574
+ # )
1575
+
1576
+ # # Trivial normalization for time of interest
1577
+ # times_of_interest /= self.model_signal_length
1578
+
1579
+ # # Add to global dataset
1580
+ # self.x_set.extend(fluo)
1581
+ # self.y_time_set.extend(times_of_interest)
1582
+ # self.y_class_set.extend(classes)
1465
1583
 
1466
1584
  def _interpret_normalization_parameters(n_channels, normalization_percentile, normalization_values, normalization_clip):
1467
1585
 
@@ -1655,7 +1773,7 @@ def pad_to_model_length(signal_set, model_signal_length):
1655
1773
 
1656
1774
  """
1657
1775
 
1658
- padded = np.pad(signal_set, [(0,0),(0,model_signal_length - signal_set.shape[1]),(0,0)])
1776
+ padded = np.pad(signal_set, [(0,0),(0,model_signal_length - signal_set.shape[1]),(0,0)],mode="edge")
1659
1777
 
1660
1778
  return padded
1661
1779
 
@@ -1767,13 +1885,16 @@ def random_time_shift(signal, time_of_interest, cclass, model_signal_length):
1767
1885
 
1768
1886
  """
1769
1887
 
1888
+ min_time = 3
1770
1889
  max_time = model_signal_length
1890
+
1771
1891
  return_target = False
1772
1892
  if time_of_interest != -1:
1773
1893
  return_target = True
1774
- max_time = model_signal_length - 3 # to prevent approaching too much to the edge
1894
+ max_time = model_signal_length + 1/3*model_signal_length # bias to have a third of event class becoming no event
1895
+ min_time = -model_signal_length*1/3
1775
1896
 
1776
- times = np.linspace(-max_time,max_time,2000) # symmetrize to create left-censored events
1897
+ times = np.linspace(min_time,max_time,2000) # symmetrize to create left-censored events
1777
1898
  target_time = np.random.choice(times)
1778
1899
 
1779
1900
  delta_t = target_time - time_of_interest
@@ -1782,13 +1903,16 @@ def random_time_shift(signal, time_of_interest, cclass, model_signal_length):
1782
1903
  if target_time<=0 and np.argmax(cclass)==0:
1783
1904
  target_time = -1
1784
1905
  cclass = np.array([0.,0.,1.]).astype(np.float32)
1906
+ if target_time>=model_signal_length and np.argmax(cclass)==0:
1907
+ target_time = -1
1908
+ cclass = np.array([0.,1.,0.]).astype(np.float32)
1785
1909
 
1786
1910
  if return_target:
1787
1911
  return signal,target_time, cclass
1788
1912
  else:
1789
1913
  return signal, time_of_interest, cclass
1790
1914
 
1791
- def augmenter(signal, time_of_interest, cclass, model_signal_length, time_shift=True, probability=0.8):
1915
+ def augmenter(signal, time_of_interest, cclass, model_signal_length, time_shift=True, probability=0.95):
1792
1916
 
1793
1917
  """
1794
1918
  Randomly augments single-cell signals to simulate variations in noise, intensity ratios, and event times.
@@ -1839,9 +1963,8 @@ def augmenter(signal, time_of_interest, cclass, model_signal_length, time_shift=
1839
1963
 
1840
1964
  if time_shift:
1841
1965
  # do not time shift miscellaneous cells
1842
- if cclass.argmax()!=2.:
1843
- assert time_of_interest is not None, f"Please provide valid lysis times"
1844
- signal,time_of_interest,cclass = random_time_shift(signal, time_of_interest, cclass, model_signal_length)
1966
+ assert time_of_interest is not None, f"Please provide valid lysis times"
1967
+ signal,time_of_interest,cclass = random_time_shift(signal, time_of_interest, cclass, model_signal_length)
1845
1968
 
1846
1969
  #signal = random_intensity_change(signal) #maybe bad idea for non percentile-normalized signals
1847
1970
  signal = gauss_noise(signal)