local-coze 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- local_coze/__init__.py +110 -0
- local_coze/cli/__init__.py +3 -0
- local_coze/cli/chat.py +126 -0
- local_coze/cli/cli.py +34 -0
- local_coze/cli/constants.py +7 -0
- local_coze/cli/db.py +81 -0
- local_coze/cli/embedding.py +193 -0
- local_coze/cli/image.py +162 -0
- local_coze/cli/knowledge.py +195 -0
- local_coze/cli/search.py +198 -0
- local_coze/cli/utils.py +41 -0
- local_coze/cli/video.py +191 -0
- local_coze/cli/video_edit.py +888 -0
- local_coze/cli/voice.py +351 -0
- local_coze/core/__init__.py +25 -0
- local_coze/core/client.py +253 -0
- local_coze/core/config.py +58 -0
- local_coze/core/exceptions.py +67 -0
- local_coze/database/__init__.py +29 -0
- local_coze/database/client.py +170 -0
- local_coze/database/migration.py +342 -0
- local_coze/embedding/__init__.py +31 -0
- local_coze/embedding/client.py +350 -0
- local_coze/embedding/models.py +130 -0
- local_coze/image/__init__.py +19 -0
- local_coze/image/client.py +110 -0
- local_coze/image/models.py +163 -0
- local_coze/knowledge/__init__.py +19 -0
- local_coze/knowledge/client.py +148 -0
- local_coze/knowledge/models.py +45 -0
- local_coze/llm/__init__.py +25 -0
- local_coze/llm/client.py +317 -0
- local_coze/llm/models.py +48 -0
- local_coze/memory/__init__.py +14 -0
- local_coze/memory/client.py +176 -0
- local_coze/s3/__init__.py +12 -0
- local_coze/s3/client.py +580 -0
- local_coze/s3/models.py +18 -0
- local_coze/search/__init__.py +19 -0
- local_coze/search/client.py +183 -0
- local_coze/search/models.py +57 -0
- local_coze/video/__init__.py +17 -0
- local_coze/video/client.py +347 -0
- local_coze/video/models.py +39 -0
- local_coze/video_edit/__init__.py +23 -0
- local_coze/video_edit/examples.py +340 -0
- local_coze/video_edit/frame_extractor.py +176 -0
- local_coze/video_edit/models.py +362 -0
- local_coze/video_edit/video_edit.py +631 -0
- local_coze/voice/__init__.py +17 -0
- local_coze/voice/asr.py +82 -0
- local_coze/voice/models.py +86 -0
- local_coze/voice/tts.py +94 -0
- local_coze-0.0.1.dist-info/METADATA +636 -0
- local_coze-0.0.1.dist-info/RECORD +58 -0
- local_coze-0.0.1.dist-info/WHEEL +4 -0
- local_coze-0.0.1.dist-info/entry_points.txt +3 -0
- local_coze-0.0.1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from typing import Optional, Any
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class CozeSDKError(Exception):
|
|
5
|
+
def __init__(
|
|
6
|
+
self,
|
|
7
|
+
message: str,
|
|
8
|
+
code: Optional[str] = None,
|
|
9
|
+
details: Optional[dict] = None
|
|
10
|
+
):
|
|
11
|
+
self.message = message
|
|
12
|
+
self.code = code
|
|
13
|
+
self.details = details or {}
|
|
14
|
+
super().__init__(self.message)
|
|
15
|
+
|
|
16
|
+
def __str__(self) -> str:
|
|
17
|
+
if self.code:
|
|
18
|
+
return f"[{self.code}] {self.message}"
|
|
19
|
+
return self.message
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ConfigurationError(CozeSDKError):
|
|
23
|
+
def __init__(self, message: str, missing_key: Optional[str] = None):
|
|
24
|
+
self.missing_key = missing_key
|
|
25
|
+
super().__init__(
|
|
26
|
+
message=message,
|
|
27
|
+
code="CONFIGURATION_ERROR",
|
|
28
|
+
details={"missing_key": missing_key}
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class APIError(CozeSDKError):
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
message: str,
|
|
36
|
+
code: Optional[str] = None,
|
|
37
|
+
status_code: Optional[int] = None,
|
|
38
|
+
response_data: Optional[dict] = None
|
|
39
|
+
):
|
|
40
|
+
self.status_code = status_code
|
|
41
|
+
self.response_data = response_data or {}
|
|
42
|
+
super().__init__(
|
|
43
|
+
message=message,
|
|
44
|
+
code=code,
|
|
45
|
+
details={"status_code": status_code, "response_data": response_data}
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class NetworkError(CozeSDKError):
|
|
50
|
+
def __init__(self, message: str, original_error: Optional[Exception] = None):
|
|
51
|
+
self.original_error = original_error
|
|
52
|
+
super().__init__(
|
|
53
|
+
message=f"网络请求失败: {message}",
|
|
54
|
+
code="NETWORK_ERROR",
|
|
55
|
+
details={"original_error": str(original_error) if original_error else None}
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class ValidationError(CozeSDKError):
|
|
60
|
+
def __init__(self, message: str, field: Optional[str] = None, value: Optional[Any] = None):
|
|
61
|
+
self.field = field
|
|
62
|
+
self.value = value
|
|
63
|
+
super().__init__(
|
|
64
|
+
message=message,
|
|
65
|
+
code="VALIDATION_ERROR",
|
|
66
|
+
details={"field": field, "value": str(value) if value else None}
|
|
67
|
+
)
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
"""
|
|
2
|
+
数据库模块
|
|
3
|
+
提供 PostgreSQL 数据库连接、会话管理、ORM 基类和迁移功能
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from .client import (
|
|
7
|
+
Base,
|
|
8
|
+
get_db_url,
|
|
9
|
+
get_engine,
|
|
10
|
+
get_session,
|
|
11
|
+
get_sessionmaker,
|
|
12
|
+
)
|
|
13
|
+
from .migration import (
|
|
14
|
+
generate_models,
|
|
15
|
+
upgrade,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
# ORM 基类
|
|
20
|
+
"Base",
|
|
21
|
+
# 数据库连接
|
|
22
|
+
"get_db_url",
|
|
23
|
+
"get_engine",
|
|
24
|
+
"get_session",
|
|
25
|
+
"get_sessionmaker",
|
|
26
|
+
# 迁移功能
|
|
27
|
+
"generate_models",
|
|
28
|
+
"upgrade",
|
|
29
|
+
]
|
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
"""
|
|
2
|
+
数据库连接模块
|
|
3
|
+
提供 PostgreSQL 数据库连接、会话管理和 ORM 基类
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import os
|
|
7
|
+
import time
|
|
8
|
+
import logging
|
|
9
|
+
from typing import Optional
|
|
10
|
+
|
|
11
|
+
from sqlalchemy import create_engine, text
|
|
12
|
+
from sqlalchemy.orm import sessionmaker, DeclarativeBase, Session
|
|
13
|
+
from sqlalchemy.exc import OperationalError
|
|
14
|
+
from sqlalchemy.engine import Engine
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
MAX_RETRY_TIME = 20 # 连接最大重试时间(秒)
|
|
19
|
+
|
|
20
|
+
_engine: Optional[Engine] = None
|
|
21
|
+
_SessionLocal = None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class Base(DeclarativeBase):
|
|
25
|
+
"""SQLAlchemy ORM 模型基类"""
|
|
26
|
+
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _load_env() -> None:
|
|
31
|
+
"""加载环境变量(内部使用)"""
|
|
32
|
+
try:
|
|
33
|
+
from dotenv import load_dotenv
|
|
34
|
+
|
|
35
|
+
load_dotenv()
|
|
36
|
+
except ImportError:
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def get_db_url() -> str:
|
|
41
|
+
"""
|
|
42
|
+
获取数据库连接 URL
|
|
43
|
+
|
|
44
|
+
优先从环境变量 PGDATABASE_URL 获取,
|
|
45
|
+
如果不存在则尝试从 coze_workload_identity 获取。
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
str: 数据库连接 URL
|
|
49
|
+
|
|
50
|
+
Raises:
|
|
51
|
+
ValueError: 如果无法获取数据库 URL
|
|
52
|
+
"""
|
|
53
|
+
_load_env()
|
|
54
|
+
|
|
55
|
+
url = os.getenv("PGDATABASE_URL") or ""
|
|
56
|
+
if url:
|
|
57
|
+
return url
|
|
58
|
+
|
|
59
|
+
try:
|
|
60
|
+
from coze_workload_identity import Client
|
|
61
|
+
|
|
62
|
+
client = Client()
|
|
63
|
+
env_vars = client.get_project_env_vars()
|
|
64
|
+
client.close()
|
|
65
|
+
for env_var in env_vars:
|
|
66
|
+
if env_var.key == "PGDATABASE_URL":
|
|
67
|
+
url = env_var.value.replace("'", "'\\''")
|
|
68
|
+
return url
|
|
69
|
+
except Exception as e:
|
|
70
|
+
logger.error(f"Error loading PGDATABASE_URL: {e}")
|
|
71
|
+
raise e
|
|
72
|
+
|
|
73
|
+
if not url:
|
|
74
|
+
logger.error("PGDATABASE_URL is not set")
|
|
75
|
+
raise ValueError("PGDATABASE_URL is not set")
|
|
76
|
+
|
|
77
|
+
return url
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _create_engine_with_retry() -> Engine:
|
|
81
|
+
"""创建数据库引擎(带重试)"""
|
|
82
|
+
url = get_db_url()
|
|
83
|
+
if not url:
|
|
84
|
+
logger.error("PGDATABASE_URL is not set")
|
|
85
|
+
raise ValueError("PGDATABASE_URL is not set")
|
|
86
|
+
|
|
87
|
+
size = 100
|
|
88
|
+
overflow = 100
|
|
89
|
+
recycle = 1800
|
|
90
|
+
timeout = 30
|
|
91
|
+
|
|
92
|
+
engine = create_engine(
|
|
93
|
+
url,
|
|
94
|
+
pool_size=size,
|
|
95
|
+
max_overflow=overflow,
|
|
96
|
+
pool_pre_ping=True,
|
|
97
|
+
pool_recycle=recycle,
|
|
98
|
+
pool_timeout=timeout,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# 验证连接,带重试
|
|
102
|
+
start_time = time.time()
|
|
103
|
+
last_error = None
|
|
104
|
+
while time.time() - start_time < MAX_RETRY_TIME:
|
|
105
|
+
try:
|
|
106
|
+
with engine.connect() as conn:
|
|
107
|
+
conn.execute(text("SELECT 1"))
|
|
108
|
+
return engine
|
|
109
|
+
except OperationalError as e:
|
|
110
|
+
last_error = e
|
|
111
|
+
elapsed = time.time() - start_time
|
|
112
|
+
logger.warning(f"Database connection failed, retrying... (elapsed: {elapsed:.1f}s)")
|
|
113
|
+
time.sleep(min(1, MAX_RETRY_TIME - elapsed))
|
|
114
|
+
|
|
115
|
+
logger.error(f"Database connection failed after {MAX_RETRY_TIME}s: {last_error}")
|
|
116
|
+
raise last_error
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def get_engine() -> Engine:
|
|
120
|
+
"""
|
|
121
|
+
获取数据库引擎(单例)
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
Engine: SQLAlchemy 数据库引擎
|
|
125
|
+
"""
|
|
126
|
+
global _engine
|
|
127
|
+
if _engine is None:
|
|
128
|
+
_engine = _create_engine_with_retry()
|
|
129
|
+
return _engine
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def get_sessionmaker():
|
|
133
|
+
"""
|
|
134
|
+
获取 sessionmaker(单例)
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
sessionmaker: SQLAlchemy sessionmaker
|
|
138
|
+
"""
|
|
139
|
+
global _SessionLocal
|
|
140
|
+
if _SessionLocal is None:
|
|
141
|
+
_SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=get_engine())
|
|
142
|
+
return _SessionLocal
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def get_session() -> Session:
|
|
146
|
+
"""
|
|
147
|
+
创建新的数据库会话
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
Session: SQLAlchemy 会话实例
|
|
151
|
+
|
|
152
|
+
Example:
|
|
153
|
+
>>> from local_coze.database import get_session
|
|
154
|
+
>>> db = get_session()
|
|
155
|
+
>>> try:
|
|
156
|
+
... # 执行数据库操作
|
|
157
|
+
... pass
|
|
158
|
+
... finally:
|
|
159
|
+
... db.close()
|
|
160
|
+
"""
|
|
161
|
+
return get_sessionmaker()()
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
__all__ = [
|
|
165
|
+
"Base",
|
|
166
|
+
"get_db_url",
|
|
167
|
+
"get_engine",
|
|
168
|
+
"get_sessionmaker",
|
|
169
|
+
"get_session",
|
|
170
|
+
]
|
|
@@ -0,0 +1,342 @@
|
|
|
1
|
+
"""
|
|
2
|
+
数据库迁移模块
|
|
3
|
+
提供 Alembic 迁移功能的 Python API 封装
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import os
|
|
7
|
+
import sys
|
|
8
|
+
import tempfile
|
|
9
|
+
import logging
|
|
10
|
+
from typing import Optional
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
# Alembic script.py.mako 模板
|
|
15
|
+
SCRIPT_MAKO_TEMPLATE = '''"""${message}
|
|
16
|
+
|
|
17
|
+
Revision ID: ${up_revision}
|
|
18
|
+
Revises: ${down_revision | comma,n}
|
|
19
|
+
Create Date: ${create_date}
|
|
20
|
+
|
|
21
|
+
"""
|
|
22
|
+
from typing import Sequence, Union
|
|
23
|
+
|
|
24
|
+
from alembic import op
|
|
25
|
+
import sqlalchemy as sa
|
|
26
|
+
${imports if imports else ""}
|
|
27
|
+
|
|
28
|
+
# revision identifiers, used by Alembic.
|
|
29
|
+
revision: str = ${repr(up_revision)}
|
|
30
|
+
down_revision: Union[str, None] = ${repr(down_revision)}
|
|
31
|
+
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
|
32
|
+
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def upgrade() -> None:
|
|
36
|
+
${upgrades if upgrades else "pass"}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def downgrade() -> None:
|
|
40
|
+
${downgrades if downgrades else "pass"}
|
|
41
|
+
'''
|
|
42
|
+
|
|
43
|
+
# 全局缓存的 alembic 目录
|
|
44
|
+
_alembic_dir: Optional[str] = None
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _load_env() -> None:
|
|
48
|
+
"""加载环境变量"""
|
|
49
|
+
try:
|
|
50
|
+
from dotenv import load_dotenv
|
|
51
|
+
load_dotenv()
|
|
52
|
+
except ImportError:
|
|
53
|
+
pass
|
|
54
|
+
|
|
55
|
+
try:
|
|
56
|
+
from coze_workload_identity import Client
|
|
57
|
+
client = Client()
|
|
58
|
+
env_vars = client.get_project_env_vars()
|
|
59
|
+
client.close()
|
|
60
|
+
for env_var in env_vars:
|
|
61
|
+
if env_var.key not in os.environ:
|
|
62
|
+
os.environ[env_var.key] = env_var.value
|
|
63
|
+
except Exception as e:
|
|
64
|
+
logger.debug(f"coze_workload_identity not available: {e}")
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _get_db_url() -> str:
|
|
68
|
+
"""获取数据库 URL"""
|
|
69
|
+
_load_env()
|
|
70
|
+
|
|
71
|
+
url = os.getenv("PGDATABASE_URL")
|
|
72
|
+
if url:
|
|
73
|
+
# 兼容 postgres:// 前缀
|
|
74
|
+
if url.startswith("postgres://"):
|
|
75
|
+
url = "postgresql://" + url[len("postgres://"):]
|
|
76
|
+
return url
|
|
77
|
+
|
|
78
|
+
raise ValueError(
|
|
79
|
+
"Database URL not configured. Set PGDATABASE_URL environment variable.\n"
|
|
80
|
+
"Did you create a database? You can create one via the Coze Coding platform."
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _get_alembic_dir() -> str:
|
|
85
|
+
"""获取或创建 alembic 临时目录"""
|
|
86
|
+
global _alembic_dir
|
|
87
|
+
if _alembic_dir and os.path.exists(_alembic_dir):
|
|
88
|
+
return _alembic_dir
|
|
89
|
+
|
|
90
|
+
# 使用固定的临时目录,避免每次创建新目录
|
|
91
|
+
_alembic_dir = os.path.join(tempfile.gettempdir(), "coze_sdk_alembic")
|
|
92
|
+
os.makedirs(_alembic_dir, exist_ok=True)
|
|
93
|
+
os.makedirs(os.path.join(_alembic_dir, "versions"), exist_ok=True)
|
|
94
|
+
return _alembic_dir
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _ensure_alembic_env(model_import_path: str) -> str:
|
|
98
|
+
"""
|
|
99
|
+
确保 alembic 环境已初始化
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
model_import_path: 模型导入路径,如 "storage.database.shared.model"
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
str: alembic 脚本目录路径
|
|
106
|
+
"""
|
|
107
|
+
script_location = _get_alembic_dir()
|
|
108
|
+
|
|
109
|
+
# 创建 env.py(支持 -x 参数传递 version_table 和 version_table_schema)
|
|
110
|
+
env_py_content = f'''"""Alembic Environment Configuration (Auto-generated)"""
|
|
111
|
+
|
|
112
|
+
import os
|
|
113
|
+
import sys
|
|
114
|
+
from logging.config import fileConfig
|
|
115
|
+
from sqlalchemy import engine_from_config, pool
|
|
116
|
+
from alembic import context
|
|
117
|
+
|
|
118
|
+
config = context.config
|
|
119
|
+
|
|
120
|
+
if config.config_file_name:
|
|
121
|
+
fileConfig(config.config_file_name)
|
|
122
|
+
|
|
123
|
+
# 获取 -x 参数
|
|
124
|
+
x_args = context.get_x_argument(as_dictionary=True)
|
|
125
|
+
version_table = x_args.get("version_table")
|
|
126
|
+
version_table_schema = x_args.get("version_table_schema")
|
|
127
|
+
|
|
128
|
+
# 导入 Base
|
|
129
|
+
from local_coze.database import Base
|
|
130
|
+
|
|
131
|
+
# 导入用户模型以注册到 metadata
|
|
132
|
+
from {model_import_path} import *
|
|
133
|
+
|
|
134
|
+
target_metadata = Base.metadata
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def run_migrations_offline():
|
|
138
|
+
url = config.get_main_option("sqlalchemy.url")
|
|
139
|
+
kwargs = {{
|
|
140
|
+
"url": url,
|
|
141
|
+
"target_metadata": target_metadata,
|
|
142
|
+
"literal_binds": True,
|
|
143
|
+
"compare_type": True,
|
|
144
|
+
"compare_server_default": True,
|
|
145
|
+
"compare_nullable": True,
|
|
146
|
+
}}
|
|
147
|
+
if version_table:
|
|
148
|
+
kwargs["version_table"] = version_table
|
|
149
|
+
if version_table_schema:
|
|
150
|
+
kwargs["version_table_schema"] = version_table_schema
|
|
151
|
+
context.configure(**kwargs)
|
|
152
|
+
with context.begin_transaction():
|
|
153
|
+
context.run_migrations()
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def run_migrations_online():
|
|
157
|
+
connectable = engine_from_config(
|
|
158
|
+
config.get_section(config.config_ini_section),
|
|
159
|
+
prefix="sqlalchemy.",
|
|
160
|
+
poolclass=pool.NullPool,
|
|
161
|
+
)
|
|
162
|
+
with connectable.connect() as connection:
|
|
163
|
+
kwargs = {{
|
|
164
|
+
"connection": connection,
|
|
165
|
+
"target_metadata": target_metadata,
|
|
166
|
+
"compare_type": True,
|
|
167
|
+
"compare_server_default": True,
|
|
168
|
+
"compare_nullable": True,
|
|
169
|
+
}}
|
|
170
|
+
if version_table:
|
|
171
|
+
kwargs["version_table"] = version_table
|
|
172
|
+
if version_table_schema:
|
|
173
|
+
kwargs["version_table_schema"] = version_table_schema
|
|
174
|
+
context.configure(**kwargs)
|
|
175
|
+
with context.begin_transaction():
|
|
176
|
+
context.run_migrations()
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
if context.is_offline_mode():
|
|
180
|
+
run_migrations_offline()
|
|
181
|
+
else:
|
|
182
|
+
run_migrations_online()
|
|
183
|
+
'''
|
|
184
|
+
|
|
185
|
+
env_py_path = os.path.join(script_location, "env.py")
|
|
186
|
+
with open(env_py_path, "w") as f:
|
|
187
|
+
f.write(env_py_content)
|
|
188
|
+
|
|
189
|
+
# 创建 script.py.mako
|
|
190
|
+
mako_path = os.path.join(script_location, "script.py.mako")
|
|
191
|
+
with open(mako_path, "w") as f:
|
|
192
|
+
f.write(SCRIPT_MAKO_TEMPLATE)
|
|
193
|
+
|
|
194
|
+
return script_location
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def _get_alembic_config(model_import_path: str, model_path: Optional[str] = None):
|
|
198
|
+
"""
|
|
199
|
+
创建 Alembic 配置对象
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
model_import_path: 模型导入路径
|
|
203
|
+
model_path: 模型文件所在目录(会添加到 sys.path)
|
|
204
|
+
"""
|
|
205
|
+
from alembic.config import Config
|
|
206
|
+
|
|
207
|
+
# 添加模型目录到 sys.path
|
|
208
|
+
if model_path and model_path not in sys.path:
|
|
209
|
+
sys.path.insert(0, model_path)
|
|
210
|
+
|
|
211
|
+
script_location = _ensure_alembic_env(model_import_path)
|
|
212
|
+
|
|
213
|
+
config = Config()
|
|
214
|
+
config.set_main_option("script_location", script_location)
|
|
215
|
+
config.set_main_option("sqlalchemy.url", _get_db_url())
|
|
216
|
+
|
|
217
|
+
return config
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def generate_models(output_path: str, verbose: bool = False) -> None:
|
|
221
|
+
"""
|
|
222
|
+
从数据库生成 ORM 模型
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
output_path: 输出文件路径,如 "src/storage/database/shared/model.py"
|
|
226
|
+
verbose: 是否输出详细信息
|
|
227
|
+
|
|
228
|
+
Example:
|
|
229
|
+
>>> from local_coze.database import generate_models
|
|
230
|
+
>>> generate_models("src/storage/database/shared/model.py")
|
|
231
|
+
"""
|
|
232
|
+
db_url = _get_db_url()
|
|
233
|
+
|
|
234
|
+
# 确保输出目录存在
|
|
235
|
+
output_dir = os.path.dirname(output_path)
|
|
236
|
+
if output_dir:
|
|
237
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
238
|
+
|
|
239
|
+
# 使用 subprocess 调用 sqlacodegen
|
|
240
|
+
import subprocess
|
|
241
|
+
|
|
242
|
+
cmd = ["sqlacodegen", db_url, "--outfile", output_path]
|
|
243
|
+
if verbose:
|
|
244
|
+
print(f"Running: {' '.join(cmd)}")
|
|
245
|
+
|
|
246
|
+
result = subprocess.run(cmd, capture_output=True, text=True)
|
|
247
|
+
|
|
248
|
+
if result.returncode != 0:
|
|
249
|
+
raise RuntimeError(f"sqlacodegen failed: {result.stderr}")
|
|
250
|
+
|
|
251
|
+
# 在生成的文件头部添加 Base 导入
|
|
252
|
+
with open(output_path, "r") as f:
|
|
253
|
+
content = f.read()
|
|
254
|
+
|
|
255
|
+
# 替换默认的 Base 定义为从 SDK 导入
|
|
256
|
+
import re
|
|
257
|
+
|
|
258
|
+
# 删除 "class Base(DeclarativeBase):\n pass\n"
|
|
259
|
+
content = re.sub(r"class Base\(DeclarativeBase\):\s*\n\s*pass\s*\n*", "", content)
|
|
260
|
+
|
|
261
|
+
# 删除 DeclarativeBase 相关导入
|
|
262
|
+
content = re.sub(r",\s*DeclarativeBase", "", content)
|
|
263
|
+
content = re.sub(r"DeclarativeBase,\s*", "", content)
|
|
264
|
+
content = re.sub(r"from sqlalchemy\.orm import DeclarativeBase\n", "", content)
|
|
265
|
+
|
|
266
|
+
# 在文件开头添加 SDK 的 Base 导入
|
|
267
|
+
content = "from local_coze.database import Base\n\n" + content
|
|
268
|
+
|
|
269
|
+
with open(output_path, "w") as f:
|
|
270
|
+
f.write(content)
|
|
271
|
+
|
|
272
|
+
if verbose:
|
|
273
|
+
print(f"Models generated at {output_path}")
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def upgrade(
|
|
277
|
+
model_import_path: str = "storage.database.shared.model",
|
|
278
|
+
model_path: Optional[str] = None,
|
|
279
|
+
verbose: bool = False,
|
|
280
|
+
version_table: str = "schema_version",
|
|
281
|
+
version_table_schema: str = "internal",
|
|
282
|
+
) -> None:
|
|
283
|
+
"""
|
|
284
|
+
执行数据库迁移(包含自动生成迁移版本)
|
|
285
|
+
|
|
286
|
+
相当于执行:
|
|
287
|
+
1. alembic revision --autogenerate
|
|
288
|
+
2. alembic upgrade head
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
model_import_path: 模型导入路径,如 "storage.database.shared.model"
|
|
292
|
+
model_path: 模型文件所在目录(会添加到 sys.path)
|
|
293
|
+
verbose: 是否输出详细信息
|
|
294
|
+
version_table: 版本表名,默认 "schema_version"
|
|
295
|
+
version_table_schema: 版本表所在 schema,默认 "internal"
|
|
296
|
+
|
|
297
|
+
Example:
|
|
298
|
+
>>> from local_coze.database import upgrade
|
|
299
|
+
>>> upgrade() # 使用默认配置
|
|
300
|
+
>>> upgrade(model_import_path="myapp.models", model_path="/path/to/src")
|
|
301
|
+
"""
|
|
302
|
+
from alembic import command
|
|
303
|
+
from sqlalchemy import create_engine, text
|
|
304
|
+
|
|
305
|
+
# 确保 version_table_schema 存在
|
|
306
|
+
db_url = _get_db_url()
|
|
307
|
+
engine = create_engine(db_url, isolation_level="AUTOCOMMIT")
|
|
308
|
+
with engine.connect() as conn:
|
|
309
|
+
conn.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{version_table_schema}"'))
|
|
310
|
+
engine.dispose()
|
|
311
|
+
|
|
312
|
+
config = _get_alembic_config(model_import_path, model_path)
|
|
313
|
+
|
|
314
|
+
# 设置 -x 参数
|
|
315
|
+
config.cmd_opts = type("obj", (object,), {"x": [
|
|
316
|
+
f"version_table={version_table}",
|
|
317
|
+
f"version_table_schema={version_table_schema}",
|
|
318
|
+
]})()
|
|
319
|
+
|
|
320
|
+
# 去除 alembic 输出的前导空格
|
|
321
|
+
config.print_stdout = lambda msg, *args: print(msg.lstrip() % args if args else msg.lstrip())
|
|
322
|
+
|
|
323
|
+
# 自动生成迁移版本
|
|
324
|
+
try:
|
|
325
|
+
command.revision(config, message="auto migration", autogenerate=True)
|
|
326
|
+
if verbose:
|
|
327
|
+
print("Migration revision generated")
|
|
328
|
+
except Exception as e:
|
|
329
|
+
# 如果没有变更,revision 可能会失败,忽略
|
|
330
|
+
logger.debug(f"Revision generation skipped: {e}")
|
|
331
|
+
|
|
332
|
+
# 执行升级
|
|
333
|
+
command.upgrade(config, "head")
|
|
334
|
+
if verbose:
|
|
335
|
+
print("Database upgraded to head")
|
|
336
|
+
logger.info("Database upgraded to head")
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
__all__ = [
|
|
340
|
+
"generate_models",
|
|
341
|
+
"upgrade",
|
|
342
|
+
]
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from .client import EmbeddingClient
|
|
2
|
+
from .models import (
|
|
3
|
+
EmbeddingConfig,
|
|
4
|
+
EmbeddingInputItem,
|
|
5
|
+
EmbeddingInputImageURL,
|
|
6
|
+
EmbeddingInputVideoURL,
|
|
7
|
+
EmbeddingRequest,
|
|
8
|
+
EmbeddingResponse,
|
|
9
|
+
EmbeddingData,
|
|
10
|
+
EmbeddingUsage,
|
|
11
|
+
MultiEmbeddingConfig,
|
|
12
|
+
SparseEmbeddingConfig,
|
|
13
|
+
SparseEmbeddingItem,
|
|
14
|
+
PromptTokensDetails,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"EmbeddingClient",
|
|
19
|
+
"EmbeddingConfig",
|
|
20
|
+
"EmbeddingInputItem",
|
|
21
|
+
"EmbeddingInputImageURL",
|
|
22
|
+
"EmbeddingInputVideoURL",
|
|
23
|
+
"EmbeddingRequest",
|
|
24
|
+
"EmbeddingResponse",
|
|
25
|
+
"EmbeddingData",
|
|
26
|
+
"EmbeddingUsage",
|
|
27
|
+
"MultiEmbeddingConfig",
|
|
28
|
+
"SparseEmbeddingConfig",
|
|
29
|
+
"SparseEmbeddingItem",
|
|
30
|
+
"PromptTokensDetails",
|
|
31
|
+
]
|