dwave-sona-core 0.2.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. dwave_sona_core-0.2.7/PKG-INFO +145 -0
  2. dwave_sona_core-0.2.7/README.md +124 -0
  3. dwave_sona_core-0.2.7/pyproject.toml +24 -0
  4. dwave_sona_core-0.2.7/sona/__init__.py +62 -0
  5. dwave_sona_core-0.2.7/sona/__main__.py +4 -0
  6. dwave_sona_core-0.2.7/sona/core/consumers/__init__.py +18 -0
  7. dwave_sona_core-0.2.7/sona/core/consumers/base.py +11 -0
  8. dwave_sona_core-0.2.7/sona/core/consumers/kafka.py +30 -0
  9. dwave_sona_core-0.2.7/sona/core/consumers/redis.py +36 -0
  10. dwave_sona_core-0.2.7/sona/core/consumers/sqs.py +34 -0
  11. dwave_sona_core-0.2.7/sona/core/messages/__init__.py +8 -0
  12. dwave_sona_core-0.2.7/sona/core/messages/base.py +15 -0
  13. dwave_sona_core-0.2.7/sona/core/messages/context.py +46 -0
  14. dwave_sona_core-0.2.7/sona/core/messages/file.py +9 -0
  15. dwave_sona_core-0.2.7/sona/core/messages/job.py +50 -0
  16. dwave_sona_core-0.2.7/sona/core/messages/result.py +17 -0
  17. dwave_sona_core-0.2.7/sona/core/messages/state.py +33 -0
  18. dwave_sona_core-0.2.7/sona/core/producers/__init__.py +25 -0
  19. dwave_sona_core-0.2.7/sona/core/producers/base.py +7 -0
  20. dwave_sona_core-0.2.7/sona/core/producers/kafka.py +23 -0
  21. dwave_sona_core-0.2.7/sona/core/producers/mock.py +8 -0
  22. dwave_sona_core-0.2.7/sona/core/producers/redis.py +15 -0
  23. dwave_sona_core-0.2.7/sona/core/producers/sqs.py +12 -0
  24. dwave_sona_core-0.2.7/sona/core/storages/__init__.py +13 -0
  25. dwave_sona_core-0.2.7/sona/core/storages/base.py +51 -0
  26. dwave_sona_core-0.2.7/sona/core/storages/local.py +32 -0
  27. dwave_sona_core-0.2.7/sona/core/storages/s3.py +53 -0
  28. dwave_sona_core-0.2.7/sona/core/utils/__init__.py +0 -0
  29. dwave_sona_core-0.2.7/sona/core/utils/cls_utils.py +13 -0
  30. dwave_sona_core-0.2.7/sona/core/utils/dict_utils.py +7 -0
  31. dwave_sona_core-0.2.7/sona/inferencers/__init__.py +3 -0
  32. dwave_sona_core-0.2.7/sona/inferencers/base.py +35 -0
  33. dwave_sona_core-0.2.7/sona/inferencers/mock.py +64 -0
  34. dwave_sona_core-0.2.7/sona/settings.py +36 -0
  35. dwave_sona_core-0.2.7/sona/workers/__init__.py +4 -0
  36. dwave_sona_core-0.2.7/sona/workers/base.py +66 -0
  37. dwave_sona_core-0.2.7/sona/workers/inferencer.py +73 -0
