celldetective 1.1.1.post3__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- celldetective/__init__.py +2 -1
- celldetective/__main__.py +17 -0
- celldetective/extra_properties.py +62 -34
- celldetective/gui/__init__.py +1 -0
- celldetective/gui/analyze_block.py +2 -1
- celldetective/gui/classifier_widget.py +18 -10
- celldetective/gui/control_panel.py +57 -6
- celldetective/gui/layouts.py +14 -11
- celldetective/gui/neighborhood_options.py +21 -13
- celldetective/gui/plot_signals_ui.py +39 -11
- celldetective/gui/process_block.py +413 -95
- celldetective/gui/retrain_segmentation_model_options.py +17 -4
- celldetective/gui/retrain_signal_model_options.py +106 -6
- celldetective/gui/signal_annotator.py +110 -30
- celldetective/gui/signal_annotator2.py +2708 -0
- celldetective/gui/signal_annotator_options.py +3 -1
- celldetective/gui/survival_ui.py +15 -6
- celldetective/gui/tableUI.py +248 -43
- celldetective/io.py +598 -416
- celldetective/measure.py +919 -969
- celldetective/models/pair_signal_detection/blank +0 -0
- celldetective/neighborhood.py +482 -340
- celldetective/preprocessing.py +81 -61
- celldetective/relative_measurements.py +648 -0
- celldetective/scripts/analyze_signals.py +1 -1
- celldetective/scripts/measure_cells.py +28 -8
- celldetective/scripts/measure_relative.py +103 -0
- celldetective/scripts/segment_cells.py +5 -5
- celldetective/scripts/track_cells.py +4 -1
- celldetective/scripts/train_segmentation_model.py +23 -18
- celldetective/scripts/train_signal_model.py +33 -0
- celldetective/segmentation.py +67 -29
- celldetective/signals.py +402 -8
- celldetective/tracking.py +8 -2
- celldetective/utils.py +144 -12
- {celldetective-1.1.1.post3.dist-info → celldetective-1.2.0.dist-info}/METADATA +8 -8
- {celldetective-1.1.1.post3.dist-info → celldetective-1.2.0.dist-info}/RECORD +42 -38
- {celldetective-1.1.1.post3.dist-info → celldetective-1.2.0.dist-info}/WHEEL +1 -1
- tests/test_segmentation.py +1 -1
- {celldetective-1.1.1.post3.dist-info → celldetective-1.2.0.dist-info}/LICENSE +0 -0
- {celldetective-1.1.1.post3.dist-info → celldetective-1.2.0.dist-info}/entry_points.txt +0 -0
- {celldetective-1.1.1.post3.dist-info → celldetective-1.2.0.dist-info}/top_level.txt +0 -0
celldetective/signals.py
CHANGED
|
@@ -18,8 +18,8 @@ from sklearn.metrics import jaccard_score, balanced_accuracy_score, precision_sc
|
|
|
18
18
|
from scipy.interpolate import interp1d
|
|
19
19
|
from scipy.ndimage import shift
|
|
20
20
|
|
|
21
|
-
from celldetective.io import get_signal_models_list, locate_signal_model
|
|
22
|
-
from celldetective.tracking import clean_trajectories
|
|
21
|
+
from celldetective.io import get_signal_models_list, locate_signal_model, get_position_pickle, get_position_table
|
|
22
|
+
from celldetective.tracking import clean_trajectories, interpolate_nan_properties
|
|
23
23
|
from celldetective.utils import regression_plot, train_test_split, compute_weights
|
|
24
24
|
import matplotlib.pyplot as plt
|
|
25
25
|
from natsort import natsorted
|
|
@@ -32,6 +32,7 @@ from scipy.optimize import curve_fit
|
|
|
32
32
|
import time
|
|
33
33
|
import math
|
|
34
34
|
import pandas as pd
|
|
35
|
+
from pandas.api.types import is_numeric_dtype
|
|
35
36
|
|
|
36
37
|
abs_path = os.sep.join([os.path.split(os.path.dirname(os.path.realpath(__file__)))[0],'celldetective'])
|
|
37
38
|
|
|
@@ -334,6 +335,392 @@ def analyze_signals_at_position(pos, model, mode, use_gpu=True, return_table=Fal
|
|
|
334
335
|
else:
|
|
335
336
|
return None
|
|
336
337
|
|
|
338
|
+
def analyze_pair_signals_at_position(pos, model, use_gpu=True):
|
|
339
|
+
|
|
340
|
+
"""
|
|
341
|
+
|
|
342
|
+
"""
|
|
343
|
+
|
|
344
|
+
pos = pos.replace('\\','/')
|
|
345
|
+
pos = rf"{pos}"
|
|
346
|
+
assert os.path.exists(pos),f'Position {pos} is not a valid path.'
|
|
347
|
+
if not pos.endswith('/'):
|
|
348
|
+
pos += '/'
|
|
349
|
+
|
|
350
|
+
df_targets = get_position_pickle(pos, population='targets')
|
|
351
|
+
df_effectors = get_position_pickle(pos, population='effectors')
|
|
352
|
+
dataframes = {
|
|
353
|
+
'targets': df_targets,
|
|
354
|
+
'effectors': df_effectors,
|
|
355
|
+
}
|
|
356
|
+
df_pairs = get_position_table(pos, population='pairs')
|
|
357
|
+
|
|
358
|
+
# Need to identify expected reference / neighbor tables
|
|
359
|
+
model_path = locate_signal_model(model, pairs=True)
|
|
360
|
+
print(f'Looking for model in {model_path}...')
|
|
361
|
+
complete_path = model_path
|
|
362
|
+
complete_path = rf"{complete_path}"
|
|
363
|
+
model_config_path = os.sep.join([complete_path, 'config_input.json'])
|
|
364
|
+
model_config_path = rf"{model_config_path}"
|
|
365
|
+
f = open(model_config_path)
|
|
366
|
+
model_config_path = json.load(f)
|
|
367
|
+
|
|
368
|
+
reference_population = model_config_path['reference_population']
|
|
369
|
+
neighbor_population = model_config_path['neighbor_population']
|
|
370
|
+
|
|
371
|
+
df = analyze_pair_signals(df_pairs, dataframes[reference_population], dataframes[neighbor_population], model=model)
|
|
372
|
+
|
|
373
|
+
table = pos + os.sep.join(["output","tables",f"trajectories_pairs.csv"])
|
|
374
|
+
df.to_csv(table, index=False)
|
|
375
|
+
|
|
376
|
+
return None
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
def analyze_signals(trajectories, model, interpolate_na=True,
|
|
380
|
+
selected_signals=None,
|
|
381
|
+
model_path=None,
|
|
382
|
+
column_labels={'track': "TRACK_ID", 'time': 'FRAME', 'x': 'POSITION_X', 'y': 'POSITION_Y'},
|
|
383
|
+
plot_outcome=False, output_dir=None):
|
|
384
|
+
"""
|
|
385
|
+
Analyzes signals from trajectory data using a specified signal detection model and configuration.
|
|
386
|
+
|
|
387
|
+
This function preprocesses trajectory data, selects specified signals, and applies a pretrained signal detection
|
|
388
|
+
model to predict classes and times of interest for each trajectory. It supports custom column labeling, interpolation
|
|
389
|
+
of missing values, and plotting of analysis outcomes.
|
|
390
|
+
|
|
391
|
+
Parameters
|
|
392
|
+
----------
|
|
393
|
+
trajectories : pandas.DataFrame
|
|
394
|
+
DataFrame containing trajectory data with columns for track ID, frame, position, and signals.
|
|
395
|
+
model : str
|
|
396
|
+
The name of the signal detection model to be used for analysis.
|
|
397
|
+
interpolate_na : bool, optional
|
|
398
|
+
Whether to interpolate missing values in the trajectories (default is True).
|
|
399
|
+
selected_signals : list of str, optional
|
|
400
|
+
A list of column names from `trajectories` representing the signals to be analyzed. If None, signals will
|
|
401
|
+
be automatically selected based on the model configuration (default is None).
|
|
402
|
+
column_labels : dict, optional
|
|
403
|
+
A dictionary mapping the default column names ('track', 'time', 'x', 'y') to the corresponding column names
|
|
404
|
+
in `trajectories` (default is {'track': "TRACK_ID", 'time': 'FRAME', 'x': 'POSITION_X', 'y': 'POSITION_Y'}).
|
|
405
|
+
plot_outcome : bool, optional
|
|
406
|
+
If True, generates and saves a plot of the signal analysis outcome (default is False).
|
|
407
|
+
output_dir : str, optional
|
|
408
|
+
The directory where the outcome plot will be saved. Required if `plot_outcome` is True (default is None).
|
|
409
|
+
|
|
410
|
+
Returns
|
|
411
|
+
-------
|
|
412
|
+
pandas.DataFrame
|
|
413
|
+
The input `trajectories` DataFrame with additional columns for predicted classes, times of interest, and
|
|
414
|
+
corresponding colors based on status and class.
|
|
415
|
+
|
|
416
|
+
Raises
|
|
417
|
+
------
|
|
418
|
+
AssertionError
|
|
419
|
+
If the model or its configuration file cannot be located.
|
|
420
|
+
|
|
421
|
+
Notes
|
|
422
|
+
-----
|
|
423
|
+
- The function relies on an external model configuration file (`config_input.json`) located in the model's directory.
|
|
424
|
+
- Signal selection and preprocessing are based on the requirements specified in the model's configuration.
|
|
425
|
+
|
|
426
|
+
"""
|
|
427
|
+
|
|
428
|
+
model_path = locate_signal_model(model, path=model_path)
|
|
429
|
+
complete_path = model_path # +model
|
|
430
|
+
complete_path = rf"{complete_path}"
|
|
431
|
+
model_config_path = os.sep.join([complete_path, 'config_input.json'])
|
|
432
|
+
model_config_path = rf"{model_config_path}"
|
|
433
|
+
assert os.path.exists(complete_path), f'Model {model} could not be located in folder {model_path}... Abort.'
|
|
434
|
+
assert os.path.exists(
|
|
435
|
+
model_config_path), f'Model configuration could not be located in folder {model_path}... Abort.'
|
|
436
|
+
|
|
437
|
+
available_signals = list(trajectories.columns)
|
|
438
|
+
print('The available_signals are : ', available_signals)
|
|
439
|
+
|
|
440
|
+
f = open(model_config_path)
|
|
441
|
+
config = json.load(f)
|
|
442
|
+
required_signals = config["channels"]
|
|
443
|
+
|
|
444
|
+
try:
|
|
445
|
+
label = config['label']
|
|
446
|
+
if label == '':
|
|
447
|
+
label = None
|
|
448
|
+
except:
|
|
449
|
+
label = None
|
|
450
|
+
|
|
451
|
+
if selected_signals is None:
|
|
452
|
+
selected_signals = []
|
|
453
|
+
for s in required_signals:
|
|
454
|
+
pattern_test = [s in a or s == a for a in available_signals]
|
|
455
|
+
print(f'Pattern test for signal {s}: ', pattern_test)
|
|
456
|
+
assert np.any(
|
|
457
|
+
pattern_test), f'No signal matches with the requirements of the model {required_signals}. Please pass the signals manually with the argument selected_signals or add measurements. Abort.'
|
|
458
|
+
valid_columns = np.array(available_signals)[np.array(pattern_test)]
|
|
459
|
+
if len(valid_columns) == 1:
|
|
460
|
+
selected_signals.append(valid_columns[0])
|
|
461
|
+
else:
|
|
462
|
+
# print(test_number_of_nan(trajectories, valid_columns))
|
|
463
|
+
print(f'Found several candidate signals: {valid_columns}')
|
|
464
|
+
for vc in natsorted(valid_columns):
|
|
465
|
+
if 'circle' in vc:
|
|
466
|
+
selected_signals.append(vc)
|
|
467
|
+
break
|
|
468
|
+
else:
|
|
469
|
+
selected_signals.append(valid_columns[0])
|
|
470
|
+
# do something more complicated in case of one to many columns
|
|
471
|
+
# pass
|
|
472
|
+
else:
|
|
473
|
+
assert len(selected_signals) == len(
|
|
474
|
+
required_signals), f'Mismatch between the number of required signals {required_signals} and the provided signals {selected_signals}... Abort.'
|
|
475
|
+
|
|
476
|
+
print(f'The following channels will be passed to the model: {selected_signals}')
|
|
477
|
+
trajectories_clean = clean_trajectories(trajectories, interpolate_na=interpolate_na,
|
|
478
|
+
interpolate_position_gaps=interpolate_na, column_labels=column_labels)
|
|
479
|
+
|
|
480
|
+
max_signal_size = int(trajectories_clean[column_labels['time']].max()) + 2
|
|
481
|
+
tracks = trajectories_clean[column_labels['track']].unique()
|
|
482
|
+
signals = np.zeros((len(tracks), max_signal_size, len(selected_signals)))
|
|
483
|
+
|
|
484
|
+
for i, (tid, group) in enumerate(trajectories_clean.groupby(column_labels['track'])):
|
|
485
|
+
frames = group[column_labels['time']].to_numpy().astype(int)
|
|
486
|
+
for j, col in enumerate(selected_signals):
|
|
487
|
+
signal = group[col].to_numpy()
|
|
488
|
+
signals[i, frames, j] = signal
|
|
489
|
+
signals[i, max(frames):, j] = signal[-1]
|
|
490
|
+
|
|
491
|
+
# for i in range(5):
|
|
492
|
+
# print('pre model')
|
|
493
|
+
# plt.plot(signals[i,:,0])
|
|
494
|
+
# plt.show()
|
|
495
|
+
|
|
496
|
+
model = SignalDetectionModel(pretrained=complete_path)
|
|
497
|
+
print('signal shape: ', signals.shape)
|
|
498
|
+
|
|
499
|
+
classes = model.predict_class(signals)
|
|
500
|
+
times_recast = model.predict_time_of_interest(signals)
|
|
501
|
+
|
|
502
|
+
if label is None:
|
|
503
|
+
class_col = 'class'
|
|
504
|
+
time_col = 't0'
|
|
505
|
+
status_col = 'status'
|
|
506
|
+
else:
|
|
507
|
+
class_col = 'class_' + label
|
|
508
|
+
time_col = 't_' + label
|
|
509
|
+
status_col = 'status_' + label
|
|
510
|
+
|
|
511
|
+
for i, (tid, group) in enumerate(trajectories.groupby(column_labels['track'])):
|
|
512
|
+
indices = group.index
|
|
513
|
+
trajectories.loc[indices, class_col] = classes[i]
|
|
514
|
+
trajectories.loc[indices, time_col] = times_recast[i]
|
|
515
|
+
print('Done.')
|
|
516
|
+
|
|
517
|
+
for tid, group in trajectories.groupby(column_labels['track']):
|
|
518
|
+
|
|
519
|
+
indices = group.index
|
|
520
|
+
t0 = group[time_col].to_numpy()[0]
|
|
521
|
+
cclass = group[class_col].to_numpy()[0]
|
|
522
|
+
timeline = group[column_labels['time']].to_numpy()
|
|
523
|
+
status = np.zeros_like(timeline)
|
|
524
|
+
if t0 > 0:
|
|
525
|
+
status[timeline >= t0] = 1.
|
|
526
|
+
if cclass == 2:
|
|
527
|
+
status[:] = 2
|
|
528
|
+
if cclass > 2:
|
|
529
|
+
status[:] = 42
|
|
530
|
+
status_color = [color_from_status(s) for s in status]
|
|
531
|
+
class_color = [color_from_class(cclass) for i in range(len(status))]
|
|
532
|
+
|
|
533
|
+
trajectories.loc[indices, status_col] = status
|
|
534
|
+
trajectories.loc[indices, 'status_color'] = status_color
|
|
535
|
+
trajectories.loc[indices, 'class_color'] = class_color
|
|
536
|
+
|
|
537
|
+
if plot_outcome:
|
|
538
|
+
fig, ax = plt.subplots(1, len(selected_signals), figsize=(10, 5))
|
|
539
|
+
for i, s in enumerate(selected_signals):
|
|
540
|
+
for k, (tid, group) in enumerate(trajectories.groupby(column_labels['track'])):
|
|
541
|
+
cclass = group[class_col].to_numpy()[0]
|
|
542
|
+
t0 = group[time_col].to_numpy()[0]
|
|
543
|
+
timeline = group[column_labels['time']].to_numpy()
|
|
544
|
+
if cclass == 0:
|
|
545
|
+
if len(selected_signals) > 1:
|
|
546
|
+
ax[i].plot(timeline - t0, group[s].to_numpy(), c='tab:blue', alpha=0.1)
|
|
547
|
+
else:
|
|
548
|
+
ax.plot(timeline - t0, group[s].to_numpy(), c='tab:blue', alpha=0.1)
|
|
549
|
+
if len(selected_signals) > 1:
|
|
550
|
+
for a, s in zip(ax, selected_signals):
|
|
551
|
+
a.set_title(s)
|
|
552
|
+
a.set_xlabel(r'time - t$_0$ [frame]')
|
|
553
|
+
a.spines['top'].set_visible(False)
|
|
554
|
+
a.spines['right'].set_visible(False)
|
|
555
|
+
else:
|
|
556
|
+
ax.set_title(s)
|
|
557
|
+
ax.set_xlabel(r'time - t$_0$ [frame]')
|
|
558
|
+
ax.spines['top'].set_visible(False)
|
|
559
|
+
ax.spines['right'].set_visible(False)
|
|
560
|
+
plt.tight_layout()
|
|
561
|
+
if output_dir is not None:
|
|
562
|
+
plt.savefig(output_dir + 'signal_collapse.png', bbox_inches='tight', dpi=300)
|
|
563
|
+
plt.pause(3)
|
|
564
|
+
plt.close()
|
|
565
|
+
|
|
566
|
+
return trajectories
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
def analyze_pair_signals(trajectories_pairs,trajectories_reference,trajectories_neighbors, model, interpolate_na=True, selected_signals=None,
|
|
570
|
+
model_path=None, plot_outcome=False, output_dir=None, column_labels = {'track': "TRACK_ID", 'time': 'FRAME', 'x': 'POSITION_X', 'y': 'POSITION_Y'}):
|
|
571
|
+
"""
|
|
572
|
+
"""
|
|
573
|
+
|
|
574
|
+
model_path = locate_signal_model(model, path=model_path, pairs=True)
|
|
575
|
+
print(f'Looking for model in {model_path}...')
|
|
576
|
+
complete_path = model_path
|
|
577
|
+
complete_path = rf"{complete_path}"
|
|
578
|
+
model_config_path = os.sep.join([complete_path, 'config_input.json'])
|
|
579
|
+
model_config_path = rf"{model_config_path}"
|
|
580
|
+
assert os.path.exists(complete_path), f'Model {model} could not be located in folder {model_path}... Abort.'
|
|
581
|
+
assert os.path.exists(model_config_path), f'Model configuration could not be located in folder {model_path}... Abort.'
|
|
582
|
+
|
|
583
|
+
trajectories_pairs = trajectories_pairs.rename(columns=lambda x: 'pair_' + x)
|
|
584
|
+
trajectories_reference = trajectories_reference.rename(columns=lambda x: 'reference_' + x)
|
|
585
|
+
trajectories_neighbors = trajectories_neighbors.rename(columns=lambda x: 'neighbor_' + x)
|
|
586
|
+
|
|
587
|
+
if 'pair_position' in list(trajectories_pairs.columns):
|
|
588
|
+
pair_groupby_cols = ['pair_position', 'pair_REFERENCE_ID', 'pair_NEIGHBOR_ID']
|
|
589
|
+
else:
|
|
590
|
+
pair_groupby_cols = ['pair_REFERENCE_ID', 'pair_NEIGHBOR_ID']
|
|
591
|
+
|
|
592
|
+
if 'reference_position' in list(trajectories_reference.columns):
|
|
593
|
+
reference_groupby_cols = ['reference_position', 'reference_TRACK_ID']
|
|
594
|
+
else:
|
|
595
|
+
reference_groupby_cols = ['reference_TRACK_ID']
|
|
596
|
+
|
|
597
|
+
if 'neighbor_position' in list(trajectories_neighbors.columns):
|
|
598
|
+
neighbor_groupby_cols = ['neighbor_position', 'neighbor_TRACK_ID']
|
|
599
|
+
else:
|
|
600
|
+
neighbor_groupby_cols = ['neighbor_TRACK_ID']
|
|
601
|
+
|
|
602
|
+
available_signals = [] #list(trajectories_pairs.columns) + list(trajectories_reference.columns) + list(trajectories_neighbors.columns)
|
|
603
|
+
for col in list(trajectories_pairs.columns):
|
|
604
|
+
if is_numeric_dtype(trajectories_pairs[col]):
|
|
605
|
+
available_signals.append(col)
|
|
606
|
+
for col in list(trajectories_reference.columns):
|
|
607
|
+
if is_numeric_dtype(trajectories_reference[col]):
|
|
608
|
+
available_signals.append(col)
|
|
609
|
+
for col in list(trajectories_neighbors.columns):
|
|
610
|
+
if is_numeric_dtype(trajectories_neighbors[col]):
|
|
611
|
+
available_signals.append(col)
|
|
612
|
+
|
|
613
|
+
print('The available signals are : ', available_signals)
|
|
614
|
+
|
|
615
|
+
f = open(model_config_path)
|
|
616
|
+
config = json.load(f)
|
|
617
|
+
required_signals = config["channels"]
|
|
618
|
+
|
|
619
|
+
try:
|
|
620
|
+
label = config['label']
|
|
621
|
+
if label=='':
|
|
622
|
+
label = None
|
|
623
|
+
except:
|
|
624
|
+
label = None
|
|
625
|
+
|
|
626
|
+
if selected_signals is None:
|
|
627
|
+
selected_signals = []
|
|
628
|
+
for s in required_signals:
|
|
629
|
+
pattern_test = [s in a or s==a for a in available_signals]
|
|
630
|
+
print(f'Pattern test for signal {s}: ', pattern_test)
|
|
631
|
+
assert np.any(pattern_test),f'No signal matches with the requirements of the model {required_signals}. Please pass the signals manually with the argument selected_signals or add measurements. Abort.'
|
|
632
|
+
valid_columns = np.array(available_signals)[np.array(pattern_test)]
|
|
633
|
+
if len(valid_columns)==1:
|
|
634
|
+
selected_signals.append(valid_columns[0])
|
|
635
|
+
else:
|
|
636
|
+
#print(test_number_of_nan(trajectories, valid_columns))
|
|
637
|
+
print(f'Found several candidate signals: {valid_columns}')
|
|
638
|
+
for vc in natsorted(valid_columns):
|
|
639
|
+
if 'circle' in vc:
|
|
640
|
+
selected_signals.append(vc)
|
|
641
|
+
break
|
|
642
|
+
else:
|
|
643
|
+
selected_signals.append(valid_columns[0])
|
|
644
|
+
# do something more complicated in case of one to many columns
|
|
645
|
+
#pass
|
|
646
|
+
else:
|
|
647
|
+
assert len(selected_signals)==len(required_signals),f'Mismatch between the number of required signals {required_signals} and the provided signals {selected_signals}... Abort.'
|
|
648
|
+
|
|
649
|
+
print(f'The following channels will be passed to the model: {selected_signals}')
|
|
650
|
+
trajectories_reference_clean = interpolate_nan_properties(trajectories_reference, track_label=reference_groupby_cols)
|
|
651
|
+
trajectories_neighbors_clean = interpolate_nan_properties(trajectories_neighbors, track_label=neighbor_groupby_cols)
|
|
652
|
+
trajectories_pairs_clean = interpolate_nan_properties(trajectories_pairs, track_label=pair_groupby_cols)
|
|
653
|
+
print(f'{trajectories_pairs_clean.columns=}')
|
|
654
|
+
|
|
655
|
+
max_signal_size = int(trajectories_pairs_clean['pair_FRAME'].max()) + 2
|
|
656
|
+
pair_tracks = trajectories_pairs_clean.groupby(pair_groupby_cols).size()
|
|
657
|
+
signals = np.zeros((len(pair_tracks),max_signal_size, len(selected_signals)))
|
|
658
|
+
print(f'{max_signal_size=} {len(pair_tracks)=} {signals.shape=}')
|
|
659
|
+
|
|
660
|
+
for i,(pair,group) in enumerate(trajectories_pairs_clean.groupby(pair_groupby_cols)):
|
|
661
|
+
|
|
662
|
+
if 'pair_position' not in list(trajectories_pairs_clean.columns):
|
|
663
|
+
pos_mode = False
|
|
664
|
+
reference_cell = pair[0]; neighbor_cell = pair[1]
|
|
665
|
+
else:
|
|
666
|
+
pos_mode = True
|
|
667
|
+
reference_cell = pair[1]; neighbor_cell = pair[2]; pos = pair[0]
|
|
668
|
+
|
|
669
|
+
if pos_mode and 'reference_position' in list(trajectories_reference_clean.columns) and 'neighbor_position' in list(trajectories_neighbors_clean.columns):
|
|
670
|
+
reference_filter = (trajectories_reference_clean['reference_TRACK_ID']==reference_cell)&(trajectories_reference_clean['reference_position']==pos)
|
|
671
|
+
neighbor_filter = (trajectories_neighbors_clean['neighbor_TRACK_ID']==neighbor_cell)&(trajectories_neighbors_clean['neighbor_position']==pos)
|
|
672
|
+
else:
|
|
673
|
+
reference_filter = trajectories_reference_clean['reference_TRACK_ID']==reference_cell
|
|
674
|
+
neighbor_filter = trajectories_neighbors_clean['neighbor_TRACK_ID']==neighbor_cell
|
|
675
|
+
|
|
676
|
+
pair_frames = group['pair_FRAME'].to_numpy().astype(int)
|
|
677
|
+
|
|
678
|
+
for j,col in enumerate(selected_signals):
|
|
679
|
+
if col.startswith('pair_'):
|
|
680
|
+
signal = group[col].to_numpy()
|
|
681
|
+
signals[i,pair_frames,j] = signal
|
|
682
|
+
signals[i,max(pair_frames):,j] = signal[-1]
|
|
683
|
+
elif col.startswith('reference_'):
|
|
684
|
+
signal = trajectories_reference_clean.loc[reference_filter, col].to_numpy()
|
|
685
|
+
timeline = trajectories_reference_clean.loc[reference_filter, 'reference_FRAME'].to_numpy()
|
|
686
|
+
signals[i,timeline,j] = signal
|
|
687
|
+
signals[i,max(timeline):,j] = signal[-1]
|
|
688
|
+
elif col.startswith('neighbor_'):
|
|
689
|
+
signal = trajectories_neighbors_clean.loc[neighbor_filter, col].to_numpy()
|
|
690
|
+
timeline = trajectories_neighbors_clean.loc[neighbor_filter, 'neighbor_FRAME'].to_numpy()
|
|
691
|
+
signals[i,timeline,j] = signal
|
|
692
|
+
signals[i,max(timeline):,j] = signal[-1]
|
|
693
|
+
|
|
694
|
+
|
|
695
|
+
model = SignalDetectionModel(pretrained=complete_path)
|
|
696
|
+
print('signal shape: ', signals.shape)
|
|
697
|
+
|
|
698
|
+
classes = model.predict_class(signals)
|
|
699
|
+
times_recast = model.predict_time_of_interest(signals)
|
|
700
|
+
|
|
701
|
+
if label is None:
|
|
702
|
+
class_col = 'pair_class'
|
|
703
|
+
time_col = 'pair_t0'
|
|
704
|
+
status_col = 'pair_status'
|
|
705
|
+
else:
|
|
706
|
+
class_col = 'pair_class_'+label
|
|
707
|
+
time_col = 'pair_t_'+label
|
|
708
|
+
status_col = 'pair_status_'+label
|
|
709
|
+
|
|
710
|
+
for i,(pair,group) in enumerate(trajectories_pairs.groupby(pair_groupby_cols)):
|
|
711
|
+
indices = group.index
|
|
712
|
+
trajectories_pairs.loc[indices,class_col] = classes[i]
|
|
713
|
+
trajectories_pairs.loc[indices,time_col] = times_recast[i]
|
|
714
|
+
print('Done.')
|
|
715
|
+
|
|
716
|
+
# At the end rename cols again
|
|
717
|
+
trajectories_pairs = trajectories_pairs.rename(columns=lambda x: x.replace('pair_',''))
|
|
718
|
+
trajectories_reference = trajectories_pairs.rename(columns=lambda x: x.replace('reference_',''))
|
|
719
|
+
trajectories_neighbors = trajectories_pairs.rename(columns=lambda x: x.replace('neighbor_',''))
|
|
720
|
+
invalid_cols = [c for c in list(trajectories_pairs.columns) if c.startswith('Unnamed')]
|
|
721
|
+
trajectories_pairs = trajectories_pairs.drop(columns=invalid_cols)
|
|
722
|
+
|
|
723
|
+
return trajectories_pairs
|
|
337
724
|
|
|
338
725
|
class SignalDetectionModel(object):
|
|
339
726
|
|
|
@@ -1258,10 +1645,10 @@ class SignalDetectionModel(object):
|
|
|
1258
1645
|
csv_logger = CSVLogger(os.sep.join([self.model_folder,'log_classifier.csv']), append=True, separator=';')
|
|
1259
1646
|
self.cb.append(csv_logger)
|
|
1260
1647
|
checkpoint_path = os.sep.join([self.model_folder,"classifier.h5"])
|
|
1261
|
-
cp_callback = ModelCheckpoint(checkpoint_path,monitor="val_iou",mode="max",verbose=1,save_best_only=True,save_weights_only=False,save_freq="epoch")
|
|
1648
|
+
cp_callback = ModelCheckpoint(checkpoint_path, monitor="val_iou",mode="max",verbose=1,save_best_only=True,save_weights_only=False,save_freq="epoch")
|
|
1262
1649
|
self.cb.append(cp_callback)
|
|
1263
1650
|
|
|
1264
|
-
callback_stop = EarlyStopping(monitor='val_iou',
|
|
1651
|
+
callback_stop = EarlyStopping(monitor='val_iou',mode='max',patience=100)
|
|
1265
1652
|
self.cb.append(callback_stop)
|
|
1266
1653
|
|
|
1267
1654
|
elif mode=="regressor":
|
|
@@ -1278,7 +1665,7 @@ class SignalDetectionModel(object):
|
|
|
1278
1665
|
cp_callback = ModelCheckpoint(checkpoint_path,monitor="val_loss",mode="min",verbose=1,save_best_only=True,save_weights_only=False,save_freq="epoch")
|
|
1279
1666
|
self.cb.append(cp_callback)
|
|
1280
1667
|
|
|
1281
|
-
callback_stop = EarlyStopping(monitor='val_loss', patience=200)
|
|
1668
|
+
callback_stop = EarlyStopping(monitor='val_loss', mode='min', patience=200)
|
|
1282
1669
|
self.cb.append(callback_stop)
|
|
1283
1670
|
|
|
1284
1671
|
log_dir = self.model_folder+os.sep
|
|
@@ -1312,7 +1699,7 @@ class SignalDetectionModel(object):
|
|
|
1312
1699
|
if isinstance(self.list_of_sets[0],str):
|
|
1313
1700
|
# Case 1: a list of npy files to be loaded
|
|
1314
1701
|
for s in self.list_of_sets:
|
|
1315
|
-
|
|
1702
|
+
|
|
1316
1703
|
signal_dataset = self.load_set(s)
|
|
1317
1704
|
selected_signals, max_length = self.find_best_signal_match(signal_dataset)
|
|
1318
1705
|
signals_recast, classes, times_of_interest = self.cast_signals_into_training_data(signal_dataset, selected_signals, max_length)
|
|
@@ -1416,7 +1803,6 @@ class SignalDetectionModel(object):
|
|
|
1416
1803
|
x_train_aug.append(aug[0])
|
|
1417
1804
|
y_time_train_aug.append(aug[1])
|
|
1418
1805
|
y_class_train_aug.append(aug[2])
|
|
1419
|
-
print('per class counts ',counts)
|
|
1420
1806
|
|
|
1421
1807
|
# Save augmented training set
|
|
1422
1808
|
self.x_train = np.array(x_train_aug)
|
|
@@ -1469,7 +1855,15 @@ class SignalDetectionModel(object):
|
|
|
1469
1855
|
for i in range(self.n_channels):
|
|
1470
1856
|
try:
|
|
1471
1857
|
# take into account timeline for accurate time regression
|
|
1472
|
-
|
|
1858
|
+
|
|
1859
|
+
if selected_signals[i].startswith('pair_'):
|
|
1860
|
+
timeline = signal_dataset[k]['pair_FRAME'].astype(int)
|
|
1861
|
+
elif selected_signals[i].startswith('reference_'):
|
|
1862
|
+
timeline = signal_dataset[k]['reference_FRAME'].astype(int)
|
|
1863
|
+
elif selected_signals[i].startswith('neighbor_'):
|
|
1864
|
+
timeline = signal_dataset[k]['neighbor_FRAME'].astype(int)
|
|
1865
|
+
else:
|
|
1866
|
+
timeline = signal_dataset[k]['FRAME'].astype(int)
|
|
1473
1867
|
signals_recast[k,timeline,i] = signal_dataset[k][selected_signals[i]]
|
|
1474
1868
|
except:
|
|
1475
1869
|
print(f"Attribute {selected_signals[i]} matched to {self.channel_option[i]} not found in annotation...")
|
celldetective/tracking.py
CHANGED
|
@@ -130,7 +130,7 @@ def track(labels, configuration=None, stack=None, spatial_calibration=1, feature
|
|
|
130
130
|
tracking_updates = ["motion"]
|
|
131
131
|
|
|
132
132
|
tracker.append(new_btrack_objects)
|
|
133
|
-
tracker.volume = ((0,volume[0]), (0,volume[1])) #(-1e5, 1e5)
|
|
133
|
+
tracker.volume = ((0,volume[0]), (0,volume[1]), (-1e5, 1e5)) #(-1e5, 1e5)
|
|
134
134
|
#print(tracker.volume)
|
|
135
135
|
tracker.track(tracking_updates=tracking_updates, **track_kwargs)
|
|
136
136
|
tracker.optimize(options=optimizer_options)
|
|
@@ -138,7 +138,11 @@ def track(labels, configuration=None, stack=None, spatial_calibration=1, feature
|
|
|
138
138
|
data, properties, graph = tracker.to_napari() #ndim=2
|
|
139
139
|
|
|
140
140
|
# do the table post processing and napari options
|
|
141
|
-
|
|
141
|
+
if data.shape[1]==4:
|
|
142
|
+
df = pd.DataFrame(data, columns=[column_labels['track'],column_labels['time'],column_labels['y'],column_labels['x']])
|
|
143
|
+
elif data.shape[1]==5:
|
|
144
|
+
df = pd.DataFrame(data, columns=[column_labels['track'],column_labels['time'],"z",column_labels['y'],column_labels['x']])
|
|
145
|
+
df = df.drop(columns=['z'])
|
|
142
146
|
df[column_labels['x']+'_um'] = df[column_labels['x']]*spatial_calibration
|
|
143
147
|
df[column_labels['y']+'_um'] = df[column_labels['y']]*spatial_calibration
|
|
144
148
|
|
|
@@ -160,6 +164,8 @@ def track(labels, configuration=None, stack=None, spatial_calibration=1, feature
|
|
|
160
164
|
if clean_trajectories_kwargs is not None:
|
|
161
165
|
df = clean_trajectories(df.copy(),**clean_trajectories_kwargs)
|
|
162
166
|
|
|
167
|
+
df['ID'] = np.arange(len(df)).astype(int)
|
|
168
|
+
|
|
163
169
|
if view_on_napari:
|
|
164
170
|
view_on_napari_btrack(data,properties,graph,stack=stack,labels=labels,relabel=True)
|
|
165
171
|
|
celldetective/utils.py
CHANGED
|
@@ -23,7 +23,85 @@ from tqdm import tqdm
|
|
|
23
23
|
import shutil
|
|
24
24
|
import tempfile
|
|
25
25
|
from scipy.interpolate import griddata
|
|
26
|
+
import re
|
|
27
|
+
from scipy.ndimage.morphology import distance_transform_edt
|
|
28
|
+
from scipy import ndimage
|
|
29
|
+
from skimage.morphology import disk
|
|
30
|
+
|
|
31
|
+
def contour_of_instance_segmentation(label, distance):
|
|
32
|
+
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
Generate an instance mask containing the contour of the segmented objects.
|
|
36
|
+
|
|
37
|
+
Parameters
|
|
38
|
+
----------
|
|
39
|
+
label : ndarray
|
|
40
|
+
The instance segmentation labels.
|
|
41
|
+
distance : int, float, list, or tuple
|
|
42
|
+
The distance or range of distances from the edge of each instance to include in the contour.
|
|
43
|
+
If a single value is provided, it represents the maximum distance. If a tuple or list is provided,
|
|
44
|
+
it represents the minimum and maximum distances.
|
|
45
|
+
|
|
46
|
+
Returns
|
|
47
|
+
-------
|
|
48
|
+
border_label : ndarray
|
|
49
|
+
An instance mask containing the contour of the segmented objects.
|
|
50
|
+
|
|
51
|
+
Notes
|
|
52
|
+
-----
|
|
53
|
+
This function generates an instance mask representing the contour of the segmented instances in the label image.
|
|
54
|
+
It use the distance_transform_edt function from the scipy.ndimage module to compute the Euclidean distance transform.
|
|
55
|
+
The contour is defined based on the specified distance(s) from the edge of each instance.
|
|
56
|
+
The resulting mask, `border_label`, contains the contour regions, while the interior regions are set to zero.
|
|
57
|
+
|
|
58
|
+
Examples
|
|
59
|
+
--------
|
|
60
|
+
>>> border_label = contour_of_instance_segmentation(label, distance=3)
|
|
61
|
+
# Generate a binary mask containing the contour of the segmented instances with a maximum distance of 3 pixels.
|
|
62
|
+
|
|
63
|
+
"""
|
|
64
|
+
if isinstance(distance,(list,tuple)) or distance >= 0 :
|
|
65
|
+
|
|
66
|
+
edt = distance_transform_edt(label)
|
|
67
|
+
|
|
68
|
+
if isinstance(distance, list) or isinstance(distance, tuple):
|
|
69
|
+
min_distance = distance[0]; max_distance = distance[1]
|
|
70
|
+
|
|
71
|
+
elif isinstance(distance, (int, float)):
|
|
72
|
+
min_distance = 0
|
|
73
|
+
max_distance = distance
|
|
74
|
+
|
|
75
|
+
thresholded = (edt <= max_distance) * (edt > min_distance)
|
|
76
|
+
border_label = np.copy(label)
|
|
77
|
+
border_label[np.where(thresholded == 0)] = 0
|
|
78
|
+
|
|
79
|
+
else:
|
|
80
|
+
size = (2*abs(int(distance))+1, 2*abs(int(distance))+1)
|
|
81
|
+
dilated_image = ndimage.grey_dilation(label, footprint=disk(int(abs(distance)))) #size=size,
|
|
82
|
+
border_label=np.copy(dilated_image)
|
|
83
|
+
matching_cells = np.logical_and(dilated_image != 0, label == dilated_image)
|
|
84
|
+
border_label[np.where(matching_cells == True)] = 0
|
|
85
|
+
border_label[label!=0] = 0.
|
|
86
|
+
|
|
87
|
+
return border_label
|
|
88
|
+
|
|
89
|
+
def extract_identity_col(trajectories):
|
|
90
|
+
|
|
91
|
+
if 'TRACK_ID' in list(trajectories.columns):
|
|
92
|
+
if not np.all(trajectories['TRACK_ID'].isnull()):
|
|
93
|
+
id_col = 'TRACK_ID'
|
|
94
|
+
else:
|
|
95
|
+
if 'ID' in list(trajectories.columns):
|
|
96
|
+
id_col = 'ID'
|
|
97
|
+
elif 'ID' in list(trajectories.columns):
|
|
98
|
+
|
|
99
|
+
id_col = 'ID'
|
|
100
|
+
else:
|
|
101
|
+
print('ID or TRACK ID column could not be found in the table...')
|
|
102
|
+
id_col = None
|
|
26
103
|
|
|
104
|
+
return id_col
|
|
27
105
|
|
|
28
106
|
def derivative(x, timeline, window, mode='bi'):
|
|
29
107
|
|
|
@@ -430,6 +508,31 @@ def mask_edges(binary_mask, border_size):
|
|
|
430
508
|
return binary_mask
|
|
431
509
|
|
|
432
510
|
|
|
511
|
+
def extract_cols_from_query(query: str):
|
|
512
|
+
|
|
513
|
+
# Track variables in a dictionary to be used as a dictionary of globals. From: https://stackoverflow.com/questions/64576913/extract-pandas-dataframe-column-names-from-query-string
|
|
514
|
+
|
|
515
|
+
variables = {}
|
|
516
|
+
|
|
517
|
+
while True:
|
|
518
|
+
try:
|
|
519
|
+
# Try creating a Expr object with the query string and dictionary of globals.
|
|
520
|
+
# This will raise an error as long as the dictionary of globals is incomplete.
|
|
521
|
+
env = pd.core.computation.scope.ensure_scope(level=0, global_dict=variables)
|
|
522
|
+
pd.core.computation.eval.Expr(query, env=env)
|
|
523
|
+
|
|
524
|
+
# Exit the loop when evaluation is successful.
|
|
525
|
+
break
|
|
526
|
+
except pd.errors.UndefinedVariableError as e:
|
|
527
|
+
# This relies on the format defined here: https://github.com/pandas-dev/pandas/blob/965ceca9fd796940050d6fc817707bba1c4f9bff/pandas/errors/__init__.py#L401
|
|
528
|
+
name = re.findall("name '(.+?)' is not defined", str(e))[0]
|
|
529
|
+
|
|
530
|
+
# Add the name to the globals dictionary with a dummy value.
|
|
531
|
+
variables[name] = None
|
|
532
|
+
|
|
533
|
+
return list(variables.keys())
|
|
534
|
+
|
|
535
|
+
|
|
433
536
|
def create_patch_mask(h, w, center=None, radius=None):
|
|
434
537
|
|
|
435
538
|
"""
|
|
@@ -564,6 +667,22 @@ def rename_intensity_column(df, channels):
|
|
|
564
667
|
else:
|
|
565
668
|
new_name = new_name.replace('centre_of_mass', "centre_of_mass_orientation")
|
|
566
669
|
to_rename.update({intensity_columns[k]: new_name.replace('-', '_')})
|
|
670
|
+
elif sections[-2] == "2":
|
|
671
|
+
new_name = np.delete(measure, -1)
|
|
672
|
+
new_name = '_'.join(list(new_name))
|
|
673
|
+
if 'edge' in intensity_columns[k]:
|
|
674
|
+
new_name = new_name.replace('centre_of_mass_displacement', "edge_centre_of_mass_x")
|
|
675
|
+
else:
|
|
676
|
+
new_name = new_name.replace('centre_of_mass', "centre_of_mass_x")
|
|
677
|
+
to_rename.update({intensity_columns[k]: new_name.replace('-', '_')})
|
|
678
|
+
elif sections[-2] == "3":
|
|
679
|
+
new_name = np.delete(measure, -1)
|
|
680
|
+
new_name = '_'.join(list(new_name))
|
|
681
|
+
if 'edge' in intensity_columns[k]:
|
|
682
|
+
new_name = new_name.replace('centre_of_mass_displacement', "edge_centre_of_mass_y")
|
|
683
|
+
else:
|
|
684
|
+
new_name = new_name.replace('centre_of_mass', "centre_of_mass_y")
|
|
685
|
+
to_rename.update({intensity_columns[k]: new_name.replace('-', '_')})
|
|
567
686
|
if 'radial_gradient' in intensity_columns[k]:
|
|
568
687
|
# sections = np.array(re.split('-|_', intensity_columns[k]))
|
|
569
688
|
measure = np.array(re.split('-|_', new_name))
|
|
@@ -1058,20 +1177,33 @@ def _extract_channel_indices(channels, required_channels):
|
|
|
1058
1177
|
# [0, 1]
|
|
1059
1178
|
"""
|
|
1060
1179
|
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1180
|
+
channel_indices = []
|
|
1181
|
+
for c in required_channels:
|
|
1182
|
+
if c!='None' and c is not None:
|
|
1065
1183
|
try:
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1184
|
+
ch_idx = channels.index(c)
|
|
1185
|
+
channel_indices.append(ch_idx)
|
|
1186
|
+
except Exception as e:
|
|
1187
|
+
print(f"Error {e}. The channel required by the model is not available in your data... Check the configuration file.")
|
|
1188
|
+
channels = None
|
|
1189
|
+
break
|
|
1190
|
+
else:
|
|
1191
|
+
channel_indices.append(None)
|
|
1070
1192
|
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
|
|
1193
|
+
# if channels is not None:
|
|
1194
|
+
# channel_indices = []
|
|
1195
|
+
# for ch in required_channels:
|
|
1196
|
+
|
|
1197
|
+
# try:
|
|
1198
|
+
# idx = channels.index(ch)
|
|
1199
|
+
# except ValueError:
|
|
1200
|
+
# print('Mismatch between the channels required by the model and the provided channels.')
|
|
1201
|
+
# return None
|
|
1202
|
+
|
|
1203
|
+
# channel_indices.append(idx)
|
|
1204
|
+
# channel_indices = np.array(channel_indices)
|
|
1205
|
+
# else:
|
|
1206
|
+
# channel_indices = np.arange(len(required_channels))
|
|
1075
1207
|
|
|
1076
1208
|
return channel_indices
|
|
1077
1209
|
|