protein-quest 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of protein-quest might be problematic. Click here for more details.

protein_quest/parallel.py CHANGED
@@ -2,8 +2,10 @@
2
2
 
3
3
  import logging
4
4
  import os
5
+ from collections.abc import Callable, Collection
6
+ from typing import Concatenate, ParamSpec, cast
5
7
 
6
- from dask.distributed import LocalCluster
8
+ from dask.distributed import Client, LocalCluster, progress
7
9
  from distributed.deploy.cluster import Cluster
8
10
  from psutil import cpu_count
9
11
 
@@ -66,3 +68,37 @@ def _configure_cpu_dask_scheduler(nproc: int, name: str) -> LocalCluster:
66
68
  n_workers = total_cpus // nproc
67
69
  # Use single thread per worker to prevent GIL slowing down the computations
68
70
  return LocalCluster(name=name, threads_per_worker=1, n_workers=n_workers)
71
+
72
+
73
+ # Generic type parameters used across helpers
74
+ P = ParamSpec("P")
75
+
76
+
77
+ def dask_map_with_progress[T, R, **P](
78
+ client: Client,
79
+ func: Callable[Concatenate[T, P], R],
80
+ iterable: Collection[T],
81
+ *args: P.args,
82
+ **kwargs: P.kwargs,
83
+ ) -> list[R]:
84
+ """
85
+ Wrapper for map, progress, and gather of Dask that returns a correctly typed list.
86
+
87
+ Args:
88
+ client: Dask client.
89
+ func: Function to map; first parameter comes from ``iterable`` and any
90
+ additional parameters can be provided positionally via ``*args`` or
91
+ as keyword arguments via ``**kwargs``.
92
+ iterable: Collection of arguments to map over.
93
+ *args: Additional positional arguments to pass to client.map().
94
+ **kwargs: Additional keyword arguments to pass to client.map().
95
+
96
+ Returns:
97
+ List of results of type returned by `func` function.
98
+ """
99
+ if client.dashboard_link:
100
+ logger.info(f"Follow progress on dask dashboard at: {client.dashboard_link}")
101
+ futures = client.map(func, iterable, *args, **kwargs)
102
+ progress(futures)
103
+ results = client.gather(futures)
104
+ return cast("list[R]", results)
@@ -3,7 +3,7 @@
3
3
  from collections.abc import Iterable, Mapping
4
4
  from pathlib import Path
5
5
 
6
- from protein_quest.utils import retrieve_files
6
+ from protein_quest.utils import retrieve_files, run_async
7
7
 
8
8
 
9
9
  def _map_id_mmcif(pdb_id: str) -> tuple[str, str]:
