celldetective 1.1.1.post4__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. celldetective/__init__.py +2 -1
  2. celldetective/extra_properties.py +62 -34
  3. celldetective/gui/__init__.py +1 -0
  4. celldetective/gui/analyze_block.py +2 -1
  5. celldetective/gui/classifier_widget.py +15 -9
  6. celldetective/gui/control_panel.py +50 -6
  7. celldetective/gui/layouts.py +5 -4
  8. celldetective/gui/neighborhood_options.py +13 -9
  9. celldetective/gui/plot_signals_ui.py +39 -11
  10. celldetective/gui/process_block.py +413 -95
  11. celldetective/gui/retrain_segmentation_model_options.py +17 -4
  12. celldetective/gui/retrain_signal_model_options.py +106 -6
  13. celldetective/gui/signal_annotator.py +29 -9
  14. celldetective/gui/signal_annotator2.py +2708 -0
  15. celldetective/gui/signal_annotator_options.py +3 -1
  16. celldetective/gui/survival_ui.py +15 -6
  17. celldetective/gui/tableUI.py +222 -60
  18. celldetective/io.py +536 -420
  19. celldetective/measure.py +919 -969
  20. celldetective/models/pair_signal_detection/blank +0 -0
  21. celldetective/models/segmentation_effectors/ricm-bimodal/config_input.json +130 -0
  22. celldetective/models/segmentation_effectors/ricm-bimodal/ricm-bimodal +0 -0
  23. celldetective/models/segmentation_effectors/ricm-bimodal/training_instructions.json +37 -0
  24. celldetective/neighborhood.py +428 -354
  25. celldetective/relative_measurements.py +648 -0
  26. celldetective/scripts/analyze_signals.py +1 -1
  27. celldetective/scripts/measure_cells.py +28 -8
  28. celldetective/scripts/measure_relative.py +103 -0
  29. celldetective/scripts/segment_cells.py +5 -5
  30. celldetective/scripts/track_cells.py +4 -1
  31. celldetective/scripts/train_segmentation_model.py +23 -18
  32. celldetective/scripts/train_signal_model.py +33 -0
  33. celldetective/signals.py +405 -8
  34. celldetective/tracking.py +8 -2
  35. celldetective/utils.py +178 -17
  36. {celldetective-1.1.1.post4.dist-info → celldetective-1.2.1.dist-info}/METADATA +8 -8
  37. {celldetective-1.1.1.post4.dist-info → celldetective-1.2.1.dist-info}/RECORD +41 -34
  38. {celldetective-1.1.1.post4.dist-info → celldetective-1.2.1.dist-info}/WHEEL +1 -1
  39. {celldetective-1.1.1.post4.dist-info → celldetective-1.2.1.dist-info}/LICENSE +0 -0
  40. {celldetective-1.1.1.post4.dist-info → celldetective-1.2.1.dist-info}/entry_points.txt +0 -0
  41. {celldetective-1.1.1.post4.dist-info → celldetective-1.2.1.dist-info}/top_level.txt +0 -0
celldetective/signals.py CHANGED
@@ -18,8 +18,8 @@ from sklearn.metrics import jaccard_score, balanced_accuracy_score, precision_sc
18
18
  from scipy.interpolate import interp1d
19
19
  from scipy.ndimage import shift
20
20
 
21
- from celldetective.io import get_signal_models_list, locate_signal_model
22
- from celldetective.tracking import clean_trajectories
21
+ from celldetective.io import get_signal_models_list, locate_signal_model, get_position_pickle, get_position_table
22
+ from celldetective.tracking import clean_trajectories, interpolate_nan_properties
23
23
  from celldetective.utils import regression_plot, train_test_split, compute_weights
24
24
  import matplotlib.pyplot as plt
25
25
  from natsort import natsorted
@@ -32,6 +32,7 @@ from scipy.optimize import curve_fit
32
32
  import time
33
33
  import math
34
34
  import pandas as pd
