virtool-workflow 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. virtool_workflow/__init__.py +13 -0
  2. virtool_workflow/analysis/__init__.py +1 -0
  3. virtool_workflow/analysis/fastqc.py +467 -0
  4. virtool_workflow/analysis/skewer.py +265 -0
  5. virtool_workflow/analysis/trimming.py +56 -0
  6. virtool_workflow/analysis/utils.py +27 -0
  7. virtool_workflow/api/__init__.py +0 -0
  8. virtool_workflow/api/acquire.py +66 -0
  9. virtool_workflow/api/client.py +132 -0
  10. virtool_workflow/api/utils.py +109 -0
  11. virtool_workflow/cli.py +66 -0
  12. virtool_workflow/data/__init__.py +22 -0
  13. virtool_workflow/data/analyses.py +106 -0
  14. virtool_workflow/data/hmms.py +109 -0
  15. virtool_workflow/data/indexes.py +319 -0
  16. virtool_workflow/data/jobs.py +62 -0
  17. virtool_workflow/data/ml.py +82 -0
  18. virtool_workflow/data/samples.py +190 -0
  19. virtool_workflow/data/subtractions.py +244 -0
  20. virtool_workflow/data/uploads.py +35 -0
  21. virtool_workflow/decorators.py +47 -0
  22. virtool_workflow/errors.py +62 -0
  23. virtool_workflow/files.py +40 -0
  24. virtool_workflow/hooks.py +140 -0
  25. virtool_workflow/pytest_plugin/__init__.py +35 -0
  26. virtool_workflow/pytest_plugin/data.py +197 -0
  27. virtool_workflow/pytest_plugin/utils.py +9 -0
  28. virtool_workflow/runtime/__init__.py +0 -0
  29. virtool_workflow/runtime/config.py +21 -0
  30. virtool_workflow/runtime/discover.py +95 -0
  31. virtool_workflow/runtime/events.py +7 -0
  32. virtool_workflow/runtime/hook.py +129 -0
  33. virtool_workflow/runtime/path.py +19 -0
  34. virtool_workflow/runtime/ping.py +54 -0
  35. virtool_workflow/runtime/redis.py +65 -0
  36. virtool_workflow/runtime/run.py +276 -0
  37. virtool_workflow/runtime/run_subprocess.py +168 -0
  38. virtool_workflow/runtime/sentry.py +28 -0
  39. virtool_workflow/utils.py +90 -0
  40. virtool_workflow/workflow.py +90 -0
  41. virtool_workflow-0.0.0.dist-info/LICENSE +21 -0
  42. virtool_workflow-0.0.0.dist-info/METADATA +71 -0
  43. virtool_workflow-0.0.0.dist-info/RECORD +45 -0
  44. virtool_workflow-0.0.0.dist-info/WHEEL +4 -0
  45. virtool_workflow-0.0.0.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,66 @@
