oocana 0.15.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- oocana-0.15.0/PKG-INFO +10 -0
- oocana-0.15.0/oocana/__init__.py +7 -0
- oocana-0.15.0/oocana/context.py +334 -0
- oocana-0.15.0/oocana/data.py +96 -0
- oocana-0.15.0/oocana/handle_data.py +55 -0
- oocana-0.15.0/oocana/mainframe.py +135 -0
- oocana-0.15.0/oocana/preview.py +79 -0
- oocana-0.15.0/oocana/schema.py +118 -0
- oocana-0.15.0/oocana/service.py +70 -0
- oocana-0.15.0/oocana/throtter.py +37 -0
- oocana-0.15.0/pyproject.toml +22 -0
- oocana-0.15.0/tests/__init__.py +0 -0
- oocana-0.15.0/tests/test_data.py +75 -0
- oocana-0.15.0/tests/test_handle_data.py +109 -0
- oocana-0.15.0/tests/test_json.py +72 -0
- oocana-0.15.0/tests/test_mainframe.py +36 -0
- oocana-0.15.0/tests/test_performance.py +22 -0
- oocana-0.15.0/tests/test_schema.py +91 -0
- oocana-0.15.0/tests/test_throtter.py +36 -0
oocana-0.15.0/PKG-INFO
ADDED
@@ -0,0 +1,10 @@
|
|
1
|
+
Metadata-Version: 2.1
|
2
|
+
Name: oocana
|
3
|
+
Version: 0.15.0
|
4
|
+
Summary: python implement of oocana to give a context for oocana block
|
5
|
+
License: MIT
|
6
|
+
Requires-Python: >=3.9
|
7
|
+
Requires-Dist: paho-mqtt>=2
|
8
|
+
Requires-Dist: simplejson>=3.19.2
|
9
|
+
Requires-Dist: typing-extensions>=4.12.2; python_version < "3.11"
|
10
|
+
|
@@ -0,0 +1,7 @@
|
|
1
|
+
from .data import * # noqa: F403
|
2
|
+
from .context import * # noqa: F403
|
3
|
+
from .service import * # noqa: F403
|
4
|
+
from .handle_data import * # noqa: F403
|
5
|
+
from .preview import * # noqa: F403
|
6
|
+
from .schema import * # noqa: F403
|
7
|
+
from .mainframe import Mainframe as Mainframe # noqa: F403
|
@@ -0,0 +1,334 @@
|
|
1
|
+
from dataclasses import asdict
|
2
|
+
from json import loads
|
3
|
+
from .data import BlockInfo, StoreKey, JobDict, BlockDict, BinValueDict, VarValueDict
|
4
|
+
from .handle_data import HandleDef
|
5
|
+
from .mainframe import Mainframe
|
6
|
+
from typing import Dict, Any, TypedDict, Optional
|
7
|
+
from base64 import b64encode
|
8
|
+
from io import BytesIO
|
9
|
+
from .throtter import throttle
|
10
|
+
from .preview import PreviewPayload, TablePreviewData, DataFrame, ShapeDataFrame, PartialDataFrame
|
11
|
+
from .data import EXECUTOR_NAME
|
12
|
+
import os.path
|
13
|
+
import logging
|
14
|
+
|
15
|
+
__all__ = ["Context"]
|
16
|
+
|
17
|
+
class OnlyEqualSelf:
|
18
|
+
def __eq__(self, value: object) -> bool:
|
19
|
+
return self is value
|
20
|
+
|
21
|
+
class OOMOL_LLM_ENV(TypedDict):
|
22
|
+
base_url: str
|
23
|
+
api_key: str
|
24
|
+
models: list[str]
|
25
|
+
|
26
|
+
class HostInfo(TypedDict):
|
27
|
+
gpu_vendor: str
|
28
|
+
gpu_renderer: str
|
29
|
+
|
30
|
+
class Context:
|
31
|
+
__inputs: Dict[str, Any]
|
32
|
+
|
33
|
+
__block_info: BlockInfo
|
34
|
+
__outputs_def: Dict[str, HandleDef]
|
35
|
+
__store: Any
|
36
|
+
__is_done: bool = False
|
37
|
+
__keep_alive: OnlyEqualSelf = OnlyEqualSelf()
|
38
|
+
__session_dir: str
|
39
|
+
_logger: Optional[logging.Logger] = None
|
40
|
+
|
41
|
+
def __init__(
|
42
|
+
self, inputs: Dict[str, Any], blockInfo: BlockInfo, mainframe: Mainframe, store, outputs, session_dir: str
|
43
|
+
) -> None:
|
44
|
+
|
45
|
+
self.__block_info = blockInfo
|
46
|
+
|
47
|
+
self.__mainframe = mainframe
|
48
|
+
self.__store = store
|
49
|
+
self.__inputs = inputs
|
50
|
+
|
51
|
+
outputs_defs = {}
|
52
|
+
if outputs is not None:
|
53
|
+
for k, v in outputs.items():
|
54
|
+
outputs_defs[k] = HandleDef(**v)
|
55
|
+
self.__outputs_def = outputs_defs
|
56
|
+
self.__session_dir = session_dir
|
57
|
+
|
58
|
+
@property
|
59
|
+
def logger(self) -> logging.Logger:
|
60
|
+
"""a custom logger for the block, you can use it to log the message to the block log. this logger will report the log by context report_logger api.
|
61
|
+
"""
|
62
|
+
|
63
|
+
# setup after init, so the logger always exists
|
64
|
+
if self._logger is None:
|
65
|
+
raise ValueError("logger is not setup, please setup the logger in the block init function.")
|
66
|
+
return self._logger
|
67
|
+
|
68
|
+
@property
|
69
|
+
def session_dir(self) -> str:
|
70
|
+
"""a temporary directory for the current session, all blocks in the one session will share the same directory.
|
71
|
+
"""
|
72
|
+
return self.__session_dir
|
73
|
+
|
74
|
+
@property
|
75
|
+
def keepAlive(self):
|
76
|
+
return self.__keep_alive
|
77
|
+
|
78
|
+
@property
|
79
|
+
def inputs(self):
|
80
|
+
return self.__inputs
|
81
|
+
|
82
|
+
@property
|
83
|
+
def session_id(self):
|
84
|
+
return self.__block_info.session_id
|
85
|
+
|
86
|
+
@property
|
87
|
+
def job_id(self):
|
88
|
+
return self.__block_info.job_id
|
89
|
+
|
90
|
+
@property
|
91
|
+
def job_info(self) -> JobDict:
|
92
|
+
return self.__block_info.job_info()
|
93
|
+
|
94
|
+
@property
|
95
|
+
def block_info(self) -> BlockDict:
|
96
|
+
return self.__block_info.block_dict()
|
97
|
+
|
98
|
+
@property
|
99
|
+
def node_id(self) -> str:
|
100
|
+
return self.__block_info.stacks[-1].get("node_id", None)
|
101
|
+
|
102
|
+
@property
|
103
|
+
def oomol_llm_env(self) -> OOMOL_LLM_ENV:
|
104
|
+
"""this is a dict contains the oomol llm environment variables
|
105
|
+
"""
|
106
|
+
return {
|
107
|
+
"base_url": os.getenv("OOMOL_LLM_BASE_URL", ""),
|
108
|
+
"api_key": os.getenv("OOMOL_LLM_API_KEY", ""),
|
109
|
+
"models": os.getenv("OOMOL_LLM_MODELS", "").split(","),
|
110
|
+
}
|
111
|
+
|
112
|
+
@property
|
113
|
+
def host_info(self) -> HostInfo:
|
114
|
+
"""this is a dict contains the host information
|
115
|
+
"""
|
116
|
+
return {
|
117
|
+
"gpu_vendor": os.getenv("OOMOL_HOST_GPU_VENDOR", "unknown"),
|
118
|
+
"gpu_renderer": os.getenv("OOMOL_HOST_GPU_RENDERER", "unknown"),
|
119
|
+
}
|
120
|
+
|
121
|
+
def __store_ref(self, handle: str):
|
122
|
+
return StoreKey(
|
123
|
+
executor=EXECUTOR_NAME,
|
124
|
+
handle=handle,
|
125
|
+
job_id=self.job_id,
|
126
|
+
session_id=self.session_id,
|
127
|
+
)
|
128
|
+
|
129
|
+
def __is_basic_type(self, value: Any) -> bool:
|
130
|
+
return isinstance(value, (int, float, str, bool))
|
131
|
+
|
132
|
+
def output(self, key: str, value: Any, done: bool = False):
|
133
|
+
"""
|
134
|
+
output the value to the next block
|
135
|
+
|
136
|
+
key: str, the key of the output, should be defined in the block schema output defs, the field name is handle
|
137
|
+
value: Any, the value of the output
|
138
|
+
"""
|
139
|
+
|
140
|
+
v = value
|
141
|
+
|
142
|
+
if self.__outputs_def is not None:
|
143
|
+
output_def = self.__outputs_def.get(key)
|
144
|
+
if (
|
145
|
+
output_def is not None and output_def.is_var_handle() and not self.__is_basic_type(value) # 基础类型即使是变量也不放进 store,直接作为 json 内容传递
|
146
|
+
):
|
147
|
+
ref = self.__store_ref(key)
|
148
|
+
self.__store[ref] = value
|
149
|
+
d: VarValueDict = {
|
150
|
+
"__OOMOL_TYPE__": "oomol/var",
|
151
|
+
"value": asdict(ref)
|
152
|
+
}
|
153
|
+
v = d
|
154
|
+
elif output_def is not None and output_def.is_bin_handle():
|
155
|
+
if not isinstance(value, bytes):
|
156
|
+
self.send_warning(
|
157
|
+
f"Output handle key: [{key}] is defined as binary, but the value is not bytes."
|
158
|
+
)
|
159
|
+
return
|
160
|
+
|
161
|
+
bin_file = f"{self.session_dir}/binary/{self.session_id}/{self.job_id}/{key}"
|
162
|
+
os.makedirs(os.path.dirname(bin_file), exist_ok=True)
|
163
|
+
try:
|
164
|
+
with open(bin_file, "wb") as f:
|
165
|
+
f.write(value)
|
166
|
+
except IOError as e:
|
167
|
+
self.send_warning(
|
168
|
+
f"Output handle key: [{key}] is defined as binary, but an error occurred while writing the file: {e}"
|
169
|
+
)
|
170
|
+
return
|
171
|
+
|
172
|
+
if os.path.exists(bin_file):
|
173
|
+
bin_value: BinValueDict = {
|
174
|
+
"__OOMOL_TYPE__": "oomol/bin",
|
175
|
+
"value": bin_file,
|
176
|
+
}
|
177
|
+
v = bin_value
|
178
|
+
else:
|
179
|
+
self.send_warning(
|
180
|
+
f"Output handle key: [{key}] is defined as binary, but the file is not written."
|
181
|
+
)
|
182
|
+
return
|
183
|
+
|
184
|
+
# 如果传入 key 在输出定义中不存在,直接忽略,不发送数据。但是 done 仍然生效。
|
185
|
+
if self.__outputs_def is not None and self.__outputs_def.get(key) is None:
|
186
|
+
self.send_warning(
|
187
|
+
f"Output handle key: [{key}] is not defined in Block outputs schema."
|
188
|
+
)
|
189
|
+
if done:
|
190
|
+
self.done()
|
191
|
+
return
|
192
|
+
|
193
|
+
node_result = {
|
194
|
+
"type": "BlockOutput",
|
195
|
+
"handle": key,
|
196
|
+
"output": v,
|
197
|
+
"done": done,
|
198
|
+
}
|
199
|
+
self.__mainframe.send(self.job_info, node_result)
|
200
|
+
|
201
|
+
if done:
|
202
|
+
self.done()
|
203
|
+
|
204
|
+
def done(self, error: str | None = None):
|
205
|
+
if self.__is_done:
|
206
|
+
self.send_warning("done has been called multiple times, will be ignored.")
|
207
|
+
return
|
208
|
+
self.__is_done = True
|
209
|
+
if error is None:
|
210
|
+
self.__mainframe.send(self.job_info, {"type": "BlockFinished"})
|
211
|
+
else:
|
212
|
+
self.__mainframe.send(
|
213
|
+
self.job_info, {"type": "BlockFinished", "error": error}
|
214
|
+
)
|
215
|
+
|
216
|
+
def send_message(self, payload):
|
217
|
+
self.__mainframe.report(
|
218
|
+
self.block_info,
|
219
|
+
{
|
220
|
+
"type": "BlockMessage",
|
221
|
+
"payload": payload,
|
222
|
+
},
|
223
|
+
)
|
224
|
+
|
225
|
+
def __dataframe(self, payload: PreviewPayload) -> PreviewPayload:
|
226
|
+
if isinstance(payload, DataFrame):
|
227
|
+
payload = { "type": "table", "data": payload }
|
228
|
+
|
229
|
+
if isinstance(payload, dict) and payload.get("type") == "table":
|
230
|
+
df = payload.get("data")
|
231
|
+
if isinstance(df, ShapeDataFrame):
|
232
|
+
row_count = df.shape[0]
|
233
|
+
if row_count <= 10:
|
234
|
+
data = df.to_dict(orient='split')
|
235
|
+
columns = data.get("columns", [])
|
236
|
+
rows = data.get("data", [])
|
237
|
+
elif isinstance(df, PartialDataFrame):
|
238
|
+
data_columns = loads(df.head(5).to_json(orient='split'))
|
239
|
+
columns = data_columns.get("columns", [])
|
240
|
+
rows_head = data_columns.get("data", [])
|
241
|
+
data_tail = loads(df.tail(5).to_json(orient='split'))
|
242
|
+
rows_tail = data_tail.get("data", [])
|
243
|
+
rows_dots = [["..."] * len(columns)]
|
244
|
+
rows = rows_head + rows_dots + rows_tail
|
245
|
+
else:
|
246
|
+
print("dataframe more than 10 rows but not support head and tail is not supported")
|
247
|
+
return payload
|
248
|
+
data: TablePreviewData = { "rows": rows, "columns": columns, "row_count": row_count }
|
249
|
+
payload = { "type": "table", "data": data }
|
250
|
+
else:
|
251
|
+
print("dataframe is not support shape property")
|
252
|
+
|
253
|
+
return payload
|
254
|
+
|
255
|
+
def __matplotlib(self, payload: PreviewPayload) -> PreviewPayload:
|
256
|
+
# payload is a matplotlib Figure
|
257
|
+
if hasattr(payload, 'savefig'):
|
258
|
+
fig: Any = payload
|
259
|
+
buffer = BytesIO()
|
260
|
+
fig.savefig(buffer, format='png')
|
261
|
+
buffer.seek(0)
|
262
|
+
png = buffer.getvalue()
|
263
|
+
buffer.close()
|
264
|
+
url = f'data:image/png;base64,{b64encode(png).decode("utf-8")}'
|
265
|
+
payload = { "type": "image", "data": url }
|
266
|
+
|
267
|
+
return payload
|
268
|
+
|
269
|
+
def preview(self, payload: PreviewPayload):
|
270
|
+
payload = self.__dataframe(payload)
|
271
|
+
payload = self.__matplotlib(payload)
|
272
|
+
|
273
|
+
self.__mainframe.report(
|
274
|
+
self.block_info,
|
275
|
+
{
|
276
|
+
"type": "BlockPreview",
|
277
|
+
"payload": payload,
|
278
|
+
},
|
279
|
+
)
|
280
|
+
|
281
|
+
@throttle(0.3)
|
282
|
+
def report_progress(self, progress: float | int):
|
283
|
+
"""report progress
|
284
|
+
|
285
|
+
This api is used to report the progress of the block. but it just effect the ui progress not the real progress.
|
286
|
+
This api is throttled. the minimum interval is 0.3s.
|
287
|
+
When you first call this api, it will report the progress immediately. After it invoked once, it will report the progress at the end of the throttling period.
|
288
|
+
|
289
|
+
| 0.25 s | 0.2 s |
|
290
|
+
first call second call third call 4 5 6 7's calls
|
291
|
+
| | | | | | |
|
292
|
+
| -------- 0.3 s -------- | -------- 0.3 s -------- |
|
293
|
+
invoke invoke invoke
|
294
|
+
:param float | int progress: the progress of the block, the value should be in [0, 100].
|
295
|
+
"""
|
296
|
+
self.__mainframe.report(
|
297
|
+
self.block_info,
|
298
|
+
{
|
299
|
+
"type": "BlockProgress",
|
300
|
+
"rate": progress,
|
301
|
+
}
|
302
|
+
)
|
303
|
+
|
304
|
+
def report_log(self, line: str, stdio: str = "stdout"):
|
305
|
+
self.__mainframe.report(
|
306
|
+
self.block_info,
|
307
|
+
{
|
308
|
+
"type": "BlockLog",
|
309
|
+
"log": line,
|
310
|
+
stdio: stdio,
|
311
|
+
},
|
312
|
+
)
|
313
|
+
|
314
|
+
def log_json(self, payload):
|
315
|
+
self.__mainframe.report(
|
316
|
+
self.block_info,
|
317
|
+
{
|
318
|
+
"type": "BlockLogJSON",
|
319
|
+
"json": payload,
|
320
|
+
},
|
321
|
+
)
|
322
|
+
|
323
|
+
def send_warning(self, warning: str):
|
324
|
+
self.__mainframe.report(self.block_info, {"type": "BlockWarning", "warning": warning})
|
325
|
+
|
326
|
+
def send_error(self, error: str):
|
327
|
+
'''
|
328
|
+
deprecated, use error(error) instead.
|
329
|
+
consider to remove in the future.
|
330
|
+
'''
|
331
|
+
self.error(error)
|
332
|
+
|
333
|
+
def error(self, error: str):
|
334
|
+
self.__mainframe.send(self.job_info, {"type": "BlockError", "error": error})
|
@@ -0,0 +1,96 @@
|
|
1
|
+
from dataclasses import dataclass, asdict
|
2
|
+
from typing import TypedDict, Literal
|
3
|
+
from simplejson import JSONEncoder
|
4
|
+
import simplejson as json
|
5
|
+
|
6
|
+
EXECUTOR_NAME = "python"
|
7
|
+
|
8
|
+
__all__ = ["dumps", "BinValueDict", "VarValueDict", "JobDict", "BlockDict", "StoreKey", "BlockInfo", "EXECUTOR_NAME", "JobDict", "BinValueDict", "VarValueDict"]
|
9
|
+
|
10
|
+
def dumps(obj, **kwargs):
|
11
|
+
return json.dumps(obj, cls=DataclassJSONEncoder, ignore_nan=True, **kwargs)
|
12
|
+
|
13
|
+
class DataclassJSONEncoder(JSONEncoder):
|
14
|
+
def default(self, o): # pyright: ignore[reportIncompatibleMethodOverride]
|
15
|
+
if hasattr(o, '__dataclass_fields__'):
|
16
|
+
return asdict(o)
|
17
|
+
return JSONEncoder.default(self, o)
|
18
|
+
|
19
|
+
class BinValueDict(TypedDict):
|
20
|
+
value: str
|
21
|
+
__OOMOL_TYPE__: Literal["oomol/bin"]
|
22
|
+
|
23
|
+
class VarValueDict(TypedDict):
|
24
|
+
value: dict
|
25
|
+
__OOMOL_TYPE__: Literal["oomol/var"]
|
26
|
+
|
27
|
+
class JobDict(TypedDict):
|
28
|
+
session_id: str
|
29
|
+
job_id: str
|
30
|
+
|
31
|
+
class BlockDict(TypedDict):
|
32
|
+
|
33
|
+
try:
|
34
|
+
# NotRequired, Required was added in version 3.11
|
35
|
+
from typing import NotRequired, Required, TypedDict # type: ignore
|
36
|
+
except ImportError:
|
37
|
+
from typing_extensions import NotRequired, Required, TypedDict
|
38
|
+
|
39
|
+
session_id: str
|
40
|
+
job_id: str
|
41
|
+
stacks: list
|
42
|
+
block_path: NotRequired[str]
|
43
|
+
|
44
|
+
# dataclass 默认字段必须一一匹配
|
45
|
+
# 如果多一个或者少一个字段,就会报错。
|
46
|
+
# 这里想兼容额外多余字段,所以需要自己重写 __init__ 方法,忽略处理多余字段。同时需要自己处理缺少字段的情况。
|
47
|
+
@dataclass(frozen=True, kw_only=True)
|
48
|
+
class StoreKey:
|
49
|
+
executor: str
|
50
|
+
handle: str
|
51
|
+
job_id: str
|
52
|
+
session_id: str
|
53
|
+
|
54
|
+
def __init__(self, **kwargs):
|
55
|
+
for key, value in kwargs.items():
|
56
|
+
object.__setattr__(self, key, value)
|
57
|
+
for key in self.__annotations__.keys():
|
58
|
+
if key not in kwargs:
|
59
|
+
raise ValueError(f"missing key {key}")
|
60
|
+
|
61
|
+
|
62
|
+
# 发送 reporter 时,固定需要的 block 信息参数
|
63
|
+
@dataclass(frozen=True, kw_only=True)
|
64
|
+
class BlockInfo:
|
65
|
+
|
66
|
+
session_id: str
|
67
|
+
job_id: str
|
68
|
+
stacks: list
|
69
|
+
block_path: str | None = None
|
70
|
+
|
71
|
+
def __init__(self, **kwargs):
|
72
|
+
for key, value in kwargs.items():
|
73
|
+
object.__setattr__(self, key, value)
|
74
|
+
for key in self.__annotations__.keys():
|
75
|
+
if key not in kwargs and key != "block_path":
|
76
|
+
raise ValueError(f"missing key {key}")
|
77
|
+
|
78
|
+
def job_info(self) -> JobDict:
|
79
|
+
return {"session_id": self.session_id, "job_id": self.job_id}
|
80
|
+
|
81
|
+
def block_dict(self) -> BlockDict:
|
82
|
+
if self.block_path is None:
|
83
|
+
return {
|
84
|
+
"session_id": self.session_id,
|
85
|
+
"job_id": self.job_id,
|
86
|
+
"stacks": self.stacks,
|
87
|
+
}
|
88
|
+
|
89
|
+
return {
|
90
|
+
"session_id": self.session_id,
|
91
|
+
"job_id": self.job_id,
|
92
|
+
"stacks": self.stacks,
|
93
|
+
"block_path": self.block_path,
|
94
|
+
}
|
95
|
+
|
96
|
+
|
@@ -0,0 +1,55 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from typing import Any, Optional
|
3
|
+
from .schema import FieldSchema, ContentMediaType
|
4
|
+
|
5
|
+
__all__ = ["HandleDef", "InputHandleDef"]
|
6
|
+
|
7
|
+
@dataclass(frozen=True, kw_only=True)
|
8
|
+
class HandleDef:
|
9
|
+
"""The base handle for output def, can be directly used for output def
|
10
|
+
"""
|
11
|
+
handle: str
|
12
|
+
"""The name of the handle. it should be unique in handle list."""
|
13
|
+
|
14
|
+
json_schema: Optional[FieldSchema] = None
|
15
|
+
"""The schema of the handle. It is used to validate the handle's content."""
|
16
|
+
|
17
|
+
name: Optional[str] = None
|
18
|
+
"""A alias of the handle's type name. It is used to display in the UI and connect to the other handle match"""
|
19
|
+
|
20
|
+
def __init__(self, **kwargs):
|
21
|
+
for key, value in kwargs.items():
|
22
|
+
object.__setattr__(self, key, value)
|
23
|
+
if "handle" not in kwargs:
|
24
|
+
raise ValueError("missing attr key: 'handle'")
|
25
|
+
json_schema = self.json_schema
|
26
|
+
if json_schema is not None and not isinstance(json_schema, FieldSchema):
|
27
|
+
object.__setattr__(self, "json_schema", FieldSchema.generate_schema(json_schema))
|
28
|
+
|
29
|
+
def check_handle_type(self, type: ContentMediaType) -> bool:
|
30
|
+
if self.handle is None:
|
31
|
+
return False
|
32
|
+
if self.json_schema is None:
|
33
|
+
return False
|
34
|
+
if self.json_schema.contentMediaType is None:
|
35
|
+
return False
|
36
|
+
return self.json_schema.contentMediaType == type
|
37
|
+
|
38
|
+
def is_var_handle(self) -> bool:
|
39
|
+
return self.check_handle_type("oomol/var")
|
40
|
+
|
41
|
+
def is_secret_handle(self) -> bool:
|
42
|
+
return self.check_handle_type("oomol/secret")
|
43
|
+
|
44
|
+
def is_bin_handle(self) -> bool:
|
45
|
+
return self.check_handle_type("oomol/bin")
|
46
|
+
|
47
|
+
@dataclass(frozen=True, kw_only=True)
|
48
|
+
class InputHandleDef(HandleDef):
|
49
|
+
|
50
|
+
value: Optional[Any] = None
|
51
|
+
"""default value for input handle, can be None.
|
52
|
+
"""
|
53
|
+
|
54
|
+
def __init__(self, **kwargs):
|
55
|
+
super().__init__(**kwargs)
|
@@ -0,0 +1,135 @@
|
|
1
|
+
from simplejson import loads
|
2
|
+
import paho.mqtt.client as mqtt
|
3
|
+
from paho.mqtt.enums import CallbackAPIVersion
|
4
|
+
import operator
|
5
|
+
from urllib.parse import urlparse
|
6
|
+
import uuid
|
7
|
+
from .data import BlockDict, JobDict, dumps
|
8
|
+
import logging
|
9
|
+
from typing import Optional
|
10
|
+
|
11
|
+
__all__ = ["Mainframe"]
|
12
|
+
|
13
|
+
class Mainframe:
|
14
|
+
address: str
|
15
|
+
client: mqtt.Client
|
16
|
+
client_id: str
|
17
|
+
_subscriptions: set[str] = set()
|
18
|
+
_logger: logging.Logger
|
19
|
+
|
20
|
+
def __init__(self, address: str, client_id: Optional[str] = None, logger = None) -> None:
|
21
|
+
self.address = address
|
22
|
+
self.client_id = client_id or f"python-executor-{uuid.uuid4().hex[:8]}"
|
23
|
+
self._logger = logger or logging.getLogger(__name__)
|
24
|
+
|
25
|
+
def connect(self):
|
26
|
+
connect_address = (
|
27
|
+
self.address
|
28
|
+
if operator.contains(self.address, "://")
|
29
|
+
else f"mqtt://{self.address}"
|
30
|
+
)
|
31
|
+
url = urlparse(connect_address)
|
32
|
+
client = self._setup_client()
|
33
|
+
client.connect(host=url.hostname, port=url.port) # type: ignore
|
34
|
+
client.loop_start()
|
35
|
+
|
36
|
+
def _setup_client(self):
|
37
|
+
self.client = mqtt.Client(
|
38
|
+
callback_api_version=CallbackAPIVersion.VERSION2,
|
39
|
+
client_id=self.client_id,
|
40
|
+
)
|
41
|
+
self.client.logger = self._logger
|
42
|
+
self.client.on_connect = self.on_connect
|
43
|
+
self.client.on_disconnect = self.on_disconnect
|
44
|
+
self.client.on_connect_fail = self.on_connect_fail # type: ignore
|
45
|
+
return self.client
|
46
|
+
|
47
|
+
# mqtt v5 重连后,订阅和队列信息会丢失(v3 在初始化时,设置 clean_session 后,会保留两者。
|
48
|
+
# 我们的 broker 使用的是 v5,在 on_connect 里订阅,可以保证每次重连都重新订阅上。
|
49
|
+
def on_connect(self, client, userdata, flags, reason_code, properties):
|
50
|
+
if reason_code != 0:
|
51
|
+
self._logger.error("connect to broker failed, reason_code: %s", reason_code)
|
52
|
+
return
|
53
|
+
else:
|
54
|
+
self._logger.info("connect to broker success")
|
55
|
+
|
56
|
+
for topic in self._subscriptions.copy(): # 进程冲突
|
57
|
+
self._logger.info("resubscribe to topic: {}".format(topic))
|
58
|
+
self.client.subscribe(topic, qos=1)
|
59
|
+
|
60
|
+
def on_connect_fail(self) -> None:
|
61
|
+
self._logger.error("connect to broker failed")
|
62
|
+
|
63
|
+
def on_disconnect(self, client, userdata, flags, reason_code, properties):
|
64
|
+
self._logger.warning("disconnect to broker, reason_code: %s", reason_code)
|
65
|
+
|
66
|
+
# 不等待 publish 完成,使用 qos 参数来会保证消息到达。
|
67
|
+
def send(self, job_info: JobDict, msg) -> mqtt.MQTTMessageInfo:
|
68
|
+
return self.client.publish(
|
69
|
+
f'session/{job_info["session_id"]}', dumps({"job_id": job_info["job_id"], "session_id": job_info["session_id"], **msg}), qos=1
|
70
|
+
)
|
71
|
+
|
72
|
+
def report(self, block_info: BlockDict, msg: dict) -> mqtt.MQTTMessageInfo:
|
73
|
+
return self.client.publish("report", dumps({**block_info, **msg}), qos=1)
|
74
|
+
|
75
|
+
def notify_executor_ready(self, session_id: str, executor_name: str, package: str | None) -> None:
|
76
|
+
self.client.publish(f"session/{session_id}", dumps({
|
77
|
+
"type": "ExecutorReady",
|
78
|
+
"session_id": session_id,
|
79
|
+
"executor_name": executor_name,
|
80
|
+
"package": package,
|
81
|
+
}), qos=1)
|
82
|
+
|
83
|
+
def notify_block_ready(self, session_id: str, job_id: str) -> dict:
|
84
|
+
|
85
|
+
topic = f"inputs/{session_id}/{job_id}"
|
86
|
+
replay = None
|
87
|
+
|
88
|
+
def on_message_once(_client, _userdata, message):
|
89
|
+
nonlocal replay
|
90
|
+
self.client.unsubscribe(topic)
|
91
|
+
replay = loads(message.payload)
|
92
|
+
|
93
|
+
self.client.subscribe(topic, qos=1)
|
94
|
+
self.client.message_callback_add(topic, on_message_once)
|
95
|
+
|
96
|
+
self.client.publish(f"session/{session_id}", dumps({
|
97
|
+
"type": "BlockReady",
|
98
|
+
"session_id": session_id,
|
99
|
+
"job_id": job_id,
|
100
|
+
}), qos=1)
|
101
|
+
|
102
|
+
while True:
|
103
|
+
if replay is not None:
|
104
|
+
self._logger.info("notify ready success in {} {}".format(session_id, job_id))
|
105
|
+
return replay
|
106
|
+
|
107
|
+
def publish(self, topic, payload):
|
108
|
+
self.client.publish(topic, dumps(payload), qos=1)
|
109
|
+
|
110
|
+
def subscribe(self, topic: str, callback):
|
111
|
+
def on_message(_client, _userdata, message):
|
112
|
+
self._logger.info("receive topic: {} payload: {}".format(topic, message.payload))
|
113
|
+
payload = loads(message.payload)
|
114
|
+
callback(payload)
|
115
|
+
|
116
|
+
self.client.message_callback_add(topic, on_message)
|
117
|
+
self._subscriptions.add(topic)
|
118
|
+
|
119
|
+
if self.client.is_connected():
|
120
|
+
self.client.subscribe(topic, qos=1)
|
121
|
+
self._logger.info("subscribe to topic: {}".format(topic))
|
122
|
+
else:
|
123
|
+
self._logger.info("wait connected to subscribe to topic: {}".format(topic))
|
124
|
+
|
125
|
+
|
126
|
+
def unsubscribe(self, topic):
|
127
|
+
self.client.message_callback_remove(topic)
|
128
|
+
self.client.unsubscribe(topic)
|
129
|
+
self._subscriptions.remove(topic)
|
130
|
+
|
131
|
+
def loop(self):
|
132
|
+
self.client.loop_forever()
|
133
|
+
|
134
|
+
def disconnect(self):
|
135
|
+
self.client.disconnect()
|