XspecT 0.5.3__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of XspecT might be problematic. Click here for more details.

@@ -1,6 +1,7 @@
1
1
  """Probabilistic filter model for sequence data"""
2
2
 
3
3
  import json
4
+ import shutil
4
5
  from math import ceil
5
6
  from pathlib import Path
6
7
  from typing import Any
@@ -9,9 +10,19 @@ from Bio.SeqRecord import SeqRecord
9
10
  from Bio import SeqIO
10
11
  from slugify import slugify
11
12
  import cobs_index as cobs
12
- from xspect.definitions import fasta_endings, fastq_endings
13
+ from xspect.definitions import (
14
+ fasta_endings,
15
+ fastq_endings,
16
+ get_xspect_misclassification_path,
17
+ )
13
18
  from xspect.file_io import get_record_iterator
19
+ from xspect.misclassification_detection.mapping import MappingHandler
14
20
  from xspect.models.result import ModelResult
21
+ from collections import defaultdict
22
+ from xspect.ncbi import NCBIHandler
23
+ from xspect.misclassification_detection.point_pattern_analysis import (
24
+ PointPatternAnalysis,
25
+ )
15
26
 
16
27
 
17
28
  class ProbabilisticFilterModel:
@@ -135,8 +146,8 @@ class ProbabilisticFilterModel:
135
146
  display_names (dict | None): A dictionary mapping file names to display names.
136
147
  If None, uses file names as display names.
137
148
  training_accessions (dict[str, list[str]] | None): A dictionary mapping filter IDs to
138
- lists of accession numbers used for training the model. If None, no training accessions
139
- are set.
149
+ lists of accession numbers used for training the model. If None, no training
150
+ accessions are set.
140
151
  Raises:
141
152
  ValueError: If the directory path is invalid, does not exist, or is not a directory.
142
153
  """
@@ -230,6 +241,8 @@ class ProbabilisticFilterModel:
230
241
  ),
231
242
  filter_ids: list[str] = None,
232
243
  step: int = 1,
244
+ display_name: bool = False,
245
+ validation: bool = False,
233
246
  ) -> ModelResult:
234
247
  """
235
248
  Returns a model result object for the sequence(s) based on the filters in the model
@@ -246,6 +259,8 @@ class ProbabilisticFilterModel:
246
259
  filter_ids (list[str]): A list of filter IDs to filter the results. If None,
247
260
  all results are returned.
248
261
  step (int): The step size for the k-mer search. Default is 1.
262
+ display_name (bool): Includes a display name for each tax_ID.
263
+ validation (bool): Sorts out misclassified reads.
249
264
 
250
265
  Returns:
251
266
  ModelResult: An object containing the hits for each sequence, the number of kmers,
@@ -253,11 +268,12 @@ class ProbabilisticFilterModel:
253
268
 
254
269
  Raises:
255
270
  ValueError: If the input sequence is not valid, or if it is not a Seq object,
