pro-craft 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pro-craft might be problematic. Click here for more details.

@@ -1,131 +1,124 @@
1
-
2
- from fastapi import FastAPI, HTTPException, Header
1
+ # server
2
+ from fastapi import FastAPI, HTTPException
3
3
  from fastapi.middleware.cors import CORSMiddleware
4
- from fastapi import FastAPI, Depends
5
- from fastapi_users import FastAPIUsers
6
- from contextlib import asynccontextmanager
7
- from .utils import create_db_and_tables
8
- from .utils import get_user_manager, auth_backend
9
- from .models.models import User,UserCreate, UserRead,UserUpdate
10
- from .routers import admin_router,user_router
11
- from pydantic import BaseModel, Field, model_validator
4
+ from prompt_writing_assistant.log import Log
5
+ from dotenv import load_dotenv, find_dotenv
6
+ from contextlib import asynccontextmanager, AsyncExitStack
7
+ from fastapi import FastAPI, Request
8
+ from fastapi.routing import APIRoute
12
9
  import argparse
13
10
  import uvicorn
14
- import uuid
15
- from umanager.log import Log
16
11
 
17
- # 应用程序生命周期事件
18
- @asynccontextmanager
19
- async def lifespan(app: FastAPI):
20
- # 应用启动时创建数据库表
21
- await create_db_and_tables()
22
- yield
12
+ from .mcp.math import mcp as fm_math
13
+ from .mcp.weather import mcp as fm_weather
14
+
23
15
 
16
+ default = 8007
24
17
 
25
- default=8008
18
+ dotenv_path = find_dotenv()
19
+ load_dotenv(dotenv_path, override=True)
20
+ logger = Log.logger
21
+
22
+
23
+ # Combine both lifespans
24
+ @asynccontextmanager
25
+ async def combined_lifespan(app: FastAPI):
26
+ # Run both lifespans
27
+ async with AsyncExitStack() as stack:
28
+ await stack.enter_async_context(fm_math.session_manager.run())
29
+ await stack.enter_async_context(fm_weather.session_manager.run())
30
+ yield
26
31
 
27
32
  app = FastAPI(
28
- lifespan=lifespan,
29
33
  title="LLM Service",
30
34
  description="Provides an OpenAI-compatible API for custom large language models.",
31
35
  version="1.0.1",
36
+ # debug=True,
37
+ # docs_url="/api-docs",
38
+ lifespan=combined_lifespan
32
39
  )
33
40
 
34
- fastapi_users = FastAPIUsers[User, uuid.UUID](
35
- get_user_manager,
36
- [auth_backend],
37
- )
38
-
39
-
40
41
  # --- Configure CORS ---
41
42
  origins = [
42
- "*", # Allows all origins (convenient for development, insecure for production)
43
- # Add the specific origin of your "别的调度" tool/frontend if known
44
- # e.g., "http://localhost:5173" for a typical Vite frontend dev server
45
- # e.g., "http://127.0.0.1:5173"
43
+ "*",
46
44
  ]
47
45
 
48
46
  app.add_middleware(
49
47
  CORSMiddleware,
50
48
  allow_origins=origins, # Specifies the allowed origins
51
- allow_credentials=True, # Allows cookies/authorization headers
52
- allow_methods=["*"], # Allows all methods (GET, POST, OPTIONS, etc.)
53
- allow_headers=["*"], # Allows all headers (Content-Type, Authorization, etc.)
49
+ allow_credentials=True, # Allows cookies/authorization headers
50
+ allow_methods=["*"], # Allows all methods (GET, POST, OPTIONS, etc.)
51
+ allow_headers=["*"], # Allows all headers (Content-Type, Authorization, etc.)
54
52
  )
55
53
  # --- End CORS Configuration ---
56
54
 
57
- # 注册认证和用户管理路由
58
- app.include_router(
59
- fastapi_users.get_auth_router(auth_backend),
60
- prefix="/auth/jwt",
61
- tags=["auth"],
62
- )
63
- app.include_router(
64
- fastapi_users.get_register_router(UserRead, UserCreate),
65
- prefix="/auth",
66
- tags=["auth"],
67
- )
68
- app.include_router(
69
- fastapi_users.get_reset_password_router(),
70
- prefix="/auth",
71
- tags=["auth"],
72
- )
73
- app.include_router(
74
- fastapi_users.get_verify_router(UserRead),
75
- prefix="/auth",
76
- tags=["auth"],
77
- )
78
- app.include_router(
79
- fastapi_users.get_users_router(UserRead, UserUpdate),
80
- prefix="/users",
81
- tags=["users"],
82
- )
83
55
 
84
56
 
85
- # -------------------- 集成待办事项路由 --------------------
86
- # app.include_router(admin_router, prefix="/")
87
- # app.include_router(user_router, prefix="/")
88
57
 
89
- app.include_router(admin_router,prefix="/admin")
90
- app.include_router(user_router,prefix="/users")
58
+ app.mount("/math", fm_math.streamable_http_app()) # /math/mcp
59
+ app.mount("/weather", fm_weather.streamable_http_app()) # /weather/mcp
60
+
61
+
62
+
91
63
 
92
64
 
93
65
  @app.get("/")
94
66
  async def root():
95
- """ x """
67
+ """server run"""
96
68
  return {"message": "LLM Service is running."}
97
69
 
98
70
 
71
+ @app.get("/api/status")
72
+ def status():
73
+ return {"status": "ok"}
74
+
75
+ @app.get("/api/list-routes/")
76
+ async def list_fastapi_routes(request: Request):
77
+ routes_data = []
78
+ for route in request.app.routes:
79
+ if isinstance(route, APIRoute):
80
+ routes_data.append({
81
+ "path": route.path,
82
+ "name": route.name,
83
+ "methods": list(route.methods),
84
+ "endpoint": route.endpoint.__name__ # Get the name of the function
85
+ })
86
+ return {"routes": routes_data}
87
+
99
88
 
100
89
 
101
90
  if __name__ == "__main__":
91
+ # 这是一个标准的 Python 入口点惯用法
92
+ # 当脚本直接运行时 (__name__ == "__main__"),这里的代码会被执行
93
+ # 当通过 python -m YourPackageName 执行 __main__.py 时,__name__ 也是 "__main__"
94
+ # 27
102
95
 
103
96
  parser = argparse.ArgumentParser(
104
97
  description="Start a simple HTTP server similar to http.server."
105
98
  )
106
99
  parser.add_argument(
107
- 'port',
108
- metavar='PORT',
100
+ "port",
101
+ metavar="PORT",
109
102
  type=int,
110
- nargs='?', # 端口是可选的
103
+ nargs="?", # 端口是可选的
111
104
  default=default,
112
- help=f'Specify alternate port [default: {default}]'
105
+ help=f"Specify alternate port [default: {default}]",
113
106
  )
114
107
  # 创建一个互斥组用于环境选择
115
108
  group = parser.add_mutually_exclusive_group()
116
109
 
117
110
  # 添加 --dev 选项
118
111
  group.add_argument(
119
- '--dev',
120
- action='store_true', # 当存在 --dev 时,该值为 True
121
- help='Run in development mode (default).'
112
+ "--dev",
113
+ action="store_true", # 当存在 --dev 时,该值为 True
114
+ help="Run in development mode (default).",
122
115
  )
123
116
 
124
117
  # 添加 --prod 选项
125
118
  group.add_argument(
126
- '--prod',
127
- action='store_true', # 当存在 --prod 时,该值为 True
128
- help='Run in production mode.'
119
+ "--prod",
120
+ action="store_true", # 当存在 --prod 时,该值为 True
121
+ help="Run in production mode.",
129
122
  )
130
123
  args = parser.parse_args()
131
124
 
@@ -136,25 +129,26 @@ if __name__ == "__main__":
136
129
  env = "dev"
137
130
 
138
131
  port = args.port
132
+
139
133
  if env == "dev":
140
134
  port += 100
141
- Log.reset_level('debug',env = env)
135
+ Log.reset_level("debug", env=env)
142
136
  reload = True
143
- app_import_string = f"{__package__}.__main__:app" # <--- 关键修改:传递导入字符串
137
+ app_import_string = (
138
+ f"{__package__}.__main__:app" # <--- 关键修改:传递导入字符串
139
+ )
144
140
  elif env == "prod":
145
- Log.reset_level('info',env = env)# ['debug', 'info', 'warning', 'error', 'critical']
141
+ Log.reset_level(
142
+ "info", env=env
143
+ ) # ['debug', 'info', 'warning', 'error', 'critical']
146
144
  reload = False
147
145
  app_import_string = app
148
146
  else:
149
147
  reload = False
150
148
  app_import_string = app
151
-
152
149
 
153
150
  # 使用 uvicorn.run() 来启动服务器
154
151
  # 参数对应于命令行选项
155
152
  uvicorn.run(
156
- app_import_string,
157
- host="0.0.0.0",
158
- port=port,
159
- reload=reload # 启用热重载
153
+ app_import_string, host="0.0.0.0", port=port, reload=reload # 启用热重载
160
154
  )
@@ -0,0 +1,6 @@
1
+
2
+ # server
3
+ from typing import Dict, Any, Optional, List
4
+ from pydantic import BaseModel, Field, model_validator
5
+
6
+
@@ -0,0 +1,283 @@
1
+ # server
2
+ # 推荐算法
3
+
4
+ from ..models import UpdateItem, DeleteResponse, DeleteRequest, QueryItem
5
+ from fastapi import APIRouter, Depends, HTTPException, status
6
+ from typing import Dict, Any
7
+ from diglife.embedding_pool import EmbeddingPool
8
+ from diglife.log import Log
9
+ import os
10
+ import httpx
11
+
12
+ from dotenv import load_dotenv, find_dotenv
13
+
14
+ dotenv_path = find_dotenv()
15
+ load_dotenv(dotenv_path, override=True)
16
+
17
+
18
+
19
+ router = APIRouter(tags=["recommended"])
20
+
21
+ logger = Log.logger
22
+
23
+ recommended_biographies_cache_max_leng = os.getenv("recommended_biographies_cache_max_leng",2) #config.get("recommended_biographies_cache_max_leng", 2)
24
+ recommended_biographies_cache_max_leng = int(recommended_biographies_cache_max_leng)
25
+ recommended_cache_max_leng = os.getenv("recommended_cache_max_leng",2) #config.get("recommended_cache_max_leng", 2)
26
+ recommended_cache_max_leng = int(recommended_cache_max_leng)
27
+ user_server_base_url = "http://182.92.107.224:7000"
28
+
29
+ ep = EmbeddingPool()
30
+ recommended_biographies_cache: Dict[str, Dict[str, Any]] = {}
31
+ recommended_figure_cache: Dict[str, Dict[str, Any]] = {}
32
+
33
+ @router.post(
34
+ "/update", # 推荐使用POST请求进行数据更新
35
+ summary="更新或添加文本嵌入",
36
+ description="将给定的文本内容与一个ID关联并更新到Embedding池中。",
37
+ response_description="表示操作是否成功。",
38
+ )
39
+ def recommended_update(item: UpdateItem):
40
+ """记忆卡片是0 传记是1
41
+ 记忆卡片是0
42
+ 记忆卡片上传的是记忆卡片的内容 str
43
+ 记忆卡片id
44
+ 0
45
+
46
+ 传记是1
47
+ 上传的是传记简介 str
48
+ 传记id
49
+ 1
50
+
51
+ 数字分身是2
52
+ 上传数字分身简介和性格描述 str
53
+ 数字分身id
54
+ 2
55
+ """
56
+ # TODO 需要一个反馈状态
57
+ try:
58
+ if item.type in [0, 1, 2]: # 上传的是卡片
59
+ ep.update(text=item.text, id=item.id, type=item.type)
60
+ else:
61
+ logger.error(f"Error updating EmbeddingPool for ID '{item.id}': {e}")
62
+ raise HTTPException(
63
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
64
+ detail=f"Failed to update embedding for ID '{item.id}': {e}",
65
+ )
66
+
67
+ return {"status": "success", "message": f"ID '{item.id}' updated successfully."}
68
+
69
+ except ValueError as e: # 假设EmbeddingPool.update可能抛出ValueError
70
+ logger.warning(f"Validation error during update: {e}")
71
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
72
+ except Exception as e:
73
+ logger.error(f"Error updating EmbeddingPool for ID '{item.id}': {e}")
74
+ raise HTTPException(
75
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
76
+ detail=f"Failed to update embedding for ID '{item.id}': {e}",
77
+ )
78
+
79
+
80
+ @router.post("/delete", response_model=DeleteResponse, description="delete")
81
+ async def delete_server(request: DeleteRequest):
82
+
83
+ logger.info("running delete_server")
84
+
85
+ # TODO 需要一个反馈状态
86
+ result = ep.delete(id=request.id) # 包裹的内核函数
87
+
88
+ ########
89
+ return DeleteResponse(
90
+ status="success",
91
+ )
92
+
93
+
94
+
95
+
96
+ # async def aget_content_by_id(url = ""):
97
+ # # url = url.format(user_profile_id = user_profile_id)
98
+ # async with httpx.AsyncClient() as client:
99
+ # try:
100
+ # response = await client.get(url)
101
+ # response.raise_for_status() # 如果状态码是 4xx 或 5xx,会抛出 HTTPStatusError 异常
102
+
103
+ # print(f"Status Code: {response.status_code}")
104
+ # print(f"Response Body: {response.json()}") # 假设返回的是 JSON
105
+ # return response.json()
106
+ # except httpx.HTTPStatusError as e:
107
+ # print(f"HTTP error occurred: {e.response.status_code} - {e.response.text}")
108
+ # except httpx.RequestError as e:
109
+ # print(f"An error occurred while requesting {e.request.url!r}: {e}")
110
+ # except Exception as e:
111
+ # print(f"An unexpected error occurred: {e}")
112
+ # return None
113
+
114
+ async def aget_(url = ""):
115
+ async with httpx.AsyncClient() as client:
116
+ try:
117
+ response = await client.get(url)
118
+ response.raise_for_status() # 如果状态码是 4xx 或 5xx,会抛出 HTTPStatusError 异常
119
+
120
+ print(f"Status Code: {response.status_code}")
121
+ print(f"Response Body: {response.json()}") # 假设返回的是 JSON
122
+ return response.json()
123
+ except httpx.HTTPStatusError as e:
124
+ print(f"HTTP error occurred: {e.response.status_code} - {e.response.text}")
125
+ except httpx.RequestError as e:
126
+ print(f"An error occurred while requesting {e.request.url!r}: {e}")
127
+ except Exception as e:
128
+ print(f"An unexpected error occurred: {e}")
129
+ return None
130
+
131
+ @router.post(
132
+ "/search_biographies_and_cards",
133
+ summary="搜索传记和记忆卡片",
134
+ description="搜索传记和记忆卡片",
135
+ response_description="搜索结果列表。",
136
+ )
137
+ async def recommended_biographies_and_cards(query_item: QueryItem):
138
+ """
139
+ # result = [
140
+ # {
141
+ # "id": "1916693308020916225", # 传记ID
142
+ # "type": 1,
143
+ # "order": 0,
144
+ # },
145
+ # {
146
+ # "id": "1962459564012359682", # 卡片ID
147
+ # "type": 0,
148
+ # "order": 1,
149
+ # },
150
+ # {
151
+ # "id": "1916389315373727745", # 传记ID
152
+ # "type": 1,
153
+ # "order": 2,
154
+ # },
155
+ # ]
156
+
157
+ {
158
+ "text":"这是一个传记001",
159
+ "id":"1916693308020916225",
160
+ "type":1
161
+ }
162
+ {
163
+ "text":"这是一个传记002",
164
+ "id":"1916389315373727745",
165
+ "type":1
166
+ }
167
+ {
168
+ "text":"这是一个卡片001",
169
+ "id":"1962459564012359682",
170
+ "type":0
171
+ }
172
+ """
173
+ try:
174
+ # TODO 需要一个通过id 获取对应内容的接口
175
+ # TODO 调用id 获得对应的用户简介 query_item.user_id
176
+
177
+
178
+ user_profile_id_to_fetch = query_item.user_id
179
+ # memory_info = await aget_content_by_id(user_profile_id_to_fetch,url = user_server_base_url + "/api/inner/getMemoryCards?userProfileId={user_profile_id}")
180
+ memory_info = await aget_(url = user_server_base_url + f"/api/inner/getMemoryCards?userProfileId={user_profile_id_to_fetch}")
181
+ # memory_info = await get_memorycards_by_id(user_profile_id_to_fetch)
182
+ user_brief = '\n'.join([i.get('content') for i in memory_info['data']["memoryCards"][:4]])
183
+
184
+
185
+ result = ep.search_bac(query=user_brief)
186
+
187
+ if recommended_biographies_cache.get(query_item.user_id):
188
+ clear_result = [
189
+ i
190
+ for i in result
191
+ if i.get("id")
192
+ not in recommended_biographies_cache.get(query_item.user_id)
193
+ ]
194
+ else:
195
+ recommended_biographies_cache[query_item.user_id] = []
196
+ clear_result = result
197
+
198
+ recommended_biographies_cache[query_item.user_id] += [
199
+ i.get("id") for i in result
200
+ ]
201
+ recommended_biographies_cache[query_item.user_id] = list(
202
+ set(recommended_biographies_cache[query_item.user_id])
203
+ )
204
+ if (
205
+ len(recommended_biographies_cache[query_item.user_id])
206
+ > recommended_biographies_cache_max_leng
207
+ ):
208
+ recommended_biographies_cache[query_item.user_id] = []
209
+
210
+ return {
211
+ "status": "success",
212
+ "result": clear_result,
213
+ "query": query_item.user_id,
214
+ }
215
+
216
+ except Exception as e:
217
+ logger.error(
218
+ f"Error searching EmbeddingPool for query '{query_item.user_id}': {e}"
219
+ )
220
+ raise HTTPException(
221
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
222
+ detail=f"Failed to perform search: {e}",
223
+ )
224
+
225
+
226
+
227
+ @router.post(
228
+ "/search_figure_person",
229
+ description="搜索数字分身的",
230
+ )
231
+ async def recommended_figure_person(query_item: QueryItem):
232
+ """
233
+
234
+ """
235
+ try:
236
+
237
+ user_profile_id_to_fetch = query_item.user_id
238
+ # avatar_info = await aget_avatar_desc_by_id(user_profile_id_to_fetch)
239
+ # avatar_info = await aget_content_by_id(user_profile_id_to_fetch,url = user_server_base_url + "/api/inner/getAvatarDesc?userProfileId={user_profile_id}")
240
+ avatar_info = await aget_(url = user_server_base_url + f"/api/inner/getAvatarDesc?userProfileId={user_profile_id_to_fetch}")
241
+ print(avatar_info,'avatar_info')
242
+ if avatar_info["code"] == 200:
243
+ user_brief = avatar_info["data"].get("avatarDesc")
244
+ else:
245
+ user_brief = "这是一个简单的人"
246
+
247
+ result = ep.search_figure_person(query=user_brief) # 100+
248
+
249
+ if recommended_figure_cache.get(query_item.user_id):
250
+ # 不需要创建
251
+ clear_result = [
252
+ i
253
+ for i in result
254
+ if i.get("id") not in recommended_figure_cache.get(query_item.user_id)
255
+ ]
256
+ else:
257
+ recommended_figure_cache[query_item.user_id] = []
258
+ clear_result = result
259
+
260
+ recommended_figure_cache[query_item.user_id] += [i.get("id") for i in result]
261
+ recommended_figure_cache[query_item.user_id] = list(
262
+ set(recommended_figure_cache[query_item.user_id])
263
+ )
264
+ if (
265
+ len(recommended_figure_cache[query_item.user_id])
266
+ > recommended_cache_max_leng
267
+ ):
268
+ recommended_figure_cache[query_item.user_id] = []
269
+ return {
270
+ "status": "success",
271
+ "result": clear_result,
272
+ "query": query_item.user_id,
273
+ }
274
+
275
+ except Exception as e:
276
+ logger.error(
277
+ f"Error searching EmbeddingPool for query '{query_item.user_id}': {e}"
278
+ )
279
+ raise HTTPException(
280
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
281
+ detail=f"Failed to perform search: {e}",
282
+ )
283
+
pro_craft/utils.py CHANGED
@@ -1 +1,161 @@
1
- from utils_tool import *
1
+ '''
2
+ Author: 823042332@qq.com 823042332@qq.com
3
+ Date: 2025-08-28 09:07:54
4
+ LastEditors: 823042332@qq.com 823042332@qq.com
5
+ LastEditTime: 2025-08-28 09:30:32
6
+ FilePath: /pro_craft/src/pro_craft/unit.py
7
+ Description: 这是默认设置,请设置`customMade`, 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE
8
+ '''
9
+ import re
10
+ import inspect
11
+ import importlib
12
+ import yaml
13
+ import zlib
14
+ from volcenginesdkarkruntime import Ark
15
+ import os
16
+
17
+
18
+ def extract_(text: str, pattern_key = r"json",multi = False):
19
+ pattern = r"```"+ pattern_key + r"([\s\S]*?)```"
20
+ matches = re.findall(pattern, text)
21
+ if multi:
22
+ [match.strip() for match in matches]
23
+ if matches:
24
+ return [match.strip() for match in matches]
25
+ else:
26
+ return "" # 返回空字符串或抛出异常,此处返回空字符串
27
+ else:
28
+ if matches:
29
+ return matches[0].strip() # 添加strip()去除首尾空白符
30
+ else:
31
+ return "" # 返回空字符串或抛出异常,此处返回空字符串
32
+
33
+
34
+ def extract_from_loaded_objects(obj_list):
35
+ results = []
36
+ for obj in obj_list:
37
+ if inspect.isclass(obj):
38
+ class_info = {
39
+ "type": "class",
40
+ "name": obj.__name__,
41
+ "docstring": inspect.getdoc(obj),
42
+ "signature": f"class {obj.__name__}{inspect.getclasstree([obj], unique=True)[0][0].__bases__}:" if inspect.getclasstree([obj], unique=True)[0][0].__bases__ != (object,) else f"class {obj.__name__}:", # 尝试获取基类
43
+ "methods": []
44
+ }
45
+ # 遍历类的方法
46
+ for name, member in inspect.getmembers(obj, predicate=inspect.isfunction):
47
+ if name.startswith('__') and name != '__init__': # 过滤掉大多数魔术方法,但保留 __init__
48
+ continue
49
+
50
+ # inspect.signature 可以获取更精确的签名
51
+ sig = inspect.signature(member)
52
+ is_async = inspect.iscoroutinefunction(member)
53
+
54
+ method_info = {
55
+ "type": "method",
56
+ "name": name,
57
+ "docstring": inspect.getdoc(member),
58
+ "signature": f"{'async ' if is_async else ''}def {name}{sig}:",
59
+ "is_async": is_async
60
+ }
61
+ class_info["methods"].append(method_info)
62
+ results.append(class_info)
63
+ elif inspect.isfunction(obj) or inspect.iscoroutinefunction(obj):
64
+ is_async = inspect.iscoroutinefunction(obj)
65
+ sig = inspect.signature(obj)
66
+ results.append({
67
+ "type": "function",
68
+ "name": obj.__name__,
69
+ "docstring": inspect.getdoc(obj),
70
+ "signature": f"{'async ' if is_async else ''}def {obj.__name__}{sig}:",
71
+ "is_async": is_async
72
+ })
73
+ return results
74
+
75
+
76
+ def get_adler32_hash(s):
77
+ return zlib.adler32(s.encode('utf-8'))
78
+
79
+ def embedding_inputs(inputs:list[str],model_name = None):
80
+ model_name = model_name or os.getenv("Ark_model_name")
81
+ ark_client = Ark(api_key=os.getenv("Ark_api_key"))
82
+
83
+ resp = ark_client.embeddings.create(
84
+ model=model_name,
85
+ input=inputs,
86
+ encoding_format="float",
87
+ )
88
+ return [i.embedding for i in resp.data]
89
+
90
+ def load_inpackage_file(package_name:str, file_name:str,file_type = 'yaml'):
91
+ """ load config """
92
+ with importlib.resources.open_text(package_name, file_name) as f:
93
+ if file_type == 'yaml':
94
+ return yaml.safe_load(f)
95
+ else:
96
+ return f.read()
97
+
98
+
99
+ def super_print(s,target:str):
100
+ print()
101
+ print()
102
+ print("=="*21 + target + "=="*21)
103
+ print()
104
+ print("=="*50)
105
+ print(type(s))
106
+ print("=="*50)
107
+ print(s)
108
+ print("=="*50)
109
+ print()
110
+
111
+
112
+ from sqlalchemy.orm import sessionmaker
113
+
114
+ from contextlib import contextmanager
115
+ @contextmanager
116
+ def create_session(engine):
117
+ # 5. 创建会话 (Session)
118
+ # Session 是与数据库交互的主要接口,它管理着你的对象和数据库之间的持久化操作
119
+ Session = sessionmaker(bind=engine)
120
+ session = Session()
121
+ try:
122
+ yield session
123
+
124
+ except Exception as e:
125
+ print(f"An error occurred: {e}")
126
+ session.rollback() # 发生错误时回滚事务
127
+ finally:
128
+ session.close() # 关闭会话,释放资源
129
+
130
+ from contextlib import contextmanager
131
+ from sqlalchemy import create_engine, Column, Integer, String, UniqueConstraint
132
+ from sqlalchemy.orm import declarative_base, sessionmaker
133
+ from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine # 异步核心
134
+ import asyncio
135
+
136
+ from contextlib import contextmanager
137
+
138
+ from contextlib import asynccontextmanager # 注意这里是 asynccontextmanager
139
+ import asyncio
140
+
141
+
142
+ @asynccontextmanager
143
+ async def create_async_session(async_engine):
144
+ # 5. 创建会话 (Session)
145
+ # Session 是与数据库交互的主要接口,它管理着你的对象和数据库之间的持久化操作
146
+ Session = sessionmaker(bind=async_engine,
147
+ expire_on_commit=False,
148
+ class_=AsyncSession
149
+ )
150
+ session = Session()
151
+ try:
152
+ yield session
153
+ await session.commit() # 在成功的情况下自动提交事务
154
+
155
+ except Exception as e:
156
+ print(f"An error occurred: {e}")
157
+ await session.rollback() # 发生错误时回滚事务
158
+ raise # 重新抛出异常,让调用者知道操作失败
159
+ finally:
160
+ await session.close() # 关闭会话,释放资源
161
+