fast-web-core 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. fast_web_core-0.0.1/MANIFEST.in +0 -0
  2. fast_web_core-0.0.1/PKG-INFO +37 -0
  3. fast_web_core-0.0.1/README.rst +12 -0
  4. fast_web_core-0.0.1/fast_web_core/__init__.py +9 -0
  5. fast_web_core-0.0.1/fast_web_core/auth/__init__.py +0 -0
  6. fast_web_core-0.0.1/fast_web_core/auth/auth_cache_pool.py +109 -0
  7. fast_web_core-0.0.1/fast_web_core/auth/share_auth.py +216 -0
  8. fast_web_core-0.0.1/fast_web_core/client/__init__.py +0 -0
  9. fast_web_core-0.0.1/fast_web_core/client/async_mongo_client.py +233 -0
  10. fast_web_core-0.0.1/fast_web_core/client/async_mulit_tenant_mongo_client.py +69 -0
  11. fast_web_core-0.0.1/fast_web_core/client/async_multi_database_mongo_client.py +60 -0
  12. fast_web_core-0.0.1/fast_web_core/client/async_rabbit_client.py +71 -0
  13. fast_web_core-0.0.1/fast_web_core/client/async_redis_client.py +37 -0
  14. fast_web_core-0.0.1/fast_web_core/client/mongo_client.py +208 -0
  15. fast_web_core-0.0.1/fast_web_core/client/multi_database_mongo_client.py +59 -0
  16. fast_web_core-0.0.1/fast_web_core/client/multi_tenant_mongo_client.py +60 -0
  17. fast_web_core-0.0.1/fast_web_core/client/oss2_client.py +95 -0
  18. fast_web_core-0.0.1/fast_web_core/client/redis_client.py +36 -0
  19. fast_web_core-0.0.1/fast_web_core/context/__init__.py +0 -0
  20. fast_web_core-0.0.1/fast_web_core/context/context_vars.py +14 -0
  21. fast_web_core-0.0.1/fast_web_core/exception/__init__.py +0 -0
  22. fast_web_core-0.0.1/fast_web_core/exception/exceptions.py +32 -0
  23. fast_web_core-0.0.1/fast_web_core/lib/__init__.py +0 -0
  24. fast_web_core-0.0.1/fast_web_core/lib/cfg.py +60 -0
  25. fast_web_core-0.0.1/fast_web_core/lib/func.py +92 -0
  26. fast_web_core-0.0.1/fast_web_core/lib/log_filter.py +37 -0
  27. fast_web_core-0.0.1/fast_web_core/lib/logger.py +87 -0
  28. fast_web_core-0.0.1/fast_web_core/lib/secret.py +59 -0
  29. fast_web_core-0.0.1/fast_web_core/lib/settings.py +103 -0
  30. fast_web_core-0.0.1/fast_web_core/lib/time.py +343 -0
  31. fast_web_core-0.0.1/fast_web_core/lib/util.py +56 -0
  32. fast_web_core-0.0.1/fast_web_core/middleware/__init__.py +17 -0
  33. fast_web_core-0.0.1/fast_web_core/middleware/auth.py +56 -0
  34. fast_web_core-0.0.1/fast_web_core/middleware/trace.py +71 -0
  35. fast_web_core-0.0.1/fast_web_core/model/__init__.py +0 -0
  36. fast_web_core-0.0.1/fast_web_core/model/base.py +25 -0
  37. fast_web_core-0.0.1/fast_web_core/model/enums.py +24 -0
  38. fast_web_core-0.0.1/fast_web_core/model/handler.py +96 -0
  39. fast_web_core-0.0.1/fast_web_core/model/items.py +59 -0
  40. fast_web_core-0.0.1/fast_web_core/model/query.py +23 -0
  41. fast_web_core-0.0.1/fast_web_core/utils/__init__.py +0 -0
  42. fast_web_core-0.0.1/fast_web_core/utils/common.py +164 -0
  43. fast_web_core-0.0.1/fast_web_core/utils/context_var.py +9 -0
  44. fast_web_core-0.0.1/fast_web_core/utils/decator.py +13 -0
  45. fast_web_core-0.0.1/fast_web_core/utils/encryption.py +75 -0
  46. fast_web_core-0.0.1/fast_web_core/utils/sequence.py +44 -0
  47. fast_web_core-0.0.1/fast_web_core.egg-info/PKG-INFO +37 -0
  48. fast_web_core-0.0.1/fast_web_core.egg-info/SOURCES.txt +50 -0
  49. fast_web_core-0.0.1/fast_web_core.egg-info/dependency_links.txt +1 -0
  50. fast_web_core-0.0.1/fast_web_core.egg-info/top_level.txt +1 -0
  51. fast_web_core-0.0.1/setup.cfg +4 -0
  52. fast_web_core-0.0.1/setup.py +57 -0
