erlangshen 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. package/.claude/agents/equity-agent.md +26 -0
  2. package/.claude/agents/macro-agent.md +25 -0
  3. package/.claude/commands/analyze.md +40 -0
  4. package/.claude/commands/macro.md +29 -0
  5. package/.claude/settings.json +12 -0
  6. package/CODEX_GOAL.md +46 -0
  7. package/README.md +206 -0
  8. package/bin/cli.js +67 -0
  9. package/bin/erlangshen +2 -0
  10. package/bin/xiaoergod +2 -0
  11. package/frontend/index.html +700 -0
  12. package/knowledge/crypto_guide.md +147 -0
  13. package/knowledge/economic_indicators.md +125 -0
  14. package/knowledge/financial_glossary.md +148 -0
  15. package/knowledge/first_principles.md +50 -0
  16. package/knowledge/first_principles_deep.md +115 -0
  17. package/knowledge/global_markets.md +173 -0
  18. package/knowledge/insights.md +141 -0
  19. package/knowledge/market_basics.md +116 -0
  20. package/knowledge/memos/session_20260513_003616.json +6 -0
  21. package/knowledge/memos/session_20260513_003822.json +6 -0
  22. package/knowledge/risk_management.md +151 -0
  23. package/knowledge/team_context.md +42 -0
  24. package/knowledge/trading_strategies.md +114 -0
  25. package/package.json +42 -0
  26. package/requirements.txt +14 -0
  27. package/scripts/postinstall.js +188 -0
  28. package/scripts/preuninstall.js +22 -0
  29. package/src/__init__.py +4 -0
  30. package/src/__pycache__/__init__.cpython-313.pyc +0 -0
  31. package/src/agents/__init__.py +3 -0
  32. package/src/agents/base.py +103 -0
  33. package/src/agents/base_agent.py +86 -0
  34. package/src/agents/equity.py +136 -0
  35. package/src/agents/equity_agent.py +91 -0
  36. package/src/agents/erlang.py +165 -0
  37. package/src/agents/macro.py +137 -0
  38. package/src/agents/macro_agent.py +81 -0
  39. package/src/agents/multi_asset.py +147 -0
  40. package/src/agents/multi_asset_agent.py +87 -0
  41. package/src/api/__init__.py +1 -0
  42. package/src/api/__pycache__/__init__.cpython-313.pyc +0 -0
  43. package/src/api/__pycache__/server.cpython-313.pyc +0 -0
  44. package/src/api/cli.py +435 -0
  45. package/src/api/cli_enhanced.py +537 -0
  46. package/src/api/server.py +266 -0
  47. package/src/brain.py +200 -0
  48. package/src/cli.py +153 -0
  49. package/src/commands/__init__.py +3 -0
  50. package/src/commands/analyze.py +131 -0
  51. package/src/commands/macro.py +100 -0
  52. package/src/commands/memo.py +216 -0
  53. package/src/commands/portfolio.py +154 -0
  54. package/src/commands/report.py +228 -0
  55. package/src/commands/risk.py +183 -0
  56. package/src/commands/search.py +183 -0
  57. package/src/commands/stock.py +124 -0
  58. package/src/config.py +327 -0
  59. package/src/core/__init__.py +1 -0
  60. package/src/core/brain.py +645 -0
  61. package/src/core/cerebellum.py +175 -0
  62. package/src/core/investment_universe.py +423 -0
  63. package/src/core/knowledge.py +207 -0
  64. package/src/core/memory.py +115 -0
  65. package/src/hooks/__init__.py +3 -0
  66. package/src/hooks/session_end.py +57 -0
  67. package/src/hooks/session_start.py +75 -0
  68. package/src/knowledge/__init__.py +1 -0
  69. package/src/mcp/__init__.py +3 -0
  70. package/src/mcp/feishu.py +331 -0
  71. package/src/mcp/fund_tools.py +323 -0
  72. package/src/mcp/macro.py +452 -0
  73. package/src/mcp/market.py +331 -0
  74. package/src/mcp/registry.py +168 -0
  75. package/src/network/__init__.py +15 -0
  76. package/src/network/detector.py +125 -0
  77. package/src/network/proxy.py +199 -0
  78. package/src/network/router.py +103 -0
  79. package/src/prompts/__init__.py +1 -0
  80. package/src/prompts/analysis_framework.md +164 -0
  81. package/src/prompts/persona.md +65 -0
  82. package/src/prompts/report_template.md +144 -0
  83. package/src/skills/__init__.py +3 -0
  84. package/src/skills/framework.py +105 -0
  85. package/src/skills/templates.py +342 -0
  86. package/src/tools/__init__.py +1 -0
  87. package/src/tools/file_tools.py +209 -0
  88. package/src/tools/macro_tools.py +152 -0
  89. package/src/tools/market_tools.py +1172 -0
  90. package/src/tools/registry.py +398 -0
  91. package/src/tools/search_tools.py +777 -0
  92. package/tests/__init__.py +1 -0
  93. package/tests/test_erlangshen.py +140 -0
