XspecT 0.5.1__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of XspecT might be problematic. Click here for more details.

Files changed (33) hide show
  1. xspect/classify.py +51 -38
  2. xspect/definitions.py +50 -10
  3. xspect/download_models.py +10 -2
  4. xspect/file_io.py +115 -48
  5. xspect/filter_sequences.py +36 -66
  6. xspect/main.py +44 -11
  7. xspect/mlst_feature/mlst_helper.py +3 -0
  8. xspect/mlst_feature/pub_mlst_handler.py +43 -1
  9. xspect/model_management.py +84 -14
  10. xspect/models/probabilistic_filter_mlst_model.py +75 -37
  11. xspect/models/probabilistic_filter_model.py +194 -12
  12. xspect/models/probabilistic_filter_svm_model.py +99 -6
  13. xspect/models/probabilistic_single_filter_model.py +66 -5
  14. xspect/models/result.py +77 -10
  15. xspect/ncbi.py +48 -12
  16. xspect/train.py +2 -1
  17. xspect/web.py +71 -13
  18. xspect/xspect-web/dist/assets/index-Ceo58xui.css +1 -0
  19. xspect/xspect-web/dist/assets/{index-CMG4V7fZ.js → index-Dt_UlbgE.js} +82 -77
  20. xspect/xspect-web/dist/index.html +2 -2
  21. xspect/xspect-web/src/App.tsx +4 -2
  22. xspect/xspect-web/src/api.tsx +23 -1
  23. xspect/xspect-web/src/components/filter-form.tsx +16 -3
  24. xspect/xspect-web/src/components/filtering-result.tsx +65 -0
  25. xspect/xspect-web/src/components/result.tsx +2 -2
  26. xspect/xspect-web/src/types.tsx +5 -0
  27. {xspect-0.5.1.dist-info → xspect-0.5.3.dist-info}/METADATA +1 -1
  28. {xspect-0.5.1.dist-info → xspect-0.5.3.dist-info}/RECORD +32 -31
  29. {xspect-0.5.1.dist-info → xspect-0.5.3.dist-info}/WHEEL +1 -1
  30. xspect/xspect-web/dist/assets/index-jIKg1HIy.css +0 -1
  31. {xspect-0.5.1.dist-info → xspect-0.5.3.dist-info}/entry_points.txt +0 -0
  32. {xspect-0.5.1.dist-info → xspect-0.5.3.dist-info}/licenses/LICENSE +0 -0
  33. {xspect-0.5.1.dist-info → xspect-0.5.3.dist-info}/top_level.txt +0 -0
xspect/classify.py CHANGED
@@ -4,64 +4,77 @@ import xspect.model_management as mm
4
4
  from xspect.models.probabilistic_filter_mlst_model import (
5
5
  ProbabilisticFilterMlstSchemeModel,
6
6
  )
7
- from xspect.definitions import fasta_endings, fastq_endings
7
+ from xspect.file_io import prepare_input_output_paths
8
8
 
9
9
 
10
10
  def classify_genus(
11
11
  model_genus: str, input_path: Path, output_path: Path, step: int = 1
12
12
  ):
13
- """Classify the input file using the genus model."""
14
- model = mm.get_genus_model(model_genus)
13
+ """
14
+ Classify the genus of sequences.
15
15
 
16
- input_paths = []
17
- input_is_dir = input_path.is_dir()
18
- ending_wildcards = [f"*.{ending}" for ending in fasta_endings + fastq_endings]
16
+ This function classifies input files using the genus model.
17
+ The input path can be a file or directory
19
18
 
20
- if input_is_dir:
21
- input_paths = [p for e in ending_wildcards for p in input_path.glob(e)]
22
- elif input_path.is_file():
23
- input_paths = [input_path]
19
+ Args:
20
+ model_genus (str): The genus model slug.
21
+ input_path (Path): The path to the input file/directory containing sequences.
22
+ output_path (Path): The path to the output file where results will be saved.
23
+ step (int): The amount of kmers to be skipped.
24
+ """
25
+ model = mm.get_genus_model(model_genus)
26
+ input_paths, get_output_path = prepare_input_output_paths(input_path)
24
27
 
25
28
  for idx, current_path in enumerate(input_paths):
