celldetective 1.4.1.post1__py3-none-any.whl → 1.5.0b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (151) hide show
  1. celldetective/__init__.py +25 -0
  2. celldetective/__main__.py +62 -43
  3. celldetective/_version.py +1 -1
  4. celldetective/extra_properties.py +477 -399
  5. celldetective/filters.py +192 -97
  6. celldetective/gui/InitWindow.py +541 -411
  7. celldetective/gui/__init__.py +0 -15
  8. celldetective/gui/about.py +44 -39
  9. celldetective/gui/analyze_block.py +120 -84
  10. celldetective/gui/base/__init__.py +0 -0
  11. celldetective/gui/base/channel_norm_generator.py +335 -0
  12. celldetective/gui/base/components.py +249 -0
  13. celldetective/gui/base/feature_choice.py +92 -0
  14. celldetective/gui/base/figure_canvas.py +52 -0
  15. celldetective/gui/base/list_widget.py +133 -0
  16. celldetective/gui/{styles.py → base/styles.py} +92 -36
  17. celldetective/gui/base/utils.py +33 -0
  18. celldetective/gui/base_annotator.py +900 -767
  19. celldetective/gui/classifier_widget.py +642 -554
  20. celldetective/gui/configure_new_exp.py +777 -671
  21. celldetective/gui/control_panel.py +635 -524
  22. celldetective/gui/dynamic_progress.py +449 -0
  23. celldetective/gui/event_annotator.py +2023 -1662
  24. celldetective/gui/generic_signal_plot.py +1292 -944
  25. celldetective/gui/gui_utils.py +899 -1289
  26. celldetective/gui/interactions_block.py +658 -0
  27. celldetective/gui/interactive_timeseries_viewer.py +447 -0
  28. celldetective/gui/json_readers.py +48 -15
  29. celldetective/gui/layouts/__init__.py +5 -0
  30. celldetective/gui/layouts/background_model_free_layout.py +537 -0
  31. celldetective/gui/layouts/channel_offset_layout.py +134 -0
  32. celldetective/gui/layouts/local_correction_layout.py +91 -0
  33. celldetective/gui/layouts/model_fit_layout.py +372 -0
  34. celldetective/gui/layouts/operation_layout.py +68 -0
  35. celldetective/gui/layouts/protocol_designer_layout.py +96 -0
  36. celldetective/gui/pair_event_annotator.py +3130 -2435
  37. celldetective/gui/plot_measurements.py +586 -267
  38. celldetective/gui/plot_signals_ui.py +724 -506
  39. celldetective/gui/preprocessing_block.py +395 -0
  40. celldetective/gui/process_block.py +1678 -1831
  41. celldetective/gui/seg_model_loader.py +580 -473
  42. celldetective/gui/settings/__init__.py +0 -7
  43. celldetective/gui/settings/_cellpose_model_params.py +181 -0
  44. celldetective/gui/settings/_event_detection_model_params.py +95 -0
  45. celldetective/gui/settings/_segmentation_model_params.py +159 -0
  46. celldetective/gui/settings/_settings_base.py +77 -65
  47. celldetective/gui/settings/_settings_event_model_training.py +752 -526
  48. celldetective/gui/settings/_settings_measurements.py +1133 -964
  49. celldetective/gui/settings/_settings_neighborhood.py +574 -488
  50. celldetective/gui/settings/_settings_segmentation_model_training.py +779 -564
  51. celldetective/gui/settings/_settings_signal_annotator.py +329 -305
  52. celldetective/gui/settings/_settings_tracking.py +1304 -1094
  53. celldetective/gui/settings/_stardist_model_params.py +98 -0
  54. celldetective/gui/survival_ui.py +422 -312
  55. celldetective/gui/tableUI.py +1665 -1700
  56. celldetective/gui/table_ops/_maths.py +295 -0
  57. celldetective/gui/table_ops/_merge_groups.py +140 -0
  58. celldetective/gui/table_ops/_merge_one_hot.py +95 -0
  59. celldetective/gui/table_ops/_query_table.py +43 -0
  60. celldetective/gui/table_ops/_rename_col.py +44 -0
  61. celldetective/gui/thresholds_gui.py +382 -179
  62. celldetective/gui/viewers/__init__.py +0 -0
  63. celldetective/gui/viewers/base_viewer.py +700 -0
  64. celldetective/gui/viewers/channel_offset_viewer.py +331 -0
  65. celldetective/gui/viewers/contour_viewer.py +394 -0
  66. celldetective/gui/viewers/size_viewer.py +153 -0
  67. celldetective/gui/viewers/spot_detection_viewer.py +341 -0
  68. celldetective/gui/viewers/threshold_viewer.py +309 -0
  69. celldetective/gui/workers.py +304 -126
  70. celldetective/log_manager.py +92 -0
  71. celldetective/measure.py +1895 -1478
  72. celldetective/napari/__init__.py +0 -0
  73. celldetective/napari/utils.py +1025 -0
  74. celldetective/neighborhood.py +1914 -1448
  75. celldetective/preprocessing.py +1620 -1220
  76. celldetective/processes/__init__.py +0 -0
  77. celldetective/processes/background_correction.py +271 -0
  78. celldetective/processes/compute_neighborhood.py +894 -0
  79. celldetective/processes/detect_events.py +246 -0
  80. celldetective/processes/measure_cells.py +565 -0
  81. celldetective/processes/segment_cells.py +760 -0
  82. celldetective/processes/track_cells.py +435 -0
  83. celldetective/processes/train_segmentation_model.py +694 -0
  84. celldetective/processes/train_signal_model.py +265 -0
  85. celldetective/processes/unified_process.py +292 -0
  86. celldetective/regionprops/_regionprops.py +358 -317
  87. celldetective/relative_measurements.py +987 -710
  88. celldetective/scripts/measure_cells.py +313 -212
  89. celldetective/scripts/measure_relative.py +90 -46
  90. celldetective/scripts/segment_cells.py +165 -104
  91. celldetective/scripts/segment_cells_thresholds.py +96 -68
  92. celldetective/scripts/track_cells.py +198 -149
  93. celldetective/scripts/train_segmentation_model.py +324 -201
  94. celldetective/scripts/train_signal_model.py +87 -45
  95. celldetective/segmentation.py +844 -749
  96. celldetective/signals.py +3514 -2861
  97. celldetective/tracking.py +1332 -1011
  98. celldetective/utils/__init__.py +0 -0
  99. celldetective/utils/cellpose_utils/__init__.py +133 -0
  100. celldetective/utils/color_mappings.py +42 -0
  101. celldetective/utils/data_cleaning.py +630 -0
  102. celldetective/utils/data_loaders.py +450 -0
  103. celldetective/utils/dataset_helpers.py +207 -0
  104. celldetective/utils/downloaders.py +197 -0
  105. celldetective/utils/event_detection/__init__.py +8 -0
  106. celldetective/utils/experiment.py +1782 -0
  107. celldetective/utils/image_augmenters.py +308 -0
  108. celldetective/utils/image_cleaning.py +74 -0
  109. celldetective/utils/image_loaders.py +926 -0
  110. celldetective/utils/image_transforms.py +335 -0
  111. celldetective/utils/io.py +62 -0
  112. celldetective/utils/mask_cleaning.py +348 -0
  113. celldetective/utils/mask_transforms.py +5 -0
  114. celldetective/utils/masks.py +184 -0
  115. celldetective/utils/maths.py +351 -0
  116. celldetective/utils/model_getters.py +325 -0
  117. celldetective/utils/model_loaders.py +296 -0
  118. celldetective/utils/normalization.py +380 -0
  119. celldetective/utils/parsing.py +465 -0
  120. celldetective/utils/plots/__init__.py +0 -0
  121. celldetective/utils/plots/regression.py +53 -0
  122. celldetective/utils/resources.py +34 -0
  123. celldetective/utils/stardist_utils/__init__.py +104 -0
  124. celldetective/utils/stats.py +90 -0
  125. celldetective/utils/types.py +21 -0
  126. {celldetective-1.4.1.post1.dist-info → celldetective-1.5.0b0.dist-info}/METADATA +1 -1
  127. celldetective-1.5.0b0.dist-info/RECORD +187 -0
  128. {celldetective-1.4.1.post1.dist-info → celldetective-1.5.0b0.dist-info}/WHEEL +1 -1
  129. tests/gui/test_new_project.py +129 -117
  130. tests/gui/test_project.py +127 -79
  131. tests/test_filters.py +39 -15
  132. tests/test_notebooks.py +8 -0
  133. tests/test_tracking.py +425 -144
  134. tests/test_utils.py +123 -77
  135. celldetective/gui/base_components.py +0 -23
  136. celldetective/gui/layouts.py +0 -1602
  137. celldetective/gui/processes/compute_neighborhood.py +0 -594
  138. celldetective/gui/processes/measure_cells.py +0 -360
  139. celldetective/gui/processes/segment_cells.py +0 -499
  140. celldetective/gui/processes/track_cells.py +0 -303
  141. celldetective/gui/processes/train_segmentation_model.py +0 -270
  142. celldetective/gui/processes/train_signal_model.py +0 -108
  143. celldetective/gui/table_ops/merge_groups.py +0 -118
  144. celldetective/gui/viewers.py +0 -1354
  145. celldetective/io.py +0 -3663
  146. celldetective/utils.py +0 -3108
  147. celldetective-1.4.1.post1.dist-info/RECORD +0 -123
  148. /celldetective/{gui/processes → processes}/downloader.py +0 -0
  149. {celldetective-1.4.1.post1.dist-info → celldetective-1.5.0b0.dist-info}/entry_points.txt +0 -0
  150. {celldetective-1.4.1.post1.dist-info → celldetective-1.5.0b0.dist-info}/licenses/LICENSE +0 -0
  151. {celldetective-1.4.1.post1.dist-info → celldetective-1.5.0b0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,450 @@
1
+ import os
2
+
3
+ import numpy as np
4
+
5
+
6
+ from tqdm import tqdm
7
+ from celldetective import get_logger
8
+ from celldetective.utils.image_loaders import locate_stack_and_labels
9
+ from celldetective.utils.experiment import (
10
+ get_config,
11
+ get_experiment_wells,
12
+ get_experiment_labels,
13
+ get_experiment_metadata,
14
+ extract_well_name_and_number,
15
+ extract_position_name,
16
+ interpret_wells_and_positions,
17
+ get_position_movie_path,
18
+ get_positions_in_well,
19
+ )
20
+ from celldetective.utils.parsing import (
21
+ config_section_to_dict,
22
+ _extract_labels_from_config,
23
+ )
24
+
25
+ logger = get_logger()
26
+
27
+
28
+ def get_position_table(pos, population, return_path=False):
29
+ """
30
+ Retrieves the data table for a specified population at a given position, optionally returning the table's file path.
31
+
32
+ This function locates and loads a CSV data table associated with a specific population (e.g., 'targets', 'cells')
33
+ from a specified position directory. The position directory should contain an 'output/tables' subdirectory where
34
+ the CSV file named 'trajectories_{population}.csv' is expected to be found. If the file exists, it is loaded into
35
+ a pandas DataFrame; otherwise, None is returned.
36
+
37
+ Parameters
38
+ ----------
39
+ pos : str
40
+ The path to the position directory from which to load the data table.
41
+ population : str
42
+ The name of the population for which the data table is to be retrieved. This name is used to construct the
43
+ file name of the CSV file to be loaded.
44
+ return_path : bool, optional
45
+ If True, returns a tuple containing the loaded data table (or None) and the path to the CSV file. If False,
46
+ only the loaded data table (or None) is returned (default is False).
47
+
48
+ Returns
49
+ -------
50
+ pandas.DataFrame or None, or (pandas.DataFrame or None, str)
51
+ If return_path is False, returns the loaded data table as a pandas DataFrame, or None if the table file does
52
+ not exist. If return_path is True, returns a tuple where the first element is the data table (or None) and the
53
+ second element is the path to the CSV file.
54
+
55
+ Examples
56
+ --------
57
+ >>> df_pos = get_position_table('/path/to/position', 'targets')
58
+ # This will load the 'trajectories_targets.csv' table from the specified position directory into a pandas DataFrame.
59
+
60
+ >>> df_pos, table_path = get_position_table('/path/to/position', 'targets', return_path=True)
61
+ # This will load the 'trajectories_targets.csv' table and also return the path to the CSV file.
62
+
63
+ """
64
+
65
+ import pandas as pd
66
+
67
+ if not pos.endswith(os.sep):
68
+ table = os.sep.join([pos, "output", "tables", f"trajectories_{population}.csv"])
69
+ else:
70
+ table = pos + os.sep.join(
71
+ ["output", "tables", f"trajectories_{population}.csv"]
72
+ )
73
+
74
+ if os.path.exists(table):
75
+ try:
76
+ df_pos = pd.read_csv(table, low_memory=False)
77
+ except Exception as e:
78
+ logger.error(e)
79
+ df_pos = None
80
+ else:
81
+ df_pos = None
82
+
83
+ if return_path:
84
+ return df_pos, table
85
+ else:
86
+ return df_pos
87
+
88
+
89
+ def get_position_pickle(pos, population, return_path=False):
90
+ """
91
+ Retrieves the data table for a specified population at a given position, optionally returning the table's file path.
92
+
93
+ This function locates and loads a CSV data table associated with a specific population (e.g., 'targets', 'cells')
94
+ from a specified position directory. The position directory should contain an 'output/tables' subdirectory where
95
+ the CSV file named 'trajectories_{population}.csv' is expected to be found. If the file exists, it is loaded into
96
+ a pandas DataFrame; otherwise, None is returned.
97
+
98
+ Parameters
99
+ ----------
100
+ pos : str
101
+ The path to the position directory from which to load the data table.
102
+ population : str
103
+ The name of the population for which the data table is to be retrieved. This name is used to construct the
104
+ file name of the CSV file to be loaded.
105
+ return_path : bool, optional
106
+ If True, returns a tuple containing the loaded data table (or None) and the path to the CSV file. If False,
107
+ only the loaded data table (or None) is returned (default is False).
108
+
109
+ Returns
110
+ -------
111
+ pandas.DataFrame or None, or (pandas.DataFrame or None, str)
112
+ If return_path is False, returns the loaded data table as a pandas DataFrame, or None if the table file does
113
+ not exist. If return_path is True, returns a tuple where the first element is the data table (or None) and the
114
+ second element is the path to the CSV file.
115
+
116
+ Examples
117
+ --------
118
+ >>> df_pos = get_position_table('/path/to/position', 'targets')
119
+ # This will load the 'trajectories_targets.csv' table from the specified position directory into a pandas DataFrame.
120
+
121
+ >>> df_pos, table_path = get_position_table('/path/to/position', 'targets', return_path=True)
122
+ # This will load the 'trajectories_targets.csv' table and also return the path to the CSV file.
123
+
124
+ """
125
+
126
+ if not pos.endswith(os.sep):
127
+ table = os.sep.join([pos, "output", "tables", f"trajectories_{population}.pkl"])
128
+ else:
129
+ table = pos + os.sep.join(
130
+ ["output", "tables", f"trajectories_{population}.pkl"]
131
+ )
132
+
133
+ if os.path.exists(table):
134
+ df_pos = np.load(table, allow_pickle=True)
135
+ else:
136
+ df_pos = None
137
+
138
+ if return_path:
139
+ return df_pos, table
140
+ else:
141
+ return df_pos
142
+
143
+
144
+ def load_experiment_tables(
145
+ experiment,
146
+ population="targets",
147
+ well_option="*",
148
+ position_option="*",
149
+ return_pos_info=False,
150
+ load_pickle=False,
151
+ ):
152
+ """
153
+ Load tabular data for an experiment, optionally including position-level information.
154
+
155
+ This function retrieves and processes tables associated with positions in an experiment.
156
+ It supports filtering by wells and positions, and can load either CSV data or pickle files.
157
+
158
+ Parameters
159
+ ----------
160
+ experiment : str
161
+ Path to the experiment folder to load data for.
162
+ population : str, optional
163
+ The population to extract from the position tables (`'targets'` or `'effectors'`). Default is `'targets'`.
164
+ well_option : str or list, optional
165
+ Specifies which wells to include. Default is `'*'`, meaning all wells.
166
+ position_option : str or list, optional
167
+ Specifies which positions to include within selected wells. Default is `'*'`, meaning all positions.
168
+ return_pos_info : bool, optional
169
+ If `True`, also returns a DataFrame containing position-level metadata. Default is `False`.
170
+ load_pickle : bool, optional
171
+ If `True`, loads pre-processed pickle files for the positions instead of raw data. Default is `False`.
172
+
173
+ Returns
174
+ -------
175
+ df : pandas.DataFrame or None
176
+ A DataFrame containing aggregated data for the specified wells and positions, or `None` if no data is found.
177
+ The DataFrame includes metadata such as well and position identifiers, concentrations, antibodies, and other
178
+ experimental parameters.
179
+ df_pos_info : pandas.DataFrame, optional
180
+ A DataFrame with metadata for each position, including file paths and experimental details. Returned only
181
+ if `return_pos_info=True`.
182
+
183
+ Notes
184
+ -----
185
+ - The function assumes the experiment's configuration includes details about movie prefixes, concentrations,
186
+ cell types, antibodies, and pharmaceutical agents.
187
+ - Wells and positions can be filtered using `well_option` and `position_option`, respectively. If filtering
188
+ fails or is invalid, those specific wells/positions are skipped.
189
+ - Position-level metadata is assembled into `df_pos_info` and includes paths to data and movies.
190
+
191
+ Examples
192
+ --------
193
+ Load all data for an experiment:
194
+
195
+ >>> df = load_experiment_tables("path/to/experiment1")
196
+
197
+ Load data for specific wells and positions, including position metadata:
198
+
199
+ >>> df, df_pos_info = load_experiment_tables(
200
+ ... "experiment_01", well_option=["A1", "B1"], position_option=[0, 1], return_pos_info=True
201
+ ... )
202
+
203
+ Use pickle files for faster loading:
204
+
205
+ >>> df = load_experiment_tables("experiment_01", load_pickle=True)
206
+
207
+ """
208
+
209
+ import pandas as pd
210
+
211
+ config = get_config(experiment)
212
+ wells = get_experiment_wells(experiment)
213
+
214
+ movie_prefix = config_section_to_dict(config, "MovieSettings")["movie_prefix"]
215
+
216
+ labels = get_experiment_labels(experiment)
217
+ metadata = get_experiment_metadata(experiment) # None or dict of metadata
218
+ well_labels = _extract_labels_from_config(config, len(wells))
219
+
220
+ well_indices, position_indices = interpret_wells_and_positions(
221
+ experiment, well_option, position_option
222
+ )
223
+
224
+ df = []
225
+ df_pos_info = []
226
+ real_well_index = 0
227
+
228
+ for k, well_path in enumerate(tqdm(wells[well_indices])):
229
+
230
+ any_table = False # assume no table
231
+
232
+ well_name, well_number = extract_well_name_and_number(well_path)
233
+ widx = well_indices[k]
234
+ well_alias = well_labels[widx]
235
+
236
+ positions = get_positions_in_well(well_path)
237
+ if position_indices is not None:
238
+ try:
239
+ positions = positions[position_indices]
240
+ except Exception as e:
241
+ logger.error(e)
242
+ continue
243
+
244
+ real_pos_index = 0
245
+ for pidx, pos_path in enumerate(positions):
246
+
247
+ pos_name = extract_position_name(pos_path)
248
+
249
+ stack_path = get_position_movie_path(pos_path, prefix=movie_prefix)
250
+
251
+ if not load_pickle:
252
+ df_pos, table = get_position_table(
253
+ pos_path, population=population, return_path=True
254
+ )
255
+ else:
256
+ df_pos, table = get_position_pickle(
257
+ pos_path, population=population, return_path=True
258
+ )
259
+
260
+ if df_pos is not None:
261
+
262
+ df_pos["position"] = pos_path
263
+ df_pos["well"] = well_path
264
+ df_pos["well_index"] = well_number
265
+ df_pos["well_name"] = well_name
266
+ df_pos["pos_name"] = pos_name
267
+
268
+ for k in list(labels.keys()):
269
+ values = labels[k]
270
+ try:
271
+ df_pos[k] = values[widx]
272
+ except Exception as e:
273
+ logger.error(f"{e=}")
274
+
275
+ if metadata is not None:
276
+ keys = list(metadata.keys())
277
+ for key in keys:
278
+ df_pos[key] = metadata[key]
279
+
280
+ df.append(df_pos)
281
+ any_table = True
282
+
283
+ pos_dict = {
284
+ "pos_path": pos_path,
285
+ "pos_index": real_pos_index,
286
+ "pos_name": pos_name,
287
+ "table_path": table,
288
+ "stack_path": stack_path,
289
+ "well_path": well_path,
290
+ "well_index": real_well_index,
291
+ "well_name": well_name,
292
+ "well_number": well_number,
293
+ "well_alias": well_alias,
294
+ }
295
+
296
+ df_pos_info.append(pos_dict)
297
+
298
+ real_pos_index += 1
299
+
300
+ if any_table:
301
+ real_well_index += 1
302
+
303
+ df_pos_info = pd.DataFrame(df_pos_info)
304
+ if len(df) > 0:
305
+ df = pd.concat(df)
306
+ df = df.reset_index(drop=True)
307
+ else:
308
+ df = None
309
+
310
+ if return_pos_info:
311
+ return df, df_pos_info
312
+ else:
313
+ return df
314
+
315
+
316
+ def load_tracking_data(position, prefix="Aligned", population="target"):
317
+ """
318
+
319
+ Load the tracking data, labels, and stack for a given position and population.
320
+
321
+ Parameters
322
+ ----------
323
+ position : str
324
+ The position or directory where the data is located.
325
+ prefix : str, optional
326
+ The prefix used in the filenames of the stack images (default is "Aligned").
327
+ population : str, optional
328
+ The population to load the data for. Options are "target" or "effector" (default is "target").
329
+
330
+ Returns
331
+ -------
332
+ trajectories : DataFrame
333
+ The tracking data loaded as a pandas DataFrame.
334
+ labels : ndarray
335
+ The segmentation labels loaded as a numpy ndarray.
336
+ stack : ndarray
337
+ The image stack loaded as a numpy ndarray.
338
+
339
+ Notes
340
+ -----
341
+ This function loads the tracking data, labels, and stack for a given position and population.
342
+ It reads the trajectories from the appropriate CSV file based on the specified population.
343
+ The stack and labels are located using the `locate_stack_and_labels` function.
344
+ The resulting tracking data is returned as a pandas DataFrame, and the labels and stack are returned as numpy ndarrays.
345
+
346
+ Examples
347
+ --------
348
+ >>> trajectories, labels, stack = load_tracking_data(position, population="target")
349
+ # Load the tracking data, labels, and stack for the specified position and target population.
350
+
351
+ """
352
+
353
+ import pandas as pd
354
+
355
+ position = position.replace("\\", "/")
356
+ if population.lower() == "target" or population.lower() == "targets":
357
+ trajectories = pd.read_csv(
358
+ position + os.sep.join(["output", "tables", "trajectories_targets.csv"])
359
+ )
360
+ elif population.lower() == "effector" or population.lower() == "effectors":
361
+ trajectories = pd.read_csv(
362
+ position + os.sep.join(["output", "tables", "trajectories_effectors.csv"])
363
+ )
364
+ else:
365
+ trajectories = pd.read_csv(
366
+ position
367
+ + os.sep.join(["output", "tables", f"trajectories_{population}.csv"])
368
+ )
369
+
370
+ stack, labels = locate_stack_and_labels(
371
+ position, prefix=prefix, population=population
372
+ )
373
+
374
+ return trajectories, labels, stack
375
+
376
+
377
+ def interpret_tracking_configuration(config):
378
+ """
379
+ Interpret and resolve the path for a tracking configuration file.
380
+
381
+ This function determines the appropriate configuration file path based on the input.
382
+ If the input is a string representing an existing path or a known configuration name,
383
+ it resolves to the correct file path. If the input is invalid or `None`, a default
384
+ configuration is returned.
385
+
386
+ Parameters
387
+ ----------
388
+ config : str or None
389
+ The input configuration, which can be:
390
+ - A string representing the full path to a configuration file.
391
+ - A short name of a configuration file without the `.json` extension.
392
+ - `None` to use a default configuration.
393
+
394
+ Returns
395
+ -------
396
+ str
397
+ The resolved path to the configuration file.
398
+
399
+ Notes
400
+ -----
401
+ - If `config` is a string and the specified path exists, it is returned as-is.
402
+ - If `config` is a name, the function searches in the `tracking_configs` directory
403
+ within the `celldetective` models folder.
404
+ - If the file or name is not found, or if `config` is `None`, the function falls
405
+ back to a default configuration using `cell_config()`.
406
+
407
+ Examples
408
+ --------
409
+ Resolve a full path:
410
+
411
+ >>> interpret_tracking_configuration("/path/to/config.json")
412
+ '/path/to/config.json'
413
+
414
+ Resolve a named configuration:
415
+
416
+ >>> interpret_tracking_configuration("default_tracking")
417
+ '/path/to/celldetective/models/tracking_configs/default_tracking.json'
418
+
419
+ Handle `None` to return the default configuration:
420
+
421
+ >>> interpret_tracking_configuration(None)
422
+ '/path/to/default/config.json'
423
+
424
+ """
425
+
426
+ if isinstance(config, str):
427
+ if os.path.exists(config):
428
+ return config
429
+ else:
430
+ modelpath = os.sep.join(
431
+ [
432
+ os.path.split(os.path.dirname(os.path.realpath(__file__)))[0],
433
+ # "celldetective",
434
+ "models",
435
+ "tracking_configs",
436
+ os.sep,
437
+ ]
438
+ )
439
+ if os.path.exists(modelpath + config + ".json"):
440
+ return modelpath + config + ".json"
441
+ else:
442
+ from btrack.datasets import cell_config
443
+
444
+ config = cell_config()
445
+ elif config is None:
446
+ from btrack.datasets import cell_config
447
+
448
+ config = cell_config()
449
+
450
+ return config
@@ -0,0 +1,207 @@
1
+ import numpy as np
2
+ from celldetective import get_logger
3
+
4
+ logger = get_logger()
5
+
6
+
7
+ def split_by_ratio(arr, *ratios):
8
+ """
9
+
10
+ Split an array into multiple chunks based on given ratios.
11
+
12
+ Parameters
13
+ ----------
14
+ arr : array-like
15
+ The input array to be split.
16
+ *ratios : float
17
+ Ratios specifying the proportions of each chunk. The sum of ratios should be less than or equal to 1.
18
+
19
+ Returns
20
+ -------
21
+ list
22
+ A list of arrays containing the splits/chunks of the input array.
23
+
24
+ Notes
25
+ -----
26
+ This function randomly permutes the input array (`arr`) and then splits it into multiple chunks based on the provided ratios.
27
+ The ratios determine the relative sizes of the resulting chunks. The sum of the ratios should be less than or equal to 1.
28
+ The function uses the accumulated ratios to determine the split indices.
29
+
30
+ The function returns a list of arrays representing the splits of the input array. The number of splits is equal to the number
31
+ of provided ratios. If there are more ratios than splits, the extra ratios are ignored.
32
+
33
+ Examples
34
+ --------
35
+ >>> arr = np.arange(10)
36
+ >>> splits = split_by_ratio(arr, 0.6, 0.2, 0.2)
37
+ >>> print(len(splits))
38
+ 3
39
+ # Split the array into 3 chunks with ratios 0.6, 0.2, and 0.2.
40
+
41
+ >>> arr = np.arange(100)
42
+ >>> splits = split_by_ratio(arr, 0.5, 0.25)
43
+ >>> print([len(split) for split in splits])
44
+ [50, 25]
45
+ # Split the array into 2 chunks with ratios 0.5 and 0.25.
46
+
47
+ """
48
+
49
+ arr = np.random.permutation(arr)
50
+ ind = np.add.accumulate(np.array(ratios) * len(arr)).astype(int)
51
+ return [x.tolist() for x in np.split(arr, ind)][: len(ratios)]
52
+
53
+
54
+ def compute_weights(y):
55
+ """
56
+
57
+ Compute class weights based on the input labels.
58
+
59
+ Parameters
60
+ ----------
61
+ y : array-like
62
+ Array of labels.
63
+
64
+ Returns
65
+ -------
66
+ dict
67
+ A dictionary containing the computed class weights.
68
+
69
+ Notes
70
+ -----
71
+ This function calculates the class weights based on the input labels (`y`) using the "balanced" method.
72
+ The class weights are computed to address the class imbalance problem, where the weights are inversely
73
+ proportional to the class frequencies.
74
+
75
+ The function returns a dictionary (`class_weights`) where the keys represent the unique classes in `y`
76
+ and the values represent the computed weights for each class.
77
+
78
+ Examples
79
+ --------
80
+ >>> labels = np.array([0, 1, 0, 1, 1])
81
+ >>> weights = compute_weights(labels)
82
+ >>> print(weights)
83
+ {0: 1.5, 1: 0.75}
84
+ # Compute class weights for the binary labels.
85
+
86
+ >>> labels = np.array([0, 1, 2, 0, 1, 2, 2])
87
+ >>> weights = compute_weights(labels)
88
+ >>> print(weights)
89
+ {0: 1.1666666666666667, 1: 1.1666666666666667, 2: 0.5833333333333334}
90
+ # Compute class weights for the multi-class labels.
91
+
92
+ """
93
+ from sklearn.utils import compute_class_weight
94
+
95
+ class_weights = compute_class_weight(
96
+ class_weight="balanced",
97
+ classes=np.unique(y),
98
+ y=y,
99
+ )
100
+ class_weights = dict(zip(np.unique(y), class_weights))
101
+
102
+ return class_weights
103
+
104
+
105
+ def train_test_split(
106
+ data_x, data_y1, data_class=None, validation_size=0.25, test_size=0, n_iterations=10
107
+ ):
108
+ """
109
+
110
+ Split the dataset into training, validation, and test sets.
111
+
112
+ Parameters
113
+ ----------
114
+ data_x : array-like
115
+ Input features or independent variables.
116
+ data_y1 : array-like
117
+ Target variable 1.
118
+ data_y2 : array-like
119
+ Target variable 2.
120
+ validation_size : float, optional
121
+ Proportion of the dataset to include in the validation set. Default is 0.25.
122
+ test_size : float, optional
123
+ Proportion of the dataset to include in the test set. Default is 0.
124
+
125
+ Returns
126
+ -------
127
+ dict
128
+ A dictionary containing the split datasets.
129
+ Keys: "x_train", "x_val", "y1_train", "y1_val", "y2_train", "y2_val".
130
+ If test_size > 0, additional keys: "x_test", "y1_test", "y2_test".
131
+
132
+ Notes
133
+ -----
134
+ This function divides the dataset into training, validation, and test sets based on the specified proportions.
135
+ It shuffles the data and splits it according to the proportions defined by `validation_size` and `test_size`.
136
+
137
+ The input features (`data_x`) and target variables (`data_y1`, `data_y2`) should be arrays or array-like objects
138
+ with compatible dimensions.
139
+
140
+ The function returns a dictionary containing the split datasets. The training set is assigned to "x_train",
141
+ "y1_train", and "y2_train". The validation set is assigned to "x_val", "y1_val", and "y2_val". If `test_size` is
142
+ greater than 0, the test set is assigned to "x_test", "y1_test", and "y2_test".
143
+
144
+ """
145
+
146
+ if data_class is not None:
147
+ logger.info(
148
+ f"Unique classes: {np.sort(np.argmax(np.unique(data_class,axis=0),axis=1))}"
149
+ )
150
+
151
+ for i in range(n_iterations):
152
+
153
+ n_values = len(data_x)
154
+ randomize = np.arange(n_values)
155
+ np.random.shuffle(randomize)
156
+
157
+ train_percentage = 1 - validation_size - test_size
158
+
159
+ chunks = split_by_ratio(randomize, train_percentage, validation_size, test_size)
160
+
161
+ x_train = data_x[chunks[0]]
162
+ y1_train = data_y1[chunks[0]]
163
+ if data_class is not None:
164
+ y2_train = data_class[chunks[0]]
165
+
166
+ x_val = data_x[chunks[1]]
167
+ y1_val = data_y1[chunks[1]]
168
+ if data_class is not None:
169
+ y2_val = data_class[chunks[1]]
170
+
171
+ if data_class is not None:
172
+ print(
173
+ f"classes in train set: {np.sort(np.argmax(np.unique(y2_train,axis=0),axis=1))}; classes in validation set: {np.sort(np.argmax(np.unique(y2_val,axis=0),axis=1))}"
174
+ )
175
+ same_class_test = np.array_equal(
176
+ np.sort(np.argmax(np.unique(y2_train, axis=0), axis=1)),
177
+ np.sort(np.argmax(np.unique(y2_val, axis=0), axis=1)),
178
+ )
179
+ print(f"Check that classes are found in all sets: {same_class_test}...")
180
+ else:
181
+ same_class_test = True
182
+
183
+ if same_class_test:
184
+
185
+ ds = {
186
+ "x_train": x_train,
187
+ "x_val": x_val,
188
+ "y1_train": y1_train,
189
+ "y1_val": y1_val,
190
+ }
191
+ if data_class is not None:
192
+ ds.update({"y2_train": y2_train, "y2_val": y2_val})
193
+
194
+ if test_size > 0:
195
+ x_test = data_x[chunks[2]]
196
+ y1_test = data_y1[chunks[2]]
197
+ ds.update({"x_test": x_test, "y1_test": y1_test})
198
+ if data_class is not None:
199
+ y2_test = data_class[chunks[2]]
200
+ ds.update({"y2_test": y2_test})
201
+ return ds
202
+ else:
203
+ continue
204
+
205
+ raise Exception(
206
+ "Some classes are missing from the train or validation set... Abort."
207
+ )