256
- a list of Seq objects, a SeqIO iterator, or a Path object to a fasta/fastq file.
271
+ a list of Seq objects, a SeqIO iterator, or a Path object to a fasta/fastq
272
+ file.
257
273
  """
258
274
  if isinstance(sequence_input, (SeqRecord)):
259
275
  return ProbabilisticFilterModel.predict(
260
- self, [sequence_input], filter_ids, step=step
276
+ self, [sequence_input], filter_ids, step, display_name, validation
261
277
  )
262
278
 
263
279
  if self._is_sequence_list(sequence_input) | self._is_sequence_iterator(
@@ -265,19 +281,42 @@ class ProbabilisticFilterModel:
265
281
  ):
266
282
  hits = {}
267
283
  num_kmers = {}
284
+ if validation and self._is_sequence_iterator(sequence_input):
285
+ sequence_input = list(sequence_input)
286
+
268
287
  for individual_sequence in sequence_input:
269
288
  individual_hits = self.calculate_hits(
270
- individual_sequence.seq, filter_ids, step=step
289
+ individual_sequence.seq, filter_ids, step
271
290
  )
272
291
  num_kmers[individual_sequence.id] = self._count_kmers(
273
- individual_sequence, step=step
292
+ individual_sequence, step
274
293
  )
294
+
295
+ if display_name:
296
+ individual_hits.update(
297
+ {
298
+ f"{key} -{self.display_names.get(key, 'Unknown').replace(
299
+ self.model_display_name, '', 1)}": individual_hits.pop(
300
+ key
301
+ )
302
+ for key in list(individual_hits.keys())
303
+ }
304
+ )
305
+
275
306
  hits[individual_sequence.id] = individual_hits
307
+
308
+ if validation:
309
+ hits = self.detecting_misclassification(hits, sequence_input)
310
+
276
311
  return ModelResult(self.slug(), hits, num_kmers, sparse_sampling_step=step)
277
312
 
278
313
  if isinstance(sequence_input, Path):
279
314
  return ProbabilisticFilterModel.predict(
280
- self, get_record_iterator(sequence_input), step=step
315
+ self,
316
+ get_record_iterator(sequence_input),
317
+ step=step,
318
+ display_name=display_name,
319
+ validation=validation,
281
320
  )
282
321
 
283
322
  raise ValueError(
@@ -460,3 +499,98 @@ class ProbabilisticFilterModel:
460
499
  sequence_input,
461
500
  (SeqIO.FastaIO.FastaIterator, SeqIO.QualityIO.FastqPhredIterator),
462
501
  )
502
+
503
+ def detecting_misclassification(
504
+ self,
505
+ hits: dict[str, dict[str, int]],
506
+ seq_records: list[SeqRecord],
507
+ min_reads: int = 10,
508
+ ) -> dict[str, dict[str, int]]:
509
+ """
510
+ Notes:
511
+ Developed by Oemer Cetin as part of a Bsc thesis at Goethe University Frankfurt am Main (2025).
512
+ (An Integration of Alignment-Free and Alignment-Based Approaches for Bacterial Taxon Assignment)
513
+
514
+ Detects misclassification for short sequences.
515
+
516
+ This function is an alignment-based procedure that groups species by highest XspecT scores.
517
+ Each species group is mapped against the respective reference genome.
518
+ Start coordinates are extracted and scanned for local clustering.
519
+ When local clustering is detected, all sequences belonging to the species are sorted out.
520
+
521
+ Args:
522
+ hits (dict): The species annotations from the prediction step.
523
+ seq_records (list): The provided sequences.
524
+ min_reads (int): Minimum amount of reads, that species groups should have.
525
+
526
+ Returns:
527
+ dict: hits where misclassified sequences have been sorted out.
528
+ """
529
+ rec_by_id = {record.id: record for record in seq_records}
530
+ grouped: dict[int, list[SeqRecord]] = defaultdict(list)
531
+ misclassified = {}
532
+
533
+ # group by species annotation
534
+ for record, score_dict in hits.items():
535
+ if record == "misclassified":
536
+ continue
537
+ sorted_hits = sorted(
538
+ score_dict.items(), key=lambda entry: entry[1], reverse=True
539
+ )
540
+ if sorted_hits[0][1] > sorted_hits[1][1]: # unique highest score
541
+ highest_tax_id = int(sorted_hits[0][0]) # tax_id
542
+ if record in rec_by_id:
543
+ # groups all reads with the highest score by tax_id
544
+ grouped[highest_tax_id].append(rec_by_id[record])
545
+ filtered_grouped = {
546
+ tax_id: seq for tax_id, seq in grouped.items() if len(seq) > min_reads
547
+ }
548
+ largest_group = max(
549
+ filtered_grouped,
550
+ key=lambda tax_id: len(filtered_grouped[tax_id]),
551
+ default=None,
552
+ )
553
+
554
+ # mapping procedure
555
+ handler = NCBIHandler()
556
+ out_dir = get_xspect_misclassification_path()
557
+ out_dir.mkdir(parents=True, exist_ok=True)
558
+ for tax_id, reads in filtered_grouped.items():
559
+ if tax_id == largest_group:
560
+ continue
561
+
562
+ tax_dir = out_dir / str(tax_id)
563
+ tax_dir.mkdir(parents=True, exist_ok=True)
564
+ fasta_path = tax_dir / f"{tax_id}.fasta"
565
+ SeqIO.write(reads, fasta_path, "fasta")
566
+ reference_path = tax_dir / f"{tax_id}.fna"
567
+ # download reference once
568
+ if not (reference_path.exists() and reference_path.stat().st_size > 0):
569
+ handler.download_reference_genome(tax_id, tax_dir)
570
+ if not reference_path.exists():
571
+ shutil.rmtree(tax_dir)
572
+ continue
573
+
574
+ mapping_handler = MappingHandler(str(reference_path), str(fasta_path))
575
+ mapping_handler.map_reads_onto_reference()
576
+ mapping_handler.extract_starting_coordinates()
577
+ genome_length = mapping_handler.get_total_genome_length()
578
+ start_coordinates = mapping_handler.get_start_coordinates()
579
+
580
+ if len(start_coordinates) < min_reads:
581
+ continue
582
+
583
+ # cluster analysis
584
+ analysis = PointPatternAnalysis(start_coordinates, genome_length)
585
+ clustered = analysis.ripleys_k_edge_corrected()
586
+
587
+ if clustered[0]: # True or False
588
+ bucket = misclassified.setdefault(tax_id, {})
589
+ for read in reads:
590
+ data = hits.pop(read.id, None) # remove false reads from main hits
591
+ if data is not None:
592
+ bucket[read.id] = data
593
+
594
+ if misclassified:
595
+ hits["misclassified"] = misclassified
596
+ return hits
@@ -55,10 +55,14 @@ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
55
55
  base_path (Path): The base path where the model will be stored.
56
56
  kernel (str): The kernel type for the SVM (e.g., 'linear', 'rbf').
57
57
  c (float): Regularization parameter for the SVM.
58
- fpr (float, optional): False positive rate for the probabilistic filter. Defaults to 0.01.
59
- num_hashes (int, optional): Number of hashes for the probabilistic filter. Defaults to 7.
60
- training_accessions (dict[str, list[str]] | None, optional): Accessions used for training the probabilistic filter. Defaults to None.
61
- svm_accessions (dict[str, list[str]] | None, optional): Accessions used for training the SVM. Defaults to None.
58
+ fpr (float, optional): False positive rate for the probabilistic filter.
59
+ Defaults to 0.01.
60
+ num_hashes (int, optional): Number of hashes for the probabilistic filter.
61
+ Defaults to 7.
62
+ training_accessions (dict[str, list[str]] | None, optional): Accessions used for
63
+ training the probabilistic filter. Defaults to None.
64
+ svm_accessions (dict[str, list[str]] | None, optional): Accessions used for
65
+ training the SVM. Defaults to None.
62
66
  """