26
29
  result = model.predict(current_path, step=step)
27
30
  result.input_source = current_path.name
28
- output_name = (
29
- f"{output_path.stem}_{idx+1}{output_path.suffix}"
30
- if input_is_dir
31
- else output_path.name
32
- )
33
- result.save(output_path.parent / output_name)
34
- print(f"Saved result as {output_name}")
31
+ cls_path = get_output_path(idx, output_path)
32
+ result.save(cls_path)
33
+ print(f"Saved result as {cls_path.name}")
35
34
 
36
35
 
37
- def classify_species(model_genus, input_path, output_path, step=1):
38
- """Classify the input file using the species model."""
39
- model = mm.get_species_model(model_genus)
36
+ def classify_species(
37
+ model_genus: str, input_path: Path, output_path: Path, step: int = 1
38
+ ):
39
+ """
40
+ Classify the species of sequences.
40
41
 
41
- input_paths = []
42
- input_is_dir = input_path.is_dir()
43
- ending_wildcards = [f"*.{ending}" for ending in fasta_endings + fastq_endings]
42
+ This function classifies input files using the species model.
43
+ The input path can be a file or directory
44
44
 
45
- if input_is_dir:
46
- input_paths = [p for e in ending_wildcards for p in input_path.glob(e)]
47
- elif input_path.is_file():
48
- input_paths = [input_path]
45
+ Args:
46
+ model_genus (str): The genus model slug.
47
+ input_path (Path): The path to the input file/directory containing sequences.
48
+ output_path (Path): The path to the output file where results will be saved.
49
+ step (int): The amount of kmers to be skipped.
50
+ """
51
+ model = mm.get_species_model(model_genus)
52
+ input_paths, get_output_path = prepare_input_output_paths(input_path)
49
53
 
50
54
  for idx, current_path in enumerate(input_paths):
51
55
  result = model.predict(current_path, step=step)
52
56
  result.input_source = current_path.name
53
- output_name = (
54
- f"{output_path.stem}_{idx+1}{output_path.suffix}"
55
- if input_is_dir
56
- else output_path.name
57
- )
58
- result.save(output_path.parent / output_name)
59
- print(f"Saved result as {output_name}")
57
+ cls_path = get_output_path(idx, output_path)
58
+ result.save(cls_path)
59
+ print(f"Saved result as {cls_path.name}")
60
+
61
+
62
+ def classify_mlst(input_path: Path, output_path: Path, limit: bool):
63
+ """
64
+ Classify the strain type using the specific MLST model.
60
65
 
66
+ Args:
67
+ input_path (Path): The path to the input file/directory containing sequences.
68
+ output_path (Path): The path to the output file where results will be saved.
69
+ limit (bool): A limit for the highest allele_id results that are shown.
70
+ """
61
71
 
62
- def classify_mlst(input_path, output_path):
63
- """Classify the input file using the MLST model."""
64
72
  scheme_path = pick_scheme_from_models_dir()
65
73
  model = ProbabilisticFilterMlstSchemeModel.load(scheme_path)
66
- result = model.predict(scheme_path, input_path)
67
- result.save(output_path)
74
+ input_paths, get_output_path = prepare_input_output_paths(input_path)
75
+ for idx, current_path in enumerate(input_paths):
76
+ result = model.predict(scheme_path, current_path, step=1, limit=limit)
77
+ result.input_source = current_path.name
78
+ cls_path = get_output_path(idx, output_path)
79
+ result.save(cls_path)
80
+ print(f"Saved result as {cls_path.name}")
xspect/definitions.py CHANGED
@@ -7,8 +7,16 @@ fasta_endings = ["fasta", "fna", "fa", "ffn", "frn"]
7
7
  fastq_endings = ["fastq", "fq"]
8
8
 
9
9
 
10
- def get_xspect_root_path():
11
- """Return the root path for XspecT data."""
10
+ def get_xspect_root_path() -> Path:
11
+ """
12
+ Return the root path for XspecT data.
13
+
14
+ Returns the path to the XspecT data directory, which can be located either in the user's home directory or in the current working directory.
15
+ If neither exists, it creates the directory in the user's home directory.
16
+
17
+ Returns:
18
+ Path: The path to the XspecT data directory.
19
+ """
12
20
 
