celldetective 1.3.7.post1__py3-none-any.whl → 1.3.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- celldetective/_version.py +1 -1
- celldetective/gui/btrack_options.py +8 -8
- celldetective/gui/classifier_widget.py +8 -0
- celldetective/gui/configure_new_exp.py +1 -1
- celldetective/gui/json_readers.py +2 -4
- celldetective/gui/plot_signals_ui.py +38 -29
- celldetective/gui/process_block.py +1 -0
- celldetective/gui/processes/downloader.py +108 -0
- celldetective/gui/processes/measure_cells.py +346 -0
- celldetective/gui/processes/segment_cells.py +354 -0
- celldetective/gui/processes/track_cells.py +298 -0
- celldetective/gui/processes/train_segmentation_model.py +270 -0
- celldetective/gui/processes/train_signal_model.py +108 -0
- celldetective/gui/seg_model_loader.py +71 -25
- celldetective/gui/signal_annotator2.py +10 -7
- celldetective/gui/signal_annotator_options.py +1 -1
- celldetective/gui/tableUI.py +252 -20
- celldetective/gui/viewers.py +1 -1
- celldetective/io.py +53 -20
- celldetective/measure.py +12 -144
- celldetective/relative_measurements.py +40 -43
- celldetective/segmentation.py +48 -1
- celldetective/signals.py +84 -305
- celldetective/tracking.py +23 -24
- celldetective/utils.py +1 -1
- {celldetective-1.3.7.post1.dist-info → celldetective-1.3.8.dist-info}/METADATA +11 -2
- {celldetective-1.3.7.post1.dist-info → celldetective-1.3.8.dist-info}/RECORD +31 -25
- {celldetective-1.3.7.post1.dist-info → celldetective-1.3.8.dist-info}/WHEEL +1 -1
- {celldetective-1.3.7.post1.dist-info → celldetective-1.3.8.dist-info}/LICENSE +0 -0
- {celldetective-1.3.7.post1.dist-info → celldetective-1.3.8.dist-info}/entry_points.txt +0 -0
- {celldetective-1.3.7.post1.dist-info → celldetective-1.3.8.dist-info}/top_level.txt +0 -0
celldetective/signals.py
CHANGED
|
@@ -33,6 +33,7 @@ import time
|
|
|
33
33
|
import math
|
|
34
34
|
import pandas as pd
|
|
35
35
|
from pandas.api.types import is_numeric_dtype
|
|
36
|
+
from scipy.stats import median_abs_deviation
|
|
36
37
|
|
|
37
38
|
abs_path = os.sep.join([os.path.split(os.path.dirname(os.path.realpath(__file__)))[0],'celldetective'])
|
|
38
39
|
|
|
@@ -325,9 +326,6 @@ def analyze_signals_at_position(pos, model, mode, use_gpu=True, return_table=Fal
|
|
|
325
326
|
|
|
326
327
|
def analyze_pair_signals_at_position(pos, model, use_gpu=True):
|
|
327
328
|
|
|
328
|
-
"""
|
|
329
|
-
|
|
330
|
-
"""
|
|
331
329
|
|
|
332
330
|
pos = pos.replace('\\','/')
|
|
333
331
|
pos = rf"{pos}"
|
|
@@ -364,199 +362,8 @@ def analyze_pair_signals_at_position(pos, model, use_gpu=True):
|
|
|
364
362
|
return None
|
|
365
363
|
|
|
366
364
|
|
|
367
|
-
# def analyze_signals(trajectories, model, interpolate_na=True,
|
|
368
|
-
# selected_signals=None,
|
|
369
|
-
# model_path=None,
|
|
370
|
-
# column_labels={'track': "TRACK_ID", 'time': 'FRAME', 'x': 'POSITION_X', 'y': 'POSITION_Y'},
|
|
371
|
-
# plot_outcome=False, output_dir=None):
|
|
372
|
-
# """
|
|
373
|
-
# Analyzes signals from trajectory data using a specified signal detection model and configuration.
|
|
374
|
-
|
|
375
|
-
# This function preprocesses trajectory data, selects specified signals, and applies a pretrained signal detection
|
|
376
|
-
# model to predict classes and times of interest for each trajectory. It supports custom column labeling, interpolation
|
|
377
|
-
# of missing values, and plotting of analysis outcomes.
|
|
378
|
-
|
|
379
|
-
# Parameters
|
|
380
|
-
# ----------
|
|
381
|
-
# trajectories : pandas.DataFrame
|
|
382
|
-
# DataFrame containing trajectory data with columns for track ID, frame, position, and signals.
|
|
383
|
-
# model : str
|
|
384
|
-
# The name of the signal detection model to be used for analysis.
|
|
385
|
-
# interpolate_na : bool, optional
|
|
386
|
-
# Whether to interpolate missing values in the trajectories (default is True).
|
|
387
|
-
# selected_signals : list of str, optional
|
|
388
|
-
# A list of column names from `trajectories` representing the signals to be analyzed. If None, signals will
|
|
389
|
-
# be automatically selected based on the model configuration (default is None).
|
|
390
|
-
# column_labels : dict, optional
|
|
391
|
-
# A dictionary mapping the default column names ('track', 'time', 'x', 'y') to the corresponding column names
|
|
392
|
-
# in `trajectories` (default is {'track': "TRACK_ID", 'time': 'FRAME', 'x': 'POSITION_X', 'y': 'POSITION_Y'}).
|
|
393
|
-
# plot_outcome : bool, optional
|
|
394
|
-
# If True, generates and saves a plot of the signal analysis outcome (default is False).
|
|
395
|
-
# output_dir : str, optional
|
|
396
|
-
# The directory where the outcome plot will be saved. Required if `plot_outcome` is True (default is None).
|
|
397
|
-
|
|
398
|
-
# Returns
|
|
399
|
-
# -------
|
|
400
|
-
# pandas.DataFrame
|
|
401
|
-
# The input `trajectories` DataFrame with additional columns for predicted classes, times of interest, and
|
|
402
|
-
# corresponding colors based on status and class.
|
|
403
|
-
|
|
404
|
-
# Raises
|
|
405
|
-
# ------
|
|
406
|
-
# AssertionError
|
|
407
|
-
# If the model or its configuration file cannot be located.
|
|
408
|
-
|
|
409
|
-
# Notes
|
|
410
|
-
# -----
|
|
411
|
-
# - The function relies on an external model configuration file (`config_input.json`) located in the model's directory.
|
|
412
|
-
# - Signal selection and preprocessing are based on the requirements specified in the model's configuration.
|
|
413
|
-
|
|
414
|
-
# """
|
|
415
|
-
|
|
416
|
-
# model_path = locate_signal_model(model, path=model_path)
|
|
417
|
-
# complete_path = model_path # +model
|
|
418
|
-
# complete_path = rf"{complete_path}"
|
|
419
|
-
# model_config_path = os.sep.join([complete_path, 'config_input.json'])
|
|
420
|
-
# model_config_path = rf"{model_config_path}"
|
|
421
|
-
# assert os.path.exists(complete_path), f'Model {model} could not be located in folder {model_path}... Abort.'
|
|
422
|
-
# assert os.path.exists(
|
|
423
|
-
# model_config_path), f'Model configuration could not be located in folder {model_path}... Abort.'
|
|
424
|
-
|
|
425
|
-
# available_signals = list(trajectories.columns)
|
|
426
|
-
|
|
427
|
-
# f = open(model_config_path)
|
|
428
|
-
# config = json.load(f)
|
|
429
|
-
# required_signals = config["channels"]
|
|
430
|
-
|
|
431
|
-
# try:
|
|
432
|
-
# label = config['label']
|
|
433
|
-
# if label == '':
|
|
434
|
-
# label = None
|
|
435
|
-
# except:
|
|
436
|
-
# label = None
|
|
437
|
-
|
|
438
|
-
# if selected_signals is None:
|
|
439
|
-
# selected_signals = []
|
|
440
|
-
# for s in required_signals:
|
|
441
|
-
# pattern_test = [s in a or s == a for a in available_signals]
|
|
442
|
-
# #print(f'Pattern test for signal {s}: ', pattern_test)
|
|
443
|
-
# assert np.any(
|
|
444
|
-
# pattern_test), f'No signal matches with the requirements of the model {required_signals}. Please pass the signals manually with the argument selected_signals or add measurements. Abort.'
|
|
445
|
-
# valid_columns = np.array(available_signals)[np.array(pattern_test)]
|
|
446
|
-
# if len(valid_columns) == 1:
|
|
447
|
-
# selected_signals.append(valid_columns[0])
|
|
448
|
-
# else:
|
|
449
|
-
# # print(test_number_of_nan(trajectories, valid_columns))
|
|
450
|
-
# print(f'Found several candidate signals: {valid_columns}')
|
|
451
|
-
# for vc in natsorted(valid_columns):
|
|
452
|
-
# if 'circle' in vc:
|
|
453
|
-
# selected_signals.append(vc)
|
|
454
|
-
# break
|
|
455
|
-
# else:
|
|
456
|
-
# selected_signals.append(valid_columns[0])
|
|
457
|
-
# # do something more complicated in case of one to many columns
|
|
458
|
-
# # pass
|
|
459
|
-
# else:
|
|
460
|
-
# assert len(selected_signals) == len(
|
|
461
|
-
# required_signals), f'Mismatch between the number of required signals {required_signals} and the provided signals {selected_signals}... Abort.'
|
|
462
|
-
|
|
463
|
-
# print(f'The following channels will be passed to the model: {selected_signals}')
|
|
464
|
-
# trajectories_clean = clean_trajectories(trajectories, interpolate_na=interpolate_na,
|
|
465
|
-
# interpolate_position_gaps=interpolate_na, column_labels=column_labels)
|
|
466
|
-
|
|
467
|
-
# max_signal_size = int(trajectories_clean[column_labels['time']].max()) + 2
|
|
468
|
-
# tracks = trajectories_clean[column_labels['track']].unique()
|
|
469
|
-
# signals = np.zeros((len(tracks), max_signal_size, len(selected_signals)))
|
|
470
|
-
|
|
471
|
-
# for i, (tid, group) in enumerate(trajectories_clean.groupby(column_labels['track'])):
|
|
472
|
-
# frames = group[column_labels['time']].to_numpy().astype(int)
|
|
473
|
-
# for j, col in enumerate(selected_signals):
|
|
474
|
-
# signal = group[col].to_numpy()
|
|
475
|
-
# signals[i, frames, j] = signal
|
|
476
|
-
# signals[i, max(frames):, j] = signal[-1]
|
|
477
|
-
|
|
478
|
-
# # for i in range(5):
|
|
479
|
-
# # print('pre model')
|
|
480
|
-
# # plt.plot(signals[i,:,0])
|
|
481
|
-
# # plt.show()
|
|
482
|
-
|
|
483
|
-
# model = SignalDetectionModel(pretrained=complete_path)
|
|
484
|
-
# print('signal shape: ', signals.shape)
|
|
485
|
-
|
|
486
|
-
# classes = model.predict_class(signals)
|
|
487
|
-
# times_recast = model.predict_time_of_interest(signals)
|
|
488
|
-
|
|
489
|
-
# if label is None:
|
|
490
|
-
# class_col = 'class'
|
|
491
|
-
# time_col = 't0'
|
|
492
|
-
# status_col = 'status'
|
|
493
|
-
# else:
|
|
494
|
-
# class_col = 'class_' + label
|
|
495
|
-
# time_col = 't_' + label
|
|
496
|
-
# status_col = 'status_' + label
|
|
497
|
-
|
|
498
|
-
# for i, (tid, group) in enumerate(trajectories.groupby(column_labels['track'])):
|
|
499
|
-
# indices = group.index
|
|
500
|
-
# trajectories.loc[indices, class_col] = classes[i]
|
|
501
|
-
# trajectories.loc[indices, time_col] = times_recast[i]
|
|
502
|
-
# print('Done.')
|
|
503
|
-
|
|
504
|
-
# for tid, group in trajectories.groupby(column_labels['track']):
|
|
505
|
-
|
|
506
|
-
# indices = group.index
|
|
507
|
-
# t0 = group[time_col].to_numpy()[0]
|
|
508
|
-
# cclass = group[class_col].to_numpy()[0]
|
|
509
|
-
# timeline = group[column_labels['time']].to_numpy()
|
|
510
|
-
# status = np.zeros_like(timeline)
|
|
511
|
-
# if t0 > 0:
|
|
512
|
-
# status[timeline >= t0] = 1.
|
|
513
|
-
# if cclass == 2:
|
|
514
|
-
# status[:] = 2
|
|
515
|
-
# if cclass > 2:
|
|
516
|
-
# status[:] = 42
|
|
517
|
-
# status_color = [color_from_status(s) for s in status]
|
|
518
|
-
# class_color = [color_from_class(cclass) for i in range(len(status))]
|
|
519
|
-
|
|
520
|
-
# trajectories.loc[indices, status_col] = status
|
|
521
|
-
# trajectories.loc[indices, 'status_color'] = status_color
|
|
522
|
-
# trajectories.loc[indices, 'class_color'] = class_color
|
|
523
|
-
|
|
524
|
-
# if plot_outcome:
|
|
525
|
-
# fig, ax = plt.subplots(1, len(selected_signals), figsize=(10, 5))
|
|
526
|
-
# for i, s in enumerate(selected_signals):
|
|
527
|
-
# for k, (tid, group) in enumerate(trajectories.groupby(column_labels['track'])):
|
|
528
|
-
# cclass = group[class_col].to_numpy()[0]
|
|
529
|
-
# t0 = group[time_col].to_numpy()[0]
|
|
530
|
-
# timeline = group[column_labels['time']].to_numpy()
|
|
531
|
-
# if cclass == 0:
|
|
532
|
-
# if len(selected_signals) > 1:
|
|
533
|
-
# ax[i].plot(timeline - t0, group[s].to_numpy(), c='tab:blue', alpha=0.1)
|
|
534
|
-
# else:
|
|
535
|
-
# ax.plot(timeline - t0, group[s].to_numpy(), c='tab:blue', alpha=0.1)
|
|
536
|
-
# if len(selected_signals) > 1:
|
|
537
|
-
# for a, s in zip(ax, selected_signals):
|
|
538
|
-
# a.set_title(s)
|
|
539
|
-
# a.set_xlabel(r'time - t$_0$ [frame]')
|
|
540
|
-
# a.spines['top'].set_visible(False)
|
|
541
|
-
# a.spines['right'].set_visible(False)
|
|
542
|
-
# else:
|
|
543
|
-
# ax.set_title(s)
|
|
544
|
-
# ax.set_xlabel(r'time - t$_0$ [frame]')
|
|
545
|
-
# ax.spines['top'].set_visible(False)
|
|
546
|
-
# ax.spines['right'].set_visible(False)
|
|
547
|
-
# plt.tight_layout()
|
|
548
|
-
# if output_dir is not None:
|
|
549
|
-
# plt.savefig(output_dir + 'signal_collapse.png', bbox_inches='tight', dpi=300)
|
|
550
|
-
# plt.pause(3)
|
|
551
|
-
# plt.close()
|
|
552
|
-
|
|
553
|
-
# return trajectories
|
|
554
|
-
|
|
555
|
-
|
|
556
365
|
def analyze_pair_signals(trajectories_pairs,trajectories_reference,trajectories_neighbors, model, interpolate_na=True, selected_signals=None,
|
|
557
366
|
model_path=None, plot_outcome=False, output_dir=None, column_labels = {'track': "TRACK_ID", 'time': 'FRAME', 'x': 'POSITION_X', 'y': 'POSITION_Y'}):
|
|
558
|
-
"""
|
|
559
|
-
"""
|
|
560
367
|
|
|
561
368
|
model_path = locate_signal_model(model, path=model_path, pairs=True)
|
|
562
369
|
print(f'Looking for model in {model_path}...')
|
|
@@ -832,6 +639,7 @@ class SignalDetectionModel(object):
|
|
|
832
639
|
-----
|
|
833
640
|
- The models are expected to be saved in .h5 format with the filenames "classifier.h5" and "regressor.h5".
|
|
834
641
|
- The configuration file is expected to be named "config_input.json" and located in the same directory as the models.
|
|
642
|
+
|
|
835
643
|
"""
|
|
836
644
|
|
|
837
645
|
if self.pretrained.endswith(os.sep):
|
|
@@ -873,11 +681,22 @@ class SignalDetectionModel(object):
|
|
|
873
681
|
if 'label' in model_config:
|
|
874
682
|
self.label = model_config['label']
|
|
875
683
|
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
684
|
+
try:
|
|
685
|
+
self.n_channels = self.model_class.layers[0].input_shape[0][-1]
|
|
686
|
+
self.model_signal_length = self.model_class.layers[0].input_shape[0][-2]
|
|
687
|
+
self.n_classes = self.model_class.layers[-1].output_shape[-1]
|
|
688
|
+
model_class_input_shape = self.model_class.layers[0].input_shape[0]
|
|
689
|
+
model_reg_input_shape = self.model_reg.layers[0].input_shape[0]
|
|
690
|
+
except AttributeError:
|
|
691
|
+
self.n_channels = self.model_class.input_shape[-1] #self.model_class.layers[0].input.shape[0][-1]
|
|
692
|
+
self.model_signal_length = self.model_class.input_shape[-2] #self.model_class.layers[0].input[0].shape[0][-2]
|
|
693
|
+
self.n_classes = self.model_class.output_shape[-1] #self.model_class.layers[-1].output[0].shape[-1]
|
|
694
|
+
model_class_input_shape = self.model_class.input_shape
|
|
695
|
+
model_reg_input_shape = self.model_reg.input_shape
|
|
696
|
+
except Exception as e:
|
|
697
|
+
print(e)
|
|
879
698
|
|
|
880
|
-
assert
|
|
699
|
+
assert model_class_input_shape==model_reg_input_shape, f"mismatch between input shape of classification: {self.model_class.layers[0].input_shape[0]} and regression {self.model_reg.layers[0].input_shape[0]} models... Error."
|
|
881
700
|
|
|
882
701
|
return True
|
|
883
702
|
|
|
@@ -893,6 +712,7 @@ class SignalDetectionModel(object):
|
|
|
893
712
|
-----
|
|
894
713
|
- The models are created using a custom ResNet architecture defined elsewhere in the codebase.
|
|
895
714
|
- The models are stored in the `model_class` and `model_reg` attributes of the class.
|
|
715
|
+
|
|
896
716
|
"""
|
|
897
717
|
|
|
898
718
|
self.model_class = ResNetModelCurrent(n_channels=self.n_channels,
|
|
@@ -925,6 +745,7 @@ class SignalDetectionModel(object):
|
|
|
925
745
|
-----
|
|
926
746
|
- This method should be called before any TensorFlow/Keras operations that might allocate GPU memory.
|
|
927
747
|
- If no GPUs are detected, the method will pass silently.
|
|
748
|
+
|
|
928
749
|
"""
|
|
929
750
|
|
|
930
751
|
try:
|
|
@@ -987,6 +808,7 @@ class SignalDetectionModel(object):
|
|
|
987
808
|
Notes
|
|
988
809
|
-----
|
|
989
810
|
- The method automatically splits the dataset into training, validation, and test sets according to the specified splits.
|
|
811
|
+
|
|
990
812
|
"""
|
|
991
813
|
|
|
992
814
|
if not hasattr(self, 'normalization_percentile'):
|
|
@@ -1049,6 +871,7 @@ class SignalDetectionModel(object):
|
|
|
1049
871
|
-----
|
|
1050
872
|
- This method provides an alternative way to train the model when data is already loaded into memory, offering
|
|
1051
873
|
flexibility for data preprocessing steps outside this class.
|
|
874
|
+
|
|
1052
875
|
"""
|
|
1053
876
|
|
|
1054
877
|
self.normalize = normalize
|
|
@@ -1179,6 +1002,7 @@ class SignalDetectionModel(object):
|
|
|
1179
1002
|
-----
|
|
1180
1003
|
- The method processes the input signals according to the specified options to ensure compatibility with the model's
|
|
1181
1004
|
input requirements.
|
|
1005
|
+
|
|
1182
1006
|
"""
|
|
1183
1007
|
|
|
1184
1008
|
self.x = np.copy(x)
|
|
@@ -1203,8 +1027,15 @@ class SignalDetectionModel(object):
|
|
|
1203
1027
|
# plt.plot(self.x[i,:,0])
|
|
1204
1028
|
# plt.show()
|
|
1205
1029
|
|
|
1206
|
-
|
|
1207
|
-
|
|
1030
|
+
try:
|
|
1031
|
+
n_channels = self.model_class.layers[0].input_shape[0][-1]
|
|
1032
|
+
model_signal_length = self.model_class.layers[0].input_shape[0][-2]
|
|
1033
|
+
except AttributeError:
|
|
1034
|
+
n_channels = self.model_class.input_shape[-1]
|
|
1035
|
+
model_signal_length = self.model_class.input_shape[-2]
|
|
1036
|
+
|
|
1037
|
+
assert self.x.shape[-1] == n_channels, f"Shape mismatch between the input shape and the model input shape..."
|
|
1038
|
+
assert self.x.shape[-2] == model_signal_length, f"Shape mismatch between the input shape and the model input shape..."
|
|
1208
1039
|
|
|
1209
1040
|
self.class_predictions_one_hot = self.model_class.predict(self.x)
|
|
1210
1041
|
self.class_predictions = self.class_predictions_one_hot.argmax(axis=1)
|
|
@@ -1240,6 +1071,7 @@ class SignalDetectionModel(object):
|
|
|
1240
1071
|
-----
|
|
1241
1072
|
- The method processes the input signals according to the specified options and uses the regression model to
|
|
1242
1073
|
predict times at which a particular event of interest occurs.
|
|
1074
|
+
|
|
1243
1075
|
"""
|
|
1244
1076
|
|
|
1245
1077
|
self.x = np.copy(x)
|
|
@@ -1259,8 +1091,15 @@ class SignalDetectionModel(object):
|
|
|
1259
1091
|
normalization_values=self.normalization_values, normalization_clip=self.normalization_clip,
|
|
1260
1092
|
)
|
|
1261
1093
|
|
|
1262
|
-
|
|
1263
|
-
|
|
1094
|
+
try:
|
|
1095
|
+
n_channels = self.model_reg.layers[0].input_shape[0][-1]
|
|
1096
|
+
model_signal_length = self.model_reg.layers[0].input_shape[0][-2]
|
|
1097
|
+
except AttributeError:
|
|
1098
|
+
n_channels = self.model_reg.input_shape[-1]
|
|
1099
|
+
model_signal_length = self.model_reg.input_shape[-2]
|
|
1100
|
+
|
|
1101
|
+
assert self.x.shape[-1] == n_channels, f"Shape mismatch between the input shape and the model input shape..."
|
|
1102
|
+
assert self.x.shape[-2] == model_signal_length, f"Shape mismatch between the input shape and the model input shape..."
|
|
1264
1103
|
|
|
1265
1104
|
if np.any(self.class_predictions==0):
|
|
1266
1105
|
self.time_predictions = self.model_reg.predict(self.x[self.class_predictions==0])*self.model_signal_length
|
|
@@ -1290,6 +1129,7 @@ class SignalDetectionModel(object):
|
|
|
1290
1129
|
-----
|
|
1291
1130
|
- This method is useful for preparing signals that have gaps or missing time points before further processing
|
|
1292
1131
|
or model training.
|
|
1132
|
+
|
|
1293
1133
|
"""
|
|
1294
1134
|
|
|
1295
1135
|
for i in range(len(x_set)):
|
|
@@ -1317,6 +1157,7 @@ class SignalDetectionModel(object):
|
|
|
1317
1157
|
- The classifier model predicts the class of each signal, such as live, dead, or miscellaneous.
|
|
1318
1158
|
- Training parameters such as epochs, batch size, and learning rate are specified during class instantiation.
|
|
1319
1159
|
- Model performance metrics and training history are saved for analysis.
|
|
1160
|
+
|
|
1320
1161
|
"""
|
|
1321
1162
|
|
|
1322
1163
|
# if pretrained model
|
|
@@ -1461,6 +1302,7 @@ class SignalDetectionModel(object):
|
|
|
1461
1302
|
- The regressor model estimates the time at which an event of interest occurs within each signal.
|
|
1462
1303
|
- Only signals predicted to have an event by the classifier model are used for regressor training.
|
|
1463
1304
|
- Model performance metrics and training history are saved for analysis.
|
|
1305
|
+
|
|
1464
1306
|
"""
|
|
1465
1307
|
|
|
1466
1308
|
|
|
@@ -1545,6 +1387,7 @@ class SignalDetectionModel(object):
|
|
|
1545
1387
|
-----
|
|
1546
1388
|
- Plots include loss and accuracy metrics over epochs for the classifier, and loss metrics for the regressor.
|
|
1547
1389
|
- The plots are saved as image files in the model's output directory.
|
|
1390
|
+
|
|
1548
1391
|
"""
|
|
1549
1392
|
|
|
1550
1393
|
if mode=="regressor":
|
|
@@ -1589,6 +1432,7 @@ class SignalDetectionModel(object):
|
|
|
1589
1432
|
-----
|
|
1590
1433
|
- Evaluation is performed on both test and validation datasets, if available.
|
|
1591
1434
|
- Regression plots and performance metrics are saved in the model's output directory.
|
|
1435
|
+
|
|
1592
1436
|
"""
|
|
1593
1437
|
|
|
1594
1438
|
|
|
@@ -1641,6 +1485,7 @@ class SignalDetectionModel(object):
|
|
|
1641
1485
|
-----
|
|
1642
1486
|
- Callbacks include learning rate reduction on plateau, early stopping, model checkpointing, and TensorBoard logging.
|
|
1643
1487
|
- The list of callbacks is stored in the class attribute `cb` and used during model training.
|
|
1488
|
+
|
|
1644
1489
|
"""
|
|
1645
1490
|
|
|
1646
1491
|
self.cb = []
|
|
@@ -1698,6 +1543,7 @@ class SignalDetectionModel(object):
|
|
|
1698
1543
|
-----
|
|
1699
1544
|
- Signal annotations are expected to be stored in .npy format and contain required channels and event information.
|
|
1700
1545
|
- The method applies specified normalization and interpolation options to prepare the signals for model training.
|
|
1546
|
+
|
|
1701
1547
|
"""
|
|
1702
1548
|
|
|
1703
1549
|
|
|
@@ -1781,6 +1627,7 @@ class SignalDetectionModel(object):
|
|
|
1781
1627
|
-----
|
|
1782
1628
|
- Augmentation strategies include random time shifting and signal modifications to simulate variations in real data.
|
|
1783
1629
|
- The augmented dataset is used for training the classifier and regressor models to improve generalization.
|
|
1630
|
+
|
|
1784
1631
|
"""
|
|
1785
1632
|
|
|
1786
1633
|
|
|
@@ -1902,89 +1749,6 @@ class SignalDetectionModel(object):
|
|
|
1902
1749
|
return signals_recast, times_of_interest
|
|
1903
1750
|
|
|
1904
1751
|
|
|
1905
|
-
# def load_and_normalize(self, subset):
|
|
1906
|
-
|
|
1907
|
-
# """
|
|
1908
|
-
# Loads a subset of signal data from an annotation file and applies normalization.
|
|
1909
|
-
|
|
1910
|
-
# Parameters
|
|
1911
|
-
# ----------
|
|
1912
|
-
# subset : str
|
|
1913
|
-
# The file path to the .npy annotation file containing signal data for a subset of observations.
|
|
1914
|
-
|
|
1915
|
-
# Notes
|
|
1916
|
-
# -----
|
|
1917
|
-
# - The method extracts required signal channels from the annotation file and applies specified normalization
|
|
1918
|
-
# and interpolation steps.
|
|
1919
|
-
# - Preprocessed signals are added to the global dataset for model training.
|
|
1920
|
-
# """
|
|
1921
|
-
|
|
1922
|
-
# set_k = np.load(subset,allow_pickle=True)
|
|
1923
|
-
# ### here do a mapping between channel option and existing signals
|
|
1924
|
-
|
|
1925
|
-
# required_signals = self.channel_option
|
|
1926
|
-
# available_signals = list(set_k[0].keys())
|
|
1927
|
-
|
|
1928
|
-
# selected_signals = []
|
|
1929
|
-
# for s in required_signals:
|
|
1930
|
-
# pattern_test = [s in a for a in available_signals]
|
|
1931
|
-
# if np.any(pattern_test):
|
|
1932
|
-
# valid_columns = np.array(available_signals)[np.array(pattern_test)]
|
|
1933
|
-
# if len(valid_columns)==1:
|
|
1934
|
-
# selected_signals.append(valid_columns[0])
|
|
1935
|
-
# else:
|
|
1936
|
-
# print(f'Found several candidate signals: {valid_columns}')
|
|
1937
|
-
# for vc in natsorted(valid_columns):
|
|
1938
|
-
# if 'circle' in vc:
|
|
1939
|
-
# selected_signals.append(vc)
|
|
1940
|
-
# break
|
|
1941
|
-
# else:
|
|
1942
|
-
# selected_signals.append(valid_columns[0])
|
|
1943
|
-
# else:
|
|
1944
|
-
# return None
|
|
1945
|
-
|
|
1946
|
-
|
|
1947
|
-
# key_to_check = selected_signals[0] #self.channel_option[0]
|
|
1948
|
-
# signal_lengths = [len(l[key_to_check]) for l in set_k]
|
|
1949
|
-
# max_length = np.amax(signal_lengths)
|
|
1950
|
-
|
|
1951
|
-
# fluo = np.zeros((len(set_k),max_length,self.n_channels))
|
|
1952
|
-
# classes = np.zeros(len(set_k))
|
|
1953
|
-
# times_of_interest = np.zeros(len(set_k))
|
|
1954
|
-
|
|
1955
|
-
# for k in range(len(set_k)):
|
|
1956
|
-
|
|
1957
|
-
# for i in range(self.n_channels):
|
|
1958
|
-
# try:
|
|
1959
|
-
# # take into account timeline for accurate time regression
|
|
1960
|
-
# timeline = set_k[k]['FRAME'].astype(int)
|
|
1961
|
-
# fluo[k,timeline,i] = set_k[k][selected_signals[i]]
|
|
1962
|
-
# except:
|
|
1963
|
-
# print(f"Attribute {selected_signals[i]} matched to {self.channel_option[i]} not found in annotation...")
|
|
1964
|
-
# pass
|
|
1965
|
-
|
|
1966
|
-
# classes[k] = set_k[k]["class"]
|
|
1967
|
-
# times_of_interest[k] = set_k[k]["time_of_interest"]
|
|
1968
|
-
|
|
1969
|
-
# # Correct absurd times of interest
|
|
1970
|
-
# times_of_interest[np.nonzero(classes)] = -1
|
|
1971
|
-
# times_of_interest[(times_of_interest<=0.0)] = -1
|
|
1972
|
-
|
|
1973
|
-
# # Attempt per-set normalization
|
|
1974
|
-
# fluo = pad_to_model_length(fluo, self.model_signal_length)
|
|
1975
|
-
# if self.normalize:
|
|
1976
|
-
# fluo = normalize_signal_set(fluo, self.channel_option, normalization_percentile=self.normalization_percentile,
|
|
1977
|
-
# normalization_values=self.normalization_values, normalization_clip=self.normalization_clip,
|
|
1978
|
-
# )
|
|
1979
|
-
|
|
1980
|
-
# # Trivial normalization for time of interest
|
|
1981
|
-
# times_of_interest /= self.model_signal_length
|
|
1982
|
-
|
|
1983
|
-
# # Add to global dataset
|
|
1984
|
-
# self.x_set.extend(fluo)
|
|
1985
|
-
# self.y_time_set.extend(times_of_interest)
|
|
1986
|
-
# self.y_class_set.extend(classes)
|
|
1987
|
-
|
|
1988
1752
|
def _interpret_normalization_parameters(n_channels, normalization_percentile, normalization_values, normalization_clip):
|
|
1989
1753
|
|
|
1990
1754
|
"""
|
|
@@ -2028,6 +1792,7 @@ def _interpret_normalization_parameters(n_channels, normalization_percentile, no
|
|
|
2028
1792
|
>>> params = _interpret_normalization_parameters(n_channels, normalization_percentile, normalization_values, normalization_clip)
|
|
2029
1793
|
>>> print(params)
|
|
2030
1794
|
# ([True, True], [[0.1, 99.9], [0.1, 99.9]], [False, False])
|
|
1795
|
+
|
|
2031
1796
|
"""
|
|
2032
1797
|
|
|
2033
1798
|
|
|
@@ -2101,6 +1866,7 @@ def normalize_signal_set(signal_set, channel_option, percentile_alive=[0.01,99.9
|
|
|
2101
1866
|
>>> channel_option = ['alive', 'dead']
|
|
2102
1867
|
>>> normalized_signals = normalize_signal_set(signal_set, channel_option)
|
|
2103
1868
|
# Normalizes the signal set based on the default percentile values for 'alive' and 'dead' channels.
|
|
1869
|
+
|
|
2104
1870
|
"""
|
|
2105
1871
|
|
|
2106
1872
|
# Check normalization params are ok
|
|
@@ -2584,6 +2350,7 @@ def ResNetModelCurrent(n_channels, n_slices, depth=2, use_pooling=True, n_classe
|
|
|
2584
2350
|
--------
|
|
2585
2351
|
>>> model = ResNetModelCurrent(n_channels=1, n_slices=2, depth=2, use_pooling=True, n_classes=3, dropout_rate=0.1, dense_collection=512, header="classifier", model_signal_length=128)
|
|
2586
2352
|
# Creates a ResNet model configured for classification with 3 classes.
|
|
2353
|
+
|
|
2587
2354
|
"""
|
|
2588
2355
|
|
|
2589
2356
|
if header=="classifier":
|
|
@@ -2652,6 +2419,7 @@ def train_signal_model(config):
|
|
|
2652
2419
|
>>> config_path = '/path/to/training_config.json'
|
|
2653
2420
|
>>> train_signal_model(config_path)
|
|
2654
2421
|
# This will execute the 'train_signal_model.py' script using the parameters specified in 'training_config.json'.
|
|
2422
|
+
|
|
2655
2423
|
"""
|
|
2656
2424
|
|
|
2657
2425
|
config = config.replace('\\','/')
|
|
@@ -2698,6 +2466,7 @@ def T_MSD(x,y,dt):
|
|
|
2698
2466
|
>>> T_MSD(x, y, dt)
|
|
2699
2467
|
([6.0, 9.0, 4.666666666666667, 1.6666666666666667],
|
|
2700
2468
|
array([1., 2., 3., 4.]))
|
|
2469
|
+
|
|
2701
2470
|
"""
|
|
2702
2471
|
|
|
2703
2472
|
msd = []
|
|
@@ -2735,6 +2504,7 @@ def linear_msd(t, m):
|
|
|
2735
2504
|
>>> m = 2.0
|
|
2736
2505
|
>>> linear_msd(t, m)
|
|
2737
2506
|
array([2., 4., 6., 8.])
|
|
2507
|
+
|
|
2738
2508
|
"""
|
|
2739
2509
|
|
|
2740
2510
|
return m*t
|
|
@@ -2766,6 +2536,7 @@ def alpha_msd(t, m, alpha):
|
|
|
2766
2536
|
>>> alpha = 0.5
|
|
2767
2537
|
>>> alpha_msd(t, m, alpha)
|
|
2768
2538
|
array([2. , 4. , 6. , 8. ])
|
|
2539
|
+
|
|
2769
2540
|
"""
|
|
2770
2541
|
|
|
2771
2542
|
return m*t**alpha
|
|
@@ -2820,6 +2591,7 @@ def sliding_msd(x, y, timeline, window, mode='bi', n_points_migration=7, n_poin
|
|
|
2820
2591
|
>>> timeline = np.array([0, 1, 2, 3, 4, 5, 6])
|
|
2821
2592
|
>>> window = 3
|
|
2822
2593
|
>>> s_msd, s_alpha = sliding_msd(x, y, timeline, window, n_points_migration=2, n_points_transport=3)
|
|
2594
|
+
|
|
2823
2595
|
"""
|
|
2824
2596
|
|
|
2825
2597
|
assert window > n_points_migration,'Please set a window larger than the number of fit points...'
|
|
@@ -2956,6 +2728,7 @@ def sliding_msd_drift(x, y, timeline, window, mode='bi', n_points_migration=7,
|
|
|
2956
2728
|
>>> window = 11
|
|
2957
2729
|
>>> diffusion, velocity = sliding_msd_drift(x, y, timeline, window, mode='bi')
|
|
2958
2730
|
# Calculates the diffusion coefficient and drift velocity using a bidirectional sliding window.
|
|
2731
|
+
|
|
2959
2732
|
"""
|
|
2960
2733
|
|
|
2961
2734
|
assert window > n_points_migration,'Please set a window larger than the number of fit points...'
|
|
@@ -3002,12 +2775,12 @@ def sliding_msd_drift(x, y, timeline, window, mode='bi', n_points_migration=7,
|
|
|
3002
2775
|
|
|
3003
2776
|
return s_diffusion, s_velocity
|
|
3004
2777
|
|
|
3005
|
-
def columnwise_mean(matrix, min_nbr_values = 1):
|
|
2778
|
+
def columnwise_mean(matrix, min_nbr_values = 1, projection='mean'):
|
|
3006
2779
|
|
|
3007
2780
|
"""
|
|
3008
2781
|
Calculate the column-wise mean and standard deviation of non-NaN elements in the input matrix.
|
|
3009
2782
|
|
|
3010
|
-
Parameters
|
|
2783
|
+
Parameters
|
|
3011
2784
|
----------
|
|
3012
2785
|
matrix : numpy.ndarray
|
|
3013
2786
|
The input matrix for which column-wise mean and standard deviation are calculated.
|
|
@@ -3015,7 +2788,7 @@ def columnwise_mean(matrix, min_nbr_values = 1):
|
|
|
3015
2788
|
The minimum number of non-NaN values required in a column to calculate mean and standard deviation.
|
|
3016
2789
|
Default is 8.
|
|
3017
2790
|
|
|
3018
|
-
Returns
|
|
2791
|
+
Returns
|
|
3019
2792
|
-------
|
|
3020
2793
|
mean_line : numpy.ndarray
|
|
3021
2794
|
An array containing the column-wise mean of non-NaN elements. Elements with fewer than `min_nbr_values` non-NaN
|
|
@@ -3024,11 +2797,12 @@ def columnwise_mean(matrix, min_nbr_values = 1):
|
|
|
3024
2797
|
An array containing the column-wise standard deviation of non-NaN elements. Elements with fewer than `min_nbr_values`
|
|
3025
2798
|
non-NaN values are replaced with NaN.
|
|
3026
2799
|
|
|
3027
|
-
Notes
|
|
3028
|
-
|
|
2800
|
+
Notes
|
|
2801
|
+
-----
|
|
3029
2802
|
1. This function calculates the mean and standard deviation of non-NaN elements in each column of the input matrix.
|
|
3030
2803
|
2. Columns with fewer than `min_nbr_values` non-zero elements will have NaN as the mean and standard deviation.
|
|
3031
2804
|
3. NaN values in the input matrix are ignored during calculation.
|
|
2805
|
+
|
|
3032
2806
|
"""
|
|
3033
2807
|
|
|
3034
2808
|
mean_line = np.zeros(matrix.shape[1])
|
|
@@ -3040,17 +2814,21 @@ def columnwise_mean(matrix, min_nbr_values = 1):
|
|
|
3040
2814
|
values = matrix[:,k]
|
|
3041
2815
|
values = values[values==values]
|
|
3042
2816
|
if len(values[values==values])>min_nbr_values:
|
|
3043
|
-
|
|
3044
|
-
|
|
2817
|
+
if projection=='mean':
|
|
2818
|
+
mean_line[k] = np.nanmean(values)
|
|
2819
|
+
mean_line_std[k] = np.nanstd(values)
|
|
2820
|
+
elif projection=='median':
|
|
2821
|
+
mean_line[k] = np.nanmedian(values)
|
|
2822
|
+
mean_line_std[k] = median_abs_deviation(values, center=np.nanmedian, nan_policy='omit')
|
|
3045
2823
|
return mean_line, mean_line_std
|
|
3046
2824
|
|
|
3047
2825
|
|
|
3048
|
-
def mean_signal(df, signal_name, class_col, time_col=None, class_value=[0], return_matrix=False, forced_max_duration=None, min_nbr_values=2,conflict_mode='mean'):
|
|
2826
|
+
def mean_signal(df, signal_name, class_col, time_col=None, class_value=[0], return_matrix=False, forced_max_duration=None, min_nbr_values=2,conflict_mode='mean', projection='mean'):
|
|
3049
2827
|
|
|
3050
2828
|
"""
|
|
3051
2829
|
Calculate the mean and standard deviation of a specified signal for tracks of a given class in the input DataFrame.
|
|
3052
2830
|
|
|
3053
|
-
Parameters
|
|
2831
|
+
Parameters
|
|
3054
2832
|
----------
|
|
3055
2833
|
df : pandas.DataFrame
|
|
3056
2834
|
Input DataFrame containing tracking data.
|
|
@@ -3063,7 +2841,7 @@ def mean_signal(df, signal_name, class_col, time_col=None, class_value=[0], retu
|
|
|
3063
2841
|
class_value : int, optional
|
|
3064
2842
|
Value representing the class of interest. Default is 0.
|
|
3065
2843
|
|
|
3066
|
-
Returns
|
|
2844
|
+
Returns
|
|
3067
2845
|
-------
|
|
3068
2846
|
mean_signal : numpy.ndarray
|
|
3069
2847
|
An array containing the mean signal values for tracks of the specified class. Tracks with class not equal to
|
|
@@ -3074,12 +2852,13 @@ def mean_signal(df, signal_name, class_col, time_col=None, class_value=[0], retu
|
|
|
3074
2852
|
actual_timeline : numpy.ndarray
|
|
3075
2853
|
An array representing the time points corresponding to the mean signal values.
|
|
3076
2854
|
|
|
3077
|
-
Notes
|
|
3078
|
-
|
|
2855
|
+
Notes
|
|
2856
|
+
-----
|
|
3079
2857
|
1. This function calculates the mean and standard deviation of the specified signal for tracks of a given class.
|
|
3080
2858
|
2. Tracks with class not equal to `class_value` are excluded from the calculation.
|
|
3081
2859
|
3. Tracks with missing or NaN values in the specified signal are ignored during calculation.
|
|
3082
2860
|
4. Tracks are aligned based on their 'FRAME' values and the specified `time_col` (if provided).
|
|
2861
|
+
|
|
3083
2862
|
"""
|
|
3084
2863
|
|
|
3085
2864
|
assert signal_name in list(df.columns),"The signal you want to plot is not one of the measured features."
|
|
@@ -3135,18 +2914,18 @@ def mean_signal(df, signal_name, class_col, time_col=None, class_value=[0], retu
|
|
|
3135
2914
|
signal_matrix[trackid,timeline_shifted.astype(int)] = signal
|
|
3136
2915
|
trackid+=1
|
|
3137
2916
|
|
|
3138
|
-
mean_signal, std_signal = columnwise_mean(signal_matrix, min_nbr_values=min_nbr_values)
|
|
2917
|
+
mean_signal, std_signal = columnwise_mean(signal_matrix, min_nbr_values=min_nbr_values, projection=projection)
|
|
3139
2918
|
actual_timeline = np.linspace(-max_duration, max_duration, 2*max_duration+1)
|
|
3140
2919
|
if return_matrix:
|
|
3141
2920
|
return mean_signal, std_signal, actual_timeline, signal_matrix
|
|
3142
2921
|
else:
|
|
3143
2922
|
return mean_signal, std_signal, actual_timeline
|
|
3144
2923
|
|
|
3145
|
-
if __name__ == "__main__":
|
|
2924
|
+
# if __name__ == "__main__":
|
|
3146
2925
|
|
|
3147
|
-
|
|
3148
|
-
|
|
3149
|
-
|
|
3150
|
-
|
|
3151
|
-
|
|
3152
|
-
|
|
2926
|
+
# # model = MultiScaleResNetModel(3, n_classes = 3, dropout_rate=0, dense_collection=1024, header="classifier", model_signal_length = 128)
|
|
2927
|
+
# # print(model.summary())
|
|
2928
|
+
# model = ResNetModelCurrent(1, 2, depth=2, use_pooling=True, n_classes = 3, dropout_rate=0.1, dense_collection=512,
|
|
2929
|
+
# header="classifier", model_signal_length = 128)
|
|
2930
|
+
# print(model.summary())
|
|
2931
|
+
# #plot_model(model, to_file='test.png', show_shapes=True)
|