63
67
  super().__init__(
64
68
  k=k,
@@ -112,17 +116,18 @@ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
112
116
  """
113
117
  Fit the SVM to the sequences and labels.
114
118
 
115
- This method first trains the probabilistic filter model and then
116
- calculates scores for the SVM training. It expects the sequences to be in
117
- the specified directory and the SVM training sequences to be in the
118
- specified SVM path. The scores are saved in a CSV file for later use.
119
+ This method first trains the probabilistic filter model and then calculates scores for
120
+ the SVM training. It expects the sequences to be in the specified directory and the SVM
121
+ training sequences to be in the specified SVM path. The scores are saved in a CSV file
122
+ for later use.
119
123
 
120
124
  Args:
121
125
  dir_path (Path): The directory containing the training sequences.
122
126
  svm_path (Path): The directory containing the SVM training sequences.
123
127
  display_names (dict[str, str] | None): A mapping of accession IDs to display names.
124
128
  svm_step (int): Step size for sparse sampling in SVM training.
125
- training_accessions (dict[str, list[str]] | None): Accessions used for training the probabilistic filter.
129
+ training_accessions (dict[str, list[str]] | None): Accessions used for training the
130
+ probabilistic filter.
126
131
  svm_accessions (dict[str, list[str]] | None): Accessions used for training the SVM.
127
132
  """
128
133
 
@@ -178,6 +183,8 @@ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
178
183
  ),
