fast-web-core 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fast_web_core-0.0.1/MANIFEST.in +0 -0
- fast_web_core-0.0.1/PKG-INFO +37 -0
- fast_web_core-0.0.1/README.rst +12 -0
- fast_web_core-0.0.1/fast_web_core/__init__.py +9 -0
- fast_web_core-0.0.1/fast_web_core/auth/__init__.py +0 -0
- fast_web_core-0.0.1/fast_web_core/auth/auth_cache_pool.py +109 -0
- fast_web_core-0.0.1/fast_web_core/auth/share_auth.py +216 -0
- fast_web_core-0.0.1/fast_web_core/client/__init__.py +0 -0
- fast_web_core-0.0.1/fast_web_core/client/async_mongo_client.py +233 -0
- fast_web_core-0.0.1/fast_web_core/client/async_mulit_tenant_mongo_client.py +69 -0
- fast_web_core-0.0.1/fast_web_core/client/async_multi_database_mongo_client.py +60 -0
- fast_web_core-0.0.1/fast_web_core/client/async_rabbit_client.py +71 -0
- fast_web_core-0.0.1/fast_web_core/client/async_redis_client.py +37 -0
- fast_web_core-0.0.1/fast_web_core/client/mongo_client.py +208 -0
- fast_web_core-0.0.1/fast_web_core/client/multi_database_mongo_client.py +59 -0
- fast_web_core-0.0.1/fast_web_core/client/multi_tenant_mongo_client.py +60 -0
- fast_web_core-0.0.1/fast_web_core/client/oss2_client.py +95 -0
- fast_web_core-0.0.1/fast_web_core/client/redis_client.py +36 -0
- fast_web_core-0.0.1/fast_web_core/context/__init__.py +0 -0
- fast_web_core-0.0.1/fast_web_core/context/context_vars.py +14 -0
- fast_web_core-0.0.1/fast_web_core/exception/__init__.py +0 -0
- fast_web_core-0.0.1/fast_web_core/exception/exceptions.py +32 -0
- fast_web_core-0.0.1/fast_web_core/lib/__init__.py +0 -0
- fast_web_core-0.0.1/fast_web_core/lib/cfg.py +60 -0
- fast_web_core-0.0.1/fast_web_core/lib/func.py +92 -0
- fast_web_core-0.0.1/fast_web_core/lib/log_filter.py +37 -0
- fast_web_core-0.0.1/fast_web_core/lib/logger.py +87 -0
- fast_web_core-0.0.1/fast_web_core/lib/secret.py +59 -0
- fast_web_core-0.0.1/fast_web_core/lib/settings.py +103 -0
- fast_web_core-0.0.1/fast_web_core/lib/time.py +343 -0
- fast_web_core-0.0.1/fast_web_core/lib/util.py +56 -0
- fast_web_core-0.0.1/fast_web_core/middleware/__init__.py +17 -0
- fast_web_core-0.0.1/fast_web_core/middleware/auth.py +56 -0
- fast_web_core-0.0.1/fast_web_core/middleware/trace.py +71 -0
- fast_web_core-0.0.1/fast_web_core/model/__init__.py +0 -0
- fast_web_core-0.0.1/fast_web_core/model/base.py +25 -0
- fast_web_core-0.0.1/fast_web_core/model/enums.py +24 -0
- fast_web_core-0.0.1/fast_web_core/model/handler.py +96 -0
- fast_web_core-0.0.1/fast_web_core/model/items.py +59 -0
- fast_web_core-0.0.1/fast_web_core/model/query.py +23 -0
- fast_web_core-0.0.1/fast_web_core/utils/__init__.py +0 -0
- fast_web_core-0.0.1/fast_web_core/utils/common.py +164 -0
- fast_web_core-0.0.1/fast_web_core/utils/context_var.py +9 -0
- fast_web_core-0.0.1/fast_web_core/utils/decator.py +13 -0
- fast_web_core-0.0.1/fast_web_core/utils/encryption.py +75 -0
- fast_web_core-0.0.1/fast_web_core/utils/sequence.py +44 -0
- fast_web_core-0.0.1/fast_web_core.egg-info/PKG-INFO +37 -0
- fast_web_core-0.0.1/fast_web_core.egg-info/SOURCES.txt +50 -0
- fast_web_core-0.0.1/fast_web_core.egg-info/dependency_links.txt +1 -0
- fast_web_core-0.0.1/fast_web_core.egg-info/top_level.txt +1 -0
- fast_web_core-0.0.1/setup.cfg +4 -0
- fast_web_core-0.0.1/setup.py +57 -0
|
File without changes
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: fast_web_core
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: fast web core for zsodata
|
|
5
|
+
Home-page: http://www.zsodata.com
|
|
6
|
+
Author: zsodata
|
|
7
|
+
Author-email: team@zso.io
|
|
8
|
+
License: BSD License
|
|
9
|
+
Project-URL: Source, http://www.zsodata.com
|
|
10
|
+
Keywords: fastapi web core framework async mongodb redis
|
|
11
|
+
Platform: all
|
|
12
|
+
Classifier: Development Status :: 4 - Beta
|
|
13
|
+
Classifier: Intended Audience :: Developers
|
|
14
|
+
Classifier: License :: OSI Approved :: BSD License
|
|
15
|
+
Classifier: Operating System :: OS Independent
|
|
16
|
+
Classifier: Programming Language :: Python :: 3
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.7
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.8
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
20
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
21
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
22
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
23
|
+
Requires-Python: >=3.7
|
|
24
|
+
Description-Content-Type: text/plain
|
|
25
|
+
|
|
26
|
+
pip uninstall fast-web-core
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
注意:
|
|
30
|
+
mongo依赖版本:pymongo==4.7.3 and mongo version 7.0
|
|
31
|
+
如果需要使用旧版,切换到分支:mongo4
|
|
32
|
+
|
|
33
|
+
重点新增:
|
|
34
|
+
异步mongo,异步redis
|
|
35
|
+
|
|
36
|
+
pip install --upgrade fast-web-core -i https://pypi.python.org/simple
|
|
37
|
+
pip show fast-web-core
|
|
File without changes
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
from typing import List, Set, Optional, Dict
|
|
5
|
+
from ..lib import cfg
|
|
6
|
+
from ..client.async_redis_client import AsyncRedis
|
|
7
|
+
from ..model.items import AuthUser, AuthApp
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
rds = AsyncRedis(redis_uri=cfg.get('AUTH_REDIS_URL', None), redis_db=cfg.get('AUTH_REDIS_DB', None)).client
|
|
11
|
+
default_expire = 43200
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
async def get_cached_app(access_token: str) -> Optional[AuthApp]:
|
|
15
|
+
series_prefix = cfg.get('AUTH_SERIES_PREFIX', 'fast')
|
|
16
|
+
rs = await rds.get(f'{series_prefix}:access:share:{access_token}')
|
|
17
|
+
if rs:
|
|
18
|
+
js_user = json.loads(rs)
|
|
19
|
+
# 容错
|
|
20
|
+
if isinstance(js_user, str):
|
|
21
|
+
js_user = json.loads(js_user)
|
|
22
|
+
# 特征验证(user 对象需要有 id 字段)
|
|
23
|
+
if js_user and js_user.get('team_id', None):
|
|
24
|
+
auth_app = AuthApp(**js_user)
|
|
25
|
+
return auth_app
|
|
26
|
+
|
|
27
|
+
return None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
async def get_cached_user(access_token: str) -> Optional[AuthUser]:
|
|
31
|
+
series_prefix = cfg.get('AUTH_SERIES_PREFIX', 'fast')
|
|
32
|
+
rs = await rds.get(f'{series_prefix}:access:share:{access_token}')
|
|
33
|
+
if rs:
|
|
34
|
+
js_user = json.loads(rs)
|
|
35
|
+
# 容错
|
|
36
|
+
if isinstance(js_user, str):
|
|
37
|
+
js_user = json.loads(js_user)
|
|
38
|
+
# 特征验证(user 对象需要有 id 字段)
|
|
39
|
+
if js_user and js_user.get('id', None) is not None:
|
|
40
|
+
user = AuthUser(**js_user)
|
|
41
|
+
return user
|
|
42
|
+
|
|
43
|
+
return None
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
async def get_cached_access_dict(access_token: str) -> Optional[Dict]:
|
|
47
|
+
series_prefix = cfg.get('AUTH_SERIES_PREFIX', 'fast')
|
|
48
|
+
rs = await rds.get(f'{series_prefix}:access:share:{access_token}')
|
|
49
|
+
if rs:
|
|
50
|
+
item = json.loads(rs)
|
|
51
|
+
# 容错
|
|
52
|
+
if isinstance(item, str):
|
|
53
|
+
item = json.loads(item)
|
|
54
|
+
return item
|
|
55
|
+
|
|
56
|
+
return None
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
async def has_access(access_token: str) -> bool:
|
|
60
|
+
series_prefix = cfg.get('AUTH_SERIES_PREFIX', 'fast')
|
|
61
|
+
return await rds.exists(f'{series_prefix}:access:share:{access_token}')
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
async def has_app_access(access_token: str) -> bool:
|
|
65
|
+
series_prefix = cfg.get('AUTH_SERIES_PREFIX', 'fast')
|
|
66
|
+
return await rds.exists(f"{series_prefix}:access:share:{access_token}")
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
async def save_cached_permission(access_token: str, permissions: Set[str], expire: int = None):
|
|
70
|
+
if permissions:
|
|
71
|
+
series_prefix = cfg.get('AUTH_SERIES_PREFIX', 'fast')
|
|
72
|
+
key = f'{series_prefix}:share:permissions:{access_token}'
|
|
73
|
+
for per in permissions:
|
|
74
|
+
await rds.sadd(key, per)
|
|
75
|
+
_expire = expire or default_expire
|
|
76
|
+
await rds.expire(key, _expire)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
async def has_permission(access_token: str, permission: str) -> bool:
|
|
80
|
+
# 后者兼容 Java 版本异常
|
|
81
|
+
series_prefix = cfg.get('AUTH_SERIES_PREFIX', 'fast')
|
|
82
|
+
_key = f'{series_prefix}:access:permission:{access_token}'
|
|
83
|
+
is_member = await rds.sismember(_key, permission)
|
|
84
|
+
if not is_member:
|
|
85
|
+
_permission = f'"{permission}"'
|
|
86
|
+
is_member = await rds.sismember(_key, _permission)
|
|
87
|
+
|
|
88
|
+
if not is_member:
|
|
89
|
+
return False
|
|
90
|
+
# 如果有这个接口权限,校验权限状态是否是 1
|
|
91
|
+
key_pattern = f"{series_prefix}:access:*{permission}:{access_token}"
|
|
92
|
+
keys = await rds.keys(key_pattern)
|
|
93
|
+
permission_is_open = False
|
|
94
|
+
for key in keys:
|
|
95
|
+
permission_status = await rds.get(key)
|
|
96
|
+
if permission_status == "1":
|
|
97
|
+
permission_is_open = True
|
|
98
|
+
return permission_is_open
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
async def get_cached_permission(access_token: str) -> List[str]:
|
|
102
|
+
series_prefix = cfg.get('AUTH_SERIES_PREFIX', 'fast')
|
|
103
|
+
return await rds.smembers(f'{series_prefix}:share:permissions:{access_token}')
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
async def clean_cache(access_token: str):
|
|
107
|
+
series_prefix = cfg.get('AUTH_SERIES_PREFIX', 'fast')
|
|
108
|
+
await rds.delete(f'{series_prefix}:access:share:{access_token}')
|
|
109
|
+
await rds.delete(f'{series_prefix}:share:permissions:{access_token}')
|
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
import logging
|
|
3
|
+
import re
|
|
4
|
+
import json
|
|
5
|
+
import threading
|
|
6
|
+
from typing import Optional, List
|
|
7
|
+
from starlette.routing import BaseRoute
|
|
8
|
+
from fastapi.routing import APIRoute
|
|
9
|
+
from starlette.requests import Request
|
|
10
|
+
|
|
11
|
+
from ..lib import cfg
|
|
12
|
+
from ..auth import auth_cache_pool
|
|
13
|
+
from ..client.async_redis_client import AsyncRedis
|
|
14
|
+
from ..exception.exceptions import NoAuthException
|
|
15
|
+
from ..model.items import AuthUser, AuthApp
|
|
16
|
+
from ..context.context_vars import tenant_context
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
FIX_WHITELIST = [
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ShareAuth(object):
|
|
25
|
+
_instance_lock = threading.Lock()
|
|
26
|
+
_biz_inited = False
|
|
27
|
+
_routes_inited = False
|
|
28
|
+
|
|
29
|
+
def __new__(cls, *args, **kwargs):
|
|
30
|
+
if not hasattr(ShareAuth, "_instance"):
|
|
31
|
+
with ShareAuth._instance_lock:
|
|
32
|
+
if not hasattr(ShareAuth, "_instance"):
|
|
33
|
+
ShareAuth._instance = object.__new__(cls)
|
|
34
|
+
return ShareAuth._instance
|
|
35
|
+
|
|
36
|
+
def __init__(self):
|
|
37
|
+
if not self._biz_inited:
|
|
38
|
+
self._biz_inited = True
|
|
39
|
+
self.rds = AsyncRedis(redis_uri=cfg.get('AUTH_REDIS_URL', None), redis_db=cfg.get('AUTH_REDIS_DB', None)).client
|
|
40
|
+
# route path regex -> tags
|
|
41
|
+
self.regex_to_tags_map = dict()
|
|
42
|
+
# no auth path whitelist
|
|
43
|
+
self.auth_whitelist = []
|
|
44
|
+
|
|
45
|
+
def get_access_token(self, request: Request):
|
|
46
|
+
"""
|
|
47
|
+
解析请求头鉴权信息
|
|
48
|
+
:param request:
|
|
49
|
+
:return:
|
|
50
|
+
"""
|
|
51
|
+
if request and (request.headers or request.query_params):
|
|
52
|
+
access_token = request.headers.get('Authorization', '').replace('Bearer ', '')
|
|
53
|
+
if not access_token:
|
|
54
|
+
access_token = request.cookies.get('access_token')
|
|
55
|
+
if not access_token:
|
|
56
|
+
access_token = request.query_params.get('access_token')
|
|
57
|
+
return access_token
|
|
58
|
+
|
|
59
|
+
return ''
|
|
60
|
+
|
|
61
|
+
def reload(self, routes: Optional[List[BaseRoute]], white_list=[], force: bool = False):
|
|
62
|
+
if not force and self._routes_inited:
|
|
63
|
+
return self
|
|
64
|
+
# 构造路由匹配正则与权限颗粒(tags)映射
|
|
65
|
+
_tmp_regex_to_tags_map = dict()
|
|
66
|
+
for route in routes:
|
|
67
|
+
if isinstance(route, APIRoute) and route.path_regex:
|
|
68
|
+
_tmp_regex_to_tags_map[route.path_regex] = route.tags
|
|
69
|
+
self.regex_to_tags_map.clear()
|
|
70
|
+
self.regex_to_tags_map = _tmp_regex_to_tags_map
|
|
71
|
+
logger.info(f'reload path to tags map: {len(self.regex_to_tags_map)}')
|
|
72
|
+
|
|
73
|
+
# 构造权限白名单
|
|
74
|
+
_auth_whitelist = list()
|
|
75
|
+
for row in FIX_WHITELIST:
|
|
76
|
+
_auth_whitelist.append(re.compile(row))
|
|
77
|
+
for row in white_list:
|
|
78
|
+
_auth_whitelist.append(re.compile(row))
|
|
79
|
+
self.auth_whitelist.clear()
|
|
80
|
+
self.auth_whitelist = _auth_whitelist
|
|
81
|
+
logger.info(f'reload no auth whitelist: {len(self.auth_whitelist)}')
|
|
82
|
+
#
|
|
83
|
+
self._routes_inited = True
|
|
84
|
+
|
|
85
|
+
return self
|
|
86
|
+
|
|
87
|
+
async def access_check(self, request: Request):
|
|
88
|
+
"""
|
|
89
|
+
登录校验
|
|
90
|
+
:param request:
|
|
91
|
+
:return:
|
|
92
|
+
"""
|
|
93
|
+
access_token = self.get_access_token(request)
|
|
94
|
+
if not access_token:
|
|
95
|
+
raise NoAuthException('请先登录')
|
|
96
|
+
|
|
97
|
+
if not await auth_cache_pool.has_access(access_token):
|
|
98
|
+
raise NoAuthException('登录信息已失效')
|
|
99
|
+
|
|
100
|
+
return access_token
|
|
101
|
+
|
|
102
|
+
async def auth_check(self, request: Request):
|
|
103
|
+
"""
|
|
104
|
+
权限校验
|
|
105
|
+
:param request:
|
|
106
|
+
:return:
|
|
107
|
+
"""
|
|
108
|
+
if not self.regex_to_tags_map:
|
|
109
|
+
logger.debug(f'no path to tags map')
|
|
110
|
+
return False
|
|
111
|
+
# 优先处理白名单
|
|
112
|
+
for white_regex in self.auth_whitelist:
|
|
113
|
+
if white_regex.match(request.url.path):
|
|
114
|
+
return True
|
|
115
|
+
# 先校验登录
|
|
116
|
+
access_token = await self.access_check(request)
|
|
117
|
+
for regex in self.regex_to_tags_map.keys():
|
|
118
|
+
if regex.match(request.url.path):
|
|
119
|
+
# 匹配到路由
|
|
120
|
+
tags = self.regex_to_tags_map.get(regex)
|
|
121
|
+
if tags:
|
|
122
|
+
# 需要登录且需要鉴权
|
|
123
|
+
auth_rs = await self._rds_permissions_check(tags, access_token)
|
|
124
|
+
if not auth_rs:
|
|
125
|
+
raise NoAuthException('权限不足')
|
|
126
|
+
return True
|
|
127
|
+
else:
|
|
128
|
+
# 仅登录无需鉴权
|
|
129
|
+
return True
|
|
130
|
+
|
|
131
|
+
return False
|
|
132
|
+
|
|
133
|
+
async def _rds_permissions_check(self, tags, access_token) -> bool:
|
|
134
|
+
# redis 权限查询
|
|
135
|
+
for tag in tags:
|
|
136
|
+
# 新权限颗粒 xx.xx.xx.xx
|
|
137
|
+
has = await auth_cache_pool.has_permission(access_token, tag)
|
|
138
|
+
if not has:
|
|
139
|
+
# 兼容 Java 版权限颗粒 "xx.xx.xx.xx"
|
|
140
|
+
has = await auth_cache_pool.has_permission(access_token, json.dumps(tag))
|
|
141
|
+
if not has:
|
|
142
|
+
return False
|
|
143
|
+
|
|
144
|
+
return True
|
|
145
|
+
|
|
146
|
+
async def get_auth_user(self, request: Request) -> Optional[AuthUser]:
|
|
147
|
+
"""
|
|
148
|
+
获取当前登录的用户信息
|
|
149
|
+
:param request:
|
|
150
|
+
:return:
|
|
151
|
+
"""
|
|
152
|
+
access_token = self.get_access_token(request)
|
|
153
|
+
if not access_token:
|
|
154
|
+
raise NoAuthException('请先登录')
|
|
155
|
+
if not await auth_cache_pool.has_access(access_token):
|
|
156
|
+
raise NoAuthException('登录信息已失效')
|
|
157
|
+
|
|
158
|
+
auth_user = await auth_cache_pool.get_cached_user(access_token)
|
|
159
|
+
|
|
160
|
+
# 设置租户上下文
|
|
161
|
+
if auth_user and auth_user.team_id:
|
|
162
|
+
tenant_context.set(auth_user.team_id)
|
|
163
|
+
|
|
164
|
+
return auth_user
|
|
165
|
+
|
|
166
|
+
async def get_auth_app(self, request: Request) -> Optional[AuthApp]:
|
|
167
|
+
"""
|
|
168
|
+
获取当前登录的应用信息
|
|
169
|
+
:param request:
|
|
170
|
+
:return:
|
|
171
|
+
"""
|
|
172
|
+
access_token = self.get_access_token(request)
|
|
173
|
+
if not access_token:
|
|
174
|
+
raise NoAuthException('请先登录')
|
|
175
|
+
if not await auth_cache_pool.has_app_access(access_token):
|
|
176
|
+
raise NoAuthException('登录信息已失效')
|
|
177
|
+
|
|
178
|
+
auth_app = await auth_cache_pool.get_cached_app(access_token)
|
|
179
|
+
|
|
180
|
+
return auth_app
|
|
181
|
+
|
|
182
|
+
async def get_auth_team_sn(self, request: Request) -> Optional[str]:
|
|
183
|
+
"""
|
|
184
|
+
获取授权用户所属租户
|
|
185
|
+
:param request:
|
|
186
|
+
:return:
|
|
187
|
+
"""
|
|
188
|
+
access_token = self.get_access_token(request)
|
|
189
|
+
if access_token:
|
|
190
|
+
access_item = await auth_cache_pool.get_cached_access_dict(access_token)
|
|
191
|
+
if access_item:
|
|
192
|
+
# 仅超级管理员可以切换租户
|
|
193
|
+
team_sn = access_item.get('team_sn', None)
|
|
194
|
+
if team_sn:
|
|
195
|
+
return team_sn
|
|
196
|
+
|
|
197
|
+
return ''
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
# 用于handler来Depends,获取授权用户
|
|
201
|
+
async def authed_user(request: Request) -> Optional[AuthUser]:
|
|
202
|
+
return ShareAuth().get_auth_user(request)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
# 用于handler来Depends,获取授权APP信息
|
|
206
|
+
async def authed_app(request: Request) -> Optional[AuthApp]:
|
|
207
|
+
return ShareAuth().get_auth_app(request)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
# 用于handler来Depends,获取授权用户
|
|
211
|
+
async def access_token(request: Request) -> Optional[str]:
|
|
212
|
+
return ShareAuth().get_access_token(request)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
|
|
File without changes
|
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
# encoding: utf-8
|
|
2
|
+
# version: pymongo==4.7.3, motor==3.5.0
|
|
3
|
+
|
|
4
|
+
from ..exception.exceptions import BizException
|
|
5
|
+
from ..exception.exceptions import NoConfigException
|
|
6
|
+
from ..lib import cfg, logger
|
|
7
|
+
|
|
8
|
+
LOGGER = logger.get('AsyncMongoClient')
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class AsyncMongo(object):
|
|
12
|
+
def __init__(self, mongo_uri: str = None, mongo_db: str = None):
|
|
13
|
+
import motor.motor_asyncio
|
|
14
|
+
self.mongo_uri = mongo_uri or cfg.get('MONGO_URL')
|
|
15
|
+
self.mongo_db = mongo_db or cfg.get('MONGO_DB')
|
|
16
|
+
if not self.mongo_uri:
|
|
17
|
+
raise NoConfigException('mongodb uri not config!')
|
|
18
|
+
if not self.mongo_db:
|
|
19
|
+
raise NoConfigException('mongodb database not config!')
|
|
20
|
+
|
|
21
|
+
self.client = motor.motor_asyncio.AsyncIOMotorClient(self.mongo_uri)
|
|
22
|
+
self.db = self.client[self.mongo_db]
|
|
23
|
+
|
|
24
|
+
LOGGER.info(self.client)
|
|
25
|
+
|
|
26
|
+
# 获取集合
|
|
27
|
+
def get_collection(self, coll):
|
|
28
|
+
# 非异步
|
|
29
|
+
return self.db.get_collection(coll)
|
|
30
|
+
|
|
31
|
+
# 查询对象
|
|
32
|
+
async def get(self, collection, query={}):
|
|
33
|
+
return await self.db[collection].find_one(query)
|
|
34
|
+
|
|
35
|
+
# 统计数量
|
|
36
|
+
async def count(self, collection, query={}):
|
|
37
|
+
return await self.db[collection].count_documents(query)
|
|
38
|
+
|
|
39
|
+
# 查询列表
|
|
40
|
+
async def list(self, collection, query={}, fields=None, sort=[], batch_size: int = 2000):
|
|
41
|
+
result = []
|
|
42
|
+
cursor = self.db[collection].find(filter=query, projection=fields, sort=sort)
|
|
43
|
+
cursor.batch_size(batch_size)
|
|
44
|
+
async for document in cursor:
|
|
45
|
+
result.append(document)
|
|
46
|
+
await cursor.close()
|
|
47
|
+
|
|
48
|
+
return result
|
|
49
|
+
|
|
50
|
+
async def list_with_cursor(self, collection, query={}, fields=None, sort=[], batch_size: int = 2000):
|
|
51
|
+
"""
|
|
52
|
+
样例:
|
|
53
|
+
cursor = await aio_mongo.list_with_cursor(collection)
|
|
54
|
+
async for document in cursor:
|
|
55
|
+
print("document =", document)
|
|
56
|
+
await cursor.close()
|
|
57
|
+
注意:应该在使用完毕后主动关闭 cursor
|
|
58
|
+
"""
|
|
59
|
+
cursor = self.db[collection].find(filter=query, projection=fields, sort=sort)
|
|
60
|
+
cursor.batch_size(batch_size)
|
|
61
|
+
|
|
62
|
+
return cursor
|
|
63
|
+
|
|
64
|
+
# 分页查询
|
|
65
|
+
async def page(self, collection, query={}, page_no=1, page_size=20, fields=None, sort=[], batch_size: int = 2000):
|
|
66
|
+
total = await self.db[collection].count_documents(query) or 0
|
|
67
|
+
cursor = self.db[collection].find(query, fields, sort=sort).skip(page_size * (page_no - 1)).limit(page_size)
|
|
68
|
+
cursor.batch_size(batch_size)
|
|
69
|
+
rows = []
|
|
70
|
+
async for document in cursor:
|
|
71
|
+
rows.append(document)
|
|
72
|
+
await cursor.close()
|
|
73
|
+
|
|
74
|
+
return rows, total
|
|
75
|
+
|
|
76
|
+
# 查询列表前N个
|
|
77
|
+
async def top(self, collection, query={}, sort=[], limit=1, fields=None, batch_size: int = 2000):
|
|
78
|
+
cursor = self.db[collection].find(filter=query, projection=fields, sort=sort, limit=limit)
|
|
79
|
+
cursor.batch_size(batch_size)
|
|
80
|
+
rows = []
|
|
81
|
+
async for document in cursor:
|
|
82
|
+
if limit == 1:
|
|
83
|
+
return document
|
|
84
|
+
rows.append(document)
|
|
85
|
+
await cursor.close()
|
|
86
|
+
|
|
87
|
+
return rows
|
|
88
|
+
|
|
89
|
+
# 查询去重列表
|
|
90
|
+
async def distinct(self, collection, dist_key, query={}, fields=None):
|
|
91
|
+
return await self.db[collection].find(query, fields).distinct(dist_key)
|
|
92
|
+
|
|
93
|
+
# 含分页聚合查询
|
|
94
|
+
async def aggregate_page(self, collection, pipelines, page_no=1, page_size=20):
|
|
95
|
+
skip = page_size * (page_no - 1)
|
|
96
|
+
if pipelines:
|
|
97
|
+
pipelines.append({'$facet': {'total': [{'$count': 'count'}], 'rows': [{'$skip': skip}, {'$limit': page_size}]}})
|
|
98
|
+
pipelines.append({'$project': {'data': '$rows', 'total': {'$arrayElemAt': ['$total.count', 0]}}})
|
|
99
|
+
|
|
100
|
+
cursor = self.db[collection].aggregate(pipelines, session=None, allowDiskUse=True)
|
|
101
|
+
cursor.batch_size(page_size)
|
|
102
|
+
async for rs in cursor:
|
|
103
|
+
# rs: dict as {"data": [], "total": int}
|
|
104
|
+
if rs and 'data' in rs and 'total' in rs:
|
|
105
|
+
await cursor.close()
|
|
106
|
+
return rs.get('data'), rs.get('total')
|
|
107
|
+
await cursor.close()
|
|
108
|
+
return [], 0
|
|
109
|
+
|
|
110
|
+
# 聚合查询
|
|
111
|
+
async def aggregate(self, collection, pipelines=[], batch_size: int = 2000):
|
|
112
|
+
"""
|
|
113
|
+
样例:
|
|
114
|
+
cursor = await aio_mongo.aggregate(collection="collection_name", pipelines=[])
|
|
115
|
+
async for row in cursor:
|
|
116
|
+
print(row)
|
|
117
|
+
|
|
118
|
+
await cursor.close()
|
|
119
|
+
注意:应该在使用完毕后主动关闭 cursor
|
|
120
|
+
"""
|
|
121
|
+
cursor = self.db[collection].aggregate(pipelines, session=None, allowDiskUse=True)
|
|
122
|
+
cursor.batch_size(batch_size)
|
|
123
|
+
return cursor
|
|
124
|
+
|
|
125
|
+
# 查询分页列表,还没找到实际应用
|
|
126
|
+
async def list_with_page(self, collection, query={}, page_size=10000, fields=None):
|
|
127
|
+
# 没有用到
|
|
128
|
+
rows = list()
|
|
129
|
+
total = await self.db[collection].count_documents(query)
|
|
130
|
+
if total > 0 and page_size > 0:
|
|
131
|
+
total_page = round(total / page_size) + 1
|
|
132
|
+
for page in range(0, total_page):
|
|
133
|
+
if fields:
|
|
134
|
+
cursor = self.db[collection].find(query, fields).skip(page_size * page).limit(page)
|
|
135
|
+
else:
|
|
136
|
+
cursor = self.db[collection].find(query).skip(page_size * page).limit(page)
|
|
137
|
+
async for document in cursor:
|
|
138
|
+
rows.append(document)
|
|
139
|
+
await cursor.close()
|
|
140
|
+
return rows
|
|
141
|
+
|
|
142
|
+
# 插入或更新
|
|
143
|
+
async def insert_or_update(self, collection, data, id_key='_id', update=None, upsert=True, multi=False):
|
|
144
|
+
if not multi:
|
|
145
|
+
if data and not update:
|
|
146
|
+
result = await self.db[collection].update_one({id_key: data[id_key]}, {'$set': data}, upsert=upsert)
|
|
147
|
+
elif not data and update:
|
|
148
|
+
result = await self.db[collection].update_one({id_key: data[id_key]}, update, upsert=upsert)
|
|
149
|
+
else:
|
|
150
|
+
# all([data, update]) or not all([data, update])
|
|
151
|
+
raise BizException("data和update不能同时存在或同时为空")
|
|
152
|
+
else:
|
|
153
|
+
if data and not update:
|
|
154
|
+
result = await self.db[collection].update_many({id_key: data[id_key]}, {'$set': data}, upsert=upsert)
|
|
155
|
+
elif not data and update:
|
|
156
|
+
result = await self.db[collection].update_many({id_key: data[id_key]}, update, upsert=upsert)
|
|
157
|
+
else:
|
|
158
|
+
raise BizException("data和update不能同时存在或同时为空")
|
|
159
|
+
|
|
160
|
+
return result
|
|
161
|
+
|
|
162
|
+
# 插入
|
|
163
|
+
async def insert(self, collection, data):
|
|
164
|
+
return await self.db[collection].insert_one(data)
|
|
165
|
+
|
|
166
|
+
# 更新
|
|
167
|
+
async def update(self, collection, filter, data=None, update=None, multi=False):
|
|
168
|
+
if multi:
|
|
169
|
+
if data and not update:
|
|
170
|
+
result = await self.db[collection].update_many(filter, {'$set': data})
|
|
171
|
+
elif not data and update:
|
|
172
|
+
result = await self.db[collection].update_many(filter, update)
|
|
173
|
+
else:
|
|
174
|
+
raise BizException("data和update不能同时存在或同时为空")
|
|
175
|
+
else:
|
|
176
|
+
if data and not update:
|
|
177
|
+
result = await self.db[collection].update_one(filter, {'$set': data})
|
|
178
|
+
elif not data and update:
|
|
179
|
+
result = await self.db[collection].update_one(filter, update)
|
|
180
|
+
else:
|
|
181
|
+
raise BizException("data和update不能同时存在或同时为空")
|
|
182
|
+
|
|
183
|
+
return result
|
|
184
|
+
|
|
185
|
+
# 原生保存方法
|
|
186
|
+
async def save(self, collection, filter: dict, save_data: dict, upsert=True):
|
|
187
|
+
return await self.db[collection].update_one(filter, {'$set': save_data}, upsert=upsert)
|
|
188
|
+
|
|
189
|
+
# 以主键更新
|
|
190
|
+
async def update_by_pk(self, collection, pk_val, data=None, update=None, upsert=False):
|
|
191
|
+
if data and not update:
|
|
192
|
+
result = await self.db[collection].update_one({'_id': pk_val}, {'$set': data}, upsert=upsert)
|
|
193
|
+
elif not data and update:
|
|
194
|
+
result = await self.db[collection].update_one({'_id': pk_val}, update, upsert=upsert)
|
|
195
|
+
else:
|
|
196
|
+
raise BizException("data和update不能同时存在或同时为空")
|
|
197
|
+
return result
|
|
198
|
+
|
|
199
|
+
# 批量更新
|
|
200
|
+
async def batch_update(self, collection, filter, datas=None, update=None, *args, **kwargs):
|
|
201
|
+
if datas and not update:
|
|
202
|
+
result = await self.db[collection].update_many(filter, {'$set': datas})
|
|
203
|
+
elif not datas and update:
|
|
204
|
+
result = await self.db[collection].update_many(filter, update)
|
|
205
|
+
else:
|
|
206
|
+
raise BizException("data和update不能同时存在或同时为空")
|
|
207
|
+
return result
|
|
208
|
+
|
|
209
|
+
# 删除
|
|
210
|
+
async def delete(self, collection, filter):
|
|
211
|
+
return await self.db[collection].delete_many(filter)
|
|
212
|
+
|
|
213
|
+
# 插入或更新
|
|
214
|
+
async def bulk_write(self, collection, bulk_list: list, batch_size: int = 1000):
|
|
215
|
+
result = None
|
|
216
|
+
if bulk_list:
|
|
217
|
+
bulk_lists = [bulk_list[i: i + batch_size] for i in range(0, len(bulk_list), batch_size)]
|
|
218
|
+
for _bulk_list in bulk_lists:
|
|
219
|
+
result = await self.db[collection].bulk_write(_bulk_list, ordered=False, bypass_document_validation=True)
|
|
220
|
+
return result
|
|
221
|
+
|
|
222
|
+
# 创建索引
|
|
223
|
+
async def create_index(self, collection, fields):
|
|
224
|
+
return await self.db[collection].create_index(fields)
|
|
225
|
+
|
|
226
|
+
def close(self):
|
|
227
|
+
if self.client:
|
|
228
|
+
try:
|
|
229
|
+
self.client.close()
|
|
230
|
+
print("close successful")
|
|
231
|
+
except Exception as e:
|
|
232
|
+
print("close mongo client catch err:", e)
|
|
233
|
+
|