XspecT 0.5.4__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of XspecT might be problematic. Click here for more details.

xspect/classify.py CHANGED
@@ -46,6 +46,7 @@ def classify_species(
46
46
  output_path: Path,
47
47
  step: int = 1,
48
48
  display_name: bool = False,
49
+ validation: bool = False,
49
50
  ):
50
51
  """
51
52
  Classify the species of sequences.
@@ -59,6 +60,7 @@ def classify_species(
59
60
  output_path (Path): The path to the output file where results will be saved.
60
61
  step (int): The amount of kmers to be skipped.
61
62
  display_name (bool): Includes a display name for each tax_ID.
63
+ validation (bool): Sorts out misclassified reads.
62
64
  """
63
65
  ProbabilisticFilterSVMModel = import_module(
64
66
  "xspect.models.probabilistic_filter_svm_model"
@@ -69,7 +71,12 @@ def classify_species(
69
71
  input_paths, get_output_path = prepare_input_output_paths(input_path)
70
72
 
71
73
  for idx, current_path in enumerate(input_paths):
72
- result = model.predict(current_path, step=step, display_name=display_name)
74
+ result = model.predict(
75
+ current_path,
76
+ step=step,
77
+ display_name=display_name,
78
+ validation=validation,
79
+ )
73
80
  result.input_source = current_path.name
74
81
  cls_path = get_output_path(idx, output_path)
75
82
  result.save(cls_path)
xspect/definitions.py CHANGED
@@ -89,3 +89,22 @@ def get_xspect_mlst_path() -> Path:
89
89
  mlst_path = get_xspect_root_path() / "mlst"
90
90
  mlst_path.mkdir(exist_ok=True, parents=True)
91
91
  return mlst_path
92
+
93
+
94
+ def get_xspect_misclassification_path() -> Path:
95
+ """
96
+ Notes:
97
+ Developed by Oemer Cetin as part of a Bsc thesis at Goethe University Frankfurt am Main (2025).
98
+ (An Integration of Alignment-Free and Alignment-Based Approaches for Bacterial Taxon Assignment)
99
+
100
+ Return the path to the XspecT Misclassification directory.
101
+
102
+ Returns the path to the XspecT Misclassification directory, which is located within the XspecT data
103
+ directory. If the directory does not exist, it creates the directory.
104
+
105
+ Returns:
106
+ Path: The path to the XspecT Misclassification directory.
107
+ """
108
+ misclassification_path = get_xspect_root_path() / "misclassification"
109
+ misclassification_path.mkdir(exist_ok=True, parents=True)
110
+ return misclassification_path
xspect/main.py CHANGED
@@ -87,13 +87,62 @@ def train():
87
87
  help="Email of the author.",
88
88
  default=None,
89
89
  )
90
- def train_ncbi(model_genus, svm_steps, author, author_email):
90
+ @click.option(
91
+ "--min-n50",
92
+ type=int,
93
+ help="Minimum contig N50 to filter the accessions (default: 10000).",
94
+ default=10000,
95
+ )
96
+ @click.option(
97
+ "--include-atypical/--exclude-atypical",
98
+ help="Include or exclude atypical accessions (default: exclude).",
99
+ default=False,
100
+ )
101
+ @click.option(
102
+ "--allow-inconclusive",
103
+ is_flag=True,
104
+ help="Allow the use of accessions with inconclusive taxonomy check status for training.",
105
+ default=False,
106
+ )
107
+ @click.option(
108
+ "--allow-candidatus",
109
+ is_flag=True,
110
+ help="Allow the use of Candidatus species for training.",
111
+ default=False,
112
+ )
113
+ @click.option(
114
+ "--allow-sp",
115
+ is_flag=True,
116
+ help="Allow the use of species with 'sp.' in their names for training.",
117
+ default=False,
118
+ )
119
+ def train_ncbi(
120
+ model_genus,
121
+ svm_steps,
122
+ author,
123
+ author_email,
124
+ min_n50,
125
+ include_atypical,
126
+ allow_inconclusive,
127
+ allow_candidatus,
128
+ allow_sp,
129
+ ):
91
130
  """Train a species and a genus model based on NCBI data."""
92
131
  click.echo(f"Training {model_genus} species and genus metagenome model.")
93
132
  try:
94
133
  train_from_ncbi = import_module("xspect.train").train_from_ncbi
95
134
 
96
- train_from_ncbi(model_genus, svm_steps, author, author_email)
135
+ train_from_ncbi(
136
+ model_genus,
137
+ svm_steps,
138
+ author,
139
+ author_email,
140
+ min_n50=min_n50,
141
+ exclude_atypical=not include_atypical,
142
+ allow_inconclusive=allow_inconclusive,
143
+ allow_candidatus=allow_candidatus,
144
+ allow_sp=allow_sp,
145
+ )
97
146
  except ValueError as e:
98
147
  click.echo(f"Error: {e}")
99
148
  return
@@ -287,8 +336,19 @@ def classify_genus(model_genus, input_path, output_path, sparse_sampling_step):
287
336
  help="Includes the display names next to taxonomy-IDs.",
288
337
  is_flag=True,
289
338
  )
339
+ @click.option(
340
+ "-v",
341
+ "--validation",
342
+ help="Detects misclassification for small reads or contigs.",
343
+ is_flag=True,
344
+ )
290
345
  def classify_species(
291
- model_genus, input_path, output_path, sparse_sampling_step, display_names
346
+ model_genus,
347
+ input_path,
348
+ output_path,
349
+ sparse_sampling_step,
350
+ display_names,
351
+ validation,
292
352
  ):
293
353
  """Classify samples using a species model."""
294
354
  click.echo("Classifying...")
@@ -300,6 +360,7 @@ def classify_species(
300
360
  Path(output_path),
301
361
  sparse_sampling_step,
302
362
  display_names,
363
+ validation,
303
364
  )
304
365
 
305
366
 
