rainycode 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,360 @@
1
+ import datetime
2
+ from typing import Dict, List, Optional, Any, Tuple
3
+ from fastapi import FastAPI
4
+ from fastapi.encoders import jsonable_encoder
5
+ from tortoise import Tortoise
6
+ from tortoise.models import Model
7
+ from tortoise.contrib.fastapi import RegisterTortoise
8
+ from tortoise.queryset import QuerySetSingle, QuerySet
9
+ from contextlib import asynccontextmanager
10
+
11
+ from core.base_config import base_config
12
+
13
+ db_models = [
14
+ 'aerich.models',
15
+ 'models'
16
+ ]
17
+
18
+ db_config = {
19
+ 'connections': {
20
+ 'main': base_config.mysql_url, # 默认数据库连接
21
+ },
22
+ 'apps': {
23
+ 'main': {'models': db_models, 'default_connection': 'main'},
24
+ },
25
+ 'use_tz': False,
26
+ 'timezone': 'Asia/Shanghai'
27
+ }
28
+
29
+ class AioDb:
30
+
31
+ ### 事务的两种方式:
32
+ # async with in_transaction() as connection:
33
+ # pass
34
+ #
35
+ # @atomic('main') (推荐使用)
36
+ # def func():
37
+ # pass
38
+ #
39
+
40
+ @classmethod
41
+ async def connect(cls):
42
+ await Tortoise.init(config=db_config)
43
+
44
+ @classmethod
45
+ async def close(cls):
46
+ await Tortoise.close_connections()
47
+
48
+ @classmethod
49
+ async def register_db(cls, app: FastAPI) -> RegisterTortoise:
50
+ # 注册数据库,FastAPI用
51
+ tortoise = RegisterTortoise(
52
+ app,
53
+ config=db_config,
54
+ generate_schemas=False, # 自动创建不存在的表(生产环境建议关闭,手动管理)
55
+ add_exception_handlers=True # 添加 Tortoise 异常处理器
56
+ )
57
+ return tortoise
58
+
59
+
60
+ @classmethod
61
+ async def init_tables(cls):
62
+ """初始化数据库表,如果表不存在则创建"""
63
+ try:
64
+ await Tortoise.generate_schemas()
65
+ except Exception as e:
66
+ raise e
67
+
68
+ @classmethod
69
+ @asynccontextmanager
70
+ async def execute(cls):
71
+ """
72
+ :example:
73
+ async with AioDb.execute():
74
+ await async_func()
75
+ """
76
+ await cls.connect()
77
+ yield
78
+ await cls.close()
79
+
80
+ @classmethod
81
+ def get_model_by_name(cls, name, app='main'):
82
+ return Tortoise.apps[app][name]
83
+
84
+ @classmethod
85
+ async def run_func(cls, func, *args, **kwargs):
86
+ async with AioDb.execute():
87
+ await func(*args, **kwargs)
88
+
89
+
90
+ class MySQLUtils(Model):
91
+
92
+ async def model_dump(self, **kwargs):
93
+ return await MySQLUtils.model2dict(self, **kwargs)
94
+
95
+ @staticmethod
96
+ def json_encoder(
97
+ obj: Any,
98
+ include: Optional[List[str]] = None,
99
+ exclude: Optional[List[str]] = None,
100
+ **kwargs
101
+ ):
102
+ return jsonable_encoder(
103
+ obj,
104
+ include=set(include) if include else None,
105
+ exclude=set(exclude) if exclude else None,
106
+ custom_encoder={
107
+ datetime.datetime: lambda dt: datetime.datetime.strftime(dt, '%Y-%m-%d %H:%M:%S')
108
+ }
109
+ )
110
+
111
+ @classmethod
112
+ async def model2dict(
113
+ cls,
114
+ model: Model,
115
+ include: List[str] | None = None,
116
+ exclude: List[str] | None = None,
117
+ alias: List[Tuple[str, str]] | None = None,
118
+ fetch_related: List[str] | None = None,
119
+ flat: bool = False, # 扁平化
120
+ **kwargs
121
+ ) -> Dict[str, Any]:
122
+ meta = getattr(cls, 'PydanticMeta', getattr(cls, 'Meta', None))
123
+ if meta:
124
+ _include = getattr(meta, 'include', None)
125
+ _exclude = getattr(meta, 'exclude', None)
126
+ if _include:
127
+ include = list(set(include + _include)) if include else _include
128
+ if _exclude:
129
+ exclude = list(set(exclude + _exclude)) if exclude else _exclude
130
+
131
+ if include and exclude:
132
+ for include_field in include:
133
+ if include_field in exclude:
134
+ exclude.remove(include_field)
135
+
136
+ alias_dict = dict()
137
+ if alias:
138
+ for item in alias:
139
+ alias_dict[item[0]] = item[1]
140
+
141
+ # 默认过滤掉deleted字段
142
+ if exclude:
143
+ exclude.append('deleted')
144
+ else:
145
+ exclude = ['deleted']
146
+
147
+ result = cls.json_encoder(
148
+ obj=model,
149
+ include=include,
150
+ exclude=exclude,
151
+ **kwargs
152
+ )
153
+
154
+ if alias_dict:
155
+ for attr in result.keys():
156
+ if attr in alias_dict:
157
+ result[alias_dict[attr]] = result.pop(attr)
158
+
159
+ fix_data = getattr(model, 'fix_data', None)
160
+ if fix_data and callable(fix_data):
161
+ fix_data(result)
162
+
163
+ if fetch_related:
164
+ for fr in fetch_related:
165
+ fr_include = kwargs.get(f'fr_{fr}_include')
166
+ fr_exclude = kwargs.get(f'fr_{fr}_exclude')
167
+ fr_field = getattr(model, fr)
168
+ if not fr_field:
169
+ result[fr] = None
170
+ continue
171
+ fr_model = await fr_field
172
+
173
+ fr_meta = getattr(fr_model, 'PydanticMeta', None)
174
+ if fr_meta:
175
+ _fr_include = getattr(fr_meta, 'include', None)
176
+ _fr_exclude = getattr(fr_meta, 'exclude', None)
177
+ if _fr_include:
178
+ fr_include = list(set(fr_include + _fr_include)) if fr_include else _fr_include
179
+ if _fr_exclude:
180
+ fr_exclude = list(set(fr_exclude + _fr_exclude)) if fr_exclude else _fr_exclude
181
+
182
+ if fr_include and fr_exclude:
183
+ for fr_include_field in fr_include:
184
+ if fr_include_field in fr_exclude:
185
+ fr_exclude.remove(fr_include_field)
186
+
187
+ fix_data = None
188
+ if isinstance(fr_model, list) and fr_model:
189
+ fix_data = getattr(fr_model[0], 'fix_data', None)
190
+ elif isinstance(fr_model, Model):
191
+ fix_data = getattr(fr_model, 'fix_data', None)
192
+
193
+ # 默认过滤掉deleted字段
194
+ if fr_exclude:
195
+ fr_exclude.append('deleted')
196
+ else:
197
+ fr_exclude = ['deleted']
198
+
199
+ fr_result = cls.json_encoder(
200
+ obj=fr_model,
201
+ include=fr_include,
202
+ exclude=fr_exclude,
203
+ **kwargs
204
+ )
205
+
206
+ if alias_dict:
207
+ for attr in result.keys():
208
+ key = f'{fr}.{attr}'
209
+ if key in alias_dict:
210
+ fr_result[alias_dict[key]] = fr_result.pop(attr)
211
+
212
+ if fix_data and callable(fix_data):
213
+ if isinstance(fr_result, list):
214
+ for frr in fr_result:
215
+ fix_data(frr)
216
+ elif isinstance(fr_result, dict):
217
+ fix_data(fr_result)
218
+
219
+ if flat and isinstance(fr_result, dict):
220
+ for k, v in fr_result.items():
221
+ result[fr + '_' + k] = v
222
+ else:
223
+ result[fr] = fr_result
224
+ return result
225
+
226
+ @classmethod
227
+ async def get_data(
228
+ cls,
229
+ *args,
230
+ include: List[str] | None = None, # 包含哪些字段,默认全部
231
+ exclude: List[str] | None = None, # 排除哪些字段,默认全部
232
+ alias: List[Tuple[str, str]] | None = None, # 字段别名,如[('id', 'uid')]
233
+ fetch_related: List[str] | None = None, # 对应关系表
234
+ flat: bool = False, # 扁平化
235
+ deleted: bool | None = None,
236
+ queryset_model: QuerySetSingle | Model | None = None,
237
+ **kwargs
238
+ ) -> Dict[str, Any] | None:
239
+ if 'deleted' in cls._meta.fields and deleted is None:
240
+ kwargs['deleted'] = False
241
+
242
+ queryKwargs = {}
243
+ if fetch_related:
244
+ for k, v in kwargs.items():
245
+ if k.startswith('fr_'):
246
+ continue
247
+ queryKwargs[k] = v
248
+ else:
249
+ queryKwargs = kwargs
250
+
251
+ qm: QuerySetSingle | Model = queryset_model if queryset_model else cls.get(*args, **queryKwargs)
252
+ if isinstance(qm, Model):
253
+ if fetch_related:
254
+ await qm.fetch_related(*fetch_related)
255
+ model = qm
256
+ else:
257
+ if fetch_related:
258
+ qm = qm.prefetch_related(*fetch_related)
259
+ model = await qm
260
+
261
+ if not model:
262
+ return None
263
+ result = await cls.model2dict(model, include, exclude, alias, fetch_related, flat, **kwargs)
264
+ return result
265
+
266
+ @classmethod
267
+ async def get_or_none_data(
268
+ cls,
269
+ *args,
270
+ include: List[str] | None = None, # 包含哪些字段,默认全部
271
+ exclude: List[str] | None = None, # 排除哪些字段,默认全部
272
+ alias: List[Tuple[str, str]] | None = None, # 字段别名,如[('id', 'uid')]
273
+ fetch_related: List[str] | None = None, # 对应关系表
274
+ flat: bool = False, # 扁平化
275
+ deleted: bool | None = None,
276
+ queryset_model: QuerySetSingle | Model | None = None,
277
+ **kwargs
278
+ ) -> Dict[str, Any] | None:
279
+ if 'deleted' in cls._meta.fields and deleted is None:
280
+ kwargs['deleted'] = False
281
+
282
+ queryKwargs = {}
283
+ if fetch_related:
284
+ for k, v in kwargs.items():
285
+ if k.startswith('fr_'):
286
+ continue
287
+ queryKwargs[k] = v
288
+ else:
289
+ queryKwargs = kwargs
290
+
291
+ qm: QuerySetSingle | Model = queryset_model if queryset_model else cls.get_or_none(*args, **kwargs)
292
+ if isinstance(qm, Model):
293
+ if fetch_related:
294
+ await qm.fetch_related(*fetch_related)
295
+ model = qm
296
+ else:
297
+ if fetch_related:
298
+ qm = qm.prefetch_related(*fetch_related)
299
+ model = await qm
300
+
301
+ if not model:
302
+ return None
303
+ result = await cls.model2dict(model, include, exclude, alias, fetch_related, flat, **kwargs)
304
+ return result
305
+
306
+ @classmethod
307
+ async def query_data(
308
+ cls,
309
+ *args,
310
+ page_no: int | None = 0,
311
+ page_size: int | None = 0,
312
+ nocount: bool = False,
313
+ include: List[str] | None = None, # 包含哪些字段,默认全部
314
+ exclude: List[str] | None = None, # 排除哪些字段,默认全部
315
+ alias: List[Tuple[str, str]] | None = None, # 字段别名,如[('id', 'uid')]
316
+ fetch_related: List[str] | None = None, # 对应关系表
317
+ flat: bool = False, # 扁平化
318
+ deleted: bool | None = None,
319
+ queryset: QuerySet | None = None,
320
+ order_by: List[str] | None = None,
321
+ **kwargs
322
+ ) -> Dict[str, Any] | List[Dict[str, Any]]:
323
+ if 'deleted' in cls._meta.fields and deleted is None:
324
+ kwargs['deleted'] = False
325
+
326
+ queryKwargs = {}
327
+ if fetch_related:
328
+ for k, v in kwargs.items():
329
+ if k.startswith('fr_'):
330
+ continue
331
+ queryKwargs[k] = v
332
+ else:
333
+ queryKwargs = kwargs
334
+
335
+ qs: QuerySet = queryset if queryset else cls.filter(*args, **kwargs)
336
+ result = {'items': [], 'count': 0}
337
+ if not nocount:
338
+ result['count'] = await qs.count()
339
+ else:
340
+ del result['count']
341
+
342
+ if order_by:
343
+ qs = qs.order_by(*order_by)
344
+ if page_no and page_size:
345
+ qs = qs.limit(page_size).offset((page_no - 1) * page_size)
346
+ if fetch_related:
347
+ qs = qs.prefetch_related(*fetch_related)
348
+
349
+ model_list = await qs
350
+ for model in model_list:
351
+ item = await cls.model2dict(model, include, exclude, alias, fetch_related, flat, **kwargs)
352
+ result['items'].append(item)
353
+
354
+ if nocount:
355
+ return result['items']
356
+
357
+ return result
358
+
359
+ class Meta:
360
+ abstract = True
@@ -0,0 +1,279 @@
1
+ import time
2
+ import asyncio
3
+ from contextlib import asynccontextmanager
4
+ from typing import Optional, Union, Callable, Mapping
5
+ from redis.asyncio import Redis
6
+ from redis import asyncio as aioredis
7
+ from redis.typing import KeyT, EncodableT, ExpiryT, ZScoreBoundT, AnyKeyT
8
+ from common.logging import log_exc
9
+ from core.base_config import base_config
10
+
11
+
12
+ class AioRedis:
13
+ """异步 Redis 客户端封装"""
14
+
15
+ __redis_pool: Redis | None = None
16
+ __connection_lock: asyncio.Lock = asyncio.Lock() # 确保连接池初始化线程安全
17
+
18
+ @classmethod
19
+ async def connect(cls) -> Redis:
20
+ """获取 Redis 连接(单例连接池)"""
21
+ if cls.__redis_pool:
22
+ return cls.__redis_pool
23
+
24
+ # 使用锁确保连接池初始化的线程安全
25
+ async with cls.__connection_lock:
26
+ if not cls.__redis_pool: # 双重检查
27
+ try:
28
+ pool = aioredis.ConnectionPool.from_url(
29
+ base_config.redis_url,
30
+ encoding='utf-8',
31
+ decode_responses=True,
32
+ retry_on_timeout=True # 增加超时重试
33
+ )
34
+ cls.__redis_pool = Redis(connection_pool=pool)
35
+ # 测试连接是否有效
36
+ cls.__redis_pool.ping()
37
+ except Exception as e:
38
+ log_exc(e)
39
+ raise
40
+ return cls.__redis_pool
41
+
42
+ @classmethod
43
+ async def close(cls) -> None:
44
+ """关闭 Redis 连接池"""
45
+ if cls.__redis_pool:
46
+ # 安全关闭连接池
47
+ try:
48
+ await cls.__redis_pool.connection_pool.disconnect()
49
+ except Exception as e:
50
+ log_exc(e)
51
+ finally:
52
+ cls.__redis_pool = None
53
+
54
+ @classmethod
55
+ @asynccontextmanager
56
+ async def session(cls):
57
+ """上下文管理器,自动管理连接生命周期"""
58
+ try:
59
+ redis = await cls.connect()
60
+ yield redis
61
+ finally:
62
+ # 连接池模式下不需要每次关闭,由外部统一管理
63
+ pass
64
+
65
+ @classmethod
66
+ async def get(cls, name: KeyT) -> Optional[bytes]:
67
+ redis = await cls.connect()
68
+ result = await redis.get(name)
69
+ return result
70
+
71
+ @classmethod
72
+ async def set(
73
+ cls,
74
+ name: KeyT,
75
+ value: EncodableT,
76
+ ex: Union[ExpiryT, None] = None, # 超时秒数
77
+ px: Union[ExpiryT, None] = None, # 超时毫秒
78
+ nx: bool = False,
79
+ xx: bool = False
80
+ ) -> Optional[bool]:
81
+ redis = await cls.connect()
82
+ result = await redis.set(name, value, ex=ex, px=px, nx=nx, xx=xx)
83
+ return result
84
+
85
+ @classmethod
86
+ async def exists(
87
+ cls,
88
+ *names: KeyT
89
+ ) -> Optional[bool]:
90
+ redis = await cls.connect()
91
+ result = await redis.exists(*names)
92
+ return result
93
+
94
+ @classmethod
95
+ async def zadd(
96
+ cls,
97
+ name: KeyT,
98
+ mapping: Mapping[AnyKeyT, EncodableT],
99
+ nx: bool = False,
100
+ xx: bool = False,
101
+ ch: bool = False,
102
+ incr: bool = False,
103
+ gt: bool = False,
104
+ lt: bool = False
105
+ ) -> Union[int, float]:
106
+ redis = await cls.connect()
107
+ result = await redis.zadd(name, mapping, nx, xx, ch, incr, gt, lt)
108
+ return result
109
+
110
+ @classmethod
111
+ async def zrem(cls, name: KeyT, values) -> int:
112
+ redis = await cls.connect()
113
+ result = await redis.zrem(name, values)
114
+ return result
115
+
116
+ @classmethod
117
+ async def zscore(
118
+ cls,
119
+ name: KeyT,
120
+ value: EncodableT
121
+ ) -> Optional[float]:
122
+ redis = await cls.connect()
123
+ result = await redis.zscore(name, value)
124
+ return result
125
+
126
+ @classmethod
127
+ async def zcard(cls, name: KeyT) -> int:
128
+ redis = await cls.connect()
129
+ result = await redis.zcard(name)
130
+ return result
131
+
132
+ @classmethod
133
+ async def zrevrangebyscore(
134
+ cls,
135
+ name: KeyT,
136
+ max: ZScoreBoundT,
137
+ min: ZScoreBoundT,
138
+ start: Union[int, None] = None,
139
+ num: Union[int, None] = None,
140
+ withscores: bool = False,
141
+ score_cast_func: Union[type, Callable] = float,
142
+ ):
143
+ redis = await cls.connect()
144
+ result = await redis.zrevrangebyscore(name, max, min, start, num, withscores, score_cast_func)
145
+ return result
146
+
147
+ @classmethod
148
+ async def zremrangebyscore(
149
+ cls,
150
+ name: KeyT,
151
+ min: ZScoreBoundT,
152
+ max: ZScoreBoundT
153
+ ) -> int:
154
+ redis = await cls.connect()
155
+ result = await redis.zremrangebyscore(name, min, max)
156
+ return result
157
+
158
+ @classmethod
159
+ async def delete(cls, *names: KeyT) -> int:
160
+ redis = await cls.connect()
161
+ result = await redis.delete(*names)
162
+ return result
163
+
164
+ @classmethod
165
+ async def ttl(cls, name: KeyT) -> int:
166
+ redis = await cls.connect()
167
+ result = await redis.ttl(name)
168
+ return result
169
+
170
+ @classmethod
171
+ async def expire(cls, name: KeyT, expire: ExpiryT) -> int:
172
+ redis = await cls.connect()
173
+ return await redis.expire(name, expire)
174
+
175
+ @classmethod
176
+ async def incrby(cls, name: KeyT, amount: int = 1) -> int:
177
+ redis = await cls.connect()
178
+ return await redis.incrby(name, amount)
179
+
180
+ @classmethod
181
+ async def decrby(cls, name: KeyT, amount: int = 1) -> int:
182
+ redis = await cls.connect()
183
+ return await redis.decrby(name, amount)
184
+
185
+ # 分布式锁实现优化
186
+ @classmethod
187
+ @asynccontextmanager
188
+ async def lock(
189
+ cls,
190
+ name: str,
191
+ value: EncodableT = 1,
192
+ timeout: int = 3, # 等待锁的时间(秒)
193
+ expire: int = 3 # 锁的有效期(秒)
194
+ ):
195
+ if isinstance(value, int):
196
+ value = str(value)
197
+
198
+ lock_key = f"lock:{name}"
199
+ acquired = False
200
+ renewal_task: asyncio.Task | None = None
201
+
202
+ try:
203
+ # 获取锁
204
+ timeout_at = time.time() + timeout
205
+ while time.time() < timeout_at:
206
+ # 使用 set 命令的 nx 和 ex 参数原子性获取锁
207
+ result = await cls.set(lock_key, value, nx=True, ex=expire)
208
+ if result:
209
+ acquired = True
210
+ break
211
+ await asyncio.sleep(0.1) # 短暂等待后重试
212
+
213
+ if not acquired:
214
+ raise Exception(f'获取分布式锁失败:{name}')
215
+
216
+ # 启动锁续期任务(只在需要时启动)
217
+ async def renewal():
218
+ # 续期间隔设为锁有效期的 1/3,确保在锁过期前完成续期
219
+ renewal_interval = max(0.1, expire / 3)
220
+ while acquired:
221
+ # 检查剩余时间,不足 1 秒时续期
222
+ if await cls.ttl(lock_key) < 1:
223
+ await cls.expire(lock_key, expire)
224
+ await asyncio.sleep(renewal_interval)
225
+
226
+ renewal_task = asyncio.create_task(renewal())
227
+ yield acquired
228
+
229
+ except Exception as e:
230
+ log_exc(e)
231
+ raise
232
+ finally:
233
+ # 清理工作
234
+ acquired = False # 停止续期任务
235
+ if renewal_task:
236
+ # 等待续期任务结束
237
+ renewal_task.cancel()
238
+ try:
239
+ await renewal_task
240
+ except asyncio.CancelledError:
241
+ pass
242
+
243
+ # 只删除自己持有的锁
244
+ if await cls.get(lock_key) == value:
245
+ await cls.delete(lock_key)
246
+
247
+ @classmethod
248
+ async def pass_fix_window(cls, name: str, limit: int, expire: int) -> bool:
249
+ """固定窗口限流算法"""
250
+ key = f"ratelimit:fixed:{name}"
251
+ async with cls.lock(name):
252
+ # 使用 incr 原子操作替代 get + set,减少 Redis 调用
253
+ count = await cls.incrby(key, 1)
254
+ # 首次设置过期时间
255
+ if count == 1:
256
+ await cls.expire(key, expire)
257
+ # 将 count 转换为整数后再进行比较
258
+ return count <= limit
259
+
260
+
261
+ @classmethod
262
+ async def pass_slide_window(cls, name, limit, expire):
263
+ """限流:滑动窗口"""
264
+ key = f"slidelimit:fixed:{name}"
265
+ async with cls.lock(name):
266
+ count = await cls.zcard(key)
267
+ if not count or count < limit:
268
+ now_timestamp = int(time.time() * 1000)
269
+ old_timestamp = now_timestamp - expire * 1000
270
+ await cls.zadd(key, {str(now_timestamp): now_timestamp})
271
+ await cls.zremrangebyscore(key, 0, old_timestamp)
272
+ await cls.expire(key, expire)
273
+ return True
274
+ else:
275
+ now_timestamp = int(time.time() * 1000)
276
+ old_timestamp = now_timestamp - expire * 1000
277
+ await cls.zremrangebyscore(key, 0, old_timestamp)
278
+
279
+ return False
@@ -0,0 +1,28 @@
1
+ import time
2
+ from fastapi import Request
3
+ from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
4
+ from starlette.types import ASGIApp
5
+ from starlette.responses import Response
6
+ from common.logging import logger
7
+
8
+ class RequestHttpMiddleware(BaseHTTPMiddleware):
9
+ def __init__(self, app: ASGIApp) -> None:
10
+ super().__init__(app)
11
+
12
+ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
13
+ start_time = time.time()
14
+ response = await call_next(request)
15
+ process_time = time.time() - start_time
16
+ response.headers['X-Process-Time'] = str(process_time)
17
+ self.write_request_log(request, response)
18
+ return response
19
+
20
+ @staticmethod
21
+ def write_request_log(request: Request, response: Response) -> None:
22
+ http_version = f"http/{request.scope['http_version']}"
23
+ process_time = response.headers["X-Process-Time"]
24
+ # 获取客户端IP和端口
25
+ client_host = request.client.host if request.client else "unknown"
26
+ client_port = request.client.port if request.client else "unknown"
27
+ content = f"{client_host}:{client_port} - {request.method} {request.url} {http_version} {response.status_code} {round(float(process_time), 2)}ms"
28
+ logger.info(content)