fast-web-core 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. fast-web-core-1.0.0/MANIFEST.in +0 -0
  2. fast-web-core-1.0.0/PKG-INFO +23 -0
  3. fast-web-core-1.0.0/README.rst +12 -0
  4. fast-web-core-1.0.0/fast_web_core/__init__.py +9 -0
  5. fast-web-core-1.0.0/fast_web_core/auth/__init__.py +0 -0
  6. fast-web-core-1.0.0/fast_web_core/auth/auth_cache_pool.py +88 -0
  7. fast-web-core-1.0.0/fast_web_core/auth/share_auth.py +223 -0
  8. fast-web-core-1.0.0/fast_web_core/client/__init__.py +0 -0
  9. fast-web-core-1.0.0/fast_web_core/client/async_mongo_client.py +233 -0
  10. fast-web-core-1.0.0/fast_web_core/client/async_mulit_tenant_mongo_client.py +69 -0
  11. fast-web-core-1.0.0/fast_web_core/client/async_multi_database_mongo_client.py +60 -0
  12. fast-web-core-1.0.0/fast_web_core/client/async_rabbit_client.py +71 -0
  13. fast-web-core-1.0.0/fast_web_core/client/async_redis_client.py +37 -0
  14. fast-web-core-1.0.0/fast_web_core/client/mongo_client.py +208 -0
  15. fast-web-core-1.0.0/fast_web_core/client/multi_database_mongo_client.py +59 -0
  16. fast-web-core-1.0.0/fast_web_core/client/multi_tenant_mongo_client.py +60 -0
  17. fast-web-core-1.0.0/fast_web_core/client/oss2_client.py +67 -0
  18. fast-web-core-1.0.0/fast_web_core/client/redis_client.py +36 -0
  19. fast-web-core-1.0.0/fast_web_core/context/__init__.py +0 -0
  20. fast-web-core-1.0.0/fast_web_core/context/context_vars.py +11 -0
  21. fast-web-core-1.0.0/fast_web_core/exception/__init__.py +0 -0
  22. fast-web-core-1.0.0/fast_web_core/exception/exceptions.py +32 -0
  23. fast-web-core-1.0.0/fast_web_core/lib/__init__.py +0 -0
  24. fast-web-core-1.0.0/fast_web_core/lib/cfg.py +60 -0
  25. fast-web-core-1.0.0/fast_web_core/lib/func.py +92 -0
  26. fast-web-core-1.0.0/fast_web_core/lib/logger.py +87 -0
  27. fast-web-core-1.0.0/fast_web_core/lib/secret.py +59 -0
  28. fast-web-core-1.0.0/fast_web_core/lib/settings.py +103 -0
  29. fast-web-core-1.0.0/fast_web_core/lib/time.py +343 -0
  30. fast-web-core-1.0.0/fast_web_core/middleware/__init__.py +23 -0
  31. fast-web-core-1.0.0/fast_web_core/middleware/auth.py +55 -0
  32. fast-web-core-1.0.0/fast_web_core/model/__init__.py +0 -0
  33. fast-web-core-1.0.0/fast_web_core/model/base.py +25 -0
  34. fast-web-core-1.0.0/fast_web_core/model/enums.py +24 -0
  35. fast-web-core-1.0.0/fast_web_core/model/handler.py +96 -0
  36. fast-web-core-1.0.0/fast_web_core/model/items.py +75 -0
  37. fast-web-core-1.0.0/fast_web_core/model/query.py +23 -0
  38. fast-web-core-1.0.0/fast_web_core/utils/__init__.py +0 -0
  39. fast-web-core-1.0.0/fast_web_core/utils/common.py +164 -0
  40. fast-web-core-1.0.0/fast_web_core/utils/context_var.py +9 -0
  41. fast-web-core-1.0.0/fast_web_core/utils/decator.py +13 -0
  42. fast-web-core-1.0.0/fast_web_core/utils/encryption.py +106 -0
  43. fast-web-core-1.0.0/fast_web_core/utils/sequence.py +44 -0
  44. fast-web-core-1.0.0/fast_web_core.egg-info/PKG-INFO +23 -0
  45. fast-web-core-1.0.0/fast_web_core.egg-info/SOURCES.txt +47 -0
  46. fast-web-core-1.0.0/fast_web_core.egg-info/dependency_links.txt +1 -0
  47. fast-web-core-1.0.0/fast_web_core.egg-info/top_level.txt +1 -0
  48. fast-web-core-1.0.0/setup.cfg +4 -0
  49. fast-web-core-1.0.0/setup.py +25 -0
