celldetective 1.1.1.post4__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- celldetective/__init__.py +2 -1
- celldetective/extra_properties.py +62 -34
- celldetective/gui/__init__.py +1 -0
- celldetective/gui/analyze_block.py +2 -1
- celldetective/gui/classifier_widget.py +8 -7
- celldetective/gui/control_panel.py +50 -6
- celldetective/gui/layouts.py +5 -4
- celldetective/gui/neighborhood_options.py +10 -8
- celldetective/gui/plot_signals_ui.py +39 -11
- celldetective/gui/process_block.py +413 -95
- celldetective/gui/retrain_segmentation_model_options.py +17 -4
- celldetective/gui/retrain_signal_model_options.py +106 -6
- celldetective/gui/signal_annotator.py +25 -5
- celldetective/gui/signal_annotator2.py +2708 -0
- celldetective/gui/signal_annotator_options.py +3 -1
- celldetective/gui/survival_ui.py +15 -6
- celldetective/gui/tableUI.py +235 -39
- celldetective/io.py +537 -421
- celldetective/measure.py +919 -969
- celldetective/models/pair_signal_detection/blank +0 -0
- celldetective/neighborhood.py +426 -354
- celldetective/relative_measurements.py +648 -0
- celldetective/scripts/analyze_signals.py +1 -1
- celldetective/scripts/measure_cells.py +28 -8
- celldetective/scripts/measure_relative.py +103 -0
- celldetective/scripts/segment_cells.py +5 -5
- celldetective/scripts/track_cells.py +4 -1
- celldetective/scripts/train_segmentation_model.py +23 -18
- celldetective/scripts/train_signal_model.py +33 -0
- celldetective/signals.py +402 -8
- celldetective/tracking.py +8 -2
- celldetective/utils.py +93 -0
- {celldetective-1.1.1.post4.dist-info → celldetective-1.2.0.dist-info}/METADATA +8 -8
- {celldetective-1.1.1.post4.dist-info → celldetective-1.2.0.dist-info}/RECORD +38 -34
- {celldetective-1.1.1.post4.dist-info → celldetective-1.2.0.dist-info}/WHEEL +1 -1
- {celldetective-1.1.1.post4.dist-info → celldetective-1.2.0.dist-info}/LICENSE +0 -0
- {celldetective-1.1.1.post4.dist-info → celldetective-1.2.0.dist-info}/entry_points.txt +0 -0
- {celldetective-1.1.1.post4.dist-info → celldetective-1.2.0.dist-info}/top_level.txt +0 -0
celldetective/signals.py
CHANGED
|
@@ -18,8 +18,8 @@ from sklearn.metrics import jaccard_score, balanced_accuracy_score, precision_sc
|
|
|
18
18
|
from scipy.interpolate import interp1d
|
|
19
19
|
from scipy.ndimage import shift
|
|
20
20
|
|
|
21
|
-
from celldetective.io import get_signal_models_list, locate_signal_model
|
|
22
|
-
from celldetective.tracking import clean_trajectories
|
|
21
|
+
from celldetective.io import get_signal_models_list, locate_signal_model, get_position_pickle, get_position_table
|
|
22
|
+
from celldetective.tracking import clean_trajectories, interpolate_nan_properties
|
|
23
23
|
from celldetective.utils import regression_plot, train_test_split, compute_weights
|
|
24
24
|
import matplotlib.pyplot as plt
|
|
25
25
|
from natsort import natsorted
|
|
@@ -32,6 +32,7 @@ from scipy.optimize import curve_fit
|
|
|
32
32
|
import time
|
|
33
33
|
import math
|
|
34
34
|
import pandas as pd
|
|
35
|
+
from pandas.api.types import is_numeric_dtype
|
|
35
36
|
|
|
36
37
|
abs_path = os.sep.join([os.path.split(os.path.dirname(os.path.realpath(__file__)))[0],'celldetective'])
|
|
37
38
|
|
|
@@ -334,6 +335,392 @@ def analyze_signals_at_position(pos, model, mode, use_gpu=True, return_table=Fal
|
|
|
334
335
|
else:
|
|
335
336
|
return None
|
|
336
337
|
|
|
338
|
+
def analyze_pair_signals_at_position(pos, model, use_gpu=True):
|
|
339
|
+
|
|
340
|
+
"""
|
|
341
|
+
|
|
342
|
+
"""
|
|
343
|
+
|
|
344
|
+
pos = pos.replace('\\','/')
|
|
345
|
+
pos = rf"{pos}"
|
|
346
|
+
assert os.path.exists(pos),f'Position {pos} is not a valid path.'
|
|
347
|
+
if not pos.endswith('/'):
|
|
348
|
+
pos += '/'
|
|
349
|
+
|
|
350
|
+
df_targets = get_position_pickle(pos, population='targets')
|
|
351
|
+
df_effectors = get_position_pickle(pos, population='effectors')
|
|
352
|
+
dataframes = {
|
|
353
|
+
'targets': df_targets,
|
|
354
|
+
'effectors': df_effectors,
|
|
355
|
+
}
|
|
356
|
+
df_pairs = get_position_table(pos, population='pairs')
|
|
357
|
+
|
|
358
|
+
# Need to identify expected reference / neighbor tables
|
|
359
|
+
model_path = locate_signal_model(model, pairs=True)
|
|
360
|
+
print(f'Looking for model in {model_path}...')
|
|
361
|
+
complete_path = model_path
|
|
362
|
+
complete_path = rf"{complete_path}"
|
|
363
|
+
model_config_path = os.sep.join([complete_path, 'config_input.json'])
|
|
364
|
+
model_config_path = rf"{model_config_path}"
|
|
365
|
+
f = open(model_config_path)
|
|
366
|
+
model_config_path = json.load(f)
|
|
367
|
+
|
|
368
|
+
reference_population = model_config_path['reference_population']
|
|
369
|
+
neighbor_population = model_config_path['neighbor_population']
|
|
370
|
+
|
|
371
|
+
df = analyze_pair_signals(df_pairs, dataframes[reference_population], dataframes[neighbor_population], model=model)
|
|
372
|
+
|
|
373
|
+
table = pos + os.sep.join(["output","tables",f"trajectories_pairs.csv"])
|
|
374
|
+
df.to_csv(table, index=False)
|
|
375
|
+
|
|
376
|
+
return None
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
def analyze_signals(trajectories, model, interpolate_na=True,
|
|
380
|
+
selected_signals=None,
|
|
381
|
+
model_path=None,
|
|
382
|
+
column_labels={'track': "TRACK_ID", 'time': 'FRAME', 'x': 'POSITION_X', 'y': 'POSITION_Y'},
|
|
383
|
+
plot_outcome=False, output_dir=None):
|
|
384
|
+
"""
|
|
385
|
+
Analyzes signals from trajectory data using a specified signal detection model and configuration.
|
|
386
|
+
|
|
387
|
+
This function preprocesses trajectory data, selects specified signals, and applies a pretrained signal detection
|
|
388
|
+
model to predict classes and times of interest for each trajectory. It supports custom column labeling, interpolation
|
|
389
|
+
of missing values, and plotting of analysis outcomes.
|
|
390
|
+
|
|
391
|
+
Parameters
|
|
392
|
+
----------
|
|
393
|
+
trajectories : pandas.DataFrame
|
|
394
|
+
DataFrame containing trajectory data with columns for track ID, frame, position, and signals.
|
|
395
|
+
model : str
|
|
396
|
+
The name of the signal detection model to be used for analysis.
|
|
397
|
+
interpolate_na : bool, optional
|
|
398
|
+
Whether to interpolate missing values in the trajectories (default is True).
|
|
399
|
+
selected_signals : list of str, optional
|
|
400
|
+
A list of column names from `trajectories` representing the signals to be analyzed. If None, signals will
|
|
401
|
+
be automatically selected based on the model configuration (default is None).
|
|
402
|
+
column_labels : dict, optional
|
|
403
|
+
A dictionary mapping the default column names ('track', 'time', 'x', 'y') to the corresponding column names
|
|
404
|
+
in `trajectories` (default is {'track': "TRACK_ID", 'time': 'FRAME', 'x': 'POSITION_X', 'y': 'POSITION_Y'}).
|
|
405
|
+
plot_outcome : bool, optional
|
|
406
|
+
If True, generates and saves a plot of the signal analysis outcome (default is False).
|
|
407
|
+
output_dir : str, optional
|
|
408
|
+
The directory where the outcome plot will be saved. Required if `plot_outcome` is True (default is None).
|
|
409
|
+
|
|
410
|
+
Returns
|
|
411
|
+
-------
|
|
412
|
+
pandas.DataFrame
|
|
413
|
+
The input `trajectories` DataFrame with additional columns for predicted classes, times of interest, and
|
|
414
|
+
corresponding colors based on status and class.
|
|
415
|
+
|
|
416
|
+
Raises
|
|
417
|
+
------
|
|
418
|
+
AssertionError
|
|
419
|
+
If the model or its configuration file cannot be located.
|
|
420
|
+
|
|
421
|
+
Notes
|
|
422
|
+
-----
|
|
423
|
+
- The function relies on an external model configuration file (`config_input.json`) located in the model's directory.
|
|
424
|
+
- Signal selection and preprocessing are based on the requirements specified in the model's configuration.
|
|
425
|
+
|
|
426
|
+
"""
|
|
427
|
+
|
|
428
|
+
model_path = locate_signal_model(model, path=model_path)
|
|
429
|
+
complete_path = model_path # +model
|
|
430
|
+
complete_path = rf"{complete_path}"
|
|
431
|
+
model_config_path = os.sep.join([complete_path, 'config_input.json'])
|
|
432
|
+
model_config_path = rf"{model_config_path}"
|
|
433
|
+
assert os.path.exists(complete_path), f'Model {model} could not be located in folder {model_path}... Abort.'
|
|
434
|
+
assert os.path.exists(
|
|
435
|
+
model_config_path), f'Model configuration could not be located in folder {model_path}... Abort.'
|
|
436
|
+
|
|
437
|
+
available_signals = list(trajectories.columns)
|
|
438
|
+
print('The available_signals are : ', available_signals)
|
|
439
|
+
|
|
440
|
+
f = open(model_config_path)
|
|
441
|
+
config = json.load(f)
|
|
442
|
+
required_signals = config["channels"]
|
|
443
|
+
|
|
444
|
+
try:
|
|
445
|
+
label = config['label']
|
|
446
|
+
if label == '':
|
|
447
|
+
label = None
|
|
448
|
+
except:
|
|
449
|
+
label = None
|
|
450
|
+
|
|
451
|
+
if selected_signals is None:
|
|
452
|
+
selected_signals = []
|
|
453
|
+
for s in required_signals:
|
|
454
|
+
pattern_test = [s in a or s == a for a in available_signals]
|
|
455
|
+
print(f'Pattern test for signal {s}: ', pattern_test)
|
|
456
|
+
assert np.any(
|
|
457
|
+
pattern_test), f'No signal matches with the requirements of the model {required_signals}. Please pass the signals manually with the argument selected_signals or add measurements. Abort.'
|
|
458
|
+
valid_columns = np.array(available_signals)[np.array(pattern_test)]
|
|
459
|
+
if len(valid_columns) == 1:
|
|
460
|
+
selected_signals.append(valid_columns[0])
|
|
461
|
+
else:
|
|
462
|
+
# print(test_number_of_nan(trajectories, valid_columns))
|
|
463
|
+
print(f'Found several candidate signals: {valid_columns}')
|
|
464
|
+
for vc in natsorted(valid_columns):
|
|
465
|
+
if 'circle' in vc:
|
|
466
|
+
selected_signals.append(vc)
|
|
467
|
+
break
|
|
468
|
+
else:
|
|
469
|
+
selected_signals.append(valid_columns[0])
|
|
470
|
+
# do something more complicated in case of one to many columns
|
|
471
|
+
# pass
|
|
472
|
+
else:
|
|
473
|
+
assert len(selected_signals) == len(
|
|
474
|
+
required_signals), f'Mismatch between the number of required signals {required_signals} and the provided signals {selected_signals}... Abort.'
|
|
475
|
+
|
|
476
|
+
print(f'The following channels will be passed to the model: {selected_signals}')
|
|
477
|
+
trajectories_clean = clean_trajectories(trajectories, interpolate_na=interpolate_na,
|
|
478
|
+
interpolate_position_gaps=interpolate_na, column_labels=column_labels)
|
|
479
|
+
|
|
480
|
+
max_signal_size = int(trajectories_clean[column_labels['time']].max()) + 2
|
|
481
|
+
tracks = trajectories_clean[column_labels['track']].unique()
|
|
482
|
+
signals = np.zeros((len(tracks), max_signal_size, len(selected_signals)))
|
|
483
|
+
|
|
484
|
+
for i, (tid, group) in enumerate(trajectories_clean.groupby(column_labels['track'])):
|
|
485
|
+
frames = group[column_labels['time']].to_numpy().astype(int)
|
|
486
|
+
for j, col in enumerate(selected_signals):
|
|
487
|
+
signal = group[col].to_numpy()
|
|
488
|
+
signals[i, frames, j] = signal
|
|
489
|
+
signals[i, max(frames):, j] = signal[-1]
|
|
490
|
+
|
|
491
|
+
# for i in range(5):
|
|
492
|
+
# print('pre model')
|
|
493
|
+
# plt.plot(signals[i,:,0])
|
|
494
|
+
# plt.show()
|
|
495
|
+
|
|
496
|
+
model = SignalDetectionModel(pretrained=complete_path)
|
|
497
|
+
print('signal shape: ', signals.shape)
|
|
498
|
+
|
|
499
|
+
classes = model.predict_class(signals)
|
|
500
|
+
times_recast = model.predict_time_of_interest(signals)
|
|
501
|
+
|
|
502
|
+
if label is None:
|
|
503
|
+
class_col = 'class'
|
|
504
|
+
time_col = 't0'
|
|
505
|
+
status_col = 'status'
|
|
506
|
+
else:
|
|
507
|
+
class_col = 'class_' + label
|
|
508
|
+
time_col = 't_' + label
|
|
509
|
+
status_col = 'status_' + label
|
|
510
|
+
|
|
511
|
+
for i, (tid, group) in enumerate(trajectories.groupby(column_labels['track'])):
|
|
512
|
+
indices = group.index
|
|
513
|
+
trajectories.loc[indices, class_col] = classes[i]
|
|
514
|
+
trajectories.loc[indices, time_col] = times_recast[i]
|
|
515
|
+
print('Done.')
|
|
516
|
+
|
|
517
|
+
for tid, group in trajectories.groupby(column_labels['track']):
|
|
518
|
+
|
|
519
|
+
indices = group.index
|
|
520
|
+
t0 = group[time_col].to_numpy()[0]
|
|
521
|
+
cclass = group[class_col].to_numpy()[0]
|
|
522
|
+
timeline = group[column_labels['time']].to_numpy()
|
|
523
|
+
status = np.zeros_like(timeline)
|
|
524
|
+
if t0 > 0:
|
|
525
|
+
status[timeline >= t0] = 1.
|
|
526
|
+
if cclass == 2:
|
|
527
|
+
status[:] = 2
|
|
528
|
+
if cclass > 2:
|
|
529
|
+
status[:] = 42
|
|
530
|
+
status_color = [color_from_status(s) for s in status]
|
|
531
|
+
class_color = [color_from_class(cclass) for i in range(len(status))]
|
|
532
|
+
|
|
533
|
+
trajectories.loc[indices, status_col] = status
|
|
534
|
+
trajectories.loc[indices, 'status_color'] = status_color
|
|
535
|
+
trajectories.loc[indices, 'class_color'] = class_color
|
|
536
|
+
|
|
537
|
+
if plot_outcome:
|
|
538
|
+
fig, ax = plt.subplots(1, len(selected_signals), figsize=(10, 5))
|
|
539
|
+
for i, s in enumerate(selected_signals):
|
|
540
|
+
for k, (tid, group) in enumerate(trajectories.groupby(column_labels['track'])):
|
|
541
|
+
cclass = group[class_col].to_numpy()[0]
|
|
542
|
+
t0 = group[time_col].to_numpy()[0]
|
|
543
|
+
timeline = group[column_labels['time']].to_numpy()
|
|
544
|
+
if cclass == 0:
|
|
545
|
+
if len(selected_signals) > 1:
|
|
546
|
+
ax[i].plot(timeline - t0, group[s].to_numpy(), c='tab:blue', alpha=0.1)
|
|
547
|
+
else:
|
|
548
|
+
ax.plot(timeline - t0, group[s].to_numpy(), c='tab:blue', alpha=0.1)
|
|
549
|
+
if len(selected_signals) > 1:
|
|
550
|
+
for a, s in zip(ax, selected_signals):
|
|
551
|
+
a.set_title(s)
|
|
552
|
+
a.set_xlabel(r'time - t$_0$ [frame]')
|
|
553
|
+
a.spines['top'].set_visible(False)
|
|
554
|
+
a.spines['right'].set_visible(False)
|
|
555
|
+
else:
|
|
556
|
+
ax.set_title(s)
|
|
557
|
+
ax.set_xlabel(r'time - t$_0$ [frame]')
|
|
558
|
+
ax.spines['top'].set_visible(False)
|
|
559
|
+
ax.spines['right'].set_visible(False)
|
|
560
|
+
plt.tight_layout()
|
|
561
|
+
if output_dir is not None:
|
|
562
|
+
plt.savefig(output_dir + 'signal_collapse.png', bbox_inches='tight', dpi=300)
|
|
563
|
+
plt.pause(3)
|
|
564
|
+
plt.close()
|
|
565
|
+
|
|
566
|
+
return trajectories
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
def analyze_pair_signals(trajectories_pairs,trajectories_reference,trajectories_neighbors, model, interpolate_na=True, selected_signals=None,
|
|
570
|
+
model_path=None, plot_outcome=False, output_dir=None, column_labels = {'track': "TRACK_ID", 'time': 'FRAME', 'x': 'POSITION_X', 'y': 'POSITION_Y'}):
|
|
571
|
+
"""
|
|
572
|
+
"""
|
|
573
|
+
|
|
574
|
+
model_path = locate_signal_model(model, path=model_path, pairs=True)
|
|
575
|
+
print(f'Looking for model in {model_path}...')
|
|
576
|
+
complete_path = model_path
|
|
577
|
+
complete_path = rf"{complete_path}"
|
|
578
|
+
model_config_path = os.sep.join([complete_path, 'config_input.json'])
|
|
579
|
+
model_config_path = rf"{model_config_path}"
|
|
580
|
+
assert os.path.exists(complete_path), f'Model {model} could not be located in folder {model_path}... Abort.'
|
|
581
|
+
assert os.path.exists(model_config_path), f'Model configuration could not be located in folder {model_path}... Abort.'
|
|
582
|
+
|
|
583
|
+
trajectories_pairs = trajectories_pairs.rename(columns=lambda x: 'pair_' + x)
|
|
584
|
+
trajectories_reference = trajectories_reference.rename(columns=lambda x: 'reference_' + x)
|
|
585
|
+
trajectories_neighbors = trajectories_neighbors.rename(columns=lambda x: 'neighbor_' + x)
|
|
586
|
+
|
|
587
|
+
if 'pair_position' in list(trajectories_pairs.columns):
|
|
588
|
+
pair_groupby_cols = ['pair_position', 'pair_REFERENCE_ID', 'pair_NEIGHBOR_ID']
|
|
589
|
+
else:
|
|
590
|
+
pair_groupby_cols = ['pair_REFERENCE_ID', 'pair_NEIGHBOR_ID']
|
|
591
|
+
|
|
592
|
+
if 'reference_position' in list(trajectories_reference.columns):
|
|
593
|
+
reference_groupby_cols = ['reference_position', 'reference_TRACK_ID']
|
|
594
|
+
else:
|
|
595
|
+
reference_groupby_cols = ['reference_TRACK_ID']
|
|
596
|
+
|
|
597
|
+
if 'neighbor_position' in list(trajectories_neighbors.columns):
|
|
598
|
+
neighbor_groupby_cols = ['neighbor_position', 'neighbor_TRACK_ID']
|
|
599
|
+
else:
|
|
600
|
+
neighbor_groupby_cols = ['neighbor_TRACK_ID']
|
|
601
|
+
|
|
602
|
+
available_signals = [] #list(trajectories_pairs.columns) + list(trajectories_reference.columns) + list(trajectories_neighbors.columns)
|
|
603
|
+
for col in list(trajectories_pairs.columns):
|
|
604
|
+
if is_numeric_dtype(trajectories_pairs[col]):
|
|
605
|
+
available_signals.append(col)
|
|
606
|
+
for col in list(trajectories_reference.columns):
|
|
607
|
+
if is_numeric_dtype(trajectories_reference[col]):
|
|
608
|
+
available_signals.append(col)
|
|
609
|
+
for col in list(trajectories_neighbors.columns):
|
|
610
|
+
if is_numeric_dtype(trajectories_neighbors[col]):
|
|
611
|
+
available_signals.append(col)
|
|
612
|
+
|
|
613
|
+
print('The available signals are : ', available_signals)
|
|
614
|
+
|
|
615
|
+
f = open(model_config_path)
|
|
616
|
+
config = json.load(f)
|
|
617
|
+
required_signals = config["channels"]
|
|
618
|
+
|
|
619
|
+
try:
|
|
620
|
+
label = config['label']
|
|
621
|
+
if label=='':
|
|
622
|
+
label = None
|
|
623
|
+
except:
|
|
624
|
+
label = None
|
|
625
|
+
|
|
626
|
+
if selected_signals is None:
|
|
627
|
+
selected_signals = []
|
|
628
|
+
for s in required_signals:
|
|
629
|
+
pattern_test = [s in a or s==a for a in available_signals]
|
|
630
|
+
print(f'Pattern test for signal {s}: ', pattern_test)
|
|
631
|
+
assert np.any(pattern_test),f'No signal matches with the requirements of the model {required_signals}. Please pass the signals manually with the argument selected_signals or add measurements. Abort.'
|
|
632
|
+
valid_columns = np.array(available_signals)[np.array(pattern_test)]
|
|
633
|
+
if len(valid_columns)==1:
|
|
634
|
+
selected_signals.append(valid_columns[0])
|
|
635
|
+
else:
|
|
636
|
+
#print(test_number_of_nan(trajectories, valid_columns))
|
|
637
|
+
print(f'Found several candidate signals: {valid_columns}')
|
|
638
|
+
for vc in natsorted(valid_columns):
|
|
639
|
+
if 'circle' in vc:
|
|
640
|
+
selected_signals.append(vc)
|
|
641
|
+
break
|
|
642
|
+
else:
|
|
643
|
+
selected_signals.append(valid_columns[0])
|
|
644
|
+
# do something more complicated in case of one to many columns
|
|
645
|
+
#pass
|
|
646
|
+
else:
|
|
647
|
+
assert len(selected_signals)==len(required_signals),f'Mismatch between the number of required signals {required_signals} and the provided signals {selected_signals}... Abort.'
|
|
648
|
+
|
|
649
|
+
print(f'The following channels will be passed to the model: {selected_signals}')
|
|
650
|
+
trajectories_reference_clean = interpolate_nan_properties(trajectories_reference, track_label=reference_groupby_cols)
|
|
651
|
+
trajectories_neighbors_clean = interpolate_nan_properties(trajectories_neighbors, track_label=neighbor_groupby_cols)
|
|
652
|
+
trajectories_pairs_clean = interpolate_nan_properties(trajectories_pairs, track_label=pair_groupby_cols)
|
|
653
|
+
print(f'{trajectories_pairs_clean.columns=}')
|
|
654
|
+
|
|
655
|
+
max_signal_size = int(trajectories_pairs_clean['pair_FRAME'].max()) + 2
|
|
656
|
+
pair_tracks = trajectories_pairs_clean.groupby(pair_groupby_cols).size()
|
|
657
|
+
signals = np.zeros((len(pair_tracks),max_signal_size, len(selected_signals)))
|
|
658
|
+
print(f'{max_signal_size=} {len(pair_tracks)=} {signals.shape=}')
|
|
659
|
+
|
|
660
|
+
for i,(pair,group) in enumerate(trajectories_pairs_clean.groupby(pair_groupby_cols)):
|
|
661
|
+
|
|
662
|
+
if 'pair_position' not in list(trajectories_pairs_clean.columns):
|
|
663
|
+
pos_mode = False
|
|
664
|
+
reference_cell = pair[0]; neighbor_cell = pair[1]
|
|
665
|
+
else:
|
|
666
|
+
pos_mode = True
|
|
667
|
+
reference_cell = pair[1]; neighbor_cell = pair[2]; pos = pair[0]
|
|
668
|
+
|
|
669
|
+
if pos_mode and 'reference_position' in list(trajectories_reference_clean.columns) and 'neighbor_position' in list(trajectories_neighbors_clean.columns):
|
|
670
|
+
reference_filter = (trajectories_reference_clean['reference_TRACK_ID']==reference_cell)&(trajectories_reference_clean['reference_position']==pos)
|
|
671
|
+
neighbor_filter = (trajectories_neighbors_clean['neighbor_TRACK_ID']==neighbor_cell)&(trajectories_neighbors_clean['neighbor_position']==pos)
|
|
672
|
+
else:
|
|
673
|
+
reference_filter = trajectories_reference_clean['reference_TRACK_ID']==reference_cell
|
|
674
|
+
neighbor_filter = trajectories_neighbors_clean['neighbor_TRACK_ID']==neighbor_cell
|
|
675
|
+
|
|
676
|
+
pair_frames = group['pair_FRAME'].to_numpy().astype(int)
|
|
677
|
+
|
|
678
|
+
for j,col in enumerate(selected_signals):
|
|
679
|
+
if col.startswith('pair_'):
|
|
680
|
+
signal = group[col].to_numpy()
|
|
681
|
+
signals[i,pair_frames,j] = signal
|
|
682
|
+
signals[i,max(pair_frames):,j] = signal[-1]
|
|
683
|
+
elif col.startswith('reference_'):
|
|
684
|
+
signal = trajectories_reference_clean.loc[reference_filter, col].to_numpy()
|
|
685
|
+
timeline = trajectories_reference_clean.loc[reference_filter, 'reference_FRAME'].to_numpy()
|
|
686
|
+
signals[i,timeline,j] = signal
|
|
687
|
+
signals[i,max(timeline):,j] = signal[-1]
|
|
688
|
+
elif col.startswith('neighbor_'):
|
|
689
|
+
signal = trajectories_neighbors_clean.loc[neighbor_filter, col].to_numpy()
|
|
690
|
+
timeline = trajectories_neighbors_clean.loc[neighbor_filter, 'neighbor_FRAME'].to_numpy()
|
|
691
|
+
signals[i,timeline,j] = signal
|
|
692
|
+
signals[i,max(timeline):,j] = signal[-1]
|
|
693
|
+
|
|
694
|
+
|
|
695
|
+
model = SignalDetectionModel(pretrained=complete_path)
|
|
696
|
+
print('signal shape: ', signals.shape)
|
|
697
|
+
|
|
698
|
+
classes = model.predict_class(signals)
|
|
699
|
+
times_recast = model.predict_time_of_interest(signals)
|
|
700
|
+
|
|
701
|
+
if label is None:
|
|
702
|
+
class_col = 'pair_class'
|
|
703
|
+
time_col = 'pair_t0'
|
|
704
|
+
status_col = 'pair_status'
|
|
705
|
+
else:
|
|
706
|
+
class_col = 'pair_class_'+label
|
|
707
|
+
time_col = 'pair_t_'+label
|
|
708
|
+
status_col = 'pair_status_'+label
|
|
709
|
+
|
|
710
|
+
for i,(pair,group) in enumerate(trajectories_pairs.groupby(pair_groupby_cols)):
|
|
711
|
+
indices = group.index
|
|
712
|
+
trajectories_pairs.loc[indices,class_col] = classes[i]
|
|
713
|
+
trajectories_pairs.loc[indices,time_col] = times_recast[i]
|
|
714
|
+
print('Done.')
|
|
715
|
+
|
|
716
|
+
# At the end rename cols again
|
|
717
|
+
trajectories_pairs = trajectories_pairs.rename(columns=lambda x: x.replace('pair_',''))
|
|
718
|
+
trajectories_reference = trajectories_pairs.rename(columns=lambda x: x.replace('reference_',''))
|
|
719
|
+
trajectories_neighbors = trajectories_pairs.rename(columns=lambda x: x.replace('neighbor_',''))
|
|
720
|
+
invalid_cols = [c for c in list(trajectories_pairs.columns) if c.startswith('Unnamed')]
|
|
721
|
+
trajectories_pairs = trajectories_pairs.drop(columns=invalid_cols)
|
|
722
|
+
|
|
723
|
+
return trajectories_pairs
|
|
337
724
|
|
|
338
725
|
class SignalDetectionModel(object):
|
|
339
726
|
|
|
@@ -1258,10 +1645,10 @@ class SignalDetectionModel(object):
|
|
|
1258
1645
|
csv_logger = CSVLogger(os.sep.join([self.model_folder,'log_classifier.csv']), append=True, separator=';')
|
|
1259
1646
|
self.cb.append(csv_logger)
|
|
1260
1647
|
checkpoint_path = os.sep.join([self.model_folder,"classifier.h5"])
|
|
1261
|
-
cp_callback = ModelCheckpoint(checkpoint_path,monitor="val_iou",mode="max",verbose=1,save_best_only=True,save_weights_only=False,save_freq="epoch")
|
|
1648
|
+
cp_callback = ModelCheckpoint(checkpoint_path, monitor="val_iou",mode="max",verbose=1,save_best_only=True,save_weights_only=False,save_freq="epoch")
|
|
1262
1649
|
self.cb.append(cp_callback)
|
|
1263
1650
|
|
|
1264
|
-
callback_stop = EarlyStopping(monitor='val_iou',
|
|
1651
|
+
callback_stop = EarlyStopping(monitor='val_iou',mode='max',patience=100)
|
|
1265
1652
|
self.cb.append(callback_stop)
|
|
1266
1653
|
|
|
1267
1654
|
elif mode=="regressor":
|
|
@@ -1278,7 +1665,7 @@ class SignalDetectionModel(object):
|
|
|
1278
1665
|
cp_callback = ModelCheckpoint(checkpoint_path,monitor="val_loss",mode="min",verbose=1,save_best_only=True,save_weights_only=False,save_freq="epoch")
|
|
1279
1666
|
self.cb.append(cp_callback)
|
|
1280
1667
|
|
|
1281
|
-
callback_stop = EarlyStopping(monitor='val_loss', patience=200)
|
|
1668
|
+
callback_stop = EarlyStopping(monitor='val_loss', mode='min', patience=200)
|
|
1282
1669
|
self.cb.append(callback_stop)
|
|
1283
1670
|
|
|
1284
1671
|
log_dir = self.model_folder+os.sep
|
|
@@ -1312,7 +1699,7 @@ class SignalDetectionModel(object):
|
|
|
1312
1699
|
if isinstance(self.list_of_sets[0],str):
|
|
1313
1700
|
# Case 1: a list of npy files to be loaded
|
|
1314
1701
|
for s in self.list_of_sets:
|
|
1315
|
-
|
|
1702
|
+
|
|
1316
1703
|
signal_dataset = self.load_set(s)
|
|
1317
1704
|
selected_signals, max_length = self.find_best_signal_match(signal_dataset)
|
|
1318
1705
|
signals_recast, classes, times_of_interest = self.cast_signals_into_training_data(signal_dataset, selected_signals, max_length)
|
|
@@ -1416,7 +1803,6 @@ class SignalDetectionModel(object):
|
|
|
1416
1803
|
x_train_aug.append(aug[0])
|
|
1417
1804
|
y_time_train_aug.append(aug[1])
|
|
1418
1805
|
y_class_train_aug.append(aug[2])
|
|
1419
|
-
print('per class counts ',counts)
|
|
1420
1806
|
|
|
1421
1807
|
# Save augmented training set
|
|
1422
1808
|
self.x_train = np.array(x_train_aug)
|
|
@@ -1469,7 +1855,15 @@ class SignalDetectionModel(object):
|
|
|
1469
1855
|
for i in range(self.n_channels):
|
|
1470
1856
|
try:
|
|
1471
1857
|
# take into account timeline for accurate time regression
|
|
1472
|
-
|
|
1858
|
+
|
|
1859
|
+
if selected_signals[i].startswith('pair_'):
|
|
1860
|
+
timeline = signal_dataset[k]['pair_FRAME'].astype(int)
|
|
1861
|
+
elif selected_signals[i].startswith('reference_'):
|
|
1862
|
+
timeline = signal_dataset[k]['reference_FRAME'].astype(int)
|
|
1863
|
+
elif selected_signals[i].startswith('neighbor_'):
|
|
1864
|
+
timeline = signal_dataset[k]['neighbor_FRAME'].astype(int)
|
|
1865
|
+
else:
|
|
1866
|
+
timeline = signal_dataset[k]['FRAME'].astype(int)
|
|
1473
1867
|
signals_recast[k,timeline,i] = signal_dataset[k][selected_signals[i]]
|
|
1474
1868
|
except:
|
|
1475
1869
|
print(f"Attribute {selected_signals[i]} matched to {self.channel_option[i]} not found in annotation...")
|
celldetective/tracking.py
CHANGED
|
@@ -130,7 +130,7 @@ def track(labels, configuration=None, stack=None, spatial_calibration=1, feature
|
|
|
130
130
|
tracking_updates = ["motion"]
|
|
131
131
|
|
|
132
132
|
tracker.append(new_btrack_objects)
|
|
133
|
-
tracker.volume = ((0,volume[0]), (0,volume[1])) #(-1e5, 1e5)
|
|
133
|
+
tracker.volume = ((0,volume[0]), (0,volume[1]), (-1e5, 1e5)) #(-1e5, 1e5)
|
|
134
134
|
#print(tracker.volume)
|
|
135
135
|
tracker.track(tracking_updates=tracking_updates, **track_kwargs)
|
|
136
136
|
tracker.optimize(options=optimizer_options)
|
|
@@ -138,7 +138,11 @@ def track(labels, configuration=None, stack=None, spatial_calibration=1, feature
|
|
|
138
138
|
data, properties, graph = tracker.to_napari() #ndim=2
|
|
139
139
|
|
|
140
140
|
# do the table post processing and napari options
|
|
141
|
-
|
|
141
|
+
if data.shape[1]==4:
|
|
142
|
+
df = pd.DataFrame(data, columns=[column_labels['track'],column_labels['time'],column_labels['y'],column_labels['x']])
|
|
143
|
+
elif data.shape[1]==5:
|
|
144
|
+
df = pd.DataFrame(data, columns=[column_labels['track'],column_labels['time'],"z",column_labels['y'],column_labels['x']])
|
|
145
|
+
df = df.drop(columns=['z'])
|
|
142
146
|
df[column_labels['x']+'_um'] = df[column_labels['x']]*spatial_calibration
|
|
143
147
|
df[column_labels['y']+'_um'] = df[column_labels['y']]*spatial_calibration
|
|
144
148
|
|
|
@@ -160,6 +164,8 @@ def track(labels, configuration=None, stack=None, spatial_calibration=1, feature
|
|
|
160
164
|
if clean_trajectories_kwargs is not None:
|
|
161
165
|
df = clean_trajectories(df.copy(),**clean_trajectories_kwargs)
|
|
162
166
|
|
|
167
|
+
df['ID'] = np.arange(len(df)).astype(int)
|
|
168
|
+
|
|
163
169
|
if view_on_napari:
|
|
164
170
|
view_on_napari_btrack(data,properties,graph,stack=stack,labels=labels,relabel=True)
|
|
165
171
|
|
celldetective/utils.py
CHANGED
|
@@ -24,7 +24,84 @@ import shutil
|
|
|
24
24
|
import tempfile
|
|
25
25
|
from scipy.interpolate import griddata
|
|
26
26
|
import re
|
|
27
|
+
from scipy.ndimage.morphology import distance_transform_edt
|
|
28
|
+
from scipy import ndimage
|
|
29
|
+
from skimage.morphology import disk
|
|
27
30
|
|
|
31
|
+
def contour_of_instance_segmentation(label, distance):
|
|
32
|
+
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
Generate an instance mask containing the contour of the segmented objects.
|
|
36
|
+
|
|
37
|
+
Parameters
|
|
38
|
+
----------
|
|
39
|
+
label : ndarray
|
|
40
|
+
The instance segmentation labels.
|
|
41
|
+
distance : int, float, list, or tuple
|
|
42
|
+
The distance or range of distances from the edge of each instance to include in the contour.
|
|
43
|
+
If a single value is provided, it represents the maximum distance. If a tuple or list is provided,
|
|
44
|
+
it represents the minimum and maximum distances.
|
|
45
|
+
|
|
46
|
+
Returns
|
|
47
|
+
-------
|
|
48
|
+
border_label : ndarray
|
|
49
|
+
An instance mask containing the contour of the segmented objects.
|
|
50
|
+
|
|
51
|
+
Notes
|
|
52
|
+
-----
|
|
53
|
+
This function generates an instance mask representing the contour of the segmented instances in the label image.
|
|
54
|
+
It use the distance_transform_edt function from the scipy.ndimage module to compute the Euclidean distance transform.
|
|
55
|
+
The contour is defined based on the specified distance(s) from the edge of each instance.
|
|
56
|
+
The resulting mask, `border_label`, contains the contour regions, while the interior regions are set to zero.
|
|
57
|
+
|
|
58
|
+
Examples
|
|
59
|
+
--------
|
|
60
|
+
>>> border_label = contour_of_instance_segmentation(label, distance=3)
|
|
61
|
+
# Generate a binary mask containing the contour of the segmented instances with a maximum distance of 3 pixels.
|
|
62
|
+
|
|
63
|
+
"""
|
|
64
|
+
if isinstance(distance,(list,tuple)) or distance >= 0 :
|
|
65
|
+
|
|
66
|
+
edt = distance_transform_edt(label)
|
|
67
|
+
|
|
68
|
+
if isinstance(distance, list) or isinstance(distance, tuple):
|
|
69
|
+
min_distance = distance[0]; max_distance = distance[1]
|
|
70
|
+
|
|
71
|
+
elif isinstance(distance, (int, float)):
|
|
72
|
+
min_distance = 0
|
|
73
|
+
max_distance = distance
|
|
74
|
+
|
|
75
|
+
thresholded = (edt <= max_distance) * (edt > min_distance)
|
|
76
|
+
border_label = np.copy(label)
|
|
77
|
+
border_label[np.where(thresholded == 0)] = 0
|
|
78
|
+
|
|
79
|
+
else:
|
|
80
|
+
size = (2*abs(int(distance))+1, 2*abs(int(distance))+1)
|
|
81
|
+
dilated_image = ndimage.grey_dilation(label, footprint=disk(int(abs(distance)))) #size=size,
|
|
82
|
+
border_label=np.copy(dilated_image)
|
|
83
|
+
matching_cells = np.logical_and(dilated_image != 0, label == dilated_image)
|
|
84
|
+
border_label[np.where(matching_cells == True)] = 0
|
|
85
|
+
border_label[label!=0] = 0.
|
|
86
|
+
|
|
87
|
+
return border_label
|
|
88
|
+
|
|
89
|
+
def extract_identity_col(trajectories):
|
|
90
|
+
|
|
91
|
+
if 'TRACK_ID' in list(trajectories.columns):
|
|
92
|
+
if not np.all(trajectories['TRACK_ID'].isnull()):
|
|
93
|
+
id_col = 'TRACK_ID'
|
|
94
|
+
else:
|
|
95
|
+
if 'ID' in list(trajectories.columns):
|
|
96
|
+
id_col = 'ID'
|
|
97
|
+
elif 'ID' in list(trajectories.columns):
|
|
98
|
+
|
|
99
|
+
id_col = 'ID'
|
|
100
|
+
else:
|
|
101
|
+
print('ID or TRACK ID column could not be found in the table...')
|
|
102
|
+
id_col = None
|
|
103
|
+
|
|
104
|
+
return id_col
|
|
28
105
|
|
|
29
106
|
def derivative(x, timeline, window, mode='bi'):
|
|
30
107
|
|
|
@@ -590,6 +667,22 @@ def rename_intensity_column(df, channels):
|
|
|
590
667
|
else:
|
|
591
668
|
new_name = new_name.replace('centre_of_mass', "centre_of_mass_orientation")
|
|
592
669
|
to_rename.update({intensity_columns[k]: new_name.replace('-', '_')})
|
|
670
|
+
elif sections[-2] == "2":
|
|
671
|
+
new_name = np.delete(measure, -1)
|
|
672
|
+
new_name = '_'.join(list(new_name))
|
|
673
|
+
if 'edge' in intensity_columns[k]:
|
|
674
|
+
new_name = new_name.replace('centre_of_mass_displacement', "edge_centre_of_mass_x")
|
|
675
|
+
else:
|
|
676
|
+
new_name = new_name.replace('centre_of_mass', "centre_of_mass_x")
|
|
677
|
+
to_rename.update({intensity_columns[k]: new_name.replace('-', '_')})
|
|
678
|
+
elif sections[-2] == "3":
|
|
679
|
+
new_name = np.delete(measure, -1)
|
|
680
|
+
new_name = '_'.join(list(new_name))
|
|
681
|
+
if 'edge' in intensity_columns[k]:
|
|
682
|
+
new_name = new_name.replace('centre_of_mass_displacement', "edge_centre_of_mass_y")
|
|
683
|
+
else:
|
|
684
|
+
new_name = new_name.replace('centre_of_mass', "centre_of_mass_y")
|
|
685
|
+
to_rename.update({intensity_columns[k]: new_name.replace('-', '_')})
|
|
593
686
|
if 'radial_gradient' in intensity_columns[k]:
|
|
594
687
|
# sections = np.array(re.split('-|_', intensity_columns[k]))
|
|
595
688
|
measure = np.array(re.split('-|_', new_name))
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: celldetective
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.2.0
|
|
4
4
|
Summary: description
|
|
5
5
|
Home-page: http://github.com/remyeltorro/celldetective
|
|
6
6
|
Author: Rémy Torro
|
|
@@ -12,14 +12,14 @@ Requires-Dist: wheel
|
|
|
12
12
|
Requires-Dist: nbsphinx
|
|
13
13
|
Requires-Dist: nbsphinx-link
|
|
14
14
|
Requires-Dist: sphinx-rtd-theme
|
|
15
|
-
Requires-Dist: sphinx
|
|
16
|
-
Requires-Dist: jinja2
|
|
15
|
+
Requires-Dist: sphinx==5.0.2
|
|
16
|
+
Requires-Dist: jinja2<3.1
|
|
17
17
|
Requires-Dist: ipykernel
|
|
18
18
|
Requires-Dist: stardist
|
|
19
|
-
Requires-Dist: cellpose
|
|
19
|
+
Requires-Dist: cellpose<3
|
|
20
20
|
Requires-Dist: scikit-learn
|
|
21
21
|
Requires-Dist: btrack
|
|
22
|
-
Requires-Dist: tensorflow
|
|
22
|
+
Requires-Dist: tensorflow<=2.12.1
|
|
23
23
|
Requires-Dist: napari
|
|
24
24
|
Requires-Dist: tqdm
|
|
25
25
|
Requires-Dist: mahotas
|
|
@@ -29,11 +29,11 @@ Requires-Dist: lifelines
|
|
|
29
29
|
Requires-Dist: setuptools
|
|
30
30
|
Requires-Dist: scipy
|
|
31
31
|
Requires-Dist: seaborn
|
|
32
|
-
Requires-Dist: opencv-python-headless
|
|
32
|
+
Requires-Dist: opencv-python-headless==4.7.0.72
|
|
33
33
|
Requires-Dist: liblapack
|
|
34
34
|
Requires-Dist: gputools
|
|
35
|
-
Requires-Dist: lmfit
|
|
36
|
-
Requires-Dist: superqt[cmap]
|
|
35
|
+
Requires-Dist: lmfit~=1.2.2
|
|
36
|
+
Requires-Dist: superqt[cmap]>=0.6.1
|
|
37
37
|
Requires-Dist: matplotlib-scalebar
|
|
38
38
|
|
|
39
39
|
# Celldetective
|