XspecT 0.5.0__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of XspecT might be problematic. Click here for more details.

Files changed (33) hide show
  1. xspect/classify.py +61 -13
  2. xspect/definitions.py +61 -13
  3. xspect/download_models.py +10 -2
  4. xspect/file_io.py +115 -48
  5. xspect/filter_sequences.py +81 -29
  6. xspect/main.py +90 -39
  7. xspect/mlst_feature/mlst_helper.py +3 -0
  8. xspect/mlst_feature/pub_mlst_handler.py +43 -1
  9. xspect/model_management.py +84 -14
  10. xspect/models/probabilistic_filter_mlst_model.py +75 -37
  11. xspect/models/probabilistic_filter_model.py +201 -19
  12. xspect/models/probabilistic_filter_svm_model.py +106 -13
  13. xspect/models/probabilistic_single_filter_model.py +73 -9
  14. xspect/models/result.py +77 -10
  15. xspect/ncbi.py +48 -12
  16. xspect/train.py +19 -11
  17. xspect/web.py +68 -12
  18. xspect/xspect-web/dist/assets/index-Ceo58xui.css +1 -0
  19. xspect/xspect-web/dist/assets/{index-CMG4V7fZ.js → index-Dt_UlbgE.js} +82 -77
  20. xspect/xspect-web/dist/index.html +2 -2
  21. xspect/xspect-web/src/App.tsx +4 -2
  22. xspect/xspect-web/src/api.tsx +23 -1
  23. xspect/xspect-web/src/components/filter-form.tsx +16 -3
  24. xspect/xspect-web/src/components/filtering-result.tsx +65 -0
  25. xspect/xspect-web/src/components/result.tsx +2 -2
  26. xspect/xspect-web/src/types.tsx +5 -0
  27. {xspect-0.5.0.dist-info → xspect-0.5.2.dist-info}/METADATA +11 -5
  28. {xspect-0.5.0.dist-info → xspect-0.5.2.dist-info}/RECORD +32 -31
  29. {xspect-0.5.0.dist-info → xspect-0.5.2.dist-info}/WHEEL +1 -1
  30. xspect/xspect-web/dist/assets/index-jIKg1HIy.css +0 -1
  31. {xspect-0.5.0.dist-info → xspect-0.5.2.dist-info}/entry_points.txt +0 -0
  32. {xspect-0.5.0.dist-info → xspect-0.5.2.dist-info}/licenses/LICENSE +0 -0
  33. {xspect-0.5.0.dist-info → xspect-0.5.2.dist-info}/top_level.txt +0 -0
@@ -15,23 +15,51 @@ from xspect.models.result import ModelResult
15
15
 
16
16
 
17
17
  class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
18
- """Probabilistic filter SVM model for sequence data"""
18
+ """
19
+ Probabilistic filter SVM model for sequence data
20
+
21
+ In addition to the standard probabilistic filter model, this model uses an SVM to predict
22
+ labels based on their scores and training data. It requires the `scikit-learn` library
23
+ to be installed.
24
+ """
19
25
 
20
26
  def __init__(
21
27
  self,
22
28
  k: int,
23
29
  model_display_name: str,
24
- author: str,
25
- author_email: str,
30
+ author: str | None,
31
+ author_email: str | None,
26
32
  model_type: str,
27
33
  base_path: Path,
28
34
  kernel: str,
29
35
  c: float,
30
36
  fpr: float = 0.01,
31
37
  num_hashes: int = 7,
32
- training_accessions: dict[str, list[str]] = None,
33
- svm_accessions: dict[str, list[str]] = None,
38
+ training_accessions: dict[str, list[str]] | None = None,
39
+ svm_accessions: dict[str, list[str]] | None = None,
34
40
  ) -> None:
41
+ """
42
+ Initialize the SVM model with the given parameters.
43
+
44
+ In addition to the standard parameters, this model uses an SVM.
45
+ Therefore, it requires the `kernel` and `C` parameters to be set.
46
+ Furthermore, the `svm_accessions` parameter is used to store which accessions
47
+ are used for training the SVM.
48
+
49
+ Args:
50
+ k (int): The k-mer size for the probabilistic filter.
51
+ model_display_name (str): The display name of the model.
52
+ author (str | None): The author of the model.
53
+ author_email (str | None): The author's email address.
54
+ model_type (str): The type of the model.
55
+ base_path (Path): The base path where the model will be stored.
56
+ kernel (str): The kernel type for the SVM (e.g., 'linear', 'rbf').
57
+ c (float): Regularization parameter for the SVM.
58
+ fpr (float, optional): False positive rate for the probabilistic filter. Defaults to 0.01.
59
+ num_hashes (int, optional): Number of hashes for the probabilistic filter. Defaults to 7.
60
+ training_accessions (dict[str, list[str]] | None, optional): Accessions used for training the probabilistic filter. Defaults to None.
61
+ svm_accessions (dict[str, list[str]] | None, optional): Accessions used for training the SVM. Defaults to None.
62
+ """
35
63
  super().__init__(
36
64
  k=k,
37
65
  model_display_name=model_display_name,
@@ -48,6 +76,12 @@ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
48
76
  self.svm_accessions = svm_accessions
49
77
 
50
78
  def to_dict(self) -> dict:
79
+ """
80
+ Convert the model to a dictionary representation
81
+
82
+ Returns:
83
+ dict: A dictionary containing the model's parameters and state.
84
+ """
51
85
  return super().to_dict() | {
52
86
  "kernel": self.kernel,
53
87
  "C": self.c,
@@ -55,7 +89,13 @@ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
55
89
  }
56
90
 
57
91
  def set_svm_params(self, kernel: str, c: float) -> None:
58
- """Set the parameters for the SVM"""
92
+ """
93
+ Set the parameters for the SVM
94
+
95
+ Args:
96
+ kernel (str): The kernel type for the SVM (e.g., 'linear', 'rbf').
97
+ c (float): Regularization parameter for the SVM.
98
+ """
59
99
  self.kernel = kernel
60
100
  self.c = c
61
101
  self.save()
@@ -64,12 +104,27 @@ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
64
104
  self,
65
105
  dir_path: Path,
66
106
  svm_path: Path,
67
- display_names: dict = None,
107
+ display_names: dict[str, str] | None = None,
68
108
  svm_step: int = 1,
69
- training_accessions: list[str] = None,
70
- svm_accessions: list[str] = None,
109
+ training_accessions: dict[str, list[str]] | None = None,
110
+ svm_accessions: dict[str, list[str]] | None = None,
71
111
  ) -> None:
72
- """Fit the SVM to the sequences and labels"""
112
+ """
113
+ Fit the SVM to the sequences and labels.
114
+
115
+ This method first trains the probabilistic filter model and then
116
+ calculates scores for the SVM training. It expects the sequences to be in
117
+ the specified directory and the SVM training sequences to be in the
118
+ specified SVM path. The scores are saved in a CSV file for later use.
119
+
120
+ Args:
121
+ dir_path (Path): The directory containing the training sequences.
122
+ svm_path (Path): The directory containing the SVM training sequences.
123
+ display_names (dict[str, str] | None): A mapping of accession IDs to display names.
124
+ svm_step (int): Step size for sparse sampling in SVM training.
125
+ training_accessions (dict[str, list[str]] | None): Accessions used for training the probabilistic filter.
126
+ svm_accessions (dict[str, list[str]] | None): Accessions used for training the SVM.
127
+ """
73
128
 
74
129
  # Since the SVM works with score data, we need to train
75
130
  # the underlying data structure for score generation first
@@ -124,7 +179,21 @@ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
124
179
  filter_ids: list[str] = None,
125
180
  step: int = 1,
126
181
  ) -> ModelResult:
127
- """Predict the labels of the sequences"""
182
+ """
183
+ Predict the labels of the sequences.
184
+
185
+ This method uses the SVM to predict labels based on the scores generated
186
+ from the sequences. It expects the sequences to be in a format compatible
187
+ with the probabilistic filter model, and it will return a `ModelResult`.
188
+
189
+ Args:
190
+ sequence_input (SeqRecord | list[SeqRecord] | SeqIO.FastaIO.FastaIterator | SeqIO.QualityIO.FastqPhredIterator | Path): The input sequences to predict.
191
+ filter_ids (list[str], optional): A list of IDs to filter the predictions. Defaults to None.
192
+ step (int, optional): Step size for sparse sampling. Defaults to 1.
193
+
194
+ Returns:
195
+ ModelResult: The result of the prediction containing hits, number of kmers, and the predicted label.
196
+ """
128
197
  # get scores and format them for the SVM
129
198
  res = super().predict(sequence_input, filter_ids, step=step)
130
199
  svm_scores = dict(sorted(res.get_scores()["total"].items()))
@@ -140,7 +209,19 @@ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
140
209
  )
141
210
 
142
211
  def _get_svm(self, id_keys) -> SVC:
143
- """Get the SVM for the given id keys"""
212
+ """
213
+ Get the SVM for the given id keys.
214
+
215
+ This method loads the SVM model from the scores CSV file and trains it
216
+ using the scores from the CSV. If `id_keys` is provided, it filters the
217
+ training data to only include those keys.
218
+
219
+ Args:
220
+ id_keys (list[str] | None): A list of IDs to filter the training data. If None, all data is used.
221
+
222
+ Returns:
223
+ SVC: The trained SVM model.
224
+ """
144
225
  svm = SVC(kernel=self.kernel, C=self.c)
145
226
  # parse csv