13
21
  home_based_dir = Path.home() / "xspect-data"
14
22
  if home_based_dir.exists():
@@ -22,29 +30,61 @@ def get_xspect_root_path():
22
30
  return home_based_dir
23
31
 
24
32
 
25
- def get_xspect_model_path():
26
- """Return the path to the XspecT models."""
33
+ def get_xspect_model_path() -> Path:
34
+ """
35
+ Return the path to the XspecT models.
36
+
37
+ Returns the path to the XspecT models directory, which is located within the XspecT data directory.
38
+ If the directory does not exist, it creates the directory.
39
+
40
+ Returns:
41
+ Path: The path to the XspecT models directory.
42
+ """
27
43
  model_path = get_xspect_root_path() / "models"
28
44
  model_path.mkdir(exist_ok=True, parents=True)
29
45
  return model_path
30
46
 
31
47
 
32
- def get_xspect_upload_path():
33
- """Return the path to the XspecT upload directory."""
48
+ def get_xspect_upload_path() -> Path:
49
+ """
50
+ Return the path to the XspecT upload directory.
51
+
52
+ Returns the path to the XspecT uploads directory, which is located within the XspecT data directory.
53
+ If the directory does not exist, it creates the directory.
54
+
55
+ Returns:
56
+ Path: The path to the XspecT uploads directory.
57
+ """
34
58
  upload_path = get_xspect_root_path() / "uploads"
35
59
  upload_path.mkdir(exist_ok=True, parents=True)
36
60
  return upload_path
37
61
 
38
62
 
39
- def get_xspect_runs_path():
40
- """Return the path to the XspecT runs directory."""
63
+ def get_xspect_runs_path() -> Path:
64
+ """
65
+ Return the path to the XspecT runs directory.
66
+
67
+ Returns the path to the XspecT runs directory, which is located within the XspecT data directory.
68
+ If the directory does not exist, it creates the directory.
69
+
70
+ Returns:
71
+ Path: The path to the XspecT runs directory.
72
+ """
41
73
  runs_path = get_xspect_root_path() / "runs"
42
74
  runs_path.mkdir(exist_ok=True, parents=True)
43
75
  return runs_path
44
76
 
45
77
 
46
- def get_xspect_mlst_path():
47
- """Return the path to the XspecT runs directory."""
78
+ def get_xspect_mlst_path() -> Path:
79
+ """
80
+ Return the path to the XspecT MLST directory.
81
+
82
+ Returns the path to the XspecT MLST directory, which is located within the XspecT data directory.
83
+ If the directory does not exist, it creates the directory.
84
+
85
+ Returns:
86
+ Path: The path to the XspecT MLST directory.
87
+ """
48
88
  mlst_path = get_xspect_root_path() / "mlst"
49
89
  mlst_path.mkdir(exist_ok=True, parents=True)
50
90
  return mlst_path
xspect/download_models.py CHANGED
@@ -8,8 +8,16 @@ import requests
8
8
  from xspect.definitions import get_xspect_model_path
9
9
 
10
10
 
11
- def download_test_models(url):
12
- """Download models."""
11
+ def download_test_models(url: str) -> None:
12
+ """
13
+ Download models from the specified URL.
14
+
15
+ This function downloads a zip file from the given URL, extracts its contents,
16
+ and copies the extracted files to the XspecT model directory.
17
+
18
+ Args:
19
+ url (str): The URL from which to download the models.
20
+ """
13
21
  with TemporaryDirectory() as tmp_dir:
14
22
  tmp_dir = Path(tmp_dir)
15
23
  download_path = tmp_dir / "models.zip"
xspect/file_io.py CHANGED
@@ -6,12 +6,20 @@ from json import loads
6
6
  import os
7
7
  from pathlib import Path
8
8
  import zipfile
9
+ from typing import Callable, Iterator
9
10
  from Bio import SeqIO
10
11
  from xspect.definitions import fasta_endings, fastq_endings
11
12
 
12
13
 
13
- def delete_zip_files(dir_path):
14
- """Delete all zip files in the given directory."""
14
+ def delete_zip_files(dir_path) -> None:
15
+ """
16
+ Delete all zip files in the given directory.
17
+
18
+ This function checks each file in the specified directory and removes it if it is a zip file.
19
+
20
+ Args:
21
+ dir_path (Path): Path to the directory where zip files should be deleted.
22
+ """
15
23
  files = os.listdir(dir_path)
