decline-curve 0.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. decline_curve-0.1.2/LICENSE +1 -0
  2. decline_curve-0.1.2/PKG-INFO +70 -0
  3. decline_curve-0.1.2/README.md +45 -0
  4. decline_curve-0.1.2/decline_analysis/__init__.py +1 -0
  5. decline_curve-0.1.2/decline_analysis/__main__.py +51 -0
  6. decline_curve-0.1.2/decline_analysis/data.py +100 -0
  7. decline_curve-0.1.2/decline_analysis/dca.py +164 -0
  8. decline_curve-0.1.2/decline_analysis/economics.py +33 -0
  9. decline_curve-0.1.2/decline_analysis/evaluate.py +56 -0
  10. decline_curve-0.1.2/decline_analysis/example.py +9 -0
  11. decline_curve-0.1.2/decline_analysis/forecast.py +109 -0
  12. decline_curve-0.1.2/decline_analysis/forecast_arima.py +249 -0
  13. decline_curve-0.1.2/decline_analysis/forecast_chronos.py +229 -0
  14. decline_curve-0.1.2/decline_analysis/forecast_timesfm.py +183 -0
  15. decline_curve-0.1.2/decline_analysis/models.py +190 -0
  16. decline_curve-0.1.2/decline_analysis/plot.py +289 -0
  17. decline_curve-0.1.2/decline_analysis/reserves.py +28 -0
  18. decline_curve-0.1.2/decline_analysis/sensitivity.py +63 -0
  19. decline_curve-0.1.2/decline_analysis/utils/data_loader.py +50 -0
  20. decline_curve-0.1.2/decline_curve.egg-info/PKG-INFO +70 -0
  21. decline_curve-0.1.2/decline_curve.egg-info/SOURCES.txt +35 -0
  22. decline_curve-0.1.2/decline_curve.egg-info/dependency_links.txt +1 -0
  23. decline_curve-0.1.2/decline_curve.egg-info/entry_points.txt +2 -0
  24. decline_curve-0.1.2/decline_curve.egg-info/requires.txt +14 -0
  25. decline_curve-0.1.2/decline_curve.egg-info/top_level.txt +1 -0
  26. decline_curve-0.1.2/pyproject.toml +66 -0
  27. decline_curve-0.1.2/setup.cfg +4 -0
  28. decline_curve-0.1.2/tests/test_data_loader.py +453 -0
  29. decline_curve-0.1.2/tests/test_dca.py +243 -0
  30. decline_curve-0.1.2/tests/test_economics.py +343 -0
  31. decline_curve-0.1.2/tests/test_evaluate.py +216 -0
  32. decline_curve-0.1.2/tests/test_forecast.py +260 -0
  33. decline_curve-0.1.2/tests/test_forecast_arima.py +283 -0
  34. decline_curve-0.1.2/tests/test_models.py +213 -0
  35. decline_curve-0.1.2/tests/test_plot.py +293 -0
  36. decline_curve-0.1.2/tests/test_reserves.py +382 -0
  37. decline_curve-0.1.2/tests/test_sensitivity.py +250 -0