179
184
  filter_ids: list[str] = None,
180
185
  step: int = 1,
186
+ display_name: bool = False,
187
+ validation: bool = False,
181
188
  ) -> ModelResult:
182
189
  """
183
190
  Predict the labels of the sequences.
@@ -187,19 +194,26 @@ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
187
194
  with the probabilistic filter model, and it will return a `ModelResult`.
188
195
 
189
196
  Args:
190
- sequence_input (SeqRecord | list[SeqRecord] | SeqIO.FastaIO.FastaIterator | SeqIO.QualityIO.FastqPhredIterator | Path): The input sequences to predict.
191
- filter_ids (list[str], optional): A list of IDs to filter the predictions. Defaults to None.
197
+ sequence_input (SeqRecord | list[SeqRecord] | SeqIO.FastaIO.FastaIterator |
198
+ SeqIO.QualityIO.FastqPhredIterator | Path): The input sequences to predict.
199
+ filter_ids (list[str], optional): A list of IDs to filter the predictions.
192
200
  step (int, optional): Step size for sparse sampling. Defaults to 1.
201
+ display_name (bool): Includes a display name for each tax_ID.
202
+ validation (bool): Sorts out misclassified reads .
193
203
 
194
204
  Returns:
195
- ModelResult: The result of the prediction containing hits, number of kmers, and the predicted label.
205
+ ModelResult: The result of the prediction containing hits, number of kmers, and the
206
+ predicted label.
196
207
  """
197
208
  # get scores and format them for the SVM
198
- res = super().predict(sequence_input, filter_ids, step=step)
209
+ res = super().predict(
210
+ sequence_input, filter_ids, step, display_name, validation
211
+ )
199
212
  svm_scores = dict(sorted(res.get_scores()["total"].items()))
200
213
  svm_scores = [list(svm_scores.values())]
201
214
 
202
215
  svm = self._get_svm(filter_ids)
216
+ res.hits["misclassified"] = res.misclassified
203
217
  return ModelResult(
204
218
  self.slug(),
205
219
  res.hits,
@@ -217,7 +231,8 @@ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
217
231
  training data to only include those keys.
218
232
 
219
233
  Args:
220
- id_keys (list[str] | None): A list of IDs to filter the training data. If None, all data is used.
234
+ id_keys (list[str] | None): A list of IDs to filter the training data.
235
+ If None, all data is used.
221
236
 
222
237
  Returns:
223
238
  SVC: The trained SVM model.
@@ -34,8 +34,8 @@ class ProbabilisticSingleFilterModel(ProbabilisticFilterModel):
34
34
  ) -> None:
35
35
  """Initialize probabilistic single filter model.
36
36
 
37
- This model uses a Bloom filter to store k-mers from the training sequences. It is designed to
38
- be used with a single filter, which is suitable e.g. for genus-level classification.
37
+ This model uses a Bloom filter to store k-mers from the training sequences. It is designed
38
+ to be used with a single filter, which is suitable e.g. for genus-level classification.
39
39
 
40
40
  Args:
41
41
  k (int): Length of the k-mers to use for filtering
@@ -45,7 +45,7 @@ class ProbabilisticSingleFilterModel(ProbabilisticFilterModel):
45
45
  model_type (str): Type of the model, e.g. "probabilistic_single_filter"
46
46
  base_path (Path): Base path where the model will be saved
47
47
  fpr (float): False positive rate for the Bloom filter, default is 0.01
48
- training_accessions (list[str] | None): List of accessions used for training, default is None
48
+ training_accessions (list[str] | None): List of accessions used for training
49
49
  """
