ml4t-diagnostic 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. ml4t/diagnostic/AGENT.md +25 -0
  2. ml4t/diagnostic/__init__.py +166 -0
  3. ml4t/diagnostic/backends/__init__.py +10 -0
  4. ml4t/diagnostic/backends/adapter.py +192 -0
  5. ml4t/diagnostic/backends/polars_backend.py +899 -0
  6. ml4t/diagnostic/caching/__init__.py +40 -0
  7. ml4t/diagnostic/caching/cache.py +331 -0
  8. ml4t/diagnostic/caching/decorators.py +131 -0
  9. ml4t/diagnostic/caching/smart_cache.py +339 -0
  10. ml4t/diagnostic/config/AGENT.md +24 -0
  11. ml4t/diagnostic/config/README.md +267 -0
  12. ml4t/diagnostic/config/__init__.py +219 -0
  13. ml4t/diagnostic/config/barrier_config.py +277 -0
  14. ml4t/diagnostic/config/base.py +301 -0
  15. ml4t/diagnostic/config/event_config.py +148 -0
  16. ml4t/diagnostic/config/feature_config.py +404 -0
  17. ml4t/diagnostic/config/multi_signal_config.py +55 -0
  18. ml4t/diagnostic/config/portfolio_config.py +215 -0
  19. ml4t/diagnostic/config/report_config.py +391 -0
  20. ml4t/diagnostic/config/sharpe_config.py +202 -0
  21. ml4t/diagnostic/config/signal_config.py +206 -0
  22. ml4t/diagnostic/config/trade_analysis_config.py +310 -0
  23. ml4t/diagnostic/config/validation.py +279 -0
  24. ml4t/diagnostic/core/__init__.py +29 -0
  25. ml4t/diagnostic/core/numba_utils.py +315 -0
  26. ml4t/diagnostic/core/purging.py +372 -0
  27. ml4t/diagnostic/core/sampling.py +471 -0
  28. ml4t/diagnostic/errors/__init__.py +205 -0
  29. ml4t/diagnostic/evaluation/AGENT.md +26 -0
  30. ml4t/diagnostic/evaluation/__init__.py +437 -0
  31. ml4t/diagnostic/evaluation/autocorrelation.py +531 -0
  32. ml4t/diagnostic/evaluation/barrier_analysis.py +1050 -0
  33. ml4t/diagnostic/evaluation/binary_metrics.py +910 -0
  34. ml4t/diagnostic/evaluation/dashboard.py +715 -0
  35. ml4t/diagnostic/evaluation/diagnostic_plots.py +1037 -0
  36. ml4t/diagnostic/evaluation/distribution/__init__.py +499 -0
  37. ml4t/diagnostic/evaluation/distribution/moments.py +299 -0
  38. ml4t/diagnostic/evaluation/distribution/tails.py +777 -0
  39. ml4t/diagnostic/evaluation/distribution/tests.py +470 -0
  40. ml4t/diagnostic/evaluation/drift/__init__.py +139 -0
  41. ml4t/diagnostic/evaluation/drift/analysis.py +432 -0
  42. ml4t/diagnostic/evaluation/drift/domain_classifier.py +517 -0
  43. ml4t/diagnostic/evaluation/drift/population_stability_index.py +310 -0
  44. ml4t/diagnostic/evaluation/drift/wasserstein.py +388 -0
  45. ml4t/diagnostic/evaluation/event_analysis.py +647 -0
  46. ml4t/diagnostic/evaluation/excursion.py +390 -0
  47. ml4t/diagnostic/evaluation/feature_diagnostics.py +873 -0
  48. ml4t/diagnostic/evaluation/feature_outcome.py +666 -0
  49. ml4t/diagnostic/evaluation/framework.py +935 -0
  50. ml4t/diagnostic/evaluation/metric_registry.py +255 -0
  51. ml4t/diagnostic/evaluation/metrics/AGENT.md +23 -0
  52. ml4t/diagnostic/evaluation/metrics/__init__.py +133 -0
  53. ml4t/diagnostic/evaluation/metrics/basic.py +160 -0
  54. ml4t/diagnostic/evaluation/metrics/conditional_ic.py +469 -0
  55. ml4t/diagnostic/evaluation/metrics/feature_outcome.py +475 -0
  56. ml4t/diagnostic/evaluation/metrics/ic_statistics.py +446 -0
  57. ml4t/diagnostic/evaluation/metrics/importance_analysis.py +338 -0
  58. ml4t/diagnostic/evaluation/metrics/importance_classical.py +375 -0
  59. ml4t/diagnostic/evaluation/metrics/importance_mda.py +371 -0
  60. ml4t/diagnostic/evaluation/metrics/importance_shap.py +715 -0
  61. ml4t/diagnostic/evaluation/metrics/information_coefficient.py +527 -0
  62. ml4t/diagnostic/evaluation/metrics/interactions.py +772 -0
  63. ml4t/diagnostic/evaluation/metrics/monotonicity.py +226 -0
  64. ml4t/diagnostic/evaluation/metrics/risk_adjusted.py +324 -0
  65. ml4t/diagnostic/evaluation/multi_signal.py +550 -0
  66. ml4t/diagnostic/evaluation/portfolio_analysis/__init__.py +83 -0
  67. ml4t/diagnostic/evaluation/portfolio_analysis/analysis.py +734 -0
  68. ml4t/diagnostic/evaluation/portfolio_analysis/metrics.py +589 -0
  69. ml4t/diagnostic/evaluation/portfolio_analysis/results.py +334 -0
  70. ml4t/diagnostic/evaluation/report_generation.py +824 -0
  71. ml4t/diagnostic/evaluation/signal_selector.py +452 -0
  72. ml4t/diagnostic/evaluation/stat_registry.py +139 -0
  73. ml4t/diagnostic/evaluation/stationarity/__init__.py +97 -0
  74. ml4t/diagnostic/evaluation/stationarity/analysis.py +518 -0
  75. ml4t/diagnostic/evaluation/stationarity/augmented_dickey_fuller.py +296 -0
  76. ml4t/diagnostic/evaluation/stationarity/kpss_test.py +308 -0
  77. ml4t/diagnostic/evaluation/stationarity/phillips_perron.py +365 -0
  78. ml4t/diagnostic/evaluation/stats/AGENT.md +43 -0
  79. ml4t/diagnostic/evaluation/stats/__init__.py +191 -0
  80. ml4t/diagnostic/evaluation/stats/backtest_overfitting.py +219 -0
  81. ml4t/diagnostic/evaluation/stats/bootstrap.py +228 -0
  82. ml4t/diagnostic/evaluation/stats/deflated_sharpe_ratio.py +591 -0
  83. ml4t/diagnostic/evaluation/stats/false_discovery_rate.py +295 -0
  84. ml4t/diagnostic/evaluation/stats/hac_standard_errors.py +108 -0
  85. ml4t/diagnostic/evaluation/stats/minimum_track_record.py +408 -0
  86. ml4t/diagnostic/evaluation/stats/moments.py +164 -0
  87. ml4t/diagnostic/evaluation/stats/rademacher_adjustment.py +436 -0
  88. ml4t/diagnostic/evaluation/stats/reality_check.py +155 -0
  89. ml4t/diagnostic/evaluation/stats/sharpe_inference.py +219 -0
  90. ml4t/diagnostic/evaluation/themes.py +330 -0
  91. ml4t/diagnostic/evaluation/threshold_analysis.py +957 -0
  92. ml4t/diagnostic/evaluation/trade_analysis.py +1136 -0
  93. ml4t/diagnostic/evaluation/trade_dashboard/__init__.py +32 -0
  94. ml4t/diagnostic/evaluation/trade_dashboard/app.py +315 -0
  95. ml4t/diagnostic/evaluation/trade_dashboard/export/__init__.py +18 -0
  96. ml4t/diagnostic/evaluation/trade_dashboard/export/csv.py +82 -0
  97. ml4t/diagnostic/evaluation/trade_dashboard/export/html.py +276 -0
  98. ml4t/diagnostic/evaluation/trade_dashboard/io.py +166 -0
  99. ml4t/diagnostic/evaluation/trade_dashboard/normalize.py +304 -0
  100. ml4t/diagnostic/evaluation/trade_dashboard/stats.py +386 -0
  101. ml4t/diagnostic/evaluation/trade_dashboard/style.py +79 -0
  102. ml4t/diagnostic/evaluation/trade_dashboard/tabs/__init__.py +21 -0
  103. ml4t/diagnostic/evaluation/trade_dashboard/tabs/patterns.py +354 -0
  104. ml4t/diagnostic/evaluation/trade_dashboard/tabs/shap_analysis.py +280 -0
  105. ml4t/diagnostic/evaluation/trade_dashboard/tabs/stat_validation.py +186 -0
  106. ml4t/diagnostic/evaluation/trade_dashboard/tabs/worst_trades.py +236 -0
  107. ml4t/diagnostic/evaluation/trade_dashboard/types.py +129 -0
  108. ml4t/diagnostic/evaluation/trade_shap/__init__.py +102 -0
  109. ml4t/diagnostic/evaluation/trade_shap/alignment.py +188 -0
  110. ml4t/diagnostic/evaluation/trade_shap/characterize.py +413 -0
  111. ml4t/diagnostic/evaluation/trade_shap/cluster.py +302 -0
  112. ml4t/diagnostic/evaluation/trade_shap/explain.py +208 -0
  113. ml4t/diagnostic/evaluation/trade_shap/hypotheses/__init__.py +23 -0
  114. ml4t/diagnostic/evaluation/trade_shap/hypotheses/generator.py +290 -0
  115. ml4t/diagnostic/evaluation/trade_shap/hypotheses/matcher.py +251 -0
  116. ml4t/diagnostic/evaluation/trade_shap/hypotheses/templates.yaml +467 -0
  117. ml4t/diagnostic/evaluation/trade_shap/models.py +386 -0
  118. ml4t/diagnostic/evaluation/trade_shap/normalize.py +116 -0
  119. ml4t/diagnostic/evaluation/trade_shap/pipeline.py +263 -0
  120. ml4t/diagnostic/evaluation/trade_shap_dashboard.py +283 -0
  121. ml4t/diagnostic/evaluation/trade_shap_diagnostics.py +588 -0
  122. ml4t/diagnostic/evaluation/validated_cv.py +535 -0
  123. ml4t/diagnostic/evaluation/visualization.py +1050 -0
  124. ml4t/diagnostic/evaluation/volatility/__init__.py +45 -0
  125. ml4t/diagnostic/evaluation/volatility/analysis.py +351 -0
  126. ml4t/diagnostic/evaluation/volatility/arch.py +258 -0
  127. ml4t/diagnostic/evaluation/volatility/garch.py +460 -0
  128. ml4t/diagnostic/integration/__init__.py +48 -0
  129. ml4t/diagnostic/integration/backtest_contract.py +671 -0
  130. ml4t/diagnostic/integration/data_contract.py +316 -0
  131. ml4t/diagnostic/integration/engineer_contract.py +226 -0
  132. ml4t/diagnostic/logging/__init__.py +77 -0
  133. ml4t/diagnostic/logging/logger.py +245 -0
  134. ml4t/diagnostic/logging/performance.py +234 -0
  135. ml4t/diagnostic/logging/progress.py +234 -0
  136. ml4t/diagnostic/logging/wandb.py +412 -0
  137. ml4t/diagnostic/metrics/__init__.py +9 -0
  138. ml4t/diagnostic/metrics/percentiles.py +128 -0
  139. ml4t/diagnostic/py.typed +1 -0
  140. ml4t/diagnostic/reporting/__init__.py +43 -0
  141. ml4t/diagnostic/reporting/base.py +130 -0
  142. ml4t/diagnostic/reporting/html_renderer.py +275 -0
  143. ml4t/diagnostic/reporting/json_renderer.py +51 -0
  144. ml4t/diagnostic/reporting/markdown_renderer.py +117 -0
  145. ml4t/diagnostic/results/AGENT.md +24 -0
  146. ml4t/diagnostic/results/__init__.py +105 -0
  147. ml4t/diagnostic/results/barrier_results/__init__.py +36 -0
  148. ml4t/diagnostic/results/barrier_results/hit_rate.py +304 -0
  149. ml4t/diagnostic/results/barrier_results/precision_recall.py +266 -0
  150. ml4t/diagnostic/results/barrier_results/profit_factor.py +297 -0
  151. ml4t/diagnostic/results/barrier_results/tearsheet.py +397 -0
  152. ml4t/diagnostic/results/barrier_results/time_to_target.py +305 -0
  153. ml4t/diagnostic/results/barrier_results/validation.py +38 -0
  154. ml4t/diagnostic/results/base.py +177 -0
  155. ml4t/diagnostic/results/event_results.py +349 -0
  156. ml4t/diagnostic/results/feature_results.py +787 -0
  157. ml4t/diagnostic/results/multi_signal_results.py +431 -0
  158. ml4t/diagnostic/results/portfolio_results.py +281 -0
  159. ml4t/diagnostic/results/sharpe_results.py +448 -0
  160. ml4t/diagnostic/results/signal_results/__init__.py +74 -0
  161. ml4t/diagnostic/results/signal_results/ic.py +581 -0
  162. ml4t/diagnostic/results/signal_results/irtc.py +110 -0
  163. ml4t/diagnostic/results/signal_results/quantile.py +392 -0
  164. ml4t/diagnostic/results/signal_results/tearsheet.py +456 -0
  165. ml4t/diagnostic/results/signal_results/turnover.py +213 -0
  166. ml4t/diagnostic/results/signal_results/validation.py +147 -0
  167. ml4t/diagnostic/signal/AGENT.md +17 -0
  168. ml4t/diagnostic/signal/__init__.py +69 -0
  169. ml4t/diagnostic/signal/_report.py +152 -0
  170. ml4t/diagnostic/signal/_utils.py +261 -0
  171. ml4t/diagnostic/signal/core.py +275 -0
  172. ml4t/diagnostic/signal/quantile.py +148 -0
  173. ml4t/diagnostic/signal/result.py +214 -0
  174. ml4t/diagnostic/signal/signal_ic.py +129 -0
  175. ml4t/diagnostic/signal/turnover.py +182 -0
  176. ml4t/diagnostic/splitters/AGENT.md +19 -0
  177. ml4t/diagnostic/splitters/__init__.py +36 -0
  178. ml4t/diagnostic/splitters/base.py +501 -0
  179. ml4t/diagnostic/splitters/calendar.py +421 -0
  180. ml4t/diagnostic/splitters/calendar_config.py +91 -0
  181. ml4t/diagnostic/splitters/combinatorial.py +1064 -0
  182. ml4t/diagnostic/splitters/config.py +322 -0
  183. ml4t/diagnostic/splitters/cpcv/__init__.py +57 -0
  184. ml4t/diagnostic/splitters/cpcv/combinations.py +119 -0
  185. ml4t/diagnostic/splitters/cpcv/partitioning.py +263 -0
  186. ml4t/diagnostic/splitters/cpcv/purge_engine.py +379 -0
  187. ml4t/diagnostic/splitters/cpcv/windows.py +190 -0
  188. ml4t/diagnostic/splitters/group_isolation.py +329 -0
  189. ml4t/diagnostic/splitters/persistence.py +316 -0
  190. ml4t/diagnostic/splitters/utils.py +207 -0
  191. ml4t/diagnostic/splitters/walk_forward.py +757 -0
  192. ml4t/diagnostic/utils/__init__.py +42 -0
  193. ml4t/diagnostic/utils/config.py +542 -0
  194. ml4t/diagnostic/utils/dependencies.py +318 -0
  195. ml4t/diagnostic/utils/sessions.py +127 -0
  196. ml4t/diagnostic/validation/__init__.py +54 -0
  197. ml4t/diagnostic/validation/dataframe.py +274 -0
  198. ml4t/diagnostic/validation/returns.py +280 -0
  199. ml4t/diagnostic/validation/timeseries.py +299 -0
  200. ml4t/diagnostic/visualization/AGENT.md +19 -0
  201. ml4t/diagnostic/visualization/__init__.py +223 -0
  202. ml4t/diagnostic/visualization/backtest/__init__.py +98 -0
  203. ml4t/diagnostic/visualization/backtest/cost_attribution.py +762 -0
  204. ml4t/diagnostic/visualization/backtest/executive_summary.py +895 -0
  205. ml4t/diagnostic/visualization/backtest/interactive_controls.py +673 -0
  206. ml4t/diagnostic/visualization/backtest/statistical_validity.py +874 -0
  207. ml4t/diagnostic/visualization/backtest/tearsheet.py +565 -0
  208. ml4t/diagnostic/visualization/backtest/template_system.py +373 -0
  209. ml4t/diagnostic/visualization/backtest/trade_plots.py +1172 -0
  210. ml4t/diagnostic/visualization/barrier_plots.py +782 -0
  211. ml4t/diagnostic/visualization/core.py +1060 -0
  212. ml4t/diagnostic/visualization/dashboards/__init__.py +36 -0
  213. ml4t/diagnostic/visualization/dashboards/base.py +582 -0
  214. ml4t/diagnostic/visualization/dashboards/importance.py +801 -0
  215. ml4t/diagnostic/visualization/dashboards/interaction.py +263 -0
  216. ml4t/diagnostic/visualization/dashboards.py +43 -0
  217. ml4t/diagnostic/visualization/data_extraction/__init__.py +48 -0
  218. ml4t/diagnostic/visualization/data_extraction/importance.py +649 -0
  219. ml4t/diagnostic/visualization/data_extraction/interaction.py +504 -0
  220. ml4t/diagnostic/visualization/data_extraction/types.py +113 -0
  221. ml4t/diagnostic/visualization/data_extraction/validation.py +66 -0
  222. ml4t/diagnostic/visualization/feature_plots.py +888 -0
  223. ml4t/diagnostic/visualization/interaction_plots.py +618 -0
  224. ml4t/diagnostic/visualization/portfolio/__init__.py +41 -0
  225. ml4t/diagnostic/visualization/portfolio/dashboard.py +514 -0
  226. ml4t/diagnostic/visualization/portfolio/drawdown_plots.py +341 -0
  227. ml4t/diagnostic/visualization/portfolio/returns_plots.py +487 -0
  228. ml4t/diagnostic/visualization/portfolio/risk_plots.py +301 -0
  229. ml4t/diagnostic/visualization/report_generation.py +1343 -0
  230. ml4t/diagnostic/visualization/signal/__init__.py +103 -0
  231. ml4t/diagnostic/visualization/signal/dashboard.py +911 -0
  232. ml4t/diagnostic/visualization/signal/event_plots.py +514 -0
  233. ml4t/diagnostic/visualization/signal/ic_plots.py +635 -0
  234. ml4t/diagnostic/visualization/signal/multi_signal_dashboard.py +974 -0
  235. ml4t/diagnostic/visualization/signal/multi_signal_plots.py +603 -0
  236. ml4t/diagnostic/visualization/signal/quantile_plots.py +625 -0
  237. ml4t/diagnostic/visualization/signal/turnover_plots.py +400 -0
  238. ml4t/diagnostic/visualization/trade_shap/__init__.py +90 -0
  239. ml4t_diagnostic-0.1.0a1.dist-info/METADATA +1044 -0
  240. ml4t_diagnostic-0.1.0a1.dist-info/RECORD +242 -0
  241. ml4t_diagnostic-0.1.0a1.dist-info/WHEEL +4 -0
  242. ml4t_diagnostic-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,647 @@