@@ -49,3 +49,17 @@ async def fetch(ids: Iterable[str], save_dir: Path, max_parallel_downloads: int
49
49
 
50
50
  await retrieve_files(urls, save_dir, max_parallel_downloads, desc="Downloading PDBe mmCIF files")
51
51
  return id2paths
52
+
53
+
54
+ def sync_fetch(ids: Iterable[str], save_dir: Path, max_parallel_downloads: int = 5) -> Mapping[str, Path]:
55
+ """Synchronously fetches mmCIF files from the PDBe database.
56
+
57
+ Args:
58
+ ids: A set of PDB IDs to fetch.
59
+ save_dir: The directory to save the fetched mmCIF files to.
60
+ max_parallel_downloads: The maximum number of parallel downloads.
61
+
62
+ Returns:
63
+ A dict of id and paths to the downloaded mmCIF files.
64
+ """
65
+ return run_async(fetch(ids, save_dir, max_parallel_downloads))
protein_quest/pdbe/io.py CHANGED
@@ -2,15 +2,22 @@
2
2
 
3
3
  import gzip
4
4
  import logging
5
- from collections.abc import Generator
5
+ from collections.abc import Generator, Iterable
6
+ from datetime import UTC, datetime
6
7
  from pathlib import Path
7
8
 
8
9
  import gemmi
9
10
 
10
- from protein_quest import __version__
11
+ from protein_quest.__version__ import __version__
12
+ from protein_quest.utils import CopyMethod, copyfile
11
13
 
12
14
  logger = logging.getLogger(__name__)
13
15
 
16
+ # TODO remove once v0.7.4 of gemmi is released,
17
+ # as uv pip install git+https://github.com/project-gemmi/gemmi.git installs 0.7.4.dev0 which does not print leaks
18
+ # Swallow gemmi leaked function warnings
19
+ gemmi.set_leak_warnings(False)
20
+
14
21
 
15
22
  def nr_residues_in_chain(file: Path | str, chain: str = "A") -> int:
16
23
  """Returns the number of residues in a specific chain from a mmCIF/pdb file.
@@ -23,14 +30,21 @@ def nr_residues_in_chain(file: Path | str, chain: str = "A") -> int:
23
30
  The number of residues in the specified chain.
24
31
  """
25
32
  structure = gemmi.read_structure(str(file))
26
- model = structure[0]
27
- gchain = find_chain_in_model(model, chain)
33
+ gchain = find_chain_in_structure(structure, chain)
28
34
  if gchain is None:
29
35
  logger.warning("Chain %s not found in %s. Returning 0.", chain, file)
30
36
  return 0
31
37
  return len(gchain)
32
38
 
33
39
 
40
+ def find_chain_in_structure(structure: gemmi.Structure, wanted_chain: str) -> gemmi.Chain | None:
41
+ for model in structure:
42
+ chain = find_chain_in_model(model, wanted_chain)
43
+ if chain is not None:
44
+ return chain
45
+ return None
46
+
47
+
34
48
  def find_chain_in_model(model: gemmi.Model, wanted_chain: str) -> gemmi.Chain | None:
35
49
  chain = model.find_chain(wanted_chain)
36
50
  if chain is None:
@@ -63,10 +77,12 @@ def write_structure(structure: gemmi.Structure, path: Path):
63
77
  with gzip.open(path, "wt") as f:
64
78
  f.write(body)
65
79
  elif path.name.endswith(".cif"):
66
- doc = structure.make_mmcif_document()
80
+ # do not write chem_comp so it is viewable by molstar
81
+ # see https://github.com/project-gemmi/gemmi/discussions/362
82
+ doc = structure.make_mmcif_document(gemmi.MmcifOutputGroups(True, chem_comp=False))
67
83
  doc.write_file(str(path))
68
84
  elif path.name.endswith(".cif.gz"):
69
- doc = structure.make_mmcif_document()
85
+ doc = structure.make_mmcif_document(gemmi.MmcifOutputGroups(True, chem_comp=False))
70
86
  cif_str = doc.as_string()
71
87
  with gzip.open(path, "wt") as f:
72
88
  f.write(cif_str)
@@ -106,14 +122,17 @@ def locate_structure_file(root: Path, pdb_id: str) -> Path:
106
122
  Raises:
107
123
  FileNotFoundError: If no structure file is found for the given PDB ID.
108
124
  """
109
- exts = [".cif.gz", ".cif", ".pdb.gz", ".pdb"]
110
- # files downloaded from https://www.ebi.ac.uk/pdbe/ website
111
- # have file names like pdb6t5y.ent or pdb6t5y.ent.gz for a PDB formatted file.
112
- # TODO support pdb6t5y.ent or pdb6t5y.ent.gz file names
125
+ exts = [".cif.gz", ".cif", ".pdb.gz", ".pdb", ".ent", ".ent.gz"]
113
126
  for ext in exts:
114
- candidate = root / f"{pdb_id.lower()}{ext}"
115
- if candidate.exists():
116
- return candidate
127
+ candidates = (
128
+ root / f"{pdb_id}{ext}",
129
+ root / f"{pdb_id.lower()}{ext}",
130
+ root / f"{pdb_id.upper()}{ext}",
131
+ root / f"pdb{pdb_id.lower()}{ext}",
132
+ )
133
+ for candidate in candidates:
134
+ if candidate.exists():
135
+ return candidate
117
136
  msg = f"No structure file found for {pdb_id} in {root}"
118
137
  raise FileNotFoundError(msg)
119
138
 
@@ -131,55 +150,132 @@ def glob_structure_files(input_dir: Path) -> Generator[Path]:
131
150
  yield from input_dir.glob(f"*{ext}")
132
151
 
133
152
 
153
+ class ChainNotFoundError(IndexError):
154
+ """Exception raised when a chain is not found in a structure."""
155
+
156
+ def __init__(self, chain: str, file: Path | str, available_chains: Iterable[str]):
157
+ super().__init__(f"Chain {chain} not found in {file}. Available chains are: {available_chains}")
158
+ self.chain_id = chain
159
+ self.file = file
160
+
161
+
162
+ def _dedup_helices(structure: gemmi.Structure):
163
+ helix_starts: set[str] = set()
164
+ duplicate_helix_indexes: list[int] = []
165
+ for hindex, helix in enumerate(structure.helices):
166
+ if str(helix.start) in helix_starts:
167
+ logger.debug(f"Duplicate start helix found: {hindex} {helix.start}, removing")
168
+ duplicate_helix_indexes.append(hindex)
169
+ else:
170
+ helix_starts.add(str(helix.start))
171
+ for helix_index in reversed(duplicate_helix_indexes):
172
+ structure.helices.pop(helix_index)
173
+
174
+
175
+ def _dedup_sheets(structure: gemmi.Structure, chain2keep: str):
176
+ duplicate_sheet_indexes: list[int] = []
177
+ for sindex, sheet in enumerate(structure.sheets):
178
+ if sheet.name != chain2keep:
179
+ duplicate_sheet_indexes.append(sindex)
180
+ for sheet_index in reversed(duplicate_sheet_indexes):
181
+ structure.sheets.pop(sheet_index)
182
+
183
+
184
+ def _add_provenance_info(structure: gemmi.Structure, chain2keep: str, out_chain: str):
185
+ old_id = structure.name
186
+ new_id = structure.name + f"{chain2keep}2{out_chain}"
187
+ structure.name = new_id
188
+ structure.info["_entry.id"] = new_id
189
+ new_title = f"From {old_id} chain {chain2keep} to {out_chain}"
190
+ structure.info["_struct.title"] = new_title
191
+ structure.info["_struct_keywords.pdbx_keywords"] = new_title.upper()
192
+ new_si = gemmi.SoftwareItem()
193
+ new_si.classification = gemmi.SoftwareItem.Classification.DataExtraction
194
+ new_si.name = "protein-quest.pdbe.io.write_single_chain_pdb_file"
195
+ new_si.version = str(__version__)
196
+ new_si.date = str(datetime.now(tz=UTC).date())
197
+ structure.meta.software = [*structure.meta.software, new_si]
198
+
199
+
200
+ def chains_in_structure(structure: gemmi.Structure) -> set[gemmi.Chain]:
201
+ """Get a list of chains in a structure."""
202
+ return {c for model in structure for c in model}
203
+
204
+
134
205
  def write_single_chain_pdb_file(
135
- input_file: Path, chain2keep: str, output_dir: Path, out_chain: str = "A"
136
- ) -> Path | None:
206
+ input_file: Path,
207
+ chain2keep: str,
208
+ output_dir: Path,
209
+ out_chain: str = "A",
210
+ copy_method: CopyMethod = "copy",
211
+ ) -> Path:
137
212
  """Write a single chain from a mmCIF/pdb file to a new mmCIF/pdb file.
138
213
 
214
+ Also
215
+
216
+ - removes ligands and waters
217
+ - renumbers atoms ids
218
+ - removes chem_comp section from cif files
219
+ - adds provenance information to the header like software and input file+chain
220
+
221
+ This function is equivalent to the following gemmi commands:
222
+
223
+ ```shell
224
+ gemmi convert --remove-lig-wat --select=B --to=cif chain-in/3JRS.cif - | \\
225
+ gemmi convert --from=cif --rename-chain=B:A - chain-out/3JRS_B2A.gemmi.cif
226
+ ```
227
+
139
228
  Args:
140
229
  input_file: Path to the input mmCIF/pdb file.
141
230
  chain2keep: The chain to keep.
142
231
  output_dir: Directory to save the output file.
143
232
  out_chain: The chain identifier for the output file.
233
+ copy_method: How to copy when no changes are needed to output file.
144
234
 
145
235
  Returns:
146
- Path to the output mmCIF/pdb file or None if not created.
236
+ Path to the output mmCIF/pdb file
237
+
238
+ Raises:
239
+ FileNotFoundError: If the input file does not exist.
240
+ ChainNotFoundError: If the specified chain is not found in the input file.
147
241
  """
148
242
 
243
+ logger.debug(f"chain2keep: {chain2keep}, out_chain: {out_chain}")
149
244
  structure = gemmi.read_structure(str(input_file))
150
- model = structure[0]
245
+ structure.setup_entities()
151
246
 
152
- # Only count residues of polymer
153
- model.remove_ligands_and_waters()
154
-
155
- chain = find_chain_in_model(model, chain2keep)
247
+ chain = find_chain_in_structure(structure, chain2keep)
248
+ chainnames_in_structure = {c.name for c in chains_in_structure(structure)}
156
249
  if chain is None:
157
- logger.warning(
158
- "Chain %s not found in %s. Skipping.",
159
- chain2keep,
160
- input_file,
161
- )
162
- return None
250
+ raise ChainNotFoundError(chain2keep, input_file, chainnames_in_structure)
251
+ chain_name = chain.name
163
252
  name, extension = _split_name_and_extension(input_file.name)
164
- output_file = output_dir / f"{name}_{chain.name}2{out_chain}{extension}"
253
+ output_file = output_dir / f"{name}_{chain_name}2{out_chain}{extension}"
165
254
 
166
- new_structure = gemmi.Structure()
167
- new_structure.resolution = structure.resolution
168
- new_id = structure.name + f"{chain2keep}2{out_chain}"
169
- new_structure.name = new_id
170
- new_structure.info["_entry.id"] = new_id
171
- new_title = f"From {structure.info['_entry.id']} chain {chain2keep} to {out_chain}"
172
- new_structure.info["_struct.title"] = new_title
173
- new_structure.info["_struct_keywords.pdbx_keywords"] = new_title.upper()
174
- new_si = gemmi.SoftwareItem()
175
- new_si.classification = gemmi.SoftwareItem.Classification.DataExtraction
176
- new_si.name = "protein-quest"
177
- new_si.version = str(__version__)
178
- new_structure.meta.software.append(new_si)
179
- new_model = gemmi.Model(1)
180
- chain.name = out_chain
181
- new_model.add_chain(chain)
182
- new_structure.add_model(new_model)
183
- write_structure(new_structure, output_file)
255
+ if output_file.exists():
256
+ logger.info("Output file %s already exists for input file %s. Skipping.", output_file, input_file)
257
+ return output_file
258
+
259
+ if chain_name == out_chain and len(chainnames_in_structure) == 1:
260
+ logger.info(
261
+ "%s only has chain %s and out_chain is also %s. Copying file to %s.",
262
+ input_file,
263
+ chain_name,
264
+ out_chain,
265
+ output_file,
266
+ )
267
+ copyfile(input_file, output_file, copy_method)
268
+ return output_file
269
+
270
+ gemmi.Selection(chain_name).remove_not_selected(structure)
271
+ for m in structure:
272
+ m.remove_ligands_and_waters()
273
+ structure.setup_entities()
274
+ structure.rename_chain(chain_name, out_chain)
275
+ _dedup_helices(structure)
276
+ _dedup_sheets(structure, out_chain)
277
+ _add_provenance_info(structure, chain_name, out_chain)
278
+
279
+ write_structure(structure, output_file)
184
280
 
185
281
  return output_file
protein_quest/ss.py ADDED
@@ -0,0 +1,264 @@
1
+ """Module for dealing with secondary structure."""
2
+
3
+ import logging
4
+ from collections.abc import Generator, Iterable
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+
8
+ from gemmi import Structure, read_structure, set_leak_warnings
9
+
10
+ from protein_quest.converter import PositiveInt, Ratio, converter
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # TODO remove once v0.7.4 of gemmi is released,
15
+ # as uv pip install git+https://github.com/project-gemmi/gemmi.git installs 0.7.4.dev0 which does not print leaks
16
+ # Swallow gemmi leaked function warnings
17
+ set_leak_warnings(False)
18
+
19
+ # TODO if a structure has no secondary structure information, calculate it with `gemmi ss`.
20
+ # https://github.com/MonomerLibrary/monomers/wiki/Installation as --monomers dir
21
+ # gemmi executable is in https://pypi.org/project/gemmi-program/
22
+ # `gemmi ss` only prints secondary structure to stdout with `-v` flag.
23
+
24
+
25
+ def nr_of_residues_in_total(structure: Structure) -> int:
26
+ """Count the total number of residues in the structure.
27
+
28
+ Args:
29
+ structure: The gemmi Structure object to analyze.
30
+
31
+ Returns:
32
+ The total number of residues in the structure.
33
+ """
34
+ count = 0
35
+ for model in structure:
36
+ for chain in model:
37
+ count += len(chain)
38
+ return count
39
+
40
+
41
+ def nr_of_residues_in_helix(structure: Structure) -> int:
42
+ """Count the number of residues in alpha helices.
43
+
44
+ Requires structure to have secondary structure information.
45
+
46
+ Args:
47
+ structure: The gemmi Structure object to analyze.
48
+
49
+ Returns:
50
+ The number of residues in alpha helices.
51
+ """
52
+ # For cif files from AlphaFold the helix.length is set to -1
53
+ # so use resid instead
54
+ count = 0
55
+ for helix in structure.helices:
56
+ end = helix.end.res_id.seqid.num
57
+ start = helix.start.res_id.seqid.num
58
+ if end is None or start is None:
59
+ logger.warning(f"Invalid helix coordinates: {helix.end} or {helix.start}")
60
+ continue
61
+ length = end - start + 1
62
+ count += length
63
+ return count
64
+
65
+
66
+ def nr_of_residues_in_sheet(structure: Structure) -> int:
67
+ """Count the number of residues in beta sheets.
68
+
69
+ Requires structure to have secondary structure information.
70
+
71
+ Args:
72
+ structure: The gemmi Structure object to analyze.
73
+
74
+ Returns:
75
+ The number of residues in beta sheets.
76
+ """
77
+ count = 0
78
+ for sheet in structure.sheets:
79
+ for strand in sheet.strands:
80
+ end = strand.end.res_id.seqid.num
81
+ start = strand.start.res_id.seqid.num
82
+ if end is None or start is None:
83
+ logger.warning(f"Invalid strand coordinates: {strand.end} or {strand.start}")
84
+ continue
85
+ length = end - start + 1
86
+ count += length
87
+ return count
88
+
89
+
90
+ @dataclass
91
+ class SecondaryStructureFilterQuery:
92
+ """Query object to filter on secondary structure.
93
+
94
+ Parameters:
95
+ abs_min_helix_residues: Minimum number of residues in helices (absolute).
96
+ abs_max_helix_residues: Maximum number of residues in helices (absolute).
97
+ abs_min_sheet_residues: Minimum number of residues in sheets (absolute).
98
+ abs_max_sheet_residues: Maximum number of residues in sheets (absolute).
99
+ ratio_min_helix_residues: Minimum number of residues in helices (relative).
100
+ ratio_max_helix_residues: Maximum number of residues in helices (relative).
101
+ ratio_min_sheet_residues: Minimum number of residues in sheets (relative).
102
+ ratio_max_sheet_residues: Maximum number of residues in sheets (relative).
103
+ """
104
+
105
+ abs_min_helix_residues: PositiveInt | None = None
106
+ abs_max_helix_residues: PositiveInt | None = None
107
+ abs_min_sheet_residues: PositiveInt | None = None
108
+ abs_max_sheet_residues: PositiveInt | None = None
109
+ ratio_min_helix_residues: Ratio | None = None
110
+ ratio_max_helix_residues: Ratio | None = None
111
+ ratio_min_sheet_residues: Ratio | None = None
112
+ ratio_max_sheet_residues: Ratio | None = None
113
+
114
+
115
+ def _check_range(min_val, max_val, label):
116
+ if min_val is not None and max_val is not None and min_val >= max_val:
117
+ msg = f"Invalid {label} range: min {min_val} must be smaller than max {max_val}"
118
+ raise ValueError(msg)
119
+
120
+
121
+ base_query_hook = converter.get_structure_hook(SecondaryStructureFilterQuery)
122
+
123
+
124
+ @converter.register_structure_hook
125
+ def secondary_structure_filter_query_hook(value, _type) -> SecondaryStructureFilterQuery:
126
+ result: SecondaryStructureFilterQuery = base_query_hook(value, _type)
127
+ _check_range(result.abs_min_helix_residues, result.abs_max_helix_residues, "absolute helix residue")
128
+ _check_range(result.abs_min_sheet_residues, result.abs_max_sheet_residues, "absolute sheet residue")
129
+ _check_range(result.ratio_min_helix_residues, result.ratio_max_helix_residues, "ratio helix residue")
130
+ _check_range(result.ratio_min_sheet_residues, result.ratio_max_sheet_residues, "ratio sheet residue")
131
+ return result
132
+
133
+
134
+ @dataclass
135
+ class SecondaryStructureStats:
136
+ """Statistics about the secondary structure of a protein.
137
+
138
+ Parameters:
139
+ nr_residues: Total number of residues in the structure.
140
+ nr_helix_residues: Number of residues in helices.
141
+ nr_sheet_residues: Number of residues in sheets.
142
+ helix_ratio: Ratio of residues in helices.
143
+ sheet_ratio: Ratio of residues in sheets.
144
+ """
145
+
146
+ nr_residues: PositiveInt
147
+ nr_helix_residues: PositiveInt
148
+ nr_sheet_residues: PositiveInt
149
+ helix_ratio: Ratio
150
+ sheet_ratio: Ratio
151
+
152
+
153
+ @dataclass
154
+ class SecondaryStructureFilterResult:
155
+ """Result of filtering on secondary structure.
156
+
157
+ Parameters:
158
+ stats: The secondary structure statistics.
159
+ passed: Whether the structure passed the filtering criteria.
160
+ """
161
+
162
+ stats: SecondaryStructureStats
163
+ passed: bool = False
164
+
165
+
166
+ def _gather_stats(structure: Structure) -> SecondaryStructureStats:
167
+ nr_total_residues = nr_of_residues_in_total(structure)
168
+ nr_helix_residues = nr_of_residues_in_helix(structure)
169
+ nr_sheet_residues = nr_of_residues_in_sheet(structure)
170
+ if nr_total_residues == 0:
171
+ msg = "Structure has zero residues; cannot compute secondary structure ratios."
172
+ raise ValueError(msg)
173
+ helix_ratio = nr_helix_residues / nr_total_residues
174
+ sheet_ratio = nr_sheet_residues / nr_total_residues
175
+ return SecondaryStructureStats(
176
+ nr_residues=nr_total_residues,
177
+ nr_helix_residues=nr_helix_residues,
178
+ nr_sheet_residues=nr_sheet_residues,
179
+ helix_ratio=helix_ratio,
180
+ sheet_ratio=sheet_ratio,
181
+ )
182
+
183
+
184
+ def filter_on_secondary_structure(
185
+ structure: Structure,
186
+ query: SecondaryStructureFilterQuery,
187
+ ) -> SecondaryStructureFilterResult:
188
+ """Filter a structure based on secondary structure criteria.
189
+
190
+ Args:
191
+ structure: The gemmi Structure object to analyze.
192
+ query: The filtering criteria to apply.
193
+
194
+ Returns:
195
+ Filtering statistics and whether structure passed.
196
+ """
197
+ stats = _gather_stats(structure)
198
+ conditions: list[bool] = []
199
+
200
+ # Helix absolute thresholds
201
+ if query.abs_min_helix_residues is not None:
202
+ conditions.append(stats.nr_helix_residues >= query.abs_min_helix_residues)
203
+ if query.abs_max_helix_residues is not None:
204
+ conditions.append(stats.nr_helix_residues <= query.abs_max_helix_residues)
205
+
206
+ # Helix ratio thresholds
207
+ if query.ratio_min_helix_residues is not None:
208
+ conditions.append(stats.helix_ratio >= query.ratio_min_helix_residues)
209
+ if query.ratio_max_helix_residues is not None:
210
+ conditions.append(stats.helix_ratio <= query.ratio_max_helix_residues)
211
+
212
+ # Sheet absolute thresholds
213
+ if query.abs_min_sheet_residues is not None:
214
+ conditions.append(stats.nr_sheet_residues >= query.abs_min_sheet_residues)
215
+ if query.abs_max_sheet_residues is not None:
216
+ conditions.append(stats.nr_sheet_residues <= query.abs_max_sheet_residues)
217
+
218
+ # Sheet ratio thresholds
219
+ if query.ratio_min_sheet_residues is not None:
220
+ conditions.append(stats.sheet_ratio >= query.ratio_min_sheet_residues)
221
+ if query.ratio_max_sheet_residues is not None:
222
+ conditions.append(stats.sheet_ratio <= query.ratio_max_sheet_residues)
223
+
224
+ if not conditions:
225
+ msg = "No filtering conditions provided. Please specify at least one condition."
226
+ raise ValueError(msg)
227
+ passed = all(conditions)
228
+ return SecondaryStructureFilterResult(stats=stats, passed=passed)
229
+
230
+
231
+ def filter_file_on_secondary_structure(
232
+ file_path: Path,
233
+ query: SecondaryStructureFilterQuery,
234
+ ) -> SecondaryStructureFilterResult:
235
+ """Filter a structure file based on secondary structure criteria.
236
+
237
+ Args:
238
+ file_path: The path to the structure file to analyze.
239
+ query: The filtering criteria to apply.
240
+
241
+ Returns:
242
+ Filtering statistics and whether file passed.
243
+ """
244
+ structure = read_structure(str(file_path))
245
+ return filter_on_secondary_structure(structure, query)
246
+
247
+
248
+ def filter_files_on_secondary_structure(
249
+ file_paths: Iterable[Path],
250
+ query: SecondaryStructureFilterQuery,
251
+ ) -> Generator[tuple[Path, SecondaryStructureFilterResult]]:
252
+ """Filter multiple structure files based on secondary structure criteria.
253
+
254
+ Args:
255
+ file_paths: A list of paths to the structure files to analyze.
256
+ query: The filtering criteria to apply.
257
+
258
+ Yields:
259
+ For each file returns the filtering statistics and whether structure passed.
260
+ """
261
+ # TODO check if quick enough in serial mode, if not switch to dask map
262
+ for file_path in file_paths:
263
+ result = filter_file_on_secondary_structure(file_path, query)
264
+ yield file_path, result
protein_quest/taxonomy.py CHANGED
@@ -9,9 +9,9 @@ from typing import Literal, get_args
9
9
  from aiohttp.client import ClientResponse
10
10
  from aiohttp_retry import RetryClient
11
11
  from cattrs.gen import make_dict_structure_fn, override
12
- from cattrs.preconf.orjson import make_converter
13
12
  from yarl import URL
14
13
 
14
+ from protein_quest.converter import converter
15
15
  from protein_quest.go import TextIOWrapper
16
16
  from protein_quest.utils import friendly_session
17
17
 
@@ -20,6 +20,16 @@ logger = logging.getLogger(__name__)
20
20
 
21
21
  @dataclass(frozen=True, slots=True)
22
22
  class Taxon:
23
+ """Dataclass representing a taxon.
24
+
25
+ Arguments:
26
+ taxon_id: The unique identifier for the taxon.
27
+ scientific_name: The scientific name of the taxon.
28
+ rank: The taxonomic rank of the taxon (e.g., species, genus).
29
+ common_name: The common name of the taxon (if available).
30
+ other_names: A set of other names for the taxon (if available).
31
+ """
32
+
23
33
  taxon_id: str
24
34
  scientific_name: str
25
35
  rank: str
@@ -32,8 +42,6 @@ class SearchTaxonResponse:
32
42
  results: list[Taxon]
33
43
 
34
44
 
35
- converter = make_converter()
36
-
37
45
  converter.register_structure_hook(
38
46
  Taxon,
39
47
  make_dict_structure_fn(
@@ -47,7 +55,9 @@ converter.register_structure_hook(
47
55
  )
48
56
 
49
57
  SearchField = Literal["tax_id", "scientific", "common", "parent"]
58
+ """Type of search field"""
50
59
  search_fields: set[SearchField | None] = set(get_args(SearchField)) | {None}
60
+ """Set of valid search fields"""
51
61
 
52
62
 
53
63
  def _get_next_page(response: ClientResponse) -> URL | str | None: