aigroup-econ-mcp 0.3.7__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aigroup_econ_mcp/__init__.py +18 -18
- aigroup_econ_mcp/server.py +284 -3291
- aigroup_econ_mcp/server_v1_backup.py +1250 -0
- aigroup_econ_mcp/server_v1_old.py +1250 -0
- aigroup_econ_mcp/server_with_file_support.py +259 -0
- aigroup_econ_mcp/tools/__init__.py +3 -2
- aigroup_econ_mcp/tools/data_loader.py +171 -0
- aigroup_econ_mcp/tools/decorators.py +178 -0
- aigroup_econ_mcp/tools/file_input_handler.py +268 -0
- aigroup_econ_mcp/tools/file_parser.py +560 -0
- aigroup_econ_mcp/tools/machine_learning.py +14 -14
- aigroup_econ_mcp/tools/panel_data.py +10 -6
- aigroup_econ_mcp/tools/time_series.py +54 -127
- aigroup_econ_mcp/tools/tool_handlers.py +378 -0
- aigroup_econ_mcp/tools/tool_registry.py +170 -0
- {aigroup_econ_mcp-0.3.7.dist-info → aigroup_econ_mcp-0.4.0.dist-info}/METADATA +287 -22
- aigroup_econ_mcp-0.4.0.dist-info/RECORD +30 -0
- aigroup_econ_mcp-0.3.7.dist-info/RECORD +0 -21
- {aigroup_econ_mcp-0.3.7.dist-info → aigroup_econ_mcp-0.4.0.dist-info}/WHEEL +0 -0
- {aigroup_econ_mcp-0.3.7.dist-info → aigroup_econ_mcp-0.4.0.dist-info}/entry_points.txt +0 -0
- {aigroup_econ_mcp-0.3.7.dist-info → aigroup_econ_mcp-0.4.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,560 @@
|
|
|
1
|
+
"""
|
|
2
|
+
文件解析模块
|
|
3
|
+
支持CSV和JSON格式文件的智能解析和数据转换
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
import csv
|
|
8
|
+
from typing import Dict, List, Any, Union, Tuple, Optional
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
import io
|
|
11
|
+
import base64
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class FileParser:
|
|
15
|
+
"""文件解析器,支持CSV和JSON格式"""
|
|
16
|
+
|
|
17
|
+
@staticmethod
|
|
18
|
+
def parse_file_path(
|
|
19
|
+
file_path: str,
|
|
20
|
+
file_format: str = "auto"
|
|
21
|
+
) -> Dict[str, Any]:
|
|
22
|
+
"""
|
|
23
|
+
从文件路径解析文件
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
file_path: 文件路径(相对或绝对路径)
|
|
27
|
+
file_format: 文件格式 ("csv", "json", "auto")
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
解析后的数据字典
|
|
31
|
+
"""
|
|
32
|
+
path = Path(file_path)
|
|
33
|
+
|
|
34
|
+
if not path.exists():
|
|
35
|
+
raise FileNotFoundError(f"文件不存在: {file_path}")
|
|
36
|
+
|
|
37
|
+
if not path.is_file():
|
|
38
|
+
raise ValueError(f"路径不是文件: {file_path}")
|
|
39
|
+
|
|
40
|
+
# 自动检测格式(基于文件扩展名)
|
|
41
|
+
if file_format == "auto":
|
|
42
|
+
ext = path.suffix.lower()
|
|
43
|
+
if ext == '.csv':
|
|
44
|
+
file_format = "csv"
|
|
45
|
+
elif ext in ['.json', '.jsonl']:
|
|
46
|
+
file_format = "json"
|
|
47
|
+
else:
|
|
48
|
+
# 尝试从内容检测
|
|
49
|
+
with open(path, 'r', encoding='utf-8') as f:
|
|
50
|
+
content = f.read()
|
|
51
|
+
return FileParser.parse_file_content(content, "auto")
|
|
52
|
+
|
|
53
|
+
# 读取文件内容
|
|
54
|
+
with open(path, 'r', encoding='utf-8') as f:
|
|
55
|
+
content = f.read()
|
|
56
|
+
|
|
57
|
+
return FileParser.parse_file_content(content, file_format)
|
|
58
|
+
|
|
59
|
+
@staticmethod
|
|
60
|
+
def parse_file_content(
|
|
61
|
+
content: str,
|
|
62
|
+
file_format: str = "auto"
|
|
63
|
+
) -> Dict[str, Any]:
|
|
64
|
+
"""
|
|
65
|
+
解析文件内容
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
content: 文件内容(base64编码的字符串或直接文本)
|
|
69
|
+
file_format: 文件格式 ("csv", "json", "auto")
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
解析后的数据字典,包含:
|
|
73
|
+
- data: 数据内容
|
|
74
|
+
- variables: 变量名列表
|
|
75
|
+
- format: 检测到的格式
|
|
76
|
+
- data_type: 数据类型('univariate', 'multivariate', 'time_series', 'panel')
|
|
77
|
+
"""
|
|
78
|
+
# 尝试检测是否为base64编码
|
|
79
|
+
try:
|
|
80
|
+
decoded_content = base64.b64decode(content).decode('utf-8')
|
|
81
|
+
except:
|
|
82
|
+
decoded_content = content
|
|
83
|
+
|
|
84
|
+
# 自动检测格式
|
|
85
|
+
if file_format == "auto":
|
|
86
|
+
file_format = FileParser._detect_format(decoded_content)
|
|
87
|
+
|
|
88
|
+
if file_format == "csv":
|
|
89
|
+
return FileParser._parse_csv(decoded_content)
|
|
90
|
+
elif file_format == "json":
|
|
91
|
+
return FileParser._parse_json(decoded_content)
|
|
92
|
+
else:
|
|
93
|
+
raise ValueError(f"不支持的文件格式: {file_format}")
|
|
94
|
+
|
|
95
|
+
@staticmethod
|
|
96
|
+
def _detect_format(content: str) -> str:
|
|
97
|
+
"""自动检测文件格式"""
|
|
98
|
+
# 尝试解析JSON
|
|
99
|
+
try:
|
|
100
|
+
json.loads(content.strip())
|
|
101
|
+
return "json"
|
|
102
|
+
except:
|
|
103
|
+
pass
|
|
104
|
+
|
|
105
|
+
# 检测CSV特征
|
|
106
|
+
if ',' in content or '\t' in content:
|
|
107
|
+
return "csv"
|
|
108
|
+
|
|
109
|
+
raise ValueError("无法自动检测文件格式,请明确指定")
|
|
110
|
+
|
|
111
|
+
@staticmethod
|
|
112
|
+
def _parse_csv(content: str) -> Dict[str, Any]:
|
|
113
|
+
"""
|
|
114
|
+
解析CSV文件
|
|
115
|
+
|
|
116
|
+
支持的格式:
|
|
117
|
+
1. 带表头的列数据
|
|
118
|
+
2. 无表头的纯数值数据
|
|
119
|
+
"""
|
|
120
|
+
lines = content.strip().split('\n')
|
|
121
|
+
if not lines:
|
|
122
|
+
raise ValueError("CSV文件为空")
|
|
123
|
+
|
|
124
|
+
# 检测分隔符
|
|
125
|
+
delimiter = FileParser._detect_delimiter(lines[0])
|
|
126
|
+
|
|
127
|
+
# 使用csv.reader解析
|
|
128
|
+
reader = csv.reader(io.StringIO(content), delimiter=delimiter)
|
|
129
|
+
rows = list(reader)
|
|
130
|
+
|
|
131
|
+
if not rows:
|
|
132
|
+
raise ValueError("CSV文件没有数据")
|
|
133
|
+
|
|
134
|
+
# 检测是否有表头
|
|
135
|
+
has_header = FileParser._has_header(rows)
|
|
136
|
+
|
|
137
|
+
if has_header:
|
|
138
|
+
headers = rows[0]
|
|
139
|
+
data_rows = rows[1:]
|
|
140
|
+
else:
|
|
141
|
+
# 自动生成列名
|
|
142
|
+
headers = [f"var{i+1}" for i in range(len(rows[0]))]
|
|
143
|
+
data_rows = rows
|
|
144
|
+
|
|
145
|
+
# 转换为数值数据
|
|
146
|
+
parsed_data = {}
|
|
147
|
+
for i, header in enumerate(headers):
|
|
148
|
+
column_data = []
|
|
149
|
+
for row in data_rows:
|
|
150
|
+
if i < len(row):
|
|
151
|
+
try:
|
|
152
|
+
# 尝试转换为浮点数
|
|
153
|
+
value = float(row[i].strip())
|
|
154
|
+
column_data.append(value)
|
|
155
|
+
except ValueError:
|
|
156
|
+
# 如果无法转换,保留原字符串(用于ID列)
|
|
157
|
+
column_data.append(row[i].strip())
|
|
158
|
+
|
|
159
|
+
if column_data: # 只保留有数据的列
|
|
160
|
+
parsed_data[header.strip()] = column_data
|
|
161
|
+
|
|
162
|
+
if not parsed_data:
|
|
163
|
+
raise ValueError("CSV文件中没有有效的数据")
|
|
164
|
+
|
|
165
|
+
# 检测数据类型
|
|
166
|
+
data_type = FileParser._detect_data_type(parsed_data)
|
|
167
|
+
|
|
168
|
+
return {
|
|
169
|
+
"data": parsed_data,
|
|
170
|
+
"variables": list(parsed_data.keys()),
|
|
171
|
+
"format": "csv",
|
|
172
|
+
"data_type": data_type,
|
|
173
|
+
"n_variables": len(parsed_data),
|
|
174
|
+
"n_observations": len(next(iter(parsed_data.values())))
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
@staticmethod
|
|
178
|
+
def _parse_json(content: str) -> Dict[str, Any]:
|
|
179
|
+
"""
|
|
180
|
+
解析JSON文件
|
|
181
|
+
|
|
182
|
+
支持的格式:
|
|
183
|
+
1. {"变量名": [数据列表], ...}
|
|
184
|
+
2. [{"变量1": 值, "变量2": 值, ...}, ...]
|
|
185
|
+
3. {"data": {...}, "metadata": {...}}
|
|
186
|
+
"""
|
|
187
|
+
try:
|
|
188
|
+
json_data = json.loads(content)
|
|
189
|
+
except json.JSONDecodeError as e:
|
|
190
|
+
raise ValueError(f"JSON格式错误: {str(e)}")
|
|
191
|
+
|
|
192
|
+
# 格式1: 直接的变量-数据字典
|
|
193
|
+
if isinstance(json_data, dict) and all(
|
|
194
|
+
isinstance(v, list) for v in json_data.values()
|
|
195
|
+
):
|
|
196
|
+
# 保留所有列(包括字符串类型的ID和时间列)
|
|
197
|
+
parsed_data = {}
|
|
198
|
+
for key, values in json_data.items():
|
|
199
|
+
if key.lower() in ['metadata', 'info', 'description']:
|
|
200
|
+
continue # 跳过元数据字段
|
|
201
|
+
|
|
202
|
+
# 智能转换:尝试转数值,失败则保留原始类型
|
|
203
|
+
converted_values = []
|
|
204
|
+
for v in values:
|
|
205
|
+
try:
|
|
206
|
+
# 尝试转换为浮点数
|
|
207
|
+
converted_values.append(float(v))
|
|
208
|
+
except (ValueError, TypeError):
|
|
209
|
+
# 无法转换则保留原始值(字符串等)
|
|
210
|
+
converted_values.append(v)
|
|
211
|
+
|
|
212
|
+
parsed_data[key] = converted_values
|
|
213
|
+
|
|
214
|
+
if parsed_data:
|
|
215
|
+
data_type = FileParser._detect_data_type(parsed_data)
|
|
216
|
+
return {
|
|
217
|
+
"data": parsed_data,
|
|
218
|
+
"variables": list(parsed_data.keys()),
|
|
219
|
+
"format": "json",
|
|
220
|
+
"data_type": data_type,
|
|
221
|
+
"n_variables": len(parsed_data),
|
|
222
|
+
"n_observations": len(next(iter(parsed_data.values())))
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
# 格式2: 记录数组格式
|
|
226
|
+
elif isinstance(json_data, list) and json_data and isinstance(json_data[0], dict):
|
|
227
|
+
# 转换为变量-数据字典,保留字符串类型
|
|
228
|
+
parsed_data = {}
|
|
229
|
+
for record in json_data:
|
|
230
|
+
for key, value in record.items():
|
|
231
|
+
if key not in parsed_data:
|
|
232
|
+
parsed_data[key] = []
|
|
233
|
+
# 智能转换:尝试转数值,失败则保留原始类型
|
|
234
|
+
try:
|
|
235
|
+
parsed_data[key].append(float(value))
|
|
236
|
+
except (ValueError, TypeError):
|
|
237
|
+
# 保留原始值(字符串等)
|
|
238
|
+
parsed_data[key].append(value)
|
|
239
|
+
|
|
240
|
+
if parsed_data:
|
|
241
|
+
data_type = FileParser._detect_data_type(parsed_data)
|
|
242
|
+
return {
|
|
243
|
+
"data": parsed_data,
|
|
244
|
+
"variables": list(parsed_data.keys()),
|
|
245
|
+
"format": "json",
|
|
246
|
+
"data_type": data_type,
|
|
247
|
+
"n_variables": len(parsed_data),
|
|
248
|
+
"n_observations": len(next(iter(parsed_data.values())))
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
# 格式3: 包含data字段的结构
|
|
252
|
+
elif isinstance(json_data, dict) and "data" in json_data:
|
|
253
|
+
return FileParser._parse_json(json.dumps(json_data["data"]))
|
|
254
|
+
|
|
255
|
+
raise ValueError("不支持的JSON数据格式")
|
|
256
|
+
|
|
257
|
+
@staticmethod
|
|
258
|
+
def _detect_delimiter(line: str) -> str:
|
|
259
|
+
"""检测CSV分隔符"""
|
|
260
|
+
# 常见分隔符
|
|
261
|
+
delimiters = [',', '\t', ';', '|']
|
|
262
|
+
counts = {d: line.count(d) for d in delimiters}
|
|
263
|
+
# 返回出现次数最多的分隔符
|
|
264
|
+
return max(counts.items(), key=lambda x: x[1])[0]
|
|
265
|
+
|
|
266
|
+
@staticmethod
|
|
267
|
+
def _has_header(rows: List[List[str]]) -> bool:
|
|
268
|
+
"""检测CSV是否有表头"""
|
|
269
|
+
if len(rows) < 2:
|
|
270
|
+
return False
|
|
271
|
+
|
|
272
|
+
# 检查第一行是否包含非数值字符串
|
|
273
|
+
first_row = rows[0]
|
|
274
|
+
|
|
275
|
+
# 如果第一行有任何元素无法转换为数字,认为有表头
|
|
276
|
+
for cell in first_row:
|
|
277
|
+
try:
|
|
278
|
+
float(cell.strip())
|
|
279
|
+
except ValueError:
|
|
280
|
+
return True
|
|
281
|
+
|
|
282
|
+
return False
|
|
283
|
+
|
|
284
|
+
@staticmethod
|
|
285
|
+
def _detect_data_type(data: Dict[str, List]) -> str:
|
|
286
|
+
"""
|
|
287
|
+
检测数据类型
|
|
288
|
+
|
|
289
|
+
Returns:
|
|
290
|
+
- 'univariate': 单变量
|
|
291
|
+
- 'multivariate': 多变量
|
|
292
|
+
- 'time_series': 时间序列(通过变量名推断)
|
|
293
|
+
- 'panel': 面板数据(通过变量名推断)
|
|
294
|
+
"""
|
|
295
|
+
n_vars = len(data)
|
|
296
|
+
var_names = [v.lower() for v in data.keys()]
|
|
297
|
+
|
|
298
|
+
# 检查是否包含时间/日期相关的变量名
|
|
299
|
+
time_keywords = ['time', 'date', 'year', 'month', 'day', 'period', 'quarter']
|
|
300
|
+
has_time_var = any(any(kw in var for kw in time_keywords) for var in var_names)
|
|
301
|
+
|
|
302
|
+
# 检查是否包含实体/ID相关的变量名
|
|
303
|
+
entity_keywords = ['id', 'entity', 'firm', 'company', 'country', 'region']
|
|
304
|
+
has_entity_var = any(any(kw in var for kw in entity_keywords) for var in var_names)
|
|
305
|
+
|
|
306
|
+
if n_vars == 1:
|
|
307
|
+
return 'univariate'
|
|
308
|
+
elif has_entity_var and has_time_var:
|
|
309
|
+
return 'panel'
|
|
310
|
+
elif has_time_var or n_vars >= 2:
|
|
311
|
+
return 'time_series'
|
|
312
|
+
else:
|
|
313
|
+
return 'multivariate'
|
|
314
|
+
|
|
315
|
+
@staticmethod
|
|
316
|
+
def convert_to_tool_format(
|
|
317
|
+
parsed_data: Dict[str, Any],
|
|
318
|
+
tool_type: str
|
|
319
|
+
) -> Dict[str, Any]:
|
|
320
|
+
"""
|
|
321
|
+
将解析后的数据转换为工具所需的格式
|
|
322
|
+
|
|
323
|
+
Args:
|
|
324
|
+
parsed_data: parse_file_content返回的数据
|
|
325
|
+
tool_type: 工具类型
|
|
326
|
+
- 'single_var': 单变量 (List[float])
|
|
327
|
+
- 'multi_var_dict': 多变量字典 (Dict[str, List[float]])
|
|
328
|
+
- 'multi_var_matrix': 多变量矩阵 (List[List[float]])
|
|
329
|
+
- 'regression': 回归分析 (y_data, x_data)
|
|
330
|
+
- 'panel': 面板数据 (y_data, x_data, entity_ids, time_periods)
|
|
331
|
+
|
|
332
|
+
Returns:
|
|
333
|
+
转换后的数据字典
|
|
334
|
+
"""
|
|
335
|
+
data = parsed_data["data"]
|
|
336
|
+
variables = parsed_data["variables"]
|
|
337
|
+
|
|
338
|
+
if tool_type == 'single_var':
|
|
339
|
+
# 返回第一个变量的数据
|
|
340
|
+
var_data = data[variables[0]]
|
|
341
|
+
return {
|
|
342
|
+
"data": var_data
|
|
343
|
+
}
|
|
344
|
+
|
|
345
|
+
elif tool_type == 'multi_var_dict':
|
|
346
|
+
# 直接返回字典格式
|
|
347
|
+
return {"data": data}
|
|
348
|
+
|
|
349
|
+
elif tool_type == 'time_series':
|
|
350
|
+
# 时间序列类型,与multi_var_dict相同,返回字典格式
|
|
351
|
+
return {"data": data}
|
|
352
|
+
|
|
353
|
+
elif tool_type == 'multi_var_matrix':
|
|
354
|
+
# 转换为矩阵格式 (List[List[float]])
|
|
355
|
+
n_obs = len(data[variables[0]])
|
|
356
|
+
matrix = []
|
|
357
|
+
for i in range(n_obs):
|
|
358
|
+
row = [data[var][i] for var in variables]
|
|
359
|
+
matrix.append(row)
|
|
360
|
+
|
|
361
|
+
return {
|
|
362
|
+
"data": matrix,
|
|
363
|
+
"feature_names": variables
|
|
364
|
+
}
|
|
365
|
+
|
|
366
|
+
elif tool_type == 'regression':
|
|
367
|
+
# 假设最后一个变量是因变量,其余是自变量
|
|
368
|
+
if len(variables) < 2:
|
|
369
|
+
raise ValueError("回归分析至少需要2个变量(1个因变量和至少1个自变量)")
|
|
370
|
+
|
|
371
|
+
y_var = variables[-1]
|
|
372
|
+
x_vars = variables[:-1]
|
|
373
|
+
|
|
374
|
+
y_data = data[y_var]
|
|
375
|
+
n_obs = len(y_data)
|
|
376
|
+
|
|
377
|
+
# 构建x_data矩阵
|
|
378
|
+
x_data = []
|
|
379
|
+
for i in range(n_obs):
|
|
380
|
+
row = [data[var][i] for var in x_vars]
|
|
381
|
+
x_data.append(row)
|
|
382
|
+
|
|
383
|
+
return {
|
|
384
|
+
"y_data": y_data,
|
|
385
|
+
"x_data": x_data,
|
|
386
|
+
"feature_names": x_vars
|
|
387
|
+
}
|
|
388
|
+
|
|
389
|
+
elif tool_type == 'panel':
|
|
390
|
+
# 识别实体ID、时间标识和数据变量
|
|
391
|
+
entity_var = None
|
|
392
|
+
time_var = None
|
|
393
|
+
data_vars = []
|
|
394
|
+
|
|
395
|
+
entity_keywords = ['id', 'entity', 'firm', 'company', 'country', 'region']
|
|
396
|
+
time_keywords = ['time', 'date', 'year', 'month', 'day', 'period', 'quarter']
|
|
397
|
+
|
|
398
|
+
# 更详细的检测逻辑
|
|
399
|
+
print(f"Debug: 开始检测面板数据列...")
|
|
400
|
+
for var in variables:
|
|
401
|
+
var_lower = var.lower()
|
|
402
|
+
print(f"Debug: 检查变量 '{var}' (小写: '{var_lower}')")
|
|
403
|
+
|
|
404
|
+
# 检查是否是实体ID列
|
|
405
|
+
is_entity = any(kw in var_lower for kw in entity_keywords)
|
|
406
|
+
is_time = any(kw in var_lower for kw in time_keywords)
|
|
407
|
+
|
|
408
|
+
if is_entity and entity_var is None:
|
|
409
|
+
entity_var = var
|
|
410
|
+
print(f"Debug: 识别为实体ID列: {var}")
|
|
411
|
+
elif is_time and time_var is None:
|
|
412
|
+
time_var = var
|
|
413
|
+
print(f"Debug: 识别为时间列: {var}")
|
|
414
|
+
else:
|
|
415
|
+
data_vars.append(var)
|
|
416
|
+
print(f"Debug: 识别为数据列: {var}")
|
|
417
|
+
|
|
418
|
+
print(f"Debug: entity_var={entity_var}, time_var={time_var}, data_vars={data_vars}")
|
|
419
|
+
|
|
420
|
+
if not entity_var or not time_var:
|
|
421
|
+
# 提供更详细的错误信息
|
|
422
|
+
available_vars = ', '.join(variables)
|
|
423
|
+
error_msg = f"面板数据需要包含实体ID和时间标识变量。\n"
|
|
424
|
+
error_msg += f"可用列: {available_vars}\n"
|
|
425
|
+
error_msg += f"检测到的实体ID列: {entity_var if entity_var else '未检测到'}\n"
|
|
426
|
+
error_msg += f"检测到的时间列: {time_var if time_var else '未检测到'}\n"
|
|
427
|
+
error_msg += f"实体ID关键词: {entity_keywords}\n"
|
|
428
|
+
error_msg += f"时间关键词: {time_keywords}"
|
|
429
|
+
raise ValueError(error_msg)
|
|
430
|
+
|
|
431
|
+
if len(data_vars) < 1:
|
|
432
|
+
raise ValueError(f"面板数据至少需要1个数据变量。当前数据列: {data_vars}")
|
|
433
|
+
|
|
434
|
+
# 转换ID和时间(保持原类型,可能是字符串或数字)
|
|
435
|
+
entity_ids = [str(x) for x in data[entity_var]]
|
|
436
|
+
time_periods = [str(int(x)) if isinstance(x, float) and x == int(x) else str(x) for x in data[time_var]]
|
|
437
|
+
|
|
438
|
+
print(f"Debug: entity_ids样本: {entity_ids[:5]}")
|
|
439
|
+
print(f"Debug: time_periods样本: {time_periods[:5]}")
|
|
440
|
+
|
|
441
|
+
# 如果只有一个数据变量,将其作为因变量
|
|
442
|
+
if len(data_vars) == 1:
|
|
443
|
+
y_var = data_vars[0]
|
|
444
|
+
y_data = data[y_var]
|
|
445
|
+
# 创建一个虚拟自变量(常数项)
|
|
446
|
+
n_obs = len(y_data)
|
|
447
|
+
x_data = [[1.0] for _ in range(n_obs)]
|
|
448
|
+
x_vars = ['const']
|
|
449
|
+
else:
|
|
450
|
+
# 假设最后一个数据变量是因变量
|
|
451
|
+
y_var = data_vars[-1]
|
|
452
|
+
x_vars = data_vars[:-1]
|
|
453
|
+
|
|
454
|
+
y_data = data[y_var]
|
|
455
|
+
n_obs = len(y_data)
|
|
456
|
+
|
|
457
|
+
# 构建x_data矩阵
|
|
458
|
+
x_data = []
|
|
459
|
+
for i in range(n_obs):
|
|
460
|
+
row = [data[var][i] for var in x_vars]
|
|
461
|
+
x_data.append(row)
|
|
462
|
+
|
|
463
|
+
return {
|
|
464
|
+
"y_data": y_data,
|
|
465
|
+
"x_data": x_data,
|
|
466
|
+
"entity_ids": entity_ids,
|
|
467
|
+
"time_periods": time_periods,
|
|
468
|
+
"feature_names": x_vars
|
|
469
|
+
}
|
|
470
|
+
|
|
471
|
+
else:
|
|
472
|
+
raise ValueError(f"不支持的工具类型: {tool_type}")
|
|
473
|
+
|
|
474
|
+
@staticmethod
|
|
475
|
+
def auto_detect_tool_params(parsed_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
476
|
+
"""
|
|
477
|
+
自动检测并推荐适合的工具参数
|
|
478
|
+
|
|
479
|
+
Args:
|
|
480
|
+
parsed_data: parse_file_content返回的数据
|
|
481
|
+
|
|
482
|
+
Returns:
|
|
483
|
+
推荐的工具和参数
|
|
484
|
+
"""
|
|
485
|
+
data_type = parsed_data["data_type"]
|
|
486
|
+
n_vars = parsed_data["n_variables"]
|
|
487
|
+
n_obs = parsed_data["n_observations"]
|
|
488
|
+
|
|
489
|
+
recommendations = {
|
|
490
|
+
"data_type": data_type,
|
|
491
|
+
"suggested_tools": [],
|
|
492
|
+
"warnings": []
|
|
493
|
+
}
|
|
494
|
+
|
|
495
|
+
# 根据数据类型推荐工具
|
|
496
|
+
if data_type == 'univariate':
|
|
497
|
+
recommendations["suggested_tools"] = [
|
|
498
|
+
"descriptive_statistics",
|
|
499
|
+
"hypothesis_testing",
|
|
500
|
+
"time_series_analysis"
|
|
501
|
+
]
|
|
502
|
+
elif data_type == 'multivariate':
|
|
503
|
+
recommendations["suggested_tools"] = [
|
|
504
|
+
"descriptive_statistics",
|
|
505
|
+
"correlation_analysis",
|
|
506
|
+
"ols_regression",
|
|
507
|
+
"random_forest_regression_analysis",
|
|
508
|
+
"lasso_regression_analysis"
|
|
509
|
+
]
|
|
510
|
+
elif data_type == 'time_series':
|
|
511
|
+
recommendations["suggested_tools"] = [
|
|
512
|
+
"time_series_analysis",
|
|
513
|
+
"var_model_analysis",
|
|
514
|
+
"garch_model_analysis"
|
|
515
|
+
]
|
|
516
|
+
elif data_type == 'panel':
|
|
517
|
+
recommendations["suggested_tools"] = [
|
|
518
|
+
"panel_fixed_effects",
|
|
519
|
+
"panel_random_effects",
|
|
520
|
+
"panel_hausman_test",
|
|
521
|
+
"panel_unit_root_test"
|
|
522
|
+
]
|
|
523
|
+
|
|
524
|
+
# 添加警告
|
|
525
|
+
if n_obs < 30:
|
|
526
|
+
recommendations["warnings"].append(
|
|
527
|
+
f"样本量较小({n_obs}个观测),统计推断可能不可靠"
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
if n_vars > 10:
|
|
531
|
+
recommendations["warnings"].append(
|
|
532
|
+
f"变量数量较多({n_vars}个变量),可能需要特征选择"
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
if n_vars > n_obs / 10:
|
|
536
|
+
recommendations["warnings"].append(
|
|
537
|
+
"变量数量接近样本量的1/10,可能存在过拟合风险"
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
return recommendations
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
def parse_file_input(
|
|
544
|
+
file_content: Optional[str] = None,
|
|
545
|
+
file_format: str = "auto"
|
|
546
|
+
) -> Optional[Dict[str, Any]]:
|
|
547
|
+
"""
|
|
548
|
+
便捷函数:解析文件输入
|
|
549
|
+
|
|
550
|
+
Args:
|
|
551
|
+
file_content: 文件内容(可选)
|
|
552
|
+
file_format: 文件格式
|
|
553
|
+
|
|
554
|
+
Returns:
|
|
555
|
+
解析后的数据,如果file_content为None则返回None
|
|
556
|
+
"""
|
|
557
|
+
if file_content is None:
|
|
558
|
+
return None
|
|
559
|
+
|
|
560
|
+
return FileParser.parse_file_content(file_content, file_format)
|
|
@@ -111,7 +111,7 @@ def random_forest_regression(
|
|
|
111
111
|
raise ValueError("因变量和自变量数据不能为空")
|
|
112
112
|
|
|
113
113
|
if len(y_data) != len(x_data):
|
|
114
|
-
raise ValueError("因变量和自变量的观测数量不一致: y_data={}, x_data={
|
|
114
|
+
raise ValueError(f"因变量和自变量的观测数量不一致: y_data={len(y_data)}, x_data={len(x_data)}")
|
|
115
115
|
|
|
116
116
|
# 准备数据
|
|
117
117
|
X = np.array(x_data)
|
|
@@ -121,7 +121,7 @@ def random_forest_regression(
|
|
|
121
121
|
if feature_names is None:
|
|
122
122
|
feature_names = [f"x{i}" for i in range(X.shape[1])]
|
|
123
123
|
elif len(feature_names) != X.shape[1]:
|
|
124
|
-
raise ValueError("特征名称数量({})与自变量数量({
|
|
124
|
+
raise ValueError(f"特征名称数量({len(feature_names)})与自变量数量({X.shape[1]})不匹配")
|
|
125
125
|
|
|
126
126
|
# 数据标准化
|
|
127
127
|
scaler = StandardScaler()
|
|
@@ -210,7 +210,7 @@ def gradient_boosting_regression(
|
|
|
210
210
|
raise ValueError("因变量和自变量数据不能为空")
|
|
211
211
|
|
|
212
212
|
if len(y_data) != len(x_data):
|
|
213
|
-
raise ValueError("因变量和自变量的观测数量不一致: y_data={}, x_data={
|
|
213
|
+
raise ValueError(f"因变量和自变量的观测数量不一致: y_data={len(y_data)}, x_data={len(x_data)}")
|
|
214
214
|
|
|
215
215
|
# 准备数据
|
|
216
216
|
X = np.array(x_data)
|
|
@@ -220,7 +220,7 @@ def gradient_boosting_regression(
|
|
|
220
220
|
if feature_names is None:
|
|
221
221
|
feature_names = [f"x{i}" for i in range(X.shape[1])]
|
|
222
222
|
elif len(feature_names) != X.shape[1]:
|
|
223
|
-
raise ValueError("特征名称数量({})与自变量数量({
|
|
223
|
+
raise ValueError(f"特征名称数量({len(feature_names)})与自变量数量({X.shape[1]})不匹配")
|
|
224
224
|
|
|
225
225
|
# 数据标准化
|
|
226
226
|
scaler = StandardScaler()
|
|
@@ -364,7 +364,7 @@ def _regularized_regression(
|
|
|
364
364
|
raise ValueError("因变量和自变量数据不能为空")
|
|
365
365
|
|
|
366
366
|
if len(y_data) != len(x_data):
|
|
367
|
-
raise ValueError("因变量和自变量的观测数量不一致: y_data={}, x_data={
|
|
367
|
+
raise ValueError(f"因变量和自变量的观测数量不一致: y_data={len(y_data)}, x_data={len(x_data)}")
|
|
368
368
|
|
|
369
369
|
# 准备数据
|
|
370
370
|
X = np.array(x_data)
|
|
@@ -374,7 +374,7 @@ def _regularized_regression(
|
|
|
374
374
|
if feature_names is None:
|
|
375
375
|
feature_names = [f"x{i}" for i in range(X.shape[1])]
|
|
376
376
|
elif len(feature_names) != X.shape[1]:
|
|
377
|
-
raise ValueError("特征名称数量({})与自变量数量({
|
|
377
|
+
raise ValueError(f"特征名称数量({len(feature_names)})与自变量数量({X.shape[1]})不匹配")
|
|
378
378
|
|
|
379
379
|
# 数据标准化
|
|
380
380
|
scaler = StandardScaler()
|
|
@@ -387,7 +387,7 @@ def _regularized_regression(
|
|
|
387
387
|
elif model_type == "ridge":
|
|
388
388
|
model = Ridge(alpha=alpha, random_state=random_state)
|
|
389
389
|
else:
|
|
390
|
-
raise ValueError("不支持的模型类型: {}"
|
|
390
|
+
raise ValueError(f"不支持的模型类型: {model_type}")
|
|
391
391
|
|
|
392
392
|
# 训练模型
|
|
393
393
|
model.fit(X_scaled, y_scaled)
|
|
@@ -464,10 +464,10 @@ def cross_validation(
|
|
|
464
464
|
raise ValueError("因变量和自变量数据不能为空")
|
|
465
465
|
|
|
466
466
|
if len(y_data) != len(x_data):
|
|
467
|
-
raise ValueError("因变量和自变量的观测数量不一致: y_data={}, x_data={
|
|
467
|
+
raise ValueError(f"因变量和自变量的观测数量不一致: y_data={len(y_data)}, x_data={len(x_data)}")
|
|
468
468
|
|
|
469
469
|
if cv_folds < 2 or cv_folds > len(y_data):
|
|
470
|
-
raise ValueError("交叉验证折数应在2到样本数量之间: cv_folds={}, n_obs={
|
|
470
|
+
raise ValueError(f"交叉验证折数应在2到样本数量之间: cv_folds={cv_folds}, n_obs={len(y_data)}")
|
|
471
471
|
|
|
472
472
|
# 准备数据
|
|
473
473
|
X = np.array(x_data)
|
|
@@ -487,7 +487,7 @@ def cross_validation(
|
|
|
487
487
|
elif model_type == "ridge":
|
|
488
488
|
model = Ridge(**model_params)
|
|
489
489
|
else:
|
|
490
|
-
raise ValueError("不支持的模型类型: {}"
|
|
490
|
+
raise ValueError(f"不支持的模型类型: {model_type}")
|
|
491
491
|
|
|
492
492
|
# 执行交叉验证
|
|
493
493
|
cv = KFold(n_splits=cv_folds, shuffle=True, random_state=42)
|
|
@@ -546,7 +546,7 @@ def feature_importance_analysis(
|
|
|
546
546
|
raise ValueError("因变量和自变量数据不能为空")
|
|
547
547
|
|
|
548
548
|
if len(y_data) != len(x_data):
|
|
549
|
-
raise ValueError("因变量和自变量的观测数量不一致: y_data={}, x_data={
|
|
549
|
+
raise ValueError(f"因变量和自变量的观测数量不一致: y_data={len(y_data)}, x_data={len(x_data)}")
|
|
550
550
|
|
|
551
551
|
# 准备数据
|
|
552
552
|
X = np.array(x_data)
|
|
@@ -556,7 +556,7 @@ def feature_importance_analysis(
|
|
|
556
556
|
if feature_names is None:
|
|
557
557
|
feature_names = [f"x{i}" for i in range(X.shape[1])]
|
|
558
558
|
elif len(feature_names) != X.shape[1]:
|
|
559
|
-
raise ValueError("特征名称数量({})与自变量数量({
|
|
559
|
+
raise ValueError(f"特征名称数量({len(feature_names)})与自变量数量({X.shape[1]})不匹配")
|
|
560
560
|
|
|
561
561
|
# 数据标准化
|
|
562
562
|
scaler = StandardScaler()
|
|
@@ -568,7 +568,7 @@ def feature_importance_analysis(
|
|
|
568
568
|
elif method == "gradient_boosting":
|
|
569
569
|
model = GradientBoostingRegressor(n_estimators=100, random_state=42)
|
|
570
570
|
else:
|
|
571
|
-
raise ValueError("不支持的特征重要性分析方法: {}"
|
|
571
|
+
raise ValueError(f"不支持的特征重要性分析方法: {method}")
|
|
572
572
|
|
|
573
573
|
# 训练模型
|
|
574
574
|
model.fit(X_scaled, y)
|
|
@@ -649,7 +649,7 @@ def compare_ml_models(
|
|
|
649
649
|
results[model_name] = result.model_dump()
|
|
650
650
|
|
|
651
651
|
except Exception as e:
|
|
652
|
-
print("模型 {} 运行失败: {}"
|
|
652
|
+
print(f"模型 {model_name} 运行失败: {e}")
|
|
653
653
|
continue
|
|
654
654
|
|
|
655
655
|
# 找出最佳模型(基于R²得分)
|