16
24
  for file in files:
17
25
  if zipfile.is_zipfile(file):
@@ -19,45 +27,39 @@ def delete_zip_files(dir_path):
19
27
  os.remove(file_path)
20
28
 
21
29
 
22
- def extract_zip(zip_path: Path, unzipped_path: Path):
23
- """Extracts all files from a zip file."""
30
+ def extract_zip(zip_path: Path, unzipped_path: Path) -> None:
31
+ """
32
+ Extracts all files from a zip file.
33
+
34
+ Extracts the contents of the specified zip file to the given directory.
35
+
36
+ Args:
37
+ zip_path (Path): Path to the zip file to be extracted.
38
+ unzipped_path (Path): Path to the directory where the contents will be extracted.
39
+ """
24
40
  unzipped_path.mkdir(parents=True, exist_ok=True)
25
41
 
26
42
  with zipfile.ZipFile(zip_path) as item:
27
43
  item.extractall(unzipped_path)
28
44
 
29
45
 
30
- def concatenate_meta(path: Path, genus: str):
31
- """Concatenates all species files to one fasta file.
32
-
33
- :param path: Path to the directory with the concatenated fasta files.
34
- :type path: Path
35
- :param genus: Genus name.
36
- :type genus: str
46
+ def get_record_iterator(file_path: Path) -> Iterator:
37
47
  """
38
- files_path = path / "concatenate"
39
- meta_path = path / (genus + ".fasta")
40
- files = os.listdir(files_path)
48
+ Returns a record iterator for a fasta or fastq file.
41
49
 
42
- with open(meta_path, "w", encoding="utf-8") as meta_file:
43
- # Write the header.
44
- meta_header = f">{genus} metagenome\n"
45
- meta_file.write(meta_header)
46
-
47
- # Open each concatenated species file and write the sequence in the meta file.
48
- for file in files:
49
- file_ending = str(file).rsplit(".", maxsplit=1)[-1]
50
- if file_ending in fasta_endings:
51
- with open(
52
- (files_path / str(file)), "r", encoding="utf-8"
53
- ) as species_file:
54
- for line in species_file:
55
- if line[0] != ">":
56
- meta_file.write(line.replace("\n", ""))
57
-
58
-
59
- def get_record_iterator(file_path: Path):
60
- """Returns a record iterator for a fasta or fastq file."""
50
+ This function checks the file extension to determine if the file is in fasta or fastq format
51
+ and returns an iterator over the records in the file using Biopython's SeqIO module.
52
+
53
+ Args:
54
+ file_path (Path): Path to the fasta or fastq file.
55
+
56
+ Returns:
57
+ Iterator: An iterator over the records in the file.
58
+
59
+ Raises:
60
+ ValueError: If the file path is not a Path object, does not exist, is not a file,
61
+ or has an invalid file format.
62
+ """
61
63
  if not isinstance(file_path, Path):
62
64
  raise ValueError("Path must be a Path object")
63
65
 
@@ -76,17 +78,18 @@ def get_record_iterator(file_path: Path):
76
78
  raise ValueError("Invalid file format, must be a fasta or fastq file")
77
79
 
78
80
 
79
- def get_records_by_id(file: Path, ids: list[str]):
80
- """Return records with the specified ids."""
81
- records = get_record_iterator(file)
82
- return [record for record in records if record.id in ids]
83
-
81
+ def concatenate_species_fasta_files(
82
+ input_folders: list[Path], output_directory: Path
83
+ ) -> None:
84
+ """
85
+ Concatenate fasta files from different species into one file per species.
84
86
 
85
- def concatenate_species_fasta_files(input_folders: list[Path], output_directory: Path):
86
- """Concatenate fasta files from different species into one file per species.
87
+ This function iterates through each species folder within the given input folder,
88
+ collects all fasta files, and concatenates their contents into a single fasta file
89
+ named after the species.
87
90
 
88
91
  Args:
89
- input_species_folders (list[Path]): List of paths to species folders.
92
+ input_folders (list[Path]): List of paths to species folders.
90
93
  output_directory (Path): Path to the output directory.
91
94
  """