@@ -0,0 +1,145 @@
1
+ Metadata-Version: 2.1
2
+ Name: dwave-sona-core
3
+ Version: 0.2.7
4
+ Summary:
5
+ Author: dwave-dev
6
+ Requires-Python: >=3.7,<4.0
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: Programming Language :: Python :: 3.7
9
+ Classifier: Programming Language :: Python :: 3.8
10
+ Classifier: Programming Language :: Python :: 3.9
11
+ Classifier: Programming Language :: Python :: 3.10
12
+ Classifier: Programming Language :: Python :: 3.11
13
+ Requires-Dist: boto3 (>=1.26.85,<2.0.0)
14
+ Requires-Dist: confluent-kafka (>=2.1.1,<3.0.0)
15
+ Requires-Dist: loguru (>=0.6.0,<0.7.0)
16
+ Requires-Dist: pydantic[dotenv] (>=1.10.5,<2.0.0)
17
+ Requires-Dist: redis (>=4.5.4,<5.0.0)
18
+ Requires-Dist: typer (>=0.7.0,<0.8.0)
19
+ Description-Content-Type: text/markdown
20
+
21
+ # Dwave SONA Core
22
+
23
+ 迪威智能 SONA 服務專用核心開發套件
24
+
25
+ ## 已使用模組
26
+ ### Inferencers
27
+ | 名稱 | 連結 |
28
+ | ------ | ---------------------------------------------------------------------------------------------------------------------- |
29
+ | 轉檔工具 | [dwave-sona-inferencer-media-tools](https://github.com/DeepWaveInc/dwave-sona-inferencer-media-tools) |
30
+ | 字幕產生 | [dwave-sona-inferencer-captioner](https://github.com/DeepWaveInc/dwave-sona-inferencer-captioner) |
31
+ | 人聲降噪 | [dwave-sona-inferencer-denoise](https://github.com/DeepWaveInc/dwave-sona-inferencer-denoise) |
32
+ | 語者分離 | [dwave-sona-inferencer-vad](https://github.com/DeepWaveInc/dwave-sona-inferencer-vad) |
33
+ | 音軌分離 | [dwave-sona-inferencer-svs](https://github.com/DeepWaveInc/dwave-sona-inferencer-svs) |
34
+ | 全曲轉譜 | [dwave-sona-inferencer-singing-transcription](https://github.com/DeepWaveInc/dwave-sona-inferencer-singing-transcription) |
35
+ | 字幕摘要 | [dwave-sona-inferencer-summarizer](https://github.com/DeepWaveInc/dwave-sona-inferencer-summarizer) |
36
+
37
+ ### Workers
38
+ | 名稱 | 連結 |
39
+ | ------ | ------------------------------------------------------------------------------------------ |
40
+ | 控制中心 |[dwave-sona-worker-controller](https://github.com/DeepWaveInc/dwave-sona-worker-controller) |
41
+ | 郵件寄送 |[dwave-sona-worker-mail](https://github.com/DeepWaveInc/dwave-sona-worker-mail) |
42
+
43
+
44
+ ### Cronjobs
45
+ | 名稱 | 連結 |
46
+ | ------- | -------------------------------------------------------------------------------- |
47
+ | 數發部排程 |[dwave-sona-cronjob-moda](https://github.com/DeepWaveInc/dwave-sona-cronjob-moda)|
48
+
49
+
50
+ ## 安裝與使用
51
+
52
+ ### 開發環境需求
53
+
54
+ - Python 3.8 或更新版本
55
+ - poetry
56
+
57
+ ### 安裝與使用
58
+
59
+ 1. 環境建構
60
+
61
+ ```sh
62
+ $ pip install poetry
63
+ ```
64
+
65
+ 2. 下載與安裝
66
+
67
+ ```sh
68
+ $ export PYTHON_KEYRING_BACKEND=keyring.backends.null.Keyring
69
+ $ poetry add git+ssh://git@github.com/DeepWaveInc/dwave-sona-core.git
70
+ ```
71
+
72
+ 3. 撰寫 Inferencer 模組
73
+
74
+ ```python
75
+ # example/basic.py
76
+ from pathlib import Path
77
+ from typing import Dict, List
78
+
79
+ from loguru import logger
80
+ from sona.core.messages import Context, File, Job, Result
81
+ from sona.inferencers import InferencerBas
82
+
83
+
84
+ class BasicExample(InferencerBase):
85
+ inferencer = "basic" # Inferencer 名稱
86
+
87
+ def on_load(self) -> None:
88
+ """
89
+ 載入函式,Worker 啟動時呼叫
90
+ """
91
+ logger.info(f"Download {self.__class__.__name__} models...")
92
+
93
+ def inference(self, params: Dict, files: List[File]) -> Result:
94
+ """
95
+ 訊息處理函式,Worker 接受到新訊息後呼叫
96
+ :param params: 處理參數
97
+ :param files: 處理檔案
98
+ :return: 處理結果
99
+ """
100
+ logger.info(f"Get params {params}")
101
+ logger.info(f"Get files {files}")
102
+
103
+ filname = "output.wav"
104
+ Path(filname).touch(exist_ok=True)
105
+ return Result(
106
+ files=[File(label="output", path=filname)],
107
+ data={"data_key": "data_val"},
108
+ )
109
+
110
+ def context_example(self) -> Context:
111
+ """
112
+ 範例訊息,供 Worker 測試及 API 開發時參考使用
113
+ :return: 範例訊息
114
+ """
115
+ filname = "input.wav"
116
+ Path(filname).touch(exist_ok=True)
117
+
118
+ params = {"param_key": "param_val"}
119
+ files = [File(label="input", path=filname)]
120
+ return Context(
121
+ jobs=[
122
+ Job(
123
+ name="basic_job",
124
+ topic=self.get_topic(),
125
+ params=params,
126
+ files=files,
127
+ )
128
+ ]
129
+ )
130
+
131
+ ```
132
+
133
+ 4. Worker 測試
134
+
135
+ ```sh
136
+ $ poetry run python -m sona inferencer test inferencer.basic.BasicExample
137
+ 2023-03-22 02:52:27.392 | INFO | sona.workers.inferencer:on_load:33 - Loading inferencer: basic
138
+ 2023-03-22 02:52:27.392 | INFO | inferencer.basic:on_load:13 - Download BasicExample models...
139
+ 2023-03-22 02:52:27.392 | INFO | sona.workers.inferencer:on_load:36 - Susbcribe on sona.worker.inferencer.basic(MockConsumer)
140
+ 2023-03-22 02:52:27.392 | INFO | sona.workers.inferencer:on_context:42 - [sona.worker.inferencer.basic] recv: {"id": "5fe59e8bcd4b4efb84462cdbcad4a3b4", "header": {}, "jobs": [{"name": "basic_job", "topic": "sona.worker.inferencer.basic", "params": {"param_key": "param_val"}, "files": [{"label": "input", "path": "input.wav"}], "extra_params": {}, "extra_files": {}}], "fallbacks": [], "results": {}, "states": []}
141
+ 2023-03-22 02:52:27.392 | INFO | inferencer.basic:inference:16 - Get params {'param_key': 'param_val'}
142
+ 2023-03-22 02:52:27.392 | INFO | inferencer.basic:inference:17 - Get files [File(label='input', path='input.wav')]
143
+ 2023-03-22 02:52:27.393 | INFO | sona.workers.inferencer:on_context:59 - [sona.worker.inferencer.basic] success: {"id": "5fe59e8bcd4b4efb84462cdbcad4a3b4", "header": {}, "jobs": [{"name": "basic_job", "topic": "sona.worker.inferencer.basic", "params": {"param_key": "param_val"}, "files": [{"label": "input", "path": "input.wav"}], "extra_params": {}, "extra_files": {}}], "fallbacks": [], "results": {"basic_job": {"files": [{"label": "output", "path": "output.wav"}], "data": {"data_key": "data_val"}}}, "states": [{"job_name": "basic_job", "exec_time": 0.00029676100000000015, "exception": {}}]}
144
+ ```
145
+
@@ -0,0 +1,124 @@
1
+ # Dwave SONA Core
2
+
3
+ 迪威智能 SONA 服務專用核心開發套件
4
+
5
+ ## 已使用模組
6
+ ### Inferencers
7
+ | 名稱 | 連結 |
8
+ | ------ | ---------------------------------------------------------------------------------------------------------------------- |
9
+ | 轉檔工具 | [dwave-sona-inferencer-media-tools](https://github.com/DeepWaveInc/dwave-sona-inferencer-media-tools) |
10
+ | 字幕產生 | [dwave-sona-inferencer-captioner](https://github.com/DeepWaveInc/dwave-sona-inferencer-captioner) |
11
+ | 人聲降噪 | [dwave-sona-inferencer-denoise](https://github.com/DeepWaveInc/dwave-sona-inferencer-denoise) |
12
+ | 語者分離 | [dwave-sona-inferencer-vad](https://github.com/DeepWaveInc/dwave-sona-inferencer-vad) |
13
+ | 音軌分離 | [dwave-sona-inferencer-svs](https://github.com/DeepWaveInc/dwave-sona-inferencer-svs) |
14
+ | 全曲轉譜 | [dwave-sona-inferencer-singing-transcription](https://github.com/DeepWaveInc/dwave-sona-inferencer-singing-transcription) |
15
+ | 字幕摘要 | [dwave-sona-inferencer-summarizer](https://github.com/DeepWaveInc/dwave-sona-inferencer-summarizer) |
16
+
17
+ ### Workers
18
+ | 名稱 | 連結 |
19
+ | ------ | ------------------------------------------------------------------------------------------ |
20
+ | 控制中心 |[dwave-sona-worker-controller](https://github.com/DeepWaveInc/dwave-sona-worker-controller) |
21
+ | 郵件寄送 |[dwave-sona-worker-mail](https://github.com/DeepWaveInc/dwave-sona-worker-mail) |
22
+
23
+
24
+ ### Cronjobs
25
+ | 名稱 | 連結 |
26
+ | ------- | -------------------------------------------------------------------------------- |
27
+ | 數發部排程 |[dwave-sona-cronjob-moda](https://github.com/DeepWaveInc/dwave-sona-cronjob-moda)|
28
+
29
+
30
+ ## 安裝與使用
31
+
32
+ ### 開發環境需求
33
+
34
+ - Python 3.8 或更新版本
35
+ - poetry
36
+
37
+ ### 安裝與使用
38
+
39
+ 1. 環境建構
40
+
41
+ ```sh
42
+ $ pip install poetry
43
+ ```
44
+
45
+ 2. 下載與安裝
46
+
47
+ ```sh
48
+ $ export PYTHON_KEYRING_BACKEND=keyring.backends.null.Keyring
49
+ $ poetry add git+ssh://git@github.com/DeepWaveInc/dwave-sona-core.git
50
+ ```
51
+
52
+ 3. 撰寫 Inferencer 模組
53
+
54
+ ```python
55
+ # example/basic.py
56
+ from pathlib import Path
57
+ from typing import Dict, List
58
+
59
+ from loguru import logger
60
+ from sona.core.messages import Context, File, Job, Result
61
+ from sona.inferencers import InferencerBas
62
+
63
+
64
+ class BasicExample(InferencerBase):
65
+ inferencer = "basic" # Inferencer 名稱
66
+
67
+ def on_load(self) -> None:
68
+ """
69
+ 載入函式,Worker 啟動時呼叫
70
+ """
71
+ logger.info(f"Download {self.__class__.__name__} models...")
72
+
73
+ def inference(self, params: Dict, files: List[File]) -> Result:
74
+ """
75
+ 訊息處理函式,Worker 接受到新訊息後呼叫
76
+ :param params: 處理參數
77
+ :param files: 處理檔案
78
+ :return: 處理結果
79
+ """
80
+ logger.info(f"Get params {params}")
81
+ logger.info(f"Get files {files}")
82
+
83
+ filname = "output.wav"
84
+ Path(filname).touch(exist_ok=True)
85
+ return Result(
86
+ files=[File(label="output", path=filname)],
87
+ data={"data_key": "data_val"},
88
+ )
89
+
90
+ def context_example(self) -> Context:
91
+ """
92
+ 範例訊息,供 Worker 測試及 API 開發時參考使用
93
+ :return: 範例訊息
94
+ """
95
+ filname = "input.wav"
96
+ Path(filname).touch(exist_ok=True)
97
+
98
+ params = {"param_key": "param_val"}
99
+ files = [File(label="input", path=filname)]
100
+ return Context(
101
+ jobs=[
102
+ Job(
103
+ name="basic_job",
104
+ topic=self.get_topic(),
105
+ params=params,
106
+ files=files,
107
+ )
108
+ ]
109
+ )
110
+
111
+ ```
112
+
113
+ 4. Worker 測試
114
+
115
+ ```sh
116
+ $ poetry run python -m sona inferencer test inferencer.basic.BasicExample
117
+ 2023-03-22 02:52:27.392 | INFO | sona.workers.inferencer:on_load:33 - Loading inferencer: basic
118
+ 2023-03-22 02:52:27.392 | INFO | inferencer.basic:on_load:13 - Download BasicExample models...
119
+ 2023-03-22 02:52:27.392 | INFO | sona.workers.inferencer:on_load:36 - Susbcribe on sona.worker.inferencer.basic(MockConsumer)
120
+ 2023-03-22 02:52:27.392 | INFO | sona.workers.inferencer:on_context:42 - [sona.worker.inferencer.basic] recv: {"id": "5fe59e8bcd4b4efb84462cdbcad4a3b4", "header": {}, "jobs": [{"name": "basic_job", "topic": "sona.worker.inferencer.basic", "params": {"param_key": "param_val"}, "files": [{"label": "input", "path": "input.wav"}], "extra_params": {}, "extra_files": {}}], "fallbacks": [], "results": {}, "states": []}
121
+ 2023-03-22 02:52:27.392 | INFO | inferencer.basic:inference:16 - Get params {'param_key': 'param_val'}
122
+ 2023-03-22 02:52:27.392 | INFO | inferencer.basic:inference:17 - Get files [File(label='input', path='input.wav')]
123
+ 2023-03-22 02:52:27.393 | INFO | sona.workers.inferencer:on_context:59 - [sona.worker.inferencer.basic] success: {"id": "5fe59e8bcd4b4efb84462cdbcad4a3b4", "header": {}, "jobs": [{"name": "basic_job", "topic": "sona.worker.inferencer.basic", "params": {"param_key": "param_val"}, "files": [{"label": "input", "path": "input.wav"}], "extra_params": {}, "extra_files": {}}], "fallbacks": [], "results": {"basic_job": {"files": [{"label": "output", "path": "output.wav"}], "data": {"data_key": "data_val"}}}, "states": [{"job_name": "basic_job", "exec_time": 0.00029676100000000015, "exception": {}}]}
124
+ ```
@@ -0,0 +1,24 @@
1
+ [tool.poetry]
2
+ name = "dwave-sona-core"
3
+ version = "0.2.7"
4
+ description = ""
5
+ authors = ["dwave-dev"]
6
+ license = ""
7
+ readme = "README.md"
8
+ packages = [{include = "sona"}]
9
+
10
+ [tool.poetry.dependencies]
11
+ python = "^3.7"
12
+ typer = "^0.7.0"
13
+ loguru = "^0.6.0"
14
+ pydantic = {extras = ["dotenv"], version = "^1.10.5"}
15
+ boto3 = "^1.26.85"
16
+ redis = "^4.5.4"
17
+ confluent-kafka = "^2.1.1"
18
+
19
+ [tool.poetry.group.dev.dependencies]
20
+ black = "^23.1.0"
21
+
22
+ [build-system]
23
+ requires = ["poetry-core"]
24
+ build-backend = "poetry.core.masonry.api"
@@ -0,0 +1,62 @@
1
+ import asyncio
2
+
3
+ import typer
4
+ from loguru import logger
5
+
6
+ from sona.core.consumers import create_consumer
7
+ from sona.core.messages.context import Context
8
+ from sona.core.producers import MockProducer, create_producer
9
+ from sona.core.storages import LocalStorage, ShareStorageBase, create_storage
10
+ from sona.inferencers import InferencerBase
11
+ from sona.settings import settings
12
+ from sona.workers import *
13
+
14
+ worker_app = typer.Typer()
15
+ inference_app = typer.Typer()
16
+
17
+
18
+ @worker_app.command()
19
+ def run(
20
+ inferencer: str = settings.SONA_INFERENCER,
21
+ worker: str = settings.SONA_WORKER,
22
+ ):
23
+ try:
24
+ worker_cls = WorkerBase.load_class(worker)
25
+ if worker_cls == InferencerWorker:
26
+ worker: WorkerBase = worker_cls(InferencerBase.load_class(inferencer)())
27
+ else:
28
+ worker: WorkerBase = worker_cls()
29
+ worker.set_storage(create_storage())
30
+ worker.set_producer(create_producer())
31
+ worker.set_consumer(create_consumer())
32
+ asyncio.run(worker.start())
33
+ except Exception as e:
34
+ logger.exception(e)
35
+
36
+
37
+ @worker_app.command("test")
38
+ def worker_test(worker_class: str, file: str = None):
39
+ worker: WorkerBase = WorkerBase.load_class(worker_class)()
40
+ _test(worker, file)
41
+
42
+
43
+ @inference_app.command("test")
44
+ def inferencer_test(inferencer_class: str, file: str = None):
45
+ worker: WorkerBase = InferencerWorker(InferencerBase.load_class(inferencer_class)())
46
+ _test(worker, file)
47
+
48
+
49
+ def _test(worker: WorkerBase, file: str):
50
+ storage: ShareStorageBase = LocalStorage()
51
+ if file:
52
+ context: Context = Context.parse_file(file)
53
+ else:
54
+ context: Context = worker.context_example()
55
+ worker.set_storage(storage)
56
+ worker.set_producer(MockProducer())
57
+ asyncio.run(worker.test(context))
58
+
59
+
60
+ app = typer.Typer()
61
+ app.add_typer(worker_app, name="worker")
62
+ app.add_typer(inference_app, name="inferencer")
@@ -0,0 +1,4 @@
1
+ from sona import app
2
+
3
+ if __name__ == "__main__":
4
+ app()
@@ -0,0 +1,18 @@
1
+ from sona.settings import settings
2
+
3
+ from .base import ConsumerBase
4
+ from .kafka import KafkaConsumer
5
+ from .redis import RedisConsumer
6
+ from .sqs import SQSConsumer
7
+
8
+ __all__ = ["ConsumerBase", "KafkaConsumer", "RedisConsumer", "SQSConsumer"]
9
+
10
+
11
+ def create_consumer():
12
+ if settings.SONA_CONSUMER_REDIS_URL:
13
+ return RedisConsumer()
14
+ if settings.SONA_CONSUMER_KAFKA_SETTING:
15
+ return KafkaConsumer()
16
+ raise Exception(
17
+ "Consumer settings not found, please set SONA_CONSUMER_KAFKA_SETTING or SONA_CONSUMER_REDIS_URL"
18
+ )
@@ -0,0 +1,11 @@
1
+ import abc
2
+
3
+
4
+ class ConsumerBase:
5
+ @abc.abstractmethod
6
+ def subscribe(self, topic: str) -> None:
7
+ return NotImplemented
8
+
9
+ @abc.abstractmethod
10
+ async def consume(self) -> str:
11
+ return NotImplemented
@@ -0,0 +1,30 @@
1
+ import asyncio
2
+
3
+ from confluent_kafka import Consumer
4
+ from loguru import logger
5
+
6
+ from sona.settings import settings
7
+
8
+ from .base import ConsumerBase
9
+
10
+ CONSUMER_SETTING = settings.SONA_CONSUMER_KAFKA_SETTING
11
+
12
+
13
+ class KafkaConsumer(ConsumerBase):
14
+ def __init__(self, configs=CONSUMER_SETTING):
15
+ self.consumer = Consumer(configs)
16
+
17
+ def subscribe(self, topic):
18
+ self.consumer.subscribe([topic])
19
+
20
+ async def consume(self):
21
+ loop = asyncio.get_running_loop()
22
+ while True:
23
+ msg = await loop.run_in_executor(None, self.consumer.poll, 1)
24
+ if not msg:
25
+ continue
26
+ if msg.error():
27
+ logger.warning(f"kafka error: {msg}")
28
+ continue
29
+ self.consumer.commit()
30
+ yield msg.value()
@@ -0,0 +1,36 @@
1
+ import uuid
2
+ from loguru import logger
3
+ import redis.asyncio as redis
4
+ from redis.exceptions import ResponseError
5
+
6
+ from sona.settings import settings
7
+
8
+ from .base import ConsumerBase
9
+
10
+ REDIS_URL = settings.SONA_CONSUMER_REDIS_URL
11
+ CONSUMER_GROUP= settings.SONA_CONSUMER_REDIS_GROUP
12
+
13
+
14
+ class RedisConsumer(ConsumerBase):
15
+ def __init__(self, url=REDIS_URL):
16
+ self.topics = []
17
+ self.redis = redis.from_url(url)
18
+ self.client_id = str(uuid.uuid4())
19
+
20
+ def subscribe(self, topic):
21
+ self.topics += [topic]
22
+
23
+ async def consume(self):
24
+ for topic in self.topics:
25
+ try:
26
+ await self.redis.xgroup_create(topic, CONSUMER_GROUP, id='0', mkstream=True)
27
+ except ResponseError as e:
28
+ logger.warning(e)
29
+
30
+ while True:
31
+ stream_list = {topic: '>' for topic in self.topics}
32
+ streams = await self.redis.xreadgroup(CONSUMER_GROUP, self.client_id, stream_list, block=1000 * 60, noack=True)
33
+ for _stream_key, stream in streams:
34
+ for _id, message in stream:
35
+ if message is not None:
36
+ yield message[b"data"].decode()
@@ -0,0 +1,34 @@
1
+ import asyncio
2
+ import functools
3
+
4
+ import boto3
5
+
6
+ from .base import ConsumerBase
7
+
8
+
9
+ class SQSConsumer(ConsumerBase):
10
+ def __init__(self):
11
+ self.sqs = boto3.resource("sqs")
12
+ self.queue = None
13
+
14
+ def subscribe(self, topic):
15
+ self.queue = self.sqs.get_queue_by_name(QueueName=topic)
16
+
17
+ async def consume(self):
18
+ loop = asyncio.get_running_loop()
19
+ while True:
20
+ messages = await loop.run_in_executor(
21
+ None,
22
+ functools.partial(
23
+ self.queue.receive_messages,
24
+ AttributeNames=["ApproximateReceiveCount"],
25
+ MaxNumberOfMessages=1,
26
+ WaitTimeSeconds=20,
27
+ ),
28
+ )
29
+ if len(messages) <= 0:
30
+ continue
31
+ for message in messages:
32
+ body = message.body
33
+ message.delete()
34
+ yield body
@@ -0,0 +1,8 @@
1
+ from .base import MessageBase
2
+ from .context import Context
3
+ from .file import File
4
+ from .job import Job
5
+ from .result import Result
6
+ from .state import State
7
+
8
+ __all__ = ["MessageBase", "Context", "File", "Job", "Result", "State"]
@@ -0,0 +1,15 @@
1
+ from __future__ import annotations
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class MessageBase(BaseModel):
7
+ def mutate(self, **kwargs) -> MessageBase:
8
+ kwargs = {**self.dict(), **kwargs}
9
+ return self.__class__(**kwargs)
10
+
11
+ def to_message(self) -> str:
12
+ return self.json()
13
+
14
+ class Config:
15
+ allow_mutation = False
@@ -0,0 +1,46 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ import uuid
5
+ from typing import Dict, List
6
+
7
+ from pydantic import Field
8
+
9
+ from .base import MessageBase
10
+ from .job import Job
11
+ from .result import Result
12
+ from .state import State
13
+
14
+
15
+ class Context(MessageBase):
16
+ id: uuid.UUID = Field(default_factory=uuid.uuid4)
17
+ start_time: int = Field(default_factory=time.time_ns)
18
+ reporters: List[str] = []
19
+ fallbacks: List[str] = []
20
+ headers: Dict = {}
21
+ jobs: List[Job]
22
+ results: Dict[str, Result] = {}
23
+ states: List[State] = []
24
+
25
+ @property
26
+ def duration(self):
27
+ return float(time.time_ns() - self.start_time) * (0.1**9)
28
+
29
+ @property
30
+ def current_job(self) -> Job:
31
+ if len(self.states) == len(self.jobs):
32
+ return None
33
+ return self.jobs[len(self.states)]
34
+
35
+ def find_job(self, job_name):
36
+ for job in self.jobs:
37
+ if job.name == job_name:
38
+ return job
39
+ return None
40
+
41
+ def next_context(self, state, result=None) -> Context:
42
+ current_job = self.current_job
43
+ context = self.mutate(states=self.states + [state])
44
+ if result:
45
+ context = context.mutate(results={**self.results, current_job.name: result})
46
+ return context
@@ -0,0 +1,9 @@
1
+ from typing import Dict
2
+
3
+ from .base import MessageBase
4
+
5
+
6
+ class File(MessageBase):
7
+ label: str
8
+ path: str
9
+ metadata: Dict = {}
@@ -0,0 +1,50 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Dict, List
4
+
5
+ from loguru import logger
6
+
7
+ from sona.core.messages.result import Result
8
+ from sona.core.utils.dict_utils import find_value_from_nested_keys
9
+
10
+ from .base import MessageBase
11
+ from .file import File
12
+
13
+
14
+ class Job(MessageBase):
15
+ name: str
16
+ topic: str
17
+ params: Dict = {}
18
+ files: List[File] = []
19
+ extra_params: Dict[str, str] = {} # Will be deprecated in 1.0.0
20
+ extra_files: Dict[str, str] = {} # Will be deprecated in 1.0.0
21
+ required_result_params: Dict[str, str] = {}
22
+ required_result_files: Dict[str, str] = {}
23
+
24
+ @property
25
+ def required_params(self):
26
+ if self.extra_params:
27
+ logger.warning("extra_params will be deprecated in version 1.0.0")
28
+ return {**self.extra_params, **self.required_result_params}
29
+
30
+ @property
31
+ def required_files(self):
32
+ if self.extra_files:
33
+ logger.warning("extra_files will be deprecated in version 1.0.0")
34
+ return {**self.extra_files, **self.required_result_files}
35
+
36
+ def prepare_params(self, results: Dict[str, Result]):
37
+ params = {**self.params}
38
+ for key, target in self.required_params.items():
39
+ targets = target.split("__")
40
+ job_name, keys = targets[0], targets[1:]
41
+ params[key] = find_value_from_nested_keys(keys, results[job_name].data)
42
+ return params
43
+
44
+ def prepare_files(self, results: Dict[str, Result]):
45
+ file_map = {file.label: file for file in self.files}
46
+ for label, target in self.required_files.items():
47
+ job, files_label = target.split("__")
48
+ file = results[job].find_file(files_label).mutate(label=label)
49
+ file_map[label] = file
50
+ return list(file_map.values())
@@ -0,0 +1,17 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Dict, List
4
+
5
+ from .base import MessageBase
6
+ from .file import File
7
+
8
+
9
+ class Result(MessageBase):
10
+ files: List[File] = []
11
+ data: Dict = {}
12
+
13
+ def find_file(self, label) -> File:
14
+ for file in self.files:
15
+ if file.label == label:
16
+ return file
17
+ return None
@@ -0,0 +1,33 @@
1
+ import socket
2
+ import time
3
+ import traceback
4
+ from typing import Dict
5
+
6
+ from pydantic import Field
7
+
8
+ from .base import MessageBase
9
+
10
+
11
+ class State(MessageBase):
12
+ job_name: str
13
+ node_name: str = socket.gethostname()
14
+ timestamp: int = Field(default_factory=time.time_ns)
15
+ exec_time: float = 0
16
+ exception: Dict = {}
17
+
18
+ @property
19
+ def duration(self):
20
+ return float(time.time_ns() - self.timestamp) * (0.1**9)
21
+
22
+ @classmethod
23
+ def start(cls, job_name):
24
+ return State(job_name=job_name)
25
+
26
+ def complete(self):
27
+ return self.mutate(exec_time=self.duration)
28
+
29
+ def fail(self, exception):
30
+ return self.mutate(
31
+ exec_time=self.duration,
32
+ exception={"message": str(exception), "traceback": traceback.format_exc()},
33
+ )
@@ -0,0 +1,25 @@
1
+ from sona.settings import settings
2
+
3
+ from .base import ProducerBase
4
+ from .kafka import KafkaProducer
5
+ from .mock import MockProducer
6
+ from .redis import RedisProducer
7
+ from .sqs import SQSProducer
8
+
9
+ __all__ = [
10
+ "ProducerBase",
11
+ "KafkaProducer",
12
+ "RedisProducer",
13
+ "SQSProducer",
14
+ "MockProducer",
15
+ ]
16
+
17
+
18
+ def create_producer():
19
+ if settings.SONA_PRODUCER_REDIS_URL:
20
+ return RedisProducer()
21
+ if settings.SONA_PRODUCER_KAFKA_SETTING:
22
+ return KafkaProducer()
23
+ raise Exception(
24
+ "Producer settings not found, please set SONA_PRODUCER_KAFKA_SETTING or SONA_PRODUCER_REDIS_URL"
25
+ )
@@ -0,0 +1,7 @@
1
+ import abc
2
+
3
+
4
+ class ProducerBase:
5
+ @abc.abstractmethod
6
+ def emit(self, topic, message) -> None:
7
+ return NotImplemented
@@ -0,0 +1,23 @@
1
+ from confluent_kafka import Producer
2
+
3
+ from sona.settings import settings
4
+
5
+ from .base import ProducerBase
6
+
7
+ PRODUCER_SETTING = settings.SONA_PRODUCER_KAFKA_SETTING
8
+
9
+
10
+ class KafkaProducer(ProducerBase):
11
+ def __init__(self, configs=PRODUCER_SETTING):
12
+ self.producer = Producer(configs)
13
+
14
+ def emit(self, topic, message):
15
+ self.producer.poll(0)
16
+ self.producer.produce(
17
+ topic, message.encode("utf-8"), callback=self.__delivery_report
18
+ )
19
+ self.producer.flush()
20
+
21
+ def __delivery_report(self, err, msg):
22
+ if err:
23
+ raise Exception(msg.error())
@@ -0,0 +1,8 @@
1
+ from loguru import logger
2
+
3
+ from .base import ProducerBase
4
+
5
+
6
+ class MockProducer(ProducerBase):
7
+ def emit(self, topic, message) -> None:
8
+ logger.info(f"emit [{topic}] {message}")
@@ -0,0 +1,15 @@
1
+ import redis
2
+
3
+ from sona.settings import settings
4
+
5
+ from .base import ProducerBase
6
+
7
+ REDIS_URL = settings.SONA_PRODUCER_REDIS_URL
8
+
9
+
10
+ class RedisProducer(ProducerBase):
11
+ def __init__(self, url=REDIS_URL):
12
+ self.redis = redis.from_url(url)
13
+
14
+ def emit(self, topic, message):
15
+ self.redis.xadd(topic, {"data": message})
@@ -0,0 +1,12 @@
1
+ import boto3
2
+
3
+ from .base import ProducerBase
4
+
5
+
6
+ class SQSProducer(ProducerBase):
7
+ def __init__(self):
8
+ self.sqs = boto3.resource("sqs")
9
+
10
+ def emit(self, topic, message):
11
+ queue = self.sqs.get_queue_by_name(QueueName=topic)
12
+ queue.send_message(MessageBody=message)
@@ -0,0 +1,13 @@
1
+ from sona.settings import settings
2
+
3
+ from .base import ShareStorageBase
4
+ from .local import LocalStorage
5
+ from .s3 import S3Storage
6
+
7
+ __all__ = ["ShareStorageBase", "LocalStorage", "S3Storage"]
8
+
9
+
10
+ def create_storage():
11
+ if settings.SONA_STORAGE_SETTING:
12
+ return S3Storage()
13
+ return LocalStorage()
@@ -0,0 +1,51 @@
1
+ from pathlib import Path
2
+ from typing import List
3
+
4
+ from sona.core.messages.file import File
5
+
6
+
7
+ class ShareStorageBase:
8
+ def __init__(self):
9
+ self.cached_path = []
10
+
11
+ def pull_all(self, files: List[File]) -> List[File]:
12
+ return [self.pull(file) for file in files]
13
+
14
+ def pull(self, file: File) -> File:
15
+ if Path(file.path).is_file():
16
+ return file
17
+ if not self.is_valid(file.path):
18
+ raise Exception("Invalid share storage path")
19
+ path = self.on_pull(file.path)
20
+ self.cached_path.append(path)
21
+ return file.mutate(path=path)
22
+
23
+ def push_all(self, files: List[File]) -> List[File]:
24
+ return [self.push(file) for file in files]
25
+
26
+ def push(self, file: File) -> File:
27
+ if self.is_valid(file.path):
28
+ return file
29
+ if not Path(file.path).is_file():
30
+ raise Exception(f"Missing file: {file}")
31
+ self.cached_path.append(file.path)
32
+ return file.mutate(path=self.on_push(file.path))
33
+
34
+ def clean(self) -> None:
35
+ for path in self.cached_path:
36
+ file_path = Path(path)
37
+ if file_path.exists():
38
+ file_path.unlink()
39
+
40
+ def create_storage(self):
41
+ return self.__class__()
42
+
43
+ # Callbacks
44
+ def is_valid(self, path) -> bool:
45
+ return NotImplemented
46
+
47
+ def on_pull(self, path) -> str:
48
+ return NotImplemented
49
+
50
+ def on_push(self, path) -> str:
51
+ return NotImplemented
@@ -0,0 +1,32 @@
1
+ import os
2
+ import shutil
3
+ from pathlib import Path
4
+
5
+ from sona.core.storages.base import ShareStorageBase
6
+ from sona.settings import settings
7
+
8
+
9
+ class LocalStorage(ShareStorageBase):
10
+ def __init__(self,
11
+ local_root=settings.SONA_STORAGE_LOCAL_ROOT,
12
+ storage_dir=settings.SONA_STORAGE_DIR):
13
+ super().__init__()
14
+ self.local_root = local_root
15
+ self.storage_dir = storage_dir
16
+
17
+ def is_valid(self, path: str) -> bool:
18
+ return path.startswith(self.local_root)
19
+
20
+ def on_pull(self, path: str) -> str:
21
+ tmp_path = os.path.relpath(path, self.local_root)
22
+ Path(tmp_path).parent.mkdir(parents=True, exist_ok=True)
23
+ with open(path, "rb") as f_in, open(tmp_path, "wb") as f_out:
24
+ shutil.copyfileobj(f_in, f_out)
25
+ return str(tmp_path)
26
+
27
+ def on_push(self, path: str) -> str:
28
+ local_path = Path(self.local_root) / self.storage_dir / Path(path).name
29
+ local_path.parent.mkdir(parents=True, exist_ok=True)
30
+ with open(path, "rb") as f_in, open(local_path, "wb") as f_out:
31
+ shutil.copyfileobj(f_in, f_out)
32
+ return str(local_path)
@@ -0,0 +1,53 @@
1
+ import datetime
2
+ import hashlib
3
+ import re
4
+ from pathlib import Path
5
+
6
+ import boto3
7
+ from botocore.client import Config
8
+
9
+ from sona.core.storages.base import ShareStorageBase
10
+ from sona.settings import settings
11
+
12
+
13
+ class S3Storage(ShareStorageBase):
14
+ def __init__(
15
+ self,
16
+ bucket=settings.SONA_STORAGE_BUCKET,
17
+ upload_dir=settings.SONA_STORAGE_DIR,
18
+ configs=settings.SONA_STORAGE_SETTING,
19
+ ):
20
+ super().__init__()
21
+ self.bucket = bucket
22
+ self.upload_dir = upload_dir
23
+ self.configs = configs
24
+
25
+ @property
26
+ def client(self):
27
+ configs = self.configs or {}
28
+ configs.update({"config": Config(signature_version="s3v4")})
29
+ return boto3.resource("s3", **configs).meta.client
30
+
31
+ def is_valid(self, path: str) -> bool:
32
+ return re.match(r"^[Ss]3://.*", path)
33
+
34
+ def on_pull(self, path: str) -> str:
35
+ match = re.match(r"[Ss]3://([-_A-Za-z0-9]+)/(.+)", path)
36
+ bucket = match.group(1)
37
+ obj_key = match.group(2)
38
+ filepath = Path(obj_key)
39
+ filepath.parent.mkdir(parents=True, exist_ok=True)
40
+ self.client.download_file(bucket, obj_key, str(filepath))
41
+ return str(filepath)
42
+
43
+ def on_push(self, path: str) -> str:
44
+ md5 = hashlib.md5()
45
+ with open(path, "rb") as f:
46
+ for chunk in iter(lambda: f.read(4096), b""):
47
+ md5.update(chunk)
48
+ obj_key = f"{md5.hexdigest()}{''.join(Path(path).suffixes)}"
49
+ obj_key = (
50
+ Path(self.upload_dir) / datetime.date.today().strftime("%Y%m%d") / obj_key
51
+ )
52
+ self.client.upload_file(path, self.bucket, str(obj_key))
53
+ return f"S3://{self.bucket}/{obj_key}"
File without changes
@@ -0,0 +1,13 @@
1
+ import sys
2
+ import traceback
3
+
4
+
5
+ def import_class(import_str):
6
+ mod_str, _sep, class_str = import_str.rpartition(".")
7
+ __import__(mod_str)
8
+ try:
9
+ return getattr(sys.modules[mod_str], class_str)
10
+ except AttributeError:
11
+ raise ImportError(
12
+ f"Class {class_str} cannot be found ({traceback.format_exception(*sys.exc_info())})"
13
+ )
@@ -0,0 +1,7 @@
1
+ def find_value_from_nested_keys(keys, dict_):
2
+ dict_value = dict_
3
+ for key in keys:
4
+ dict_value = dict_value.get(key)
5
+ if not dict_value:
6
+ return None
7
+ return dict_value
@@ -0,0 +1,3 @@
1
+ from .base import InferencerBase
2
+
3
+ __all__ = ["InferencerBase"]
@@ -0,0 +1,35 @@
1
+ from __future__ import annotations
2
+
3
+ import abc
4
+ from typing import Dict, List
5
+
6
+ from sona.core.messages import Context, File, Result
7
+ from sona.core.utils.cls_utils import import_class
8
+ from sona.settings import settings
9
+
10
+ TOPIC_PREFIX = settings.SONA_INFERENCER_TOPIC_PREFIX
11
+
12
+
13
+ class InferencerBase:
14
+ inferencer = NotImplemented
15
+
16
+ def on_load(self) -> None:
17
+ return
18
+
19
+ def context_example(self) -> Context:
20
+ return None
21
+
22
+ @classmethod
23
+ def get_topic(cls):
24
+ return f"{TOPIC_PREFIX}.{cls.inferencer}"
25
+
26
+ @classmethod
27
+ def load_class(cls, import_str):
28
+ inferencer_cls = import_class(import_str)
29
+ if inferencer_cls not in cls.__subclasses__():
30
+ raise Exception(f"Unknown inferencer class: {import_str}")
31
+ return inferencer_cls
32
+
33
+ @abc.abstractmethod
34
+ def inference(self, params: Dict, files: List[File]) -> Result:
35
+ return NotImplemented
@@ -0,0 +1,64 @@
1
+ from pathlib import Path
2
+ from typing import Dict, List
3
+
4
+ from loguru import logger
5
+
6
+ from sona.core.messages import Context, File, Job, Result
7
+ from sona.inferencers import InferencerBase
8
+ from sona.settings import settings
9
+
10
+
11
+ class MockInferencer(InferencerBase):
12
+ inferencer = "mock"
13
+
14
+ def on_load(self) -> None:
15
+ logger.info(f"Download {self.__class__.__name__} models...")
16
+
17
+ def inference(self, params: Dict, files: List[File]) -> Result:
18
+ logger.info(f"Get params {params}")
19
+ logger.info(f"Get files {files}")
20
+ filepath = "output.wav"
21
+ Path(filepath).touch(exist_ok=True)
22
+ return Result(
23
+ files=[File(label="output", path=filepath)],
24
+ data={"output_key": "output_val"},
25
+ )
26
+
27
+ def context_example(self) -> Context:
28
+ storage = settings.SONA_STORAGE_LOCAL_ROOT
29
+ filepath1 = f"{storage}/input1.wav"
30
+ filepath2 = f"{storage}/input2.wav"
31
+ Path(storage).mkdir(exist_ok=True, parents=True)
32
+ Path(filepath1).touch(exist_ok=True)
33
+ Path(filepath2).touch(exist_ok=True)
34
+
35
+ return Context(
36
+ jobs=[
37
+ Job(
38
+ name="input1",
39
+ topic=self.get_topic(),
40
+ params={"input1_key": "input1_val"},
41
+ files=[
42
+ File(
43
+ label="input1",
44
+ path=filepath1,
45
+ )
46
+ ],
47
+ required_result_params={
48
+ "input2_key": "input2__input2_key1__input2_key2"
49
+ },
50
+ required_result_files={"input2": "input2__input2_r"},
51
+ )
52
+ ],
53
+ results={
54
+ "input2": Result(
55
+ files=[
56
+ File(
57
+ label="input2_r",
58
+ path=filepath2,
59
+ )
60
+ ],
61
+ data={"input2_key1": {"input2_key2": "input2_val"}},
62
+ )
63
+ },
64
+ )
@@ -0,0 +1,36 @@
1
+ from typing import Any, Dict, Optional
2
+
3
+ from pydantic import BaseSettings, RedisDsn, root_validator
4
+
5
+
6
+ class Settings(BaseSettings):
7
+ # Consumer settings
8
+ SONA_CONSUMER_KAFKA_SETTING: Optional[Dict] = None
9
+ SONA_CONSUMER_REDIS_URL: Optional[RedisDsn] = None
10
+ SONA_CONSUMER_REDIS_GROUP: Optional[str] = "dwave.anonymous"
11
+
12
+ # Producer settings
13
+ SONA_PRODUCER_KAFKA_SETTING: Optional[Dict] = None
14
+ SONA_PRODUCER_REDIS_URL: Optional[RedisDsn] = None
15
+
16
+ # Storage settings
17
+ SONA_STORAGE_DIR: str = "_tmp"
18
+ SONA_STORAGE_SETTING: Dict = None
19
+ SONA_STORAGE_BUCKET: str = "sona"
20
+ SONA_STORAGE_LOCAL_ROOT: str = "_share"
21
+
22
+ # Inferencer settings
23
+ SONA_WORKER: str = "sona.workers.InferencerWorker"
24
+ SONA_INFERENCER: str = None
25
+ SONA_INFERENCER_TOPIC_PREFIX: str = "dwave.inferencer"
26
+
27
+ @root_validator(pre=False)
28
+ def load_settings(cls, values: Dict[str, Any]) -> Dict[str, Any]:
29
+ return values
30
+
31
+ class Config:
32
+ env_file = ".env"
33
+ env_file_encoding = "utf-8"
34
+
35
+
36
+ settings = Settings()
@@ -0,0 +1,4 @@
1
+ from .base import WorkerBase
2
+ from .inferencer import InferencerWorker
3
+
4
+ __all__ = ["WorkerBase", "InferencerWorker"]
@@ -0,0 +1,66 @@
1
+ from __future__ import annotations
2
+
3
+ import abc
4
+
5
+ from loguru import logger
6
+
7
+ from sona.core.consumers import ConsumerBase
8
+ from sona.core.messages import Context
9
+ from sona.core.producers import ProducerBase
10
+ from sona.core.storages.base import ShareStorageBase
11
+ from sona.core.utils.cls_utils import import_class
12
+
13
+
14
+ class WorkerBase:
15
+ topic: str = "dummy"
16
+
17
+ def set_consumer(self, consumer: ConsumerBase):
18
+ self.consumer = consumer
19
+
20
+ def set_producer(self, producer: ProducerBase):
21
+ self.producer = producer
22
+
23
+ def set_storage(self, storage: ShareStorageBase):
24
+ self.storage = storage
25
+
26
+ async def start(self):
27
+ await self.on_load()
28
+ self.topic = self.get_topic()
29
+ logger.info(f"Susbcribe on {self.topic}({self.consumer.__class__.__name__})")
30
+ self.consumer.subscribe(self.topic)
31
+ async for message in self.consumer.consume():
32
+ try:
33
+ self.storage = self.storage.create_storage()
34
+ context = Context.parse_raw(message)
35
+ await self.on_context(context)
36
+ except Exception as e:
37
+ logger.warning(f"[{self.topic}] error: {e}, msg: {message}")
38
+ finally:
39
+ self.storage.clean()
40
+
41
+ async def test(self, context: Context):
42
+ await self.on_load()
43
+ await self.on_context(context)
44
+
45
+ @classmethod
46
+ def get_topic(cls) -> str:
47
+ return cls.topic
48
+
49
+ @classmethod
50
+ def load_class(cls, import_str):
51
+ worker_cls = import_class(import_str)
52
+ if worker_cls not in cls.__subclasses__():
53
+ raise Exception(f"Unknown worker class: {import_str}")
54
+ return worker_cls
55
+
56
+ # Callbacks
57
+ @abc.abstractmethod
58
+ async def on_load(self) -> None:
59
+ return NotImplemented
60
+
61
+ @abc.abstractmethod
62
+ async def on_context(self, message: Context) -> Context:
63
+ return NotImplemented
64
+
65
+ def context_example(self) -> Context:
66
+ return None
@@ -0,0 +1,73 @@
1
+ import asyncio
2
+
3
+ from loguru import logger
4
+
5
+ from sona.core.messages import Context, Job, State
6
+ from sona.inferencers import InferencerBase
7
+ from sona.settings import settings
8
+
9
+ from .base import WorkerBase
10
+
11
+ TOPIC_PREFIX = settings.SONA_INFERENCER_TOPIC_PREFIX
12
+ INFERENCER_CLASS = settings.SONA_INFERENCER
13
+
14
+
15
+ class InferencerWorker(WorkerBase):
16
+ def __init__(self, inferencer: InferencerBase):
17
+ super().__init__()
18
+ self.inferencer = inferencer
19
+
20
+ async def on_load(self):
21
+ logger.info(f"Loading inferencer: {self.inferencer.inferencer}")
22
+ loop = asyncio.get_running_loop()
23
+ await loop.run_in_executor(None, self.inferencer.on_load)
24
+
25
+ async def on_context(self, context: Context):
26
+ try:
27
+ logger.info(f"[{self.topic}] recv: {context.to_message()}")
28
+
29
+ # Prepare process data
30
+ current_job: Job = context.current_job
31
+ current_state: State = State.start(current_job.name)
32
+ params = current_job.prepare_params(context.results)
33
+ files = current_job.prepare_files(context.results)
34
+ files = self.storage.pull_all(files)
35
+
36
+ # Process
37
+ # TODO: Make process cancelable
38
+ loop = asyncio.get_running_loop()
39
+ result = await loop.run_in_executor(
40
+ None, self.inferencer.inference, params, files
41
+ )
42
+ result = result.mutate(files=self.storage.push_all(result.files))
43
+
44
+ # Create success context
45
+ current_state = current_state.complete()
46
+ next_context = context.next_context(current_state, result)
47
+ logger.info(f"[{self.topic}] success: {next_context.to_message()}")
48
+
49
+ # Emit message
50
+ next_job = next_context.current_job
51
+ if next_job:
52
+ self.producer.emit(next_job.topic, next_context.to_message())
53
+ else:
54
+ for topic in next_context.reporters:
55
+ self.producer.emit(topic, next_context.to_message())
56
+ return next_context
57
+
58
+ except Exception as e:
59
+ # Create fail context
60
+ current_state = current_state.fail(e)
61
+ next_context = context.next_context(current_state)
62
+ logger.exception(f"[{self.topic}] error: {next_context.to_message()}")
63
+
64
+ # Emit message
65
+ for topic in next_context.fallbacks:
66
+ self.producer.emit(topic, next_context.to_message())
67
+ return next_context
68
+
69
+ def context_example(self):
70
+ return self.inferencer.context_example()
71
+
72
+ def get_topic(self):
73
+ return self.inferencer.get_topic()