msreport 0.0.27__py3-none-any.whl → 0.0.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,253 @@
1
+ import itertools
2
+ import warnings
3
+ from collections.abc import Iterable, Sequence
4
+ from typing import Optional
5
+
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import pandas as pd
9
+ import seaborn as sns
10
+
11
+ from msreport.qtable import Qtable
12
+
13
+ from .style import ColorWheelDict, with_active_style
14
+
15
+
16
+ @with_active_style
17
+ def replicate_ratios(
18
+ qtable: Qtable,
19
+ exclude_invalid: bool = True,
20
+ xlim: Iterable[float] = (-2, 2),
21
+ ) -> tuple[plt.Figure, list[plt.Axes]]:
22
+ """Figure to compare the similarity of expression values between replicates.
23
+
24
+ Displays the distribution of pair-wise log2 ratios between samples of the same
25
+ experiment. Comparisons of the same experiment are placed in the same row. Requires
26
+ log2 transformed expression values.
27
+
28
+ Args:
29
+ qtable: A `Qtable` instance, which data is used for plotting.
30
+ exclude_invalid: If True, rows are filtered according to the Boolean entries of
31
+ the "Valid" column.
32
+ xlim: Specifies the displayed range for the log2 ratios on the x-axis. Default
33
+ is from -2 to 2.
34
+
35
+ Returns:
36
+ A matplotlib Figure and a list of Axes objects, containing the comparison plots.
37
+ """
38
+ tag: str = "Expression"
39
+ table = qtable.make_sample_table(
40
+ tag, samples_as_columns=True, exclude_invalid=exclude_invalid
41
+ )
42
+ design = qtable.get_design()
43
+
44
+ color_wheel = ColorWheelDict()
45
+ for exp in design["Experiment"].unique():
46
+ _ = color_wheel[exp]
47
+
48
+ experiments = []
49
+ for experiment in design["Experiment"].unique():
50
+ if len(qtable.get_samples(experiment)) >= 2:
51
+ experiments.append(experiment)
52
+ if not experiments:
53
+ fig, ax = plt.subplots(1, 1, figsize=(2, 1.3))
54
+ fig.suptitle("Pair wise comparison of replicates", y=1.1)
55
+ ax.text(0.5, 0.5, "No replicate\ndata available", ha="center", va="center")
56
+ ax.grid(False)
57
+ ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
58
+ sns.despine(top=False, right=False, fig=fig)
59
+ return fig, np.array([ax])
60
+
61
+ num_experiments = len(experiments)
62
+ max_replicates = max([len(qtable.get_samples(exp)) for exp in experiments])
63
+ max_combinations = len(list(itertools.combinations(range(max_replicates), 2)))
64
+
65
+ suptitle_space_inch = 0.55
66
+ ax_height_inch = 0.6
67
+ ax_width_inch = 1.55
68
+ ax_hspace_inch = 0.35
69
+ fig_height = (
70
+ num_experiments * ax_height_inch
71
+ + (num_experiments - 1) * ax_hspace_inch
72
+ + suptitle_space_inch
73
+ )
74
+ fig_width = max_combinations * ax_width_inch
75
+ fig_size = (fig_width, fig_height)
76
+
77
+ subplot_top = 1 - (suptitle_space_inch / fig_height)
78
+ subplot_hspace = ax_hspace_inch / ax_height_inch
79
+
80
+ fig, axes = plt.subplots(
81
+ num_experiments, max_combinations, figsize=fig_size, sharex=True
82
+ )
83
+ if num_experiments == 1 and max_combinations == 1:
84
+ axes = np.array([[axes]])
85
+ elif num_experiments == 1:
86
+ axes = np.array([axes])
87
+ elif max_combinations == 1:
88
+ axes = np.array([axes]).T
89
+ fig.subplots_adjust(
90
+ bottom=0, top=subplot_top, left=0, right=1, hspace=subplot_hspace
91
+ )
92
+ fig.suptitle("Pair wise comparison of replicates", y=1)
93
+
94
+ for x_pos, experiment in enumerate(experiments):
95
+ sample_combinations = itertools.combinations(qtable.get_samples(experiment), 2)
96
+ for y_pos, (s1, s2) in enumerate(sample_combinations):
97
+ s1_label = design.loc[(design["Sample"] == s1), "Replicate"].tolist()[0]
98
+ s2_label = design.loc[(design["Sample"] == s2), "Replicate"].tolist()[0]
99
+ ax = axes[x_pos, y_pos]
100
+ ratios = table[s1] - table[s2]
101
+ ratios = ratios[np.isfinite(ratios)]
102
+ ylabel = experiment if y_pos == 0 else ""
103
+ title = f"{s1_label} vs {s2_label}"
104
+ color = color_wheel[experiment]
105
+
106
+ sns.kdeplot(x=ratios, fill=True, ax=ax, zorder=3, color=color, alpha=0.5)
107
+ ax.set_title(title, fontsize=plt.rcParams["axes.labelsize"])
108
+ ax.set_ylabel(ylabel, rotation=0, va="center", ha="right")
109
+ ax.set_xlabel("Ratio [log2]")
110
+ ax.tick_params(labelleft=False)
111
+ ax.locator_params(axis="x", nbins=5)
112
+
113
+ axes[0, 0].set_xlim(xlim)
114
+ for ax in axes.flatten():
115
+ if not ax.has_data():
116
+ ax.remove()
117
+ continue
118
+
119
+ ax.axvline(x=0, color="#999999", lw=1, zorder=2)
120
+ ax.grid(False, axis="y")
121
+ sns.despine(top=True, right=True, fig=fig)
122
+
123
+ return fig, axes
124
+
125
+
126
+ @with_active_style
127
+ def experiment_ratios(
128
+ qtable: Qtable,
129
+ experiments: Optional[str] = None,
130
+ exclude_invalid: bool = True,
131
+ ylim: Sequence[float] = (-2, 2),
132
+ ) -> tuple[plt.Figure, list[plt.Axes]]:
133
+ """Figure to compare the similarity of expression values between experiments.
134
+
135
+ Intended to evaluate the bulk distribution of expression values after normalization.
136
+ For each experiment a subplot is generated, which displays the distribution of log2
137
+ ratios to a pseudo reference experiment as a density plot. The pseudo reference
138
+ values are calculated as the average intensity values of all experiments. Only rows
139
+ with quantitative values in all experiment are considered.
140
+
141
+ Requires "Events experiment" columns and that average experiment expression values
142
+ are calculated. This can be achieved by calling
143
+ `msreport.analyze.analyze_missingness(qtable: Qtable)` and
144
+ `msreport.analyze.calculate_experiment_means(qtable: Qtable)`.
145
+
146
+ Args:
147
+ qtable: A `Qtable` instance, which data is used for plotting.
148
+ experiments: Optional, list of experiments that will be displayed. If None, all
149
+ experiments from `qtable.design` will be used.
150
+ exclude_invalid: If True, rows are filtered according to the Boolean entries of
151
+ the "Valid" column.
152
+ ylim: Specifies the displayed range for the log2 ratios on the y-axis. Default
153
+ is from -2 to 2.
154
+
155
+ Raises:
156
+ ValueError: If only one experiment is specified in the `experiments` parameter
157
+ or if the specified experiments are not present in the qtable design.
158
+
159
+ Returns:
160
+ A matplotlib Figure and a list of Axes objects, containing the comparison plots.
161
+ """
162
+ tag: str = "Expression"
163
+
164
+ if experiments is not None and len(experiments) == 1:
165
+ raise ValueError(
166
+ "Only one experiment is specified, please provide at least two experiments."
167
+ )
168
+ elif experiments is not None:
169
+ experiments_not_in_design = set(experiments) - set(qtable.design["Experiment"])
170
+ if experiments_not_in_design:
171
+ raise ValueError(
172
+ "All experiments must be present in qtable.design. The following "
173
+ f"experiments are not present: {experiments_not_in_design}"
174
+ )
175
+ else:
176
+ experiments = qtable.design["Experiment"].unique().tolist()
177
+
178
+ if len(experiments) < 2:
179
+ fig, ax = plt.subplots(1, 1, figsize=(2.5, 1.3))
180
+ fig.suptitle("Comparison of experiments means", y=1.1)
181
+ ax.text(
182
+ 0.5,
183
+ 0.5,
184
+ "Comparison not possible.\nOnly one experiment\npresent in design.",
185
+ ha="center",
186
+ va="center",
187
+ )
188
+ ax.grid(False)
189
+ ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
190
+ sns.despine(top=False, right=False, fig=fig)
191
+ return fig, np.array([ax])
192
+
193
+ sample_data = qtable.make_sample_table(tag, samples_as_columns=True)
194
+ experiment_means = {}
195
+ for experiment in experiments:
196
+ samples = qtable.get_samples(experiment)
197
+ with warnings.catch_warnings():
198
+ warnings.simplefilter("ignore", category=RuntimeWarning)
199
+ row_means = np.nanmean(sample_data[samples], axis=1)
200
+ experiment_means[experiment] = row_means
201
+ experiment_data = pd.DataFrame(experiment_means)
202
+
203
+ # Only consider rows with quantitative values in all experiments
204
+ mask = np.all([(qtable.data[f"Events {exp}"] > 0) for exp in experiments], axis=0)
205
+ if exclude_invalid:
206
+ mask = mask & qtable["Valid"]
207
+ experiment_data = experiment_data[mask]
208
+ pseudo_reference = np.nanmean(experiment_data, axis=1)
209
+ ratio_data = experiment_data.subtract(pseudo_reference, axis=0)
210
+
211
+ color_wheel = ColorWheelDict()
212
+ for exp in qtable.design["Experiment"].unique():
213
+ _ = color_wheel[exp]
214
+ num_experiments = len(experiments)
215
+
216
+ suptitle_space_inch = 0.55
217
+ ax_height_inch = 1.25
218
+ ax_width_inch = 0.65
219
+ ax_wspace_inch = 0.2
220
+ fig_height = ax_height_inch + suptitle_space_inch
221
+ fig_width = num_experiments * ax_width_inch + (num_experiments - 1) * ax_wspace_inch
222
+ fig_size = (fig_width, fig_height)
223
+
224
+ subplot_top = 1 - (suptitle_space_inch / fig_height)
225
+ subplot_wspace = ax_wspace_inch / ax_width_inch
226
+
227
+ fig, axes = plt.subplots(1, num_experiments, figsize=fig_size, sharey=True)
228
+ fig.subplots_adjust(
229
+ bottom=0, top=subplot_top, left=0, right=1, wspace=subplot_wspace
230
+ )
231
+ fig.suptitle("Comparison of experiments means", y=1)
232
+
233
+ for exp_pos, experiment in enumerate(experiments):
234
+ ax = axes[exp_pos]
235
+ values = ratio_data[experiment]
236
+ color = color_wheel[experiment]
237
+ sns.kdeplot(y=values, fill=True, ax=ax, zorder=3, color=color, alpha=0.5)
238
+ if exp_pos == 0:
239
+ ax.set_title(
240
+ f"n={str(len(values))}",
241
+ fontsize=plt.rcParams["xtick.labelsize"],
242
+ loc="left",
243
+ )
244
+ ax.tick_params(labelbottom=False)
245
+ ax.set_xlabel(experiment, rotation=90)
246
+
247
+ axes[0].set_ylabel("Ratio [log2]\nto pseudo reference")
248
+ axes[0].set_ylim(ylim)
249
+ for ax in axes:
250
+ ax.axhline(y=0, color="#999999", lw=1, zorder=2)
251
+ ax.grid(False, axis="x")
252
+ sns.despine(top=True, right=True, fig=fig)
253
+ return fig, axes
@@ -0,0 +1,355 @@
1
+ from typing import Any
2
+
3
+ import adjustText
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import pandas as pd
7
+ import seaborn as sns
8
+ import sklearn.decomposition
9
+ import sklearn.preprocessing
10
+
11
+ import msreport.helper
12
+ from msreport.qtable import Qtable
13
+
14
+ from .style import ColorWheelDict, with_active_style
15
+
16
+
17
+ @with_active_style
18
+ def sample_pca(
19
+ qtable: Qtable,
20
+ tag: str = "Expression",
21
+ pc_x: str = "PC1",
22
+ pc_y: str = "PC2",
23
+ exclude_invalid: bool = True,
24
+ ) -> tuple[plt.Figure, list[plt.Axes]]:
25
+ """Figure to compare sample similarities with a principle component analysis.
26
+
27
+ On the left subplots two PCA components of log2 transformed, mean centered intensity
28
+ values are shown. On the right subplot the explained variance of the principle
29
+ components is display as barplots.
30
+
31
+ It is possible to use intensity columns that are either log-transformed or not. The
32
+ intensity values undergo an automatic evaluation to determine if they are already
33
+ in log-space, and if necessary, they are transformed accordingly.
34
+
35
+ Args:
36
+ qtable: A `Qtable` instance, which data is used for plotting.
37
+ tag: A string that is used to extract intensity containing columns.
38
+ Default "Expression".
39
+ pc_x: Principle component to plot on x-axis of the scatter plot, default "PC1".
40
+ The number of calculated principal components is equal to the number of
41
+ samples.
42
+ pc_y: Principle component to plot on y-axis of the scatter plot, default "PC2".
43
+ The number of calculated principal components is equal to the number of
44
+ samples.
45
+ exclude_invalid: If True, rows are filtered according to the Boolean entries of
46
+ the "Valid" column.
47
+
48
+ Returns:
49
+ A matplotlib Figure and a list of Axes objects, containing the PCA plots.
50
+ """
51
+ design = qtable.get_design()
52
+ if design.shape[0] < 3:
53
+ fig, ax = plt.subplots(1, 1, figsize=(2, 1.3))
54
+ fig.suptitle(f'PCA of "{tag}" values', y=1.1)
55
+ ax.text(
56
+ 0.5,
57
+ 0.5,
58
+ "PCA analysis cannot\nbe performed with\nless than 3 samples",
59
+ ha="center",
60
+ va="center",
61
+ )
62
+ ax.grid(False)
63
+ ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
64
+ sns.despine(top=True, right=True, fig=fig)
65
+ return fig, np.array([ax])
66
+
67
+ table = qtable.make_sample_table(
68
+ tag, samples_as_columns=True, exclude_invalid=exclude_invalid
69
+ )
70
+ table = table.replace({0: np.nan})
71
+ table = table[np.isfinite(table).sum(axis=1) > 0]
72
+ if not msreport.helper.intensities_in_logspace(table):
73
+ table = np.log2(table)
74
+ table[table.isna()] = 0
75
+
76
+ table = table.transpose()
77
+ sample_index = table.index.tolist()
78
+ table = sklearn.preprocessing.scale(table, with_std=False)
79
+
80
+ num_components = min(len(sample_index) - 1, 9)
81
+ pca = sklearn.decomposition.PCA(n_components=num_components)
82
+ components = pca.fit_transform(table)
83
+ component_labels = ["PC{}".format(i + 1) for i in range(components.shape[1])]
84
+ components_table = pd.DataFrame(
85
+ data=components, columns=component_labels, index=sample_index
86
+ )
87
+ variance = pca.explained_variance_ratio_ * 100
88
+ variance_lookup = dict(zip(component_labels, variance, strict=True))
89
+
90
+ # Prepare colors
91
+ color_wheel = ColorWheelDict()
92
+ for exp in qtable.get_experiments():
93
+ _ = color_wheel[exp]
94
+
95
+ # Prepare figure
96
+ num_legend_cols = 3 # math.ceil(len(qtable.get_experiments()) / 8)
97
+ bar_width = 0.8
98
+ bar_width_inches = 0.25
99
+ x_padding = 0.25
100
+
101
+ suptitle_space_inch = 0.4
102
+ ax_height_inch = 2.7
103
+ ax_width_inch = ax_height_inch
104
+ ax_wspace_inch = 0.6
105
+ bar_ax_width_inch = (num_components + (2 * x_padding)) * bar_width_inches
106
+ width_ratios = [ax_width_inch, bar_ax_width_inch]
107
+
108
+ fig_height = suptitle_space_inch + ax_height_inch
109
+ fig_width = ax_height_inch + bar_ax_width_inch + ax_wspace_inch
110
+ fig_size = (fig_width, fig_height)
111
+
112
+ subplot_top = 1 - (suptitle_space_inch / fig_height)
113
+ subplot_wspace = ax_wspace_inch / np.mean([ax_width_inch, bar_ax_width_inch])
114
+
115
+ bar_half_width = 0.5
116
+ lower_xbound = (0 - bar_half_width) - x_padding
117
+ upper_xbound = (num_components - 1) + bar_half_width + x_padding
118
+
119
+ fig, axes = plt.subplots(
120
+ 1,
121
+ 2,
122
+ figsize=fig_size,
123
+ gridspec_kw={
124
+ "bottom": 0,
125
+ "top": subplot_top,
126
+ "left": 0,
127
+ "right": 1,
128
+ "wspace": subplot_wspace,
129
+ "width_ratios": width_ratios,
130
+ },
131
+ )
132
+ fig.suptitle(f'PCA of "{tag}" values', y=1)
133
+
134
+ # Comparison of two principle components
135
+ ax = axes[0]
136
+ texts = []
137
+ for sample, data in components_table.iterrows():
138
+ experiment = qtable.get_experiment(str(sample))
139
+ label = design.loc[(design["Sample"] == sample), "Replicate"].tolist()[0]
140
+ color = color_wheel[experiment]
141
+ edge_color = color_wheel.modified_color(experiment, 0.4)
142
+ ax.scatter(
143
+ data[pc_x],
144
+ data[pc_y],
145
+ color=color,
146
+ edgecolor=edge_color,
147
+ lw=0.7,
148
+ s=50,
149
+ label=experiment,
150
+ )
151
+ texts.append(ax.text(data[pc_x], data[pc_y], label))
152
+ adjustText.adjust_text(
153
+ texts,
154
+ force_text=0.15,
155
+ expand_points=(1.4, 1.4),
156
+ lim=20,
157
+ ax=ax,
158
+ )
159
+ ax.set_xlabel(f"{pc_x} ({variance_lookup[pc_x]:.1f}%)")
160
+ ax.set_ylabel(f"{pc_y} ({variance_lookup[pc_y]:.1f}%)")
161
+ ax.grid(axis="both", linestyle="dotted")
162
+
163
+ # Explained variance bar plot
164
+ ax = axes[1]
165
+ xpos = range(len(variance))
166
+ ax.bar(xpos, variance, width=bar_width, color="#D0D0D0", edgecolor="#000000")
167
+ ax.set_xticks(xpos)
168
+ ax.set_xticklabels(
169
+ component_labels,
170
+ rotation="vertical",
171
+ ha="center",
172
+ size=plt.rcParams["axes.labelsize"],
173
+ )
174
+ ax.set_ylabel("Explained variance [%]")
175
+ ax.grid(False, axis="x")
176
+ ax.set_xlim(lower_xbound, upper_xbound)
177
+
178
+ handles, labels = axes[0].get_legend_handles_labels()
179
+ experiment_handles = dict(zip(labels, handles, strict=True))
180
+
181
+ first_ax_bbox = axes[1].get_position()
182
+ legend_xgap_inches = 0.25
183
+ legend_ygap_inches = 0.03
184
+ legend_bbox_x = first_ax_bbox.x1 + (legend_xgap_inches / fig.get_figwidth())
185
+ legend_bbox_y = first_ax_bbox.y1 + (legend_ygap_inches / fig.get_figheight())
186
+ handles, _ = axes[0].get_legend_handles_labels()
187
+ num_legend_cols = np.ceil(len(qtable.get_experiments()) / 12)
188
+ fig.legend(
189
+ handles=experiment_handles.values(),
190
+ loc="upper left",
191
+ bbox_to_anchor=(legend_bbox_x, legend_bbox_y),
192
+ title="Experiment",
193
+ alignment="left",
194
+ frameon=False,
195
+ borderaxespad=0,
196
+ ncol=num_legend_cols,
197
+ )
198
+
199
+ return fig, axes
200
+
201
+
202
+ @with_active_style
203
+ def expression_clustermap(
204
+ qtable: Qtable,
205
+ exclude_invalid: bool = True,
206
+ remove_imputation: bool = True,
207
+ mean_center: bool = False,
208
+ cluster_samples: bool = True,
209
+ cluster_method: str = "average",
210
+ ) -> sns.matrix.ClusterGrid:
211
+ """Plot sample expression values as a hierarchically-clustered heatmap.
212
+
213
+ By default missing and imputed values are assigned an intensity value of 0 to
214
+ perform the clustering. Once clustering is done, these values are removed from the
215
+ heatmap, making them appear white.
216
+
217
+ Args:
218
+ qtable: A `Qtable` instance, which data is used for plotting.
219
+ exclude_invalid: If True, rows are filtered according to the Boolean entries of
220
+ the "Valid" column.
221
+ remove_imputation: If True, imputed values are set to 0 before clustering.
222
+ Defaults to True.
223
+ mean_center: If True, the data is mean-centered before clustering. Defaults to
224
+ False.
225
+ cluster_samples: If True, sample order is determined by hierarchical clustering.
226
+ Otherwise, the order is determined by the order of samples in the qtable
227
+ design. Defaults to True.
228
+ cluster_method: Linkage method to use for calculating clusters. See
229
+ `scipy.cluster.hierarchy.linkage` documentation for more information.
230
+
231
+ Raises:
232
+ ValueError: If less than two samples are present in the qtable.
233
+
234
+ Returns:
235
+ A seaborn ClusterGrid instance. Note that ClusterGrid has a `savefig` method
236
+ that can be used for saving the figure.
237
+ """
238
+ tag: str = "Expression"
239
+ samples = qtable.get_samples()
240
+ experiments = qtable.get_experiments()
241
+
242
+ if len(samples) < 2:
243
+ raise ValueError("At least two samples are required to generate a clustermap.")
244
+
245
+ data = qtable.make_expression_table(samples_as_columns=True)
246
+ data = data[samples]
247
+
248
+ for sample in samples:
249
+ if remove_imputation:
250
+ data.loc[qtable.data[f"Missing {sample}"], sample] = 0
251
+ data[sample] = data[sample].fillna(0)
252
+
253
+ if not mean_center:
254
+ # Hide missing values in the heatmap, making them appear white
255
+ mask_values = qtable.data[
256
+ [f"Missing {sample}" for sample in samples]
257
+ ].to_numpy()
258
+ else:
259
+ mask_values = np.zeros(data.shape, dtype=bool)
260
+
261
+ if exclude_invalid:
262
+ data = data[qtable.data["Valid"]]
263
+ mask_values = mask_values[qtable.data["Valid"]]
264
+
265
+ color_wheel = ColorWheelDict()
266
+ for exp in experiments:
267
+ _ = color_wheel[exp]
268
+ sample_colors = [color_wheel[qtable.get_experiment(sample)] for sample in samples]
269
+
270
+ suptitle_space_inch = 0.4
271
+ sample_width_inch = 0.27
272
+ cbar_height_inch = 3
273
+ cbar_width_inch = sample_width_inch
274
+ cbar_gap_inch = sample_width_inch
275
+ col_colors_height_inch = 0.12
276
+ col_dendrogram_height_inch = 0.6 if cluster_samples else 0.0
277
+ heatmap_height_inch = 3
278
+ heatmap_width_inch = len(samples) * sample_width_inch
279
+
280
+ fig_width = cbar_width_inch + heatmap_width_inch + cbar_gap_inch
281
+ fig_height = (
282
+ suptitle_space_inch
283
+ + col_dendrogram_height_inch
284
+ + col_colors_height_inch
285
+ + heatmap_height_inch
286
+ )
287
+ fig_size = fig_width, fig_height
288
+
289
+ heatmap_width = heatmap_width_inch / fig_width
290
+ heatmap_x0 = 0
291
+ heatmap_height = heatmap_height_inch / fig_height
292
+ heatmap_y0 = 0
293
+ col_colors_height = col_colors_height_inch / fig_height
294
+ col_colors_y0 = heatmap_y0 + heatmap_height
295
+ col_dendrogram_height = col_dendrogram_height_inch / fig_height
296
+ col_dendrogram_y0 = col_colors_y0 + col_colors_height
297
+ cbar_widh = cbar_width_inch / fig_width
298
+ cbar_x0 = (heatmap_width_inch + cbar_gap_inch) / fig_width
299
+ cbar_height = cbar_height_inch / fig_height
300
+ cbar_y0 = col_colors_y0 - cbar_height
301
+
302
+ heatmap_args: dict[str, Any] = {
303
+ "cmap": "magma",
304
+ "yticklabels": False,
305
+ "figsize": fig_size,
306
+ }
307
+ if mean_center:
308
+ data = data.sub(data.mean(axis=1), axis=0)
309
+ heatmap_args.update({"vmin": -2.5, "vmax": 2.5, "center": 0, "cmap": "vlag"})
310
+
311
+ # Generate the plot
312
+ grid = sns.clustermap(
313
+ data,
314
+ col_cluster=cluster_samples,
315
+ col_colors=sample_colors,
316
+ row_colors=["#000000" for _ in range(len(data))],
317
+ mask=mask_values,
318
+ method=cluster_method,
319
+ metric="euclidean",
320
+ **heatmap_args,
321
+ )
322
+ # Reloacte clustermap axes to create a consistent layout
323
+ grid.figure.suptitle(f'Hierarchically-clustered heatmap of "{tag}" values', y=1)
324
+ grid.figure.delaxes(grid.ax_row_colors)
325
+ grid.figure.delaxes(grid.ax_row_dendrogram)
326
+ grid.ax_heatmap.set_position(
327
+ [heatmap_x0, heatmap_y0, heatmap_width, heatmap_height]
328
+ )
329
+ grid.ax_col_colors.set_position(
330
+ [heatmap_x0, col_colors_y0, heatmap_width, col_colors_height]
331
+ )
332
+ grid.ax_col_dendrogram.set_position(
333
+ [heatmap_x0, col_dendrogram_y0, heatmap_width, col_dendrogram_height]
334
+ )
335
+ grid.ax_cbar.set_position([cbar_x0, cbar_y0, cbar_widh, cbar_height])
336
+
337
+ # manually set xticks to guarantee that all samples are displayed
338
+ if cluster_samples:
339
+ sample_order = [samples[i] for i in grid.dendrogram_col.reordered_ind]
340
+ else:
341
+ sample_order = samples
342
+ sample_ticks = np.arange(len(sample_order)) + 0.5
343
+ grid.ax_heatmap.grid(False)
344
+ grid.ax_heatmap.set_xticks(sample_ticks, labels=sample_order)
345
+ grid.ax_heatmap.tick_params(
346
+ axis="x", labelsize=plt.rcParams["axes.labelsize"], rotation=90
347
+ )
348
+
349
+ grid.ax_heatmap.set_facecolor("#F9F9F9")
350
+
351
+ for ax in [grid.ax_heatmap, grid.ax_cbar, grid.ax_col_colors]:
352
+ sns.despine(top=False, right=False, left=False, bottom=False, ax=ax)
353
+ for spine in ["top", "right", "left", "bottom"]:
354
+ ax.spines[spine].set_linewidth(0.75)
355
+ return grid