levseq 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1130 @@
1
+ ###############################################################################
2
+ # #
3
+ # This program is free software: you can redistribute it and/or modify #
4
+ # it under the terms of the GNU General Public License as published by #
5
+ # the Free Software Foundation, either version 3 of the License, or #
6
+ # (at your option) any later version. #
7
+ # #
8
+ # This program is distributed in the hope that it will be useful, #
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of #
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
11
+ # GNU General Public License for more details. #
12
+ # #
13
+ # You should have received a copy of the GNU General Public License #
14
+ # along with this program. If not, see <http://www.gnu.org/licenses/>. #
15
+ # #
16
+ ###############################################################################
17
+
18
+ # Import all packages
19
+ from __future__ import annotations
20
+
21
+ from collections import Counter
22
+
23
+ import warnings
24
+
25
+ import os
26
+ import sys
27
+ from copy import deepcopy
28
+
29
+ import pandas as pd
30
+ import numpy as np
31
+
32
+ from Bio import AlignIO
33
+ from Bio.motifs import Motif
34
+ from Bio.PDB.Polypeptide import aa1
35
+ from Bio.Align import MultipleSeqAlignment
36
+
37
+ import matplotlib.pyplot as plt
38
+ import matplotlib as mpl
39
+
40
+ import holoviews as hv
41
+ from holoviews.streams import Tap
42
+
43
+ import ninetysix as ns
44
+ import colorcet as cc
45
+
46
+ from bokeh.plotting import figure
47
+ from bokeh.models import (
48
+ ColumnDataSource,
49
+ Range1d,
50
+ CustomJS,
51
+ RangeSlider,
52
+ TapTool,
53
+ HoverTool,
54
+ Label,
55
+ Div,
56
+ FactorRange,
57
+ Legend,
58
+ LegendItem,
59
+ )
60
+ from bokeh.models.glyphs import Text, Rect
61
+ from bokeh.layouts import column, gridplot, row, Spacer
62
+ from bokeh.events import Tap
63
+ from bokeh.io import save, show, output_file, output_notebook
64
+
65
+ import panel as pn
66
+
67
+ from levseq.utils import *
68
+
69
+ output_notebook()
70
+
71
+ pn.extension()
72
+ pn.config.comms = "vscode"
73
+
74
+ hv.extension("bokeh")
75
+ hv.renderer("bokeh").webgl = True
76
+
77
+ # warnings.filterwarnings("ignore")
78
+ #warnings.filterwarnings("ignore", category=Warning)
79
+
80
+ #with warnings.catch_warnings():
81
+ # warnings.simplefilter("ignore")
82
+
83
+ # Redirect stderr to devnull to suppress any remaining output
84
+ #sys.stderr = open(os.devnull, "w")
85
+
86
+ ######## Define constants for MSA alignments ########
87
+
88
+ # Set light gray for seq matched with the reference
89
+ match_color = "#d9d9d9"
90
+
91
+ # Set nuecleotide colors for the MSA alignment plot
92
+ NUC_COLOR_DICT = {
93
+ "A": "green",
94
+ "T": "red",
95
+ "G": "black",
96
+ "C": "blue",
97
+ "-": "white",
98
+ "N": "gray",
99
+ }
100
+
101
+
102
+ def get_well_ids():
103
+ """
104
+ Generate a list of well IDs for a 96-well plate.
105
+ """
106
+ # Initialize an empty list to store
107
+ well_ids = []
108
+
109
+ # Loop through the rows (A to H)
110
+ for row in "ABCDEFGH":
111
+ # Loop through the columns (1 to 12)
112
+ for col in range(1, 13):
113
+ # Combine the row and column to form the well position
114
+ well_ids.append(f"{row}{col}")
115
+ return deepcopy(well_ids)
116
+
117
+
118
+ WELL_IDS = get_well_ids()
119
+
120
+
121
+ def well2nb(well_id):
122
+ """
123
+ Given a well ID, return the row and column.
124
+ """
125
+ row = ord(well_id[0]) - 64
126
+ col = int(well_id[1:])
127
+ nb = (row - 1) * 12 + col
128
+ return f"NB{'0' if nb < 10 else ''}{nb}"
129
+
130
+
131
+ # Function for making plate maps
132
+ def _make_platemap(df, title, cmap=None):
133
+ """Generates a plate heatmap from LevSeq data using Holoviews with
134
+ bokeh backend.
135
+
136
+ Called via `generate_platemaps`; see docs there.
137
+ """
138
+ # Convert SeqDepth to log for easier visualization.
139
+ df["logseqdepth"] = np.log(
140
+ df["Alignment Count"],
141
+ out=np.zeros_like(df["Alignment Count"], dtype=float),
142
+ where=(df["Alignment Count"] != 0),
143
+ )
144
+
145
+ # Create necessary Row and Column values and sort
146
+ df = df.sort_values(["Column", "Row"])
147
+ df["Column"] = df["Column"].astype("str")
148
+
149
+ # Set some base opts
150
+ opts = dict(invert_yaxis=True, title=title, show_legend=True)
151
+
152
+ # logseqdepth heatmap
153
+ seq_depth_cmap = list(reversed(cc.CET_D9))
154
+
155
+ # Set the center
156
+ center = np.log(10)
157
+
158
+ add_min = False
159
+ if df["logseqdepth"].min() >= center:
160
+ add_min = True
161
+
162
+ # Adjust if it is greater than max of data (avoids ValueError)
163
+ if df["logseqdepth"].max() <= center:
164
+ # Adjust the center
165
+ center = df["logseqdepth"].median()
166
+
167
+ # center colormap
168
+ if not add_min:
169
+ color_levels = ns.viz._center_colormap(df["logseqdepth"], center)
170
+ else:
171
+ color_levels = ns.viz._center_colormap(
172
+ list(df["logseqdepth"]) + [np.log(1)], center
173
+ )
174
+
175
+ # Get heights
176
+ n_rows = len(df["Row"].unique())
177
+ n_cols = len(df["Column"].unique())
178
+ height = int(50 * n_rows)
179
+ width = height * n_cols // n_rows
180
+
181
+ # add tooltips
182
+ tooltips = [
183
+ ("Mutations", "@Mutations"),
184
+ ("Alignment Count", "@Alignment Count"),
185
+ ("Alignment Probability", "@Alignment Probability"),
186
+ ]
187
+
188
+ def hook(plot, element):
189
+ plot.handles["y_range"].factors = list("HGFEDCBA")
190
+ plot.handles["x_range"].factors = [str(value) for value in range(1, 13)]
191
+
192
+ # generate the heatmap
193
+ hm = (
194
+ hv.HeatMap(
195
+ df,
196
+ kdims=["Column", "Row"],
197
+ vdims=[
198
+ "logseqdepth",
199
+ "Mutations",
200
+ "Alignment Count",
201
+ "Alignment Probability",
202
+ ],
203
+ )
204
+ .redim.values(row=np.unique(df["Row"]), Column=np.unique(df["Column"]))
205
+ .opts(
206
+ **opts,
207
+ colorbar=True,
208
+ cmap=seq_depth_cmap,
209
+ height=height,
210
+ width=width,
211
+ line_width=4,
212
+ clipping_colors={"NaN": "#DCDCDC"},
213
+ color_levels=color_levels,
214
+ tools=["hover"],
215
+ colorbar_opts=dict(title="LogSeqDepth", background_fill_alpha=0),
216
+ hooks=[hook],
217
+ )
218
+ )
219
+ # function to bin the alignment frequencies into more relevant groupings
220
+ def bin_align_freq(value):
221
+ if value > 0.95:
222
+ bin_vals = "0.95+"
223
+ elif value <= 0.95 and value > 0.9:
224
+ bin_vals = "0.90-0.95"
225
+ elif value <= 0.9 and value > 0.8:
226
+ bin_vals = "0.80-0.90"
227
+ # anything below 0.8 should really be discarded
228
+ else:
229
+ bin_vals = "<0.80"
230
+ return bin_vals
231
+
232
+ # Bin alignment frequencies for easier viz
233
+ bins = ["0.95+", "0.90-0.95", "0.80-0.90", "<0.80"]
234
+ if cmap is None:
235
+ cmap = [cc.bmy[int((1.1 - i) * len(cc.bmy))] for i in [0.95, 0.9, 0.8, 0.4]]
236
+ if "stoplight" in cmap:
237
+ cmap = ["#337D1F", "#94CD35", "#FFC300", "#C62C20"]
238
+ else:
239
+ # Validate colormap
240
+ if not isinstance(cmap, (list, tuple)):
241
+ raise ValueError("cmap argument must be a list or tuple")
242
+ if len(cmap) > 4:
243
+ raise ValueError(
244
+ "cmap argument has too many entries; only 4 should be passed"
245
+ )
246
+ cmap = {bin: color for bin, color in zip(bins, cmap)}
247
+
248
+ # apply binning function to the AlignmentFrequency
249
+ df["AlignmentProbabilityBinned"] = df["Alignment Probability"].apply(bin_align_freq)
250
+
251
+ # Set up size of the outline boxes
252
+ box_size = height // n_rows * 1.2
253
+
254
+ # alignment frequency heatmap for edges around wells
255
+ boxes = hv.Points(
256
+ df.sort_values(["Alignment Probability"], ascending=False),
257
+ ["Column", "Row"],
258
+ "AlignmentProbabilityBinned",
259
+ ).opts(
260
+ **opts,
261
+ marker="square",
262
+ line_color="AlignmentProbabilityBinned",
263
+ line_join="miter",
264
+ cmap=cmap,
265
+ line_width=6,
266
+ fill_alpha=0,
267
+ line_alpha=1,
268
+ legend_position="top",
269
+ size=box_size,
270
+ )
271
+
272
+ # Use in apply statement for residue labels
273
+ def split_variant_labels(mutation_string):
274
+
275
+ num_mutations = len(mutation_string.split("_"))
276
+
277
+ if num_mutations > 4:
278
+ return str(num_mutations) + " muts"
279
+
280
+ mutation_string = mutation_string.replace("?", "")
281
+ new_line_mutations = mutation_string.replace("_", "\n")
282
+
283
+ return new_line_mutations
284
+
285
+ _df = df.copy()
286
+ _df["Labels"] = _df["Mutations"].apply(split_variant_labels)
287
+
288
+ # Set the font size based on if #PARENT# is in a well and num of mutations
289
+ max_num_mutations = _df["Labels"].apply(lambda x: len(x.split("\n"))).max()
290
+ has_parent = "#PARENT#" in _df["Labels"]
291
+
292
+ if max_num_mutations > 3 or has_parent:
293
+ label_fontsize = "8pt"
294
+ else:
295
+ label_fontsize = "8pt"
296
+
297
+ labels = hv.Labels(
298
+ _df,
299
+ ["Column", "Row"],
300
+ "Labels",
301
+ ).opts(text_font_size=label_fontsize, **opts, text_color="#000000")
302
+ # return formatted final plot
303
+ return (hm * boxes * labels).opts(
304
+ frame_height=550, frame_width=550 * 3 // 2, border=50, show_legend=True
305
+ )
306
+
307
+
308
+ # Main function to return heatmap with or without alignment
309
+ def generate_platemaps(
310
+ max_combo_data,
311
+ result_folder,
312
+ cmap=None,
313
+ show_msa=False,
314
+ widget_location="top_left",
315
+ ):
316
+ """Saves a plate heatmap html generated from from evSeq data.
317
+
318
+ Input:
319
+ ------
320
+ max_combo_data: path (str) or DartaFrame
321
+ Path to 'variants.csv' from an LevSeq experiment or
322
+ a pandas DataFrame of that file.
323
+ cmap: list-like or str, default None
324
+ The colormap to use for the well outline indicating alignment
325
+ frequency. If None, defaults to a Plasma-like (colorcet.bmy)
326
+ colormap. If 'stoplight', uses a green-yellow-red colormap (not
327
+ the most colorblind friendly, but highly intuitive). Otherwise
328
+ you may pass any list -like object containing four colors (e.g.,
329
+ ['#337D1F', '#94CD35', '#FFC300', '#C62C20'] for 'stoplight').
330
+ show_msa: bool, default False
331
+ widget_location: string, default 'top_left'
332
+ Location of the widget for navigating plots. Must be one of:
333
+ ['left', 'bottom', 'right', 'top', 'top_left', 'top_right',
334
+ 'bottom_left', 'bottom_right', 'left_top', 'left_bottom',
335
+ 'right_top', 'right_bottom'].
336
+
337
+ Returns:
338
+ --------
339
+ hm_holomap: an interactive Platemap
340
+ unique_plates: list of unique plates in the data,
341
+ plate2barcode: dictionary mapping plate to barcode_plate
342
+ """
343
+
344
+ # Convert to dataframe if necessary
345
+ if isinstance(max_combo_data, str):
346
+ max_combo_df = pd.read_csv(max_combo_data)
347
+ else:
348
+ max_combo_df = max_combo_data.copy()
349
+
350
+ # Identify unique plates
351
+ unique_plates = max_combo_df.Plate.unique()
352
+
353
+ # Create a new DataFrame to modify without affecting the original
354
+ temp_df = max_combo_df.copy()
355
+
356
+ # Convert barcode_plate to string and modify its format
357
+ temp_df["barcode_plate"] = temp_df["barcode_plate"].apply(
358
+ lambda x: "RB0" + str(x) if x < 10 else "RB" + str(x)
359
+ )
360
+
361
+ # Create a dictionary with unique Plate to modified barcode_plate mapping
362
+ plate2barcode = (
363
+ temp_df[["Plate", "barcode_plate"]]
364
+ .drop_duplicates("Plate")
365
+ .set_index("Plate")["barcode_plate"]
366
+ .to_dict()
367
+ )
368
+
369
+ # make logseqdepth column
370
+ max_combo_df["logseqdepth"] = np.log(
371
+ max_combo_df["Alignment Count"],
372
+ out=np.zeros_like(max_combo_df["Alignment Count"], dtype=float),
373
+ where=max_combo_df["Alignment Count"] != 0,
374
+ )
375
+ # Set the center
376
+ center = np.log(10)
377
+
378
+ add_min = False
379
+ if max_combo_df["logseqdepth"].min() >= center:
380
+ add_min = True
381
+
382
+ # Adjust if it is greater than max of data (avoids ValueError)
383
+ if max_combo_df["logseqdepth"].max() <= center:
384
+ # Adjust the center
385
+ center = max_combo_df["logseqdepth"].median()
386
+
387
+ # center colormap
388
+ if not add_min:
389
+ color_levels = ns.viz._center_colormap(max_combo_df["logseqdepth"], center)
390
+ else:
391
+ color_levels = ns.viz._center_colormap(
392
+ list(max_combo_df["logseqdepth"]) + [np.log(1)], center
393
+ )
394
+
395
+ # dictionary for storing plots
396
+ hm_dict = {}
397
+ aln_dict = {}
398
+
399
+ # Uniform color levels
400
+ for _hm in hm_dict.values():
401
+ _hm.opts({"HeatMap": {"color_levels": color_levels}})
402
+
403
+ # Create dropdowns
404
+ plate_selector = pn.widgets.Select(name="Plate", options=list(unique_plates))
405
+
406
+ if show_msa:
407
+
408
+ well_selector = pn.widgets.Select(name="Well", options=WELL_IDS)
409
+
410
+ for plate in unique_plates:
411
+
412
+ # Split to just the information of interest
413
+ df = max_combo_df.loc[max_combo_df.Plate == plate].copy()
414
+
415
+ # generate a holoviews plot
416
+ hm_dict[plate] = _make_platemap(df, title=plate, cmap=cmap)
417
+
418
+ def get_plate_well(plate, well_id):
419
+
420
+ """
421
+ Get the platemap and alignment plot for a given well
422
+
423
+ Args:
424
+ - plate: str, plate name
425
+ - well_id: str, well ID, ie. 'A1'
426
+ """
427
+
428
+ # Split to just the information of interest
429
+ df = max_combo_df.loc[max_combo_df.Plate == plate].copy()
430
+
431
+ hm_bokeh = hv.render(hm_dict[plate], backend="bokeh")
432
+ hm_bokeh.toolbar_location = "right"
433
+ hm_bokeh.toolbar.active_drag = None
434
+ hm_bokeh.toolbar.active_scroll = None
435
+
436
+ # hm_bokeh = hv.render(
437
+ # _make_platemap(df, title=plate, cmap=cmap), backend="bokeh"
438
+ # )
439
+
440
+ # hm_bokeh.toolbar_location = 'right'
441
+ # hm_bokeh.toolbar.active_drag = None
442
+ # hm_bokeh.toolbar.active_scroll = None
443
+ # hm_bokeh = hm_dict.get(plate, pn.pane.Markdown("No platemap available for this plate"))
444
+
445
+ # Get the row and column
446
+ aln_path = os.path.join(
447
+ result_folder,
448
+ plate,
449
+ plate2barcode[plate],
450
+ well2nb(well_id),
451
+ f"msa_{plate}_{well_id}.fa",
452
+ )
453
+
454
+ # plot the alignment using the nc_variant sequence
455
+ aln = plot_sequence_alignment(
456
+ aln_path,
457
+ parent_name=plate,
458
+ well_seq=df[(df["Row"] == well_id[0]) & (df["Column"] == well_id[1:])][
459
+ "nc_variant"
460
+ ].values[0],
461
+ markdown_title=f"{result_folder} {plate} {plate2barcode[plate]} {well2nb(well_id)} {well_id}",
462
+ )
463
+
464
+ # generate a holoviews plot
465
+ return gridplot(
466
+ [[hm_bokeh], [aln]],
467
+ toolbar_location="right",
468
+ sizing_mode="fixed", # "stretch_width",
469
+ )
470
+
471
+ # Function to update the plots based on dropdown selection
472
+ @pn.depends(plate=plate_selector.param.value, well_id=well_selector.param.value)
473
+ def update_plot(plate, well_id):
474
+ """
475
+ Update the plot based on the dropdown selection
476
+
477
+ Args:
478
+ - plate: str, plate name
479
+ - well_id: str, well ID, ie. 'A1'
480
+ """
481
+ return get_plate_well(plate, well_id)
482
+
483
+ # Layout the dropdowns and the plot
484
+ return pn.Column(pn.Row(plate_selector, well_selector), pn.Column(update_plot))
485
+
486
+ else:
487
+
488
+ # Generate plots for each plate
489
+ for plate in unique_plates:
490
+
491
+ # Split to just the information of interest
492
+ df = max_combo_df.loc[max_combo_df.Plate == plate].copy()
493
+
494
+ # generate a holoviews plot
495
+ hm_dict[plate] = _make_platemap(df, title=plate, cmap=cmap)
496
+
497
+ # plot from the dictionary
498
+ hm_holomap = hv.HoloMap(hm_dict, kdims=["Plate"])
499
+ # Update widget location
500
+ hv.output(widget_location=widget_location)
501
+
502
+ return hm_holomap
503
+
504
+
505
+ ########### Functions for the MSA alignment plot ###########
506
+ def get_sequence_colors(seqs: list, palette="viridis") -> list[str]:
507
+
508
+ """
509
+ Get colors for a sequence without parent seq highlighting differences
510
+
511
+ Args:
512
+ - seqs: list of sequences
513
+ - palette: str, name of the color palette
514
+
515
+ Returns:
516
+ - list: colors for each nucleotide
517
+ """
518
+
519
+ aas = deepcopy(ALL_AAS)
520
+ aas.append("-")
521
+ aas.append("X")
522
+
523
+ pal = plt.colormaps[palette]
524
+ pal = [mpl.colors.to_hex(i) for i in pal(np.linspace(0, 1, 20))]
525
+ pal.append("white")
526
+ pal.append("gray")
527
+
528
+ pcolors = {i: j for i, j in zip(aas, pal)}
529
+ nuc = [i for s in list(seqs) for i in s]
530
+
531
+ try:
532
+ colors = [NUC_COLOR_DICT[i] for i in nuc]
533
+ except:
534
+ colors = [pcolors[i] for i in nuc]
535
+
536
+ return colors
537
+
538
+
539
+ def get_sequence_diff_colorNseq(seqs: list, seq_ids: list, parent_seq: str) -> tuple:
540
+
541
+ """
542
+ Get colors and nucleotides for input sequences highlighting differences from parent
543
+
544
+ Args:
545
+ - seqs: str, list of sequences
546
+ - seq_ids: str, list of sequence ids
547
+ - parent_seq: str, parent sequence to compare against
548
+
549
+ Returns:
550
+ - block_colors: list of colors for each nucleotide highlighting differences
551
+ - nuc_colors: list of colors for each nucleotide
552
+ - nucs: list of nucleotides highlighting differences
553
+ - spacers: list of spacers (for plotting)
554
+ """
555
+ # color for the highlighted nuc block over text
556
+ block_colors = []
557
+ # color for the nuc text to be annotated
558
+ nuc_textcolors = []
559
+ # init nuc to annotate
560
+ diff_nucs = []
561
+ # parent nuc or spacer annotation
562
+ text_annot = []
563
+
564
+ for seq, seq_id in zip(seqs, seq_ids):
565
+ if seq_id == "parent":
566
+ for p in list(parent_seq):
567
+ block_colors.append(match_color)
568
+ nuc_textcolors.append(NUC_COLOR_DICT[p])
569
+ diff_nucs.append(" ")
570
+ text_annot.append(" ")
571
+ else:
572
+ for n, p in zip(list(seq), list(parent_seq)):
573
+ if n != p:
574
+ block_colors.append(NUC_COLOR_DICT[n])
575
+ diff_nucs.append(n)
576
+ if n == "-":
577
+ text_annot.append("-")
578
+ else:
579
+ text_annot.append(" ")
580
+ else:
581
+ block_colors.append(match_color)
582
+ diff_nucs.append(" ")
583
+ text_annot.append(" ")
584
+
585
+ nuc_textcolors.append("gray")
586
+
587
+ return block_colors, nuc_textcolors, diff_nucs, text_annot
588
+
589
+
590
+ def get_cons(aln: MultipleSeqAlignment) -> list[float]:
591
+
592
+ """
593
+ Get conservation values from alignment
594
+
595
+ Args:
596
+ - aln: MultipleSeqAlignment, input alignment
597
+
598
+ Returns:
599
+ - list: conservation values
600
+ """
601
+
602
+ x = []
603
+ l = len(aln)
604
+ for i in range(aln.get_alignment_length()):
605
+ a = aln[:, i]
606
+ res = Counter(a)
607
+ del res["-"]
608
+ x.append(max(res.values()) / l)
609
+ return x
610
+
611
+
612
+ def get_cons_seq(aln: MultipleSeqAlignment, ifdeg: bool = True) -> str:
613
+
614
+ """
615
+ Ger consensus sequence from alignment
616
+
617
+ Args:
618
+ - aln: MultipleSeqAlignment, input alignment
619
+ - ifdeg: bool, if True, return degenerate consensus
620
+
621
+ Returns:
622
+ - str: consensus sequences
623
+ """
624
+
625
+ alignment = aln.alignment
626
+ motif = Motif("ACGT", alignment)
627
+
628
+ if ifdeg:
629
+ return motif.degenerate_consensus
630
+ else:
631
+ return motif.consensus
632
+
633
+
634
+ def get_cons_diff_colorNseq(cons_seq: str, parent_seq: str) -> tuple:
635
+
636
+ """
637
+ Get consensus sequence highlighting differences from parent
638
+
639
+ Args:
640
+ - cons_seq: str, consensus sequence
641
+ - parent_seq: str, parent sequence
642
+
643
+ Returns:
644
+ - colors: list, colors for each nucleotide highlighting differences
645
+ - cons_seq_diff: list, nucleotides highlighting differences
646
+ """
647
+
648
+ colors = []
649
+ cons_seq_diff = []
650
+ for n, p in zip(list(cons_seq), list(parent_seq)):
651
+ if n != p:
652
+ cons_seq_diff.append(n)
653
+ if n in NUC_COLOR_DICT.keys():
654
+ colors.append(NUC_COLOR_DICT[n])
655
+ else:
656
+ colors.append("#f2f2f2")
657
+ else:
658
+ cons_seq_diff.append(" ")
659
+ colors.append(match_color)
660
+
661
+ return colors, cons_seq_diff
662
+
663
+
664
+ def aggregate_gray_blocks(x_vals: list, y_vals: list, colors: list, text: list):
665
+
666
+ """
667
+ Aggregate gray blocks in the MSA alignment plot
668
+ to reduce the number of elements for plotting
669
+
670
+ Args:
671
+ - x_vals: list, x values
672
+ - y_vals: list, y values
673
+ - colors: list, colors
674
+ - text: list, text
675
+
676
+ Returns:
677
+ - aggregated_x: list, aggregated x values
678
+ - aggregated_y: list, aggregated y values
679
+ - aggregated_width: list, aggregated width values
680
+ - aggregated_height: list, aggregated height values
681
+ - aggregated_colors: list, aggregated colors
682
+ - aggregated_text: list, aggregated text
683
+ """
684
+
685
+ aggregated_x = []
686
+ aggregated_y = []
687
+ aggregated_width = []
688
+ aggregated_height = []
689
+ aggregated_colors = []
690
+ aggregated_text = []
691
+
692
+ current_x_start = None
693
+ current_y = None
694
+ current_width = 0
695
+
696
+ for i, (x, y, color, t) in enumerate(zip(x_vals, y_vals, colors, text)):
697
+ if color == "gray":
698
+ # Start or continue aggregating gray blocks
699
+ if current_x_start is None:
700
+ current_x_start = x
701
+ current_y = y
702
+ current_width = 1
703
+ else:
704
+ if y == current_y:
705
+ # Continue aggregating in the same row
706
+ current_width += 1
707
+ else:
708
+ # Row changed, finalize the current block
709
+ aggregated_x.append(current_x_start + current_width / 2)
710
+ aggregated_y.append(current_y)
711
+ aggregated_width.append(current_width)
712
+ aggregated_height.append(1)
713
+ aggregated_colors.append("gray")
714
+ aggregated_text.append("")
715
+
716
+ # Start a new gray block for the new row
717
+ current_x_start = x
718
+ current_y = y
719
+ current_width = 1
720
+ else:
721
+ # Add the current gray block if it exists
722
+ if current_x_start is not None:
723
+ aggregated_x.append(current_x_start + current_width / 2)
724
+ aggregated_y.append(current_y)
725
+ aggregated_width.append(current_width)
726
+ aggregated_height.append(1)
727
+ aggregated_colors.append("gray")
728
+ aggregated_text.append("")
729
+
730
+ # Reset aggregation variables
731
+ current_x_start = None
732
+ current_width = 0
733
+
734
+ # Add the non-gray block as it is
735
+ aggregated_x.append(x)
736
+ aggregated_y.append(y)
737
+ aggregated_width.append(1)
738
+ aggregated_height.append(1)
739
+ aggregated_colors.append(color)
740
+ aggregated_text.append(t)
741
+
742
+ # Add any remaining aggregated gray block
743
+ if current_x_start is not None:
744
+ aggregated_x.append(current_x_start + current_width / 2)
745
+ aggregated_y.append(current_y)
746
+ aggregated_width.append(current_width)
747
+ aggregated_height.append(1)
748
+ aggregated_colors.append("gray")
749
+ aggregated_text.append("")
750
+
751
+ return (
752
+ aggregated_x,
753
+ aggregated_y,
754
+ aggregated_width,
755
+ aggregated_height,
756
+ aggregated_colors,
757
+ aggregated_text,
758
+ )
759
+
760
+
761
+ def aggregate_conservation(x_vals, heights, colors):
762
+
763
+ """
764
+ Aggregate gray blocks in the conservation plot
765
+ to reduce the number of elements for plotting
766
+
767
+ Args:
768
+ - x_vals: list, x values
769
+ - heights: list, heights
770
+ - colors: list, colors
771
+
772
+ Returns:
773
+ - aggregated_x: list, aggregated x values
774
+ - aggregated_height: list, aggregated heights
775
+ - aggregated_colors: list, aggregated colors
776
+ """
777
+
778
+ aggregated_x = []
779
+ aggregated_height = []
780
+ aggregated_colors = []
781
+ current_x_start = None
782
+ current_width = 0
783
+
784
+ for i, (x, height, color) in enumerate(zip(x_vals, heights, colors)):
785
+ if color == "gray" and height == 2:
786
+ # Start or continue aggregating gray blocks with height = 2
787
+ if current_x_start is None:
788
+ current_x_start = x
789
+ current_width = 1
790
+ else:
791
+ current_width += 1
792
+ else:
793
+ # Add the current gray block if it exists
794
+ if current_x_start is not None:
795
+ aggregated_x.append(current_x_start + current_width / 2)
796
+ aggregated_height.append(2)
797
+ aggregated_colors.append("gray")
798
+ current_x_start = None
799
+ current_width = 0
800
+
801
+ # Add the non-gray or non-height-2 block as it is
802
+ aggregated_x.append(x)
803
+ aggregated_height.append(height)
804
+ aggregated_colors.append(color)
805
+
806
+ # Append the final aggregated gray block if any
807
+ if current_x_start is not None:
808
+ aggregated_x.append(current_x_start + current_width / 2)
809
+ aggregated_height.append(2)
810
+ aggregated_colors.append("gray")
811
+
812
+ return aggregated_x, aggregated_height, aggregated_colors
813
+
814
+
815
+ def plot_empty(msg="", plot_width=1000, plot_height=200) -> figure:
816
+ """
817
+ Return an empty bokeh plot with optional text displayed
818
+
819
+ Args:
820
+ - msg: str, message to display
821
+ - plot_width: int, width of the plot
822
+ - plot_height: int, height of the plot
823
+
824
+ Returns:
825
+ - figure: bokeh plot
826
+ """
827
+
828
+ p = figure(
829
+ width=plot_width,
830
+ height=plot_height,
831
+ x_range=(0, 1),
832
+ y_range=(0, 2),
833
+ sizing_mode="fixed", # "stretch_width",
834
+ # output_backend="webgl"
835
+ )
836
+
837
+ text = Label(x=0.3, y=1, text=msg)
838
+ p.add_layout(text)
839
+ p.grid.visible = False
840
+ p.xaxis.visible = False
841
+ p.yaxis.visible = False
842
+ return p
843
+
844
+
845
+ def plot_sequence_alignment(
846
+ aln_path: str,
847
+ parent_name: str = "parent",
848
+ well_seq: str = "",
849
+ markdown_title: str = "Multiple sequence alignment",
850
+ fontsize: str = "4pt",
851
+ plot_width: int = 1000,
852
+ sizing_mode: str = "fixed", # "stretch_width",
853
+ palette: str = "viridis",
854
+ row_height: float = 8,
855
+ ) -> figure:
856
+
857
+ """
858
+ Plot sequence alignment
859
+
860
+ Args:
861
+ - aln_path: str, path to the alignment file
862
+ - parent_name: str, name of the parent sequence
863
+ - well_seq: str, sequence of the well
864
+ - markdown_title: str, title of the plot
865
+ - fontsize: str, fontsize of the text
866
+ - plot_width: int, width of the plot
867
+ - sizing_mode: str, sizing mode of the plot
868
+ - palette: str, color palette
869
+ - row_height: float, height of the row
870
+
871
+ Returns:
872
+ - figure: bokeh plot
873
+ """
874
+
875
+ # get text from markdown
876
+ msa_title = Div(
877
+ text=f"""
878
+ {markdown_title}
879
+ """
880
+ )
881
+
882
+ # check if alignment file exists
883
+ if not os.path.exists(aln_path):
884
+ p = plot_empty("Alignment file not found", plot_width)
885
+ return gridplot(
886
+ [[msa_title], [p]],
887
+ toolbar_location=None,
888
+ sizing_mode=sizing_mode,
889
+ )
890
+
891
+ # read in alignment
892
+ aln = AlignIO.read(aln_path, "fasta")
893
+
894
+ seqs = [rec.seq for rec in (aln)]
895
+ ids = [rec.id for rec in aln]
896
+ rev_ids = list(reversed(ids))
897
+
898
+ seq_len = len(seqs[0])
899
+ numb_seq = len(seqs)
900
+
901
+ # get parent sequence
902
+ parent_seq = None
903
+ for rec in aln:
904
+ if rec.id in ["parent", "Parent", parent_name]:
905
+ parent_seq = rec.seq
906
+ break
907
+
908
+ # check if alignment has at least two sequences
909
+ if len(seqs) <= 1:
910
+ p = plot_empty("Alignment plot needs at least two sequences", plot_width)
911
+ return gridplot(
912
+ [[msa_title], [p]],
913
+ toolbar_location=None,
914
+ sizing_mode=sizing_mode,
915
+ )
916
+
917
+ seq_nucs = [i for s in list(seqs) for i in s]
918
+
919
+ # get colors for the alignment
920
+ if parent_seq == None:
921
+ block_colors = get_sequence_colors(seqs=seqs, palette=palette)
922
+ text = seq_nucs
923
+ text_colors = "black"
924
+ else:
925
+ parent_nucs = [i for i in parent_seq] * numb_seq
926
+ (
927
+ block_colors,
928
+ text_colors,
929
+ diff_nucs,
930
+ text,
931
+ ) = get_sequence_diff_colorNseq(seqs=seqs, seq_ids=ids, parent_seq=parent_seq)
932
+
933
+ # get conservation values
934
+ cons = get_cons(aln)
935
+
936
+ # get consensus sequence
937
+ if well_seq == "":
938
+ cons_seq = get_cons_seq(aln)
939
+ else:
940
+ cons_seq = well_seq
941
+
942
+ cons_nucs = [i for i in cons_seq] * numb_seq
943
+
944
+ # coords of the plot
945
+ x = np.arange(1, seq_len + 1)
946
+ y = np.arange(0, numb_seq, 1)
947
+ xx, yy = np.meshgrid(x, y)
948
+ gx = xx.ravel()
949
+ gy = yy.flatten()
950
+ recty = gy + 0.5
951
+
952
+ # Apply transformation to IDs
953
+ y_flipped = numb_seq - gy - 1
954
+ ids_repeated = [ids[yi] for yi in y_flipped]
955
+ rev_ids_repeated = [rev_ids[yi] for yi in y_flipped]
956
+
957
+ # set up msa source
958
+ msa_source = ColumnDataSource(
959
+ dict(
960
+ x=gx,
961
+ y=y_flipped,
962
+ ids=ids_repeated,
963
+ rev_ids=rev_ids_repeated,
964
+ seq_nucs=seq_nucs,
965
+ parent_nucs=parent_nucs,
966
+ cons_nucs=cons_nucs,
967
+ recty=numb_seq - recty,
968
+ block_colors=block_colors,
969
+ text=text,
970
+ text_colors=text_colors,
971
+ )
972
+ )
973
+
974
+ # define the plot size and range
975
+ plot_height = len(seqs) * row_height + 20
976
+ x_range = Range1d(0, seq_len + 1, bounds="auto")
977
+
978
+ # Aggregating gray blocks
979
+ agg_x, agg_y, agg_width, agg_height, agg_colors, agg_text = aggregate_gray_blocks(
980
+ msa_source.data["x"],
981
+ msa_source.data["y"],
982
+ msa_source.data["block_colors"],
983
+ msa_source.data["text"],
984
+ )
985
+
986
+ # Create the updated ColumnDataSource with aggregated gray blocks
987
+ msa_source_aggregated = ColumnDataSource(
988
+ dict(
989
+ x=agg_x,
990
+ y=agg_y,
991
+ width=agg_width,
992
+ height=agg_height,
993
+ fill_color=agg_colors,
994
+ text=agg_text,
995
+ )
996
+ )
997
+
998
+ # Create a new figure
999
+ p_aln = figure(
1000
+ title=None,
1001
+ width=plot_width,
1002
+ height=plot_height,
1003
+ x_range=(0, max(agg_x) + 1),
1004
+ y_range=(-0.5, max(agg_y) + 0.5),
1005
+ )
1006
+
1007
+ # Add aggregated rectangles (blocks of sequences)
1008
+ rect_glyph = Rect(
1009
+ x="x",
1010
+ y="y",
1011
+ width="width",
1012
+ height="height",
1013
+ fill_color="fill_color",
1014
+ line_color=None,
1015
+ )
1016
+
1017
+ # Add the rectangles to the plot
1018
+ p_aln.add_glyph(msa_source_aggregated, rect_glyph)
1019
+
1020
+ p_aln.grid.visible = False
1021
+ p_aln.yaxis.visible = False
1022
+ p_aln.yaxis.major_label_text_font_size = "0pt"
1023
+ p_aln.yaxis.minor_tick_line_width = 0
1024
+ p_aln.yaxis.major_tick_line_width = 0
1025
+
1026
+ # conservation plot
1027
+ cons_colors, cons_text = get_cons_diff_colorNseq(
1028
+ cons_seq=cons_seq, parent_seq=parent_seq
1029
+ )
1030
+
1031
+ cons_source = ColumnDataSource(
1032
+ dict(
1033
+ x=x,
1034
+ cons=cons,
1035
+ cons_height=[2 * c for c in cons],
1036
+ cons_colors=cons_colors,
1037
+ parent_nucs=list(parent_seq),
1038
+ cons_nucs=list(cons_seq),
1039
+ # cons_text=cons_text,
1040
+ )
1041
+ )
1042
+
1043
+ # Example usage with the conservation data
1044
+ agg_cons_x, agg_cons_height, agg_cons_colors = aggregate_conservation(
1045
+ cons_source.data["x"],
1046
+ cons_source.data["cons_height"],
1047
+ cons_source.data["cons_colors"],
1048
+ )
1049
+
1050
+ # Create the updated ColumnDataSource with aggregated gray blocks
1051
+ cons_source_aggregated = ColumnDataSource(
1052
+ dict(
1053
+ x=agg_cons_x,
1054
+ height=agg_cons_height,
1055
+ fill_color=agg_cons_colors,
1056
+ )
1057
+ )
1058
+
1059
+ # Create a new figure for the aggregated conservation plot
1060
+ p_cons = figure(
1061
+ title=None,
1062
+ width=plot_width,
1063
+ height=60,
1064
+ x_range=p_aln.x_range,
1065
+ y_range=(Range1d(0, 1)),
1066
+ # tools=[cons_hover],
1067
+ # output_backend="webgl"
1068
+ )
1069
+
1070
+ # Add aggregated rectangles (conservation bars)
1071
+ cons_rects = Rect(
1072
+ x="x", y=0, width=1, height="height", fill_color="fill_color", line_color=None
1073
+ )
1074
+
1075
+ # cons_text = Text(
1076
+ # x="x",
1077
+ # y=0,
1078
+ # text="cons_text",
1079
+ # text_align="center",
1080
+ # text_color="black",
1081
+ # text_font_size=fontsize,
1082
+ # )
1083
+
1084
+ p_cons.add_glyph(cons_source_aggregated, cons_rects)
1085
+
1086
+ # Adding the legend items for each nucleotide from NUC_COLOR_DICT
1087
+ legend_items = []
1088
+ for nucleotide, color in dict(list(NUC_COLOR_DICT.items())[:5]).items():
1089
+ legend_source = ColumnDataSource(
1090
+ dict(x=[0], y=[0], width=[0], height=[0], fill_color=[color])
1091
+ )
1092
+ rect = p_cons.rect(
1093
+ x="x",
1094
+ y="y",
1095
+ width="width",
1096
+ height="height",
1097
+ fill_color="fill_color",
1098
+ line_color="#F5F5F5",
1099
+ source=legend_source,
1100
+ visible=False,
1101
+ )
1102
+ legend_items.append(LegendItem(label=nucleotide, renderers=[rect]))
1103
+
1104
+ # Create and add the horizontal legend to the plot
1105
+ legend = Legend(items=legend_items, location="top_center", orientation="horizontal")
1106
+ p_cons.add_layout(legend, "above") # Place legend above the plot
1107
+
1108
+ # Remove the legend box and background
1109
+ p_cons.legend.border_line_color = None # Removes the border of the legend
1110
+ p_cons.legend.background_fill_alpha = 0 # Makes the legend background transparent
1111
+
1112
+ # Tighten the legend layout
1113
+ p_cons.legend.padding = 0
1114
+ p_cons.legend.spacing = 0
1115
+
1116
+ # Hide the x-axis labels and keep the y-axis visible
1117
+ p_cons.xaxis.visible = False
1118
+ p_cons.yaxis.visible = True
1119
+ p_cons.yaxis.ticker = [1]
1120
+ p_cons.yaxis.axis_label = "Alignment conservation values"
1121
+ p_cons.yaxis.axis_label_orientation = "horizontal"
1122
+
1123
+ p_cons.grid.visible = False
1124
+ p_cons.background_fill_color = "white"
1125
+
1126
+ return gridplot(
1127
+ [[msa_title], [p_cons], [p_aln]],
1128
+ toolbar_location=None,
1129
+ sizing_mode=sizing_mode,
1130
+ )