@@ -0,0 +1,331 @@
1
+ """
2
+ MCP 工具 - 市场行情数据
3
+ 使用真实数据库数据
4
+ """
5
+
6
+ import sys
7
+ from typing import List, Dict, Any, Optional
8
+ from datetime import datetime, timedelta
9
+ import logging
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ # 导入投资系统数据库连接
14
+ sys.path.insert(0, '/Users/wanghui/.openclaw-agent-06/workspace/investment-strategy')
15
+ from backend.core.database import execute_remote_query
16
+
17
+
18
+ class MarketMCP:
19
+ """
20
+ 行情数据 MCP
21
+
22
+ 提供股票、指数、期货等行情数据接口
23
+ 数据来源:远程数据库 (193.112.183.130)
24
+ """
25
+
26
+ def __init__(self):
27
+ self.name = "market"
28
+
29
+ async def get_stock_price(self, symbol: str) -> Dict[str, Any]:
30
+ """
31
+ 获取股票实时价格
32
+
33
+ Args:
34
+ symbol: 股票代码,如 "000001", "600519"
35
+
36
+ Returns:
37
+ 价格数据字典
38
+ """
39
+ try:
40
+ sql = """
41
+ SELECT 日期, 代码, 开盘, 收盘, 最高, 最低, 成交量, 成交额, 涨跌幅, 涨跌额, 换手率
42
+ FROM A股历史行情表
43
+ WHERE 代码 = %s
44
+ ORDER BY 日期 DESC
45
+ LIMIT 1
46
+ """
47
+ rows = execute_remote_query('stock', sql, (symbol,))
48
+ if rows:
49
+ r = rows[0]
50
+ return {
51
+ "symbol": r['代码'],
52
+ "date": str(r['日期']),
53
+ "price": float(r['收盘']) if r['收盘'] else None,
54
+ "open": float(r['开盘']) if r['开盘'] else None,
55
+ "high": float(r['最高']) if r['最高'] else None,
56
+ "low": float(r['最低']) if r['最低'] else None,
57
+ "change": float(r['涨跌额']) if r['涨跌额'] else 0.0,
58
+ "change_pct": float(r['涨跌幅']) if r['涨跌幅'] else 0.0,
59
+ "volume": float(r['成交量']) if r['成交量'] else 0,
60
+ "amount": float(r['成交额']) if r['成交额'] else 0.0,
61
+ "turnover": float(r['换手率']) if r['换手率'] else None,
62
+ "source": "remote_db"
63
+ }
64
+ return {"error": f"未找到股票 {symbol} 的数据", "symbol": symbol}
65
+ except Exception as e:
66
+ logger.error(f"获取股票 {symbol} 价格失败: {e}")
67
+ return {"error": f"数据库查询失败: {str(e)}", "symbol": symbol}
68
+
69
+ async def get_stock_history(
70
+ self,
71
+ symbol: str,
72
+ days: int = 30,
73
+ end_date: Optional[str] = None
74
+ ) -> List[Dict[str, Any]]:
75
+ """
76
+ 获取股票历史行情
77
+
78
+ Args:
79
+ symbol: 股票代码
80
+ days: 历史天数
81
+ end_date: 结束日期 (YYYY-MM-DD)
82
+
83
+ Returns:
84
+ 历史行情列表
85
+ """
86
+ try:
87
+ sql = """
88
+ SELECT 日期, 代码, 开盘, 收盘, 最高, 最低, 成交量, 成交额, 涨跌幅, 涨跌额
89
+ FROM A股历史行情表
90
+ WHERE 代码 = %s
91
+ ORDER BY 日期 DESC
92
+ LIMIT %s
93
+ """
94
+ rows = execute_remote_query('stock', sql, (symbol, days))
95
+ if rows:
96
+ return [
97
+ {
98
+ "date": str(r['日期']),
99
+ "symbol": r['代码'],
100
+ "open": float(r['开盘']) if r['开盘'] else None,
101
+ "close": float(r['收盘']) if r['收盘'] else None,
102
+ "high": float(r['最高']) if r['最高'] else None,
103
+ "low": float(r['最低']) if r['最低'] else None,
104
+ "volume": float(r['成交量']) if r['成交量'] else 0,
105
+ "amount": float(r['成交额']) if r['成交额'] else 0.0,
106
+ "change_pct": float(r['涨跌幅']) if r['涨跌幅'] else 0.0,
107
+ "change": float(r['涨跌额']) if r['涨跌额'] else 0.0,
108
+ "source": "remote_db"
109
+ }
110
+ for r in rows
111
+ ]
112
+ return [{"error": f"未找到股票 {symbol} 的历史数据"}]
113
+ except Exception as e:
114
+ logger.error(f"获取股票 {symbol} 历史失败: {e}")
115
+ return [{"error": f"数据库查询失败: {str(e)}"}]
116
+
117
+ async def get_index_quote(self, index_name: str) -> Dict[str, Any]:
118
+ """
119
+ 获取指数行情
120
+
121
+ Args:
122
+ index_name: 指数名称,如 "上证指数", "沪深300", "创业板指"
123
+
124
+ Returns:
125
+ 指数行情数据
126
+ """
127
+ index_names = {
128
+ "上证指数": "000001",
129
+ "深证成指": "399001",
130
+ "创业板指": "399006",
131
+ "沪深300": "000300",
132
+ "上证50": "000016",
133
+ "中证500": "000905",
134
+ "科创50": "科创50",
135
+ }
136
+ try:
137
+ sql = """
138
+ SELECT 指数名称, 日期, 开盘价, 最高价, 最低价, 收盘价, 成交量, 成交额, 涨跌幅, 涨跌额
139
+ FROM 国内宽基指数行情表
140
+ WHERE 指数名称 = %s
141
+ ORDER BY 日期 DESC
142
+ LIMIT 1
143
+ """
144
+ rows = execute_remote_query('index', sql, (index_name,))
145
+ if rows:
146
+ r = rows[0]
147
+ return {
148
+ "index_name": r['指数名称'],
149
+ "date": str(r['日期']),
150
+ "price": float(r['收盘价']) if r['收盘价'] else None,
151
+ "open": float(r['开盘价']) if r['开盘价'] else None,
152
+ "high": float(r['最高价']) if r['最高价'] else None,
153
+ "low": float(r['最低价']) if r['最低价'] else None,
154
+ "change": float(r['涨跌额']) if r['涨跌额'] else 0.0,
155
+ "change_pct": float(r['涨跌幅']) if r['涨跌幅'] else 0.0,
156
+ "volume": float(r['成交量']) if r['成交量'] else 0,
157
+ "amount": float(r['成交额']) if r['成交额'] else 0.0,
158
+ "source": "remote_db"
159
+ }
160
+ return {"error": f"未找到指数 {index_name} 的数据", "index_name": index_name}
161
+ except Exception as e:
162
+ logger.error(f"获取指数 {index_name} 行情失败: {e}")
163
+ return {"error": f"数据库查询失败: {str(e)}", "index_name": index_name}
164
+
165
+ async def get_index_history(
166
+ self,
167
+ index_name: str,
168
+ days: int = 30
169
+ ) -> List[Dict[str, Any]]:
170
+ """
171
+ 获取指数历史行情
172
+
173
+ Args:
174
+ index_name: 指数名称
175
+ days: 历史天数
176
+
177
+ Returns:
178
+ 历史行情列表
179
+ """
180
+ try:
181
+ sql = """
182
+ SELECT 指数名称, 日期, 开盘价, 最高价, 最低价, 收盘价, 成交量, 成交额, 涨跌幅, 涨跌额
183
+ FROM 国内宽基指数行情表
184
+ WHERE 指数名称 = %s
185
+ ORDER BY 日期 DESC
186
+ LIMIT %s
187
+ """
188
+ rows = execute_remote_query('index', sql, (index_name, days))
189
+ if rows:
190
+ return [
191
+ {
192
+ "index_name": r['指数名称'],
193
+ "date": str(r['日期']),
194
+ "open": float(r['开盘价']) if r['开盘价'] else None,
195
+ "high": float(r['最高价']) if r['最高价'] else None,
196
+ "low": float(r['最低价']) if r['最低价'] else None,
197
+ "close": float(r['收盘价']) if r['收盘价'] else None,
198
+ "volume": float(r['成交量']) if r['成交量'] else 0,
199
+ "amount": float(r['成交额']) if r['成交额'] else 0.0,
200
+ "change_pct": float(r['涨跌幅']) if r['涨跌幅'] else 0.0,
201
+ "change": float(r['涨跌额']) if r['涨跌额'] else 0.0,
202
+ "source": "remote_db"
203
+ }
204
+ for r in rows
205
+ ]
206
+ return [{"error": f"未找到指数 {index_name} 的历史数据"}]
207
+ except Exception as e:
208
+ logger.error(f"获取指数 {index_name} 历史失败: {e}")
209
+ return [{"error": f"数据库查询失败: {str(e)}"}]
210
+
211
+ async def get_futures_price(self, contract: str) -> Dict[str, Any]:
212
+ """
213
+ 获取期货价格
214
+
215
+ Args:
216
+ contract: 期货合约名称
217
+
218
+ Returns:
219
+ 期货行情数据
220
+ """
221
+ try:
222
+ sql = """
223
+ SELECT contract_code, trade_date, open, high, low, close, volume, open_interest
224
+ FROM 全部期货合约历史行情
225
+ WHERE contract_code = %s
226
+ ORDER BY trade_date DESC
227
+ LIMIT 1
228
+ """
229
+ rows = execute_remote_query('futures', sql, (contract,))
230
+ if rows:
231
+ r = rows[0]
232
+ return {
233
+ "contract": r['contract_code'],
234
+ "date": str(r['trade_date']),
235
+ "price": float(r['close']),
236
+ "open": float(r['open']) if r['open'] else None,
237
+ "high": float(r['high']) if r['high'] else None,
238
+ "low": float(r['low']) if r['low'] else None,
239
+ "volume": int(r['volume']) if r['volume'] else 0,
240
+ "open_interest": int(r['open_interest']) if r['open_interest'] else 0,
241
+ "source": "remote_db"
242
+ }
243
+ return {"error": f"未找到期货合约 {contract} 的数据", "contract": contract}
244
+ except Exception as e:
245
+ logger.error(f"获取期货 {contract} 价格失败: {e}")
246
+ return {"error": f"数据库查询失败: {str(e)}", "contract": contract}
247
+
248
+ async def get_etf_quote(self, symbol: str) -> Dict[str, Any]:
249
+ """
250
+ 获取ETF行情 (暂无独立ETF表,暂用股票表)
251
+
252
+ Args:
253
+ symbol: ETF代码
254
+
255
+ Returns:
256
+ ETF行情数据
257
+ """
258
+ return await self.get_stock_price(symbol)
259
+
260
+ async def get_realtime_quotes(self, symbols: List[str]) -> List[Dict[str, Any]]:
261
+ """
262
+ 批量获取实时行情
263
+
264
+ Args:
265
+ symbols: 股票代码列表
266
+
267
+ Returns:
268
+ 行情列表
269
+ """
270
+ results = []
271
+ for symbol in symbols:
272
+ quote = await self.get_stock_price(symbol)
273
+ results.append(quote)
274
+ return results
275
+
276
+ def list_tools(self) -> List[Dict[str, Any]]:
277
+ """列出所有可用工具"""
278
+ return [
279
+ {
280
+ "name": "get_stock_price",
281
+ "description": "获取股票实时价格",
282
+ "parameters": {
283
+ "symbol": "股票代码,如 000001, 600519"
284
+ }
285
+ },
286
+ {
287
+ "name": "get_stock_history",
288
+ "description": "获取股票历史行情",
289
+ "parameters": {
290
+ "symbol": "股票代码",
291
+ "days": "历史天数 (默认30)",
292
+ "end_date": "结束日期 (YYYY-MM-DD)"
293
+ }
294
+ },
295
+ {
296
+ "name": "get_index_quote",
297
+ "description": "获取指数行情",
298
+ "parameters": {
299
+ "index_name": "指数名称,如 上证指数, 沪深300, 创业板指"
300
+ }
301
+ },
302
+ {
303
+ "name": "get_index_history",
304
+ "description": "获取指数历史行情",
305
+ "parameters": {
306
+ "index_name": "指数名称",
307
+ "days": "历史天数"
308
+ }
309
+ },
310
+ {
311
+ "name": "get_futures_price",
312
+ "description": "获取期货价格",
313
+ "parameters": {
314
+ "contract": "期货合约代码"
315
+ }
316
+ },
317
+ {
318
+ "name": "get_etf_quote",
319
+ "description": "获取ETF行情",
320
+ "parameters": {
321
+ "symbol": "ETF代码"
322
+ }
323
+ },
324
+ {
325
+ "name": "get_realtime_quotes",
326
+ "description": "批量获取实时行情",
327
+ "parameters": {
328
+ "symbols": "股票代码列表"
329
+ }
330
+ },
331
+ ]
@@ -0,0 +1,168 @@
1
+ """
2
+ MCP 注册表 - 统一管理所有MCP工具
3
+ """
4
+
5
+ from typing import Dict, Any, List, Optional, Callable, Awaitable
6
+ from dataclasses import dataclass
7
+ import logging
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ @dataclass
13
+ class MCPTool:
14
+ """MCP 工具定义"""
15
+ name: str
16
+ description: str
17
+ mcp_name: str
18
+ handler: Callable[..., Awaitable[Any]]
19
+ parameters: Dict[str, Any]
20
+
21
+
22
+ class MCPRegistry:
23
+ """
24
+ MCP 注册表
25
+
26
+ 统一管理所有 MCP 工具的注册和调用
27
+ """
28
+
29
+ def __init__(self):
30
+ self._mcps: Dict[str, Any] = {}
31
+ self._tools: Dict[str, MCPTool] = {}
32
+ self._register_default_mcps()
33
+
34
+ def _register_default_mcps(self):
35
+ """注册默认 MCP"""
36
+ # Market MCP
37
+ try:
38
+ from src.mcp.market import MarketMCP
39
+ self.register_mcp("market", MarketMCP())
40
+ except Exception as e:
41
+ logger.warning(f"Market MCP 注册失败: {e}")
42
+
43
+ # Macro MCP
44
+ try:
45
+ from src.mcp.macro import MacroMCP
46
+ self.register_mcp("macro", MacroMCP())
47
+ except Exception as e:
48
+ logger.warning(f"Macro MCP 注册失败: {e}")
49
+
50
+ # Feishu MCP
51
+ try:
52
+ from src.mcp.feishu import FeishuMCP
53
+ self.register_mcp("feishu", FeishuMCP())
54
+ except Exception as e:
55
+ logger.warning(f"Feishu MCP 注册失败: {e}")
56
+
57
+ # Fund MCP
58
+ try:
59
+ from src.mcp.fund_tools import FundMCP
60
+ self.register_mcp("fund", FundMCP())
61
+ except Exception as e:
62
+ logger.warning(f"Fund MCP 注册失败: {e}")
63
+
64
+ def register_mcp(self, name: str, mcp_instance: Any):
65
+ """
66
+ 注册 MCP 实例
67
+
68
+ Args:
69
+ name: MCP 名称
70
+ mcp_instance: MCP 实例
71
+ """
72
+ self._mcps[name] = mcp_instance
73
+
74
+ # 自动注册 MCP 的工具
75
+ if hasattr(mcp_instance, "list_tools"):
76
+ for tool in mcp_instance.list_tools():
77
+ tool_name = tool["name"]
78
+ handler = getattr(mcp_instance, tool_name, None)
79
+
80
+ if handler and callable(handler):
81
+ self._tools[tool_name] = MCPTool(
82
+ name=tool_name,
83
+ description=tool["description"],
84
+ mcp_name=name,
85
+ handler=handler,
86
+ parameters=tool.get("parameters", {})
87
+ )
88
+
89
+ logger.info(f"MCP '{name}' 已注册,包含 {len([t for t in self._tools.values() if t.mcp_name == name])} 个工具")
90
+
91
+ def get_mcp(self, name: str) -> Optional[Any]:
92
+ """
93
+ 获取 MCP 实例
94
+
95
+ Args:
96
+ name: MCP 名称
97
+
98
+ Returns:
99
+ MCP 实例
100
+ """
101
+ return self._mcps.get(name)
102
+
103
+ def list_mcps(self) -> List[Dict[str, Any]]:
104
+ """列出所有已注册的 MCP"""
105
+ return [
106
+ {
107
+ "name": name,
108
+ "tools": len([t for t in self._tools.values() if t.mcp_name == name])
109
+ }
110
+ for name, mcp in self._mcps.items()
111
+ ]
112
+
113
+ def list_tools(self) -> List[Dict[str, Any]]:
114
+ """列出所有可用工具"""
115
+ return [
116
+ {
117
+ "name": tool.name,
118
+ "description": tool.description,
119
+ "mcp": tool.mcp_name,
120
+ "parameters": tool.parameters
121
+ }
122
+ for tool in self._tools.values()
123
+ ]
124
+
125
+ async def call_tool(self, tool_name: str, **kwargs) -> Any:
126
+ """
127
+ 调用 MCP 工具
128
+
129
+ Args:
130
+ tool_name: 工具名称
131
+ **kwargs: 工具参数
132
+
133
+ Returns:
134
+ 工具执行结果
135
+ """
136
+ tool = self._tools.get(tool_name)
137
+
138
+ if not tool:
139
+ return {
140
+ "success": False,
141
+ "error": f"未知工具: {tool_name}",
142
+ "available_tools": list(self._tools.keys())
143
+ }
144
+
145
+ try:
146
+ result = await tool.handler(**kwargs)
147
+ return result
148
+ except Exception as e:
149
+ logger.error(f"工具 {tool_name} 执行失败: {e}")
150
+ return {
151
+ "success": False,
152
+ "error": str(e),
153
+ "tool": tool_name
154
+ }
155
+
156
+ def get_tool_info(self, tool_name: str) -> Optional[Dict[str, Any]]:
157
+ """获取工具信息"""
158
+ tool = self._tools.get(tool_name)
159
+
160
+ if not tool:
161
+ return None
162
+
163
+ return {
164
+ "name": tool.name,
165
+ "description": tool.description,
166
+ "mcp": tool.mcp_name,
167
+ "parameters": tool.parameters
168
+ }
@@ -0,0 +1,15 @@
1
+ """
2
+ 网络自适应模块
3
+ 根据目标自动选择最优网络路径(直连/代理)
4
+ """
5
+
6
+ from .proxy import ProxyManager, ProxyConfig
7
+ from .detector import NetworkDetector
8
+ from .router import NetworkRouter
9
+
10
+ __all__ = [
11
+ "ProxyManager",
12
+ "ProxyConfig",
13
+ "NetworkDetector",
14
+ "NetworkRouter",
15
+ ]
@@ -0,0 +1,125 @@
1
+ """
2
+ 网络环境检测
3
+ 检测当前是否使用VPN、直连等
4
+ """
5
+
6
+ import asyncio
7
+ import socket
8
+ import os
9
+ from typing import Dict, Optional
10
+
11
+
12
+ class NetworkDetector:
13
+ """网络环境检测器"""
14
+
15
+ # 中国测试主机
16
+ CHINA_TEST_HOSTS = [
17
+ ("www.baidu.com", 80),
18
+ ("www.aliyun.com", 80),
19
+ ("api.binance.com", 443), # Binance中国可直连
20
+ ("api.okx.com", 443), # OKX中国可直连
21
+ ]
22
+
23
+ # 国际测试主机
24
+ GLOBAL_TEST_HOSTS = [
25
+ ("www.google.com", 443),
26
+ ("www.cloudflare.com", 443),
27
+ ("api.coingecko.com", 443),
28
+ ("api.openai.com", 443),
29
+ ]
30
+
31
+ async def detect_environment(self) -> Dict[str, bool]:
32
+ """
33
+ 检测网络环境
34
+ 返回: {
35
+ "china_reachable": True, # 中国网络是否可达
36
+ "global_reachable": True, # 全球网络是否可达
37
+ "vpn_active": False, # VPN是否激活
38
+ "proxy_active": False, # 代理是否激活
39
+ }
40
+ """
41
+ results = {
42
+ "china_reachable": False,
43
+ "global_reachable": False,
44
+ "vpn_active": False,
45
+ "proxy_active": False,
46
+ }
47
+
48
+ # 并发检测中国和国际网络
49
+ china_task = self._test_hosts(self.CHINA_TEST_HOSTS)
50
+ global_task = self._test_hosts(self.GLOBAL_TEST_HOSTS)
51
+
52
+ china_results, global_results = await asyncio.gather(china_task, global_task)
53
+
54
+ results["china_reachable"] = any(china_results)
55
+ results["global_reachable"] = any(global_results)
56
+
57
+ # 检测VPN
58
+ results["vpn_active"] = await self._detect_vpn()
59
+
60
+ # 检测代理环境变量
61
+ results["proxy_active"] = self._detect_proxy_env()
62
+
63
+ return results
64
+
65
+ async def _test_hosts(self, hosts: list) -> list:
66
+ """测试一组主机是否可达"""
67
+ tasks = [self._test_host(host, port) for host, port in hosts]
68
+ results = await asyncio.gather(*tasks, return_exceptions=True)
69
+ return [r for r in results if r is True]
70
+
71
+ async def _test_host(self, host: str, port: int, timeout: float = 3.0) -> bool:
72
+ """测试单个主机"""
73
+ try:
74
+ reader, writer = await asyncio.wait_for(
75
+ asyncio.open_connection(host, port),
76
+ timeout=timeout
77
+ )
78
+ writer.close()
79
+ await writer.wait_closed()
80
+ return True
81
+ except Exception:
82
+ return False
83
+
84
+ async def _detect_vpn(self) -> bool:
85
+ """检测VPN是否激活"""
86
+ import subprocess
87
+ try:
88
+ # macOS检测VPN接口
89
+ result = subprocess.run(
90
+ ["networksetup", "-listallnetworkservices"],
91
+ capture_output=True, text=True, timeout=5
92
+ )
93
+ # 检查是否有VPN相关的网络服务
94
+ return "VPN" in result.stdout or "utun" in result.stdout
95
+ except Exception:
96
+ return False
97
+
98
+ def _detect_proxy_env(self) -> bool:
99
+ """检测代理环境变量"""
100
+ proxy_vars = [
101
+ "HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY",
102
+ "http_proxy", "https_proxy", "all_proxy",
103
+ "SOCKS_PROXY", "socks_proxy"
104
+ ]
105
+ for var in proxy_vars:
106
+ if os.environ.get(var):
107
+ return True
108
+ return False
109
+
110
+ def get_recommended_mode(self) -> str:
111
+ """获取推荐的代理模式"""
112
+ env = asyncio.run(self.detect_environment())
113
+
114
+ if env["vpn_active"] and env["global_reachable"]:
115
+ return "vpn"
116
+ elif env["china_reachable"] and not env["global_reachable"]:
117
+ return "china_only"
118
+ elif env["global_reachable"] and env["china_reachable"]:
119
+ return "auto"
120
+ else:
121
+ return "direct"
122
+
123
+ async def quick_test(self, host: str = "www.baidu.com", port: int = 80) -> bool:
124
+ """快速网络测试"""
125
+ return await self._test_host(host, port, timeout=2.0)