oocana 0.15.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
oocana-0.15.0/PKG-INFO ADDED
@@ -0,0 +1,10 @@
1
+ Metadata-Version: 2.1
2
+ Name: oocana
3
+ Version: 0.15.0
4
+ Summary: python implement of oocana to give a context for oocana block
5
+ License: MIT
6
+ Requires-Python: >=3.9
7
+ Requires-Dist: paho-mqtt>=2
8
+ Requires-Dist: simplejson>=3.19.2
9
+ Requires-Dist: typing-extensions>=4.12.2; python_version < "3.11"
10
+
@@ -0,0 +1,7 @@
1
+ from .data import * # noqa: F403
2
+ from .context import * # noqa: F403
3
+ from .service import * # noqa: F403
4
+ from .handle_data import * # noqa: F403
5
+ from .preview import * # noqa: F403
6
+ from .schema import * # noqa: F403
7
+ from .mainframe import Mainframe as Mainframe # noqa: F403
@@ -0,0 +1,334 @@
1
+ from dataclasses import asdict
2
+ from json import loads
3
+ from .data import BlockInfo, StoreKey, JobDict, BlockDict, BinValueDict, VarValueDict
4
+ from .handle_data import HandleDef
5
+ from .mainframe import Mainframe
6
+ from typing import Dict, Any, TypedDict, Optional
7
+ from base64 import b64encode
8
+ from io import BytesIO
9
+ from .throtter import throttle
10
+ from .preview import PreviewPayload, TablePreviewData, DataFrame, ShapeDataFrame, PartialDataFrame
11
+ from .data import EXECUTOR_NAME
12
+ import os.path
13
+ import logging
14
+
15
+ __all__ = ["Context"]
16
+
17
+ class OnlyEqualSelf:
18
+ def __eq__(self, value: object) -> bool:
19
+ return self is value
20
+
21
+ class OOMOL_LLM_ENV(TypedDict):
22
+ base_url: str
23
+ api_key: str
24
+ models: list[str]
25
+
26
+ class HostInfo(TypedDict):
27
+ gpu_vendor: str
28
+ gpu_renderer: str
29
+
30
+ class Context:
31
+ __inputs: Dict[str, Any]
32
+
33
+ __block_info: BlockInfo
34
+ __outputs_def: Dict[str, HandleDef]
35
+ __store: Any
36
+ __is_done: bool = False
37
+ __keep_alive: OnlyEqualSelf = OnlyEqualSelf()
38
+ __session_dir: str
39
+ _logger: Optional[logging.Logger] = None
40
+
41
+ def __init__(
42
+ self, inputs: Dict[str, Any], blockInfo: BlockInfo, mainframe: Mainframe, store, outputs, session_dir: str
43
+ ) -> None:
44
+
45
+ self.__block_info = blockInfo
46
+
47
+ self.__mainframe = mainframe
48
+ self.__store = store
49
+ self.__inputs = inputs
50
+
51
+ outputs_defs = {}
52
+ if outputs is not None:
53
+ for k, v in outputs.items():
54
+ outputs_defs[k] = HandleDef(**v)
55
+ self.__outputs_def = outputs_defs
56
+ self.__session_dir = session_dir
57
+
58
+ @property
59
+ def logger(self) -> logging.Logger:
60
+ """a custom logger for the block, you can use it to log the message to the block log. this logger will report the log by context report_logger api.
61
+ """
62
+
63
+ # setup after init, so the logger always exists
64
+ if self._logger is None:
65
+ raise ValueError("logger is not setup, please setup the logger in the block init function.")
66
+ return self._logger
67
+
68
+ @property
69
+ def session_dir(self) -> str:
70
+ """a temporary directory for the current session, all blocks in the one session will share the same directory.
71
+ """
72
+ return self.__session_dir
73
+
74
+ @property
75
+ def keepAlive(self):
76
+ return self.__keep_alive
77
+
78
+ @property
79
+ def inputs(self):
80
+ return self.__inputs
81
+
82
+ @property
83
+ def session_id(self):
84
+ return self.__block_info.session_id
85
+
86
+ @property
87
+ def job_id(self):
88
+ return self.__block_info.job_id
89
+
90
+ @property
91
+ def job_info(self) -> JobDict:
92
+ return self.__block_info.job_info()
93
+
94
+ @property
95
+ def block_info(self) -> BlockDict:
96
+ return self.__block_info.block_dict()
97
+
98
+ @property
99
+ def node_id(self) -> str:
100
+ return self.__block_info.stacks[-1].get("node_id", None)
101
+
102
+ @property
103
+ def oomol_llm_env(self) -> OOMOL_LLM_ENV:
104
+ """this is a dict contains the oomol llm environment variables
105
+ """
106
+ return {
107
+ "base_url": os.getenv("OOMOL_LLM_BASE_URL", ""),
108
+ "api_key": os.getenv("OOMOL_LLM_API_KEY", ""),
109
+ "models": os.getenv("OOMOL_LLM_MODELS", "").split(","),
110
+ }
111
+
112
+ @property
113
+ def host_info(self) -> HostInfo:
114
+ """this is a dict contains the host information
115
+ """
116
+ return {
117
+ "gpu_vendor": os.getenv("OOMOL_HOST_GPU_VENDOR", "unknown"),
118
+ "gpu_renderer": os.getenv("OOMOL_HOST_GPU_RENDERER", "unknown"),
119
+ }
120
+
121
+ def __store_ref(self, handle: str):
122
+ return StoreKey(
123
+ executor=EXECUTOR_NAME,
124
+ handle=handle,
125
+ job_id=self.job_id,
126
+ session_id=self.session_id,
127
+ )
128
+
129
+ def __is_basic_type(self, value: Any) -> bool:
130
+ return isinstance(value, (int, float, str, bool))
131
+
132
+ def output(self, key: str, value: Any, done: bool = False):
133
+ """
134
+ output the value to the next block
135
+
136
+ key: str, the key of the output, should be defined in the block schema output defs, the field name is handle
137
+ value: Any, the value of the output
138
+ """
139
+
140
+ v = value
141
+
142
+ if self.__outputs_def is not None:
143
+ output_def = self.__outputs_def.get(key)
144
+ if (
145
+ output_def is not None and output_def.is_var_handle() and not self.__is_basic_type(value) # 基础类型即使是变量也不放进 store,直接作为 json 内容传递
146
+ ):
147
+ ref = self.__store_ref(key)
148
+ self.__store[ref] = value
149
+ d: VarValueDict = {
150
+ "__OOMOL_TYPE__": "oomol/var",
151
+ "value": asdict(ref)
152
+ }
153
+ v = d
154
+ elif output_def is not None and output_def.is_bin_handle():
155
+ if not isinstance(value, bytes):
156
+ self.send_warning(
157
+ f"Output handle key: [{key}] is defined as binary, but the value is not bytes."
158
+ )
159
+ return
160
+
161
+ bin_file = f"{self.session_dir}/binary/{self.session_id}/{self.job_id}/{key}"
162
+ os.makedirs(os.path.dirname(bin_file), exist_ok=True)
163
+ try:
164
+ with open(bin_file, "wb") as f:
165
+ f.write(value)
166
+ except IOError as e:
167
+ self.send_warning(
168
+ f"Output handle key: [{key}] is defined as binary, but an error occurred while writing the file: {e}"
169
+ )
170
+ return
171
+
172
+ if os.path.exists(bin_file):
173
+ bin_value: BinValueDict = {
174
+ "__OOMOL_TYPE__": "oomol/bin",
175
+ "value": bin_file,
176
+ }
177
+ v = bin_value
178
+ else:
179
+ self.send_warning(
180
+ f"Output handle key: [{key}] is defined as binary, but the file is not written."
181
+ )
182
+ return
183
+
184
+ # 如果传入 key 在输出定义中不存在,直接忽略,不发送数据。但是 done 仍然生效。
185
+ if self.__outputs_def is not None and self.__outputs_def.get(key) is None:
186
+ self.send_warning(
187
+ f"Output handle key: [{key}] is not defined in Block outputs schema."
188
+ )
189
+ if done:
190
+ self.done()
191
+ return
192
+
193
+ node_result = {
194
+ "type": "BlockOutput",
195
+ "handle": key,
196
+ "output": v,
197
+ "done": done,
198
+ }
199
+ self.__mainframe.send(self.job_info, node_result)
200
+
201
+ if done:
202
+ self.done()
203
+
204
+ def done(self, error: str | None = None):
205
+ if self.__is_done:
206
+ self.send_warning("done has been called multiple times, will be ignored.")
207
+ return
208
+ self.__is_done = True
209
+ if error is None:
210
+ self.__mainframe.send(self.job_info, {"type": "BlockFinished"})
211
+ else:
212
+ self.__mainframe.send(
213
+ self.job_info, {"type": "BlockFinished", "error": error}
214
+ )
215
+
216
+ def send_message(self, payload):
217
+ self.__mainframe.report(
218
+ self.block_info,
219
+ {
220
+ "type": "BlockMessage",
221
+ "payload": payload,
222
+ },
223
+ )
224
+
225
+ def __dataframe(self, payload: PreviewPayload) -> PreviewPayload:
226
+ if isinstance(payload, DataFrame):
227
+ payload = { "type": "table", "data": payload }
228
+
229
+ if isinstance(payload, dict) and payload.get("type") == "table":
230
+ df = payload.get("data")
231
+ if isinstance(df, ShapeDataFrame):
232
+ row_count = df.shape[0]
233
+ if row_count <= 10:
234
+ data = df.to_dict(orient='split')
235
+ columns = data.get("columns", [])
236
+ rows = data.get("data", [])
237
+ elif isinstance(df, PartialDataFrame):
238
+ data_columns = loads(df.head(5).to_json(orient='split'))
239
+ columns = data_columns.get("columns", [])
240
+ rows_head = data_columns.get("data", [])
241
+ data_tail = loads(df.tail(5).to_json(orient='split'))
242
+ rows_tail = data_tail.get("data", [])
243
+ rows_dots = [["..."] * len(columns)]
244
+ rows = rows_head + rows_dots + rows_tail
245
+ else:
246
+ print("dataframe more than 10 rows but not support head and tail is not supported")
247
+ return payload
248
+ data: TablePreviewData = { "rows": rows, "columns": columns, "row_count": row_count }
249
+ payload = { "type": "table", "data": data }
250
+ else:
251
+ print("dataframe is not support shape property")
252
+
253
+ return payload
254
+
255
+ def __matplotlib(self, payload: PreviewPayload) -> PreviewPayload:
256
+ # payload is a matplotlib Figure
257
+ if hasattr(payload, 'savefig'):
258
+ fig: Any = payload
259
+ buffer = BytesIO()
260
+ fig.savefig(buffer, format='png')
261
+ buffer.seek(0)
262
+ png = buffer.getvalue()
263
+ buffer.close()
264
+ url = f'data:image/png;base64,{b64encode(png).decode("utf-8")}'
265
+ payload = { "type": "image", "data": url }
266
+
267
+ return payload
268
+
269
+ def preview(self, payload: PreviewPayload):
270
+ payload = self.__dataframe(payload)
271
+ payload = self.__matplotlib(payload)
272
+
273
+ self.__mainframe.report(
274
+ self.block_info,
275
+ {
276
+ "type": "BlockPreview",
277
+ "payload": payload,
278
+ },
279
+ )
280
+
281
+ @throttle(0.3)
282
+ def report_progress(self, progress: float | int):
283
+ """report progress
284
+
285
+ This api is used to report the progress of the block. but it just effect the ui progress not the real progress.
286
+ This api is throttled. the minimum interval is 0.3s.
287
+ When you first call this api, it will report the progress immediately. After it invoked once, it will report the progress at the end of the throttling period.
288
+
289
+ | 0.25 s | 0.2 s |
290
+ first call second call third call 4 5 6 7's calls
291
+ | | | | | | |
292
+ | -------- 0.3 s -------- | -------- 0.3 s -------- |
293
+ invoke invoke invoke
294
+ :param float | int progress: the progress of the block, the value should be in [0, 100].
295
+ """
296
+ self.__mainframe.report(
297
+ self.block_info,
298
+ {
299
+ "type": "BlockProgress",
300
+ "rate": progress,
301
+ }
302
+ )
303
+
304
+ def report_log(self, line: str, stdio: str = "stdout"):
305
+ self.__mainframe.report(
306
+ self.block_info,
307
+ {
308
+ "type": "BlockLog",
309
+ "log": line,
310
+ stdio: stdio,
311
+ },
312
+ )
313
+
314
+ def log_json(self, payload):
315
+ self.__mainframe.report(
316
+ self.block_info,
317
+ {
318
+ "type": "BlockLogJSON",
319
+ "json": payload,
320
+ },
321
+ )
322
+
323
+ def send_warning(self, warning: str):
324
+ self.__mainframe.report(self.block_info, {"type": "BlockWarning", "warning": warning})
325
+
326
+ def send_error(self, error: str):
327
+ '''
328
+ deprecated, use error(error) instead.
329
+ consider to remove in the future.
330
+ '''
331
+ self.error(error)
332
+
333
+ def error(self, error: str):
334
+ self.__mainframe.send(self.job_info, {"type": "BlockError", "error": error})
@@ -0,0 +1,96 @@
1
+ from dataclasses import dataclass, asdict
2
+ from typing import TypedDict, Literal
3
+ from simplejson import JSONEncoder
4
+ import simplejson as json
5
+
6
+ EXECUTOR_NAME = "python"
7
+
8
+ __all__ = ["dumps", "BinValueDict", "VarValueDict", "JobDict", "BlockDict", "StoreKey", "BlockInfo", "EXECUTOR_NAME", "JobDict", "BinValueDict", "VarValueDict"]
9
+
10
+ def dumps(obj, **kwargs):
11
+ return json.dumps(obj, cls=DataclassJSONEncoder, ignore_nan=True, **kwargs)
12
+
13
+ class DataclassJSONEncoder(JSONEncoder):
14
+ def default(self, o): # pyright: ignore[reportIncompatibleMethodOverride]
15
+ if hasattr(o, '__dataclass_fields__'):
16
+ return asdict(o)
17
+ return JSONEncoder.default(self, o)
18
+
19
+ class BinValueDict(TypedDict):
20
+ value: str
21
+ __OOMOL_TYPE__: Literal["oomol/bin"]
22
+
23
+ class VarValueDict(TypedDict):
24
+ value: dict
25
+ __OOMOL_TYPE__: Literal["oomol/var"]
26
+
27
+ class JobDict(TypedDict):
28
+ session_id: str
29
+ job_id: str
30
+
31
+ class BlockDict(TypedDict):
32
+
33
+ try:
34
+ # NotRequired, Required was added in version 3.11
35
+ from typing import NotRequired, Required, TypedDict # type: ignore
36
+ except ImportError:
37
+ from typing_extensions import NotRequired, Required, TypedDict
38
+
39
+ session_id: str
40
+ job_id: str
41
+ stacks: list
42
+ block_path: NotRequired[str]
43
+
44
+ # dataclass 默认字段必须一一匹配
45
+ # 如果多一个或者少一个字段,就会报错。
46
+ # 这里想兼容额外多余字段,所以需要自己重写 __init__ 方法,忽略处理多余字段。同时需要自己处理缺少字段的情况。
47
+ @dataclass(frozen=True, kw_only=True)
48
+ class StoreKey:
49
+ executor: str
50
+ handle: str
51
+ job_id: str
52
+ session_id: str
53
+
54
+ def __init__(self, **kwargs):
55
+ for key, value in kwargs.items():
56
+ object.__setattr__(self, key, value)
57
+ for key in self.__annotations__.keys():
58
+ if key not in kwargs:
59
+ raise ValueError(f"missing key {key}")
60
+
61
+
62
+ # 发送 reporter 时,固定需要的 block 信息参数
63
+ @dataclass(frozen=True, kw_only=True)
64
+ class BlockInfo:
65
+
66
+ session_id: str
67
+ job_id: str
68
+ stacks: list
69
+ block_path: str | None = None
70
+
71
+ def __init__(self, **kwargs):
72
+ for key, value in kwargs.items():
73
+ object.__setattr__(self, key, value)
74
+ for key in self.__annotations__.keys():
75
+ if key not in kwargs and key != "block_path":
76
+ raise ValueError(f"missing key {key}")
77
+
78
+ def job_info(self) -> JobDict:
79
+ return {"session_id": self.session_id, "job_id": self.job_id}
80
+
81
+ def block_dict(self) -> BlockDict:
82
+ if self.block_path is None:
83
+ return {
84
+ "session_id": self.session_id,
85
+ "job_id": self.job_id,
86
+ "stacks": self.stacks,
87
+ }
88
+
89
+ return {
90
+ "session_id": self.session_id,
91
+ "job_id": self.job_id,
92
+ "stacks": self.stacks,
93
+ "block_path": self.block_path,
94
+ }
95
+
96
+
@@ -0,0 +1,55 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any, Optional
3
+ from .schema import FieldSchema, ContentMediaType
4
+
5
+ __all__ = ["HandleDef", "InputHandleDef"]
6
+
7
+ @dataclass(frozen=True, kw_only=True)
8
+ class HandleDef:
9
+ """The base handle for output def, can be directly used for output def
10
+ """
11
+ handle: str
12
+ """The name of the handle. it should be unique in handle list."""
13
+
14
+ json_schema: Optional[FieldSchema] = None
15
+ """The schema of the handle. It is used to validate the handle's content."""
16
+
17
+ name: Optional[str] = None
18
+ """A alias of the handle's type name. It is used to display in the UI and connect to the other handle match"""
19
+
20
+ def __init__(self, **kwargs):
21
+ for key, value in kwargs.items():
22
+ object.__setattr__(self, key, value)
23
+ if "handle" not in kwargs:
24
+ raise ValueError("missing attr key: 'handle'")
25
+ json_schema = self.json_schema
26
+ if json_schema is not None and not isinstance(json_schema, FieldSchema):
27
+ object.__setattr__(self, "json_schema", FieldSchema.generate_schema(json_schema))
28
+
29
+ def check_handle_type(self, type: ContentMediaType) -> bool:
30
+ if self.handle is None:
31
+ return False
32
+ if self.json_schema is None:
33
+ return False
34
+ if self.json_schema.contentMediaType is None:
35
+ return False
36
+ return self.json_schema.contentMediaType == type
37
+
38
+ def is_var_handle(self) -> bool:
39
+ return self.check_handle_type("oomol/var")
40
+
41
+ def is_secret_handle(self) -> bool:
42
+ return self.check_handle_type("oomol/secret")
43
+
44
+ def is_bin_handle(self) -> bool:
45
+ return self.check_handle_type("oomol/bin")
46
+
47
+ @dataclass(frozen=True, kw_only=True)
48
+ class InputHandleDef(HandleDef):
49
+
50
+ value: Optional[Any] = None
51
+ """default value for input handle, can be None.
52
+ """
53
+
54
+ def __init__(self, **kwargs):
55
+ super().__init__(**kwargs)
@@ -0,0 +1,135 @@
1
+ from simplejson import loads
2
+ import paho.mqtt.client as mqtt
3
+ from paho.mqtt.enums import CallbackAPIVersion
4
+ import operator
5
+ from urllib.parse import urlparse
6
+ import uuid
7
+ from .data import BlockDict, JobDict, dumps
8
+ import logging
9
+ from typing import Optional
10
+
11
+ __all__ = ["Mainframe"]
12
+
13
+ class Mainframe:
14
+ address: str
15
+ client: mqtt.Client
16
+ client_id: str
17
+ _subscriptions: set[str] = set()
18
+ _logger: logging.Logger
19
+
20
+ def __init__(self, address: str, client_id: Optional[str] = None, logger = None) -> None:
21
+ self.address = address
22
+ self.client_id = client_id or f"python-executor-{uuid.uuid4().hex[:8]}"
23
+ self._logger = logger or logging.getLogger(__name__)
24
+
25
+ def connect(self):
26
+ connect_address = (
27
+ self.address
28
+ if operator.contains(self.address, "://")
29
+ else f"mqtt://{self.address}"
30
+ )
31
+ url = urlparse(connect_address)
32
+ client = self._setup_client()
33
+ client.connect(host=url.hostname, port=url.port) # type: ignore
34
+ client.loop_start()
35
+
36
+ def _setup_client(self):
37
+ self.client = mqtt.Client(
38
+ callback_api_version=CallbackAPIVersion.VERSION2,
39
+ client_id=self.client_id,
40
+ )
41
+ self.client.logger = self._logger
42
+ self.client.on_connect = self.on_connect
43
+ self.client.on_disconnect = self.on_disconnect
44
+ self.client.on_connect_fail = self.on_connect_fail # type: ignore
45
+ return self.client
46
+
47
+ # mqtt v5 重连后,订阅和队列信息会丢失(v3 在初始化时,设置 clean_session 后,会保留两者。
48
+ # 我们的 broker 使用的是 v5,在 on_connect 里订阅,可以保证每次重连都重新订阅上。
49
+ def on_connect(self, client, userdata, flags, reason_code, properties):
50
+ if reason_code != 0:
51
+ self._logger.error("connect to broker failed, reason_code: %s", reason_code)
52
+ return
53
+ else:
54
+ self._logger.info("connect to broker success")
55
+
56
+ for topic in self._subscriptions.copy(): # 进程冲突
57
+ self._logger.info("resubscribe to topic: {}".format(topic))
58
+ self.client.subscribe(topic, qos=1)
59
+
60
+ def on_connect_fail(self) -> None:
61
+ self._logger.error("connect to broker failed")
62
+
63
+ def on_disconnect(self, client, userdata, flags, reason_code, properties):
64
+ self._logger.warning("disconnect to broker, reason_code: %s", reason_code)
65
+
66
+ # 不等待 publish 完成,使用 qos 参数来会保证消息到达。
67
+ def send(self, job_info: JobDict, msg) -> mqtt.MQTTMessageInfo:
68
+ return self.client.publish(
69
+ f'session/{job_info["session_id"]}', dumps({"job_id": job_info["job_id"], "session_id": job_info["session_id"], **msg}), qos=1
70
+ )
71
+
72
+ def report(self, block_info: BlockDict, msg: dict) -> mqtt.MQTTMessageInfo:
73
+ return self.client.publish("report", dumps({**block_info, **msg}), qos=1)
74
+
75
+ def notify_executor_ready(self, session_id: str, executor_name: str, package: str | None) -> None:
76
+ self.client.publish(f"session/{session_id}", dumps({
77
+ "type": "ExecutorReady",
78
+ "session_id": session_id,
79
+ "executor_name": executor_name,
80
+ "package": package,
81
+ }), qos=1)
82
+
83
+ def notify_block_ready(self, session_id: str, job_id: str) -> dict:
84
+
85
+ topic = f"inputs/{session_id}/{job_id}"
86
+ replay = None
87
+
88
+ def on_message_once(_client, _userdata, message):
89
+ nonlocal replay
90
+ self.client.unsubscribe(topic)
91
+ replay = loads(message.payload)
92
+
93
+ self.client.subscribe(topic, qos=1)
94
+ self.client.message_callback_add(topic, on_message_once)
95
+
96
+ self.client.publish(f"session/{session_id}", dumps({
97
+ "type": "BlockReady",
98
+ "session_id": session_id,
99
+ "job_id": job_id,
100
+ }), qos=1)
101
+
102
+ while True:
103
+ if replay is not None:
104
+ self._logger.info("notify ready success in {} {}".format(session_id, job_id))
105
+ return replay
106
+
107
+ def publish(self, topic, payload):
108
+ self.client.publish(topic, dumps(payload), qos=1)
109
+
110
+ def subscribe(self, topic: str, callback):
111
+ def on_message(_client, _userdata, message):
112
+ self._logger.info("receive topic: {} payload: {}".format(topic, message.payload))
113
+ payload = loads(message.payload)
114
+ callback(payload)
115
+
116
+ self.client.message_callback_add(topic, on_message)
117
+ self._subscriptions.add(topic)
118
+
119
+ if self.client.is_connected():
120
+ self.client.subscribe(topic, qos=1)
121
+ self._logger.info("subscribe to topic: {}".format(topic))
122
+ else:
123
+ self._logger.info("wait connected to subscribe to topic: {}".format(topic))
124
+
125
+
126
+ def unsubscribe(self, topic):
127
+ self.client.message_callback_remove(topic)
128
+ self.client.unsubscribe(topic)
129
+ self._subscriptions.remove(topic)
130
+
131
+ def loop(self):
132
+ self.client.loop_forever()
133
+
134
+ def disconnect(self):
135
+ self.client.disconnect()