XspecT 0.5.0__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of XspecT might be problematic. Click here for more details.

Files changed (33) hide show
  1. xspect/classify.py +61 -13
  2. xspect/definitions.py +61 -13
  3. xspect/download_models.py +10 -2
  4. xspect/file_io.py +115 -48
  5. xspect/filter_sequences.py +81 -29
  6. xspect/main.py +90 -39
  7. xspect/mlst_feature/mlst_helper.py +3 -0
  8. xspect/mlst_feature/pub_mlst_handler.py +43 -1
  9. xspect/model_management.py +84 -14
  10. xspect/models/probabilistic_filter_mlst_model.py +75 -37
  11. xspect/models/probabilistic_filter_model.py +201 -19
  12. xspect/models/probabilistic_filter_svm_model.py +106 -13
  13. xspect/models/probabilistic_single_filter_model.py +73 -9
  14. xspect/models/result.py +77 -10
  15. xspect/ncbi.py +48 -12
  16. xspect/train.py +19 -11
  17. xspect/web.py +68 -12
  18. xspect/xspect-web/dist/assets/index-Ceo58xui.css +1 -0
  19. xspect/xspect-web/dist/assets/{index-CMG4V7fZ.js → index-Dt_UlbgE.js} +82 -77
  20. xspect/xspect-web/dist/index.html +2 -2
  21. xspect/xspect-web/src/App.tsx +4 -2
  22. xspect/xspect-web/src/api.tsx +23 -1
  23. xspect/xspect-web/src/components/filter-form.tsx +16 -3
  24. xspect/xspect-web/src/components/filtering-result.tsx +65 -0
  25. xspect/xspect-web/src/components/result.tsx +2 -2
  26. xspect/xspect-web/src/types.tsx +5 -0
  27. {xspect-0.5.0.dist-info → xspect-0.5.2.dist-info}/METADATA +11 -5
  28. {xspect-0.5.0.dist-info → xspect-0.5.2.dist-info}/RECORD +32 -31
  29. {xspect-0.5.0.dist-info → xspect-0.5.2.dist-info}/WHEEL +1 -1
  30. xspect/xspect-web/dist/assets/index-jIKg1HIy.css +0 -1
  31. {xspect-0.5.0.dist-info → xspect-0.5.2.dist-info}/entry_points.txt +0 -0
  32. {xspect-0.5.0.dist-info → xspect-0.5.2.dist-info}/licenses/LICENSE +0 -0
  33. {xspect-0.5.0.dist-info → xspect-0.5.2.dist-info}/top_level.txt +0 -0
@@ -12,6 +12,7 @@ from cobs_index import DocumentList
12
12
  from collections import defaultdict
13
13
  from xspect.file_io import get_record_iterator
14
14
  from xspect.mlst_feature.mlst_helper import MlstResult
15
+ from xspect.mlst_feature.pub_mlst_handler import PubMLSTHandler
15
16
 
16
17
 
17
18
  class ProbabilisticFilterMlstSchemeModel:
@@ -19,20 +20,22 @@ class ProbabilisticFilterMlstSchemeModel:
19
20
 
20
21
  def __init__(
21
22
  self,
22
- k: int,
23
- model_display_name: str,
23
+ k_value: int,
24
+ model_name: str,
24
25
  base_path: Path,
26
+ scheme_url: str,
25
27
  fpr: float = 0.001,
26
28
  ) -> None:
27
29
  """Initialise a ProbabilisticFilterMlstSchemeModel object."""
28
- if k < 1:
30
+ if k_value < 1:
29
31
  raise ValueError("Invalid k value, must be greater than 0")
30
32
  if not isinstance(base_path, Path):
31
33
  raise ValueError("Invalid base path, must be a pathlib.Path object")
32
34
 
33
- self.k = k
34
- self.model_display_name = model_display_name
35
+ self.k_value = k_value
36
+ self.model_name = model_name
35
37
  self.base_path = base_path / "MLST"
38
+ self.scheme_url = scheme_url
36
39
  self.fpr = fpr
37
40
  self.model_type = "Strain"
38
41
  self.loci = {}
