aigroup-econ-mcp 0.3.8__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aigroup_econ_mcp/__init__.py +18 -18
- aigroup_econ_mcp/server.py +284 -3291
- aigroup_econ_mcp/server_v1_backup.py +1250 -0
- aigroup_econ_mcp/server_v1_old.py +1250 -0
- aigroup_econ_mcp/server_with_file_support.py +259 -0
- aigroup_econ_mcp/tools/__init__.py +3 -2
- aigroup_econ_mcp/tools/data_loader.py +171 -0
- aigroup_econ_mcp/tools/decorators.py +178 -0
- aigroup_econ_mcp/tools/file_input_handler.py +268 -0
- aigroup_econ_mcp/tools/file_parser.py +560 -0
- aigroup_econ_mcp/tools/machine_learning.py +14 -14
- aigroup_econ_mcp/tools/panel_data.py +10 -6
- aigroup_econ_mcp/tools/time_series.py +54 -127
- aigroup_econ_mcp/tools/tool_handlers.py +378 -0
- aigroup_econ_mcp/tools/tool_registry.py +170 -0
- {aigroup_econ_mcp-0.3.8.dist-info → aigroup_econ_mcp-0.4.0.dist-info}/METADATA +287 -22
- aigroup_econ_mcp-0.4.0.dist-info/RECORD +30 -0
- aigroup_econ_mcp-0.3.8.dist-info/RECORD +0 -21
- {aigroup_econ_mcp-0.3.8.dist-info → aigroup_econ_mcp-0.4.0.dist-info}/WHEEL +0 -0
- {aigroup_econ_mcp-0.3.8.dist-info → aigroup_econ_mcp-0.4.0.dist-info}/entry_points.txt +0 -0
- {aigroup_econ_mcp-0.3.8.dist-info → aigroup_econ_mcp-0.4.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,259 @@
|
|
|
1
|
+
"""
|
|
2
|
+
AIGroup 计量经济学 MCP 服务器 - 支持CSV文件路径输入
|
|
3
|
+
使用最新的MCP特性提供专业计量经济学分析工具
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import List, Dict, Any, Optional, Annotated, Union
|
|
7
|
+
from collections.abc import AsyncIterator
|
|
8
|
+
from contextlib import asynccontextmanager
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
import json
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
import pandas as pd
|
|
14
|
+
import numpy as np
|
|
15
|
+
import statsmodels.api as sm
|
|
16
|
+
from statsmodels.tsa import stattools
|
|
17
|
+
from scipy import stats
|
|
18
|
+
from pydantic import BaseModel, Field
|
|
19
|
+
|
|
20
|
+
from mcp.server.fastmcp import FastMCP, Context
|
|
21
|
+
from mcp.server.session import ServerSession
|
|
22
|
+
from mcp.types import CallToolResult, TextContent
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
# 数据模型定义
|
|
26
|
+
class DescriptiveStatsResult(BaseModel):
|
|
27
|
+
"""描述性统计结果"""
|
|
28
|
+
count: int = Field(description="样本数量")
|
|
29
|
+
mean: float = Field(description="均值")
|
|
30
|
+
std: float = Field(description="标准差")
|
|
31
|
+
min: float = Field(description="最小值")
|
|
32
|
+
max: float = Field(description="最大值")
|
|
33
|
+
median: float = Field(description="中位数")
|
|
34
|
+
skewness: float = Field(description="偏度")
|
|
35
|
+
kurtosis: float = Field(description="峰度")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# 应用上下文
|
|
39
|
+
@dataclass
|
|
40
|
+
class AppContext:
|
|
41
|
+
"""应用上下文,包含共享资源"""
|
|
42
|
+
config: Dict[str, Any]
|
|
43
|
+
version: str = "0.2.0"
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@asynccontextmanager
|
|
47
|
+
async def lifespan(server: FastMCP) -> AsyncIterator[AppContext]:
|
|
48
|
+
"""服务器生命周期管理"""
|
|
49
|
+
config = {
|
|
50
|
+
"max_sample_size": 10000,
|
|
51
|
+
"default_significance_level": 0.05,
|
|
52
|
+
"supported_tests": ["t_test", "f_test", "chi_square", "adf"],
|
|
53
|
+
"data_types": ["cross_section", "time_series", "panel"]
|
|
54
|
+
}
|
|
55
|
+
try:
|
|
56
|
+
yield AppContext(config=config, version="0.2.0")
|
|
57
|
+
finally:
|
|
58
|
+
pass
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
# 创建MCP服务器实例
|
|
62
|
+
mcp = FastMCP(
|
|
63
|
+
name="aigroup-econ-mcp",
|
|
64
|
+
instructions="Econometrics MCP Server with CSV file path support",
|
|
65
|
+
lifespan=lifespan
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
# 辅助函数:加载CSV文件
|
|
70
|
+
async def load_data_from_path(file_path: str, ctx: Context) -> Dict[str, List[float]]:
|
|
71
|
+
"""从CSV文件路径加载数据"""
|
|
72
|
+
await ctx.info(f"正在从文件加载数据: {file_path}")
|
|
73
|
+
|
|
74
|
+
try:
|
|
75
|
+
# 读取CSV文件
|
|
76
|
+
df = pd.read_csv(file_path)
|
|
77
|
+
|
|
78
|
+
# 转换为字典格式
|
|
79
|
+
data = {col: df[col].tolist() for col in df.columns}
|
|
80
|
+
|
|
81
|
+
await ctx.info(f"✅ 文件加载成功:{len(df.columns)}个变量,{len(df)}个观测")
|
|
82
|
+
return data
|
|
83
|
+
|
|
84
|
+
except FileNotFoundError:
|
|
85
|
+
raise ValueError(f"文件不存在: {file_path}")
|
|
86
|
+
except Exception as e:
|
|
87
|
+
raise ValueError(f"文件读取失败: {str(e)}")
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@mcp.tool()
|
|
91
|
+
async def descriptive_statistics(
|
|
92
|
+
ctx: Context[ServerSession, AppContext],
|
|
93
|
+
data: Annotated[
|
|
94
|
+
Union[Dict[str, List[float]], str],
|
|
95
|
+
Field(
|
|
96
|
+
description="""数据输入,支持两种格式:
|
|
97
|
+
|
|
98
|
+
📊 格式1:数据字典
|
|
99
|
+
{
|
|
100
|
+
"GDP增长率": [3.2, 2.8, 3.5, 2.9],
|
|
101
|
+
"通货膨胀率": [2.1, 2.3, 1.9, 2.4],
|
|
102
|
+
"失业率": [4.5, 4.2, 4.0, 4.3]
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
📁 格式2:CSV文件路径(推荐)
|
|
106
|
+
"d:/data/economic_data.csv"
|
|
107
|
+
或
|
|
108
|
+
"./test_data.csv"
|
|
109
|
+
|
|
110
|
+
CSV文件要求:
|
|
111
|
+
- 第一行为变量名(表头)
|
|
112
|
+
- 后续行为数值数据
|
|
113
|
+
- 所有列必须为数值类型"""
|
|
114
|
+
)
|
|
115
|
+
]
|
|
116
|
+
) -> CallToolResult:
|
|
117
|
+
"""计算描述性统计量
|
|
118
|
+
|
|
119
|
+
📊 功能说明:
|
|
120
|
+
对输入数据进行全面的描述性统计分析,包括集中趋势、离散程度、分布形状等指标。
|
|
121
|
+
|
|
122
|
+
💡 使用场景:
|
|
123
|
+
- 初步了解数据的分布特征
|
|
124
|
+
- 检查数据质量和异常值
|
|
125
|
+
- 为后续建模提供基础信息
|
|
126
|
+
"""
|
|
127
|
+
try:
|
|
128
|
+
# 检测输入类型并加载数据
|
|
129
|
+
if isinstance(data, str):
|
|
130
|
+
# 文件路径输入
|
|
131
|
+
data_dict = await load_data_from_path(data, ctx)
|
|
132
|
+
else:
|
|
133
|
+
# 字典输入
|
|
134
|
+
data_dict = data
|
|
135
|
+
|
|
136
|
+
await ctx.info(f"开始计算描述性统计,处理 {len(data_dict)} 个变量")
|
|
137
|
+
|
|
138
|
+
# 数据验证
|
|
139
|
+
if not data_dict:
|
|
140
|
+
raise ValueError("数据不能为空")
|
|
141
|
+
|
|
142
|
+
df = pd.DataFrame(data_dict)
|
|
143
|
+
|
|
144
|
+
# 基础统计量
|
|
145
|
+
result = DescriptiveStatsResult(
|
|
146
|
+
count=len(df),
|
|
147
|
+
mean=df.mean().mean(),
|
|
148
|
+
std=df.std().mean(),
|
|
149
|
+
min=df.min().min(),
|
|
150
|
+
max=df.max().max(),
|
|
151
|
+
median=df.median().mean(),
|
|
152
|
+
skewness=df.skew().mean(),
|
|
153
|
+
kurtosis=df.kurtosis().mean()
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
# 计算相关系数矩阵
|
|
157
|
+
correlation_matrix = df.corr().round(4)
|
|
158
|
+
|
|
159
|
+
await ctx.info(f"描述性统计计算完成,样本大小: {len(df)}")
|
|
160
|
+
|
|
161
|
+
return CallToolResult(
|
|
162
|
+
content=[
|
|
163
|
+
TextContent(
|
|
164
|
+
type="text",
|
|
165
|
+
text=f"描述性统计结果:\n"
|
|
166
|
+
f"样本数: {result.count}\n"
|
|
167
|
+
f"均值: {result.mean:.4f}\n"
|
|
168
|
+
f"标准差: {result.std:.4f}\n"
|
|
169
|
+
f"最小值: {result.min:.4f}\n"
|
|
170
|
+
f"最大值: {result.max:.4f}\n"
|
|
171
|
+
f"中位数: {result.median:.4f}\n"
|
|
172
|
+
f"偏度: {result.skewness:.4f}\n"
|
|
173
|
+
f"峰度: {result.kurtosis:.4f}\n\n"
|
|
174
|
+
f"相关系数矩阵:\n{correlation_matrix.to_string()}"
|
|
175
|
+
)
|
|
176
|
+
],
|
|
177
|
+
structuredContent=result.model_dump()
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
except Exception as e:
|
|
181
|
+
await ctx.error(f"计算描述性统计时出错: {str(e)}")
|
|
182
|
+
return CallToolResult(
|
|
183
|
+
content=[TextContent(type="text", text=f"错误: {str(e)}")],
|
|
184
|
+
isError=True
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
@mcp.tool()
|
|
189
|
+
async def correlation_analysis(
|
|
190
|
+
ctx: Context[ServerSession, AppContext],
|
|
191
|
+
data: Annotated[
|
|
192
|
+
Union[Dict[str, List[float]], str],
|
|
193
|
+
Field(
|
|
194
|
+
description="""数据输入,支持两种格式:
|
|
195
|
+
|
|
196
|
+
📊 格式1:数据字典
|
|
197
|
+
{
|
|
198
|
+
"销售额": [12000, 13500, 11800, 14200],
|
|
199
|
+
"广告支出": [800, 900, 750, 1000],
|
|
200
|
+
"价格": [99, 95, 102, 98]
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
📁 格式2:CSV文件路径
|
|
204
|
+
"d:/data/marketing_data.csv"
|
|
205
|
+
|
|
206
|
+
要求:
|
|
207
|
+
- 至少包含2个变量
|
|
208
|
+
- 所有变量的数据点数量必须相同"""
|
|
209
|
+
)
|
|
210
|
+
],
|
|
211
|
+
method: Annotated[
|
|
212
|
+
str,
|
|
213
|
+
Field(
|
|
214
|
+
default="pearson",
|
|
215
|
+
description="相关系数类型:pearson/spearman/kendall"
|
|
216
|
+
)
|
|
217
|
+
] = "pearson"
|
|
218
|
+
) -> CallToolResult:
|
|
219
|
+
"""变量间相关性分析"""
|
|
220
|
+
try:
|
|
221
|
+
# 检测输入类型并加载数据
|
|
222
|
+
if isinstance(data, str):
|
|
223
|
+
data_dict = await load_data_from_path(data, ctx)
|
|
224
|
+
else:
|
|
225
|
+
data_dict = data
|
|
226
|
+
|
|
227
|
+
await ctx.info(f"开始相关性分析: {method}")
|
|
228
|
+
|
|
229
|
+
# 数据验证
|
|
230
|
+
if not data_dict:
|
|
231
|
+
raise ValueError("数据不能为空")
|
|
232
|
+
if len(data_dict) < 2:
|
|
233
|
+
raise ValueError("至少需要2个变量进行相关性分析")
|
|
234
|
+
|
|
235
|
+
df = pd.DataFrame(data_dict)
|
|
236
|
+
correlation_matrix = df.corr(method=method)
|
|
237
|
+
|
|
238
|
+
await ctx.info("相关性分析完成")
|
|
239
|
+
|
|
240
|
+
return CallToolResult(
|
|
241
|
+
content=[
|
|
242
|
+
TextContent(
|
|
243
|
+
type="text",
|
|
244
|
+
text=f"{method.title()}相关系数矩阵:\n{correlation_matrix.round(4).to_string()}"
|
|
245
|
+
)
|
|
246
|
+
]
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
except Exception as e:
|
|
250
|
+
await ctx.error(f"相关性分析出错: {str(e)}")
|
|
251
|
+
return CallToolResult(
|
|
252
|
+
content=[TextContent(type="text", text=f"错误: {str(e)}")],
|
|
253
|
+
isError=True
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def create_mcp_server() -> FastMCP:
|
|
258
|
+
"""创建并返回MCP服务器实例"""
|
|
259
|
+
return mcp
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from . import regression, statistics, time_series, machine_learning, panel_data
|
|
6
|
-
from . import validation, cache, monitoring
|
|
6
|
+
from . import validation, cache, monitoring, file_parser
|
|
7
7
|
|
|
8
8
|
__all__ = [
|
|
9
9
|
"regression",
|
|
@@ -13,5 +13,6 @@ __all__ = [
|
|
|
13
13
|
"panel_data",
|
|
14
14
|
"validation",
|
|
15
15
|
"cache",
|
|
16
|
-
"monitoring"
|
|
16
|
+
"monitoring",
|
|
17
|
+
"file_parser"
|
|
17
18
|
]
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
"""
|
|
2
|
+
数据加载辅助模块
|
|
3
|
+
提供通用的CSV文件加载功能
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Dict, List, Union
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
import pandas as pd
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
async def load_data_if_path(
|
|
12
|
+
data: Union[Dict[str, List[float]], str],
|
|
13
|
+
ctx = None
|
|
14
|
+
) -> Dict[str, List[float]]:
|
|
15
|
+
"""
|
|
16
|
+
智能加载数据:如果是字符串则作为文件路径加载,否则直接返回
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
data: 数据字典或CSV文件路径
|
|
20
|
+
ctx: MCP上下文对象(可选,用于日志)
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
数据字典
|
|
24
|
+
|
|
25
|
+
Raises:
|
|
26
|
+
ValueError: 文件不存在或读取失败
|
|
27
|
+
"""
|
|
28
|
+
# 如果已经是字典,直接返回
|
|
29
|
+
if isinstance(data, dict):
|
|
30
|
+
return data
|
|
31
|
+
|
|
32
|
+
# 如果是字符串,作为文件路径处理
|
|
33
|
+
if isinstance(data, str):
|
|
34
|
+
if ctx:
|
|
35
|
+
await ctx.info(f"📁 检测到文件路径,正在加载: {data}")
|
|
36
|
+
|
|
37
|
+
try:
|
|
38
|
+
# 检查文件是否存在
|
|
39
|
+
path = Path(data)
|
|
40
|
+
if not path.exists():
|
|
41
|
+
raise ValueError(f"文件不存在: {data}")
|
|
42
|
+
|
|
43
|
+
# 读取CSV文件
|
|
44
|
+
df = pd.read_csv(path)
|
|
45
|
+
|
|
46
|
+
# 转换为字典格式
|
|
47
|
+
result = {col: df[col].tolist() for col in df.columns}
|
|
48
|
+
|
|
49
|
+
if ctx:
|
|
50
|
+
await ctx.info(f"✅ CSV文件加载成功:{len(df.columns)}个变量,{len(df)}个观测")
|
|
51
|
+
|
|
52
|
+
return result
|
|
53
|
+
|
|
54
|
+
except FileNotFoundError:
|
|
55
|
+
raise ValueError(f"文件不存在: {data}")
|
|
56
|
+
except Exception as e:
|
|
57
|
+
raise ValueError(f"CSV文件读取失败: {str(e)}")
|
|
58
|
+
|
|
59
|
+
# 其他类型报错
|
|
60
|
+
|
|
61
|
+
async def load_single_var_if_path(
|
|
62
|
+
data: Union[List[float], str],
|
|
63
|
+
ctx = None,
|
|
64
|
+
column_name: str = None
|
|
65
|
+
) -> List[float]:
|
|
66
|
+
"""
|
|
67
|
+
智能加载单变量数据:如果是字符串则作为文件路径加载,否则直接返回
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
data: 数据列表或CSV文件路径
|
|
71
|
+
ctx: MCP上下文对象(可选,用于日志)
|
|
72
|
+
column_name: CSV文件中要读取的列名(可选,默认读取第一列)
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
数据列表
|
|
76
|
+
|
|
77
|
+
Raises:
|
|
78
|
+
ValueError: 文件不存在或读取失败
|
|
79
|
+
"""
|
|
80
|
+
# 如果已经是列表,直接返回
|
|
81
|
+
if isinstance(data, list):
|
|
82
|
+
return data
|
|
83
|
+
|
|
84
|
+
# 如果是字符串,作为文件路径处理
|
|
85
|
+
if isinstance(data, str):
|
|
86
|
+
if ctx:
|
|
87
|
+
await ctx.info(f"📁 检测到文件路径,正在加载: {data}")
|
|
88
|
+
|
|
89
|
+
try:
|
|
90
|
+
# 检查文件是否存在
|
|
91
|
+
path = Path(data)
|
|
92
|
+
if not path.exists():
|
|
93
|
+
raise ValueError(f"文件不存在: {data}")
|
|
94
|
+
|
|
95
|
+
# 读取CSV文件
|
|
96
|
+
df = pd.read_csv(path)
|
|
97
|
+
|
|
98
|
+
# 确定要读取的列
|
|
99
|
+
if column_name:
|
|
100
|
+
if column_name not in df.columns:
|
|
101
|
+
raise ValueError(f"列'{column_name}'不存在于CSV文件中。可用列: {list(df.columns)}")
|
|
102
|
+
result = df[column_name].tolist()
|
|
103
|
+
else:
|
|
104
|
+
# 默认读取第一列
|
|
105
|
+
result = df.iloc[:, 0].tolist()
|
|
106
|
+
if ctx:
|
|
107
|
+
await ctx.info(f"未指定列名,使用第一列: {df.columns[0]}")
|
|
108
|
+
|
|
109
|
+
if ctx:
|
|
110
|
+
await ctx.info(f"✅ CSV文件加载成功:{len(result)}个观测")
|
|
111
|
+
|
|
112
|
+
return result
|
|
113
|
+
|
|
114
|
+
except FileNotFoundError:
|
|
115
|
+
raise ValueError(f"文件不存在: {data}")
|
|
116
|
+
except Exception as e:
|
|
117
|
+
raise ValueError(f"CSV文件读取失败: {str(e)}")
|
|
118
|
+
|
|
119
|
+
# 其他类型报错
|
|
120
|
+
raise TypeError(f"不支持的数据类型: {type(data)},期望List或str")
|
|
121
|
+
async def load_x_data_if_path(
|
|
122
|
+
data: Union[List[List[float]], str],
|
|
123
|
+
ctx = None
|
|
124
|
+
) -> List[List[float]]:
|
|
125
|
+
"""
|
|
126
|
+
智能加载自变量数据:如果是字符串则作为文件路径加载,否则直接返回
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
data: 自变量数据(二维列表)或CSV文件路径
|
|
130
|
+
ctx: MCP上下文对象(可选,用于日志)
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
自变量数据(二维列表)
|
|
134
|
+
|
|
135
|
+
Raises:
|
|
136
|
+
ValueError: 文件不存在或读取失败
|
|
137
|
+
"""
|
|
138
|
+
# 如果已经是二维列表,直接返回
|
|
139
|
+
if isinstance(data, list) and all(isinstance(item, list) for item in data):
|
|
140
|
+
return data
|
|
141
|
+
|
|
142
|
+
# 如果是字符串,作为文件路径处理
|
|
143
|
+
if isinstance(data, str):
|
|
144
|
+
if ctx:
|
|
145
|
+
await ctx.info(f"📁 检测到自变量文件路径,正在加载: {data}")
|
|
146
|
+
|
|
147
|
+
try:
|
|
148
|
+
# 检查文件是否存在
|
|
149
|
+
path = Path(data)
|
|
150
|
+
if not path.exists():
|
|
151
|
+
raise ValueError(f"文件不存在: {data}")
|
|
152
|
+
|
|
153
|
+
# 读取CSV文件
|
|
154
|
+
df = pd.read_csv(path)
|
|
155
|
+
|
|
156
|
+
# 转换为二维列表格式
|
|
157
|
+
result = df.values.tolist()
|
|
158
|
+
|
|
159
|
+
if ctx:
|
|
160
|
+
await ctx.info(f"✅ 自变量CSV文件加载成功:{len(result)}个观测,{len(result[0]) if result else 0}个自变量")
|
|
161
|
+
|
|
162
|
+
return result
|
|
163
|
+
|
|
164
|
+
except FileNotFoundError:
|
|
165
|
+
raise ValueError(f"文件不存在: {data}")
|
|
166
|
+
except Exception as e:
|
|
167
|
+
raise ValueError(f"自变量CSV文件读取失败: {str(e)}")
|
|
168
|
+
|
|
169
|
+
# 其他类型报错
|
|
170
|
+
raise TypeError(f"不支持的数据类型: {type(data)},期望List[List[float]]或str")
|
|
171
|
+
raise TypeError(f"不支持的数据类型: {type(data)},期望Dict或str")
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
"""
|
|
2
|
+
工具装饰器模块
|
|
3
|
+
提供自动文件输入处理、错误处理等功能
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Callable, Optional, Dict, Any, List
|
|
7
|
+
from functools import wraps
|
|
8
|
+
from mcp.server.session import ServerSession
|
|
9
|
+
from mcp.server.fastmcp import Context
|
|
10
|
+
from mcp.types import CallToolResult, TextContent
|
|
11
|
+
|
|
12
|
+
from .file_parser import FileParser
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def with_file_input(tool_type: str):
|
|
16
|
+
"""
|
|
17
|
+
为工具函数添加文件输入支持的装饰器
|
|
18
|
+
|
|
19
|
+
支持两种输入方式:
|
|
20
|
+
1. file_path: CSV/JSON文件路径
|
|
21
|
+
2. file_content: 文件内容字符串
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
tool_type: 工具类型 ('single_var', 'multi_var_dict', 'regression', 'panel', 'time_series')
|
|
25
|
+
|
|
26
|
+
使用示例:
|
|
27
|
+
@with_file_input('regression')
|
|
28
|
+
async def my_tool(ctx, y_data=None, x_data=None, file_path=None, file_content=None, file_format='auto', **kwargs):
|
|
29
|
+
# 如果提供了file_path或file_content,数据会被自动填充
|
|
30
|
+
pass
|
|
31
|
+
"""
|
|
32
|
+
def decorator(func: Callable) -> Callable:
|
|
33
|
+
@wraps(func)
|
|
34
|
+
async def wrapper(*args, **kwargs):
|
|
35
|
+
# 提取上下文和文件参数
|
|
36
|
+
ctx = args[0] if args else kwargs.get('ctx')
|
|
37
|
+
file_path = kwargs.get('file_path')
|
|
38
|
+
file_content = kwargs.get('file_content')
|
|
39
|
+
file_format = kwargs.get('file_format', 'auto')
|
|
40
|
+
|
|
41
|
+
# 优先处理file_path
|
|
42
|
+
if file_path:
|
|
43
|
+
try:
|
|
44
|
+
await ctx.info(f"检测到文件路径输入: {file_path}")
|
|
45
|
+
|
|
46
|
+
# 从文件路径解析
|
|
47
|
+
parsed = FileParser.parse_file_path(file_path, file_format)
|
|
48
|
+
|
|
49
|
+
await ctx.info(
|
|
50
|
+
f"文件解析成功:{parsed['n_variables']}个变量,"
|
|
51
|
+
f"{parsed['n_observations']}个观测"
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# 转换为工具格式
|
|
55
|
+
converted = FileParser.convert_to_tool_format(parsed, tool_type)
|
|
56
|
+
|
|
57
|
+
# 更新kwargs
|
|
58
|
+
kwargs.update(converted)
|
|
59
|
+
|
|
60
|
+
await ctx.info(f"数据已转换为{tool_type}格式")
|
|
61
|
+
|
|
62
|
+
except Exception as e:
|
|
63
|
+
await ctx.error(f"文件解析失败: {str(e)}")
|
|
64
|
+
return CallToolResult(
|
|
65
|
+
content=[TextContent(type="text", text=f"文件解析错误: {str(e)}")],
|
|
66
|
+
isError=True
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
# 如果没有file_path但有file_content,处理文件内容
|
|
70
|
+
elif file_content:
|
|
71
|
+
try:
|
|
72
|
+
await ctx.info("检测到文件内容输入,开始解析...")
|
|
73
|
+
|
|
74
|
+
# 解析文件内容
|
|
75
|
+
parsed = FileParser.parse_file_content(file_content, file_format)
|
|
76
|
+
|
|
77
|
+
await ctx.info(
|
|
78
|
+
f"文件解析成功:{parsed['n_variables']}个变量,"
|
|
79
|
+
f"{parsed['n_observations']}个观测"
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# 转换为工具格式
|
|
83
|
+
converted = FileParser.convert_to_tool_format(parsed, tool_type)
|
|
84
|
+
|
|
85
|
+
# 更新kwargs
|
|
86
|
+
kwargs.update(converted)
|
|
87
|
+
|
|
88
|
+
await ctx.info(f"数据已转换为{tool_type}格式")
|
|
89
|
+
|
|
90
|
+
except Exception as e:
|
|
91
|
+
await ctx.error(f"文件解析失败: {str(e)}")
|
|
92
|
+
return CallToolResult(
|
|
93
|
+
content=[TextContent(type="text", text=f"文件解析错误: {str(e)}")],
|
|
94
|
+
isError=True
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# 调用原函数
|
|
98
|
+
return await func(*args, **kwargs)
|
|
99
|
+
|
|
100
|
+
return wrapper
|
|
101
|
+
return decorator
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def with_error_handling(func: Callable) -> Callable:
|
|
105
|
+
"""
|
|
106
|
+
为工具函数添加统一错误处理的装饰器
|
|
107
|
+
"""
|
|
108
|
+
@wraps(func)
|
|
109
|
+
async def wrapper(*args, **kwargs):
|
|
110
|
+
ctx = args[0] if args else kwargs.get('ctx')
|
|
111
|
+
tool_name = func.__name__
|
|
112
|
+
|
|
113
|
+
try:
|
|
114
|
+
return await func(*args, **kwargs)
|
|
115
|
+
except Exception as e:
|
|
116
|
+
await ctx.error(f"{tool_name}执行出错: {str(e)}")
|
|
117
|
+
return CallToolResult(
|
|
118
|
+
content=[TextContent(type="text", text=f"错误: {str(e)}")],
|
|
119
|
+
isError=True
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
return wrapper
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def with_logging(func: Callable) -> Callable:
|
|
126
|
+
"""
|
|
127
|
+
为工具函数添加日志记录的装饰器
|
|
128
|
+
"""
|
|
129
|
+
@wraps(func)
|
|
130
|
+
async def wrapper(*args, **kwargs):
|
|
131
|
+
ctx = args[0] if args else kwargs.get('ctx')
|
|
132
|
+
tool_name = func.__name__
|
|
133
|
+
|
|
134
|
+
await ctx.info(f"开始执行 {tool_name}")
|
|
135
|
+
result = await func(*args, **kwargs)
|
|
136
|
+
await ctx.info(f"{tool_name} 执行完成")
|
|
137
|
+
|
|
138
|
+
return result
|
|
139
|
+
|
|
140
|
+
return wrapper
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def econometric_tool(
|
|
144
|
+
tool_type: str,
|
|
145
|
+
with_file_support: bool = True,
|
|
146
|
+
with_error_handling: bool = True,
|
|
147
|
+
with_logging: bool = True
|
|
148
|
+
):
|
|
149
|
+
"""
|
|
150
|
+
组合装饰器:为计量经济学工具添加所有标准功能
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
tool_type: 工具类型
|
|
154
|
+
with_file_support: 是否启用文件输入支持
|
|
155
|
+
with_error_handling: 是否启用错误处理
|
|
156
|
+
with_logging: 是否启用日志记录
|
|
157
|
+
|
|
158
|
+
使用示例:
|
|
159
|
+
@econometric_tool('regression')
|
|
160
|
+
async def ols_regression(ctx, y_data=None, x_data=None, **kwargs):
|
|
161
|
+
# 只需要编写核心业务逻辑
|
|
162
|
+
pass
|
|
163
|
+
"""
|
|
164
|
+
def decorator(func: Callable) -> Callable:
|
|
165
|
+
wrapped = func
|
|
166
|
+
|
|
167
|
+
if with_error_handling:
|
|
168
|
+
wrapped = globals()['with_error_handling'](wrapped)
|
|
169
|
+
|
|
170
|
+
if with_file_support:
|
|
171
|
+
wrapped = with_file_input(tool_type)(wrapped)
|
|
172
|
+
|
|
173
|
+
if with_logging:
|
|
174
|
+
wrapped = globals()['with_logging'](wrapped)
|
|
175
|
+
|
|
176
|
+
return wrapped
|
|
177
|
+
|
|
178
|
+
return decorator
|