onestep 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,93 @@
1
+ import json
2
+ import threading
3
+ from queue import Queue
4
+ from typing import Any
5
+
6
+ try:
7
+ from use_redis import useRedis
8
+ except ImportError:
9
+ ...
10
+
11
+ from ..base import BaseBroker, BaseConsumer, Message
12
+
13
+
14
+ class _RedisPubSubMessage(Message):
15
+
16
+ @classmethod
17
+ def from_broker(cls, broker_message: Any):
18
+ if "channel" in broker_message:
19
+ try:
20
+ message = json.loads(broker_message.get("data")) # 已转换的 message
21
+ except (json.JSONDecodeError, TypeError):
22
+ message = {"body": broker_message.get("data")} # 未转换的 message
23
+ else:
24
+ # 来自 外部的消息 直接认为都是 message.body
25
+ message = {"body": broker_message.body}
26
+
27
+ yield cls(body=message.get("body"), extra=message.get("extra"), message=broker_message)
28
+
29
+
30
+ class RedisPubSubBroker(BaseBroker):
31
+ """ Redis PubSub Broker """
32
+ message_cls = _RedisPubSubMessage
33
+
34
+ def __init__(self, channel: str, *args, **kwargs):
35
+ super().__init__(*args, **kwargs)
36
+ self.channel = channel
37
+ self.queue = Queue()
38
+
39
+ self.threads = []
40
+
41
+ self.client = useRedis(**kwargs).connection
42
+
43
+ def _consume(self):
44
+ def callback(message: dict):
45
+ if message.get('type') != 'message':
46
+ return
47
+ self.queue.put(message)
48
+
49
+ ps = self.client.pubsub()
50
+ ps.subscribe(self.channel)
51
+ for message in ps.listen():
52
+ callback(message)
53
+
54
+ def consume(self, *args, **kwargs):
55
+ daemon = kwargs.pop('daemon', True)
56
+ thread = threading.Thread(target=self._consume, *args, **kwargs)
57
+ thread.daemon = daemon
58
+ thread.start()
59
+ self.threads.append(thread)
60
+ return RedisPubSubConsumer(self)
61
+
62
+ def send(self, message: Any, *args, **kwargs):
63
+ """Publish message to the Redis channel"""
64
+ if not isinstance(message, Message):
65
+ message = self.message_cls(body=message)
66
+
67
+ self.client.publish(self.channel, message.to_json(), *args, **kwargs)
68
+
69
+ publish = send
70
+
71
+ def confirm(self, message: Message):
72
+ pass
73
+
74
+ def reject(self, message: Message):
75
+ pass
76
+
77
+ def requeue(self, message: Message, is_source=False):
78
+ """
79
+ 重发消息:先拒绝 再 重入
80
+
81
+ :param message: 消息
82
+ :param is_source: 是否是原始消息消息,True: 使用原始消息重入当前队列,False: 使用消息的最新数据重入当前队列
83
+ """
84
+ self.reject(message)
85
+
86
+ if is_source:
87
+ self.client.publish(self.channel, message.message['data'])
88
+ else:
89
+ self.send(message)
90
+
91
+
92
+ class RedisPubSubConsumer(BaseConsumer):
93
+ ...
@@ -0,0 +1,114 @@
1
+ import json
2
+ import threading
3
+ import uuid
4
+ from queue import Queue
5
+ from typing import Optional, Dict, Any
6
+
7
+ try:
8
+ from use_redis import useRedisStreamStore, RedisStreamMessage
9
+ except ImportError:
10
+ ...
11
+
12
+ from ..base import BaseBroker, BaseConsumer, Message
13
+
14
+
15
+ class _RedisStreamMessage(Message):
16
+ @classmethod
17
+ def from_broker(cls, broker_message: "RedisStreamMessage"):
18
+ if "_message" in broker_message.body:
19
+ # 来自 RedisStreamBroker.send 的消息,message.body 默认是存于 _message 字段中
20
+ try:
21
+ message = json.loads(broker_message.body.get("_message")) # 已转换的 message
22
+ except (json.JSONDecodeError, TypeError):
23
+ message = {"body": broker_message.body.get("_message")} # 未转换的 message
24
+ else:
25
+ # 来自 外部的消息 直接认为都是 message.body
26
+ message = {"body": broker_message.body}
27
+
28
+ yield cls(body=message.get("body"), extra=message.get("extra"), message=broker_message)
29
+
30
+
31
+ class RedisStreamBroker(BaseBroker):
32
+ """ Redis Stream Broker """
33
+ message_cls = _RedisStreamMessage
34
+
35
+ def __init__(
36
+ self,
37
+ stream: str,
38
+ group: str = "onestep",
39
+ params: Optional[Dict] = None,
40
+ prefetch: Optional[int] = 1,
41
+ stream_max_entries: int = 0,
42
+ redeliver_timeout: int = 60000,
43
+ claim_interval: int = 1800000,
44
+ *args,
45
+ **kwargs):
46
+ super().__init__(*args, **kwargs)
47
+ self.stream = stream
48
+ self.group = group
49
+ self.prefetch = prefetch
50
+ self.queue = Queue()
51
+
52
+ self.threads = []
53
+
54
+ self.client = useRedisStreamStore(
55
+ stream=stream,
56
+ group=group,
57
+ stream_max_entries=stream_max_entries,
58
+ redeliver_timeout=redeliver_timeout,
59
+ claim_interval=claim_interval,
60
+ **(params or {})
61
+ )
62
+
63
+ def _consume(self, *args, **kwargs):
64
+ def callback(message):
65
+ self.queue.put(message)
66
+
67
+ prefetch = kwargs.pop("prefetch", self.prefetch)
68
+ self.client.start_consuming(consumer=uuid.uuid4().hex, callback=callback, prefetch=prefetch, **kwargs)
69
+
70
+ def consume(self, *args, **kwargs):
71
+ daemon = kwargs.pop('daemon', True)
72
+ thread = threading.Thread(target=self._consume, args=args, kwargs=kwargs)
73
+ thread.daemon = daemon
74
+ thread.start()
75
+ self.threads.append(thread)
76
+ return RedisStreamConsumer(self)
77
+
78
+ def send(self, message: Any, *args, **kwargs):
79
+ """对消息进行预处理,然后再发送"""
80
+ if not isinstance(message, Message):
81
+ message = self.message_cls(body=message)
82
+
83
+ self.client.send({"_message": message.to_json()}, *args, **kwargs)
84
+
85
+ publish = send
86
+
87
+ def confirm(self, message: Message):
88
+ broker_msg = getattr(message, "message", None)
89
+ if broker_msg is not None:
90
+ self.client.ack(broker_msg)
91
+
92
+ def reject(self, message: Message):
93
+ broker_msg = getattr(message, "message", None)
94
+ if broker_msg is not None:
95
+ self.client.reject(broker_msg)
96
+
97
+ def requeue(self, message: Message, is_source=False):
98
+ """
99
+ 重发消息:先拒绝 再 重入
100
+
101
+ :param message: 消息
102
+ :param is_source: 是否是原始消息消息,True: 使用原始消息重入当前队列,False: 使用消息的最新数据重入当前队列
103
+ """
104
+ self.reject(message)
105
+
106
+ broker_msg = getattr(message, "message", None)
107
+ if is_source and broker_msg is not None and hasattr(broker_msg, "body"):
108
+ self.client.send(broker_msg.body)
109
+ else:
110
+ self.send(message)
111
+
112
+
113
+ class RedisStreamConsumer(BaseConsumer):
114
+ ...
@@ -0,0 +1,4 @@
1
+ from .sqs import SQSBroker, SQSConsumer
2
+ from .sns import SNSBroker
3
+
4
+ __all__ = ["SQSBroker", "SQSConsumer", "SNSBroker"]
@@ -0,0 +1,53 @@
1
+ from typing import Optional, Dict, Any
2
+
3
+ from onestep.broker import BaseBroker
4
+
5
+ try:
6
+ from use_sqs import SNSPublisher
7
+ except ImportError:
8
+ SNSPublisher = None
9
+
10
+
11
+ class SNSBroker(BaseBroker):
12
+ """SNS消息Broker实现"""
13
+
14
+ def __init__(
15
+ self,
16
+ topic_arn: str,
17
+ message_group_id: Optional[str] = None,
18
+ params: Optional[Dict] = None,
19
+ *args,
20
+ **kwargs,
21
+ ):
22
+ """
23
+ 初始化SNS Broker
24
+
25
+ :param topic_arn: SNS主题ARN
26
+ :param message_group_id: 消息组ID (FIFO主题必需)
27
+ :param params: SNS连接参数
28
+ """
29
+ if SNSPublisher is None:
30
+ raise ImportError("Please install the `use-sqs` module to use SNSBroker")
31
+
32
+ super().__init__(*args, **kwargs)
33
+ self.topic_arn = topic_arn
34
+ self.message_group_id = message_group_id
35
+
36
+ # 创建SNSPublisher实例
37
+ self.publisher = SNSPublisher(**(params or {}))
38
+
39
+ def publish(self, message: Any, **kwargs):
40
+ """发布消息"""
41
+ # 如果初始化时设置了message_group_id,且调用时未指定,则使用初始化的值
42
+ if self.message_group_id and "message_group_id" not in kwargs:
43
+ kwargs["message_group_id"] = self.message_group_id
44
+
45
+ self.publisher.publish(
46
+ self.topic_arn,
47
+ message=message,
48
+ **kwargs,
49
+ )
50
+
51
+ def consume(self, *args, **kwargs):
52
+ """消费消息 - SNS不支持直接消费"""
53
+ raise NotImplementedError("SNS Broker does not support consumption")
@@ -0,0 +1,181 @@
1
+ import json
2
+ import threading
3
+ from queue import Queue
4
+ from typing import Optional, Dict, Any, Callable
5
+
6
+ from onestep.broker import BaseBroker, BaseConsumer
7
+ from onestep.message import Message
8
+ try:
9
+ from use_sqs import SQSStore
10
+ except ImportError:
11
+ SQSStore = None
12
+ boto3 = None
13
+
14
+
15
+
16
+ class _SQSMessage(Message):
17
+ """SQS消息类"""
18
+
19
+ @classmethod
20
+ def from_broker(cls, broker_message: "boto3.resources.factory.sqs.Message"):
21
+ """从SQS消息创建Message对象"""
22
+ required_attrs = ("body", "delete", "message_id")
23
+ if not all(hasattr(broker_message, attr) for attr in required_attrs):
24
+ raise TypeError(
25
+ f"Message object missing required SQS attributes: {required_attrs}"
26
+ )
27
+
28
+ try:
29
+ # 如果body已经是dict,直接使用;如果是字符串,则解析JSON
30
+ if isinstance(broker_message.body, dict):
31
+ message = broker_message.body
32
+ else:
33
+ message = json.loads(broker_message.body)
34
+ except (json.JSONDecodeError, TypeError):
35
+ message = {"body": broker_message.body}
36
+
37
+ if not isinstance(message, dict):
38
+ message = {"body": message}
39
+ if "body" not in message:
40
+ message = {"body": message}
41
+
42
+ return cls(
43
+ body=message.get("body"), extra=message.get("extra"), message=broker_message
44
+ )
45
+
46
+
47
+ class SQSBroker(BaseBroker):
48
+ """SQS消息队列Broker实现"""
49
+
50
+ message_cls = _SQSMessage
51
+
52
+ def __init__(
53
+ self,
54
+ queue_name: str,
55
+ message_group_id: str,
56
+ message_deduplication_id_func: Optional[Callable] = None,
57
+ params: Optional[Dict] = None,
58
+ prefetch: Optional[int] = 1,
59
+ auto_create: bool = True,
60
+ queue_params: Optional[Dict] = None,
61
+ *args,
62
+ **kwargs,
63
+ ):
64
+ """
65
+ 初始化SQS Broker
66
+
67
+ :param queue_name: 队列名称
68
+ :param message_group_id: 消息组ID (FIFO队列必需)
69
+ :param message_deduplication_id_func: 接收 msg_body(json_str) 作为参数,返回 str 类型的 message_deduplication_id
70
+ :param params: SQS连接参数
71
+ :param prefetch: 预取消息数量
72
+ :param auto_create: 是否自动创建队列
73
+ :param queue_params: 队列参数
74
+ """
75
+ super().__init__(*args, **kwargs)
76
+ self.queue_name = queue_name
77
+ self.queue = Queue()
78
+ self.prefetch = prefetch
79
+ self.message_group_id = message_group_id
80
+ self.message_deduplication_id_func = message_deduplication_id_func
81
+ self.threads = []
82
+ self._shutdown = False
83
+ self._consuming_started = False
84
+ self._consume_lock = threading.Lock()
85
+
86
+ # 创建SQSStore实例
87
+ if SQSStore is None:
88
+ raise ImportError("Please install the `use-sqs` module to use SQSBroker")
89
+ self.store = SQSStore(**(params or {}))
90
+
91
+ # 确保队列存在
92
+ if auto_create:
93
+ self.store.declare_queue(queue_name, attributes=queue_params)
94
+
95
+ def _consume(self, *args, **kwargs):
96
+ """消费消息的内部方法"""
97
+ prefetch = kwargs.pop("prefetch", self.prefetch)
98
+
99
+ def callback(message):
100
+ """处理接收到的消息"""
101
+ # 直接将原始SQS消息放入队列,保留完整的消息引用
102
+ self.queue.put(message)
103
+
104
+ self.store.start_consuming(
105
+ self.queue_name, callback=callback, prefetch=prefetch, **kwargs
106
+ )
107
+
108
+ def consume(self, *args, **kwargs) -> "SQSConsumer":
109
+ """启动消费者"""
110
+ daemon = kwargs.pop("daemon", True)
111
+ timeout = kwargs.pop("timeout", 1000)
112
+ with self._consume_lock:
113
+ if not self._consuming_started:
114
+ thread_kwargs = kwargs.copy()
115
+ thread = threading.Thread(target=self._consume, args=args, kwargs=thread_kwargs)
116
+ thread.daemon = daemon
117
+ thread.start()
118
+ self.threads.append(thread)
119
+ self._consuming_started = True
120
+ return SQSConsumer(self, timeout=timeout)
121
+
122
+ def publish(self, message: Any, **kwargs):
123
+ """发布消息"""
124
+ if self.message_deduplication_id_func:
125
+ message_deduplication_id = self.message_deduplication_id_func(message)
126
+ if (
127
+ isinstance(message_deduplication_id, str)
128
+ and message_deduplication_id.strip()
129
+ and len(message_deduplication_id.strip()) <= 128
130
+ ):
131
+ kwargs["message_deduplication_id"] = message_deduplication_id
132
+
133
+ self.store.send(
134
+ self.queue_name,
135
+ message=message,
136
+ message_group_id=self.message_group_id,
137
+ **kwargs,
138
+ )
139
+
140
+ def confirm(self, message: Message):
141
+ """确认消息"""
142
+ message.message.delete()
143
+
144
+ def reject(self, message: Message):
145
+ """拒绝消息"""
146
+ pass # 拒绝消息时不删除,等待下一次消费 # 以支持 DLQ 机制
147
+
148
+ def requeue(self, message: Message, is_source: bool = False):
149
+ """
150
+ 重新入队消息
151
+
152
+ :param message: 消息对象
153
+ :param is_source: 是否使用原始消息
154
+ """
155
+ broker_msg = getattr(message, "message", None)
156
+ # 确认原始消息,重新入队处理后的消息,以实现将消息放在队列尾部并等待处理
157
+ if is_source and broker_msg is not None and hasattr(broker_msg, "body"):
158
+ broker_msg.delete()
159
+ self.store.send(
160
+ self.queue_name,
161
+ broker_msg.body,
162
+ message_group_id=self.message_group_id,
163
+ )
164
+ else:
165
+ broker_msg.delete()
166
+ self.store.send(
167
+ self.queue_name, message.body, message_group_id=self.message_group_id
168
+ )
169
+
170
+ def shutdown(self):
171
+ """关闭Broker"""
172
+ self._shutdown = True
173
+ self.store.shutdown()
174
+ for thread in self.threads:
175
+ thread.join()
176
+ self.queue = Queue()
177
+ self.threads.clear()
178
+ self._consuming_started = False
179
+
180
+
181
+ class SQSConsumer(BaseConsumer): ...
@@ -0,0 +1,84 @@
1
+ import logging
2
+ import threading
3
+ import collections
4
+ from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
5
+
6
+ from .memory import MemoryBroker, MemoryConsumer
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ Server = collections.namedtuple("Server", ["path", "queue"])
11
+
12
+
13
+ class WebHookServer(BaseHTTPRequestHandler):
14
+ servers = collections.defaultdict(list)
15
+
16
+ def do_POST(self):
17
+ """
18
+ 接收WebHook请求
19
+ """
20
+ server_paths = WebHookServer.servers.get(self.server.server_address, [])
21
+ for server in server_paths:
22
+ if self.path == server.path:
23
+ queue = server.queue
24
+ break
25
+ else:
26
+ return self.send_error(404)
27
+
28
+ content_len = int(self.headers.get('content-length', 0))
29
+ post_body = self.rfile.read(content_len).decode("utf-8")
30
+ queue.put_nowait(post_body)
31
+ self.send_response(200)
32
+ self.send_header('Content-type', 'application/json')
33
+ self.end_headers()
34
+ self.wfile.write(b'{ "status": "ok" }')
35
+
36
+
37
+ class WebHookBroker(MemoryBroker):
38
+ _servers = {}
39
+
40
+ def __init__(self,
41
+ path: str,
42
+ host: str = "0.0.0.0",
43
+ port: int = 8090,
44
+ *args,
45
+ **kwargs):
46
+ super().__init__(*args, **kwargs)
47
+ self.host = host
48
+ self.port = port
49
+ self.path = path
50
+ self.threads = []
51
+
52
+ def _create_server(self):
53
+
54
+ if (self.host, self.port) not in self._servers:
55
+ hs = ThreadingHTTPServer(
56
+ (self.host, self.port),
57
+ WebHookServer
58
+ )
59
+ self._servers[(self.host, self.port)] = hs
60
+ # 只有在创建新服务器时才启动线程
61
+ thread = threading.Thread(target=hs.serve_forever)
62
+ thread.daemon = True
63
+ thread.start()
64
+ self.threads.append(thread)
65
+ else:
66
+ hs = self._servers[(self.host, self.port)]
67
+
68
+ WebHookServer.servers[(self.host, self.port)].append(Server(self.path, self.queue))
69
+
70
+ def consume(self, *args, **kwargs):
71
+ self._create_server()
72
+ logger.debug(f"WebHookBroker: {self.host}:{self.port}{self.path}")
73
+ return WebHookConsumer(self, *args, **kwargs)
74
+
75
+ def shutdown(self):
76
+ hs = self._servers.get((self.host, self.port))
77
+ if hs:
78
+ hs.shutdown()
79
+ for thread in self.threads:
80
+ thread.join()
81
+
82
+
83
+ class WebHookConsumer(MemoryConsumer):
84
+ ...
onestep/cli.py ADDED
@@ -0,0 +1,80 @@
1
+ import argparse
2
+ import importlib
3
+ import logging
4
+ import sys
5
+ from onestep import step, __version__
6
+
7
+ LOGFORMAT = "[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s"
8
+
9
+
10
+ def setup_logging():
11
+ # 设置全局日志级别为INFO,避免第三方库的DEBUG日志输出
12
+ logging.basicConfig(level=logging.INFO, format=LOGFORMAT, stream=sys.stdout)
13
+
14
+ # exclude amqpstorm logs
15
+ logging.getLogger("amqpstorm").setLevel(logging.CRITICAL)
16
+
17
+ # 获取onestep的logger并设置为DEBUG级别以便调试
18
+ onestep_logger = logging.getLogger("onestep")
19
+ onestep_logger.setLevel(logging.DEBUG)
20
+
21
+ return onestep_logger
22
+
23
+
24
+ logger = setup_logging()
25
+
26
+
27
+ def parse_args():
28
+ parser = argparse.ArgumentParser(
29
+ description='run onestep'
30
+ )
31
+ group = parser.add_mutually_exclusive_group(required=True)
32
+ group.add_argument(
33
+ "step", nargs='?',
34
+ help="the run step",
35
+ )
36
+ parser.add_argument(
37
+ "--group", "-G", default=None,
38
+ help="the run group",
39
+ type=str
40
+ )
41
+ parser.add_argument(
42
+ "--print",
43
+ action="store_true",
44
+ help="enable printing")
45
+ parser.add_argument(
46
+ "--path", "-P", default=".", nargs="*", type=str,
47
+ help="the step import path (default: current running directory)"
48
+ )
49
+ group.add_argument(
50
+ "--cron",
51
+ help="the cron expression to test",
52
+ type=str
53
+ )
54
+ return parser.parse_args()
55
+
56
+
57
+ def main():
58
+ args = parse_args()
59
+ for path in args.path:
60
+ sys.path.insert(0, path)
61
+ if args.cron:
62
+ from croniter import croniter
63
+ from datetime import datetime
64
+ cron = croniter(args.cron, datetime.now())
65
+ for _ in range(10):
66
+ print(cron.get_next(datetime))
67
+ return
68
+
69
+ logger.info(f"OneStep {__version__} is start up.")
70
+ try:
71
+ importlib.import_module(args.step)
72
+ step.start(group=args.group, block=True, print_jobs=args.print)
73
+ except ModuleNotFoundError:
74
+ logger.error(f"Module `{args.step}` not found.")
75
+ except KeyboardInterrupt:
76
+ step.shutdown()
77
+
78
+
79
+ if __name__ == '__main__':
80
+ sys.exit(main())