XspecT 0.2.7__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of XspecT might be problematic. Click here for more details.

@@ -4,7 +4,6 @@
4
4
 
5
5
  import csv
6
6
  import json
7
- from linecache import getline
8
7
  from pathlib import Path
9
8
  from sklearn.svm import SVC
10
9
  from Bio.SeqRecord import SeqRecord
@@ -30,6 +29,8 @@ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
30
29
  c: float,
31
30
  fpr: float = 0.01,
32
31
  num_hashes: int = 7,
32
+ training_accessions: dict[str, list[str]] = None,
33
+ svm_accessions: dict[str, list[str]] = None,
33
34
  ) -> None:
34
35
  super().__init__(
35
36
  k=k,
@@ -40,14 +41,17 @@ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
40
41
  base_path=base_path,
41
42
  fpr=fpr,
42
43
  num_hashes=num_hashes,
44
+ training_accessions=training_accessions,
43
45
  )
44
46
  self.kernel = kernel
45
47
  self.c = c
48
+ self.svm_accessions = svm_accessions
46
49
 
47
50
  def to_dict(self) -> dict:
48
51
  return super().to_dict() | {
49
52
  "kernel": self.kernel,
50
53
  "C": self.c,
54
+ "svm_accessions": self.svm_accessions,
51
55
  }
52
56
 
53
57
  def set_svm_params(self, kernel: str, c: float) -> None:
@@ -62,32 +66,41 @@ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
62
66
  svm_path: Path,
63
67
  display_names: dict = None,
64
68
  svm_step: int = 1,
69
+ training_accessions: list[str] = None,
70
+ svm_accessions: list[str] = None,
65
71
  ) -> None:
66
72
  """Fit the SVM to the sequences and labels"""
67
73
 
68
74
  # Since the SVM works with score data, we need to train
69
75
  # the underlying data structure for score generation first
70
- super().fit(dir_path, display_names=display_names)
76
+ super().fit(
77
+ dir_path,
78
+ display_names=display_names,
79
+ training_accessions=training_accessions,
80
+ )
81
+
82
+ self.svm_accessions = svm_accessions
71
83
 
72
84
  # calculate scores for SVM training
73
85
  score_list = []
74
- for file in svm_path.iterdir():
75
- if not file.is_file():
76
- continue
77
- if file.suffix[1:] not in fasta_endings + fastq_endings:
86
+
87
+ for species_folder in svm_path.iterdir():
88
+ if not species_folder.is_dir():
78
89
  continue
79
- print(f"Calculating {file.name} scores for SVM training...")
80
- res = super().predict(file, step=svm_step)
81
- scores = res.get_scores()["total"]
82
- accession = "".join(file.name.split("_")[:2])
83
- file_header = getline(str(file), 1)
84
- label_id = file_header.replace("\n", "").replace(">", "")
85
-
86
- # format scores for csv
87
- scores = dict(sorted(scores.items()))
88
- scores = ",".join([str(score) for score in scores.values()])
89
- scores = f"{accession},{scores},{label_id}"
90
- score_list.append(scores)
90
+ for file in species_folder.iterdir():
91
+ if file.suffix[1:] not in fasta_endings + fastq_endings:
92
+ continue
93
+ print(f"Calculating {file.name} scores for SVM training...")
94
+ res = super().predict(file, step=svm_step)
95
+ scores = res.get_scores()["total"]
96
+ accession = file.stem
97
+ label_id = species_folder.name
98
+
99
+ # format scores for csv
100
+ scores = dict(sorted(scores.items()))
101
+ scores = ",".join([str(score) for score in scores.values()])
102
+ scores = f"{accession},{scores},{label_id}"
103
+ score_list.append(scores)
91
104
 
92
105
  # csv header
93
106
  keys = list(self.display_names.keys())
@@ -162,6 +175,8 @@ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
162
175
  model_json["C"],
163
176
  fpr=model_json["fpr"],
164
177
  num_hashes=model_json["num_hashes"],
178
+ training_accessions=model_json["training_accessions"],
179
+ svm_accessions=model_json["svm_accessions"],
165
180
  )
166
181
  model.display_names = model_json["display_names"]
167
182
 
@@ -25,6 +25,7 @@ class ProbabilisticSingleFilterModel(ProbabilisticFilterModel):
25
25
  model_type: str,
26
26
  base_path: Path,
27
27
  fpr: float = 0.01,
28
+ training_accessions: list[str] = None,
28
29
  ) -> None:
