sql-assistant 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sql_assistant/__init__.py +3 -0
- sql_assistant/api/__init__.py +1 -0
- sql_assistant/api/backup.py +116 -0
- sql_assistant/api/config.py +183 -0
- sql_assistant/api/conversation.py +71 -0
- sql_assistant/api/dependencies.py +22 -0
- sql_assistant/api/history.py +61 -0
- sql_assistant/api/models.py +221 -0
- sql_assistant/api/query.py +275 -0
- sql_assistant/api/routes.py +19 -0
- sql_assistant/api/schema.py +21 -0
- sql_assistant/config.py +144 -0
- sql_assistant/database/__init__.py +1 -0
- sql_assistant/database/backup.py +568 -0
- sql_assistant/database/connectors/__init__.py +1 -0
- sql_assistant/database/connectors/base.py +185 -0
- sql_assistant/database/connectors/exceptions.py +88 -0
- sql_assistant/database/connectors/mongodb.py +194 -0
- sql_assistant/database/connectors/mysql.py +110 -0
- sql_assistant/database/connectors/postgresql.py +133 -0
- sql_assistant/database/connectors/redis.py +132 -0
- sql_assistant/database/connectors/sqlserver.py +140 -0
- sql_assistant/database/history.py +290 -0
- sql_assistant/database/manager.py +178 -0
- sql_assistant/database/security.py +230 -0
- sql_assistant/llm/__init__.py +1 -0
- sql_assistant/llm/base.py +28 -0
- sql_assistant/llm/exceptions.py +96 -0
- sql_assistant/llm/manager.py +82 -0
- sql_assistant/llm/prompts.py +29 -0
- sql_assistant/llm/providers/__init__.py +1 -0
- sql_assistant/llm/providers/claude.py +132 -0
- sql_assistant/llm/providers/gemini.py +127 -0
- sql_assistant/llm/providers/openai_compatible.py +103 -0
- sql_assistant/llm/retry.py +88 -0
- sql_assistant/main.py +94 -0
- sql_assistant/settings.py +219 -0
- sql_assistant/web/__init__.py +1 -0
- sql_assistant/web/static/css/base.css +25 -0
- sql_assistant/web/static/css/components/backup.css +146 -0
- sql_assistant/web/static/css/components/chat.css +465 -0
- sql_assistant/web/static/css/components/modal.css +143 -0
- sql_assistant/web/static/css/components/settings.css +358 -0
- sql_assistant/web/static/css/components/sidebar.css +235 -0
- sql_assistant/web/static/css/components/toast.css +30 -0
- sql_assistant/web/static/css/style.css +10 -0
- sql_assistant/web/static/css/theme.css +200 -0
- sql_assistant/web/static/js/api.js +38 -0
- sql_assistant/web/static/js/app.js +161 -0
- sql_assistant/web/static/js/backup.js +216 -0
- sql_assistant/web/static/js/chat.js +238 -0
- sql_assistant/web/static/js/color-theme-manager.js +121 -0
- sql_assistant/web/static/js/confirm.js +95 -0
- sql_assistant/web/static/js/conversations.js +182 -0
- sql_assistant/web/static/js/settings.js +425 -0
- sql_assistant/web/static/js/state.js +43 -0
- sql_assistant/web/static/js/theme-manager.js +64 -0
- sql_assistant/web/static/js/ui.js +53 -0
- sql_assistant/web/templates/index.html +373 -0
- sql_assistant-1.0.0.dist-info/METADATA +24 -0
- sql_assistant-1.0.0.dist-info/RECORD +64 -0
- sql_assistant-1.0.0.dist-info/WHEEL +4 -0
- sql_assistant-1.0.0.dist-info/entry_points.txt +2 -0
- sql_assistant-1.0.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
"""Anthropic Claude Provider"""
|
|
2
|
+
|
|
3
|
+
import httpx
|
|
4
|
+
from typing import AsyncGenerator, Optional
|
|
5
|
+
|
|
6
|
+
from ..base import BaseLLMProvider
|
|
7
|
+
from ..retry import async_retry
|
|
8
|
+
from ..exceptions import LLMConnectionError, LLMResponseError
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ClaudeProvider(BaseLLMProvider):
|
|
12
|
+
"""Anthropic Claude API 实现"""
|
|
13
|
+
|
|
14
|
+
def __init__(self, api_key: str, base_url: str, model: str):
|
|
15
|
+
super().__init__(api_key, base_url, model)
|
|
16
|
+
self._client: Optional[httpx.AsyncClient] = None
|
|
17
|
+
|
|
18
|
+
async def _get_client(self) -> httpx.AsyncClient:
|
|
19
|
+
if self._client is None:
|
|
20
|
+
self._client = httpx.AsyncClient(
|
|
21
|
+
timeout=httpx.Timeout(60.0),
|
|
22
|
+
headers={
|
|
23
|
+
"x-api-key": self.api_key,
|
|
24
|
+
"anthropic-version": "2023-06-01",
|
|
25
|
+
"Content-Type": "application/json",
|
|
26
|
+
},
|
|
27
|
+
)
|
|
28
|
+
return self._client
|
|
29
|
+
|
|
30
|
+
def _convert_messages(self, messages: list[dict]) -> tuple[list[dict], Optional[str]]:
|
|
31
|
+
"""将 OpenAI 格式消息转为 Claude 格式"""
|
|
32
|
+
system_content = None
|
|
33
|
+
claude_messages = []
|
|
34
|
+
|
|
35
|
+
for msg in messages:
|
|
36
|
+
role = msg.get("role", "user")
|
|
37
|
+
content = msg.get("content", "")
|
|
38
|
+
|
|
39
|
+
if role == "system":
|
|
40
|
+
system_content = content
|
|
41
|
+
elif role == "user":
|
|
42
|
+
claude_messages.append({"role": "user", "content": content})
|
|
43
|
+
elif role == "assistant":
|
|
44
|
+
claude_messages.append({"role": "assistant", "content": content})
|
|
45
|
+
|
|
46
|
+
return claude_messages, system_content
|
|
47
|
+
|
|
48
|
+
@async_retry(max_attempts=3, base_delay=1.0, retryable_exceptions=(httpx.HTTPError,))
|
|
49
|
+
async def chat(self, messages: list[dict], temperature: float = 0.1) -> str:
|
|
50
|
+
client = await self._get_client()
|
|
51
|
+
claude_messages, system = self._convert_messages(messages)
|
|
52
|
+
|
|
53
|
+
url = "https://api.anthropic.com/v1/messages"
|
|
54
|
+
payload = {
|
|
55
|
+
"model": self.model,
|
|
56
|
+
"max_tokens": 4096,
|
|
57
|
+
"temperature": temperature,
|
|
58
|
+
"messages": claude_messages,
|
|
59
|
+
}
|
|
60
|
+
if system:
|
|
61
|
+
payload["system"] = system
|
|
62
|
+
|
|
63
|
+
response = await client.post(url, json=payload)
|
|
64
|
+
try:
|
|
65
|
+
response.raise_for_status()
|
|
66
|
+
except httpx.HTTPStatusError as e:
|
|
67
|
+
raise LLMConnectionError(f"Claude 请求失败: {e.response.status_code} - {e.response.text}", "claude")
|
|
68
|
+
|
|
69
|
+
data = response.json()
|
|
70
|
+
try:
|
|
71
|
+
return data["content"][0]["text"]
|
|
72
|
+
except (KeyError, IndexError) as e:
|
|
73
|
+
raise LLMResponseError(f"Claude 响应格式错误: {e}", "claude", data)
|
|
74
|
+
|
|
75
|
+
async def chat_stream(self, messages: list[dict], temperature: float = 0.1) -> AsyncGenerator[str, None]:
|
|
76
|
+
client = await self._get_client()
|
|
77
|
+
claude_messages, system = self._convert_messages(messages)
|
|
78
|
+
|
|
79
|
+
url = "https://api.anthropic.com/v1/messages"
|
|
80
|
+
payload = {
|
|
81
|
+
"model": self.model,
|
|
82
|
+
"max_tokens": 4096,
|
|
83
|
+
"temperature": temperature,
|
|
84
|
+
"messages": claude_messages,
|
|
85
|
+
"stream": True,
|
|
86
|
+
}
|
|
87
|
+
if system:
|
|
88
|
+
payload["system"] = system
|
|
89
|
+
|
|
90
|
+
async with client.stream("POST", url, json=payload) as response:
|
|
91
|
+
try:
|
|
92
|
+
response.raise_for_status()
|
|
93
|
+
except httpx.HTTPStatusError as e:
|
|
94
|
+
raise LLMConnectionError(f"Claude 流式请求失败: {e.response.status_code}", "claude")
|
|
95
|
+
|
|
96
|
+
async for line in response.aiter_lines():
|
|
97
|
+
if line.startswith("data: "):
|
|
98
|
+
data_str = line[6:]
|
|
99
|
+
try:
|
|
100
|
+
import json
|
|
101
|
+
data = json.loads(data_str)
|
|
102
|
+
if data.get("type") == "content_block_delta":
|
|
103
|
+
delta = data.get("delta", {})
|
|
104
|
+
text = delta.get("text", "")
|
|
105
|
+
if text:
|
|
106
|
+
yield text
|
|
107
|
+
except Exception:
|
|
108
|
+
continue
|
|
109
|
+
|
|
110
|
+
async def test_connection(self) -> dict:
|
|
111
|
+
from ..exceptions import format_llm_result
|
|
112
|
+
try:
|
|
113
|
+
client = await self._get_client()
|
|
114
|
+
url = "https://api.anthropic.com/v1/messages"
|
|
115
|
+
payload = {
|
|
116
|
+
"model": self.model,
|
|
117
|
+
"max_tokens": 1,
|
|
118
|
+
"temperature": 0,
|
|
119
|
+
"messages": [{"role": "user", "content": "Hello"}],
|
|
120
|
+
}
|
|
121
|
+
response = await client.post(url, json=payload, timeout=30.0)
|
|
122
|
+
if response.status_code == 200:
|
|
123
|
+
return format_llm_result(True, data={"message": "Claude 连接测试成功"})
|
|
124
|
+
else:
|
|
125
|
+
return format_llm_result(False, error=f"连接失败: {response.status_code}", provider="claude", code="HTTP_ERROR")
|
|
126
|
+
except Exception as e:
|
|
127
|
+
return format_llm_result(False, error=f"连接失败: {str(e)}", provider="claude", code="CONNECTION_ERROR")
|
|
128
|
+
|
|
129
|
+
async def close(self):
|
|
130
|
+
if self._client:
|
|
131
|
+
await self._client.aclose()
|
|
132
|
+
self._client = None
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
"""Google Gemini Provider"""
|
|
2
|
+
|
|
3
|
+
import httpx
|
|
4
|
+
from typing import AsyncGenerator, Optional
|
|
5
|
+
|
|
6
|
+
from ..base import BaseLLMProvider
|
|
7
|
+
from ..retry import async_retry
|
|
8
|
+
from ..exceptions import LLMConnectionError, LLMResponseError
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class GeminiProvider(BaseLLMProvider):
|
|
12
|
+
"""Google Gemini API 实现"""
|
|
13
|
+
|
|
14
|
+
def __init__(self, api_key: str, base_url: str, model: str):
|
|
15
|
+
super().__init__(api_key, base_url, model)
|
|
16
|
+
self._client: Optional[httpx.AsyncClient] = None
|
|
17
|
+
|
|
18
|
+
async def _get_client(self) -> httpx.AsyncClient:
|
|
19
|
+
if self._client is None:
|
|
20
|
+
self._client = httpx.AsyncClient(timeout=httpx.Timeout(60.0))
|
|
21
|
+
return self._client
|
|
22
|
+
|
|
23
|
+
def _convert_messages(self, messages: list[dict]) -> tuple[list[dict], Optional[str]]:
|
|
24
|
+
"""将 OpenAI 格式消息转为 Gemini 格式"""
|
|
25
|
+
gemini_contents = []
|
|
26
|
+
system_instruction = None
|
|
27
|
+
|
|
28
|
+
for msg in messages:
|
|
29
|
+
role = msg.get("role", "user")
|
|
30
|
+
content = msg.get("content", "")
|
|
31
|
+
|
|
32
|
+
if role == "system":
|
|
33
|
+
system_instruction = content
|
|
34
|
+
elif role == "user":
|
|
35
|
+
gemini_contents.append({"role": "user", "parts": [{"text": content}]})
|
|
36
|
+
elif role == "assistant":
|
|
37
|
+
gemini_contents.append({"role": "model", "parts": [{"text": content}]})
|
|
38
|
+
|
|
39
|
+
return gemini_contents, system_instruction
|
|
40
|
+
|
|
41
|
+
@async_retry(max_attempts=3, base_delay=1.0, retryable_exceptions=(httpx.HTTPError,))
|
|
42
|
+
async def chat(self, messages: list[dict], temperature: float = 0.1) -> str:
|
|
43
|
+
client = await self._get_client()
|
|
44
|
+
contents, system_instruction = self._convert_messages(messages)
|
|
45
|
+
|
|
46
|
+
url = f"https://generativelanguage.googleapis.com/v1beta/models/{self.model}:generateContent"
|
|
47
|
+
params = {"key": self.api_key}
|
|
48
|
+
|
|
49
|
+
payload = {
|
|
50
|
+
"contents": contents,
|
|
51
|
+
"generationConfig": {"temperature": temperature},
|
|
52
|
+
}
|
|
53
|
+
if system_instruction:
|
|
54
|
+
payload["systemInstruction"] = {
|
|
55
|
+
"parts": [{"text": system_instruction}]
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
response = await client.post(url, params=params, json=payload)
|
|
59
|
+
try:
|
|
60
|
+
response.raise_for_status()
|
|
61
|
+
except httpx.HTTPStatusError as e:
|
|
62
|
+
raise LLMConnectionError(f"Gemini 请求失败: {e.response.status_code} - {e.response.text}", "gemini")
|
|
63
|
+
|
|
64
|
+
data = response.json()
|
|
65
|
+
try:
|
|
66
|
+
return data["candidates"][0]["content"]["parts"][0]["text"]
|
|
67
|
+
except (KeyError, IndexError) as e:
|
|
68
|
+
raise LLMResponseError(f"Gemini 响应格式错误: {e}", "gemini", data)
|
|
69
|
+
|
|
70
|
+
async def chat_stream(self, messages: list[dict], temperature: float = 0.1) -> AsyncGenerator[str, None]:
|
|
71
|
+
client = await self._get_client()
|
|
72
|
+
contents, system_instruction = self._convert_messages(messages)
|
|
73
|
+
|
|
74
|
+
url = f"https://generativelanguage.googleapis.com/v1beta/models/{self.model}:streamGenerateContent"
|
|
75
|
+
params = {"key": self.api_key, "alt": "sse"}
|
|
76
|
+
|
|
77
|
+
payload = {
|
|
78
|
+
"contents": contents,
|
|
79
|
+
"generationConfig": {"temperature": temperature},
|
|
80
|
+
}
|
|
81
|
+
if system_instruction:
|
|
82
|
+
payload["systemInstruction"] = {
|
|
83
|
+
"parts": [{"text": system_instruction}]
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
async with client.stream("POST", url, params=params, json=payload) as response:
|
|
87
|
+
try:
|
|
88
|
+
response.raise_for_status()
|
|
89
|
+
except httpx.HTTPStatusError as e:
|
|
90
|
+
raise LLMConnectionError(f"Gemini 流式请求失败: {e.response.status_code}", "gemini")
|
|
91
|
+
|
|
92
|
+
async for line in response.aiter_lines():
|
|
93
|
+
if line.startswith("data: "):
|
|
94
|
+
data_str = line[6:]
|
|
95
|
+
try:
|
|
96
|
+
import json
|
|
97
|
+
data = json.loads(data_str)
|
|
98
|
+
parts = data.get("candidates", [{}])[0].get("content", {}).get("parts", [])
|
|
99
|
+
for part in parts:
|
|
100
|
+
text = part.get("text", "")
|
|
101
|
+
if text:
|
|
102
|
+
yield text
|
|
103
|
+
except Exception:
|
|
104
|
+
continue
|
|
105
|
+
|
|
106
|
+
async def test_connection(self) -> dict:
|
|
107
|
+
from ..exceptions import format_llm_result
|
|
108
|
+
try:
|
|
109
|
+
client = await self._get_client()
|
|
110
|
+
url = f"https://generativelanguage.googleapis.com/v1beta/models/{self.model}:generateContent"
|
|
111
|
+
params = {"key": self.api_key}
|
|
112
|
+
payload = {
|
|
113
|
+
"contents": [{"role": "user", "parts": [{"text": "Hello"}]}],
|
|
114
|
+
"generationConfig": {"temperature": 0, "maxOutputTokens": 1},
|
|
115
|
+
}
|
|
116
|
+
response = await client.post(url, params=params, json=payload, timeout=30.0)
|
|
117
|
+
if response.status_code == 200:
|
|
118
|
+
return format_llm_result(True, data={"message": "Gemini 连接测试成功"})
|
|
119
|
+
else:
|
|
120
|
+
return format_llm_result(False, error=f"连接失败: {response.status_code}", provider="gemini", code="HTTP_ERROR")
|
|
121
|
+
except Exception as e:
|
|
122
|
+
return format_llm_result(False, error=f"连接失败: {str(e)}", provider="gemini", code="CONNECTION_ERROR")
|
|
123
|
+
|
|
124
|
+
async def close(self):
|
|
125
|
+
if self._client:
|
|
126
|
+
await self._client.aclose()
|
|
127
|
+
self._client = None
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
"""OpenAI 兼容接口 Provider - 支持 DeepSeek / Doubao / Kimi / Qwen / OpenAI"""
|
|
2
|
+
|
|
3
|
+
import httpx
|
|
4
|
+
from typing import AsyncGenerator, Optional
|
|
5
|
+
|
|
6
|
+
from ..base import BaseLLMProvider
|
|
7
|
+
from ..retry import async_retry
|
|
8
|
+
from ..exceptions import LLMConnectionError, LLMResponseError
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class OpenAICompatibleProvider(BaseLLMProvider):
|
|
12
|
+
"""OpenAI 兼容接口实现"""
|
|
13
|
+
|
|
14
|
+
def __init__(self, api_key: str, base_url: str, model: str, provider_name: str = ""):
|
|
15
|
+
super().__init__(api_key, base_url, model)
|
|
16
|
+
self.provider_name = provider_name
|
|
17
|
+
self._client: Optional[httpx.AsyncClient] = None
|
|
18
|
+
|
|
19
|
+
async def _get_client(self) -> httpx.AsyncClient:
|
|
20
|
+
if self._client is None:
|
|
21
|
+
self._client = httpx.AsyncClient(
|
|
22
|
+
timeout=httpx.Timeout(60.0),
|
|
23
|
+
headers={
|
|
24
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
25
|
+
"Content-Type": "application/json",
|
|
26
|
+
},
|
|
27
|
+
)
|
|
28
|
+
return self._client
|
|
29
|
+
|
|
30
|
+
@async_retry(max_attempts=3, base_delay=1.0, retryable_exceptions=(httpx.HTTPError,))
|
|
31
|
+
async def chat(self, messages: list[dict], temperature: float = 0.1) -> str:
|
|
32
|
+
client = await self._get_client()
|
|
33
|
+
url = f"{self.base_url}/chat/completions"
|
|
34
|
+
payload = {
|
|
35
|
+
"model": self.model,
|
|
36
|
+
"messages": messages,
|
|
37
|
+
"temperature": temperature,
|
|
38
|
+
}
|
|
39
|
+
response = await client.post(url, json=payload)
|
|
40
|
+
try:
|
|
41
|
+
response.raise_for_status()
|
|
42
|
+
except httpx.HTTPStatusError as e:
|
|
43
|
+
raise LLMConnectionError(f"LLM 请求失败: {e.response.status_code} - {e.response.text}", self.provider_name)
|
|
44
|
+
|
|
45
|
+
data = response.json()
|
|
46
|
+
try:
|
|
47
|
+
return data["choices"][0]["message"]["content"]
|
|
48
|
+
except (KeyError, IndexError) as e:
|
|
49
|
+
raise LLMResponseError(f"LLM 响应格式错误: {e}", self.provider_name, data)
|
|
50
|
+
|
|
51
|
+
async def chat_stream(self, messages: list[dict], temperature: float = 0.1) -> AsyncGenerator[str, None]:
|
|
52
|
+
client = await self._get_client()
|
|
53
|
+
url = f"{self.base_url}/chat/completions"
|
|
54
|
+
payload = {
|
|
55
|
+
"model": self.model,
|
|
56
|
+
"messages": messages,
|
|
57
|
+
"temperature": temperature,
|
|
58
|
+
"stream": True,
|
|
59
|
+
}
|
|
60
|
+
async with client.stream("POST", url, json=payload) as response:
|
|
61
|
+
try:
|
|
62
|
+
response.raise_for_status()
|
|
63
|
+
except httpx.HTTPStatusError as e:
|
|
64
|
+
raise LLMConnectionError(f"LLM 流式请求失败: {e.response.status_code}", self.provider_name)
|
|
65
|
+
|
|
66
|
+
async for line in response.aiter_lines():
|
|
67
|
+
if line.startswith("data: "):
|
|
68
|
+
data_str = line[6:]
|
|
69
|
+
if data_str == "[DONE]":
|
|
70
|
+
break
|
|
71
|
+
try:
|
|
72
|
+
import json
|
|
73
|
+
data = json.loads(data_str)
|
|
74
|
+
delta = data["choices"][0].get("delta", {})
|
|
75
|
+
content = delta.get("content", "")
|
|
76
|
+
if content:
|
|
77
|
+
yield content
|
|
78
|
+
except Exception:
|
|
79
|
+
continue
|
|
80
|
+
|
|
81
|
+
async def test_connection(self) -> dict:
|
|
82
|
+
from ..exceptions import format_llm_result
|
|
83
|
+
try:
|
|
84
|
+
client = await self._get_client()
|
|
85
|
+
url = f"{self.base_url}/chat/completions"
|
|
86
|
+
payload = {
|
|
87
|
+
"model": self.model,
|
|
88
|
+
"messages": [{"role": "user", "content": "Hello"}],
|
|
89
|
+
"temperature": 0,
|
|
90
|
+
"max_tokens": 1,
|
|
91
|
+
}
|
|
92
|
+
response = await client.post(url, json=payload, timeout=30.0)
|
|
93
|
+
if response.status_code == 200:
|
|
94
|
+
return format_llm_result(True, data={"message": "LLM 连接测试成功"})
|
|
95
|
+
else:
|
|
96
|
+
return format_llm_result(False, error=f"连接失败: {response.status_code}", provider=self.provider_name, code="HTTP_ERROR")
|
|
97
|
+
except Exception as e:
|
|
98
|
+
return format_llm_result(False, error=f"连接失败: {str(e)}", provider=self.provider_name, code="CONNECTION_ERROR")
|
|
99
|
+
|
|
100
|
+
async def close(self):
|
|
101
|
+
if self._client:
|
|
102
|
+
await self._client.aclose()
|
|
103
|
+
self._client = None
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
"""LLM 请求重试机制"""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import random
|
|
5
|
+
from functools import wraps
|
|
6
|
+
from typing import Callable, TypeVar, Any
|
|
7
|
+
|
|
8
|
+
T = TypeVar("T")
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class RetryError(Exception):
|
|
12
|
+
"""重试耗尽后的异常"""
|
|
13
|
+
def __init__(self, message: str, last_exception: Exception):
|
|
14
|
+
super().__init__(message)
|
|
15
|
+
self.last_exception = last_exception
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def async_retry(
|
|
19
|
+
max_attempts: int = 3,
|
|
20
|
+
base_delay: float = 1.0,
|
|
21
|
+
max_delay: float = 10.0,
|
|
22
|
+
exponential_base: float = 2.0,
|
|
23
|
+
jitter: bool = True,
|
|
24
|
+
retryable_exceptions: tuple = (Exception,),
|
|
25
|
+
):
|
|
26
|
+
"""异步重试装饰器
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
max_attempts: 最大尝试次数
|
|
30
|
+
base_delay: 基础延迟时间(秒)
|
|
31
|
+
max_delay: 最大延迟时间(秒)
|
|
32
|
+
exponential_base: 指数退避基数
|
|
33
|
+
jitter: 是否添加随机抖动
|
|
34
|
+
retryable_exceptions: 可重试的异常类型元组
|
|
35
|
+
"""
|
|
36
|
+
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
|
37
|
+
@wraps(func)
|
|
38
|
+
async def wrapper(*args, **kwargs) -> Any:
|
|
39
|
+
last_exception = None
|
|
40
|
+
|
|
41
|
+
for attempt in range(1, max_attempts + 1):
|
|
42
|
+
try:
|
|
43
|
+
return await func(*args, **kwargs)
|
|
44
|
+
except retryable_exceptions as e:
|
|
45
|
+
last_exception = e
|
|
46
|
+
|
|
47
|
+
if attempt == max_attempts:
|
|
48
|
+
break
|
|
49
|
+
|
|
50
|
+
delay = min(base_delay * (exponential_base ** (attempt - 1)), max_delay)
|
|
51
|
+
|
|
52
|
+
if jitter:
|
|
53
|
+
delay = delay * (0.5 + random.random() * 0.5)
|
|
54
|
+
|
|
55
|
+
await asyncio.sleep(delay)
|
|
56
|
+
|
|
57
|
+
raise RetryError(
|
|
58
|
+
f"重试 {max_attempts} 次后仍然失败: {last_exception}",
|
|
59
|
+
last_exception
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
return wrapper
|
|
63
|
+
return decorator
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class RateLimiter:
|
|
67
|
+
"""简单的速率限制器"""
|
|
68
|
+
|
|
69
|
+
def __init__(self, calls: int, period: float):
|
|
70
|
+
self.calls = calls
|
|
71
|
+
self.period = period
|
|
72
|
+
self._tokens = calls
|
|
73
|
+
self._last_update = asyncio.get_event_loop().time()
|
|
74
|
+
|
|
75
|
+
async def acquire(self) -> None:
|
|
76
|
+
loop = asyncio.get_event_loop()
|
|
77
|
+
now = loop.time()
|
|
78
|
+
|
|
79
|
+
elapsed = now - self._last_update
|
|
80
|
+
self._tokens = min(self.calls, self._tokens + elapsed * (self.calls / self.period))
|
|
81
|
+
self._last_update = now
|
|
82
|
+
|
|
83
|
+
if self._tokens < 1:
|
|
84
|
+
wait_time = (1 - self._tokens) * (self.period / self.calls)
|
|
85
|
+
await asyncio.sleep(wait_time)
|
|
86
|
+
self._tokens = 0
|
|
87
|
+
else:
|
|
88
|
+
self._tokens -= 1
|
sql_assistant/main.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
"""SQL 智能助手 - 主入口"""
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
from contextlib import asynccontextmanager
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
from fastapi import FastAPI, Request
|
|
8
|
+
from fastapi.staticfiles import StaticFiles
|
|
9
|
+
from fastapi.templating import Jinja2Templates
|
|
10
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
11
|
+
|
|
12
|
+
from .api.routes import router as api_router
|
|
13
|
+
from .database.history import close_history_manager
|
|
14
|
+
from .llm.manager import get_llm_manager
|
|
15
|
+
from .database.manager import get_db_manager
|
|
16
|
+
|
|
17
|
+
# 模板和静态文件路径
|
|
18
|
+
BASE_DIR = Path(__file__).parent
|
|
19
|
+
TEMPLATES_DIR = BASE_DIR / "web" / "templates"
|
|
20
|
+
STATIC_DIR = BASE_DIR / "web" / "static"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@asynccontextmanager
|
|
24
|
+
async def lifespan(app: FastAPI):
|
|
25
|
+
"""应用生命周期管理"""
|
|
26
|
+
yield
|
|
27
|
+
await get_llm_manager().close_all()
|
|
28
|
+
await get_db_manager().close_all()
|
|
29
|
+
await close_history_manager()
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
app = FastAPI(
|
|
33
|
+
title="SQL 智能助手",
|
|
34
|
+
description="自然语言转 SQL 查询工具",
|
|
35
|
+
version="1.0.0",
|
|
36
|
+
lifespan=lifespan,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
app.add_middleware(
|
|
40
|
+
CORSMiddleware,
|
|
41
|
+
allow_origins=["*"],
|
|
42
|
+
allow_methods=["*"],
|
|
43
|
+
allow_headers=["*"],
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
# 静态文件
|
|
47
|
+
if STATIC_DIR.exists():
|
|
48
|
+
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
|
|
49
|
+
|
|
50
|
+
# Jinja2 模板
|
|
51
|
+
templates = Jinja2Templates(directory=str(TEMPLATES_DIR))
|
|
52
|
+
|
|
53
|
+
# API 路由
|
|
54
|
+
app.include_router(api_router)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@app.get("/")
|
|
58
|
+
async def index(request: Request):
|
|
59
|
+
return templates.TemplateResponse(request, "index.html", {"request": request})
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def main():
|
|
63
|
+
"""CLI 入口"""
|
|
64
|
+
import uvicorn
|
|
65
|
+
|
|
66
|
+
# 确保控制台支持 UTF-8
|
|
67
|
+
try:
|
|
68
|
+
sys.stdout.reconfigure(encoding='utf-8')
|
|
69
|
+
except Exception:
|
|
70
|
+
pass
|
|
71
|
+
|
|
72
|
+
print("=" * 60)
|
|
73
|
+
print(" SQL Assistant v1.0.0")
|
|
74
|
+
print(" Natural Language -> SQL -> Results")
|
|
75
|
+
print("=" * 60)
|
|
76
|
+
print()
|
|
77
|
+
print(" URL: http://localhost:5010")
|
|
78
|
+
print(" Docs: http://localhost:5010/docs")
|
|
79
|
+
print()
|
|
80
|
+
print(" Configure LLM and Database in Settings (gear icon)")
|
|
81
|
+
print("=" * 60)
|
|
82
|
+
print()
|
|
83
|
+
|
|
84
|
+
uvicorn.run(
|
|
85
|
+
"sql_assistant.main:app",
|
|
86
|
+
host="0.0.0.0",
|
|
87
|
+
port=5010,
|
|
88
|
+
reload=False,
|
|
89
|
+
log_level="info",
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
if __name__ == "__main__":
|
|
94
|
+
main()
|