prophet-rb 0.4.0 → 0.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +101 -4
- data/lib/prophet/diagnostics.rb +349 -0
- data/lib/prophet/forecaster.rb +18 -4
- data/lib/prophet/plot.rb +56 -6
- data/lib/prophet/stan_backend.rb +10 -1
- data/lib/prophet/version.rb +1 -1
- data/lib/prophet.rb +1 -0
- metadata +3 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 69d58f060a9bda44b1ab8666ded81b3e61ff1b95b08ae40727fa7a2dfee57ee0
|
4
|
+
data.tar.gz: dc76685a8b45ca7cad79561f986af9c66e5e49bd45dd5e10c72f91088d9b470a
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 34d2fd0587110c6de9db334c44a4859a5de92d1cc124efb5f865c845d504816f3d3cd34d63776a709de396af8f9b2cf82a390cfb9f3e49e685ef8be4f469fb3e
|
7
|
+
data.tar.gz: '0587986407bb68a9ca97928ad65bcbb42fe101bd9d2e50a44b2126df38d67c04544932ec39009c6f1d6467736aa0a8404ac4aad04187dcb953285d4d5ee80a77'
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -88,7 +88,8 @@ Check out the [Prophet documentation](https://facebook.github.io/prophet/docs/qu
|
|
88
88
|
- [Multiplicative Seasonality](#multiplicative-seasonality)
|
89
89
|
- [Uncertainty Intervals](#uncertainty-intervals)
|
90
90
|
- [Non-Daily Data](#non-daily-data)
|
91
|
-
- [
|
91
|
+
- [Diagnostics](#diagnostics)
|
92
|
+
- [Additional Topics](#additional-topics)
|
92
93
|
|
93
94
|
## Advanced Quick Start
|
94
95
|
|
@@ -177,11 +178,24 @@ df = Rover.read_csv("example_wp_log_R.csv")
|
|
177
178
|
df["cap"] = 8.5
|
178
179
|
m = Prophet.new(growth: "logistic")
|
179
180
|
m.fit(df)
|
180
|
-
future = m.make_future_dataframe(periods:
|
181
|
+
future = m.make_future_dataframe(periods: 1826)
|
181
182
|
future["cap"] = 8.5
|
182
183
|
forecast = m.predict(future)
|
183
184
|
```
|
184
185
|
|
186
|
+
Saturating minimum
|
187
|
+
|
188
|
+
```ruby
|
189
|
+
df["y"] = 10 - df["y"]
|
190
|
+
df["cap"] = 6
|
191
|
+
df["floor"] = 1.5
|
192
|
+
future["cap"] = 6
|
193
|
+
future["floor"] = 1.5
|
194
|
+
m = Prophet.new(growth: "logistic")
|
195
|
+
m.fit(df)
|
196
|
+
forecast = m.predict(future)
|
197
|
+
```
|
198
|
+
|
185
199
|
## Trend Changepoints
|
186
200
|
|
187
201
|
[Explanation](https://facebook.github.io/prophet/docs/trend_changepoints.html)
|
@@ -308,9 +322,64 @@ future = m.make_future_dataframe(periods: 300, freq: "H")
|
|
308
322
|
forecast = m.predict(future)
|
309
323
|
```
|
310
324
|
|
311
|
-
##
|
325
|
+
## Diagnostics
|
326
|
+
|
327
|
+
[Explanation](http://facebook.github.io/prophet/docs/diagnostics.html)
|
328
|
+
|
329
|
+
Cross validation
|
330
|
+
|
331
|
+
```ruby
|
332
|
+
df_cv = Prophet::Diagnostics.cross_validation(m, initial: "730 days", period: "180 days", horizon: "365 days")
|
333
|
+
```
|
334
|
+
|
335
|
+
Custom cutoffs
|
336
|
+
|
337
|
+
```ruby
|
338
|
+
cutoffs = ["2013-02-15", "2013-08-15", "2014-02-15"].map { |v| Time.parse("#{v} 00:00:00 UTC") }
|
339
|
+
df_cv2 = Prophet::Diagnostics.cross_validation(m, cutoffs: cutoffs, horizon: "365 days")
|
340
|
+
```
|
341
|
+
|
342
|
+
Get performance metrics
|
312
343
|
|
313
|
-
|
344
|
+
```ruby
|
345
|
+
df_p = Prophet::Diagnostics.performance_metrics(df_cv)
|
346
|
+
```
|
347
|
+
|
348
|
+
Plot cross validation metrics
|
349
|
+
|
350
|
+
```ruby
|
351
|
+
Prophet::Plot.plot_cross_validation_metric(df_cv, metric: "mape")
|
352
|
+
```
|
353
|
+
|
354
|
+
Hyperparameter tuning
|
355
|
+
|
356
|
+
```ruby
|
357
|
+
param_grid = {
|
358
|
+
changepoint_prior_scale: [0.001, 0.01, 0.1, 0.5],
|
359
|
+
seasonality_prior_scale: [0.01, 0.1, 1.0, 10.0]
|
360
|
+
}
|
361
|
+
|
362
|
+
# Generate all combinations of parameters
|
363
|
+
all_params = param_grid.values[0].product(*param_grid.values[1..-1]).map { |v| param_grid.keys.zip(v).to_h }
|
364
|
+
rmses = [] # Store the RMSEs for each params here
|
365
|
+
|
366
|
+
# Use cross validation to evaluate all parameters
|
367
|
+
all_params.each do |params|
|
368
|
+
m = Prophet.new(**params).fit(df) # Fit model with given params
|
369
|
+
df_cv = Prophet::Diagnostics.cross_validation(m, cutoffs: cutoffs, horizon: "30 days")
|
370
|
+
df_p = Prophet::Diagnostics.performance_metrics(df_cv, rolling_window: 1)
|
371
|
+
rmses << df_p["rmse"][0]
|
372
|
+
end
|
373
|
+
|
374
|
+
# Find the best parameters
|
375
|
+
tuning_results = Rover::DataFrame.new(all_params)
|
376
|
+
tuning_results["rmse"] = rmses
|
377
|
+
p tuning_results
|
378
|
+
```
|
379
|
+
|
380
|
+
## Additional Topics
|
381
|
+
|
382
|
+
[Explanation](https://facebook.github.io/prophet/docs/additional_topics.html)
|
314
383
|
|
315
384
|
Save a model
|
316
385
|
|
@@ -326,6 +395,34 @@ m = Prophet.from_json(File.read("model.json"))
|
|
326
395
|
|
327
396
|
Uses the same format as Python, so models can be saved and loaded in either language
|
328
397
|
|
398
|
+
Flat trend
|
399
|
+
|
400
|
+
```ruby
|
401
|
+
m = Prophet.new(growth: "flat")
|
402
|
+
```
|
403
|
+
|
404
|
+
Updating fitted models
|
405
|
+
|
406
|
+
```ruby
|
407
|
+
def stan_init(m)
|
408
|
+
res = {}
|
409
|
+
["k", "m", "sigma_obs"].each do |pname|
|
410
|
+
res[pname] = m.params[pname][0, true][0]
|
411
|
+
end
|
412
|
+
["delta", "beta"].each do |pname|
|
413
|
+
res[pname] = m.params[pname][0, true]
|
414
|
+
end
|
415
|
+
res
|
416
|
+
end
|
417
|
+
|
418
|
+
df = Rover.read_csv("example_wp_log_peyton_manning.csv")
|
419
|
+
df1 = df[df["ds"] <= "2016-01-19"] # All data except the last day
|
420
|
+
m1 = Prophet.new.fit(df1) # A model fit to all data except the last day
|
421
|
+
|
422
|
+
m2 = Prophet.new.fit(df) # Adding the last day, fitting from scratch
|
423
|
+
m2 = Prophet.new.fit(df, init: stan_init(m1)) # Adding the last day, warm-starting from m1
|
424
|
+
```
|
425
|
+
|
329
426
|
## Resources
|
330
427
|
|
331
428
|
- [Forecasting at Scale](https://peerj.com/preprints/3190.pdf)
|
@@ -0,0 +1,349 @@
|
|
1
|
+
module Prophet
|
2
|
+
module Diagnostics
|
3
|
+
def self.generate_cutoffs(df, horizon, initial, period)
|
4
|
+
# Last cutoff is 'latest date in data - horizon' date
|
5
|
+
cutoff = df["ds"].max - horizon
|
6
|
+
if cutoff < df["ds"].min
|
7
|
+
raise Error, "Less data than horizon."
|
8
|
+
end
|
9
|
+
result = [cutoff]
|
10
|
+
while result[-1] >= df["ds"].min + initial
|
11
|
+
cutoff -= period
|
12
|
+
# If data does not exist in data range (cutoff, cutoff + horizon]
|
13
|
+
if !(((df["ds"] > cutoff) & (df["ds"] <= cutoff + horizon)).any?)
|
14
|
+
# Next cutoff point is 'last date before cutoff in data - horizon'
|
15
|
+
if cutoff > df["ds"].min
|
16
|
+
closest_date = df[df["ds"] <= cutoff].max["ds"]
|
17
|
+
cutoff = closest_date - horizon
|
18
|
+
end
|
19
|
+
# else no data left, leave cutoff as is, it will be dropped.
|
20
|
+
end
|
21
|
+
result << cutoff
|
22
|
+
end
|
23
|
+
result = result[0...-1]
|
24
|
+
if result.length == 0
|
25
|
+
raise Error, "Less data than horizon after initial window. Make horizon or initial shorter."
|
26
|
+
end
|
27
|
+
# logger.info("Making #{result.length} forecasts with cutoffs between #{result[-1]} and #{result[0]}")
|
28
|
+
result.reverse
|
29
|
+
end
|
30
|
+
|
31
|
+
def self.cross_validation(model, horizon:, period: nil, initial: nil, cutoffs: nil)
|
32
|
+
if model.history.nil?
|
33
|
+
raise Error, "Model has not been fit. Fitting the model provides contextual parameters for cross validation."
|
34
|
+
end
|
35
|
+
|
36
|
+
df = model.history.dup
|
37
|
+
horizon = timedelta(horizon)
|
38
|
+
|
39
|
+
predict_columns = ["ds", "yhat"]
|
40
|
+
if model.uncertainty_samples
|
41
|
+
predict_columns.concat(["yhat_lower", "yhat_upper"])
|
42
|
+
end
|
43
|
+
|
44
|
+
# Identify largest seasonality period
|
45
|
+
period_max = 0.0
|
46
|
+
model.seasonalities.each do |_, s|
|
47
|
+
period_max = [period_max, s[:period]].max
|
48
|
+
end
|
49
|
+
seasonality_dt = timedelta("#{period_max} days")
|
50
|
+
|
51
|
+
if cutoffs.nil?
|
52
|
+
# Set period
|
53
|
+
period = period.nil? ? 0.5 * horizon : timedelta(period)
|
54
|
+
|
55
|
+
# Set initial
|
56
|
+
initial = initial.nil? ? [3 * horizon, seasonality_dt].max : timedelta(initial)
|
57
|
+
|
58
|
+
# Compute Cutoffs
|
59
|
+
cutoffs = generate_cutoffs(df, horizon, initial, period)
|
60
|
+
else
|
61
|
+
# add validation of the cutoff to make sure that the min cutoff is strictly greater than the min date in the history
|
62
|
+
if cutoffs.min <= df["ds"].min
|
63
|
+
raise Error, "Minimum cutoff value is not strictly greater than min date in history"
|
64
|
+
end
|
65
|
+
# max value of cutoffs is <= (end date minus horizon)
|
66
|
+
end_date_minus_horizon = df["ds"].max - horizon
|
67
|
+
if cutoffs.max > end_date_minus_horizon
|
68
|
+
raise Error, "Maximum cutoff value is greater than end date minus horizon, no value for cross-validation remaining"
|
69
|
+
end
|
70
|
+
initial = cutoffs[0] - df["ds"].min
|
71
|
+
end
|
72
|
+
|
73
|
+
# Check if the initial window
|
74
|
+
# (that is, the amount of time between the start of the history and the first cutoff)
|
75
|
+
# is less than the maximum seasonality period
|
76
|
+
if initial < seasonality_dt
|
77
|
+
msg = "Seasonality has period of #{period_max} days "
|
78
|
+
msg += "which is larger than initial window. "
|
79
|
+
msg += "Consider increasing initial."
|
80
|
+
# logger.warn(msg)
|
81
|
+
end
|
82
|
+
|
83
|
+
predicts = cutoffs.map { |cutoff| single_cutoff_forecast(df, model, cutoff, horizon, predict_columns) }
|
84
|
+
|
85
|
+
# Combine all predicted DataFrame into one DataFrame
|
86
|
+
predicts.reduce(Rover::DataFrame.new) { |memo, v| memo.concat(v) }
|
87
|
+
end
|
88
|
+
|
89
|
+
def self.single_cutoff_forecast(df, model, cutoff, horizon, predict_columns)
|
90
|
+
# Generate new object with copying fitting options
|
91
|
+
m = prophet_copy(model, cutoff)
|
92
|
+
# Train model
|
93
|
+
history_c = df[df["ds"] <= cutoff]
|
94
|
+
if history_c.shape[0] < 2
|
95
|
+
raise Error, "Less than two datapoints before cutoff. Increase initial window."
|
96
|
+
end
|
97
|
+
m.fit(history_c, **model.fit_kwargs)
|
98
|
+
# Calculate yhat
|
99
|
+
index_predicted = (df["ds"] > cutoff) & (df["ds"] <= cutoff + horizon)
|
100
|
+
# Get the columns for the future dataframe
|
101
|
+
columns = ["ds"]
|
102
|
+
if m.growth == "logistic"
|
103
|
+
columns << "cap"
|
104
|
+
if m.logistic_floor
|
105
|
+
columns << "floor"
|
106
|
+
end
|
107
|
+
end
|
108
|
+
columns.concat(m.extra_regressors.keys)
|
109
|
+
columns.concat(m.seasonalities.map { |_, props| props[:condition_name] }.compact)
|
110
|
+
yhat = m.predict(df[index_predicted][columns])
|
111
|
+
# Merge yhat(predicts), y(df, original data) and cutoff
|
112
|
+
yhat[predict_columns].merge(df[index_predicted][["y"]]).merge(Rover::DataFrame.new({"cutoff" => [cutoff] * yhat.length}))
|
113
|
+
end
|
114
|
+
|
115
|
+
def self.prophet_copy(m, cutoff = nil)
|
116
|
+
if m.history.nil?
|
117
|
+
raise Error, "This is for copying a fitted Prophet object."
|
118
|
+
end
|
119
|
+
|
120
|
+
if m.specified_changepoints
|
121
|
+
changepoints = m.changepoints
|
122
|
+
if !cutoff.nil?
|
123
|
+
# Filter change points '< cutoff'
|
124
|
+
last_history_date = m.history["ds"][m.history["ds"] <= cutoff].max
|
125
|
+
changepoints = changepoints[changepoints < last_history_date]
|
126
|
+
end
|
127
|
+
else
|
128
|
+
changepoints = nil
|
129
|
+
end
|
130
|
+
|
131
|
+
# Auto seasonalities are set to False because they are already set in
|
132
|
+
# m.seasonalities.
|
133
|
+
m2 = m.class.new(
|
134
|
+
growth: m.growth,
|
135
|
+
n_changepoints: m.n_changepoints,
|
136
|
+
changepoint_range: m.changepoint_range,
|
137
|
+
changepoints: changepoints,
|
138
|
+
yearly_seasonality: false,
|
139
|
+
weekly_seasonality: false,
|
140
|
+
daily_seasonality: false,
|
141
|
+
holidays: m.holidays,
|
142
|
+
seasonality_mode: m.seasonality_mode,
|
143
|
+
seasonality_prior_scale: m.seasonality_prior_scale,
|
144
|
+
changepoint_prior_scale: m.changepoint_prior_scale,
|
145
|
+
holidays_prior_scale: m.holidays_prior_scale,
|
146
|
+
mcmc_samples: m.mcmc_samples,
|
147
|
+
interval_width: m.interval_width,
|
148
|
+
uncertainty_samples: m.uncertainty_samples
|
149
|
+
)
|
150
|
+
m2.extra_regressors = deepcopy(m.extra_regressors)
|
151
|
+
m2.seasonalities = deepcopy(m.seasonalities)
|
152
|
+
m2.country_holidays = deepcopy(m.country_holidays)
|
153
|
+
m2
|
154
|
+
end
|
155
|
+
|
156
|
+
def self.timedelta(value)
|
157
|
+
if value.is_a?(Numeric)
|
158
|
+
# ActiveSupport::Duration is a numeric
|
159
|
+
value
|
160
|
+
elsif (m = /\A(\d+(\.\d+)?) days\z/.match(value))
|
161
|
+
m[1].to_f * 86400
|
162
|
+
else
|
163
|
+
raise Error, "Unknown time delta"
|
164
|
+
end
|
165
|
+
end
|
166
|
+
|
167
|
+
def self.deepcopy(value)
|
168
|
+
if value.is_a?(Hash)
|
169
|
+
value.to_h { |k, v| [deepcopy(k), deepcopy(v)] }
|
170
|
+
elsif value.is_a?(Array)
|
171
|
+
value.map { |v| deepcopy(v) }
|
172
|
+
else
|
173
|
+
value.dup
|
174
|
+
end
|
175
|
+
end
|
176
|
+
|
177
|
+
def self.performance_metrics(df, metrics: nil, rolling_window: 0.1, monthly: false)
|
178
|
+
valid_metrics = ["mse", "rmse", "mae", "mape", "mdape", "smape", "coverage"]
|
179
|
+
if metrics.nil?
|
180
|
+
metrics = valid_metrics
|
181
|
+
end
|
182
|
+
if (df["yhat_lower"].nil? || df["yhat_upper"].nil?) && metrics.include?("coverage")
|
183
|
+
metrics.delete("coverage")
|
184
|
+
end
|
185
|
+
if metrics.uniq.length != metrics.length
|
186
|
+
raise ArgumentError, "Input metrics must be a list of unique values"
|
187
|
+
end
|
188
|
+
if !Set.new(metrics).subset?(Set.new(valid_metrics))
|
189
|
+
raise ArgumentError, "Valid values for metrics are: #{valid_metrics}"
|
190
|
+
end
|
191
|
+
df_m = df.dup
|
192
|
+
if monthly
|
193
|
+
raise Error, "Not implemented yet"
|
194
|
+
# df_m["horizon"] = df_m["ds"].dt.to_period("M").astype(int) - df_m["cutoff"].dt.to_period("M").astype(int)
|
195
|
+
else
|
196
|
+
df_m["horizon"] = df_m["ds"] - df_m["cutoff"]
|
197
|
+
end
|
198
|
+
df_m.sort_by! { |r| r["horizon"] }
|
199
|
+
if metrics.include?("mape") && df_m["y"].abs.min < 1e-8
|
200
|
+
# logger.info("Skipping MAPE because y close to 0")
|
201
|
+
metrics.delete("mape")
|
202
|
+
end
|
203
|
+
if metrics.length == 0
|
204
|
+
return nil
|
205
|
+
end
|
206
|
+
w = (rolling_window * df_m.shape[0]).to_i
|
207
|
+
if w >= 0
|
208
|
+
w = [w, 1].max
|
209
|
+
w = [w, df_m.shape[0]].min
|
210
|
+
end
|
211
|
+
# Compute all metrics
|
212
|
+
dfs = {}
|
213
|
+
metrics.each do |metric|
|
214
|
+
dfs[metric] = send(metric, df_m, w)
|
215
|
+
end
|
216
|
+
res = dfs[metrics[0]]
|
217
|
+
metrics.each do |metric|
|
218
|
+
res_m = dfs[metric]
|
219
|
+
res[metric] = res_m[metric]
|
220
|
+
end
|
221
|
+
res
|
222
|
+
end
|
223
|
+
|
224
|
+
def self.rolling_mean_by_h(x, h, w, name)
|
225
|
+
# Aggregate over h
|
226
|
+
df = Rover::DataFrame.new({"x" => x, "h" => h})
|
227
|
+
df2 = df.group("h").sum("x").inner_join(df.group("h").count).sort_by { |r| r["h"] }
|
228
|
+
xs = df2["sum_x"]
|
229
|
+
ns = df2["count"]
|
230
|
+
hs = df2["h"]
|
231
|
+
|
232
|
+
trailing_i = df2.length - 1
|
233
|
+
x_sum = 0
|
234
|
+
n_sum = 0
|
235
|
+
# We don't know output size but it is bounded by len(df2)
|
236
|
+
res_x = [nil] * df2.length
|
237
|
+
|
238
|
+
# Start from the right and work backwards
|
239
|
+
(df2.length - 1).downto(0) do |i|
|
240
|
+
x_sum += xs[i]
|
241
|
+
n_sum += ns[i]
|
242
|
+
while n_sum >= w
|
243
|
+
# Include points from the previous horizon. All of them if still
|
244
|
+
# less than w, otherwise weight the mean by the difference
|
245
|
+
excess_n = n_sum - w
|
246
|
+
excess_x = excess_n * xs[i] / ns[i]
|
247
|
+
res_x[trailing_i] = (x_sum - excess_x) / w
|
248
|
+
x_sum -= xs[trailing_i]
|
249
|
+
n_sum -= ns[trailing_i]
|
250
|
+
trailing_i -= 1
|
251
|
+
end
|
252
|
+
end
|
253
|
+
|
254
|
+
res_h = hs[(trailing_i + 1)..-1]
|
255
|
+
res_x = res_x[(trailing_i + 1)..-1]
|
256
|
+
|
257
|
+
Rover::DataFrame.new({"horizon" => res_h, name => res_x})
|
258
|
+
end
|
259
|
+
|
260
|
+
def self.rolling_median_by_h(x, h, w, name)
|
261
|
+
# Aggregate over h
|
262
|
+
df = Rover::DataFrame.new({"x" => x, "h" => h})
|
263
|
+
grouped = df.group("h")
|
264
|
+
df2 = grouped.count.sort_by { |r| r["h"] }
|
265
|
+
hs = df2["h"]
|
266
|
+
|
267
|
+
res_h = []
|
268
|
+
res_x = []
|
269
|
+
# Start from the right and work backwards
|
270
|
+
i = hs.length - 1
|
271
|
+
while i >= 0
|
272
|
+
h_i = hs[i]
|
273
|
+
xs = df[df["h"] == h_i]["x"].to_a
|
274
|
+
|
275
|
+
next_idx_to_add = (h == h_i).to_numo.cast_to(Numo::UInt8).argmax - 1
|
276
|
+
while xs.length < w && next_idx_to_add >= 0
|
277
|
+
# Include points from the previous horizon. All of them if still
|
278
|
+
# less than w, otherwise just enough to get to w.
|
279
|
+
xs << x[next_idx_to_add]
|
280
|
+
next_idx_to_add -= 1
|
281
|
+
end
|
282
|
+
if xs.length < w
|
283
|
+
# Ran out of points before getting enough.
|
284
|
+
break
|
285
|
+
end
|
286
|
+
res_h << hs[i]
|
287
|
+
res_x << Rover::Vector.new(xs).median
|
288
|
+
i -= 1
|
289
|
+
end
|
290
|
+
res_h.reverse!
|
291
|
+
res_x.reverse!
|
292
|
+
Rover::DataFrame.new({"horizon" => res_h, name => res_x})
|
293
|
+
end
|
294
|
+
|
295
|
+
def self.mse(df, w)
|
296
|
+
se = (df["y"] - df["yhat"]) ** 2
|
297
|
+
if w < 0
|
298
|
+
return Rover::DataFrame.new({"horizon" => df["horizon"], "mse" => se})
|
299
|
+
end
|
300
|
+
rolling_mean_by_h(se, df["horizon"], w, "mse")
|
301
|
+
end
|
302
|
+
|
303
|
+
def self.rmse(df, w)
|
304
|
+
res = mse(df, w)
|
305
|
+
res["rmse"] = res.delete("mse").map { |v| Math.sqrt(v) }
|
306
|
+
res
|
307
|
+
end
|
308
|
+
|
309
|
+
def self.mae(df, w)
|
310
|
+
ae = (df["y"] - df["yhat"]).abs
|
311
|
+
if w < 0
|
312
|
+
return Rover::DataFrame.new({"horizon" => df["horizon"], "mae" => ae})
|
313
|
+
end
|
314
|
+
rolling_mean_by_h(ae, df["horizon"], w, "mae")
|
315
|
+
end
|
316
|
+
|
317
|
+
def self.mape(df, w)
|
318
|
+
ape = ((df["y"] - df["yhat"]) / df["y"]).abs
|
319
|
+
if w < 0
|
320
|
+
return Rover::DataFrame.new({"horizon" => df["horizon"], "mape" => ape})
|
321
|
+
end
|
322
|
+
rolling_mean_by_h(ape, df["horizon"], w, "mape")
|
323
|
+
end
|
324
|
+
|
325
|
+
def self.mdape(df, w)
|
326
|
+
ape = ((df["y"] - df["yhat"]) / df["y"]).abs
|
327
|
+
if w < 0
|
328
|
+
return Rover::DataFrame.new({"horizon" => df["horizon"], "mdape" => ape})
|
329
|
+
end
|
330
|
+
rolling_median_by_h(ape, df["horizon"], w, "mdape")
|
331
|
+
end
|
332
|
+
|
333
|
+
def self.smape(df, w)
|
334
|
+
sape = (df["y"] - df["yhat"]).abs / ((df["y"].abs + df["yhat"].abs) / 2)
|
335
|
+
if w < 0
|
336
|
+
return Rover::DataFrame.new({"horizon" => df["horizon"], "smape" => sape})
|
337
|
+
end
|
338
|
+
rolling_mean_by_h(sape, df["horizon"], w, "smape")
|
339
|
+
end
|
340
|
+
|
341
|
+
def self.coverage(df, w)
|
342
|
+
is_covered = (df["y"] >= df["yhat_lower"]) & (df["y"] <= df["yhat_upper"])
|
343
|
+
if w < 0
|
344
|
+
return Rover::DataFrame.new({"horizon" => df["horizon"], "coverage" => is_covered})
|
345
|
+
end
|
346
|
+
rolling_mean_by_h(is_covered.to(:float), df["horizon"], w, "coverage")
|
347
|
+
end
|
348
|
+
end
|
349
|
+
end
|
data/lib/prophet/forecaster.rb
CHANGED
@@ -3,7 +3,14 @@ module Prophet
|
|
3
3
|
include Holidays
|
4
4
|
include Plot
|
5
5
|
|
6
|
-
attr_reader :logger, :params, :train_holiday_names
|
6
|
+
attr_reader :logger, :params, :train_holiday_names,
|
7
|
+
:history, :seasonalities, :specified_changepoints, :fit_kwargs,
|
8
|
+
:growth, :changepoints, :n_changepoints, :changepoint_range,
|
9
|
+
:holidays, :seasonality_mode, :seasonality_prior_scale,
|
10
|
+
:holidays_prior_scale, :changepoint_prior_scale, :mcmc_samples,
|
11
|
+
:interval_width, :uncertainty_samples
|
12
|
+
|
13
|
+
attr_accessor :extra_regressors, :seasonalities, :country_holidays
|
7
14
|
|
8
15
|
def initialize(
|
9
16
|
growth: "linear",
|
@@ -176,8 +183,10 @@ module Prophet
|
|
176
183
|
|
177
184
|
initialize_scales(initialize_scales, df)
|
178
185
|
|
179
|
-
if @logistic_floor
|
180
|
-
|
186
|
+
if @logistic_floor
|
187
|
+
unless df.include?("floor")
|
188
|
+
raise ArgumentError, "Expected column \"floor\"."
|
189
|
+
end
|
181
190
|
else
|
182
191
|
df["floor"] = 0
|
183
192
|
end
|
@@ -207,7 +216,12 @@ module Prophet
|
|
207
216
|
def initialize_scales(initialize_scales, df)
|
208
217
|
return unless initialize_scales
|
209
218
|
|
210
|
-
|
219
|
+
if @growth == "logistic" && df.include?("floor")
|
220
|
+
@logistic_floor = true
|
221
|
+
floor = df["floor"]
|
222
|
+
else
|
223
|
+
floor = 0.0
|
224
|
+
end
|
211
225
|
@y_scale = (df["y"] - floor).abs.max
|
212
226
|
@y_scale = 1 if @y_scale == 0
|
213
227
|
@start = df["ds"].min
|
data/lib/prophet/plot.rb
CHANGED
@@ -111,6 +111,61 @@ module Prophet
|
|
111
111
|
artists
|
112
112
|
end
|
113
113
|
|
114
|
+
def self.plot_cross_validation_metric(df_cv, metric:, rolling_window: 0.1, ax: nil, figsize: [10, 6], color: "b", point_color: "gray")
|
115
|
+
if ax.nil?
|
116
|
+
fig = plt.figure(facecolor: "w", figsize: figsize)
|
117
|
+
ax = fig.add_subplot(111)
|
118
|
+
else
|
119
|
+
fig = ax.get_figure
|
120
|
+
end
|
121
|
+
# Get the metric at the level of individual predictions, and with the rolling window.
|
122
|
+
df_none = Diagnostics.performance_metrics(df_cv, metrics: [metric], rolling_window: -1)
|
123
|
+
df_h = Diagnostics.performance_metrics(df_cv, metrics: [metric], rolling_window: rolling_window)
|
124
|
+
|
125
|
+
# Some work because matplotlib does not handle timedelta
|
126
|
+
# Target ~10 ticks.
|
127
|
+
tick_w = df_none["horizon"].max * 1e9 / 10.0
|
128
|
+
# Find the largest time resolution that has <1 unit per bin.
|
129
|
+
dts = ["D", "h", "m", "s", "ms", "us", "ns"]
|
130
|
+
dt_names = ["days", "hours", "minutes", "seconds", "milliseconds", "microseconds", "nanoseconds"]
|
131
|
+
dt_conversions = [
|
132
|
+
24 * 60 * 60 * 10 ** 9,
|
133
|
+
60 * 60 * 10 ** 9,
|
134
|
+
60 * 10 ** 9,
|
135
|
+
10 ** 9,
|
136
|
+
10 ** 6,
|
137
|
+
10 ** 3,
|
138
|
+
1.0
|
139
|
+
]
|
140
|
+
# TODO update
|
141
|
+
i = 0
|
142
|
+
# dts.each_with_index do |dt, i|
|
143
|
+
# if np.timedelta64(1, dt) < np.timedelta64(tick_w, "ns")
|
144
|
+
# break
|
145
|
+
# end
|
146
|
+
# end
|
147
|
+
|
148
|
+
x_plt = df_none["horizon"] * 1e9 / dt_conversions[i].to_f
|
149
|
+
x_plt_h = df_h["horizon"] * 1e9 / dt_conversions[i].to_f
|
150
|
+
|
151
|
+
ax.plot(x_plt.to_a, df_none[metric].to_a, ".", alpha: 0.1, c: point_color)
|
152
|
+
ax.plot(x_plt_h.to_a, df_h[metric].to_a, "-", c: color)
|
153
|
+
ax.grid(true)
|
154
|
+
|
155
|
+
ax.set_xlabel("Horizon (#{dt_names[i]})")
|
156
|
+
ax.set_ylabel(metric)
|
157
|
+
fig
|
158
|
+
end
|
159
|
+
|
160
|
+
def self.plt
|
161
|
+
begin
|
162
|
+
require "matplotlib/pyplot"
|
163
|
+
rescue LoadError
|
164
|
+
raise Error, "Install the matplotlib gem for plots"
|
165
|
+
end
|
166
|
+
Matplotlib::Pyplot
|
167
|
+
end
|
168
|
+
|
114
169
|
private
|
115
170
|
|
116
171
|
def plot_forecast_component(fcst, name, ax: nil, uncertainty: true, plot_cap: false, figsize: [10, 6])
|
@@ -263,12 +318,7 @@ module Prophet
|
|
263
318
|
end
|
264
319
|
|
265
320
|
def plt
|
266
|
-
|
267
|
-
require "matplotlib/pyplot"
|
268
|
-
rescue LoadError
|
269
|
-
raise Error, "Install the matplotlib gem for plots"
|
270
|
-
end
|
271
|
-
Matplotlib::Pyplot
|
321
|
+
Plot.plt
|
272
322
|
end
|
273
323
|
|
274
324
|
def dates
|
data/lib/prophet/stan_backend.rb
CHANGED
@@ -13,6 +13,11 @@ module Prophet
|
|
13
13
|
|
14
14
|
def fit(stan_init, stan_data, **kwargs)
|
15
15
|
stan_init, stan_data = prepare_data(stan_init, stan_data)
|
16
|
+
|
17
|
+
if !kwargs[:inits] && kwargs[:init]
|
18
|
+
kwargs[:inits] = prepare_data(kwargs.delete(:init), stan_data)[0]
|
19
|
+
end
|
20
|
+
|
16
21
|
kwargs[:algorithm] ||= stan_data["T"] < 100 ? "Newton" : "LBFGS"
|
17
22
|
iterations = 10000
|
18
23
|
|
@@ -49,6 +54,10 @@ module Prophet
|
|
49
54
|
def sampling(stan_init, stan_data, samples, **kwargs)
|
50
55
|
stan_init, stan_data = prepare_data(stan_init, stan_data)
|
51
56
|
|
57
|
+
if !kwargs[:inits] && kwargs[:init]
|
58
|
+
kwargs[:inits] = prepare_data(kwargs.delete(:init), stan_data)[0]
|
59
|
+
end
|
60
|
+
|
52
61
|
kwargs[:chains] ||= 4
|
53
62
|
kwargs[:warmup_iters] ||= samples / 2
|
54
63
|
|
@@ -128,7 +137,7 @@ module Prophet
|
|
128
137
|
stan_data["t_change"] = stan_data["t_change"].to_a
|
129
138
|
stan_data["s_a"] = stan_data["s_a"].to_a
|
130
139
|
stan_data["s_m"] = stan_data["s_m"].to_a
|
131
|
-
stan_data["X"] = stan_data["X"].to_numo.to_a
|
140
|
+
stan_data["X"] = stan_data["X"].respond_to?(:to_numo) ? stan_data["X"].to_numo.to_a : stan_data["X"].to_a
|
132
141
|
stan_init["delta"] = stan_init["delta"].to_a
|
133
142
|
stan_init["beta"] = stan_init["beta"].to_a
|
134
143
|
[stan_init, stan_data]
|
data/lib/prophet/version.rb
CHANGED
data/lib/prophet.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: prophet-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.4.
|
4
|
+
version: 0.4.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2022-07-
|
11
|
+
date: 2022-07-10 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: cmdstan
|
@@ -66,6 +66,7 @@ files:
|
|
66
66
|
- data-raw/generated_holidays.csv
|
67
67
|
- lib/prophet-rb.rb
|
68
68
|
- lib/prophet.rb
|
69
|
+
- lib/prophet/diagnostics.rb
|
69
70
|
- lib/prophet/forecaster.rb
|
70
71
|
- lib/prophet/holidays.rb
|
71
72
|
- lib/prophet/plot.rb
|