@@ -0,0 +1 @@
1
+ Apache License 2.0
@@ -0,0 +1,70 @@
1
+ Metadata-Version: 2.4
2
+ Name: decline_curve
3
+ Version: 0.1.2
4
+ Summary: Decline curve analysis for oil well production using with Arps and LLM-based models
5
+ Author: Kyle T. Jones
6
+ License-Expression: Apache-2.0
7
+ Requires-Python: >=3.9
8
+ Description-Content-Type: text/markdown
9
+ License-File: LICENSE
10
+ Requires-Dist: numpy>=1.23
11
+ Requires-Dist: pandas>=2.0
12
+ Requires-Dist: scipy>=1.10
13
+ Requires-Dist: matplotlib>=3.7
14
+ Requires-Dist: statsmodels>=0.14
15
+ Requires-Dist: tqdm>=4.66
16
+ Requires-Dist: transformers>=4.41
17
+ Requires-Dist: torch>=2.0
18
+ Requires-Dist: accelerate>=0.29
19
+ Requires-Dist: einops>=0.7
20
+ Requires-Dist: pmdarima>=2.0
21
+ Requires-Dist: numpy-financial>=1.0
22
+ Requires-Dist: requests>=2.25
23
+ Requires-Dist: xlrd>=2.0
24
+ Dynamic: license-file
25
+
26
+ # Decline Analysis
27
+
28
+ A Python package for decline curve analysis of oil well production using Arps models and LLM-based forecasting methods.
29
+
30
+ ## Features
31
+
32
+ - Traditional Arps decline curve analysis (exponential, hyperbolic, harmonic)
33
+ - Advanced forecasting with machine learning models (ARIMA, Chronos, TimesFM)
34
+ - Economic analysis and reserves estimation
35
+ - Sensitivity analysis capabilities
36
+ - Data visualization and plotting tools
37
+
38
+ ## Installation
39
+
40
+ ```bash
41
+ pip install decline-analysis
42
+ ```
43
+
44
+ ## Usage
45
+
46
+ ```python
47
+ from decline_analysis import dca
48
+
49
+ # Load your production data
50
+ # Run decline curve analysis
51
+ # Generate forecasts and economic metrics
52
+ ```
53
+
54
+ ## Development
55
+
56
+ Install in development mode:
57
+
58
+ ```bash
59
+ pip install -e .
60
+ ```
61
+
62
+ Run tests:
63
+
64
+ ```bash
65
+ pytest
66
+ ```
67
+
68
+ ## License
69
+
70
+ Apache-2.0
@@ -0,0 +1,45 @@
1
+ # Decline Analysis
2
+
3
+ A Python package for decline curve analysis of oil well production using Arps models and LLM-based forecasting methods.
4
+
5
+ ## Features
6
+
7
+ - Traditional Arps decline curve analysis (exponential, hyperbolic, harmonic)
8
+ - Advanced forecasting with machine learning models (ARIMA, Chronos, TimesFM)
9
+ - Economic analysis and reserves estimation
10
+ - Sensitivity analysis capabilities
11
+ - Data visualization and plotting tools
12
+
13
+ ## Installation
14
+
15
+ ```bash
16
+ pip install decline-analysis
17
+ ```
18
+
19
+ ## Usage
20
+
21
+ ```python
22
+ from decline_analysis import dca
23
+
24
+ # Load your production data
25
+ # Run decline curve analysis
26
+ # Generate forecasts and economic metrics
27
+ ```
28
+
29
+ ## Development
30
+
31
+ Install in development mode:
32
+
33
+ ```bash
34
+ pip install -e .
35
+ ```
36
+
37
+ Run tests:
38
+
39
+ ```bash
40
+ pytest
41
+ ```
42
+
43
+ ## License
44
+
45
+ Apache-2.0
@@ -0,0 +1 @@
1
+ from . import dca
@@ -0,0 +1,51 @@
1
+ import argparse
2
+
3
+ import pandas as pd
4
+
5
+ from . import dca
6
+
7
+
8
+ def main():
9
+ parser = argparse.ArgumentParser(description="Decline curve forecast tool")
10
+ parser.add_argument("csv", help="Input CSV file")
11
+ parser.add_argument(
12
+ "--model", default="arps", choices=["arps", "timesfm", "chronos"]
13
+ )
14
+ parser.add_argument(
15
+ "--kind",
16
+ default="hyperbolic",
17
+ choices=["exponential", "harmonic", "hyperbolic"],
18
+ )
19
+ parser.add_argument("--horizon", type=int, default=12)
20
+ parser.add_argument("--well", help="Well ID to forecast")
21
+ parser.add_argument("--benchmark", action="store_true")
22
+ parser.add_argument("--top_n", type=int, default=10)
23
+ parser.add_argument("--verbose", action="store_true")
24
+ args = parser.parse_args()
25
+
26
+ df = pd.read_csv(args.csv)
27
+
28
+ if args.benchmark:
29
+ result = dca.benchmark(
30
+ df,
31
+ model=args.model,
32
+ kind=args.kind,
33
+ horizon=args.horizon,
34
+ top_n=args.top_n,
35
+ verbose=args.verbose,
36
+ )
37
+ print(result.to_string(index=False))
38
+ else:
39
+ if args.well is None:
40
+ raise ValueError("Must provide --well when not using --benchmark")
41
+ sub = df[df["well_id"] == args.well].copy()
42
+ sub["date"] = pd.to_datetime(sub["date"])
43
+ y = sub.set_index("date")["oil_bbl"].asfreq("MS")
44
+ yhat = dca.forecast(
45
+ y,
46
+ model=args.model,
47
+ kind=args.kind,
48
+ horizon=args.horizon,
49
+ verbose=args.verbose,
50
+ )
51
+ dca.plot(y, yhat, title=f"{args.well} {args.model}")
@@ -0,0 +1,100 @@
1
+ from pathlib import Path
2
+ from typing import List, Optional
3
+
4
+ import pandas as pd
5
+
6
+
7
+ def load_production_csvs(
8
+ paths: List[str],
9
+ date_col: str = "date",
10
+ well_id_col: str = "well_id",
11
+ oil_col: str = "oil_bbl",
12
+ ) -> pd.DataFrame:
13
+ """Load and stack well-level production CSV files.
14
+
15
+ Args:
16
+ paths: List of CSV paths.
17
+ date_col: Name of the date column.
18
+ well_id_col: Name of the well id column.
19
+ oil_col: Name of the oil volume column.
20
+
21
+ Returns:
22
+ A DataFrame with [date, well_id, oil_bbl] and a DateTimeIndex.
23
+ """
24
+ frames = []
25
+ for p in paths:
26
+ df = pd.read_csv(p)
27
+ _assert_cols(df, [date_col, well_id_col, oil_col])
28
+ df[date_col] = pd.to_datetime(df[date_col])
29
+ frames.append(df[[date_col, well_id_col, oil_col]])
30
+ out = pd.concat(frames, ignore_index=True)
31
+ out = out.sort_values([well_id_col, date_col])
32
+ out = out.rename(
33
+ columns={date_col: "date", well_id_col: "well_id", oil_col: "oil_bbl"}
34
+ )
35
+ return out.set_index("date")
36
+
37
+
38
+ def to_monthly(
39
+ df: pd.DataFrame, well_id_col: str = "well_id", oil_col: str = "oil_bbl"
40
+ ) -> pd.DataFrame:
41
+ """Aggregate to monthly frequency.
42
+
43
+ Args:
44
+ df: A DataFrame indexed by date.
45
+ well_id_col: Well id column.
46
+ oil_col: Oil volume column.
47
+
48
+ Returns:
49
+ A monthly panel by well.
50
+ """
51
+ return (
52
+ df.groupby(well_id_col)
53
+ .resample("M")[oil_col]
54
+ .sum()
55
+ .reset_index()
56
+ .set_index("date")
57
+ )
58
+
59
+
60
+ def make_panel(df: pd.DataFrame, first_n_months: Optional[int] = None) -> pd.DataFrame:
61
+ """Create a relative-time panel for decline fitting.
62
+
63
+ Args:
64
+ df: Monthly panel with index date and columns well_id and oil_bbl.
65
+ first_n_months: Truncate each well to N months from first production.
66
+
67
+ Returns:
68
+ A panel with cycle (t) per well.
69
+ """
70
+ df = df.sort_values(["well_id", "date"]).copy()
71
+ df["t"] = df.groupby("well_id").cumcount()
72
+ if first_n_months is not None:
73
+ df = df[df["t"] < first_n_months]
74
+ return df
75
+
76
+
77
+ def load_price_csv(
78
+ path: str, date_col: str = "date", price_col: str = "price"
79
+ ) -> pd.DataFrame:
80
+ """Load oil price CSV.
81
+
82
+ Args:
83
+ path: Path to price CSV.
84
+ date_col: Date column name.
85
+ price_col: Price column name.
86
+
87
+ Returns:
88
+ A DataFrame indexed by date with a single price column.
89
+ """
90
+ df = pd.read_csv(path)
91
+ _assert_cols(df, [date_col, price_col])
92
+ df[date_col] = pd.to_datetime(df[date_col])
93
+ df = df.rename(columns={date_col: "date", price_col: "price"}).set_index("date")
94
+ return df.sort_index()
95
+
96
+
97
+ def _assert_cols(df: pd.DataFrame, cols: List[str]) -> None:
98
+ missing = [c for c in cols if c not in df.columns]
99
+ if missing:
100
+ raise ValueError(f"Missing columns: {missing}")
@@ -0,0 +1,164 @@
1
+ from typing import Dict, List, Literal, Optional, Tuple
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+
6
+ from .economics import economic_metrics
7
+ from .evaluate import mae, rmse, smape
8
+ from .forecast import Forecaster
9
+ from .models import ArpsParams
10
+ from .plot import plot_forecast
11
+ from .reserves import forecast_and_reserves
12
+ from .sensitivity import run_sensitivity
13
+ from .utils.data_loader import scrape_ndic
14
+
15
+
16
+ def forecast(
17
+ series: pd.Series,
18
+ model: Literal["arps", "timesfm", "chronos", "arima"] = "arps",
19
+ kind: Optional[Literal["exponential", "harmonic", "hyperbolic"]] = "hyperbolic",
20
+ horizon: int = 12,
21
+ verbose: bool = False,
22
+ ) -> pd.Series:
23
+ fc = Forecaster(series)
24
+ result = fc.forecast(model=model, kind=kind, horizon=horizon)
25
+ if verbose:
26
+ print(f"Forecast model: {model}, horizon: {horizon}")
27
+ print(result.head())
28
+ return result
29
+
30
+
31
+ def evaluate(y_true: pd.Series, y_pred: pd.Series) -> dict:
32
+ common = y_true.index.intersection(y_pred.index)
33
+ yt = y_true.loc[common]
34
+ yp = y_pred.loc[common]
35
+ return {
36
+ "rmse": rmse(yt, yp),
37
+ "mae": mae(yt, yp),
38
+ "smape": smape(yt, yp),
39
+ }
40
+
41
+
42
+ def plot(
43
+ y: pd.Series,
44
+ yhat: pd.Series,
45
+ title: str = "Forecast",
46
+ filename: Optional[str] = None,
47
+ ):
48
+ plot_forecast(y, yhat, title, filename)
49
+
50
+
51
+ def benchmark(
52
+ df: pd.DataFrame,
53
+ model: Literal["arps", "timesfm", "chronos", "arima"] = "arps",
54
+ kind: Optional[str] = "hyperbolic",
55
+ horizon: int = 12,
56
+ well_col: str = "well_id",
57
+ date_col: str = "date",
58
+ value_col: str = "oil_bbl",
59
+ top_n: int = 10,
60
+ verbose: bool = False,
61
+ ) -> pd.DataFrame:
62
+ out = []
63
+ wells = df[well_col].unique()[:top_n]
64
+ for wid in wells:
65
+ wdf = df[df[well_col] == wid].copy()
66
+ wdf = wdf[[date_col, value_col]].dropna()
67
+ wdf[date_col] = pd.to_datetime(wdf[date_col])
68
+ wdf = wdf.set_index(date_col).asfreq("MS")
69
+ if len(wdf) < 24:
70
+ continue
71
+ try:
72
+ y = wdf[value_col]
73
+ yhat = forecast(y, model=model, kind=kind, horizon=horizon)
74
+ metrics = evaluate(y, yhat)
75
+ metrics[well_col] = wid
76
+ out.append(metrics)
77
+ if verbose:
78
+ print(f"{wid}: {metrics}")
79
+ except Exception as e:
80
+ if verbose:
81
+ print(f"{wid} failed: {e}")
82
+ continue
83
+ return pd.DataFrame(out)
84
+
85
+
86
+ def sensitivity_analysis(
87
+ param_grid: List[Tuple[float, float, float]],
88
+ prices: List[float],
89
+ opex: float,
90
+ discount_rate: float = 0.10,
91
+ t_max: float = 240,
92
+ econ_limit: float = 10.0,
93
+ dt: float = 1.0,
94
+ ) -> pd.DataFrame:
95
+ """
96
+ Run sensitivity analysis across Arps parameters and oil/gas prices.
97
+
98
+ Args:
99
+ param_grid: List of (qi, di, b) tuples to test
100
+ prices: List of oil/gas prices to test
101
+ opex: Operating cost per unit
102
+ discount_rate: Annual discount rate (default 0.10)
103
+ t_max: Time horizon in months (default 240)
104
+ econ_limit: Minimum economic production rate (default 10.0)
105
+ dt: Time step in months (default 1.0)
106
+
107
+ Returns:
108
+ DataFrame with sensitivity results including EUR, NPV, and payback
109
+ """
110
+ return run_sensitivity(
111
+ param_grid, prices, opex, discount_rate, t_max, econ_limit, dt
112
+ )
113
+
114
+
115
+ def economics(
116
+ production: pd.Series, price: float, opex: float, discount_rate: float = 0.10
117
+ ) -> Dict:
118
+ """
119
+ Calculate economic metrics from production forecast.
120
+
121
+ Args:
122
+ production: Monthly production forecast
123
+ price: Unit price ($/bbl or $/mcf)
124
+ opex: Operating cost per unit
125
+ discount_rate: Annual discount rate (default 0.10)
126
+
127
+ Returns:
128
+ Dictionary with NPV, cash flow, and payback period
129
+ """
130
+ return economic_metrics(production.values, price, opex, discount_rate)
131
+
132
+
133
+ def reserves(
134
+ params: ArpsParams, t_max: float = 240, dt: float = 1.0, econ_limit: float = 10.0
135
+ ) -> Dict:
136
+ """
137
+ Generate production forecast and compute EUR (Estimated Ultimate Recovery).
138
+
139
+ Args:
140
+ params: Arps decline parameters (qi, di, b)
141
+ t_max: Time horizon in months (default 240)
142
+ dt: Time step in months (default 1.0)
143
+ econ_limit: Minimum economic production rate (default 10.0)
144
+
145
+ Returns:
146
+ Dictionary with forecast, time arrays, and EUR
147
+ """
148
+ return forecast_and_reserves(params, t_max, dt, econ_limit)
149
+
150
+
151
+ def load_ndic_data(
152
+ months_list: List[str], output_dir: str = "ndic_raw"
153
+ ) -> pd.DataFrame:
154
+ """
155
+ Load North Dakota Industrial Commission (NDIC) production data.
156
+
157
+ Args:
158
+ months_list: List of month strings (e.g., ['2023-01', '2023-02'])
159
+ output_dir: Directory to save raw data files (default 'ndic_raw')
160
+
161
+ Returns:
162
+ Combined DataFrame with production data
163
+ """
164
+ return scrape_ndic(months_list, output_dir)
@@ -0,0 +1,33 @@
1
+ from typing import Dict
2
+
3
+ import numpy as np
4
+ from numpy_financial import npv
5
+
6
+
7
+ def economic_metrics(
8
+ q: np.ndarray,
9
+ price: float,
10
+ opex: float,
11
+ discount_rate: float = 0.10,
12
+ time_step_months: float = 1.0,
13
+ ) -> Dict:
14
+ """
15
+ Calculate economics from forecasted production.
16
+
17
+ Args:
18
+ q: Production forecast (monthly).
19
+ price: Unit price.
20
+ opex: Operating cost.
21
+ discount_rate: Annual discount rate.
22
+ time_step_months: Length of time step in months.
23
+
24
+ Returns:
25
+ Dict with cash flow, NPV, payback.
26
+ """
27
+ monthly_rate = discount_rate / 12
28
+ net_revenue = (price - opex) * q
29
+ cash_flow = net_revenue
30
+ npv_val = npv(monthly_rate, cash_flow)
31
+ cum_cf = np.cumsum(cash_flow)
32
+ payback_month = int(np.argmax(cum_cf > 0)) if np.any(cum_cf > 0) else -1
33
+ return {"npv": npv_val, "cash_flow": cash_flow, "payback_month": payback_month}
@@ -0,0 +1,56 @@
1
+ """
2
+ Evaluation metrics for decline curve analysis forecasts.
3
+ """
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+
9
+ def rmse(y_true: pd.Series, y_pred: pd.Series) -> float:
10
+ """Root Mean Square Error."""
11
+ return np.sqrt(np.mean((y_true - y_pred) ** 2))
12
+
13
+
14
+ def mae(y_true: pd.Series, y_pred: pd.Series) -> float:
15
+ """Mean Absolute Error."""
16
+ return np.mean(np.abs(y_true - y_pred))
17
+
18
+
19
+ def smape(y_true: pd.Series, y_pred: pd.Series) -> float:
20
+ """Symmetric Mean Absolute Percentage Error."""
21
+ numerator = np.abs(y_pred - y_true)
22
+ denominator = (np.abs(y_true) + np.abs(y_pred)) / 2
23
+ return np.mean(numerator / denominator) * 100
24
+
25
+
26
+ def mape(y_true: pd.Series, y_pred: pd.Series) -> float:
27
+ """Mean Absolute Percentage Error."""
28
+ return np.mean(np.abs((y_true - y_pred) / y_true)) * 100
29
+
30
+
31
+ def r2_score(y_true: pd.Series, y_pred: pd.Series) -> float:
32
+ """R-squared coefficient of determination."""
33
+ ss_res = np.sum((y_true - y_pred) ** 2)
34
+ ss_tot = np.sum((y_true - np.mean(y_true)) ** 2)
35
+
36
+ # Handle constant values case where ss_tot = 0
37
+ if ss_tot == 0:
38
+ # If actual values are constant and predictions match, R² = 1
39
+ if ss_res == 0:
40
+ return 1.0
41
+ # If actual values are constant but predictions don't match, R² = 0
42
+ else:
43
+ return 0.0
44
+
45
+ return 1 - (ss_res / ss_tot)
46
+
47
+
48
+ def evaluate_forecast(y_true: pd.Series, y_pred: pd.Series) -> dict:
49
+ """Comprehensive evaluation of forecast performance."""
50
+ return {
51
+ "rmse": rmse(y_true, y_pred),
52
+ "mae": mae(y_true, y_pred),
53
+ "smape": smape(y_true, y_pred),
54
+ "mape": mape(y_true, y_pred),
55
+ "r2": r2_score(y_true, y_pred),
56
+ }
@@ -0,0 +1,9 @@
1
+ import pandas as pd
2
+
3
+ from decline_analysis import dca
4
+
5
+ # Example usage - replace with your actual data
6
+ # df = pd.read_csv("your_production_data.csv")
7
+ # series = df[df["well_id"] == "WELL_001"].set_index("date")["oil_bbl"]
8
+ # yhat = dca.forecast(series, model="arps", kind="hyperbolic", horizon=12)
9
+ # dca.plot(series, yhat)
@@ -0,0 +1,109 @@
1
+ from typing import Literal, Optional
2
+
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import pandas as pd
6
+
7
+ from .forecast_chronos import forecast_chronos
8
+ from .forecast_timesfm import forecast_timesfm
9
+ from .models import fit_arps, predict_arps
10
+
11
+ try:
12
+ from .forecast_arima import forecast_arima
13
+
14
+ ARIMA_AVAILABLE = True
15
+ except ImportError:
16
+ ARIMA_AVAILABLE = False
17
+
18
+ def forecast_arima(*args, **kwargs):
19
+ raise ImportError("ARIMA forecasting is not available due to dependency issues")
20
+
21
+
22
+ from .evaluate import mae, rmse, smape
23
+ from .plot import _range_markers, tufte_style
24
+
25
+
26
+ class Forecaster:
27
+ def __init__(self, series: pd.Series):
28
+ if not isinstance(series.index, pd.DatetimeIndex):
29
+ raise ValueError("Input must be indexed by datetime")
30
+ if not series.index.freq:
31
+ series = series.asfreq(pd.infer_freq(series.index))
32
+ self.series = series.dropna().copy()
33
+ self.last_forecast = None
34
+
35
+ def forecast(
36
+ self,
37
+ model: Literal["arps", "timesfm", "chronos", "arima"],
38
+ kind: Optional[Literal["exponential", "harmonic", "hyperbolic"]] = "hyperbolic",
39
+ horizon: Optional[int] = 12,
40
+ ) -> pd.Series:
41
+ if model == "arps":
42
+ t = np.arange(len(self.series))
43
+ q = self.series.to_numpy()
44
+ params = fit_arps(t, q, kind=kind)
45
+ full_t = np.arange(len(self.series) + horizon)
46
+ yhat = predict_arps(full_t, params)
47
+ idx = pd.date_range(
48
+ self.series.index[0], periods=len(yhat), freq=self.series.index.freq
49
+ )
50
+ forecast = pd.Series(yhat, index=idx, name=f"arps_{kind}")
51
+
52
+ elif model == "timesfm":
53
+ forecast = forecast_timesfm(self.series, horizon=horizon)
54
+
55
+ elif model == "chronos":
56
+ forecast = forecast_chronos(self.series, horizon=horizon)
57
+
58
+ elif model == "arima":
59
+ forecast_part = forecast_arima(self.series, horizon=horizon)
60
+ # Combine historical and forecast data
61
+ full_index = pd.date_range(
62
+ self.series.index[0],
63
+ periods=len(self.series) + horizon,
64
+ freq=self.series.index.freq,
65
+ )
66
+ full_forecast = pd.concat([self.series, forecast_part])
67
+ forecast = pd.Series(
68
+ full_forecast.values, index=full_index, name="arima_forecast"
69
+ )
70
+
71
+ else:
72
+ raise ValueError(f"Unknown model: {model}")
73
+
74
+ self.last_forecast = forecast
75
+ return forecast
76
+
77
+ def evaluate(self, actual: pd.Series) -> dict:
78
+ if self.last_forecast is None:
79
+ raise RuntimeError("Call .forecast() first.")
80
+ common = self.last_forecast.index.intersection(actual.index)
81
+ if len(common) == 0:
82
+ raise ValueError("No overlapping dates to compare.")
83
+ yhat = self.last_forecast.loc[common]
84
+ ytrue = actual.loc[common]
85
+ return {
86
+ "rmse": rmse(ytrue, yhat),
87
+ "mae": mae(ytrue, yhat),
88
+ "smape": smape(ytrue, yhat),
89
+ }
90
+
91
+ def plot(self, title: str = "Forecast", filename: Optional[str] = None):
92
+ if self.last_forecast is None:
93
+ raise RuntimeError("Call .forecast() first.")
94
+ tufte_style()
95
+ fig, ax = plt.subplots()
96
+ hist = self.series
97
+ fcst = self.last_forecast
98
+
99
+ ax.plot(hist.index, hist.values, lw=1.0, label="history")
100
+ ax.plot(fcst.index, fcst.values, lw=1.2, label="forecast")
101
+
102
+ _range_markers(ax, hist.values)
103
+ ax.set_xlabel("Date")
104
+ ax.set_ylabel("Production")
105
+ ax.set_title(title)
106
+ ax.legend()
107
+ if filename:
108
+ plt.savefig(filename, bbox_inches="tight")
109
+ plt.show()