File without changes
@@ -0,0 +1,168 @@
1
+ """
2
+ Mapping handler for the alignment-based misclassification detection.
3
+
4
+ Notes:
5
+ Developed by Oemer Cetin as part of a Bsc thesis at Goethe University Frankfurt am Main (2025).
6
+ (An Integration of Alignment-Free and Alignment-Based Approaches for Bacterial Taxon Assignment)
7
+ """
8
+
9
+ import mappy, pysam, os, csv
10
+ from Bio import SeqIO
11
+ from xspect.definitions import fasta_endings
12
+
13
+ __author__ = "Cetin, Oemer"
14
+
15
+
16
+ class MappingHandler:
17
+ """Handler class for all mapping related procedures."""
18
+
19
+ def __init__(self, ref_genome_path: str, reads_path: str) -> None:
20
+ """
21
+ Initialise the mapping handler.
22
+
23
+ This method sets up the paths to the reference genome and query sequences.
24
+ Additionally, the paths to the output formats (SAM, BAM and TSV) are generated.
25
+
26
+ Args:
27
+ ref_genome_path (str): The path to the reference genome.
28
+ reads_path (str): The path to the query sequences.
29
+ """
30
+ if not os.path.isfile(ref_genome_path):
31
+ raise ValueError("The path to the reference genome does not exist.")
32
+
33
+ if not os.path.isfile(reads_path):
34
+ raise ValueError("The path to the reads does not exist.")
35
+
36
+ if not ref_genome_path.endswith(tuple(fasta_endings)) and reads_path.endswith(
37
+ tuple(fasta_endings)
38
+ ):
39
+ raise ValueError("The files must be FASTA-files!")
40
+
41
+ stem = reads_path.rsplit(".", 1)[0] + "_mapped"
42
+ self.ref_genome_path = ref_genome_path
43
+ self.reads_path = reads_path
44
+ self.sam = stem + ".sam"
45
+ self.bam = stem + ".sorted.bam"
46
+ self.tsv = stem + ".start_coordinates.tsv"
47
+
48
+ def map_reads_onto_reference(self) -> None:
49
+ """
50
+ A Method that maps reads against the respective reference genome.
51
+
52
+ This function creates a SAM file via Mappy and converts it into a BAM file.
53
+ """
54
+ # create header (entry = sequences of the reference genome)
55
+ ref_seq = [
56
+ {"SN": rec.id, "LN": len(rec.seq)}
57
+ for rec in SeqIO.parse(self.ref_genome_path, "fasta")
58
+ ]
59
+ header = {"HD": {"VN": "1.0"}, "SQ": ref_seq}
60
+ target_id = {sequence["SN"]: number for number, sequence in enumerate(ref_seq)}
61
+
62
+ reads = list(SeqIO.parse(self.reads_path, "fasta"))
63
+ if not reads:
64
+ raise ValueError("Reads file is empty.")
65
+
66
+ read_length = len(reads[0].seq)
67
+ preset = "map-ont" if read_length > 150 else "sr"
68
+ # create SAM-file
69
+ aln = mappy.Aligner(self.ref_genome_path, preset=preset)
70
+ with pysam.AlignmentFile(self.sam, "w", header=header) as out:
71
+ for read in reads:
72
+ read_seq = str(read.seq)
73
+ for hit in aln.map(read_seq):
74
+ if hit.cigar_str is None:
75
+ continue
76
+ # add soft-clips so CIGAR length == len(read_seq) IMPORTANT!!
77
+ leftS = hit.q_st
78
+ rightS = len(read_seq) - hit.q_en
79
+ cigar = (
80
+ (f"{leftS}S" if leftS > 0 else "")
81
+ + hit.cigar_str
82
+ + (f"{rightS}S" if rightS > 0 else "")
83
+ )
84
+
85
+ mapped_region = pysam.AlignedSegment()
86
+ mapped_region.query_name = read.id
87
+ mapped_region.query_sequence = read_seq
88
+ mapped_region.flag = 16 if hit.strand == -1 else 0
89
+ mapped_region.reference_id = target_id[hit.ctg]
90
+ mapped_region.reference_start = hit.r_st
91
+ mapped_region.mapping_quality = (
92
+ hit.mapq or 255
93
+ ) # 0-60 (255 means unavailable)
94
+ mapped_region.cigarstring = cigar
95
+ out.write(mapped_region)
96
+ break # keep only primary
97
+
98
+ # create BAM-file
99
+ pysam.sort("-o", self.bam, self.sam)
100
+ pysam.index(self.bam)
101
+
102
+ def get_total_genome_length(self) -> int:
103
+ """
104
+ Get the genome length from a BAM-file.
105
+
106
+ This function opens a BAM-file and extracts the genome length information.
107
+
108
+ Returns:
109
+ int: The genome length.
110
+ """
111
+ with pysam.AlignmentFile(self.bam, "rb") as bam:
112
+ return sum(bam.lengths)
113
+
114
+ def extract_starting_coordinates(self) -> None:
115
+ """
116
+ Extract starting coordinates of mapped regions from a BAM-file.
117
+
118
+ This function scans through a BAM-file and creates a TSV-file.
119
+ The information that is extracted is the starting coordinate for each mapped read.
120
+ """
121
+ # create tsv-file with all start positions
122
+ with open(self.tsv, "w") as tsv:
123
+ tsv.write("reference_genome\tread\tmapped_starting_coordinate\n")
124
+ try:
125
+ with pysam.AlignmentFile(self.bam, "rb") as bam:
126
+ entry = {
127
+ i: seq["SN"] for i, seq in enumerate(bam.header.to_dict()["SQ"])
128
+ }
129
+ seen = set()
130
+ for ref_seq in bam.references:
131
+ for hit in bam.fetch(ref_seq):
132
+ if (
133
+ hit.is_unmapped
134
+ or hit.is_secondary
135
+ or hit.is_supplementary
136
+ ):
137
+ continue
138
+ key = (hit.reference_id, hit.reference_start)
139
+ if key in seen:
140
+ continue
141
+ seen.add(key)
142
+ tsv.write(
143
+ f"{entry[hit.reference_id]}\t{hit.query_name}\t{hit.reference_start}\n"
144
+ )
145
+ except ValueError:
146
+ tsv.write("dummy_reference\tdummy_read\t1000\n")
147
+
148
+ def get_start_coordinates(self) -> list[int]:
149
+ """
150
+ Get the coordinates of a TSV-file.
151
+
152
+ This function opens a TSV-file and saves all starting coordinates in a list.
153
+
154
+ Returns:
155
+ list[int]: The list containing all starting coordinates.
156
+
157
+ Raises:
158
+ ValueError: If no column with starting coordinates is found.
159
+ """
160
+ coordinates = []
161
+ with open(self.tsv, "r", newline="") as f:
162
+ reader = csv.DictReader(f, delimiter="\t")
163
+ for row in reader:
164
+ val = row.get("mapped_starting_coordinate")
165
+ if val is None:
166
+ raise ValueError("Column with starting coordinates not found.")
167
+ coordinates.append(int(val))
168
+ return coordinates
@@ -0,0 +1,102 @@
1
+ """
2
+ Point pattern density analysis tool for the alignment-based misclassification detection.
3
+
4
+ Notes:
5
+ Developed by Oemer Cetin as part of Bsc thesis (2025), Goethe University Frankfurt am Main.
6
+ (An Integration of Alignment-Free and Alignment-Based Approaches for Bacterial Taxon Assignment)
7
+ """
8
+
9
+ import numpy
10
+
11
+ __author__ = "Cetin, Oemer"
12
+
13
+
14
+ class PointPatternAnalysis:
15
+ """Class for all point pattern density analysis procedures."""
16
+
17
+ def __init__(self, points: list[int], length: int):
18
+ """
19
+ Initialise the class for point pattern analysis.
20
+
21
+ This method sets up the required list with data points (sorted) and the length of the reference genome.
22
+ All required intensity for the statistics is also calculated.
23
+
24
+ Args:
25
+ points (list): The start coordinates of mapped regions on the genome.
26
+ length (int): The length of the reference genome.
27
+ """
28
+ if len(points) < 2:
29
+ raise ValueError("Need at least 2 points.")
30
+ self.sorted_points = numpy.sort(numpy.asarray(points, dtype=float))
31
+ self.n = len(points)
32
+ self.length = float(length)
33
+
34
+ def ripleys_k(self) -> tuple[bool, float, float]:
35
+ """
36
+ Calculates the K-function for the given point distribution.
37
+
38
+ This method calculates the K-function to describe the point distribution.
39
+ The result is than compared with what would be expected under a completely random distribution.
40
+ (Under complete randomness the K-function result is 2*r)
41
+
42
+ Returns:
43
+ tuple: A tuple containing the information whether points are clustered or not.
44
+ """
45
+ r = 0.01 * self.length
46
+ left = 0
47
+ right = 0
48
+ total_neighbors = 0
49
+
50
+ for i in range(self.n):
51
+ while self.sorted_points[i] - self.sorted_points[left] > r:
52
+ left += 1
53
+ if right < i:
54
+ right = i
55
+ while (
56
+ right + 1 < self.n
57
+ and self.sorted_points[right + 1] - self.sorted_points[i] <= r
58
+ ):
59
+ right += 1
60
+ total_neighbors += right - left
61
+ k = (self.length / (self.n * (self.n - 1))) * total_neighbors
62
+ return (k > 2 * r), k, 2 * r
63
+
64
+ def ripleys_k_edge_corrected(self) -> tuple[bool, float, float]:
65
+ """
66
+ Calculates the K-function for the given point distribution with an edge correction factor.
67
+
68
+ This method calculates the K-function to describe the point distribution.
69
+ This time an additional factor is multiplied for each data point to account for edge effects.
70
+ The result is than compared with what would be expected under a completely random distribution.
71
+ (Under complete randomness the K-function result is 2*r)
72
+
73
+ Returns:
74
+ tuple: A tuple containing the information whether the points are clustered or not.
75
+ """
76
+ r = 0.01 * self.length
77
+ left = 0
78
+ right = 0
79
+ total_weighted = 0
80
+
81
+ for i in range(self.n):
82
+ while self.sorted_points[i] - self.sorted_points[left] > r:
83
+ left += 1
84
+ if right < i:
85
+ right = i
86
+ while (
87
+ right + 1 < self.n
88
+ and self.sorted_points[right + 1] - self.sorted_points[i] <= r
89
+ ):
90
+ right += 1
91
+
92
+ neighbors = right - left
93
+ if neighbors > 0:
94
+ a = max(0, self.sorted_points[i] - r)
95
+ b = min(self.length, self.sorted_points[i] + r)
96
+ overlap = b - a
97
+ weight = (2 * r) / overlap if overlap > 0 else 0
98
+
99
+ total_weighted += weight * neighbors
100
+
101
+ k = (self.length / (self.n * (self.n - 1))) * total_weighted
102
+ return (bool(k > 2 * r)), float(k), 2 * r
@@ -0,0 +1,55 @@
1
+ """
2
+ Read simulation for the alignment-based misclassification detection (Used for testing purposes).
3
+
4
+ Notes:
5
+ Developed by Oemer Cetin as part of a Bsc thesis at Goethe University Frankfurt am Main (2025).
6
+ (An Integration of Alignment-Free and Alignment-Based Approaches for Bacterial Taxon Assignment)
7
+ """
8
+
9
+ import random
10
+ from Bio import SeqIO
11
+
12
+ __author__ = "Cetin, Oemer"
13
+
14
+
15
+ def extract_random_reads(
16
+ fasta_file, output_fasta, read_length=150, num_reads=1000, seed=42
17
+ ) -> None:
18
+ """
19
+ Uniformly extracts reads from a genome and writes them to a FASTA-file.
20
+
21
+ Args:
22
+ fasta_file (str): Path to input FASTA file.
23
+ output_fasta (str): Output FASTA file to write simulated reads.
24
+ read_length (int): Length of each read to extract.
25
+ num_reads (int): Total number of reads to extract.
26
+ seed (int): A seed for reproducibility.
27
+
28
+ Raises:
29
+ ValueError: If the sequences are shorter than the chosen read length.
30
+ """
31
+ random.seed(seed)
32
+ sequences = [
33
+ record
34
+ for record in SeqIO.parse(fasta_file, "fasta")
35
+ if len(record.seq) >= read_length
36
+ ]
37
+ if not sequences:
38
+ raise ValueError("No sequences long enough for the desired read length.")
39
+
40
+ # Probability to extract reads from large contigs is higher
41
+ seq_lengths = [len(rec.seq) for rec in sequences]
42
+ total_length = sum(seq_lengths)
43
+ weights = [single_length / total_length for single_length in seq_lengths]
44
+
45
+ with open(output_fasta, "w") as o:
46
+ for i in range(num_reads):
47
+ # random.choices() provides a list!
48
+ selected = random.choices(sequences, weights=weights, k=1)[0]
49
+ seq_length = len(selected.seq)
50
+ start = random.randint(0, seq_length - read_length)
51
+ read_seq = selected.seq[start : start + read_length]
52
+ o.write(
53
+ f">read_{i}_{selected.id}_{start}-{start + read_length}\n{read_seq}\n"
54
+ )
55
+ print("The reads have been simulated successfully.")
@@ -1,6 +1,7 @@
1
1
  """Probabilistic filter model for sequence data"""
