fast-web-core 1.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fast-web-core-1.0.0/MANIFEST.in +0 -0
- fast-web-core-1.0.0/PKG-INFO +23 -0
- fast-web-core-1.0.0/README.rst +12 -0
- fast-web-core-1.0.0/fast_web_core/__init__.py +9 -0
- fast-web-core-1.0.0/fast_web_core/auth/__init__.py +0 -0
- fast-web-core-1.0.0/fast_web_core/auth/auth_cache_pool.py +88 -0
- fast-web-core-1.0.0/fast_web_core/auth/share_auth.py +223 -0
- fast-web-core-1.0.0/fast_web_core/client/__init__.py +0 -0
- fast-web-core-1.0.0/fast_web_core/client/async_mongo_client.py +233 -0
- fast-web-core-1.0.0/fast_web_core/client/async_mulit_tenant_mongo_client.py +69 -0
- fast-web-core-1.0.0/fast_web_core/client/async_multi_database_mongo_client.py +60 -0
- fast-web-core-1.0.0/fast_web_core/client/async_rabbit_client.py +71 -0
- fast-web-core-1.0.0/fast_web_core/client/async_redis_client.py +37 -0
- fast-web-core-1.0.0/fast_web_core/client/mongo_client.py +208 -0
- fast-web-core-1.0.0/fast_web_core/client/multi_database_mongo_client.py +59 -0
- fast-web-core-1.0.0/fast_web_core/client/multi_tenant_mongo_client.py +60 -0
- fast-web-core-1.0.0/fast_web_core/client/oss2_client.py +67 -0
- fast-web-core-1.0.0/fast_web_core/client/redis_client.py +36 -0
- fast-web-core-1.0.0/fast_web_core/context/__init__.py +0 -0
- fast-web-core-1.0.0/fast_web_core/context/context_vars.py +11 -0
- fast-web-core-1.0.0/fast_web_core/exception/__init__.py +0 -0
- fast-web-core-1.0.0/fast_web_core/exception/exceptions.py +32 -0
- fast-web-core-1.0.0/fast_web_core/lib/__init__.py +0 -0
- fast-web-core-1.0.0/fast_web_core/lib/cfg.py +60 -0
- fast-web-core-1.0.0/fast_web_core/lib/func.py +92 -0
- fast-web-core-1.0.0/fast_web_core/lib/logger.py +87 -0
- fast-web-core-1.0.0/fast_web_core/lib/secret.py +59 -0
- fast-web-core-1.0.0/fast_web_core/lib/settings.py +103 -0
- fast-web-core-1.0.0/fast_web_core/lib/time.py +343 -0
- fast-web-core-1.0.0/fast_web_core/middleware/__init__.py +23 -0
- fast-web-core-1.0.0/fast_web_core/middleware/auth.py +55 -0
- fast-web-core-1.0.0/fast_web_core/model/__init__.py +0 -0
- fast-web-core-1.0.0/fast_web_core/model/base.py +25 -0
- fast-web-core-1.0.0/fast_web_core/model/enums.py +24 -0
- fast-web-core-1.0.0/fast_web_core/model/handler.py +96 -0
- fast-web-core-1.0.0/fast_web_core/model/items.py +75 -0
- fast-web-core-1.0.0/fast_web_core/model/query.py +23 -0
- fast-web-core-1.0.0/fast_web_core/utils/__init__.py +0 -0
- fast-web-core-1.0.0/fast_web_core/utils/common.py +164 -0
- fast-web-core-1.0.0/fast_web_core/utils/context_var.py +9 -0
- fast-web-core-1.0.0/fast_web_core/utils/decator.py +13 -0
- fast-web-core-1.0.0/fast_web_core/utils/encryption.py +106 -0
- fast-web-core-1.0.0/fast_web_core/utils/sequence.py +44 -0
- fast-web-core-1.0.0/fast_web_core.egg-info/PKG-INFO +23 -0
- fast-web-core-1.0.0/fast_web_core.egg-info/SOURCES.txt +47 -0
- fast-web-core-1.0.0/fast_web_core.egg-info/dependency_links.txt +1 -0
- fast-web-core-1.0.0/fast_web_core.egg-info/top_level.txt +1 -0
- fast-web-core-1.0.0/setup.cfg +4 -0
- fast-web-core-1.0.0/setup.py +25 -0
|
File without changes
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: fast-web-core
|
|
3
|
+
Version: 1.0.0
|
|
4
|
+
Summary: fast web core for zsodata
|
|
5
|
+
Home-page: http://www.zsodata.com
|
|
6
|
+
Author: zsodata
|
|
7
|
+
Author-email: team@zso.io
|
|
8
|
+
License: BSD License
|
|
9
|
+
Platform: all
|
|
10
|
+
Requires-Python: >=3.7
|
|
11
|
+
|
|
12
|
+
pip uninstall fast-web-core
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
注意:
|
|
16
|
+
mongo依赖版本:pymongo==4.7.3 and mongo version 7.0
|
|
17
|
+
如果需要使用旧版,切换到分支:mongo4
|
|
18
|
+
|
|
19
|
+
重点新增:
|
|
20
|
+
异步mongo,异步redis
|
|
21
|
+
|
|
22
|
+
pip install --upgrade fast-web-core -i https://pypi.python.org/simple
|
|
23
|
+
pip show fast-web-core
|
|
File without changes
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
from typing import List, Set, Optional, Dict
|
|
5
|
+
from ..lib import cfg
|
|
6
|
+
from ..client.redis_client import Redis
|
|
7
|
+
from ..model.items import AuthUser, AuthApp
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
rds = Redis(redis_uri=cfg.get('AUTH_REDIS_URL', None), redis_db=cfg.get('AUTH_REDIS_DB', None)).client
|
|
11
|
+
default_expire = 43200
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def save_cached_app(access_token: str, app: AuthApp, expire: int = None):
|
|
15
|
+
if app:
|
|
16
|
+
_expire = expire or default_expire
|
|
17
|
+
rds.set(f'share:access:{access_token}', app.json(), ex=_expire)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def get_cached_user(access_token: str) -> Optional[AuthUser]:
|
|
21
|
+
rs = rds.get(f'share:access:{access_token}')
|
|
22
|
+
if rs:
|
|
23
|
+
js_user = json.loads(rs)
|
|
24
|
+
# 容错
|
|
25
|
+
if isinstance(js_user, str):
|
|
26
|
+
js_user = json.loads(js_user)
|
|
27
|
+
# 特征验证(user对象需要有id字段)
|
|
28
|
+
if js_user and js_user.get('id', None):
|
|
29
|
+
user = AuthUser(**js_user)
|
|
30
|
+
user.userName = js_user.get('username', '')
|
|
31
|
+
return user
|
|
32
|
+
|
|
33
|
+
return None
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def get_cached_app(access_token: str) -> Optional[AuthApp]:
|
|
37
|
+
rs = rds.get(f'share:access:{access_token}')
|
|
38
|
+
if rs:
|
|
39
|
+
js_app = json.loads(rs)
|
|
40
|
+
# 容错
|
|
41
|
+
if isinstance(js_app, str):
|
|
42
|
+
js_app = json.loads(js_app)
|
|
43
|
+
# 特征验证(APP对象需要有appId字段)
|
|
44
|
+
if js_app and js_app.get('appId', None):
|
|
45
|
+
app = AuthApp(**js_app)
|
|
46
|
+
return app
|
|
47
|
+
|
|
48
|
+
return None
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def get_cached_access_dict(access_token: str) -> Optional[Dict]:
|
|
52
|
+
rs = rds.get(f'share:access:{access_token}')
|
|
53
|
+
if rs:
|
|
54
|
+
item = json.loads(rs)
|
|
55
|
+
# 容错
|
|
56
|
+
if isinstance(item, str):
|
|
57
|
+
item = json.loads(item)
|
|
58
|
+
return item
|
|
59
|
+
|
|
60
|
+
return None
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def has_access(access_token: str) -> bool:
|
|
64
|
+
return rds.exists(f'share:access:{access_token}')
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def save_cached_permission(access_token: str, permissions: Set[str], expire: int = None):
|
|
68
|
+
if permissions:
|
|
69
|
+
key = f'share:permissions:{access_token}'
|
|
70
|
+
for per in permissions:
|
|
71
|
+
rds.sadd(key, per)
|
|
72
|
+
_expire = expire or default_expire
|
|
73
|
+
rds.expire(key, _expire)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def has_permission(access_token: str, permission: str) -> bool:
|
|
77
|
+
# 后者兼容Java版本异常
|
|
78
|
+
_key = f'share:permissions:{access_token}'
|
|
79
|
+
return rds.sismember(_key, f'{permission}') or rds.sismember(_key, f'"{permission}"')
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def get_cached_permission(access_token: str) -> List[str]:
|
|
83
|
+
return rds.smembers(f'share:permissions:{access_token}')
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def clean_cache(access_token: str):
|
|
87
|
+
rds.delete(f'share:access:{access_token}')
|
|
88
|
+
rds.delete(f'share:permissions:{access_token}')
|
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
import logging
|
|
3
|
+
import re
|
|
4
|
+
import json
|
|
5
|
+
import threading
|
|
6
|
+
from typing import Optional, List
|
|
7
|
+
from starlette.routing import BaseRoute
|
|
8
|
+
from fastapi.routing import APIRoute
|
|
9
|
+
from starlette.requests import Request
|
|
10
|
+
|
|
11
|
+
from ..lib import cfg
|
|
12
|
+
from ..auth import auth_cache_pool
|
|
13
|
+
from ..client.redis_client import Redis
|
|
14
|
+
from ..exception.exceptions import NoAuthException
|
|
15
|
+
from ..model.items import AuthUser, AuthApp
|
|
16
|
+
from ..context.context_vars import tenant_context
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
FIX_WHITELIST = [
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ShareAuth(object):
|
|
25
|
+
_instance_lock = threading.Lock()
|
|
26
|
+
_biz_inited = False
|
|
27
|
+
_routes_inited = False
|
|
28
|
+
|
|
29
|
+
def __new__(cls, *args, **kwargs):
|
|
30
|
+
if not hasattr(ShareAuth, "_instance"):
|
|
31
|
+
with ShareAuth._instance_lock:
|
|
32
|
+
if not hasattr(ShareAuth, "_instance"):
|
|
33
|
+
ShareAuth._instance = object.__new__(cls)
|
|
34
|
+
return ShareAuth._instance
|
|
35
|
+
|
|
36
|
+
def __init__(self):
|
|
37
|
+
if not self._biz_inited:
|
|
38
|
+
self._biz_inited = True
|
|
39
|
+
self.rds = Redis(redis_uri=cfg.get('AUTH_REDIS_URL', None), redis_db=cfg.get('AUTH_REDIS_DB', None)).client
|
|
40
|
+
# route path regex -> tags
|
|
41
|
+
self.regex_to_tags_map = dict()
|
|
42
|
+
# no auth path whitelist
|
|
43
|
+
self.auth_whitelist = []
|
|
44
|
+
|
|
45
|
+
def get_access_token(self, request: Request):
|
|
46
|
+
"""
|
|
47
|
+
解析请求头鉴权信息
|
|
48
|
+
:param request:
|
|
49
|
+
:return:
|
|
50
|
+
"""
|
|
51
|
+
if request and (request.headers or request.query_params):
|
|
52
|
+
access_token = request.headers.get('Authorization', '').replace('Bearer ', '')
|
|
53
|
+
if not access_token:
|
|
54
|
+
access_token = request.cookies.get('access_token')
|
|
55
|
+
if not access_token:
|
|
56
|
+
access_token = request.query_params.get('access_token')
|
|
57
|
+
return access_token
|
|
58
|
+
|
|
59
|
+
return ''
|
|
60
|
+
|
|
61
|
+
def reload(self, routes: Optional[List[BaseRoute]], white_list=[], force: bool = False):
|
|
62
|
+
if not force and self._routes_inited:
|
|
63
|
+
return self
|
|
64
|
+
# 构造路由匹配正则与权限颗粒(tags)映射
|
|
65
|
+
_tmp_regex_to_tags_map = dict()
|
|
66
|
+
for route in routes:
|
|
67
|
+
if isinstance(route, APIRoute) and route.path_regex:
|
|
68
|
+
_tmp_regex_to_tags_map[route.path_regex] = route.tags
|
|
69
|
+
self.regex_to_tags_map.clear()
|
|
70
|
+
self.regex_to_tags_map = _tmp_regex_to_tags_map
|
|
71
|
+
logger.info(f'reload path to tags map: {len(self.regex_to_tags_map)}')
|
|
72
|
+
|
|
73
|
+
# 构造权限白名单
|
|
74
|
+
_auth_whitelist = list()
|
|
75
|
+
for row in FIX_WHITELIST:
|
|
76
|
+
_auth_whitelist.append(re.compile(row))
|
|
77
|
+
for row in white_list:
|
|
78
|
+
_auth_whitelist.append(re.compile(row))
|
|
79
|
+
self.auth_whitelist.clear()
|
|
80
|
+
self.auth_whitelist = _auth_whitelist
|
|
81
|
+
logger.info(f'reload no auth whitelist: {len(self.auth_whitelist)}')
|
|
82
|
+
#
|
|
83
|
+
self._routes_inited = True
|
|
84
|
+
|
|
85
|
+
return self
|
|
86
|
+
|
|
87
|
+
def access_check(self, request: Request):
|
|
88
|
+
"""
|
|
89
|
+
登录校验
|
|
90
|
+
:param request:
|
|
91
|
+
:return:
|
|
92
|
+
"""
|
|
93
|
+
access_token = self.get_access_token(request)
|
|
94
|
+
if not access_token:
|
|
95
|
+
raise NoAuthException('请先登录')
|
|
96
|
+
|
|
97
|
+
if not auth_cache_pool.has_access(access_token):
|
|
98
|
+
raise NoAuthException('登录信息已失效')
|
|
99
|
+
|
|
100
|
+
return access_token
|
|
101
|
+
|
|
102
|
+
def auth_check(self, request: Request):
|
|
103
|
+
"""
|
|
104
|
+
权限校验
|
|
105
|
+
:param request:
|
|
106
|
+
:return:
|
|
107
|
+
"""
|
|
108
|
+
if not self.regex_to_tags_map:
|
|
109
|
+
logger.debug(f'no path to tags map')
|
|
110
|
+
return False
|
|
111
|
+
# 优先处理白名单
|
|
112
|
+
for white_regex in self.auth_whitelist:
|
|
113
|
+
if white_regex.match(request.url.path):
|
|
114
|
+
return True
|
|
115
|
+
# 先校验登录
|
|
116
|
+
access_token = self.access_check(request)
|
|
117
|
+
for regex in self.regex_to_tags_map.keys():
|
|
118
|
+
if regex.match(request.url.path):
|
|
119
|
+
# 匹配到路由
|
|
120
|
+
tags = self.regex_to_tags_map.get(regex)
|
|
121
|
+
if tags:
|
|
122
|
+
# 需要登录且需要鉴权
|
|
123
|
+
auth_rs = self._rds_permissions_check(tags, access_token)
|
|
124
|
+
if not auth_rs:
|
|
125
|
+
raise NoAuthException('权限不足')
|
|
126
|
+
return True
|
|
127
|
+
else:
|
|
128
|
+
# 仅登录无需鉴权
|
|
129
|
+
return True
|
|
130
|
+
|
|
131
|
+
return False
|
|
132
|
+
|
|
133
|
+
def _rds_permissions_check(self, tags, access_token) -> bool:
|
|
134
|
+
# redis权限查询
|
|
135
|
+
for tag in tags:
|
|
136
|
+
# 新权限颗粒 xx.xx.xx.xx
|
|
137
|
+
has = auth_cache_pool.has_permission(access_token, tag)
|
|
138
|
+
if not has:
|
|
139
|
+
# 兼容Java版权限颗粒 "xx.xx.xx.xx"
|
|
140
|
+
has = auth_cache_pool.has_permission(access_token, json.dumps(tag))
|
|
141
|
+
if not has:
|
|
142
|
+
return False
|
|
143
|
+
|
|
144
|
+
return True
|
|
145
|
+
|
|
146
|
+
def get_auth_user(self, request: Request) -> Optional[AuthUser]:
|
|
147
|
+
"""
|
|
148
|
+
获取当前登录的用户信息
|
|
149
|
+
:param request:
|
|
150
|
+
:return:
|
|
151
|
+
"""
|
|
152
|
+
access_token = self.get_access_token(request)
|
|
153
|
+
if not access_token:
|
|
154
|
+
raise NoAuthException('请先登录')
|
|
155
|
+
if not auth_cache_pool.has_access(access_token):
|
|
156
|
+
raise NoAuthException('登录信息已失效')
|
|
157
|
+
|
|
158
|
+
auth_user = auth_cache_pool.get_cached_user(access_token)
|
|
159
|
+
# 仅超级管理员可以切换租户
|
|
160
|
+
if auth_user and (auth_user.superAdmin == 1 or auth_user.superAdmin == '1'):
|
|
161
|
+
tenant_code = request.headers.get('TenantCode', None)
|
|
162
|
+
if tenant_code:
|
|
163
|
+
# 切换租户
|
|
164
|
+
auth_user.tenantCode = tenant_code
|
|
165
|
+
|
|
166
|
+
# 设置租户上下文
|
|
167
|
+
if auth_user and auth_user.tenantCode:
|
|
168
|
+
tenant_context.set(auth_user.tenantCode)
|
|
169
|
+
|
|
170
|
+
return auth_user
|
|
171
|
+
|
|
172
|
+
def get_auth_app(self, request: Request) -> Optional[AuthApp]:
|
|
173
|
+
"""
|
|
174
|
+
获取当前登录的应用信息
|
|
175
|
+
"""
|
|
176
|
+
access_token = self.get_access_token(request)
|
|
177
|
+
if not access_token:
|
|
178
|
+
raise NoAuthException('请先鉴权认证')
|
|
179
|
+
if not auth_cache_pool.has_access(access_token):
|
|
180
|
+
raise NoAuthException('鉴权认证信息已失效')
|
|
181
|
+
|
|
182
|
+
auth_app = auth_cache_pool.get_cached_app(access_token=access_token)
|
|
183
|
+
# 设置租户上下文
|
|
184
|
+
if auth_app and auth_app.tenantCode:
|
|
185
|
+
tenant_context.set(auth_app.tenantCode)
|
|
186
|
+
|
|
187
|
+
return auth_app
|
|
188
|
+
|
|
189
|
+
def get_auth_tenant_code(self, request: Request) -> Optional[str]:
|
|
190
|
+
"""
|
|
191
|
+
获取授权用户所属租户
|
|
192
|
+
:param request:
|
|
193
|
+
:return:
|
|
194
|
+
"""
|
|
195
|
+
access_token = self.get_access_token(request)
|
|
196
|
+
if access_token:
|
|
197
|
+
access_item = auth_cache_pool.get_cached_access_dict(access_token)
|
|
198
|
+
if access_item:
|
|
199
|
+
# 仅超级管理员可以切换租户
|
|
200
|
+
super_admin = access_item.get('superAdmin', None)
|
|
201
|
+
if super_admin == 1 or super_admin == '1':
|
|
202
|
+
tenant_code = request.headers.get('TenantCode', None)
|
|
203
|
+
if tenant_code:
|
|
204
|
+
return tenant_code
|
|
205
|
+
|
|
206
|
+
return access_item.get('tenantCode', '')
|
|
207
|
+
|
|
208
|
+
return ''
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
# 用于handler来Depends,获取授权用户
|
|
212
|
+
async def authed_user(request: Request) -> Optional[AuthUser]:
|
|
213
|
+
return ShareAuth().get_auth_user(request)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
# 用于handler来Depends,获取授权APP信息
|
|
217
|
+
async def authed_app(request: Request) -> Optional[AuthApp]:
|
|
218
|
+
return ShareAuth().get_auth_app(request)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
# 用于handler来Depends,获取授权用户
|
|
222
|
+
async def access_token(request: Request) -> Optional[str]:
|
|
223
|
+
return ShareAuth().get_access_token(request)
|
|
File without changes
|
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
# encoding: utf-8
|
|
2
|
+
# version: pymongo==4.7.3, motor==3.5.0
|
|
3
|
+
|
|
4
|
+
from ..exception.exceptions import BizException
|
|
5
|
+
from ..exception.exceptions import NoConfigException
|
|
6
|
+
from ..lib import cfg, logger
|
|
7
|
+
|
|
8
|
+
LOGGER = logger.get('AsyncMongoClient')
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class AsyncMongo(object):
|
|
12
|
+
def __init__(self, mongo_uri: str = None, mongo_db: str = None):
|
|
13
|
+
import motor.motor_asyncio
|
|
14
|
+
self.mongo_uri = mongo_uri or cfg.get('MONGO_URL')
|
|
15
|
+
self.mongo_db = mongo_db or cfg.get('MONGO_DB')
|
|
16
|
+
if not self.mongo_uri:
|
|
17
|
+
raise NoConfigException('mongodb uri not config!')
|
|
18
|
+
if not self.mongo_db:
|
|
19
|
+
raise NoConfigException('mongodb database not config!')
|
|
20
|
+
|
|
21
|
+
self.client = motor.motor_asyncio.AsyncIOMotorClient(self.mongo_uri)
|
|
22
|
+
self.db = self.client[self.mongo_db]
|
|
23
|
+
|
|
24
|
+
LOGGER.info(self.client)
|
|
25
|
+
|
|
26
|
+
# 获取集合
|
|
27
|
+
def get_collection(self, coll):
|
|
28
|
+
# 非异步
|
|
29
|
+
return self.db.get_collection(coll)
|
|
30
|
+
|
|
31
|
+
# 查询对象
|
|
32
|
+
async def get(self, collection, query={}):
|
|
33
|
+
return await self.db[collection].find_one(query)
|
|
34
|
+
|
|
35
|
+
# 统计数量
|
|
36
|
+
async def count(self, collection, query={}):
|
|
37
|
+
return await self.db[collection].count_documents(query)
|
|
38
|
+
|
|
39
|
+
# 查询列表
|
|
40
|
+
async def list(self, collection, query={}, fields=None, sort=[], batch_size: int = 2000):
|
|
41
|
+
result = []
|
|
42
|
+
cursor = self.db[collection].find(filter=query, projection=fields, sort=sort)
|
|
43
|
+
cursor.batch_size(batch_size)
|
|
44
|
+
async for document in cursor:
|
|
45
|
+
result.append(document)
|
|
46
|
+
await cursor.close()
|
|
47
|
+
|
|
48
|
+
return result
|
|
49
|
+
|
|
50
|
+
async def list_with_cursor(self, collection, query={}, fields=None, sort=[], batch_size: int = 2000):
|
|
51
|
+
"""
|
|
52
|
+
样例:
|
|
53
|
+
cursor = await aio_mongo.list_with_cursor(collection)
|
|
54
|
+
async for document in cursor:
|
|
55
|
+
print("document =", document)
|
|
56
|
+
await cursor.close()
|
|
57
|
+
注意:应该在使用完毕后主动关闭 cursor
|
|
58
|
+
"""
|
|
59
|
+
cursor = self.db[collection].find(filter=query, projection=fields, sort=sort)
|
|
60
|
+
cursor.batch_size(batch_size)
|
|
61
|
+
|
|
62
|
+
return cursor
|
|
63
|
+
|
|
64
|
+
# 分页查询
|
|
65
|
+
async def page(self, collection, query={}, page_no=1, page_size=20, fields=None, sort=[], batch_size: int = 2000):
|
|
66
|
+
total = await self.db[collection].count_documents(query) or 0
|
|
67
|
+
cursor = self.db[collection].find(query, fields, sort=sort).skip(page_size * (page_no - 1)).limit(page_size)
|
|
68
|
+
cursor.batch_size(batch_size)
|
|
69
|
+
rows = []
|
|
70
|
+
async for document in cursor:
|
|
71
|
+
rows.append(document)
|
|
72
|
+
await cursor.close()
|
|
73
|
+
|
|
74
|
+
return rows, total
|
|
75
|
+
|
|
76
|
+
# 查询列表前N个
|
|
77
|
+
async def top(self, collection, query={}, sort=[], limit=1, fields=None, batch_size: int = 2000):
|
|
78
|
+
cursor = self.db[collection].find(filter=query, projection=fields, sort=sort, limit=limit)
|
|
79
|
+
cursor.batch_size(batch_size)
|
|
80
|
+
rows = []
|
|
81
|
+
async for document in cursor:
|
|
82
|
+
if limit == 1:
|
|
83
|
+
return document
|
|
84
|
+
rows.append(document)
|
|
85
|
+
await cursor.close()
|
|
86
|
+
|
|
87
|
+
return rows
|
|
88
|
+
|
|
89
|
+
# 查询去重列表
|
|
90
|
+
async def distinct(self, collection, dist_key, query={}, fields=None):
|
|
91
|
+
return await self.db[collection].find(query, fields).distinct(dist_key)
|
|
92
|
+
|
|
93
|
+
# 含分页聚合查询
|
|
94
|
+
async def aggregate_page(self, collection, pipelines, page_no=1, page_size=20):
|
|
95
|
+
skip = page_size * (page_no - 1)
|
|
96
|
+
if pipelines:
|
|
97
|
+
pipelines.append({'$facet': {'total': [{'$count': 'count'}], 'rows': [{'$skip': skip}, {'$limit': page_size}]}})
|
|
98
|
+
pipelines.append({'$project': {'data': '$rows', 'total': {'$arrayElemAt': ['$total.count', 0]}}})
|
|
99
|
+
|
|
100
|
+
cursor = self.db[collection].aggregate(pipelines, session=None, allowDiskUse=True)
|
|
101
|
+
cursor.batch_size(page_size)
|
|
102
|
+
async for rs in cursor:
|
|
103
|
+
# rs: dict as {"data": [], "total": int}
|
|
104
|
+
if rs and 'data' in rs and 'total' in rs:
|
|
105
|
+
await cursor.close()
|
|
106
|
+
return rs.get('data'), rs.get('total')
|
|
107
|
+
await cursor.close()
|
|
108
|
+
return [], 0
|
|
109
|
+
|
|
110
|
+
# 聚合查询
|
|
111
|
+
async def aggregate(self, collection, pipelines=[], batch_size: int = 2000):
|
|
112
|
+
"""
|
|
113
|
+
样例:
|
|
114
|
+
cursor = await aio_mongo.aggregate(collection="collection_name", pipelines=[])
|
|
115
|
+
async for row in cursor:
|
|
116
|
+
print(row)
|
|
117
|
+
|
|
118
|
+
await cursor.close()
|
|
119
|
+
注意:应该在使用完毕后主动关闭 cursor
|
|
120
|
+
"""
|
|
121
|
+
cursor = self.db[collection].aggregate(pipelines, session=None, allowDiskUse=True)
|
|
122
|
+
cursor.batch_size(batch_size)
|
|
123
|
+
return cursor
|
|
124
|
+
|
|
125
|
+
# 查询分页列表,还没找到实际应用
|
|
126
|
+
async def list_with_page(self, collection, query={}, page_size=10000, fields=None):
|
|
127
|
+
# 没有用到
|
|
128
|
+
rows = list()
|
|
129
|
+
total = await self.db[collection].count_documents(query)
|
|
130
|
+
if total > 0 and page_size > 0:
|
|
131
|
+
total_page = round(total / page_size) + 1
|
|
132
|
+
for page in range(0, total_page):
|
|
133
|
+
if fields:
|
|
134
|
+
cursor = self.db[collection].find(query, fields).skip(page_size * page).limit(page)
|
|
135
|
+
else:
|
|
136
|
+
cursor = self.db[collection].find(query).skip(page_size * page).limit(page)
|
|
137
|
+
async for document in cursor:
|
|
138
|
+
rows.append(document)
|
|
139
|
+
await cursor.close()
|
|
140
|
+
return rows
|
|
141
|
+
|
|
142
|
+
# 插入或更新
|
|
143
|
+
async def insert_or_update(self, collection, data, id_key='_id', update=None, upsert=True, multi=False):
|
|
144
|
+
if not multi:
|
|
145
|
+
if data and not update:
|
|
146
|
+
result = await self.db[collection].update_one({id_key: data[id_key]}, {'$set': data}, upsert=upsert)
|
|
147
|
+
elif not data and update:
|
|
148
|
+
result = await self.db[collection].update_one({id_key: data[id_key]}, update, upsert=upsert)
|
|
149
|
+
else:
|
|
150
|
+
# all([data, update]) or not all([data, update])
|
|
151
|
+
raise BizException("data和update不能同时存在或同时为空")
|
|
152
|
+
else:
|
|
153
|
+
if data and not update:
|
|
154
|
+
result = await self.db[collection].update_many({id_key: data[id_key]}, {'$set': data}, upsert=upsert)
|
|
155
|
+
elif not data and update:
|
|
156
|
+
result = await self.db[collection].update_many({id_key: data[id_key]}, update, upsert=upsert)
|
|
157
|
+
else:
|
|
158
|
+
raise BizException("data和update不能同时存在或同时为空")
|
|
159
|
+
|
|
160
|
+
return result
|
|
161
|
+
|
|
162
|
+
# 插入
|
|
163
|
+
async def insert(self, collection, data):
|
|
164
|
+
return await self.db[collection].insert_one(data)
|
|
165
|
+
|
|
166
|
+
# 更新
|
|
167
|
+
async def update(self, collection, filter, data=None, update=None, multi=False):
|
|
168
|
+
if multi:
|
|
169
|
+
if data and not update:
|
|
170
|
+
result = await self.db[collection].update_many(filter, {'$set': data})
|
|
171
|
+
elif not data and update:
|
|
172
|
+
result = await self.db[collection].update_many(filter, update)
|
|
173
|
+
else:
|
|
174
|
+
raise BizException("data和update不能同时存在或同时为空")
|
|
175
|
+
else:
|
|
176
|
+
if data and not update:
|
|
177
|
+
result = await self.db[collection].update_one(filter, {'$set': data})
|
|
178
|
+
elif not data and update:
|
|
179
|
+
result = await self.db[collection].update_one(filter, update)
|
|
180
|
+
else:
|
|
181
|
+
raise BizException("data和update不能同时存在或同时为空")
|
|
182
|
+
|
|
183
|
+
return result
|
|
184
|
+
|
|
185
|
+
# 原生保存方法
|
|
186
|
+
async def save(self, collection, filter: dict, save_data: dict, upsert=True):
|
|
187
|
+
return await self.db[collection].update_one(filter, {'$set': save_data}, upsert=upsert)
|
|
188
|
+
|
|
189
|
+
# 以主键更新
|
|
190
|
+
async def update_by_pk(self, collection, pk_val, data=None, update=None, upsert=False):
|
|
191
|
+
if data and not update:
|
|
192
|
+
result = await self.db[collection].update_one({'_id': pk_val}, {'$set': data}, upsert=upsert)
|
|
193
|
+
elif not data and update:
|
|
194
|
+
result = await self.db[collection].update_one({'_id': pk_val}, update, upsert=upsert)
|
|
195
|
+
else:
|
|
196
|
+
raise BizException("data和update不能同时存在或同时为空")
|
|
197
|
+
return result
|
|
198
|
+
|
|
199
|
+
# 批量更新
|
|
200
|
+
async def batch_update(self, collection, filter, datas=None, update=None, *args, **kwargs):
|
|
201
|
+
if datas and not update:
|
|
202
|
+
result = await self.db[collection].update_many(filter, {'$set': datas})
|
|
203
|
+
elif not datas and update:
|
|
204
|
+
result = await self.db[collection].update_many(filter, update)
|
|
205
|
+
else:
|
|
206
|
+
raise BizException("data和update不能同时存在或同时为空")
|
|
207
|
+
return result
|
|
208
|
+
|
|
209
|
+
# 删除
|
|
210
|
+
async def delete(self, collection, filter):
|
|
211
|
+
return await self.db[collection].delete_many(filter)
|
|
212
|
+
|
|
213
|
+
# 插入或更新
|
|
214
|
+
async def bulk_write(self, collection, bulk_list: list, batch_size: int = 1000):
|
|
215
|
+
result = None
|
|
216
|
+
if bulk_list:
|
|
217
|
+
bulk_lists = [bulk_list[i: i + batch_size] for i in range(0, len(bulk_list), batch_size)]
|
|
218
|
+
for _bulk_list in bulk_lists:
|
|
219
|
+
result = await self.db[collection].bulk_write(_bulk_list, ordered=False, bypass_document_validation=True)
|
|
220
|
+
return result
|
|
221
|
+
|
|
222
|
+
# 创建索引
|
|
223
|
+
async def create_index(self, collection, fields):
|
|
224
|
+
return await self.db[collection].create_index(fields)
|
|
225
|
+
|
|
226
|
+
def close(self):
|
|
227
|
+
if self.client:
|
|
228
|
+
try:
|
|
229
|
+
self.client.close()
|
|
230
|
+
print("close successful")
|
|
231
|
+
except Exception as e:
|
|
232
|
+
print("close mongo client catch err:", e)
|
|
233
|
+
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
# encoding: utf-8
|
|
2
|
+
import asyncio
|
|
3
|
+
import threading
|
|
4
|
+
from typing import Optional
|
|
5
|
+
from ..client.async_mongo_client import AsyncMongo
|
|
6
|
+
from ..lib import cfg, logger
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
LOGGER = logger.get('AsyncMultiTenantMongo')
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class AsyncMultiTenantMongo(AsyncMongo):
|
|
13
|
+
"""
|
|
14
|
+
异步多租户Mongo客户端简易封装
|
|
15
|
+
"""
|
|
16
|
+
def __init__(self, tenant_code=None, mongo_url=None, mongo_db=None, db_prefix=None):
|
|
17
|
+
if not mongo_url:
|
|
18
|
+
mongo_url = cfg.get('MONGO_URL') or 'mongodb://localhost:27017'
|
|
19
|
+
if not mongo_db:
|
|
20
|
+
if tenant_code and (db_prefix or cfg.get('MONGO_DB_BIZ_PREFIX', None)):
|
|
21
|
+
mongo_db = f"{db_prefix or cfg.get('MONGO_DB_BIZ_PREFIX', None)}{tenant_code}"
|
|
22
|
+
if not mongo_db:
|
|
23
|
+
raise Exception('mongodb database not specified')
|
|
24
|
+
super().__init__(mongo_url, mongo_db)
|
|
25
|
+
LOGGER.info(f'[{tenant_code}] tenant mongodb inited~')
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class AsyncMultiTenantMongoHolder(object):
|
|
29
|
+
"""
|
|
30
|
+
异步多租户Mongo客户端持有器(租户数据库连接池)
|
|
31
|
+
"""
|
|
32
|
+
_instance_lock = threading.Lock()
|
|
33
|
+
_instance_async_lock = asyncio.Lock()
|
|
34
|
+
_tenant_instance_dict = dict()
|
|
35
|
+
|
|
36
|
+
def __new__(cls, *args, **kwargs):
|
|
37
|
+
if not hasattr(AsyncMultiTenantMongoHolder, "_instance"):
|
|
38
|
+
|
|
39
|
+
if not hasattr(AsyncMultiTenantMongoHolder, "_instance"):
|
|
40
|
+
AsyncMultiTenantMongoHolder._instance = object.__new__(cls)
|
|
41
|
+
|
|
42
|
+
return AsyncMultiTenantMongoHolder._instance
|
|
43
|
+
|
|
44
|
+
@staticmethod
|
|
45
|
+
def get_tenant_mongo(tenant_code: str) -> Optional[AsyncMultiTenantMongo]:
|
|
46
|
+
if not tenant_code:
|
|
47
|
+
return None
|
|
48
|
+
|
|
49
|
+
# 有实例则直接返回
|
|
50
|
+
if tenant_code in AsyncMultiTenantMongoHolder._tenant_instance_dict:
|
|
51
|
+
return AsyncMultiTenantMongoHolder._tenant_instance_dict.get(tenant_code)
|
|
52
|
+
|
|
53
|
+
# 无实例则加锁创建
|
|
54
|
+
AsyncMultiTenantMongoHolder._instance_lock.acquire()
|
|
55
|
+
try:
|
|
56
|
+
# 双重锁校验
|
|
57
|
+
if tenant_code not in AsyncMultiTenantMongoHolder._tenant_instance_dict:
|
|
58
|
+
# 初始化新实例
|
|
59
|
+
inst = AsyncMultiTenantMongo(tenant_code=tenant_code)
|
|
60
|
+
AsyncMultiTenantMongoHolder._tenant_instance_dict[tenant_code] = inst
|
|
61
|
+
if AsyncMultiTenantMongoHolder._instance_lock.locked():
|
|
62
|
+
AsyncMultiTenantMongoHolder._instance_lock.release()
|
|
63
|
+
return AsyncMultiTenantMongoHolder._tenant_instance_dict.get(tenant_code)
|
|
64
|
+
finally:
|
|
65
|
+
try:
|
|
66
|
+
if AsyncMultiTenantMongoHolder._instance_lock.locked():
|
|
67
|
+
AsyncMultiTenantMongoHolder._instance_lock.release()
|
|
68
|
+
except Exception as e:
|
|
69
|
+
LOGGER.info(e)
|