aigroup-econ-mcp 1.4.3__py3-none-any.whl → 2.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. PKG-INFO +344 -322
  2. README.md +335 -320
  3. __init__.py +1 -1
  4. aigroup_econ_mcp-2.0.1.dist-info/METADATA +732 -0
  5. aigroup_econ_mcp-2.0.1.dist-info/RECORD +170 -0
  6. cli.py +4 -0
  7. econometrics/advanced_methods/modern_computing_machine_learning/__init__.py +30 -0
  8. econometrics/advanced_methods/modern_computing_machine_learning/causal_forest.py +253 -0
  9. econometrics/advanced_methods/modern_computing_machine_learning/double_ml.py +268 -0
  10. econometrics/advanced_methods/modern_computing_machine_learning/gradient_boosting.py +249 -0
  11. econometrics/advanced_methods/modern_computing_machine_learning/hierarchical_clustering.py +243 -0
  12. econometrics/advanced_methods/modern_computing_machine_learning/kmeans_clustering.py +293 -0
  13. econometrics/advanced_methods/modern_computing_machine_learning/neural_network.py +264 -0
  14. econometrics/advanced_methods/modern_computing_machine_learning/random_forest.py +195 -0
  15. econometrics/advanced_methods/modern_computing_machine_learning/support_vector_machine.py +226 -0
  16. econometrics/advanced_methods/modern_computing_machine_learning/test_all_modules.py +329 -0
  17. econometrics/advanced_methods/modern_computing_machine_learning/test_report.md +107 -0
  18. econometrics/causal_inference/__init__.py +66 -0
  19. econometrics/causal_inference/causal_identification_strategy/__init__.py +104 -0
  20. econometrics/causal_inference/causal_identification_strategy/control_function.py +112 -0
  21. econometrics/causal_inference/causal_identification_strategy/difference_in_differences.py +107 -0
  22. econometrics/causal_inference/causal_identification_strategy/event_study.py +119 -0
  23. econometrics/causal_inference/causal_identification_strategy/first_difference.py +89 -0
  24. econometrics/causal_inference/causal_identification_strategy/fixed_effects.py +103 -0
  25. econometrics/causal_inference/causal_identification_strategy/hausman_test.py +69 -0
  26. econometrics/causal_inference/causal_identification_strategy/instrumental_variables.py +145 -0
  27. econometrics/causal_inference/causal_identification_strategy/mediation_analysis.py +121 -0
  28. econometrics/causal_inference/causal_identification_strategy/moderation_analysis.py +109 -0
  29. econometrics/causal_inference/causal_identification_strategy/propensity_score_matching.py +140 -0
  30. econometrics/causal_inference/causal_identification_strategy/random_effects.py +100 -0
  31. econometrics/causal_inference/causal_identification_strategy/regression_discontinuity.py +98 -0
  32. econometrics/causal_inference/causal_identification_strategy/synthetic_control.py +111 -0
  33. econometrics/causal_inference/causal_identification_strategy/triple_difference.py +86 -0
  34. econometrics/distribution_analysis/__init__.py +28 -0
  35. econometrics/distribution_analysis/oaxaca_blinder.py +184 -0
  36. econometrics/distribution_analysis/time_series_decomposition.py +152 -0
  37. econometrics/distribution_analysis/variance_decomposition.py +179 -0
  38. econometrics/missing_data/__init__.py +18 -0
  39. econometrics/missing_data/imputation_methods.py +219 -0
  40. econometrics/nonparametric/__init__.py +35 -0
  41. econometrics/nonparametric/gam_model.py +117 -0
  42. econometrics/nonparametric/kernel_regression.py +161 -0
  43. econometrics/nonparametric/quantile_regression.py +249 -0
  44. econometrics/nonparametric/spline_regression.py +100 -0
  45. econometrics/spatial_econometrics/__init__.py +68 -0
  46. econometrics/spatial_econometrics/geographically_weighted_regression.py +211 -0
  47. econometrics/spatial_econometrics/gwr_simple.py +154 -0
  48. econometrics/spatial_econometrics/spatial_autocorrelation.py +356 -0
  49. econometrics/spatial_econometrics/spatial_durbin_model.py +177 -0
  50. econometrics/spatial_econometrics/spatial_regression.py +315 -0
  51. econometrics/spatial_econometrics/spatial_weights.py +226 -0
  52. econometrics/specific_data_modeling/micro_discrete_limited_data/README.md +164 -0
  53. econometrics/specific_data_modeling/micro_discrete_limited_data/__init__.py +40 -0
  54. econometrics/specific_data_modeling/micro_discrete_limited_data/count_data_models.py +311 -0
  55. econometrics/specific_data_modeling/micro_discrete_limited_data/discrete_choice_models.py +294 -0
  56. econometrics/specific_data_modeling/micro_discrete_limited_data/limited_dependent_variable_models.py +282 -0
  57. econometrics/statistical_inference/__init__.py +21 -0
  58. econometrics/statistical_inference/bootstrap_methods.py +162 -0
  59. econometrics/statistical_inference/permutation_test.py +177 -0
  60. econometrics/survival_analysis/__init__.py +18 -0
  61. econometrics/survival_analysis/survival_models.py +259 -0
  62. econometrics/tests/causal_inference_tests/__init__.py +3 -0
  63. econometrics/tests/causal_inference_tests/detailed_test.py +441 -0
  64. econometrics/tests/causal_inference_tests/test_all_methods.py +418 -0
  65. econometrics/tests/causal_inference_tests/test_causal_identification_strategy.py +202 -0
  66. econometrics/tests/causal_inference_tests/test_difference_in_differences.py +53 -0
  67. econometrics/tests/causal_inference_tests/test_instrumental_variables.py +44 -0
  68. econometrics/tests/specific_data_modeling_tests/test_micro_discrete_limited_data.py +189 -0
  69. econometrics//321/206/320/254/320/272/321/205/342/225/235/320/220/321/205/320/237/320/241/321/205/320/264/320/267/321/207/342/226/222/342/225/227/321/204/342/225/235/320/250/321/205/320/225/320/230/321/207/342/225/221/320/267/321/205/320/230/320/226/321/206/320/256/320/240.md +544 -0
  70. pyproject.toml +9 -2
  71. server.py +15 -1
  72. tools/__init__.py +75 -1
  73. tools/causal_inference_adapter.py +658 -0
  74. tools/distribution_analysis_adapter.py +121 -0
  75. tools/gwr_simple_adapter.py +54 -0
  76. tools/machine_learning_adapter.py +567 -0
  77. tools/mcp_tool_groups/__init__.py +15 -1
  78. tools/mcp_tool_groups/causal_inference_tools.py +643 -0
  79. tools/mcp_tool_groups/distribution_analysis_tools.py +169 -0
  80. tools/mcp_tool_groups/machine_learning_tools.py +422 -0
  81. tools/mcp_tool_groups/microecon_tools.py +325 -0
  82. tools/mcp_tool_groups/missing_data_tools.py +117 -0
  83. tools/mcp_tool_groups/nonparametric_tools.py +225 -0
  84. tools/mcp_tool_groups/spatial_econometrics_tools.py +323 -0
  85. tools/mcp_tool_groups/statistical_inference_tools.py +131 -0
  86. tools/mcp_tools_registry.py +13 -3
  87. tools/microecon_adapter.py +412 -0
  88. tools/missing_data_adapter.py +73 -0
  89. tools/nonparametric_adapter.py +190 -0
  90. tools/spatial_econometrics_adapter.py +318 -0
  91. tools/statistical_inference_adapter.py +90 -0
  92. tools/survival_analysis_adapter.py +46 -0
  93. aigroup_econ_mcp-1.4.3.dist-info/METADATA +0 -710
  94. aigroup_econ_mcp-1.4.3.dist-info/RECORD +0 -92
  95. {aigroup_econ_mcp-1.4.3.dist-info → aigroup_econ_mcp-2.0.1.dist-info}/WHEEL +0 -0
  96. {aigroup_econ_mcp-1.4.3.dist-info → aigroup_econ_mcp-2.0.1.dist-info}/entry_points.txt +0 -0
  97. {aigroup_econ_mcp-1.4.3.dist-info → aigroup_econ_mcp-2.0.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,179 @@
1
+ """
2
+ 方差分解 (Variance Decomposition / ANOVA)
3
+ 基于 scipy 和 statsmodels 实现
4
+ """
5
+
6
+ from typing import List, Optional, Dict
7
+ from pydantic import BaseModel, Field
8
+ import numpy as np
9
+
10
+ try:
11
+ from scipy import stats
12
+ import statsmodels.api as sm
13
+ from statsmodels.formula.api import ols
14
+ SCIPY_AVAILABLE = True
15
+ except ImportError:
16
+ SCIPY_AVAILABLE = False
17
+ stats = None
18
+ sm = None
19
+
20
+
21
+ class VarianceDecompositionResult(BaseModel):
22
+ """方差分解结果"""
23
+ total_variance: float = Field(..., description="总方差")
24
+ between_group_variance: float = Field(..., description="组间方差")
25
+ within_group_variance: float = Field(..., description="组内方差")
26
+ f_statistic: float = Field(..., description="F统计量")
27
+ p_value: float = Field(..., description="P值")
28
+ eta_squared: float = Field(..., description="Eta平方(效应量)")
29
+ omega_squared: float = Field(..., description="Omega平方(偏效应量)")
30
+ group_means: Dict[str, float] = Field(..., description="各组均值")
31
+ group_variances: Dict[str, float] = Field(..., description="各组方差")
32
+ group_sizes: Dict[str, int] = Field(..., description="各组样本量")
33
+ n_groups: int = Field(..., description="组数")
34
+ total_n: int = Field(..., description="总样本量")
35
+ summary: str = Field(..., description="摘要信息")
36
+
37
+
38
+ def variance_decomposition(
39
+ values: List[float],
40
+ groups: List[str],
41
+ group_names: Optional[List[str]] = None
42
+ ) -> VarianceDecompositionResult:
43
+ """
44
+ 方差分解 / 单因素ANOVA
45
+
46
+ Args:
47
+ values: 观测值列表
48
+ groups: 组别标识列表
49
+ group_names: 组名称映射
50
+
51
+ Returns:
52
+ VarianceDecompositionResult: 方差分解结果
53
+
54
+ Raises:
55
+ ImportError: scipy库未安装
56
+ ValueError: 输入数据无效
57
+ """
58
+ if not SCIPY_AVAILABLE:
59
+ raise ImportError("scipy和statsmodels库未安装。请运行: pip install scipy statsmodels")
60
+
61
+ # 输入验证
62
+ if not values or not groups:
63
+ raise ValueError("values和groups不能为空")
64
+
65
+ if len(values) != len(groups):
66
+ raise ValueError(f"values长度({len(values)})与groups长度({len(groups)})不一致")
67
+
68
+ # 数据准备
69
+ y = np.array(values, dtype=np.float64)
70
+ g = np.array(groups)
71
+
72
+ # 获取唯一组别
73
+ unique_groups = np.unique(g)
74
+ n_groups = len(unique_groups)
75
+
76
+ if n_groups < 2:
77
+ raise ValueError("至少需要2个组进行方差分解")
78
+
79
+ # 计算总体统计量
80
+ grand_mean = y.mean()
81
+ total_variance = y.var(ddof=1)
82
+ total_n = len(y)
83
+
84
+ # 计算各组统计量
85
+ group_means = {}
86
+ group_variances = {}
87
+ group_sizes = {}
88
+
89
+ # 按组分组数据
90
+ groups_data = []
91
+ for group_id in unique_groups:
92
+ mask = g == group_id
93
+ group_data = y[mask]
94
+ groups_data.append(group_data)
95
+
96
+ group_key = str(group_id)
97
+ group_means[group_key] = float(group_data.mean())
98
+ group_variances[group_key] = float(group_data.var(ddof=1))
99
+ group_sizes[group_key] = int(len(group_data))
100
+
101
+ # 执行单因素ANOVA
102
+ f_stat, p_value = stats.f_oneway(*groups_data)
103
+
104
+ # 计算组间方差和组内方差
105
+ # SS_between = Σnᵢ(ȳᵢ - ȳ)²
106
+ ss_between = sum(
107
+ group_sizes[str(gid)] * (group_means[str(gid)] - grand_mean)**2
108
+ for gid in unique_groups
109
+ )
110
+
111
+ # SS_within = Σ(nᵢ - 1)sᵢ²
112
+ ss_within = sum(
113
+ (group_sizes[str(gid)] - 1) * group_variances[str(gid)]
114
+ for gid in unique_groups
115
+ )
116
+
117
+ # SS_total
118
+ ss_total = (total_n - 1) * total_variance
119
+
120
+ # 自由度
121
+ df_between = n_groups - 1
122
+ df_within = total_n - n_groups
123
+
124
+ # 均方
125
+ ms_between = ss_between / df_between
126
+ ms_within = ss_within / df_within
127
+
128
+ # 组间方差和组内方差(作为总方差的比例)
129
+ between_group_var = ss_between / (total_n - 1)
130
+ within_group_var = ss_within / (total_n - 1)
131
+
132
+ # 效应量
133
+ # Eta平方 = SS_between / SS_total
134
+ eta_squared = ss_between / ss_total if ss_total > 0 else 0.0
135
+
136
+ # Omega平方(偏效应量)
137
+ omega_squared = (ss_between - df_between * ms_within) / (ss_total + ms_within)
138
+ omega_squared = max(0.0, omega_squared) # 确保非负
139
+
140
+ # 生成摘要
141
+ summary = f"""方差分解 (ANOVA) 分析:
142
+ - 总样本量: {total_n}
143
+ - 组数: {n_groups}
144
+ - 总方差: {total_variance:.4f}
145
+
146
+ 方差分解:
147
+ - 组间方差: {between_group_var:.4f} ({eta_squared*100:.1f}%)
148
+ - 组内方差: {within_group_var:.4f} ({(1-eta_squared)*100:.1f}%)
149
+
150
+ F检验:
151
+ - F统计量: {f_stat:.4f}
152
+ - P值: {p_value:.4f}
153
+ - 结论: {'组间差异显著' if p_value < 0.05 else '组间差异不显著'}
154
+
155
+ 效应量:
156
+ - Eta²: {eta_squared:.4f}
157
+ - Omega²: {omega_squared:.4f}
158
+
159
+ 各组均值:
160
+ """
161
+ for gid in unique_groups:
162
+ gkey = str(gid)
163
+ summary += f" {gkey}: {group_means[gkey]:.4f} (n={group_sizes[gkey]}, s²={group_variances[gkey]:.4f})\n"
164
+
165
+ return VarianceDecompositionResult(
166
+ total_variance=float(total_variance),
167
+ between_group_variance=float(between_group_var),
168
+ within_group_variance=float(within_group_var),
169
+ f_statistic=float(f_stat),
170
+ p_value=float(p_value),
171
+ eta_squared=float(eta_squared),
172
+ omega_squared=float(omega_squared),
173
+ group_means=group_means,
174
+ group_variances=group_variances,
175
+ group_sizes=group_sizes,
176
+ n_groups=n_groups,
177
+ total_n=total_n,
178
+ summary=summary
179
+ )
@@ -0,0 +1,18 @@
1
+ """
2
+ 缺失数据处理模块
3
+ 提供多种插补和处理缺失数据的方法
4
+ """
5
+
6
+ from .imputation_methods import (
7
+ simple_imputation,
8
+ multiple_imputation,
9
+ SimpleImputationResult,
10
+ MultipleImputationResult
11
+ )
12
+
13
+ __all__ = [
14
+ 'simple_imputation',
15
+ 'multiple_imputation',
16
+ 'SimpleImputationResult',
17
+ 'MultipleImputationResult'
18
+ ]
@@ -0,0 +1,219 @@
1
+ """
2
+ 缺失数据插补方法
3
+ 基于 sklearn.impute 实现
4
+ """
5
+
6
+ from typing import List, Optional, Dict
7
+ from pydantic import BaseModel, Field
8
+ import numpy as np
9
+
10
+ try:
11
+ from sklearn.impute import SimpleImputer, IterativeImputer
12
+ from sklearn.experimental import enable_iterative_imputer
13
+ SKLEARN_AVAILABLE = True
14
+ except ImportError:
15
+ SKLEARN_AVAILABLE = False
16
+ SimpleImputer = None
17
+ IterativeImputer = None
18
+
19
+
20
+ class SimpleImputationResult(BaseModel):
21
+ """简单插补结果"""
22
+ imputed_data: List[List[float]] = Field(..., description="插补后的数据")
23
+ missing_mask: List[List[bool]] = Field(..., description="缺失值掩码")
24
+ n_missing: int = Field(..., description="缺失值总数")
25
+ missing_rate: float = Field(..., description="缺失率")
26
+ imputation_method: str = Field(..., description="插补方法")
27
+ fill_values: List[float] = Field(..., description="填充值(每列)")
28
+ n_observations: int = Field(..., description="观测数量")
29
+ n_features: int = Field(..., description="特征数量")
30
+ summary: str = Field(..., description="摘要信息")
31
+
32
+
33
+ class MultipleImputationResult(BaseModel):
34
+ """多重插补结果"""
35
+ imputed_datasets: List[List[List[float]]] = Field(..., description="多个插补数据集")
36
+ n_imputations: int = Field(..., description="插补次数")
37
+ missing_mask: List[List[bool]] = Field(..., description="缺失值掩码")
38
+ n_missing: int = Field(..., description="缺失值总数")
39
+ missing_rate: float = Field(..., description="缺失率")
40
+ convergence_info: Dict = Field(..., description="收敛信息")
41
+ n_observations: int = Field(..., description="观测数量")
42
+ n_features: int = Field(..., description="特征数量")
43
+ summary: str = Field(..., description="摘要信息")
44
+
45
+
46
+ def simple_imputation(
47
+ data: List[List[float]],
48
+ strategy: str = "mean",
49
+ fill_value: Optional[float] = None
50
+ ) -> SimpleImputationResult:
51
+ """
52
+ 简单插补方法
53
+
54
+ Args:
55
+ data: 含缺失值的数据(二维列表,NaN表示缺失)
56
+ strategy: 插补策略 - "mean"(均值), "median"(中位数),
57
+ "most_frequent"(众数), "constant"(常数)
58
+ fill_value: 当strategy="constant"时使用的填充值
59
+
60
+ Returns:
61
+ SimpleImputationResult: 简单插补结果
62
+
63
+ Raises:
64
+ ImportError: sklearn库未安装
65
+ ValueError: 输入数据无效
66
+ """
67
+ if not SKLEARN_AVAILABLE:
68
+ raise ImportError("sklearn库未安装。请运行: pip install scikit-learn")
69
+
70
+ # 输入验证
71
+ if not data:
72
+ raise ValueError("data不能为空")
73
+
74
+ # 转换为numpy数组
75
+ X = np.array(data, dtype=np.float64)
76
+
77
+ if X.ndim == 1:
78
+ X = X.reshape(-1, 1)
79
+
80
+ n, k = X.shape
81
+
82
+ # 创建缺失值掩码
83
+ missing_mask = np.isnan(X)
84
+ n_missing = int(missing_mask.sum())
85
+ missing_rate = float(n_missing / (n * k))
86
+
87
+ # 简单插补
88
+ if strategy == "constant":
89
+ if fill_value is None:
90
+ fill_value = 0.0
91
+ imputer = SimpleImputer(strategy=strategy, fill_value=fill_value)
92
+ else:
93
+ imputer = SimpleImputer(strategy=strategy)
94
+
95
+ # 执行插补
96
+ X_imputed = imputer.fit_transform(X)
97
+
98
+ # 填充值
99
+ fill_values = imputer.statistics_.tolist()
100
+
101
+ # 生成摘要
102
+ summary = f"""简单插补:
103
+ - 观测数量: {n}
104
+ - 特征数量: {k}
105
+ - 缺失值数量: {n_missing}
106
+ - 缺失率: {missing_rate*100:.2f}%
107
+ - 插补策略: {strategy}
108
+
109
+ 各列填充值:
110
+ """
111
+ for i, val in enumerate(fill_values):
112
+ col_missing = int(missing_mask[:, i].sum())
113
+ summary += f" 列{i+1}: {val:.4f} (缺失{col_missing}个)\n"
114
+
115
+ return SimpleImputationResult(
116
+ imputed_data=X_imputed.tolist(),
117
+ missing_mask=missing_mask.tolist(),
118
+ n_missing=n_missing,
119
+ missing_rate=missing_rate,
120
+ imputation_method=strategy,
121
+ fill_values=fill_values,
122
+ n_observations=n,
123
+ n_features=k,
124
+ summary=summary
125
+ )
126
+
127
+
128
+ def multiple_imputation(
129
+ data: List[List[float]],
130
+ n_imputations: int = 5,
131
+ max_iter: int = 10,
132
+ random_state: Optional[int] = None
133
+ ) -> MultipleImputationResult:
134
+ """
135
+ 多重插补 (MICE - Multivariate Imputation by Chained Equations)
136
+
137
+ Args:
138
+ data: 含缺失值的数据
139
+ n_imputations: 生成的插补数据集数量
140
+ max_iter: 最大迭代次数
141
+ random_state: 随机种子
142
+
143
+ Returns:
144
+ MultipleImputationResult: 多重插补结果
145
+ """
146
+ if not SKLEARN_AVAILABLE:
147
+ raise ImportError("sklearn库未安装")
148
+
149
+ # 输入验证
150
+ if not data:
151
+ raise ValueError("data不能为空")
152
+
153
+ X = np.array(data, dtype=np.float64)
154
+
155
+ if X.ndim == 1:
156
+ X = X.reshape(-1, 1)
157
+
158
+ n, k = X.shape
159
+
160
+ # 缺失值统计
161
+ missing_mask = np.isnan(X)
162
+ n_missing = int(missing_mask.sum())
163
+ missing_rate = float(n_missing / (n * k))
164
+
165
+ # 执行多重插补
166
+ imputed_datasets = []
167
+ convergence_info = {"iterations": [], "converged": []}
168
+
169
+ for i in range(n_imputations):
170
+ # 设置随机种子
171
+ seed = random_state + i if random_state is not None else None
172
+
173
+ # 创建迭代插补器
174
+ imputer = IterativeImputer(
175
+ max_iter=max_iter,
176
+ random_state=seed,
177
+ verbose=0
178
+ )
179
+
180
+ # 执行插补
181
+ X_imputed = imputer.fit_transform(X)
182
+ imputed_datasets.append(X_imputed.tolist())
183
+
184
+ # 记录收敛信息
185
+ convergence_info["iterations"].append(imputer.n_iter_)
186
+ convergence_info["converged"].append(imputer.n_iter_ < max_iter)
187
+
188
+ # 计算平均收敛迭代数
189
+ avg_iter = np.mean(convergence_info["iterations"])
190
+ n_converged = sum(convergence_info["converged"])
191
+
192
+ # 生成摘要
193
+ summary = f"""多重插补 (MICE):
194
+ - 观测数量: {n}
195
+ - 特征数量: {k}
196
+ - 缺失值数量: {n_missing}
197
+ - 缺失率: {missing_rate*100:.2f}%
198
+ - 插补次数: {n_imputations}
199
+ - 最大迭代: {max_iter}
200
+
201
+ 收敛信息:
202
+ - 平均迭代数: {avg_iter:.1f}
203
+ - 收敛数据集: {n_converged}/{n_imputations}
204
+
205
+ 说明: 生成{n_imputations}个完整的插补数据集,
206
+ 可用于后续分析并合并结果(Rubin规则)
207
+ """
208
+
209
+ return MultipleImputationResult(
210
+ imputed_datasets=imputed_datasets,
211
+ n_imputations=n_imputations,
212
+ missing_mask=missing_mask.tolist(),
213
+ n_missing=n_missing,
214
+ missing_rate=missing_rate,
215
+ convergence_info=convergence_info,
216
+ n_observations=n,
217
+ n_features=k,
218
+ summary=summary
219
+ )
@@ -0,0 +1,35 @@
1
+ """
2
+ 非参数与半参数方法模块
3
+ 放宽函数形式的线性或参数化假设
4
+ """
5
+
6
+ from .kernel_regression import (
7
+ kernel_regression,
8
+ KernelRegressionResult
9
+ )
10
+
11
+ from .quantile_regression import (
12
+ quantile_regression,
13
+ QuantileRegressionResult
14
+ )
15
+
16
+ from .spline_regression import (
17
+ spline_regression,
18
+ SplineRegressionResult
19
+ )
20
+
21
+ from .gam_model import (
22
+ gam_model,
23
+ GAMResult
24
+ )
25
+
26
+ __all__ = [
27
+ 'kernel_regression',
28
+ 'KernelRegressionResult',
29
+ 'quantile_regression',
30
+ 'QuantileRegressionResult',
31
+ 'spline_regression',
32
+ 'SplineRegressionResult',
33
+ 'gam_model',
34
+ 'GAMResult'
35
+ ]
@@ -0,0 +1,117 @@
1
+ """
2
+ 广义可加模型 (Generalized Additive Model - GAM)
3
+ 基于 pygam 库实现
4
+ """
5
+
6
+ from typing import List, Optional
7
+ from pydantic import BaseModel, Field
8
+ import numpy as np
9
+
10
+ try:
11
+ from pygam import LinearGAM, LogisticGAM, s, f
12
+ PYGAM_AVAILABLE = True
13
+ except ImportError:
14
+ PYGAM_AVAILABLE = False
15
+ LinearGAM = None
16
+
17
+
18
+ class GAMResult(BaseModel):
19
+ """GAM模型结果"""
20
+ fitted_values: List[float] = Field(..., description="拟合值")
21
+ residuals: List[float] = Field(..., description="残差")
22
+ deviance: float = Field(..., description="偏差")
23
+ aic: float = Field(..., description="AIC信息准则")
24
+ aicc: float = Field(..., description="AICc信息准则")
25
+ r_squared: float = Field(..., description="伪R²")
26
+ n_splines: List[int] = Field(..., description="每个特征的样条数")
27
+ problem_type: str = Field(..., description="问题类型")
28
+ n_observations: int = Field(..., description="观测数量")
29
+ summary: str = Field(..., description="摘要信息")
30
+
31
+
32
+ def gam_model(
33
+ y_data: List[float],
34
+ x_data: List[List[float]],
35
+ problem_type: str = "regression",
36
+ n_splines: int = 10,
37
+ lam: float = 0.6
38
+ ) -> GAMResult:
39
+ """
40
+ 广义可加模型
41
+
42
+ Args:
43
+ y_data: 因变量
44
+ x_data: 自变量(二维列表)
45
+ problem_type: 问题类型 - "regression"(回归) 或 "classification"(分类)
46
+ n_splines: 每个特征的样条数
47
+ lam: 平滑参数(lambda)
48
+
49
+ Returns:
50
+ GAMResult: GAM模型结果
51
+ """
52
+ if not PYGAM_AVAILABLE:
53
+ raise ImportError("pygam库未安装。请运行: pip install pygam")
54
+
55
+ # 数据准备
56
+ y = np.array(y_data, dtype=np.float64)
57
+ X = np.array(x_data, dtype=np.float64)
58
+
59
+ if X.ndim == 1:
60
+ X = X.reshape(-1, 1)
61
+
62
+ n, k = X.shape
63
+
64
+ # 创建GAM模型
65
+ if problem_type == "regression":
66
+ gam = LinearGAM(s(0, n_splines=n_splines, lam=lam))
67
+ for i in range(1, k):
68
+ gam = LinearGAM(s(i, n_splines=n_splines, lam=lam))
69
+ elif problem_type == "classification":
70
+ gam = LogisticGAM(s(0, n_splines=n_splines, lam=lam))
71
+ else:
72
+ raise ValueError(f"不支持的问题类型: {problem_type}")
73
+
74
+ # 拟合模型
75
+ gam.fit(X, y)
76
+
77
+ # 拟合值
78
+ y_pred = gam.predict(X)
79
+
80
+ # 残差
81
+ residuals = y - y_pred
82
+
83
+ # 模型统计量
84
+ deviance = float(gam.statistics_['deviance'])
85
+ aic = float(gam.statistics_['AIC'])
86
+ aicc = float(gam.statistics_['AICc'])
87
+
88
+ # 伪R²
89
+ r_squared = float(gam.statistics_['pseudo_r2']['explained_deviance'])
90
+
91
+ # 样条数信息
92
+ n_splines_list = [n_splines] * k
93
+
94
+ summary = f"""广义可加模型 (GAM):
95
+ - 观测数量: {n}
96
+ - 特征数量: {k}
97
+ - 问题类型: {problem_type}
98
+ - 样条数: {n_splines}
99
+ - 平滑参数: {lam}
100
+ - 偏差: {deviance:.4f}
101
+ - AIC: {aic:.2f}
102
+ - AICc: {aicc:.2f}
103
+ - 伪R²: {r_squared:.4f}
104
+ """
105
+
106
+ return GAMResult(
107
+ fitted_values=y_pred.tolist(),
108
+ residuals=residuals.tolist(),
109
+ deviance=deviance,
110
+ aic=aic,
111
+ aicc=aicc,
112
+ r_squared=r_squared,
113
+ n_splines=n_splines_list,
114
+ problem_type=problem_type,
115
+ n_observations=n,
116
+ summary=summary
117
+ )