celldetective 1.1.1.post3__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. celldetective/__init__.py +2 -1
  2. celldetective/__main__.py +17 -0
  3. celldetective/extra_properties.py +62 -34
  4. celldetective/gui/__init__.py +1 -0
  5. celldetective/gui/analyze_block.py +2 -1
  6. celldetective/gui/classifier_widget.py +18 -10
  7. celldetective/gui/control_panel.py +57 -6
  8. celldetective/gui/layouts.py +14 -11
  9. celldetective/gui/neighborhood_options.py +21 -13
  10. celldetective/gui/plot_signals_ui.py +39 -11
  11. celldetective/gui/process_block.py +413 -95
  12. celldetective/gui/retrain_segmentation_model_options.py +17 -4
  13. celldetective/gui/retrain_signal_model_options.py +106 -6
  14. celldetective/gui/signal_annotator.py +110 -30
  15. celldetective/gui/signal_annotator2.py +2708 -0
  16. celldetective/gui/signal_annotator_options.py +3 -1
  17. celldetective/gui/survival_ui.py +15 -6
  18. celldetective/gui/tableUI.py +248 -43
  19. celldetective/io.py +598 -416
  20. celldetective/measure.py +919 -969
  21. celldetective/models/pair_signal_detection/blank +0 -0
  22. celldetective/neighborhood.py +482 -340
  23. celldetective/preprocessing.py +81 -61
  24. celldetective/relative_measurements.py +648 -0
  25. celldetective/scripts/analyze_signals.py +1 -1
  26. celldetective/scripts/measure_cells.py +28 -8
  27. celldetective/scripts/measure_relative.py +103 -0
  28. celldetective/scripts/segment_cells.py +5 -5
  29. celldetective/scripts/track_cells.py +4 -1
  30. celldetective/scripts/train_segmentation_model.py +23 -18
  31. celldetective/scripts/train_signal_model.py +33 -0
  32. celldetective/segmentation.py +67 -29
  33. celldetective/signals.py +402 -8
  34. celldetective/tracking.py +8 -2
  35. celldetective/utils.py +144 -12
  36. {celldetective-1.1.1.post3.dist-info → celldetective-1.2.0.dist-info}/METADATA +8 -8
  37. {celldetective-1.1.1.post3.dist-info → celldetective-1.2.0.dist-info}/RECORD +42 -38
  38. {celldetective-1.1.1.post3.dist-info → celldetective-1.2.0.dist-info}/WHEEL +1 -1
  39. tests/test_segmentation.py +1 -1
  40. {celldetective-1.1.1.post3.dist-info → celldetective-1.2.0.dist-info}/LICENSE +0 -0
  41. {celldetective-1.1.1.post3.dist-info → celldetective-1.2.0.dist-info}/entry_points.txt +0 -0
  42. {celldetective-1.1.1.post3.dist-info → celldetective-1.2.0.dist-info}/top_level.txt +0 -0
celldetective/signals.py CHANGED
@@ -18,8 +18,8 @@ from sklearn.metrics import jaccard_score, balanced_accuracy_score, precision_sc
18
18
  from scipy.interpolate import interp1d
19
19
  from scipy.ndimage import shift
20
20
 
21
- from celldetective.io import get_signal_models_list, locate_signal_model
22
- from celldetective.tracking import clean_trajectories
21
+ from celldetective.io import get_signal_models_list, locate_signal_model, get_position_pickle, get_position_table
22
+ from celldetective.tracking import clean_trajectories, interpolate_nan_properties
23
23
  from celldetective.utils import regression_plot, train_test_split, compute_weights
24
24
  import matplotlib.pyplot as plt
25
25
  from natsort import natsorted
@@ -32,6 +32,7 @@ from scipy.optimize import curve_fit
32
32
  import time
33
33
  import math
34
34
  import pandas as pd
35
+ from pandas.api.types import is_numeric_dtype
35
36
 
36
37
  abs_path = os.sep.join([os.path.split(os.path.dirname(os.path.realpath(__file__)))[0],'celldetective'])
37
38
 