29
30
  super().__init__(
30
31
  k=k,
@@ -35,11 +36,16 @@ class ProbabilisticSingleFilterModel(ProbabilisticFilterModel):
35
36
  base_path=base_path,
36
37
  fpr=fpr,
37
38
  num_hashes=1,
39
+ training_accessions=training_accessions,
38
40
  )
39
41
  self.bf = None
40
42
 
41
- def fit(self, file_path: Path, display_name: str) -> None:
43
+ def fit(
44
+ self, file_path: Path, display_name: str, training_accessions: list[str] = None
45
+ ) -> None:
42
46
  """Fit the cobs classic index to the sequences and labels"""
47
+ self.training_accessions = training_accessions
48
+
43
49
  # estimate number of kmers
44
50
  total_length = 0
45
51
  for record in get_record_iterator(file_path):
@@ -88,6 +94,7 @@ class ProbabilisticSingleFilterModel(ProbabilisticFilterModel):
88
94
  model_json["model_type"],
89
95
  path.parent,
90
96
  fpr=model_json["fpr"],
97
+ training_accessions=model_json["training_accessions"],
91
98
  )
92
99
  model.display_names = model_json["display_names"]
93
100
  bloom_path = model.base_path / model.slug() / "filter.bloom"
xspect/models/result.py CHANGED
@@ -1,50 +1,7 @@
1
1
  """Module for storing the results of XspecT models."""
2
2
 
3
- from enum import Enum
4
-
5
-
6
- def get_last_processing_step(result: "ModelResult") -> "ModelResult":
7
- """Get the last subprocessing step of the result. First path only."""
8
-
9
- # traverse result tree to get last step
10
- while result.subprocessing_steps:
11
- result = result.subprocessing_steps[-1].result
12
- return result
13
-
14
-
15
- class StepType(Enum):
16
- """Enum for defining the type of a subprocessing step."""
17
-
18
- PREDICTION = 1
19
- FILTERING = 2
20
-
21
- def __str__(self) -> str:
22
- return self.name.lower()
23
-
24
-
25
- class SubprocessingStep:
26
- """Class for storing a subprocessing step of an XspecT model."""
27
-
28
- def __init__(
29
- self,
30
- subprocessing_type: StepType,
31
- label: str,
32
- treshold: float,
33
- result: "ModelResult",
34
- ):
35
- self.subprocessing_type = subprocessing_type
36
- self.label = label
37
- self.treshold = treshold
38
- self.result = result
39
-
40
- def to_dict(self) -> dict:
41
- """Return the subprocessing step as a dictionary."""
42
- return {
43
- "subprocessing_type": str(self.subprocessing_type),
44
- "label": self.label,
45
- "treshold": self.treshold,
46
- "result": self.result.to_dict() if self.result else {},
47
- }
3
+ from json import dumps
4
+ from pathlib import Path
48
5
 
49
6
 
50
7
  class ModelResult:
@@ -58,6 +15,7 @@ class ModelResult:
58
15
  num_kmers: dict[str, int],
59
16
  sparse_sampling_step: int = 1,
60
17
  prediction: str = None,
18
+ input_source: str = None,
61
19
  ):
62
20
  if "total" in hits:
63
21
  raise ValueError(
@@ -68,15 +26,7 @@ class ModelResult:
68
26
  self.num_kmers = num_kmers
69
27
  self.sparse_sampling_step = sparse_sampling_step
70
28
  self.prediction = prediction
71
- self.subprocessing_steps = []
72
-
73
- def add_subprocessing_step(self, subprocessing_step: SubprocessingStep) -> None:
74
- """Add a subprocessing step to the result."""
75
- if subprocessing_step.label in self.subprocessing_steps:
76
- raise ValueError(
77
- f"Subprocessing step {subprocessing_step.label} already exists in the result"
78
- )
79
- self.subprocessing_steps.append(subprocessing_step)
29
+ self.input_source = input_source
80
30
 
81
31
  def get_scores(self) -> dict:
82
32
  """Return the scores of the model."""
@@ -108,19 +58,33 @@ class ModelResult:
108
58
  return total_hits
109
59
 
110
60
  def get_filter_mask(self, label: str, filter_threshold: float) -> dict[str, bool]:
111
- """Return a mask for filtered subsequences."""
112
- if filter_threshold < 0 or filter_threshold > 1:
61
+ """Return a mask for filtered subsequences.
62
+
63
+ The mask is a dictionary with subsequence names as keys and boolean values
64
+ indicating whether the subsequence is above the filter threshold for the given label.
65
+ A value of -1 for filter_threshold indicates that the subsequence with the maximum score
66
+ for the given label should be returned.
67
+ """
68
+ if filter_threshold < 0 and not filter_threshold == -1 or filter_threshold > 1:
113
69
  raise ValueError("The filter threshold must be between 0 and 1.")
114
70
 
115
71
  scores = self.get_scores()
116
72
  scores.pop("total")
117
- return {
118
- subsequence: score[label] >= filter_threshold
119
- for subsequence, score in scores.items()
120
- }
73
+ if not filter_threshold == -1:
74
+ return {
75
+ subsequence: score[label] >= filter_threshold
76
+ for subsequence, score in scores.items()
77
+ }
78
+ else:
79
+ return {
80
+ subsequence: score[label] == max(score.values())
81
+ for subsequence, score in scores.items()
82
+ }
121
83
 
122
- def get_filtered_subsequences(self, label: str, filter_threshold: 0.7) -> list[str]:
123
- """Return the filtered subsequences."""
84
+ def get_filtered_subsequence_labels(
85
+ self, label: str, filter_threshold: float = 0.7
86
+ ) -> list[str]:
87
+ """Return the labels of filtered subsequences."""
124
88
  return [
125
89
  subsequence
126
90
  for subsequence, mask in self.get_filter_mask(
@@ -137,13 +101,15 @@ class ModelResult:
137
101
  "hits": self.hits,
138
102
  "scores": self.get_scores(),
139
103
  "num_kmers": self.num_kmers,
140
- "subprocessing_steps": [
141
- subprocessing_step.to_dict()
142
- for subprocessing_step in self.subprocessing_steps
143
- ],
104
+ "input_source": self.input_source,
144
105
  }
145
106
 
146
107
  if self.prediction is not None:
147
108
  res["prediction"] = self.prediction
148
109
 
149
110
  return res
111
+
112
+ def save(self, path: Path) -> None:
113
+ """Save the result as a JSON file."""
114
+ with open(path, "w", encoding="utf-8") as f:
115
+ f.write(dumps(self.to_dict(), indent=4))
xspect/ncbi.py ADDED
@@ -0,0 +1,265 @@
1
+ """NCBI handler for the NCBI Datasets API."""
2
+
3
+ from enum import Enum
4
+ from pathlib import Path
5
+ import requests
6
+ import time
7
+
8
+ # pylint: disable=line-too-long
9
+
10
+
11
+ class AssemblyLevel(Enum):
12
+ """Enum for the assembly level."""
13
+
14
+ REFERENCE = "reference"
15
+ COMPLETE_GENOME = "complete_genome"
16
+ CHROMOSOME = "chromosome"
17
+ SCAFFOLD = "scaffold"
18
+ CONTIG = "contig"
19
+
20
+
21
+ class AssemblySource(Enum):
22
+ """Enum for the assembly source."""
23
+
24
+ REFSEQ = "refseq"
25
+ GENBANK = "genbank"
26
+
27
+
28
+ class NCBIHandler:
29
+ """This class uses the NCBI Datasets API to get the taxonomy tree of a given Taxon.
30
+
31
+ The taxonomy tree consists of only the next children to the parent taxon.
32
+ The children are only of the next lower rank of the parent taxon.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ api_key: str = None,
38
+ ):
39
+ """Initialise the NCBI handler."""
40
+ self.api_key = api_key
41
+ self.base_url = "https://api.ncbi.nlm.nih.gov/datasets/v2"
42
+ self.last_request_time = 0
43
+ self.min_interval = (
44
+ 1 / 10 if api_key else 1 / 5
45
+ ) # NCBI allows 10 requests per second with if an API key, otherwise 5 requests per second
46
+
47
+ def _enforce_rate_limit(self):
48
+ """Enforce rate limiting for the NCBI Datasets API.
49
+
50
+ This method ensures that the requests to the API are limited to 5 requests per second
51
+ without an API key and 10 requests per second with an API key.
52
+ It uses a simple time-based approach to enforce the rate limit.
53
+ """
54
+ now = time.time()
55
+ elapsed_time = now - self.last_request_time
56
+ if elapsed_time < self.min_interval:
57
+ time.sleep(self.min_interval - elapsed_time)
58
+ self.last_request_time = now # Update last request time
59
+
60
+ def _make_request(self, endpoint: str, timeout: int = 5) -> dict:
61
+ """Make a request to the NCBI Datasets API.
62
+
63
+ Args:
64
+ endpoint (str): The endpoint to make the request to.
65
+ timeout (int, optional): The timeout for the request in seconds. Defaults to 5.
66
+
67
+ Returns:
68
+ dict: The response from the API.
69
+ """
70
+ self._enforce_rate_limit()
71
+
72
+ endpoint = endpoint if endpoint.startswith("/") else "/" + endpoint
73
+ headers = {}
74
+ if self.api_key:
75
+ headers["api-key"] = self.api_key
76
+ response = requests.get(
77
+ self.base_url + endpoint, headers=headers, timeout=timeout
78
+ )
79
+ if response.status_code != 200:
80
+ response.raise_for_status()
81
+
82
+ return response.json()
83
+
84
+ def get_genus_taxon_id(self, genus: str) -> int:
85
+ """
86
+ Get the taxon id for a given genus name.
87
+
88
+ This function checks if the genus name is valid by making a request to the NCBI Datasets API.
89
+ If the genus name is valid, it returns the taxon id.
90
+ If the genus name is not valid, it raises an exception.
91
+
92
+ Args:
93
+ genus (str): The genus name to validate.
94
+
95
+ Returns:
96
+ int: The taxon id for the given genus name.
97
+
98
+ Raises:
99
+ ValueError: If the genus name is not valid.
100
+ """
101
+ endpoint = f"/taxonomy/taxon/{genus}"
102
+ response = self._make_request(endpoint)
103
+
104
+ try:
105
+ taxonomy = response["taxonomy_nodes"][0]["taxonomy"]
106
+
107
+ taxon_id = taxonomy["tax_id"]
108
+ rank = taxonomy["rank"]
109
+ lineage = taxonomy["lineage"]
110
+
111
+ if rank != "GENUS":
112
+ raise ValueError(f"Genus name {genus} is not a genus.")
113
+ if lineage[2] != 2:
114
+ raise ValueError(f"Genus name {genus} does not belong to bacteria.")
115
+
116
+ return taxon_id
117
+ except (IndexError, KeyError, TypeError) as e:
118
+ raise ValueError(f"Invalid genus name: {genus}") from e
119
+
120
+ def get_species(self, genus_id: int) -> list[int]:
121
+ """
122
+ Get the species for a given genus id.
123
+
124
+ This function makes a request to the NCBI Datasets API to get the species for a given genus id.
125
+ It returns a list of species taxonomy ids.
126
+
127
+ Args:
128
+ genus_id (int): The genus id to get the species for.
129
+
130
+ Returns:
131
+ list[int]: A list containing the species taxnomy ids.
132
+ """
133
+ endpoint = f"/taxonomy/taxon/{genus_id}/filtered_subtree"
134
+ response = self._make_request(endpoint)
135
+
136
+ try:
137
+ species_ids = response["edges"][str(genus_id)]["visible_children"]
138
+ except (IndexError, KeyError, TypeError) as e:
139
+ raise ValueError(f"Invalid genus id: {genus_id}") from e
140
+ return species_ids
141
+
142
+ def get_taxon_names(self, taxon_ids: list[int]) -> dict[int, str]:
143
+ """
144
+ Get the names for a given list of taxon ids.
145
+
146
+ This function makes a request to the NCBI Datasets API to get the names for a given list of taxon ids.
147
+ It returns a dictionary with the taxon ids as keys and the names as values.
148
+
149
+ Args:
150
+ taxon_ids (list[int]): The list of taxon ids to get the names for.
151
+
152
+ Returns:
153
+ dict[int, str]: A dictionary containing the taxon ids and their corresponding names.
154
+ """
155
+ if len(taxon_ids) > 1000:
156
+ raise ValueError("Maximum number of taxon ids is 1000.")
157
+ if len(taxon_ids) < 1:
158
+ raise ValueError("At least one taxon id is required.")
159
+
160
+ endpoint = f"/taxonomy/taxon/{','.join(map(str, taxon_ids))}?page_size=1000"
161
+ response = self._make_request(endpoint)
162
+
163
+ try:
164
+ taxon_names = {
165
+ int(taxonomy_node["taxonomy"]["tax_id"]): taxonomy_node["taxonomy"][
166
+ "organism_name"
167
+ ]
168
+ for taxonomy_node in response["taxonomy_nodes"]
169
+ }
170
+ if len(taxon_names) != len(taxon_ids):
171
+ raise ValueError("Not all taxon ids were found.")
172
+ except (IndexError, KeyError, TypeError) as e:
173
+ raise ValueError(f"Invalid taxon ids: {taxon_ids}") from e
174
+
175
+ return taxon_names
176
+
177
+ def get_accessions(
178
+ self,
179
+ taxon_id: int,
180
+ assembly_level: AssemblyLevel,
181
+ assembly_source: AssemblySource,
182
+ count: int,
183
+ min_n50: int = 10000,
184
+ exclude_atypical: bool = True,
185
+ exclude_paired_reports: bool = True,
186
+ current_version_only: bool = True,
187
+ ) -> list[str]:
188
+ """
189
+ Get the accessions for a given taxon id.
190
+
191
+ This function makes a request to the NCBI Datasets API to get the accessions for a given taxon id.
192
+ It filters the accessions based on the assembly level, assembly source, and other parameters.
193
+ It returns a list with the respective accessions.
194
+
195
+ Args:
196
+ taxon_id int: The taxon id to get the accessions for.
197
+ assembly_level (AssemblyLevel): The assembly level to get the accessions for.
198
+ assembly_source (AssemblySource): The assembly source to get the accessions for.
199
+ count (int): The number of accessions to get.
200
+ min_n50 (int, optional): The minimum contig n50 to filter the accessions. Defaults to 10000.
201
+ exclude_atypical (bool, optional): Whether to exclude atypical accessions. Defaults to True.
202
+ exclude_paired_reports (bool, optional): Whether to exclude paired reports. Defaults to True.
203
+ current_version_only (bool, optional): Whether to get only the current version of the accessions. Defaults to True.
204
+
205
+ Returns:
206
+ list[str]: A list containing the accessions.
207
+ """
208
+ endpoint = (
209
+ f"/genome/taxon/{taxon_id}/dataset_report?"
210
+ f"filters.assembly_source={assembly_source.value}&"
211
+ f"filters.exclude_atypical={exclude_atypical}&"
212
+ f"filters.exclude_paired_reports={exclude_paired_reports}&"
213
+ f"filters.current_version_only={current_version_only}&"
214
+ f"page_size={count * 2}&" # to avoid having less than count if n50 or ANI is not met
215
+ )
216
+ endpoint += (
217
+ "&filters.reference_only=true"
218
+ if assembly_level == AssemblyLevel.REFERENCE
219
+ else f"&filters.assembly_level={assembly_level.value}"
220
+ )
221
+
222
+ response = self._make_request(endpoint)
223
+ try:
224
+ accessions = [
225
+ report["accession"]
226
+ for report in response["reports"]
227
+ if report["assembly_stats"]["contig_n50"] >= min_n50
228
+ and report["average_nucleotide_identity"]["taxonomy_check_status"]
229
+ == "OK"
230
+ ]
231
+ except (IndexError, KeyError, TypeError):
232
+ print(f"Could not get accessions for taxon with ID: {taxon_id}. Skipping.")
233
+ return []
234
+ return accessions[:count] # Limit to count
235
+
236
+ def get_highest_quality_accessions(
237
+ self, taxon_id: int, assembly_source: AssemblySource, count: int
238
+ ) -> list[str]:
239
+ """Get the highest quality accessions for a given taxon id (based on the assembly level)."""
240
+ accessions = []
241
+ for assembly_level in list(AssemblyLevel):
242
+ accessions += self.get_accessions(
243
+ taxon_id,
244
+ assembly_level,
245
+ assembly_source,
246
+ count,
247
+ )
248
+ if len(set(accessions)) >= count:
249
+ break
250
+ return list(set(accessions))[:count] # Remove duplicates and limit to count
251
+
252
+ def download_assemblies(self, accessions: list[str], output_dir: Path) -> None:
253
+ """Download assemblies for a list of accessions."""
254
+ endpoint = f"/genome/accession/{','.join(accessions)}/download?include_annotation_type=GENOME_FASTA"
255
+
256
+ self._enforce_rate_limit()
257
+
258
+ response = requests.get(self.base_url + endpoint, stream=True, timeout=5)
259
+ if response.status_code != 200:
260
+ response.raise_for_status()
261
+
262
+ output_dir.mkdir(parents=True, exist_ok=True)
263
+ with open(output_dir / "ncbi_dataset.zip", "wb") as f:
264
+ for chunk in response.iter_content(chunk_size=8192):
265
+ f.write(chunk)