50
50
  super().__init__(
51
51
  k=k,
@@ -75,7 +75,7 @@ class ProbabilisticSingleFilterModel(ProbabilisticFilterModel):
75
75
  Args:
76
76
  file_path (Path): Path to the file containing sequences in FASTA format
77
77
  display_name (str): Display name for the model
78
- training_accessions (list[str] | None): List of accessions used for training, default is None
78
+ training_accessions (list[str] | None): List of accessions used for training
79
79
  """
80
80
  self.training_accessions = training_accessions
81
81
 
@@ -104,7 +104,7 @@ class ProbabilisticSingleFilterModel(ProbabilisticFilterModel):
104
104
  Calculates the number of k-mers in the sequence that are present in the Bloom filter.
105
105
 
106
106
  Args:
107
- sequence (Seq | SeqRecord): Sequence to calculate hits for, can be a Bio.Seq or Bio.SeqRecord object
107
+ sequence (Seq | SeqRecord): Sequence to calculate hits for
108
108
  filter_ids (list[str] | None): List of filter IDs to use, default is None
109
109
  step (int): Step size for generating k-mers, default is 1
110
110
  Returns:
@@ -162,13 +162,15 @@ class ProbabilisticSingleFilterModel(ProbabilisticFilterModel):
162
162
  """
163
163
  Generate kmers from the sequence
164
164
 
165
- Generates k-mers from the sequence, considering both the forward and reverse complement strands.
165
+ Generates k-mers from the sequence, considering both the forward and reverse complement
166
+ strands.
166
167
 
167
168
  Args:
168
169
  sequence (Seq): Sequence to generate k-mers from
169
170
  step (int): Step size for generating k-mers, default is 1
170
171
  Yields:
171
- str: The minimizer k-mer (the lexicographically smallest k-mer between the forward and reverse complement)
172
+ str: The minimizer k-mer (the lexicographically smallest k-mer between the forward and
173
+ reverse complement)
172
174
  """
173
175
  num_kmers = ceil((len(sequence) - self.k + 1) / step)
174
176
  for i in range(num_kmers):
xspect/models/result.py CHANGED
@@ -40,6 +40,7 @@ class ModelResult:
40
40
  self.sparse_sampling_step = sparse_sampling_step
41
41
  self.prediction = prediction
42
42
  self.input_source = input_source
43
+ self.misclassified = self.hits.pop("misclassified", None)
43
44
 
44
45
  def get_scores(self) -> dict:
45
46
  """
@@ -50,7 +51,8 @@ class ModelResult:
50
51
 
51
52
  Returns:
52
53
  dict: A dictionary where keys are subsequence names and values are dictionaries
53
- with labels as keys and scores as values. Also includes a 'total' key for overall scores.
54
+ with labels as keys and scores as values. Also includes a 'total' key for
55
+ overall scores.
54
56
  """
55
57
  scores = {
56
58
  subsequence: {
@@ -78,7 +80,8 @@ class ModelResult:
78
80
  The total hits are calculated by summing the hits for each label across all subsequences.
79
81
 
80
82
  Returns:
81
- dict: A dictionary where keys are labels and values are the total number of hits for that label.
83
+ dict: A dictionary where keys are labels and values are the total number of hits for
84
+ that label.
82
85
  """
83
86
  total_hits = {label: 0 for label in list(self.hits.values())[0]}
84
87
  for _, subsequence_hits in self.hits.items():
@@ -97,8 +100,8 @@ class ModelResult:
97
100
 
98
101
  Args:
99
102
  label (str): The label for which to filter the subsequences.
100
- filter_threshold (float): The threshold for filtering subsequences. Must be between 0 and 1,
101
- or -1 to return the subsequence with the maximum score for the label.
103
+ filter_threshold (float): The threshold for filtering subsequences. Must be between 0
104
+ and 1, or -1 to return the subsequence with the maximum score for the label.
102
105
 
103
106
  Returns:
104
107
  dict[str, bool]: A dictionary where keys are subsequence names and values are booleans
@@ -114,11 +117,10 @@ class ModelResult:
114
117
  subsequence: score[label] >= filter_threshold
115
118
  for subsequence, score in scores.items()
116
119
  }
117
- else:
118
- return {
119
- subsequence: score[label] == max(score.values())
120
- for subsequence, score in scores.items()
121
- }
120
+ return {
121
+ subsequence: score[label] == max(score.values())
122
+ for subsequence, score in scores.items()
123
+ }
122
124
 