@@ -49,9 +52,10 @@ class ProbabilisticFilterMlstSchemeModel:
49
52
  dict: The dictionary containing all metadata of an object.
50
53
  """
51
54
  return {
52
- "k": self.k,
53
- "model_display_name": self.model_display_name,
55
+ "k_value": self.k_value,
56
+ "model_name": self.model_name,
54
57
  "model_type": self.model_type,
58
+ "scheme_url": str(self.scheme_url),
55
59
  "fpr": self.fpr,
56
60
  "scheme_path": str(self.scheme_path),
57
61
  "cobs_path": str(self.cobs_path),
@@ -115,7 +119,7 @@ class ProbabilisticFilterMlstSchemeModel:
115
119
  # COBS only accepts strings as paths
116
120
  doclist = DocumentList(str(locus_path))
117
121
  index_params = cobs_index.CompactIndexParameters()
118
- index_params.term_size = self.k # k-mer size
122
+ index_params.term_size = self.k_value # k-mer size
119
123
  index_params.clobber = True # overwrite output and temporary files
120
124
  index_params.false_positive_rate = self.fpr
121
125
 
@@ -130,9 +134,7 @@ class ProbabilisticFilterMlstSchemeModel:
130
134
 
131
135
  def save(self) -> None:
132
136
  """Saves the model to disk"""
133
- scheme = str(self.scheme_path).split("/")[
134
- -1
135
- ] # [-1] -> contains the scheme name
137
+ scheme = str(self.scheme_path).split("/")[-1] # [-1] contains the scheme name
136
138
  json_path = self.base_path / scheme / f"{scheme}.json"
137
139
  json_object = json.dumps(self.to_dict(), indent=4)
138
140
 
@@ -156,9 +158,10 @@ class ProbabilisticFilterMlstSchemeModel:
156
158
  json_object = file.read()
157
159
  model_json = json.loads(json_object)
158
160
  model = ProbabilisticFilterMlstSchemeModel(
159
- model_json["k"],
160
- model_json["model_display_name"],
161
+ model_json["k_value"],
162
+ model_json["model_name"],
161
163
  json_path.parent,
164
+ model_json["scheme_url"],
162
165
  model_json["fpr"],
163
166
  )
164
167
  model.scheme_path = model_json["scheme_path"]
@@ -175,7 +178,12 @@ class ProbabilisticFilterMlstSchemeModel:
175
178
  return model
176
179
 
177
180
  def calculate_hits(
178
- self, cobs_path: Path, sequence: Seq, step: int = 1
181
+ self,
182
+ cobs_path: Path,
183
+ sequence: Seq,
184
+ step: int = 1,
185
+ limit: bool = False,
186
+ limit_number: int = 5,
179
187
  ) -> list[dict]:
180
188
  """
181
189
  Calculates the hits for a sequence.
@@ -189,6 +197,8 @@ class ProbabilisticFilterMlstSchemeModel:
189
197
  cobs_path (Path): The path of the COBS-structure directory.
190
198
  sequence (Seq): The input sequence for classification.
191
199
  step (int, optional): The amount of kmers that are passed; defaults to one.
200
+ limit (bool): Applying a filter that limits the best result.
201
+ limit_number (int): The amount of results when the filter is set to true.
192
202
 
193
203
  Returns:
194
204
  list[dict]: The results of the prediction.
@@ -201,7 +211,7 @@ class ProbabilisticFilterMlstSchemeModel:
201
211
  if not isinstance(sequence, Seq):
202
212
  raise ValueError("Invalid sequence, must be a Bio.Seq object")
203
213
 
204
- if not len(sequence) > self.k:
214
+ if not len(sequence) > self.k_value:
205
215
  raise ValueError("Invalid sequence, must be longer than k")
206
216
 
207
217
  if not self.indices:
@@ -239,6 +249,10 @@ class ProbabilisticFilterMlstSchemeModel:
239
249
  sorted_counts = dict(
240
250
  sorted(all_counts.items(), key=lambda item: -item[1])
241
251
  )
252
+
253
+ if limit:
254
+ sorted_counts = dict(list(sorted_counts.items())[:limit_number])
255
+
242
256
  if not sorted_counts:
243
257
  result_dict = "A Strain type could not be detected because of no kmer matches!"
244
258
  highest_results[scheme_path_list[counter]] = {"N/A": 0}
@@ -250,25 +264,37 @@ class ProbabilisticFilterMlstSchemeModel:
250
264
  first_key: highest_result
251
265
  }
252
266
  counter += 1
253
- else:
267
+ else: # No split procedure is needed, when the sequence is short
254
268
  for index in self.indices:
255
- res = index.search(
269
+ res = index.search( # COBS can't handle Seq-Objects
256
270
  str(sequence), step=step
257
- ) # COBS can't handle Seq-Objects
258
- result_dict[scheme_path_list[counter]] = self.get_cobs_result(
259
- res, False
260
271
  )
261
- first_key, highest_result = next(
262
- iter(result_dict[scheme_path_list[counter]].items())
272
+ result = self.get_cobs_result(res, False)
273
+ result = (
274
+ dict(sorted(result.items(), key=lambda x: -x[1])[:limit_number])
275
+ if limit
276
+ else result
263
277
  )
278
+ result_dict[scheme_path_list[counter]] = result
279
+ first_key, highest_result = next(iter(result.items()))
264
280
  highest_results[scheme_path_list[counter]] = {first_key: highest_result}
265
281
  counter += 1
282
+
266
283
  # check if the strain type has sufficient amount of kmer hits
267
284
  is_valid = self.has_sufficient_score(highest_results, self.avg_locus_bp_size)
268
285
  if not is_valid:
269
286
  highest_results["Attention:"] = (
270
287
  "This strain type is not reliable due to low kmer hit rates!"
271
288
  )
289
+ else:
290
+ handler = PubMLSTHandler()
291
+ # allele_id is of type dict
292
+ flattened = {
293
+ locus: int(list(allele_id.keys())[0].split("_")[-1])
294
+ for locus, allele_id in highest_results.items()
295
+ }
296
+ strain_type_name = handler.get_strain_type_name(flattened, self.scheme_url)
297
+ highest_results["ST_Name"] = strain_type_name
272
298
  return [{"Strain type": highest_results}, {"All results": result_dict}]
273
299
 
274
300
  def predict(
@@ -282,6 +308,7 @@ class ProbabilisticFilterMlstSchemeModel:
282
308
  | Path
283
309
  ),
284
310
  step: int = 1,
311
+ limit: bool = False,
285
312
  ) -> MlstResult:
286
313
  """
287
314
  Get scores for the sequence(s) based on the filters in the model.
@@ -290,6 +317,7 @@ class ProbabilisticFilterMlstSchemeModel:
290
317
  cobs_path (Path): The path of the COBS-structure directory.
291
318
  sequence_input (Seq): The input sequence for classification
292
319
  step (int, optional): The amount of kmers that are passed; defaults to one
320
+ limit (bool, optional): Applying a filter that limits the best result.
293
321
 
294
322
  Returns:
295
323
  MlstResult: The results of the prediction.
@@ -301,13 +329,19 @@ class ProbabilisticFilterMlstSchemeModel:
301
329
  if sequence_input.id == "<unknown id>":
302
330
  sequence_input.id = "test"
303
331
  hits = {
304
- sequence_input.id: self.calculate_hits(cobs_path, sequence_input.seq)
332
+ sequence_input.id: self.calculate_hits(
333
+ cobs_path, sequence_input.seq, step, limit
334
+ )
305
335
  }
306
- return MlstResult(self.model_display_name, step, hits)
336
+ return MlstResult(self.model_name, step, hits, None)
307
337
 
308
338
  if isinstance(sequence_input, Path):
309
339
  return ProbabilisticFilterMlstSchemeModel.predict(
310
- self, cobs_path, get_record_iterator(sequence_input), step=step
340
+ self,
341
+ cobs_path,
342
+ get_record_iterator(sequence_input),
343
+ step=step,
344
+ limit=limit,
311
345
  )
312
346
 
313
347
  if isinstance(
@@ -317,33 +351,35 @@ class ProbabilisticFilterMlstSchemeModel:
317
351
  hits = {}
318
352
  # individual_seq is a SeqRecord-Object
319
353
  for individual_seq in sequence_input:
320
- individual_hits = self.calculate_hits(cobs_path, individual_seq.seq)
354
+ individual_hits = self.calculate_hits(
355
+ cobs_path, individual_seq.seq, step, limit
356
+ )
321
357
  hits[individual_seq.id] = individual_hits
322
- return MlstResult(self.model_display_name, step, hits)
323
-
358
+ return MlstResult(self.model_name, step, hits, None)
324
359
  raise ValueError(
325
360
  "Invalid sequence input, must be a Seq object, a list of Seq objects, a"
326
361
  " SeqIO FastaIterator, or a SeqIO FastqPhredIterator"
327
362
  )
328
363
 
329
364
  def get_cobs_result(
330
- self, cobs_result: cobs_index.SearchResult, kmer_threshold: bool
365
+ self,
366
+ cobs_result: cobs_index.SearchResult,
367
+ kmer_threshold: bool,
331
368
  ) -> dict:
332
369
  """
333
370
  Get every entry in a COBS search result.
334
371
 
335
372
  Args:
336
373
  cobs_result (SearchResult): The result of the prediction.
337
- kmer_threshold (bool): Applying a kmer threshold to mitigate false positives
374
+ kmer_threshold (bool): Applying a kmer threshold to mitigate false positives.
338
375
 
339
376
  Returns:
340
377
  dict: A dictionary storing the allele id of locus as key and the score as value.
341
378
  """
342
- return {
343
- individual_result.doc_name: individual_result.score
344
- for individual_result in cobs_result
345
- if not kmer_threshold or individual_result.score > 50
346
- }
379
+ hits = [
380
+ result for result in cobs_result if not kmer_threshold or result.score > 50
381
+ ]
382
+ return {result.doc_name: result.score for result in hits}
347
383
 
348
384
  def sequence_splitter(self, input_sequence: str, allele_len: int) -> list[str]:
349
385
  """
@@ -379,13 +415,15 @@ class ProbabilisticFilterMlstSchemeModel:
379
415
 
380
416
  while start + substring_length <= sequence_len:
381
417
  substring_list.append(input_sequence[start : start + substring_length])
382
- start += substring_length - self.k + 1 # To not lose kmers when dividing
418
+ start += (
419
+ substring_length - self.k_value + 1
420
+ ) # To not lose kmers when dividing
383
421
 
384
422
  # The remaining string is either appended to the list or added to the last entry.
385
423
  if start < len(input_sequence):
386
424
  remaining_substring = input_sequence[start:]
387
425
  # A substring needs to be at least of size k for COBS.
388
- if len(remaining_substring) < self.k:
426
+ if len(remaining_substring) < self.k_value:
389
427
  substring_list[-1] += remaining_substring
390
428
  else:
391
429
  substring_list.append(remaining_substring)
@@ -3,6 +3,7 @@
3
3
  import json
4
4
  from math import ceil
5
5
  from pathlib import Path
6
+ from typing import Any
6
7
  from Bio.Seq import Seq
7
8
  from Bio.SeqRecord import SeqRecord
8
9
  from Bio import SeqIO
@@ -20,14 +21,33 @@ class ProbabilisticFilterModel:
20
21
  self,
21
22
  k: int,
22
23
  model_display_name: str,
23
- author: str,
24
- author_email: str,
24
+ author: str | None,
25
+ author_email: str | None,
25
26
  model_type: str,
26
27
  base_path: Path,
27
28
  fpr: float = 0.01,
28
29
  num_hashes: int = 7,
29
- training_accessions: dict[str, list[str]] = None,
30
+ training_accessions: dict[str, list[str]] | None = None,
30
31
  ) -> None:
32
+ """
33
+ Initializes the probabilistic filter model.
34
+
35
+ This method sets up the model with the specified parameters, including the k-mer size,
36
+ display name, author information, model type, base path for storage, false positive rate,
37
+ number of hashes, and training accessions.
38
+
39
+ Args:
40
+ k (int): The size of the k-mers to be used in the model.
41
+ model_display_name (str): The display name of the model.
42
+ author (str | None): The name of the author of the model.
43
+ author_email (str | None): The email of the author of the model.
44
+ model_type (str): The type of the model.
45
+ base_path (Path): The base path where the model will be stored.
46
+ fpr (float): The false positive rate for the model. Default is 0.01.
47
+ num_hashes (int): The number of hashes to use in the model. Default is 7.
48
+ training_accessions (dict[str, list[str]] | None): A dictionary mapping filter IDs to
49
+ lists of accession numbers used for training the model. Default is None.
50
+ """
31
51
  if k < 1:
32
52
  raise ValueError("Invalid k value, must be greater than 0")
33
53
  if not model_display_name:
@@ -49,12 +69,28 @@ class ProbabilisticFilterModel:
49
69
  self.index = None
50
70
  self.training_accessions = training_accessions
51
71
 
52
- def get_cobs_index_path(self) -> Path:
53
- """Returns the path to the cobs index"""
72
+ def get_cobs_index_path(self) -> str:
73
+ """
74
+ Returns the path to the cobs inde
75
+
76
+ This method constructs the path where the cobs index file will be stored,
77
+ based on the model's slug and the base path.
78
+
79
+ Returns:
80
+ str: The path to the cobs index file.
81
+ """
54
82
  return str(self.base_path / self.slug() / "index.cobs_classic")
