celldetective 1.3.7.post1__py3-none-any.whl → 1.3.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. celldetective/_version.py +1 -1
  2. celldetective/gui/btrack_options.py +8 -8
  3. celldetective/gui/classifier_widget.py +8 -0
  4. celldetective/gui/configure_new_exp.py +1 -1
  5. celldetective/gui/json_readers.py +2 -4
  6. celldetective/gui/plot_signals_ui.py +38 -29
  7. celldetective/gui/process_block.py +1 -0
  8. celldetective/gui/processes/downloader.py +108 -0
  9. celldetective/gui/processes/measure_cells.py +346 -0
  10. celldetective/gui/processes/segment_cells.py +354 -0
  11. celldetective/gui/processes/track_cells.py +298 -0
  12. celldetective/gui/processes/train_segmentation_model.py +270 -0
  13. celldetective/gui/processes/train_signal_model.py +108 -0
  14. celldetective/gui/seg_model_loader.py +71 -25
  15. celldetective/gui/signal_annotator2.py +10 -7
  16. celldetective/gui/signal_annotator_options.py +1 -1
  17. celldetective/gui/tableUI.py +252 -20
  18. celldetective/gui/viewers.py +1 -1
  19. celldetective/io.py +53 -20
  20. celldetective/measure.py +12 -144
  21. celldetective/relative_measurements.py +40 -43
  22. celldetective/segmentation.py +48 -1
  23. celldetective/signals.py +84 -305
  24. celldetective/tracking.py +23 -24
  25. celldetective/utils.py +1 -1
  26. {celldetective-1.3.7.post1.dist-info → celldetective-1.3.8.dist-info}/METADATA +11 -2
  27. {celldetective-1.3.7.post1.dist-info → celldetective-1.3.8.dist-info}/RECORD +31 -25
  28. {celldetective-1.3.7.post1.dist-info → celldetective-1.3.8.dist-info}/WHEEL +1 -1
  29. {celldetective-1.3.7.post1.dist-info → celldetective-1.3.8.dist-info}/LICENSE +0 -0
  30. {celldetective-1.3.7.post1.dist-info → celldetective-1.3.8.dist-info}/entry_points.txt +0 -0
  31. {celldetective-1.3.7.post1.dist-info → celldetective-1.3.8.dist-info}/top_level.txt +0 -0
celldetective/signals.py CHANGED
@@ -33,6 +33,7 @@ import time
33
33
  import math
34
34
  import pandas as pd
35
35
  from pandas.api.types import is_numeric_dtype
36
+ from scipy.stats import median_abs_deviation
36
37
 
37
38
  abs_path = os.sep.join([os.path.split(os.path.dirname(os.path.realpath(__file__)))[0],'celldetective'])
38
39
 