File without changes
@@ -0,0 +1,23 @@
1
+ Metadata-Version: 2.1
2
+ Name: fast-web-core
3
+ Version: 1.0.0
4
+ Summary: fast web core for zsodata
5
+ Home-page: http://www.zsodata.com
6
+ Author: zsodata
7
+ Author-email: team@zso.io
8
+ License: BSD License
9
+ Platform: all
10
+ Requires-Python: >=3.7
11
+
12
+ pip uninstall fast-web-core
13
+
14
+
15
+ 注意:
16
+ mongo依赖版本:pymongo==4.7.3 and mongo version 7.0
17
+ 如果需要使用旧版,切换到分支:mongo4
18
+
19
+ 重点新增:
20
+ 异步mongo,异步redis
21
+
22
+ pip install --upgrade fast-web-core -i https://pypi.python.org/simple
23
+ pip show fast-web-core
@@ -0,0 +1,12 @@
1
+ pip uninstall fast-web-core
2
+
3
+
4
+ 注意:
5
+ mongo依赖版本:pymongo==4.7.3 and mongo version 7.0
6
+ 如果需要使用旧版,切换到分支:mongo4
7
+
8
+ 重点新增:
9
+ 异步mongo,异步redis
10
+
11
+ pip install --upgrade fast-web-core -i https://pypi.python.org/simple
12
+ pip show fast-web-core
@@ -0,0 +1,9 @@
1
+ from __future__ import absolute_import
2
+
3
+
4
+ def init_core_modules(*args, **kwargs) -> None:
5
+ """
6
+ Init core handlers, middlewares and services for fastapi application
7
+ :return: None
8
+ """
9
+ pass
File without changes
@@ -0,0 +1,88 @@
1
+ # -*- coding: utf-8 -*-
2
+ import json
3
+ import logging
4
+ from typing import List, Set, Optional, Dict
5
+ from ..lib import cfg
6
+ from ..client.redis_client import Redis
7
+ from ..model.items import AuthUser, AuthApp
8
+
9
+ logger = logging.getLogger(__name__)
10
+ rds = Redis(redis_uri=cfg.get('AUTH_REDIS_URL', None), redis_db=cfg.get('AUTH_REDIS_DB', None)).client
11
+ default_expire = 43200
12
+
13
+
14
+ def save_cached_app(access_token: str, app: AuthApp, expire: int = None):
15
+ if app:
16
+ _expire = expire or default_expire
17
+ rds.set(f'share:access:{access_token}', app.json(), ex=_expire)
18
+
19
+
20
+ def get_cached_user(access_token: str) -> Optional[AuthUser]:
21
+ rs = rds.get(f'share:access:{access_token}')
22
+ if rs:
23
+ js_user = json.loads(rs)
24
+ # 容错
25
+ if isinstance(js_user, str):
26
+ js_user = json.loads(js_user)
27
+ # 特征验证(user对象需要有id字段)
28
+ if js_user and js_user.get('id', None):
29
+ user = AuthUser(**js_user)
30
+ user.userName = js_user.get('username', '')
31
+ return user
32
+
33
+ return None
34
+
35
+
36
+ def get_cached_app(access_token: str) -> Optional[AuthApp]:
37
+ rs = rds.get(f'share:access:{access_token}')
38
+ if rs:
39
+ js_app = json.loads(rs)
40
+ # 容错
41
+ if isinstance(js_app, str):
42
+ js_app = json.loads(js_app)
43
+ # 特征验证(APP对象需要有appId字段)
44
+ if js_app and js_app.get('appId', None):
45
+ app = AuthApp(**js_app)
46
+ return app
47
+
48
+ return None
49
+
50
+
51
+ def get_cached_access_dict(access_token: str) -> Optional[Dict]:
52
+ rs = rds.get(f'share:access:{access_token}')
53
+ if rs:
54
+ item = json.loads(rs)
55
+ # 容错
56
+ if isinstance(item, str):
57
+ item = json.loads(item)
58
+ return item
59
+
60
+ return None
61
+
62
+
63
+ def has_access(access_token: str) -> bool:
64
+ return rds.exists(f'share:access:{access_token}')
65
+
66
+
67
+ def save_cached_permission(access_token: str, permissions: Set[str], expire: int = None):
68
+ if permissions:
69
+ key = f'share:permissions:{access_token}'
70
+ for per in permissions:
71
+ rds.sadd(key, per)
72
+ _expire = expire or default_expire
73
+ rds.expire(key, _expire)
74
+
75
+
76
+ def has_permission(access_token: str, permission: str) -> bool:
77
+ # 后者兼容Java版本异常
78
+ _key = f'share:permissions:{access_token}'
79
+ return rds.sismember(_key, f'{permission}') or rds.sismember(_key, f'"{permission}"')
80
+
81
+
82
+ def get_cached_permission(access_token: str) -> List[str]:
83
+ return rds.smembers(f'share:permissions:{access_token}')
84
+
85
+
86
+ def clean_cache(access_token: str):
87
+ rds.delete(f'share:access:{access_token}')
88
+ rds.delete(f'share:permissions:{access_token}')
@@ -0,0 +1,223 @@
1
+ # -*- coding: utf-8 -*-
2
+ import logging
3
+ import re
4
+ import json
5
+ import threading
6
+ from typing import Optional, List
7
+ from starlette.routing import BaseRoute
8
+ from fastapi.routing import APIRoute
9
+ from starlette.requests import Request
10
+
11
+ from ..lib import cfg
12
+ from ..auth import auth_cache_pool
13
+ from ..client.redis_client import Redis
14
+ from ..exception.exceptions import NoAuthException
15
+ from ..model.items import AuthUser, AuthApp
16
+ from ..context.context_vars import tenant_context
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ FIX_WHITELIST = [
21
+ ]
22
+
23
+
24
+ class ShareAuth(object):
25
+ _instance_lock = threading.Lock()
26
+ _biz_inited = False
27
+ _routes_inited = False
28
+
29
+ def __new__(cls, *args, **kwargs):
30
+ if not hasattr(ShareAuth, "_instance"):
31
+ with ShareAuth._instance_lock:
32
+ if not hasattr(ShareAuth, "_instance"):
33
+ ShareAuth._instance = object.__new__(cls)
34
+ return ShareAuth._instance
35
+
36
+ def __init__(self):
37
+ if not self._biz_inited:
38
+ self._biz_inited = True
39
+ self.rds = Redis(redis_uri=cfg.get('AUTH_REDIS_URL', None), redis_db=cfg.get('AUTH_REDIS_DB', None)).client
40
+ # route path regex -> tags
41
+ self.regex_to_tags_map = dict()
42
+ # no auth path whitelist
43
+ self.auth_whitelist = []
44
+
45
+ def get_access_token(self, request: Request):
46
+ """
47
+ 解析请求头鉴权信息
48
+ :param request:
49
+ :return:
50
+ """
51
+ if request and (request.headers or request.query_params):
52
+ access_token = request.headers.get('Authorization', '').replace('Bearer ', '')
53
+ if not access_token:
54
+ access_token = request.cookies.get('access_token')
55
+ if not access_token:
56
+ access_token = request.query_params.get('access_token')
57
+ return access_token
58
+
59
+ return ''
60
+
61
+ def reload(self, routes: Optional[List[BaseRoute]], white_list=[], force: bool = False):
62
+ if not force and self._routes_inited:
63
+ return self
64
+ # 构造路由匹配正则与权限颗粒(tags)映射
65
+ _tmp_regex_to_tags_map = dict()
66
+ for route in routes:
67
+ if isinstance(route, APIRoute) and route.path_regex:
68
+ _tmp_regex_to_tags_map[route.path_regex] = route.tags
69
+ self.regex_to_tags_map.clear()
70
+ self.regex_to_tags_map = _tmp_regex_to_tags_map
71
+ logger.info(f'reload path to tags map: {len(self.regex_to_tags_map)}')
72
+
73
+ # 构造权限白名单
74
+ _auth_whitelist = list()
75
+ for row in FIX_WHITELIST:
76
+ _auth_whitelist.append(re.compile(row))
77
+ for row in white_list:
78
+ _auth_whitelist.append(re.compile(row))
79
+ self.auth_whitelist.clear()
80
+ self.auth_whitelist = _auth_whitelist
81
+ logger.info(f'reload no auth whitelist: {len(self.auth_whitelist)}')
82
+ #
83
+ self._routes_inited = True
84
+
85
+ return self
86
+
87
+ def access_check(self, request: Request):
88
+ """
89
+ 登录校验
90
+ :param request:
91
+ :return:
92
+ """
93
+ access_token = self.get_access_token(request)
94
+ if not access_token:
95
+ raise NoAuthException('请先登录')
96
+
97
+ if not auth_cache_pool.has_access(access_token):
98
+ raise NoAuthException('登录信息已失效')
99
+
100
+ return access_token
101
+
102
+ def auth_check(self, request: Request):
103
+ """
104
+ 权限校验
105
+ :param request:
106
+ :return:
107
+ """
108
+ if not self.regex_to_tags_map:
109
+ logger.debug(f'no path to tags map')
110
+ return False
111
+ # 优先处理白名单
112
+ for white_regex in self.auth_whitelist:
113
+ if white_regex.match(request.url.path):
114
+ return True
115
+ # 先校验登录
116
+ access_token = self.access_check(request)
117
+ for regex in self.regex_to_tags_map.keys():
118
+ if regex.match(request.url.path):
119
+ # 匹配到路由
120
+ tags = self.regex_to_tags_map.get(regex)
121
+ if tags:
122
+ # 需要登录且需要鉴权
123
+ auth_rs = self._rds_permissions_check(tags, access_token)
124
+ if not auth_rs:
125
+ raise NoAuthException('权限不足')
126
+ return True
127
+ else:
128
+ # 仅登录无需鉴权
129
+ return True
130
+
131
+ return False
132
+
133
+ def _rds_permissions_check(self, tags, access_token) -> bool:
134
+ # redis权限查询
135
+ for tag in tags:
136
+ # 新权限颗粒 xx.xx.xx.xx
137
+ has = auth_cache_pool.has_permission(access_token, tag)
138
+ if not has:
139
+ # 兼容Java版权限颗粒 "xx.xx.xx.xx"
140
+ has = auth_cache_pool.has_permission(access_token, json.dumps(tag))
141
+ if not has:
142
+ return False
143
+
144
+ return True
145
+
146
+ def get_auth_user(self, request: Request) -> Optional[AuthUser]:
147
+ """
148
+ 获取当前登录的用户信息
149
+ :param request:
150
+ :return:
151
+ """
152
+ access_token = self.get_access_token(request)
153
+ if not access_token:
154
+ raise NoAuthException('请先登录')
155
+ if not auth_cache_pool.has_access(access_token):
156
+ raise NoAuthException('登录信息已失效')
157
+
158
+ auth_user = auth_cache_pool.get_cached_user(access_token)
159
+ # 仅超级管理员可以切换租户
160
+ if auth_user and (auth_user.superAdmin == 1 or auth_user.superAdmin == '1'):
161
+ tenant_code = request.headers.get('TenantCode', None)
162
+ if tenant_code:
163
+ # 切换租户
164
+ auth_user.tenantCode = tenant_code
165
+
166
+ # 设置租户上下文
167
+ if auth_user and auth_user.tenantCode:
168
+ tenant_context.set(auth_user.tenantCode)
169
+
170
+ return auth_user
171
+
172
+ def get_auth_app(self, request: Request) -> Optional[AuthApp]:
173
+ """
174
+ 获取当前登录的应用信息
175
+ """
176
+ access_token = self.get_access_token(request)
177
+ if not access_token:
178
+ raise NoAuthException('请先鉴权认证')
179
+ if not auth_cache_pool.has_access(access_token):
180
+ raise NoAuthException('鉴权认证信息已失效')
181
+
182
+ auth_app = auth_cache_pool.get_cached_app(access_token=access_token)
183
+ # 设置租户上下文
184
+ if auth_app and auth_app.tenantCode:
185
+ tenant_context.set(auth_app.tenantCode)
186
+
187
+ return auth_app
188
+
189
+ def get_auth_tenant_code(self, request: Request) -> Optional[str]:
190
+ """
191
+ 获取授权用户所属租户
192
+ :param request:
193
+ :return:
194
+ """
195
+ access_token = self.get_access_token(request)
196
+ if access_token:
197
+ access_item = auth_cache_pool.get_cached_access_dict(access_token)
198
+ if access_item:
199
+ # 仅超级管理员可以切换租户
200
+ super_admin = access_item.get('superAdmin', None)
201
+ if super_admin == 1 or super_admin == '1':
202
+ tenant_code = request.headers.get('TenantCode', None)
203
+ if tenant_code:
204
+ return tenant_code
205
+
206
+ return access_item.get('tenantCode', '')
207
+
208
+ return ''
209
+
210
+
211
+ # 用于handler来Depends,获取授权用户
212
+ async def authed_user(request: Request) -> Optional[AuthUser]:
213
+ return ShareAuth().get_auth_user(request)
214
+
215
+
216
+ # 用于handler来Depends,获取授权APP信息
217
+ async def authed_app(request: Request) -> Optional[AuthApp]:
218
+ return ShareAuth().get_auth_app(request)
219
+
220
+
221
+ # 用于handler来Depends,获取授权用户
222
+ async def access_token(request: Request) -> Optional[str]:
223
+ return ShareAuth().get_access_token(request)
File without changes
@@ -0,0 +1,233 @@
1
+ # encoding: utf-8
2
+ # version: pymongo==4.7.3, motor==3.5.0
3
+
4
+ from ..exception.exceptions import BizException
5
+ from ..exception.exceptions import NoConfigException
6
+ from ..lib import cfg, logger
7
+
8
+ LOGGER = logger.get('AsyncMongoClient')
9
+
10
+
11
+ class AsyncMongo(object):
12
+ def __init__(self, mongo_uri: str = None, mongo_db: str = None):
13
+ import motor.motor_asyncio
14
+ self.mongo_uri = mongo_uri or cfg.get('MONGO_URL')
15
+ self.mongo_db = mongo_db or cfg.get('MONGO_DB')
16
+ if not self.mongo_uri:
17
+ raise NoConfigException('mongodb uri not config!')
18
+ if not self.mongo_db:
19
+ raise NoConfigException('mongodb database not config!')
20
+
21
+ self.client = motor.motor_asyncio.AsyncIOMotorClient(self.mongo_uri)
22
+ self.db = self.client[self.mongo_db]
23
+
24
+ LOGGER.info(self.client)
25
+
26
+ # 获取集合
27
+ def get_collection(self, coll):
28
+ # 非异步
29
+ return self.db.get_collection(coll)
30
+
31
+ # 查询对象
32
+ async def get(self, collection, query={}):
33
+ return await self.db[collection].find_one(query)
34
+
35
+ # 统计数量
36
+ async def count(self, collection, query={}):
37
+ return await self.db[collection].count_documents(query)
38
+
39
+ # 查询列表
40
+ async def list(self, collection, query={}, fields=None, sort=[], batch_size: int = 2000):
41
+ result = []
42
+ cursor = self.db[collection].find(filter=query, projection=fields, sort=sort)
43
+ cursor.batch_size(batch_size)
44
+ async for document in cursor:
45
+ result.append(document)
46
+ await cursor.close()
47
+
48
+ return result
49
+
50
+ async def list_with_cursor(self, collection, query={}, fields=None, sort=[], batch_size: int = 2000):
51
+ """
52
+ 样例:
53
+ cursor = await aio_mongo.list_with_cursor(collection)
54
+ async for document in cursor:
55
+ print("document =", document)
56
+ await cursor.close()
57
+ 注意:应该在使用完毕后主动关闭 cursor
58
+ """
59
+ cursor = self.db[collection].find(filter=query, projection=fields, sort=sort)
60
+ cursor.batch_size(batch_size)
61
+
62
+ return cursor
63
+
64
+ # 分页查询
65
+ async def page(self, collection, query={}, page_no=1, page_size=20, fields=None, sort=[], batch_size: int = 2000):
66
+ total = await self.db[collection].count_documents(query) or 0
67
+ cursor = self.db[collection].find(query, fields, sort=sort).skip(page_size * (page_no - 1)).limit(page_size)
68
+ cursor.batch_size(batch_size)
69
+ rows = []
70
+ async for document in cursor:
71
+ rows.append(document)
72
+ await cursor.close()
73
+
74
+ return rows, total
75
+
76
+ # 查询列表前N个
77
+ async def top(self, collection, query={}, sort=[], limit=1, fields=None, batch_size: int = 2000):
78
+ cursor = self.db[collection].find(filter=query, projection=fields, sort=sort, limit=limit)
79
+ cursor.batch_size(batch_size)
80
+ rows = []
81
+ async for document in cursor:
82
+ if limit == 1:
83
+ return document
84
+ rows.append(document)
85
+ await cursor.close()
86
+
87
+ return rows
88
+
89
+ # 查询去重列表
90
+ async def distinct(self, collection, dist_key, query={}, fields=None):
91
+ return await self.db[collection].find(query, fields).distinct(dist_key)
92
+
93
+ # 含分页聚合查询
94
+ async def aggregate_page(self, collection, pipelines, page_no=1, page_size=20):
95
+ skip = page_size * (page_no - 1)
96
+ if pipelines:
97
+ pipelines.append({'$facet': {'total': [{'$count': 'count'}], 'rows': [{'$skip': skip}, {'$limit': page_size}]}})
98
+ pipelines.append({'$project': {'data': '$rows', 'total': {'$arrayElemAt': ['$total.count', 0]}}})
99
+
100
+ cursor = self.db[collection].aggregate(pipelines, session=None, allowDiskUse=True)
101
+ cursor.batch_size(page_size)
102
+ async for rs in cursor:
103
+ # rs: dict as {"data": [], "total": int}
104
+ if rs and 'data' in rs and 'total' in rs:
105
+ await cursor.close()
106
+ return rs.get('data'), rs.get('total')
107
+ await cursor.close()
108
+ return [], 0
109
+
110
+ # 聚合查询
111
+ async def aggregate(self, collection, pipelines=[], batch_size: int = 2000):
112
+ """
113
+ 样例:
114
+ cursor = await aio_mongo.aggregate(collection="collection_name", pipelines=[])
115
+ async for row in cursor:
116
+ print(row)
117
+
118
+ await cursor.close()
119
+ 注意:应该在使用完毕后主动关闭 cursor
120
+ """
121
+ cursor = self.db[collection].aggregate(pipelines, session=None, allowDiskUse=True)
122
+ cursor.batch_size(batch_size)
123
+ return cursor
124
+
125
+ # 查询分页列表,还没找到实际应用
126
+ async def list_with_page(self, collection, query={}, page_size=10000, fields=None):
127
+ # 没有用到
128
+ rows = list()
129
+ total = await self.db[collection].count_documents(query)
130
+ if total > 0 and page_size > 0:
131
+ total_page = round(total / page_size) + 1
132
+ for page in range(0, total_page):
133
+ if fields:
134
+ cursor = self.db[collection].find(query, fields).skip(page_size * page).limit(page)
135
+ else:
136
+ cursor = self.db[collection].find(query).skip(page_size * page).limit(page)
137
+ async for document in cursor:
138
+ rows.append(document)
139
+ await cursor.close()
140
+ return rows
141
+
142
+ # 插入或更新
143
+ async def insert_or_update(self, collection, data, id_key='_id', update=None, upsert=True, multi=False):
144
+ if not multi:
145
+ if data and not update:
146
+ result = await self.db[collection].update_one({id_key: data[id_key]}, {'$set': data}, upsert=upsert)
147
+ elif not data and update:
148
+ result = await self.db[collection].update_one({id_key: data[id_key]}, update, upsert=upsert)
149
+ else:
150
+ # all([data, update]) or not all([data, update])
151
+ raise BizException("data和update不能同时存在或同时为空")
152
+ else:
153
+ if data and not update:
154
+ result = await self.db[collection].update_many({id_key: data[id_key]}, {'$set': data}, upsert=upsert)
155
+ elif not data and update:
156
+ result = await self.db[collection].update_many({id_key: data[id_key]}, update, upsert=upsert)
157
+ else:
158
+ raise BizException("data和update不能同时存在或同时为空")
159
+
160
+ return result
161
+
162
+ # 插入
163
+ async def insert(self, collection, data):
164
+ return await self.db[collection].insert_one(data)
165
+
166
+ # 更新
167
+ async def update(self, collection, filter, data=None, update=None, multi=False):
168
+ if multi:
169
+ if data and not update:
170
+ result = await self.db[collection].update_many(filter, {'$set': data})
171
+ elif not data and update:
172
+ result = await self.db[collection].update_many(filter, update)
173
+ else:
174
+ raise BizException("data和update不能同时存在或同时为空")
175
+ else:
176
+ if data and not update:
177
+ result = await self.db[collection].update_one(filter, {'$set': data})
178
+ elif not data and update:
179
+ result = await self.db[collection].update_one(filter, update)
180
+ else:
181
+ raise BizException("data和update不能同时存在或同时为空")
182
+
183
+ return result
184
+
185
+ # 原生保存方法
186
+ async def save(self, collection, filter: dict, save_data: dict, upsert=True):
187
+ return await self.db[collection].update_one(filter, {'$set': save_data}, upsert=upsert)
188
+
189
+ # 以主键更新
190
+ async def update_by_pk(self, collection, pk_val, data=None, update=None, upsert=False):
191
+ if data and not update:
192
+ result = await self.db[collection].update_one({'_id': pk_val}, {'$set': data}, upsert=upsert)
193
+ elif not data and update:
194
+ result = await self.db[collection].update_one({'_id': pk_val}, update, upsert=upsert)
195
+ else:
196
+ raise BizException("data和update不能同时存在或同时为空")
197
+ return result
198
+
199
+ # 批量更新
200
+ async def batch_update(self, collection, filter, datas=None, update=None, *args, **kwargs):
201
+ if datas and not update:
202
+ result = await self.db[collection].update_many(filter, {'$set': datas})
203
+ elif not datas and update:
204
+ result = await self.db[collection].update_many(filter, update)
205
+ else:
206
+ raise BizException("data和update不能同时存在或同时为空")
207
+ return result
208
+
209
+ # 删除
210
+ async def delete(self, collection, filter):
211
+ return await self.db[collection].delete_many(filter)
212
+
213
+ # 插入或更新
214
+ async def bulk_write(self, collection, bulk_list: list, batch_size: int = 1000):
215
+ result = None
216
+ if bulk_list:
217
+ bulk_lists = [bulk_list[i: i + batch_size] for i in range(0, len(bulk_list), batch_size)]
218
+ for _bulk_list in bulk_lists:
219
+ result = await self.db[collection].bulk_write(_bulk_list, ordered=False, bypass_document_validation=True)
220
+ return result
221
+
222
+ # 创建索引
223
+ async def create_index(self, collection, fields):
224
+ return await self.db[collection].create_index(fields)
225
+
226
+ def close(self):
227
+ if self.client:
228
+ try:
229
+ self.client.close()
230
+ print("close successful")
231
+ except Exception as e:
232
+ print("close mongo client catch err:", e)
233
+
@@ -0,0 +1,69 @@
1
+ # encoding: utf-8
2
+ import asyncio
3
+ import threading
4
+ from typing import Optional
5
+ from ..client.async_mongo_client import AsyncMongo
6
+ from ..lib import cfg, logger
7
+
8
+
9
+ LOGGER = logger.get('AsyncMultiTenantMongo')
10
+
11
+
12
+ class AsyncMultiTenantMongo(AsyncMongo):
13
+ """
14
+ 异步多租户Mongo客户端简易封装
15
+ """
16
+ def __init__(self, tenant_code=None, mongo_url=None, mongo_db=None, db_prefix=None):
17
+ if not mongo_url:
18
+ mongo_url = cfg.get('MONGO_URL') or 'mongodb://localhost:27017'
19
+ if not mongo_db:
20
+ if tenant_code and (db_prefix or cfg.get('MONGO_DB_BIZ_PREFIX', None)):
21
+ mongo_db = f"{db_prefix or cfg.get('MONGO_DB_BIZ_PREFIX', None)}{tenant_code}"
22
+ if not mongo_db:
23
+ raise Exception('mongodb database not specified')
24
+ super().__init__(mongo_url, mongo_db)
25
+ LOGGER.info(f'[{tenant_code}] tenant mongodb inited~')
26
+
27
+
28
+ class AsyncMultiTenantMongoHolder(object):
29
+ """
30
+ 异步多租户Mongo客户端持有器(租户数据库连接池)
31
+ """
32
+ _instance_lock = threading.Lock()
33
+ _instance_async_lock = asyncio.Lock()
34
+ _tenant_instance_dict = dict()
35
+
36
+ def __new__(cls, *args, **kwargs):
37
+ if not hasattr(AsyncMultiTenantMongoHolder, "_instance"):
38
+
39
+ if not hasattr(AsyncMultiTenantMongoHolder, "_instance"):
40
+ AsyncMultiTenantMongoHolder._instance = object.__new__(cls)
41
+
42
+ return AsyncMultiTenantMongoHolder._instance
43
+
44
+ @staticmethod
45
+ def get_tenant_mongo(tenant_code: str) -> Optional[AsyncMultiTenantMongo]:
46
+ if not tenant_code:
47
+ return None
48
+
49
+ # 有实例则直接返回
50
+ if tenant_code in AsyncMultiTenantMongoHolder._tenant_instance_dict:
51
+ return AsyncMultiTenantMongoHolder._tenant_instance_dict.get(tenant_code)
52
+
53
+ # 无实例则加锁创建
54
+ AsyncMultiTenantMongoHolder._instance_lock.acquire()
55
+ try:
56
+ # 双重锁校验
57
+ if tenant_code not in AsyncMultiTenantMongoHolder._tenant_instance_dict:
58
+ # 初始化新实例
59
+ inst = AsyncMultiTenantMongo(tenant_code=tenant_code)
60
+ AsyncMultiTenantMongoHolder._tenant_instance_dict[tenant_code] = inst
61
+ if AsyncMultiTenantMongoHolder._instance_lock.locked():
62
+ AsyncMultiTenantMongoHolder._instance_lock.release()
63
+ return AsyncMultiTenantMongoHolder._tenant_instance_dict.get(tenant_code)
64
+ finally:
65
+ try:
66
+ if AsyncMultiTenantMongoHolder._instance_lock.locked():
67
+ AsyncMultiTenantMongoHolder._instance_lock.release()
68
+ except Exception as e:
69
+ LOGGER.info(e)