123
125
  def get_filtered_subsequence_labels(
124
126
  self, label: str, filter_threshold: float = 0.7
@@ -126,15 +128,17 @@ class ModelResult:
126
128
  """
127
129
  Return the labels of filtered subsequences.
128
130
 
129
- This method filters subsequences based on the scores for a given label and a filter threshold.
131
+ This method filters subsequences based on the scores for a given label and a filter
132
+ threshold.
130
133
 
131
134
  Args:
132
135
  label (str): The label for which to filter the subsequences.
133
- filter_threshold (float): The threshold for filtering subsequences. Must be between 0 and 1,
134
- or -1 to return the subsequence with the maximum score for the label.
136
+ filter_threshold (float): The threshold for filtering subsequences. Must be between 0
137
+ and 1, or -1 to return the subsequence with the maximum score for the label.
135
138
 
136
139
  Returns:
137
- list[str]: A list of subsequence names that meet the filter criteria for the given label.
140
+ list[str]: A list of subsequence names that meet the filter criteria for the given
141
+ label.
138
142
  """
139
143
  return [
140
144
  subsequence
@@ -148,11 +152,13 @@ class ModelResult:
148
152
  """
149
153
  Return the result as a dictionary.
150
154
 
151
- This method converts the ModelResult object into a dictionary format suitable for serialization.
155
+ This method converts the ModelResult object into a dictionary format suitable for
156
+ serialization.
152
157
 
153
158
  Returns:
154
159
  dict: A dictionary representation of the ModelResult object, including model slug,
155
- sparse sampling step, hits, scores, number of k-mers, input source, and prediction if available.
160
+ sparse sampling step, hits, scores, number of k-mers, input source, and prediction if
161
+ available.
156
162
  """
157
163
  res = {
158
164
  "model_slug": self.model_slug,
@@ -160,6 +166,7 @@ class ModelResult:
160
166
  "hits": self.hits,
161
167
  "scores": self.get_scores(),
162
168
  "num_kmers": self.num_kmers,
169
+ "misclassified": self.misclassified,
163
170
  "input_source": self.input_source,
164
171
  }
165
172
 
xspect/ncbi.py CHANGED
@@ -1,10 +1,12 @@
1
1
  """NCBI handler for the NCBI Datasets API."""
2
2
 
3
+ import shutil
3
4
  from enum import Enum
4
5
  from pathlib import Path
5
6
  import time
6
7
  from loguru import logger
7
8
  import requests
9
+ import zipfile
8
10
 
9
11
  # pylint: disable=line-too-long
10
12
 
@@ -194,8 +196,9 @@ class NCBIHandler:
194
196
  assembly_level: AssemblyLevel,
195
197
  assembly_source: AssemblySource,
196
198
  count: int,
197
- min_n50: int = 10000,
198
- exclude_atypical: bool = True,
199
+ min_n50: int,
200
+ exclude_atypical: bool,
201
+ allow_inconclusive: bool,
199
202
  exclude_paired_reports: bool = True,
200
203
  current_version_only: bool = True,
201
204
  ) -> list[str]:
@@ -211,11 +214,13 @@ class NCBIHandler:
211
214
  assembly_level (AssemblyLevel): The assembly level to get the accessions for.
212
215
  assembly_source (AssemblySource): The assembly source to get the accessions for.
213
216
  count (int): The number of accessions to get.
214
- min_n50 (int, optional): The minimum contig n50 to filter the accessions. Defaults to 10000.
215
- exclude_atypical (bool, optional): Whether to exclude atypical accessions. Defaults to True.
217
+ min_n50 (int): The minimum contig n50 to filter the accessions.
218
+ exclude_atypical (bool): Whether to exclude atypical accessions.
219
+ allow_inconclusive (bool): Whether to allow accessions with an inconclusive taxonomy check status.
216
220
  exclude_paired_reports (bool, optional): Whether to exclude paired reports. Defaults to True.
217
221
  current_version_only (bool, optional): Whether to get only the current version of the accessions. Defaults to True.
218
222
 
223
+
219
224
  Returns:
220
225
  list[str]: A list containing the accessions.
221
226
  """
@@ -240,8 +245,11 @@ class NCBIHandler:
240
245
  report["accession"]
241
246
  for report in response["reports"]
242
247
  if report["assembly_stats"]["contig_n50"] >= min_n50
243
- and report["average_nucleotide_identity"]["taxonomy_check_status"]
244
- == "OK"
248
+ and (
249
+ allow_inconclusive
250
+ or report["average_nucleotide_identity"]["taxonomy_check_status"]
251
+ == "OK"
252
+ )
245
253
  ]
246
254
  except (IndexError, KeyError, TypeError):
247
255
  logger.debug(
@@ -251,7 +259,13 @@ class NCBIHandler:
251
259
  return accessions[:count] # Limit to count
252
260
 
253
261
  def get_highest_quality_accessions(
254
- self, taxon_id: int, assembly_source: AssemblySource, count: int
262
+ self,
263
+ taxon_id: int,
264
+ assembly_source: AssemblySource,
265
+ count: int,
266
+ min_n50: int,
267
+ exclude_atypical: bool,
268
+ allow_inconclusive: bool,
255
269
  ) -> list[str]:
256
270
  """
257
271
  Get the highest quality accessions for a given taxon id (based on the assembly level).
@@ -263,6 +277,9 @@ class NCBIHandler:
263
277
  taxon_id (int): The taxon id to get the accessions for.
264
278
  assembly_source (AssemblySource): The assembly source to get the accessions for.
265
279
  count (int): The number of accessions to get.
280
+ min_n50 (int): The minimum contig n50 to filter the accessions.
281
+ exclude_atypical (bool): Whether to exclude atypical accessions.
282
+ allow_inconclusive (bool): Whether to allow accessions with an inconclusive taxonomy check status.
266
283
 
267
284
  Returns:
268
285
  list[str]: A list containing the highest quality accessions.
@@ -274,6 +291,9 @@ class NCBIHandler:
274
291
  assembly_level,
275
292
  assembly_source,
276
293
  count,
294
+ min_n50=min_n50,
295
+ exclude_atypical=exclude_atypical,
296
+ allow_inconclusive=allow_inconclusive,
277
297
  )