2
2
 
3
3
  import json
4
+ import shutil
4
5
  from math import ceil
5
6
  from pathlib import Path
6
7
  from typing import Any
@@ -9,9 +10,19 @@ from Bio.SeqRecord import SeqRecord
9
10
  from Bio import SeqIO
10
11
  from slugify import slugify
11
12
  import cobs_index as cobs
12
- from xspect.definitions import fasta_endings, fastq_endings
13
+ from xspect.definitions import (
14
+ fasta_endings,
15
+ fastq_endings,
16
+ get_xspect_misclassification_path,
17
+ )
13
18
  from xspect.file_io import get_record_iterator
19
+ from xspect.misclassification_detection.mapping import MappingHandler
14
20
  from xspect.models.result import ModelResult
21
+ from collections import defaultdict
22
+ from xspect.ncbi import NCBIHandler
23
+ from xspect.misclassification_detection.point_pattern_analysis import (
24
+ PointPatternAnalysis,
25
+ )
15
26
 
16
27
 
17
28
  class ProbabilisticFilterModel:
@@ -231,6 +242,7 @@ class ProbabilisticFilterModel:
231
242
  filter_ids: list[str] = None,
232
243
  step: int = 1,
233
244
  display_name: bool = False,
245
+ validation: bool = False,
234
246
  ) -> ModelResult:
