celldetective 1.3.7.post1__py3-none-any.whl → 1.3.7.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
celldetective/signals.py CHANGED
@@ -325,9 +325,6 @@ def analyze_signals_at_position(pos, model, mode, use_gpu=True, return_table=Fal
325
325
 
326
326
  def analyze_pair_signals_at_position(pos, model, use_gpu=True):
327
327
 
328
- """
329
-
330
- """
331
328
 
332
329
  pos = pos.replace('\\','/')
333
330
  pos = rf"{pos}"
@@ -364,199 +361,8 @@ def analyze_pair_signals_at_position(pos, model, use_gpu=True):
364
361
  return None
365
362
 
366
363
 
367
- # def analyze_signals(trajectories, model, interpolate_na=True,
368
- # selected_signals=None,
369
- # model_path=None,
370
- # column_labels={'track': "TRACK_ID", 'time': 'FRAME', 'x': 'POSITION_X', 'y': 'POSITION_Y'},
371
- # plot_outcome=False, output_dir=None):
372
- # """
373
- # Analyzes signals from trajectory data using a specified signal detection model and configuration.
374
-
375
- # This function preprocesses trajectory data, selects specified signals, and applies a pretrained signal detection
376
- # model to predict classes and times of interest for each trajectory. It supports custom column labeling, interpolation
377
- # of missing values, and plotting of analysis outcomes.
378
-
379
- # Parameters
380
- # ----------
381
- # trajectories : pandas.DataFrame
382
- # DataFrame containing trajectory data with columns for track ID, frame, position, and signals.
383
- # model : str
384
- # The name of the signal detection model to be used for analysis.
385
- # interpolate_na : bool, optional
386
- # Whether to interpolate missing values in the trajectories (default is True).
387
- # selected_signals : list of str, optional
388
- # A list of column names from `trajectories` representing the signals to be analyzed. If None, signals will
389
- # be automatically selected based on the model configuration (default is None).
390
- # column_labels : dict, optional
391
- # A dictionary mapping the default column names ('track', 'time', 'x', 'y') to the corresponding column names
392
- # in `trajectories` (default is {'track': "TRACK_ID", 'time': 'FRAME', 'x': 'POSITION_X', 'y': 'POSITION_Y'}).
393
- # plot_outcome : bool, optional
394
- # If True, generates and saves a plot of the signal analysis outcome (default is False).
395
- # output_dir : str, optional
396
- # The directory where the outcome plot will be saved. Required if `plot_outcome` is True (default is None).
397
-
398
- # Returns
399
- # -------
400
- # pandas.DataFrame
401
- # The input `trajectories` DataFrame with additional columns for predicted classes, times of interest, and
402
- # corresponding colors based on status and class.
403
-
404
- # Raises
405
- # ------
406
- # AssertionError
407
- # If the model or its configuration file cannot be located.
408
-
409
- # Notes
410
- # -----
411
- # - The function relies on an external model configuration file (`config_input.json`) located in the model's directory.
412
- # - Signal selection and preprocessing are based on the requirements specified in the model's configuration.
413
-
414
- # """
415
-
416
- # model_path = locate_signal_model(model, path=model_path)
417
- # complete_path = model_path # +model
418
- # complete_path = rf"{complete_path}"
419
- # model_config_path = os.sep.join([complete_path, 'config_input.json'])
420
- # model_config_path = rf"{model_config_path}"
421
- # assert os.path.exists(complete_path), f'Model {model} could not be located in folder {model_path}... Abort.'
422
- # assert os.path.exists(
423
- # model_config_path), f'Model configuration could not be located in folder {model_path}... Abort.'
424
-
425
- # available_signals = list(trajectories.columns)
426
-
427
- # f = open(model_config_path)
428
- # config = json.load(f)
429
- # required_signals = config["channels"]
430
-
431
- # try:
432
- # label = config['label']
433
- # if label == '':
434
- # label = None
435
- # except:
436
- # label = None
437
-
438
- # if selected_signals is None:
439
- # selected_signals = []
440
- # for s in required_signals:
441
- # pattern_test = [s in a or s == a for a in available_signals]
442
- # #print(f'Pattern test for signal {s}: ', pattern_test)
443
- # assert np.any(
444
- # pattern_test), f'No signal matches with the requirements of the model {required_signals}. Please pass the signals manually with the argument selected_signals or add measurements. Abort.'
445
- # valid_columns = np.array(available_signals)[np.array(pattern_test)]
446
- # if len(valid_columns) == 1:
447
- # selected_signals.append(valid_columns[0])
448
- # else:
449
- # # print(test_number_of_nan(trajectories, valid_columns))
450
- # print(f'Found several candidate signals: {valid_columns}')
451
- # for vc in natsorted(valid_columns):
452
- # if 'circle' in vc:
453
- # selected_signals.append(vc)
454
- # break
455
- # else:
456
- # selected_signals.append(valid_columns[0])
457
- # # do something more complicated in case of one to many columns
458
- # # pass
459
- # else:
460
- # assert len(selected_signals) == len(
461
- # required_signals), f'Mismatch between the number of required signals {required_signals} and the provided signals {selected_signals}... Abort.'
462
-
463
- # print(f'The following channels will be passed to the model: {selected_signals}')
464
- # trajectories_clean = clean_trajectories(trajectories, interpolate_na=interpolate_na,
465
- # interpolate_position_gaps=interpolate_na, column_labels=column_labels)
466
-
467
- # max_signal_size = int(trajectories_clean[column_labels['time']].max()) + 2
468
- # tracks = trajectories_clean[column_labels['track']].unique()
469
- # signals = np.zeros((len(tracks), max_signal_size, len(selected_signals)))
470
-
471
- # for i, (tid, group) in enumerate(trajectories_clean.groupby(column_labels['track'])):
472
- # frames = group[column_labels['time']].to_numpy().astype(int)
473
- # for j, col in enumerate(selected_signals):
474
- # signal = group[col].to_numpy()
475
- # signals[i, frames, j] = signal
476
- # signals[i, max(frames):, j] = signal[-1]
477
-
478
- # # for i in range(5):
479
- # # print('pre model')
480
- # # plt.plot(signals[i,:,0])
481
- # # plt.show()
482
-
483
- # model = SignalDetectionModel(pretrained=complete_path)
484
- # print('signal shape: ', signals.shape)
485
-
486
- # classes = model.predict_class(signals)
487
- # times_recast = model.predict_time_of_interest(signals)
488
-
489
- # if label is None:
490
- # class_col = 'class'
491
- # time_col = 't0'
492
- # status_col = 'status'
493
- # else:
494
- # class_col = 'class_' + label
495
- # time_col = 't_' + label
496
- # status_col = 'status_' + label
497
-
498
- # for i, (tid, group) in enumerate(trajectories.groupby(column_labels['track'])):
499
- # indices = group.index
500
- # trajectories.loc[indices, class_col] = classes[i]
501
- # trajectories.loc[indices, time_col] = times_recast[i]
502
- # print('Done.')
503
-
504
- # for tid, group in trajectories.groupby(column_labels['track']):
505
-
506
- # indices = group.index
507
- # t0 = group[time_col].to_numpy()[0]
508
- # cclass = group[class_col].to_numpy()[0]
509
- # timeline = group[column_labels['time']].to_numpy()
510
- # status = np.zeros_like(timeline)
511
- # if t0 > 0:
512
- # status[timeline >= t0] = 1.
513
- # if cclass == 2:
514
- # status[:] = 2
515
- # if cclass > 2:
516
- # status[:] = 42
517
- # status_color = [color_from_status(s) for s in status]
518
- # class_color = [color_from_class(cclass) for i in range(len(status))]
519
-
520
- # trajectories.loc[indices, status_col] = status
521
- # trajectories.loc[indices, 'status_color'] = status_color
522
- # trajectories.loc[indices, 'class_color'] = class_color
523
-
524
- # if plot_outcome:
525
- # fig, ax = plt.subplots(1, len(selected_signals), figsize=(10, 5))
526
- # for i, s in enumerate(selected_signals):
527
- # for k, (tid, group) in enumerate(trajectories.groupby(column_labels['track'])):
528
- # cclass = group[class_col].to_numpy()[0]
529
- # t0 = group[time_col].to_numpy()[0]
530
- # timeline = group[column_labels['time']].to_numpy()
531
- # if cclass == 0:
532
- # if len(selected_signals) > 1:
533
- # ax[i].plot(timeline - t0, group[s].to_numpy(), c='tab:blue', alpha=0.1)
534
- # else:
535
- # ax.plot(timeline - t0, group[s].to_numpy(), c='tab:blue', alpha=0.1)
536
- # if len(selected_signals) > 1:
537
- # for a, s in zip(ax, selected_signals):
538
- # a.set_title(s)
539
- # a.set_xlabel(r'time - t$_0$ [frame]')
540
- # a.spines['top'].set_visible(False)
541
- # a.spines['right'].set_visible(False)
542
- # else:
543
- # ax.set_title(s)
544
- # ax.set_xlabel(r'time - t$_0$ [frame]')
545
- # ax.spines['top'].set_visible(False)
546
- # ax.spines['right'].set_visible(False)
547
- # plt.tight_layout()
548
- # if output_dir is not None:
549
- # plt.savefig(output_dir + 'signal_collapse.png', bbox_inches='tight', dpi=300)
550
- # plt.pause(3)
551
- # plt.close()
552
-
553
- # return trajectories
554
-
555
-
556
364
  def analyze_pair_signals(trajectories_pairs,trajectories_reference,trajectories_neighbors, model, interpolate_na=True, selected_signals=None,
557
365
  model_path=None, plot_outcome=False, output_dir=None, column_labels = {'track': "TRACK_ID", 'time': 'FRAME', 'x': 'POSITION_X', 'y': 'POSITION_Y'}):
558
- """
559
- """
560
366
 
561
367
  model_path = locate_signal_model(model, path=model_path, pairs=True)
562
368
  print(f'Looking for model in {model_path}...')
@@ -832,6 +638,7 @@ class SignalDetectionModel(object):
832
638
  -----
833
639
  - The models are expected to be saved in .h5 format with the filenames "classifier.h5" and "regressor.h5".
834
640
  - The configuration file is expected to be named "config_input.json" and located in the same directory as the models.
641
+
835
642
  """
836
643
 
837
644
  if self.pretrained.endswith(os.sep):
@@ -893,6 +700,7 @@ class SignalDetectionModel(object):
893
700
  -----
894
701
  - The models are created using a custom ResNet architecture defined elsewhere in the codebase.
895
702
  - The models are stored in the `model_class` and `model_reg` attributes of the class.
703
+
896
704
  """
897
705
 
898
706
  self.model_class = ResNetModelCurrent(n_channels=self.n_channels,
@@ -925,6 +733,7 @@ class SignalDetectionModel(object):
925
733
  -----
926
734
  - This method should be called before any TensorFlow/Keras operations that might allocate GPU memory.
927
735
  - If no GPUs are detected, the method will pass silently.
736
+
928
737
  """
929
738
 
930
739
  try:
@@ -987,6 +796,7 @@ class SignalDetectionModel(object):
987
796
  Notes
988
797
  -----
989
798
  - The method automatically splits the dataset into training, validation, and test sets according to the specified splits.
799
+
990
800
  """
991
801
 
992
802
  if not hasattr(self, 'normalization_percentile'):
@@ -1049,6 +859,7 @@ class SignalDetectionModel(object):
1049
859
  -----
1050
860
  - This method provides an alternative way to train the model when data is already loaded into memory, offering
1051
861
  flexibility for data preprocessing steps outside this class.
862
+
1052
863
  """
1053
864
 
1054
865
  self.normalize = normalize
@@ -1179,6 +990,7 @@ class SignalDetectionModel(object):
1179
990
  -----
1180
991
  - The method processes the input signals according to the specified options to ensure compatibility with the model's
1181
992
  input requirements.
993
+
1182
994
  """
1183
995
 
1184
996
  self.x = np.copy(x)
@@ -1240,6 +1052,7 @@ class SignalDetectionModel(object):
1240
1052
  -----
1241
1053
  - The method processes the input signals according to the specified options and uses the regression model to
1242
1054
  predict times at which a particular event of interest occurs.
1055
+
1243
1056
  """
1244
1057
 
1245
1058
  self.x = np.copy(x)
@@ -1290,6 +1103,7 @@ class SignalDetectionModel(object):
1290
1103
  -----
1291
1104
  - This method is useful for preparing signals that have gaps or missing time points before further processing
1292
1105
  or model training.
1106
+
1293
1107
  """
1294
1108
 
1295
1109
  for i in range(len(x_set)):
@@ -1317,6 +1131,7 @@ class SignalDetectionModel(object):
1317
1131
  - The classifier model predicts the class of each signal, such as live, dead, or miscellaneous.
1318
1132
  - Training parameters such as epochs, batch size, and learning rate are specified during class instantiation.
1319
1133
  - Model performance metrics and training history are saved for analysis.
1134
+
1320
1135
  """
1321
1136
 
1322
1137
  # if pretrained model
@@ -1461,6 +1276,7 @@ class SignalDetectionModel(object):
1461
1276
  - The regressor model estimates the time at which an event of interest occurs within each signal.
1462
1277
  - Only signals predicted to have an event by the classifier model are used for regressor training.
1463
1278
  - Model performance metrics and training history are saved for analysis.
1279
+
1464
1280
  """
1465
1281
 
1466
1282
 
@@ -1545,6 +1361,7 @@ class SignalDetectionModel(object):
1545
1361
  -----
1546
1362
  - Plots include loss and accuracy metrics over epochs for the classifier, and loss metrics for the regressor.
1547
1363
  - The plots are saved as image files in the model's output directory.
1364
+
1548
1365
  """
1549
1366
 
1550
1367
  if mode=="regressor":
@@ -1589,6 +1406,7 @@ class SignalDetectionModel(object):
1589
1406
  -----
1590
1407
  - Evaluation is performed on both test and validation datasets, if available.
1591
1408
  - Regression plots and performance metrics are saved in the model's output directory.
1409
+
1592
1410
  """
1593
1411
 
1594
1412
 
@@ -1641,6 +1459,7 @@ class SignalDetectionModel(object):
1641
1459
  -----
1642
1460
  - Callbacks include learning rate reduction on plateau, early stopping, model checkpointing, and TensorBoard logging.
1643
1461
  - The list of callbacks is stored in the class attribute `cb` and used during model training.
1462
+
1644
1463
  """
1645
1464
 
1646
1465
  self.cb = []
@@ -1698,6 +1517,7 @@ class SignalDetectionModel(object):
1698
1517
  -----
1699
1518
  - Signal annotations are expected to be stored in .npy format and contain required channels and event information.
1700
1519
  - The method applies specified normalization and interpolation options to prepare the signals for model training.
1520
+
1701
1521
  """
1702
1522
 
1703
1523
 
@@ -1781,6 +1601,7 @@ class SignalDetectionModel(object):
1781
1601
  -----
1782
1602
  - Augmentation strategies include random time shifting and signal modifications to simulate variations in real data.
1783
1603
  - The augmented dataset is used for training the classifier and regressor models to improve generalization.
1604
+
1784
1605
  """
1785
1606
 
1786
1607
 
@@ -1902,89 +1723,6 @@ class SignalDetectionModel(object):
1902
1723
  return signals_recast, times_of_interest
1903
1724
 
1904
1725
 
1905
- # def load_and_normalize(self, subset):
1906
-
1907
- # """
1908
- # Loads a subset of signal data from an annotation file and applies normalization.
1909
-
1910
- # Parameters
1911
- # ----------
1912
- # subset : str
1913
- # The file path to the .npy annotation file containing signal data for a subset of observations.
1914
-
1915
- # Notes
1916
- # -----
1917
- # - The method extracts required signal channels from the annotation file and applies specified normalization
1918
- # and interpolation steps.
1919
- # - Preprocessed signals are added to the global dataset for model training.
1920
- # """
1921
-
1922
- # set_k = np.load(subset,allow_pickle=True)
1923
- # ### here do a mapping between channel option and existing signals
1924
-
1925
- # required_signals = self.channel_option
1926
- # available_signals = list(set_k[0].keys())
1927
-
1928
- # selected_signals = []
1929
- # for s in required_signals:
1930
- # pattern_test = [s in a for a in available_signals]
1931
- # if np.any(pattern_test):
1932
- # valid_columns = np.array(available_signals)[np.array(pattern_test)]
1933
- # if len(valid_columns)==1:
1934
- # selected_signals.append(valid_columns[0])
1935
- # else:
1936
- # print(f'Found several candidate signals: {valid_columns}')
1937
- # for vc in natsorted(valid_columns):
1938
- # if 'circle' in vc:
1939
- # selected_signals.append(vc)
1940
- # break
1941
- # else:
1942
- # selected_signals.append(valid_columns[0])
1943
- # else:
1944
- # return None
1945
-
1946
-
1947
- # key_to_check = selected_signals[0] #self.channel_option[0]
1948
- # signal_lengths = [len(l[key_to_check]) for l in set_k]
1949
- # max_length = np.amax(signal_lengths)
1950
-
1951
- # fluo = np.zeros((len(set_k),max_length,self.n_channels))
1952
- # classes = np.zeros(len(set_k))
1953
- # times_of_interest = np.zeros(len(set_k))
1954
-
1955
- # for k in range(len(set_k)):
1956
-
1957
- # for i in range(self.n_channels):
1958
- # try:
1959
- # # take into account timeline for accurate time regression
1960
- # timeline = set_k[k]['FRAME'].astype(int)
1961
- # fluo[k,timeline,i] = set_k[k][selected_signals[i]]
1962
- # except:
1963
- # print(f"Attribute {selected_signals[i]} matched to {self.channel_option[i]} not found in annotation...")
1964
- # pass
1965
-
1966
- # classes[k] = set_k[k]["class"]
1967
- # times_of_interest[k] = set_k[k]["time_of_interest"]
1968
-
1969
- # # Correct absurd times of interest
1970
- # times_of_interest[np.nonzero(classes)] = -1
1971
- # times_of_interest[(times_of_interest<=0.0)] = -1
1972
-
1973
- # # Attempt per-set normalization
1974
- # fluo = pad_to_model_length(fluo, self.model_signal_length)
1975
- # if self.normalize:
1976
- # fluo = normalize_signal_set(fluo, self.channel_option, normalization_percentile=self.normalization_percentile,
1977
- # normalization_values=self.normalization_values, normalization_clip=self.normalization_clip,
1978
- # )
1979
-
1980
- # # Trivial normalization for time of interest
1981
- # times_of_interest /= self.model_signal_length
1982
-
1983
- # # Add to global dataset
1984
- # self.x_set.extend(fluo)
1985
- # self.y_time_set.extend(times_of_interest)
1986
- # self.y_class_set.extend(classes)
1987
-
1988
1726
  def _interpret_normalization_parameters(n_channels, normalization_percentile, normalization_values, normalization_clip):
1989
1727
 
1990
1728
  """
@@ -2028,6 +1766,7 @@ def _interpret_normalization_parameters(n_channels, normalization_percentile, no
2028
1766
  >>> params = _interpret_normalization_parameters(n_channels, normalization_percentile, normalization_values, normalization_clip)
2029
1767
  >>> print(params)
2030
1768
  # ([True, True], [[0.1, 99.9], [0.1, 99.9]], [False, False])
1769
+
2031
1770
  """
2032
1771
 
2033
1772
 
@@ -2101,6 +1840,7 @@ def normalize_signal_set(signal_set, channel_option, percentile_alive=[0.01,99.9
2101
1840
  >>> channel_option = ['alive', 'dead']
2102
1841
  >>> normalized_signals = normalize_signal_set(signal_set, channel_option)
2103
1842
  # Normalizes the signal set based on the default percentile values for 'alive' and 'dead' channels.
1843
+
2104
1844
  """
2105
1845
 
2106
1846
  # Check normalization params are ok
@@ -2584,6 +2324,7 @@ def ResNetModelCurrent(n_channels, n_slices, depth=2, use_pooling=True, n_classe
2584
2324
  --------
2585
2325
  >>> model = ResNetModelCurrent(n_channels=1, n_slices=2, depth=2, use_pooling=True, n_classes=3, dropout_rate=0.1, dense_collection=512, header="classifier", model_signal_length=128)
2586
2326
  # Creates a ResNet model configured for classification with 3 classes.
2327
+
2587
2328
  """
2588
2329
 
2589
2330
  if header=="classifier":
@@ -2652,6 +2393,7 @@ def train_signal_model(config):
2652
2393
  >>> config_path = '/path/to/training_config.json'
2653
2394
  >>> train_signal_model(config_path)
2654
2395
  # This will execute the 'train_signal_model.py' script using the parameters specified in 'training_config.json'.
2396
+
2655
2397
  """
2656
2398
 
2657
2399
  config = config.replace('\\','/')
@@ -2698,6 +2440,7 @@ def T_MSD(x,y,dt):
2698
2440
  >>> T_MSD(x, y, dt)
2699
2441
  ([6.0, 9.0, 4.666666666666667, 1.6666666666666667],
2700
2442
  array([1., 2., 3., 4.]))
2443
+
2701
2444
  """
2702
2445
 
2703
2446
  msd = []
@@ -2735,6 +2478,7 @@ def linear_msd(t, m):
2735
2478
  >>> m = 2.0
2736
2479
  >>> linear_msd(t, m)
2737
2480
  array([2., 4., 6., 8.])
2481
+
2738
2482
  """
2739
2483
 
2740
2484
  return m*t
@@ -2766,6 +2510,7 @@ def alpha_msd(t, m, alpha):
2766
2510
  >>> alpha = 0.5
2767
2511
  >>> alpha_msd(t, m, alpha)
2768
2512
  array([2. , 4. , 6. , 8. ])
2513
+
2769
2514
  """
2770
2515
 
2771
2516
  return m*t**alpha
@@ -2820,6 +2565,7 @@ def sliding_msd(x, y, timeline, window, mode='bi', n_points_migration=7, n_poin
2820
2565
  >>> timeline = np.array([0, 1, 2, 3, 4, 5, 6])
2821
2566
  >>> window = 3
2822
2567
  >>> s_msd, s_alpha = sliding_msd(x, y, timeline, window, n_points_migration=2, n_points_transport=3)
2568
+
2823
2569
  """
2824
2570
 
2825
2571
  assert window > n_points_migration,'Please set a window larger than the number of fit points...'
@@ -2956,6 +2702,7 @@ def sliding_msd_drift(x, y, timeline, window, mode='bi', n_points_migration=7,
2956
2702
  >>> window = 11
2957
2703
  >>> diffusion, velocity = sliding_msd_drift(x, y, timeline, window, mode='bi')
2958
2704
  # Calculates the diffusion coefficient and drift velocity using a bidirectional sliding window.
2705
+
2959
2706
  """
2960
2707
 
2961
2708
  assert window > n_points_migration,'Please set a window larger than the number of fit points...'
@@ -3007,7 +2754,7 @@ def columnwise_mean(matrix, min_nbr_values = 1):
3007
2754
  """
3008
2755
  Calculate the column-wise mean and standard deviation of non-NaN elements in the input matrix.
3009
2756
 
3010
- Parameters:
2757
+ Parameters
3011
2758
  ----------
3012
2759
  matrix : numpy.ndarray
3013
2760
  The input matrix for which column-wise mean and standard deviation are calculated.
@@ -3015,7 +2762,7 @@ def columnwise_mean(matrix, min_nbr_values = 1):
3015
2762
  The minimum number of non-NaN values required in a column to calculate mean and standard deviation.
3016
2763
  Default is 8.
3017
2764
 
3018
- Returns:
2765
+ Returns
3019
2766
  -------
3020
2767
  mean_line : numpy.ndarray
3021
2768
  An array containing the column-wise mean of non-NaN elements. Elements with fewer than `min_nbr_values` non-NaN
@@ -3024,11 +2771,12 @@ def columnwise_mean(matrix, min_nbr_values = 1):
3024
2771
  An array containing the column-wise standard deviation of non-NaN elements. Elements with fewer than `min_nbr_values`
3025
2772
  non-NaN values are replaced with NaN.
3026
2773
 
3027
- Notes:
3028
- ------
2774
+ Notes
2775
+ -----
3029
2776
  1. This function calculates the mean and standard deviation of non-NaN elements in each column of the input matrix.
3030
2777
  2. Columns with fewer than `min_nbr_values` non-zero elements will have NaN as the mean and standard deviation.
3031
2778
  3. NaN values in the input matrix are ignored during calculation.
2779
+
3032
2780
  """
3033
2781
 
3034
2782
  mean_line = np.zeros(matrix.shape[1])
@@ -3050,7 +2798,7 @@ def mean_signal(df, signal_name, class_col, time_col=None, class_value=[0], retu
3050
2798
  """
3051
2799
  Calculate the mean and standard deviation of a specified signal for tracks of a given class in the input DataFrame.
3052
2800
 
3053
- Parameters:
2801
+ Parameters
3054
2802
  ----------
3055
2803
  df : pandas.DataFrame
3056
2804
  Input DataFrame containing tracking data.
@@ -3063,7 +2811,7 @@ def mean_signal(df, signal_name, class_col, time_col=None, class_value=[0], retu
3063
2811
  class_value : int, optional
3064
2812
  Value representing the class of interest. Default is 0.
3065
2813
 
3066
- Returns:
2814
+ Returns
3067
2815
  -------
3068
2816
  mean_signal : numpy.ndarray
3069
2817
  An array containing the mean signal values for tracks of the specified class. Tracks with class not equal to
@@ -3074,12 +2822,13 @@ def mean_signal(df, signal_name, class_col, time_col=None, class_value=[0], retu
3074
2822
  actual_timeline : numpy.ndarray
3075
2823
  An array representing the time points corresponding to the mean signal values.
3076
2824
 
3077
- Notes:
3078
- ------
2825
+ Notes
2826
+ -----
3079
2827
  1. This function calculates the mean and standard deviation of the specified signal for tracks of a given class.
3080
2828
  2. Tracks with class not equal to `class_value` are excluded from the calculation.
3081
2829
  3. Tracks with missing or NaN values in the specified signal are ignored during calculation.
3082
2830
  4. Tracks are aligned based on their 'FRAME' values and the specified `time_col` (if provided).
2831
+
3083
2832
  """
3084
2833
 
3085
2834
  assert signal_name in list(df.columns),"The signal you want to plot is not one of the measured features."
@@ -3142,11 +2891,11 @@ def mean_signal(df, signal_name, class_col, time_col=None, class_value=[0], retu
3142
2891
  else:
3143
2892
  return mean_signal, std_signal, actual_timeline
3144
2893
 
3145
- if __name__ == "__main__":
2894
+ # if __name__ == "__main__":
3146
2895
 
3147
- # model = MultiScaleResNetModel(3, n_classes = 3, dropout_rate=0, dense_collection=1024, header="classifier", model_signal_length = 128)
3148
- # print(model.summary())
3149
- model = ResNetModelCurrent(1, 2, depth=2, use_pooling=True, n_classes = 3, dropout_rate=0.1, dense_collection=512,
3150
- header="classifier", model_signal_length = 128)
3151
- print(model.summary())
3152
- #plot_model(model, to_file='test.png', show_shapes=True)
2896
+ # # model = MultiScaleResNetModel(3, n_classes = 3, dropout_rate=0, dense_collection=1024, header="classifier", model_signal_length = 128)
2897
+ # # print(model.summary())
2898
+ # model = ResNetModelCurrent(1, 2, depth=2, use_pooling=True, n_classes = 3, dropout_rate=0.1, dense_collection=512,
2899
+ # header="classifier", model_signal_length = 128)
2900
+ # print(model.summary())
2901
+ # #plot_model(model, to_file='test.png', show_shapes=True)
celldetective/tracking.py CHANGED
@@ -937,18 +937,6 @@ def track_at_position(pos, mode, return_tracks=False, view_on_napari=False, thre
937
937
  return df
938
938
  else:
939
939
  return None
940
-
941
- # # if return_labels or view_on_napari:
942
- # # labels = locate_labels(pos, population=mode)
943
- # # if view_on_napari:
944
- # # if stack_prefix is None:
945
- # # stack_prefix = ''
946
- # # stack = locate_stack(pos, prefix=stack_prefix)
947
- # # _view_on_napari(tracks=None, stack=stack, labels=labels)
948
- # # if return_labels:
949
- # # return labels
950
- # # else:
951
- # return None
952
940
 
953
941
  def write_first_detection_class(df, img_shape=None, edge_threshold=20, column_labels={'track': "TRACK_ID", 'time': 'FRAME', 'x': 'POSITION_X', 'y': 'POSITION_Y'}):
954
942
 
@@ -972,27 +960,33 @@ def write_first_detection_class(df, img_shape=None, edge_threshold=20, column_la
972
960
 
973
961
  column_labels : dict, optional
974
962
  A dictionary mapping logical column names to actual column names in `tab`. Keys include:
975
- - `'track'`: The column indicating the track ID (default: `"TRACK_ID"`).
976
- - `'time'`: The column indicating the frame/time (default: `"FRAME"`).
977
- - `'x'`: The column indicating the X-coordinate (default: `"POSITION_X"`).
978
- - `'y'`: The column indicating the Y-coordinate (default: `"POSITION_Y"`).
963
+
964
+ - `'track'`: The column indicating the track ID (default: `"TRACK_ID"`).
965
+ - `'time'`: The column indicating the frame/time (default: `"FRAME"`).
966
+ - `'x'`: The column indicating the X-coordinate (default: `"POSITION_X"`).
967
+ - `'y'`: The column indicating the Y-coordinate (default: `"POSITION_Y"`).
979
968
 
980
969
  Returns
981
970
  -------
982
971
  pandas.DataFrame
983
972
  The input DataFrame `df` with two additional columns:
984
- - `'class_firstdetection'`: A class assigned based on detection status:
985
- - `0`: Valid detection not near the edge and not at the initial frame.
986
- - `2`: Detection near the edge, at the initial frame, or no detection available.
987
- - `'t_firstdetection'`: The adjusted first detection time (in frame units):
988
- - `-1`: Indicates no valid detection or detection near the edge.
989
- - A float value representing the adjusted first detection time otherwise.
973
+
974
+ - `'class_firstdetection'`: A class assigned based on detection status:
975
+
976
+ - `0`: Valid detection not near the edge and not at the initial frame.
977
+ - `2`: Detection near the edge, at the initial frame, or no detection available.
978
+
979
+ - `'t_firstdetection'`: The adjusted first detection time (in frame units):
980
+
981
+ - `-1`: Indicates no valid detection or detection near the edge.
982
+ - A float value representing the adjusted first detection time otherwise.
990
983
 
991
984
  Notes
992
985
  -----
993
986
  - The function assumes that tracks are grouped and sorted by track ID and frame.
994
987
  - Detections near the edge or at the initial frame (frame 0) are considered invalid and assigned special values.
995
988
  - If `img_shape` is not provided, edge checks are skipped.
989
+
996
990
  """
997
991
 
998
992
  df = df.sort_values(by=[column_labels['track'],column_labels['time']])
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: celldetective
3
- Version: 1.3.7.post1
3
+ Version: 1.3.7.post2
4
4
  Summary: description
5
5
  Home-page: http://github.com/remyeltorro/celldetective
6
6
  Author: Rémy Torro
@@ -30,6 +30,7 @@ Requires-Dist: setuptools
30
30
  Requires-Dist: scipy
31
31
  Requires-Dist: seaborn
32
32
  Requires-Dist: opencv-python-headless==4.7.0.72
33
+ Requires-Dist: PyQt5
33
34
  Requires-Dist: liblapack
34
35
  Requires-Dist: gputools
35
36
  Requires-Dist: lmfit
@@ -43,6 +44,14 @@ Requires-Dist: h5py
43
44
  Requires-Dist: cliffs_delta
44
45
  Requires-Dist: requests
45
46
  Requires-Dist: trackpy
47
+ Dynamic: author
48
+ Dynamic: author-email
49
+ Dynamic: description
50
+ Dynamic: description-content-type
51
+ Dynamic: home-page
52
+ Dynamic: license
53
+ Dynamic: requires-dist
54
+ Dynamic: summary
46
55
 
47
56
  # Celldetective
48
57