55
83
 
56
84
  def to_dict(self) -> dict:
57
- """Returns a dictionary representation of the model"""
85
+ """
86
+ Returns a dictionary representation of the model
87
+
88
+ This method includes all relevant attributes of the model, such as k-mer size,
89
+ display name, author information, model type, and other parameters.
90
+
91
+ Returns:
92
+ dict: A dictionary containing the model's attributes.
93
+ """
58
94
  return {
59
95
  "model_slug": self.slug(),
60
96
  "k": self.k,
@@ -70,16 +106,40 @@ class ProbabilisticFilterModel:
70
106
  }
71
107
 
72
108
  def slug(self) -> str:
73
- """Returns a slug representation of the model"""
109
+ """
110
+ Returns a slug representation of the model
111
+
112
+ This method generates a slug based on the model's display name and type,
113
+ which can be used for file naming or identification purposes.
114
+
115
+ Returns:
116
+ str: A slug representation of the model.
117
+ """
74
118
  return slugify(self.model_display_name + "-" + str(self.model_type))
75
119
 
76
120
  def fit(
77
121
  self,
78
122
  dir_path: Path,
79
- display_names: dict = None,
80
- training_accessions: dict[str, list[str]] = None,
123
+ display_names: dict | None = None,
124
+ training_accessions: dict[str, list[str]] | None = None,
81
125
  ) -> None:
82
- """Adds filters to the model"""
126
+ """
127
+ Adds filters to the model
128
+
129
+ This method constructs the model's index from sequence files in the specified directory.
130
+ It reads files with specified extensions (fasta and fastq), constructs a document list,
131
+ and builds a cobs index for efficient searching.
132
+
133
+ Args:
134
+ dir_path (Path): The directory containing sequence files to be indexed.
135
+ display_names (dict | None): A dictionary mapping file names to display names.
136
+ If None, uses file names as display names.
137
+ training_accessions (dict[str, list[str]] | None): A dictionary mapping filter IDs to
138
+ lists of accession numbers used for training the model. If None, no training accessions
139
+ are set.
140
+ Raises:
141
+ ValueError: If the directory path is invalid, does not exist, or is not a directory.
142
+ """
83
143
 
84
144
  if display_names is None:
85
145
  display_names = {}
@@ -123,10 +183,28 @@ class ProbabilisticFilterModel:
123
183
  self.index = cobs.Search(self.get_cobs_index_path(), True)
124
184
 
125
185
  def calculate_hits(
126
- self, sequence: Seq, filter_ids: list[str] = None, step: int = 1
186
+ self, sequence: Seq, filter_ids: list[str] | None = None, step: int = 1
127
187
  ) -> dict:
