XspecT 0.2.7__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of XspecT might be problematic. Click here for more details.

@@ -25,6 +25,7 @@ class ProbabilisticSingleFilterModel(ProbabilisticFilterModel):
25
25
  model_type: str,
26
26
  base_path: Path,
27
27
  fpr: float = 0.01,
28
+ training_accessions: list[str] = None,
28
29
  ) -> None:
29
30
  super().__init__(
30
31
  k=k,
@@ -35,11 +36,16 @@ class ProbabilisticSingleFilterModel(ProbabilisticFilterModel):
35
36
  base_path=base_path,
36
37
  fpr=fpr,
37
38
  num_hashes=1,
39
+ training_accessions=training_accessions,
38
40
  )
39
41
  self.bf = None
40
42
 
41
- def fit(self, file_path: Path, display_name: str) -> None:
43
+ def fit(
44
+ self, file_path: Path, display_name: str, training_accessions: list[str] = None
45
+ ) -> None:
42
46
  """Fit the cobs classic index to the sequences and labels"""
47
+ self.training_accessions = training_accessions
48
+
43
49
  # estimate number of kmers
44
50
  total_length = 0
45
51
  for record in get_record_iterator(file_path):
@@ -88,6 +94,7 @@ class ProbabilisticSingleFilterModel(ProbabilisticFilterModel):
88
94
  model_json["model_type"],
89
95
  path.parent,
90
96
  fpr=model_json["fpr"],
97
+ training_accessions=model_json["training_accessions"],
91
98
  )
92
99
  model.display_names = model_json["display_names"]
93
100
  bloom_path = model.base_path / model.slug() / "filter.bloom"
xspect/models/result.py CHANGED
@@ -1,50 +1,7 @@
1
1
  """Module for storing the results of XspecT models."""
2
2
 
3
- from enum import Enum
4
-
5
-
6
- def get_last_processing_step(result: "ModelResult") -> "ModelResult":
7
- """Get the last subprocessing step of the result. First path only."""
8
-
9
- # traverse result tree to get last step
10
- while result.subprocessing_steps:
11
- result = result.subprocessing_steps[-1].result
12
- return result
13
-
14
-
15
- class StepType(Enum):
16
- """Enum for defining the type of a subprocessing step."""
17
-
18
- PREDICTION = 1
19
- FILTERING = 2
20
-
21
- def __str__(self) -> str:
22
- return self.name.lower()
23
-
24
-
25
- class SubprocessingStep:
26
- """Class for storing a subprocessing step of an XspecT model."""
27
-
28
- def __init__(
29
- self,
30
- subprocessing_type: StepType,
31
- label: str,
32
- treshold: float,
33
- result: "ModelResult",
34
- ):
35
- self.subprocessing_type = subprocessing_type
36
- self.label = label
37
- self.treshold = treshold
38
- self.result = result
39
-
40
- def to_dict(self) -> dict:
41
- """Return the subprocessing step as a dictionary."""
42
- return {
43
- "subprocessing_type": str(self.subprocessing_type),
44
- "label": self.label,
45
- "treshold": self.treshold,
46
- "result": self.result.to_dict() if self.result else {},
47
- }
3
+ from json import dumps
4
+ from pathlib import Path
48
5
 
49
6
 
50
7
  class ModelResult:
@@ -58,6 +15,7 @@ class ModelResult:
58
15
  num_kmers: dict[str, int],
59
16
  sparse_sampling_step: int = 1,
60
17
  prediction: str = None,
18
+ input_source: str = None,
61
19
  ):
62
20
  if "total" in hits:
63
21
  raise ValueError(
@@ -68,15 +26,7 @@ class ModelResult:
68
26
  self.num_kmers = num_kmers
69
27
  self.sparse_sampling_step = sparse_sampling_step
70
28
  self.prediction = prediction
71
- self.subprocessing_steps = []
72
-
73
- def add_subprocessing_step(self, subprocessing_step: SubprocessingStep) -> None:
74
- """Add a subprocessing step to the result."""
75
- if subprocessing_step.label in self.subprocessing_steps:
76
- raise ValueError(
77
- f"Subprocessing step {subprocessing_step.label} already exists in the result"
78
- )
79
- self.subprocessing_steps.append(subprocessing_step)
29
+ self.input_source = input_source
80
30
 
81
31
  def get_scores(self) -> dict:
82
32
  """Return the scores of the model."""
@@ -119,8 +69,10 @@ class ModelResult:
119
69
  for subsequence, score in scores.items()
120
70
  }