235
247
  """
236
248
  Returns a model result object for the sequence(s) based on the filters in the model
@@ -248,6 +260,7 @@ class ProbabilisticFilterModel:
248
260
  all results are returned.
249
261
  step (int): The step size for the k-mer search. Default is 1.
250
262
  display_name (bool): Includes a display name for each tax_ID.
263
+ validation (bool): Sorts out misclassified reads.
251
264
 
252
265
  Returns:
253
266
  ModelResult: An object containing the hits for each sequence, the number of kmers,
@@ -260,7 +273,7 @@ class ProbabilisticFilterModel:
260
273
  """
261
274
  if isinstance(sequence_input, (SeqRecord)):
262
275
  return ProbabilisticFilterModel.predict(
263
- self, [sequence_input], filter_ids, step=step, display_name=display_name
276
+ self, [sequence_input], filter_ids, step, display_name, validation
264
277
  )
265
278
 
266
279
  if self._is_sequence_list(sequence_input) | self._is_sequence_iterator(
@@ -268,13 +281,17 @@ class ProbabilisticFilterModel:
268
281
  ):
269
282
  hits = {}
270
283
  num_kmers = {}
284
+ if validation and self._is_sequence_iterator(sequence_input):
285
+ sequence_input = list(sequence_input)
286
+
271
287
  for individual_sequence in sequence_input:
272
288
  individual_hits = self.calculate_hits(
273
- individual_sequence.seq, filter_ids, step=step
289
+ individual_sequence.seq, filter_ids, step
274
290
  )
275
291
  num_kmers[individual_sequence.id] = self._count_kmers(
276
- individual_sequence, step=step
292
+ individual_sequence, step
277
293
  )
294
+
278
295
  if display_name:
279
296
  individual_hits.update(
280
297
  {
@@ -285,7 +302,12 @@ class ProbabilisticFilterModel:
285
302
  for key in list(individual_hits.keys())
286
303
  }
287
304
  )
305
+
288
306
  hits[individual_sequence.id] = individual_hits
307
+
308
+ if validation:
309
+ hits = self.detecting_misclassification(hits, sequence_input)
310
+
289
311
  return ModelResult(self.slug(), hits, num_kmers, sparse_sampling_step=step)
290
312
 
291
313
  if isinstance(sequence_input, Path):
@@ -294,6 +316,7 @@ class ProbabilisticFilterModel:
294
316
  get_record_iterator(sequence_input),
295
317
  step=step,
296
318
  display_name=display_name,
319
+ validation=validation,
297
320
  )
298
321
 
299
322
  raise ValueError(
@@ -476,3 +499,98 @@ class ProbabilisticFilterModel:
476
499
  sequence_input,
477
500
  (SeqIO.FastaIO.FastaIterator, SeqIO.QualityIO.FastqPhredIterator),
478
501
  )
