celldetective 1.1.1.post4__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- celldetective/__init__.py +2 -1
- celldetective/extra_properties.py +62 -34
- celldetective/gui/__init__.py +1 -0
- celldetective/gui/analyze_block.py +2 -1
- celldetective/gui/classifier_widget.py +15 -9
- celldetective/gui/control_panel.py +50 -6
- celldetective/gui/layouts.py +5 -4
- celldetective/gui/neighborhood_options.py +13 -9
- celldetective/gui/plot_signals_ui.py +39 -11
- celldetective/gui/process_block.py +413 -95
- celldetective/gui/retrain_segmentation_model_options.py +17 -4
- celldetective/gui/retrain_signal_model_options.py +106 -6
- celldetective/gui/signal_annotator.py +29 -9
- celldetective/gui/signal_annotator2.py +2708 -0
- celldetective/gui/signal_annotator_options.py +3 -1
- celldetective/gui/survival_ui.py +15 -6
- celldetective/gui/tableUI.py +222 -60
- celldetective/io.py +536 -420
- celldetective/measure.py +919 -969
- celldetective/models/pair_signal_detection/blank +0 -0
- celldetective/models/segmentation_effectors/ricm-bimodal/config_input.json +130 -0
- celldetective/models/segmentation_effectors/ricm-bimodal/ricm-bimodal +0 -0
- celldetective/models/segmentation_effectors/ricm-bimodal/training_instructions.json +37 -0
- celldetective/neighborhood.py +428 -354
- celldetective/relative_measurements.py +648 -0
- celldetective/scripts/analyze_signals.py +1 -1
- celldetective/scripts/measure_cells.py +28 -8
- celldetective/scripts/measure_relative.py +103 -0
- celldetective/scripts/segment_cells.py +5 -5
- celldetective/scripts/track_cells.py +4 -1
- celldetective/scripts/train_segmentation_model.py +23 -18
- celldetective/scripts/train_signal_model.py +33 -0
- celldetective/signals.py +405 -8
- celldetective/tracking.py +8 -2
- celldetective/utils.py +178 -17
- {celldetective-1.1.1.post4.dist-info → celldetective-1.2.1.dist-info}/METADATA +8 -8
- {celldetective-1.1.1.post4.dist-info → celldetective-1.2.1.dist-info}/RECORD +41 -34
- {celldetective-1.1.1.post4.dist-info → celldetective-1.2.1.dist-info}/WHEEL +1 -1
- {celldetective-1.1.1.post4.dist-info → celldetective-1.2.1.dist-info}/LICENSE +0 -0
- {celldetective-1.1.1.post4.dist-info → celldetective-1.2.1.dist-info}/entry_points.txt +0 -0
- {celldetective-1.1.1.post4.dist-info → celldetective-1.2.1.dist-info}/top_level.txt +0 -0
celldetective/signals.py
CHANGED
|
@@ -18,8 +18,8 @@ from sklearn.metrics import jaccard_score, balanced_accuracy_score, precision_sc
|
|
|
18
18
|
from scipy.interpolate import interp1d
|
|
19
19
|
from scipy.ndimage import shift
|
|
20
20
|
|
|
21
|
-
from celldetective.io import get_signal_models_list, locate_signal_model
|
|
22
|
-
from celldetective.tracking import clean_trajectories
|
|
21
|
+
from celldetective.io import get_signal_models_list, locate_signal_model, get_position_pickle, get_position_table
|
|
22
|
+
from celldetective.tracking import clean_trajectories, interpolate_nan_properties
|
|
23
23
|
from celldetective.utils import regression_plot, train_test_split, compute_weights
|
|
24
24
|
import matplotlib.pyplot as plt
|
|
25
25
|
from natsort import natsorted
|
|
@@ -32,6 +32,7 @@ from scipy.optimize import curve_fit
|
|
|
32
32
|
import time
|
|
33
33
|
import math
|
|
34
34
|
import pandas as pd
|
|
35
|
+
from pandas.api.types import is_numeric_dtype
|
|
35
36
|
|
|
36
37
|
abs_path = os.sep.join([os.path.split(os.path.dirname(os.path.realpath(__file__)))[0],'celldetective'])
|
|
37
38
|
|
|
@@ -154,6 +155,7 @@ def analyze_signals(trajectories, model, interpolate_na=True,
|
|
|
154
155
|
f = open(model_config_path)
|
|
155
156
|
config = json.load(f)
|
|
156
157
|
required_signals = config["channels"]
|
|
158
|
+
model_signal_length = config['model_signal_length']
|
|
157
159
|
|
|
158
160
|
try:
|
|
159
161
|
label = config['label']
|
|
@@ -189,6 +191,8 @@ def analyze_signals(trajectories, model, interpolate_na=True,
|
|
|
189
191
|
trajectories_clean = clean_trajectories(trajectories, interpolate_na=interpolate_na, interpolate_position_gaps=interpolate_na, column_labels=column_labels)
|
|
190
192
|
|
|
191
193
|
max_signal_size = int(trajectories_clean[column_labels['time']].max()) + 2
|
|
194
|
+
assert max_signal_size <= model_signal_length,f'The current signals are longer ({max_signal_size}) than the maximum expected input ({model_signal_length}) for this signal analysis model. Abort...'
|
|
195
|
+
|
|
192
196
|
tracks = trajectories_clean[column_labels['track']].unique()
|
|
193
197
|
signals = np.zeros((len(tracks),max_signal_size, len(selected_signals)))
|
|
194
198
|
|
|
@@ -334,6 +338,392 @@ def analyze_signals_at_position(pos, model, mode, use_gpu=True, return_table=Fal
|
|
|
334
338
|
else:
|
|
335
339
|
return None
|
|
336
340
|
|
|
341
|
+
def analyze_pair_signals_at_position(pos, model, use_gpu=True):
|
|
342
|
+
|
|
343
|
+
"""
|
|
344
|
+
|
|
345
|
+
"""
|
|
346
|
+
|
|
347
|
+
pos = pos.replace('\\','/')
|
|
348
|
+
pos = rf"{pos}"
|
|
349
|
+
assert os.path.exists(pos),f'Position {pos} is not a valid path.'
|
|
350
|
+
if not pos.endswith('/'):
|
|
351
|
+
pos += '/'
|
|
352
|
+
|
|
353
|
+
df_targets = get_position_pickle(pos, population='targets')
|
|
354
|
+
df_effectors = get_position_pickle(pos, population='effectors')
|
|
355
|
+
dataframes = {
|
|
356
|
+
'targets': df_targets,
|
|
357
|
+
'effectors': df_effectors,
|
|
358
|
+
}
|
|
359
|
+
df_pairs = get_position_table(pos, population='pairs')
|
|
360
|
+
|
|
361
|
+
# Need to identify expected reference / neighbor tables
|
|
362
|
+
model_path = locate_signal_model(model, pairs=True)
|
|
363
|
+
print(f'Looking for model in {model_path}...')
|
|
364
|
+
complete_path = model_path
|
|
365
|
+
complete_path = rf"{complete_path}"
|
|
366
|
+
model_config_path = os.sep.join([complete_path, 'config_input.json'])
|
|
367
|
+
model_config_path = rf"{model_config_path}"
|
|
368
|
+
f = open(model_config_path)
|
|
369
|
+
model_config_path = json.load(f)
|
|
370
|
+
|
|
371
|
+
reference_population = model_config_path['reference_population']
|
|
372
|
+
neighbor_population = model_config_path['neighbor_population']
|
|
373
|
+
|
|
374
|
+
df = analyze_pair_signals(df_pairs, dataframes[reference_population], dataframes[neighbor_population], model=model)
|
|
375
|
+
|
|
376
|
+
table = pos + os.sep.join(["output","tables",f"trajectories_pairs.csv"])
|
|
377
|
+
df.to_csv(table, index=False)
|
|
378
|
+
|
|
379
|
+
return None
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
def analyze_signals(trajectories, model, interpolate_na=True,
|
|
383
|
+
selected_signals=None,
|
|
384
|
+
model_path=None,
|
|
385
|
+
column_labels={'track': "TRACK_ID", 'time': 'FRAME', 'x': 'POSITION_X', 'y': 'POSITION_Y'},
|
|
386
|
+
plot_outcome=False, output_dir=None):
|
|
387
|
+
"""
|
|
388
|
+
Analyzes signals from trajectory data using a specified signal detection model and configuration.
|
|
389
|
+
|
|
390
|
+
This function preprocesses trajectory data, selects specified signals, and applies a pretrained signal detection
|
|
391
|
+
model to predict classes and times of interest for each trajectory. It supports custom column labeling, interpolation
|
|
392
|
+
of missing values, and plotting of analysis outcomes.
|
|
393
|
+
|
|
394
|
+
Parameters
|
|
395
|
+
----------
|
|
396
|
+
trajectories : pandas.DataFrame
|
|
397
|
+
DataFrame containing trajectory data with columns for track ID, frame, position, and signals.
|
|
398
|
+
model : str
|
|
399
|
+
The name of the signal detection model to be used for analysis.
|
|
400
|
+
interpolate_na : bool, optional
|
|
401
|
+
Whether to interpolate missing values in the trajectories (default is True).
|
|
402
|
+
selected_signals : list of str, optional
|
|
403
|
+
A list of column names from `trajectories` representing the signals to be analyzed. If None, signals will
|
|
404
|
+
be automatically selected based on the model configuration (default is None).
|
|
405
|
+
column_labels : dict, optional
|
|
406
|
+
A dictionary mapping the default column names ('track', 'time', 'x', 'y') to the corresponding column names
|
|
407
|
+
in `trajectories` (default is {'track': "TRACK_ID", 'time': 'FRAME', 'x': 'POSITION_X', 'y': 'POSITION_Y'}).
|
|
408
|
+
plot_outcome : bool, optional
|
|
409
|
+
If True, generates and saves a plot of the signal analysis outcome (default is False).
|
|
410
|
+
output_dir : str, optional
|
|
411
|
+
The directory where the outcome plot will be saved. Required if `plot_outcome` is True (default is None).
|
|
412
|
+
|
|
413
|
+
Returns
|
|
414
|
+
-------
|
|
415
|
+
pandas.DataFrame
|
|
416
|
+
The input `trajectories` DataFrame with additional columns for predicted classes, times of interest, and
|
|
417
|
+
corresponding colors based on status and class.
|
|
418
|
+
|
|
419
|
+
Raises
|
|
420
|
+
------
|
|
421
|
+
AssertionError
|
|
422
|
+
If the model or its configuration file cannot be located.
|
|
423
|
+
|
|
424
|
+
Notes
|
|
425
|
+
-----
|
|
426
|
+
- The function relies on an external model configuration file (`config_input.json`) located in the model's directory.
|
|
427
|
+
- Signal selection and preprocessing are based on the requirements specified in the model's configuration.
|
|
428
|
+
|
|
429
|
+
"""
|
|
430
|
+
|
|
431
|
+
model_path = locate_signal_model(model, path=model_path)
|
|
432
|
+
complete_path = model_path # +model
|
|
433
|
+
complete_path = rf"{complete_path}"
|
|
434
|
+
model_config_path = os.sep.join([complete_path, 'config_input.json'])
|
|
435
|
+
model_config_path = rf"{model_config_path}"
|
|
436
|
+
assert os.path.exists(complete_path), f'Model {model} could not be located in folder {model_path}... Abort.'
|
|
437
|
+
assert os.path.exists(
|
|
438
|
+
model_config_path), f'Model configuration could not be located in folder {model_path}... Abort.'
|
|
439
|
+
|
|
440
|
+
available_signals = list(trajectories.columns)
|
|
441
|
+
print('The available_signals are : ', available_signals)
|
|
442
|
+
|
|
443
|
+
f = open(model_config_path)
|
|
444
|
+
config = json.load(f)
|
|
445
|
+
required_signals = config["channels"]
|
|
446
|
+
|
|
447
|
+
try:
|
|
448
|
+
label = config['label']
|
|
449
|
+
if label == '':
|
|
450
|
+
label = None
|
|
451
|
+
except:
|
|
452
|
+
label = None
|
|
453
|
+
|
|
454
|
+
if selected_signals is None:
|
|
455
|
+
selected_signals = []
|
|
456
|
+
for s in required_signals:
|
|
457
|
+
pattern_test = [s in a or s == a for a in available_signals]
|
|
458
|
+
print(f'Pattern test for signal {s}: ', pattern_test)
|
|
459
|
+
assert np.any(
|
|
460
|
+
pattern_test), f'No signal matches with the requirements of the model {required_signals}. Please pass the signals manually with the argument selected_signals or add measurements. Abort.'
|
|
461
|
+
valid_columns = np.array(available_signals)[np.array(pattern_test)]
|
|
462
|
+
if len(valid_columns) == 1:
|
|
463
|
+
selected_signals.append(valid_columns[0])
|
|
464
|
+
else:
|
|
465
|
+
# print(test_number_of_nan(trajectories, valid_columns))
|
|
466
|
+
print(f'Found several candidate signals: {valid_columns}')
|
|
467
|
+
for vc in natsorted(valid_columns):
|
|
468
|
+
if 'circle' in vc:
|
|
469
|
+
selected_signals.append(vc)
|
|
470
|
+
break
|
|
471
|
+
else:
|
|
472
|
+
selected_signals.append(valid_columns[0])
|
|
473
|
+
# do something more complicated in case of one to many columns
|
|
474
|
+
# pass
|
|
475
|
+
else:
|
|
476
|
+
assert len(selected_signals) == len(
|
|
477
|
+
required_signals), f'Mismatch between the number of required signals {required_signals} and the provided signals {selected_signals}... Abort.'
|
|
478
|
+
|
|
479
|
+
print(f'The following channels will be passed to the model: {selected_signals}')
|
|
480
|
+
trajectories_clean = clean_trajectories(trajectories, interpolate_na=interpolate_na,
|
|
481
|
+
interpolate_position_gaps=interpolate_na, column_labels=column_labels)
|
|
482
|
+
|
|
483
|
+
max_signal_size = int(trajectories_clean[column_labels['time']].max()) + 2
|
|
484
|
+
tracks = trajectories_clean[column_labels['track']].unique()
|
|
485
|
+
signals = np.zeros((len(tracks), max_signal_size, len(selected_signals)))
|
|
486
|
+
|
|
487
|
+
for i, (tid, group) in enumerate(trajectories_clean.groupby(column_labels['track'])):
|
|
488
|
+
frames = group[column_labels['time']].to_numpy().astype(int)
|
|
489
|
+
for j, col in enumerate(selected_signals):
|
|
490
|
+
signal = group[col].to_numpy()
|
|
491
|
+
signals[i, frames, j] = signal
|
|
492
|
+
signals[i, max(frames):, j] = signal[-1]
|
|
493
|
+
|
|
494
|
+
# for i in range(5):
|
|
495
|
+
# print('pre model')
|
|
496
|
+
# plt.plot(signals[i,:,0])
|
|
497
|
+
# plt.show()
|
|
498
|
+
|
|
499
|
+
model = SignalDetectionModel(pretrained=complete_path)
|
|
500
|
+
print('signal shape: ', signals.shape)
|
|
501
|
+
|
|
502
|
+
classes = model.predict_class(signals)
|
|
503
|
+
times_recast = model.predict_time_of_interest(signals)
|
|
504
|
+
|
|
505
|
+
if label is None:
|
|
506
|
+
class_col = 'class'
|
|
507
|
+
time_col = 't0'
|
|
508
|
+
status_col = 'status'
|
|
509
|
+
else:
|
|
510
|
+
class_col = 'class_' + label
|
|
511
|
+
time_col = 't_' + label
|
|
512
|
+
status_col = 'status_' + label
|
|
513
|
+
|
|
514
|
+
for i, (tid, group) in enumerate(trajectories.groupby(column_labels['track'])):
|
|
515
|
+
indices = group.index
|
|
516
|
+
trajectories.loc[indices, class_col] = classes[i]
|
|
517
|
+
trajectories.loc[indices, time_col] = times_recast[i]
|
|
518
|
+
print('Done.')
|
|
519
|
+
|
|
520
|
+
for tid, group in trajectories.groupby(column_labels['track']):
|
|
521
|
+
|
|
522
|
+
indices = group.index
|
|
523
|
+
t0 = group[time_col].to_numpy()[0]
|
|
524
|
+
cclass = group[class_col].to_numpy()[0]
|
|
525
|
+
timeline = group[column_labels['time']].to_numpy()
|
|
526
|
+
status = np.zeros_like(timeline)
|
|
527
|
+
if t0 > 0:
|
|
528
|
+
status[timeline >= t0] = 1.
|
|
529
|
+
if cclass == 2:
|
|
530
|
+
status[:] = 2
|
|
531
|
+
if cclass > 2:
|
|
532
|
+
status[:] = 42
|
|
533
|
+
status_color = [color_from_status(s) for s in status]
|
|
534
|
+
class_color = [color_from_class(cclass) for i in range(len(status))]
|
|
535
|
+
|
|
536
|
+
trajectories.loc[indices, status_col] = status
|
|
537
|
+
trajectories.loc[indices, 'status_color'] = status_color
|
|
538
|
+
trajectories.loc[indices, 'class_color'] = class_color
|
|
539
|
+
|
|
540
|
+
if plot_outcome:
|
|
541
|
+
fig, ax = plt.subplots(1, len(selected_signals), figsize=(10, 5))
|
|
542
|
+
for i, s in enumerate(selected_signals):
|
|
543
|
+
for k, (tid, group) in enumerate(trajectories.groupby(column_labels['track'])):
|
|
544
|
+
cclass = group[class_col].to_numpy()[0]
|
|
545
|
+
t0 = group[time_col].to_numpy()[0]
|
|
546
|
+
timeline = group[column_labels['time']].to_numpy()
|
|
547
|
+
if cclass == 0:
|
|
548
|
+
if len(selected_signals) > 1:
|
|
549
|
+
ax[i].plot(timeline - t0, group[s].to_numpy(), c='tab:blue', alpha=0.1)
|
|
550
|
+
else:
|
|
551
|
+
ax.plot(timeline - t0, group[s].to_numpy(), c='tab:blue', alpha=0.1)
|
|
552
|
+
if len(selected_signals) > 1:
|
|
553
|
+
for a, s in zip(ax, selected_signals):
|
|
554
|
+
a.set_title(s)
|
|
555
|
+
a.set_xlabel(r'time - t$_0$ [frame]')
|
|
556
|
+
a.spines['top'].set_visible(False)
|
|
557
|
+
a.spines['right'].set_visible(False)
|
|
558
|
+
else:
|
|
559
|
+
ax.set_title(s)
|
|
560
|
+
ax.set_xlabel(r'time - t$_0$ [frame]')
|
|
561
|
+
ax.spines['top'].set_visible(False)
|
|
562
|
+
ax.spines['right'].set_visible(False)
|
|
563
|
+
plt.tight_layout()
|
|
564
|
+
if output_dir is not None:
|
|
565
|
+
plt.savefig(output_dir + 'signal_collapse.png', bbox_inches='tight', dpi=300)
|
|
566
|
+
plt.pause(3)
|
|
567
|
+
plt.close()
|
|
568
|
+
|
|
569
|
+
return trajectories
|
|
570
|
+
|
|
571
|
+
|
|
572
|
+
def analyze_pair_signals(trajectories_pairs,trajectories_reference,trajectories_neighbors, model, interpolate_na=True, selected_signals=None,
|
|
573
|
+
model_path=None, plot_outcome=False, output_dir=None, column_labels = {'track': "TRACK_ID", 'time': 'FRAME', 'x': 'POSITION_X', 'y': 'POSITION_Y'}):
|
|
574
|
+
"""
|
|
575
|
+
"""
|
|
576
|
+
|
|
577
|
+
model_path = locate_signal_model(model, path=model_path, pairs=True)
|
|
578
|
+
print(f'Looking for model in {model_path}...')
|
|
579
|
+
complete_path = model_path
|
|
580
|
+
complete_path = rf"{complete_path}"
|
|
581
|
+
model_config_path = os.sep.join([complete_path, 'config_input.json'])
|
|
582
|
+
model_config_path = rf"{model_config_path}"
|
|
583
|
+
assert os.path.exists(complete_path), f'Model {model} could not be located in folder {model_path}... Abort.'
|
|
584
|
+
assert os.path.exists(model_config_path), f'Model configuration could not be located in folder {model_path}... Abort.'
|
|
585
|
+
|
|
586
|
+
trajectories_pairs = trajectories_pairs.rename(columns=lambda x: 'pair_' + x)
|
|
587
|
+
trajectories_reference = trajectories_reference.rename(columns=lambda x: 'reference_' + x)
|
|
588
|
+
trajectories_neighbors = trajectories_neighbors.rename(columns=lambda x: 'neighbor_' + x)
|
|
589
|
+
|
|
590
|
+
if 'pair_position' in list(trajectories_pairs.columns):
|
|
591
|
+
pair_groupby_cols = ['pair_position', 'pair_REFERENCE_ID', 'pair_NEIGHBOR_ID']
|
|
592
|
+
else:
|
|
593
|
+
pair_groupby_cols = ['pair_REFERENCE_ID', 'pair_NEIGHBOR_ID']
|
|
594
|
+
|
|
595
|
+
if 'reference_position' in list(trajectories_reference.columns):
|
|
596
|
+
reference_groupby_cols = ['reference_position', 'reference_TRACK_ID']
|
|
597
|
+
else:
|
|
598
|
+
reference_groupby_cols = ['reference_TRACK_ID']
|
|
599
|
+
|
|
600
|
+
if 'neighbor_position' in list(trajectories_neighbors.columns):
|
|
601
|
+
neighbor_groupby_cols = ['neighbor_position', 'neighbor_TRACK_ID']
|
|
602
|
+
else:
|
|
603
|
+
neighbor_groupby_cols = ['neighbor_TRACK_ID']
|
|
604
|
+
|
|
605
|
+
available_signals = [] #list(trajectories_pairs.columns) + list(trajectories_reference.columns) + list(trajectories_neighbors.columns)
|
|
606
|
+
for col in list(trajectories_pairs.columns):
|
|
607
|
+
if is_numeric_dtype(trajectories_pairs[col]):
|
|
608
|
+
available_signals.append(col)
|
|
609
|
+
for col in list(trajectories_reference.columns):
|
|
610
|
+
if is_numeric_dtype(trajectories_reference[col]):
|
|
611
|
+
available_signals.append(col)
|
|
612
|
+
for col in list(trajectories_neighbors.columns):
|
|
613
|
+
if is_numeric_dtype(trajectories_neighbors[col]):
|
|
614
|
+
available_signals.append(col)
|
|
615
|
+
|
|
616
|
+
print('The available signals are : ', available_signals)
|
|
617
|
+
|
|
618
|
+
f = open(model_config_path)
|
|
619
|
+
config = json.load(f)
|
|
620
|
+
required_signals = config["channels"]
|
|
621
|
+
|
|
622
|
+
try:
|
|
623
|
+
label = config['label']
|
|
624
|
+
if label=='':
|
|
625
|
+
label = None
|
|
626
|
+
except:
|
|
627
|
+
label = None
|
|
628
|
+
|
|
629
|
+
if selected_signals is None:
|
|
630
|
+
selected_signals = []
|
|
631
|
+
for s in required_signals:
|
|
632
|
+
pattern_test = [s in a or s==a for a in available_signals]
|
|
633
|
+
print(f'Pattern test for signal {s}: ', pattern_test)
|
|
634
|
+
assert np.any(pattern_test),f'No signal matches with the requirements of the model {required_signals}. Please pass the signals manually with the argument selected_signals or add measurements. Abort.'
|
|
635
|
+
valid_columns = np.array(available_signals)[np.array(pattern_test)]
|
|
636
|
+
if len(valid_columns)==1:
|
|
637
|
+
selected_signals.append(valid_columns[0])
|
|
638
|
+
else:
|
|
639
|
+
#print(test_number_of_nan(trajectories, valid_columns))
|
|
640
|
+
print(f'Found several candidate signals: {valid_columns}')
|
|
641
|
+
for vc in natsorted(valid_columns):
|
|
642
|
+
if 'circle' in vc:
|
|
643
|
+
selected_signals.append(vc)
|
|
644
|
+
break
|
|
645
|
+
else:
|
|
646
|
+
selected_signals.append(valid_columns[0])
|
|
647
|
+
# do something more complicated in case of one to many columns
|
|
648
|
+
#pass
|
|
649
|
+
else:
|
|
650
|
+
assert len(selected_signals)==len(required_signals),f'Mismatch between the number of required signals {required_signals} and the provided signals {selected_signals}... Abort.'
|
|
651
|
+
|
|
652
|
+
print(f'The following channels will be passed to the model: {selected_signals}')
|
|
653
|
+
trajectories_reference_clean = interpolate_nan_properties(trajectories_reference, track_label=reference_groupby_cols)
|
|
654
|
+
trajectories_neighbors_clean = interpolate_nan_properties(trajectories_neighbors, track_label=neighbor_groupby_cols)
|
|
655
|
+
trajectories_pairs_clean = interpolate_nan_properties(trajectories_pairs, track_label=pair_groupby_cols)
|
|
656
|
+
print(f'{trajectories_pairs_clean.columns=}')
|
|
657
|
+
|
|
658
|
+
max_signal_size = int(trajectories_pairs_clean['pair_FRAME'].max()) + 2
|
|
659
|
+
pair_tracks = trajectories_pairs_clean.groupby(pair_groupby_cols).size()
|
|
660
|
+
signals = np.zeros((len(pair_tracks),max_signal_size, len(selected_signals)))
|
|
661
|
+
print(f'{max_signal_size=} {len(pair_tracks)=} {signals.shape=}')
|
|
662
|
+
|
|
663
|
+
for i,(pair,group) in enumerate(trajectories_pairs_clean.groupby(pair_groupby_cols)):
|
|
664
|
+
|
|
665
|
+
if 'pair_position' not in list(trajectories_pairs_clean.columns):
|
|
666
|
+
pos_mode = False
|
|
667
|
+
reference_cell = pair[0]; neighbor_cell = pair[1]
|
|
668
|
+
else:
|
|
669
|
+
pos_mode = True
|
|
670
|
+
reference_cell = pair[1]; neighbor_cell = pair[2]; pos = pair[0]
|
|
671
|
+
|
|
672
|
+
if pos_mode and 'reference_position' in list(trajectories_reference_clean.columns) and 'neighbor_position' in list(trajectories_neighbors_clean.columns):
|
|
673
|
+
reference_filter = (trajectories_reference_clean['reference_TRACK_ID']==reference_cell)&(trajectories_reference_clean['reference_position']==pos)
|
|
674
|
+
neighbor_filter = (trajectories_neighbors_clean['neighbor_TRACK_ID']==neighbor_cell)&(trajectories_neighbors_clean['neighbor_position']==pos)
|
|
675
|
+
else:
|
|
676
|
+
reference_filter = trajectories_reference_clean['reference_TRACK_ID']==reference_cell
|
|
677
|
+
neighbor_filter = trajectories_neighbors_clean['neighbor_TRACK_ID']==neighbor_cell
|
|
678
|
+
|
|
679
|
+
pair_frames = group['pair_FRAME'].to_numpy().astype(int)
|
|
680
|
+
|
|
681
|
+
for j,col in enumerate(selected_signals):
|
|
682
|
+
if col.startswith('pair_'):
|
|
683
|
+
signal = group[col].to_numpy()
|
|
684
|
+
signals[i,pair_frames,j] = signal
|
|
685
|
+
signals[i,max(pair_frames):,j] = signal[-1]
|
|
686
|
+
elif col.startswith('reference_'):
|
|
687
|
+
signal = trajectories_reference_clean.loc[reference_filter, col].to_numpy()
|
|
688
|
+
timeline = trajectories_reference_clean.loc[reference_filter, 'reference_FRAME'].to_numpy()
|
|
689
|
+
signals[i,timeline,j] = signal
|
|
690
|
+
signals[i,max(timeline):,j] = signal[-1]
|
|
691
|
+
elif col.startswith('neighbor_'):
|
|
692
|
+
signal = trajectories_neighbors_clean.loc[neighbor_filter, col].to_numpy()
|
|
693
|
+
timeline = trajectories_neighbors_clean.loc[neighbor_filter, 'neighbor_FRAME'].to_numpy()
|
|
694
|
+
signals[i,timeline,j] = signal
|
|
695
|
+
signals[i,max(timeline):,j] = signal[-1]
|
|
696
|
+
|
|
697
|
+
|
|
698
|
+
model = SignalDetectionModel(pretrained=complete_path)
|
|
699
|
+
print('signal shape: ', signals.shape)
|
|
700
|
+
|
|
701
|
+
classes = model.predict_class(signals)
|
|
702
|
+
times_recast = model.predict_time_of_interest(signals)
|
|
703
|
+
|
|
704
|
+
if label is None:
|
|
705
|
+
class_col = 'pair_class'
|
|
706
|
+
time_col = 'pair_t0'
|
|
707
|
+
status_col = 'pair_status'
|
|
708
|
+
else:
|
|
709
|
+
class_col = 'pair_class_'+label
|
|
710
|
+
time_col = 'pair_t_'+label
|
|
711
|
+
status_col = 'pair_status_'+label
|
|
712
|
+
|
|
713
|
+
for i,(pair,group) in enumerate(trajectories_pairs.groupby(pair_groupby_cols)):
|
|
714
|
+
indices = group.index
|
|
715
|
+
trajectories_pairs.loc[indices,class_col] = classes[i]
|
|
716
|
+
trajectories_pairs.loc[indices,time_col] = times_recast[i]
|
|
717
|
+
print('Done.')
|
|
718
|
+
|
|
719
|
+
# At the end rename cols again
|
|
720
|
+
trajectories_pairs = trajectories_pairs.rename(columns=lambda x: x.replace('pair_',''))
|
|
721
|
+
trajectories_reference = trajectories_pairs.rename(columns=lambda x: x.replace('reference_',''))
|
|
722
|
+
trajectories_neighbors = trajectories_pairs.rename(columns=lambda x: x.replace('neighbor_',''))
|
|
723
|
+
invalid_cols = [c for c in list(trajectories_pairs.columns) if c.startswith('Unnamed')]
|
|
724
|
+
trajectories_pairs = trajectories_pairs.drop(columns=invalid_cols)
|
|
725
|
+
|
|
726
|
+
return trajectories_pairs
|
|
337
727
|
|
|
338
728
|
class SignalDetectionModel(object):
|
|
339
729
|
|
|
@@ -1258,10 +1648,10 @@ class SignalDetectionModel(object):
|
|
|
1258
1648
|
csv_logger = CSVLogger(os.sep.join([self.model_folder,'log_classifier.csv']), append=True, separator=';')
|
|
1259
1649
|
self.cb.append(csv_logger)
|
|
1260
1650
|
checkpoint_path = os.sep.join([self.model_folder,"classifier.h5"])
|
|
1261
|
-
cp_callback = ModelCheckpoint(checkpoint_path,monitor="val_iou",mode="max",verbose=1,save_best_only=True,save_weights_only=False,save_freq="epoch")
|
|
1651
|
+
cp_callback = ModelCheckpoint(checkpoint_path, monitor="val_iou",mode="max",verbose=1,save_best_only=True,save_weights_only=False,save_freq="epoch")
|
|
1262
1652
|
self.cb.append(cp_callback)
|
|
1263
1653
|
|
|
1264
|
-
callback_stop = EarlyStopping(monitor='val_iou',
|
|
1654
|
+
callback_stop = EarlyStopping(monitor='val_iou',mode='max',patience=100)
|
|
1265
1655
|
self.cb.append(callback_stop)
|
|
1266
1656
|
|
|
1267
1657
|
elif mode=="regressor":
|
|
@@ -1278,7 +1668,7 @@ class SignalDetectionModel(object):
|
|
|
1278
1668
|
cp_callback = ModelCheckpoint(checkpoint_path,monitor="val_loss",mode="min",verbose=1,save_best_only=True,save_weights_only=False,save_freq="epoch")
|
|
1279
1669
|
self.cb.append(cp_callback)
|
|
1280
1670
|
|
|
1281
|
-
callback_stop = EarlyStopping(monitor='val_loss', patience=200)
|
|
1671
|
+
callback_stop = EarlyStopping(monitor='val_loss', mode='min', patience=200)
|
|
1282
1672
|
self.cb.append(callback_stop)
|
|
1283
1673
|
|
|
1284
1674
|
log_dir = self.model_folder+os.sep
|
|
@@ -1312,7 +1702,7 @@ class SignalDetectionModel(object):
|
|
|
1312
1702
|
if isinstance(self.list_of_sets[0],str):
|
|
1313
1703
|
# Case 1: a list of npy files to be loaded
|
|
1314
1704
|
for s in self.list_of_sets:
|
|
1315
|
-
|
|
1705
|
+
|
|
1316
1706
|
signal_dataset = self.load_set(s)
|
|
1317
1707
|
selected_signals, max_length = self.find_best_signal_match(signal_dataset)
|
|
1318
1708
|
signals_recast, classes, times_of_interest = self.cast_signals_into_training_data(signal_dataset, selected_signals, max_length)
|
|
@@ -1416,7 +1806,6 @@ class SignalDetectionModel(object):
|
|
|
1416
1806
|
x_train_aug.append(aug[0])
|
|
1417
1807
|
y_time_train_aug.append(aug[1])
|
|
1418
1808
|
y_class_train_aug.append(aug[2])
|
|
1419
|
-
print('per class counts ',counts)
|
|
1420
1809
|
|
|
1421
1810
|
# Save augmented training set
|
|
1422
1811
|
self.x_train = np.array(x_train_aug)
|
|
@@ -1469,7 +1858,15 @@ class SignalDetectionModel(object):
|
|
|
1469
1858
|
for i in range(self.n_channels):
|
|
1470
1859
|
try:
|
|
1471
1860
|
# take into account timeline for accurate time regression
|
|
1472
|
-
|
|
1861
|
+
|
|
1862
|
+
if selected_signals[i].startswith('pair_'):
|
|
1863
|
+
timeline = signal_dataset[k]['pair_FRAME'].astype(int)
|
|
1864
|
+
elif selected_signals[i].startswith('reference_'):
|
|
1865
|
+
timeline = signal_dataset[k]['reference_FRAME'].astype(int)
|
|
1866
|
+
elif selected_signals[i].startswith('neighbor_'):
|
|
1867
|
+
timeline = signal_dataset[k]['neighbor_FRAME'].astype(int)
|
|
1868
|
+
else:
|
|
1869
|
+
timeline = signal_dataset[k]['FRAME'].astype(int)
|
|
1473
1870
|
signals_recast[k,timeline,i] = signal_dataset[k][selected_signals[i]]
|
|
1474
1871
|
except:
|
|
1475
1872
|
print(f"Attribute {selected_signals[i]} matched to {self.channel_option[i]} not found in annotation...")
|
celldetective/tracking.py
CHANGED
|
@@ -130,7 +130,7 @@ def track(labels, configuration=None, stack=None, spatial_calibration=1, feature
|
|
|
130
130
|
tracking_updates = ["motion"]
|
|
131
131
|
|
|
132
132
|
tracker.append(new_btrack_objects)
|
|
133
|
-
tracker.volume = ((0,volume[0]), (0,volume[1])) #(-1e5, 1e5)
|
|
133
|
+
tracker.volume = ((0,volume[0]), (0,volume[1]), (-1e5, 1e5)) #(-1e5, 1e5)
|
|
134
134
|
#print(tracker.volume)
|
|
135
135
|
tracker.track(tracking_updates=tracking_updates, **track_kwargs)
|
|
136
136
|
tracker.optimize(options=optimizer_options)
|
|
@@ -138,7 +138,11 @@ def track(labels, configuration=None, stack=None, spatial_calibration=1, feature
|
|
|
138
138
|
data, properties, graph = tracker.to_napari() #ndim=2
|
|
139
139
|
|
|
140
140
|
# do the table post processing and napari options
|
|
141
|
-
|
|
141
|
+
if data.shape[1]==4:
|
|
142
|
+
df = pd.DataFrame(data, columns=[column_labels['track'],column_labels['time'],column_labels['y'],column_labels['x']])
|
|
143
|
+
elif data.shape[1]==5:
|
|
144
|
+
df = pd.DataFrame(data, columns=[column_labels['track'],column_labels['time'],"z",column_labels['y'],column_labels['x']])
|
|
145
|
+
df = df.drop(columns=['z'])
|
|
142
146
|
df[column_labels['x']+'_um'] = df[column_labels['x']]*spatial_calibration
|
|
143
147
|
df[column_labels['y']+'_um'] = df[column_labels['y']]*spatial_calibration
|
|
144
148
|
|
|
@@ -160,6 +164,8 @@ def track(labels, configuration=None, stack=None, spatial_calibration=1, feature
|
|
|
160
164
|
if clean_trajectories_kwargs is not None:
|
|
161
165
|
df = clean_trajectories(df.copy(),**clean_trajectories_kwargs)
|
|
162
166
|
|
|
167
|
+
df['ID'] = np.arange(len(df)).astype(int)
|
|
168
|
+
|
|
163
169
|
if view_on_napari:
|
|
164
170
|
view_on_napari_btrack(data,properties,graph,stack=stack,labels=labels,relabel=True)
|
|
165
171
|
|