121
71
 
122
- def get_filtered_subsequences(self, label: str, filter_threshold: 0.7) -> list[str]:
123
- """Return the filtered subsequences."""
72
+ def get_filtered_subsequence_labels(
73
+ self, label: str, filter_threshold: float = 0.7
74
+ ) -> list[str]:
75
+ """Return the labels of filtered subsequences."""
124
76
  return [
125
77
  subsequence
126
78
  for subsequence, mask in self.get_filter_mask(
@@ -137,13 +89,15 @@ class ModelResult:
137
89
  "hits": self.hits,
138
90
  "scores": self.get_scores(),
139
91
  "num_kmers": self.num_kmers,
140
- "subprocessing_steps": [
141
- subprocessing_step.to_dict()
142
- for subprocessing_step in self.subprocessing_steps
143
- ],
92
+ "input_source": self.input_source,
144
93
  }
145
94
 
146
95
  if self.prediction is not None:
147
96
  res["prediction"] = self.prediction
148
97
 
149
98
  return res
99
+
100
+ def save(self, path: Path) -> None:
101
+ """Save the result as a JSON file."""
102
+ with open(path, "w", encoding="utf-8") as f:
103
+ f.write(dumps(self.to_dict(), indent=4))
xspect/ncbi.py ADDED
@@ -0,0 +1,265 @@
1
+ """NCBI handler for the NCBI Datasets API."""
2
+
3
+ from enum import Enum
4
+ from pathlib import Path
5
+ import requests
6
+ import time
7
+
8
+ # pylint: disable=line-too-long
9
+
10
+
11
+ class AssemblyLevel(Enum):
12
+ """Enum for the assembly level."""
13
+
14
+ REFERENCE = "reference"
15
+ COMPLETE_GENOME = "complete_genome"
16
+ CHROMOSOME = "chromosome"
17
+ SCAFFOLD = "scaffold"
18
+ CONTIG = "contig"
19
+
20
+
21
+ class AssemblySource(Enum):
22
+ """Enum for the assembly source."""
23
+
24
+ REFSEQ = "refseq"
25
+ GENBANK = "genbank"
26
+
27
+
28
+ class NCBIHandler:
29
+ """This class uses the NCBI Datasets API to get the taxonomy tree of a given Taxon.
30
+
31
+ The taxonomy tree consists of only the next children to the parent taxon.
32
+ The children are only of the next lower rank of the parent taxon.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ api_key: str = None,
38
+ ):
39
+ """Initialise the NCBI handler."""
40
+ self.api_key = api_key
41
+ self.base_url = "https://api.ncbi.nlm.nih.gov/datasets/v2"
42
+ self.last_request_time = 0
43
+ self.min_interval = (
44
+ 1 / 10 if api_key else 1 / 5
45
+ ) # NCBI allows 10 requests per second with if an API key, otherwise 5 requests per second
46
+
47
+ def _enforce_rate_limit(self):
48
+ """Enforce rate limiting for the NCBI Datasets API.
49
+
50
+ This method ensures that the requests to the API are limited to 5 requests per second
51
+ without an API key and 10 requests per second with an API key.
52
+ It uses a simple time-based approach to enforce the rate limit.
53
+ """
54
+ now = time.time()
55
+ elapsed_time = now - self.last_request_time
56
+ if elapsed_time < self.min_interval:
57
+ time.sleep(self.min_interval - elapsed_time)
58
+ self.last_request_time = now # Update last request time
59
+
60
+ def _make_request(self, endpoint: str, timeout: int = 5) -> dict:
61
+ """Make a request to the NCBI Datasets API.
62
+
63
+ Args:
64
+ endpoint (str): The endpoint to make the request to.
65
+ timeout (int, optional): The timeout for the request in seconds. Defaults to 5.
66
+
67
+ Returns:
68
+ dict: The response from the API.
69
+ """
70
+ self._enforce_rate_limit()
71
+
72
+ endpoint = endpoint if endpoint.startswith("/") else "/" + endpoint
73
+ headers = {}
74
+ if self.api_key:
75
+ headers["api-key"] = self.api_key
76
+ response = requests.get(
77
+ self.base_url + endpoint, headers=headers, timeout=timeout
78
+ )
79
+ if response.status_code != 200:
80
+ response.raise_for_status()
81
+
82
+ return response.json()
83
+
84
+ def get_genus_taxon_id(self, genus: str) -> int:
85
+ """
86
+ Get the taxon id for a given genus name.
87
+
88
+ This function checks if the genus name is valid by making a request to the NCBI Datasets API.
89
+ If the genus name is valid, it returns the taxon id.
90
+ If the genus name is not valid, it raises an exception.
91
+
92
+ Args:
93
+ genus (str): The genus name to validate.
94
+
95
+ Returns:
96
+ int: The taxon id for the given genus name.
97
+
98
+ Raises:
99
+ ValueError: If the genus name is not valid.
100
+ """
101
+ endpoint = f"/taxonomy/taxon/{genus}"
102
+ response = self._make_request(endpoint)
103
+
104
+ try:
105
+ taxonomy = response["taxonomy_nodes"][0]["taxonomy"]
106
+
107
+ taxon_id = taxonomy["tax_id"]
108
+ rank = taxonomy["rank"]
109
+ lineage = taxonomy["lineage"]
110
+
111
+ if rank != "GENUS":
112
+ raise ValueError(f"Genus name {genus} is not a genus.")
113
+ if lineage[2] != 2:
114
+ raise ValueError(f"Genus name {genus} does not belong to bacteria.")
115
+
116
+ return taxon_id
117
+ except (IndexError, KeyError, TypeError) as e:
118
+ raise ValueError(f"Invalid genus name: {genus}") from e
119
+
120
+ def get_species(self, genus_id: int) -> list[int]:
121
+ """
122
+ Get the species for a given genus id.
123
+
124
+ This function makes a request to the NCBI Datasets API to get the species for a given genus id.
125
+ It returns a list of species taxonomy ids.
126
+
127
+ Args:
128
+ genus_id (int): The genus id to get the species for.
129
+
130
+ Returns:
131
+ list[int]: A list containing the species taxnomy ids.
132
+ """
133
+ endpoint = f"/taxonomy/taxon/{genus_id}/filtered_subtree"
134
+ response = self._make_request(endpoint)
135
+
136
+ try:
137
+ species_ids = response["edges"][str(genus_id)]["visible_children"]
138
+ except (IndexError, KeyError, TypeError) as e:
139
+ raise ValueError(f"Invalid genus id: {genus_id}") from e
140
+ return species_ids
141
+
142
+ def get_taxon_names(self, taxon_ids: list[int]) -> dict[int, str]:
143
+ """
144
+ Get the names for a given list of taxon ids.
145
+
146
+ This function makes a request to the NCBI Datasets API to get the names for a given list of taxon ids.
147
+ It returns a dictionary with the taxon ids as keys and the names as values.
148
+
149
+ Args:
150
+ taxon_ids (list[int]): The list of taxon ids to get the names for.
151
+
152
+ Returns:
153
+ dict[int, str]: A dictionary containing the taxon ids and their corresponding names.
154
+ """
155
+ if len(taxon_ids) > 1000:
156
+ raise ValueError("Maximum number of taxon ids is 1000.")
157
+ if len(taxon_ids) < 1:
158
+ raise ValueError("At least one taxon id is required.")
159
+
160
+ endpoint = f"/taxonomy/taxon/{','.join(map(str, taxon_ids))}?page_size=1000"
161
+ response = self._make_request(endpoint)
162
+
163
+ try:
164
+ taxon_names = {
165
+ int(taxonomy_node["taxonomy"]["tax_id"]): taxonomy_node["taxonomy"][
166
+ "organism_name"
167
+ ]
168
+ for taxonomy_node in response["taxonomy_nodes"]
169
+ }
170
+ if len(taxon_names) != len(taxon_ids):
171
+ raise ValueError("Not all taxon ids were found.")
172
+ except (IndexError, KeyError, TypeError) as e:
173
+ raise ValueError(f"Invalid taxon ids: {taxon_ids}") from e
174
+
175
+ return taxon_names
176
+
177
+ def get_accessions(
178
+ self,
179
+ taxon_id: int,
180
+ assembly_level: AssemblyLevel,
181
+ assembly_source: AssemblySource,
182
+ count: int,
183
+ min_n50: int = 10000,
184
+ exclude_atypical: bool = True,
185
+ exclude_paired_reports: bool = True,
186
+ current_version_only: bool = True,
187
+ ) -> list[str]:
188
+ """
189
+ Get the accessions for a given taxon id.
190
+
191
+ This function makes a request to the NCBI Datasets API to get the accessions for a given taxon id.
192
+ It filters the accessions based on the assembly level, assembly source, and other parameters.
193
+ It returns a list with the respective accessions.
194
+
195
+ Args:
196
+ taxon_id int: The taxon id to get the accessions for.
197
+ assembly_level (AssemblyLevel): The assembly level to get the accessions for.
198
+ assembly_source (AssemblySource): The assembly source to get the accessions for.
199
+ count (int): The number of accessions to get.
200
+ min_n50 (int, optional): The minimum contig n50 to filter the accessions. Defaults to 10000.
201
+ exclude_atypical (bool, optional): Whether to exclude atypical accessions. Defaults to True.
202
+ exclude_paired_reports (bool, optional): Whether to exclude paired reports. Defaults to True.
203
+ current_version_only (bool, optional): Whether to get only the current version of the accessions. Defaults to True.
204
+
205
+ Returns:
206
+ list[str]: A list containing the accessions.
207
+ """
208
+ endpoint = (
209
+ f"/genome/taxon/{taxon_id}/dataset_report?"
210
+ f"filters.assembly_source={assembly_source.value}&"
211
+ f"filters.exclude_atypical={exclude_atypical}&"
212
+ f"filters.exclude_paired_reports={exclude_paired_reports}&"
213
+ f"filters.current_version_only={current_version_only}&"
214
+ f"page_size={count * 2}&" # to avoid having less than count if n50 or ANI is not met
215
+ )
216
+ endpoint += (
217
+ "&filters.reference_only=true"
218
+ if assembly_level == AssemblyLevel.REFERENCE
219
+ else f"&filters.assembly_level={assembly_level.value}"
220
+ )
221
+
222
+ response = self._make_request(endpoint)
223
+ try:
224
+ accessions = [
225
+ report["accession"]
226
+ for report in response["reports"]
227
+ if report["assembly_stats"]["contig_n50"] >= min_n50
228
+ and report["average_nucleotide_identity"]["taxonomy_check_status"]
229
+ == "OK"
230
+ ]
231
+ except (IndexError, KeyError, TypeError):
232
+ print(f"Could not get accessions for taxon with ID: {taxon_id}. Skipping.")
233
+ return []
234
+ return accessions[:count] # Limit to count
235
+
236
+ def get_highest_quality_accessions(
237
+ self, taxon_id: int, assembly_source: AssemblySource, count: int
238
+ ) -> list[str]:
239
+ """Get the highest quality accessions for a given taxon id (based on the assembly level)."""
240
+ accessions = []
241
+ for assembly_level in list(AssemblyLevel):
242
+ accessions += self.get_accessions(
243
+ taxon_id,
244
+ assembly_level,
245
+ assembly_source,
246
+ count,
247
+ )
248
+ if len(set(accessions)) >= count:
249
+ break
250
+ return list(set(accessions))[:count] # Remove duplicates and limit to count
251
+
252
+ def download_assemblies(self, accessions: list[str], output_dir: Path) -> None:
253
+ """Download assemblies for a list of accessions."""
254
+ endpoint = f"/genome/accession/{','.join(accessions)}/download?include_annotation_type=GENOME_FASTA"
255
+
256
+ self._enforce_rate_limit()
257
+
258
+ response = requests.get(self.base_url + endpoint, stream=True, timeout=5)
259
+ if response.status_code != 200:
260
+ response.raise_for_status()
261
+
262
+ output_dir.mkdir(parents=True, exist_ok=True)
263
+ with open(output_dir / "ncbi_dataset.zip", "wb") as f:
264
+ for chunk in response.iter_content(chunk_size=8192):
265
+ f.write(chunk)