146
227
  with open(
@@ -160,7 +241,19 @@ class ProbabilisticFilterSVMModel(ProbabilisticFilterModel):
160
241
 
161
242
  @staticmethod
162
243
  def load(path: Path) -> "ProbabilisticFilterSVMModel":
163
- """Load the model from disk"""
244
+ """
245
+ Load the model from disk
246
+
247
+ Loads the model from the specified path. The path should point to a JSON file
248
+ containing the model's parameters and state. It also checks for the existence of
249
+ the COBS index file.
250
+
251
+ Args:
252
+ path (Path): The path to the model JSON file.
253
+
254
+ Returns:
255
+ ProbabilisticFilterSVMModel: The loaded model instance.
256
+ """
164
257
  with open(path, "r", encoding="utf-8") as file:
165
258
  json_object = file.read()
166
259
  model_json = json.loads(json_object)
@@ -14,19 +14,39 @@ from xspect.file_io import get_record_iterator
14
14
 
15
15
 
16
16
  class ProbabilisticSingleFilterModel(ProbabilisticFilterModel):
17
- """Base probabilistic filter model for sequence data"""
17
+ """
18
+ Probabilistic filter model for sequence data, with a single filter
19
+
20
+ This model uses a Bloom filter to store k-mers from the training sequences. It is designed to
21
+ be used with a single filter, which is suitable e. g. for genus-level classification.
22
+ """
18
23
 
19
24
  def __init__(
20
25
  self,
21
26
  k: int,
22
27
  model_display_name: str,
23
- author: str,
24
- author_email: str,
28
+ author: str | None,
29
+ author_email: str | None,
25
30
  model_type: str,
26
31
  base_path: Path,
27
32
  fpr: float = 0.01,
28
- training_accessions: list[str] = None,
33
+ training_accessions: list[str] | None = None,
29
34
  ) -> None:
35
+ """Initialize probabilistic single filter model.
36
+
37
+ This model uses a Bloom filter to store k-mers from the training sequences. It is designed to
38
+ be used with a single filter, which is suitable e.g. for genus-level classification.
39
+
40
+ Args:
41
+ k (int): Length of the k-mers to use for filtering
42
+ model_display_name (str): Display name of the model
43
+ author (str | None): Author of the model
44
+ author_email (str | None): Email of the author
45
+ model_type (str): Type of the model, e.g. "probabilistic_single_filter"
46
+ base_path (Path): Base path where the model will be saved
47
+ fpr (float): False positive rate for the Bloom filter, default is 0.01
48
+ training_accessions (list[str] | None): List of accessions used for training, default is None
49
+ """
30
50
  super().__init__(
31
51
  k=k,
32
52
  model_display_name=model_display_name,
@@ -41,9 +61,22 @@ class ProbabilisticSingleFilterModel(ProbabilisticFilterModel):
41
61
  self.bf = None
42
62
 
43
63
  def fit(
44
- self, file_path: Path, display_name: str, training_accessions: list[str] = None
64
+ self,
65
+ file_path: Path,
66
+ display_name: str,
67
+ training_accessions: list[str] | None = None,
45
68
  ) -> None:
46
- """Fit the cobs classic index to the sequences and labels"""
69
+ """
70
+ Fit the bloom filter to the sequences.
71
+
72
+ Trains the model by reading sequences from the provided file path,
73
+ generating k-mers, and adding them to the Bloom filter.
74
+
75
+ Args:
76
+ file_path (Path): Path to the file containing sequences in FASTA format
77
+ display_name (str): Display name for the model
78
+ training_accessions (list[str] | None): List of accessions used for training, default is None
79
+ """
47
80
  self.training_accessions = training_accessions
48
81
 
49
82
  # estimate number of kmers
@@ -65,7 +98,18 @@ class ProbabilisticSingleFilterModel(ProbabilisticFilterModel):
65
98
  def calculate_hits(
66
99
  self, sequence: Seq | SeqRecord, filter_ids=None, step: int = 1
67
100
  ) -> dict:
68
- """Calculate the hits for the sequence"""
101
+ """
102
+ Calculate the hits for the sequence
103
+
104
+ Calculates the number of k-mers in the sequence that are present in the Bloom filter.
105
+
106
+ Args:
107
+ sequence (Seq | SeqRecord): Sequence to calculate hits for, can be a Bio.Seq or Bio.SeqRecord object
108
+ filter_ids (list[str] | None): List of filter IDs to use, default is None
109
+ step (int): Step size for generating k-mers, default is 1
110
+ Returns:
111
+ dict: Dictionary with the display name as key and the number of hits as value
112
+ """
69
113
  if isinstance(sequence, SeqRecord):
70
114
  sequence = sequence.seq
71
115
 
@@ -82,7 +126,17 @@ class ProbabilisticSingleFilterModel(ProbabilisticFilterModel):
82
126
 
83
127
  @staticmethod
84
128
  def load(path: Path) -> "ProbabilisticSingleFilterModel":
85
- """Load the model from disk"""
129
+ """
130
+ Load the model from disk
131
+
132
+ This method reads the model's JSON file and the associated Bloom filter file,
133
+ reconstructing the model instance.
134
+
135
+ Args:
136
+ path (Path): Path to the model directory containing the JSON file
137
+ Returns:
138
+ ProbabilisticSingleFilterModel: An instance of the model loaded from disk
139
+ """
86
140
  with open(path, "r", encoding="utf-8") as file:
87
141
  json_object = file.read()
88
142
  model_json = json.loads(json_object)
@@ -105,7 +159,17 @@ class ProbabilisticSingleFilterModel(ProbabilisticFilterModel):
105
159
  return model
106
160
 
107
161
  def _generate_kmers(self, sequence: Seq, step: int = 1):
108
- """Generate kmers from the sequence"""
162
+ """
163
+ Generate kmers from the sequence
164
+
165
+ Generates k-mers from the sequence, considering both the forward and reverse complement strands.
166
+
167
+ Args:
168
+ sequence (Seq): Sequence to generate k-mers from
169
+ step (int): Step size for generating k-mers, default is 1
170
+ Yields:
171
+ str: The minimizer k-mer (the lexicographically smallest k-mer between the forward and reverse complement)
172
+ """
109
173
  num_kmers = ceil((len(sequence) - self.k + 1) / step)
110
174
  for i in range(num_kmers):
111
175
  start_pos = i * step
xspect/models/result.py CHANGED
@@ -14,9 +14,22 @@ class ModelResult:
14
14
  hits: dict[str, dict[str, int]],
15
15
  num_kmers: dict[str, int],
16
16
  sparse_sampling_step: int = 1,
17
- prediction: str = None,
18
- input_source: str = None,
17
+ prediction: str | None = None,
18
+ input_source: str | None = None,
19
19
  ):
20
+ """
21
+ Initialize the ModelResult object.
22
+
23
+ Args:
24
+ model_slug (str): The slug of the model.
25
+ hits (dict[str, dict[str, int]]): A dictionary where keys are subsequence names
26
+ and values are dictionaries with labels as keys and hit counts as values.
27
+ num_kmers (dict[str, int]): A dictionary where keys are subsequence names
28
+ and values are the total number of k-mers for that subsequence.
29
+ sparse_sampling_step (int): The step size for sparse sampling, default is 1.
30
+ prediction (str | None): The prediction made by the model, default is None.
31
+ input_source (str | None): The source of the input data, default is None.
32
+ """
20
33
  if "total" in hits:
21
34
  raise ValueError(
22
35
  "'total' is a reserved key and cannot be used as a subsequence"
@@ -29,7 +42,16 @@ class ModelResult:
29
42
  self.input_source = input_source
30
43
 
31
44
  def get_scores(self) -> dict:
32
- """Return the scores of the model."""
45
+ """
46
+ Return the scores of the model.
47
+
48
+ The scores are calculated as the number of hits divided by the total number of k-mers
49
+ for each subsequence and label. The scores are rounded to two decimal places.
50
+
51
+ Returns:
52
+ dict: A dictionary where keys are subsequence names and values are dictionaries
53
+ with labels as keys and scores as values. Also includes a 'total' key for overall scores.
54
+ """
33
55
  scores = {
34
56
  subsequence: {
35
57
  label: round(hits / self.num_kmers[subsequence], 2)
@@ -50,20 +72,37 @@ class ModelResult:
50
72
  return scores
51
73
 
52
74
  def get_total_hits(self) -> dict[str, int]:
53
- """Return the total hits of the model."""
75
+ """
76
+ Return the total hits of the model.
77
+
78
+ The total hits are calculated by summing the hits for each label across all subsequences.
79
+
80
+ Returns:
81
+ dict: A dictionary where keys are labels and values are the total number of hits for that label.
82
+ """
54
83
  total_hits = {label: 0 for label in list(self.hits.values())[0]}
55
- for _, subseuqence_hits in self.hits.items():
56
- for label, hits in subseuqence_hits.items():
84
+ for _, subsequence_hits in self.hits.items():
85
+ for label, hits in subsequence_hits.items():
57
86
  total_hits[label] += hits
58
87
  return total_hits
59
88
 
60
89
  def get_filter_mask(self, label: str, filter_threshold: float) -> dict[str, bool]:
61
- """Return a mask for filtered subsequences.
90
+ """
91
+ Return a mask for filtered subsequences.
62
92
 
63
93
  The mask is a dictionary with subsequence names as keys and boolean values
64
94
  indicating whether the subsequence is above the filter threshold for the given label.
65
95
  A value of -1 for filter_threshold indicates that the subsequence with the maximum score
66
96
  for the given label should be returned.
97
+
98
+ Args:
99
+ label (str): The label for which to filter the subsequences.
100
+ filter_threshold (float): The threshold for filtering subsequences. Must be between 0 and 1,
101
+ or -1 to return the subsequence with the maximum score for the label.
102
+
103
+ Returns:
104
+ dict[str, bool]: A dictionary where keys are subsequence names and values are booleans
105
+ indicating whether the subsequence meets the filter criteria for the given label.
67
106
  """
68
107
  if filter_threshold < 0 and not filter_threshold == -1 or filter_threshold > 1:
69
108
  raise ValueError("The filter threshold must be between 0 and 1.")
@@ -84,7 +123,19 @@ class ModelResult:
84
123
  def get_filtered_subsequence_labels(
85
124
  self, label: str, filter_threshold: float = 0.7
86
125
  ) -> list[str]:
87
- """Return the labels of filtered subsequences."""
126
+ """
127
+ Return the labels of filtered subsequences.
128
+
129
+ This method filters subsequences based on the scores for a given label and a filter threshold.
130
+
131
+ Args:
132
+ label (str): The label for which to filter the subsequences.
133
+ filter_threshold (float): The threshold for filtering subsequences. Must be between 0 and 1,
134
+ or -1 to return the subsequence with the maximum score for the label.
135
+
136
+ Returns:
137
+ list[str]: A list of subsequence names that meet the filter criteria for the given label.
138
+ """
88
139
  return [
89
140
  subsequence
90
141
  for subsequence, mask in self.get_filter_mask(
@@ -94,7 +145,15 @@ class ModelResult:
94
145
  ]
95
146
 
96
147
  def to_dict(self) -> dict:
97
- """Return the result as a dictionary."""
148
+ """
149
+ Return the result as a dictionary.
150
+
151
+ This method converts the ModelResult object into a dictionary format suitable for serialization.
152
+
153
+ Returns:
154
+ dict: A dictionary representation of the ModelResult object, including model slug,
155
+ sparse sampling step, hits, scores, number of k-mers, input source, and prediction if available.
156
+ """
98
157
  res = {
99
158
  "model_slug": self.model_slug,
100
159
  "sparse_sampling_step": self.sparse_sampling_step,
@@ -110,6 +169,14 @@ class ModelResult:
110
169
  return res
111
170
 
112
171
  def save(self, path: Path) -> None:
113
- """Save the result as a JSON file."""
172
+ """
173
+ Save the result as a JSON file.
174
+
175
+ This method serializes the ModelResult object to a JSON file at the specified path.
176
+
177
+ Args:
178
+ path (Path): The path where the JSON file will be saved.
179
+ """
180
+ path.parent.mkdir(exist_ok=True, parents=True)
114
181
  with open(path, "w", encoding="utf-8") as f:
115
182
  f.write(dumps(self.to_dict(), indent=4))
xspect/ncbi.py CHANGED
@@ -3,6 +3,7 @@
3
3
  from enum import Enum
4
4
  from pathlib import Path
5
5
  import time
6
+ from loguru import logger
6
7
  import requests
7
8
 
8
9
  # pylint: disable=line-too-long
@@ -26,26 +27,35 @@ class AssemblySource(Enum):
26
27
 
27
28
 
28
29
  class NCBIHandler:
29
- """This class uses the NCBI Datasets API to get the taxonomy tree of a given Taxon.
30
+ """
31
+ This class uses the NCBI Datasets API to get data about taxa and their assemblies.
30
32
 
31
- The taxonomy tree consists of only the next children to the parent taxon.
32
- The children are only of the next lower rank of the parent taxon.
33
+ It provides methods to get taxon IDs, species, names, accessions, and download assemblies.
34
+ It also enforces rate limiting to comply with NCBI's API usage policies.
33
35
  """
34
36
 
35
37
  def __init__(
36
38
  self,
37
- api_key: str = None,
39
+ api_key: str | None = None,
38
40
  ):
39
- """Initialise the NCBI handler."""
41
+ """
42
+ Initialise the NCBI handler.
43
+
44
+ This method sets up the base URL for the NCBI Datasets API and initializes the rate limiting parameters.
45
+
46
+ Args:
47
+ api_key (str | None): The NCBI API key. If None, the handler will use the public API without an API key.
48
+ """
40
49
  self.api_key = api_key
41
50
  self.base_url = "https://api.ncbi.nlm.nih.gov/datasets/v2"
42
- self.last_request_time = 0
51
+ self.last_request_time = 0.0
43
52
  self.min_interval = (
44
53
  1 / 10 if api_key else 1 / 5
45
54
  ) # NCBI allows 10 requests per second with if an API key, otherwise 5 requests per second
46
55
 
47
- def _enforce_rate_limit(self):
48
- """Enforce rate limiting for the NCBI Datasets API.
56
+ def _enforce_rate_limit(self) -> None:
57
+ """
58
+ Enforce rate limiting for the NCBI Datasets API.
49
59
 
50
60
  This method ensures that the requests to the API are limited to 5 requests per second
51
61
  without an API key and 10 requests per second with an API key.
@@ -58,7 +68,11 @@ class NCBIHandler:
58
68
  self.last_request_time = now
59
69
 
60
70
  def _make_request(self, endpoint: str, timeout: int = 15) -> dict:
61
- """Make a request to the NCBI Datasets API.
71
+ """
72
+ Make a request to the NCBI Datasets API.
73
+
74
+ This method constructs the full URL for the API endpoint, adds the necessary headers (including the API key if provided),
75
+ and makes a GET request to the API. It also enforces rate limiting before making the request.
62
76
 
63
77
  Args:
64
78
  endpoint (str): The endpoint to make the request to.
@@ -229,7 +243,7 @@ class NCBIHandler:
229
243
  == "OK"
230
244
  ]
231
245
  except (IndexError, KeyError, TypeError):
232
- print(
246
+ logger.debug(
233
247
  f"Could not get {assembly_level.value} accessions for taxon with ID: {taxon_id}. Skipping."
234
248
  )
235
249
  return []
@@ -238,7 +252,20 @@ class NCBIHandler:
238
252
  def get_highest_quality_accessions(
239
253
  self, taxon_id: int, assembly_source: AssemblySource, count: int
240
254
  ) -> list[str]:
241
- """Get the highest quality accessions for a given taxon id (based on the assembly level)."""
255
+ """
256
+ Get the highest quality accessions for a given taxon id (based on the assembly level).
257
+
258
+ This function iterates through the assembly levels in order of quality and retrieves accessions
259
+ until the specified count is reached. It ensures that the accessions are unique and sorted by quality.
260
+
261
+ Args:
262
+ taxon_id (int): The taxon id to get the accessions for.
263
+ assembly_source (AssemblySource): The assembly source to get the accessions for.
264
+ count (int): The number of accessions to get.
265
+
266
+ Returns:
267
+ list[str]: A list containing the highest quality accessions.
268
+ """
242
269
  accessions = []
243
270
  for assembly_level in list(AssemblyLevel):
244
271
  accessions += self.get_accessions(
@@ -252,7 +279,16 @@ class NCBIHandler:
252
279
  return list(set(accessions))[:count] # Remove duplicates and limit to count
253
280
 
254
281
  def download_assemblies(self, accessions: list[str], output_dir: Path) -> None:
255
- """Download assemblies for a list of accessions."""
282
+ """
283
+ Download assemblies for a list of accessions.
284
+
285
+ This function makes a request to the NCBI Datasets API to download the assemblies for the given accessions.
286
+ It saves the downloaded assemblies as a zip file in the specified output directory.
287
+
288
+ Args:
289
+ accessions (list[str]): A list of accessions to download.
290
+ output_dir (Path): The directory where the downloaded assemblies will be saved.
291
+ """
256
292
  endpoint = f"/genome/accession/{','.join(accessions)}/download?include_annotation_type=GENOME_FASTA"
257
293
 
258
294
  self._enforce_rate_limit()
xspect/train.py CHANGED
@@ -25,12 +25,12 @@ def train_from_directory(
25
25
  display_name: str,
26
26
  dir_path: Path,
27
27
  meta: bool = False,
28
- training_accessions: dict[str, list[str]] = None,
29
- svm_accessions: list[str] = None,
28
+ training_accessions: dict[str, list[str]] | None = None,
29
+ svm_accessions: dict[str, list[str]] | None = None,
30
30
  svm_step: int = 1,
31
- translation_dict: dict[str, str] = None,
32
- author: str = None,
33
- author_email: str = None,
31
+ translation_dict: dict[str, str] | None = None,
32
+ author: str | None = None,
33
+ author_email: str | None = None,
34
34
  ):
35
35
  """
36
36
  Train a model from a directory containing training data.
@@ -113,10 +113,11 @@ def train_from_directory(
113
113
  species_dir = tmp_dir / "species"
114
114
  species_dir.mkdir(parents=True, exist_ok=True)
115
115
 
116
- # concatenate files in cobs_training_data for each species
116
+ logger.info("Concatenating genomes for species training...")
117
117
  concatenate_species_fasta_files(cobs_folders, species_dir)
118
118
 
119
119
  if svm_path.exists():
120
+ logger.info("Training species SVM model...")
120
121
  species_model = ProbabilisticFilterSVMModel(
121
122
  k=21,
122
123
  model_display_name=display_name,
@@ -136,6 +137,7 @@ def train_from_directory(
136
137
  svm_accessions=svm_accessions,
137
138
  )
138
139
  else:
140
+ logger.info("Training species model...")
139
141
  species_model = ProbabilisticFilterModel(
140
142
  k=21,
141
143
  model_display_name=display_name,
@@ -153,9 +155,11 @@ def train_from_directory(
153
155
  species_model.save()
154
156
 
155
157
  if meta:
158
+ logger.info("Concatenating genomes for metagenome training...")
156
159
  meta_fasta = tmp_dir / f"{display_name}.fasta"
157
160
  concatenate_metagenome(species_dir, meta_fasta)
158
161
 
162
+ logger.info("Training metagenome model...")
159
163
  genus_model = ProbabilisticSingleFilterModel(
160
164
  k=21,
161
165
  model_display_name=display_name,
@@ -179,10 +183,12 @@ def train_from_directory(
179
183
  def train_from_ncbi(
180
184
  genus: str,
181
185
  svm_step: int = 1,
182
- author: str = None,
183
- author_email: str = None,
186
+ author: str | None = None,
187
+ author_email: str | None = None,
188
+ ncbi_api_key: str | None = None,
184
189
  ):
185
- """Train a model using NCBI assembly data for a given genus.
190
+ """
191
+ Train a model using NCBI assembly data for a given genus.
186
192
 
187
193
  This function trains a probabilistic filter model using the assembly data from NCBI.
188
194
  The training data is downloaded and processed, and the model is saved to the
@@ -193,6 +199,7 @@ def train_from_ncbi(
193
199
  svm_step (int, optional): Step size for SVM training. Defaults to 1.
194
200
  author (str, optional): Author of the model. Defaults to None.
195
201
  author_email (str, optional): Author's email. Defaults to None.
202
+ ncbi_api_key (str, optional): NCBI API key for accessing NCBI resources. Defaults to None.
196
203
 
197
204
  Raises:
198
205
  TypeError: If `genus` is not a string.
@@ -205,7 +212,8 @@ def train_from_ncbi(
205
212
  if not isinstance(genus, str):
206
213
  raise TypeError("genus must be a string")
207
214
 
208
- ncbi_handler = NCBIHandler()
215
+ logger.info("Getting NCBI metadata...")
216
+ ncbi_handler = NCBIHandler(api_key=ncbi_api_key)
209
217
  genus_tax_id = ncbi_handler.get_genus_taxon_id(genus)
210
218
  species_ids = ncbi_handler.get_species(genus_tax_id)
211
219
  species_names = ncbi_handler.get_taxon_names(species_ids)
@@ -243,7 +251,7 @@ def train_from_ncbi(
243
251
  cobs_dir.mkdir(parents=True, exist_ok=True)
244
252
  svm_dir.mkdir(parents=True, exist_ok=True)
245
253
 
246
- # download assemblies
254
+ logger.info("Downloading genomes from NCBI...")
247
255
  all_accessions = sum(accessions.values(), [])
248
256
  batch_size = 100
249
257
  accession_paths = {}