partis-bcr 1.0.0__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. bin/FastTree +0 -0
  2. bin/add-chimeras.py +59 -0
  3. bin/add-seqs-to-outputs.py +81 -0
  4. bin/bcr-phylo-run.py +799 -0
  5. bin/build.sh +24 -0
  6. bin/cf-alleles.py +97 -0
  7. bin/cf-germlines.py +57 -0
  8. bin/cf-linearham.py +199 -0
  9. bin/chimera-plot.py +76 -0
  10. bin/choose-partially-paired.py +143 -0
  11. bin/circle-plots.py +30 -0
  12. bin/compare-plotdirs.py +298 -0
  13. bin/diff-parameters.py +133 -0
  14. bin/docker-hub-push.sh +6 -0
  15. bin/extract-pairing-info.py +55 -0
  16. bin/gcdyn-simu-run.py +223 -0
  17. bin/gctree-run.py +244 -0
  18. bin/get-naive-probabilities.py +126 -0
  19. bin/iqtree-1.6.12 +0 -0
  20. bin/lonr.r +1020 -0
  21. bin/makeHtml +52 -0
  22. bin/mds-run.py +46 -0
  23. bin/parse-output.py +277 -0
  24. bin/partis +1869 -0
  25. bin/partis-pip +116 -0
  26. bin/partis.py +1869 -0
  27. bin/plot-gl-set-trees.py +519 -0
  28. bin/plot-hmms.py +151 -0
  29. bin/plot-lb-tree.py +427 -0
  30. bin/raxml-ng +0 -0
  31. bin/read-bcr-phylo-trees.py +38 -0
  32. bin/read-gctree-output.py +166 -0
  33. bin/run-chimeras.sh +64 -0
  34. bin/run-dtr-scan.sh +25 -0
  35. bin/run-paired-loci.sh +100 -0
  36. bin/run-tree-metrics.sh +88 -0
  37. bin/smetric-run.py +62 -0
  38. bin/split-loci.py +317 -0
  39. bin/swarm-2.1.13-linux-x86_64 +0 -0
  40. bin/test-germline-inference.py +425 -0
  41. bin/tree-perf-run.py +194 -0
  42. bin/vsearch-2.4.3-linux-x86_64 +0 -0
  43. bin/vsearch-2.4.3-macos-x86_64 +0 -0
  44. bin/xvfb-run +194 -0
  45. partis_bcr-1.0.1.data/scripts/cf-alleles.py +97 -0
  46. partis_bcr-1.0.1.data/scripts/cf-germlines.py +57 -0
  47. partis_bcr-1.0.1.data/scripts/extract-pairing-info.py +55 -0
  48. partis_bcr-1.0.1.data/scripts/gctree-run.py +244 -0
  49. partis_bcr-1.0.1.data/scripts/parse-output.py +277 -0
  50. partis_bcr-1.0.1.data/scripts/split-loci.py +317 -0
  51. partis_bcr-1.0.1.data/scripts/test.py +1005 -0
  52. {partis_bcr-1.0.0.dist-info → partis_bcr-1.0.1.dist-info}/METADATA +1 -1
  53. {partis_bcr-1.0.0.dist-info → partis_bcr-1.0.1.dist-info}/RECORD +101 -50
  54. partis_bcr-1.0.1.dist-info/top_level.txt +1 -0
  55. {partis → python}/glutils.py +1 -1
  56. python/main.py +30 -0
  57. {partis → python}/plotting.py +10 -1
  58. {partis → python}/treeutils.py +18 -16
  59. {partis → python}/utils.py +14 -7
  60. partis/main.py +0 -59
  61. partis_bcr-1.0.0.dist-info/top_level.txt +0 -1
  62. {partis_bcr-1.0.0.dist-info → partis_bcr-1.0.1.dist-info}/WHEEL +0 -0
  63. {partis_bcr-1.0.0.dist-info → partis_bcr-1.0.1.dist-info}/entry_points.txt +0 -0
  64. {partis_bcr-1.0.0.dist-info → partis_bcr-1.0.1.dist-info}/licenses/COPYING +0 -0
  65. {partis → python}/__init__.py +0 -0
  66. {partis → python}/alleleclusterer.py +0 -0
  67. {partis → python}/allelefinder.py +0 -0
  68. {partis → python}/alleleremover.py +0 -0
  69. {partis → python}/annotationclustering.py +0 -0
  70. {partis → python}/baseutils.py +0 -0
  71. {partis → python}/cache/__init__.py +0 -0
  72. {partis → python}/cache/cached_uncertainties.py +0 -0
  73. {partis → python}/clusterpath.py +0 -0
  74. {partis → python}/coar.py +0 -0
  75. {partis → python}/corrcounter.py +0 -0
  76. {partis → python}/datautils.py +0 -0
  77. {partis → python}/event.py +0 -0
  78. {partis → python}/fraction_uncertainty.py +0 -0
  79. {partis → python}/gex.py +0 -0
  80. {partis → python}/glomerator.py +0 -0
  81. {partis → python}/hist.py +0 -0
  82. {partis → python}/hmmwriter.py +0 -0
  83. {partis → python}/hutils.py +0 -0
  84. {partis → python}/indelutils.py +0 -0
  85. {partis → python}/lbplotting.py +0 -0
  86. {partis → python}/mds.py +0 -0
  87. {partis → python}/mutefreqer.py +0 -0
  88. {partis → python}/paircluster.py +0 -0
  89. {partis → python}/parametercounter.py +0 -0
  90. {partis → python}/paramutils.py +0 -0
  91. {partis → python}/partitiondriver.py +0 -0
  92. {partis → python}/partitionplotter.py +0 -0
  93. {partis → python}/performanceplotter.py +0 -0
  94. {partis → python}/plotconfig.py +0 -0
  95. {partis → python}/processargs.py +0 -0
  96. {partis → python}/prutils.py +0 -0
  97. {partis → python}/recombinator.py +0 -0
  98. {partis → python}/scanplot.py +0 -0
  99. {partis → python}/seqfileopener.py +0 -0
  100. {partis → python}/treegenerator.py +0 -0
  101. {partis → python}/viterbicluster.py +0 -0
  102. {partis → python}/vrc01.py +0 -0
  103. {partis → python}/waterer.py +0 -0