128
- """Calculates the hits for a sequence"""
129
-
188
+ """
189
+ Calculates the hits for a sequence
190
+
191
+ This method searches the model's index for the given sequence and returns a dictionary
192
+ of filter IDs and their corresponding scores. If filter_ids is provided, it filters the
193
+ results to only include those IDs.
194
+
195
+ Args:
196
+ sequence (Seq): The sequence to search for in the model's index.
197
+ filter_ids (list[str] | None): A list of filter IDs to filter the results. If None,
198
+ all results are returned.
199
+ step (int): The step size for the k-mer search. Default is 1.
200
+
201
+ Returns:
202
+ dict: A dictionary where keys are filter IDs and values are scores for the sequence.
203
+
204
+ Raises:
205
+ ValueError: If the sequence is not a valid Bio.Seq or Bio.SeqRecord object,
206
+ if the sequence length is not greater than k, or if the input is invalid.
207
+ """
130
208
  if not isinstance(sequence, (Seq)):
131
209
  raise ValueError(
132
210
  "Invalid sequence, must be a Bio.Seq or a Bio.SeqRecord object"
@@ -153,7 +231,30 @@ class ProbabilisticFilterModel:
153
231
  filter_ids: list[str] = None,
154
232
  step: int = 1,
155
233
  ) -> ModelResult:
156
- """Returns scores for the sequence(s) based on the filters in the model"""
234
+ """
235
+ Returns a model result object for the sequence(s) based on the filters in the model
236
+
237
+ This method processes the input sequence(s) and calculates hits against the model's index.
238
+ It supports various input types, including single sequences, lists of sequences,
239
+ SeqIO iterators, and file paths. The results are returned as a ModelResult object.
240
+
241
+ Args:
242
+ sequence_input (SeqRecord | list[SeqRecord] | SeqIO.FastaIO.FastaIterator |
243
+ SeqIO.QualityIO.FastqPhredIterator | Path):
244
+ The input sequence(s) to be processed. Can be a single SeqRecord, a list of
245
+ SeqRecords, a SeqIO iterator, or a Path to a fasta/fastq file.
246
+ filter_ids (list[str]): A list of filter IDs to filter the results. If None,
247
+ all results are returned.
248
+ step (int): The step size for the k-mer search. Default is 1.
249
+
250
+ Returns:
251
+ ModelResult: An object containing the hits for each sequence, the number of kmers,
252
+ and the sparse sampling step.
253
+
254
+ Raises:
255
+ ValueError: If the input sequence is not valid, or if it is not a Seq object,
256
+ a list of Seq objects, a SeqIO iterator, or a Path object to a fasta/fastq file.
257
+ """
157
258
  if isinstance(sequence_input, (SeqRecord)):
158
259
  return ProbabilisticFilterModel.predict(
159
260
  self, [sequence_input], filter_ids, step=step
@@ -186,7 +287,14 @@ class ProbabilisticFilterModel:
186
287
  )
187
288
 
188
289
  def save(self) -> None:
189
- """Saves the model to disk"""
290
+ """
291
+ Saves the model to disk
292
+
293
+ This method serializes the model's attributes to a JSON file and creates a directory
294
+ for the model based on its slug. The JSON file contains all relevant information about
295
+ the model, including k-mer size, display name, author information, model type, and
296
+ other parameters. The directory structure is created if it does not already exist.
297
+ """
190
298
  json_path = self.base_path / f"{self.slug()}.json"
191
299
  filter_path = self.base_path / self.slug()
192
300
  filter_path.mkdir(exist_ok=True, parents=True)
@@ -198,7 +306,23 @@ class ProbabilisticFilterModel:
198
306
 
199
307
  @staticmethod
200
308
  def load(path: Path) -> "ProbabilisticFilterModel":
201
- """Loads the model from a file"""
309
+ """
310
+ Loads the model from a file
311
+
312
+ This static method reads a JSON file containing the model's attributes and constructs
313
+ a ProbabilisticFilterModel object. It also checks for the existence of the cobs index file
314
+ and initializes the index if it exists.
315
+
316
+ Args:
317
+ path (Path): The path to the JSON file containing the model's attributes.
318
+
319
+ Returns:
320
+ ProbabilisticFilterModel: An instance of the ProbabilisticFilterModel class
321
+ initialized with the attributes from the JSON file.
322
+
323
+ Raises:
324
+ FileNotFoundError: If the JSON file or the cobs index file does not exist.
325
+ """
202
326
  with open(path, "r", encoding="utf-8") as file:
