fucciphase 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fucciphase/plot.py ADDED
@@ -0,0 +1,548 @@
1
+ from itertools import cycle
2
+ from typing import List, Optional
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ from matplotlib import colormaps
7
+ from matplotlib import pyplot as plt
8
+ from matplotlib.collections import LineCollection
9
+ from matplotlib.figure import Figure
10
+ from scipy import interpolate
11
+
12
+ from .phase import NewColumns
13
+ from .utils import get_norm_channel_name
14
+
15
+
16
+ def set_phase_colors(
17
+ df: pd.DataFrame, colordict: dict, phase_column: str = "DISCRETE_PHASE_MAX"
18
+ ) -> None:
19
+ """Label each phase by fixed color."""
20
+ phases = df[phase_column].unique()
21
+ if not all(phase in colordict for phase in phases):
22
+ raise ValueError(f"Provide a color for every phase in: {phases}")
23
+
24
+ df["COLOR"] = df[phase_column].copy()
25
+ for phase in phases:
26
+ df.loc[df[phase_column] == phase, "COLOR"] = colordict[phase]
27
+
28
+
29
+ def plot_feature(
30
+ df: pd.DataFrame,
31
+ time_column: str,
32
+ feature_name: str,
33
+ interpolate_time: bool = False,
34
+ track_id_name: str = "TRACK_ID",
35
+ ylim: Optional[tuple] = None,
36
+ yticks: Optional[list] = None,
37
+ ) -> Figure:
38
+ """Plot features of individual tracks in one plot."""
39
+ if feature_name not in df:
40
+ raise ValueError(f"(Feature {feature_name} not in provided DataFrame.")
41
+ if time_column not in df:
42
+ raise ValueError(f"(Time {time_column} not in provided DataFrame.")
43
+ tracks = df[track_id_name].unique()
44
+ tracks = tracks[tracks >= 0]
45
+
46
+ fig = plt.figure()
47
+ # Plot each graph, and manually set the y tick values
48
+ for track_idx in tracks:
49
+ time = df.loc[df[track_id_name] == track_idx, time_column].to_numpy()
50
+ feature = df.loc[df[track_id_name] == track_idx, feature_name].to_numpy()
51
+ plt.plot(time, feature)
52
+ if ylim is not None:
53
+ plt.ylim(ylim)
54
+ if yticks is not None:
55
+ plt.yticks(yticks)
56
+ return fig
57
+
58
+
59
+ # flake8: noqa: C901
60
+ def plot_feature_stacked(
61
+ df: pd.DataFrame,
62
+ time_column: str,
63
+ feature_name: str,
64
+ interpolate_time: bool = False,
65
+ track_id_name: str = "TRACK_ID",
66
+ ylim: Optional[tuple] = None,
67
+ yticks: Optional[list] = None,
68
+ interpolation_steps: int = 1000,
69
+ figsize: Optional[tuple] = None,
70
+ selected_tracks: Optional[List[int]] = None,
71
+ ) -> Figure:
72
+ """Stack features of individual tracks.
73
+
74
+
75
+ Notes
76
+ -----
77
+ If `selected_tracks` are chosen, the averaging
78
+ is still performed on all tracks.
79
+ Few selected tracks are stacked to enhance visibility.
80
+ """
81
+ if feature_name not in df:
82
+ raise ValueError(f"(Feature {feature_name} not in provided DataFrame.")
83
+ if time_column not in df:
84
+ raise ValueError(f"(Time {time_column} not in provided DataFrame.")
85
+ if "COLOR" not in df:
86
+ raise ValueError("Run set_phase_colors first on DataFrame")
87
+ tracks = df[track_id_name].unique()
88
+ tracks = tracks[tracks >= 0]
89
+ if selected_tracks is None:
90
+ selected_tracks = tracks
91
+ else:
92
+ if not set(selected_tracks).issubset(tracks):
93
+ raise ValueError(
94
+ "Selected tracks contain tracks that are not in track list."
95
+ )
96
+ if figsize is None:
97
+ figsize = (10, 2 * len(selected_tracks))
98
+ if not interpolate_time:
99
+ fig, axs = plt.subplots(len(selected_tracks), 1, sharex=True, figsize=figsize)
100
+ else:
101
+ fig, axs = plt.subplots(
102
+ len(selected_tracks) + 1, 1, sharex=True, figsize=figsize
103
+ )
104
+ # Remove horizontal space between axes
105
+ fig.subplots_adjust(hspace=0)
106
+
107
+ max_frame = 0
108
+ min_frame = np.inf
109
+
110
+ # Plot each graph, and manually set the y tick values
111
+ for i, track_idx in enumerate(selected_tracks):
112
+ time = df.loc[df[track_id_name] == track_idx, time_column].to_numpy()
113
+ feature = df.loc[df[track_id_name] == track_idx, feature_name].to_numpy()
114
+ colors = df.loc[df[track_id_name] == track_idx, "COLOR"].to_numpy()
115
+ axs[i].plot(time, feature)
116
+ axs[i].scatter(time, feature, c=colors, lw=4)
117
+ if ylim is not None:
118
+ axs[i].set_ylim(ylim)
119
+ if yticks is not None:
120
+ axs[i].set_yticks(yticks)
121
+ if time.max() > max_frame:
122
+ max_frame = time.max()
123
+ if time.min() < min_frame:
124
+ min_frame = time.min()
125
+
126
+ if interpolate_time:
127
+ interpolated_time = np.linspace(min_frame, max_frame, num=interpolation_steps)
128
+ interpolated_feature = np.zeros(shape=(len(interpolated_time), len(tracks)))
129
+ for i, track_idx in enumerate(tracks):
130
+ time = df.loc[df[track_id_name] == track_idx, time_column].to_numpy()
131
+ feature = df.loc[df[track_id_name] == track_idx, feature_name].to_numpy()
132
+ interpolated_feature[:, i] = np.interp(
133
+ interpolated_time, time, feature, left=np.nan, right=np.nan
134
+ )
135
+ axs[-1].plot(
136
+ interpolated_time,
137
+ np.nanmean(interpolated_feature, axis=1),
138
+ lw=5,
139
+ color="black",
140
+ )
141
+ if ylim is not None:
142
+ axs[-1].set_ylim(ylim)
143
+ if yticks is not None:
144
+ axs[-1].set_yticks(yticks)
145
+
146
+ return fig
147
+
148
+
149
+ def plot_raw_intensities(
150
+ df: pd.DataFrame,
151
+ channel1: str,
152
+ channel2: str,
153
+ color1: str = "cyan",
154
+ color2: str = "magenta",
155
+ time_column: str = "FRAME",
156
+ time_label: str = "Frame #",
157
+ **plot_kwargs: bool,
158
+ ) -> None:
159
+ """Plot intensities of two-channel sensor."""
160
+ ch1_intensity = df[channel1]
161
+ ch2_intensity = df[channel2]
162
+
163
+ t = df[time_column]
164
+
165
+ fig, ax1 = plt.subplots()
166
+
167
+ # prepare axes
168
+ ax1.set_xlabel(time_label)
169
+ ax1.set_ylabel(channel1, color=color1)
170
+ ax1.tick_params(axis="y", labelcolor=color1)
171
+ ax2 = ax1.twinx()
172
+ ax2.set_ylabel(channel2, color=color2)
173
+ ax2.tick_params(axis="y", labelcolor=color2)
174
+
175
+ # plot signal
176
+ ax1.plot(t, ch1_intensity, color=color1, **plot_kwargs)
177
+ ax2.plot(t, ch2_intensity, color=color2, **plot_kwargs)
178
+ fig.tight_layout()
179
+
180
+
181
+ def plot_normalized_intensities(
182
+ df: pd.DataFrame,
183
+ channel1: str,
184
+ channel2: str,
185
+ color1: str = "cyan",
186
+ color2: str = "magenta",
187
+ time_column: str = "FRAME",
188
+ time_label: str = "Frame #",
189
+ **plot_kwargs: bool,
190
+ ) -> None:
191
+ """Plot normalised intensities of two-channel sensor."""
192
+ ch1_intensity = df[get_norm_channel_name(channel1)]
193
+ ch2_intensity = df[get_norm_channel_name(channel2)]
194
+
195
+ t = df[time_column]
196
+ plt.plot(t, ch1_intensity, color=color1, label=channel1, **plot_kwargs)
197
+ plt.plot(t, ch2_intensity, color=color2, label=channel2, **plot_kwargs)
198
+ plt.xlabel(time_label)
199
+ plt.ylabel("Normalised intensity")
200
+
201
+
202
+ def plot_phase(df: pd.DataFrame, channel1: str, channel2: str) -> None:
203
+ """Plot the two channels and vertical lines
204
+ corresponding to the change of phase.
205
+
206
+ The dataframe must be preprocessed with one of the available phase
207
+ computation function and must contain the following columns:
208
+
209
+ - normalised channels (channel1 + "_NORM", etc)
210
+ - cell cycle percentage
211
+ - FRAME
212
+
213
+ Parameters
214
+ ----------
215
+ df : pd.DataFrame
216
+ Dataframe
217
+ channel1 : str
218
+ First channel
219
+ channel2 : str
220
+ Second channel
221
+
222
+ Raises
223
+ ------
224
+ ValueError
225
+ If the dataframe does not contain the FRAME, CELL_CYCLE_PERC and normalised
226
+ columns.
227
+ """
228
+ # check if the FRAME column is present
229
+ if "FRAME" not in df.columns:
230
+ raise ValueError("Column FRAME not found")
231
+
232
+ # check if all new columns are present
233
+ if NewColumns.cell_cycle() not in df.columns:
234
+ raise ValueError(f"Column {NewColumns.cell_cycle()} not found")
235
+
236
+ # get frame, normalised channels, unique intensity and phase
237
+ t = df["FRAME"].to_numpy()
238
+ channel1_norm = df[get_norm_channel_name(channel1)]
239
+ channel2_norm = df[get_norm_channel_name(channel2)]
240
+ unique_intensity = df[NewColumns.cell_cycle()]
241
+
242
+ # plot
243
+ plt.plot(t, channel1_norm, label=channel1)
244
+ plt.plot(t, channel2_norm, label=channel2)
245
+ plt.plot(t, unique_intensity, label="unique intensity")
246
+
247
+
248
+ def plot_dtw_query_vs_reference(
249
+ reference_df: pd.DataFrame,
250
+ df: pd.DataFrame,
251
+ channels: List[str],
252
+ ref_percentage_column: str = "percentage",
253
+ est_percentage_column: str = "CELL_CYCLE_PERC_DTW",
254
+ ground_truth: Optional[pd.DataFrame] = None,
255
+ colors: Optional[List[str]] = None,
256
+ **plot_kwargs: bool,
257
+ ) -> None:
258
+ """Plot query and alignment to reference curve.
259
+
260
+ Parameters
261
+ ----------
262
+ reference_df: pd.DataFrame
263
+ DataFrame with reference curve data
264
+ df: pd.DataFrame
265
+ DataFrame used for query
266
+ channels: List[str]
267
+ Name of the channels
268
+ ref_percentage_column: str
269
+ Name of column with percentages of reference curve
270
+ est_percentage_column: str
271
+ Name of column with estimated percentages
272
+ ground_truth: pd.DataFrame
273
+ DataFrame with ground truth data, needs to be named as reference_df
274
+ colors: List[str]
275
+ Colors for plot
276
+ plot_kwargs: dict
277
+ Kwargs to be passed to matplotlib
278
+ """
279
+ for channel in channels:
280
+ if channel not in reference_df.columns:
281
+ raise ValueError(f"Channel {channel} not in reference DataFrame")
282
+ if channel not in df.columns:
283
+ raise ValueError(f"Channel {channel} not in query DataFrame")
284
+ if est_percentage_column not in df.columns:
285
+ raise ValueError(
286
+ "Percentage column not found in query DataFrame"
287
+ f", available options {df.columns}"
288
+ )
289
+ if colors is None:
290
+ colors = ["cyan", "magenta"]
291
+ if ref_percentage_column not in reference_df.columns:
292
+ raise ValueError(
293
+ "Percentage column not found in reference DataFrame"
294
+ f", available options {reference_df.columns}"
295
+ )
296
+ fig, ax = plt.subplots(1, len(channels))
297
+ for idx, channel in enumerate(channels):
298
+ ax[idx].plot(
299
+ df[est_percentage_column], df[channel], label="Query", **plot_kwargs
300
+ )
301
+ ax[idx].plot(
302
+ reference_df[ref_percentage_column],
303
+ reference_df[channel],
304
+ color=colors[idx],
305
+ **plot_kwargs,
306
+ )
307
+ f_cyan = interpolate.interp1d(
308
+ reference_df[ref_percentage_column], reference_df[channel]
309
+ )
310
+ ax[idx].plot(
311
+ df[est_percentage_column],
312
+ f_cyan(df[est_percentage_column]),
313
+ lw=6,
314
+ alpha=0.5,
315
+ color="red",
316
+ label="Match",
317
+ )
318
+ if ground_truth is not None:
319
+ ax[idx].plot(
320
+ ground_truth[ref_percentage_column],
321
+ ground_truth[channel],
322
+ label="Ground truth",
323
+ lw=3,
324
+ )
325
+
326
+ ax[idx].set_ylabel(f"{channel.capitalize()} intensity / arb. u.")
327
+ ax[idx].set_xlabel("Cell cycle percentage")
328
+ if idx == 0:
329
+ ax[idx].legend()
330
+ plt.tight_layout()
331
+
332
+
333
+ def plot_query_vs_reference_in_time(
334
+ reference_df: pd.DataFrame,
335
+ df: pd.DataFrame,
336
+ channels: List[str],
337
+ ref_time_column: str = "time",
338
+ query_time_column: str = "time",
339
+ colors: Optional[List[str]] = None,
340
+ channel_titles: Optional[List[str]] = None,
341
+ fig_title: Optional[str] = None,
342
+ **plot_kwargs: bool,
343
+ ) -> None:
344
+ """Plot query and alignment to reference curve.
345
+
346
+ Parameters
347
+ ----------
348
+ reference_df: pd.DataFrame
349
+ DataFrame with reference curve data
350
+ df: pd.DataFrame
351
+ DataFrame used for query
352
+ channels: List[str]
353
+ Name of the channels
354
+ ref_time_column: str
355
+ Name of column with times of reference curve
356
+ query_time_column: str
357
+ Name of column with times in query
358
+ colors: List[str]
359
+ Colors for plot
360
+ plot_kwargs: dict
361
+ Kwargs to be passed to matplotlib
362
+ channel_titles: Optional[List]
363
+ titles for each channel
364
+ fig_title: Optional[str]
365
+ Figure title
366
+ """
367
+ for channel in channels:
368
+ if channel not in reference_df.columns:
369
+ raise ValueError(f"Channel {channel} not in reference DataFrame")
370
+ if channel not in df.columns:
371
+ raise ValueError(f"Channel {channel} not in query DataFrame")
372
+ if query_time_column not in df.columns:
373
+ raise ValueError(
374
+ f"Time column not found in query DataFrame, available options {df.columns}"
375
+ )
376
+ if channel_titles is not None:
377
+ if len(channels) != len(channel_titles):
378
+ raise ValueError("Provide a channel name for each channel")
379
+ if ref_time_column not in reference_df.columns:
380
+ raise ValueError(
381
+ "Time column not found in reference DataFrame"
382
+ f", available options {reference_df.columns}"
383
+ )
384
+ if colors is None:
385
+ colors = ["cyan", "magenta"]
386
+ fig, ax = plt.subplots(1, len(channels))
387
+ if fig_title is not None:
388
+ fig.suptitle(fig_title)
389
+ for idx, channel in enumerate(channels):
390
+ ax[idx].plot(
391
+ df[query_time_column],
392
+ df[channel],
393
+ label="Query",
394
+ color="blue",
395
+ **plot_kwargs,
396
+ )
397
+ ax[idx].plot(
398
+ reference_df[ref_time_column],
399
+ reference_df[channel],
400
+ color=colors[idx],
401
+ **plot_kwargs,
402
+ )
403
+ ax[idx].set_yticks([])
404
+ ax[idx].set_xlabel("Time / h")
405
+ if idx == 0:
406
+ ax[idx].set_ylabel("Intensity / arb. u.")
407
+ ax[idx].legend()
408
+ if channel_titles is not None:
409
+ ax[idx].set_title(channel_titles[idx])
410
+ plt.tight_layout()
411
+
412
+
413
+ def get_phase_color(phase: str) -> tuple:
414
+ """Get color for a certain phase."""
415
+ if phase == "G1":
416
+ return (0.09019607843137255, 0.7450980392156863, 0.8117647058823529, 1.0)
417
+ elif phase == "S/G2/M":
418
+ return (0.75, 0.0, 0.75, 1.0)
419
+ else:
420
+ return (0.5019607843137255, 0.5019607843137255, 0.5019607843137255, 1.0)
421
+
422
+
423
+ def get_percentage_color(percentage: float) -> tuple:
424
+ """Get color corresponding to percentage."""
425
+ cmap_name = "cool"
426
+ cmap = colormaps.get(cmap_name)
427
+ if np.isnan(percentage):
428
+ print("WARNING: NaN value detected, plot will be transparent")
429
+ rgba_value = (0, 0, 0, 0)
430
+ else:
431
+ rgba_value = cmap(percentage / 100.0)
432
+ return (rgba_value[0], rgba_value[1], rgba_value[2], 1.0)
433
+
434
+
435
+ def plot_cell_trajectory(
436
+ track_df: pd.DataFrame,
437
+ track_id_name: str,
438
+ min_track_length: int = 30,
439
+ centroid0_name: str = "centroid-0",
440
+ centroid1_name: str = "centroid-1",
441
+ phase_column: Optional[str] = None,
442
+ percentage_column: Optional[str] = None,
443
+ coloring_mode: str = "phase",
444
+ line_cycle: Optional[list] = None,
445
+ **kwargs: int,
446
+ ) -> None:
447
+ """Plot cell migration trajectories with phase or percentage-based coloring.
448
+
449
+ Parameters
450
+ ----------
451
+ track_df : pandas.DataFrame
452
+ DataFrame containing cell tracking data.
453
+ track_id_name : str
454
+ Column name containing unique track identifiers.
455
+ min_track_length : int, optional
456
+ Minimum number of timepoints required to include a track, default is 30.
457
+ centroid0_name : str, optional
458
+ Column name for x-coordinate of cell centroid, default is "centroid-0".
459
+ centroid1_name : str, optional
460
+ Column name for y-coordinate of cell centroid, default is "centroid-1".
461
+ phase_column : str, optional
462
+ Column name containing cell cycle phase information, default is None.
463
+ percentage_column : str, optional
464
+ Column name containing percentage values for coloring, default is None.
465
+ coloring_mode : str, optional
466
+ Color tracks by cell cycle phase (`phase`) or by percentage (`percentage`)
467
+ line_cycle: list
468
+ Cycle through the list, can help with visualization
469
+ kwargs: dict, optional
470
+ Kwargs are directly passed to the LineCollection, use it to adjust
471
+ the linestyle for example
472
+
473
+ Notes
474
+ -----
475
+ Phase or percentage columns need to be provided for the respective coloring.
476
+ If not, an error will be raised.
477
+
478
+ """
479
+ # inital checks
480
+ possible_coloring = ["phase", "percentage"]
481
+ if coloring_mode not in possible_coloring:
482
+ raise ValueError(f"coloring_mode needs to be one {possible_coloring}")
483
+
484
+ if phase_column is None and coloring_mode == "phase":
485
+ raise ValueError("No phase column value provided but phase coloring required.")
486
+ if percentage_column is None and coloring_mode == "percentage":
487
+ raise ValueError(
488
+ "No percentage column value provided but percentage coloring required."
489
+ )
490
+
491
+ if "ls" in kwargs or "linestyles" in kwargs:
492
+ raise ValueError("Set the linestyles via line_cycle argument.")
493
+ # default: all curves solid
494
+ if line_cycle is None:
495
+ line_cycle = ["solid"]
496
+ linecycler = cycle(line_cycle)
497
+ # data structures
498
+ line_collections = []
499
+
500
+ # populate data structures
501
+ indices = track_df[track_id_name].unique()
502
+ xmin = np.inf
503
+ xmax = -np.inf
504
+ ymin = np.inf
505
+ ymax = -np.inf
506
+ for index in indices:
507
+ if len(track_df.loc[track_df[track_id_name] == index]) < min_track_length:
508
+ continue
509
+ centroids = track_df.loc[
510
+ track_df[track_id_name] == index, [centroid0_name, centroid1_name]
511
+ ].to_numpy()
512
+ # set start location to (0, 0)
513
+ centroids[:, 0] -= centroids[0, 0]
514
+ centroids[:, 1] -= centroids[0, 1]
515
+ xmin = min(xmin, centroids[:, 0].min())
516
+ xmax = max(xmax, centroids[:, 0].max())
517
+ ymin = min(ymin, centroids[:, 1].min())
518
+ ymax = max(ymax, centroids[:, 1].max())
519
+ lines = np.c_[
520
+ centroids[:-1, 0], centroids[:-1, 1], centroids[1:, 0], centroids[1:, 1]
521
+ ]
522
+ if phase_column is not None:
523
+ phase_colors = (
524
+ track_df.loc[track_df[track_id_name] == index, phase_column]
525
+ .map(get_phase_color)
526
+ .to_list()
527
+ )
528
+ if percentage_column is not None:
529
+ phase_colors = (
530
+ track_df.loc[track_df[track_id_name] == index, percentage_column]
531
+ .map(get_percentage_color)
532
+ .to_list()
533
+ )
534
+
535
+ line_collections.append(
536
+ LineCollection(
537
+ lines.reshape(-1, 2, 2),
538
+ colors=phase_colors,
539
+ ls=next(linecycler),
540
+ **kwargs,
541
+ )
542
+ )
543
+ fig, ax = plt.subplots()
544
+ for line_collection in line_collections:
545
+ ax.add_collection(line_collection)
546
+ ax.margins(0.05)
547
+ plt.xlabel(r"X in $\mu$m")
548
+ plt.ylabel(r"Y in $\mu$m")
fucciphase/py.typed ADDED
@@ -0,0 +1,5 @@
1
+ You may remove this file if you don't intend to add types to your package
2
+
3
+ Details at:
4
+
5
+ https://mypy.readthedocs.io/en/stable/installed_packages.html#creating-pep-561-compatible-packages