virtool-workflow 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- virtool_workflow/__init__.py +13 -0
- virtool_workflow/analysis/__init__.py +1 -0
- virtool_workflow/analysis/fastqc.py +467 -0
- virtool_workflow/analysis/skewer.py +265 -0
- virtool_workflow/analysis/trimming.py +56 -0
- virtool_workflow/analysis/utils.py +27 -0
- virtool_workflow/api/__init__.py +0 -0
- virtool_workflow/api/acquire.py +66 -0
- virtool_workflow/api/client.py +132 -0
- virtool_workflow/api/utils.py +109 -0
- virtool_workflow/cli.py +66 -0
- virtool_workflow/data/__init__.py +22 -0
- virtool_workflow/data/analyses.py +106 -0
- virtool_workflow/data/hmms.py +109 -0
- virtool_workflow/data/indexes.py +319 -0
- virtool_workflow/data/jobs.py +62 -0
- virtool_workflow/data/ml.py +82 -0
- virtool_workflow/data/samples.py +190 -0
- virtool_workflow/data/subtractions.py +244 -0
- virtool_workflow/data/uploads.py +35 -0
- virtool_workflow/decorators.py +47 -0
- virtool_workflow/errors.py +62 -0
- virtool_workflow/files.py +40 -0
- virtool_workflow/hooks.py +140 -0
- virtool_workflow/pytest_plugin/__init__.py +35 -0
- virtool_workflow/pytest_plugin/data.py +197 -0
- virtool_workflow/pytest_plugin/utils.py +9 -0
- virtool_workflow/runtime/__init__.py +0 -0
- virtool_workflow/runtime/config.py +21 -0
- virtool_workflow/runtime/discover.py +95 -0
- virtool_workflow/runtime/events.py +7 -0
- virtool_workflow/runtime/hook.py +129 -0
- virtool_workflow/runtime/path.py +19 -0
- virtool_workflow/runtime/ping.py +54 -0
- virtool_workflow/runtime/redis.py +65 -0
- virtool_workflow/runtime/run.py +276 -0
- virtool_workflow/runtime/run_subprocess.py +168 -0
- virtool_workflow/runtime/sentry.py +28 -0
- virtool_workflow/utils.py +90 -0
- virtool_workflow/workflow.py +90 -0
- virtool_workflow-0.0.0.dist-info/LICENSE +21 -0
- virtool_workflow-0.0.0.dist-info/METADATA +71 -0
- virtool_workflow-0.0.0.dist-info/RECORD +45 -0
- virtool_workflow-0.0.0.dist-info/WHEEL +4 -0
- virtool_workflow-0.0.0.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,265 @@
|
|
1
|
+
"""Utilities and a fixture for using `Skewer <https://github.com/relipmoc/skewer>`_ to
|
2
|
+
trim reads.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import asyncio
|
6
|
+
import os
|
7
|
+
import shutil
|
8
|
+
from asyncio.subprocess import Process
|
9
|
+
from dataclasses import dataclass
|
10
|
+
from enum import Enum
|
11
|
+
from pathlib import Path
|
12
|
+
from tempfile import mkdtemp
|
13
|
+
from typing import Protocol
|
14
|
+
|
15
|
+
from pyfixtures import fixture
|
16
|
+
from virtool.models.enums import LibraryType
|
17
|
+
|
18
|
+
from virtool_workflow import RunSubprocess
|
19
|
+
from virtool_workflow.analysis.utils import ReadPaths
|
20
|
+
from virtool_workflow.data.samples import WFSample
|
21
|
+
|
22
|
+
|
23
|
+
class SkewerMode(str, Enum):
|
24
|
+
"""The mode to run Skewer in."""
|
25
|
+
|
26
|
+
PAIRED_END = "pe"
|
27
|
+
"""Run Skewer in paired-end mode."""
|
28
|
+
|
29
|
+
SINGLE_END = "any"
|
30
|
+
"""Run Skewer in single-end mode."""
|
31
|
+
|
32
|
+
|
33
|
+
@dataclass
|
34
|
+
class SkewerConfiguration:
|
35
|
+
"""A configuration for running Skewer."""
|
36
|
+
|
37
|
+
min_length: int
|
38
|
+
"""The minimum length of a trimmed read."""
|
39
|
+
|
40
|
+
mode: SkewerMode
|
41
|
+
"""The mode to run Skewer in."""
|
42
|
+
|
43
|
+
end_quality: int = 20
|
44
|
+
"""The minimum quality score for the end of a trimmed read."""
|
45
|
+
|
46
|
+
max_error_rate: float = 0.1
|
47
|
+
"""
|
48
|
+
The maximum error rate for a trimmed read. Reads that exceed the rate will be
|
49
|
+
discarded.
|
50
|
+
"""
|
51
|
+
|
52
|
+
max_indel_rate: float = 0.03
|
53
|
+
"""
|
54
|
+
The maximum indel rate for a trimmed read. Reads that exceed the rate will be
|
55
|
+
discarded.
|
56
|
+
"""
|
57
|
+
|
58
|
+
mean_quality: int = 25
|
59
|
+
"""The minimum mean quality score for a trimmed read. Reads """
|
60
|
+
|
61
|
+
number_of_processes: int = 1
|
62
|
+
"""The number of processes to use when running Skewer."""
|
63
|
+
|
64
|
+
quiet: bool = True
|
65
|
+
"""Whether to run Skewer in quiet mode."""
|
66
|
+
|
67
|
+
other_options: tuple[str] = ("-n", "-z")
|
68
|
+
"""Other options to pass to Skewer."""
|
69
|
+
|
70
|
+
|
71
|
+
@dataclass
|
72
|
+
class SkewerResult:
|
73
|
+
"""Represents the result of running Skewer to trim a paired or unpaired FASTQ dataset."""
|
74
|
+
|
75
|
+
command: list[str]
|
76
|
+
"""The command used to run Skewer."""
|
77
|
+
|
78
|
+
output_path: Path
|
79
|
+
"""The path to the directory containing the trimmed reads."""
|
80
|
+
|
81
|
+
process: Process
|
82
|
+
"""The process that ran Skewer."""
|
83
|
+
|
84
|
+
read_paths: ReadPaths
|
85
|
+
"""The paths to the trimmed reads."""
|
86
|
+
|
87
|
+
@property
|
88
|
+
def left(self) -> Path:
|
89
|
+
"""The path to one of:
|
90
|
+
- the FASTQ trimming result for an unpaired Illumina dataset
|
91
|
+
- the FASTA trimming result for the left reads of a paired Illumina dataset
|
92
|
+
|
93
|
+
"""
|
94
|
+
return self.read_paths[0]
|
95
|
+
|
96
|
+
@property
|
97
|
+
def right(self) -> Path | None:
|
98
|
+
"""The path to the rights reads of a paired Illumina dataset.
|
99
|
+
|
100
|
+
``None`` if the dataset in unpaired.
|
101
|
+
|
102
|
+
:type: :class:`.Path`
|
103
|
+
|
104
|
+
"""
|
105
|
+
try:
|
106
|
+
return self.read_paths[1]
|
107
|
+
except IndexError:
|
108
|
+
return None
|
109
|
+
|
110
|
+
|
111
|
+
def calculate_skewer_trimming_parameters(
|
112
|
+
sample: WFSample,
|
113
|
+
min_read_length: int,
|
114
|
+
) -> SkewerConfiguration:
|
115
|
+
"""Calculates trimming parameters based on the library type, and minimum allowed trim length.
|
116
|
+
|
117
|
+
:param sample: The sample to calculate trimming parameters for.
|
118
|
+
:param min_read_length: The minimum length of a read before it is discarded.
|
119
|
+
:return: the trimming parameters
|
120
|
+
"""
|
121
|
+
config = SkewerConfiguration(
|
122
|
+
min_length=min_read_length,
|
123
|
+
mode=SkewerMode.PAIRED_END if sample.paired else SkewerMode.SINGLE_END,
|
124
|
+
)
|
125
|
+
|
126
|
+
if sample.library_type == LibraryType.amplicon:
|
127
|
+
config.end_quality = 0
|
128
|
+
config.mean_quality = 0
|
129
|
+
config.min_length = min_read_length
|
130
|
+
|
131
|
+
return config
|
132
|
+
|
133
|
+
if sample.library_type == LibraryType.srna:
|
134
|
+
config.max_length = 22
|
135
|
+
config.min_length = 20
|
136
|
+
|
137
|
+
return config
|
138
|
+
|
139
|
+
raise ValueError(f"Unknown library type: {sample.library_type}")
|
140
|
+
|
141
|
+
|
142
|
+
class SkewerRunner(Protocol):
|
143
|
+
"""A protocol describing callables that can be used to run Skewer."""
|
144
|
+
|
145
|
+
async def __call__(
|
146
|
+
self,
|
147
|
+
config: SkewerConfiguration,
|
148
|
+
paths: ReadPaths,
|
149
|
+
output_path: Path,
|
150
|
+
) -> SkewerResult: ...
|
151
|
+
|
152
|
+
|
153
|
+
@fixture
|
154
|
+
def skewer(proc: int, run_subprocess: RunSubprocess) -> SkewerRunner:
|
155
|
+
"""Provides an asynchronous function that can run skewer.
|
156
|
+
|
157
|
+
The provided function takes a :class:`.SkewerConfiguration` and a tuple of paths to
|
158
|
+
the left and right reads to trim. If a single member tuple is provided, the dataset
|
159
|
+
is assumed to be unpaired.
|
160
|
+
|
161
|
+
The Skewer process will automatically be assigned the number of processes configured
|
162
|
+
for the workflow run.
|
163
|
+
|
164
|
+
Example:
|
165
|
+
-------
|
166
|
+
.. code-block:: python
|
167
|
+
|
168
|
+
@step
|
169
|
+
async def step_one(skewer: SkewerRunner, work_path: Path):
|
170
|
+
config = SkewerConfiguration(
|
171
|
+
mean_quality=30
|
172
|
+
)
|
173
|
+
|
174
|
+
skewer_result = await skewer(config, (
|
175
|
+
work_path / "test_1.fq.gz",
|
176
|
+
work_path / "test_2.fq.gz",
|
177
|
+
))
|
178
|
+
|
179
|
+
|
180
|
+
"""
|
181
|
+
if shutil.which("skewer") is None:
|
182
|
+
raise RuntimeError("skewer is not installed.")
|
183
|
+
|
184
|
+
async def func(
|
185
|
+
config: SkewerConfiguration,
|
186
|
+
read_paths: ReadPaths,
|
187
|
+
output_path: Path,
|
188
|
+
):
|
189
|
+
temp_path = Path(await asyncio.to_thread(mkdtemp, suffix="_virtool_skewer"))
|
190
|
+
|
191
|
+
await asyncio.to_thread(output_path.mkdir, exist_ok=True, parents=True)
|
192
|
+
|
193
|
+
command = [
|
194
|
+
str(a)
|
195
|
+
for a in [
|
196
|
+
"skewer",
|
197
|
+
"-r",
|
198
|
+
config.max_error_rate,
|
199
|
+
"-d",
|
200
|
+
config.max_indel_rate,
|
201
|
+
"-m",
|
202
|
+
config.mode.value,
|
203
|
+
"-l",
|
204
|
+
config.min_length,
|
205
|
+
"-q",
|
206
|
+
config.end_quality,
|
207
|
+
"-Q",
|
208
|
+
config.mean_quality,
|
209
|
+
"-t",
|
210
|
+
proc,
|
211
|
+
# Skewer spams the console with progress updates. Set quiet to avoid.
|
212
|
+
"--quiet",
|
213
|
+
# Compress the trimmed output.
|
214
|
+
"-z",
|
215
|
+
"-o",
|
216
|
+
f"{temp_path}/reads",
|
217
|
+
*read_paths,
|
218
|
+
]
|
219
|
+
]
|
220
|
+
|
221
|
+
process = await run_subprocess(
|
222
|
+
command,
|
223
|
+
cwd=read_paths[0].parent,
|
224
|
+
env={**os.environ, "LD_LIBRARY_PATH": "/usr/lib/x86_64-linux-gnu"},
|
225
|
+
)
|
226
|
+
|
227
|
+
read_paths = await asyncio.to_thread(
|
228
|
+
_rename_trimming_results,
|
229
|
+
temp_path,
|
230
|
+
output_path,
|
231
|
+
)
|
232
|
+
|
233
|
+
return SkewerResult(command, output_path, process, read_paths)
|
234
|
+
|
235
|
+
return func
|
236
|
+
|
237
|
+
|
238
|
+
def _rename_trimming_results(temp_path: Path, output_path: Path) -> ReadPaths:
|
239
|
+
"""Rename Skewer output to a simple name used in Virtool.
|
240
|
+
|
241
|
+
:param path: The path containing the results from Skewer
|
242
|
+
"""
|
243
|
+
shutil.move(
|
244
|
+
temp_path / "reads-trimmed.log",
|
245
|
+
output_path / "trim.log",
|
246
|
+
)
|
247
|
+
|
248
|
+
try:
|
249
|
+
return (
|
250
|
+
shutil.move(
|
251
|
+
temp_path / "reads-trimmed.fastq.gz",
|
252
|
+
output_path / "reads_1.fq.gz",
|
253
|
+
),
|
254
|
+
)
|
255
|
+
except FileNotFoundError:
|
256
|
+
return (
|
257
|
+
shutil.move(
|
258
|
+
temp_path / "reads-trimmed-pair1.fastq.gz",
|
259
|
+
output_path / "reads_1.fq.gz",
|
260
|
+
),
|
261
|
+
shutil.move(
|
262
|
+
temp_path / "reads-trimmed-pair2.fastq.gz",
|
263
|
+
output_path / "reads_2.fq.gz",
|
264
|
+
),
|
265
|
+
)
|
@@ -0,0 +1,56 @@
|
|
1
|
+
"""Calculate trimming parameters which are passed the Skewer read trimming tool."""
|
2
|
+
|
3
|
+
import hashlib
|
4
|
+
import json
|
5
|
+
|
6
|
+
from virtool.models.enums import LibraryType
|
7
|
+
|
8
|
+
from virtool_workflow.data.samples import WFSample
|
9
|
+
|
10
|
+
|
11
|
+
def calculate_trimming_cache_key(
|
12
|
+
sample_id: str,
|
13
|
+
trimming_parameters: dict,
|
14
|
+
program: str = "skewer",
|
15
|
+
):
|
16
|
+
"""Compute a unique cache key.
|
17
|
+
|
18
|
+
**This is not currently used.**
|
19
|
+
|
20
|
+
:param sample_id: The ID of the sample being trimmed.
|
21
|
+
:param trimming_parameters: The trimming parameters.
|
22
|
+
:param program: The name of the trimming program.
|
23
|
+
:return: A unique cache key.
|
24
|
+
|
25
|
+
"""
|
26
|
+
raw_key = "reads-" + json.dumps(
|
27
|
+
{
|
28
|
+
"id": sample_id,
|
29
|
+
"parameters": trimming_parameters,
|
30
|
+
"program": program,
|
31
|
+
},
|
32
|
+
sort_keys=True,
|
33
|
+
)
|
34
|
+
|
35
|
+
return hashlib.sha256(raw_key.encode()).hexdigest()
|
36
|
+
|
37
|
+
|
38
|
+
def calculate_trimming_min_length(sample: WFSample) -> int:
|
39
|
+
"""Calculate the minimum trimming length that should be used for the passed sample.
|
40
|
+
|
41
|
+
This takes into account the library type (:class:`.LibraryType`) and the maximum
|
42
|
+
observed read length in the sample.
|
43
|
+
|
44
|
+
:param sample: the sample
|
45
|
+
:return: the minimum allowed trimmed read length
|
46
|
+
"""
|
47
|
+
if sample.library_type == LibraryType.amplicon:
|
48
|
+
return round(0.95 * sample.max_length)
|
49
|
+
|
50
|
+
if sample.max_length < 80:
|
51
|
+
return 35
|
52
|
+
|
53
|
+
if sample.max_length < 160:
|
54
|
+
return 100
|
55
|
+
|
56
|
+
return 160
|
@@ -0,0 +1,27 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
from typing import Callable, TypeAlias
|
3
|
+
|
4
|
+
ReadPaths: TypeAlias = tuple[Path] | tuple[Path, Path]
|
5
|
+
"""A tuple of paths to FASTQ files. There may be one or two paths, depending on whether the dataset is paired."""
|
6
|
+
|
7
|
+
|
8
|
+
def _make_paired_paths(
|
9
|
+
dir_path: Path, paired: bool, mkstr: Callable[[int], str]
|
10
|
+
) -> ReadPaths:
|
11
|
+
path1 = dir_path / mkstr(1)
|
12
|
+
return (path1, dir_path / mkstr(2)) if paired else (path1,)
|
13
|
+
|
14
|
+
|
15
|
+
def make_read_paths(reads_dir_path: Path, paired: bool) -> ReadPaths:
|
16
|
+
"""
|
17
|
+
Get the path(s) locating the compressed fastq files containing the read data.
|
18
|
+
|
19
|
+
:param reads_dir_path: The directory containing the fastq file(s).
|
20
|
+
:param paired: A boolean indicating if the sequence is paired (two fastq files).
|
21
|
+
:return: A :class:`Tuple[Path]` if :obj:`paired` is `False`, else a :class:`Tuple[Path, Path]`.
|
22
|
+
"""
|
23
|
+
return _make_paired_paths(reads_dir_path, paired, lambda n: f"reads_{n}.fq.gz")
|
24
|
+
|
25
|
+
|
26
|
+
def make_legacy_read_paths(reads_dir_path: Path, paired: bool) -> ReadPaths:
|
27
|
+
return _make_paired_paths(reads_dir_path, paired, lambda n: f"reads_{n}.fastq")
|
File without changes
|
@@ -0,0 +1,66 @@
|
|
1
|
+
import asyncio
|
2
|
+
|
3
|
+
from aiohttp import ClientConnectionError, ClientSession, TCPConnector
|
4
|
+
from structlog import get_logger
|
5
|
+
from virtool.jobs.models import JobAcquired
|
6
|
+
|
7
|
+
from virtool_workflow.errors import (
|
8
|
+
JobAlreadyAcquiredError,
|
9
|
+
JobsAPIError,
|
10
|
+
JobsAPIServerError,
|
11
|
+
)
|
12
|
+
|
13
|
+
logger = get_logger("api")
|
14
|
+
|
15
|
+
|
16
|
+
async def acquire_job_by_id(
|
17
|
+
jobs_api_connection_string: str,
|
18
|
+
job_id: str,
|
19
|
+
) -> JobAcquired:
|
20
|
+
"""Acquire the job with a given ID via the API.
|
21
|
+
|
22
|
+
:param jobs_api_connection_string: The url for the jobs API.
|
23
|
+
:param job_id: The id of the job to acquire
|
24
|
+
:return: a job including its API key
|
25
|
+
"""
|
26
|
+
async with ClientSession(
|
27
|
+
connector=TCPConnector(force_close=True, limit=100),
|
28
|
+
) as session:
|
29
|
+
attempts = 4
|
30
|
+
|
31
|
+
while attempts > 0:
|
32
|
+
try:
|
33
|
+
async with session.patch(
|
34
|
+
f"{jobs_api_connection_string}/jobs/{job_id}",
|
35
|
+
json={"acquired": True},
|
36
|
+
) as resp:
|
37
|
+
logger.info("acquiring job", remaining_attempts=attempts, id=job_id)
|
38
|
+
|
39
|
+
if resp.status == 200:
|
40
|
+
job_json = await resp.json()
|
41
|
+
logger.info("acquired job", id=job_id)
|
42
|
+
return JobAcquired(**job_json)
|
43
|
+
|
44
|
+
if resp.status == 400:
|
45
|
+
if "already acquired" in await resp.text():
|
46
|
+
raise JobAlreadyAcquiredError(await resp.json())
|
47
|
+
|
48
|
+
logger.critical(
|
49
|
+
"unexpected api error during job acquisition",
|
50
|
+
status=resp.status,
|
51
|
+
body=await resp.text(),
|
52
|
+
)
|
53
|
+
|
54
|
+
raise JobsAPIError("Unexpected API error during job acquisition")
|
55
|
+
|
56
|
+
except ClientConnectionError:
|
57
|
+
logger.warning(
|
58
|
+
"unable to connect to server. retrying in 1 second.",
|
59
|
+
remaining_attemtps=attempts,
|
60
|
+
id=job_id,
|
61
|
+
)
|
62
|
+
await asyncio.sleep(1)
|
63
|
+
|
64
|
+
attempts -= 1
|
65
|
+
|
66
|
+
raise JobsAPIServerError("Unable to connect to server.")
|
@@ -0,0 +1,132 @@
|
|
1
|
+
from contextlib import asynccontextmanager
|
2
|
+
from pathlib import Path
|
3
|
+
|
4
|
+
import aiofiles
|
5
|
+
from aiohttp import BasicAuth, ClientSession
|
6
|
+
|
7
|
+
from virtool_workflow.api.utils import (
|
8
|
+
decode_json_response,
|
9
|
+
raise_exception_by_status_code,
|
10
|
+
)
|
11
|
+
from virtool_workflow.errors import JobsAPIError
|
12
|
+
from virtool_workflow.files import VirtoolFileFormat
|
13
|
+
|
14
|
+
CHUNK_SIZE = 1024 * 1024 * 2
|
15
|
+
|
16
|
+
|
17
|
+
class APIClient:
|
18
|
+
def __init__(self, http: ClientSession, jobs_api_connection_string: str):
|
19
|
+
self.http = http
|
20
|
+
self.jobs_api_connection_string = jobs_api_connection_string
|
21
|
+
|
22
|
+
async def get_json(self, path: str) -> dict:
|
23
|
+
"""Get the JSON response from the provided API ``path``."""
|
24
|
+
async with self.http.get(f"{self.jobs_api_connection_string}{path}") as resp:
|
25
|
+
await raise_exception_by_status_code(resp)
|
26
|
+
return await decode_json_response(resp)
|
27
|
+
|
28
|
+
async def get_file(self, path: str, target_path: Path):
|
29
|
+
"""Download the file at URL ``path`` to the local filesystem path ``target_path``.
|
30
|
+
"""
|
31
|
+
async with self.http.get(f"{self.jobs_api_connection_string}{path}") as resp:
|
32
|
+
if resp.status != 200:
|
33
|
+
raise JobsAPIError(
|
34
|
+
f"Encountered {resp.status} while downloading '{path}'",
|
35
|
+
)
|
36
|
+
async with aiofiles.open(target_path, "wb") as f:
|
37
|
+
async for chunk in resp.content.iter_chunked(CHUNK_SIZE):
|
38
|
+
await f.write(chunk)
|
39
|
+
|
40
|
+
return target_path
|
41
|
+
|
42
|
+
async def patch_json(self, path: str, data: dict) -> dict:
|
43
|
+
"""Make a patch request against the provided API ``path`` and return the response
|
44
|
+
as a dictionary of decoded JSON.
|
45
|
+
|
46
|
+
:param path: the API path to make the request against
|
47
|
+
:param data: the data to send with the request
|
48
|
+
:return: the response as a dictionary of decoded JSON
|
49
|
+
"""
|
50
|
+
async with self.http.patch(
|
51
|
+
f"{self.jobs_api_connection_string}{path}", json=data,
|
52
|
+
) as resp:
|
53
|
+
await raise_exception_by_status_code(resp)
|
54
|
+
return await decode_json_response(resp)
|
55
|
+
|
56
|
+
async def post_file(
|
57
|
+
self,
|
58
|
+
path: str,
|
59
|
+
file_path: Path,
|
60
|
+
file_format: VirtoolFileFormat,
|
61
|
+
params: dict | None = None,
|
62
|
+
):
|
63
|
+
if not params:
|
64
|
+
params = {"name": file_path.name}
|
65
|
+
|
66
|
+
if file_format is not None:
|
67
|
+
params.update(format=file_format)
|
68
|
+
|
69
|
+
async with self.http.post(
|
70
|
+
f"{self.jobs_api_connection_string}{path}",
|
71
|
+
data={"file": open(file_path, "rb")},
|
72
|
+
params=params,
|
73
|
+
) as response:
|
74
|
+
await raise_exception_by_status_code(response)
|
75
|
+
|
76
|
+
async def post_json(self, path: str, data: dict) -> dict:
|
77
|
+
async with self.http.post(
|
78
|
+
f"{self.jobs_api_connection_string}{path}", json=data,
|
79
|
+
) as resp:
|
80
|
+
await raise_exception_by_status_code(resp)
|
81
|
+
return await decode_json_response(resp)
|
82
|
+
|
83
|
+
async def put_file(
|
84
|
+
self,
|
85
|
+
path: str,
|
86
|
+
file_path: Path,
|
87
|
+
file_format: VirtoolFileFormat,
|
88
|
+
params: dict | None = None,
|
89
|
+
):
|
90
|
+
if not params:
|
91
|
+
params = {"name": file_path.name}
|
92
|
+
|
93
|
+
if file_format is not None:
|
94
|
+
params.update(format=file_format)
|
95
|
+
|
96
|
+
async with self.http.put(
|
97
|
+
f"{self.jobs_api_connection_string}{path}",
|
98
|
+
data={"file": open(file_path, "rb")},
|
99
|
+
params=params,
|
100
|
+
) as response:
|
101
|
+
await raise_exception_by_status_code(response)
|
102
|
+
|
103
|
+
async def put_json(self, path: str, data: dict) -> dict:
|
104
|
+
async with self.http.put(
|
105
|
+
f"{self.jobs_api_connection_string}{path}", json=data,
|
106
|
+
) as resp:
|
107
|
+
await raise_exception_by_status_code(resp)
|
108
|
+
return await decode_json_response(resp)
|
109
|
+
|
110
|
+
async def delete(self, path: str) -> dict | None:
|
111
|
+
"""Make a delete request against the provided API ``path``."""
|
112
|
+
async with self.http.delete(f"{self.jobs_api_connection_string}{path}") as resp:
|
113
|
+
await raise_exception_by_status_code(resp)
|
114
|
+
|
115
|
+
try:
|
116
|
+
return await decode_json_response(resp)
|
117
|
+
except ValueError:
|
118
|
+
return None
|
119
|
+
|
120
|
+
|
121
|
+
@asynccontextmanager
|
122
|
+
async def api_client(
|
123
|
+
jobs_api_connection_string: str,
|
124
|
+
job_id: str,
|
125
|
+
key: str,
|
126
|
+
):
|
127
|
+
"""An authenticated :class:``APIClient`` to make requests against the jobs API.
|
128
|
+
"""
|
129
|
+
async with ClientSession(
|
130
|
+
auth=BasicAuth(login=f"job-{job_id}", password=key),
|
131
|
+
) as http:
|
132
|
+
yield APIClient(http, jobs_api_connection_string)
|
@@ -0,0 +1,109 @@
|
|
1
|
+
import asyncio
|
2
|
+
import functools
|
3
|
+
|
4
|
+
from aiohttp import (
|
5
|
+
ClientConnectorError,
|
6
|
+
ClientResponse,
|
7
|
+
ContentTypeError,
|
8
|
+
ServerDisconnectedError,
|
9
|
+
)
|
10
|
+
from structlog import get_logger
|
11
|
+
|
12
|
+
from virtool_workflow.errors import (
|
13
|
+
JobsAPIBadRequestError,
|
14
|
+
JobsAPIConflictError,
|
15
|
+
JobsAPIForbiddenError,
|
16
|
+
JobsAPINotFoundError,
|
17
|
+
JobsAPIServerError,
|
18
|
+
)
|
19
|
+
|
20
|
+
logger = get_logger("api")
|
21
|
+
|
22
|
+
|
23
|
+
def retry(func):
|
24
|
+
"""Retry an API call five times when encountering the following exceptions:
|
25
|
+
* ``ConnectionRefusedError``.
|
26
|
+
* ``ClientConnectorError``.
|
27
|
+
* ``ServerDisconnectedError``.
|
28
|
+
|
29
|
+
These are probably due to transient issues in the cluster network.
|
30
|
+
|
31
|
+
"""
|
32
|
+
|
33
|
+
@functools.wraps(func)
|
34
|
+
async def wrapped(*args, **kwargs):
|
35
|
+
attempts = 0
|
36
|
+
|
37
|
+
try:
|
38
|
+
return await func(*args, **kwargs)
|
39
|
+
except (
|
40
|
+
ConnectionRefusedError,
|
41
|
+
ClientConnectorError,
|
42
|
+
ServerDisconnectedError,
|
43
|
+
) as err:
|
44
|
+
if attempts == 5:
|
45
|
+
raise
|
46
|
+
|
47
|
+
attempts += 1
|
48
|
+
get_logger("runtime").info(
|
49
|
+
f"Encountered {type(err).__name__}. Retrying in 5 seconds.",
|
50
|
+
)
|
51
|
+
await asyncio.sleep(5)
|
52
|
+
|
53
|
+
return await func(*args, **kwargs)
|
54
|
+
|
55
|
+
return wrapped
|
56
|
+
|
57
|
+
|
58
|
+
async def decode_json_response(resp: ClientResponse) -> dict | list | None:
|
59
|
+
"""Decode a JSON response from a :class:``ClientResponse``.
|
60
|
+
|
61
|
+
Raise a :class:`ValueError` if the response is not JSON.
|
62
|
+
|
63
|
+
:param resp: the response to decode
|
64
|
+
:return: the decoded JSON
|
65
|
+
"""
|
66
|
+
try:
|
67
|
+
return await resp.json()
|
68
|
+
except ContentTypeError:
|
69
|
+
raise ValueError(f"Response from {resp.url} was not JSON. {await resp.text()}")
|
70
|
+
|
71
|
+
|
72
|
+
async def raise_exception_by_status_code(resp: ClientResponse):
|
73
|
+
"""Raise an exception based on the status code of the response.
|
74
|
+
|
75
|
+
:param resp: the response to check
|
76
|
+
:raise JobsAPIBadRequest: the response status code is 400
|
77
|
+
:raise JobsAPIForbidden: the response status code is 403
|
78
|
+
:raise JobsAPINotFound: the response status code is 404
|
79
|
+
:raise JobsAPIConflict: the response status code is 409
|
80
|
+
:raise JobsAPIServerError: the response status code is 500
|
81
|
+
"""
|
82
|
+
status_exception_map = {
|
83
|
+
400: JobsAPIBadRequestError,
|
84
|
+
403: JobsAPIForbiddenError,
|
85
|
+
404: JobsAPINotFoundError,
|
86
|
+
409: JobsAPIConflictError,
|
87
|
+
500: JobsAPIServerError,
|
88
|
+
}
|
89
|
+
|
90
|
+
try:
|
91
|
+
resp_json: dict | None = await resp.json()
|
92
|
+
except ContentTypeError:
|
93
|
+
resp_json = None
|
94
|
+
|
95
|
+
if resp.status not in range(200, 299):
|
96
|
+
if resp_json is None:
|
97
|
+
try:
|
98
|
+
message = await resp.text()
|
99
|
+
except UnicodeDecodeError:
|
100
|
+
message = "Could not decode response message"
|
101
|
+
else:
|
102
|
+
message = resp_json["message"] if "message" in resp_json else str(resp_json)
|
103
|
+
|
104
|
+
if resp.status in status_exception_map:
|
105
|
+
raise status_exception_map[resp.status](message)
|
106
|
+
else:
|
107
|
+
raise ValueError(
|
108
|
+
f"Status code {resp.status} not handled for response\n {resp}",
|
109
|
+
)
|