gomyck-tools 1.4.2__py3-none-any.whl → 1.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ctools/ai/llm_chat.py CHANGED
@@ -109,6 +109,7 @@ class ChatSession:
109
109
  self.full_messages.append(build_message(ROLE.ASSISTANT, llm_response)) # 不能调换顺序
110
110
  await self.add_tool_call_res_2_message(last_user_input, tool_call_result)
111
111
  await self.process_tool_call_message(get_call_id, get_event_msg_func, tool_call_result)
112
+ # 工具调用, 说明没有结束对话, 要继续执行
112
113
  final_resp = False
113
114
  else:
114
115
  self.full_messages.append(build_message(ROLE.ASSISTANT, llm_response))
@@ -141,15 +142,27 @@ class ChatSession:
141
142
  if get_event_msg_func: await get_event_msg_func(get_call_id(), ROLE.ASSISTANT, self.current_message)
142
143
 
143
144
  async def process_tool_call_message(self, get_call_id, get_event_msg_func, tool_call_result):
144
- # 实时通知前端(工具调用特殊通知一次) 如果是图片结果, 就是 user 消息(必须是 user, 否则 api 报错), 否则是 system(现在统一都改成 user 了, 看看后面有没有改回 system 的必要)
145
+ # 实时通知前端(工具调用特殊通知一次, 输出的是工具返回的结果)
146
+ # 如果是图片结果, 就是 user 消息(必须是 user, 否则 api 报错), 否则是 system(现在统一都改成 user 了, 看看后面有没有改回 system 的必要)
145
147
  self.current_message = tool_call_result["result"] if res_has_img(tool_call_result) else tool_call_result
146
148
  if get_event_msg_func: await get_event_msg_func(get_call_id(), ROLE.USER, self.current_message)
147
149
 
148
150
  async def process_full_message(self, final_resp, get_call_id, get_full_msg_func):
151
+ """
152
+ 全量消息回调函数
153
+ :param final_resp: 最终响应信息
154
+ :param get_call_id: 调用 ID
155
+ :param get_full_msg_func: 回调的函数
156
+ """
149
157
  self.current_message = self.full_messages[-1]["content"]
150
158
  if get_full_msg_func: await get_full_msg_func(get_call_id(), final_resp, self.full_messages)
151
159
 
152
160
  async def add_tool_call_res_2_message(self, last_user_input, tool_call_result: dict):
161
+ """
162
+ 添加当前会话结果, 以便于用当前 chat 对象取值
163
+ :param last_user_input: 客户端最后一次输入
164
+ :param tool_call_result: 工具调用结果
165
+ """
153
166
  if type(tool_call_result) != dict: return
154
167
  response: [] = tool_call_result.get("result")
155
168
  image_content = []
ctools/ai/llm_client.py CHANGED
@@ -4,7 +4,7 @@ import os
4
4
  import httpx
5
5
 
6
6
  from ctools import sys_log, cjson, call
7
- from ctools.ai.env_config import float_env, bool_env, int_env
7
+ from ctools.util.env_config import float_env, bool_env, int_env
8
8
  from ctools.ai.llm_exception import LLMException
9
9
 
10
10
  logging.getLogger("httpcore").setLevel(logging.WARNING)
@@ -10,6 +10,7 @@ import mimetypes
10
10
  import sys
11
11
  import uuid
12
12
 
13
+ from ctools.util.env_config import bool_env
13
14
  from ctools.web.aio_web_server import get_stream_resp
14
15
 
15
16
 
@@ -60,7 +61,7 @@ def build_image_message(content: str, file: bytes = None, file_path: str = None)
60
61
  return build_message(ROLE.USER, img_content)
61
62
 
62
63
 
63
- async def build_call_back(debug=False, request=None, SSE=True):
64
+ async def build_call_back(debug=None, request=None, SSE=True):
64
65
  """
65
66
  快速构建回调函数
66
67
  Parameters
@@ -71,6 +72,7 @@ async def build_call_back(debug=False, request=None, SSE=True):
71
72
  Returns 响应对象, 消息队列, 回调函数
72
73
  -------
73
74
  """
75
+ if not debug: debug = bool_env("LLM_DEBUG", False)
74
76
  response = None
75
77
  if request: response = await get_stream_resp(request)
76
78
  call_id = uuid.uuid4()
@@ -89,7 +91,7 @@ async def build_call_back(debug=False, request=None, SSE=True):
89
91
 
90
92
  async def on_final(cid, is_final, msg):
91
93
  nonlocal response
92
- if debug: print(cid, is_final, msg, file=sys.__stdout__, flush=True)
94
+ if debug: print("\n", cid, "\n", is_final, "\n", msg, "\n", file=sys.__stdout__, flush=True)
93
95
  if is_final:
94
96
  await message_queue.put("[DONE]")
95
97
  if response:
ctools/application.py CHANGED
@@ -159,7 +159,7 @@ def sync_version(callFunc):
159
159
  shutil.rmtree(taguiPath)
160
160
  except Exception:
161
161
  pass
162
- from ctools.auto.pacth import Patch
162
+ import Patch
163
163
  patch = Patch(oldVersion='V1.0.0', newVersion=Server.version, pythonPath=pythonPath, playwrightPath=msPlayPath, driverPath=driverPath)
164
164
  patch.apply_patch()
165
165
  if callFunc: callFunc()
@@ -174,7 +174,7 @@ def sync_version(callFunc):
174
174
  oldVersion.close()
175
175
  if oldV == Server.version and '-snapshot' not in oldV: return
176
176
  print('开始升级本地程序..')
177
- from ctools.auto.pacth import Patch
177
+ import Patch
178
178
  patch = Patch(oldVersion=oldV, newVersion=Server.version, pythonPath=pythonPath, playwrightPath=msPlayPath, driverPath=driverPath)
179
179
  patch.apply_patch()
180
180
  if callFunc: callFunc()
ctools/aspect.py ADDED
@@ -0,0 +1,65 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: UTF-8 -*-
3
+ __author__ = 'haoyang'
4
+ __date__ = '2025/7/23 08:33'
5
+
6
+ import functools
7
+ import sys
8
+
9
+ from ctools.pools import thread_pool
10
+
11
+ def _ensure_list(funcs):
12
+ if callable(funcs):
13
+ return [funcs]
14
+ if isinstance(funcs, (list, tuple, set)):
15
+ return list(funcs)
16
+ raise TypeError("必须是可调用对象或可迭代对象")
17
+
18
+ def before(before_funcs):
19
+ """
20
+ 用于将无参函数注入目标函数的调用前
21
+ 支持多个函数
22
+ """
23
+ before_funcs = _ensure_list(before_funcs)
24
+ def decorator(func, sync=True):
25
+ @functools.wraps(func)
26
+ def wrapper(*args, **kwargs):
27
+ for bf in before_funcs:
28
+ if callable(bf):
29
+ if sync:
30
+ bf()
31
+ else:
32
+ thread_pool.submit(bf)
33
+ return func(*args, **kwargs)
34
+ _replace_func_binding(func, wrapper)
35
+ return wrapper
36
+ return decorator
37
+
38
+ def after(after_funcs, sync=True):
39
+ """
40
+ 用于将无参函数注入目标函数的调用后
41
+ 支持多个函数
42
+ """
43
+ after_funcs = _ensure_list(after_funcs)
44
+ def decorator(func):
45
+ @functools.wraps(func)
46
+ def wrapper(*args, **kwargs):
47
+ result = func(*args, **kwargs)
48
+ for af in after_funcs:
49
+ if callable(af):
50
+ if sync:
51
+ af()
52
+ else:
53
+ thread_pool.submit(af)
54
+ return result
55
+ _replace_func_binding(func, wrapper)
56
+ return wrapper
57
+ return decorator
58
+
59
+ def _replace_func_binding(old_func, new_func):
60
+ """
61
+ 替换函数在其模块中的绑定,确保所有使用点都生效
62
+ """
63
+ mod = sys.modules.get(old_func.__module__)
64
+ if mod and hasattr(mod, old_func.__name__):
65
+ setattr(mod, old_func.__name__, new_func)
ctools/call.py CHANGED
@@ -1,3 +1,4 @@
1
+ import os
1
2
  import sched
2
3
  import threading
3
4
  import time
@@ -5,63 +6,51 @@ from functools import wraps
5
6
 
6
7
 
7
8
  # annotation
9
+ _global_once_cache = {}
8
10
  def once(func):
9
- """
10
- decorator to initialize a function once
11
- :param func: function to be initialized
12
- :return: the real decorator for return the result
13
- """
14
- initialized = False
15
- res = None
16
-
11
+ code = func.__code__
12
+ key = f"{os.path.abspath(code.co_filename)}:{code.co_firstlineno}"
17
13
  def wrapper(*args, **kwargs):
18
- nonlocal initialized, res
19
- if not initialized:
20
- res = func(*args, **kwargs)
21
- initialized = True
22
- return res
23
- else:
24
- return res
25
-
14
+ if key not in _global_once_cache:
15
+ _global_once_cache[key] = func(*args, **kwargs)
16
+ return _global_once_cache[key]
26
17
  return wrapper
27
18
 
28
-
29
19
  # annotation
20
+ _cache = {}
30
21
  def init(func):
31
- """
32
- decorator to initialize a function automic
33
- :param func: function to be initialized
34
- :return: the real decorator for return the result
35
- """
36
- res = func()
37
-
22
+ code = func.__code__
23
+ key = f"{os.path.abspath(code.co_filename)}:{code.co_firstlineno}"
24
+ if key not in _cache:
25
+ _cache[key] = func()
38
26
  def wrapper():
39
- return res
40
-
27
+ return _cache[key]
41
28
  return wrapper
42
29
 
43
-
44
30
  # annotation
31
+ _scheduler_cache = {}
45
32
  def schd(interval_seconds, start_by_call=False, run_now=False):
46
- scheduler = sched.scheduler(time.time, time.sleep)
47
- lock = threading.Lock()
48
- started = [False] # 可变对象,线程可见
49
- print("schd delay is: ", interval_seconds)
50
33
  def decorator(func):
34
+ key = f"{os.path.abspath(func.__code__.co_filename)}:{func.__code__.co_firstlineno}"
35
+ lock = threading.Lock()
51
36
  @wraps(func)
52
37
  def wrapper(*args, **kwargs):
38
+ if key in _scheduler_cache:
39
+ return # 已经调度过
40
+ scheduler = sched.scheduler(time.time, time.sleep)
53
41
  def job():
54
42
  func(*args, **kwargs)
55
43
  scheduler.enter(interval_seconds, 1, job)
56
44
  def start_scheduler():
57
45
  with lock:
58
- if started[0]: return
59
- started[0] = True
60
- if run_now: func(*args, **kwargs)
46
+ if _scheduler_cache.get(key): return
47
+ _scheduler_cache[key] = True
48
+ if run_now:
49
+ func(*args, **kwargs)
61
50
  scheduler.enter(interval_seconds, 1, job)
62
51
  scheduler.run()
63
52
  threading.Thread(target=start_scheduler, daemon=True).start()
64
- # 如果不是手动触发,则自动启动一次(无参数)
65
- if not start_by_call: wrapper()
66
- return wrapper # 如果是 start_by_call=True,返回 wrapper 让用户手动调用时带参
53
+ if not start_by_call:
54
+ wrapper()
55
+ return wrapper
67
56
  return decorator
ctools/cid.py CHANGED
@@ -7,6 +7,9 @@ def get_snowflake_id():
7
7
  return idWorker.get_id()
8
8
 
9
9
 
10
+ def get_snowflake_id_str():
11
+ return str(get_snowflake_id())
12
+
10
13
  def get_random_str(size: int = 10) -> str:
11
14
  import random
12
15
  return "".join(random.sample('abcdefghjklmnpqrstuvwxyz123456789', size))
@@ -1,8 +1,9 @@
1
1
  import contextlib
2
2
  import datetime
3
3
  import math
4
+ import threading
4
5
 
5
- from sqlalchemy import create_engine, Integer, Column, event
6
+ from sqlalchemy import create_engine, BigInteger, Column, event
6
7
  from sqlalchemy.ext.declarative import declarative_base
7
8
  from sqlalchemy.orm import sessionmaker, Session
8
9
  from sqlalchemy.sql import text
@@ -58,17 +59,19 @@ and ``driver`` the name of a DBAPI such as ``psycopg2``, ``pyodbc``, ``cx_oracle
58
59
  # > PRAGMA journal_mode=WAL; 设置事务的模式, wal 允许读写并发, 但是会额外创建俩文件
59
60
  # > PRAGMA synchronous=NORMAL; 设置写盘策略, 默认是 FULL, 日志,数据都落, 设置成 NORMAL, 日志写完就算事务完成
60
61
 
61
- def init_db(db_url: str, db_key: str = 'default', connect_args: dict = {}, default_schema: str = None, pool_size: int = 5, max_overflow: int = 25, echo: bool = False, auto_gen_table: bool = True):
62
+ def init_db(db_url: str, db_key: str = 'default', connect_args: dict = {}, default_schema: str = None, pool_size: int = 5, max_overflow: int = 25, echo: bool = False, auto_gen_table: bool = False):
62
63
  if db_url.startswith('mysql'):
63
64
  import pymysql
64
65
  pymysql.install_as_MySQLdb()
65
66
  if inited_db.get(db_key): raise Exception('db {} already init!!!'.format(db_key))
66
67
  global engines, sessionMakers
68
+ if default_schema: connect_args.update({'options': '-csearch_path={}'.format(default_schema)})
67
69
  engine, sessionMaker = _create_connection(db_url=db_url, connect_args=connect_args, pool_size=pool_size, max_overflow=max_overflow, echo=echo)
68
70
  engines[db_key] = engine
69
71
  sessionMakers[db_key] = sessionMaker
70
72
  inited_db[db_key] = True
71
- if default_schema: event.listen(engine, 'connect', lambda dbapi_connection, connection_record: _set_search_path(dbapi_connection, default_schema))
73
+ # 这个有并发问题, 高并发会导致卡顿, 可以考虑去做一些别的事儿
74
+ #if default_schema: event.listen(engine, 'connect', lambda dbapi_connection, connection_record: _set_search_path(dbapi_connection, default_schema))
72
75
  if auto_gen_table: Base.metadata.create_all(engine)
73
76
 
74
77
 
@@ -86,7 +89,7 @@ def _create_connection(db_url: str, pool_size: int = 5, max_overflow: int = 25,
86
89
  pool_pre_ping=True,
87
90
  pool_recycle=3600,
88
91
  connect_args=connect_args)
89
- sm = sessionmaker(bind=engine)
92
+ sm = sessionmaker(bind=engine, expire_on_commit=False)
90
93
  return engine, sm
91
94
 
92
95
 
@@ -96,7 +99,7 @@ def generate_custom_id():
96
99
 
97
100
  class BaseMixin(Base):
98
101
  __abstract__ = True
99
- obj_id = Column(Integer, primary_key=True, default=generate_custom_id)
102
+ obj_id = Column(BigInteger, primary_key=True, default=generate_custom_id)
100
103
 
101
104
  # ext1 = Column(String)
102
105
  # ext2 = Column(String)
@@ -153,7 +156,14 @@ class PageInfoBuilder:
153
156
  self.records = records
154
157
 
155
158
 
156
- def query_by_page(query, pageInfo):
159
+ def query_by_page(query, pageInfo) -> PageInfoBuilder:
160
+ """
161
+ 使用方法:
162
+ with database.get_session() as s:
163
+ query = s.query(AppInfoEntity).filter(AppInfoEntity.app_name.contains(params.app_name))
164
+ result = database.query_by_page(query, params.page_info)
165
+ return R.ok(result)
166
+ """
157
167
  records = query.offset((pageInfo.page_index - 1) * pageInfo.page_size).limit(pageInfo.page_size).all()
158
168
  rs = []
159
169
  for r in records:
ctools/ex.py CHANGED
@@ -2,6 +2,9 @@ import time
2
2
  import traceback
3
3
  from functools import wraps
4
4
 
5
+ """
6
+ @exception_handler(fail_return=['解析错误'], print_exc=True)
7
+ """
5
8
 
6
9
  # annotation
7
10
  def exception_handler(fail_return, retry_num=0, delay=3, catch_e=Exception, print_exc=False):
ctools/patch.py ADDED
@@ -0,0 +1,88 @@
1
+ import os
2
+
3
+ from sqlalchemy.sql import text
4
+
5
+ from ctools import path_info
6
+ from ctools.database import database
7
+
8
+ """
9
+ from ctools import patch
10
+ def xx():
11
+ print('hello world')
12
+ def xx1():
13
+ print('hello world1')
14
+ def xx2():
15
+ print('hello world2')
16
+ patch_funcs = {
17
+ 'V1.0.2': xx,
18
+ 'V1.0.3': xx1,
19
+ 'V1.1.4': xx2
20
+ }
21
+ patch.sync_version("kwc", "V1.1.5", patch_funcs)
22
+ """
23
+
24
+ class Patch:
25
+
26
+ def __init__(self, oldVersion, newVersion, patch_func: dict) -> None:
27
+ super().__init__()
28
+ if oldVersion:
29
+ self.oldV = version_to_int(oldVersion)
30
+ else:
31
+ self.oldV = 0
32
+ self.currentV = version_to_int(newVersion)
33
+ self.snapshot = '-snapshot' in newVersion or (oldVersion is not None and '-snapshot' in oldVersion)
34
+ self.patch_func = patch_func
35
+
36
+ def apply_patch(self):
37
+ patch_methods = [method for method in self.patch_func.keys() if method and (method.startswith('V') or method.startswith('v'))]
38
+ patch_methods.sort(key=lambda x: version_to_int(x))
39
+ max_method_name = patch_methods[-1]
40
+ exec_max_method = False
41
+ for method_name in patch_methods:
42
+ slVersion = version_to_int(method_name)
43
+ if self.currentV > slVersion >= self.oldV:
44
+ if max_method_name == method_name: exec_max_method = True
45
+ method = self.patch_func[method_name]
46
+ print('start exec patch {}'.format(method_name))
47
+ method()
48
+ print('patch {} update success'.format(method_name))
49
+ if self.snapshot and not exec_max_method:
50
+ print('start exec snapshot patch {}'.format(max_method_name))
51
+ method = self.patch_func[max_method_name]
52
+ method()
53
+ print('snapshot patch {} update success'.format(max_method_name))
54
+
55
+ def version_to_int(version):
56
+ return int(version.replace('V', '').replace('v', '').replace('.', '').replace('-snapshot', ''))
57
+
58
+ def run_sqls(sqls):
59
+ with database.get_session() as s:
60
+ for sql in sqls.split(";"):
61
+ try:
62
+ s.execute(text(sql.strip()))
63
+ s.commit()
64
+ except Exception as e:
65
+ print('结构升级错误, 请检查!!! {}'.format(e.__cause__))
66
+
67
+ def sync_version(app_name, new_version, patch_func: dict):
68
+ destFilePath = os.path.join(path_info.get_user_work_path(".ck/{}".format(app_name), mkdir=True), "version")
69
+ if not os.path.exists(destFilePath):
70
+ patch = Patch(oldVersion=None, newVersion=new_version, patch_func=patch_func)
71
+ patch.apply_patch()
72
+ with open(destFilePath, 'w') as nv:
73
+ nv.write(new_version)
74
+ print('初始化安装, 版本信息为: {}'.format(new_version))
75
+ nv.flush()
76
+ else:
77
+ with open(destFilePath, 'r') as oldVersion:
78
+ oldV = oldVersion.readline()
79
+ print('本地版本信息为: {}, 程序版本信息为: {}'.format(oldV, new_version))
80
+ oldVersion.close()
81
+ if oldV >= new_version and '-snapshot' not in oldV: return
82
+ print('开始升级本地程序..')
83
+ patch = Patch(oldVersion=oldV, newVersion=new_version, patch_func=patch_func)
84
+ patch.apply_patch()
85
+ with open(destFilePath, 'w') as newVersion:
86
+ newVersion.write(new_version)
87
+ print('程序升级成功, 更新版本信息为: {}'.format(new_version))
88
+ newVersion.flush()
ctools/pkg/__init__.py ADDED
@@ -0,0 +1,4 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: UTF-8 -*-
3
+ __author__ = 'haoyang'
4
+ __date__ = '2025/7/15 11:02'
@@ -0,0 +1,38 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: UTF-8 -*-
3
+ __author__ = 'haoyang'
4
+ __date__ = '2025/7/15 11:03'
5
+
6
+ import importlib
7
+ import pkgutil
8
+
9
+
10
+ def load_modules_from_package(package, exclude=None, recursive=True):
11
+ """
12
+ 递归加载指定包下所有模块(不包括包本身)
13
+
14
+ :param package: 要加载模块的包对象(如 mypkg.plugins)
15
+ :param exclude: 排除的模块完整路径列表(如 ['mypkg.plugins.demo.mod2'])
16
+ :param recursive: 是否递归子包
17
+ :return: 模块列表(不含子包本身,只包含模块)
18
+ """
19
+ if exclude is None: exclude = []
20
+ modules = []
21
+ for finder, modname, ispkg in pkgutil.iter_modules(package.__path__):
22
+ full_modname = f"{package.__name__}.{modname}"
23
+ if ispkg and recursive:
24
+ try:
25
+ subpkg = importlib.import_module(full_modname)
26
+ modules.extend(load_modules_from_package(subpkg, exclude, recursive))
27
+ except Exception as e:
28
+ print(f"递归子包 {full_modname} 失败:{e}")
29
+ continue
30
+ if full_modname in exclude:
31
+ continue
32
+ try:
33
+ module = importlib.import_module(full_modname)
34
+ modules.append(module)
35
+ except Exception as e:
36
+ print(f"!!!!!!加载模块 {full_modname} 失败:{e}!!!!!!")
37
+ continue
38
+ return modules
ctools/stream/credis.py CHANGED
@@ -9,6 +9,10 @@ from redis import Redis
9
9
  from ctools import cdate, cid