278
298
  if len(set(accessions)) >= count:
279
299
  break
@@ -302,3 +322,58 @@ class NCBIHandler:
302
322
  with open(output_dir / "ncbi_dataset.zip", "wb") as f:
303
323
  for chunk in response.iter_content(chunk_size=8192):
304
324
  f.write(chunk)
325
+
326
+ def download_reference_genome(self, taxon_id: int, output_dir: Path) -> Path | None:
327
+ """
328
+ Notes:
329
+ Developed by Oemer Cetin as part of a Bsc thesis at Goethe University Frankfurt am Main (2025).
330
+ (An Integration of Alignment-Free and Alignment-Based Approaches for Bacterial Taxon Assignment)
331
+
332
+ Downloads the reference genome from the RefSeq-DB for a given taxon ID.
333
+
334
+ This function queries the NCBI Datasets API for the reference genome and downloads it.
335
+
336
+ Args:
337
+ taxon_id (int): The taxonomy ID of the species.
338
+ output_dir (Path): Directory where the genome will be saved.
339
+
340
+ Returns:
341
+ Path: Path to the downloaded ZIP file.
342
+ """
343
+ accessions = self.get_accessions(
344
+ taxon_id=taxon_id,
345
+ assembly_level=AssemblyLevel.REFERENCE,
346
+ assembly_source=AssemblySource.REFSEQ,
347
+ count=1, # only one reference exists
348
+ min_n50=0,
349
+ exclude_atypical=True,
350
+ allow_inconclusive=False,
351
+ )
352
+
353
+ if not accessions:
354
+ return None
355
+
356
+ logger.info(
357
+ f"Downloading reference genome for taxon {taxon_id}: {accessions[0]}"
358
+ )
359
+ self.download_assemblies(accessions, output_dir)
360
+
361
+ zip_path = output_dir / "ncbi_dataset.zip"
362
+
363
+ fna_file = ""
364
+ with zipfile.ZipFile(zip_path, "r") as zip_ref:
365
+ for file in zip_ref.namelist():
366
+ if file.endswith(".fna"):
367
+ extracted_path = zip_ref.extract(file, path=output_dir)
368
+ fna_file = output_dir / f"{taxon_id}.fna"
369
+ Path(extracted_path).rename(
370
+ fna_file
371
+ ) # consistent file name (tax_id)
372
+ logger.info(f"Extracted reference genome to {fna_file}")
373
+ break
374
+
375
+ # clean up
376
+ zip_path.unlink()
377
+ shutil.rmtree(output_dir / "ncbi_dataset")
378
+
379
+ return fna_file
xspect/train.py CHANGED
@@ -186,6 +186,11 @@ def train_from_ncbi(
186
186
  author: str | None = None,
187
187
  author_email: str | None = None,
188
188
  ncbi_api_key: str | None = None,
189
+ min_n50: int = 10000,
190
+ exclude_atypical: bool = True,
191
+ allow_inconclusive: bool = False,
192
+ allow_candidatus: bool = False,
193
+ allow_sp: bool = False,
189
194
  ):
