python-katlas 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- katlas/__init__.py +1 -0
- katlas/_modidx.py +110 -0
- katlas/core.py +769 -0
- katlas/dl.py +355 -0
- katlas/feature.py +290 -0
- katlas/imports.py +7 -0
- katlas/plot.py +663 -0
- katlas/train.py +231 -0
- python_katlas-0.0.1.dist-info/LICENSE +201 -0
- python_katlas-0.0.1.dist-info/METADATA +402 -0
- python_katlas-0.0.1.dist-info/RECORD +14 -0
- python_katlas-0.0.1.dist-info/WHEEL +5 -0
- python_katlas-0.0.1.dist-info/entry_points.txt +2 -0
- python_katlas-0.0.1.dist-info/top_level.txt +1 -0
katlas/plot.py
ADDED
|
@@ -0,0 +1,663 @@
|
|
|
1
|
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/02_plot.ipynb.
|
|
2
|
+
|
|
3
|
+
# %% auto 0
|
|
4
|
+
__all__ = ['set_sns', 'get_color_dict', 'logo_func', 'get_logo', 'get_logo2', 'plot_rank', 'plot_hist', 'plot_heatmap', 'plot_2d',
|
|
5
|
+
'plot_cluster', 'plot_bokeh', 'plot_count', 'plot_bar', 'plot_group_bar', 'plot_box', 'plot_corr',
|
|
6
|
+
'draw_corr', 'get_AUCDF', 'plot_confusion_matrix']
|
|
7
|
+
|
|
8
|
+
# %% ../nbs/02_plot.ipynb 4
|
|
9
|
+
import joblib,logomaker
|
|
10
|
+
import fastcore.all as fc, pandas as pd, numpy as np, seaborn as sns
|
|
11
|
+
from adjustText import adjust_text
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
|
|
14
|
+
from scipy.stats import spearmanr, pearsonr
|
|
15
|
+
from sklearn.metrics import confusion_matrix
|
|
16
|
+
from matplotlib import pyplot as plt
|
|
17
|
+
from matplotlib.ticker import MultipleLocator
|
|
18
|
+
from numpy import trapz
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# Katlas
|
|
22
|
+
from .feature import *
|
|
23
|
+
from .core import *
|
|
24
|
+
|
|
25
|
+
# Bokeh
|
|
26
|
+
from bokeh.io import output_notebook, show
|
|
27
|
+
from bokeh.plotting import figure, ColumnDataSource
|
|
28
|
+
from bokeh.models import HoverTool, AutocompleteInput, CustomJS
|
|
29
|
+
from bokeh.layouts import column
|
|
30
|
+
from bokeh.palettes import Category20_20
|
|
31
|
+
from itertools import cycle
|
|
32
|
+
|
|
33
|
+
# %% ../nbs/02_plot.ipynb 6
|
|
34
|
+
def set_sns():
|
|
35
|
+
"Set seaborn resolution for notebook display"
|
|
36
|
+
sns.set(rc={"figure.dpi":300, 'savefig.dpi':300})
|
|
37
|
+
sns.set_context('notebook')
|
|
38
|
+
sns.set_style("ticks")
|
|
39
|
+
|
|
40
|
+
# %% ../nbs/02_plot.ipynb 7
|
|
41
|
+
def get_color_dict(categories, # list of names to assign color
|
|
42
|
+
palette: str='tab20', # choose from sns.color_palette
|
|
43
|
+
):
|
|
44
|
+
"Assign colors to a list of names (allow duplicates), returns a dictionary of unique name with corresponding color"
|
|
45
|
+
p=sns.color_palette(palette)
|
|
46
|
+
color_cycle = cycle(p)
|
|
47
|
+
color_map = {category: next(color_cycle) for category in categories}
|
|
48
|
+
return color_map
|
|
49
|
+
|
|
50
|
+
# %% ../nbs/02_plot.ipynb 11
|
|
51
|
+
def logo_func(df:pd.DataFrame, # a dataframe that contains ratios for each amino acid at each position
|
|
52
|
+
title: str='logo', # title of the motif logo
|
|
53
|
+
):
|
|
54
|
+
"Use logomaker plot motif logos given a df matrix "
|
|
55
|
+
|
|
56
|
+
# Indicates color scheme of the amino acid
|
|
57
|
+
aa = {
|
|
58
|
+
'AG': '#037f04',
|
|
59
|
+
'DEsty': '#da143e', # sty seems to be the same color as big ST&Y even though we set it here
|
|
60
|
+
'F': '#84380b',
|
|
61
|
+
'HQN': '#8d2be1',
|
|
62
|
+
'LMIFWTVC': '#d9a41c',
|
|
63
|
+
'P': '#000000',
|
|
64
|
+
'RK': '#0000ff',
|
|
65
|
+
'ST': '#8d008d', # STY overwrites the previous s,t,y as logomaker does not distingusih capital and lower case
|
|
66
|
+
'Y': '#84380b',
|
|
67
|
+
|
|
68
|
+
# some old settings
|
|
69
|
+
# 'st':'#8d2be1',
|
|
70
|
+
# 'y':'#8d2be1'
|
|
71
|
+
# 'pS/pT':'#8d2be1',# logomaker does not support double letters like pS or pT
|
|
72
|
+
# 'pY':'#8d2be1'
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
# Use logomaker to plot
|
|
76
|
+
logo = logomaker.Logo(df,color_scheme = aa,flip_below=False,figsize=(7,2.5)) #5.5,2.5
|
|
77
|
+
|
|
78
|
+
logo.style_xticks(fmt='%d')
|
|
79
|
+
logo.ax.set_yticks([])
|
|
80
|
+
logo.ax.set_title(title)
|
|
81
|
+
|
|
82
|
+
# %% ../nbs/02_plot.ipynb 12
|
|
83
|
+
def get_logo(df: pd.DataFrame, # stacked Dataframe with kinase as index, substrates as columns
|
|
84
|
+
kinase: str, # a specific kinase name in index
|
|
85
|
+
):
|
|
86
|
+
"Given stacked df (index as kinase, columns as substrates), get a specific kinase's logo"
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
# get raw kinase to calculate S/T
|
|
90
|
+
pp = get_one_kinase(df,kinase,normalize=False)
|
|
91
|
+
|
|
92
|
+
# get S/T ratio value
|
|
93
|
+
ss = pp['S'].sum()
|
|
94
|
+
st = pp['T'].sum()
|
|
95
|
+
|
|
96
|
+
S_ctrl = 0.75*ss - 0.25*st
|
|
97
|
+
T_ctrl = 0.75*st - 0.25*ss
|
|
98
|
+
|
|
99
|
+
S0 = S_ctrl / max(S_ctrl, T_ctrl)
|
|
100
|
+
T0 = T_ctrl / max(S_ctrl, T_ctrl)
|
|
101
|
+
|
|
102
|
+
S_ratio = S0/(S0+T0)
|
|
103
|
+
T_ratio = T0/(S0+T0)
|
|
104
|
+
|
|
105
|
+
# get normalized kinase
|
|
106
|
+
norm_p = get_one_kinase(df,kinase, normalize=True)
|
|
107
|
+
|
|
108
|
+
# calculate ratio, divide values by median, followed by log2 transformation
|
|
109
|
+
ratio =norm_p.apply(lambda r: r/r.median(),axis=1)
|
|
110
|
+
ratio = np.log2(ratio)
|
|
111
|
+
|
|
112
|
+
m = ratio.apply(lambda row: row[row > 0].sum(), axis=1).max()
|
|
113
|
+
|
|
114
|
+
new_row = pd.DataFrame({'S': S_ratio*m, 'T':T_ratio*m}, index=[0])
|
|
115
|
+
|
|
116
|
+
ratio2 = pd.concat([ratio, new_row], ignore_index=False).fillna(0)
|
|
117
|
+
|
|
118
|
+
# plot logo
|
|
119
|
+
logo_func(ratio2, kinase)
|
|
120
|
+
|
|
121
|
+
# %% ../nbs/02_plot.ipynb 16
|
|
122
|
+
def get_logo2(full: pd.DataFrame, # a dataframe that contains the full matrix of a kinase, with index as amino acid, and columns as positions
|
|
123
|
+
title: str = 'logo', # title of the graph
|
|
124
|
+
):
|
|
125
|
+
|
|
126
|
+
"Plot logo from a full freqency matrix of a kinase"
|
|
127
|
+
|
|
128
|
+
# get S,T,Y ratio
|
|
129
|
+
S_ratio,T_ratio,Y_ratio = full[0][['s','t','y']]/full[0][['s','t','y']].sum()
|
|
130
|
+
|
|
131
|
+
# drop position 0
|
|
132
|
+
full = full.drop(columns=[0])
|
|
133
|
+
|
|
134
|
+
# identify the minimum value other than 0
|
|
135
|
+
min_val = full[full > 0].min().min()
|
|
136
|
+
|
|
137
|
+
# replace 0s with the identified minimum value
|
|
138
|
+
full = full.replace(0, min_val)
|
|
139
|
+
|
|
140
|
+
norm_p = full.T
|
|
141
|
+
|
|
142
|
+
# calculate ratio, use substraction
|
|
143
|
+
ratio =norm_p.apply(lambda r: r-r.median(),axis=1)
|
|
144
|
+
|
|
145
|
+
# calculate ratio based on previous method, divide values by median, followed by log2
|
|
146
|
+
# ratio =norm_p.apply(lambda r: r/r.median(),axis=1)
|
|
147
|
+
# ratio = np.log2(ratio)
|
|
148
|
+
|
|
149
|
+
# get the max value for a position
|
|
150
|
+
m = ratio.apply(lambda row: row[row > 0].sum(), axis=1).max()
|
|
151
|
+
|
|
152
|
+
# get the relative height of S,T,Y relative to the max value
|
|
153
|
+
new_row = pd.DataFrame({'S': S_ratio*m, 'T':T_ratio*m,'Y':Y_ratio*m}, index=[0])
|
|
154
|
+
|
|
155
|
+
# prepare the matrix for logomaker
|
|
156
|
+
ratio2 = pd.concat([ratio, new_row], ignore_index=False).fillna(0)
|
|
157
|
+
|
|
158
|
+
logo_func(ratio2,title)
|
|
159
|
+
|
|
160
|
+
# %% ../nbs/02_plot.ipynb 19
|
|
161
|
+
@fc.delegates(sns.scatterplot)
|
|
162
|
+
def plot_rank(sorted_df: pd.DataFrame, # a sorted dataframe
|
|
163
|
+
x: str, # column name for x axis
|
|
164
|
+
y: str, # column name for y aixs
|
|
165
|
+
n_hi: int=10, # if not None, show the head n names
|
|
166
|
+
n_lo: int=10, # if not None, show the tail n names
|
|
167
|
+
figsize: tuple=(10,8), # figure size
|
|
168
|
+
**kwargs # arguments for sns.scatterplot()
|
|
169
|
+
):
|
|
170
|
+
|
|
171
|
+
"Plot rank from a sorted dataframe"
|
|
172
|
+
|
|
173
|
+
plt.figure(figsize=figsize)
|
|
174
|
+
|
|
175
|
+
sorted_df = sorted_df.reset_index(drop=True) # drop customized index
|
|
176
|
+
|
|
177
|
+
sns_plot = sns.scatterplot(data=sorted_df,
|
|
178
|
+
x = x,
|
|
179
|
+
y = y, **kwargs)
|
|
180
|
+
|
|
181
|
+
sns_plot.set_xticks([])
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
texts = []
|
|
185
|
+
|
|
186
|
+
# Annotate the highest n values
|
|
187
|
+
if n_hi is not None:
|
|
188
|
+
|
|
189
|
+
for i, row in sorted_df.head(n_hi).iterrows():
|
|
190
|
+
texts.append(plt.text(i, row[y], row[x], ha='center', va='bottom'))
|
|
191
|
+
|
|
192
|
+
if n_lo is not None:
|
|
193
|
+
# Annotate the lowest n values
|
|
194
|
+
n_lowest = n_lo
|
|
195
|
+
for i, row in sorted_df.tail(n_lowest).iterrows():
|
|
196
|
+
texts.append(plt.text(i, row[y], row[x], ha='center', va='bottom'))
|
|
197
|
+
|
|
198
|
+
if len(texts)>0:
|
|
199
|
+
# Use adjustText to adjust text positions
|
|
200
|
+
adjust_text(texts, arrowprops=dict(arrowstyle='-', color='black'))
|
|
201
|
+
|
|
202
|
+
plt.tight_layout()
|
|
203
|
+
|
|
204
|
+
# %% ../nbs/02_plot.ipynb 23
|
|
205
|
+
@fc.delegates(sns.histplot)
|
|
206
|
+
def plot_hist(df: pd.DataFrame, # a dataframe that contain values for plot
|
|
207
|
+
x: str, # column name of values
|
|
208
|
+
figsize: tuple=(6,2),
|
|
209
|
+
**kwargs, # arguments for sns.histplot()
|
|
210
|
+
):
|
|
211
|
+
|
|
212
|
+
hist_params = {'element':'poly',
|
|
213
|
+
'edgecolor': None,
|
|
214
|
+
'alpha':0.5,
|
|
215
|
+
'bins':100,
|
|
216
|
+
'kde':True}
|
|
217
|
+
|
|
218
|
+
plt.figure(figsize=figsize)
|
|
219
|
+
sns.histplot(data=df,x=x,**hist_params,**kwargs)
|
|
220
|
+
|
|
221
|
+
# %% ../nbs/02_plot.ipynb 27
|
|
222
|
+
@fc.delegates(sns.heatmap)
|
|
223
|
+
def plot_heatmap(matrix, # a matrix of values
|
|
224
|
+
title: str='heatmap', # title of the heatmap
|
|
225
|
+
figsize: tuple=(6,10), # figure size of the heatmap
|
|
226
|
+
cmap: str='binary', # color map, default is dark&white
|
|
227
|
+
**kwargs, # arguments for sns.heatmap()
|
|
228
|
+
):
|
|
229
|
+
|
|
230
|
+
"Plot heatmap based on a matrix of values"
|
|
231
|
+
|
|
232
|
+
plt.figure(figsize=figsize)
|
|
233
|
+
sns.heatmap(matrix, cmap=cmap, annot=False,**kwargs)
|
|
234
|
+
plt.title(title)
|
|
235
|
+
|
|
236
|
+
# %% ../nbs/02_plot.ipynb 31
|
|
237
|
+
@fc.delegates(sns.scatterplot)
|
|
238
|
+
def plot_2d(X: pd.DataFrame, # a dataframe that has first column to be x, and second column to be y
|
|
239
|
+
**kwargs, # arguments for sns.scatterplot
|
|
240
|
+
):
|
|
241
|
+
"Make 2D plot from a dataframe that has first column to be x, and second column to be y"
|
|
242
|
+
plt.figure(figsize=(7,7))
|
|
243
|
+
sns.scatterplot(data = X,x=X.columns[0],y=X.columns[1],alpha=0.7,**kwargs)
|
|
244
|
+
|
|
245
|
+
# %% ../nbs/02_plot.ipynb 33
|
|
246
|
+
def plot_cluster(df: pd.DataFrame, # a dataframe of values that is waited for dimensionality reduction
|
|
247
|
+
method: str='pca', # dimensionality reduction method, choose from pca, umap, and tsne
|
|
248
|
+
hue: str=None, # colname of color
|
|
249
|
+
complexity: int=30, # recommend 30 for tsne, 15 for umap, none for pca
|
|
250
|
+
palette: str='tab20', # color scheme, could be tab10 if less categories
|
|
251
|
+
legend: bool=False, # whether or not add the legend on the side
|
|
252
|
+
name_list=None, # a list of names to annotate each dot in the plot
|
|
253
|
+
seed: int=123, # seed for dimensionality reduction
|
|
254
|
+
s: int=50, # size of the dot
|
|
255
|
+
**kwargs # arguments for dimensional reduction method to be used
|
|
256
|
+
):
|
|
257
|
+
|
|
258
|
+
"Given a dataframe of values, plot it in 2d, method could be pca, tsne, or umap"
|
|
259
|
+
|
|
260
|
+
embedding_df = reduce_feature(df, method=method, seed=seed, complexity = complexity,**kwargs)
|
|
261
|
+
# x_col, y_col = [col for col in embedding_df.columns if col.startswith(method.upper())]
|
|
262
|
+
x_col, y_col = embedding_df.columns
|
|
263
|
+
sns.relplot(data=embedding_df, x=x_col, y=y_col, hue=hue, palette=palette, s=s, alpha=0.8, legend=legend)
|
|
264
|
+
plt.xticks([])
|
|
265
|
+
plt.yticks([])
|
|
266
|
+
if name_list is not None:
|
|
267
|
+
texts = [plt.text(embedding_df[x_col][i], embedding_df[y_col][i], name_list[i],fontsize=8) for i in range(len(embedding_df))]
|
|
268
|
+
adjust_text(texts, arrowprops=dict(arrowstyle='-', color='black'))
|
|
269
|
+
|
|
270
|
+
# %% ../nbs/02_plot.ipynb 37
|
|
271
|
+
def plot_bokeh(X:pd.DataFrame, # a dataframe of two columns from dimensionality reduction
|
|
272
|
+
idx, # pd.Series or list that indicates identities for searching box
|
|
273
|
+
hue:None, # pd.Series or list that indicates category for each sample
|
|
274
|
+
s: int=3, # dot size
|
|
275
|
+
**kwargs # key:args format for information to include in the dot information box
|
|
276
|
+
):
|
|
277
|
+
|
|
278
|
+
"Make interactive 2D plot with a searching box and window of dot information when pointing "
|
|
279
|
+
|
|
280
|
+
output_notebook()
|
|
281
|
+
|
|
282
|
+
idx = list(idx)
|
|
283
|
+
hue = list(hue)
|
|
284
|
+
|
|
285
|
+
def assign_colors(categories, palette):
|
|
286
|
+
"assign each unique name in a list with a color, returns a color list of same length"
|
|
287
|
+
color_cycle = cycle(palette)
|
|
288
|
+
color_map = {category: next(color_cycle) for category in categories}
|
|
289
|
+
return [color_map[category] for category in categories]
|
|
290
|
+
|
|
291
|
+
if hue is not None:
|
|
292
|
+
colors = assign_colors(hue, Category20_20)
|
|
293
|
+
else:
|
|
294
|
+
colors = ['navy'] * len(X)
|
|
295
|
+
|
|
296
|
+
data_dict={
|
|
297
|
+
'x': X.iloc[:,0],
|
|
298
|
+
'y': X.iloc[:,1],
|
|
299
|
+
'identity': idx,
|
|
300
|
+
'color': colors,
|
|
301
|
+
'original_color': colors,
|
|
302
|
+
'size': [s] * len(X),
|
|
303
|
+
'highlighted': ['no'] * len(X) # To keep track of which dot is highlighted
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
for key, value in kwargs.items():
|
|
307
|
+
data_dict[key] = value
|
|
308
|
+
|
|
309
|
+
source = ColumnDataSource(data=data_dict)
|
|
310
|
+
|
|
311
|
+
p = figure(tools="pan,box_zoom,wheel_zoom,reset")
|
|
312
|
+
p.scatter('x', 'y', source=source, alpha=0.6, color='color', size='size')
|
|
313
|
+
|
|
314
|
+
# Disable grid lines
|
|
315
|
+
p.xgrid.visible = False
|
|
316
|
+
p.ygrid.visible = False
|
|
317
|
+
|
|
318
|
+
# Add hover tool
|
|
319
|
+
hover = HoverTool()
|
|
320
|
+
|
|
321
|
+
tooltips = [("Identity", "@identity")]
|
|
322
|
+
|
|
323
|
+
for key in kwargs.keys():
|
|
324
|
+
tooltips.append((key.capitalize(), f"@{key}"))
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
hover.tooltips = tooltips
|
|
328
|
+
p.add_tools(hover)
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
autocomplete = AutocompleteInput(title="Search by Identity:", completions=idx)
|
|
332
|
+
|
|
333
|
+
callback = CustomJS(args=dict(source=source, plot=p), code="""
|
|
334
|
+
const data = source.data;
|
|
335
|
+
const search_val = cb_obj.value.toLowerCase();
|
|
336
|
+
const x = data['x'];
|
|
337
|
+
const y = data['y'];
|
|
338
|
+
const identity = data['identity'];
|
|
339
|
+
const color = data['color'];
|
|
340
|
+
const original_color = data['original_color'];
|
|
341
|
+
const size = data['size'];
|
|
342
|
+
const highlighted = data['highlighted'];
|
|
343
|
+
|
|
344
|
+
for (let i = 0; i < identity.length; i++) {
|
|
345
|
+
if (highlighted[i] === 'yes') {
|
|
346
|
+
color[i] = original_color[i];
|
|
347
|
+
size[i] = 10;
|
|
348
|
+
highlighted[i] = 'no';
|
|
349
|
+
}
|
|
350
|
+
if (identity[i].toLowerCase() === search_val) {
|
|
351
|
+
plot.x_range.start = x[i] - 5;
|
|
352
|
+
plot.x_range.end = x[i] + 5;
|
|
353
|
+
plot.y_range.start = y[i] - 5;
|
|
354
|
+
plot.y_range.end = y[i] + 5;
|
|
355
|
+
color[i] = 'red';
|
|
356
|
+
size[i] = 15;
|
|
357
|
+
highlighted[i] = 'yes';
|
|
358
|
+
}
|
|
359
|
+
}
|
|
360
|
+
source.change.emit();
|
|
361
|
+
""")
|
|
362
|
+
autocomplete.js_on_change('value', callback)
|
|
363
|
+
|
|
364
|
+
# Show layout
|
|
365
|
+
layout = column(autocomplete, p)
|
|
366
|
+
show(layout)
|
|
367
|
+
|
|
368
|
+
# %% ../nbs/02_plot.ipynb 40
|
|
369
|
+
def plot_count(cnt, # from df['x'].value_counts()
|
|
370
|
+
tick_spacing: float= None, # tick spacing for x axis
|
|
371
|
+
palette: str='tab20'):
|
|
372
|
+
|
|
373
|
+
"Make bar plot from df['x'].value_counts()"
|
|
374
|
+
|
|
375
|
+
c = sns.color_palette(palette)
|
|
376
|
+
ax = cnt.plot.barh(color = c)
|
|
377
|
+
|
|
378
|
+
for index, value in enumerate(cnt):
|
|
379
|
+
plt.text(value, index, str(value),fontsize=10,rotation=-90, va='center')
|
|
380
|
+
# Set x-ticks at regular intervals
|
|
381
|
+
if tick_spacing is not None:
|
|
382
|
+
ax.xaxis.set_major_locator(MultipleLocator(tick_spacing))
|
|
383
|
+
|
|
384
|
+
# %% ../nbs/02_plot.ipynb 42
|
|
385
|
+
@fc.delegates(sns.barplot)
|
|
386
|
+
def plot_bar(df,
|
|
387
|
+
value, # colname of value
|
|
388
|
+
group, # colname of group
|
|
389
|
+
title = None,
|
|
390
|
+
figsize = (12,5),
|
|
391
|
+
fontsize=14,
|
|
392
|
+
dots = True, # whether or not add dots in the graph
|
|
393
|
+
rotation=90,
|
|
394
|
+
ascending=False,
|
|
395
|
+
**kwargs
|
|
396
|
+
):
|
|
397
|
+
|
|
398
|
+
"Plot bar graph from unstacked dataframe; need to indicate columns of values and categories"
|
|
399
|
+
|
|
400
|
+
plt.figure(figsize=figsize)
|
|
401
|
+
|
|
402
|
+
idx = df.groupby(group)[value].mean().sort_values(ascending=ascending).index
|
|
403
|
+
|
|
404
|
+
sns.barplot(data=df, x=group, y=value, order=idx, **kwargs)
|
|
405
|
+
|
|
406
|
+
if dots:
|
|
407
|
+
marker = {'marker': 'o',
|
|
408
|
+
'color': 'white',
|
|
409
|
+
'edgecolor': 'black',
|
|
410
|
+
'linewidth': 1.5,
|
|
411
|
+
'jitter':True,
|
|
412
|
+
's': 5}
|
|
413
|
+
|
|
414
|
+
sns.stripplot(data=df,
|
|
415
|
+
x=group,
|
|
416
|
+
y=value,
|
|
417
|
+
order=idx,
|
|
418
|
+
alpha=0.8,
|
|
419
|
+
# ax=g.ax,
|
|
420
|
+
**marker)
|
|
421
|
+
|
|
422
|
+
# Increase font size for the x-axis and y-axis tick labels
|
|
423
|
+
plt.tick_params(axis='x', labelsize=fontsize) # Increase x-axis label size
|
|
424
|
+
plt.tick_params(axis='y', labelsize=fontsize) # Increase y-axis label size
|
|
425
|
+
|
|
426
|
+
# Modify x and y label and increase font size
|
|
427
|
+
plt.xlabel('', fontsize=fontsize)
|
|
428
|
+
plt.ylabel(value, fontsize=fontsize)
|
|
429
|
+
|
|
430
|
+
# Rotate X labels
|
|
431
|
+
plt.xticks(rotation=rotation)
|
|
432
|
+
|
|
433
|
+
# Plot titles
|
|
434
|
+
if title is not None:
|
|
435
|
+
plt.title(title,fontsize=fontsize)
|
|
436
|
+
|
|
437
|
+
plt.gca().spines[['right', 'top']].set_visible(False)
|
|
438
|
+
|
|
439
|
+
# %% ../nbs/02_plot.ipynb 45
|
|
440
|
+
@fc.delegates(sns.barplot)
|
|
441
|
+
def plot_group_bar(df,
|
|
442
|
+
value_cols, # list of column names for values, the order depends on the first item
|
|
443
|
+
group, # column name of group (e.g., 'kinase')
|
|
444
|
+
figsize=(12, 5),
|
|
445
|
+
order=None,
|
|
446
|
+
title=None,
|
|
447
|
+
fontsize=14,
|
|
448
|
+
rotation=90,
|
|
449
|
+
**kwargs):
|
|
450
|
+
|
|
451
|
+
" Plot grouped bar graph from dataframe. "
|
|
452
|
+
|
|
453
|
+
# Prepare the dataframe for plotting
|
|
454
|
+
# Melt the dataframe to go from wide to long format
|
|
455
|
+
df_melted = df.melt(id_vars=group, value_vars=value_cols, var_name='Ranking', value_name='Value')
|
|
456
|
+
|
|
457
|
+
plt.figure(figsize=figsize)
|
|
458
|
+
|
|
459
|
+
# Create the bar plot
|
|
460
|
+
sns.barplot(data=df_melted, x=group, y='Value', hue='Ranking', order=order,
|
|
461
|
+
capsize=0.1,errwidth=1.5,errcolor='gray', # adjust the error bar settings
|
|
462
|
+
**kwargs)
|
|
463
|
+
|
|
464
|
+
# Increase font size for the x-axis and y-axis tick labels
|
|
465
|
+
plt.tick_params(axis='x', labelsize=fontsize) # Increase x-axis label size
|
|
466
|
+
plt.tick_params(axis='y', labelsize=fontsize) # Increase y-axis label size
|
|
467
|
+
|
|
468
|
+
# Modify x and y label and increase font size
|
|
469
|
+
plt.xlabel('', fontsize=fontsize)
|
|
470
|
+
plt.ylabel('Value', fontsize=fontsize)
|
|
471
|
+
|
|
472
|
+
# Rotate X labels
|
|
473
|
+
plt.xticks(rotation=rotation)
|
|
474
|
+
|
|
475
|
+
# Plot titles
|
|
476
|
+
if title is not None:
|
|
477
|
+
plt.title(title, fontsize=fontsize)
|
|
478
|
+
|
|
479
|
+
plt.gca().spines[['right', 'top']].set_visible(False)
|
|
480
|
+
plt.legend(fontsize=fontsize) # if change legend location, use loc='upper right'
|
|
481
|
+
|
|
482
|
+
# %% ../nbs/02_plot.ipynb 48
|
|
483
|
+
@fc.delegates(sns.boxplot)
|
|
484
|
+
def plot_box(df,
|
|
485
|
+
value, # colname of value
|
|
486
|
+
group, # colname of group
|
|
487
|
+
title=None,
|
|
488
|
+
figsize=(6,3),
|
|
489
|
+
fontsize=14,
|
|
490
|
+
dots=True,
|
|
491
|
+
rotation=90,
|
|
492
|
+
**kwargs
|
|
493
|
+
):
|
|
494
|
+
|
|
495
|
+
"Plot box plot."
|
|
496
|
+
|
|
497
|
+
plt.figure(figsize=figsize)
|
|
498
|
+
|
|
499
|
+
idx = df[[group,value]].groupby(group).median().sort_values(value,ascending=False).index
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
sns.boxplot(data=df, x=group, y=value, order=idx, **kwargs)
|
|
503
|
+
|
|
504
|
+
if dots:
|
|
505
|
+
sns.stripplot(x=group, y=value, data=df, order=idx, jitter=True, color='black', size=3)
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
# Increase font size for the x-axis and y-axis tick labels
|
|
509
|
+
plt.tick_params(axis='x', labelsize=fontsize) # Increase x-axis label size
|
|
510
|
+
plt.tick_params(axis='y', labelsize=fontsize) # Increase y-axis label size
|
|
511
|
+
|
|
512
|
+
plt.xlabel('', fontsize=fontsize)
|
|
513
|
+
plt.ylabel(value, fontsize=fontsize)
|
|
514
|
+
|
|
515
|
+
plt.xticks(rotation=rotation)
|
|
516
|
+
|
|
517
|
+
if title is not None:
|
|
518
|
+
plt.title(title,fontsize=fontsize)
|
|
519
|
+
|
|
520
|
+
# Remove right and top spines
|
|
521
|
+
# plt.gca().spines[['right', 'top']].set_visible(False)
|
|
522
|
+
|
|
523
|
+
|
|
524
|
+
# %% ../nbs/02_plot.ipynb 51
|
|
525
|
+
@fc.delegates(sns.regplot)
|
|
526
|
+
def plot_corr(x, # x axis values, or colname of x axis
|
|
527
|
+
y, # y axis values, or colname of y axis
|
|
528
|
+
xlabel=None,# x axis label
|
|
529
|
+
ylabel=None,# y axis label
|
|
530
|
+
data = None, # dataframe that contains data
|
|
531
|
+
text_location = [0.8,0.1],
|
|
532
|
+
**kwargs
|
|
533
|
+
):
|
|
534
|
+
"Given a dataframe and the name of two columns, plot the two columns' correlation"
|
|
535
|
+
if data is not None:
|
|
536
|
+
x=data[x]
|
|
537
|
+
y=data[y]
|
|
538
|
+
|
|
539
|
+
pear, pvalue = pearsonr(x, y)
|
|
540
|
+
|
|
541
|
+
sns.regplot(
|
|
542
|
+
x=x,
|
|
543
|
+
y=y,
|
|
544
|
+
line_kws={'color': 'gray'}, **kwargs
|
|
545
|
+
)
|
|
546
|
+
|
|
547
|
+
if xlabel is not None:
|
|
548
|
+
plt.xlabel(xlabel)
|
|
549
|
+
|
|
550
|
+
if ylabel is not None:
|
|
551
|
+
plt.ylabel(ylabel)
|
|
552
|
+
|
|
553
|
+
# correlation_text = f'Spearman: {correlation:.2f}' if method == 'spearman' else f'Pearson: {correlation:.2f}'
|
|
554
|
+
|
|
555
|
+
# plt.text(x=0.8, y=0.1, s=correlation_text, transform=plt.gca().transAxes, ha='center', va='center')
|
|
556
|
+
plt.text(s=f'Pearson = {round(pear,2)}\n p = {"{:.2e}".format(pvalue)}',
|
|
557
|
+
x=text_location[0],y=text_location[1],
|
|
558
|
+
transform=plt.gca().transAxes,
|
|
559
|
+
ha='center', va='center')
|
|
560
|
+
|
|
561
|
+
# %% ../nbs/02_plot.ipynb 55
|
|
562
|
+
def draw_corr(corr):
|
|
563
|
+
|
|
564
|
+
"plot heatmap from df.corr()"
|
|
565
|
+
|
|
566
|
+
# Mask for the upper triangle
|
|
567
|
+
mask = np.triu(np.ones_like(corr, dtype=bool))
|
|
568
|
+
|
|
569
|
+
# Plotting the heatmap
|
|
570
|
+
plt.figure(figsize=(20, 16)) # Set the figure size
|
|
571
|
+
sns.heatmap(corr, annot=True, cmap='coolwarm', vmin=-1, vmax=1, mask=mask, fmt='.2f')
|
|
572
|
+
|
|
573
|
+
# %% ../nbs/02_plot.ipynb 59
|
|
574
|
+
def get_AUCDF(df,col, reverse=False,plot=True,xlabel='Rank of reported kinase'):
|
|
575
|
+
|
|
576
|
+
"Plot CDF curve and get relative area under the curve"
|
|
577
|
+
|
|
578
|
+
# sort col values as x values
|
|
579
|
+
x_values = df[col].sort_values().values
|
|
580
|
+
|
|
581
|
+
# get y_values evenly distributed from 0 to 1
|
|
582
|
+
# y_values = np.arange(1, len(x_values) + 1) / len(x_values) # this method assumes equal distribution of each x value
|
|
583
|
+
y_values = pd.Series(x_values).rank(method='average', pct=True).values # this method takes duplicates into account
|
|
584
|
+
|
|
585
|
+
if reverse:
|
|
586
|
+
y_values = 1 - y_values + y_values.min() # Adjust for reverse while keeping the distribution's integrity
|
|
587
|
+
# calculate the area under the curve using the trapezoidal rule
|
|
588
|
+
area_under_curve = trapz(y_values, x_values)
|
|
589
|
+
|
|
590
|
+
# calculate total area
|
|
591
|
+
total_area = (x_values[-1] - x_values[0]) * (y_values[-1] - y_values[0])
|
|
592
|
+
|
|
593
|
+
|
|
594
|
+
AUCDF = area_under_curve / total_area
|
|
595
|
+
if reverse:
|
|
596
|
+
AUCDF = -AUCDF
|
|
597
|
+
|
|
598
|
+
if plot:
|
|
599
|
+
# Create a figure and a primary axis
|
|
600
|
+
fig, ax1 = plt.subplots(figsize=(7,5))
|
|
601
|
+
|
|
602
|
+
# fontsize
|
|
603
|
+
fontsize=17
|
|
604
|
+
|
|
605
|
+
# Plot the histogram on the primary axis
|
|
606
|
+
sns.histplot(x_values,bins=20,ax=ax1)
|
|
607
|
+
ax1.set_xlabel(xlabel,fontsize=fontsize)
|
|
608
|
+
ax1.set_ylabel('Substrates',color='darkblue',fontsize=fontsize)
|
|
609
|
+
ax1.tick_params(axis='y', labelcolor='darkblue',labelsize=fontsize)
|
|
610
|
+
ax1.tick_params(axis='x', labelcolor='black',labelsize=fontsize)
|
|
611
|
+
ax1.set_xlim(min(x_values),max(x_values))
|
|
612
|
+
|
|
613
|
+
# Create a secondary axis for the CDF
|
|
614
|
+
ax2 = ax1.twinx()
|
|
615
|
+
|
|
616
|
+
# Plot the CDF on the secondary axis
|
|
617
|
+
# ax2.plot(bin_edges[:-1], cumulative_data, color='red', linestyle='-', linewidth=2.0)
|
|
618
|
+
ax2.plot(x_values, y_values, color='darkred', linestyle='-', linewidth=2.0)
|
|
619
|
+
if reverse:
|
|
620
|
+
ax2.plot([max(x_values),0],[0, max(y_values)], 'k--') # 'k--' is for a black dashed line
|
|
621
|
+
else:
|
|
622
|
+
ax2.plot([0, max(x_values)], [0, max(y_values)], 'k--') # 'k--' is for a black dashed line
|
|
623
|
+
|
|
624
|
+
ax2.set_ylabel('Probability', color='darkred',fontsize=fontsize,rotation=270,labelpad=18)
|
|
625
|
+
if reverse:
|
|
626
|
+
ax2.text(0.45, 0.3, f"AUCDF:{AUCDF.round(4)}", transform=plt.gca().transAxes, ha='right', va='bottom',fontsize=fontsize)
|
|
627
|
+
else:
|
|
628
|
+
ax2.text(0.95, 0.3, f"AUCDF:{AUCDF.round(4)}", transform=plt.gca().transAxes, ha='right', va='bottom',fontsize=fontsize)
|
|
629
|
+
ax2.tick_params(axis='y', labelcolor='darkred',labelsize=fontsize)
|
|
630
|
+
ax2.set_ylim(0, 1) # Probabilities range from 0 to 1
|
|
631
|
+
|
|
632
|
+
# Show the plot
|
|
633
|
+
plt.title(f'{len(x_values)} kinase-substrate pairs',fontsize=fontsize)
|
|
634
|
+
plt.show()
|
|
635
|
+
|
|
636
|
+
return AUCDF
|
|
637
|
+
|
|
638
|
+
# %% ../nbs/02_plot.ipynb 62
|
|
639
|
+
def plot_confusion_matrix(target, # pd.Series
|
|
640
|
+
pred, # pd.Series
|
|
641
|
+
class_names:list=['0','1'],
|
|
642
|
+
normalize=False,
|
|
643
|
+
title='Confusion matrix',
|
|
644
|
+
cmap=plt.cm.Blues):
|
|
645
|
+
|
|
646
|
+
"Plot the confusion matrix."
|
|
647
|
+
|
|
648
|
+
cm = confusion_matrix(target, pred)
|
|
649
|
+
|
|
650
|
+
if normalize:
|
|
651
|
+
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
|
|
652
|
+
print("Normalized confusion matrix")
|
|
653
|
+
else:
|
|
654
|
+
print('Confusion matrix, without normalization')
|
|
655
|
+
|
|
656
|
+
|
|
657
|
+
plt.figure(figsize=(6,6))
|
|
658
|
+
sns.heatmap(cm, annot=True, cmap=cmap) # Plot the heatmap
|
|
659
|
+
plt.title(title)
|
|
660
|
+
plt.ylabel('True label')
|
|
661
|
+
plt.xlabel('Predicted label')
|
|
662
|
+
plt.xticks(np.arange(len(class_names)) + 0.5, class_names)
|
|
663
|
+
plt.yticks(np.arange(len(class_names)) + 0.5, class_names, rotation=0)
|