502
+
503
+ def detecting_misclassification(
504
+ self,
505
+ hits: dict[str, dict[str, int]],
506
+ seq_records: list[SeqRecord],
507
+ min_reads: int = 10,
508
+ ) -> dict[str, dict[str, int]]:
509
+ """
510
+ Notes:
511
+ Developed by Oemer Cetin as part of a Bsc thesis at Goethe University Frankfurt am Main (2025).
512
+ (An Integration of Alignment-Free and Alignment-Based Approaches for Bacterial Taxon Assignment)
513
+
514
+ Detects misclassification for short sequences.
515
+
516
+ This function is an alignment-based procedure that groups species by highest XspecT scores.
517
+ Each species group is mapped against the respective reference genome.
518
+ Start coordinates are extracted and scanned for local clustering.
519
+ When local clustering is detected, all sequences belonging to the species are sorted out.
520
+
521
+ Args:
522
+ hits (dict): The species annotations from the prediction step.
523
+ seq_records (list): The provided sequences.
524
+ min_reads (int): Minimum amount of reads, that species groups should have.
525
+
526
+ Returns:
527
+ dict: hits where misclassified sequences have been sorted out.
528
+ """
529
+ rec_by_id = {record.id: record for record in seq_records}
530
+ grouped: dict[int, list[SeqRecord]] = defaultdict(list)
531
+ misclassified = {}
532
+
533
+ # group by species annotation
534
+ for record, score_dict in hits.items():
535
+ if record == "misclassified":
536
+ continue
537
+ sorted_hits = sorted(
538
+ score_dict.items(), key=lambda entry: entry[1], reverse=True
539
+ )
540
+ if sorted_hits[0][1] > sorted_hits[1][1]: # unique highest score
541
+ highest_tax_id = int(sorted_hits[0][0]) # tax_id
542
+ if record in rec_by_id:
543
+ # groups all reads with the highest score by tax_id
544
+ grouped[highest_tax_id].append(rec_by_id[record])
545
+ filtered_grouped = {
546
+ tax_id: seq for tax_id, seq in grouped.items() if len(seq) > min_reads
547
+ }
548
+ largest_group = max(
549
+ filtered_grouped,
550
+ key=lambda tax_id: len(filtered_grouped[tax_id]),
551
+ default=None,
552
+ )
553
+
554
+ # mapping procedure
555
+ handler = NCBIHandler()
556
+ out_dir = get_xspect_misclassification_path()
557
+ out_dir.mkdir(parents=True, exist_ok=True)
558
+ for tax_id, reads in filtered_grouped.items():
559
+ if tax_id == largest_group:
560
+ continue
561
+
562
+ tax_dir = out_dir / str(tax_id)
563
+ tax_dir.mkdir(parents=True, exist_ok=True)
564
+ fasta_path = tax_dir / f"{tax_id}.fasta"
565
+ SeqIO.write(reads, fasta_path, "fasta")
566
+ reference_path = tax_dir / f"{tax_id}.fna"
567
+ # download reference once
568
+ if not (reference_path.exists() and reference_path.stat().st_size > 0):
569
+ handler.download_reference_genome(tax_id, tax_dir)
570
+ if not reference_path.exists():
571
+ shutil.rmtree(tax_dir)
572
+ continue
573
+
574
+ mapping_handler = MappingHandler(str(reference_path), str(fasta_path))
575
+ mapping_handler.map_reads_onto_reference()
576
+ mapping_handler.extract_starting_coordinates()
577
+ genome_length = mapping_handler.get_total_genome_length()
578
+ start_coordinates = mapping_handler.get_start_coordinates()
579
+
580
+ if len(start_coordinates) < min_reads:
581
+ continue
582
+
583
+ # cluster analysis
584
+ analysis = PointPatternAnalysis(start_coordinates, genome_length)
585
+ clustered = analysis.ripleys_k_edge_corrected()
586
+
587
+ if clustered[0]: # True or False
588
+ bucket = misclassified.setdefault(tax_id, {})
589
+ for read in reads:
590
+ data = hits.pop(read.id, None) # remove false reads from main hits
591
+ if data is not None:
592
+ bucket[read.id] = data
593
+
594
+ if misclassified:
595
+ hits["misclassified"] = misclassified
596
+ return hits
@@ -184,6 +184,7 @@ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
184
184
  filter_ids: list[str] = None,
185
185
  step: int = 1,
186
186
  display_name: bool = False,
187
+ validation: bool = False,
187
188
  ) -> ModelResult:
188
189
  """
189
190
  Predict the labels of the sequences.
@@ -198,28 +199,27 @@ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
198
199
  filter_ids (list[str], optional): A list of IDs to filter the predictions.
199
200
  step (int, optional): Step size for sparse sampling. Defaults to 1.
200
201
  display_name (bool): Includes a display name for each tax_ID.
202
+ validation (bool): Sorts out misclassified reads .
201
203
 
202
204
  Returns:
203
205
  ModelResult: The result of the prediction containing hits, number of kmers, and the
204
206
  predicted label.
205
207
  """
206
208
  # get scores and format them for the SVM
207
- res = super().predict(sequence_input, filter_ids, step, display_name)
209
+ res = super().predict(
210
+ sequence_input, filter_ids, step, display_name, validation
211
+ )
208
212
  svm_scores = dict(sorted(res.get_scores()["total"].items()))
209
213
  svm_scores = [list(svm_scores.values())]
210
214
 
211
215
  svm = self._get_svm(filter_ids)
212
- svm_prediction = str(svm.predict(svm_scores)[0])
213
- if display_name:
214
- svm_prediction = f"{svm_prediction} -{self.display_names.get(svm_prediction, 'Unknown')}".replace(
215
- self.model_display_name, "", 1
216
- )
216
+ res.hits["misclassified"] = res.misclassified
217
217
  return ModelResult(
218
218
  self.slug(),
219
219
  res.hits,
220
220
  res.num_kmers,
221
221
  sparse_sampling_step=step,
222
- prediction=svm_prediction,
222
+ prediction=str(svm.predict(svm_scores)[0]),
223
223
  )
224
224
 
225
225
  def _get_svm(self, id_keys) -> SVC:
xspect/models/result.py CHANGED
@@ -40,6 +40,7 @@ class ModelResult:
40
40
  self.sparse_sampling_step = sparse_sampling_step
41
41
  self.prediction = prediction
42
42
  self.input_source = input_source
43
+ self.misclassified = self.hits.pop("misclassified", None)
43
44
 
44
45
  def get_scores(self) -> dict:
45
46
  """
@@ -165,6 +166,7 @@ class ModelResult:
165
166
  "hits": self.hits,
166
167
  "scores": self.get_scores(),
167
168
  "num_kmers": self.num_kmers,
169
+ "misclassified": self.misclassified,
168
170
  "input_source": self.input_source,
169
171
  }
170
172
 
xspect/ncbi.py CHANGED
@@ -1,10 +1,12 @@
1
1
  """NCBI handler for the NCBI Datasets API."""
2
2
 
3
+ import shutil
3
4
  from enum import Enum
4
5
  from pathlib import Path
5
6
  import time
6
7
  from loguru import logger
7
8
  import requests
9
+ import zipfile
8
10
 
9
11
  # pylint: disable=line-too-long
10
12
 
@@ -194,8 +196,9 @@ class NCBIHandler:
194
196
  assembly_level: AssemblyLevel,
195
197
  assembly_source: AssemblySource,
196
198
  count: int,
197
- min_n50: int = 10000,
198
- exclude_atypical: bool = True,
199
+ min_n50: int,
200
+ exclude_atypical: bool,
201
+ allow_inconclusive: bool,
199
202
  exclude_paired_reports: bool = True,
200
203
  current_version_only: bool = True,
201
204
  ) -> list[str]:
@@ -211,11 +214,13 @@ class NCBIHandler:
211
214
  assembly_level (AssemblyLevel): The assembly level to get the accessions for.
212
215
  assembly_source (AssemblySource): The assembly source to get the accessions for.
213
216
  count (int): The number of accessions to get.
214
- min_n50 (int, optional): The minimum contig n50 to filter the accessions. Defaults to 10000.
215
- exclude_atypical (bool, optional): Whether to exclude atypical accessions. Defaults to True.
217
+ min_n50 (int): The minimum contig n50 to filter the accessions.
218
+ exclude_atypical (bool): Whether to exclude atypical accessions.
219
+ allow_inconclusive (bool): Whether to allow accessions with an inconclusive taxonomy check status.
216
220
  exclude_paired_reports (bool, optional): Whether to exclude paired reports. Defaults to True.
217
221
  current_version_only (bool, optional): Whether to get only the current version of the accessions. Defaults to True.
218
222
 
223
+
219
224
  Returns:
220
225
  list[str]: A list containing the accessions.
221
226
  """
@@ -240,8 +245,11 @@ class NCBIHandler:
240
245
  report["accession"]
241
246
  for report in response["reports"]
242
247
  if report["assembly_stats"]["contig_n50"] >= min_n50
243
- and report["average_nucleotide_identity"]["taxonomy_check_status"]
244
- == "OK"
248
+ and (
249
+ allow_inconclusive
250
+ or report["average_nucleotide_identity"]["taxonomy_check_status"]
251
+ == "OK"
252
+ )
245
253
  ]
246
254
  except (IndexError, KeyError, TypeError):
