AlphaMicrobiome 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,25 @@
1
+ Metadata-Version: 2.4
2
+ Name: AlphaMicrobiome
3
+ Version: 0.1.0
4
+ Summary: Microbiome analysis
5
+ Author: Zhao Yu
6
+ Requires-Python: >=3.12
7
+ Description-Content-Type: text/markdown
8
+ Requires-Dist: ipykernel
9
+ Requires-Dist: loguru
10
+ Requires-Dist: tqdm
11
+ Requires-Dist: numpy
12
+ Requires-Dist: pandas
13
+ Requires-Dist: scipy
14
+ Requires-Dist: statsmodels
15
+ Requires-Dist: pingouin
16
+ Requires-Dist: matplotlib
17
+ Requires-Dist: seaborn
18
+ Requires-Dist: statannotations
19
+ Requires-Dist: scikit-learn
20
+ Requires-Dist: shap
21
+ Requires-Dist: torch
22
+ Requires-Dist: torchvision
23
+ Requires-Dist: scikit-bio
24
+ Requires-Dist: networkx
25
+ Requires-Dist: reflex
@@ -0,0 +1,17 @@
1
+ microbiome/__init__.py,sha256=Z2kmWD29Ze5ZsQo1vLaNYgHljtwCgAQGXeiMQhpNuzs,276
2
+ microbiome/amplicon.py,sha256=ArZZpV6Pk_rYt3Xj9ugcc8C1csSb80yIZZH7iGWQodA,26481
3
+ microbiome/betaNTI.py,sha256=ZBA7jmor1VGGopAXBnRaCSvuBK1o9MEr_DdnLiTwk7Q,26063
4
+ microbiome/colors.py,sha256=JcC6qSlxwaplgVjhM27n0iu90FabBgYnmCiCATR83lA,4222
5
+ microbiome/diversity.py,sha256=uUDg2MLnS3dGrK6nXsuszQPCET6Pexq0E1VnvzKNCFs,8598
6
+ microbiome/journal.py,sha256=BvfD7PDLYyOgjRTVENAimUEHQe8jiAOgy5uh6jHjYdo,1348
7
+ microbiome/metacyc_utils.py,sha256=cN8a8vQmVswVQ6BKuKfEbW27j-b9_bkbkBkSprAAxf8,3079
8
+ microbiome/network.py,sha256=Hbehi4Bk5OC6hLxE2S2yo6nkYv6veUIi1i-y4M1Kaag,33013
9
+ microbiome/plot.py,sha256=mzp8gZsgfgEUkxKkQd6kqciNPK2erYgqhInZYAiJ3Zk,1415
10
+ microbiome/skbio_utils.py,sha256=ltTn2i472VkSHGlSP6AhqTpwD3CdTVgC-kB-FTEScZk,8025
11
+ microbiome/stats.py,sha256=W-eqyA4OR0vLNDh44pnV8RBqVJ13Y2QhagRtYlu3EVM,6229
12
+ microbiome/usearch.py,sha256=0smk8a0XNYM0dLAsdiXmdAXrVX4SeMBF6jBBSZMxYdA,11259
13
+ microbiome/vae.py,sha256=gd0pcRU2ul2VtsOK6-E7M-5gBuhiUePHafZfqKjJLlU,17088
14
+ alphamicrobiome-0.1.0.dist-info/METADATA,sha256=dl-sKasqyBZ24m7RxE1EbKb5eAQzUhUDH_qHQ9dKmLg,598
15
+ alphamicrobiome-0.1.0.dist-info/WHEEL,sha256=YCfwYGOYMi5Jhw2fU4yNgwErybb2IX5PEwBKV4ZbdBo,91
16
+ alphamicrobiome-0.1.0.dist-info/top_level.txt,sha256=SXrMtOEl7-o3hqgflE4JDpm5CM8htboeXeTk4Z2h4lI,11
17
+ alphamicrobiome-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (82.0.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1 @@
1
+ microbiome
microbiome/__init__.py ADDED
@@ -0,0 +1,24 @@
1
+
2
+
3
+ from . import (
4
+ amplicon,
5
+ diversity,
6
+ network,
7
+ vae,
8
+ stats,
9
+ colors,
10
+ plot,
11
+ )
12
+
13
+ __version__ = "0.1.0"
14
+ __author__ = "ZhaoYu"
15
+ __all__ = [
16
+ 'amplicon',
17
+ 'diversity',
18
+ 'network',
19
+ 'vae',
20
+ 'stats',
21
+ 'betaNTI',
22
+ 'colors',
23
+ 'plot',
24
+ ]
microbiome/amplicon.py ADDED
@@ -0,0 +1,582 @@
1
+ from loguru import logger
2
+ from typing import List, Tuple, Dict, Optional, Union
3
+
4
+ from pathlib import Path
5
+
6
+ import re
7
+ from tqdm import tqdm
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ import matplotlib.pyplot as plt
12
+ import seaborn as sns
13
+
14
+ from sklearn.preprocessing import StandardScaler, MinMaxScaler, LabelEncoder, OneHotEncoder
15
+
16
+
17
+ def sintax_to_taxonomy(sintax_file_path: str, fill_unclassified: bool = True):
18
+ '''Transformer sintax format to taxonomy format.
19
+ Args:
20
+ sintax_file_path: str, path to sintax file.
21
+ fill_unclassified: bool, whether to fill unclassified taxonomy with "Unclassified_<Rank>".
22
+
23
+ Return taxonomy format:
24
+ OTU_ID Kingdom Phylum Class Order Family Genus Species
25
+ OTU_1 Bacteria Firmicutes Bacilli Lactobacillales Lactobacillaceae Lactobacillus NA
26
+ OTU_2 Bacteria Proteobacteria Gammaproteobacteria ... ... ... ...
27
+
28
+ Examples:
29
+ >>>taxonomy_df = sintax_to_taxonomy(sintax_file_path='./asv.sintax')
30
+ '''
31
+
32
+ otu_sintax_df = pd.read_csv(sintax_file_path, sep='\t', header=None, names=['OTU_ID', 'Taxonomy_with_confidence', 'Strand', 'Taxonomy'])
33
+
34
+ RANK_MAP = {
35
+ "d": "Domain",
36
+ "k": "Kingdom",
37
+ "p": "Phylum",
38
+ "c": "Class",
39
+ "o": "Order",
40
+ "f": "Family",
41
+ "g": "Genus",
42
+ "s": "Species"
43
+ }
44
+
45
+ taxonomy_records: list = []
46
+
47
+ for idx, row in otu_sintax_df.iterrows():
48
+ otu_id = row['OTU_ID']
49
+ tax_str = row["Taxonomy_with_confidence"]
50
+
51
+ # init firstly taxonomy information with NA.
52
+ tax_dict: dict = {rank: "NA" for rank in RANK_MAP.values()} # Domain, Kingdom, Phylum, Class, Order, Family, Genus, Species
53
+ tax_dict['OTU_ID'] = otu_id # OTU_ID, Domain, Kingdom, Phylum, Class, Order, Family, Genus, Species
54
+
55
+ if isinstance(tax_str, str) and tax_str.strip():
56
+ items = tax_str.split(",") # e.g. d:Bacteria(1.0000),p:Actinomycetota(1.0000),
57
+ for item in items:
58
+ m = re.match(r"([dkpcfogs]):([^()]+)\([\d\.]+\)", item.strip())
59
+ if m:
60
+ rank_short, tax_name = m.group(1), m.group(2)
61
+ full_rank = RANK_MAP.get(rank_short)
62
+ if full_rank:
63
+ tax_dict[full_rank] = tax_name
64
+
65
+ taxonomy_records.append(tax_dict)
66
+
67
+ taxonomy_df = pd.DataFrame(taxonomy_records)
68
+ taxonomy_df = taxonomy_df[["OTU_ID"] + list(RANK_MAP.values())] # according rank order to order.
69
+
70
+ if fill_unclassified:
71
+ for rank in RANK_MAP.values():
72
+ taxonomy_df[rank] = taxonomy_df[rank].fillna("NA")
73
+ taxonomy_df[rank] = taxonomy_df[rank].replace("NA", f"Unclassified_{rank}")
74
+
75
+ return taxonomy_df
76
+
77
+
78
+ class Amplicon:
79
+ '''Operator of feature tables, e.g. OTU, Taxonomy, Metadata and so on.'''
80
+
81
+ def __init__(self,
82
+ otu_file_path: str,
83
+ sintax_file_path: str,
84
+ metadata_file_path: str,
85
+ otu_id_file_path: str = None
86
+ ):
87
+ self.otu_file_path = otu_file_path
88
+ self.sintax_file_path = sintax_file_path
89
+ self.metadata_file_path = metadata_file_path
90
+ self.otu_id_file_path = otu_id_file_path
91
+
92
+ def features_parser(self):
93
+ '''Load features (e.g. otu_table, taxonomy_table, metadata_table)
94
+ Return:
95
+ if otu_id_file_path is not None:
96
+ (index2id, id2index, otu_table, taxonomy_table, metadata_table)
97
+ else:
98
+ (otu_table, taxonomy_table, metadata_table)
99
+ '''
100
+
101
+ # load otu table
102
+ otu_table = pd.read_csv(self.otu_file_path, sep='\t', index_col=0)
103
+ otu_table.columns.name = "SampleID"
104
+
105
+ # load taxonomy
106
+ taxonomy_table = sintax_to_taxonomy(sintax_file_path=self.sintax_file_path)
107
+ taxonomy_table.set_index('OTU_ID', inplace=True)
108
+
109
+ # load metadata
110
+ metadata_table = pd.read_excel(self.metadata_file_path, sheet_name='clean')
111
+ metadata_table.set_index('SampleID', inplace=True)
112
+
113
+ if self.otu_id_file_path is not None:
114
+ # load otu id mapping
115
+ otu_id_df = pd.read_csv(self.otu_id_file_path, sep='\t', header=None, names=['Index', 'OTU_ID'])
116
+ index2id = dict(zip(otu_id_df['Index'], otu_id_df['OTU_ID']))
117
+ id2index = dict(zip(otu_id_df['OTU_ID'], otu_id_df['Index']))
118
+ return (index2id, id2index, otu_table, taxonomy_table, metadata_table)
119
+
120
+ return (otu_table, taxonomy_table, metadata_table)
121
+
122
+ @staticmethod
123
+ def merge_otu_metadata_taxonomy_table(otu_table_df: pd.DataFrame, metadata_table_df: pd.DataFrame = None, taxonomy_table_df: pd.DataFrame = None):
124
+ '''Merge otu_table and metadata_table or taxonomy_table.
125
+ Args:
126
+ otu_table_df: otu table with 'OUT_ID' as the index.
127
+ metadata_table_df: metadata table with "SampleID' as the index.
128
+ taxonomy_table_df: taxonomy table with "OTU_ID'' as th index.
129
+ Returns:
130
+ otu_metadata_table_df: otu table merged with metadata table.
131
+ otu_taxonomy_table_df: otu table merged with taxonomy table.
132
+
133
+ Examples:
134
+ >>>otu_meta, otu_taxa = merge_otu_metadata_taxonomy_table(otu_table, metadata_table, taxonomy_table)
135
+ '''
136
+
137
+ # wide to long with columns: [OUT_ID, SampleID, Abundance]
138
+ otu_long_df = otu_table_df.reset_index().melt(id_vars=['OTU_ID'], var_name='SampleID', value_name='Abundance')
139
+
140
+ # merge otu_long_df with metadata_table_df with columns: [OUT_ID, SampleID, Abundance, Group, SoilType, ...]
141
+ if metadata_table_df is not None:
142
+ otu_metadata = pd.merge(left=otu_long_df, right=metadata_table_df, left_on='SampleID', right_index=True, how='left')
143
+
144
+ # merge otu_long_df with taxonomy_table_df with columns: [OUT_ID, SampleID, Abundance, Domain, Kingdom, ...]
145
+ if taxonomy_table_df is not None:
146
+ otu_taxonomy = pd.merge(left=otu_long_df, right=taxonomy_table_df, left_on='OTU_ID', right_index=True, how='left')
147
+
148
+ # return
149
+ if (metadata_table_df is not None) and (taxonomy_table_df is not None):
150
+ return (otu_metadata.set_index('OTU_ID'), otu_taxonomy.set_index('OTU_ID'))
151
+ elif metadata_table_df is not None:
152
+ return otu_metadata.set_index('OTU_ID')
153
+ elif taxonomy_table_df is not None:
154
+ return otu_taxonomy.set_index('OTU_ID')
155
+ else:
156
+ raise ValueError('ERROR: There are no metadata and taxonomy table.')
157
+
158
+ @staticmethod
159
+ def group_by_metadata_taxonomy_table(otu_taxonomy_table_df: pd.DataFrame, taxonomy_rank_col: list = ['Species'], metadata_table_df: pd.DataFrame = None):
160
+ '''Group OTU table by taxonomy rank and metadata group.
161
+ Args:
162
+ otu_taxonomy_table_df: OTU table merged with taxonomy table with columns: [OTU_ID, SampleID, Abundance, Domain, Kingdom, ...]
163
+ taxonomy_rank_col: list, taxonomy rank columns to group by, e.g. ['Phylum', 'Genus']
164
+ metadata_table_df: metadata table with "SampleID' as the index, columns: [Group, SoilType, ...]
165
+ Returns:
166
+ otu_grouped_rank_group: DataFrame, grouped OTU table with index: taxonomy_rank_col, columns: [SampleID, Abundance, metadata_group_col...]
167
+
168
+ Examples:
169
+ >>>grouped_otu_df = group_by_metadata_taxonomy_table(otu_taxonomy_table_df, taxonomy_rank_col=["Phylum"], metadata_table_df=metadata_table)
170
+ '''
171
+
172
+ # 1. otu_taxonomy_table: groupby taxonomy rank.
173
+ # index: taxonomy_rank_col; columns: SampleID, Abundance.
174
+ otu_grouped_taxonomy_rank_table = otu_taxonomy_table_df.groupby(by=taxonomy_rank_col + ['SampleID'])['Abundance'].agg('sum').reset_index().set_index(taxonomy_rank_col)
175
+
176
+ # 2. merge with metadata_table_df
177
+ # index: SampleID; columns: SampleID, Abundance, metadata_group_col...
178
+ otu_grouped_rank_group = pd.merge(left=otu_grouped_taxonomy_rank_table, right=metadata_table_df, left_on='SampleID', right_index=True, how='left')
179
+
180
+ return otu_grouped_rank_group
181
+
182
+ @staticmethod
183
+ def subset_by_group(taxa_metadata_table_df: pd.DataFrame, group_col: str, group_name: str = None,
184
+ save_dir_path: str = None, prefix: str = 'otu'):
185
+ '''Extract DataFrame from taxa_metadata_table according to group_col == group_name.
186
+ Args:
187
+ taxa_metadata_table_df: DataFrame, taxa [OTU, Species, Genus, ...] merged with metadata [SampleID, Abundance, Group, SoilType, ...]
188
+ group_col: str, metadata column to group by, e.g. 'PartName'`
189
+ group_name: str, metadata group name to filter, e.g. 'Bulk'. If None, return all groups as a dict.
190
+ Returns:
191
+ if group_name is None:
192
+ groups_dict: dict, key is group_name, value is DataFrame of the group.
193
+ else:
194
+ df_group: DataFrame, filtered DataFrame of the group.
195
+ '''
196
+
197
+ def long_to_wide(df_long: pd.DataFrame):
198
+ df_wide = df_long.reset_index().pivot(
199
+ index=[prefix],
200
+ columns='SampleID',
201
+ values='Abundance',
202
+ )
203
+ df_wide.index.rename('OTU_ID', inplace=True) # rename index, for subsequent FastSpar input requirement.
204
+ return df_wide
205
+
206
+ if group_name is None:
207
+ groups = taxa_metadata_table_df[group_col].unique().tolist()
208
+ groups_dict = {}
209
+ for g in tqdm(groups, desc=f'Grouping by {group_col}'):
210
+ df_group = taxa_metadata_table_df[taxa_metadata_table_df[group_col] == g]
211
+
212
+ # save if needed
213
+ if save_dir_path is not None:
214
+ save_path = Path(save_dir_path) / f"{prefix}_{group_col}_{g}.txt"
215
+ if not save_path.parent.exists():
216
+ save_path.parent.mkdir(parents=True, exist_ok=True)
217
+ long_to_wide(df_group).to_csv(save_path, sep='\t')
218
+
219
+ groups_dict[g] = df_group
220
+ return groups_dict
221
+
222
+ else:
223
+ df_group = taxa_metadata_table_df[taxa_metadata_table_df[group_col] == group_name]
224
+ if save_dir_path is not None:
225
+ save_path = Path(save_dir_path) / f"{prefix}_{group_col}_{group_name}.txt"
226
+ if not save_path.parent.exists():
227
+ save_path.parent.mkdir(parents=True, exist_ok=True)
228
+ long_to_wide(df_group).to_csv(save_path, sep='\t')
229
+ return df_group
230
+
231
+
232
+ class OTUQC:
233
+ '''OTU table quality control including filter and sparsity curve.
234
+
235
+ Examples:
236
+ >>>otu_table_filetered = OTUQC.filter_otu_table(otu_table_df=otu_table, min_prevalence=0.01, min_abundance=0.00001)
237
+ >>>otu_table_filtered_topk = OTUQC.keep_top_otus(otu_table_df=otu_table_filtered, top_k=3000)
238
+ >>>otu_table_for_network = OTUQC.otu_for_network(otu_table_df=otu_table, min_prevalence=0.01, min_abundance=0.00001, top_k=3000)
239
+ >>>sparsities = OTUQC.sparsity_curve(otu_table_df=otu_table_filtered_topk, steps=[1000, 2000, 5000, 10000, 20000])
240
+ '''
241
+
242
+ @staticmethod
243
+ def filter_otu_table_by_sample(otu_table_df: pd.DataFrame, method: str = 'iqr', min_max_threshold: Tuple[float, float] = (0.25, 0.75), iqr_scale_factor: float = 1.5, show: bool = True) -> pd.DataFrame:
244
+ '''Filter samples based on total abundance.
245
+ Args:
246
+ otu_table_df: OTU table DataFrame, OTU_ID as index, sample IDs as columns.
247
+ method: str, filtering method, options: 'iqr', 'quantile', 'abundance'; default is 'iqr'.
248
+ min_max_threshold: Tuple[float, float], minimum and maximum quantile to filter samples based on total abundance.
249
+ iqr_scale_factor: float, scale factor for IQR method.
250
+ show: bool, whether to show the filtering plots.
251
+ Returns:
252
+ filtered_otu_table_df: filtered OTU table DataFrame.
253
+ '''
254
+
255
+ total_abundance = otu_table_df.sum(axis=0) # for each sample
256
+ min_quantile, max_quantile = min_max_threshold
257
+
258
+ if method == 'iqr':
259
+ # IQR method [default]
260
+ q1 = total_abundance.quantile(q=min_quantile)
261
+ q3 = total_abundance.quantile(q=max_quantile)
262
+ iqr = q3 - q1
263
+ min_threshold = q1 - iqr_scale_factor * iqr
264
+ max_threshold = q3 + iqr_scale_factor * iqr
265
+ elif method == 'quantile':
266
+ min_threshold = total_abundance.quantile(q=min_quantile)
267
+ max_threshold = total_abundance.quantile(q=max_quantile)
268
+ elif method == 'abundance':
269
+ min_threshold = min_quantile
270
+ max_threshold = max_quantile
271
+ else:
272
+ raise ValueError(f"Filtering method '{method}' is not supported. Choose from 'iqr' or 'quantile'.")
273
+
274
+ logger.info(f"Sample total abundance filtering thresholds: min={min_threshold}, max={max_threshold}")
275
+
276
+ keep = (total_abundance >= min_threshold) & (total_abundance <= max_threshold)
277
+ filtered_otu_table_df = otu_table_df.loc[:, keep]
278
+
279
+ logger.info(f"Raw/Filtered shape: {otu_table_df.shape} -> {filtered_otu_table_df.shape}, kept samples: {keep.sum() * 100/len(keep):.2f}%")
280
+
281
+ if show:
282
+ plt.subplot(221)
283
+ sns.boxplot(data=total_abundance, orient='h')
284
+ plt.title('Total Abundance per Sample')
285
+
286
+ plt.subplot(222)
287
+ sns.histplot(total_abundance, kde=True)
288
+ plt.axvline(min_threshold, color='green', linestyle='--', label='Min Threshold', alpha=0.7)
289
+ plt.axvline(max_threshold, color='blue', linestyle='--', label='Max Threshold', alpha=0.7)
290
+ plt.legend()
291
+ plt.title('Total Abundance Distribution')
292
+
293
+ plt.subplot(223)
294
+ sns.boxplot(data=filtered_otu_table_df.sum(axis=0), orient='h')
295
+ plt.title('Filtered Total Abundance per Sample')
296
+
297
+ plt.subplot(224)
298
+ sns.histplot(filtered_otu_table_df.sum(axis=0), kde=True)
299
+ plt.axvline(min_threshold, color='green', linestyle='--', label='Min Threshold', alpha=0.7)
300
+ plt.axvline(max_threshold, color='blue', linestyle='--', label='Max Threshold', alpha=0.7)
301
+ plt.legend()
302
+ plt.title('Filtered Total Abundance Distribution')
303
+
304
+ plt.tight_layout()
305
+ plt.show()
306
+
307
+ return filtered_otu_table_df
308
+
309
+ @staticmethod
310
+ def filter_otu_table_by_otu(otu_table_df: pd.DataFrame, min_prevalence: float = 0.01, min_abundance: float = 0.00001, show: bool = False) -> pd.DataFrame:
311
+ '''Filter OTU table based on prevalence and abundance.
312
+ Meanwhile, [core microbiome] can be extracted via setting min_prevalence=0.5, min_abundance=None.
313
+ Args:
314
+ otu_table_df: OTU table DataFrame, OTU_ID as index, sample IDs as columns.
315
+ min_prevalence: float, minimum prevalence threshold (0-1), default is 0.01.
316
+ min_abundance: float, minimum abundance threshold (0-1), default is 0.00001.
317
+ show: bool, whether to show the filtering plots.
318
+ Returns:
319
+ filtered_otu_table_df: filtered OTU table DataFrame.
320
+ '''
321
+
322
+ n_samples = otu_table_df.shape[1]
323
+
324
+ # Calculate prevalence threshold
325
+ prevalence = (otu_table_df > 0).sum(axis=1) / n_samples
326
+ keep_prevalence = prevalence >= min_prevalence
327
+
328
+ # Calculate total abundance threshold
329
+ if min_abundance is not None:
330
+ rel_abundance = otu_table_df.div(otu_table_df.sum(axis=0), axis=1)
331
+ global_abundance = rel_abundance.mean(axis=1) # or sum(axis=1), mean operation is more strict.
332
+ keep_abundance = global_abundance >= min_abundance
333
+ else:
334
+ keep_abundance = pd.Series(True, index=otu_table_df.index)
335
+
336
+ keep = keep_prevalence & keep_abundance
337
+
338
+ filtered_otu_table_df = otu_table_df.loc[keep, :]
339
+
340
+ logger.info(f"Raw/Filtered shape: {otu_table_df.shape} -> {filtered_otu_table_df.shape}, kept samples: {filtered_otu_table_df.shape[0] * 100/otu_table_df.shape[0]:.2f}%")
341
+
342
+ if show:
343
+ plt.subplot(221)
344
+ sns.histplot(prevalence, kde=True)
345
+ plt.axvline(min_prevalence, color='green', linestyle='--', label='Min Prevalence', alpha=0.7)
346
+ plt.legend(loc='upper right')
347
+ plt.xscale('log')
348
+ plt.title('OTU Prevalence Distribution')
349
+
350
+ plt.subplot(222)
351
+ sns.histplot(prevalence.loc[filtered_otu_table_df.index], kde=True)
352
+ plt.xscale('log')
353
+ plt.title('Filtered OTU Prevalence Distribution')
354
+
355
+ if min_abundance is not None:
356
+ plt.subplot(223)
357
+ sns.histplot(global_abundance, kde=True)
358
+ plt.axvline(min_abundance, color='green', linestyle='--', label='Min Abundance', alpha=0.7)
359
+ plt.legend(loc='upper right')
360
+ plt.xscale('log')
361
+ plt.title('OTU Global Abundance Distribution')
362
+
363
+ plt.subplot(224)
364
+ sns.histplot(rel_abundance.loc[filtered_otu_table_df.index, :].mean(axis=1), kde=True)
365
+ plt.xscale('log')
366
+ plt.title('Filtered OTU Global Abundance Distribution')
367
+
368
+ plt.tight_layout()
369
+ plt.show()
370
+
371
+ return filtered_otu_table_df
372
+
373
+ @staticmethod
374
+ def keep_top_otus(otu_table_df: pd.DataFrame, top_k: int = 3000):
375
+ '''Keep top_k OTUs based on total abundance.
376
+ Args:
377
+ otu_table_df: OTU table DataFrame, OTU_ID as index, sample IDs as columns.
378
+ top_k: int, number of top OTUs to keep based on total abundance.
379
+ Returns:
380
+ filtered_otu_table_df: filtered OTU table DataFrame containing top_k OTUs.
381
+ '''
382
+
383
+ total_abundance = otu_table_df.sum(axis=1)
384
+ top = total_abundance.sort_values(ascending=False).head(top_k).index
385
+
386
+ return otu_table_df.loc[top, :]
387
+
388
+ @staticmethod
389
+ def otu_for_network(otu_table_df: pd.DataFrame, min_prevalence: float = 0.01, min_abundance: float = 0.00001, top_k: int = 3000):
390
+ '''Comprehensive filter function including prevalence, abundance and top_k.'''
391
+
392
+ filtered_otu_table = OTUQC.filter_otu_table_by_otu(
393
+ otu_table_df=otu_table_df,
394
+ min_prevalence=min_prevalence,
395
+ min_abundance=min_abundance,
396
+ show=False,
397
+ )
398
+
399
+ filtered_otu_table = OTUQC.keep_top_otus(
400
+ otu_table_df=filtered_otu_table,
401
+ top_k=top_k
402
+ )
403
+
404
+ return filtered_otu_table
405
+
406
+ @staticmethod
407
+ def sparsity_curve(otu_table_df: pd.DataFrame, steps=[1000, 2000, 5000, 10000, 20000]):
408
+ '''Plot sparsity curve based on different number of OTUs.'''
409
+
410
+ sparsities = []
411
+
412
+ for k in steps:
413
+ if k >= otu_table_df.shape[0]: # number of OTUs
414
+ sparsities.append(np.nan)
415
+ continue
416
+
417
+ subset = otu_table_df.iloc[:k, :]
418
+ # Spearman correlation justfor sparsity test
419
+ corr = subset.T.corr(method='spearman')
420
+ sparsity = (corr == 0).mean().mean()
421
+ sparsities.append(sparsity)
422
+
423
+ print(f"{k} OTUs -> Sparsity: {sparsity:.4f}")
424
+
425
+ plt.plot(steps[:len(sparsities)], sparsities, marker='o')
426
+ plt.xlabel('OTU numbers')
427
+ plt.ylabel('Sparsity')
428
+ plt.title('Sparsity Curve')
429
+ plt.grid(True)
430
+
431
+ return sparsities
432
+
433
+ @staticmethod
434
+ def normalize(otu_table_df: pd.DataFrame, method: str = 'rel') -> pd.DataFrame:
435
+ '''Normalize OTU table with specified method.
436
+ Args:
437
+ otu_table_df: OTU table DataFrame, OTU_ID as index, sample IDs as columns.
438
+ method: str, normalization method, options: 'tss', 'clr', 'rarefy', 'none'.
439
+ Returns:
440
+ normalized_otu_table_df: normalized OTU table DataFrame.
441
+ Examples:
442
+ >>> OTUQC.normalize(otu_table_df, method='tss')
443
+ >>> OTUQC.normalize(otu_table_df, method='clr')
444
+ '''
445
+
446
+ if method == 'none':
447
+ normalized_otu_table_df = otu_table_df.copy()
448
+
449
+ elif method == 'rel':
450
+ # normalized_otu_table_df = otu_table_df.div(otu_table_df.sum(axis=0), axis=1)
451
+ sums = otu_table_df.sum(axis=0).replace(0, np.nan) # avoid division by zero
452
+ normalized_otu_table_df = otu_table_df.div(sums, axis=1).fillna(0)
453
+
454
+ elif method == 'clr':
455
+ rel_abundance = otu_table_df.div(otu_table_df.sum(axis=0), axis=1)
456
+ log_rel_abundance = np.log(rel_abundance.replace(0, np.nan))
457
+ gm = log_rel_abundance.mean(axis=0)
458
+ normalized_otu_table_df = log_rel_abundance.subtract(gm, axis=1).fillna(0)
459
+
460
+ elif method == 'rarefy':
461
+ def rarefy_column(col: pd.Series, depth: int):
462
+ if col.sum() < depth:
463
+ raise ValueError(f"Cannot rarefy sample with total count {col.sum()} to depth {depth}.")
464
+ probabilities = col / col.sum()
465
+ rarefied_counts = np.random.multinomial(depth, probabilities)
466
+ return pd.Series(rarefied_counts, index=col.index)
467
+
468
+ min_depth = otu_table_df.sum(axis=0).min()
469
+ normalized_otu_table_df = otu_table_df.apply(lambda col: rarefy_column(col, int(min_depth)), axis=0)
470
+
471
+ elif method == 'deseq2':
472
+ # Placeholder for DESeq2 normalization
473
+ # In practice, this would require calling R's DESeq2 package via rpy2 or similar.
474
+ raise NotImplementedError("DESeq2 normalization requires R's DESeq2 package and is not implemented in this function.")
475
+
476
+ elif method == 'css':
477
+ # Placeholder for CSS normalization
478
+ # In practice, this would require a specific implementation or package.
479
+ raise NotImplementedError("CSS normalization is not implemented in this function.")
480
+
481
+ elif method == 'tmm':
482
+ # Placeholder for TMM normalization
483
+ # In practice, this would require a specific implementation or package.
484
+ raise NotImplementedError("TMM normalization is not implemented in this function.")
485
+
486
+ elif method == 'tpm':
487
+ # Placeholder for TPM normalization
488
+ # In practice, this would require gene length information.
489
+ raise NotImplementedError("TPM normalization requires gene length information and is not implemented in this function.")
490
+
491
+ else:
492
+ raise ValueError(f"Normalization method '{method}' is not supported. Choose from 'tss', 'clr', 'rarefy', or 'none', 'deseq2', 'css', 'tmm', 'tpm'.")
493
+
494
+ return normalized_otu_table_df
495
+
496
+ def _choose_rarefaction_depth(self, strategy: str, q: float ): ...
497
+
498
+ @staticmethod
499
+ def rarefy(otu_table_df: pd.DataFrame, depth_method: Union[str, int] = 'min') -> pd.DataFrame:
500
+ '''Rarefy OTU table to the depth.
501
+ Args:
502
+ otu_table_df: OTU table DataFrame, OTU_ID as index, sample IDs as columns.
503
+ Returns:
504
+ rarefied_otu_table_df: rarefied OTU table DataFrame.
505
+ '''
506
+
507
+ if isinstance(depth_method, int):
508
+ depth = depth_method
509
+
510
+ elif isinstance(depth_method, str) and depth_method == 'min':
511
+ depth = int(otu_table_df.sum(axis=0).min())
512
+
513
+ elif isinstance(depth_method, str) and depth_method.endswith('%'):
514
+ perc = float(depth_method.strip('%')) / 100.0
515
+ depth = int(otu_table_df.sum(axis=0).min() * perc)
516
+
517
+
518
+ class Metadata:
519
+ '''Operator of metadata table.'''
520
+
521
+ @staticmethod
522
+ def metadata_continuous_normalize(metadata_table_df: pd.DataFrame, continuous_cols: List[str], method: str = 'standard') -> pd.DataFrame:
523
+ '''Normalize continuous metadata columns.
524
+ Args:
525
+ metadata_table_df: metadata table DataFrame, SampleID as index.
526
+ continuous_cols: list of str, continuous columns to normalize.
527
+ method: str, normalization method, options: 'standard', 'minmax', 'log'.
528
+ Returns:
529
+ normalized_metadata_table_df: normalized metadata table DataFrame.
530
+ '''
531
+
532
+ normalized_metadata_table_df = metadata_table_df.copy()[continuous_cols]
533
+
534
+ if method == 'standard':
535
+ scaler = StandardScaler()
536
+ normalized_metadata_table_df[continuous_cols] = scaler.fit_transform(normalized_metadata_table_df[continuous_cols])
537
+
538
+ elif method == 'minmax':
539
+ scaler = MinMaxScaler()
540
+ normalized_metadata_table_df[continuous_cols] = scaler.fit_transform(normalized_metadata_table_df[continuous_cols])
541
+
542
+ elif method == 'log':
543
+ normalized_metadata_table_df[continuous_cols] = np.log1p(normalized_metadata_table_df[continuous_cols])
544
+
545
+ else:
546
+ raise ValueError(f"Normalization method '{method}' is not supported. Choose from 'standard', 'minmax', or 'log'.")
547
+
548
+ return normalized_metadata_table_df
549
+
550
+ @staticmethod
551
+ def metadata_categorical_encode(metadata_table_df: pd.DataFrame, categorical_cols: List[str], method: str = 'onehot') -> pd.DataFrame:
552
+ '''Encode categorical metadata columns.
553
+ Args:
554
+ metadata_table_df: metadata table DataFrame, SampleID as index.
555
+ categorical_cols: list of str, categorical columns to encode.
556
+ method: str, encoding method, options: 'label', 'onehot'.
557
+ Returns:
558
+ encoded_metadata_table_df: encoded metadata table DataFrame.
559
+ '''
560
+
561
+ encoded_metadata_table_df = metadata_table_df.copy()[categorical_cols]
562
+
563
+ if method == 'onehot':
564
+ encoder = OneHotEncoder(sparse_output=False, drop='first')
565
+ onehot_encoded_array = encoder.fit_transform(encoded_metadata_table_df[categorical_cols])
566
+ onehot_encoded_df = pd.DataFrame(
567
+ onehot_encoded_array,
568
+ index=encoded_metadata_table_df.index,
569
+ columns=encoder.get_feature_names_out(categorical_cols)
570
+ )
571
+ return onehot_encoded_df
572
+
573
+ elif method == 'label':
574
+ for col in categorical_cols:
575
+ le = LabelEncoder()
576
+ encoded_metadata_table_df[col] = le.fit_transform(encoded_metadata_table_df[col])
577
+ return encoded_metadata_table_df
578
+
579
+ else:
580
+ raise ValueError(f"Encoding method '{method}' is not supported. Choose from 'label' or 'onehot'.")
581
+
582
+