zipstrain 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
zipstrain/visualize.py ADDED
@@ -0,0 +1,586 @@
1
+ """zipstrain.visualize
2
+ ========================
3
+ This module provides statistical analysis and visualization functions for profiling and compare operations.
4
+ """
5
+
6
+ import polars as pl
7
+ import plotly.graph_objects as go
8
+ import seaborn as sns
9
+ import numpy as np
10
+ from itertools import chain, combinations
11
+ from collections import defaultdict
12
+ import matplotlib.patches as mpatches
13
+ import pandas as pd
14
+
15
+
16
+
17
+ def get_cdf(data, num_bins=10000):
18
+ """Calculate the cumulative distribution function (CDF) of the given data."""
19
+ if data[0] == -1:
20
+ return [-1], [-1]
21
+ counts, bin_edges = np.histogram(data, bins=np.linspace(0, 50000, num_bins))
22
+ counts = counts[::-1]
23
+ bin_edges = bin_edges[::-1]
24
+ cummulative_counts = np.cumsum(counts)
25
+ cdf= cummulative_counts / cummulative_counts[-1]
26
+ return bin_edges, cdf
27
+
28
+ def calculate_strainsharing(
29
+ comps_lf:pl.LazyFrame,
30
+ breadth_lf:pl.LazyFrame,
31
+ sample_to_population:pl.LazyFrame,
32
+ min_breadth:float=0.5,
33
+ strain_similarity_threshold:float=99.9,
34
+ min_total_positions:int=10000
35
+ )->dict[str, list[float]]:
36
+
37
+
38
+ """
39
+ Calculate strain sharing between populations based on popANI between genomes in their profiles.
40
+ Strain sharing between two samples is defined as the ratio of genomes passing a strain similarity threshold over the total number of genomes in each sample.
41
+ So, for two samples A and B, the strain sharing is defined as (Note the assymetric nature of the calculation):
42
+ strain_sharing(A, B) = (number of genomes in A and B passing the strain similarity threshold) / (number of genomes in A)
43
+ strain_sharing(B, A) = (number of genomes in A and B passing the strain similarity threshold) / (number of genomes in B)
44
+
45
+ Args:
46
+ comps_lf (pl.LazyFrame): LazyFrame containing the gene profiles of the samples.
47
+ breadth_lf (pl.LazyFrame): LazyFrame containing the genome breadth information.
48
+ sample_to_population (pl.LazyFrame): LazyFrame containing the sample to population mapping.
49
+ min_breadth (float, optional): Minimum genome breadth to consider a genome for strain sharing. Defaults to 0.5.
50
+ strain_similarity_threshold (float, optional): Threshold for strain similarity. Defaults to 0.99.
51
+ min_total_positions (int, optional): Minimum total positions to consider a genome for strain sharing. Defaults to 10000.
52
+ Returns:
53
+ pl.LazyFrame: LazyFrame containing the strain sharing information between populations. It will be in the following form [Sample A, Sample B, Strain Sharing, Relationship]
54
+ """
55
+ comps_lf=comps_lf.filter(
56
+ (pl.col("total_positions")>min_total_positions)
57
+ ).collect(engine="streaming").lazy()
58
+ breadth_lf=breadth_lf.fill_null(0.0)
59
+ breadth_lf_long=(
60
+ breadth_lf.unpivot(
61
+ index=["genome"],
62
+ variable_name="sample",
63
+ value_name="breadth"
64
+ )
65
+ )
66
+ breadth_lf=breadth_lf_long.group_by("sample").agg(num_genomes=(pl.col("breadth")>=min_breadth).sum())
67
+ comps_lf=comps_lf.join(breadth_lf,
68
+ left_on='sample_1',
69
+ right_on='sample',
70
+ how='left',
71
+ ).rename(
72
+ {"num_genomes":"num_genomes_1"}
73
+ ).join(
74
+ breadth_lf,
75
+ left_on='sample_2',
76
+ right_on='sample',
77
+ how='left',
78
+ ).rename(
79
+ {"num_genomes":"num_genomes_2"}
80
+ )
81
+ comps_lf = comps_lf.join(
82
+ sample_to_population,
83
+ left_on='sample_1',
84
+ right_on='sample',
85
+ how='left',
86
+ ).rename(
87
+ {"population":"population_1"}
88
+ ).join(
89
+ sample_to_population,
90
+ left_on='sample_2',
91
+ right_on='sample',
92
+ how='left',
93
+ ).rename(
94
+ {"population":"population_2"}
95
+ )
96
+ comps_lf=comps_lf.join(
97
+ breadth_lf_long,
98
+ left_on=["genome","sample_1"],
99
+ right_on=['genome','sample'],
100
+ how='left',
101
+ ).rename(
102
+ {"breadth":"breadth_1"}
103
+ ).join(
104
+ breadth_lf_long,
105
+ left_on=["genome","sample_2"],
106
+ right_on=['genome','sample'],
107
+ how='left',
108
+ ).rename(
109
+ {"breadth":"breadth_2"}
110
+ )
111
+ comps_lf=comps_lf.filter(
112
+ (pl.col("breadth_1") >= min_breadth) &
113
+ (pl.col("breadth_2") >= min_breadth) &
114
+ (pl.col("genome_pop_ani") >= strain_similarity_threshold)
115
+ )
116
+
117
+ comps_lf=comps_lf.group_by(
118
+ ["sample_1", "sample_2"]
119
+ ).agg(
120
+ pl.col("genome").count().alias("shared_strain_count"),
121
+ pl.col("num_genomes_1").first().alias("num_genomes_1"),
122
+ pl.col("num_genomes_2").first().alias("num_genomes_2"),
123
+ pl.col("population_1").first().alias("population_1"),
124
+ pl.col("population_2").first().alias("population_2"),
125
+ ).collect(engine="streaming")
126
+ strainsharingrates=defaultdict(list)
127
+ for row in comps_lf.iter_rows(named=True):
128
+ strainsharingrates[row["population_1"]+"_"+ row["population_2"]].append(row["shared_strain_count"] / row["num_genomes_1"])
129
+ strainsharingrates[row["population_2"]+"_"+ row["population_1"]].append(row["shared_strain_count"] / row["num_genomes_2"])
130
+ return strainsharingrates
131
+
132
+ def calculate_ibs(
133
+ sample_to_population:pl.LazyFrame,
134
+ comps_lf:pl.LazyFrame,
135
+ max_perc_id_genes:float=15,
136
+ min_total_positions:int=10000,
137
+ )->pl.DataFrame:
138
+ """
139
+ Calculate the Identity By State (IBS) between two populations for a given genome.
140
+ The IBS is defined as the percentage of genes that are identical between two populations for a given genome.
141
+ Args:
142
+ sample_to_population (pl.LazyFrame): LazyFrame containing the sample to population mapping.
143
+ comps_lf (pl.LazyFrame): LazyFrame containing the gene profiles of the samples.
144
+ max_perc_id_genes (float, optional): Maximum percentage of identical genes to consider. Defaults to 0.15.
145
+ Returns:
146
+ pl.LazyFrame: LazyFrame containing the IBS information for the given genome and populations.
147
+ """
148
+ comps_lf_filtered = comps_lf.filter(
149
+ (pl.col('perc_id_genes') <= max_perc_id_genes) &
150
+ (pl.col('total_positions')>min_total_positions)
151
+ )
152
+ comps_lf_filtered=comps_lf_filtered.join(
153
+ sample_to_population,
154
+ left_on='sample_1',
155
+ right_on='sample',
156
+ how='inner',
157
+ ).rename(
158
+ {"population":"population_1"}
159
+ ).join(
160
+ sample_to_population,
161
+ left_on='sample_2',
162
+ right_on='sample',
163
+ how='inner',
164
+ suffix='_2'
165
+ ).rename(
166
+ {"population":"population_2"}
167
+ )
168
+ comps_lf_filtered = comps_lf_filtered.with_columns(
169
+ pl.when(pl.col("population_1") == pl.col("population_2"))
170
+ .then(
171
+ pl.lit("within_population_")
172
+ + pl.col("population_1")
173
+ + pl.lit("|")
174
+ + pl.col("population_2")
175
+ )
176
+ .otherwise(
177
+ pl.concat_str(
178
+ [
179
+ pl.lit("between_population_"),
180
+ pl.concat_str(
181
+ [
182
+ pl.min_horizontal("population_1", "population_2"),
183
+ pl.lit("|"),
184
+ pl.max_horizontal("population_1", "population_2"),
185
+ ]
186
+ ),
187
+ ]
188
+ )
189
+ )
190
+ .alias("comparison_type")
191
+ ).fill_null(-1)
192
+
193
+ return comps_lf_filtered.group_by(["genome","comparison_type"]).agg(
194
+ pl.col("max_consecutive_length"),
195
+ ).collect(engine="streaming").pivot(
196
+ index="genome",
197
+ columns="comparison_type",
198
+ values="max_consecutive_length",
199
+ ).with_columns(
200
+ pl.col("*").exclude("genome").fill_null([-1])
201
+ )
202
+
203
+ def plot_ibs_heatmap(
204
+ df:pl.DataFrame,
205
+ vert_thresh:float=0.001,
206
+ populations:list[str]|None=None,
207
+ num_bins:int=10000,
208
+ min_member:int=50,
209
+ title:str="IBS Heatmap",
210
+ xaxis_title:str="Population Pair",
211
+ yaxis_title:str="Genome",
212
+
213
+ ):
214
+ """
215
+ Plot the Identity By State (IBS) heatmap for a given genome and two populations.
216
+ Args:
217
+ df (pl.DataFrame): DataFrame containing the IBS information.
218
+ title (str, optional): Title of the plot. Defaults to "IBS Heatmap".
219
+ xaxis_title (str, optional): Title of the x-axis. Defaults to "Population Pair".
220
+ yaxis_title (str, optional): Title of the y-axis. Defaults to "Genome".
221
+ Returns:
222
+ go.Figure: Plotly figure containing the IBS heatmap.
223
+ """
224
+ df = df.with_columns(
225
+ [
226
+ pl.when(pl.col(c).list.len() < min_member)
227
+ .then(pl.lit([-1]))
228
+ .otherwise(pl.col(c))
229
+ .alias(c)
230
+ for c in df.columns if c != "genome"
231
+ ]
232
+ )
233
+ if populations is None:
234
+ populations=set(chain.from_iterable(i.replace("within_population_","").replace("between_population_","").split("|") for i in df.columns if i!="genome"))
235
+ populations=sorted(populations)
236
+ heatmap_data = df.rows_by_key("genome", unique=True,include_key=False,named=True)
237
+ fig_data={}
238
+ for genome, genome_data in heatmap_data.items():
239
+ fig_data[genome]={}
240
+ for pop1,pop2 in combinations(populations,2):
241
+ key_between=f"between_population_{min(pop1,pop2)}|{max(pop1,pop2)}"
242
+ key_within_1=f"within_population_{pop1}|{pop1}"
243
+ key_within_2=f"within_population_{pop2}|{pop2}"
244
+ if genome_data.get(key_between, [-1])==[-1] or genome_data.get(key_within_1, [-1])==[-1] or genome_data.get(key_within_2, [-1])==[-1]:
245
+ fig_data[genome][f"{min(pop1,pop2)}-{max(pop1,pop2)}"]=-1
246
+ continue
247
+ between=get_cdf(genome_data[key_between], num_bins=num_bins)
248
+ within=get_cdf(genome_data[key_within_1]+genome_data[key_within_2], num_bins=num_bins)
249
+
250
+ between_intersect=between[0][np.where(between[1]>=vert_thresh)[0][0]]
251
+ within_intersect=within[0][np.where(within[1]>=vert_thresh)[0][0]]
252
+ distance=within_intersect-between_intersect
253
+ fig_data[genome][f"{min(pop1,pop2)}-{max(pop1,pop2)}"]=distance
254
+ ###Filter the dataframe to only have useful information
255
+ heatmap_df = pd.DataFrame(fig_data).T
256
+ heatmap_df=heatmap_df.mask(heatmap_df < 0, 0)
257
+ heatmap_df=heatmap_df[heatmap_df.sum(axis=1)>0]
258
+ heatmap_df=heatmap_df[[col for col in heatmap_df.columns if heatmap_df[col].sum()>0]]
259
+ heatmap_df_sorted = heatmap_df.assign(row_sum=heatmap_df.sum(axis=1)).sort_values("row_sum", ascending=True).drop(columns="row_sum")
260
+
261
+ fig = go.Figure(data=go.Heatmap(
262
+ z=heatmap_df_sorted.values,
263
+ x=heatmap_df_sorted.columns,
264
+ y=heatmap_df_sorted.index
265
+ ))
266
+ return fig
267
+
268
+ def plot_strainsharing(
269
+ strainsharingrates:dict[str, list[float]],
270
+ sample_frac:float=1,
271
+ title:str="Strain Sharing Rates",
272
+ xaxis_title:str="Population Pair",
273
+ yaxis_title:str="Strain Sharing Rate",
274
+ ):
275
+ """
276
+ Plot the strain sharing rates between populations.
277
+ Args:
278
+ strainsharingrates (dict[str, list[float]]): Dictionary containing the strain sharing rates between populations.
279
+ title (str, optional): Title of the plot. Defaults to "Strain Sharing".
280
+ xaxis_title (str, optional): Title of the x-axis. Defaults to "Population Pair".
281
+ yaxis_title (str, optional): Title of the y-axis. Defaults to "Strain Sharing Rate".
282
+ Returns:
283
+ go.Figure: Plotly figure containing the strain sharing plot.
284
+ """
285
+ for key in strainsharingrates.keys():
286
+ strainsharingrates[key] = np.random.choice(strainsharingrates[key], size=int(len(strainsharingrates[key]) * sample_frac), replace=False)
287
+ fig = go.Figure()
288
+ for pair, rates in strainsharingrates.items():
289
+ fig.add_trace(go.Box(
290
+ y=rates,
291
+ name=pair,
292
+ boxpoints='all',
293
+ jitter=0.3,
294
+ pointpos=0
295
+ ))
296
+ fig.update_layout(
297
+ title={"text": title, "x": 0.5},
298
+ xaxis_title=xaxis_title,
299
+ yaxis_title=yaxis_title
300
+ )
301
+ return fig
302
+ def plot_ibs(df:pl.DataFrame,
303
+ genome:str,
304
+ population_1:str,
305
+ population_2:str,
306
+ vert_thresh_hor_distance:float=0.001,
307
+ num_bins:int=10000,
308
+ title:str="IBS for <GENOME>: <POPULATION_1> vs <POPULATION_2>",
309
+ xaxis_title:str="Max Consecutive Length",
310
+ yaxis_title:str="CDF"
311
+ ):
312
+ """
313
+ Plot the Identity By State (IBS) for a given genome and two populations.
314
+ Args:
315
+ df (pl.DataFrame): DataFrame containing the IBS information.
316
+ genome (str): The genome to plot the IBS for.
317
+ population_1 (str): The first population to plot the IBS for.
318
+ population_2 (str): The second population to plot the IBS for.
319
+ title (str, optional): Title of the plot. Defaults to "IBS for <GENOME>".
320
+ xaxis_title (str, optional): Title of the x-axis. Defaults to "Membership".
321
+ yaxis_title (str, optional): Title of the y-axis. Defaults to "Max Consecutive Length".
322
+ Returns:
323
+ go.Figure: Plotly figure containing the IBS plot.
324
+ """
325
+ df_filtered = df.filter(pl.col("genome") == genome)
326
+ if df_filtered.is_empty():
327
+ raise ValueError(f"Genome {genome} not found in the dataframe.")
328
+ plot_data = {}
329
+ key_within_1=f"within_population_{population_1}|{population_1}"
330
+ key_within_2=f"within_population_{population_2}|{population_2}"
331
+ key_between=f"between_population_{min(population_1,population_2)}|{max(population_1,population_2)}"
332
+ if df_filtered.get_column(key_within_1).list.len()[0]==0 or df_filtered.get_column(key_within_2).list.len()[0]==0 or df_filtered.get_column(key_between).list.len()[0]==0:
333
+ raise ValueError(f"Not enough data for populations {population_1} and {population_2} in genome {genome}.")
334
+ plot_data["within_population"]=df_filtered.get_column(key_within_1)[0].to_list()+df_filtered.get_column(key_within_2)[0].to_list()
335
+ plot_data["between_population"]=df_filtered.get_column(key_between)[0].to_list()
336
+ fig = go.Figure()
337
+ between_pop_cdf=get_cdf(plot_data["between_population"], num_bins=num_bins)
338
+ fig.add_trace(go.Scatter(
339
+ x=between_pop_cdf[0][1:].copy(),
340
+ y=between_pop_cdf[1][1:].copy(),
341
+ mode='lines',
342
+ name='between_population',
343
+ line=dict(color='blue')
344
+ ))
345
+ within_pop_cdf=get_cdf(plot_data["within_population"], num_bins=num_bins)
346
+ fig.add_trace(go.Scatter(
347
+ x=within_pop_cdf[0][1:].copy(),
348
+ y=within_pop_cdf[1][1:].copy(),
349
+ mode='lines',
350
+ name='within_population',
351
+ line=dict(color='green')
352
+ ))
353
+
354
+ bin_edges=within_pop_cdf[0]
355
+ cdf=within_pop_cdf[1]
356
+ within_intersect=bin_edges[np.where(cdf>=vert_thresh_hor_distance)[0][0]]
357
+ bin_edges=between_pop_cdf[0]
358
+ cdf=between_pop_cdf[1]
359
+ between_intersect=bin_edges[np.where(cdf>=vert_thresh_hor_distance)[0][0]]
360
+ distance=within_intersect-between_intersect
361
+
362
+ fig.update_layout(
363
+ title={"text": title.replace("<GENOME>", genome).replace("<POPULATION_1>", population_1).replace("<POPULATION_2>", population_2), "x": 0.5},
364
+ xaxis_title=xaxis_title,
365
+ yaxis_title=yaxis_title,
366
+
367
+ )
368
+ ###Add a horizontal line from (between_intersect, vert_thresh_hor_distance) to (within_intersect, vert_thresh_hor_distance)
369
+ fig.add_trace(go.Scatter(
370
+ x=[between_intersect, within_intersect],
371
+ y=[vert_thresh_hor_distance, vert_thresh_hor_distance],
372
+ mode='lines+markers',
373
+ line=dict(color='black'),
374
+ showlegend=False
375
+ ))
376
+ ###Add a text annotation at the middle of the horizontal line with the distance
377
+ fig.add_trace(go.Scatter(
378
+ x=[(between_intersect+within_intersect)/2],
379
+ y=[vert_thresh_hor_distance],
380
+ mode="text",
381
+ text=int(distance),
382
+ textposition="top center",
383
+ showlegend=False
384
+ ))
385
+ ##make both axes logarithmic
386
+ fig.update_xaxes(type='log')
387
+ fig.update_yaxes(type='log')
388
+
389
+ return fig
390
+ def calculate_identical_frac_vs_popani(
391
+ genome:str,
392
+ population_1:str,
393
+ population_2:str,
394
+ sample_to_population:pl.LazyFrame,
395
+ comps_lf:pl.LazyFrame,
396
+ min_shared_genes_count:int=100,
397
+ min_total_positions:int=10000
398
+ ):
399
+ """
400
+ Calculate the fraction of identical genes vs popANI for a given genome and two samples in any possible combination of populations.
401
+ Args:
402
+ genome (str): The genome to calculate the fraction of identical genes vs popANI for.
403
+ population_1 (str): The first population to compare.
404
+ population_2 (str): The second population to compare.
405
+ sample_to_population (pl.LazyFrame): LazyFrame containing the sample to population mapping.
406
+ comps_lf (pl.LazyFrame): LazyFrame containing the gene profiles of the samples
407
+ Returns:
408
+ pl.LazyFrame: LazyFrame containing the fraction of identical genes vs popANI information for
409
+ """
410
+ comps_lf_filtered=comps_lf.filter(
411
+ (pl.col('genome') == genome) &
412
+ (pl.col("shared_genes_count")>min_shared_genes_count) &
413
+ (pl.col("total_positions")>min_total_positions)
414
+ ).collect(engine="streaming").lazy()
415
+
416
+ comps_lf_filtered=comps_lf_filtered.join(
417
+ sample_to_population,
418
+ left_on='sample_1',
419
+ right_on='sample',
420
+ how='left',
421
+ ).rename(
422
+ {"population":"population_1"}
423
+ ).join(
424
+ sample_to_population,
425
+ left_on='sample_2',
426
+ right_on='sample',
427
+ how='left',
428
+ suffix='_2'
429
+ ).rename(
430
+ {"population":"population_2"}
431
+ )
432
+ comps_lf_filtered = comps_lf_filtered.filter(
433
+ (pl.col("population_1").is_in({population_1, population_2})) &
434
+ (pl.col("population_2").is_in({population_1, population_2}))
435
+ ).collect(engine="streaming").lazy()
436
+ groups={
437
+ "same_1":f"{population_1}-{population_1}",
438
+ "same_2":f"{population_2}-{population_2}",
439
+ "diff":f"{population_1}-{population_2}",
440
+ }
441
+ comps_lf_filtered=comps_lf_filtered.with_columns(
442
+ pl.when((pl.col("population_1")==population_1) & (pl.col("population_2")==population_1))
443
+ .then(pl.lit(groups["same_1"]))
444
+ .when((pl.col("population_1")==population_2) & (pl.col("population_2")==population_2))
445
+ .then(pl.lit(groups["same_2"]))
446
+ .otherwise(pl.lit(groups["diff"]))
447
+ .alias("relationship")
448
+ )
449
+ return comps_lf_filtered.group_by("relationship").agg(
450
+ pl.col("perc_id_genes"),
451
+ pl.col("genome_pop_ani")
452
+ ).collect(engine="streaming")
453
+
454
+ def plot_identical_frac_vs_popani(df:pl.DataFrame,
455
+ genome:str,
456
+ title:str="Fraction of Identical Genes vs popANI for <GENOME>",
457
+ xaxis_title:str="Genome-Wide popANI",
458
+ yaxis_title:str="Fraction of Identical Genes",
459
+ ):
460
+ """
461
+ Plot the fraction of identical genes vs popANI for a given genome and two samples in any possible combination of populations.
462
+ Args:
463
+ df (pl.DataFrame): DataFrame containing the fraction of identical genes vs popANI information.
464
+ title (str, optional): Title of the plot. Defaults to "Fraction of Identical Genes vs popANI".
465
+ xaxis_title (str, optional): Title of the x-axis. Defaults to "popANI".
466
+ yaxis_title (str, optional): Title of the y-axis. Defaults to "Fraction of Identical Genes".
467
+ Returns:
468
+ go.Figure: Plotly figure containing the fraction of identical genes vs popANI plot.
469
+ """
470
+ fig = go.Figure()
471
+ for group, perc_id_genes, genome_pop_ani in zip(df["relationship"], df["perc_id_genes"], df["genome_pop_ani"]):
472
+ fig.add_trace(go.Scatter(
473
+ x=genome_pop_ani,
474
+ y=perc_id_genes,
475
+ mode='markers',
476
+ name=group
477
+ ))
478
+ fig.update_layout(
479
+ title=title.replace("<GENOME>", genome),
480
+ xaxis_title=xaxis_title,
481
+ yaxis_title=yaxis_title
482
+ )
483
+ return fig
484
+
485
+ def plot_clustermap(
486
+ comps_lf:pl.LazyFrame,
487
+ genome:str,
488
+ sample_to_population:pl.LazyFrame,
489
+ min_comp_len:int=10000,
490
+ impute_method:str|float=97.0,
491
+ max_null_samples:int=200,
492
+ color_map:dict|None=None,
493
+ ):
494
+ """
495
+ Plot a clustermap for the given genome and its associated samples.
496
+ Args:
497
+ comps_lf (pl.LazyFrame): LazyFrame containing the comparison data.
498
+ genome (str): The genome to plot.
499
+ sample_to_population (pl.LazyFrame): LazyFrame containing the sample to population mapping.
500
+ Returns:
501
+ go.Figure: Plotly figure containing the clustermap.
502
+ """
503
+ # Filter the comparison data for the specific genome
504
+ comps_lf_filtered = comps_lf.filter(
505
+ (pl.col("genome") == genome) & (pl.col("total_positions") > min_comp_len)
506
+ ).select(
507
+ pl.col("sample_1"),
508
+ pl.col("sample_2"),
509
+ pl.col("genome_pop_ani"),
510
+ )
511
+ comps_lf_filtered_oposite = comps_lf_filtered.select(
512
+ pl.col("sample_2").alias("sample_1"),
513
+ pl.col("sample_1").alias("sample_2"),
514
+ pl.col("genome_pop_ani"),
515
+ )
516
+ # Combine the filtered data with its opposite pairs
517
+ comps_lf_filtered = pl.concat([comps_lf_filtered, comps_lf_filtered_oposite])
518
+ # Make a synthetic table for similarity of samples with themselves of all samples in sample_1 and sample_2 but each sample exists only once
519
+ self_similarity =(
520
+ pl.concat([
521
+ comps_lf_filtered.select(pl.col("sample_1").alias("sample_1")),
522
+ comps_lf_filtered.select(pl.col("sample_2").alias("sample_1"))
523
+ ])
524
+ .unique()
525
+ .sort("sample_1").with_columns(
526
+ pl.col("sample_1").alias("sample_2"),
527
+ pl.lit(100.0).alias("genome_pop_ani"),
528
+ )
529
+ )
530
+
531
+ # Combine the self similarity with the filtered data
532
+ comps_lf_filtered = pl.concat([self_similarity, comps_lf_filtered]).collect()
533
+ # Pivot the data for the clustermap
534
+ clustermap_data = comps_lf_filtered.pivot(
535
+ index="sample_1",
536
+ columns="sample_2",
537
+ values="genome_pop_ani"
538
+ )
539
+ # We want to make this a similarity matrix, so we need to frop null values, have sample_1 and sample_2 as index and columns as we
540
+ # Create the clustermap
541
+ exclude_samples=clustermap_data.null_count().transpose(include_header=True, header_name="column", column_names=["null_count"]).filter(pl.col("null_count")>max_null_samples)["column"].to_list()
542
+ # Only include rows and cols not in exclude_samples
543
+ clustermap_data = clustermap_data.filter(~pl.col("sample_1").is_in(exclude_samples))
544
+ clustermap_data = clustermap_data.select(*[col for col in clustermap_data.columns if col not in exclude_samples])
545
+ if isinstance(impute_method, str):
546
+ pass # To be implemented later
547
+ elif isinstance(impute_method, (int, float)):
548
+ clustermap_data = clustermap_data.fill_null(impute_method)
549
+ sample_to_population = clustermap_data.select(pl.col("sample_1")).join(
550
+ sample_to_population.collect(),
551
+ left_on="sample_1",
552
+ right_on="sample",
553
+ how="left")
554
+ sample_to_population_dict = dict(zip(sample_to_population["sample_1"], sample_to_population["population"]))
555
+ if color_map is None:
556
+
557
+ num_categories = sample_to_population["population"].n_unique()
558
+ groups= sample_to_population["population"].unique().sort().to_list()
559
+ qualitative_palette = sns.color_palette("hls", num_categories)
560
+ row_colors = [qualitative_palette[groups.index(sample_to_population_dict[sample])] for sample in clustermap_data["sample_1"]]
561
+ col_colors = [qualitative_palette[groups.index(sample_to_population_dict[sample])] for sample in clustermap_data.columns if sample != "sample_1"]
562
+ else:
563
+ groups= list(color_map.keys())
564
+ qualitative_palette= list(color_map.values())
565
+ row_colors = [color_map[sample_to_population_dict[sample]] for sample in clustermap_data["sample_1"]]
566
+ col_colors = [color_map[sample_to_population_dict[sample]] for sample in clustermap_data.columns if sample != "sample_1"]
567
+ fig = sns.clustermap(
568
+ clustermap_data.to_pandas().set_index("sample_1"),
569
+ figsize=(30, 30),
570
+ xticklabels=True,
571
+ yticklabels=True,
572
+ row_colors=row_colors,
573
+ col_colors=col_colors
574
+ )
575
+ fig.ax_heatmap.set_xticklabels(fig.ax_heatmap.get_xmajorticklabels(), fontsize=0.1)
576
+ fig.ax_heatmap.set_yticklabels(fig.ax_heatmap.get_ymajorticklabels(), fontsize=0.1)
577
+ legend_handles = [mpatches.Patch(color=color, label=label)
578
+ for label, color in zip(groups, qualitative_palette)]
579
+ fig.ax_heatmap.legend(handles=legend_handles,
580
+ title='Population',
581
+ title_fontsize=16, # bigger title
582
+ fontsize=14, # bigger labels
583
+ handlelength=2.5, # wider color boxes
584
+ handleheight=2,
585
+ bbox_to_anchor=(-0.15, 1), loc="lower left")
586
+ return fig
@@ -0,0 +1,27 @@
1
+ Metadata-Version: 2.3
2
+ Name: zipstrain
3
+ Version: 0.2.4
4
+ Summary:
5
+ Author: ParsaGhadermazi
6
+ Author-email: 54489047+ParsaGhadermazi@users.noreply.github.com
7
+ Requires-Python: >=3.12
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: Programming Language :: Python :: 3.12
10
+ Classifier: Programming Language :: Python :: 3.13
11
+ Requires-Dist: aiofiles (>=25.1.0,<26.0.0)
12
+ Requires-Dist: click (>=8.3.1,<9.0.0)
13
+ Requires-Dist: intervaltree (>=3.1.0,<4.0.0)
14
+ Requires-Dist: matplotlib (>=3.10.7,<4.0.0)
15
+ Requires-Dist: numpy (>=2.3.4,<3.0.0)
16
+ Requires-Dist: pandas (>=2.3.3,<3.0.0)
17
+ Requires-Dist: plotly (>=6.3.1,<7.0.0)
18
+ Requires-Dist: polars (>=1.34.0,<2.0.0)
19
+ Requires-Dist: psutil (>=7.1.2,<8.0.0)
20
+ Requires-Dist: pyarrow (>=22.0.0,<23.0.0)
21
+ Requires-Dist: pydantic (>=2.12.3,<3.0.0)
22
+ Requires-Dist: rich (>=14.2.0,<15.0.0)
23
+ Requires-Dist: scipy (>=1.16.2,<2.0.0)
24
+ Requires-Dist: seaborn (>=0.13.2,<0.14.0)
25
+ Description-Content-Type: text/markdown
26
+
27
+
@@ -0,0 +1,12 @@
1
+ zipstrain/__init__.py,sha256=zXmDhRWzn8lYZIfDSmeT8F-w1jrXWrD6D3d97trj7Pw,277
2
+ zipstrain/cli.py,sha256=bBRRCWV3a5iiC5gJLOowrHF2Ay2Guh3Pg27vU1cnVMk,43691
3
+ zipstrain/compare.py,sha256=4Jb5LhjTotw2_xi6CAJHNYYD2QJgX0A003lglkJRXaI,17372
4
+ zipstrain/database.py,sha256=6nD8olTxqcmESdyANTaleL296dsaK5gc0GcDkbqw1Vk,39237
5
+ zipstrain/profile.py,sha256=ZhkdJ0PGVQAR9U8UHYrfGRY0EUZwVktwvJ6bCsjw6uY,8273
6
+ zipstrain/task_manager.py,sha256=it_Swv0oWHMV8jshwFuZiXSiBesCxedrDxrBBvJB86w,85567
7
+ zipstrain/utils.py,sha256=bpnuwCt0LB-K61xNFgXYrYnTAQIW5s7Z-a685DJvM8Q,15660
8
+ zipstrain/visualize.py,sha256=RXaNxoNCFtSRAnRfWxaxgENppfpE-Uu1LYh_FGG-fVs,24950
9
+ zipstrain-0.2.4.dist-info/METADATA,sha256=Y1F6Hi_QHdMy2iUYqXeViPkXEYnTS15VzboBXBLFLgs,936
10
+ zipstrain-0.2.4.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
11
+ zipstrain-0.2.4.dist-info/entry_points.txt,sha256=tuIX83yN65q-_IoQV94AcHWyvRopa0lxPIDTR29UxFk,47
12
+ zipstrain-0.2.4.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: poetry-core 2.1.3
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -0,0 +1,3 @@
1
+ [console_scripts]
2
+ zipstrain=zipstrain.cli:cli
3
+