10
10
  from ctools.pools import thread_pool
11
11
 
12
+ # 最后一次连接的redis
13
+ _ck_redis: Redis = None
14
+
15
+ def get_redis(): return _ck_redis
12
16
 
13
17
  def init_pool(host: str = 'localhost', port: int = 6379, db: int = 0, password: str = None,
14
18
  username: str = None, decode_responses: bool = True, max_connections: int = 75,
@@ -27,7 +31,9 @@ def init_pool(host: str = 'localhost', port: int = 6379, db: int = 0, password:
27
31
  )
28
32
  if r.ping():
29
33
  print('CRedis connect {} {} success!'.format(host, port))
30
- return r
34
+ global _ck_redis
35
+ _ck_redis = r
36
+ return _ck_redis
31
37
  except redis.ConnectionError as e:
32
38
  if attempt == retry_count - 1:
33
39
  raise Exception(f"Failed to connect to Redis after {retry_count} attempts: {str(e)}")
@@ -37,8 +43,8 @@ def init_pool(host: str = 'localhost', port: int = 6379, db: int = 0, password:
37
43
  def add_lock(r: Redis, key: str, timeout: int = 30):
38
44
  if r.exists(key):
39
45
  expire_time = r.get(key)
40
- if cdate.time_diff_in_seconds(expire_time, cdate.get_date_time()) > 0:
41
- return True
46
+ if expire_time and cdate.time_diff_in_seconds(expire_time, cdate.get_date_time()) > 0:
47
+ return False
42
48
  else:
43
49
  r.delete(key)
44
50
  return r.set(key, cdate.opt_time(seconds=timeout), nx=True, ex=timeout) is not None
ctools/sys_log.py CHANGED
@@ -85,7 +85,7 @@ class StreamToLogger(io.StringIO):
85
85
  @call.init
86
86
  def _init_log() -> None:
87
87
  global flog, clog
88
- flog = _file_log(sys_log_path='{}/ck-py-log/'.format(path_info.get_user_work_path()), mixin=True, log_level=logging.DEBUG)
88
+ flog = _file_log(path_info.get_user_work_path(".ck/ck-py-log", mkdir=True), mixin=True, log_level=logging.DEBUG)
89
89
  clog = _console_log()
90
90
  sys.stdout = StreamToLogger(flog, level=logging.INFO)
91
91
  sys.stderr = StreamToLogger(flog, level=logging.ERROR)
ctools/util/cklock.py ADDED
@@ -0,0 +1,118 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: UTF-8 -*-
3
+ __author__ = 'haoyang'
4
+ __date__ = '2025/7/18 15:46'
5
+
6
+ import contextvars
7
+ import threading
8
+ from contextlib import contextmanager
9
+ from functools import wraps
10
+
11
+ from ctools.stream.credis import get_redis, add_lock, remove_lock
12
+ from ctools.web import ctoken
13
+ from ctools.web.api_result import R
14
+
15
+ # 全局锁容器
16
+ _lock_dict = {}
17
+ _lock_dict_lock = threading.Lock()
18
+
19
+ def try_acquire_lock(key: str) -> bool:
20
+ with _lock_dict_lock:
21
+ if key not in _lock_dict:
22
+ _lock_dict[key] = threading.Lock()
23
+ return _lock_dict[key].acquire(blocking=False)
24
+
25
+ def try_acquire_lock_block(key: str):
26
+ with _lock_dict_lock:
27
+ if key not in _lock_dict:
28
+ _lock_dict[key] = threading.Lock()
29
+ _lock = _lock_dict[key]
30
+ _lock.acquire() # 这里是阻塞的
31
+
32
+ def release_lock(key: str):
33
+ with _lock_dict_lock:
34
+ _lock = _lock_dict.get(key)
35
+ if _lock and _lock.locked():
36
+ _lock.release()
37
+ if _lock and not _lock.locked():
38
+ _lock_dict.pop(key, None)
39
+
40
+ @contextmanager
41
+ def try_lock(key: str="sys_lock", block=False):
42
+ if not block:
43
+ acquired = try_acquire_lock(key)
44
+ try:
45
+ yield acquired
46
+ finally:
47
+ if acquired:
48
+ release_lock(key)
49
+ else:
50
+ try_acquire_lock_block(key)
51
+ try:
52
+ yield
53
+ finally:
54
+ release_lock(key)
55
+
56
+ #annotation
57
+ """
58
+ @lock("params.attr")
59
+ """
60
+ # 上下文保存锁key集合
61
+ current_locks = contextvars.ContextVar("current_locks", default=set())
62
+
63
+ def lock(lock_attrs=None):
64
+ def decorator(func):
65
+ @wraps(func)
66
+ def wrapper(*args, **kwargs):
67
+ lock_key = ""
68
+ nonlocal lock_attrs
69
+ user_level_lock = False
70
+
71
+ if not lock_attrs:
72
+ user_id = ctoken.get_user_id()
73
+ if user_id:
74
+ user_level_lock = True
75
+ lock_key = f"USER_ID_LOCK_{user_id}"
76
+ else:
77
+ raise ValueError("请设置 lock_attrs 或使用 token!")
78
+
79
+ if not user_level_lock:
80
+ if isinstance(lock_attrs, str): lock_attrs = [lock_attrs]
81
+ try:
82
+ for attr in lock_attrs:
83
+ parts = attr.split(".")
84
+ if len(parts) != 2:
85
+ raise ValueError(f"lock_attr: {attr} 格式错误")
86
+ obj = kwargs.get(parts[0]) or args[0]
87
+ if obj is None:
88
+ raise ValueError(f"参数 {parts[0]} 不存在")
89
+ lock_key += f"_{getattr(obj, parts[1], None)}"
90
+ except Exception as e:
91
+ raise ValueError(f"生成锁键失败: {e}")
92
+
93
+ lock_set = current_locks.get()
94
+ if lock_key in lock_set:
95
+ return func(*args, **kwargs)
96
+ token = current_locks.set(lock_set | {lock_key})
97
+
98
+ try:
99
+ if not get_redis():
100
+ with try_lock(lock_key) as locked:
101
+ if not locked:
102
+ return R.error("操作过于频繁, 请稍后再试")
103
+ return func(*args, **kwargs)
104
+ else:
105
+ locked = add_lock(get_redis(), lock_key)
106
+ try:
107
+ if locked:
108
+ return func(*args, **kwargs)
109
+ else:
110
+ return R.error("操作过于频繁, 请稍后再试")
111
+ finally:
112
+ if locked:
113
+ remove_lock(get_redis(), lock_key)
114
+ finally:
115
+ current_locks.reset(token)
116
+ return wrapper
117
+ return decorator
118
+