203
327
  json_object = file.read()
204
328
  model_json = json.loads(json_object)
@@ -223,6 +347,18 @@ class ProbabilisticFilterModel:
223
347
  return model
224
348
 
225
349
  def _convert_cobs_result_to_dict(self, cobs_result: cobs.SearchResult) -> dict:
350
+ """
351
+ Converts a cobs SearchResult to a dictionary
352
+
353
+ This method takes a cobs SearchResult object and converts it into a dictionary
354
+ where the keys are document names and the values are their corresponding scores.
355
+
356
+ Args:
357
+ cobs_result (cobs.SearchResult): The result object from a cobs search.
358
+
359
+ Returns:
360
+ dict: A dictionary mapping document names to their scores.
361
+ """
226
362
  return {
227
363
  individual_result.doc_name: individual_result.score
228
364
  for individual_result in cobs_result
@@ -239,7 +375,27 @@ class ProbabilisticFilterModel:
239
375
  ),
240
376
  step: int = 1,
241
377
  ) -> int:
242
- """Counts the number of kmers in the sequence(s)"""
378
+ """
379
+ Counts the number of kmers in the sequence(s)
380
+
381
+ This method calculates the number of k-mers in a given sequence or list of sequences.
382
+ It supports various input types, including single sequences, SeqRecords, lists of sequences,
383
+ and SeqIO iterators. The step size for the k-mer search can be specified.
384
+
385
+ Args:
386
+ sequence_input (Seq | SeqRecord | list[Seq] | SeqIO.FastaIO.FastaIterator |
387
+ SeqIO.QualityIO.FastqPhredIterator):
388
+ The input sequence(s) to count k-mers in. Can be a single Seq, a SeqRecord,
389
+ a list of Seq objects, or a SeqIO iterator.
390
+ step (int): The step size for the k-mer search. Default is 1.
391
+
392
+ Returns:
393
+ int: The total number of k-mers in the input sequence(s).
394
+
395
+ Raises:
396
+ ValueError: If the input sequence is not valid, or if it is not a Seq object,
397
+ a SeqRecord, a list of Seq objects, or a SeqIO iterator.
398
+ """
243
399
  if isinstance(sequence_input, Seq):
244
400
  return self._count_kmers([sequence_input], step=step)
245
401
 
@@ -268,12 +424,38 @@ class ProbabilisticFilterModel:
268
424
  " SeqIO FastaIterator, or a SeqIO FastqPhredIterator"
269
425
  )
270
426
 
271
- def _is_sequence_list(self, sequence_input):
427
+ def _is_sequence_list(self, sequence_input: Any) -> bool:
428
+ """
429
+ Checks if the input is a list of SeqRecord objects
430
+
431
+ This method verifies if the input is a list and that all elements in the list
432
+ are instances of SeqRecord. This is useful for ensuring that the input is a valid
433
+ collection of sequence records.
434
+
435
+ Args:
436
+ sequence_input (Any): The input to check.
437
+
438
+ Returns:
439
+ bool: True if the input is a list of SeqRecord objects, False otherwise.
440
+ """
272
441
  return isinstance(sequence_input, list) and all(
273
442
  isinstance(seq, (SeqRecord)) for seq in sequence_input
274
443
  )
275
444
 
276
- def _is_sequence_iterator(self, sequence_input):
445
+ def _is_sequence_iterator(self, sequence_input: Any) -> bool:
446
+ """
447
+ Checks if the input is a SeqIO iterator
448
+
449
+ This method verifies if the input is an instance of a SeqIO iterator, such as
450
+ FastaIterator or FastqPhredIterator. This is useful for ensuring that the input
451
+ is a valid sequence iterator that can be processed by the model.
452
+
453
+ Args:
454
+ sequence_input (Any): The input to check.
455
+
456
+ Returns:
457
+ bool: True if the input is a SeqIO iterator, False otherwise.
458
+ """
277
459
  return isinstance(
278
460
  sequence_input,
279
461
  (SeqIO.FastaIO.FastaIterator, SeqIO.QualityIO.FastqPhredIterator),