247
255
  logger.debug(
@@ -251,7 +259,13 @@ class NCBIHandler:
251
259
  return accessions[:count] # Limit to count
252
260
 
253
261
  def get_highest_quality_accessions(
254
- self, taxon_id: int, assembly_source: AssemblySource, count: int
262
+ self,
263
+ taxon_id: int,
264
+ assembly_source: AssemblySource,
265
+ count: int,
266
+ min_n50: int,
267
+ exclude_atypical: bool,
268
+ allow_inconclusive: bool,
255
269
  ) -> list[str]:
256
270
  """
257
271
  Get the highest quality accessions for a given taxon id (based on the assembly level).
@@ -263,6 +277,9 @@ class NCBIHandler:
263
277
  taxon_id (int): The taxon id to get the accessions for.
264
278
  assembly_source (AssemblySource): The assembly source to get the accessions for.
265
279
  count (int): The number of accessions to get.
280
+ min_n50 (int): The minimum contig n50 to filter the accessions.
281
+ exclude_atypical (bool): Whether to exclude atypical accessions.
282
+ allow_inconclusive (bool): Whether to allow accessions with an inconclusive taxonomy check status.
266
283
 
267
284
  Returns:
268
285
  list[str]: A list containing the highest quality accessions.
@@ -274,6 +291,9 @@ class NCBIHandler:
274
291
  assembly_level,
275
292
  assembly_source,
276
293
  count,
294
+ min_n50=min_n50,
295
+ exclude_atypical=exclude_atypical,
296
+ allow_inconclusive=allow_inconclusive,
277
297
  )
278
298
  if len(set(accessions)) >= count:
279
299
  break
@@ -302,3 +322,58 @@ class NCBIHandler:
302
322
  with open(output_dir / "ncbi_dataset.zip", "wb") as f:
303
323
  for chunk in response.iter_content(chunk_size=8192):
304
324
  f.write(chunk)
325
+
326
+ def download_reference_genome(self, taxon_id: int, output_dir: Path) -> Path | None:
327
+ """
328
+ Notes:
329
+ Developed by Oemer Cetin as part of a Bsc thesis at Goethe University Frankfurt am Main (2025).
330
+ (An Integration of Alignment-Free and Alignment-Based Approaches for Bacterial Taxon Assignment)
331
+
332
+ Downloads the reference genome from the RefSeq-DB for a given taxon ID.
333
+
334
+ This function queries the NCBI Datasets API for the reference genome and downloads it.
335
+
336
+ Args:
337
+ taxon_id (int): The taxonomy ID of the species.
338
+ output_dir (Path): Directory where the genome will be saved.
339
+
340
+ Returns:
341
+ Path: Path to the downloaded ZIP file.
342
+ """
343
+ accessions = self.get_accessions(
344
+ taxon_id=taxon_id,
345
+ assembly_level=AssemblyLevel.REFERENCE,
346
+ assembly_source=AssemblySource.REFSEQ,
347
+ count=1, # only one reference exists
348
+ min_n50=0,
349
+ exclude_atypical=True,
350
+ allow_inconclusive=False,
351
+ )
352
+
353
+ if not accessions:
354
+ return None
355
+
356
+ logger.info(
357
+ f"Downloading reference genome for taxon {taxon_id}: {accessions[0]}"
358
+ )
359
+ self.download_assemblies(accessions, output_dir)
360
+
361
+ zip_path = output_dir / "ncbi_dataset.zip"
362
+
363
+ fna_file = ""
364
+ with zipfile.ZipFile(zip_path, "r") as zip_ref:
365
+ for file in zip_ref.namelist():
366
+ if file.endswith(".fna"):
367
+ extracted_path = zip_ref.extract(file, path=output_dir)
368
+ fna_file = output_dir / f"{taxon_id}.fna"
369
+ Path(extracted_path).rename(
370
+ fna_file
371
+ ) # consistent file name (tax_id)
372
+ logger.info(f"Extracted reference genome to {fna_file}")
373
+ break
374
+
375
+ # clean up
376
+ zip_path.unlink()
377
+ shutil.rmtree(output_dir / "ncbi_dataset")
378
+
379
+ return fna_file
xspect/train.py CHANGED
@@ -186,6 +186,11 @@ def train_from_ncbi(
186
186
  author: str | None = None,
187
187
  author_email: str | None = None,
188
188
  ncbi_api_key: str | None = None,
189
+ min_n50: int = 10000,
190
+ exclude_atypical: bool = True,
191
+ allow_inconclusive: bool = False,
192
+ allow_candidatus: bool = False,
193
+ allow_sp: bool = False,
189
194
  ):
190
195
  """
191
196
  Train a model using NCBI assembly data for a given genus.
@@ -200,6 +205,11 @@ def train_from_ncbi(
200
205
  author (str, optional): Author of the model. Defaults to None.
201
206
  author_email (str, optional): Author's email. Defaults to None.
202
207
  ncbi_api_key (str, optional): NCBI API key for accessing NCBI resources. Defaults to None.
208
+ min_n50 (int, optional): Minimum N50 value for assemblies. Defaults to 10000.
209
+ exclude_atypical (bool, optional): Exclude atypical assemblies. Defaults to True.
210
+ allow_inconclusive (bool, optional): Allow use of accessions with inconclusive taxonomy check status. Defaults to False.
211
+ allow_candidatus (bool, optional): Allow use of Candidatus species for training. Defaults to False.
212
+ allow_sp (bool, optional): Allow use of species with "sp." in their names. Defaults to False.
203
213
 
204
214
  Raises:
205
215
  TypeError: If `genus` is not a string.
@@ -221,8 +231,8 @@ def train_from_ncbi(
221
231
  filtered_species_ids = [
222
232
  tax_id
223
233
  for tax_id in species_ids
224
- if "candidatus" not in species_names[tax_id].lower()
225
- and " sp." not in species_names[tax_id].lower()
234
+ if (allow_candidatus or "candidatus" not in species_names[tax_id].lower())
235
+ and (allow_sp or " sp." not in species_names[tax_id].lower())
226
236
  ]
227
237
  filtered_species_names = {
228
238
  str(tax_id): species_names[tax_id] for tax_id in filtered_species_ids
@@ -231,7 +241,12 @@ def train_from_ncbi(
231
241
  accessions = {}
232
242
  for tax_id in filtered_species_ids:
233
243
  taxon_accessions = ncbi_handler.get_highest_quality_accessions(
234
- tax_id, AssemblySource.REFSEQ, 8
244
+ tax_id,
245
+ AssemblySource.REFSEQ,
246
+ 8,
247
+ min_n50,
248
+ exclude_atypical,
249
+ allow_inconclusive,
235
250
  )
236
251
  if not taxon_accessions:
237
252
  logger.warning(f"No assemblies found for tax_id {tax_id}. Skipping.")
@@ -241,7 +256,9 @@ def train_from_ncbi(
241
256
 
242
257
  if not accessions:
243
258
  raise ValueError(
244
- "No species with accessions found. Please check the genus name."
259
+ "No species with accessions found. "
260
+ "Please check if the genus name is correct or if there are any data quality issues "
261
+ "(e. g. inconclusive taxonomy check status, atypical assemblies, low N50 values)."
245
262
  )
246
263
 
247
264
  with TemporaryDirectory() as tmp_dir:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: XspecT
3
- Version: 0.5.4
3
+ Version: 0.6.0
4
4
  Summary: Tool to monitor and characterize pathogens using Bloom filters.
5
5
  License: MIT License
6
6
 
@@ -45,6 +45,9 @@ Requires-Dist: xxhash
45
45
  Requires-Dist: fastapi
46
46
  Requires-Dist: uvicorn
47
47
  Requires-Dist: python-multipart
48
+ Requires-Dist: mappy
49
+ Requires-Dist: pysam
50
+ Requires-Dist: numpy
48
51
  Provides-Extra: docs
49
52
  Requires-Dist: mkdocs-material; extra == "docs"
50
53
  Requires-Dist: mkdocs-include-markdown-plugin; extra == "docs"
@@ -1,23 +1,27 @@
1
1
  xspect/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- xspect/classify.py,sha256=SOTfUsEarZVbVtBuwNTlOYQ_MvoW0ADAqH7rTASWEI8,3936
3
- xspect/definitions.py,sha256=8PpU8bpzcwv8PPacncywz-Na_MicMl-JsvjiX3e46yo,2734
2
+ xspect/classify.py,sha256=fpXk8KobN9QQGSWu5pPR9g65bGxlOmcYEwZ5FJ5KSdM,4106
3
+ xspect/definitions.py,sha256=eP2ttvoW9izCfMYqjLL6Y2oead44l74oHFKRDssTIF8,3506
4
4
  xspect/download_models.py,sha256=VALcnowzkUpR-OAvgB5BUdEq9WnyNbli0CxH3OT40Rc,1121
5
5
  xspect/file_io.py,sha256=QX2nBtlLAexBdfUr7rtHLlWOuXiaKvfRdpn1Dn0avnY,8120
6
6
  xspect/filter_sequences.py,sha256=QKjgUCk3RBY3U9hHmyvSQeQt8n1voBna-NjOoTqdp3A,5196
7
- xspect/main.py,sha256=vRCsSH_QVKY6usU5d5pjehPAhk4WJfZ_eyl_0xUGu5E,14137
7
+ xspect/main.py,sha256=dhJa6mUneSI0nkwCWHqHrk9sZJgP0MkaPTSceSq1s4c,15463
8
8
  xspect/model_management.py,sha256=yWbCk6tUn7-OYpzH0BViX2oWr4cdNkEBjrvnaw5GPdQ,4893
9
- xspect/ncbi.py,sha256=VRbFvtfGR4WTsc3buZE9UCabE3OJUTRphDRY20g63-E,11704
10
- xspect/train.py,sha256=jxjK4OqzTywmd5KGPou9A-doH8Nwhlv_xF4X7M6X_jI,11588
9
+ xspect/ncbi.py,sha256=RIwSJcPDREvk_YTNPfT34hQ4mb9nRKoXeNAS8wnLrHY,14413
10
+ xspect/train.py,sha256=VQ5pISDtp0rlGwrYqK_4_OH9CE6n7L1DNZ3txeTty6M,12574
11
11
  xspect/web.py,sha256=kM4BZ3fA0f731EEXScAaiGrJZvjjfep1iC1iZemfazw,7039
12
+ xspect/misclassification_detection/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
+ xspect/misclassification_detection/mapping.py,sha256=xSibBi7-NoR74jakoYBmTW4Gy5FMfRDVF8H3ef5pElI,6698
14
+ xspect/misclassification_detection/point_pattern_analysis.py,sha256=5ivksPAh0zGMVI33ETPrZ01tPE86ELOwI7-XpJNYsQA,3820
15
+ xspect/misclassification_detection/simulate_reads.py,sha256=fxfKNSAwDT7Vnuh-_vgrMTLrfKihocZPU0gokNicey0,2009
12
16
  xspect/mlst_feature/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
17
  xspect/mlst_feature/mlst_helper.py,sha256=pxRX_nRbrTSIFPf_FDV3dxR_FonmGtxttFgqNS7sIxE,8130
14
18
  xspect/mlst_feature/pub_mlst_handler.py,sha256=gX0bgAqXTaW9weWgxcbsiD7UtMGuDD9veE9mj42Ffm8,7685
15
19
  xspect/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
20
  xspect/models/probabilistic_filter_mlst_model.py,sha256=w9ibUkAYA-DSOEkU8fBenlENrs8JwRRLaF5KO1HVKoM,17716
17
- xspect/models/probabilistic_filter_model.py,sha256=pUgkN4E2EO-gePVR4BMndgMhJcyvOfVzfjVypjIz2JA,19047
18
- xspect/models/probabilistic_filter_svm_model.py,sha256=n_HMARvcUMP1i-csiW8uvskcocrvhWMjue7kfsaKPpI,11146
21
+ xspect/models/probabilistic_filter_model.py,sha256=9CyNG-wcvjz1AORzr2hCO_laerYfen1OmWKbRh02nc8,23777
22
+ xspect/models/probabilistic_filter_svm_model.py,sha256=WRXrAWr6f35B4VxgIhhLi1uj8RUZQAAhivjQdePsWhs,11094
19
23
  xspect/models/probabilistic_single_filter_model.py,sha256=vJvKZrAybYHq_UdKQ2GvvVwgTYwqRrL-nDDQZxb6RRc,6828
20
- xspect/models/result.py,sha256=Wpsm9EYrvMazDO0JAqF51Sb8BJqAZwYx4G6-SUOt5-c,7070
24
+ xspect/models/result.py,sha256=v1wslD4RHgoTSwqOVtejfSbEhg4jZ0CDPG55nyhklUg,7185
21
25
  xspect/xspect-web/.gitignore,sha256=_nGOe6uxTzy60tl_CIibnOUhXtP-DkOyuM-_s7m4ROg,253
22
26
  xspect/xspect-web/README.md,sha256=Fa5cCk66ohbqD_AAVgnXUZLhuzshnLxhlUFhxyscScc,1942
23
27
  xspect/xspect-web/components.json,sha256=5emhfq5JRW9J8Zga-1N5jAcj4B-r8VREXnH7Z6tZGNk,425
@@ -78,9 +82,9 @@ xspect/xspect-web/src/components/ui/switch.tsx,sha256=uIqRXtd41ba0eusIEUWVyYZv82
78
82
  xspect/xspect-web/src/components/ui/table.tsx,sha256=M2-TIHKwPFWuXrwysSufdQRSMJT-K9jPzGOokfU6PXo,2463
79
83
  xspect/xspect-web/src/components/ui/tabs.tsx,sha256=BImHKcdDCtrS3CCV1AGgn8qg0b65RB5P-QdH49IAhx0,1955
80
84
  xspect/xspect-web/src/lib/utils.ts,sha256=66ibdQiEHKftZBq1OMLmOKqWma1BkO-O60rc1IQYwLE,165
81
- xspect-0.5.4.dist-info/licenses/LICENSE,sha256=bhBGDKIRUVwYIHGOGO5hshzuVHyqFJajvSOA3XXOLKI,1094
82
- xspect-0.5.4.dist-info/METADATA,sha256=T1EVSE_qesDZjlSCaq3xgnUN57n0NIFjOIQCi4swsEo,4569
83
- xspect-0.5.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
84
- xspect-0.5.4.dist-info/entry_points.txt,sha256=L7qliX3pIuwupQxpuOSsrBJCSHYPOPNEzH8KZKQGGUw,43
85
- xspect-0.5.4.dist-info/top_level.txt,sha256=hdoa4cnBv6OVzpyhMmyxpJxEydH5n2lDciy8urc1paE,7
86
- xspect-0.5.4.dist-info/RECORD,,
85
+ xspect-0.6.0.dist-info/licenses/LICENSE,sha256=bhBGDKIRUVwYIHGOGO5hshzuVHyqFJajvSOA3XXOLKI,1094
86
+ xspect-0.6.0.dist-info/METADATA,sha256=_3lkPNVzF3iUOFGFaGRavSPfuJ_wdCQODA0qfU5GXJg,4632
87
+ xspect-0.6.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
88
+ xspect-0.6.0.dist-info/entry_points.txt,sha256=L7qliX3pIuwupQxpuOSsrBJCSHYPOPNEzH8KZKQGGUw,43
89
+ xspect-0.6.0.dist-info/top_level.txt,sha256=hdoa4cnBv6OVzpyhMmyxpJxEydH5n2lDciy8urc1paE,7
90
+ xspect-0.6.0.dist-info/RECORD,,
File without changes