msreport 0.0.26__py3-none-any.whl → 0.0.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- msreport/__init__.py +4 -6
- msreport/aggregate/condense.py +1 -1
- msreport/aggregate/pivot.py +1 -0
- msreport/aggregate/summarize.py +2 -2
- msreport/analyze.py +171 -38
- msreport/errors.py +1 -2
- msreport/export.py +16 -13
- msreport/fasta.py +2 -1
- msreport/helper/__init__.py +7 -7
- msreport/helper/calc.py +29 -24
- msreport/helper/maxlfq.py +2 -2
- msreport/helper/table.py +5 -6
- msreport/impute.py +7 -8
- msreport/isobar.py +10 -9
- msreport/normalize.py +54 -36
- msreport/peptidoform.py +6 -4
- msreport/plot/__init__.py +41 -0
- msreport/plot/_partial_plots.py +159 -0
- msreport/plot/comparison.py +490 -0
- msreport/plot/distribution.py +253 -0
- msreport/plot/multivariate.py +355 -0
- msreport/plot/quality.py +431 -0
- msreport/plot/style.py +286 -0
- msreport/plot/style_sheets/msreport-notebook.mplstyle +57 -0
- msreport/plot/style_sheets/seaborn-whitegrid.mplstyle +45 -0
- msreport/qtable.py +109 -17
- msreport/reader.py +73 -79
- msreport/rinterface/__init__.py +2 -1
- msreport/rinterface/limma.py +2 -1
- msreport/rinterface/rinstaller.py +3 -3
- {msreport-0.0.26.dist-info → msreport-0.0.28.dist-info}/METADATA +7 -3
- msreport-0.0.28.dist-info/RECORD +38 -0
- msreport/plot.py +0 -1132
- msreport-0.0.26.dist-info/RECORD +0 -30
- {msreport-0.0.26.dist-info → msreport-0.0.28.dist-info}/WHEEL +0 -0
- {msreport-0.0.26.dist-info → msreport-0.0.28.dist-info}/licenses/LICENSE.txt +0 -0
- {msreport-0.0.26.dist-info → msreport-0.0.28.dist-info}/top_level.txt +0 -0
msreport/plot/quality.py
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
import warnings
|
|
3
|
+
|
|
4
|
+
import matplotlib.colors as mcolors
|
|
5
|
+
import matplotlib.pyplot as plt
|
|
6
|
+
import numpy as np
|
|
7
|
+
import seaborn as sns
|
|
8
|
+
|
|
9
|
+
import msreport.helper
|
|
10
|
+
from msreport.qtable import Qtable
|
|
11
|
+
|
|
12
|
+
from ._partial_plots import box_and_bars
|
|
13
|
+
from .style import ColorWheelDict, with_active_style
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@with_active_style
|
|
17
|
+
def missing_values_vertical(
|
|
18
|
+
qtable: Qtable,
|
|
19
|
+
exclude_invalid: bool = True,
|
|
20
|
+
) -> tuple[plt.Figure, list[plt.Axes]]:
|
|
21
|
+
"""Vertical bar plot to analyze the completeness of quantification.
|
|
22
|
+
|
|
23
|
+
Requires the columns "Missing experiment_name" and "Events experiment_name", which
|
|
24
|
+
are added by calling msreport.analyze.analyze_missingness(qtable: Qtable).
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
qtable: A `Qtable` instance, which data is used for plotting.
|
|
28
|
+
exclude_invalid: If True, rows are filtered according to the Boolean entries of
|
|
29
|
+
the "Valid" column.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
A matplotlib Figure and a list of Axes objects containing the missing values
|
|
33
|
+
plots.
|
|
34
|
+
"""
|
|
35
|
+
# add a deprecation warning here
|
|
36
|
+
warnings.warn(
|
|
37
|
+
(
|
|
38
|
+
"The function `missing_values_vertical` is deprecated. Use"
|
|
39
|
+
"`missing_values_horizontal` instead."
|
|
40
|
+
),
|
|
41
|
+
DeprecationWarning,
|
|
42
|
+
stacklevel=2,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
experiments = qtable.get_experiments()
|
|
46
|
+
num_experiments = len(experiments)
|
|
47
|
+
qtable_data = qtable.get_data(exclude_invalid=exclude_invalid)
|
|
48
|
+
|
|
49
|
+
barwidth = 0.8
|
|
50
|
+
barcolors = ["#31A590", "#FAB74E", "#EB3952"]
|
|
51
|
+
figwidth = (num_experiments * 1.2) + 0.5
|
|
52
|
+
figsize = (figwidth, 3.5)
|
|
53
|
+
xtick_labels = ["No missing", "Some missing", "All missing"]
|
|
54
|
+
|
|
55
|
+
fig, axes = plt.subplots(1, num_experiments, figsize=figsize, sharey=True)
|
|
56
|
+
for exp_num, exp in enumerate(experiments):
|
|
57
|
+
ax = axes[exp_num]
|
|
58
|
+
|
|
59
|
+
exp_missing = qtable_data[f"Missing {exp}"]
|
|
60
|
+
exp_values = qtable_data[f"Events {exp}"]
|
|
61
|
+
missing_none = (exp_missing == 0).sum()
|
|
62
|
+
missing_some = ((exp_missing > 0) & (exp_values > 0)).sum()
|
|
63
|
+
missing_all = (exp_values == 0).sum()
|
|
64
|
+
|
|
65
|
+
y = [missing_none, missing_some, missing_all]
|
|
66
|
+
x = range(len(y))
|
|
67
|
+
ax.bar(x, y, width=barwidth, color=barcolors)
|
|
68
|
+
if exp_num == 0:
|
|
69
|
+
ax.set_ylabel("# Proteins")
|
|
70
|
+
ax.set_title(exp)
|
|
71
|
+
ax.set_xticks(np.array([0, 1, 2]) + 0.4)
|
|
72
|
+
ax.set_xticklabels(xtick_labels, rotation=45, va="top", ha="right")
|
|
73
|
+
ax.grid(False, axis="x")
|
|
74
|
+
sns.despine(top=True, right=True)
|
|
75
|
+
fig.tight_layout()
|
|
76
|
+
return fig, axes
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@with_active_style
|
|
80
|
+
def missing_values_horizontal(
|
|
81
|
+
qtable: Qtable,
|
|
82
|
+
exclude_invalid: bool = True,
|
|
83
|
+
) -> tuple[plt.Figure, plt.Axes]:
|
|
84
|
+
"""Horizontal bar plot to analyze the completeness of quantification.
|
|
85
|
+
|
|
86
|
+
Requires the columns "Missing experiment_name" and "Events experiment_name", which
|
|
87
|
+
are added by calling msreport.analyze.analyze_missingness(qtable: Qtable).
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
qtable: A `Qtable` instance, which data is used for plotting.
|
|
91
|
+
exclude_invalid: If True, rows are filtered according to the Boolean entries of
|
|
92
|
+
the "Valid" column.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
A matplotlib Figure and Axes object, containing the missing values plot.
|
|
96
|
+
"""
|
|
97
|
+
experiments = qtable.get_experiments()
|
|
98
|
+
num_experiments = len(experiments)
|
|
99
|
+
qtable_data = qtable.get_data(exclude_invalid=exclude_invalid)
|
|
100
|
+
|
|
101
|
+
data: dict[str, list] = {"exp": [], "max": [], "some": [], "min": []}
|
|
102
|
+
for exp in experiments:
|
|
103
|
+
exp_missing = qtable_data[f"Missing {exp}"]
|
|
104
|
+
total = len(exp_missing)
|
|
105
|
+
num_replicates = len(qtable.get_samples(exp))
|
|
106
|
+
missing_all = (exp_missing == num_replicates).sum()
|
|
107
|
+
missing_none = (exp_missing == 0).sum()
|
|
108
|
+
with_missing_some = total - missing_all
|
|
109
|
+
|
|
110
|
+
data["exp"].append(exp)
|
|
111
|
+
data["max"].append(total)
|
|
112
|
+
data["some"].append(with_missing_some)
|
|
113
|
+
data["min"].append(missing_none)
|
|
114
|
+
|
|
115
|
+
bar_width = 0.35
|
|
116
|
+
|
|
117
|
+
suptitle_space_inch = 0.4
|
|
118
|
+
ax_height_inch = num_experiments * bar_width
|
|
119
|
+
ax_width_inch = 4
|
|
120
|
+
fig_height = ax_height_inch + suptitle_space_inch
|
|
121
|
+
fig_width = ax_width_inch
|
|
122
|
+
fig_size = (fig_width, fig_height)
|
|
123
|
+
|
|
124
|
+
subplot_top = 1 - (suptitle_space_inch / fig_height)
|
|
125
|
+
|
|
126
|
+
fig, ax = plt.subplots(figsize=fig_size)
|
|
127
|
+
fig.subplots_adjust(bottom=0, top=subplot_top, left=0, right=1)
|
|
128
|
+
fig.suptitle("Completeness of quantification per experiment", y=1)
|
|
129
|
+
|
|
130
|
+
sns.barplot(y="exp", x="max", data=data, label="All missing", color="#EB3952")
|
|
131
|
+
sns.barplot(y="exp", x="some", data=data, label="Some missing", color="#FAB74E")
|
|
132
|
+
sns.barplot(y="exp", x="min", data=data, label="None missing", color="#31A590")
|
|
133
|
+
|
|
134
|
+
ax.set_ylabel("")
|
|
135
|
+
ax.set_xlabel("")
|
|
136
|
+
ax.set_xlim(0, total)
|
|
137
|
+
|
|
138
|
+
ax.legend().remove()
|
|
139
|
+
handles, labels = ax.get_legend_handles_labels()
|
|
140
|
+
legend_ygap_inches = 0.3
|
|
141
|
+
legend_bbox_y = 0 - (legend_ygap_inches / fig.get_figheight())
|
|
142
|
+
|
|
143
|
+
fig.legend(
|
|
144
|
+
handles[::-1],
|
|
145
|
+
labels[::-1],
|
|
146
|
+
bbox_to_anchor=(0.5, legend_bbox_y),
|
|
147
|
+
loc="upper center",
|
|
148
|
+
ncol=3,
|
|
149
|
+
frameon=False,
|
|
150
|
+
borderaxespad=0,
|
|
151
|
+
handlelength=0.95,
|
|
152
|
+
handleheight=1,
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
ax.tick_params(axis="y", labelsize=plt.rcParams["axes.labelsize"])
|
|
156
|
+
ax.grid(axis="x", linestyle="solid")
|
|
157
|
+
sns.despine(fig=fig, top=True, right=True, bottom=True)
|
|
158
|
+
|
|
159
|
+
return fig, ax
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
@with_active_style
|
|
163
|
+
def contaminants(
|
|
164
|
+
qtable: Qtable, tag: str = "iBAQ intensity"
|
|
165
|
+
) -> tuple[plt.Figure, plt.Axes]:
|
|
166
|
+
"""A bar plot that displays relative contaminant amounts (iBAQ) per sample.
|
|
167
|
+
|
|
168
|
+
Requires "iBAQ intensity" columns for each sample, and a "Potential contaminant"
|
|
169
|
+
column to identify the potential contaminant entries.
|
|
170
|
+
|
|
171
|
+
The relative iBAQ values are calculated as:
|
|
172
|
+
sum of contaminant iBAQ intensities / sum of all iBAQ intensities * 100
|
|
173
|
+
|
|
174
|
+
It is possible to use intensity columns that are either log-transformed or not. The
|
|
175
|
+
intensity values undergo an automatic evaluation to determine if they are already
|
|
176
|
+
in log-space, and if necessary, they are transformed accordingly.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
qtable: A `Qtable` instance, which data is used for plotting.
|
|
180
|
+
tag: A string that is used to extract iBAQ intensity containing columns.
|
|
181
|
+
Default "iBAQ intensity".
|
|
182
|
+
|
|
183
|
+
Raises:
|
|
184
|
+
ValueError: If the "Potential contaminant" column is missing in the Qtable data.
|
|
185
|
+
If the Qtable does not contain any columns for the specified 'tag'.
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
A matplotlib Figure and an Axes object, containing the contaminants plot.
|
|
189
|
+
"""
|
|
190
|
+
if "Potential contaminant" not in qtable.data.columns:
|
|
191
|
+
raise ValueError(
|
|
192
|
+
"The 'Potential contaminant' column is missing in the Qtable data."
|
|
193
|
+
)
|
|
194
|
+
data = qtable.make_sample_table(tag, samples_as_columns=True)
|
|
195
|
+
if data.empty:
|
|
196
|
+
raise ValueError(f"The Qtable does not contain any '{tag}' columns.")
|
|
197
|
+
if msreport.helper.intensities_in_logspace(data):
|
|
198
|
+
data = np.power(2, data)
|
|
199
|
+
|
|
200
|
+
relative_intensity = data / data.sum() * 100
|
|
201
|
+
contaminants = qtable["Potential contaminant"]
|
|
202
|
+
samples = data.columns.to_list()
|
|
203
|
+
|
|
204
|
+
color_wheel = ColorWheelDict()
|
|
205
|
+
colors = [color_wheel[exp] for exp in qtable.get_experiments(samples)]
|
|
206
|
+
dark_colors = [
|
|
207
|
+
color_wheel.modified_color(exp, 0.4) for exp in qtable.get_experiments(samples)
|
|
208
|
+
]
|
|
209
|
+
|
|
210
|
+
num_samples = len(samples)
|
|
211
|
+
x_values = range(relative_intensity.shape[1])
|
|
212
|
+
bar_values = relative_intensity[contaminants].sum(axis=0)
|
|
213
|
+
|
|
214
|
+
suptitle_space_inch = 0.4
|
|
215
|
+
ax_height_inch = 1.6
|
|
216
|
+
bar_width_inches = 0.24
|
|
217
|
+
x_padding = 0.24
|
|
218
|
+
|
|
219
|
+
fig_height = ax_height_inch + suptitle_space_inch
|
|
220
|
+
fig_width = (num_samples + (2 * x_padding)) * bar_width_inches
|
|
221
|
+
fig_size = (fig_width, fig_height)
|
|
222
|
+
|
|
223
|
+
subplot_top = 1 - (suptitle_space_inch / fig_height)
|
|
224
|
+
|
|
225
|
+
bar_width = 0.8
|
|
226
|
+
bar_half_width = 0.5
|
|
227
|
+
lower_xbound = (0 - bar_half_width) - x_padding
|
|
228
|
+
upper_xbound = (num_samples - 1) + bar_half_width + x_padding
|
|
229
|
+
min_upper_ybound = 5
|
|
230
|
+
|
|
231
|
+
fig, ax = plt.subplots(figsize=fig_size)
|
|
232
|
+
fig.subplots_adjust(bottom=0, top=subplot_top, left=0, right=1)
|
|
233
|
+
fig.suptitle("Relative amount of contaminants", y=1)
|
|
234
|
+
|
|
235
|
+
ax.bar(
|
|
236
|
+
x_values,
|
|
237
|
+
bar_values,
|
|
238
|
+
width=bar_width,
|
|
239
|
+
color=colors,
|
|
240
|
+
edgecolor=dark_colors,
|
|
241
|
+
zorder=3,
|
|
242
|
+
)
|
|
243
|
+
ax.set_xticks(x_values)
|
|
244
|
+
ax.set_xticklabels(samples, fontsize=plt.rcParams["axes.labelsize"], rotation=90)
|
|
245
|
+
ax.set_ylabel(f"Sum contaminant\n{tag} [%]")
|
|
246
|
+
|
|
247
|
+
ax.grid(False, axis="x")
|
|
248
|
+
sns.despine(top=True, right=True)
|
|
249
|
+
|
|
250
|
+
ax.set_ylim(0, max(min_upper_ybound, ax.get_ylim()[1]))
|
|
251
|
+
ax.set_xlim(lower_xbound, upper_xbound)
|
|
252
|
+
return fig, ax
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
@with_active_style
|
|
256
|
+
def sample_intensities(
|
|
257
|
+
qtable: Qtable, tag: str = "Intensity", exclude_invalid: bool = True
|
|
258
|
+
) -> tuple[plt.Figure, list[plt.Axes]]:
|
|
259
|
+
"""Figure to compare the overall quantitative similarity of samples.
|
|
260
|
+
|
|
261
|
+
Generates two subplots to compare the intensities of multiple samples. For the top
|
|
262
|
+
subplot a pseudo reference sample is generated by calculating the average intensity
|
|
263
|
+
values of all samples. For each row and sample the log2 ratios to the pseudo
|
|
264
|
+
reference are calculated. Only rows without missing values are selected, and for
|
|
265
|
+
each sample the log2 ratios to the pseudo reference are displayed as a box plot. The
|
|
266
|
+
lower subplot displays the summed intensity of all rows per sample as bar plots.
|
|
267
|
+
|
|
268
|
+
It is possible to use intensity columns that are either log-transformed or not. The
|
|
269
|
+
intensity values undergo an automatic evaluation to determine if they are already
|
|
270
|
+
in log-space, and if necessary, they are transformed accordingly.
|
|
271
|
+
|
|
272
|
+
Args:
|
|
273
|
+
qtable: A `Qtable` instance, which data is used for plotting.
|
|
274
|
+
tag: A string that is used to extract intensity containing columns.
|
|
275
|
+
Default "Intensity".
|
|
276
|
+
exclude_invalid: If True, rows are filtered according to the Boolean entries of
|
|
277
|
+
the "Valid" column.
|
|
278
|
+
|
|
279
|
+
Returns:
|
|
280
|
+
A matplotlib Figure and a list of Axes objects, containing the intensity plots.
|
|
281
|
+
"""
|
|
282
|
+
table = qtable.make_sample_table(
|
|
283
|
+
tag, samples_as_columns=True, exclude_invalid=exclude_invalid
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
table = table.replace({0: np.nan})
|
|
287
|
+
if msreport.helper.intensities_in_logspace(table):
|
|
288
|
+
log2_table = table
|
|
289
|
+
table = np.power(2, log2_table)
|
|
290
|
+
else:
|
|
291
|
+
log2_table = np.log2(table)
|
|
292
|
+
samples = table.columns.tolist()
|
|
293
|
+
|
|
294
|
+
finite_values = log2_table.isna().sum(axis=1) == 0
|
|
295
|
+
pseudo_ref = np.nanmean(log2_table[finite_values], axis=1)
|
|
296
|
+
log2_ratios = log2_table[finite_values].subtract(pseudo_ref, axis=0)
|
|
297
|
+
|
|
298
|
+
bar_values = table.sum()
|
|
299
|
+
box_values = [log2_ratios[c] for c in log2_ratios.columns]
|
|
300
|
+
color_wheel = ColorWheelDict()
|
|
301
|
+
colors = [color_wheel[exp] for exp in qtable.get_experiments(samples)]
|
|
302
|
+
edge_colors = [
|
|
303
|
+
color_wheel.modified_color(exp, 0.4) for exp in qtable.get_experiments(samples)
|
|
304
|
+
]
|
|
305
|
+
|
|
306
|
+
fig, axes = box_and_bars(
|
|
307
|
+
box_values, bar_values, samples, colors=colors, edge_colors=edge_colors
|
|
308
|
+
)
|
|
309
|
+
fig.suptitle(f'Comparison of "{tag}" values', y=1)
|
|
310
|
+
axes[0].set_ylabel("Ratio [log2]\nto pseudo reference")
|
|
311
|
+
axes[1].set_ylabel("Total intensity")
|
|
312
|
+
return fig, axes
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
@with_active_style
|
|
316
|
+
def sample_correlation(
|
|
317
|
+
qtable, exclude_invalid: bool = True, labels: bool = False
|
|
318
|
+
) -> tuple[plt.Figure, list[plt.Axes]]:
|
|
319
|
+
"""Generates a pair-wise correlation matrix of samples 'Expression' values.
|
|
320
|
+
|
|
321
|
+
Correlation values are calculated using the Pearson method and the "Expression"
|
|
322
|
+
values.
|
|
323
|
+
|
|
324
|
+
Args:
|
|
325
|
+
qtable: A `Qtable` instance, which data is used for plotting.
|
|
326
|
+
exclude_invalid: If True, rows are filtered according to the Boolean entries of
|
|
327
|
+
the "Valid" column.
|
|
328
|
+
labels: If True, correlation values are displayed in the heatmap.
|
|
329
|
+
|
|
330
|
+
Raises:
|
|
331
|
+
ValueError: If less than two samples are present in the qtable.
|
|
332
|
+
|
|
333
|
+
Returns:
|
|
334
|
+
A matplotlib Figure and a list of Axes objects, containing the correlation
|
|
335
|
+
matrix plot and the color bar
|
|
336
|
+
"""
|
|
337
|
+
num_samples = qtable.design.shape[0]
|
|
338
|
+
if num_samples < 2:
|
|
339
|
+
raise ValueError(
|
|
340
|
+
"At least two samples are required to generate a correlation matrix."
|
|
341
|
+
)
|
|
342
|
+
data = qtable.make_expression_table(
|
|
343
|
+
samples_as_columns=True, exclude_invalid=exclude_invalid
|
|
344
|
+
)
|
|
345
|
+
samples = data.columns.tolist()
|
|
346
|
+
corr = data.corr()
|
|
347
|
+
mask = np.triu(np.ones_like(corr, dtype=bool))
|
|
348
|
+
|
|
349
|
+
num_cells = num_samples - 1
|
|
350
|
+
cell_size_inch = 0.3
|
|
351
|
+
suptitle_space_inch = 0.4
|
|
352
|
+
ax_height_inch = ax_width_inch = cell_size_inch * num_cells
|
|
353
|
+
ax_wspace_inch = 0.4
|
|
354
|
+
cbar_height_inch = max(1.2, min(3, cell_size_inch * num_cells))
|
|
355
|
+
cbar_width_inch = 0.27
|
|
356
|
+
width_ratios = [ax_width_inch, cbar_width_inch]
|
|
357
|
+
subplot_wspace = ax_wspace_inch / np.mean([ax_width_inch, cbar_width_inch])
|
|
358
|
+
|
|
359
|
+
fig_width = ax_width_inch + cbar_width_inch + ax_wspace_inch
|
|
360
|
+
fig_height = ax_height_inch + suptitle_space_inch
|
|
361
|
+
fig_size = (fig_width, fig_height)
|
|
362
|
+
|
|
363
|
+
subplot_top = 1 - (suptitle_space_inch / fig_height)
|
|
364
|
+
cbar_width = cbar_width_inch / fig_width
|
|
365
|
+
cbar_height = cbar_height_inch / fig_height
|
|
366
|
+
cbar_x0 = (ax_width_inch + ax_wspace_inch) / fig_width
|
|
367
|
+
cbar_y0 = (ax_height_inch / fig_height) - cbar_height
|
|
368
|
+
|
|
369
|
+
fig, axes = plt.subplots(
|
|
370
|
+
1,
|
|
371
|
+
2,
|
|
372
|
+
figsize=fig_size,
|
|
373
|
+
gridspec_kw={
|
|
374
|
+
"bottom": 0,
|
|
375
|
+
"top": subplot_top,
|
|
376
|
+
"left": 0,
|
|
377
|
+
"right": 1,
|
|
378
|
+
"wspace": subplot_wspace,
|
|
379
|
+
"width_ratios": width_ratios,
|
|
380
|
+
},
|
|
381
|
+
)
|
|
382
|
+
fig.suptitle('Pairwise correlation matrix of sample "Expression" values', y=1)
|
|
383
|
+
ax_heatmap, ax_cbar = axes
|
|
384
|
+
ax_cbar.set_position((cbar_x0, cbar_y0, cbar_width, cbar_height))
|
|
385
|
+
|
|
386
|
+
palette = sns.color_palette("rainbow", desat=0.8)
|
|
387
|
+
cmap = mcolors.LinearSegmentedColormap.from_list("rainbow_desat", palette)
|
|
388
|
+
sns.heatmap(
|
|
389
|
+
corr,
|
|
390
|
+
mask=mask,
|
|
391
|
+
cmap=cmap,
|
|
392
|
+
vmax=1,
|
|
393
|
+
vmin=0.5,
|
|
394
|
+
square=False,
|
|
395
|
+
linewidths=0.5,
|
|
396
|
+
ax=ax_heatmap,
|
|
397
|
+
)
|
|
398
|
+
cbar = ax_heatmap.collections[0].colorbar
|
|
399
|
+
if cbar is not None:
|
|
400
|
+
cbar.remove()
|
|
401
|
+
fig.colorbar(ax_heatmap.collections[0], cax=ax_cbar)
|
|
402
|
+
|
|
403
|
+
if labels:
|
|
404
|
+
for i, j in itertools.product(range(num_cells + 1), range(num_cells + 1)):
|
|
405
|
+
if i <= j:
|
|
406
|
+
continue
|
|
407
|
+
corr_value = corr.iloc[i, j]
|
|
408
|
+
ax_heatmap.text(
|
|
409
|
+
j + 0.5,
|
|
410
|
+
i + 0.5,
|
|
411
|
+
f"{corr_value:.2f}",
|
|
412
|
+
ha="center",
|
|
413
|
+
va="center",
|
|
414
|
+
fontsize=8, # Fontsize cannot be larger to fit in the cell
|
|
415
|
+
)
|
|
416
|
+
# Need to manually set ticks because sometimes not all are properly included
|
|
417
|
+
ax_heatmap.set_yticks([i + 0.5 for i in range(1, len(samples))])
|
|
418
|
+
ax_heatmap.set_yticklabels(samples[1:], rotation=0)
|
|
419
|
+
ax_heatmap.set_xticks([i + 0.5 for i in range(0, len(samples) - 1)])
|
|
420
|
+
ax_heatmap.set_xticklabels(samples[:-1], rotation=90)
|
|
421
|
+
|
|
422
|
+
ax_heatmap.grid(False)
|
|
423
|
+
ax_heatmap.tick_params(labelsize=plt.rcParams["axes.labelsize"])
|
|
424
|
+
ax_heatmap.set_xlim(0, num_cells)
|
|
425
|
+
ax_heatmap.set_ylim(1 + num_cells, 1)
|
|
426
|
+
|
|
427
|
+
sns.despine(left=False, bottom=False, ax=ax_heatmap)
|
|
428
|
+
for ax in [ax_heatmap, ax_cbar]:
|
|
429
|
+
for spine in ["top", "right", "left", "bottom"]:
|
|
430
|
+
ax.spines[spine].set_linewidth(0.75)
|
|
431
|
+
return fig, axes
|