1
+ """Event Study Analysis Module.
2
+
3
+ This module implements event study methodology following MacKinlay (1997)
4
+ "Event Studies in Economics and Finance" for measuring abnormal returns
5
+ around corporate events, announcements, or other market events.
6
+
7
+ Classes
8
+ -------
9
+ EventStudyAnalysis
10
+ Main class for conducting event studies
11
+
12
+ References
13
+ ----------
14
+ MacKinlay, A.C. (1997). "Event Studies in Economics and Finance",
15
+ Journal of Economic Literature, 35(1), 13-39.
16
+ Boehmer, E., Musumeci, J., Poulsen, A.B. (1991). "Event-study methodology
17
+ under conditions of event-induced variance", Journal of Financial Economics.
18
+ Corrado, C.J. (1989). "A nonparametric test for abnormal security-price
19
+ performance in event studies", Journal of Financial Economics.
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import warnings
25
+ from typing import TYPE_CHECKING, Any
26
+
27
+ import numpy as np
28
+ import polars as pl
29
+ from scipy import stats
30
+
31
+ from ml4t.diagnostic.config.event_config import EventConfig
32
+ from ml4t.diagnostic.results.event_results import AbnormalReturnResult, EventStudyResult
33
+
34
+ if TYPE_CHECKING:
35
+ import pandas as pd
36
+
37
+
38
+ class EventStudyAnalysis:
39
+ """Event study analysis for measuring abnormal returns around events.
40
+
41
+ Implements the standard event study methodology with support for:
42
+ - Market model (CAPM-based expected returns)
43
+ - Mean-adjusted model
44
+ - Market-adjusted model
45
+
46
+ And statistical tests:
47
+ - Standard t-test
48
+ - BMP test (Boehmer et al. 1991, robust to event-induced variance)
49
+ - Corrado rank test (non-parametric)
50
+
51
+ Parameters
52
+ ----------
53
+ returns : pl.DataFrame
54
+ Asset returns in long format with columns: [date, asset, return].
55
+ Returns should be simple returns (not log returns).
56
+ events : pl.DataFrame
57
+ Events to analyze with columns: [date, asset]. Optionally
58
+ includes [event_type, event_id] for grouping.
59
+ benchmark : pl.DataFrame
60
+ Market/benchmark returns with columns: [date, return].
61
+ config : EventConfig, optional
62
+ Configuration for the analysis.
63
+
64
+ Examples
65
+ --------
66
+ >>> returns_df = pl.DataFrame({
67
+ ... 'date': [...],
68
+ ... 'asset': [...],
69
+ ... 'return': [...]
70
+ ... })
71
+ >>> events_df = pl.DataFrame({
72
+ ... 'date': ['2023-01-15', '2023-02-20'],
73
+ ... 'asset': ['AAPL', 'MSFT']
74
+ ... })
75
+ >>> benchmark_df = pl.DataFrame({
76
+ ... 'date': [...],
77
+ ... 'return': [...] # Market returns
78
+ ... })
79
+ >>> analysis = EventStudyAnalysis(returns_df, events_df, benchmark_df)
80
+ >>> result = analysis.run()
81
+ >>> print(result.summary())
82
+ """
83
+
84
+ def __init__(
85
+ self,
86
+ returns: pl.DataFrame | pd.DataFrame,
87
+ events: pl.DataFrame | pd.DataFrame,
88
+ benchmark: pl.DataFrame | pd.DataFrame,
89
+ config: EventConfig | None = None,
90
+ ) -> None:
91
+ """Initialize event study analysis."""
92
+ self.config = config or EventConfig()
93
+
94
+ # Convert to Polars if needed
95
+ self._returns = self._to_polars(returns)
96
+ self._events = self._to_polars(events)
97
+ self._benchmark = self._to_polars(benchmark)
98
+
99
+ # Validate inputs
100
+ self._validate_inputs()
101
+
102
+ # Prepare data
103
+ self._prepare_data()
104
+
105
+ # Cache for computed results
106
+ self._ar_results: list[AbnormalReturnResult] | None = None
107
+ self._aggregated_result: EventStudyResult | None = None
108
+
109
+ def _to_polars(self, df: Any) -> pl.DataFrame:
110
+ """Convert DataFrame to Polars if needed."""
111
+ if isinstance(df, pl.DataFrame):
112
+ return df
113
+ try:
114
+ import pandas as pd
115
+
116
+ if isinstance(df, pd.DataFrame):
117
+ return pl.from_pandas(df)
118
+ except ImportError:
119
+ pass
120
+ raise TypeError(f"Expected Polars or Pandas DataFrame, got {type(df)}")
121
+
122
+ def _validate_inputs(self) -> None:
123
+ """Validate input DataFrames have required columns."""
124
+ # Check returns
125
+ required_return_cols = {"date", "asset", "return"}
126
+ if not required_return_cols.issubset(set(self._returns.columns)):
127
+ raise ValueError(
128
+ f"returns DataFrame missing columns: {required_return_cols - set(self._returns.columns)}"
129
+ )
130
+
131
+ # Check events
132
+ required_event_cols = {"date", "asset"}
133
+ if not required_event_cols.issubset(set(self._events.columns)):
134
+ raise ValueError(
135
+ f"events DataFrame missing columns: {required_event_cols - set(self._events.columns)}"
136
+ )
137
+
138
+ # Check benchmark
139
+ required_bench_cols = {"date", "return"}
140
+ if not required_bench_cols.issubset(set(self._benchmark.columns)):
141
+ raise ValueError(
142
+ f"benchmark DataFrame missing columns: {required_bench_cols - set(self._benchmark.columns)}"
143
+ )
144
+
145
+ # Check we have events
146
+ if len(self._events) == 0:
147
+ raise ValueError("No events provided")
148
+
149
+ def _prepare_data(self) -> None:
150
+ """Prepare data for analysis (sorting, date alignment)."""
151
+ # Sort by date
152
+ self._returns = self._returns.sort("date")
153
+ self._benchmark = self._benchmark.sort("date")
154
+
155
+ # Create date-indexed lookup for benchmark
156
+ self._benchmark_dict: dict[Any, float] = dict(
157
+ zip(
158
+ self._benchmark["date"].to_list(),
159
+ self._benchmark["return"].to_list(),
160
+ strict=False,
161
+ )
162
+ )
163
+
164
+ # Get unique dates for index mapping
165
+ self._all_dates = sorted(self._returns["date"].unique().to_list())
166
+ self._date_to_idx = {d: i for i, d in enumerate(self._all_dates)}
167
+
168
+ # Add event_id if not present
169
+ if "event_id" not in self._events.columns:
170
+ self._events = self._events.with_row_index("event_id").with_columns(
171
+ pl.col("event_id").cast(pl.Utf8).alias("event_id")
172
+ )
173
+
174
+ def _get_estimation_window_data(
175
+ self, asset: str, event_date: Any
176
+ ) -> tuple[np.ndarray, np.ndarray] | None:
177
+ """Get returns for estimation window.
178
+
179
+ Returns
180
+ -------
181
+ tuple[np.ndarray, np.ndarray] | None
182
+ (asset_returns, market_returns) for estimation window,
183
+ or None if insufficient data.
184
+ """
185
+ est_start, est_end = self.config.window.estimation_window
186
+
187
+ # Find event date index
188
+ if event_date not in self._date_to_idx:
189
+ return None
190
+ event_idx = self._date_to_idx[event_date]
191
+
192
+ # Calculate estimation window indices
193
+ start_idx = event_idx + est_start
194
+ end_idx = event_idx + est_end
195
+
196
+ if start_idx < 0:
197
+ return None
198
+
199
+ # Get dates in estimation window
200
+ est_dates = self._all_dates[start_idx : end_idx + 1]
201
+
202
+ if len(est_dates) < self.config.min_estimation_obs:
203
+ return None
204
+
205
+ # Get asset returns
206
+ asset_data = self._returns.filter(
207
+ (pl.col("asset") == asset) & (pl.col("date").is_in(est_dates))
208
+ ).sort("date")
209
+
210
+ if len(asset_data) < self.config.min_estimation_obs:
211
+ return None
212
+
213
+ # Get benchmark returns
214
+ asset_returns = []
215
+ market_returns = []
216
+ for row in asset_data.iter_rows(named=True):
217
+ date = row["date"]
218
+ if date in self._benchmark_dict:
219
+ asset_returns.append(row["return"])
220
+ market_returns.append(self._benchmark_dict[date])
221
+
222
+ if len(asset_returns) < self.config.min_estimation_obs:
223
+ return None
224
+
225
+ return np.array(asset_returns), np.array(market_returns)
226
+
227
+ def _estimate_market_model(
228
+ self, asset_returns: np.ndarray, market_returns: np.ndarray
229
+ ) -> tuple[float, float, float, float]:
230
+ """Estimate market model parameters via OLS.
231
+
232
+ AR = R - (α + β*Rm)
233
+
234
+ Returns
235
+ -------
236
+ tuple[float, float, float, float]
237
+ (alpha, beta, r_squared, residual_std)
238
+ """
239
+ # OLS regression: R_asset = alpha + beta * R_market + epsilon
240
+ X = np.column_stack([np.ones(len(market_returns)), market_returns])
241
+ y = asset_returns
242
+
243
+ # Solve normal equations
244
+ try:
245
+ coeffs, residuals, _, _ = np.linalg.lstsq(X, y, rcond=None)
246
+ alpha, beta = coeffs[0], coeffs[1]
247
+
248
+ # Calculate R-squared
249
+ y_pred = alpha + beta * market_returns
250
+ ss_res = np.sum((y - y_pred) ** 2)
251
+ ss_tot = np.sum((y - np.mean(y)) ** 2)
252
+ r_squared = 1 - ss_res / ss_tot if ss_tot > 0 else 0.0
253
+
254
+ # Residual standard deviation
255
+ residual_std = np.std(y - y_pred, ddof=2)
256
+
257
+ return alpha, beta, r_squared, residual_std
258
+ except Exception:
259
+ return 0.0, 1.0, 0.0, np.std(asset_returns)
260
+
261
+ def _get_event_window_data(
262
+ self, asset: str, event_date: Any
263
+ ) -> dict[int, tuple[float, float]] | None:
264
+ """Get returns for event window.
265
+
266
+ Returns
267
+ -------
268
+ dict[int, tuple[float, float]] | None
269
+ {relative_day: (asset_return, market_return)}
270
+ """
271
+ evt_start, evt_end = self.config.window.event_window
272
+
273
+ if event_date not in self._date_to_idx:
274
+ return None
275
+ event_idx = self._date_to_idx[event_date]
276
+
277
+ result = {}
278
+ for rel_day in range(evt_start, evt_end + 1):
279
+ day_idx = event_idx + rel_day
280
+ if 0 <= day_idx < len(self._all_dates):
281
+ date = self._all_dates[day_idx]
282
+
283
+ # Get asset return
284
+ asset_ret = self._returns.filter(
285
+ (pl.col("asset") == asset) & (pl.col("date") == date)
286
+ )
287
+
288
+ if len(asset_ret) > 0 and date in self._benchmark_dict:
289
+ result[rel_day] = (
290
+ asset_ret["return"][0],
291
+ self._benchmark_dict[date],
292
+ )
293
+
294
+ return result if result else None
295
+
296
+ def _compute_abnormal_return_single(
297
+ self, event_row: dict[str, Any]
298
+ ) -> AbnormalReturnResult | None:
299
+ """Compute abnormal returns for a single event."""
300
+ asset = event_row["asset"]
301
+ event_date = event_row["date"]
302
+ event_id = str(event_row.get("event_id", f"{asset}_{event_date}"))
303
+
304
+ # Get estimation window data
305
+ est_data = self._get_estimation_window_data(asset, event_date)
306
+ if est_data is None:
307
+ return None
308
+
309
+ asset_est_returns, market_est_returns = est_data
310
+
311
+ # Estimate model parameters
312
+ alpha, beta, r2, residual_std = 0.0, 1.0, 0.0, 0.0
313
+
314
+ if self.config.model == "market_model":
315
+ alpha, beta, r2, residual_std = self._estimate_market_model(
316
+ asset_est_returns, market_est_returns
317
+ )
318
+ elif self.config.model == "mean_adjusted":
319
+ alpha = float(np.mean(asset_est_returns))
320
+ beta = 0.0
321
+ residual_std = float(np.std(asset_est_returns, ddof=1))
322
+ elif self.config.model == "market_adjusted":
323
+ alpha = 0.0
324
+ beta = 1.0
325
+ residual_std = float(np.std(asset_est_returns - market_est_returns, ddof=1))
326
+
327
+ # Get event window data
328
+ event_data = self._get_event_window_data(asset, event_date)
329
+ if event_data is None:
330
+ return None
331
+
332
+ # Compute abnormal returns
333
+ ar_by_day: dict[int, float] = {}
334
+ for rel_day, (asset_ret, market_ret) in event_data.items():
335
+ if self.config.model == "market_model":
336
+ expected_ret = alpha + beta * market_ret
337
+ elif self.config.model == "mean_adjusted":
338
+ expected_ret = alpha
339
+ else: # market_adjusted
340
+ expected_ret = market_ret
341
+
342
+ ar_by_day[rel_day] = asset_ret - expected_ret
343
+
344
+ # Compute CAR
345
+ car = sum(ar_by_day.values())
346
+
347
+ return AbnormalReturnResult(
348
+ event_id=event_id,
349
+ asset=asset,
350
+ event_date=str(event_date),
351
+ ar_by_day=ar_by_day,
352
+ car=car,
353
+ estimation_alpha=alpha if self.config.model == "market_model" else None,
354
+ estimation_beta=beta if self.config.model == "market_model" else None,
355
+ estimation_r2=r2 if self.config.model == "market_model" else None,
356
+ estimation_residual_std=residual_std,
357
+ )
358
+
359
+ def compute_abnormal_returns(self) -> list[AbnormalReturnResult]:
360
+ """Compute abnormal returns for all events.
361
+
362
+ Returns
363
+ -------
364
+ list[AbnormalReturnResult]
365
+ Abnormal return results for each valid event.
366
+ """
367
+ if self._ar_results is not None:
368
+ return self._ar_results
369
+
370
+ results = []
371
+ n_skipped = 0
372
+
373
+ for row in self._events.iter_rows(named=True):
374
+ result = self._compute_abnormal_return_single(row)
375
+ if result is not None:
376
+ results.append(result)
377
+ else:
378
+ n_skipped += 1
379
+
380
+ if n_skipped > 0:
381
+ warnings.warn(
382
+ f"Skipped {n_skipped} events due to insufficient data",
383
+ stacklevel=2,
384
+ )
385
+
386
+ self._ar_results = results
387
+ return results
388
+
389
+ def aggregate(self, group_by: str | None = None) -> EventStudyResult:
390
+ """Aggregate individual results to AAR and CAAR.
391
+
392
+ Parameters
393
+ ----------
394
+ group_by : str | None
395
+ Column to group by (e.g., 'event_type'). If None,
396
+ aggregates all events together.
397
+
398
+ Returns
399
+ -------
400
+ EventStudyResult
401
+ Aggregated event study results.
402
+ """
403
+ ar_results = self.compute_abnormal_returns()
404
+
405
+ if len(ar_results) == 0:
406
+ raise ValueError("No valid events to aggregate")
407
+
408
+ # Collect all relative days
409
+ all_days = set()
410
+ for r in ar_results:
411
+ all_days.update(r.ar_by_day.keys())
412
+ sorted_days = sorted(all_days)
413
+
414
+ # Compute AAR (average AR across events for each day)
415
+ aar_by_day: dict[int, float] = {}
416
+ ar_matrix: dict[int, list[float]] = {d: [] for d in sorted_days}
417
+
418
+ for r in ar_results:
419
+ for day in sorted_days:
420
+ if day in r.ar_by_day:
421
+ ar_matrix[day].append(r.ar_by_day[day])
422
+
423
+ for day in sorted_days:
424
+ if ar_matrix[day]:
425
+ aar_by_day[day] = float(np.mean(ar_matrix[day]))
426
+ else:
427
+ aar_by_day[day] = 0.0
428
+
429
+ # Compute CAAR and its statistics
430
+ caar_values = []
431
+ caar_std = []
432
+ cumsum = 0.0
433
+
434
+ for day in sorted_days:
435
+ cumsum += aar_by_day[day]
436
+ caar_values.append(cumsum)
437
+
438
+ # Cross-sectional standard deviation at this day
439
+ if ar_matrix[day]:
440
+ caar_std.append(float(np.std(ar_matrix[day], ddof=1)))
441
+ else:
442
+ caar_std.append(0.0)
443
+
444
+ # Compute confidence intervals
445
+ n_events = len(ar_results)
446
+ z_score = stats.norm.ppf(1 - self.config.alpha / 2)
447
+
448
+ caar_ci_lower = []
449
+ caar_ci_upper = []
450
+ for caar, std in zip(caar_values, caar_std, strict=False):
451
+ se = std / np.sqrt(n_events) if n_events > 0 else 0.0
452
+ caar_ci_lower.append(caar - z_score * se)
453
+ caar_ci_upper.append(caar + z_score * se)
454
+
455
+ # Run statistical test
456
+ test_stat, p_value = self._run_statistical_test(ar_results, ar_matrix)
457
+
458
+ result = EventStudyResult(
459
+ aar_by_day=aar_by_day,
460
+ caar=caar_values,
461
+ caar_dates=sorted_days,
462
+ caar_std=caar_std,
463
+ caar_ci_lower=caar_ci_lower,
464
+ caar_ci_upper=caar_ci_upper,
465
+ test_statistic=test_stat,
466
+ p_value=p_value,
467
+ test_name=self.config.test,
468
+ n_events=n_events,
469
+ model_name=self.config.model,
470
+ event_window=self.config.window.event_window,
471
+ confidence_level=self.config.confidence_level,
472
+ individual_results=ar_results,
473
+ )
474
+
475
+ self._aggregated_result = result
476
+ return result
477
+
478
+ def _run_statistical_test(
479
+ self,
480
+ ar_results: list[AbnormalReturnResult],
481
+ ar_matrix: dict[int, list[float]],
482
+ ) -> tuple[float, float]:
483
+ """Run statistical significance test.
484
+
485
+ Returns
486
+ -------
487
+ tuple[float, float]
488
+ (test_statistic, p_value)
489
+ """
490
+ if self.config.test == "t_test":
491
+ return self._t_test(ar_results, ar_matrix)
492
+ elif self.config.test == "boehmer":
493
+ return self._bmp_test(ar_results)
494
+ elif self.config.test == "corrado":
495
+ return self._corrado_test(ar_results, ar_matrix)
496
+ else:
497
+ return self._t_test(ar_results, ar_matrix)
498
+
499
+ def _t_test(
500
+ self,
501
+ ar_results: list[AbnormalReturnResult],
502
+ ar_matrix: dict[int, list[float]],
503
+ ) -> tuple[float, float]:
504
+ """Standard parametric t-test on CAAR.
505
+
506
+ H0: CAAR = 0
507
+ Test statistic: t = CAAR / SE(CAAR)
508
+ """
509
+ # Get CARs for all events
510
+ cars = [r.car for r in ar_results]
511
+ n = len(cars)
512
+
513
+ if n < 2:
514
+ return 0.0, 1.0
515
+
516
+ mean_car = np.mean(cars)
517
+ std_car = np.std(cars, ddof=1)
518
+ se_car = std_car / np.sqrt(n)
519
+
520
+ if se_car == 0:
521
+ return 0.0, 1.0
522
+
523
+ t_stat = mean_car / se_car
524
+ p_value = 2 * (1 - stats.t.cdf(abs(t_stat), df=n - 1))
525
+
526
+ return float(t_stat), float(p_value)
527
+
528
+ def _bmp_test(self, ar_results: list[AbnormalReturnResult]) -> tuple[float, float]:
529
+ """Boehmer, Musumeci, Poulsen (1991) test.
530
+
531
+ Robust to event-induced variance changes by standardizing
532
+ ARs by their estimation period volatility.
533
+
534
+ SAR_i = AR_i / σ_i
535
+ Test statistic: Z = (1/N) * Σ SAR_i / SE(SAR)
536
+ """
537
+ # Compute standardized abnormal returns
538
+ sars = []
539
+ for r in ar_results:
540
+ if r.estimation_residual_std and r.estimation_residual_std > 0:
541
+ sar = r.car / r.estimation_residual_std
542
+ else:
543
+ sar = r.car # Fallback to unstandardized
544
+ sars.append(sar)
545
+
546
+ n = len(sars)
547
+ if n < 2:
548
+ return 0.0, 1.0
549
+
550
+ mean_sar = np.mean(sars)
551
+ std_sar = np.std(sars, ddof=1)
552
+ se_sar = std_sar / np.sqrt(n)
553
+
554
+ if se_sar == 0:
555
+ return 0.0, 1.0
556
+
557
+ z_stat = mean_sar / se_sar
558
+ p_value = 2 * (1 - stats.norm.cdf(abs(z_stat)))
559
+
560
+ return float(z_stat), float(p_value)
561
+
562
+ def _corrado_test(
563
+ self,
564
+ ar_results: list[AbnormalReturnResult],
565
+ ar_matrix: dict[int, list[float]],
566
+ ) -> tuple[float, float]:
567
+ """Corrado (1989) non-parametric rank test.
568
+
569
+ Robust to non-normality in returns. Uses ranks instead of
570
+ raw abnormal returns.
571
+ """
572
+ n_events = len(ar_results)
573
+ if n_events < 2:
574
+ return 0.0, 1.0
575
+
576
+ # For simplicity, test at t=0 (event day)
577
+ if 0 not in ar_matrix or len(ar_matrix[0]) < 2:
578
+ # Fallback to t-test
579
+ return self._t_test(ar_results, ar_matrix)
580
+
581
+ event_day_ars = np.array(ar_matrix[0])
582
+
583
+ # Rank the ARs
584
+ ranks = stats.rankdata(event_day_ars)
585
+ expected_rank = (n_events + 1) / 2
586
+
587
+ # Compute test statistic
588
+ rank_deviations = ranks - expected_rank
589
+ mean_deviation = np.mean(rank_deviations)
590
+
591
+ # Standard deviation of ranks under null
592
+ std_rank = np.std(rank_deviations, ddof=1)
593
+ se_rank = std_rank / np.sqrt(n_events)
594
+
595
+ if se_rank == 0:
596
+ return 0.0, 1.0
597
+
598
+ z_stat = mean_deviation / se_rank
599
+ p_value = 2 * (1 - stats.norm.cdf(abs(z_stat)))
600
+
601
+ return float(z_stat), float(p_value)
602
+
603
+ def run(self) -> EventStudyResult:
604
+ """Run complete event study analysis.
605
+
606
+ This is the main entry point that computes abnormal returns,
607
+ aggregates results, and runs statistical tests.
608
+
609
+ Returns
610
+ -------
611
+ EventStudyResult
612
+ Complete event study results.
613
+
614
+ Examples
615
+ --------
616
+ >>> analysis = EventStudyAnalysis(returns, events, benchmark)
617
+ >>> result = analysis.run()
618
+ >>> print(result.summary())
619
+ >>> if result.is_significant:
620
+ ... print("Significant abnormal returns detected!")
621
+ """
622
+ return self.aggregate()
623
+
624
+ def create_tear_sheet(self) -> EventStudyResult:
625
+ """Alias for run() - creates complete event study results."""
626
+ return self.run()
627
+
628
+ @property
629
+ def n_events(self) -> int:
630
+ """Number of events in the study."""
631
+ return len(self._events)
632
+
633
+ @property
634
+ def n_valid_events(self) -> int:
635
+ """Number of events with sufficient data for analysis."""
636
+ ar_results = self.compute_abnormal_returns()
637
+ return len(ar_results)
638
+
639
+ @property
640
+ def assets(self) -> list[str]:
641
+ """List of unique assets in the events."""
642
+ return self._events["asset"].unique().sort().to_list()
643
+
644
+ @property
645
+ def date_range(self) -> tuple[Any, Any]:
646
+ """Date range of the returns data."""
647
+ return self._all_dates[0], self._all_dates[-1]