celldetective 1.4.2__py3-none-any.whl → 1.5.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- celldetective/__init__.py +25 -0
- celldetective/__main__.py +62 -43
- celldetective/_version.py +1 -1
- celldetective/extra_properties.py +477 -399
- celldetective/filters.py +192 -97
- celldetective/gui/InitWindow.py +541 -411
- celldetective/gui/__init__.py +0 -15
- celldetective/gui/about.py +44 -39
- celldetective/gui/analyze_block.py +120 -84
- celldetective/gui/base/__init__.py +0 -0
- celldetective/gui/base/channel_norm_generator.py +335 -0
- celldetective/gui/base/components.py +249 -0
- celldetective/gui/base/feature_choice.py +92 -0
- celldetective/gui/base/figure_canvas.py +52 -0
- celldetective/gui/base/list_widget.py +133 -0
- celldetective/gui/{styles.py → base/styles.py} +92 -36
- celldetective/gui/base/utils.py +33 -0
- celldetective/gui/base_annotator.py +900 -767
- celldetective/gui/classifier_widget.py +6 -22
- celldetective/gui/configure_new_exp.py +777 -671
- celldetective/gui/control_panel.py +635 -524
- celldetective/gui/dynamic_progress.py +449 -0
- celldetective/gui/event_annotator.py +2023 -1662
- celldetective/gui/generic_signal_plot.py +1292 -944
- celldetective/gui/gui_utils.py +899 -1289
- celldetective/gui/interactions_block.py +658 -0
- celldetective/gui/interactive_timeseries_viewer.py +447 -0
- celldetective/gui/json_readers.py +48 -15
- celldetective/gui/layouts/__init__.py +5 -0
- celldetective/gui/layouts/background_model_free_layout.py +537 -0
- celldetective/gui/layouts/channel_offset_layout.py +134 -0
- celldetective/gui/layouts/local_correction_layout.py +91 -0
- celldetective/gui/layouts/model_fit_layout.py +372 -0
- celldetective/gui/layouts/operation_layout.py +68 -0
- celldetective/gui/layouts/protocol_designer_layout.py +96 -0
- celldetective/gui/pair_event_annotator.py +3130 -2435
- celldetective/gui/plot_measurements.py +586 -267
- celldetective/gui/plot_signals_ui.py +724 -506
- celldetective/gui/preprocessing_block.py +395 -0
- celldetective/gui/process_block.py +1678 -1831
- celldetective/gui/seg_model_loader.py +580 -473
- celldetective/gui/settings/__init__.py +0 -7
- celldetective/gui/settings/_cellpose_model_params.py +181 -0
- celldetective/gui/settings/_event_detection_model_params.py +95 -0
- celldetective/gui/settings/_segmentation_model_params.py +159 -0
- celldetective/gui/settings/_settings_base.py +77 -65
- celldetective/gui/settings/_settings_event_model_training.py +752 -526
- celldetective/gui/settings/_settings_measurements.py +1133 -964
- celldetective/gui/settings/_settings_neighborhood.py +574 -488
- celldetective/gui/settings/_settings_segmentation_model_training.py +779 -564
- celldetective/gui/settings/_settings_signal_annotator.py +329 -305
- celldetective/gui/settings/_settings_tracking.py +1304 -1094
- celldetective/gui/settings/_stardist_model_params.py +98 -0
- celldetective/gui/survival_ui.py +422 -312
- celldetective/gui/tableUI.py +1665 -1701
- celldetective/gui/table_ops/_maths.py +295 -0
- celldetective/gui/table_ops/_merge_groups.py +140 -0
- celldetective/gui/table_ops/_merge_one_hot.py +95 -0
- celldetective/gui/table_ops/_query_table.py +43 -0
- celldetective/gui/table_ops/_rename_col.py +44 -0
- celldetective/gui/thresholds_gui.py +382 -179
- celldetective/gui/viewers/__init__.py +0 -0
- celldetective/gui/viewers/base_viewer.py +700 -0
- celldetective/gui/viewers/channel_offset_viewer.py +331 -0
- celldetective/gui/viewers/contour_viewer.py +394 -0
- celldetective/gui/viewers/size_viewer.py +153 -0
- celldetective/gui/viewers/spot_detection_viewer.py +341 -0
- celldetective/gui/viewers/threshold_viewer.py +309 -0
- celldetective/gui/workers.py +403 -126
- celldetective/log_manager.py +92 -0
- celldetective/measure.py +1895 -1478
- celldetective/napari/__init__.py +0 -0
- celldetective/napari/utils.py +1025 -0
- celldetective/neighborhood.py +1914 -1448
- celldetective/preprocessing.py +1620 -1220
- celldetective/processes/__init__.py +0 -0
- celldetective/processes/background_correction.py +271 -0
- celldetective/processes/compute_neighborhood.py +894 -0
- celldetective/processes/detect_events.py +246 -0
- celldetective/processes/downloader.py +137 -0
- celldetective/processes/measure_cells.py +565 -0
- celldetective/processes/segment_cells.py +760 -0
- celldetective/processes/track_cells.py +435 -0
- celldetective/processes/train_segmentation_model.py +694 -0
- celldetective/processes/train_signal_model.py +265 -0
- celldetective/processes/unified_process.py +292 -0
- celldetective/regionprops/_regionprops.py +358 -317
- celldetective/relative_measurements.py +987 -710
- celldetective/scripts/measure_cells.py +313 -212
- celldetective/scripts/measure_relative.py +90 -46
- celldetective/scripts/segment_cells.py +165 -104
- celldetective/scripts/segment_cells_thresholds.py +96 -68
- celldetective/scripts/track_cells.py +198 -149
- celldetective/scripts/train_segmentation_model.py +324 -201
- celldetective/scripts/train_signal_model.py +87 -45
- celldetective/segmentation.py +844 -749
- celldetective/signals.py +3514 -2861
- celldetective/tracking.py +30 -15
- celldetective/utils/__init__.py +0 -0
- celldetective/utils/cellpose_utils/__init__.py +133 -0
- celldetective/utils/color_mappings.py +42 -0
- celldetective/utils/data_cleaning.py +630 -0
- celldetective/utils/data_loaders.py +450 -0
- celldetective/utils/dataset_helpers.py +207 -0
- celldetective/utils/downloaders.py +235 -0
- celldetective/utils/event_detection/__init__.py +8 -0
- celldetective/utils/experiment.py +1782 -0
- celldetective/utils/image_augmenters.py +308 -0
- celldetective/utils/image_cleaning.py +74 -0
- celldetective/utils/image_loaders.py +926 -0
- celldetective/utils/image_transforms.py +335 -0
- celldetective/utils/io.py +62 -0
- celldetective/utils/mask_cleaning.py +348 -0
- celldetective/utils/mask_transforms.py +5 -0
- celldetective/utils/masks.py +184 -0
- celldetective/utils/maths.py +351 -0
- celldetective/utils/model_getters.py +325 -0
- celldetective/utils/model_loaders.py +296 -0
- celldetective/utils/normalization.py +380 -0
- celldetective/utils/parsing.py +465 -0
- celldetective/utils/plots/__init__.py +0 -0
- celldetective/utils/plots/regression.py +53 -0
- celldetective/utils/resources.py +34 -0
- celldetective/utils/stardist_utils/__init__.py +104 -0
- celldetective/utils/stats.py +90 -0
- celldetective/utils/types.py +21 -0
- {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/METADATA +1 -1
- celldetective-1.5.0b1.dist-info/RECORD +187 -0
- {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/WHEEL +1 -1
- tests/gui/test_new_project.py +129 -117
- tests/gui/test_project.py +127 -79
- tests/test_filters.py +39 -15
- tests/test_notebooks.py +8 -0
- tests/test_tracking.py +232 -13
- tests/test_utils.py +123 -77
- celldetective/gui/base_components.py +0 -23
- celldetective/gui/layouts.py +0 -1602
- celldetective/gui/processes/compute_neighborhood.py +0 -594
- celldetective/gui/processes/downloader.py +0 -111
- celldetective/gui/processes/measure_cells.py +0 -360
- celldetective/gui/processes/segment_cells.py +0 -499
- celldetective/gui/processes/track_cells.py +0 -303
- celldetective/gui/processes/train_segmentation_model.py +0 -270
- celldetective/gui/processes/train_signal_model.py +0 -108
- celldetective/gui/table_ops/merge_groups.py +0 -118
- celldetective/gui/viewers.py +0 -1354
- celldetective/io.py +0 -3663
- celldetective/utils.py +0 -3108
- celldetective-1.4.2.dist-info/RECORD +0 -123
- {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/entry_points.txt +0 -0
- {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/licenses/LICENSE +0 -0
- {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,450 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
from celldetective import get_logger
|
|
8
|
+
from celldetective.utils.image_loaders import locate_stack_and_labels
|
|
9
|
+
from celldetective.utils.experiment import (
|
|
10
|
+
get_config,
|
|
11
|
+
get_experiment_wells,
|
|
12
|
+
get_experiment_labels,
|
|
13
|
+
get_experiment_metadata,
|
|
14
|
+
extract_well_name_and_number,
|
|
15
|
+
extract_position_name,
|
|
16
|
+
interpret_wells_and_positions,
|
|
17
|
+
get_position_movie_path,
|
|
18
|
+
get_positions_in_well,
|
|
19
|
+
)
|
|
20
|
+
from celldetective.utils.parsing import (
|
|
21
|
+
config_section_to_dict,
|
|
22
|
+
_extract_labels_from_config,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
logger = get_logger()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_position_table(pos, population, return_path=False):
|
|
29
|
+
"""
|
|
30
|
+
Retrieves the data table for a specified population at a given position, optionally returning the table's file path.
|
|
31
|
+
|
|
32
|
+
This function locates and loads a CSV data table associated with a specific population (e.g., 'targets', 'cells')
|
|
33
|
+
from a specified position directory. The position directory should contain an 'output/tables' subdirectory where
|
|
34
|
+
the CSV file named 'trajectories_{population}.csv' is expected to be found. If the file exists, it is loaded into
|
|
35
|
+
a pandas DataFrame; otherwise, None is returned.
|
|
36
|
+
|
|
37
|
+
Parameters
|
|
38
|
+
----------
|
|
39
|
+
pos : str
|
|
40
|
+
The path to the position directory from which to load the data table.
|
|
41
|
+
population : str
|
|
42
|
+
The name of the population for which the data table is to be retrieved. This name is used to construct the
|
|
43
|
+
file name of the CSV file to be loaded.
|
|
44
|
+
return_path : bool, optional
|
|
45
|
+
If True, returns a tuple containing the loaded data table (or None) and the path to the CSV file. If False,
|
|
46
|
+
only the loaded data table (or None) is returned (default is False).
|
|
47
|
+
|
|
48
|
+
Returns
|
|
49
|
+
-------
|
|
50
|
+
pandas.DataFrame or None, or (pandas.DataFrame or None, str)
|
|
51
|
+
If return_path is False, returns the loaded data table as a pandas DataFrame, or None if the table file does
|
|
52
|
+
not exist. If return_path is True, returns a tuple where the first element is the data table (or None) and the
|
|
53
|
+
second element is the path to the CSV file.
|
|
54
|
+
|
|
55
|
+
Examples
|
|
56
|
+
--------
|
|
57
|
+
>>> df_pos = get_position_table('/path/to/position', 'targets')
|
|
58
|
+
# This will load the 'trajectories_targets.csv' table from the specified position directory into a pandas DataFrame.
|
|
59
|
+
|
|
60
|
+
>>> df_pos, table_path = get_position_table('/path/to/position', 'targets', return_path=True)
|
|
61
|
+
# This will load the 'trajectories_targets.csv' table and also return the path to the CSV file.
|
|
62
|
+
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
import pandas as pd
|
|
66
|
+
|
|
67
|
+
if not pos.endswith(os.sep):
|
|
68
|
+
table = os.sep.join([pos, "output", "tables", f"trajectories_{population}.csv"])
|
|
69
|
+
else:
|
|
70
|
+
table = pos + os.sep.join(
|
|
71
|
+
["output", "tables", f"trajectories_{population}.csv"]
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
if os.path.exists(table):
|
|
75
|
+
try:
|
|
76
|
+
df_pos = pd.read_csv(table, low_memory=False)
|
|
77
|
+
except Exception as e:
|
|
78
|
+
logger.error(e)
|
|
79
|
+
df_pos = None
|
|
80
|
+
else:
|
|
81
|
+
df_pos = None
|
|
82
|
+
|
|
83
|
+
if return_path:
|
|
84
|
+
return df_pos, table
|
|
85
|
+
else:
|
|
86
|
+
return df_pos
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def get_position_pickle(pos, population, return_path=False):
|
|
90
|
+
"""
|
|
91
|
+
Retrieves the data table for a specified population at a given position, optionally returning the table's file path.
|
|
92
|
+
|
|
93
|
+
This function locates and loads a CSV data table associated with a specific population (e.g., 'targets', 'cells')
|
|
94
|
+
from a specified position directory. The position directory should contain an 'output/tables' subdirectory where
|
|
95
|
+
the CSV file named 'trajectories_{population}.csv' is expected to be found. If the file exists, it is loaded into
|
|
96
|
+
a pandas DataFrame; otherwise, None is returned.
|
|
97
|
+
|
|
98
|
+
Parameters
|
|
99
|
+
----------
|
|
100
|
+
pos : str
|
|
101
|
+
The path to the position directory from which to load the data table.
|
|
102
|
+
population : str
|
|
103
|
+
The name of the population for which the data table is to be retrieved. This name is used to construct the
|
|
104
|
+
file name of the CSV file to be loaded.
|
|
105
|
+
return_path : bool, optional
|
|
106
|
+
If True, returns a tuple containing the loaded data table (or None) and the path to the CSV file. If False,
|
|
107
|
+
only the loaded data table (or None) is returned (default is False).
|
|
108
|
+
|
|
109
|
+
Returns
|
|
110
|
+
-------
|
|
111
|
+
pandas.DataFrame or None, or (pandas.DataFrame or None, str)
|
|
112
|
+
If return_path is False, returns the loaded data table as a pandas DataFrame, or None if the table file does
|
|
113
|
+
not exist. If return_path is True, returns a tuple where the first element is the data table (or None) and the
|
|
114
|
+
second element is the path to the CSV file.
|
|
115
|
+
|
|
116
|
+
Examples
|
|
117
|
+
--------
|
|
118
|
+
>>> df_pos = get_position_table('/path/to/position', 'targets')
|
|
119
|
+
# This will load the 'trajectories_targets.csv' table from the specified position directory into a pandas DataFrame.
|
|
120
|
+
|
|
121
|
+
>>> df_pos, table_path = get_position_table('/path/to/position', 'targets', return_path=True)
|
|
122
|
+
# This will load the 'trajectories_targets.csv' table and also return the path to the CSV file.
|
|
123
|
+
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
if not pos.endswith(os.sep):
|
|
127
|
+
table = os.sep.join([pos, "output", "tables", f"trajectories_{population}.pkl"])
|
|
128
|
+
else:
|
|
129
|
+
table = pos + os.sep.join(
|
|
130
|
+
["output", "tables", f"trajectories_{population}.pkl"]
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
if os.path.exists(table):
|
|
134
|
+
df_pos = np.load(table, allow_pickle=True)
|
|
135
|
+
else:
|
|
136
|
+
df_pos = None
|
|
137
|
+
|
|
138
|
+
if return_path:
|
|
139
|
+
return df_pos, table
|
|
140
|
+
else:
|
|
141
|
+
return df_pos
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def load_experiment_tables(
|
|
145
|
+
experiment,
|
|
146
|
+
population="targets",
|
|
147
|
+
well_option="*",
|
|
148
|
+
position_option="*",
|
|
149
|
+
return_pos_info=False,
|
|
150
|
+
load_pickle=False,
|
|
151
|
+
):
|
|
152
|
+
"""
|
|
153
|
+
Load tabular data for an experiment, optionally including position-level information.
|
|
154
|
+
|
|
155
|
+
This function retrieves and processes tables associated with positions in an experiment.
|
|
156
|
+
It supports filtering by wells and positions, and can load either CSV data or pickle files.
|
|
157
|
+
|
|
158
|
+
Parameters
|
|
159
|
+
----------
|
|
160
|
+
experiment : str
|
|
161
|
+
Path to the experiment folder to load data for.
|
|
162
|
+
population : str, optional
|
|
163
|
+
The population to extract from the position tables (`'targets'` or `'effectors'`). Default is `'targets'`.
|
|
164
|
+
well_option : str or list, optional
|
|
165
|
+
Specifies which wells to include. Default is `'*'`, meaning all wells.
|
|
166
|
+
position_option : str or list, optional
|
|
167
|
+
Specifies which positions to include within selected wells. Default is `'*'`, meaning all positions.
|
|
168
|
+
return_pos_info : bool, optional
|
|
169
|
+
If `True`, also returns a DataFrame containing position-level metadata. Default is `False`.
|
|
170
|
+
load_pickle : bool, optional
|
|
171
|
+
If `True`, loads pre-processed pickle files for the positions instead of raw data. Default is `False`.
|
|
172
|
+
|
|
173
|
+
Returns
|
|
174
|
+
-------
|
|
175
|
+
df : pandas.DataFrame or None
|
|
176
|
+
A DataFrame containing aggregated data for the specified wells and positions, or `None` if no data is found.
|
|
177
|
+
The DataFrame includes metadata such as well and position identifiers, concentrations, antibodies, and other
|
|
178
|
+
experimental parameters.
|
|
179
|
+
df_pos_info : pandas.DataFrame, optional
|
|
180
|
+
A DataFrame with metadata for each position, including file paths and experimental details. Returned only
|
|
181
|
+
if `return_pos_info=True`.
|
|
182
|
+
|
|
183
|
+
Notes
|
|
184
|
+
-----
|
|
185
|
+
- The function assumes the experiment's configuration includes details about movie prefixes, concentrations,
|
|
186
|
+
cell types, antibodies, and pharmaceutical agents.
|
|
187
|
+
- Wells and positions can be filtered using `well_option` and `position_option`, respectively. If filtering
|
|
188
|
+
fails or is invalid, those specific wells/positions are skipped.
|
|
189
|
+
- Position-level metadata is assembled into `df_pos_info` and includes paths to data and movies.
|
|
190
|
+
|
|
191
|
+
Examples
|
|
192
|
+
--------
|
|
193
|
+
Load all data for an experiment:
|
|
194
|
+
|
|
195
|
+
>>> df = load_experiment_tables("path/to/experiment1")
|
|
196
|
+
|
|
197
|
+
Load data for specific wells and positions, including position metadata:
|
|
198
|
+
|
|
199
|
+
>>> df, df_pos_info = load_experiment_tables(
|
|
200
|
+
... "experiment_01", well_option=["A1", "B1"], position_option=[0, 1], return_pos_info=True
|
|
201
|
+
... )
|
|
202
|
+
|
|
203
|
+
Use pickle files for faster loading:
|
|
204
|
+
|
|
205
|
+
>>> df = load_experiment_tables("experiment_01", load_pickle=True)
|
|
206
|
+
|
|
207
|
+
"""
|
|
208
|
+
|
|
209
|
+
import pandas as pd
|
|
210
|
+
|
|
211
|
+
config = get_config(experiment)
|
|
212
|
+
wells = get_experiment_wells(experiment)
|
|
213
|
+
|
|
214
|
+
movie_prefix = config_section_to_dict(config, "MovieSettings")["movie_prefix"]
|
|
215
|
+
|
|
216
|
+
labels = get_experiment_labels(experiment)
|
|
217
|
+
metadata = get_experiment_metadata(experiment) # None or dict of metadata
|
|
218
|
+
well_labels = _extract_labels_from_config(config, len(wells))
|
|
219
|
+
|
|
220
|
+
well_indices, position_indices = interpret_wells_and_positions(
|
|
221
|
+
experiment, well_option, position_option
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
df = []
|
|
225
|
+
df_pos_info = []
|
|
226
|
+
real_well_index = 0
|
|
227
|
+
|
|
228
|
+
for k, well_path in enumerate(tqdm(wells[well_indices])):
|
|
229
|
+
|
|
230
|
+
any_table = False # assume no table
|
|
231
|
+
|
|
232
|
+
well_name, well_number = extract_well_name_and_number(well_path)
|
|
233
|
+
widx = well_indices[k]
|
|
234
|
+
well_alias = well_labels[widx]
|
|
235
|
+
|
|
236
|
+
positions = get_positions_in_well(well_path)
|
|
237
|
+
if position_indices is not None:
|
|
238
|
+
try:
|
|
239
|
+
positions = positions[position_indices]
|
|
240
|
+
except Exception as e:
|
|
241
|
+
logger.error(e)
|
|
242
|
+
continue
|
|
243
|
+
|
|
244
|
+
real_pos_index = 0
|
|
245
|
+
for pidx, pos_path in enumerate(positions):
|
|
246
|
+
|
|
247
|
+
pos_name = extract_position_name(pos_path)
|
|
248
|
+
|
|
249
|
+
stack_path = get_position_movie_path(pos_path, prefix=movie_prefix)
|
|
250
|
+
|
|
251
|
+
if not load_pickle:
|
|
252
|
+
df_pos, table = get_position_table(
|
|
253
|
+
pos_path, population=population, return_path=True
|
|
254
|
+
)
|
|
255
|
+
else:
|
|
256
|
+
df_pos, table = get_position_pickle(
|
|
257
|
+
pos_path, population=population, return_path=True
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
if df_pos is not None:
|
|
261
|
+
|
|
262
|
+
df_pos["position"] = pos_path
|
|
263
|
+
df_pos["well"] = well_path
|
|
264
|
+
df_pos["well_index"] = well_number
|
|
265
|
+
df_pos["well_name"] = well_name
|
|
266
|
+
df_pos["pos_name"] = pos_name
|
|
267
|
+
|
|
268
|
+
for k in list(labels.keys()):
|
|
269
|
+
values = labels[k]
|
|
270
|
+
try:
|
|
271
|
+
df_pos[k] = values[widx]
|
|
272
|
+
except Exception as e:
|
|
273
|
+
logger.error(f"{e=}")
|
|
274
|
+
|
|
275
|
+
if metadata is not None:
|
|
276
|
+
keys = list(metadata.keys())
|
|
277
|
+
for key in keys:
|
|
278
|
+
df_pos[key] = metadata[key]
|
|
279
|
+
|
|
280
|
+
df.append(df_pos)
|
|
281
|
+
any_table = True
|
|
282
|
+
|
|
283
|
+
pos_dict = {
|
|
284
|
+
"pos_path": pos_path,
|
|
285
|
+
"pos_index": real_pos_index,
|
|
286
|
+
"pos_name": pos_name,
|
|
287
|
+
"table_path": table,
|
|
288
|
+
"stack_path": stack_path,
|
|
289
|
+
"well_path": well_path,
|
|
290
|
+
"well_index": real_well_index,
|
|
291
|
+
"well_name": well_name,
|
|
292
|
+
"well_number": well_number,
|
|
293
|
+
"well_alias": well_alias,
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
df_pos_info.append(pos_dict)
|
|
297
|
+
|
|
298
|
+
real_pos_index += 1
|
|
299
|
+
|
|
300
|
+
if any_table:
|
|
301
|
+
real_well_index += 1
|
|
302
|
+
|
|
303
|
+
df_pos_info = pd.DataFrame(df_pos_info)
|
|
304
|
+
if len(df) > 0:
|
|
305
|
+
df = pd.concat(df)
|
|
306
|
+
df = df.reset_index(drop=True)
|
|
307
|
+
else:
|
|
308
|
+
df = None
|
|
309
|
+
|
|
310
|
+
if return_pos_info:
|
|
311
|
+
return df, df_pos_info
|
|
312
|
+
else:
|
|
313
|
+
return df
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def load_tracking_data(position, prefix="Aligned", population="target"):
|
|
317
|
+
"""
|
|
318
|
+
|
|
319
|
+
Load the tracking data, labels, and stack for a given position and population.
|
|
320
|
+
|
|
321
|
+
Parameters
|
|
322
|
+
----------
|
|
323
|
+
position : str
|
|
324
|
+
The position or directory where the data is located.
|
|
325
|
+
prefix : str, optional
|
|
326
|
+
The prefix used in the filenames of the stack images (default is "Aligned").
|
|
327
|
+
population : str, optional
|
|
328
|
+
The population to load the data for. Options are "target" or "effector" (default is "target").
|
|
329
|
+
|
|
330
|
+
Returns
|
|
331
|
+
-------
|
|
332
|
+
trajectories : DataFrame
|
|
333
|
+
The tracking data loaded as a pandas DataFrame.
|
|
334
|
+
labels : ndarray
|
|
335
|
+
The segmentation labels loaded as a numpy ndarray.
|
|
336
|
+
stack : ndarray
|
|
337
|
+
The image stack loaded as a numpy ndarray.
|
|
338
|
+
|
|
339
|
+
Notes
|
|
340
|
+
-----
|
|
341
|
+
This function loads the tracking data, labels, and stack for a given position and population.
|
|
342
|
+
It reads the trajectories from the appropriate CSV file based on the specified population.
|
|
343
|
+
The stack and labels are located using the `locate_stack_and_labels` function.
|
|
344
|
+
The resulting tracking data is returned as a pandas DataFrame, and the labels and stack are returned as numpy ndarrays.
|
|
345
|
+
|
|
346
|
+
Examples
|
|
347
|
+
--------
|
|
348
|
+
>>> trajectories, labels, stack = load_tracking_data(position, population="target")
|
|
349
|
+
# Load the tracking data, labels, and stack for the specified position and target population.
|
|
350
|
+
|
|
351
|
+
"""
|
|
352
|
+
|
|
353
|
+
import pandas as pd
|
|
354
|
+
|
|
355
|
+
position = position.replace("\\", "/")
|
|
356
|
+
if population.lower() == "target" or population.lower() == "targets":
|
|
357
|
+
trajectories = pd.read_csv(
|
|
358
|
+
position + os.sep.join(["output", "tables", "trajectories_targets.csv"])
|
|
359
|
+
)
|
|
360
|
+
elif population.lower() == "effector" or population.lower() == "effectors":
|
|
361
|
+
trajectories = pd.read_csv(
|
|
362
|
+
position + os.sep.join(["output", "tables", "trajectories_effectors.csv"])
|
|
363
|
+
)
|
|
364
|
+
else:
|
|
365
|
+
trajectories = pd.read_csv(
|
|
366
|
+
position
|
|
367
|
+
+ os.sep.join(["output", "tables", f"trajectories_{population}.csv"])
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
stack, labels = locate_stack_and_labels(
|
|
371
|
+
position, prefix=prefix, population=population
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
return trajectories, labels, stack
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
def interpret_tracking_configuration(config):
|
|
378
|
+
"""
|
|
379
|
+
Interpret and resolve the path for a tracking configuration file.
|
|
380
|
+
|
|
381
|
+
This function determines the appropriate configuration file path based on the input.
|
|
382
|
+
If the input is a string representing an existing path or a known configuration name,
|
|
383
|
+
it resolves to the correct file path. If the input is invalid or `None`, a default
|
|
384
|
+
configuration is returned.
|
|
385
|
+
|
|
386
|
+
Parameters
|
|
387
|
+
----------
|
|
388
|
+
config : str or None
|
|
389
|
+
The input configuration, which can be:
|
|
390
|
+
- A string representing the full path to a configuration file.
|
|
391
|
+
- A short name of a configuration file without the `.json` extension.
|
|
392
|
+
- `None` to use a default configuration.
|
|
393
|
+
|
|
394
|
+
Returns
|
|
395
|
+
-------
|
|
396
|
+
str
|
|
397
|
+
The resolved path to the configuration file.
|
|
398
|
+
|
|
399
|
+
Notes
|
|
400
|
+
-----
|
|
401
|
+
- If `config` is a string and the specified path exists, it is returned as-is.
|
|
402
|
+
- If `config` is a name, the function searches in the `tracking_configs` directory
|
|
403
|
+
within the `celldetective` models folder.
|
|
404
|
+
- If the file or name is not found, or if `config` is `None`, the function falls
|
|
405
|
+
back to a default configuration using `cell_config()`.
|
|
406
|
+
|
|
407
|
+
Examples
|
|
408
|
+
--------
|
|
409
|
+
Resolve a full path:
|
|
410
|
+
|
|
411
|
+
>>> interpret_tracking_configuration("/path/to/config.json")
|
|
412
|
+
'/path/to/config.json'
|
|
413
|
+
|
|
414
|
+
Resolve a named configuration:
|
|
415
|
+
|
|
416
|
+
>>> interpret_tracking_configuration("default_tracking")
|
|
417
|
+
'/path/to/celldetective/models/tracking_configs/default_tracking.json'
|
|
418
|
+
|
|
419
|
+
Handle `None` to return the default configuration:
|
|
420
|
+
|
|
421
|
+
>>> interpret_tracking_configuration(None)
|
|
422
|
+
'/path/to/default/config.json'
|
|
423
|
+
|
|
424
|
+
"""
|
|
425
|
+
|
|
426
|
+
if isinstance(config, str):
|
|
427
|
+
if os.path.exists(config):
|
|
428
|
+
return config
|
|
429
|
+
else:
|
|
430
|
+
modelpath = os.sep.join(
|
|
431
|
+
[
|
|
432
|
+
os.path.split(os.path.dirname(os.path.realpath(__file__)))[0],
|
|
433
|
+
# "celldetective",
|
|
434
|
+
"models",
|
|
435
|
+
"tracking_configs",
|
|
436
|
+
os.sep,
|
|
437
|
+
]
|
|
438
|
+
)
|
|
439
|
+
if os.path.exists(modelpath + config + ".json"):
|
|
440
|
+
return modelpath + config + ".json"
|
|
441
|
+
else:
|
|
442
|
+
from btrack.datasets import cell_config
|
|
443
|
+
|
|
444
|
+
config = cell_config()
|
|
445
|
+
elif config is None:
|
|
446
|
+
from btrack.datasets import cell_config
|
|
447
|
+
|
|
448
|
+
config = cell_config()
|
|
449
|
+
|
|
450
|
+
return config
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from celldetective import get_logger
|
|
3
|
+
|
|
4
|
+
logger = get_logger()
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def split_by_ratio(arr, *ratios):
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
Split an array into multiple chunks based on given ratios.
|
|
11
|
+
|
|
12
|
+
Parameters
|
|
13
|
+
----------
|
|
14
|
+
arr : array-like
|
|
15
|
+
The input array to be split.
|
|
16
|
+
*ratios : float
|
|
17
|
+
Ratios specifying the proportions of each chunk. The sum of ratios should be less than or equal to 1.
|
|
18
|
+
|
|
19
|
+
Returns
|
|
20
|
+
-------
|
|
21
|
+
list
|
|
22
|
+
A list of arrays containing the splits/chunks of the input array.
|
|
23
|
+
|
|
24
|
+
Notes
|
|
25
|
+
-----
|
|
26
|
+
This function randomly permutes the input array (`arr`) and then splits it into multiple chunks based on the provided ratios.
|
|
27
|
+
The ratios determine the relative sizes of the resulting chunks. The sum of the ratios should be less than or equal to 1.
|
|
28
|
+
The function uses the accumulated ratios to determine the split indices.
|
|
29
|
+
|
|
30
|
+
The function returns a list of arrays representing the splits of the input array. The number of splits is equal to the number
|
|
31
|
+
of provided ratios. If there are more ratios than splits, the extra ratios are ignored.
|
|
32
|
+
|
|
33
|
+
Examples
|
|
34
|
+
--------
|
|
35
|
+
>>> arr = np.arange(10)
|
|
36
|
+
>>> splits = split_by_ratio(arr, 0.6, 0.2, 0.2)
|
|
37
|
+
>>> print(len(splits))
|
|
38
|
+
3
|
|
39
|
+
# Split the array into 3 chunks with ratios 0.6, 0.2, and 0.2.
|
|
40
|
+
|
|
41
|
+
>>> arr = np.arange(100)
|
|
42
|
+
>>> splits = split_by_ratio(arr, 0.5, 0.25)
|
|
43
|
+
>>> print([len(split) for split in splits])
|
|
44
|
+
[50, 25]
|
|
45
|
+
# Split the array into 2 chunks with ratios 0.5 and 0.25.
|
|
46
|
+
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
arr = np.random.permutation(arr)
|
|
50
|
+
ind = np.add.accumulate(np.array(ratios) * len(arr)).astype(int)
|
|
51
|
+
return [x.tolist() for x in np.split(arr, ind)][: len(ratios)]
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def compute_weights(y):
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
Compute class weights based on the input labels.
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
y : array-like
|
|
62
|
+
Array of labels.
|
|
63
|
+
|
|
64
|
+
Returns
|
|
65
|
+
-------
|
|
66
|
+
dict
|
|
67
|
+
A dictionary containing the computed class weights.
|
|
68
|
+
|
|
69
|
+
Notes
|
|
70
|
+
-----
|
|
71
|
+
This function calculates the class weights based on the input labels (`y`) using the "balanced" method.
|
|
72
|
+
The class weights are computed to address the class imbalance problem, where the weights are inversely
|
|
73
|
+
proportional to the class frequencies.
|
|
74
|
+
|
|
75
|
+
The function returns a dictionary (`class_weights`) where the keys represent the unique classes in `y`
|
|
76
|
+
and the values represent the computed weights for each class.
|
|
77
|
+
|
|
78
|
+
Examples
|
|
79
|
+
--------
|
|
80
|
+
>>> labels = np.array([0, 1, 0, 1, 1])
|
|
81
|
+
>>> weights = compute_weights(labels)
|
|
82
|
+
>>> print(weights)
|
|
83
|
+
{0: 1.5, 1: 0.75}
|
|
84
|
+
# Compute class weights for the binary labels.
|
|
85
|
+
|
|
86
|
+
>>> labels = np.array([0, 1, 2, 0, 1, 2, 2])
|
|
87
|
+
>>> weights = compute_weights(labels)
|
|
88
|
+
>>> print(weights)
|
|
89
|
+
{0: 1.1666666666666667, 1: 1.1666666666666667, 2: 0.5833333333333334}
|
|
90
|
+
# Compute class weights for the multi-class labels.
|
|
91
|
+
|
|
92
|
+
"""
|
|
93
|
+
from sklearn.utils import compute_class_weight
|
|
94
|
+
|
|
95
|
+
class_weights = compute_class_weight(
|
|
96
|
+
class_weight="balanced",
|
|
97
|
+
classes=np.unique(y),
|
|
98
|
+
y=y,
|
|
99
|
+
)
|
|
100
|
+
class_weights = dict(zip(np.unique(y), class_weights))
|
|
101
|
+
|
|
102
|
+
return class_weights
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def train_test_split(
|
|
106
|
+
data_x, data_y1, data_class=None, validation_size=0.25, test_size=0, n_iterations=10
|
|
107
|
+
):
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
Split the dataset into training, validation, and test sets.
|
|
111
|
+
|
|
112
|
+
Parameters
|
|
113
|
+
----------
|
|
114
|
+
data_x : array-like
|
|
115
|
+
Input features or independent variables.
|
|
116
|
+
data_y1 : array-like
|
|
117
|
+
Target variable 1.
|
|
118
|
+
data_y2 : array-like
|
|
119
|
+
Target variable 2.
|
|
120
|
+
validation_size : float, optional
|
|
121
|
+
Proportion of the dataset to include in the validation set. Default is 0.25.
|
|
122
|
+
test_size : float, optional
|
|
123
|
+
Proportion of the dataset to include in the test set. Default is 0.
|
|
124
|
+
|
|
125
|
+
Returns
|
|
126
|
+
-------
|
|
127
|
+
dict
|
|
128
|
+
A dictionary containing the split datasets.
|
|
129
|
+
Keys: "x_train", "x_val", "y1_train", "y1_val", "y2_train", "y2_val".
|
|
130
|
+
If test_size > 0, additional keys: "x_test", "y1_test", "y2_test".
|
|
131
|
+
|
|
132
|
+
Notes
|
|
133
|
+
-----
|
|
134
|
+
This function divides the dataset into training, validation, and test sets based on the specified proportions.
|
|
135
|
+
It shuffles the data and splits it according to the proportions defined by `validation_size` and `test_size`.
|
|
136
|
+
|
|
137
|
+
The input features (`data_x`) and target variables (`data_y1`, `data_y2`) should be arrays or array-like objects
|
|
138
|
+
with compatible dimensions.
|
|
139
|
+
|
|
140
|
+
The function returns a dictionary containing the split datasets. The training set is assigned to "x_train",
|
|
141
|
+
"y1_train", and "y2_train". The validation set is assigned to "x_val", "y1_val", and "y2_val". If `test_size` is
|
|
142
|
+
greater than 0, the test set is assigned to "x_test", "y1_test", and "y2_test".
|
|
143
|
+
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
if data_class is not None:
|
|
147
|
+
logger.info(
|
|
148
|
+
f"Unique classes: {np.sort(np.argmax(np.unique(data_class,axis=0),axis=1))}"
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
for i in range(n_iterations):
|
|
152
|
+
|
|
153
|
+
n_values = len(data_x)
|
|
154
|
+
randomize = np.arange(n_values)
|
|
155
|
+
np.random.shuffle(randomize)
|
|
156
|
+
|
|
157
|
+
train_percentage = 1 - validation_size - test_size
|
|
158
|
+
|
|
159
|
+
chunks = split_by_ratio(randomize, train_percentage, validation_size, test_size)
|
|
160
|
+
|
|
161
|
+
x_train = data_x[chunks[0]]
|
|
162
|
+
y1_train = data_y1[chunks[0]]
|
|
163
|
+
if data_class is not None:
|
|
164
|
+
y2_train = data_class[chunks[0]]
|
|
165
|
+
|
|
166
|
+
x_val = data_x[chunks[1]]
|
|
167
|
+
y1_val = data_y1[chunks[1]]
|
|
168
|
+
if data_class is not None:
|
|
169
|
+
y2_val = data_class[chunks[1]]
|
|
170
|
+
|
|
171
|
+
if data_class is not None:
|
|
172
|
+
print(
|
|
173
|
+
f"classes in train set: {np.sort(np.argmax(np.unique(y2_train,axis=0),axis=1))}; classes in validation set: {np.sort(np.argmax(np.unique(y2_val,axis=0),axis=1))}"
|
|
174
|
+
)
|
|
175
|
+
same_class_test = np.array_equal(
|
|
176
|
+
np.sort(np.argmax(np.unique(y2_train, axis=0), axis=1)),
|
|
177
|
+
np.sort(np.argmax(np.unique(y2_val, axis=0), axis=1)),
|
|
178
|
+
)
|
|
179
|
+
print(f"Check that classes are found in all sets: {same_class_test}...")
|
|
180
|
+
else:
|
|
181
|
+
same_class_test = True
|
|
182
|
+
|
|
183
|
+
if same_class_test:
|
|
184
|
+
|
|
185
|
+
ds = {
|
|
186
|
+
"x_train": x_train,
|
|
187
|
+
"x_val": x_val,
|
|
188
|
+
"y1_train": y1_train,
|
|
189
|
+
"y1_val": y1_val,
|
|
190
|
+
}
|
|
191
|
+
if data_class is not None:
|
|
192
|
+
ds.update({"y2_train": y2_train, "y2_val": y2_val})
|
|
193
|
+
|
|
194
|
+
if test_size > 0:
|
|
195
|
+
x_test = data_x[chunks[2]]
|
|
196
|
+
y1_test = data_y1[chunks[2]]
|
|
197
|
+
ds.update({"x_test": x_test, "y1_test": y1_test})
|
|
198
|
+
if data_class is not None:
|
|
199
|
+
y2_test = data_class[chunks[2]]
|
|
200
|
+
ds.update({"y2_test": y2_test})
|
|
201
|
+
return ds
|
|
202
|
+
else:
|
|
203
|
+
continue
|
|
204
|
+
|
|
205
|
+
raise Exception(
|
|
206
|
+
"Some classes are missing from the train or validation set... Abort."
|
|
207
|
+
)
|