92
95
  for species_folder in input_folders:
@@ -105,15 +108,22 @@ def concatenate_species_fasta_files(input_folders: list[Path], output_directory:
105
108
  f.write(f_in.read())
106
109
 
107
110
 
108
- def concatenate_metagenome(fasta_dir: Path, meta_path: Path):
109
- """Concatenate all fasta files in a directory into one file.
111
+ def concatenate_metagenome(fasta_dir: Path, meta_path: Path) -> None:
112
+ """
113
+ Concatenate all fasta files in a directory into one file.
114
+
115
+ This function searches for all fasta files in the specified directory and writes their contents
116
+ into a single output file. The output file will contain the concatenated sequences from all fasta files.
110
117
 
111
118
  Args:
112
119
  fasta_dir (Path): Path to the directory with the fasta files.
113
120
  meta_path (Path): Path to the output file.
114
121
  """
122
+ fasta_files = [
123
+ file for ending in fasta_endings for file in fasta_dir.glob(f"*.{ending}")
124
+ ]
115
125
  with open(meta_path, "w", encoding="utf-8") as meta_file:
116
- for fasta_file in fasta_dir.glob("*.fasta"):
126
+ for fasta_file in fasta_files:
117
127
  with open(fasta_file, "r", encoding="utf-8") as f_in:
118
128
  meta_file.write(f_in.read())
119
129
 
@@ -121,13 +131,21 @@ def concatenate_metagenome(fasta_dir: Path, meta_path: Path):
121
131
  def get_ncbi_dataset_accession_paths(
122
132
  ncbi_dataset_path: Path,
123
133
  ) -> dict[str, Path]:
124
- """Get the paths of the NCBI dataset accessions.
134
+ """
135
+ Get the paths of the NCBI dataset accessions.
136
+
137
+ This function reads the dataset catalog from the NCBI dataset directory and returns a dictionary
138
+ mapping each accession to its corresponding file path. The first item in the dataset catalog is
139
+ assumed to be a data report, and is skipped.
125
140
 
126
141
  Args:
127
142
  ncbi_dataset_path (Path): Path to the NCBI dataset directory.
128
143
 
129
144
  Returns:
130
145
  dict[str, Path]: Dictionary with the accession as key and the path as value.
146
+
147
+ Raises:
148
+ ValueError: If the dataset path does not exist or is invalid.
131
149
  """
132
150
  data_path = ncbi_dataset_path / "ncbi_dataset" / "data"
133
151
  if not data_path.exists():
@@ -147,13 +165,19 @@ def filter_sequences(
147
165
  input_file: Path,
148
166
  output_file: Path,
149
167
  included_ids: list[str],
150
- ):
151
- """Filter sequences by IDs from an input file and save them to an output file.
168
+ ) -> None:
169
+ """
170
+ Filter sequences by IDs from an input file and save them to an output file.
171
+
172
+ This function reads a fasta or fastq file, filters the sequences based on the provided IDs,
173
+ and writes the matching sequences to an output file. If no IDs are provided, no output file
174
+ is created.
152
175
 
153
176
  Args:
154
177
  input_file (Path): Path to the input file.
155
178
  output_file (Path): Path to the output file.
156
- included_ids (list[str], optional): List of IDs to include. If None, no output file is created.
179
+ included_ids (list[str], optional): List of IDs to include. If None, no output file
180
+ is created.
157
181
  """
158
182
  if not included_ids:
159
183
  print("No IDs provided, no output file will be created.")
@@ -163,3 +187,46 @@ def filter_sequences(
163
187
  for record in get_record_iterator(input_file):
164
188
  if record.id in included_ids:
165
189
  SeqIO.write(record, out_f, "fasta")
190
+
191
+
192
+ def prepare_input_output_paths(
193
+ input_path: Path,
194
+ ) -> tuple[list[Path], Callable[[int, Path], Path]]:
195
+ """
196
+ Processes the input path into a list of input paths and a function generating output paths.
197
+
198
+ This function checks if the input path is a directory or a file. If it is a directory,
199
+ it collects all files with specified fasta and fastq endings. If it is a file, it uses that file
200
+ as the input path. It then returns a list of input file paths and a function that generates
201
+ output paths based on the index of the input file and a specified output path.
202
+
203
+ Args:
204
+ input_path (Path): Path to the directory or file.
205
+
206
+ Returns:
207
+ tuple[list[Path], Callable[[int, Path], Path]]: A tuple containing:
208
+ - A list of input file paths
209
+ - A function that takes an index and the output path,
210
+ and returns the processed output path.
211
+
212
+ Raises:
213
+ ValueError: If the input path is invalid.
214
+ """
215
+ input_is_dir = input_path.is_dir()
216
+ ending_wildcards = [f"*.{ending}" for ending in fasta_endings + fastq_endings]
217
+
218
+ if input_is_dir:
219
+ input_paths = [p for e in ending_wildcards for p in input_path.glob(e)]
220
+ elif input_path.is_file():
221
+ input_paths = [input_path]
222
+ else:
223
+ raise ValueError("Invalid input path")
224
+
225
+ def get_output_path(idx: int, output_path: Path) -> Path:
226
+ return (
227
+ output_path.parent / f"{output_path.stem}_{idx+1}{output_path.suffix}"
228
+ if input_is_dir
229
+ else output_path
230
+ )
231
+
232
+ return input_paths, get_output_path
@@ -1,7 +1,6 @@
1
1
  from pathlib import Path
2
2
  from xspect.model_management import get_genus_model, get_species_model
3
- from xspect.file_io import filter_sequences
4
- from xspect.definitions import fasta_endings, fastq_endings
3
+ from xspect.file_io import filter_sequences, prepare_input_output_paths
5
4
 
6
5
 
7
6
  def filter_species(
@@ -11,67 +10,51 @@ def filter_species(
11
10
  output_path: Path,
12
11
  threshold: float,
13
12
  classification_output_path: Path | None = None,
13
+ sparse_sampling_step: int = 1,
14
14
  ):
15
- """Filter sequences by species.
15
+ """
16
+ Filter sequences by species.
17
+
16
18
  This function filters sequences from the input file based on the species model.
17
- It uses the genus model to identify the genus of the sequences and then applies
18
- the species model to filter the sequences.
19
+ It uses the species model to identify the species of individual sequences and then applies
20
+ a threshold filter the sequences.
19
21
 
20
22
  Args:
21
- model_genus (str): The genus model slug.
22
- model_species (str): The species model slug.
23
+ model_genus (str): The genus of the species model.
24
+ model_species (str): The species to filter by.
23
25
  input_path (Path): The path to the input file containing sequences.
24
26
  output_path (Path): The path to the output file where filtered sequences will be saved.
25
- above this threshold will be included in the output file. A threshold of -1 will
26
- include only sequences if the species score is the highest among the
27
- available species scores.
28
27
  classification_output_path (Path): Optional path to save the classification results.
29
28
  threshold (float): The threshold for filtering sequences. Only sequences with a score
30
29
  above this threshold will be included in the output file. A threshold of -1 will
31
30
  include only sequences if the species score is the highest among the
32
31
  available species scores.
32
+ sparse_sampling_step (int): The step size for sparse sampling. Defaults to 1.
33
33
  """
34
34
  species_model = get_species_model(model_genus)
35
-
36
- input_paths = []
37
- input_is_dir = input_path.is_dir()
38
- ending_wildcards = [f"*.{ending}" for ending in fasta_endings + fastq_endings]
39
-
40
- if input_is_dir:
41
- input_paths = [p for e in ending_wildcards for p in input_path.glob(e)]
42
- elif input_path.is_file():
43
- input_paths = [input_path]
35
+ input_paths, get_output_path = prepare_input_output_paths(input_path)
44
36
 
45
37
  for idx, current_path in enumerate(input_paths):
46
- result = species_model.predict(current_path)
38
+ result = species_model.predict(current_path, step=sparse_sampling_step)
47
39
  result.input_source = current_path.name
48
40
 
49
41
  if classification_output_path:
50
- classification_output_name = (
51
- f"{classification_output_path.stem}_{idx+1}{classification_output_path.suffix}"
52
- if input_is_dir
53
- else classification_output_path.name
54
- )
55
- result.save(classification_output_path.parent / classification_output_name)
42
+ cls_out = get_output_path(idx, classification_output_path)
43
+ result.save(cls_out)
56
44
  print(
57
- f"Saved classification results from {current_path.name} as {classification_output_name}"
45
+ f"Saved classification results from {current_path.name} as {cls_out.name}"
58
46
  )
59
47
 
60
48
  included_ids = result.get_filtered_subsequence_labels(model_species, threshold)
61
49
  if not included_ids:
62
50
  print(f"No sequences found for the given species in {current_path.name}.")
63
51
  continue
64
- output_name = (
65
- f"{output_path.stem}_{idx+1}{output_path.suffix}"
66
- if input_is_dir
67
- else output_path.name
68
- )
69
- filter_sequences(
70
- current_path,
71
- output_path.parent / output_name,
72
- included_ids,
52
+
53
+ filter_output_path = get_output_path(idx, output_path)
54
+ filter_sequences(current_path, filter_output_path, included_ids)
55
+ print(
56
+ f"Saved filtered sequences from {current_path.name} as {filter_output_path.name}"
73
57
  )
74
- print(f"Saved filtered sequences from {current_path.name} as {output_name}")
75
58
 
76
59
 
77
60
  def filter_genus(
@@ -80,8 +63,11 @@ def filter_genus(
80
63
  output_path: Path,
81
64
  threshold: float,
82
65
  classification_output_path: Path | None = None,
66
+ sparse_sampling_step: int = 1,
83
67
  ):
84
- """Filter sequences by genus.
68
+ """
69
+ Filter sequences by genus.
70
+
85
71
  This function filters sequences from the input file based on the genus model.
86
72
  It uses the genus model to identify the genus of the sequences and then applies
87
73
  the filtering based on the provided threshold.
@@ -93,46 +79,30 @@ def filter_genus(
93
79
  threshold (float): The threshold for filtering sequences. Only sequences with a score
94
80
  above this threshold will be included in the output file.
95
81
  classification_output_path (Path): Optional path to save the classification results.
82
+ sparse_sampling_step (int): The step size for sparse sampling. Defaults to 1.
96
83
 
97
84
  """
98
- genus_model = get_genus_model(model_genus)
99
-
100
- input_paths = []
101
- input_is_dir = input_path.is_dir()
102
- ending_wildcards = [f"*.{ending}" for ending in fasta_endings + fastq_endings]
103
-
104
- if input_is_dir:
105
- input_paths = [p for e in ending_wildcards for p in input_path.glob(e)]
106
- elif input_path.is_file():
107
- input_paths = [input_path]
85
+ model = get_genus_model(model_genus)
86
+ input_paths, get_output_path = prepare_input_output_paths(input_path)
108
87
 
109
88
  for idx, current_path in enumerate(input_paths):
110
- result = genus_model.predict(current_path)
89
+ result = model.predict(current_path, step=sparse_sampling_step)
111
90
  result.input_source = current_path.name
112
91
 
113
92
  if classification_output_path:
114
- classification_output_name = (
115
- f"{classification_output_path.stem}_{idx+1}{classification_output_path.suffix}"
116
- if input_is_dir
117
- else classification_output_path.name
118
- )
119
- result.save(classification_output_path.parent / classification_output_name)
93
+ cls_out = get_output_path(idx, classification_output_path)
94
+ result.save(cls_out)
120
95
  print(
121
- f"Saved classification results from {current_path.name} as {classification_output_name}"
96
+ f"Saved classification results from {current_path.name} as {cls_out.name}"
122
97
  )
123
98
 
124
99
  included_ids = result.get_filtered_subsequence_labels(model_genus, threshold)
125
100
  if not included_ids:
126
101
  print(f"No sequences found for the given genus in {current_path.name}.")
127
102
  continue
128
- output_name = (
129
- f"{output_path.stem}_{idx+1}{output_path.suffix}"
130
- if input_is_dir
131
- else output_path.name
132
- )
133
- filter_sequences(
134
- current_path,
135
- output_path.parent / output_name,
136
- included_ids,
103
+
104
+ filter_output_path = get_output_path(idx, output_path)
105
+ filter_sequences(current_path, filter_output_path, included_ids)
106
+ print(
107
+ f"Saved filtered sequences from {current_path.name} as {filter_output_path.name}"
137
108
  )
138
- print(f"Saved filtered sequences from {current_path.name} as {output_name}")