1
+ """Command Line Interface to virtool_workflow"""
2
+ import asyncio
3
+ from pathlib import Path
4
+
5
+ import click
6
+
7
+ from virtool_workflow.runtime.run import start_runtime
8
+
9
+
10
+ @click.option(
11
+ "--dev",
12
+ help="Run in development mode.",
13
+ is_flag=True,
14
+ )
15
+ @click.option(
16
+ "--jobs-api-connection-string",
17
+ help="The URL of the jobs API.",
18
+ default="https://localhost:9950",
19
+ )
20
+ @click.option(
21
+ "--mem",
22
+ help="The amount of memory to use in GB.",
23
+ type=int,
24
+ default=8,
25
+ )
26
+ @click.option(
27
+ "--proc",
28
+ help="The number of processes to use.",
29
+ type=int,
30
+ default=2,
31
+ )
32
+ @click.option(
33
+ "--redis-connection-string",
34
+ help="The URL for connecting to Redis.",
35
+ default="redis://localhost:6317",
36
+ )
37
+ @click.option(
38
+ "--redis-list-name",
39
+ help="The name of the Redis list to watch for incoming jobs.",
40
+ required=True,
41
+ )
42
+ @click.option(
43
+ "--sentry-dsn",
44
+ help="A Sentry DSN. Sentry will not be configured if no value is provided.",
45
+ default=None,
46
+ )
47
+ @click.option(
48
+ "--timeout",
49
+ help="Maximum time to wait for an incoming job",
50
+ default=1000,
51
+ )
52
+ @click.option(
53
+ "--work-path",
54
+ default="temp",
55
+ help="The path where temporary files will be stored.",
56
+ type=click.Path(path_type=Path),
57
+ )
58
+ @click.command()
59
+ def run_workflow(**kwargs):
60
+ """Run a workflow."""
61
+ asyncio.run(start_runtime(**kwargs))
62
+
63
+
64
+ def cli_main():
65
+ """Main pip entrypoint."""
66
+ run_workflow(auto_envvar_prefix="VT")
@@ -0,0 +1,22 @@
1
+ from virtool_workflow.data.analyses import analysis
2
+ from virtool_workflow.analysis.fastqc import fastqc
3
+ from virtool_workflow.data.hmms import hmms
4
+ from virtool_workflow.data.indexes import index
5
+ from virtool_workflow.data.jobs import job, push_status
6
+ from virtool_workflow.data.ml import ml
7
+ from virtool_workflow.data.samples import sample
8
+ from virtool_workflow.data.subtractions import subtractions
9
+ from virtool_workflow.data.uploads import uploads
10
+
11
+ __all__ = [
12
+ "analysis",
13
+ "fastqc",
14
+ "hmms",
15
+ "index",
16
+ "job",
17
+ "ml",
18
+ "push_status",
19
+ "sample",
20
+ "subtractions",
21
+ "uploads",
22
+ ]
@@ -0,0 +1,106 @@
1
+ """A fixture and class for representing the analysis associated with a workflow run."""
2
+
3
+ from pathlib import Path
4
+ from typing import Any
5
+
6
+ from pyfixtures import fixture
7
+ from virtool.analyses.models import Analysis, AnalysisSample
8
+ from virtool.indexes.models import IndexNested
9
+ from virtool.jobs.models import JobNested
10
+ from virtool.ml.models import MLModelRelease
11
+ from virtool.references.models import ReferenceNested
12
+ from virtool.subtractions.models import SubtractionNested
13
+
14
+ from virtool_workflow.api.client import APIClient
15
+ from virtool_workflow.files import VirtoolFileFormat
16
+
17
+
18
+ class WFAnalysis:
19
+ """The Virtool analysis being populated by the running workflow."""
20
+
21
+ def __init__(
22
+ self,
23
+ api: APIClient,
24
+ analysis_id: str,
25
+ index: IndexNested,
26
+ ml: MLModelRelease | None,
27
+ reference: ReferenceNested,
28
+ sample: AnalysisSample,
29
+ subtractions: list[SubtractionNested],
30
+ workflow: str,
31
+ ):
32
+ self._api = api
33
+
34
+ self.id = analysis_id
35
+ """The unique ID for the analysis."""
36
+
37
+ self.index = index
38
+ """The index being used for the analysis."""
39
+
40
+ self.ml = ml
41
+ """The ML model release being used for the analysis."""
42
+
43
+ self.reference = reference
44
+ """The reference being used for the analysis."""
45
+
46
+ self.sample = sample
47
+ """The parent sample for the analysis."""
48
+
49
+ self.subtractions = subtractions
50
+ """The subtractions being used for the analysis."""
51
+
52
+ self.workflow = workflow
53
+ """The workflow being run to populate the analysis."""
54
+
55
+ async def delete(self):
56
+ """Delete the analysis.
57
+
58
+ This method should be called if the workflow fails before a result is uploaded.
59
+ """
60
+ await self._api.delete(f"/analyses/{self.id}")
61
+
62
+ async def upload_file(self, path: Path, fmt: VirtoolFileFormat = "unknown"):
63
+ """Upload files in the workflow environment that should be associated with the
64
+ current analysis.
65
+
66
+ :param path: the path to the file to upload
67
+ :param fmt: the file format
68
+
69
+ """
70
+ await self._api.post_file(
71
+ f"/analyses/{self.id}/files",
72
+ path,
73
+ fmt,
74
+ )
75
+
76
+ async def upload_result(self, results: dict[str, Any]):
77
+ """Upload the results dict for the analysis.
78
+
79
+ :param results: the analysis results
80
+ """
81
+ await self._api.patch_json(f"/analyses/{self.id}", {"results": results})
82
+
83
+
84
+ @fixture
85
+ async def analysis(
86
+ _api: APIClient,
87
+ job: JobNested,
88
+ ) -> WFAnalysis:
89
+ """A :class:`.WFAnalysis` object that represents the analysis associated with the running
90
+ workflow.
91
+ """
92
+ id_ = job.args["analysis_id"]
93
+
94
+ analysis_dict = await _api.get_json(f"/analyses/{id_}")
95
+ analysis = Analysis(**analysis_dict)
96
+
97
+ return WFAnalysis(
98
+ api=_api,
99
+ analysis_id=id_,
100
+ index=analysis.index,
101
+ ml=analysis.ml,
102
+ reference=analysis.reference,
103
+ sample=analysis.sample,
104
+ subtractions=analysis.subtractions,
105
+ workflow=job.workflow,
106
+ )
@@ -0,0 +1,109 @@
1
+ """A class and fixture for accessing Virtool HMM data for use in analysis workflows."""
2
+
3
+ import asyncio
4
+ import json
5
+ from dataclasses import dataclass
6
+ from functools import cached_property
7
+ from pathlib import Path
8
+ from shutil import which
9
+
10
+ import aiofiles
11
+ from pyfixtures import fixture
12
+ from virtool.hmm.models import HMM
13
+ from virtool.utils import decompress_file
14
+
15
+ from virtool_workflow.api.client import APIClient
16
+ from virtool_workflow.runtime.run_subprocess import RunSubprocess
17
+
18
+
19
+ @dataclass
20
+ class WFHMMs:
21
+ """A class that exposes:
22
+
23
+ 1. A :class:`dict` the links `HMMER <http://hmmer.org/>`_ cluster IDs to Virtool
24
+ annotation IDs.
25
+ 2. The path to the HMM profiles file.
26
+
27
+ """
28
+
29
+ annotations: list[HMM]
30
+ """All annotations in the HMM dataset."""
31
+
32
+ path: Path
33
+ """
34
+ The path to the ``profiles.hmm`` file in the ``work_path`` of the running
35
+ workflow.
36
+ """
37
+
38
+ @cached_property
39
+ def cluster_annotation_map(self) -> dict[int, str]:
40
+ """A :class:`dict` that maps cluster IDs used to identify HMMs in
41
+ `HMMER <http://hmmer.org/>`_ to annotation IDs used in Virtool.
42
+ """
43
+ return {hmm.cluster: hmm.id for hmm in self.annotations}
44
+
45
+ @property
46
+ def profiles_path(self):
47
+ """The path to the ``profiles.hmm`` file.
48
+
49
+ It can be provided directly to HMMER.
50
+ """
51
+ return self.path / "profiles.hmm"
52
+
53
+ def get_id_by_cluster(self, cluster: int) -> str:
54
+ """Get the Virtool HMM annotation ID for a given cluster ID.
55
+
56
+ :param cluster: a cluster ID
57
+ :return: the corresponding annotation ID
58
+ """
59
+ return self.cluster_annotation_map[cluster]
60
+
61
+
62
+ @fixture
63
+ async def hmms(
64
+ _api: APIClient,
65
+ proc: int,
66
+ run_subprocess: RunSubprocess,
67
+ work_path: Path,
68
+ ):
69
+ """A fixture for accessing HMM data.
70
+
71
+ The ``*.hmm`` file is copied from the data directory and ``hmmpress`` is run to
72
+ create all the HMM files.
73
+
74
+ Returns an :class:`.HMMs` object containing the path to the HMM profile file and a
75
+ `dict` that maps HMM cluster numbers to database IDs.
76
+
77
+ :raises: :class:`RuntimeError`: hmmpress is not installed
78
+ :raises: :class:`RuntimeError`: hmmpress command failed
79
+
80
+ """
81
+ if await asyncio.to_thread(which, "hmmpress") is None:
82
+ raise RuntimeError("hmmpress is not installed")
83
+
84
+ hmms_path = work_path / "hmms"
85
+ await asyncio.to_thread(hmms_path.mkdir, parents=True, exist_ok=True)
86
+
87
+ compressed_annotations_path = hmms_path / "annotations.json.gz"
88
+ await _api.get_file("/hmms/files/annotations.json.gz", compressed_annotations_path)
89
+
90
+ annotations_path = hmms_path / "annotations.json"
91
+ await asyncio.to_thread(
92
+ decompress_file,
93
+ compressed_annotations_path,
94
+ annotations_path,
95
+ proc,
96
+ )
97
+
98
+ profiles_path = hmms_path / "profiles.hmm"
99
+ await _api.get_file("/hmms/files/profiles.hmm", profiles_path)
100
+
101
+ async with aiofiles.open(annotations_path) as f:
102
+ annotations = [HMM(**hmm) for hmm in json.loads(await f.read())]
103
+
104
+ p = await run_subprocess(["hmmpress", str(profiles_path)])
105
+
106
+ if p.returncode != 0:
107
+ raise RuntimeError("hmmpress command failed")
108
+
109
+ return WFHMMs(annotations, hmms_path)
@@ -0,0 +1,319 @@
1
+ import asyncio
2
+ import json
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+
6
+ import aiofiles
7
+ from pyfixtures import fixture
8
+ from structlog import get_logger
9
+ from virtool.analyses.models import Analysis
10
+ from virtool.indexes.models import Index
11
+ from virtool.jobs.models import Job
12
+ from virtool.references.models import ReferenceNested
13
+ from virtool.utils import decompress_file
14
+
15
+ from virtool_workflow.api.client import APIClient
16
+ from virtool_workflow.errors import MissingJobArgumentError
17
+ from virtool_workflow.files import VirtoolFileFormat
18
+
19
+ logger = get_logger("api")
20
+
21
+
22
+ @dataclass
23
+ class WFIndex:
24
+ """Represents a Virtool reference index for use in analysis workflows."""
25
+
26
+ id: str
27
+ """The ID of the index."""
28
+
29
+ path: Path
30
+ """The path to the index directory in the workflow's work directory."""
31
+
32
+ manifest: dict[str, int | str]
33
+ """The manifest (OTU ID: OTU Version) for the index."""
34
+
35
+ reference: ReferenceNested
36
+ """The parent reference."""
37
+
38
+ sequence_lengths: dict[str, int]
39
+ """A dictionary of the lengths of all sequences keyed by their IDs."""
40
+
41
+ sequence_otu_map: dict[str, str]
42
+ """A dictionary of the OTU IDs for all sequences keyed by their sequence IDs."""
43
+
44
+ @property
45
+ def bowtie_path(self) -> Path:
46
+ """The path to the Bowtie2 index prefix for the Virtool index."""
47
+ return self.path / "reference"
48
+
49
+ @property
50
+ def fasta_path(self) -> Path:
51
+ """The path to the complete FASTA file for the reference index in the workflow's
52
+ work directory.
53
+
54
+ """
55
+ return self.path / "ref.fa"
56
+
57
+ @property
58
+ def json_path(self) -> Path:
59
+ """The path to the JSON representation of the reference index in the workflow's
60
+ work directory.
61
+
62
+ """
63
+ return self.path / "otus.json"
64
+
65
+ def get_otu_id_by_sequence_id(self, sequence_id: str) -> str:
66
+ """Get the ID of the parent OTU for the given ``sequence_id``.
67
+
68
+ :param sequence_id: the sequence ID
69
+ :return: the matching OTU ID
70
+
71
+ """
72
+ try:
73
+ return self.sequence_otu_map[sequence_id]
74
+ except KeyError:
75
+ raise ValueError("The sequence_id does not exist in the index")
76
+
77
+ def get_sequence_length(self, sequence_id: str) -> int:
78
+ """Get the sequence length for the given ``sequence_id``.
79
+
80
+ :param sequence_id: the sequence ID
81
+ :return: the length of the sequence
82
+
83
+ """
84
+ try:
85
+ return self.sequence_lengths[sequence_id]
86
+ except KeyError:
87
+ raise ValueError("The sequence_id does not exist in the index")
88
+
89
+ async def write_isolate_fasta(
90
+ self,
91
+ otu_ids: list[str],
92
+ path: Path,
93
+ ) -> dict[str, int]:
94
+ """Generate a FASTA file for all the isolates of the OTUs specified by ``otu_ids``.
95
+
96
+ :param otu_ids: the list of OTU IDs for which to generate and index
97
+ :param path: the path to the reference index directory
98
+ :return: a dictionary of the lengths of all sequences keyed by their IDS
99
+
100
+ """
101
+ unique_otu_ids = set(otu_ids)
102
+
103
+ def func():
104
+ with open(self.json_path) as f:
105
+ otus = [otu for otu in json.load(f) if otu["_id"] in unique_otu_ids]
106
+
107
+ lengths = {}
108
+
109
+ with open(path, "w") as f:
110
+ for otu in otus:
111
+ for isolate in otu["isolates"]:
112
+ for sequence in isolate["sequences"]:
113
+ f.write(f">{sequence['_id']}\n{sequence['sequence']}\n")
114
+ lengths[sequence["_id"]] = len(sequence["sequence"])
115
+
116
+ return lengths
117
+
118
+ return await asyncio.to_thread(func)
119
+
120
+
121
+ class WFNewIndex:
122
+ def __init__(
123
+ self,
124
+ api: APIClient,
125
+ index_id: str,
126
+ manifest: dict[str, int | str],
127
+ path: Path,
128
+ reference: ReferenceNested,
129
+ ):
130
+ self._api = api
131
+
132
+ self.id = index_id
133
+ """The ID of the index."""
134
+
135
+ self.manifest = manifest
136
+ """The manifest (OTU ID: OTU Version) for the index."""
137
+
138
+ self.path = path
139
+ """The path to the index directory in the workflow's work directory."""
140
+
141
+ self.reference = reference
142
+ """The parent reference."""
143
+
144
+ async def delete(self):
145
+ await self._api.delete(f"/indexes/{self.id}")
146
+
147
+ async def finalize(self):
148
+ """Finalize the current index."""
149
+ await self._api.patch_json(f"/indexes/{self.id}", {})
150
+
151
+ async def upload(
152
+ self,
153
+ path: Path,
154
+ fmt: VirtoolFileFormat = "fasta",
155
+ name: str | None = None,
156
+ ):
157
+ """Upload a file to associate with the index being built.
158
+
159
+ Allowed file names are:
160
+
161
+ - reference.json.gz
162
+ - reference.fa.gz
163
+ - reference.1.bt2
164
+ - reference.2.bt2
165
+ - reference.3.bt2
166
+ - reference.4.bt4
167
+ - reference.rev.1.bt2
168
+ - reference.rev.2.bt2
169
+
170
+ :param path: The path to the file.
171
+ :param fmt: The format of the file.
172
+ :param name: An optional name for the file different that its name on disk.
173
+ :return: A :class:`VirtoolFile` object.
174
+ """
175
+ return await self._api.put_file(
176
+ f"/indexes/{self.id}/files/{name or path.name}",
177
+ path,
178
+ fmt,
179
+ )
180
+
181
+ @property
182
+ def otus_json_path(self) -> Path:
183
+ """The path to the JSON representation of the reference index in the workflow's
184
+ work directory.
185
+
186
+ """
187
+ return self.path / "otus.json"
188
+
189
+
190
+ @fixture
191
+ async def index(
192
+ _api: APIClient,
193
+ analysis: Analysis,
194
+ proc: int,
195
+ work_path: Path,
196
+ ) -> WFIndex:
197
+ """The :class:`WFIndex` for the current analysis job."""
198
+ id_ = analysis.index.id
199
+
200
+ log = logger.bind(id=id_, resource="index")
201
+
202
+ log.info("loading index")
203
+
204
+ index_json = await _api.get_json(f"/indexes/{id_}")
205
+ index_ = Index(**index_json)
206
+
207
+ log.info("got index json")
208
+
209
+ index_work_path = work_path / "indexes" / index_.id
210
+ await asyncio.to_thread(index_work_path.mkdir, parents=True, exist_ok=True)
211
+
212
+ log.info("created index directory")
213
+
214
+ for name in (
215
+ "otus.json.gz",
216
+ "reference.json.gz",
217
+ "reference.fa.gz",
218
+ "reference.1.bt2",
219
+ "reference.2.bt2",
220
+ "reference.3.bt2",
221
+ "reference.4.bt2",
222
+ "reference.rev.1.bt2",
223
+ "reference.rev.2.bt2",
224
+ ):
225
+ await _api.get_file(f"/indexes/{id_}/files/{name}", index_work_path / name)
226
+ log.info("downloaded index file", name=name)
227
+
228
+ await asyncio.to_thread(
229
+ decompress_file,
230
+ index_work_path / "reference.fa.gz",
231
+ index_work_path / "reference.fa",
232
+ proc,
233
+ )
234
+
235
+ log.info("decompressed reference fasta")
236
+
237
+ json_path = index_work_path / "otus.json"
238
+
239
+ await asyncio.to_thread(
240
+ decompress_file,
241
+ index_work_path / "otus.json.gz",
242
+ index_work_path / json_path,
243
+ proc,
244
+ )
245
+
246
+ log.info("decompressed reference otus json")
247
+
248
+ async with aiofiles.open(json_path) as f:
249
+ data = json.loads(await f.read())
250
+
251
+ sequence_lengths = {}
252
+ sequence_otu_map = {}
253
+
254
+ for otu in data:
255
+ for isolate in otu["isolates"]:
256
+ for sequence in isolate["sequences"]:
257
+ sequence_id = sequence["_id"]
258
+
259
+ sequence_otu_map[sequence_id] = otu["_id"]
260
+ sequence_lengths[sequence_id] = len(sequence["sequence"])
261
+
262
+ log.info("parsed and loaded maps from otus json")
263
+
264
+ return WFIndex(
265
+ id=id_,
266
+ path=index_work_path,
267
+ manifest=index_.manifest,
268
+ reference=index_.reference,
269
+ sequence_lengths=sequence_lengths,
270
+ sequence_otu_map=sequence_otu_map,
271
+ )
272
+
273
+
274
+ @fixture
275
+ async def new_index(
276
+ _api: APIClient,
277
+ job: Job,
278
+ proc: int,
279
+ work_path: Path,
280
+ ) -> WFNewIndex:
281
+ """The :class:`.WFNewIndex` for an index being created by the current job."""
282
+ try:
283
+ id_ = job.args["index_id"]
284
+ except KeyError:
285
+ raise MissingJobArgumentError("Missing jobs args key 'index_id'")
286
+
287
+ log = logger.bind(resource="new_index", id=id_, job_id=job.id)
288
+ log.info("loading index")
289
+
290
+ index_json = await _api.get_json(f"/indexes/{id_}")
291
+ index_ = Index(**index_json)
292
+
293
+ log.info("got index json")
294
+
295
+ index_work_path = work_path / "indexes" / index_.id
296
+ await asyncio.to_thread(index_work_path.mkdir, parents=True, exist_ok=True)
297
+
298
+ log.info("created index directory")
299
+
300
+ compressed_otus_json_path = index_work_path / "otus.json.gz"
301
+ await _api.get_file(f"/indexes/{id_}/files/otus.json.gz", compressed_otus_json_path)
302
+ log.info("downloaded otus json")
303
+
304
+ await asyncio.to_thread(
305
+ decompress_file,
306
+ compressed_otus_json_path,
307
+ index_work_path / "otus.json",
308
+ processes=proc,
309
+ )
310
+
311
+ log.info("decompressed otus json")
312
+
313
+ return WFNewIndex(
314
+ api=_api,
315
+ index_id=id_,
316
+ manifest=index_.manifest,
317
+ path=index_work_path,
318
+ reference=index_.reference,
319
+ )
@@ -0,0 +1,62 @@
1
+ import traceback
2
+
3
+ from pyfixtures import fixture
4
+ from structlog import get_logger
5
+ from virtool.jobs.models import JobAcquired, Job, JobState
6
+
7
+ from virtool_workflow import Workflow, WorkflowStep
8
+ from virtool_workflow.api.client import APIClient
9
+
10
+ MAX_TB = 50
11
+
12
+ logger = get_logger("api")
13
+
14
+
15
+ @fixture
16
+ async def job(_api: APIClient, _job: JobAcquired) -> Job:
17
+ return Job.parse_obj(_job)
18
+
19
+
20
+ @fixture(scope="function")
21
+ async def push_status(
22
+ _api: APIClient,
23
+ _job: JobAcquired,
24
+ _error: Exception | None,
25
+ _state: JobState,
26
+ _step: WorkflowStep | None,
27
+ _workflow: Workflow,
28
+ ):
29
+ error = None
30
+
31
+ if _error:
32
+ error = {
33
+ "type": _error.__class__.__name__,
34
+ "traceback": traceback.format_tb(_error.__traceback__, MAX_TB),
35
+ "details": [str(arg) for arg in _error.args],
36
+ }
37
+
38
+ logger.critical("reporting error to api", error=_error)
39
+
40
+ if _state in (JobState.WAITING, JobState.PREPARING):
41
+ progress = 0
42
+ elif _state == JobState.COMPLETE:
43
+ progress = 100
44
+ else:
45
+ progress = (100 // len(_workflow.steps)) * _workflow.steps.index(_step)
46
+
47
+ step_name = _step.display_name if _step is not None else ""
48
+
49
+ payload = {
50
+ "error": error,
51
+ "progress": progress,
52
+ "stage": _step.function.__name__ if _step is not None else "",
53
+ "state": _state.value,
54
+ "step_description": _step.description if _step is not None else "",
55
+ "step_name": step_name,
56
+ }
57
+
58
+ async def func():
59
+ await _api.post_json(f"/jobs/{_job.id}/status", payload)
60
+ logger.info("reported status to api", step=step_name, state=_state)
61
+
62
+ return func