35
+ from pandas.api.types import is_numeric_dtype
35
36
 
36
37
  abs_path = os.sep.join([os.path.split(os.path.dirname(os.path.realpath(__file__)))[0],'celldetective'])
37
38
 
@@ -154,6 +155,7 @@ def analyze_signals(trajectories, model, interpolate_na=True,
154
155
  f = open(model_config_path)
155
156
  config = json.load(f)
156
157
  required_signals = config["channels"]
158
+ model_signal_length = config['model_signal_length']
157
159
 
158
160
  try:
159
161
  label = config['label']
@@ -189,6 +191,8 @@ def analyze_signals(trajectories, model, interpolate_na=True,
189
191
  trajectories_clean = clean_trajectories(trajectories, interpolate_na=interpolate_na, interpolate_position_gaps=interpolate_na, column_labels=column_labels)
190
192
 
191
193
  max_signal_size = int(trajectories_clean[column_labels['time']].max()) + 2
194
+ assert max_signal_size <= model_signal_length,f'The current signals are longer ({max_signal_size}) than the maximum expected input ({model_signal_length}) for this signal analysis model. Abort...'
195
+
192
196
  tracks = trajectories_clean[column_labels['track']].unique()
193
197
  signals = np.zeros((len(tracks),max_signal_size, len(selected_signals)))
194
198
 
@@ -334,6 +338,392 @@ def analyze_signals_at_position(pos, model, mode, use_gpu=True, return_table=Fal
334
338
  else:
335
339
  return None
336
340
 
341
+ def analyze_pair_signals_at_position(pos, model, use_gpu=True):
342
+
343
+ """
344
+
345
+ """
346
+
347
+ pos = pos.replace('\\','/')
348
+ pos = rf"{pos}"
349
+ assert os.path.exists(pos),f'Position {pos} is not a valid path.'
350
+ if not pos.endswith('/'):
351
+ pos += '/'
352
+
353
+ df_targets = get_position_pickle(pos, population='targets')
354
+ df_effectors = get_position_pickle(pos, population='effectors')
355
+ dataframes = {
356
+ 'targets': df_targets,
357
+ 'effectors': df_effectors,
358
+ }
359
+ df_pairs = get_position_table(pos, population='pairs')
360
+
361
+ # Need to identify expected reference / neighbor tables
362
+ model_path = locate_signal_model(model, pairs=True)
363
+ print(f'Looking for model in {model_path}...')
364
+ complete_path = model_path
365
+ complete_path = rf"{complete_path}"
366
+ model_config_path = os.sep.join([complete_path, 'config_input.json'])
367
+ model_config_path = rf"{model_config_path}"
368
+ f = open(model_config_path)
369
+ model_config_path = json.load(f)
370
+
371
+ reference_population = model_config_path['reference_population']
372
+ neighbor_population = model_config_path['neighbor_population']
373
+
374
+ df = analyze_pair_signals(df_pairs, dataframes[reference_population], dataframes[neighbor_population], model=model)
375
+
376
+ table = pos + os.sep.join(["output","tables",f"trajectories_pairs.csv"])
377
+ df.to_csv(table, index=False)
378
+
379
+ return None
380
+
381
+
382
+ def analyze_signals(trajectories, model, interpolate_na=True,
383
+ selected_signals=None,
384
+ model_path=None,
385
+ column_labels={'track': "TRACK_ID", 'time': 'FRAME', 'x': 'POSITION_X', 'y': 'POSITION_Y'},
386
+ plot_outcome=False, output_dir=None):
387
+ """
388
+ Analyzes signals from trajectory data using a specified signal detection model and configuration.
389
+
390
+ This function preprocesses trajectory data, selects specified signals, and applies a pretrained signal detection
391
+ model to predict classes and times of interest for each trajectory. It supports custom column labeling, interpolation
392
+ of missing values, and plotting of analysis outcomes.
393
+
394
+ Parameters
395
+ ----------
396
+ trajectories : pandas.DataFrame
397
+ DataFrame containing trajectory data with columns for track ID, frame, position, and signals.
398
+ model : str
399
+ The name of the signal detection model to be used for analysis.
400
+ interpolate_na : bool, optional
401
+ Whether to interpolate missing values in the trajectories (default is True).
402
+ selected_signals : list of str, optional
403
+ A list of column names from `trajectories` representing the signals to be analyzed. If None, signals will
404
+ be automatically selected based on the model configuration (default is None).
405
+ column_labels : dict, optional
406
+ A dictionary mapping the default column names ('track', 'time', 'x', 'y') to the corresponding column names
407
+ in `trajectories` (default is {'track': "TRACK_ID", 'time': 'FRAME', 'x': 'POSITION_X', 'y': 'POSITION_Y'}).
408
+ plot_outcome : bool, optional
409
+ If True, generates and saves a plot of the signal analysis outcome (default is False).
410
+ output_dir : str, optional
411
+ The directory where the outcome plot will be saved. Required if `plot_outcome` is True (default is None).
412
+
413
+ Returns
414
+ -------
415
+ pandas.DataFrame
416
+ The input `trajectories` DataFrame with additional columns for predicted classes, times of interest, and
417
+ corresponding colors based on status and class.
418
+
419
+ Raises
420
+ ------
421
+ AssertionError
422
+ If the model or its configuration file cannot be located.
423
+
424
+ Notes
425
+ -----
426
+ - The function relies on an external model configuration file (`config_input.json`) located in the model's directory.
427
+ - Signal selection and preprocessing are based on the requirements specified in the model's configuration.
428
+
429
+ """
430
+
431
+ model_path = locate_signal_model(model, path=model_path)
432
+ complete_path = model_path # +model
433
+ complete_path = rf"{complete_path}"
434
+ model_config_path = os.sep.join([complete_path, 'config_input.json'])
435
+ model_config_path = rf"{model_config_path}"
436
+ assert os.path.exists(complete_path), f'Model {model} could not be located in folder {model_path}... Abort.'
437
+ assert os.path.exists(
438
+ model_config_path), f'Model configuration could not be located in folder {model_path}... Abort.'
439
+
440
+ available_signals = list(trajectories.columns)
441
+ print('The available_signals are : ', available_signals)
442
+
443
+ f = open(model_config_path)
444
+ config = json.load(f)
445
+ required_signals = config["channels"]
446
+
447
+ try:
448
+ label = config['label']
449
+ if label == '':
450
+ label = None
451
+ except:
452
+ label = None
453
+
454
+ if selected_signals is None:
455
+ selected_signals = []
456
+ for s in required_signals:
457
+ pattern_test = [s in a or s == a for a in available_signals]
458
+ print(f'Pattern test for signal {s}: ', pattern_test)
459
+ assert np.any(
460
+ pattern_test), f'No signal matches with the requirements of the model {required_signals}. Please pass the signals manually with the argument selected_signals or add measurements. Abort.'
461
+ valid_columns = np.array(available_signals)[np.array(pattern_test)]
462
+ if len(valid_columns) == 1:
463
+ selected_signals.append(valid_columns[0])
464
+ else:
465
+ # print(test_number_of_nan(trajectories, valid_columns))
466
+ print(f'Found several candidate signals: {valid_columns}')
467
+ for vc in natsorted(valid_columns):
468
+ if 'circle' in vc:
469
+ selected_signals.append(vc)
470
+ break
471
+ else:
472
+ selected_signals.append(valid_columns[0])
473
+ # do something more complicated in case of one to many columns
474
+ # pass
475
+ else:
476
+ assert len(selected_signals) == len(
477
+ required_signals), f'Mismatch between the number of required signals {required_signals} and the provided signals {selected_signals}... Abort.'
478
+
479
+ print(f'The following channels will be passed to the model: {selected_signals}')
480
+ trajectories_clean = clean_trajectories(trajectories, interpolate_na=interpolate_na,
481
+ interpolate_position_gaps=interpolate_na, column_labels=column_labels)
482
+
483
+ max_signal_size = int(trajectories_clean[column_labels['time']].max()) + 2
484
+ tracks = trajectories_clean[column_labels['track']].unique()
485
+ signals = np.zeros((len(tracks), max_signal_size, len(selected_signals)))
486
+
487
+ for i, (tid, group) in enumerate(trajectories_clean.groupby(column_labels['track'])):
488
+ frames = group[column_labels['time']].to_numpy().astype(int)
489
+ for j, col in enumerate(selected_signals):
490
+ signal = group[col].to_numpy()
491
+ signals[i, frames, j] = signal
492
+ signals[i, max(frames):, j] = signal[-1]
493
+
494
+ # for i in range(5):
495
+ # print('pre model')
496
+ # plt.plot(signals[i,:,0])
497
+ # plt.show()
498
+
499
+ model = SignalDetectionModel(pretrained=complete_path)
500
+ print('signal shape: ', signals.shape)
501
+
502
+ classes = model.predict_class(signals)
503
+ times_recast = model.predict_time_of_interest(signals)
504
+
505
+ if label is None:
506
+ class_col = 'class'
507
+ time_col = 't0'
508
+ status_col = 'status'
509
+ else:
510
+ class_col = 'class_' + label
511
+ time_col = 't_' + label
512
+ status_col = 'status_' + label
513
+
514
+ for i, (tid, group) in enumerate(trajectories.groupby(column_labels['track'])):
515
+ indices = group.index
516
+ trajectories.loc[indices, class_col] = classes[i]
517
+ trajectories.loc[indices, time_col] = times_recast[i]
518
+ print('Done.')
519
+
520
+ for tid, group in trajectories.groupby(column_labels['track']):
521
+
522
+ indices = group.index
523
+ t0 = group[time_col].to_numpy()[0]
524
+ cclass = group[class_col].to_numpy()[0]
525
+ timeline = group[column_labels['time']].to_numpy()
526
+ status = np.zeros_like(timeline)
527
+ if t0 > 0:
528
+ status[timeline >= t0] = 1.
529
+ if cclass == 2:
530
+ status[:] = 2
531
+ if cclass > 2:
532
+ status[:] = 42
533
+ status_color = [color_from_status(s) for s in status]
534
+ class_color = [color_from_class(cclass) for i in range(len(status))]
535
+
536
+ trajectories.loc[indices, status_col] = status
537
+ trajectories.loc[indices, 'status_color'] = status_color
538
+ trajectories.loc[indices, 'class_color'] = class_color
539
+
540
+ if plot_outcome:
541
+ fig, ax = plt.subplots(1, len(selected_signals), figsize=(10, 5))
542
+ for i, s in enumerate(selected_signals):
543
+ for k, (tid, group) in enumerate(trajectories.groupby(column_labels['track'])):
544
+ cclass = group[class_col].to_numpy()[0]
545
+ t0 = group[time_col].to_numpy()[0]
546
+ timeline = group[column_labels['time']].to_numpy()
547
+ if cclass == 0:
548
+ if len(selected_signals) > 1:
549
+ ax[i].plot(timeline - t0, group[s].to_numpy(), c='tab:blue', alpha=0.1)
550
+ else:
551
+ ax.plot(timeline - t0, group[s].to_numpy(), c='tab:blue', alpha=0.1)
552
+ if len(selected_signals) > 1:
553
+ for a, s in zip(ax, selected_signals):
554
+ a.set_title(s)
555
+ a.set_xlabel(r'time - t$_0$ [frame]')
556
+ a.spines['top'].set_visible(False)
557
+ a.spines['right'].set_visible(False)
558
+ else:
559
+ ax.set_title(s)
560
+ ax.set_xlabel(r'time - t$_0$ [frame]')
561
+ ax.spines['top'].set_visible(False)
562
+ ax.spines['right'].set_visible(False)
563
+ plt.tight_layout()
564
+ if output_dir is not None:
565
+ plt.savefig(output_dir + 'signal_collapse.png', bbox_inches='tight', dpi=300)
566
+ plt.pause(3)
567
+ plt.close()
568
+
569
+ return trajectories
570
+
571
+
572
+ def analyze_pair_signals(trajectories_pairs,trajectories_reference,trajectories_neighbors, model, interpolate_na=True, selected_signals=None,
573
+ model_path=None, plot_outcome=False, output_dir=None, column_labels = {'track': "TRACK_ID", 'time': 'FRAME', 'x': 'POSITION_X', 'y': 'POSITION_Y'}):
574
+ """
575
+ """
576
+
577
+ model_path = locate_signal_model(model, path=model_path, pairs=True)
578
+ print(f'Looking for model in {model_path}...')
579
+ complete_path = model_path
580
+ complete_path = rf"{complete_path}"
581
+ model_config_path = os.sep.join([complete_path, 'config_input.json'])
582
+ model_config_path = rf"{model_config_path}"
583
+ assert os.path.exists(complete_path), f'Model {model} could not be located in folder {model_path}... Abort.'
584
+ assert os.path.exists(model_config_path), f'Model configuration could not be located in folder {model_path}... Abort.'
585
+
586
+ trajectories_pairs = trajectories_pairs.rename(columns=lambda x: 'pair_' + x)
587
+ trajectories_reference = trajectories_reference.rename(columns=lambda x: 'reference_' + x)
588
+ trajectories_neighbors = trajectories_neighbors.rename(columns=lambda x: 'neighbor_' + x)
589
+
590
+ if 'pair_position' in list(trajectories_pairs.columns):
591
+ pair_groupby_cols = ['pair_position', 'pair_REFERENCE_ID', 'pair_NEIGHBOR_ID']
592
+ else:
593
+ pair_groupby_cols = ['pair_REFERENCE_ID', 'pair_NEIGHBOR_ID']
594
+
595
+ if 'reference_position' in list(trajectories_reference.columns):
596
+ reference_groupby_cols = ['reference_position', 'reference_TRACK_ID']
597
+ else:
598
+ reference_groupby_cols = ['reference_TRACK_ID']
599
+
600
+ if 'neighbor_position' in list(trajectories_neighbors.columns):
601
+ neighbor_groupby_cols = ['neighbor_position', 'neighbor_TRACK_ID']
602
+ else:
603
+ neighbor_groupby_cols = ['neighbor_TRACK_ID']
604
+
605
+ available_signals = [] #list(trajectories_pairs.columns) + list(trajectories_reference.columns) + list(trajectories_neighbors.columns)
606
+ for col in list(trajectories_pairs.columns):
607
+ if is_numeric_dtype(trajectories_pairs[col]):
608
+ available_signals.append(col)
609
+ for col in list(trajectories_reference.columns):
610
+ if is_numeric_dtype(trajectories_reference[col]):
611
+ available_signals.append(col)
612
+ for col in list(trajectories_neighbors.columns):
613
+ if is_numeric_dtype(trajectories_neighbors[col]):
614
+ available_signals.append(col)
615
+
616
+ print('The available signals are : ', available_signals)
617
+
618
+ f = open(model_config_path)
619
+ config = json.load(f)
620
+ required_signals = config["channels"]
621
+
622
+ try:
623
+ label = config['label']
624
+ if label=='':
625
+ label = None
626
+ except:
627
+ label = None
628
+
629
+ if selected_signals is None:
630
+ selected_signals = []
631
+ for s in required_signals:
632
+ pattern_test = [s in a or s==a for a in available_signals]
633
+ print(f'Pattern test for signal {s}: ', pattern_test)
634
+ assert np.any(pattern_test),f'No signal matches with the requirements of the model {required_signals}. Please pass the signals manually with the argument selected_signals or add measurements. Abort.'
635
+ valid_columns = np.array(available_signals)[np.array(pattern_test)]
636
+ if len(valid_columns)==1:
637
+ selected_signals.append(valid_columns[0])
638
+ else:
639
+ #print(test_number_of_nan(trajectories, valid_columns))
640
+ print(f'Found several candidate signals: {valid_columns}')
641
+ for vc in natsorted(valid_columns):
642
+ if 'circle' in vc:
643
+ selected_signals.append(vc)
644
+ break
645
+ else:
646
+ selected_signals.append(valid_columns[0])
647
+ # do something more complicated in case of one to many columns
648
+ #pass
649
+ else:
650
+ assert len(selected_signals)==len(required_signals),f'Mismatch between the number of required signals {required_signals} and the provided signals {selected_signals}... Abort.'
651
+
652
+ print(f'The following channels will be passed to the model: {selected_signals}')
653
+ trajectories_reference_clean = interpolate_nan_properties(trajectories_reference, track_label=reference_groupby_cols)
654
+ trajectories_neighbors_clean = interpolate_nan_properties(trajectories_neighbors, track_label=neighbor_groupby_cols)
655
+ trajectories_pairs_clean = interpolate_nan_properties(trajectories_pairs, track_label=pair_groupby_cols)
656
+ print(f'{trajectories_pairs_clean.columns=}')
657
+
658
+ max_signal_size = int(trajectories_pairs_clean['pair_FRAME'].max()) + 2
659
+ pair_tracks = trajectories_pairs_clean.groupby(pair_groupby_cols).size()
660
+ signals = np.zeros((len(pair_tracks),max_signal_size, len(selected_signals)))
661
+ print(f'{max_signal_size=} {len(pair_tracks)=} {signals.shape=}')
662
+
663
+ for i,(pair,group) in enumerate(trajectories_pairs_clean.groupby(pair_groupby_cols)):
664
+
665
+ if 'pair_position' not in list(trajectories_pairs_clean.columns):
666
+ pos_mode = False
667
+ reference_cell = pair[0]; neighbor_cell = pair[1]
668
+ else:
669
+ pos_mode = True
670
+ reference_cell = pair[1]; neighbor_cell = pair[2]; pos = pair[0]
671
+
672
+ if pos_mode and 'reference_position' in list(trajectories_reference_clean.columns) and 'neighbor_position' in list(trajectories_neighbors_clean.columns):
673
+ reference_filter = (trajectories_reference_clean['reference_TRACK_ID']==reference_cell)&(trajectories_reference_clean['reference_position']==pos)
674
+ neighbor_filter = (trajectories_neighbors_clean['neighbor_TRACK_ID']==neighbor_cell)&(trajectories_neighbors_clean['neighbor_position']==pos)
675
+ else:
676
+ reference_filter = trajectories_reference_clean['reference_TRACK_ID']==reference_cell
677
+ neighbor_filter = trajectories_neighbors_clean['neighbor_TRACK_ID']==neighbor_cell
678
+
679
+ pair_frames = group['pair_FRAME'].to_numpy().astype(int)
680
+
681
+ for j,col in enumerate(selected_signals):
682
+ if col.startswith('pair_'):
683
+ signal = group[col].to_numpy()
684
+ signals[i,pair_frames,j] = signal
685
+ signals[i,max(pair_frames):,j] = signal[-1]
686
+ elif col.startswith('reference_'):
687
+ signal = trajectories_reference_clean.loc[reference_filter, col].to_numpy()
688
+ timeline = trajectories_reference_clean.loc[reference_filter, 'reference_FRAME'].to_numpy()
689
+ signals[i,timeline,j] = signal
690
+ signals[i,max(timeline):,j] = signal[-1]
691
+ elif col.startswith('neighbor_'):
692
+ signal = trajectories_neighbors_clean.loc[neighbor_filter, col].to_numpy()
693
+ timeline = trajectories_neighbors_clean.loc[neighbor_filter, 'neighbor_FRAME'].to_numpy()
694
+ signals[i,timeline,j] = signal
695
+ signals[i,max(timeline):,j] = signal[-1]
696
+
697
+
698
+ model = SignalDetectionModel(pretrained=complete_path)
699
+ print('signal shape: ', signals.shape)
700
+
701
+ classes = model.predict_class(signals)
702
+ times_recast = model.predict_time_of_interest(signals)
703
+
704
+ if label is None:
705
+ class_col = 'pair_class'
706
+ time_col = 'pair_t0'
707
+ status_col = 'pair_status'
708
+ else:
709
+ class_col = 'pair_class_'+label
710
+ time_col = 'pair_t_'+label
711
+ status_col = 'pair_status_'+label
712
+
713
+ for i,(pair,group) in enumerate(trajectories_pairs.groupby(pair_groupby_cols)):
714
+ indices = group.index
715
+ trajectories_pairs.loc[indices,class_col] = classes[i]
716
+ trajectories_pairs.loc[indices,time_col] = times_recast[i]
717
+ print('Done.')
718
+
719
+ # At the end rename cols again
720
+ trajectories_pairs = trajectories_pairs.rename(columns=lambda x: x.replace('pair_',''))
721
+ trajectories_reference = trajectories_pairs.rename(columns=lambda x: x.replace('reference_',''))
722
+ trajectories_neighbors = trajectories_pairs.rename(columns=lambda x: x.replace('neighbor_',''))
723
+ invalid_cols = [c for c in list(trajectories_pairs.columns) if c.startswith('Unnamed')]
724
+ trajectories_pairs = trajectories_pairs.drop(columns=invalid_cols)
725
+
726
+ return trajectories_pairs
337
727
 
338
728
  class SignalDetectionModel(object):
339
729
 
@@ -1258,10 +1648,10 @@ class SignalDetectionModel(object):
1258
1648
  csv_logger = CSVLogger(os.sep.join([self.model_folder,'log_classifier.csv']), append=True, separator=';')
1259
1649
  self.cb.append(csv_logger)
1260
1650
  checkpoint_path = os.sep.join([self.model_folder,"classifier.h5"])
1261
- cp_callback = ModelCheckpoint(checkpoint_path,monitor="val_iou",mode="max",verbose=1,save_best_only=True,save_weights_only=False,save_freq="epoch")
1651
+ cp_callback = ModelCheckpoint(checkpoint_path, monitor="val_iou",mode="max",verbose=1,save_best_only=True,save_weights_only=False,save_freq="epoch")
1262
1652
  self.cb.append(cp_callback)
1263
1653
 
1264
- callback_stop = EarlyStopping(monitor='val_iou', patience=100)
1654
+ callback_stop = EarlyStopping(monitor='val_iou',mode='max',patience=100)
1265
1655
  self.cb.append(callback_stop)
1266
1656
 
1267
1657
  elif mode=="regressor":
@@ -1278,7 +1668,7 @@ class SignalDetectionModel(object):
1278
1668
  cp_callback = ModelCheckpoint(checkpoint_path,monitor="val_loss",mode="min",verbose=1,save_best_only=True,save_weights_only=False,save_freq="epoch")
1279
1669
  self.cb.append(cp_callback)
1280
1670
 
1281
- callback_stop = EarlyStopping(monitor='val_loss', patience=200)
1671
+ callback_stop = EarlyStopping(monitor='val_loss', mode='min', patience=200)
1282
1672
  self.cb.append(callback_stop)
1283
1673
 
1284
1674
  log_dir = self.model_folder+os.sep
@@ -1312,7 +1702,7 @@ class SignalDetectionModel(object):
1312
1702
  if isinstance(self.list_of_sets[0],str):
1313
1703
  # Case 1: a list of npy files to be loaded
1314
1704
  for s in self.list_of_sets:
1315
-
1705
+
1316
1706
  signal_dataset = self.load_set(s)
1317
1707
  selected_signals, max_length = self.find_best_signal_match(signal_dataset)
1318
1708
  signals_recast, classes, times_of_interest = self.cast_signals_into_training_data(signal_dataset, selected_signals, max_length)
@@ -1416,7 +1806,6 @@ class SignalDetectionModel(object):
1416
1806
  x_train_aug.append(aug[0])
1417
1807
  y_time_train_aug.append(aug[1])
1418
1808
  y_class_train_aug.append(aug[2])
1419
- print('per class counts ',counts)
1420
1809
 
1421
1810
  # Save augmented training set
1422
1811
  self.x_train = np.array(x_train_aug)
@@ -1469,7 +1858,15 @@ class SignalDetectionModel(object):
1469
1858
  for i in range(self.n_channels):
1470
1859
  try:
1471
1860
  # take into account timeline for accurate time regression
1472
- timeline = signal_dataset[k]['FRAME'].astype(int)
1861
+
1862
+ if selected_signals[i].startswith('pair_'):
1863
+ timeline = signal_dataset[k]['pair_FRAME'].astype(int)
1864
+ elif selected_signals[i].startswith('reference_'):
1865
+ timeline = signal_dataset[k]['reference_FRAME'].astype(int)
1866
+ elif selected_signals[i].startswith('neighbor_'):
1867
+ timeline = signal_dataset[k]['neighbor_FRAME'].astype(int)
1868
+ else:
1869
+ timeline = signal_dataset[k]['FRAME'].astype(int)
1473
1870
  signals_recast[k,timeline,i] = signal_dataset[k][selected_signals[i]]
1474
1871
  except:
1475
1872
  print(f"Attribute {selected_signals[i]} matched to {self.channel_option[i]} not found in annotation...")
celldetective/tracking.py CHANGED
@@ -130,7 +130,7 @@ def track(labels, configuration=None, stack=None, spatial_calibration=1, feature
130
130
  tracking_updates = ["motion"]
131
131
 
132
132
  tracker.append(new_btrack_objects)
133
- tracker.volume = ((0,volume[0]), (0,volume[1])) #(-1e5, 1e5)
133
+ tracker.volume = ((0,volume[0]), (0,volume[1]), (-1e5, 1e5)) #(-1e5, 1e5)
134
134
  #print(tracker.volume)
135
135
  tracker.track(tracking_updates=tracking_updates, **track_kwargs)
136
136
  tracker.optimize(options=optimizer_options)
@@ -138,7 +138,11 @@ def track(labels, configuration=None, stack=None, spatial_calibration=1, feature
138
138
  data, properties, graph = tracker.to_napari() #ndim=2
139
139
 
140
140
  # do the table post processing and napari options
141
- df = pd.DataFrame(data, columns=[column_labels['track'],column_labels['time'],column_labels['y'],column_labels['x']])
141
+ if data.shape[1]==4:
142
+ df = pd.DataFrame(data, columns=[column_labels['track'],column_labels['time'],column_labels['y'],column_labels['x']])
143
+ elif data.shape[1]==5:
144
+ df = pd.DataFrame(data, columns=[column_labels['track'],column_labels['time'],"z",column_labels['y'],column_labels['x']])
145
+ df = df.drop(columns=['z'])
142
146
  df[column_labels['x']+'_um'] = df[column_labels['x']]*spatial_calibration
143
147
  df[column_labels['y']+'_um'] = df[column_labels['y']]*spatial_calibration
144
148
 
@@ -160,6 +164,8 @@ def track(labels, configuration=None, stack=None, spatial_calibration=1, feature
160
164
  if clean_trajectories_kwargs is not None:
161
165
  df = clean_trajectories(df.copy(),**clean_trajectories_kwargs)
162
166
 
167
+ df['ID'] = np.arange(len(df)).astype(int)
168
+
163
169
  if view_on_napari:
164
170
  view_on_napari_btrack(data,properties,graph,stack=stack,labels=labels,relabel=True)
165
171