python-katlas 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
katlas/plot.py ADDED
@@ -0,0 +1,663 @@
1
+ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/02_plot.ipynb.
2
+
3
+ # %% auto 0
4
+ __all__ = ['set_sns', 'get_color_dict', 'logo_func', 'get_logo', 'get_logo2', 'plot_rank', 'plot_hist', 'plot_heatmap', 'plot_2d',
5
+ 'plot_cluster', 'plot_bokeh', 'plot_count', 'plot_bar', 'plot_group_bar', 'plot_box', 'plot_corr',
6
+ 'draw_corr', 'get_AUCDF', 'plot_confusion_matrix']
7
+
8
+ # %% ../nbs/02_plot.ipynb 4
9
+ import joblib,logomaker
10
+ import fastcore.all as fc, pandas as pd, numpy as np, seaborn as sns
11
+ from adjustText import adjust_text
12
+ from pathlib import Path
13
+
14
+ from scipy.stats import spearmanr, pearsonr
15
+ from sklearn.metrics import confusion_matrix
16
+ from matplotlib import pyplot as plt
17
+ from matplotlib.ticker import MultipleLocator
18
+ from numpy import trapz
19
+
20
+
21
+ # Katlas
22
+ from .feature import *
23
+ from .core import *
24
+
25
+ # Bokeh
26
+ from bokeh.io import output_notebook, show
27
+ from bokeh.plotting import figure, ColumnDataSource
28
+ from bokeh.models import HoverTool, AutocompleteInput, CustomJS
29
+ from bokeh.layouts import column
30
+ from bokeh.palettes import Category20_20
31
+ from itertools import cycle
32
+
33
+ # %% ../nbs/02_plot.ipynb 6
34
+ def set_sns():
35
+ "Set seaborn resolution for notebook display"
36
+ sns.set(rc={"figure.dpi":300, 'savefig.dpi':300})
37
+ sns.set_context('notebook')
38
+ sns.set_style("ticks")
39
+
40
+ # %% ../nbs/02_plot.ipynb 7
41
+ def get_color_dict(categories, # list of names to assign color
42
+ palette: str='tab20', # choose from sns.color_palette
43
+ ):
44
+ "Assign colors to a list of names (allow duplicates), returns a dictionary of unique name with corresponding color"
45
+ p=sns.color_palette(palette)
46
+ color_cycle = cycle(p)
47
+ color_map = {category: next(color_cycle) for category in categories}
48
+ return color_map
49
+
50
+ # %% ../nbs/02_plot.ipynb 11
51
+ def logo_func(df:pd.DataFrame, # a dataframe that contains ratios for each amino acid at each position
52
+ title: str='logo', # title of the motif logo
53
+ ):
54
+ "Use logomaker plot motif logos given a df matrix "
55
+
56
+ # Indicates color scheme of the amino acid
57
+ aa = {
58
+ 'AG': '#037f04',
59
+ 'DEsty': '#da143e', # sty seems to be the same color as big ST&Y even though we set it here
60
+ 'F': '#84380b',
61
+ 'HQN': '#8d2be1',
62
+ 'LMIFWTVC': '#d9a41c',
63
+ 'P': '#000000',
64
+ 'RK': '#0000ff',
65
+ 'ST': '#8d008d', # STY overwrites the previous s,t,y as logomaker does not distingusih capital and lower case
66
+ 'Y': '#84380b',
67
+
68
+ # some old settings
69
+ # 'st':'#8d2be1',
70
+ # 'y':'#8d2be1'
71
+ # 'pS/pT':'#8d2be1',# logomaker does not support double letters like pS or pT
72
+ # 'pY':'#8d2be1'
73
+ }
74
+
75
+ # Use logomaker to plot
76
+ logo = logomaker.Logo(df,color_scheme = aa,flip_below=False,figsize=(7,2.5)) #5.5,2.5
77
+
78
+ logo.style_xticks(fmt='%d')
79
+ logo.ax.set_yticks([])
80
+ logo.ax.set_title(title)
81
+
82
+ # %% ../nbs/02_plot.ipynb 12
83
+ def get_logo(df: pd.DataFrame, # stacked Dataframe with kinase as index, substrates as columns
84
+ kinase: str, # a specific kinase name in index
85
+ ):
86
+ "Given stacked df (index as kinase, columns as substrates), get a specific kinase's logo"
87
+
88
+
89
+ # get raw kinase to calculate S/T
90
+ pp = get_one_kinase(df,kinase,normalize=False)
91
+
92
+ # get S/T ratio value
93
+ ss = pp['S'].sum()
94
+ st = pp['T'].sum()
95
+
96
+ S_ctrl = 0.75*ss - 0.25*st
97
+ T_ctrl = 0.75*st - 0.25*ss
98
+
99
+ S0 = S_ctrl / max(S_ctrl, T_ctrl)
100
+ T0 = T_ctrl / max(S_ctrl, T_ctrl)
101
+
102
+ S_ratio = S0/(S0+T0)
103
+ T_ratio = T0/(S0+T0)
104
+
105
+ # get normalized kinase
106
+ norm_p = get_one_kinase(df,kinase, normalize=True)
107
+
108
+ # calculate ratio, divide values by median, followed by log2 transformation
109
+ ratio =norm_p.apply(lambda r: r/r.median(),axis=1)
110
+ ratio = np.log2(ratio)
111
+
112
+ m = ratio.apply(lambda row: row[row > 0].sum(), axis=1).max()
113
+
114
+ new_row = pd.DataFrame({'S': S_ratio*m, 'T':T_ratio*m}, index=[0])
115
+
116
+ ratio2 = pd.concat([ratio, new_row], ignore_index=False).fillna(0)
117
+
118
+ # plot logo
119
+ logo_func(ratio2, kinase)
120
+
121
+ # %% ../nbs/02_plot.ipynb 16
122
+ def get_logo2(full: pd.DataFrame, # a dataframe that contains the full matrix of a kinase, with index as amino acid, and columns as positions
123
+ title: str = 'logo', # title of the graph
124
+ ):
125
+
126
+ "Plot logo from a full freqency matrix of a kinase"
127
+
128
+ # get S,T,Y ratio
129
+ S_ratio,T_ratio,Y_ratio = full[0][['s','t','y']]/full[0][['s','t','y']].sum()
130
+
131
+ # drop position 0
132
+ full = full.drop(columns=[0])
133
+
134
+ # identify the minimum value other than 0
135
+ min_val = full[full > 0].min().min()
136
+
137
+ # replace 0s with the identified minimum value
138
+ full = full.replace(0, min_val)
139
+
140
+ norm_p = full.T
141
+
142
+ # calculate ratio, use substraction
143
+ ratio =norm_p.apply(lambda r: r-r.median(),axis=1)
144
+
145
+ # calculate ratio based on previous method, divide values by median, followed by log2
146
+ # ratio =norm_p.apply(lambda r: r/r.median(),axis=1)
147
+ # ratio = np.log2(ratio)
148
+
149
+ # get the max value for a position
150
+ m = ratio.apply(lambda row: row[row > 0].sum(), axis=1).max()
151
+
152
+ # get the relative height of S,T,Y relative to the max value
153
+ new_row = pd.DataFrame({'S': S_ratio*m, 'T':T_ratio*m,'Y':Y_ratio*m}, index=[0])
154
+
155
+ # prepare the matrix for logomaker
156
+ ratio2 = pd.concat([ratio, new_row], ignore_index=False).fillna(0)
157
+
158
+ logo_func(ratio2,title)
159
+
160
+ # %% ../nbs/02_plot.ipynb 19
161
+ @fc.delegates(sns.scatterplot)
162
+ def plot_rank(sorted_df: pd.DataFrame, # a sorted dataframe
163
+ x: str, # column name for x axis
164
+ y: str, # column name for y aixs
165
+ n_hi: int=10, # if not None, show the head n names
166
+ n_lo: int=10, # if not None, show the tail n names
167
+ figsize: tuple=(10,8), # figure size
168
+ **kwargs # arguments for sns.scatterplot()
169
+ ):
170
+
171
+ "Plot rank from a sorted dataframe"
172
+
173
+ plt.figure(figsize=figsize)
174
+
175
+ sorted_df = sorted_df.reset_index(drop=True) # drop customized index
176
+
177
+ sns_plot = sns.scatterplot(data=sorted_df,
178
+ x = x,
179
+ y = y, **kwargs)
180
+
181
+ sns_plot.set_xticks([])
182
+
183
+
184
+ texts = []
185
+
186
+ # Annotate the highest n values
187
+ if n_hi is not None:
188
+
189
+ for i, row in sorted_df.head(n_hi).iterrows():
190
+ texts.append(plt.text(i, row[y], row[x], ha='center', va='bottom'))
191
+
192
+ if n_lo is not None:
193
+ # Annotate the lowest n values
194
+ n_lowest = n_lo
195
+ for i, row in sorted_df.tail(n_lowest).iterrows():
196
+ texts.append(plt.text(i, row[y], row[x], ha='center', va='bottom'))
197
+
198
+ if len(texts)>0:
199
+ # Use adjustText to adjust text positions
200
+ adjust_text(texts, arrowprops=dict(arrowstyle='-', color='black'))
201
+
202
+ plt.tight_layout()
203
+
204
+ # %% ../nbs/02_plot.ipynb 23
205
+ @fc.delegates(sns.histplot)
206
+ def plot_hist(df: pd.DataFrame, # a dataframe that contain values for plot
207
+ x: str, # column name of values
208
+ figsize: tuple=(6,2),
209
+ **kwargs, # arguments for sns.histplot()
210
+ ):
211
+
212
+ hist_params = {'element':'poly',
213
+ 'edgecolor': None,
214
+ 'alpha':0.5,
215
+ 'bins':100,
216
+ 'kde':True}
217
+
218
+ plt.figure(figsize=figsize)
219
+ sns.histplot(data=df,x=x,**hist_params,**kwargs)
220
+
221
+ # %% ../nbs/02_plot.ipynb 27
222
+ @fc.delegates(sns.heatmap)
223
+ def plot_heatmap(matrix, # a matrix of values
224
+ title: str='heatmap', # title of the heatmap
225
+ figsize: tuple=(6,10), # figure size of the heatmap
226
+ cmap: str='binary', # color map, default is dark&white
227
+ **kwargs, # arguments for sns.heatmap()
228
+ ):
229
+
230
+ "Plot heatmap based on a matrix of values"
231
+
232
+ plt.figure(figsize=figsize)
233
+ sns.heatmap(matrix, cmap=cmap, annot=False,**kwargs)
234
+ plt.title(title)
235
+
236
+ # %% ../nbs/02_plot.ipynb 31
237
+ @fc.delegates(sns.scatterplot)
238
+ def plot_2d(X: pd.DataFrame, # a dataframe that has first column to be x, and second column to be y
239
+ **kwargs, # arguments for sns.scatterplot
240
+ ):
241
+ "Make 2D plot from a dataframe that has first column to be x, and second column to be y"
242
+ plt.figure(figsize=(7,7))
243
+ sns.scatterplot(data = X,x=X.columns[0],y=X.columns[1],alpha=0.7,**kwargs)
244
+
245
+ # %% ../nbs/02_plot.ipynb 33
246
+ def plot_cluster(df: pd.DataFrame, # a dataframe of values that is waited for dimensionality reduction
247
+ method: str='pca', # dimensionality reduction method, choose from pca, umap, and tsne
248
+ hue: str=None, # colname of color
249
+ complexity: int=30, # recommend 30 for tsne, 15 for umap, none for pca
250
+ palette: str='tab20', # color scheme, could be tab10 if less categories
251
+ legend: bool=False, # whether or not add the legend on the side
252
+ name_list=None, # a list of names to annotate each dot in the plot
253
+ seed: int=123, # seed for dimensionality reduction
254
+ s: int=50, # size of the dot
255
+ **kwargs # arguments for dimensional reduction method to be used
256
+ ):
257
+
258
+ "Given a dataframe of values, plot it in 2d, method could be pca, tsne, or umap"
259
+
260
+ embedding_df = reduce_feature(df, method=method, seed=seed, complexity = complexity,**kwargs)
261
+ # x_col, y_col = [col for col in embedding_df.columns if col.startswith(method.upper())]
262
+ x_col, y_col = embedding_df.columns
263
+ sns.relplot(data=embedding_df, x=x_col, y=y_col, hue=hue, palette=palette, s=s, alpha=0.8, legend=legend)
264
+ plt.xticks([])
265
+ plt.yticks([])
266
+ if name_list is not None:
267
+ texts = [plt.text(embedding_df[x_col][i], embedding_df[y_col][i], name_list[i],fontsize=8) for i in range(len(embedding_df))]
268
+ adjust_text(texts, arrowprops=dict(arrowstyle='-', color='black'))
269
+
270
+ # %% ../nbs/02_plot.ipynb 37
271
+ def plot_bokeh(X:pd.DataFrame, # a dataframe of two columns from dimensionality reduction
272
+ idx, # pd.Series or list that indicates identities for searching box
273
+ hue:None, # pd.Series or list that indicates category for each sample
274
+ s: int=3, # dot size
275
+ **kwargs # key:args format for information to include in the dot information box
276
+ ):
277
+
278
+ "Make interactive 2D plot with a searching box and window of dot information when pointing "
279
+
280
+ output_notebook()
281
+
282
+ idx = list(idx)
283
+ hue = list(hue)
284
+
285
+ def assign_colors(categories, palette):
286
+ "assign each unique name in a list with a color, returns a color list of same length"
287
+ color_cycle = cycle(palette)
288
+ color_map = {category: next(color_cycle) for category in categories}
289
+ return [color_map[category] for category in categories]
290
+
291
+ if hue is not None:
292
+ colors = assign_colors(hue, Category20_20)
293
+ else:
294
+ colors = ['navy'] * len(X)
295
+
296
+ data_dict={
297
+ 'x': X.iloc[:,0],
298
+ 'y': X.iloc[:,1],
299
+ 'identity': idx,
300
+ 'color': colors,
301
+ 'original_color': colors,
302
+ 'size': [s] * len(X),
303
+ 'highlighted': ['no'] * len(X) # To keep track of which dot is highlighted
304
+ }
305
+
306
+ for key, value in kwargs.items():
307
+ data_dict[key] = value
308
+
309
+ source = ColumnDataSource(data=data_dict)
310
+
311
+ p = figure(tools="pan,box_zoom,wheel_zoom,reset")
312
+ p.scatter('x', 'y', source=source, alpha=0.6, color='color', size='size')
313
+
314
+ # Disable grid lines
315
+ p.xgrid.visible = False
316
+ p.ygrid.visible = False
317
+
318
+ # Add hover tool
319
+ hover = HoverTool()
320
+
321
+ tooltips = [("Identity", "@identity")]
322
+
323
+ for key in kwargs.keys():
324
+ tooltips.append((key.capitalize(), f"@{key}"))
325
+
326
+
327
+ hover.tooltips = tooltips
328
+ p.add_tools(hover)
329
+
330
+
331
+ autocomplete = AutocompleteInput(title="Search by Identity:", completions=idx)
332
+
333
+ callback = CustomJS(args=dict(source=source, plot=p), code="""
334
+ const data = source.data;
335
+ const search_val = cb_obj.value.toLowerCase();
336
+ const x = data['x'];
337
+ const y = data['y'];
338
+ const identity = data['identity'];
339
+ const color = data['color'];
340
+ const original_color = data['original_color'];
341
+ const size = data['size'];
342
+ const highlighted = data['highlighted'];
343
+
344
+ for (let i = 0; i < identity.length; i++) {
345
+ if (highlighted[i] === 'yes') {
346
+ color[i] = original_color[i];
347
+ size[i] = 10;
348
+ highlighted[i] = 'no';
349
+ }
350
+ if (identity[i].toLowerCase() === search_val) {
351
+ plot.x_range.start = x[i] - 5;
352
+ plot.x_range.end = x[i] + 5;
353
+ plot.y_range.start = y[i] - 5;
354
+ plot.y_range.end = y[i] + 5;
355
+ color[i] = 'red';
356
+ size[i] = 15;
357
+ highlighted[i] = 'yes';
358
+ }
359
+ }
360
+ source.change.emit();
361
+ """)
362
+ autocomplete.js_on_change('value', callback)
363
+
364
+ # Show layout
365
+ layout = column(autocomplete, p)
366
+ show(layout)
367
+
368
+ # %% ../nbs/02_plot.ipynb 40
369
+ def plot_count(cnt, # from df['x'].value_counts()
370
+ tick_spacing: float= None, # tick spacing for x axis
371
+ palette: str='tab20'):
372
+
373
+ "Make bar plot from df['x'].value_counts()"
374
+
375
+ c = sns.color_palette(palette)
376
+ ax = cnt.plot.barh(color = c)
377
+
378
+ for index, value in enumerate(cnt):
379
+ plt.text(value, index, str(value),fontsize=10,rotation=-90, va='center')
380
+ # Set x-ticks at regular intervals
381
+ if tick_spacing is not None:
382
+ ax.xaxis.set_major_locator(MultipleLocator(tick_spacing))
383
+
384
+ # %% ../nbs/02_plot.ipynb 42
385
+ @fc.delegates(sns.barplot)
386
+ def plot_bar(df,
387
+ value, # colname of value
388
+ group, # colname of group
389
+ title = None,
390
+ figsize = (12,5),
391
+ fontsize=14,
392
+ dots = True, # whether or not add dots in the graph
393
+ rotation=90,
394
+ ascending=False,
395
+ **kwargs
396
+ ):
397
+
398
+ "Plot bar graph from unstacked dataframe; need to indicate columns of values and categories"
399
+
400
+ plt.figure(figsize=figsize)
401
+
402
+ idx = df.groupby(group)[value].mean().sort_values(ascending=ascending).index
403
+
404
+ sns.barplot(data=df, x=group, y=value, order=idx, **kwargs)
405
+
406
+ if dots:
407
+ marker = {'marker': 'o',
408
+ 'color': 'white',
409
+ 'edgecolor': 'black',
410
+ 'linewidth': 1.5,
411
+ 'jitter':True,
412
+ 's': 5}
413
+
414
+ sns.stripplot(data=df,
415
+ x=group,
416
+ y=value,
417
+ order=idx,
418
+ alpha=0.8,
419
+ # ax=g.ax,
420
+ **marker)
421
+
422
+ # Increase font size for the x-axis and y-axis tick labels
423
+ plt.tick_params(axis='x', labelsize=fontsize) # Increase x-axis label size
424
+ plt.tick_params(axis='y', labelsize=fontsize) # Increase y-axis label size
425
+
426
+ # Modify x and y label and increase font size
427
+ plt.xlabel('', fontsize=fontsize)
428
+ plt.ylabel(value, fontsize=fontsize)
429
+
430
+ # Rotate X labels
431
+ plt.xticks(rotation=rotation)
432
+
433
+ # Plot titles
434
+ if title is not None:
435
+ plt.title(title,fontsize=fontsize)
436
+
437
+ plt.gca().spines[['right', 'top']].set_visible(False)
438
+
439
+ # %% ../nbs/02_plot.ipynb 45
440
+ @fc.delegates(sns.barplot)
441
+ def plot_group_bar(df,
442
+ value_cols, # list of column names for values, the order depends on the first item
443
+ group, # column name of group (e.g., 'kinase')
444
+ figsize=(12, 5),
445
+ order=None,
446
+ title=None,
447
+ fontsize=14,
448
+ rotation=90,
449
+ **kwargs):
450
+
451
+ " Plot grouped bar graph from dataframe. "
452
+
453
+ # Prepare the dataframe for plotting
454
+ # Melt the dataframe to go from wide to long format
455
+ df_melted = df.melt(id_vars=group, value_vars=value_cols, var_name='Ranking', value_name='Value')
456
+
457
+ plt.figure(figsize=figsize)
458
+
459
+ # Create the bar plot
460
+ sns.barplot(data=df_melted, x=group, y='Value', hue='Ranking', order=order,
461
+ capsize=0.1,errwidth=1.5,errcolor='gray', # adjust the error bar settings
462
+ **kwargs)
463
+
464
+ # Increase font size for the x-axis and y-axis tick labels
465
+ plt.tick_params(axis='x', labelsize=fontsize) # Increase x-axis label size
466
+ plt.tick_params(axis='y', labelsize=fontsize) # Increase y-axis label size
467
+
468
+ # Modify x and y label and increase font size
469
+ plt.xlabel('', fontsize=fontsize)
470
+ plt.ylabel('Value', fontsize=fontsize)
471
+
472
+ # Rotate X labels
473
+ plt.xticks(rotation=rotation)
474
+
475
+ # Plot titles
476
+ if title is not None:
477
+ plt.title(title, fontsize=fontsize)
478
+
479
+ plt.gca().spines[['right', 'top']].set_visible(False)
480
+ plt.legend(fontsize=fontsize) # if change legend location, use loc='upper right'
481
+
482
+ # %% ../nbs/02_plot.ipynb 48
483
+ @fc.delegates(sns.boxplot)
484
+ def plot_box(df,
485
+ value, # colname of value
486
+ group, # colname of group
487
+ title=None,
488
+ figsize=(6,3),
489
+ fontsize=14,
490
+ dots=True,
491
+ rotation=90,
492
+ **kwargs
493
+ ):
494
+
495
+ "Plot box plot."
496
+
497
+ plt.figure(figsize=figsize)
498
+
499
+ idx = df[[group,value]].groupby(group).median().sort_values(value,ascending=False).index
500
+
501
+
502
+ sns.boxplot(data=df, x=group, y=value, order=idx, **kwargs)
503
+
504
+ if dots:
505
+ sns.stripplot(x=group, y=value, data=df, order=idx, jitter=True, color='black', size=3)
506
+
507
+
508
+ # Increase font size for the x-axis and y-axis tick labels
509
+ plt.tick_params(axis='x', labelsize=fontsize) # Increase x-axis label size
510
+ plt.tick_params(axis='y', labelsize=fontsize) # Increase y-axis label size
511
+
512
+ plt.xlabel('', fontsize=fontsize)
513
+ plt.ylabel(value, fontsize=fontsize)
514
+
515
+ plt.xticks(rotation=rotation)
516
+
517
+ if title is not None:
518
+ plt.title(title,fontsize=fontsize)
519
+
520
+ # Remove right and top spines
521
+ # plt.gca().spines[['right', 'top']].set_visible(False)
522
+
523
+
524
+ # %% ../nbs/02_plot.ipynb 51
525
+ @fc.delegates(sns.regplot)
526
+ def plot_corr(x, # x axis values, or colname of x axis
527
+ y, # y axis values, or colname of y axis
528
+ xlabel=None,# x axis label
529
+ ylabel=None,# y axis label
530
+ data = None, # dataframe that contains data
531
+ text_location = [0.8,0.1],
532
+ **kwargs
533
+ ):
534
+ "Given a dataframe and the name of two columns, plot the two columns' correlation"
535
+ if data is not None:
536
+ x=data[x]
537
+ y=data[y]
538
+
539
+ pear, pvalue = pearsonr(x, y)
540
+
541
+ sns.regplot(
542
+ x=x,
543
+ y=y,
544
+ line_kws={'color': 'gray'}, **kwargs
545
+ )
546
+
547
+ if xlabel is not None:
548
+ plt.xlabel(xlabel)
549
+
550
+ if ylabel is not None:
551
+ plt.ylabel(ylabel)
552
+
553
+ # correlation_text = f'Spearman: {correlation:.2f}' if method == 'spearman' else f'Pearson: {correlation:.2f}'
554
+
555
+ # plt.text(x=0.8, y=0.1, s=correlation_text, transform=plt.gca().transAxes, ha='center', va='center')
556
+ plt.text(s=f'Pearson = {round(pear,2)}\n p = {"{:.2e}".format(pvalue)}',
557
+ x=text_location[0],y=text_location[1],
558
+ transform=plt.gca().transAxes,
559
+ ha='center', va='center')
560
+
561
+ # %% ../nbs/02_plot.ipynb 55
562
+ def draw_corr(corr):
563
+
564
+ "plot heatmap from df.corr()"
565
+
566
+ # Mask for the upper triangle
567
+ mask = np.triu(np.ones_like(corr, dtype=bool))
568
+
569
+ # Plotting the heatmap
570
+ plt.figure(figsize=(20, 16)) # Set the figure size
571
+ sns.heatmap(corr, annot=True, cmap='coolwarm', vmin=-1, vmax=1, mask=mask, fmt='.2f')
572
+
573
+ # %% ../nbs/02_plot.ipynb 59
574
+ def get_AUCDF(df,col, reverse=False,plot=True,xlabel='Rank of reported kinase'):
575
+
576
+ "Plot CDF curve and get relative area under the curve"
577
+
578
+ # sort col values as x values
579
+ x_values = df[col].sort_values().values
580
+
581
+ # get y_values evenly distributed from 0 to 1
582
+ # y_values = np.arange(1, len(x_values) + 1) / len(x_values) # this method assumes equal distribution of each x value
583
+ y_values = pd.Series(x_values).rank(method='average', pct=True).values # this method takes duplicates into account
584
+
585
+ if reverse:
586
+ y_values = 1 - y_values + y_values.min() # Adjust for reverse while keeping the distribution's integrity
587
+ # calculate the area under the curve using the trapezoidal rule
588
+ area_under_curve = trapz(y_values, x_values)
589
+
590
+ # calculate total area
591
+ total_area = (x_values[-1] - x_values[0]) * (y_values[-1] - y_values[0])
592
+
593
+
594
+ AUCDF = area_under_curve / total_area
595
+ if reverse:
596
+ AUCDF = -AUCDF
597
+
598
+ if plot:
599
+ # Create a figure and a primary axis
600
+ fig, ax1 = plt.subplots(figsize=(7,5))
601
+
602
+ # fontsize
603
+ fontsize=17
604
+
605
+ # Plot the histogram on the primary axis
606
+ sns.histplot(x_values,bins=20,ax=ax1)
607
+ ax1.set_xlabel(xlabel,fontsize=fontsize)
608
+ ax1.set_ylabel('Substrates',color='darkblue',fontsize=fontsize)
609
+ ax1.tick_params(axis='y', labelcolor='darkblue',labelsize=fontsize)
610
+ ax1.tick_params(axis='x', labelcolor='black',labelsize=fontsize)
611
+ ax1.set_xlim(min(x_values),max(x_values))
612
+
613
+ # Create a secondary axis for the CDF
614
+ ax2 = ax1.twinx()
615
+
616
+ # Plot the CDF on the secondary axis
617
+ # ax2.plot(bin_edges[:-1], cumulative_data, color='red', linestyle='-', linewidth=2.0)
618
+ ax2.plot(x_values, y_values, color='darkred', linestyle='-', linewidth=2.0)
619
+ if reverse:
620
+ ax2.plot([max(x_values),0],[0, max(y_values)], 'k--') # 'k--' is for a black dashed line
621
+ else:
622
+ ax2.plot([0, max(x_values)], [0, max(y_values)], 'k--') # 'k--' is for a black dashed line
623
+
624
+ ax2.set_ylabel('Probability', color='darkred',fontsize=fontsize,rotation=270,labelpad=18)
625
+ if reverse:
626
+ ax2.text(0.45, 0.3, f"AUCDF:{AUCDF.round(4)}", transform=plt.gca().transAxes, ha='right', va='bottom',fontsize=fontsize)
627
+ else:
628
+ ax2.text(0.95, 0.3, f"AUCDF:{AUCDF.round(4)}", transform=plt.gca().transAxes, ha='right', va='bottom',fontsize=fontsize)
629
+ ax2.tick_params(axis='y', labelcolor='darkred',labelsize=fontsize)
630
+ ax2.set_ylim(0, 1) # Probabilities range from 0 to 1
631
+
632
+ # Show the plot
633
+ plt.title(f'{len(x_values)} kinase-substrate pairs',fontsize=fontsize)
634
+ plt.show()
635
+
636
+ return AUCDF
637
+
638
+ # %% ../nbs/02_plot.ipynb 62
639
+ def plot_confusion_matrix(target, # pd.Series
640
+ pred, # pd.Series
641
+ class_names:list=['0','1'],
642
+ normalize=False,
643
+ title='Confusion matrix',
644
+ cmap=plt.cm.Blues):
645
+
646
+ "Plot the confusion matrix."
647
+
648
+ cm = confusion_matrix(target, pred)
649
+
650
+ if normalize:
651
+ cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
652
+ print("Normalized confusion matrix")
653
+ else:
654
+ print('Confusion matrix, without normalization')
655
+
656
+
657
+ plt.figure(figsize=(6,6))
658
+ sns.heatmap(cm, annot=True, cmap=cmap) # Plot the heatmap
659
+ plt.title(title)
660
+ plt.ylabel('True label')
661
+ plt.xlabel('Predicted label')
662
+ plt.xticks(np.arange(len(class_names)) + 0.5, class_names)
663
+ plt.yticks(np.arange(len(class_names)) + 0.5, class_names, rotation=0)