@@ -325,9 +326,6 @@ def analyze_signals_at_position(pos, model, mode, use_gpu=True, return_table=Fal
325
326
 
326
327
  def analyze_pair_signals_at_position(pos, model, use_gpu=True):
327
328
 
328
- """
329
-
330
- """
331
329
 
332
330
  pos = pos.replace('\\','/')
333
331
  pos = rf"{pos}"
@@ -364,199 +362,8 @@ def analyze_pair_signals_at_position(pos, model, use_gpu=True):
364
362
  return None
365
363
 
366
364
 
367
- # def analyze_signals(trajectories, model, interpolate_na=True,
368
- # selected_signals=None,
369
- # model_path=None,
370
- # column_labels={'track': "TRACK_ID", 'time': 'FRAME', 'x': 'POSITION_X', 'y': 'POSITION_Y'},
371
- # plot_outcome=False, output_dir=None):
372
- # """
373
- # Analyzes signals from trajectory data using a specified signal detection model and configuration.
374
-
375
- # This function preprocesses trajectory data, selects specified signals, and applies a pretrained signal detection
376
- # model to predict classes and times of interest for each trajectory. It supports custom column labeling, interpolation
377
- # of missing values, and plotting of analysis outcomes.
378
-
379
- # Parameters
380
- # ----------
381
- # trajectories : pandas.DataFrame
382
- # DataFrame containing trajectory data with columns for track ID, frame, position, and signals.
383
- # model : str
384
- # The name of the signal detection model to be used for analysis.
385
- # interpolate_na : bool, optional
386
- # Whether to interpolate missing values in the trajectories (default is True).
387
- # selected_signals : list of str, optional
388
- # A list of column names from `trajectories` representing the signals to be analyzed. If None, signals will
389
- # be automatically selected based on the model configuration (default is None).
390
- # column_labels : dict, optional
391
- # A dictionary mapping the default column names ('track', 'time', 'x', 'y') to the corresponding column names
392
- # in `trajectories` (default is {'track': "TRACK_ID", 'time': 'FRAME', 'x': 'POSITION_X', 'y': 'POSITION_Y'}).
393
- # plot_outcome : bool, optional
394
- # If True, generates and saves a plot of the signal analysis outcome (default is False).
395
- # output_dir : str, optional
396
- # The directory where the outcome plot will be saved. Required if `plot_outcome` is True (default is None).
397
-
398
- # Returns
399
- # -------
400
- # pandas.DataFrame
401
- # The input `trajectories` DataFrame with additional columns for predicted classes, times of interest, and
402
- # corresponding colors based on status and class.
403
-
404
- # Raises
405
- # ------
406
- # AssertionError
407
- # If the model or its configuration file cannot be located.
408
-
409
- # Notes
410
- # -----
411
- # - The function relies on an external model configuration file (`config_input.json`) located in the model's directory.
412
- # - Signal selection and preprocessing are based on the requirements specified in the model's configuration.
413
-
414
- # """
415
-
416
- # model_path = locate_signal_model(model, path=model_path)
417
- # complete_path = model_path # +model
418
- # complete_path = rf"{complete_path}"
419
- # model_config_path = os.sep.join([complete_path, 'config_input.json'])
420
- # model_config_path = rf"{model_config_path}"
421
- # assert os.path.exists(complete_path), f'Model {model} could not be located in folder {model_path}... Abort.'
422
- # assert os.path.exists(
423
- # model_config_path), f'Model configuration could not be located in folder {model_path}... Abort.'
424
-
425
- # available_signals = list(trajectories.columns)
426
-
427
- # f = open(model_config_path)
428
- # config = json.load(f)
429
- # required_signals = config["channels"]
430
-
431
- # try:
432
- # label = config['label']
433
- # if label == '':
434
- # label = None
435
- # except:
436
- # label = None
437
-
438
- # if selected_signals is None:
439
- # selected_signals = []
440
- # for s in required_signals:
441
- # pattern_test = [s in a or s == a for a in available_signals]
442
- # #print(f'Pattern test for signal {s}: ', pattern_test)
443
- # assert np.any(
444
- # pattern_test), f'No signal matches with the requirements of the model {required_signals}. Please pass the signals manually with the argument selected_signals or add measurements. Abort.'
445
- # valid_columns = np.array(available_signals)[np.array(pattern_test)]
446
- # if len(valid_columns) == 1:
447
- # selected_signals.append(valid_columns[0])
448
- # else:
449
- # # print(test_number_of_nan(trajectories, valid_columns))
450
- # print(f'Found several candidate signals: {valid_columns}')
451
- # for vc in natsorted(valid_columns):
452
- # if 'circle' in vc:
453
- # selected_signals.append(vc)
454
- # break
455
- # else:
456
- # selected_signals.append(valid_columns[0])
457
- # # do something more complicated in case of one to many columns
458
- # # pass
459
- # else:
460
- # assert len(selected_signals) == len(
461
- # required_signals), f'Mismatch between the number of required signals {required_signals} and the provided signals {selected_signals}... Abort.'
462
-
463
- # print(f'The following channels will be passed to the model: {selected_signals}')
464
- # trajectories_clean = clean_trajectories(trajectories, interpolate_na=interpolate_na,
465
- # interpolate_position_gaps=interpolate_na, column_labels=column_labels)
466
-
467
- # max_signal_size = int(trajectories_clean[column_labels['time']].max()) + 2
468
- # tracks = trajectories_clean[column_labels['track']].unique()
469
- # signals = np.zeros((len(tracks), max_signal_size, len(selected_signals)))
470
-
471
- # for i, (tid, group) in enumerate(trajectories_clean.groupby(column_labels['track'])):
472
- # frames = group[column_labels['time']].to_numpy().astype(int)
473
- # for j, col in enumerate(selected_signals):
474
- # signal = group[col].to_numpy()
475
- # signals[i, frames, j] = signal
476
- # signals[i, max(frames):, j] = signal[-1]
477
-
478
- # # for i in range(5):
479
- # # print('pre model')
480
- # # plt.plot(signals[i,:,0])
481
- # # plt.show()
482
-
483
- # model = SignalDetectionModel(pretrained=complete_path)
484
- # print('signal shape: ', signals.shape)
485
-
486
- # classes = model.predict_class(signals)
487
- # times_recast = model.predict_time_of_interest(signals)
488
-
489
- # if label is None:
490
- # class_col = 'class'
491
- # time_col = 't0'
492
- # status_col = 'status'
493
- # else:
494
- # class_col = 'class_' + label
495
- # time_col = 't_' + label
496
- # status_col = 'status_' + label
497
-
498
- # for i, (tid, group) in enumerate(trajectories.groupby(column_labels['track'])):
499
- # indices = group.index
500
- # trajectories.loc[indices, class_col] = classes[i]
501
- # trajectories.loc[indices, time_col] = times_recast[i]
502
- # print('Done.')
503
-
504
- # for tid, group in trajectories.groupby(column_labels['track']):
505
-
506
- # indices = group.index
507
- # t0 = group[time_col].to_numpy()[0]
508
- # cclass = group[class_col].to_numpy()[0]
509
- # timeline = group[column_labels['time']].to_numpy()
510
- # status = np.zeros_like(timeline)
511
- # if t0 > 0:
512
- # status[timeline >= t0] = 1.
513
- # if cclass == 2:
514
- # status[:] = 2
515
- # if cclass > 2:
516
- # status[:] = 42
517
- # status_color = [color_from_status(s) for s in status]
518
- # class_color = [color_from_class(cclass) for i in range(len(status))]
519
-
520
- # trajectories.loc[indices, status_col] = status
521
- # trajectories.loc[indices, 'status_color'] = status_color
522
- # trajectories.loc[indices, 'class_color'] = class_color
523
-
524
- # if plot_outcome:
525
- # fig, ax = plt.subplots(1, len(selected_signals), figsize=(10, 5))
526
- # for i, s in enumerate(selected_signals):
527
- # for k, (tid, group) in enumerate(trajectories.groupby(column_labels['track'])):
528
- # cclass = group[class_col].to_numpy()[0]
529
- # t0 = group[time_col].to_numpy()[0]
530
- # timeline = group[column_labels['time']].to_numpy()
531
- # if cclass == 0:
532
- # if len(selected_signals) > 1:
533
- # ax[i].plot(timeline - t0, group[s].to_numpy(), c='tab:blue', alpha=0.1)
534
- # else:
535
- # ax.plot(timeline - t0, group[s].to_numpy(), c='tab:blue', alpha=0.1)
536
- # if len(selected_signals) > 1:
537
- # for a, s in zip(ax, selected_signals):
538
- # a.set_title(s)
539
- # a.set_xlabel(r'time - t$_0$ [frame]')
540
- # a.spines['top'].set_visible(False)
541
- # a.spines['right'].set_visible(False)
542
- # else:
543
- # ax.set_title(s)
544
- # ax.set_xlabel(r'time - t$_0$ [frame]')
545
- # ax.spines['top'].set_visible(False)
546
- # ax.spines['right'].set_visible(False)
547
- # plt.tight_layout()
548
- # if output_dir is not None:
549
- # plt.savefig(output_dir + 'signal_collapse.png', bbox_inches='tight', dpi=300)
550
- # plt.pause(3)
551
- # plt.close()
552
-
553
- # return trajectories
554
-
555
-
556
365
  def analyze_pair_signals(trajectories_pairs,trajectories_reference,trajectories_neighbors, model, interpolate_na=True, selected_signals=None,
557
366
  model_path=None, plot_outcome=False, output_dir=None, column_labels = {'track': "TRACK_ID", 'time': 'FRAME', 'x': 'POSITION_X', 'y': 'POSITION_Y'}):
558
- """
559
- """
560
367
 
561
368
  model_path = locate_signal_model(model, path=model_path, pairs=True)
562
369
  print(f'Looking for model in {model_path}...')
@@ -832,6 +639,7 @@ class SignalDetectionModel(object):
832
639
  -----
833
640
  - The models are expected to be saved in .h5 format with the filenames "classifier.h5" and "regressor.h5".
834
641
  - The configuration file is expected to be named "config_input.json" and located in the same directory as the models.
642
+
835
643
  """
836
644
 
837
645
  if self.pretrained.endswith(os.sep):
@@ -873,11 +681,22 @@ class SignalDetectionModel(object):
873
681
  if 'label' in model_config:
874
682
  self.label = model_config['label']
875
683
 
876
- self.n_channels = self.model_class.layers[0].input_shape[0][-1]
877
- self.model_signal_length = self.model_class.layers[0].input_shape[0][-2]
878
- self.n_classes = self.model_class.layers[-1].output_shape[-1]
684
+ try:
685
+ self.n_channels = self.model_class.layers[0].input_shape[0][-1]
686
+ self.model_signal_length = self.model_class.layers[0].input_shape[0][-2]
687
+ self.n_classes = self.model_class.layers[-1].output_shape[-1]
688
+ model_class_input_shape = self.model_class.layers[0].input_shape[0]
689
+ model_reg_input_shape = self.model_reg.layers[0].input_shape[0]
690
+ except AttributeError:
691
+ self.n_channels = self.model_class.input_shape[-1] #self.model_class.layers[0].input.shape[0][-1]
692
+ self.model_signal_length = self.model_class.input_shape[-2] #self.model_class.layers[0].input[0].shape[0][-2]
693
+ self.n_classes = self.model_class.output_shape[-1] #self.model_class.layers[-1].output[0].shape[-1]
694
+ model_class_input_shape = self.model_class.input_shape
695
+ model_reg_input_shape = self.model_reg.input_shape
696
+ except Exception as e:
697
+ print(e)
879
698
 
880
- assert self.model_class.layers[0].input_shape[0] == self.model_reg.layers[0].input_shape[0], f"mismatch between input shape of classification: {self.model_class.layers[0].input_shape[0]} and regression {self.model_reg.layers[0].input_shape[0]} models... Error."
699
+ assert model_class_input_shape==model_reg_input_shape, f"mismatch between input shape of classification: {self.model_class.layers[0].input_shape[0]} and regression {self.model_reg.layers[0].input_shape[0]} models... Error."
881
700
 
882
701
  return True
883
702
 
@@ -893,6 +712,7 @@ class SignalDetectionModel(object):
893
712
  -----
894
713
  - The models are created using a custom ResNet architecture defined elsewhere in the codebase.
895
714
  - The models are stored in the `model_class` and `model_reg` attributes of the class.
715
+
896
716
  """
897
717
 
898
718
  self.model_class = ResNetModelCurrent(n_channels=self.n_channels,
@@ -925,6 +745,7 @@ class SignalDetectionModel(object):
925
745
  -----
926
746
  - This method should be called before any TensorFlow/Keras operations that might allocate GPU memory.
927
747
  - If no GPUs are detected, the method will pass silently.
748
+
928
749
  """
929
750
 
930
751
  try:
@@ -987,6 +808,7 @@ class SignalDetectionModel(object):
987
808
  Notes
988
809
  -----
989
810
  - The method automatically splits the dataset into training, validation, and test sets according to the specified splits.
811
+
990
812
  """
991
813
 
992
814
  if not hasattr(self, 'normalization_percentile'):
@@ -1049,6 +871,7 @@ class SignalDetectionModel(object):
1049
871
  -----
1050
872
  - This method provides an alternative way to train the model when data is already loaded into memory, offering
1051
873
  flexibility for data preprocessing steps outside this class.
874
+
1052
875
  """
1053
876
 
1054
877
  self.normalize = normalize
@@ -1179,6 +1002,7 @@ class SignalDetectionModel(object):
1179
1002
  -----
1180
1003
  - The method processes the input signals according to the specified options to ensure compatibility with the model's
1181
1004
  input requirements.
1005
+
1182
1006
  """
1183
1007
 
1184
1008
  self.x = np.copy(x)
@@ -1203,8 +1027,15 @@ class SignalDetectionModel(object):
1203
1027
  # plt.plot(self.x[i,:,0])
1204
1028
  # plt.show()
1205
1029
 
1206
- assert self.x.shape[-1] == self.model_class.layers[0].input_shape[0][-1], f"Shape mismatch between the input shape and the model input shape..."
1207
- assert self.x.shape[-2] == self.model_class.layers[0].input_shape[0][-2], f"Shape mismatch between the input shape and the model input shape..."
1030
+ try:
1031
+ n_channels = self.model_class.layers[0].input_shape[0][-1]
1032
+ model_signal_length = self.model_class.layers[0].input_shape[0][-2]
1033
+ except AttributeError:
1034
+ n_channels = self.model_class.input_shape[-1]
1035
+ model_signal_length = self.model_class.input_shape[-2]
1036
+
1037
+ assert self.x.shape[-1] == n_channels, f"Shape mismatch between the input shape and the model input shape..."
1038
+ assert self.x.shape[-2] == model_signal_length, f"Shape mismatch between the input shape and the model input shape..."
1208
1039
 
1209
1040
  self.class_predictions_one_hot = self.model_class.predict(self.x)
1210
1041
  self.class_predictions = self.class_predictions_one_hot.argmax(axis=1)
@@ -1240,6 +1071,7 @@ class SignalDetectionModel(object):
1240
1071
  -----
1241
1072
  - The method processes the input signals according to the specified options and uses the regression model to
1242
1073
  predict times at which a particular event of interest occurs.
1074
+
1243
1075
  """
1244
1076
 
1245
1077
  self.x = np.copy(x)
@@ -1259,8 +1091,15 @@ class SignalDetectionModel(object):
1259
1091
  normalization_values=self.normalization_values, normalization_clip=self.normalization_clip,
1260
1092
  )
1261
1093
 
1262
- assert self.x.shape[-1] == self.model_reg.layers[0].input_shape[0][-1], f"Shape mismatch between the input shape and the model input shape..."
1263
- assert self.x.shape[-2] == self.model_reg.layers[0].input_shape[0][-2], f"Shape mismatch between the input shape and the model input shape..."
1094
+ try:
1095
+ n_channels = self.model_reg.layers[0].input_shape[0][-1]
1096
+ model_signal_length = self.model_reg.layers[0].input_shape[0][-2]
1097
+ except AttributeError:
1098
+ n_channels = self.model_reg.input_shape[-1]
1099
+ model_signal_length = self.model_reg.input_shape[-2]
1100
+
1101
+ assert self.x.shape[-1] == n_channels, f"Shape mismatch between the input shape and the model input shape..."
1102
+ assert self.x.shape[-2] == model_signal_length, f"Shape mismatch between the input shape and the model input shape..."
1264
1103
 
1265
1104
  if np.any(self.class_predictions==0):
1266
1105
  self.time_predictions = self.model_reg.predict(self.x[self.class_predictions==0])*self.model_signal_length
@@ -1290,6 +1129,7 @@ class SignalDetectionModel(object):
1290
1129
  -----
1291
1130
  - This method is useful for preparing signals that have gaps or missing time points before further processing
1292
1131
  or model training.
1132
+
1293
1133
  """
1294
1134
 
1295
1135
  for i in range(len(x_set)):
@@ -1317,6 +1157,7 @@ class SignalDetectionModel(object):
1317
1157
  - The classifier model predicts the class of each signal, such as live, dead, or miscellaneous.
1318
1158
  - Training parameters such as epochs, batch size, and learning rate are specified during class instantiation.
1319
1159
  - Model performance metrics and training history are saved for analysis.
1160
+
1320
1161
  """
1321
1162
 
1322
1163
  # if pretrained model
@@ -1461,6 +1302,7 @@ class SignalDetectionModel(object):
1461
1302
  - The regressor model estimates the time at which an event of interest occurs within each signal.
1462
1303
  - Only signals predicted to have an event by the classifier model are used for regressor training.
1463
1304
  - Model performance metrics and training history are saved for analysis.
1305
+
1464
1306
  """
1465
1307
 
1466
1308
 
@@ -1545,6 +1387,7 @@ class SignalDetectionModel(object):
1545
1387
  -----
1546
1388
  - Plots include loss and accuracy metrics over epochs for the classifier, and loss metrics for the regressor.
1547
1389
  - The plots are saved as image files in the model's output directory.
1390
+
1548
1391
  """
1549
1392
 
1550
1393
  if mode=="regressor":
@@ -1589,6 +1432,7 @@ class SignalDetectionModel(object):
1589
1432
  -----
1590
1433
  - Evaluation is performed on both test and validation datasets, if available.
1591
1434
  - Regression plots and performance metrics are saved in the model's output directory.
1435
+
1592
1436
  """
1593
1437
 
1594
1438
 
@@ -1641,6 +1485,7 @@ class SignalDetectionModel(object):
1641
1485
  -----
1642
1486
  - Callbacks include learning rate reduction on plateau, early stopping, model checkpointing, and TensorBoard logging.
1643
1487
  - The list of callbacks is stored in the class attribute `cb` and used during model training.
1488
+
1644
1489
  """
1645
1490
 
1646
1491
  self.cb = []
@@ -1698,6 +1543,7 @@ class SignalDetectionModel(object):
1698
1543
  -----
1699
1544
  - Signal annotations are expected to be stored in .npy format and contain required channels and event information.
1700
1545
  - The method applies specified normalization and interpolation options to prepare the signals for model training.
1546
+
1701
1547
  """
1702
1548
 
1703
1549
 
@@ -1781,6 +1627,7 @@ class SignalDetectionModel(object):
1781
1627
  -----
1782
1628
  - Augmentation strategies include random time shifting and signal modifications to simulate variations in real data.
1783
1629
  - The augmented dataset is used for training the classifier and regressor models to improve generalization.
1630
+
1784
1631
  """
1785
1632
 
1786
1633
 
@@ -1902,89 +1749,6 @@ class SignalDetectionModel(object):
1902
1749
  return signals_recast, times_of_interest
1903
1750
 
1904
1751
 
1905
- # def load_and_normalize(self, subset):
1906
-
1907
- # """
1908
- # Loads a subset of signal data from an annotation file and applies normalization.
1909
-
1910
- # Parameters
1911
- # ----------
1912
- # subset : str
1913
- # The file path to the .npy annotation file containing signal data for a subset of observations.
1914
-
1915
- # Notes
1916
- # -----
1917
- # - The method extracts required signal channels from the annotation file and applies specified normalization
1918
- # and interpolation steps.
1919
- # - Preprocessed signals are added to the global dataset for model training.
1920
- # """
1921
-
1922
- # set_k = np.load(subset,allow_pickle=True)
1923
- # ### here do a mapping between channel option and existing signals
1924
-
1925
- # required_signals = self.channel_option
1926
- # available_signals = list(set_k[0].keys())
1927
-
1928
- # selected_signals = []
1929
- # for s in required_signals:
1930
- # pattern_test = [s in a for a in available_signals]
1931
- # if np.any(pattern_test):
1932
- # valid_columns = np.array(available_signals)[np.array(pattern_test)]
1933
- # if len(valid_columns)==1:
1934
- # selected_signals.append(valid_columns[0])
1935
- # else:
1936
- # print(f'Found several candidate signals: {valid_columns}')
1937
- # for vc in natsorted(valid_columns):
1938
- # if 'circle' in vc:
1939
- # selected_signals.append(vc)
1940
- # break
1941
- # else:
1942
- # selected_signals.append(valid_columns[0])
1943
- # else:
1944
- # return None
1945
-
1946
-
1947
- # key_to_check = selected_signals[0] #self.channel_option[0]
1948
- # signal_lengths = [len(l[key_to_check]) for l in set_k]
1949
- # max_length = np.amax(signal_lengths)
1950
-
1951
- # fluo = np.zeros((len(set_k),max_length,self.n_channels))
1952
- # classes = np.zeros(len(set_k))
1953
- # times_of_interest = np.zeros(len(set_k))
1954
-
1955
- # for k in range(len(set_k)):
1956
-
1957
- # for i in range(self.n_channels):
1958
- # try:
1959
- # # take into account timeline for accurate time regression
1960
- # timeline = set_k[k]['FRAME'].astype(int)
1961
- # fluo[k,timeline,i] = set_k[k][selected_signals[i]]
1962
- # except:
1963
- # print(f"Attribute {selected_signals[i]} matched to {self.channel_option[i]} not found in annotation...")
1964
- # pass
1965
-
1966
- # classes[k] = set_k[k]["class"]
1967
- # times_of_interest[k] = set_k[k]["time_of_interest"]
1968
-
1969
- # # Correct absurd times of interest
1970
- # times_of_interest[np.nonzero(classes)] = -1
1971
- # times_of_interest[(times_of_interest<=0.0)] = -1
1972
-
1973
- # # Attempt per-set normalization
1974
- # fluo = pad_to_model_length(fluo, self.model_signal_length)
1975
- # if self.normalize:
1976
- # fluo = normalize_signal_set(fluo, self.channel_option, normalization_percentile=self.normalization_percentile,
1977
- # normalization_values=self.normalization_values, normalization_clip=self.normalization_clip,
1978
- # )
1979
-
1980
- # # Trivial normalization for time of interest
1981
- # times_of_interest /= self.model_signal_length
1982
-
1983
- # # Add to global dataset
1984
- # self.x_set.extend(fluo)
1985
- # self.y_time_set.extend(times_of_interest)
1986
- # self.y_class_set.extend(classes)
1987
-
1988
1752
  def _interpret_normalization_parameters(n_channels, normalization_percentile, normalization_values, normalization_clip):
1989
1753
 
1990
1754
  """
@@ -2028,6 +1792,7 @@ def _interpret_normalization_parameters(n_channels, normalization_percentile, no
2028
1792
  >>> params = _interpret_normalization_parameters(n_channels, normalization_percentile, normalization_values, normalization_clip)
2029
1793
  >>> print(params)
2030
1794
  # ([True, True], [[0.1, 99.9], [0.1, 99.9]], [False, False])
1795
+
2031
1796
  """
2032
1797
 
2033
1798
 
@@ -2101,6 +1866,7 @@ def normalize_signal_set(signal_set, channel_option, percentile_alive=[0.01,99.9
2101
1866
  >>> channel_option = ['alive', 'dead']
2102
1867
  >>> normalized_signals = normalize_signal_set(signal_set, channel_option)
2103
1868
  # Normalizes the signal set based on the default percentile values for 'alive' and 'dead' channels.
1869
+
2104
1870
  """
2105
1871
 
2106
1872
  # Check normalization params are ok
@@ -2584,6 +2350,7 @@ def ResNetModelCurrent(n_channels, n_slices, depth=2, use_pooling=True, n_classe
2584
2350
  --------
2585
2351
  >>> model = ResNetModelCurrent(n_channels=1, n_slices=2, depth=2, use_pooling=True, n_classes=3, dropout_rate=0.1, dense_collection=512, header="classifier", model_signal_length=128)
2586
2352
  # Creates a ResNet model configured for classification with 3 classes.
2353
+
2587
2354
  """
2588
2355
 
2589
2356
  if header=="classifier":
@@ -2652,6 +2419,7 @@ def train_signal_model(config):
2652
2419
  >>> config_path = '/path/to/training_config.json'
2653
2420
  >>> train_signal_model(config_path)
2654
2421
  # This will execute the 'train_signal_model.py' script using the parameters specified in 'training_config.json'.
2422
+
2655
2423
  """
2656
2424
 
2657
2425
  config = config.replace('\\','/')
@@ -2698,6 +2466,7 @@ def T_MSD(x,y,dt):
2698
2466
  >>> T_MSD(x, y, dt)
2699
2467
  ([6.0, 9.0, 4.666666666666667, 1.6666666666666667],
2700
2468
  array([1., 2., 3., 4.]))
2469
+
2701
2470
  """
2702
2471
 
2703
2472
  msd = []
@@ -2735,6 +2504,7 @@ def linear_msd(t, m):
2735
2504
  >>> m = 2.0
2736
2505
  >>> linear_msd(t, m)
2737
2506
  array([2., 4., 6., 8.])
2507
+
2738
2508
  """
2739
2509
 
2740
2510
  return m*t
@@ -2766,6 +2536,7 @@ def alpha_msd(t, m, alpha):
2766
2536
  >>> alpha = 0.5
2767
2537
  >>> alpha_msd(t, m, alpha)
2768
2538
  array([2. , 4. , 6. , 8. ])
2539
+
2769
2540
  """
2770
2541
 
2771
2542
  return m*t**alpha
@@ -2820,6 +2591,7 @@ def sliding_msd(x, y, timeline, window, mode='bi', n_points_migration=7, n_poin
2820
2591
  >>> timeline = np.array([0, 1, 2, 3, 4, 5, 6])
2821
2592
  >>> window = 3
2822
2593
  >>> s_msd, s_alpha = sliding_msd(x, y, timeline, window, n_points_migration=2, n_points_transport=3)
2594
+
2823
2595
  """
2824
2596
 
2825
2597
  assert window > n_points_migration,'Please set a window larger than the number of fit points...'
@@ -2956,6 +2728,7 @@ def sliding_msd_drift(x, y, timeline, window, mode='bi', n_points_migration=7,
2956
2728
  >>> window = 11
2957
2729
  >>> diffusion, velocity = sliding_msd_drift(x, y, timeline, window, mode='bi')
2958
2730
  # Calculates the diffusion coefficient and drift velocity using a bidirectional sliding window.
2731
+
2959
2732
  """
2960
2733
 
2961
2734
  assert window > n_points_migration,'Please set a window larger than the number of fit points...'
@@ -3002,12 +2775,12 @@ def sliding_msd_drift(x, y, timeline, window, mode='bi', n_points_migration=7,
3002
2775
 
3003
2776
  return s_diffusion, s_velocity
3004
2777
 
3005
- def columnwise_mean(matrix, min_nbr_values = 1):
2778
+ def columnwise_mean(matrix, min_nbr_values = 1, projection='mean'):
3006
2779
 
3007
2780
  """
3008
2781
  Calculate the column-wise mean and standard deviation of non-NaN elements in the input matrix.
3009
2782
 
3010
- Parameters:
2783
+ Parameters
3011
2784
  ----------
3012
2785
  matrix : numpy.ndarray
3013
2786
  The input matrix for which column-wise mean and standard deviation are calculated.
@@ -3015,7 +2788,7 @@ def columnwise_mean(matrix, min_nbr_values = 1):
3015
2788
  The minimum number of non-NaN values required in a column to calculate mean and standard deviation.
3016
2789
  Default is 8.
3017
2790
 
3018
- Returns:
2791
+ Returns
3019
2792
  -------
3020
2793
  mean_line : numpy.ndarray
3021
2794
  An array containing the column-wise mean of non-NaN elements. Elements with fewer than `min_nbr_values` non-NaN
@@ -3024,11 +2797,12 @@ def columnwise_mean(matrix, min_nbr_values = 1):
3024
2797
  An array containing the column-wise standard deviation of non-NaN elements. Elements with fewer than `min_nbr_values`
3025
2798
  non-NaN values are replaced with NaN.
3026
2799
 
3027
- Notes:
3028
- ------
2800
+ Notes
2801
+ -----
3029
2802
  1. This function calculates the mean and standard deviation of non-NaN elements in each column of the input matrix.
3030
2803
  2. Columns with fewer than `min_nbr_values` non-zero elements will have NaN as the mean and standard deviation.
3031
2804
  3. NaN values in the input matrix are ignored during calculation.
2805
+
3032
2806
  """
3033
2807
 
3034
2808
  mean_line = np.zeros(matrix.shape[1])
@@ -3040,17 +2814,21 @@ def columnwise_mean(matrix, min_nbr_values = 1):
3040
2814
  values = matrix[:,k]
3041
2815
  values = values[values==values]
3042
2816
  if len(values[values==values])>min_nbr_values:
3043
- mean_line[k] = np.nanmean(values)
3044
- mean_line_std[k] = np.nanstd(values)
2817
+ if projection=='mean':
2818
+ mean_line[k] = np.nanmean(values)
2819
+ mean_line_std[k] = np.nanstd(values)
2820
+ elif projection=='median':
2821
+ mean_line[k] = np.nanmedian(values)
2822
+ mean_line_std[k] = median_abs_deviation(values, center=np.nanmedian, nan_policy='omit')
3045
2823
  return mean_line, mean_line_std
3046
2824
 
3047
2825
 
3048
- def mean_signal(df, signal_name, class_col, time_col=None, class_value=[0], return_matrix=False, forced_max_duration=None, min_nbr_values=2,conflict_mode='mean'):
2826
+ def mean_signal(df, signal_name, class_col, time_col=None, class_value=[0], return_matrix=False, forced_max_duration=None, min_nbr_values=2,conflict_mode='mean', projection='mean'):
3049
2827
 
3050
2828
  """
3051
2829
  Calculate the mean and standard deviation of a specified signal for tracks of a given class in the input DataFrame.
3052
2830
 
3053
- Parameters:
2831
+ Parameters
3054
2832
  ----------
3055
2833
  df : pandas.DataFrame
3056
2834
  Input DataFrame containing tracking data.
@@ -3063,7 +2841,7 @@ def mean_signal(df, signal_name, class_col, time_col=None, class_value=[0], retu
3063
2841
  class_value : int, optional
3064
2842
  Value representing the class of interest. Default is 0.
3065
2843
 
3066
- Returns:
2844
+ Returns
3067
2845
  -------
3068
2846
  mean_signal : numpy.ndarray
3069
2847
  An array containing the mean signal values for tracks of the specified class. Tracks with class not equal to
@@ -3074,12 +2852,13 @@ def mean_signal(df, signal_name, class_col, time_col=None, class_value=[0], retu
3074
2852
  actual_timeline : numpy.ndarray
3075
2853
  An array representing the time points corresponding to the mean signal values.
3076
2854
 
3077
- Notes:
3078
- ------
2855
+ Notes
2856
+ -----
3079
2857
  1. This function calculates the mean and standard deviation of the specified signal for tracks of a given class.
3080
2858
  2. Tracks with class not equal to `class_value` are excluded from the calculation.
3081
2859
  3. Tracks with missing or NaN values in the specified signal are ignored during calculation.
3082
2860
  4. Tracks are aligned based on their 'FRAME' values and the specified `time_col` (if provided).
2861
+
3083
2862
  """
3084
2863
 
3085
2864
  assert signal_name in list(df.columns),"The signal you want to plot is not one of the measured features."
@@ -3135,18 +2914,18 @@ def mean_signal(df, signal_name, class_col, time_col=None, class_value=[0], retu
3135
2914
  signal_matrix[trackid,timeline_shifted.astype(int)] = signal
3136
2915
  trackid+=1
3137
2916
 
3138
- mean_signal, std_signal = columnwise_mean(signal_matrix, min_nbr_values=min_nbr_values)
2917
+ mean_signal, std_signal = columnwise_mean(signal_matrix, min_nbr_values=min_nbr_values, projection=projection)
3139
2918
  actual_timeline = np.linspace(-max_duration, max_duration, 2*max_duration+1)
3140
2919
  if return_matrix:
3141
2920
  return mean_signal, std_signal, actual_timeline, signal_matrix
3142
2921
  else:
3143
2922
  return mean_signal, std_signal, actual_timeline
3144
2923
 
3145
- if __name__ == "__main__":
2924
+ # if __name__ == "__main__":
3146
2925
 
3147
- # model = MultiScaleResNetModel(3, n_classes = 3, dropout_rate=0, dense_collection=1024, header="classifier", model_signal_length = 128)
3148
- # print(model.summary())
3149
- model = ResNetModelCurrent(1, 2, depth=2, use_pooling=True, n_classes = 3, dropout_rate=0.1, dense_collection=512,
3150
- header="classifier", model_signal_length = 128)
3151
- print(model.summary())
3152
- #plot_model(model, to_file='test.png', show_shapes=True)
2926
+ # # model = MultiScaleResNetModel(3, n_classes = 3, dropout_rate=0, dense_collection=1024, header="classifier", model_signal_length = 128)
2927
+ # # print(model.summary())
2928
+ # model = ResNetModelCurrent(1, 2, depth=2, use_pooling=True, n_classes = 3, dropout_rate=0.1, dense_collection=512,
2929
+ # header="classifier", model_signal_length = 128)
2930
+ # print(model.summary())
2931
+ # #plot_model(model, to_file='test.png', show_shapes=True)