XspecT 0.2.7__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of XspecT might be problematic. Click here for more details.
- xspect/definitions.py +0 -7
- xspect/download_models.py +25 -24
- xspect/fastapi.py +23 -26
- xspect/file_io.py +86 -2
- xspect/main.py +360 -98
- xspect/mlst_feature/mlst_helper.py +4 -6
- xspect/model_management.py +7 -15
- xspect/models/probabilistic_filter_model.py +16 -5
- xspect/models/probabilistic_filter_svm_model.py +33 -18
- xspect/models/probabilistic_single_filter_model.py +8 -1
- xspect/models/result.py +32 -66
- xspect/ncbi.py +265 -0
- xspect/train.py +258 -242
- {xspect-0.2.7.dist-info → xspect-0.4.1.dist-info}/METADATA +15 -21
- xspect-0.4.1.dist-info/RECORD +24 -0
- {xspect-0.2.7.dist-info → xspect-0.4.1.dist-info}/WHEEL +1 -1
- xspect/pipeline.py +0 -201
- xspect/run.py +0 -38
- xspect/train_filter/__init__.py +0 -0
- xspect/train_filter/create_svm.py +0 -45
- xspect/train_filter/extract_and_concatenate.py +0 -124
- xspect/train_filter/ncbi_api/__init__.py +0 -0
- xspect/train_filter/ncbi_api/download_assemblies.py +0 -31
- xspect/train_filter/ncbi_api/ncbi_assembly_metadata.py +0 -110
- xspect/train_filter/ncbi_api/ncbi_children_tree.py +0 -53
- xspect/train_filter/ncbi_api/ncbi_taxon_metadata.py +0 -55
- xspect-0.2.7.dist-info/RECORD +0 -33
- {xspect-0.2.7.dist-info → xspect-0.4.1.dist-info}/entry_points.txt +0 -0
- {xspect-0.2.7.dist-info → xspect-0.4.1.dist-info/licenses}/LICENSE +0 -0
- {xspect-0.2.7.dist-info → xspect-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -4,7 +4,6 @@
|
|
|
4
4
|
|
|
5
5
|
import csv
|
|
6
6
|
import json
|
|
7
|
-
from linecache import getline
|
|
8
7
|
from pathlib import Path
|
|
9
8
|
from sklearn.svm import SVC
|
|
10
9
|
from Bio.SeqRecord import SeqRecord
|
|
@@ -30,6 +29,8 @@ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
|
|
|
30
29
|
c: float,
|
|
31
30
|
fpr: float = 0.01,
|
|
32
31
|
num_hashes: int = 7,
|
|
32
|
+
training_accessions: dict[str, list[str]] = None,
|
|
33
|
+
svm_accessions: dict[str, list[str]] = None,
|
|
33
34
|
) -> None:
|
|
34
35
|
super().__init__(
|
|
35
36
|
k=k,
|
|
@@ -40,14 +41,17 @@ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
|
|
|
40
41
|
base_path=base_path,
|
|
41
42
|
fpr=fpr,
|
|
42
43
|
num_hashes=num_hashes,
|
|
44
|
+
training_accessions=training_accessions,
|
|
43
45
|
)
|
|
44
46
|
self.kernel = kernel
|
|
45
47
|
self.c = c
|
|
48
|
+
self.svm_accessions = svm_accessions
|
|
46
49
|
|
|
47
50
|
def to_dict(self) -> dict:
|
|
48
51
|
return super().to_dict() | {
|
|
49
52
|
"kernel": self.kernel,
|
|
50
53
|
"C": self.c,
|
|
54
|
+
"svm_accessions": self.svm_accessions,
|
|
51
55
|
}
|
|
52
56
|
|
|
53
57
|
def set_svm_params(self, kernel: str, c: float) -> None:
|
|
@@ -62,32 +66,41 @@ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
|
|
|
62
66
|
svm_path: Path,
|
|
63
67
|
display_names: dict = None,
|
|
64
68
|
svm_step: int = 1,
|
|
69
|
+
training_accessions: list[str] = None,
|
|
70
|
+
svm_accessions: list[str] = None,
|
|
65
71
|
) -> None:
|
|
66
72
|
"""Fit the SVM to the sequences and labels"""
|
|
67
73
|
|
|
68
74
|
# Since the SVM works with score data, we need to train
|
|
69
75
|
# the underlying data structure for score generation first
|
|
70
|
-
super().fit(
|
|
76
|
+
super().fit(
|
|
77
|
+
dir_path,
|
|
78
|
+
display_names=display_names,
|
|
79
|
+
training_accessions=training_accessions,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
self.svm_accessions = svm_accessions
|
|
71
83
|
|
|
72
84
|
# calculate scores for SVM training
|
|
73
85
|
score_list = []
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
if file.suffix[1:] not in fasta_endings + fastq_endings:
|
|
86
|
+
|
|
87
|
+
for species_folder in svm_path.iterdir():
|
|
88
|
+
if not species_folder.is_dir():
|
|
78
89
|
continue
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
90
|
+
for file in species_folder.iterdir():
|
|
91
|
+
if file.suffix[1:] not in fasta_endings + fastq_endings:
|
|
92
|
+
continue
|
|
93
|
+
print(f"Calculating {file.name} scores for SVM training...")
|
|
94
|
+
res = super().predict(file, step=svm_step)
|
|
95
|
+
scores = res.get_scores()["total"]
|
|
96
|
+
accession = file.stem
|
|
97
|
+
label_id = species_folder.name
|
|
98
|
+
|
|
99
|
+
# format scores for csv
|
|
100
|
+
scores = dict(sorted(scores.items()))
|
|
101
|
+
scores = ",".join([str(score) for score in scores.values()])
|
|
102
|
+
scores = f"{accession},{scores},{label_id}"
|
|
103
|
+
score_list.append(scores)
|
|
91
104
|
|
|
92
105
|
# csv header
|
|
93
106
|
keys = list(self.display_names.keys())
|
|
@@ -162,6 +175,8 @@ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
|
|
|
162
175
|
model_json["C"],
|
|
163
176
|
fpr=model_json["fpr"],
|
|
164
177
|
num_hashes=model_json["num_hashes"],
|
|
178
|
+
training_accessions=model_json["training_accessions"],
|
|
179
|
+
svm_accessions=model_json["svm_accessions"],
|
|
165
180
|
)
|
|
166
181
|
model.display_names = model_json["display_names"]
|
|
167
182
|
|
|
@@ -25,6 +25,7 @@ class ProbabilisticSingleFilterModel(ProbabilisticFilterModel):
|
|
|
25
25
|
model_type: str,
|
|
26
26
|
base_path: Path,
|
|
27
27
|
fpr: float = 0.01,
|
|
28
|
+
training_accessions: list[str] = None,
|
|
28
29
|
) -> None:
|
|
29
30
|
super().__init__(
|
|
30
31
|
k=k,
|
|
@@ -35,11 +36,16 @@ class ProbabilisticSingleFilterModel(ProbabilisticFilterModel):
|
|
|
35
36
|
base_path=base_path,
|
|
36
37
|
fpr=fpr,
|
|
37
38
|
num_hashes=1,
|
|
39
|
+
training_accessions=training_accessions,
|
|
38
40
|
)
|
|
39
41
|
self.bf = None
|
|
40
42
|
|
|
41
|
-
def fit(
|
|
43
|
+
def fit(
|
|
44
|
+
self, file_path: Path, display_name: str, training_accessions: list[str] = None
|
|
45
|
+
) -> None:
|
|
42
46
|
"""Fit the cobs classic index to the sequences and labels"""
|
|
47
|
+
self.training_accessions = training_accessions
|
|
48
|
+
|
|
43
49
|
# estimate number of kmers
|
|
44
50
|
total_length = 0
|
|
45
51
|
for record in get_record_iterator(file_path):
|
|
@@ -88,6 +94,7 @@ class ProbabilisticSingleFilterModel(ProbabilisticFilterModel):
|
|
|
88
94
|
model_json["model_type"],
|
|
89
95
|
path.parent,
|
|
90
96
|
fpr=model_json["fpr"],
|
|
97
|
+
training_accessions=model_json["training_accessions"],
|
|
91
98
|
)
|
|
92
99
|
model.display_names = model_json["display_names"]
|
|
93
100
|
bloom_path = model.base_path / model.slug() / "filter.bloom"
|
xspect/models/result.py
CHANGED
|
@@ -1,50 +1,7 @@
|
|
|
1
1
|
"""Module for storing the results of XspecT models."""
|
|
2
2
|
|
|
3
|
-
from
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
def get_last_processing_step(result: "ModelResult") -> "ModelResult":
|
|
7
|
-
"""Get the last subprocessing step of the result. First path only."""
|
|
8
|
-
|
|
9
|
-
# traverse result tree to get last step
|
|
10
|
-
while result.subprocessing_steps:
|
|
11
|
-
result = result.subprocessing_steps[-1].result
|
|
12
|
-
return result
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class StepType(Enum):
|
|
16
|
-
"""Enum for defining the type of a subprocessing step."""
|
|
17
|
-
|
|
18
|
-
PREDICTION = 1
|
|
19
|
-
FILTERING = 2
|
|
20
|
-
|
|
21
|
-
def __str__(self) -> str:
|
|
22
|
-
return self.name.lower()
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
class SubprocessingStep:
|
|
26
|
-
"""Class for storing a subprocessing step of an XspecT model."""
|
|
27
|
-
|
|
28
|
-
def __init__(
|
|
29
|
-
self,
|
|
30
|
-
subprocessing_type: StepType,
|
|
31
|
-
label: str,
|
|
32
|
-
treshold: float,
|
|
33
|
-
result: "ModelResult",
|
|
34
|
-
):
|
|
35
|
-
self.subprocessing_type = subprocessing_type
|
|
36
|
-
self.label = label
|
|
37
|
-
self.treshold = treshold
|
|
38
|
-
self.result = result
|
|
39
|
-
|
|
40
|
-
def to_dict(self) -> dict:
|
|
41
|
-
"""Return the subprocessing step as a dictionary."""
|
|
42
|
-
return {
|
|
43
|
-
"subprocessing_type": str(self.subprocessing_type),
|
|
44
|
-
"label": self.label,
|
|
45
|
-
"treshold": self.treshold,
|
|
46
|
-
"result": self.result.to_dict() if self.result else {},
|
|
47
|
-
}
|
|
3
|
+
from json import dumps
|
|
4
|
+
from pathlib import Path
|
|
48
5
|
|
|
49
6
|
|
|
50
7
|
class ModelResult:
|
|
@@ -58,6 +15,7 @@ class ModelResult:
|
|
|
58
15
|
num_kmers: dict[str, int],
|
|
59
16
|
sparse_sampling_step: int = 1,
|
|
60
17
|
prediction: str = None,
|
|
18
|
+
input_source: str = None,
|
|
61
19
|
):
|
|
62
20
|
if "total" in hits:
|
|
63
21
|
raise ValueError(
|
|
@@ -68,15 +26,7 @@ class ModelResult:
|
|
|
68
26
|
self.num_kmers = num_kmers
|
|
69
27
|
self.sparse_sampling_step = sparse_sampling_step
|
|
70
28
|
self.prediction = prediction
|
|
71
|
-
self.
|
|
72
|
-
|
|
73
|
-
def add_subprocessing_step(self, subprocessing_step: SubprocessingStep) -> None:
|
|
74
|
-
"""Add a subprocessing step to the result."""
|
|
75
|
-
if subprocessing_step.label in self.subprocessing_steps:
|
|
76
|
-
raise ValueError(
|
|
77
|
-
f"Subprocessing step {subprocessing_step.label} already exists in the result"
|
|
78
|
-
)
|
|
79
|
-
self.subprocessing_steps.append(subprocessing_step)
|
|
29
|
+
self.input_source = input_source
|
|
80
30
|
|
|
81
31
|
def get_scores(self) -> dict:
|
|
82
32
|
"""Return the scores of the model."""
|
|
@@ -108,19 +58,33 @@ class ModelResult:
|
|
|
108
58
|
return total_hits
|
|
109
59
|
|
|
110
60
|
def get_filter_mask(self, label: str, filter_threshold: float) -> dict[str, bool]:
|
|
111
|
-
"""Return a mask for filtered subsequences.
|
|
112
|
-
|
|
61
|
+
"""Return a mask for filtered subsequences.
|
|
62
|
+
|
|
63
|
+
The mask is a dictionary with subsequence names as keys and boolean values
|
|
64
|
+
indicating whether the subsequence is above the filter threshold for the given label.
|
|
65
|
+
A value of -1 for filter_threshold indicates that the subsequence with the maximum score
|
|
66
|
+
for the given label should be returned.
|
|
67
|
+
"""
|
|
68
|
+
if filter_threshold < 0 and not filter_threshold == -1 or filter_threshold > 1:
|
|
113
69
|
raise ValueError("The filter threshold must be between 0 and 1.")
|
|
114
70
|
|
|
115
71
|
scores = self.get_scores()
|
|
116
72
|
scores.pop("total")
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
73
|
+
if not filter_threshold == -1:
|
|
74
|
+
return {
|
|
75
|
+
subsequence: score[label] >= filter_threshold
|
|
76
|
+
for subsequence, score in scores.items()
|
|
77
|
+
}
|
|
78
|
+
else:
|
|
79
|
+
return {
|
|
80
|
+
subsequence: score[label] == max(score.values())
|
|
81
|
+
for subsequence, score in scores.items()
|
|
82
|
+
}
|
|
121
83
|
|
|
122
|
-
def
|
|
123
|
-
|
|
84
|
+
def get_filtered_subsequence_labels(
|
|
85
|
+
self, label: str, filter_threshold: float = 0.7
|
|
86
|
+
) -> list[str]:
|
|
87
|
+
"""Return the labels of filtered subsequences."""
|
|
124
88
|
return [
|
|
125
89
|
subsequence
|
|
126
90
|
for subsequence, mask in self.get_filter_mask(
|
|
@@ -137,13 +101,15 @@ class ModelResult:
|
|
|
137
101
|
"hits": self.hits,
|
|
138
102
|
"scores": self.get_scores(),
|
|
139
103
|
"num_kmers": self.num_kmers,
|
|
140
|
-
"
|
|
141
|
-
subprocessing_step.to_dict()
|
|
142
|
-
for subprocessing_step in self.subprocessing_steps
|
|
143
|
-
],
|
|
104
|
+
"input_source": self.input_source,
|
|
144
105
|
}
|
|
145
106
|
|
|
146
107
|
if self.prediction is not None:
|
|
147
108
|
res["prediction"] = self.prediction
|
|
148
109
|
|
|
149
110
|
return res
|
|
111
|
+
|
|
112
|
+
def save(self, path: Path) -> None:
|
|
113
|
+
"""Save the result as a JSON file."""
|
|
114
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
115
|
+
f.write(dumps(self.to_dict(), indent=4))
|
xspect/ncbi.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
1
|
+
"""NCBI handler for the NCBI Datasets API."""
|
|
2
|
+
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
import requests
|
|
6
|
+
import time
|
|
7
|
+
|
|
8
|
+
# pylint: disable=line-too-long
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class AssemblyLevel(Enum):
|
|
12
|
+
"""Enum for the assembly level."""
|
|
13
|
+
|
|
14
|
+
REFERENCE = "reference"
|
|
15
|
+
COMPLETE_GENOME = "complete_genome"
|
|
16
|
+
CHROMOSOME = "chromosome"
|
|
17
|
+
SCAFFOLD = "scaffold"
|
|
18
|
+
CONTIG = "contig"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class AssemblySource(Enum):
|
|
22
|
+
"""Enum for the assembly source."""
|
|
23
|
+
|
|
24
|
+
REFSEQ = "refseq"
|
|
25
|
+
GENBANK = "genbank"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class NCBIHandler:
|
|
29
|
+
"""This class uses the NCBI Datasets API to get the taxonomy tree of a given Taxon.
|
|
30
|
+
|
|
31
|
+
The taxonomy tree consists of only the next children to the parent taxon.
|
|
32
|
+
The children are only of the next lower rank of the parent taxon.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
api_key: str = None,
|
|
38
|
+
):
|
|
39
|
+
"""Initialise the NCBI handler."""
|
|
40
|
+
self.api_key = api_key
|
|
41
|
+
self.base_url = "https://api.ncbi.nlm.nih.gov/datasets/v2"
|
|
42
|
+
self.last_request_time = 0
|
|
43
|
+
self.min_interval = (
|
|
44
|
+
1 / 10 if api_key else 1 / 5
|
|
45
|
+
) # NCBI allows 10 requests per second with if an API key, otherwise 5 requests per second
|
|
46
|
+
|
|
47
|
+
def _enforce_rate_limit(self):
|
|
48
|
+
"""Enforce rate limiting for the NCBI Datasets API.
|
|
49
|
+
|
|
50
|
+
This method ensures that the requests to the API are limited to 5 requests per second
|
|
51
|
+
without an API key and 10 requests per second with an API key.
|
|
52
|
+
It uses a simple time-based approach to enforce the rate limit.
|
|
53
|
+
"""
|
|
54
|
+
now = time.time()
|
|
55
|
+
elapsed_time = now - self.last_request_time
|
|
56
|
+
if elapsed_time < self.min_interval:
|
|
57
|
+
time.sleep(self.min_interval - elapsed_time)
|
|
58
|
+
self.last_request_time = now # Update last request time
|
|
59
|
+
|
|
60
|
+
def _make_request(self, endpoint: str, timeout: int = 5) -> dict:
|
|
61
|
+
"""Make a request to the NCBI Datasets API.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
endpoint (str): The endpoint to make the request to.
|
|
65
|
+
timeout (int, optional): The timeout for the request in seconds. Defaults to 5.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
dict: The response from the API.
|
|
69
|
+
"""
|
|
70
|
+
self._enforce_rate_limit()
|
|
71
|
+
|
|
72
|
+
endpoint = endpoint if endpoint.startswith("/") else "/" + endpoint
|
|
73
|
+
headers = {}
|
|
74
|
+
if self.api_key:
|
|
75
|
+
headers["api-key"] = self.api_key
|
|
76
|
+
response = requests.get(
|
|
77
|
+
self.base_url + endpoint, headers=headers, timeout=timeout
|
|
78
|
+
)
|
|
79
|
+
if response.status_code != 200:
|
|
80
|
+
response.raise_for_status()
|
|
81
|
+
|
|
82
|
+
return response.json()
|
|
83
|
+
|
|
84
|
+
def get_genus_taxon_id(self, genus: str) -> int:
|
|
85
|
+
"""
|
|
86
|
+
Get the taxon id for a given genus name.
|
|
87
|
+
|
|
88
|
+
This function checks if the genus name is valid by making a request to the NCBI Datasets API.
|
|
89
|
+
If the genus name is valid, it returns the taxon id.
|
|
90
|
+
If the genus name is not valid, it raises an exception.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
genus (str): The genus name to validate.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
int: The taxon id for the given genus name.
|
|
97
|
+
|
|
98
|
+
Raises:
|
|
99
|
+
ValueError: If the genus name is not valid.
|
|
100
|
+
"""
|
|
101
|
+
endpoint = f"/taxonomy/taxon/{genus}"
|
|
102
|
+
response = self._make_request(endpoint)
|
|
103
|
+
|
|
104
|
+
try:
|
|
105
|
+
taxonomy = response["taxonomy_nodes"][0]["taxonomy"]
|
|
106
|
+
|
|
107
|
+
taxon_id = taxonomy["tax_id"]
|
|
108
|
+
rank = taxonomy["rank"]
|
|
109
|
+
lineage = taxonomy["lineage"]
|
|
110
|
+
|
|
111
|
+
if rank != "GENUS":
|
|
112
|
+
raise ValueError(f"Genus name {genus} is not a genus.")
|
|
113
|
+
if lineage[2] != 2:
|
|
114
|
+
raise ValueError(f"Genus name {genus} does not belong to bacteria.")
|
|
115
|
+
|
|
116
|
+
return taxon_id
|
|
117
|
+
except (IndexError, KeyError, TypeError) as e:
|
|
118
|
+
raise ValueError(f"Invalid genus name: {genus}") from e
|
|
119
|
+
|
|
120
|
+
def get_species(self, genus_id: int) -> list[int]:
|
|
121
|
+
"""
|
|
122
|
+
Get the species for a given genus id.
|
|
123
|
+
|
|
124
|
+
This function makes a request to the NCBI Datasets API to get the species for a given genus id.
|
|
125
|
+
It returns a list of species taxonomy ids.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
genus_id (int): The genus id to get the species for.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
list[int]: A list containing the species taxnomy ids.
|
|
132
|
+
"""
|
|
133
|
+
endpoint = f"/taxonomy/taxon/{genus_id}/filtered_subtree"
|
|
134
|
+
response = self._make_request(endpoint)
|
|
135
|
+
|
|
136
|
+
try:
|
|
137
|
+
species_ids = response["edges"][str(genus_id)]["visible_children"]
|
|
138
|
+
except (IndexError, KeyError, TypeError) as e:
|
|
139
|
+
raise ValueError(f"Invalid genus id: {genus_id}") from e
|
|
140
|
+
return species_ids
|
|
141
|
+
|
|
142
|
+
def get_taxon_names(self, taxon_ids: list[int]) -> dict[int, str]:
|
|
143
|
+
"""
|
|
144
|
+
Get the names for a given list of taxon ids.
|
|
145
|
+
|
|
146
|
+
This function makes a request to the NCBI Datasets API to get the names for a given list of taxon ids.
|
|
147
|
+
It returns a dictionary with the taxon ids as keys and the names as values.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
taxon_ids (list[int]): The list of taxon ids to get the names for.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
dict[int, str]: A dictionary containing the taxon ids and their corresponding names.
|
|
154
|
+
"""
|
|
155
|
+
if len(taxon_ids) > 1000:
|
|
156
|
+
raise ValueError("Maximum number of taxon ids is 1000.")
|
|
157
|
+
if len(taxon_ids) < 1:
|
|
158
|
+
raise ValueError("At least one taxon id is required.")
|
|
159
|
+
|
|
160
|
+
endpoint = f"/taxonomy/taxon/{','.join(map(str, taxon_ids))}?page_size=1000"
|
|
161
|
+
response = self._make_request(endpoint)
|
|
162
|
+
|
|
163
|
+
try:
|
|
164
|
+
taxon_names = {
|
|
165
|
+
int(taxonomy_node["taxonomy"]["tax_id"]): taxonomy_node["taxonomy"][
|
|
166
|
+
"organism_name"
|
|
167
|
+
]
|
|
168
|
+
for taxonomy_node in response["taxonomy_nodes"]
|
|
169
|
+
}
|
|
170
|
+
if len(taxon_names) != len(taxon_ids):
|
|
171
|
+
raise ValueError("Not all taxon ids were found.")
|
|
172
|
+
except (IndexError, KeyError, TypeError) as e:
|
|
173
|
+
raise ValueError(f"Invalid taxon ids: {taxon_ids}") from e
|
|
174
|
+
|
|
175
|
+
return taxon_names
|
|
176
|
+
|
|
177
|
+
def get_accessions(
|
|
178
|
+
self,
|
|
179
|
+
taxon_id: int,
|
|
180
|
+
assembly_level: AssemblyLevel,
|
|
181
|
+
assembly_source: AssemblySource,
|
|
182
|
+
count: int,
|
|
183
|
+
min_n50: int = 10000,
|
|
184
|
+
exclude_atypical: bool = True,
|
|
185
|
+
exclude_paired_reports: bool = True,
|
|
186
|
+
current_version_only: bool = True,
|
|
187
|
+
) -> list[str]:
|
|
188
|
+
"""
|
|
189
|
+
Get the accessions for a given taxon id.
|
|
190
|
+
|
|
191
|
+
This function makes a request to the NCBI Datasets API to get the accessions for a given taxon id.
|
|
192
|
+
It filters the accessions based on the assembly level, assembly source, and other parameters.
|
|
193
|
+
It returns a list with the respective accessions.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
taxon_id int: The taxon id to get the accessions for.
|
|
197
|
+
assembly_level (AssemblyLevel): The assembly level to get the accessions for.
|
|
198
|
+
assembly_source (AssemblySource): The assembly source to get the accessions for.
|
|
199
|
+
count (int): The number of accessions to get.
|
|
200
|
+
min_n50 (int, optional): The minimum contig n50 to filter the accessions. Defaults to 10000.
|
|
201
|
+
exclude_atypical (bool, optional): Whether to exclude atypical accessions. Defaults to True.
|
|
202
|
+
exclude_paired_reports (bool, optional): Whether to exclude paired reports. Defaults to True.
|
|
203
|
+
current_version_only (bool, optional): Whether to get only the current version of the accessions. Defaults to True.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
list[str]: A list containing the accessions.
|
|
207
|
+
"""
|
|
208
|
+
endpoint = (
|
|
209
|
+
f"/genome/taxon/{taxon_id}/dataset_report?"
|
|
210
|
+
f"filters.assembly_source={assembly_source.value}&"
|
|
211
|
+
f"filters.exclude_atypical={exclude_atypical}&"
|
|
212
|
+
f"filters.exclude_paired_reports={exclude_paired_reports}&"
|
|
213
|
+
f"filters.current_version_only={current_version_only}&"
|
|
214
|
+
f"page_size={count * 2}&" # to avoid having less than count if n50 or ANI is not met
|
|
215
|
+
)
|
|
216
|
+
endpoint += (
|
|
217
|
+
"&filters.reference_only=true"
|
|
218
|
+
if assembly_level == AssemblyLevel.REFERENCE
|
|
219
|
+
else f"&filters.assembly_level={assembly_level.value}"
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
response = self._make_request(endpoint)
|
|
223
|
+
try:
|
|
224
|
+
accessions = [
|
|
225
|
+
report["accession"]
|
|
226
|
+
for report in response["reports"]
|
|
227
|
+
if report["assembly_stats"]["contig_n50"] >= min_n50
|
|
228
|
+
and report["average_nucleotide_identity"]["taxonomy_check_status"]
|
|
229
|
+
== "OK"
|
|
230
|
+
]
|
|
231
|
+
except (IndexError, KeyError, TypeError):
|
|
232
|
+
print(f"Could not get accessions for taxon with ID: {taxon_id}. Skipping.")
|
|
233
|
+
return []
|
|
234
|
+
return accessions[:count] # Limit to count
|
|
235
|
+
|
|
236
|
+
def get_highest_quality_accessions(
|
|
237
|
+
self, taxon_id: int, assembly_source: AssemblySource, count: int
|
|
238
|
+
) -> list[str]:
|
|
239
|
+
"""Get the highest quality accessions for a given taxon id (based on the assembly level)."""
|
|
240
|
+
accessions = []
|
|
241
|
+
for assembly_level in list(AssemblyLevel):
|
|
242
|
+
accessions += self.get_accessions(
|
|
243
|
+
taxon_id,
|
|
244
|
+
assembly_level,
|
|
245
|
+
assembly_source,
|
|
246
|
+
count,
|
|
247
|
+
)
|
|
248
|
+
if len(set(accessions)) >= count:
|
|
249
|
+
break
|
|
250
|
+
return list(set(accessions))[:count] # Remove duplicates and limit to count
|
|
251
|
+
|
|
252
|
+
def download_assemblies(self, accessions: list[str], output_dir: Path) -> None:
|
|
253
|
+
"""Download assemblies for a list of accessions."""
|
|
254
|
+
endpoint = f"/genome/accession/{','.join(accessions)}/download?include_annotation_type=GENOME_FASTA"
|
|
255
|
+
|
|
256
|
+
self._enforce_rate_limit()
|
|
257
|
+
|
|
258
|
+
response = requests.get(self.base_url + endpoint, stream=True, timeout=5)
|
|
259
|
+
if response.status_code != 200:
|
|
260
|
+
response.raise_for_status()
|
|
261
|
+
|
|
262
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
263
|
+
with open(output_dir / "ncbi_dataset.zip", "wb") as f:
|
|
264
|
+
for chunk in response.iter_content(chunk_size=8192):
|
|
265
|
+
f.write(chunk)
|