190
195
  """
191
196
  Train a model using NCBI assembly data for a given genus.
@@ -200,6 +205,11 @@ def train_from_ncbi(
200
205
  author (str, optional): Author of the model. Defaults to None.
201
206
  author_email (str, optional): Author's email. Defaults to None.
202
207
  ncbi_api_key (str, optional): NCBI API key for accessing NCBI resources. Defaults to None.
208
+ min_n50 (int, optional): Minimum N50 value for assemblies. Defaults to 10000.
209
+ exclude_atypical (bool, optional): Exclude atypical assemblies. Defaults to True.
210
+ allow_inconclusive (bool, optional): Allow use of accessions with inconclusive taxonomy check status. Defaults to False.
211
+ allow_candidatus (bool, optional): Allow use of Candidatus species for training. Defaults to False.
212
+ allow_sp (bool, optional): Allow use of species with "sp." in their names. Defaults to False.
203
213
 
204
214
  Raises:
205
215
  TypeError: If `genus` is not a string.
@@ -221,8 +231,8 @@ def train_from_ncbi(
221
231
  filtered_species_ids = [
222
232
  tax_id
223
233
  for tax_id in species_ids
224
- if "candidatus" not in species_names[tax_id].lower()
225
- and " sp." not in species_names[tax_id].lower()
234
+ if (allow_candidatus or "candidatus" not in species_names[tax_id].lower())
235
+ and (allow_sp or " sp." not in species_names[tax_id].lower())
226
236
  ]
227
237
  filtered_species_names = {
228
238
  str(tax_id): species_names[tax_id] for tax_id in filtered_species_ids
@@ -231,7 +241,12 @@ def train_from_ncbi(
231
241
  accessions = {}
232
242
  for tax_id in filtered_species_ids:
233
243
  taxon_accessions = ncbi_handler.get_highest_quality_accessions(
234
- tax_id, AssemblySource.REFSEQ, 8
244
+ tax_id,
245
+ AssemblySource.REFSEQ,
246
+ 8,
247
+ min_n50,
248
+ exclude_atypical,
249
+ allow_inconclusive,
235
250
  )
236
251
  if not taxon_accessions:
237
252
  logger.warning(f"No assemblies found for tax_id {tax_id}. Skipping.")
@@ -241,7 +256,9 @@ def train_from_ncbi(
241
256
 
242
257
  if not accessions:
243
258
  raise ValueError(
244
- "No species with accessions found. Please check the genus name."
259
+ "No species with accessions found. "
260
+ "Please check if the genus name is correct or if there are any data quality issues "
261
+ "(e. g. inconclusive taxonomy check status, atypical assemblies, low N50 values)."
245
262
  )
246
263
 
247
264
  with TemporaryDirectory() as tmp_dir: