rainycode 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- common/aiorequests.py +31 -0
- common/consts.py +3 -0
- common/exception.py +68 -0
- common/logging.py +84 -0
- common/response.py +63 -0
- common_depend/auth_depend.py +75 -0
- common_model/base_model.py +200 -0
- common_model/user_model.py +25 -0
- common_model/wechat_model.py +43 -0
- common_utlis/bcrypt_util.py +63 -0
- common_utlis/captcha_util.py +130 -0
- common_utlis/ip_util.py +83 -0
- common_utlis/jwt_util.py +47 -0
- common_utlis/snowflake_util.py +205 -0
- core/base_config.py +28 -0
- core/databases/aiodb.py +360 -0
- core/databases/aioredis.py +279 -0
- core/middleware/http_middleware.py +28 -0
- core/start.py +93 -0
- rainycode-1.0.0.dist-info/METADATA +15 -0
- rainycode-1.0.0.dist-info/RECORD +23 -0
- rainycode-1.0.0.dist-info/WHEEL +5 -0
- rainycode-1.0.0.dist-info/top_level.txt +5 -0
core/databases/aiodb.py
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
from typing import Dict, List, Optional, Any, Tuple
|
|
3
|
+
from fastapi import FastAPI
|
|
4
|
+
from fastapi.encoders import jsonable_encoder
|
|
5
|
+
from tortoise import Tortoise
|
|
6
|
+
from tortoise.models import Model
|
|
7
|
+
from tortoise.contrib.fastapi import RegisterTortoise
|
|
8
|
+
from tortoise.queryset import QuerySetSingle, QuerySet
|
|
9
|
+
from contextlib import asynccontextmanager
|
|
10
|
+
|
|
11
|
+
from core.base_config import base_config
|
|
12
|
+
|
|
13
|
+
db_models = [
|
|
14
|
+
'aerich.models',
|
|
15
|
+
'models'
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
db_config = {
|
|
19
|
+
'connections': {
|
|
20
|
+
'main': base_config.mysql_url, # 默认数据库连接
|
|
21
|
+
},
|
|
22
|
+
'apps': {
|
|
23
|
+
'main': {'models': db_models, 'default_connection': 'main'},
|
|
24
|
+
},
|
|
25
|
+
'use_tz': False,
|
|
26
|
+
'timezone': 'Asia/Shanghai'
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
class AioDb:
|
|
30
|
+
|
|
31
|
+
### 事务的两种方式:
|
|
32
|
+
# async with in_transaction() as connection:
|
|
33
|
+
# pass
|
|
34
|
+
#
|
|
35
|
+
# @atomic('main') (推荐使用)
|
|
36
|
+
# def func():
|
|
37
|
+
# pass
|
|
38
|
+
#
|
|
39
|
+
|
|
40
|
+
@classmethod
|
|
41
|
+
async def connect(cls):
|
|
42
|
+
await Tortoise.init(config=db_config)
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
async def close(cls):
|
|
46
|
+
await Tortoise.close_connections()
|
|
47
|
+
|
|
48
|
+
@classmethod
|
|
49
|
+
async def register_db(cls, app: FastAPI) -> RegisterTortoise:
|
|
50
|
+
# 注册数据库,FastAPI用
|
|
51
|
+
tortoise = RegisterTortoise(
|
|
52
|
+
app,
|
|
53
|
+
config=db_config,
|
|
54
|
+
generate_schemas=False, # 自动创建不存在的表(生产环境建议关闭,手动管理)
|
|
55
|
+
add_exception_handlers=True # 添加 Tortoise 异常处理器
|
|
56
|
+
)
|
|
57
|
+
return tortoise
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@classmethod
|
|
61
|
+
async def init_tables(cls):
|
|
62
|
+
"""初始化数据库表,如果表不存在则创建"""
|
|
63
|
+
try:
|
|
64
|
+
await Tortoise.generate_schemas()
|
|
65
|
+
except Exception as e:
|
|
66
|
+
raise e
|
|
67
|
+
|
|
68
|
+
@classmethod
|
|
69
|
+
@asynccontextmanager
|
|
70
|
+
async def execute(cls):
|
|
71
|
+
"""
|
|
72
|
+
:example:
|
|
73
|
+
async with AioDb.execute():
|
|
74
|
+
await async_func()
|
|
75
|
+
"""
|
|
76
|
+
await cls.connect()
|
|
77
|
+
yield
|
|
78
|
+
await cls.close()
|
|
79
|
+
|
|
80
|
+
@classmethod
|
|
81
|
+
def get_model_by_name(cls, name, app='main'):
|
|
82
|
+
return Tortoise.apps[app][name]
|
|
83
|
+
|
|
84
|
+
@classmethod
|
|
85
|
+
async def run_func(cls, func, *args, **kwargs):
|
|
86
|
+
async with AioDb.execute():
|
|
87
|
+
await func(*args, **kwargs)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class MySQLUtils(Model):
|
|
91
|
+
|
|
92
|
+
async def model_dump(self, **kwargs):
|
|
93
|
+
return await MySQLUtils.model2dict(self, **kwargs)
|
|
94
|
+
|
|
95
|
+
@staticmethod
|
|
96
|
+
def json_encoder(
|
|
97
|
+
obj: Any,
|
|
98
|
+
include: Optional[List[str]] = None,
|
|
99
|
+
exclude: Optional[List[str]] = None,
|
|
100
|
+
**kwargs
|
|
101
|
+
):
|
|
102
|
+
return jsonable_encoder(
|
|
103
|
+
obj,
|
|
104
|
+
include=set(include) if include else None,
|
|
105
|
+
exclude=set(exclude) if exclude else None,
|
|
106
|
+
custom_encoder={
|
|
107
|
+
datetime.datetime: lambda dt: datetime.datetime.strftime(dt, '%Y-%m-%d %H:%M:%S')
|
|
108
|
+
}
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
@classmethod
|
|
112
|
+
async def model2dict(
|
|
113
|
+
cls,
|
|
114
|
+
model: Model,
|
|
115
|
+
include: List[str] | None = None,
|
|
116
|
+
exclude: List[str] | None = None,
|
|
117
|
+
alias: List[Tuple[str, str]] | None = None,
|
|
118
|
+
fetch_related: List[str] | None = None,
|
|
119
|
+
flat: bool = False, # 扁平化
|
|
120
|
+
**kwargs
|
|
121
|
+
) -> Dict[str, Any]:
|
|
122
|
+
meta = getattr(cls, 'PydanticMeta', getattr(cls, 'Meta', None))
|
|
123
|
+
if meta:
|
|
124
|
+
_include = getattr(meta, 'include', None)
|
|
125
|
+
_exclude = getattr(meta, 'exclude', None)
|
|
126
|
+
if _include:
|
|
127
|
+
include = list(set(include + _include)) if include else _include
|
|
128
|
+
if _exclude:
|
|
129
|
+
exclude = list(set(exclude + _exclude)) if exclude else _exclude
|
|
130
|
+
|
|
131
|
+
if include and exclude:
|
|
132
|
+
for include_field in include:
|
|
133
|
+
if include_field in exclude:
|
|
134
|
+
exclude.remove(include_field)
|
|
135
|
+
|
|
136
|
+
alias_dict = dict()
|
|
137
|
+
if alias:
|
|
138
|
+
for item in alias:
|
|
139
|
+
alias_dict[item[0]] = item[1]
|
|
140
|
+
|
|
141
|
+
# 默认过滤掉deleted字段
|
|
142
|
+
if exclude:
|
|
143
|
+
exclude.append('deleted')
|
|
144
|
+
else:
|
|
145
|
+
exclude = ['deleted']
|
|
146
|
+
|
|
147
|
+
result = cls.json_encoder(
|
|
148
|
+
obj=model,
|
|
149
|
+
include=include,
|
|
150
|
+
exclude=exclude,
|
|
151
|
+
**kwargs
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
if alias_dict:
|
|
155
|
+
for attr in result.keys():
|
|
156
|
+
if attr in alias_dict:
|
|
157
|
+
result[alias_dict[attr]] = result.pop(attr)
|
|
158
|
+
|
|
159
|
+
fix_data = getattr(model, 'fix_data', None)
|
|
160
|
+
if fix_data and callable(fix_data):
|
|
161
|
+
fix_data(result)
|
|
162
|
+
|
|
163
|
+
if fetch_related:
|
|
164
|
+
for fr in fetch_related:
|
|
165
|
+
fr_include = kwargs.get(f'fr_{fr}_include')
|
|
166
|
+
fr_exclude = kwargs.get(f'fr_{fr}_exclude')
|
|
167
|
+
fr_field = getattr(model, fr)
|
|
168
|
+
if not fr_field:
|
|
169
|
+
result[fr] = None
|
|
170
|
+
continue
|
|
171
|
+
fr_model = await fr_field
|
|
172
|
+
|
|
173
|
+
fr_meta = getattr(fr_model, 'PydanticMeta', None)
|
|
174
|
+
if fr_meta:
|
|
175
|
+
_fr_include = getattr(fr_meta, 'include', None)
|
|
176
|
+
_fr_exclude = getattr(fr_meta, 'exclude', None)
|
|
177
|
+
if _fr_include:
|
|
178
|
+
fr_include = list(set(fr_include + _fr_include)) if fr_include else _fr_include
|
|
179
|
+
if _fr_exclude:
|
|
180
|
+
fr_exclude = list(set(fr_exclude + _fr_exclude)) if fr_exclude else _fr_exclude
|
|
181
|
+
|
|
182
|
+
if fr_include and fr_exclude:
|
|
183
|
+
for fr_include_field in fr_include:
|
|
184
|
+
if fr_include_field in fr_exclude:
|
|
185
|
+
fr_exclude.remove(fr_include_field)
|
|
186
|
+
|
|
187
|
+
fix_data = None
|
|
188
|
+
if isinstance(fr_model, list) and fr_model:
|
|
189
|
+
fix_data = getattr(fr_model[0], 'fix_data', None)
|
|
190
|
+
elif isinstance(fr_model, Model):
|
|
191
|
+
fix_data = getattr(fr_model, 'fix_data', None)
|
|
192
|
+
|
|
193
|
+
# 默认过滤掉deleted字段
|
|
194
|
+
if fr_exclude:
|
|
195
|
+
fr_exclude.append('deleted')
|
|
196
|
+
else:
|
|
197
|
+
fr_exclude = ['deleted']
|
|
198
|
+
|
|
199
|
+
fr_result = cls.json_encoder(
|
|
200
|
+
obj=fr_model,
|
|
201
|
+
include=fr_include,
|
|
202
|
+
exclude=fr_exclude,
|
|
203
|
+
**kwargs
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
if alias_dict:
|
|
207
|
+
for attr in result.keys():
|
|
208
|
+
key = f'{fr}.{attr}'
|
|
209
|
+
if key in alias_dict:
|
|
210
|
+
fr_result[alias_dict[key]] = fr_result.pop(attr)
|
|
211
|
+
|
|
212
|
+
if fix_data and callable(fix_data):
|
|
213
|
+
if isinstance(fr_result, list):
|
|
214
|
+
for frr in fr_result:
|
|
215
|
+
fix_data(frr)
|
|
216
|
+
elif isinstance(fr_result, dict):
|
|
217
|
+
fix_data(fr_result)
|
|
218
|
+
|
|
219
|
+
if flat and isinstance(fr_result, dict):
|
|
220
|
+
for k, v in fr_result.items():
|
|
221
|
+
result[fr + '_' + k] = v
|
|
222
|
+
else:
|
|
223
|
+
result[fr] = fr_result
|
|
224
|
+
return result
|
|
225
|
+
|
|
226
|
+
@classmethod
|
|
227
|
+
async def get_data(
|
|
228
|
+
cls,
|
|
229
|
+
*args,
|
|
230
|
+
include: List[str] | None = None, # 包含哪些字段,默认全部
|
|
231
|
+
exclude: List[str] | None = None, # 排除哪些字段,默认全部
|
|
232
|
+
alias: List[Tuple[str, str]] | None = None, # 字段别名,如[('id', 'uid')]
|
|
233
|
+
fetch_related: List[str] | None = None, # 对应关系表
|
|
234
|
+
flat: bool = False, # 扁平化
|
|
235
|
+
deleted: bool | None = None,
|
|
236
|
+
queryset_model: QuerySetSingle | Model | None = None,
|
|
237
|
+
**kwargs
|
|
238
|
+
) -> Dict[str, Any] | None:
|
|
239
|
+
if 'deleted' in cls._meta.fields and deleted is None:
|
|
240
|
+
kwargs['deleted'] = False
|
|
241
|
+
|
|
242
|
+
queryKwargs = {}
|
|
243
|
+
if fetch_related:
|
|
244
|
+
for k, v in kwargs.items():
|
|
245
|
+
if k.startswith('fr_'):
|
|
246
|
+
continue
|
|
247
|
+
queryKwargs[k] = v
|
|
248
|
+
else:
|
|
249
|
+
queryKwargs = kwargs
|
|
250
|
+
|
|
251
|
+
qm: QuerySetSingle | Model = queryset_model if queryset_model else cls.get(*args, **queryKwargs)
|
|
252
|
+
if isinstance(qm, Model):
|
|
253
|
+
if fetch_related:
|
|
254
|
+
await qm.fetch_related(*fetch_related)
|
|
255
|
+
model = qm
|
|
256
|
+
else:
|
|
257
|
+
if fetch_related:
|
|
258
|
+
qm = qm.prefetch_related(*fetch_related)
|
|
259
|
+
model = await qm
|
|
260
|
+
|
|
261
|
+
if not model:
|
|
262
|
+
return None
|
|
263
|
+
result = await cls.model2dict(model, include, exclude, alias, fetch_related, flat, **kwargs)
|
|
264
|
+
return result
|
|
265
|
+
|
|
266
|
+
@classmethod
|
|
267
|
+
async def get_or_none_data(
|
|
268
|
+
cls,
|
|
269
|
+
*args,
|
|
270
|
+
include: List[str] | None = None, # 包含哪些字段,默认全部
|
|
271
|
+
exclude: List[str] | None = None, # 排除哪些字段,默认全部
|
|
272
|
+
alias: List[Tuple[str, str]] | None = None, # 字段别名,如[('id', 'uid')]
|
|
273
|
+
fetch_related: List[str] | None = None, # 对应关系表
|
|
274
|
+
flat: bool = False, # 扁平化
|
|
275
|
+
deleted: bool | None = None,
|
|
276
|
+
queryset_model: QuerySetSingle | Model | None = None,
|
|
277
|
+
**kwargs
|
|
278
|
+
) -> Dict[str, Any] | None:
|
|
279
|
+
if 'deleted' in cls._meta.fields and deleted is None:
|
|
280
|
+
kwargs['deleted'] = False
|
|
281
|
+
|
|
282
|
+
queryKwargs = {}
|
|
283
|
+
if fetch_related:
|
|
284
|
+
for k, v in kwargs.items():
|
|
285
|
+
if k.startswith('fr_'):
|
|
286
|
+
continue
|
|
287
|
+
queryKwargs[k] = v
|
|
288
|
+
else:
|
|
289
|
+
queryKwargs = kwargs
|
|
290
|
+
|
|
291
|
+
qm: QuerySetSingle | Model = queryset_model if queryset_model else cls.get_or_none(*args, **kwargs)
|
|
292
|
+
if isinstance(qm, Model):
|
|
293
|
+
if fetch_related:
|
|
294
|
+
await qm.fetch_related(*fetch_related)
|
|
295
|
+
model = qm
|
|
296
|
+
else:
|
|
297
|
+
if fetch_related:
|
|
298
|
+
qm = qm.prefetch_related(*fetch_related)
|
|
299
|
+
model = await qm
|
|
300
|
+
|
|
301
|
+
if not model:
|
|
302
|
+
return None
|
|
303
|
+
result = await cls.model2dict(model, include, exclude, alias, fetch_related, flat, **kwargs)
|
|
304
|
+
return result
|
|
305
|
+
|
|
306
|
+
@classmethod
|
|
307
|
+
async def query_data(
|
|
308
|
+
cls,
|
|
309
|
+
*args,
|
|
310
|
+
page_no: int | None = 0,
|
|
311
|
+
page_size: int | None = 0,
|
|
312
|
+
nocount: bool = False,
|
|
313
|
+
include: List[str] | None = None, # 包含哪些字段,默认全部
|
|
314
|
+
exclude: List[str] | None = None, # 排除哪些字段,默认全部
|
|
315
|
+
alias: List[Tuple[str, str]] | None = None, # 字段别名,如[('id', 'uid')]
|
|
316
|
+
fetch_related: List[str] | None = None, # 对应关系表
|
|
317
|
+
flat: bool = False, # 扁平化
|
|
318
|
+
deleted: bool | None = None,
|
|
319
|
+
queryset: QuerySet | None = None,
|
|
320
|
+
order_by: List[str] | None = None,
|
|
321
|
+
**kwargs
|
|
322
|
+
) -> Dict[str, Any] | List[Dict[str, Any]]:
|
|
323
|
+
if 'deleted' in cls._meta.fields and deleted is None:
|
|
324
|
+
kwargs['deleted'] = False
|
|
325
|
+
|
|
326
|
+
queryKwargs = {}
|
|
327
|
+
if fetch_related:
|
|
328
|
+
for k, v in kwargs.items():
|
|
329
|
+
if k.startswith('fr_'):
|
|
330
|
+
continue
|
|
331
|
+
queryKwargs[k] = v
|
|
332
|
+
else:
|
|
333
|
+
queryKwargs = kwargs
|
|
334
|
+
|
|
335
|
+
qs: QuerySet = queryset if queryset else cls.filter(*args, **kwargs)
|
|
336
|
+
result = {'items': [], 'count': 0}
|
|
337
|
+
if not nocount:
|
|
338
|
+
result['count'] = await qs.count()
|
|
339
|
+
else:
|
|
340
|
+
del result['count']
|
|
341
|
+
|
|
342
|
+
if order_by:
|
|
343
|
+
qs = qs.order_by(*order_by)
|
|
344
|
+
if page_no and page_size:
|
|
345
|
+
qs = qs.limit(page_size).offset((page_no - 1) * page_size)
|
|
346
|
+
if fetch_related:
|
|
347
|
+
qs = qs.prefetch_related(*fetch_related)
|
|
348
|
+
|
|
349
|
+
model_list = await qs
|
|
350
|
+
for model in model_list:
|
|
351
|
+
item = await cls.model2dict(model, include, exclude, alias, fetch_related, flat, **kwargs)
|
|
352
|
+
result['items'].append(item)
|
|
353
|
+
|
|
354
|
+
if nocount:
|
|
355
|
+
return result['items']
|
|
356
|
+
|
|
357
|
+
return result
|
|
358
|
+
|
|
359
|
+
class Meta:
|
|
360
|
+
abstract = True
|
|
@@ -0,0 +1,279 @@
|
|
|
1
|
+
import time
|
|
2
|
+
import asyncio
|
|
3
|
+
from contextlib import asynccontextmanager
|
|
4
|
+
from typing import Optional, Union, Callable, Mapping
|
|
5
|
+
from redis.asyncio import Redis
|
|
6
|
+
from redis import asyncio as aioredis
|
|
7
|
+
from redis.typing import KeyT, EncodableT, ExpiryT, ZScoreBoundT, AnyKeyT
|
|
8
|
+
from common.logging import log_exc
|
|
9
|
+
from core.base_config import base_config
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class AioRedis:
|
|
13
|
+
"""异步 Redis 客户端封装"""
|
|
14
|
+
|
|
15
|
+
__redis_pool: Redis | None = None
|
|
16
|
+
__connection_lock: asyncio.Lock = asyncio.Lock() # 确保连接池初始化线程安全
|
|
17
|
+
|
|
18
|
+
@classmethod
|
|
19
|
+
async def connect(cls) -> Redis:
|
|
20
|
+
"""获取 Redis 连接(单例连接池)"""
|
|
21
|
+
if cls.__redis_pool:
|
|
22
|
+
return cls.__redis_pool
|
|
23
|
+
|
|
24
|
+
# 使用锁确保连接池初始化的线程安全
|
|
25
|
+
async with cls.__connection_lock:
|
|
26
|
+
if not cls.__redis_pool: # 双重检查
|
|
27
|
+
try:
|
|
28
|
+
pool = aioredis.ConnectionPool.from_url(
|
|
29
|
+
base_config.redis_url,
|
|
30
|
+
encoding='utf-8',
|
|
31
|
+
decode_responses=True,
|
|
32
|
+
retry_on_timeout=True # 增加超时重试
|
|
33
|
+
)
|
|
34
|
+
cls.__redis_pool = Redis(connection_pool=pool)
|
|
35
|
+
# 测试连接是否有效
|
|
36
|
+
cls.__redis_pool.ping()
|
|
37
|
+
except Exception as e:
|
|
38
|
+
log_exc(e)
|
|
39
|
+
raise
|
|
40
|
+
return cls.__redis_pool
|
|
41
|
+
|
|
42
|
+
@classmethod
|
|
43
|
+
async def close(cls) -> None:
|
|
44
|
+
"""关闭 Redis 连接池"""
|
|
45
|
+
if cls.__redis_pool:
|
|
46
|
+
# 安全关闭连接池
|
|
47
|
+
try:
|
|
48
|
+
await cls.__redis_pool.connection_pool.disconnect()
|
|
49
|
+
except Exception as e:
|
|
50
|
+
log_exc(e)
|
|
51
|
+
finally:
|
|
52
|
+
cls.__redis_pool = None
|
|
53
|
+
|
|
54
|
+
@classmethod
|
|
55
|
+
@asynccontextmanager
|
|
56
|
+
async def session(cls):
|
|
57
|
+
"""上下文管理器,自动管理连接生命周期"""
|
|
58
|
+
try:
|
|
59
|
+
redis = await cls.connect()
|
|
60
|
+
yield redis
|
|
61
|
+
finally:
|
|
62
|
+
# 连接池模式下不需要每次关闭,由外部统一管理
|
|
63
|
+
pass
|
|
64
|
+
|
|
65
|
+
@classmethod
|
|
66
|
+
async def get(cls, name: KeyT) -> Optional[bytes]:
|
|
67
|
+
redis = await cls.connect()
|
|
68
|
+
result = await redis.get(name)
|
|
69
|
+
return result
|
|
70
|
+
|
|
71
|
+
@classmethod
|
|
72
|
+
async def set(
|
|
73
|
+
cls,
|
|
74
|
+
name: KeyT,
|
|
75
|
+
value: EncodableT,
|
|
76
|
+
ex: Union[ExpiryT, None] = None, # 超时秒数
|
|
77
|
+
px: Union[ExpiryT, None] = None, # 超时毫秒
|
|
78
|
+
nx: bool = False,
|
|
79
|
+
xx: bool = False
|
|
80
|
+
) -> Optional[bool]:
|
|
81
|
+
redis = await cls.connect()
|
|
82
|
+
result = await redis.set(name, value, ex=ex, px=px, nx=nx, xx=xx)
|
|
83
|
+
return result
|
|
84
|
+
|
|
85
|
+
@classmethod
|
|
86
|
+
async def exists(
|
|
87
|
+
cls,
|
|
88
|
+
*names: KeyT
|
|
89
|
+
) -> Optional[bool]:
|
|
90
|
+
redis = await cls.connect()
|
|
91
|
+
result = await redis.exists(*names)
|
|
92
|
+
return result
|
|
93
|
+
|
|
94
|
+
@classmethod
|
|
95
|
+
async def zadd(
|
|
96
|
+
cls,
|
|
97
|
+
name: KeyT,
|
|
98
|
+
mapping: Mapping[AnyKeyT, EncodableT],
|
|
99
|
+
nx: bool = False,
|
|
100
|
+
xx: bool = False,
|
|
101
|
+
ch: bool = False,
|
|
102
|
+
incr: bool = False,
|
|
103
|
+
gt: bool = False,
|
|
104
|
+
lt: bool = False
|
|
105
|
+
) -> Union[int, float]:
|
|
106
|
+
redis = await cls.connect()
|
|
107
|
+
result = await redis.zadd(name, mapping, nx, xx, ch, incr, gt, lt)
|
|
108
|
+
return result
|
|
109
|
+
|
|
110
|
+
@classmethod
|
|
111
|
+
async def zrem(cls, name: KeyT, values) -> int:
|
|
112
|
+
redis = await cls.connect()
|
|
113
|
+
result = await redis.zrem(name, values)
|
|
114
|
+
return result
|
|
115
|
+
|
|
116
|
+
@classmethod
|
|
117
|
+
async def zscore(
|
|
118
|
+
cls,
|
|
119
|
+
name: KeyT,
|
|
120
|
+
value: EncodableT
|
|
121
|
+
) -> Optional[float]:
|
|
122
|
+
redis = await cls.connect()
|
|
123
|
+
result = await redis.zscore(name, value)
|
|
124
|
+
return result
|
|
125
|
+
|
|
126
|
+
@classmethod
|
|
127
|
+
async def zcard(cls, name: KeyT) -> int:
|
|
128
|
+
redis = await cls.connect()
|
|
129
|
+
result = await redis.zcard(name)
|
|
130
|
+
return result
|
|
131
|
+
|
|
132
|
+
@classmethod
|
|
133
|
+
async def zrevrangebyscore(
|
|
134
|
+
cls,
|
|
135
|
+
name: KeyT,
|
|
136
|
+
max: ZScoreBoundT,
|
|
137
|
+
min: ZScoreBoundT,
|
|
138
|
+
start: Union[int, None] = None,
|
|
139
|
+
num: Union[int, None] = None,
|
|
140
|
+
withscores: bool = False,
|
|
141
|
+
score_cast_func: Union[type, Callable] = float,
|
|
142
|
+
):
|
|
143
|
+
redis = await cls.connect()
|
|
144
|
+
result = await redis.zrevrangebyscore(name, max, min, start, num, withscores, score_cast_func)
|
|
145
|
+
return result
|
|
146
|
+
|
|
147
|
+
@classmethod
|
|
148
|
+
async def zremrangebyscore(
|
|
149
|
+
cls,
|
|
150
|
+
name: KeyT,
|
|
151
|
+
min: ZScoreBoundT,
|
|
152
|
+
max: ZScoreBoundT
|
|
153
|
+
) -> int:
|
|
154
|
+
redis = await cls.connect()
|
|
155
|
+
result = await redis.zremrangebyscore(name, min, max)
|
|
156
|
+
return result
|
|
157
|
+
|
|
158
|
+
@classmethod
|
|
159
|
+
async def delete(cls, *names: KeyT) -> int:
|
|
160
|
+
redis = await cls.connect()
|
|
161
|
+
result = await redis.delete(*names)
|
|
162
|
+
return result
|
|
163
|
+
|
|
164
|
+
@classmethod
|
|
165
|
+
async def ttl(cls, name: KeyT) -> int:
|
|
166
|
+
redis = await cls.connect()
|
|
167
|
+
result = await redis.ttl(name)
|
|
168
|
+
return result
|
|
169
|
+
|
|
170
|
+
@classmethod
|
|
171
|
+
async def expire(cls, name: KeyT, expire: ExpiryT) -> int:
|
|
172
|
+
redis = await cls.connect()
|
|
173
|
+
return await redis.expire(name, expire)
|
|
174
|
+
|
|
175
|
+
@classmethod
|
|
176
|
+
async def incrby(cls, name: KeyT, amount: int = 1) -> int:
|
|
177
|
+
redis = await cls.connect()
|
|
178
|
+
return await redis.incrby(name, amount)
|
|
179
|
+
|
|
180
|
+
@classmethod
|
|
181
|
+
async def decrby(cls, name: KeyT, amount: int = 1) -> int:
|
|
182
|
+
redis = await cls.connect()
|
|
183
|
+
return await redis.decrby(name, amount)
|
|
184
|
+
|
|
185
|
+
# 分布式锁实现优化
|
|
186
|
+
@classmethod
|
|
187
|
+
@asynccontextmanager
|
|
188
|
+
async def lock(
|
|
189
|
+
cls,
|
|
190
|
+
name: str,
|
|
191
|
+
value: EncodableT = 1,
|
|
192
|
+
timeout: int = 3, # 等待锁的时间(秒)
|
|
193
|
+
expire: int = 3 # 锁的有效期(秒)
|
|
194
|
+
):
|
|
195
|
+
if isinstance(value, int):
|
|
196
|
+
value = str(value)
|
|
197
|
+
|
|
198
|
+
lock_key = f"lock:{name}"
|
|
199
|
+
acquired = False
|
|
200
|
+
renewal_task: asyncio.Task | None = None
|
|
201
|
+
|
|
202
|
+
try:
|
|
203
|
+
# 获取锁
|
|
204
|
+
timeout_at = time.time() + timeout
|
|
205
|
+
while time.time() < timeout_at:
|
|
206
|
+
# 使用 set 命令的 nx 和 ex 参数原子性获取锁
|
|
207
|
+
result = await cls.set(lock_key, value, nx=True, ex=expire)
|
|
208
|
+
if result:
|
|
209
|
+
acquired = True
|
|
210
|
+
break
|
|
211
|
+
await asyncio.sleep(0.1) # 短暂等待后重试
|
|
212
|
+
|
|
213
|
+
if not acquired:
|
|
214
|
+
raise Exception(f'获取分布式锁失败:{name}')
|
|
215
|
+
|
|
216
|
+
# 启动锁续期任务(只在需要时启动)
|
|
217
|
+
async def renewal():
|
|
218
|
+
# 续期间隔设为锁有效期的 1/3,确保在锁过期前完成续期
|
|
219
|
+
renewal_interval = max(0.1, expire / 3)
|
|
220
|
+
while acquired:
|
|
221
|
+
# 检查剩余时间,不足 1 秒时续期
|
|
222
|
+
if await cls.ttl(lock_key) < 1:
|
|
223
|
+
await cls.expire(lock_key, expire)
|
|
224
|
+
await asyncio.sleep(renewal_interval)
|
|
225
|
+
|
|
226
|
+
renewal_task = asyncio.create_task(renewal())
|
|
227
|
+
yield acquired
|
|
228
|
+
|
|
229
|
+
except Exception as e:
|
|
230
|
+
log_exc(e)
|
|
231
|
+
raise
|
|
232
|
+
finally:
|
|
233
|
+
# 清理工作
|
|
234
|
+
acquired = False # 停止续期任务
|
|
235
|
+
if renewal_task:
|
|
236
|
+
# 等待续期任务结束
|
|
237
|
+
renewal_task.cancel()
|
|
238
|
+
try:
|
|
239
|
+
await renewal_task
|
|
240
|
+
except asyncio.CancelledError:
|
|
241
|
+
pass
|
|
242
|
+
|
|
243
|
+
# 只删除自己持有的锁
|
|
244
|
+
if await cls.get(lock_key) == value:
|
|
245
|
+
await cls.delete(lock_key)
|
|
246
|
+
|
|
247
|
+
@classmethod
|
|
248
|
+
async def pass_fix_window(cls, name: str, limit: int, expire: int) -> bool:
|
|
249
|
+
"""固定窗口限流算法"""
|
|
250
|
+
key = f"ratelimit:fixed:{name}"
|
|
251
|
+
async with cls.lock(name):
|
|
252
|
+
# 使用 incr 原子操作替代 get + set,减少 Redis 调用
|
|
253
|
+
count = await cls.incrby(key, 1)
|
|
254
|
+
# 首次设置过期时间
|
|
255
|
+
if count == 1:
|
|
256
|
+
await cls.expire(key, expire)
|
|
257
|
+
# 将 count 转换为整数后再进行比较
|
|
258
|
+
return count <= limit
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
@classmethod
|
|
262
|
+
async def pass_slide_window(cls, name, limit, expire):
|
|
263
|
+
"""限流:滑动窗口"""
|
|
264
|
+
key = f"slidelimit:fixed:{name}"
|
|
265
|
+
async with cls.lock(name):
|
|
266
|
+
count = await cls.zcard(key)
|
|
267
|
+
if not count or count < limit:
|
|
268
|
+
now_timestamp = int(time.time() * 1000)
|
|
269
|
+
old_timestamp = now_timestamp - expire * 1000
|
|
270
|
+
await cls.zadd(key, {str(now_timestamp): now_timestamp})
|
|
271
|
+
await cls.zremrangebyscore(key, 0, old_timestamp)
|
|
272
|
+
await cls.expire(key, expire)
|
|
273
|
+
return True
|
|
274
|
+
else:
|
|
275
|
+
now_timestamp = int(time.time() * 1000)
|
|
276
|
+
old_timestamp = now_timestamp - expire * 1000
|
|
277
|
+
await cls.zremrangebyscore(key, 0, old_timestamp)
|
|
278
|
+
|
|
279
|
+
return False
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from fastapi import Request
|
|
3
|
+
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
|
4
|
+
from starlette.types import ASGIApp
|
|
5
|
+
from starlette.responses import Response
|
|
6
|
+
from common.logging import logger
|
|
7
|
+
|
|
8
|
+
class RequestHttpMiddleware(BaseHTTPMiddleware):
|
|
9
|
+
def __init__(self, app: ASGIApp) -> None:
|
|
10
|
+
super().__init__(app)
|
|
11
|
+
|
|
12
|
+
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
|
|
13
|
+
start_time = time.time()
|
|
14
|
+
response = await call_next(request)
|
|
15
|
+
process_time = time.time() - start_time
|
|
16
|
+
response.headers['X-Process-Time'] = str(process_time)
|
|
17
|
+
self.write_request_log(request, response)
|
|
18
|
+
return response
|
|
19
|
+
|
|
20
|
+
@staticmethod
|
|
21
|
+
def write_request_log(request: Request, response: Response) -> None:
|
|
22
|
+
http_version = f"http/{request.scope['http_version']}"
|
|
23
|
+
process_time = response.headers["X-Process-Time"]
|
|
24
|
+
# 获取客户端IP和端口
|
|
25
|
+
client_host = request.client.host if request.client else "unknown"
|
|
26
|
+
client_port = request.client.port if request.client else "unknown"
|
|
27
|
+
content = f"{client_host}:{client_port} - {request.method} {request.url} {http_version} {response.status_code} {round(float(process_time), 2)}ms"
|
|
28
|
+
logger.info(content)
|