@@ -334,6 +335,392 @@ def analyze_signals_at_position(pos, model, mode, use_gpu=True, return_table=Fal
334
335
  else:
335
336
  return None
336
337
 
338
+ def analyze_pair_signals_at_position(pos, model, use_gpu=True):
339
+
340
+ """
341
+
342
+ """
343
+
344
+ pos = pos.replace('\\','/')
345
+ pos = rf"{pos}"
346
+ assert os.path.exists(pos),f'Position {pos} is not a valid path.'
347
+ if not pos.endswith('/'):
348
+ pos += '/'
349
+
350
+ df_targets = get_position_pickle(pos, population='targets')
351
+ df_effectors = get_position_pickle(pos, population='effectors')
352
+ dataframes = {
353
+ 'targets': df_targets,
354
+ 'effectors': df_effectors,
355
+ }
356
+ df_pairs = get_position_table(pos, population='pairs')
357
+
358
+ # Need to identify expected reference / neighbor tables
359
+ model_path = locate_signal_model(model, pairs=True)
360
+ print(f'Looking for model in {model_path}...')
361
+ complete_path = model_path
362
+ complete_path = rf"{complete_path}"
363
+ model_config_path = os.sep.join([complete_path, 'config_input.json'])
364
+ model_config_path = rf"{model_config_path}"
365
+ f = open(model_config_path)
366
+ model_config_path = json.load(f)
367
+
368
+ reference_population = model_config_path['reference_population']
369
+ neighbor_population = model_config_path['neighbor_population']
370
+
371
+ df = analyze_pair_signals(df_pairs, dataframes[reference_population], dataframes[neighbor_population], model=model)
372
+
373
+ table = pos + os.sep.join(["output","tables",f"trajectories_pairs.csv"])
374
+ df.to_csv(table, index=False)
375
+
376
+ return None
377
+
378
+
379
+ def analyze_signals(trajectories, model, interpolate_na=True,
380
+ selected_signals=None,
381
+ model_path=None,
382
+ column_labels={'track': "TRACK_ID", 'time': 'FRAME', 'x': 'POSITION_X', 'y': 'POSITION_Y'},
383
+ plot_outcome=False, output_dir=None):
384
+ """
385
+ Analyzes signals from trajectory data using a specified signal detection model and configuration.
386
+
387
+ This function preprocesses trajectory data, selects specified signals, and applies a pretrained signal detection
388
+ model to predict classes and times of interest for each trajectory. It supports custom column labeling, interpolation
389
+ of missing values, and plotting of analysis outcomes.
390
+
391
+ Parameters
392
+ ----------
393
+ trajectories : pandas.DataFrame
394
+ DataFrame containing trajectory data with columns for track ID, frame, position, and signals.
395
+ model : str
396
+ The name of the signal detection model to be used for analysis.
397
+ interpolate_na : bool, optional
398
+ Whether to interpolate missing values in the trajectories (default is True).
399
+ selected_signals : list of str, optional
400
+ A list of column names from `trajectories` representing the signals to be analyzed. If None, signals will
401
+ be automatically selected based on the model configuration (default is None).
402
+ column_labels : dict, optional
403
+ A dictionary mapping the default column names ('track', 'time', 'x', 'y') to the corresponding column names
404
+ in `trajectories` (default is {'track': "TRACK_ID", 'time': 'FRAME', 'x': 'POSITION_X', 'y': 'POSITION_Y'}).
405
+ plot_outcome : bool, optional
406
+ If True, generates and saves a plot of the signal analysis outcome (default is False).
407
+ output_dir : str, optional
408
+ The directory where the outcome plot will be saved. Required if `plot_outcome` is True (default is None).
409
+
410
+ Returns
411
+ -------
412
+ pandas.DataFrame
413
+ The input `trajectories` DataFrame with additional columns for predicted classes, times of interest, and
414
+ corresponding colors based on status and class.
415
+
416
+ Raises
417
+ ------
418
+ AssertionError
419
+ If the model or its configuration file cannot be located.
420
+
421
+ Notes
422
+ -----
423
+ - The function relies on an external model configuration file (`config_input.json`) located in the model's directory.
424
+ - Signal selection and preprocessing are based on the requirements specified in the model's configuration.
425
+
426
+ """
427
+
428
+ model_path = locate_signal_model(model, path=model_path)
429
+ complete_path = model_path # +model
430
+ complete_path = rf"{complete_path}"
431
+ model_config_path = os.sep.join([complete_path, 'config_input.json'])
432
+ model_config_path = rf"{model_config_path}"
433
+ assert os.path.exists(complete_path), f'Model {model} could not be located in folder {model_path}... Abort.'
434
+ assert os.path.exists(
435
+ model_config_path), f'Model configuration could not be located in folder {model_path}... Abort.'
436
+
437
+ available_signals = list(trajectories.columns)
438
+ print('The available_signals are : ', available_signals)
439
+
440
+ f = open(model_config_path)
441
+ config = json.load(f)
442
+ required_signals = config["channels"]
443
+
444
+ try:
445
+ label = config['label']
446
+ if label == '':
447
+ label = None
448
+ except:
449
+ label = None
450
+
451
+ if selected_signals is None:
452
+ selected_signals = []
453
+ for s in required_signals:
454
+ pattern_test = [s in a or s == a for a in available_signals]
455
+ print(f'Pattern test for signal {s}: ', pattern_test)
456
+ assert np.any(
457
+ pattern_test), f'No signal matches with the requirements of the model {required_signals}. Please pass the signals manually with the argument selected_signals or add measurements. Abort.'
458
+ valid_columns = np.array(available_signals)[np.array(pattern_test)]
459
+ if len(valid_columns) == 1:
460
+ selected_signals.append(valid_columns[0])
461
+ else:
462
+ # print(test_number_of_nan(trajectories, valid_columns))
463
+ print(f'Found several candidate signals: {valid_columns}')
464
+ for vc in natsorted(valid_columns):
465
+ if 'circle' in vc:
466
+ selected_signals.append(vc)
467
+ break
468
+ else:
469
+ selected_signals.append(valid_columns[0])
470
+ # do something more complicated in case of one to many columns
471
+ # pass
472
+ else:
473
+ assert len(selected_signals) == len(
474
+ required_signals), f'Mismatch between the number of required signals {required_signals} and the provided signals {selected_signals}... Abort.'
475
+
476
+ print(f'The following channels will be passed to the model: {selected_signals}')
477
+ trajectories_clean = clean_trajectories(trajectories, interpolate_na=interpolate_na,
478
+ interpolate_position_gaps=interpolate_na, column_labels=column_labels)
479
+
480
+ max_signal_size = int(trajectories_clean[column_labels['time']].max()) + 2
481
+ tracks = trajectories_clean[column_labels['track']].unique()
482
+ signals = np.zeros((len(tracks), max_signal_size, len(selected_signals)))
483
+
484
+ for i, (tid, group) in enumerate(trajectories_clean.groupby(column_labels['track'])):
485
+ frames = group[column_labels['time']].to_numpy().astype(int)
486
+ for j, col in enumerate(selected_signals):
487
+ signal = group[col].to_numpy()
488
+ signals[i, frames, j] = signal
489
+ signals[i, max(frames):, j] = signal[-1]
490
+
491
+ # for i in range(5):
492
+ # print('pre model')
493
+ # plt.plot(signals[i,:,0])
494
+ # plt.show()
495
+
496
+ model = SignalDetectionModel(pretrained=complete_path)
497
+ print('signal shape: ', signals.shape)
498
+
499
+ classes = model.predict_class(signals)
500
+ times_recast = model.predict_time_of_interest(signals)
501
+
502
+ if label is None:
503
+ class_col = 'class'
504
+ time_col = 't0'
505
+ status_col = 'status'
506
+ else:
507
+ class_col = 'class_' + label
508
+ time_col = 't_' + label
509
+ status_col = 'status_' + label
510
+
511
+ for i, (tid, group) in enumerate(trajectories.groupby(column_labels['track'])):
512
+ indices = group.index
513
+ trajectories.loc[indices, class_col] = classes[i]
514
+ trajectories.loc[indices, time_col] = times_recast[i]
515
+ print('Done.')
516
+
517
+ for tid, group in trajectories.groupby(column_labels['track']):
518
+
519
+ indices = group.index
520
+ t0 = group[time_col].to_numpy()[0]
521
+ cclass = group[class_col].to_numpy()[0]
522
+ timeline = group[column_labels['time']].to_numpy()
523
+ status = np.zeros_like(timeline)
524
+ if t0 > 0:
525
+ status[timeline >= t0] = 1.
526
+ if cclass == 2:
527
+ status[:] = 2
528
+ if cclass > 2:
529
+ status[:] = 42
530
+ status_color = [color_from_status(s) for s in status]
531
+ class_color = [color_from_class(cclass) for i in range(len(status))]
532
+
533
+ trajectories.loc[indices, status_col] = status
534
+ trajectories.loc[indices, 'status_color'] = status_color
535
+ trajectories.loc[indices, 'class_color'] = class_color
536
+
537
+ if plot_outcome:
538
+ fig, ax = plt.subplots(1, len(selected_signals), figsize=(10, 5))
539
+ for i, s in enumerate(selected_signals):
540
+ for k, (tid, group) in enumerate(trajectories.groupby(column_labels['track'])):
541
+ cclass = group[class_col].to_numpy()[0]
542
+ t0 = group[time_col].to_numpy()[0]
543
+ timeline = group[column_labels['time']].to_numpy()
544
+ if cclass == 0:
545
+ if len(selected_signals) > 1:
546
+ ax[i].plot(timeline - t0, group[s].to_numpy(), c='tab:blue', alpha=0.1)
547
+ else:
548
+ ax.plot(timeline - t0, group[s].to_numpy(), c='tab:blue', alpha=0.1)
549
+ if len(selected_signals) > 1:
550
+ for a, s in zip(ax, selected_signals):
551
+ a.set_title(s)
552
+ a.set_xlabel(r'time - t$_0$ [frame]')
553
+ a.spines['top'].set_visible(False)
554
+ a.spines['right'].set_visible(False)
555
+ else:
556
+ ax.set_title(s)
557
+ ax.set_xlabel(r'time - t$_0$ [frame]')
558
+ ax.spines['top'].set_visible(False)
559
+ ax.spines['right'].set_visible(False)
560
+ plt.tight_layout()
561
+ if output_dir is not None:
562
+ plt.savefig(output_dir + 'signal_collapse.png', bbox_inches='tight', dpi=300)
563
+ plt.pause(3)
564
+ plt.close()
565
+
566
+ return trajectories
567
+
568
+
569
+ def analyze_pair_signals(trajectories_pairs,trajectories_reference,trajectories_neighbors, model, interpolate_na=True, selected_signals=None,
570
+ model_path=None, plot_outcome=False, output_dir=None, column_labels = {'track': "TRACK_ID", 'time': 'FRAME', 'x': 'POSITION_X', 'y': 'POSITION_Y'}):
571
+ """
572
+ """
573
+
574
+ model_path = locate_signal_model(model, path=model_path, pairs=True)
575
+ print(f'Looking for model in {model_path}...')
576
+ complete_path = model_path
577
+ complete_path = rf"{complete_path}"
578
+ model_config_path = os.sep.join([complete_path, 'config_input.json'])
579
+ model_config_path = rf"{model_config_path}"
580
+ assert os.path.exists(complete_path), f'Model {model} could not be located in folder {model_path}... Abort.'
581
+ assert os.path.exists(model_config_path), f'Model configuration could not be located in folder {model_path}... Abort.'
582
+
583
+ trajectories_pairs = trajectories_pairs.rename(columns=lambda x: 'pair_' + x)
584
+ trajectories_reference = trajectories_reference.rename(columns=lambda x: 'reference_' + x)
585
+ trajectories_neighbors = trajectories_neighbors.rename(columns=lambda x: 'neighbor_' + x)
586
+
587
+ if 'pair_position' in list(trajectories_pairs.columns):
588
+ pair_groupby_cols = ['pair_position', 'pair_REFERENCE_ID', 'pair_NEIGHBOR_ID']
589
+ else:
590
+ pair_groupby_cols = ['pair_REFERENCE_ID', 'pair_NEIGHBOR_ID']
591
+
592
+ if 'reference_position' in list(trajectories_reference.columns):
593
+ reference_groupby_cols = ['reference_position', 'reference_TRACK_ID']
594
+ else:
595
+ reference_groupby_cols = ['reference_TRACK_ID']
596
+
597
+ if 'neighbor_position' in list(trajectories_neighbors.columns):
598
+ neighbor_groupby_cols = ['neighbor_position', 'neighbor_TRACK_ID']
599
+ else:
600
+ neighbor_groupby_cols = ['neighbor_TRACK_ID']
601
+
602
+ available_signals = [] #list(trajectories_pairs.columns) + list(trajectories_reference.columns) + list(trajectories_neighbors.columns)
603
+ for col in list(trajectories_pairs.columns):
604
+ if is_numeric_dtype(trajectories_pairs[col]):
605
+ available_signals.append(col)
606
+ for col in list(trajectories_reference.columns):
607
+ if is_numeric_dtype(trajectories_reference[col]):
608
+ available_signals.append(col)
609
+ for col in list(trajectories_neighbors.columns):
610
+ if is_numeric_dtype(trajectories_neighbors[col]):
611
+ available_signals.append(col)
612
+
613
+ print('The available signals are : ', available_signals)
614
+
615
+ f = open(model_config_path)
616
+ config = json.load(f)
617
+ required_signals = config["channels"]
618
+
619
+ try:
620
+ label = config['label']
621
+ if label=='':
622
+ label = None
623
+ except:
624
+ label = None
625
+
626
+ if selected_signals is None:
627
+ selected_signals = []
628
+ for s in required_signals:
629
+ pattern_test = [s in a or s==a for a in available_signals]
630
+ print(f'Pattern test for signal {s}: ', pattern_test)
631
+ assert np.any(pattern_test),f'No signal matches with the requirements of the model {required_signals}. Please pass the signals manually with the argument selected_signals or add measurements. Abort.'
632
+ valid_columns = np.array(available_signals)[np.array(pattern_test)]
633
+ if len(valid_columns)==1:
634
+ selected_signals.append(valid_columns[0])
635
+ else:
636
+ #print(test_number_of_nan(trajectories, valid_columns))
637
+ print(f'Found several candidate signals: {valid_columns}')
638
+ for vc in natsorted(valid_columns):
639
+ if 'circle' in vc:
640
+ selected_signals.append(vc)
641
+ break
642
+ else:
643
+ selected_signals.append(valid_columns[0])
644
+ # do something more complicated in case of one to many columns
645
+ #pass
646
+ else:
647
+ assert len(selected_signals)==len(required_signals),f'Mismatch between the number of required signals {required_signals} and the provided signals {selected_signals}... Abort.'
648
+
649
+ print(f'The following channels will be passed to the model: {selected_signals}')
650
+ trajectories_reference_clean = interpolate_nan_properties(trajectories_reference, track_label=reference_groupby_cols)
651
+ trajectories_neighbors_clean = interpolate_nan_properties(trajectories_neighbors, track_label=neighbor_groupby_cols)
652
+ trajectories_pairs_clean = interpolate_nan_properties(trajectories_pairs, track_label=pair_groupby_cols)
653
+ print(f'{trajectories_pairs_clean.columns=}')
654
+
655
+ max_signal_size = int(trajectories_pairs_clean['pair_FRAME'].max()) + 2
656
+ pair_tracks = trajectories_pairs_clean.groupby(pair_groupby_cols).size()
657
+ signals = np.zeros((len(pair_tracks),max_signal_size, len(selected_signals)))
658
+ print(f'{max_signal_size=} {len(pair_tracks)=} {signals.shape=}')
659
+
660
+ for i,(pair,group) in enumerate(trajectories_pairs_clean.groupby(pair_groupby_cols)):
661
+
662
+ if 'pair_position' not in list(trajectories_pairs_clean.columns):
663
+ pos_mode = False
664
+ reference_cell = pair[0]; neighbor_cell = pair[1]
665
+ else:
666
+ pos_mode = True
667
+ reference_cell = pair[1]; neighbor_cell = pair[2]; pos = pair[0]
668
+
669
+ if pos_mode and 'reference_position' in list(trajectories_reference_clean.columns) and 'neighbor_position' in list(trajectories_neighbors_clean.columns):
670
+ reference_filter = (trajectories_reference_clean['reference_TRACK_ID']==reference_cell)&(trajectories_reference_clean['reference_position']==pos)
671
+ neighbor_filter = (trajectories_neighbors_clean['neighbor_TRACK_ID']==neighbor_cell)&(trajectories_neighbors_clean['neighbor_position']==pos)
672
+ else:
673
+ reference_filter = trajectories_reference_clean['reference_TRACK_ID']==reference_cell
674
+ neighbor_filter = trajectories_neighbors_clean['neighbor_TRACK_ID']==neighbor_cell
675
+
676
+ pair_frames = group['pair_FRAME'].to_numpy().astype(int)
677
+
678
+ for j,col in enumerate(selected_signals):
679
+ if col.startswith('pair_'):
680
+ signal = group[col].to_numpy()
681
+ signals[i,pair_frames,j] = signal
682
+ signals[i,max(pair_frames):,j] = signal[-1]
683
+ elif col.startswith('reference_'):
684
+ signal = trajectories_reference_clean.loc[reference_filter, col].to_numpy()
685
+ timeline = trajectories_reference_clean.loc[reference_filter, 'reference_FRAME'].to_numpy()
686
+ signals[i,timeline,j] = signal
687
+ signals[i,max(timeline):,j] = signal[-1]
688
+ elif col.startswith('neighbor_'):
689
+ signal = trajectories_neighbors_clean.loc[neighbor_filter, col].to_numpy()
690
+ timeline = trajectories_neighbors_clean.loc[neighbor_filter, 'neighbor_FRAME'].to_numpy()
691
+ signals[i,timeline,j] = signal
692
+ signals[i,max(timeline):,j] = signal[-1]
693
+
694
+
695
+ model = SignalDetectionModel(pretrained=complete_path)
696
+ print('signal shape: ', signals.shape)
697
+
698
+ classes = model.predict_class(signals)
699
+ times_recast = model.predict_time_of_interest(signals)
700
+
701
+ if label is None:
702
+ class_col = 'pair_class'
703
+ time_col = 'pair_t0'
704
+ status_col = 'pair_status'
705
+ else:
706
+ class_col = 'pair_class_'+label
707
+ time_col = 'pair_t_'+label
708
+ status_col = 'pair_status_'+label
709
+
710
+ for i,(pair,group) in enumerate(trajectories_pairs.groupby(pair_groupby_cols)):
711
+ indices = group.index
712
+ trajectories_pairs.loc[indices,class_col] = classes[i]
713
+ trajectories_pairs.loc[indices,time_col] = times_recast[i]
714
+ print('Done.')
715
+
716
+ # At the end rename cols again
717
+ trajectories_pairs = trajectories_pairs.rename(columns=lambda x: x.replace('pair_',''))
718
+ trajectories_reference = trajectories_pairs.rename(columns=lambda x: x.replace('reference_',''))
719
+ trajectories_neighbors = trajectories_pairs.rename(columns=lambda x: x.replace('neighbor_',''))
720
+ invalid_cols = [c for c in list(trajectories_pairs.columns) if c.startswith('Unnamed')]
721
+ trajectories_pairs = trajectories_pairs.drop(columns=invalid_cols)
722
+
723
+ return trajectories_pairs
337
724
 
338
725
  class SignalDetectionModel(object):
339
726
 
@@ -1258,10 +1645,10 @@ class SignalDetectionModel(object):
1258
1645
  csv_logger = CSVLogger(os.sep.join([self.model_folder,'log_classifier.csv']), append=True, separator=';')
1259
1646
  self.cb.append(csv_logger)
1260
1647
  checkpoint_path = os.sep.join([self.model_folder,"classifier.h5"])
1261
- cp_callback = ModelCheckpoint(checkpoint_path,monitor="val_iou",mode="max",verbose=1,save_best_only=True,save_weights_only=False,save_freq="epoch")
1648
+ cp_callback = ModelCheckpoint(checkpoint_path, monitor="val_iou",mode="max",verbose=1,save_best_only=True,save_weights_only=False,save_freq="epoch")
1262
1649
  self.cb.append(cp_callback)
1263
1650
 
1264
- callback_stop = EarlyStopping(monitor='val_iou', patience=100)
1651
+ callback_stop = EarlyStopping(monitor='val_iou',mode='max',patience=100)
1265
1652
  self.cb.append(callback_stop)
1266
1653
 
1267
1654
  elif mode=="regressor":
@@ -1278,7 +1665,7 @@ class SignalDetectionModel(object):
1278
1665
  cp_callback = ModelCheckpoint(checkpoint_path,monitor="val_loss",mode="min",verbose=1,save_best_only=True,save_weights_only=False,save_freq="epoch")
1279
1666
  self.cb.append(cp_callback)
1280
1667
 
1281
- callback_stop = EarlyStopping(monitor='val_loss', patience=200)
1668
+ callback_stop = EarlyStopping(monitor='val_loss', mode='min', patience=200)
1282
1669
  self.cb.append(callback_stop)
1283
1670
 
1284
1671
  log_dir = self.model_folder+os.sep
@@ -1312,7 +1699,7 @@ class SignalDetectionModel(object):
1312
1699
  if isinstance(self.list_of_sets[0],str):
1313
1700
  # Case 1: a list of npy files to be loaded
1314
1701
  for s in self.list_of_sets:
1315
-
1702
+
1316
1703
  signal_dataset = self.load_set(s)
1317
1704
  selected_signals, max_length = self.find_best_signal_match(signal_dataset)
1318
1705
  signals_recast, classes, times_of_interest = self.cast_signals_into_training_data(signal_dataset, selected_signals, max_length)
@@ -1416,7 +1803,6 @@ class SignalDetectionModel(object):
1416
1803
  x_train_aug.append(aug[0])
1417
1804
  y_time_train_aug.append(aug[1])
1418
1805
  y_class_train_aug.append(aug[2])
1419
- print('per class counts ',counts)
1420
1806
 
1421
1807
  # Save augmented training set
1422
1808
  self.x_train = np.array(x_train_aug)
@@ -1469,7 +1855,15 @@ class SignalDetectionModel(object):
1469
1855
  for i in range(self.n_channels):
1470
1856
  try:
1471
1857
  # take into account timeline for accurate time regression
1472
- timeline = signal_dataset[k]['FRAME'].astype(int)
1858
+
1859
+ if selected_signals[i].startswith('pair_'):
1860
+ timeline = signal_dataset[k]['pair_FRAME'].astype(int)
1861
+ elif selected_signals[i].startswith('reference_'):
1862
+ timeline = signal_dataset[k]['reference_FRAME'].astype(int)
1863
+ elif selected_signals[i].startswith('neighbor_'):
1864
+ timeline = signal_dataset[k]['neighbor_FRAME'].astype(int)
1865
+ else:
1866
+ timeline = signal_dataset[k]['FRAME'].astype(int)
1473
1867
  signals_recast[k,timeline,i] = signal_dataset[k][selected_signals[i]]
1474
1868
  except:
1475
1869
  print(f"Attribute {selected_signals[i]} matched to {self.channel_option[i]} not found in annotation...")
celldetective/tracking.py CHANGED
@@ -130,7 +130,7 @@ def track(labels, configuration=None, stack=None, spatial_calibration=1, feature
130
130
  tracking_updates = ["motion"]
131
131
 
132
132
  tracker.append(new_btrack_objects)
133
- tracker.volume = ((0,volume[0]), (0,volume[1])) #(-1e5, 1e5)
133
+ tracker.volume = ((0,volume[0]), (0,volume[1]), (-1e5, 1e5)) #(-1e5, 1e5)
134
134
  #print(tracker.volume)
135
135
  tracker.track(tracking_updates=tracking_updates, **track_kwargs)
136
136
  tracker.optimize(options=optimizer_options)
@@ -138,7 +138,11 @@ def track(labels, configuration=None, stack=None, spatial_calibration=1, feature
138
138
  data, properties, graph = tracker.to_napari() #ndim=2
139
139
 
140
140
  # do the table post processing and napari options
141
- df = pd.DataFrame(data, columns=[column_labels['track'],column_labels['time'],column_labels['y'],column_labels['x']])
141
+ if data.shape[1]==4:
142
+ df = pd.DataFrame(data, columns=[column_labels['track'],column_labels['time'],column_labels['y'],column_labels['x']])
143
+ elif data.shape[1]==5:
144
+ df = pd.DataFrame(data, columns=[column_labels['track'],column_labels['time'],"z",column_labels['y'],column_labels['x']])
145
+ df = df.drop(columns=['z'])
142
146
  df[column_labels['x']+'_um'] = df[column_labels['x']]*spatial_calibration
143
147
  df[column_labels['y']+'_um'] = df[column_labels['y']]*spatial_calibration
144
148
 
@@ -160,6 +164,8 @@ def track(labels, configuration=None, stack=None, spatial_calibration=1, feature
160
164
  if clean_trajectories_kwargs is not None:
161
165
  df = clean_trajectories(df.copy(),**clean_trajectories_kwargs)
162
166
 
167
+ df['ID'] = np.arange(len(df)).astype(int)
168
+
163
169
  if view_on_napari:
164
170
  view_on_napari_btrack(data,properties,graph,stack=stack,labels=labels,relabel=True)
165
171
 
celldetective/utils.py CHANGED
@@ -23,7 +23,85 @@ from tqdm import tqdm
23
23
  import shutil
24
24
  import tempfile
25
25
  from scipy.interpolate import griddata
26
+ import re
27
+ from scipy.ndimage.morphology import distance_transform_edt
28
+ from scipy import ndimage
29
+ from skimage.morphology import disk
30
+
31
+ def contour_of_instance_segmentation(label, distance):
32
+
33
+ """
34
+
35
+ Generate an instance mask containing the contour of the segmented objects.
36
+
37
+ Parameters
38
+ ----------
39
+ label : ndarray
40
+ The instance segmentation labels.
41
+ distance : int, float, list, or tuple
42
+ The distance or range of distances from the edge of each instance to include in the contour.
43
+ If a single value is provided, it represents the maximum distance. If a tuple or list is provided,
44
+ it represents the minimum and maximum distances.
45
+
46
+ Returns
47
+ -------
48
+ border_label : ndarray
49
+ An instance mask containing the contour of the segmented objects.
50
+
51
+ Notes
52
+ -----
53
+ This function generates an instance mask representing the contour of the segmented instances in the label image.
54
+ It use the distance_transform_edt function from the scipy.ndimage module to compute the Euclidean distance transform.
55
+ The contour is defined based on the specified distance(s) from the edge of each instance.
56
+ The resulting mask, `border_label`, contains the contour regions, while the interior regions are set to zero.
57
+
58
+ Examples
59
+ --------
60
+ >>> border_label = contour_of_instance_segmentation(label, distance=3)
61
+ # Generate a binary mask containing the contour of the segmented instances with a maximum distance of 3 pixels.
62
+
63
+ """
64
+ if isinstance(distance,(list,tuple)) or distance >= 0 :
65
+
66
+ edt = distance_transform_edt(label)
67
+
68
+ if isinstance(distance, list) or isinstance(distance, tuple):
69
+ min_distance = distance[0]; max_distance = distance[1]
70
+
71
+ elif isinstance(distance, (int, float)):
72
+ min_distance = 0
73
+ max_distance = distance
74
+
75
+ thresholded = (edt <= max_distance) * (edt > min_distance)
76
+ border_label = np.copy(label)
77
+ border_label[np.where(thresholded == 0)] = 0
78
+
79
+ else:
80
+ size = (2*abs(int(distance))+1, 2*abs(int(distance))+1)
81
+ dilated_image = ndimage.grey_dilation(label, footprint=disk(int(abs(distance)))) #size=size,
82
+ border_label=np.copy(dilated_image)
83
+ matching_cells = np.logical_and(dilated_image != 0, label == dilated_image)
84
+ border_label[np.where(matching_cells == True)] = 0
85
+ border_label[label!=0] = 0.
86
+
87
+ return border_label
88
+
89
+ def extract_identity_col(trajectories):
90
+
91
+ if 'TRACK_ID' in list(trajectories.columns):
92
+ if not np.all(trajectories['TRACK_ID'].isnull()):
93
+ id_col = 'TRACK_ID'
94
+ else:
95
+ if 'ID' in list(trajectories.columns):
96
+ id_col = 'ID'
97
+ elif 'ID' in list(trajectories.columns):
98
+
99
+ id_col = 'ID'
100
+ else:
101
+ print('ID or TRACK ID column could not be found in the table...')
102
+ id_col = None
26
103
 
104
+ return id_col
27
105
 
28
106
  def derivative(x, timeline, window, mode='bi'):
29
107
 
@@ -430,6 +508,31 @@ def mask_edges(binary_mask, border_size):
430
508
  return binary_mask
431
509
 
432
510
 
511
+ def extract_cols_from_query(query: str):
512
+
513
+ # Track variables in a dictionary to be used as a dictionary of globals. From: https://stackoverflow.com/questions/64576913/extract-pandas-dataframe-column-names-from-query-string
514
+
515
+ variables = {}
516
+
517
+ while True:
518
+ try:
519
+ # Try creating a Expr object with the query string and dictionary of globals.
520
+ # This will raise an error as long as the dictionary of globals is incomplete.
521
+ env = pd.core.computation.scope.ensure_scope(level=0, global_dict=variables)
522
+ pd.core.computation.eval.Expr(query, env=env)
523
+
524
+ # Exit the loop when evaluation is successful.
525
+ break
526
+ except pd.errors.UndefinedVariableError as e:
527
+ # This relies on the format defined here: https://github.com/pandas-dev/pandas/blob/965ceca9fd796940050d6fc817707bba1c4f9bff/pandas/errors/__init__.py#L401
528
+ name = re.findall("name '(.+?)' is not defined", str(e))[0]
529
+
530
+ # Add the name to the globals dictionary with a dummy value.
531
+ variables[name] = None
532
+
533
+ return list(variables.keys())
534
+
535
+
433
536
  def create_patch_mask(h, w, center=None, radius=None):
434
537
 
435
538
  """
@@ -564,6 +667,22 @@ def rename_intensity_column(df, channels):
564
667
  else:
565
668
  new_name = new_name.replace('centre_of_mass', "centre_of_mass_orientation")
566
669
  to_rename.update({intensity_columns[k]: new_name.replace('-', '_')})
670
+ elif sections[-2] == "2":
671
+ new_name = np.delete(measure, -1)
672
+ new_name = '_'.join(list(new_name))
673
+ if 'edge' in intensity_columns[k]:
674
+ new_name = new_name.replace('centre_of_mass_displacement', "edge_centre_of_mass_x")
675
+ else:
676
+ new_name = new_name.replace('centre_of_mass', "centre_of_mass_x")
677
+ to_rename.update({intensity_columns[k]: new_name.replace('-', '_')})
678
+ elif sections[-2] == "3":
679
+ new_name = np.delete(measure, -1)
680
+ new_name = '_'.join(list(new_name))
681
+ if 'edge' in intensity_columns[k]:
682
+ new_name = new_name.replace('centre_of_mass_displacement', "edge_centre_of_mass_y")
683
+ else:
684
+ new_name = new_name.replace('centre_of_mass', "centre_of_mass_y")
685
+ to_rename.update({intensity_columns[k]: new_name.replace('-', '_')})
567
686
  if 'radial_gradient' in intensity_columns[k]:
568
687
  # sections = np.array(re.split('-|_', intensity_columns[k]))
569
688
  measure = np.array(re.split('-|_', new_name))
@@ -1058,20 +1177,33 @@ def _extract_channel_indices(channels, required_channels):
1058
1177
  # [0, 1]
1059
1178
  """
1060
1179
 
1061
- if channels is not None:
1062
- channel_indices = []
1063
- for ch in required_channels:
1064
-
1180
+ channel_indices = []
1181
+ for c in required_channels:
1182
+ if c!='None' and c is not None:
1065
1183
  try:
1066
- idx = channels.index(ch)
1067
- except ValueError:
1068
- print('Mismatch between the channels required by the model and the provided channels.')
1069
- return None
1184
+ ch_idx = channels.index(c)
1185
+ channel_indices.append(ch_idx)
1186
+ except Exception as e:
1187
+ print(f"Error {e}. The channel required by the model is not available in your data... Check the configuration file.")
1188
+ channels = None
1189
+ break
1190
+ else:
1191
+ channel_indices.append(None)
1070
1192
 
1071
- channel_indices.append(idx)
1072
- channel_indices = np.array(channel_indices)
1073
- else:
1074
- channel_indices = np.arange(len(required_channels))
1193
+ # if channels is not None:
1194
+ # channel_indices = []
1195
+ # for ch in required_channels:
1196
+
1197
+ # try:
1198
+ # idx = channels.index(ch)
1199
+ # except ValueError:
1200
+ # print('Mismatch between the channels required by the model and the provided channels.')
1201
+ # return None
1202
+
1203
+ # channel_indices.append(idx)
1204
+ # channel_indices = np.array(channel_indices)
1205
+ # else:
1206
+ # channel_indices = np.arange(len(required_channels))
1075
1207
 
1076
1208
  return channel_indices
1077
1209