File without changes
@@ -0,0 +1,37 @@
1
+ Metadata-Version: 2.1
2
+ Name: fast_web_core
3
+ Version: 0.0.1
4
+ Summary: fast web core for zsodata
5
+ Home-page: http://www.zsodata.com
6
+ Author: zsodata
7
+ Author-email: team@zso.io
8
+ License: BSD License
9
+ Project-URL: Source, http://www.zsodata.com
10
+ Keywords: fastapi web core framework async mongodb redis
11
+ Platform: all
12
+ Classifier: Development Status :: 4 - Beta
13
+ Classifier: Intended Audience :: Developers
14
+ Classifier: License :: OSI Approved :: BSD License
15
+ Classifier: Operating System :: OS Independent
16
+ Classifier: Programming Language :: Python :: 3
17
+ Classifier: Programming Language :: Python :: 3.7
18
+ Classifier: Programming Language :: Python :: 3.8
19
+ Classifier: Programming Language :: Python :: 3.9
20
+ Classifier: Programming Language :: Python :: 3.10
21
+ Classifier: Programming Language :: Python :: 3.11
22
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
23
+ Requires-Python: >=3.7
24
+ Description-Content-Type: text/plain
25
+
26
+ pip uninstall fast-web-core
27
+
28
+
29
+ 注意:
30
+ mongo依赖版本:pymongo==4.7.3 and mongo version 7.0
31
+ 如果需要使用旧版,切换到分支:mongo4
32
+
33
+ 重点新增:
34
+ 异步mongo,异步redis
35
+
36
+ pip install --upgrade fast-web-core -i https://pypi.python.org/simple
37
+ pip show fast-web-core
@@ -0,0 +1,12 @@
1
+ pip uninstall fast-web-core
2
+
3
+
4
+ 注意:
5
+ mongo依赖版本:pymongo==4.7.3 and mongo version 7.0
6
+ 如果需要使用旧版,切换到分支:mongo4
7
+
8
+ 重点新增:
9
+ 异步mongo,异步redis
10
+
11
+ pip install --upgrade fast-web-core -i https://pypi.python.org/simple
12
+ pip show fast-web-core
@@ -0,0 +1,9 @@
1
+ from __future__ import absolute_import
2
+
3
+
4
+ def init_core_modules(*args, **kwargs) -> None:
5
+ """
6
+ Init core handlers, middlewares and services for fastapi application
7
+ :return: None
8
+ """
9
+ pass
File without changes
@@ -0,0 +1,109 @@
1
+ # -*- coding: utf-8 -*-
2
+ import json
3
+ import logging
4
+ from typing import List, Set, Optional, Dict
5
+ from ..lib import cfg
6
+ from ..client.async_redis_client import AsyncRedis
7
+ from ..model.items import AuthUser, AuthApp
8
+
9
+ logger = logging.getLogger(__name__)
10
+ rds = AsyncRedis(redis_uri=cfg.get('AUTH_REDIS_URL', None), redis_db=cfg.get('AUTH_REDIS_DB', None)).client
11
+ default_expire = 43200
12
+
13
+
14
+ async def get_cached_app(access_token: str) -> Optional[AuthApp]:
15
+ series_prefix = cfg.get('AUTH_SERIES_PREFIX', 'fast')
16
+ rs = await rds.get(f'{series_prefix}:access:share:{access_token}')
17
+ if rs:
18
+ js_user = json.loads(rs)
19
+ # 容错
20
+ if isinstance(js_user, str):
21
+ js_user = json.loads(js_user)
22
+ # 特征验证(user 对象需要有 id 字段)
23
+ if js_user and js_user.get('team_id', None):
24
+ auth_app = AuthApp(**js_user)
25
+ return auth_app
26
+
27
+ return None
28
+
29
+
30
+ async def get_cached_user(access_token: str) -> Optional[AuthUser]:
31
+ series_prefix = cfg.get('AUTH_SERIES_PREFIX', 'fast')
32
+ rs = await rds.get(f'{series_prefix}:access:share:{access_token}')
33
+ if rs:
34
+ js_user = json.loads(rs)
35
+ # 容错
36
+ if isinstance(js_user, str):
37
+ js_user = json.loads(js_user)
38
+ # 特征验证(user 对象需要有 id 字段)
39
+ if js_user and js_user.get('id', None) is not None:
40
+ user = AuthUser(**js_user)
41
+ return user
42
+
43
+ return None
44
+
45
+
46
+ async def get_cached_access_dict(access_token: str) -> Optional[Dict]:
47
+ series_prefix = cfg.get('AUTH_SERIES_PREFIX', 'fast')
48
+ rs = await rds.get(f'{series_prefix}:access:share:{access_token}')
49
+ if rs:
50
+ item = json.loads(rs)
51
+ # 容错
52
+ if isinstance(item, str):
53
+ item = json.loads(item)
54
+ return item
55
+
56
+ return None
57
+
58
+
59
+ async def has_access(access_token: str) -> bool:
60
+ series_prefix = cfg.get('AUTH_SERIES_PREFIX', 'fast')
61
+ return await rds.exists(f'{series_prefix}:access:share:{access_token}')
62
+
63
+
64
+ async def has_app_access(access_token: str) -> bool:
65
+ series_prefix = cfg.get('AUTH_SERIES_PREFIX', 'fast')
66
+ return await rds.exists(f"{series_prefix}:access:share:{access_token}")
67
+
68
+
69
+ async def save_cached_permission(access_token: str, permissions: Set[str], expire: int = None):
70
+ if permissions:
71
+ series_prefix = cfg.get('AUTH_SERIES_PREFIX', 'fast')
72
+ key = f'{series_prefix}:share:permissions:{access_token}'
73
+ for per in permissions:
74
+ await rds.sadd(key, per)
75
+ _expire = expire or default_expire
76
+ await rds.expire(key, _expire)
77
+
78
+
79
+ async def has_permission(access_token: str, permission: str) -> bool:
80
+ # 后者兼容 Java 版本异常
81
+ series_prefix = cfg.get('AUTH_SERIES_PREFIX', 'fast')
82
+ _key = f'{series_prefix}:access:permission:{access_token}'
83
+ is_member = await rds.sismember(_key, permission)
84
+ if not is_member:
85
+ _permission = f'"{permission}"'
86
+ is_member = await rds.sismember(_key, _permission)
87
+
88
+ if not is_member:
89
+ return False
90
+ # 如果有这个接口权限,校验权限状态是否是 1
91
+ key_pattern = f"{series_prefix}:access:*{permission}:{access_token}"
92
+ keys = await rds.keys(key_pattern)
93
+ permission_is_open = False
94
+ for key in keys:
95
+ permission_status = await rds.get(key)
96
+ if permission_status == "1":
97
+ permission_is_open = True
98
+ return permission_is_open
99
+
100
+
101
+ async def get_cached_permission(access_token: str) -> List[str]:
102
+ series_prefix = cfg.get('AUTH_SERIES_PREFIX', 'fast')
103
+ return await rds.smembers(f'{series_prefix}:share:permissions:{access_token}')
104
+
105
+
106
+ async def clean_cache(access_token: str):
107
+ series_prefix = cfg.get('AUTH_SERIES_PREFIX', 'fast')
108
+ await rds.delete(f'{series_prefix}:access:share:{access_token}')
109
+ await rds.delete(f'{series_prefix}:share:permissions:{access_token}')
@@ -0,0 +1,216 @@
1
+ # -*- coding: utf-8 -*-
2
+ import logging
3
+ import re
4
+ import json
5
+ import threading
6
+ from typing import Optional, List
7
+ from starlette.routing import BaseRoute
8
+ from fastapi.routing import APIRoute
9
+ from starlette.requests import Request
10
+
11
+ from ..lib import cfg
12
+ from ..auth import auth_cache_pool
13
+ from ..client.async_redis_client import AsyncRedis
14
+ from ..exception.exceptions import NoAuthException
15
+ from ..model.items import AuthUser, AuthApp
16
+ from ..context.context_vars import tenant_context
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ FIX_WHITELIST = [
21
+ ]
22
+
23
+
24
+ class ShareAuth(object):
25
+ _instance_lock = threading.Lock()
26
+ _biz_inited = False
27
+ _routes_inited = False
28
+
29
+ def __new__(cls, *args, **kwargs):
30
+ if not hasattr(ShareAuth, "_instance"):
31
+ with ShareAuth._instance_lock:
32
+ if not hasattr(ShareAuth, "_instance"):
33
+ ShareAuth._instance = object.__new__(cls)
34
+ return ShareAuth._instance
35
+
36
+ def __init__(self):
37
+ if not self._biz_inited:
38
+ self._biz_inited = True
39
+ self.rds = AsyncRedis(redis_uri=cfg.get('AUTH_REDIS_URL', None), redis_db=cfg.get('AUTH_REDIS_DB', None)).client
40
+ # route path regex -> tags
41
+ self.regex_to_tags_map = dict()
42
+ # no auth path whitelist
43
+ self.auth_whitelist = []
44
+
45
+ def get_access_token(self, request: Request):
46
+ """
47
+ 解析请求头鉴权信息
48
+ :param request:
49
+ :return:
50
+ """
51
+ if request and (request.headers or request.query_params):
52
+ access_token = request.headers.get('Authorization', '').replace('Bearer ', '')
53
+ if not access_token:
54
+ access_token = request.cookies.get('access_token')
55
+ if not access_token:
56
+ access_token = request.query_params.get('access_token')
57
+ return access_token
58
+
59
+ return ''
60
+
61
+ def reload(self, routes: Optional[List[BaseRoute]], white_list=[], force: bool = False):
62
+ if not force and self._routes_inited:
63
+ return self
64
+ # 构造路由匹配正则与权限颗粒(tags)映射
65
+ _tmp_regex_to_tags_map = dict()
66
+ for route in routes:
67
+ if isinstance(route, APIRoute) and route.path_regex:
68
+ _tmp_regex_to_tags_map[route.path_regex] = route.tags
69
+ self.regex_to_tags_map.clear()
70
+ self.regex_to_tags_map = _tmp_regex_to_tags_map
71
+ logger.info(f'reload path to tags map: {len(self.regex_to_tags_map)}')
72
+
73
+ # 构造权限白名单
74
+ _auth_whitelist = list()
75
+ for row in FIX_WHITELIST:
76
+ _auth_whitelist.append(re.compile(row))
77
+ for row in white_list:
78
+ _auth_whitelist.append(re.compile(row))
79
+ self.auth_whitelist.clear()
80
+ self.auth_whitelist = _auth_whitelist
81
+ logger.info(f'reload no auth whitelist: {len(self.auth_whitelist)}')
82
+ #
83
+ self._routes_inited = True
84
+
85
+ return self
86
+
87
+ async def access_check(self, request: Request):
88
+ """
89
+ 登录校验
90
+ :param request:
91
+ :return:
92
+ """
93
+ access_token = self.get_access_token(request)
94
+ if not access_token:
95
+ raise NoAuthException('请先登录')
96
+
97
+ if not await auth_cache_pool.has_access(access_token):
98
+ raise NoAuthException('登录信息已失效')
99
+
100
+ return access_token
101
+
102
+ async def auth_check(self, request: Request):
103
+ """
104
+ 权限校验
105
+ :param request:
106
+ :return:
107
+ """
108
+ if not self.regex_to_tags_map:
109
+ logger.debug(f'no path to tags map')
110
+ return False
111
+ # 优先处理白名单
112
+ for white_regex in self.auth_whitelist:
113
+ if white_regex.match(request.url.path):
114
+ return True
115
+ # 先校验登录
116
+ access_token = await self.access_check(request)
117
+ for regex in self.regex_to_tags_map.keys():
118
+ if regex.match(request.url.path):
119
+ # 匹配到路由
120
+ tags = self.regex_to_tags_map.get(regex)
121
+ if tags:
122
+ # 需要登录且需要鉴权
123
+ auth_rs = await self._rds_permissions_check(tags, access_token)
124
+ if not auth_rs:
125
+ raise NoAuthException('权限不足')
126
+ return True
127
+ else:
128
+ # 仅登录无需鉴权
129
+ return True
130
+
131
+ return False
132
+
133
+ async def _rds_permissions_check(self, tags, access_token) -> bool:
134
+ # redis 权限查询
135
+ for tag in tags:
136
+ # 新权限颗粒 xx.xx.xx.xx
137
+ has = await auth_cache_pool.has_permission(access_token, tag)
138
+ if not has:
139
+ # 兼容 Java 版权限颗粒 "xx.xx.xx.xx"
140
+ has = await auth_cache_pool.has_permission(access_token, json.dumps(tag))
141
+ if not has:
142
+ return False
143
+
144
+ return True
145
+
146
+ async def get_auth_user(self, request: Request) -> Optional[AuthUser]:
147
+ """
148
+ 获取当前登录的用户信息
149
+ :param request:
150
+ :return:
151
+ """
152
+ access_token = self.get_access_token(request)
153
+ if not access_token:
154
+ raise NoAuthException('请先登录')
155
+ if not await auth_cache_pool.has_access(access_token):
156
+ raise NoAuthException('登录信息已失效')
157
+
158
+ auth_user = await auth_cache_pool.get_cached_user(access_token)
159
+
160
+ # 设置租户上下文
161
+ if auth_user and auth_user.team_id:
162
+ tenant_context.set(auth_user.team_id)
163
+
164
+ return auth_user
165
+
166
+ async def get_auth_app(self, request: Request) -> Optional[AuthApp]:
167
+ """
168
+ 获取当前登录的应用信息
169
+ :param request:
170
+ :return:
171
+ """
172
+ access_token = self.get_access_token(request)
173
+ if not access_token:
174
+ raise NoAuthException('请先登录')
175
+ if not await auth_cache_pool.has_app_access(access_token):
176
+ raise NoAuthException('登录信息已失效')
177
+
178
+ auth_app = await auth_cache_pool.get_cached_app(access_token)
179
+
180
+ return auth_app
181
+
182
+ async def get_auth_team_sn(self, request: Request) -> Optional[str]:
183
+ """
184
+ 获取授权用户所属租户
185
+ :param request:
186
+ :return:
187
+ """
188
+ access_token = self.get_access_token(request)
189
+ if access_token:
190
+ access_item = await auth_cache_pool.get_cached_access_dict(access_token)
191
+ if access_item:
192
+ # 仅超级管理员可以切换租户
193
+ team_sn = access_item.get('team_sn', None)
194
+ if team_sn:
195
+ return team_sn
196
+
197
+ return ''
198
+
199
+
200
+ # 用于handler来Depends,获取授权用户
201
+ async def authed_user(request: Request) -> Optional[AuthUser]:
202
+ return ShareAuth().get_auth_user(request)
203
+
204
+
205
+ # 用于handler来Depends,获取授权APP信息
206
+ async def authed_app(request: Request) -> Optional[AuthApp]:
207
+ return ShareAuth().get_auth_app(request)
208
+
209
+
210
+ # 用于handler来Depends,获取授权用户
211
+ async def access_token(request: Request) -> Optional[str]:
212
+ return ShareAuth().get_access_token(request)
213
+
214
+
215
+
216
+
File without changes
@@ -0,0 +1,233 @@
1
+ # encoding: utf-8
2
+ # version: pymongo==4.7.3, motor==3.5.0
3
+
4
+ from ..exception.exceptions import BizException
5
+ from ..exception.exceptions import NoConfigException
6
+ from ..lib import cfg, logger
7
+
8
+ LOGGER = logger.get('AsyncMongoClient')
9
+
10
+
11
+ class AsyncMongo(object):
12
+ def __init__(self, mongo_uri: str = None, mongo_db: str = None):
13
+ import motor.motor_asyncio
14
+ self.mongo_uri = mongo_uri or cfg.get('MONGO_URL')
15
+ self.mongo_db = mongo_db or cfg.get('MONGO_DB')
16
+ if not self.mongo_uri:
17
+ raise NoConfigException('mongodb uri not config!')
18
+ if not self.mongo_db:
19
+ raise NoConfigException('mongodb database not config!')
20
+
21
+ self.client = motor.motor_asyncio.AsyncIOMotorClient(self.mongo_uri)
22
+ self.db = self.client[self.mongo_db]
23
+
24
+ LOGGER.info(self.client)
25
+
26
+ # 获取集合
27
+ def get_collection(self, coll):
28
+ # 非异步
29
+ return self.db.get_collection(coll)
30
+
31
+ # 查询对象
32
+ async def get(self, collection, query={}):
33
+ return await self.db[collection].find_one(query)
34
+
35
+ # 统计数量
36
+ async def count(self, collection, query={}):
37
+ return await self.db[collection].count_documents(query)
38
+
39
+ # 查询列表
40
+ async def list(self, collection, query={}, fields=None, sort=[], batch_size: int = 2000):
41
+ result = []
42
+ cursor = self.db[collection].find(filter=query, projection=fields, sort=sort)
43
+ cursor.batch_size(batch_size)
44
+ async for document in cursor:
45
+ result.append(document)
46
+ await cursor.close()
47
+
48
+ return result
49
+
50
+ async def list_with_cursor(self, collection, query={}, fields=None, sort=[], batch_size: int = 2000):
51
+ """
52
+ 样例:
53
+ cursor = await aio_mongo.list_with_cursor(collection)
54
+ async for document in cursor:
55
+ print("document =", document)
56
+ await cursor.close()
57
+ 注意:应该在使用完毕后主动关闭 cursor
58
+ """
59
+ cursor = self.db[collection].find(filter=query, projection=fields, sort=sort)
60
+ cursor.batch_size(batch_size)
61
+
62
+ return cursor
63
+
64
+ # 分页查询
65
+ async def page(self, collection, query={}, page_no=1, page_size=20, fields=None, sort=[], batch_size: int = 2000):
66
+ total = await self.db[collection].count_documents(query) or 0
67
+ cursor = self.db[collection].find(query, fields, sort=sort).skip(page_size * (page_no - 1)).limit(page_size)
68
+ cursor.batch_size(batch_size)
69
+ rows = []
70
+ async for document in cursor:
71
+ rows.append(document)
72
+ await cursor.close()
73
+
74
+ return rows, total
75
+
76
+ # 查询列表前N个
77
+ async def top(self, collection, query={}, sort=[], limit=1, fields=None, batch_size: int = 2000):
78
+ cursor = self.db[collection].find(filter=query, projection=fields, sort=sort, limit=limit)
79
+ cursor.batch_size(batch_size)
80
+ rows = []
81
+ async for document in cursor:
82
+ if limit == 1:
83
+ return document
84
+ rows.append(document)
85
+ await cursor.close()
86
+
87
+ return rows
88
+
89
+ # 查询去重列表
90
+ async def distinct(self, collection, dist_key, query={}, fields=None):
91
+ return await self.db[collection].find(query, fields).distinct(dist_key)
92
+
93
+ # 含分页聚合查询
94
+ async def aggregate_page(self, collection, pipelines, page_no=1, page_size=20):
95
+ skip = page_size * (page_no - 1)
96
+ if pipelines:
97
+ pipelines.append({'$facet': {'total': [{'$count': 'count'}], 'rows': [{'$skip': skip}, {'$limit': page_size}]}})
98
+ pipelines.append({'$project': {'data': '$rows', 'total': {'$arrayElemAt': ['$total.count', 0]}}})
99
+
100
+ cursor = self.db[collection].aggregate(pipelines, session=None, allowDiskUse=True)
101
+ cursor.batch_size(page_size)
102
+ async for rs in cursor:
103
+ # rs: dict as {"data": [], "total": int}
104
+ if rs and 'data' in rs and 'total' in rs:
105
+ await cursor.close()
106
+ return rs.get('data'), rs.get('total')
107
+ await cursor.close()
108
+ return [], 0
109
+
110
+ # 聚合查询
111
+ async def aggregate(self, collection, pipelines=[], batch_size: int = 2000):
112
+ """
113
+ 样例:
114
+ cursor = await aio_mongo.aggregate(collection="collection_name", pipelines=[])
115
+ async for row in cursor:
116
+ print(row)
117
+
118
+ await cursor.close()
119
+ 注意:应该在使用完毕后主动关闭 cursor
120
+ """
121
+ cursor = self.db[collection].aggregate(pipelines, session=None, allowDiskUse=True)
122
+ cursor.batch_size(batch_size)
123
+ return cursor
124
+
125
+ # 查询分页列表,还没找到实际应用
126
+ async def list_with_page(self, collection, query={}, page_size=10000, fields=None):
127
+ # 没有用到
128
+ rows = list()
129
+ total = await self.db[collection].count_documents(query)
130
+ if total > 0 and page_size > 0:
131
+ total_page = round(total / page_size) + 1
132
+ for page in range(0, total_page):
133
+ if fields:
134
+ cursor = self.db[collection].find(query, fields).skip(page_size * page).limit(page)
135
+ else:
136
+ cursor = self.db[collection].find(query).skip(page_size * page).limit(page)
137
+ async for document in cursor:
138
+ rows.append(document)
139
+ await cursor.close()
140
+ return rows
141
+
142
+ # 插入或更新
143
+ async def insert_or_update(self, collection, data, id_key='_id', update=None, upsert=True, multi=False):
144
+ if not multi:
145
+ if data and not update:
146
+ result = await self.db[collection].update_one({id_key: data[id_key]}, {'$set': data}, upsert=upsert)
147
+ elif not data and update:
148
+ result = await self.db[collection].update_one({id_key: data[id_key]}, update, upsert=upsert)
149
+ else:
150
+ # all([data, update]) or not all([data, update])
151
+ raise BizException("data和update不能同时存在或同时为空")
152
+ else:
153
+ if data and not update:
154
+ result = await self.db[collection].update_many({id_key: data[id_key]}, {'$set': data}, upsert=upsert)
155
+ elif not data and update:
156
+ result = await self.db[collection].update_many({id_key: data[id_key]}, update, upsert=upsert)
157
+ else:
158
+ raise BizException("data和update不能同时存在或同时为空")
159
+
160
+ return result
161
+
162
+ # 插入
163
+ async def insert(self, collection, data):
164
+ return await self.db[collection].insert_one(data)
165
+
166
+ # 更新
167
+ async def update(self, collection, filter, data=None, update=None, multi=False):
168
+ if multi:
169
+ if data and not update:
170
+ result = await self.db[collection].update_many(filter, {'$set': data})
171
+ elif not data and update:
172
+ result = await self.db[collection].update_many(filter, update)
173
+ else:
174
+ raise BizException("data和update不能同时存在或同时为空")
175
+ else:
176
+ if data and not update:
177
+ result = await self.db[collection].update_one(filter, {'$set': data})
178
+ elif not data and update:
179
+ result = await self.db[collection].update_one(filter, update)
180
+ else:
181
+ raise BizException("data和update不能同时存在或同时为空")
182
+
183
+ return result
184
+
185
+ # 原生保存方法
186
+ async def save(self, collection, filter: dict, save_data: dict, upsert=True):
187
+ return await self.db[collection].update_one(filter, {'$set': save_data}, upsert=upsert)
188
+
189
+ # 以主键更新
190
+ async def update_by_pk(self, collection, pk_val, data=None, update=None, upsert=False):
191
+ if data and not update:
192
+ result = await self.db[collection].update_one({'_id': pk_val}, {'$set': data}, upsert=upsert)
193
+ elif not data and update:
194
+ result = await self.db[collection].update_one({'_id': pk_val}, update, upsert=upsert)
195
+ else:
196
+ raise BizException("data和update不能同时存在或同时为空")
197
+ return result
198
+
199
+ # 批量更新
200
+ async def batch_update(self, collection, filter, datas=None, update=None, *args, **kwargs):
201
+ if datas and not update:
202
+ result = await self.db[collection].update_many(filter, {'$set': datas})
203
+ elif not datas and update:
204
+ result = await self.db[collection].update_many(filter, update)
205
+ else:
206
+ raise BizException("data和update不能同时存在或同时为空")
207
+ return result
208
+
209
+ # 删除
210
+ async def delete(self, collection, filter):
211
+ return await self.db[collection].delete_many(filter)
212
+
213
+ # 插入或更新
214
+ async def bulk_write(self, collection, bulk_list: list, batch_size: int = 1000):
215
+ result = None
216
+ if bulk_list:
217
+ bulk_lists = [bulk_list[i: i + batch_size] for i in range(0, len(bulk_list), batch_size)]
218
+ for _bulk_list in bulk_lists:
219
+ result = await self.db[collection].bulk_write(_bulk_list, ordered=False, bypass_document_validation=True)
220
+ return result
221
+
222
+ # 创建索引
223
+ async def create_index(self, collection, fields):
224
+ return await self.db[collection].create_index(fields)
225
+
226
+ def close(self):
227
+ if self.client:
228
+ try:
229
+ self.client.close()
230
+ print("close successful")
231
+ except Exception as e:
232
+ print("close mongo client catch err:", e)
233
+