@@ -0,0 +1,519 @@
1
+ #!/usr/bin/env python3
2
+ # has to be its own script, since ete3 requires its own god damn python version, installed in a separated directory
3
+ from __future__ import absolute_import, division, unicode_literals
4
+ from __future__ import print_function
5
+ import time
6
+ import yaml
7
+ import itertools
8
+ import glob
9
+ import argparse
10
+ import copy
11
+ import random
12
+ import os
13
+ import tempfile
14
+ import subprocess
15
+ import sys
16
+ import colored_traceback.always
17
+ from collections import OrderedDict
18
+ from io import open
19
+ import ete3
20
+
21
+ # ----------------------------------------------------------------------------------------
22
+ def pairkey(name1, name2):
23
+ return '-&-'.join(sorted([name1, name2]))
24
+
25
+ # ----------------------------------------------------------------------------------------
26
+ scolors = {
27
+ 'novel' : '#ffc300', # 'Gold'
28
+ 'data' : 'LightSteelBlue',
29
+ 'pale-green' : '#85ad98',
30
+ 'pale-blue' : '#94a3d1',
31
+ 'tigger-default' : '#d77c7c', #'#c32222', # red
32
+ 'igdiscover' : '#85ad98', #'#29a614', # green
33
+ 'partis' : '#94a3d1', #'#2455ed', # blue
34
+ }
35
+
36
+ listfaces = [
37
+ 'red',
38
+ 'blue',
39
+ 'green',
40
+ ]
41
+ used_colors, used_faces = {}, {}
42
+ simu_colors = OrderedDict((
43
+ ('ok', 'DarkSeaGreen'),
44
+ ('missing', '#d77c7c'),
45
+ ('spurious', '#a44949'),
46
+ ))
47
+
48
+ # ----------------------------------------------------------------------------------------
49
+ def set_colors(gl_sets, ref_label=None, mix_primary_colors=False):
50
+ listcolors = [plotting.getgrey('medium') for _ in range(10)]
51
+ if ref_label is not None: # simulation
52
+ for status, color in simu_colors.items():
53
+ scolors[status] = color
54
+ return
55
+
56
+ names = sorted(gl_sets.keys())
57
+
58
+ if len(names) == 1: # single-sample data
59
+ scolors[names[0]] = scolors['data']
60
+ return
61
+
62
+ assert len(names) in [2, 3]
63
+
64
+ if len(names) == 2:
65
+ scolors['all'] = plotting.getgrey('light')
66
+ else:
67
+ scolors['all'] = plotting.getgrey('white')
68
+
69
+ for name in names:
70
+ if name not in scolors:
71
+ scolors[name] = listcolors[names.index(name) % len(listcolors)]
72
+ facestr = listfaces[names.index(name) % len(listfaces)]
73
+ used_colors[name] = scolors[name]
74
+ used_faces[name] = facestr
75
+ for name1, name2 in itertools.combinations(names, 2):
76
+ if len(gl_sets) == 2:
77
+ shade = 'white'
78
+ else:
79
+ shade = 'medium' if len(names) == 2 else 'light-medium'
80
+ scolors[pairkey(name1, name2)] = plotting.getgrey(shade)
81
+
82
+ # ----------------------------------------------------------------------------------------
83
+ def get_cmdfos(cmdstr, workdir, outfname):
84
+ return [{'cmd_str' : cmdstr,
85
+ 'workdir' : workdir,
86
+ 'outfname' : outfname}]
87
+
88
+ # ----------------------------------------------------------------------------------------
89
+ def make_tree(all_genes, workdir, use_cache=False):
90
+ aligned_fname = workdir + '/all-aligned.fa'
91
+ raxml_label = 'xxx'
92
+ raxml_output_fnames = ['%s/RAxML_%s.%s' % (workdir, fn, raxml_label) for fn in ['parsimonyTree', 'log', 'result', 'info', 'bestTree']]
93
+ treefname = [fn for fn in raxml_output_fnames if 'result' in fn][0]
94
+ if use_cache: # don't re-run muxcle & raxml, just use the previous run's output tree file
95
+ return treefname
96
+ utils.prep_dir(workdir, wildlings=['*.' + raxml_label, os.path.basename(aligned_fname), 'out', 'err', os.path.basename(aligned_fname) + '.reduced'])
97
+
98
+ # write and align an .fa with all alleles from any gl set
99
+ start = time.time()
100
+ with tempfile.NamedTemporaryFile(mode='w') as tmpfile:
101
+ for name, seq in all_genes.items():
102
+ tmpfile.write('>%s\n%s\n' % (name, seq))
103
+ tmpfile.flush() # BEWARE if you forget this you are fucked
104
+ cmdstr = '%s -in %s -out %s' % (args.muscle_path, tmpfile.name, aligned_fname)
105
+ if args.debug:
106
+ print(' %s %s' % (utils.color('red', 'run'), cmdstr))
107
+ utils.run_cmds(get_cmdfos(cmdstr, workdir, aligned_fname), ignore_stderr=True)
108
+
109
+ # get a tree for the aligned .fa
110
+ cmdstr = '%s -mGTRCAT -n%s -s%s -p1 -w%s' % (args.raxml_path, raxml_label, aligned_fname, workdir)
111
+ if args.debug:
112
+ print(' %s %s' % (utils.color('red', 'run'), cmdstr))
113
+ utils.run_cmds(get_cmdfos(cmdstr, workdir, treefname), ignore_stderr=True)
114
+ print(' raxml time: %.1f' % (time.time() - start))
115
+
116
+ os.remove(aligned_fname) # rm muscle output
117
+ for fn in [f for f in raxml_output_fnames if f != treefname]: # rm all the raxml outputs except what the one file we really want
118
+ os.remove(fn)
119
+
120
+ return treefname
121
+
122
+ # ----------------------------------------------------------------------------------------
123
+ def getstatus(gene_categories, node, ref_label=None, debug=False):
124
+ gene = node.name
125
+
126
+ if not node.is_leaf():
127
+ return 'internal'
128
+ cats = [cat for cat, genes in gene_categories.items() if gene in genes]
129
+ if len(cats) == 0:
130
+ raise Exception('[probably need to bust plot cache/rewrite tree file] couldn\'t find a category for %s among:\n %s' % (node.name, '\n '.join(['%s:\n %s' % (k, ' '.join(gene_categories[k])) for k in gene_categories])))
131
+ elif len(cats) > 1:
132
+ raise Exception('wtf?')
133
+ if debug:
134
+ print('%-50s %s' % (gene, cats[0]))
135
+ return cats[0]
136
+
137
+ # ----------------------------------------------------------------------------------------
138
+ def print_results(gene_categories, gl_sets, ref_label=None):
139
+ pwidth = str(max([len(n) for n in gene_categories]))
140
+ for name, genes in gene_categories.items():
141
+ if name not in scolors:
142
+ raise Exception('status \'%s\' not in scolors' % name)
143
+ if name == 'ok':
144
+ genestr = ''
145
+ else:
146
+ genestr = ' '.join([utils.color_gene(g) for g in genes])
147
+ print((' %-' + pwidth + 's') % name, end=' ')
148
+ # print '%20s' % scolors[name],
149
+ if name in gl_sets:
150
+ print(' total %2d' % len(gl_sets[name]), end=' ')
151
+ else:
152
+ print(' ', end=' ')
153
+ only_str = 'only' if ref_label is None else ''
154
+ if len(genes) == 0:
155
+ print(' %s %s' % (only_str, utils.color('blue', 'none')))
156
+ else:
157
+ print(' %s %2d %s' % (only_str, len(genes), genestr))
158
+
159
+ # ----------------------------------------------------------------------------------------
160
+ def write_results(outdir, gene_categories, gl_sets):
161
+ with open(outdir + '/results.yaml', 'w') as yamlfile:
162
+ yamlfo = {gcat : list(genes) for gcat, genes in gene_categories.items()}
163
+ yaml.dump(yamlfo, yamlfile, width=150)
164
+
165
+ # ----------------------------------------------------------------------------------------
166
+ def get_gene_sets(glsfnames, glslabels, ref_label=None, classification_fcn=None, debug=False):
167
+ # debug = True
168
+ glfos = {}
169
+ for label, fname in zip(glslabels, glsfnames):
170
+ if os.path.isdir(fname):
171
+ raise Exception('directory passed instead of germline file name: %s' % fname)
172
+ if os.path.basename(os.path.dirname(fname)) != args.locus:
173
+ raise Exception('unexpected germline directory structure (should have locus \'%s\' at end): %s' % (args.locus, fname))
174
+ gldir = os.path.dirname(fname).replace('/' + args.locus, '')
175
+ glfos[label] = glutils.read_glfo(gldir, args.locus)
176
+
177
+ if args.region != 'v':
178
+ print(' not synchronizing gl sets for %s' % args.region)
179
+ if args.region == 'v': # don't want to deal with d and j synchronization yet
180
+ # synchronize to somebody -- either simulation (<ref_label>) or the first one
181
+ if ref_label is not None:
182
+ sync_label = ref_label
183
+ elif 'partis' in glslabels:
184
+ sync_label = 'partis'
185
+ else:
186
+ sync_label = glslabels[0]
187
+ for label in [l for l in glslabels if l != sync_label]:
188
+ if debug:
189
+ print(' synchronizing %s names to match %s' % (label, sync_label))
190
+ glutils.synchronize_glfos(ref_glfo=glfos[sync_label], new_glfo=glfos[label], region=args.region, ref_label=sync_label, debug=debug)
191
+
192
+ gl_sets = {label : {g : seq for g, seq in glfos[label]['seqs'][args.region].items()} for label in glfos}
193
+ all_genes = {g : s for gls in gl_sets.values() for g, s in gls.items()}
194
+
195
+ if classification_fcn is not None:
196
+ all_primary_versions = set([classification_fcn(g) for g in all_genes])
197
+ gl_sets = {pv : {label : {g : gl_sets[label][g] for g in gl_sets[label] if classification_fcn(g) == pv} for label in gl_sets} for pv in all_primary_versions}
198
+ all_genes = {pv : {g : s for g, s in all_genes.items() if classification_fcn(g) == pv} for pv in all_primary_versions}
199
+
200
+ if len(gl_sets) > 3:
201
+ raise Exception('not implemented')
202
+ gcats = OrderedDict()
203
+
204
+ gcats['all'] = set()
205
+ if len(gl_sets) > 2:
206
+ gcats['all'] = set(all_genes) # gcats['all'] is genes that are in *every* gl set, whereas <all_genes> is genes that're in *any* of 'em (i know, i know...)
207
+ for genes in gl_sets.values():
208
+ gcats['all'] &= set(genes)
209
+
210
+ for name, genes in gl_sets.items():
211
+ gcats[name] = set(genes) - gcats['all']
212
+
213
+ for ds_1, ds_2 in itertools.combinations(gl_sets, 2):
214
+ gcats[ds_1] -= set(gl_sets[ds_2])
215
+ gcats[ds_2] -= set(gl_sets[ds_1])
216
+ gcats[pairkey(ds_1, ds_2)] = set(gl_sets[ds_2]) & set(gl_sets[ds_1]) - gcats['all']
217
+
218
+ if ref_label is not None:
219
+ assert len(gl_sets) == 2
220
+ assert ref_label in gl_sets
221
+ inf_label = [ds for ds in gl_sets if ds != ref_label][0]
222
+ gcats['missing'] = gcats[ref_label]
223
+ gcats['spurious'] = gcats[inf_label]
224
+ gcats['ok'] = gcats[pairkey(ref_label, inf_label)]
225
+ del gcats[ref_label]
226
+ del gcats[inf_label]
227
+ del gcats[pairkey(ref_label, inf_label)]
228
+
229
+ any_check = []
230
+ for key, genes in gcats.items():
231
+ any_check += genes
232
+ if sorted(any_check) != sorted(all_genes):
233
+ raise Exception('you done messed up')
234
+
235
+ if len(gl_sets) <= 2:
236
+ del gcats['all']
237
+
238
+ return all_genes, gl_sets, gcats
239
+
240
+ # ----------------------------------------------------------------------------------------
241
+ def set_node_style(node, status, n_gl_sets, ref_label=None):
242
+ if status != 'internal':
243
+ if status not in scolors:
244
+ raise Exception('status \'%s\' not in scolors' % status)
245
+ node.img_style['bgcolor'] = scolors[status]
246
+ if status not in used_colors:
247
+ used_colors[status] = scolors[status]
248
+
249
+ if glutils.is_novel(node.name):
250
+ node.add_face(ete3.CircleFace(args.novel_dot_size, scolors['novel']), column=1) #, position='float') # if args.leaf_names else 'branch')
251
+
252
+ # linewidth = 2
253
+ # node.img_style['hz_line_width'] = linewidth
254
+ # node.img_style['vt_line_width'] = linewidth
255
+
256
+ names = status.split('-&-')
257
+ if node.is_leaf():
258
+ if args.pie_chart_faces and len(names) > 1:
259
+ pcf = ete3.PieChartFace(percents=[100./len(names) for _ in range(len(names))], width=args.leafheight, height=args.leafheight, colors=[scolors[n] for n in names], line_color=None)
260
+ # pcf = ete3.StackedBarFace(percents=[100./len(names) for _ in range(len(names))], width=30, height=50, colors=[scolors[n] for n in names], line_color=None)
261
+ node.add_face(pcf, column=0, position='aligned')
262
+ elif len(names) == 1 and names[0] in used_faces:
263
+ node.add_face(ete3.RectFace(width=5, height=args.leafheight, bgcolor=used_faces[names[0]], fgcolor=None), column=0, position='aligned')
264
+ elif n_gl_sets > 2:
265
+ rectnames = [n for n in names if n in used_faces]
266
+ node.add_face(ete3.StackedBarFace(percents=[100./len(names) for _ in range(len(rectnames))], width=5 * len(rectnames), height=args.leafheight, colors=[used_faces[rn] for rn in rectnames], line_color=None), column=0, position='aligned')
267
+ else: # every leaf has to have a face, so that every leaf takes up the same vertical space
268
+ node.add_face(ete3.RectFace(width=1, height=args.leafheight, bgcolor=None, fgcolor=None), column=0, position='aligned')
269
+
270
+ # ----------------------------------------------------------------------------------------
271
+ def get_entirety_of_gene_family(root, family):
272
+ # return set([leaf.name for leaf in node.get_tree_root() if utils.gene_family(leaf.name) == gene_family])
273
+ return set([leaf.name for leaf in root if utils.gene_family(leaf.name) == family])
274
+
275
+ # ----------------------------------------------------------------------------------------
276
+ def set_distance_to_zero(node, debug=False):
277
+ if node.is_root():
278
+ return True
279
+ if node.is_leaf():
280
+ if len(get_entirety_of_gene_family(node.get_tree_root(), utils.gene_family(node.name))) == 1: # if this is the *only* one from this family
281
+ if debug:
282
+ print(' family %s is of length 1 %s (set to zero)' % (utils.gene_family(node.name), node.name))
283
+ return True
284
+ else:
285
+ return False
286
+
287
+ descendents = set([leaf.name for leaf in node])
288
+ gene_families = set([utils.gene_family(d) for d in descendents])
289
+ if debug:
290
+ print(' %s' % ' '.join([utils.shorten_gene_name(d) for d in descendents]))
291
+ print(' %s' % ' '.join(gene_families))
292
+ if len(gene_families) == 0:
293
+ raise Exception('zero length gene family set from %s' % ' '.join([leaf.name for leaf in node]))
294
+ if len(gene_families) > 1:
295
+ return True
296
+
297
+ gene_family = list(gene_families)[0]
298
+ entirety_of_gene_family = get_entirety_of_gene_family(node.get_tree_root(), gene_family)
299
+ if debug:
300
+ if len(entirety_of_gene_family - descendents) > 0:
301
+ print(' missing %d / %d of family' % (len(entirety_of_gene_family - descendents), len(entirety_of_gene_family)))
302
+ elif len(descendents - entirety_of_gene_family) > 0:
303
+ raise Exception('wtf should\'ve already returned')
304
+ else:
305
+ print(' setting to zero')
306
+ return descendents == entirety_of_gene_family
307
+
308
+ # ----------------------------------------------------------------------------------------
309
+ def write_legend(plotdir):
310
+ def get_leg_name(status):
311
+ if args.legends is not None and status in args.glslabels:
312
+ lname = args.legends[args.glslabels.index(status)]
313
+ elif status == 'both':
314
+ if len(args.glsfnames) == 2:
315
+ lname = 'both'
316
+ elif len(args.glsfnames) == 3:
317
+ lname = 'two'
318
+ else:
319
+ raise Exception('can\'t make a legend when --glsfnames is length %d' % len(args.glsfnames))
320
+ elif status == 'all':
321
+ if len(args.glsfnames) == 2:
322
+ lname = 'both'
323
+ elif len(args.glsfnames) == 3:
324
+ lname = 'all three'
325
+ else:
326
+ raise Exception('can\'t make a legend when --glsfnames is length %d' % len(args.glsfnames))
327
+ else:
328
+ lname = status
329
+ return lname
330
+ def add_stuff(status, leg_name, color):
331
+ legfo[leg_name] = color
332
+ if status in used_faces:
333
+ facefo[leg_name] = used_faces[status]
334
+
335
+ legfo, facefo = {}, {}
336
+ if args.ref_label is not None:
337
+ for status, color in simu_colors.items():
338
+ add_stuff(status, status, color)
339
+ else:
340
+ added_two_method_color = False
341
+ for status, color in used_colors.items():
342
+ if '-&-' in status:
343
+ for substatus in status.split('-&-'): # arg, have to handle cases where the single one isn't in there
344
+ if get_leg_name(substatus) not in legfo:
345
+ add_stuff(substatus, get_leg_name(substatus), scolors[substatus])
346
+ if not added_two_method_color:
347
+ leg_name = get_leg_name('both')
348
+ added_two_method_color = True
349
+ else:
350
+ continue
351
+ else:
352
+ leg_name = get_leg_name(status)
353
+
354
+ add_stuff(status, leg_name, color)
355
+
356
+ # figure out the order we want 'em in
357
+ lnames = sorted(legfo.keys())
358
+ for status in ['both', 'all']:
359
+ if get_leg_name(status) in lnames:
360
+ lnames.remove(get_leg_name(status))
361
+ lnames.append(get_leg_name(status))
362
+
363
+ etree = ete3.ClusterTree() #'(a);')
364
+ tstyle = ete3.TreeStyle()
365
+ tstyle.show_scale = False
366
+ # tstyle.show_leaf_name = False
367
+ # for node in etree.traverse():
368
+ # print node.name
369
+ # node.add_face(ete3.CircleFace(args.novel_dot_size, scolors['novel']), column=1) #, position='float') # if args.leaf_names else 'branch')
370
+
371
+ dummy_column = 0
372
+ pic_column = 1
373
+ text_column = 2
374
+ leg_title_height = 1.5 * args.leafheight # if args.legend_title is not None else 0.75 * args.leafheight
375
+
376
+ for icol in range(text_column + 1): # add a top border
377
+ tstyle.title.add_face(ete3.RectFace(0.9*args.leafheight, 0.9*args.leafheight, fgcolor=None, bgcolor=None), column=icol)
378
+
379
+ tstyle.title.add_face(ete3.TextFace(' ', fsize=leg_title_height), column=dummy_column) # adds a left border
380
+
381
+ if args.legend_title is not None:
382
+ tstyle.title.add_face(ete3.TextFace('', fsize=leg_title_height), column=pic_column) # keeps the first legend entry from getting added on this line
383
+ tstyle.title.add_face(ete3.TextFace(args.legend_title, fsize=leg_title_height, fgcolor='black', bold=True), column=text_column) # add an empty title so there's some white space at the top, even with no actual title text
384
+
385
+ for leg_name in lnames:
386
+ color = legfo[leg_name]
387
+ size_factor = 2.
388
+ if leg_name in facefo:
389
+ tstyle.title.add_face(ete3.StackedBarFace([80., 20.], width=size_factor*args.leafheight, height=size_factor*args.leafheight, colors=[color, facefo[leg_name]], line_color='black'), column=pic_column) # looks like maybe they reversed fg/bg kwarg names
390
+ else:
391
+ tstyle.title.add_face(ete3.RectFace(size_factor*args.leafheight, size_factor*args.leafheight, fgcolor='black', bgcolor=color), column=pic_column) # looks like maybe they reversed fg/bg kwarg names
392
+ tstyle.title.add_face(ete3.TextFace(' ' + leg_name, fsize=args.leafheight, fgcolor='black'), column=text_column)
393
+
394
+ tstyle.title.add_face(ete3.CircleFace(1.5*args.novel_dot_size, scolors['novel']), column=pic_column)
395
+ tstyle.title.add_face(ete3.TextFace('novel allele', fsize=args.leafheight), column=text_column) # keeps the first legend entry from getting added on this line
396
+
397
+ etree.render(plotdir + '/legend.svg', tree_style=tstyle)
398
+
399
+ # ----------------------------------------------------------------------------------------
400
+ def draw_tree(plotdir, plotname, treestr, gl_sets, all_genes, gene_categories, ref_label=None, arc_start=None, arc_span=None):
401
+ etree = ete3.ClusterTree(treestr)
402
+ node_names = set() # make sure we get out all the genes we put in
403
+ for node in etree.traverse():
404
+ if set_distance_to_zero(node):
405
+ node.dist = 0. if ref_label is not None else 1e-9 # data crashes sometimes with float division by zero if you set it to 0., but simulation sometimes gets screwed up for some other reason (that I don't understand) if it's 1e-9
406
+ # node.dist = 1.
407
+ status = getstatus(gene_categories, node, ref_label=ref_label)
408
+ set_node_style(node, status, len(gl_sets), ref_label=ref_label)
409
+ if node.is_leaf():
410
+ node_names.add(node.name)
411
+ if len(set(all_genes) - node_names) > 0:
412
+ raise Exception('missing genes from final tree: %s' % ' '.join(node_names))
413
+
414
+ if args.param_dirs is not None:
415
+ countfo = OrderedDict()
416
+ for label, pdir in zip(args.glslabels, args.param_dirs): # it would be cleaner to do this somewhere else
417
+ if pdir == 'None': # not the best way to do this
418
+ continue
419
+ countfo[label] = utils.read_overall_gene_probs(pdir, normalize=True)[args.region]
420
+ for node in etree.traverse():
421
+ node.countstr = '%s' % ' '.join([('%.2f' % (100 * cfo[node.name])) if node.name in cfo else '-' for cfo in countfo.values()])
422
+
423
+ if ref_label is None: # have to do it in a separate loop so it doesn't screw up the distance setting
424
+ for node in [n for n in etree.traverse() if n.is_leaf()]: # yeah I'm sure there's a fcn for that
425
+ node.name = utils.shorten_gene_name(node.name)
426
+
427
+ tstyle = ete3.TreeStyle()
428
+ tstyle.show_scale = False
429
+
430
+ if len(args.glslabels) > 1:
431
+ write_legend(plotdir)
432
+ if args.title is not None:
433
+ fsize = 13
434
+ tstyle.title.add_face(ete3.TextFace(args.title, fsize=fsize, bold=True), column=0)
435
+ if args.title_color is not None:
436
+ # tstyle.title.add_face(ete3.CircleFace(fsize, scolors[args.title]), column=1)
437
+ tcol = scolors[args.title_color] if args.title_color in scolors else args.title_color
438
+ rect_width = 3 if len(args.title) < 12 else 2
439
+ tstyle.title.add_face(ete3.RectFace(width=rect_width*fsize, height=fsize, bgcolor=tcol, fgcolor=None), column=1)
440
+ suffix = '.svg'
441
+ imagefname = plotdir + '/' + plotname + suffix
442
+ print(' %s' % imagefname)
443
+ etree.render(utils.insert_before_suffix('-leaf-names', imagefname), tree_style=tstyle)
444
+ tstyle.show_leaf_name = False
445
+ etree.render(imagefname, tree_style=tstyle)
446
+
447
+ # NOTE all the node names are screwed up after this, so you'll have to fix them if you add another step
448
+ if args.param_dirs is not None:
449
+ for node in etree.traverse():
450
+ node.name = node.countstr
451
+ tstyle.show_leaf_name = True
452
+ etree.render(utils.insert_before_suffix('-gene-counts', imagefname), tree_style=tstyle)
453
+
454
+ # ----------------------------------------------------------------------------------------
455
+ def plot_trees(args, plotdir, plotname, glsfnames, glslabels):
456
+ all_genes, gl_sets, gene_categories = get_gene_sets(glsfnames, glslabels, ref_label=args.ref_label)
457
+ set_colors(gl_sets, ref_label=args.ref_label)
458
+ print_results(gene_categories, gl_sets, ref_label=args.ref_label)
459
+ write_results(plotdir, gene_categories, gl_sets)
460
+ if args.only_print:
461
+ return
462
+
463
+ treefname = make_tree(all_genes, plotdir + '/workdir', use_cache=args.use_cache)
464
+ with open(treefname) as treefile:
465
+ treestr = treefile.read().strip()
466
+
467
+ draw_tree(plotdir, plotname, treestr, gl_sets, all_genes, gene_categories, ref_label=args.ref_label)
468
+
469
+ # ----------------------------------------------------------------------------------------
470
+ example_str = '\n '.join(['example usage (note that this example as it is will be a) really slow, since the files are the full imgt set, with ~250 genes, and b) not very interesting, since the two .fasta files are the same):',
471
+ './bin/plot-gl-set-trees.py --glsfnames data/germlines/human/igh/ighv.fasta:data/germlines/human/igh/ighv.fasta --glslabels foo:bar --locus igh'])
472
+ parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter, epilog=example_str)
473
+ parser.add_argument('--plotdir', default=os.getcwd() + '/gl-set-tree-plots')
474
+ parser.add_argument('--plotname', default='test')
475
+ parser.add_argument('--glsfnames', required=True, help='colon-separated list of germline ig fasta file names')
476
+ parser.add_argument('--glslabels', required=True, help='colon-separated list of labels corresponding to --glsfnames')
477
+ parser.add_argument('--param-dirs', help='parameter dirs for each gls fname, for getting counts for each gene')
478
+ parser.add_argument('--locus', required=True, choices=['igh', 'igk', 'igl'])
479
+ parser.add_argument('--legends', help='colon-separated list of legend labels')
480
+ parser.add_argument('--legend-title')
481
+ parser.add_argument('--pie-chart-faces', action='store_true')
482
+ parser.add_argument('--use-cache', action='store_true', help='use existing raxml output from a previous run (crashes if it isn\'t there)')
483
+ parser.add_argument('--only-print', action='store_true', help='just print the summary, without making any plots')
484
+ parser.add_argument('--debug', action='store_true')
485
+ parser.add_argument('--title')
486
+ parser.add_argument('--title-color')
487
+ parser.add_argument('--region', default='v')
488
+ parser.add_argument('--partis-dir', default=os.getcwd(), help='path to main partis install dir')
489
+ parser.add_argument('--muscle-path', default='./packages/muscle/muscle3.8.31_i86linux64')
490
+ parser.add_argument('--raxml-path', default='./packages/standard-RAxML/raxmlHPC-SSE3') # laptop: '-AVX'
491
+ parser.add_argument('--ref-label', help='label (in --glslabels) corresponding to simulation/truth')
492
+
493
+ args = parser.parse_args()
494
+
495
+ sys.path.insert(1, args.partis_dir) # + '/python')
496
+ try:
497
+ import python.utils as utils
498
+ import python.glutils as glutils
499
+ import python.plotting as plotting
500
+ except ImportError as e:
501
+ print(e)
502
+ raise Exception('couldn\'t import from main partis dir \'%s\' (set with --partis-dir)' % args.partis_dir)
503
+
504
+ args.glsfnames = utils.get_arg_list(args.glsfnames)
505
+ args.glslabels = utils.get_arg_list(args.glslabels)
506
+ args.param_dirs = utils.get_arg_list(args.param_dirs)
507
+ args.legends = utils.get_arg_list(args.legends)
508
+ if not os.path.exists(args.muscle_path):
509
+ raise Exception('muscle binary %s doesn\'t exist (set with --muscle-path)' % args.muscle_path)
510
+ if not os.path.exists(args.raxml_path):
511
+ raise Exception('raxml binary %s doesn\'t exist (set with --raxml-path)' % args.raxml_path)
512
+ if not os.path.exists(args.plotdir):
513
+ os.makedirs(args.plotdir)
514
+ args.leafheight = 10 #20 if args.leaf_names else 10 # arg, kinda messy
515
+ args.novel_dot_size = 2.5
516
+
517
+ assert len(args.glslabels) == len(set(args.glslabels)) # no duplicates
518
+
519
+ plot_trees(args, args.plotdir, args.plotname, args.glsfnames, args.glslabels)
bin/plot-hmms.py ADDED
@@ -0,0 +1,151 @@
1
+ #!/usr/bin/env python3
2
+ from __future__ import absolute_import, division, unicode_literals
3
+ from __future__ import print_function
4
+ import os
5
+ import argparse
6
+ from collections import OrderedDict
7
+ import glob
8
+ import operator
9
+ import yaml
10
+ import sys
11
+ from subprocess import check_call
12
+ import matplotlib as mpl
13
+ from io import open
14
+ mpl.use('Agg')
15
+ import matplotlib.pyplot as plt
16
+
17
+ from pathlib import Path
18
+ partis_dir = str(Path(__file__).parent.parent)
19
+ if not os.path.exists(partis_dir):
20
+ print('WARNING current script dir %s doesn\'t exist, so python path may not be correctly set' % partis_dir)
21
+ sys.path.insert(1, partis_dir) # + '/python')
22
+ import python.plotting as plotting
23
+ import python.paramutils as paramutils
24
+ import python.utils as utils
25
+
26
+ # ----------------------------------------------------------------------------------------
27
+ class ModelPlotter(object):
28
+ def __init__(self, args, base_plotdir):
29
+ self.base_plotdir = base_plotdir
30
+
31
+ self.eps_to_skip = 1e-3
32
+ print('skipping eps %f' % self.eps_to_skip)
33
+
34
+ plot_types = ('transitions', 'emissions')
35
+ for ptype in plot_types:
36
+ plotdir = self.base_plotdir + '/' + ptype
37
+ utils.prep_dir(plotdir, wildlings=['*.png', '*.svg'])
38
+
39
+ if args.hmmdir != None:
40
+ self.filelist = glob.glob(args.hmmdir + '/*.yaml')
41
+ else:
42
+ self.filelist = utils.get_arg_list(args.infiles)
43
+ if len(self.filelist) == 0:
44
+ raise Exception('zero files passed to modelplotter')
45
+
46
+ # ----------------------------------------------------------------------------------------
47
+ def plot(self):
48
+ for infname in self.filelist:
49
+ gene_name = os.path.basename(infname).replace('.yaml', '') # the sanitized name, actually
50
+ with open(infname) as infile:
51
+ model = yaml.load(infile, Loader=yaml.Loader)
52
+ self.make_transition_plot(gene_name, model)
53
+ self.make_emission_plot(gene_name, model)
54
+
55
+ # ----------------------------------------------------------------------------------------
56
+ def make_transition_plot(self, gene_name, model):
57
+ """ NOTE shares a lot with make_mutefreq_plot() in python/paramutils.py """
58
+ fig, ax = plotting.mpl_init()
59
+ fig.set_size_inches(plotting.plot_ratios[utils.get_region(gene_name)])
60
+
61
+ ibin = 0
62
+ print(utils.color_gene(utils.unsanitize_name(gene_name)))
63
+ legend_colors = set() # add a color to this the first time you plot it
64
+ for state in model.states:
65
+
66
+ # bin label
67
+ ax.text(-0.5 + ibin, -0.075, paramutils.simplify_state_name(state.name), rotation='vertical', size=8)
68
+
69
+ sorted_to_states = {}
70
+ for name in state.transitions.keys():
71
+ if name.find('IG') == 0 or name.find('TR') == 0:
72
+ sorted_to_states[name] = int(paramutils.simplify_state_name(name))
73
+ else:
74
+ sorted_to_states[name] = name
75
+ sorted_to_states = sorted(list(sorted_to_states.items()), key=lambda x: str(x[1]))
76
+
77
+ total = 0.0
78
+ for to_state, simple_to_state in sorted_to_states:
79
+
80
+ prob = state.transitions[to_state]
81
+
82
+ alpha = 0.6
83
+ width = 3
84
+
85
+ if 'insert' in str(simple_to_state):
86
+ label = 'insert'
87
+ color = '#3498db' # blue
88
+ elif str(simple_to_state) == 'end':
89
+ label = 'end'
90
+ color = 'red'
91
+ else: # regional/internal states
92
+ assert to_state.find('IG') == 0 or to_state.find('TR') == 0
93
+ label = 'internal'
94
+ color = 'green'
95
+
96
+ label_to_use = None
97
+ if color not in legend_colors:
98
+ label_to_use = label
99
+ legend_colors.add(color)
100
+
101
+ # horizontal line at height total+prob
102
+ ax.plot([-0.5 + ibin, 0.5 + ibin], [total + prob, total + prob], color=color, linewidth=width, alpha=alpha, label=label_to_use)
103
+
104
+ # vertical line from total to total + prob
105
+ ax.plot([ibin, ibin], [total + 0.01, total + prob], color=color, alpha=alpha, linewidth=width)
106
+
107
+ midpoint = 0.5*(prob + 2*total)
108
+ # ax.text(ibin, midpoint, paramutils.simplify_state_name(to_state)) # nicely labels the midpoint of the chunk between lines, but there isn't really room for it
109
+
110
+ total += prob
111
+
112
+ ibin += 1
113
+
114
+ ax.get_xaxis().set_visible(False)
115
+ plotting.mpl_finish(ax, self.base_plotdir + '/transitions', gene_name, ybounds=(-0.01, 1.01), xbounds=(-3, len(model.states) + 3), leg_loc=(0.95, 0.1), adjust={'left' : 0.1, 'right' : 0.8}, leg_prop={'size' : 8})
116
+
117
+ # ----------------------------------------------------------------------------------------
118
+ def make_emission_plot(self, gene_name, model):
119
+ plotting_info = []
120
+ for state in model.states:
121
+ if state.emissions is None:
122
+ assert state.name == 'init'
123
+ continue
124
+ plotting_info.append({
125
+ 'name' : state.name,
126
+ 'nuke_freqs' : state.emissions['probs'],
127
+ 'gl_nuke' : state.extras['germline'] if 'germline' in state.extras else None
128
+ })
129
+
130
+ paramutils.make_mutefreq_plot(self.base_plotdir + '/emissions', gene_name, plotting_info, debug=True)
131
+
132
+ # ----------------------------------------------------------------------------------------
133
+
134
+ parser = argparse.ArgumentParser()
135
+ parser.add_argument('--hmmdir', help='directory with .yaml hmm model files, e.g. test/reference-results/test/parameters/simu/hmm/hmms')
136
+ parser.add_argument('--infiles', help='colon-separated list of .yaml hmm model files (either set this, or set --hmmdir)')
137
+ parser.add_argument('--outdir', required=True)
138
+ args = parser.parse_args()
139
+
140
+ if not os.path.exists(args.outdir):
141
+ os.makedirs(args.outdir)
142
+
143
+ if args.hmmdir is None and args.infiles is None:
144
+ raise Exception('have to specify either --hmmdir or --infiles')
145
+
146
+ if __name__ == '__main__':
147
+ print(' %s the top line in the emission plots is usually yellow because the three non-germline bases are equally likely, and G comes last when sorted alphabetically' % utils.color('red', 'note'))
148
+ if not os.path.exists(args.outdir):
149
+ raise Exception('output directory %s does not exist' % args.outdir)
150
+ mplot = ModelPlotter(args, args.outdir) # + '/modelplots')
151
+ mplot.plot()