prophet-rb 0.4.0 → 0.4.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +101 -4
- data/lib/prophet/diagnostics.rb +349 -0
- data/lib/prophet/forecaster.rb +18 -4
- data/lib/prophet/plot.rb +56 -6
- data/lib/prophet/stan_backend.rb +10 -1
- data/lib/prophet/version.rb +1 -1
- data/lib/prophet.rb +1 -0
- metadata +3 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 69d58f060a9bda44b1ab8666ded81b3e61ff1b95b08ae40727fa7a2dfee57ee0
|
4
|
+
data.tar.gz: dc76685a8b45ca7cad79561f986af9c66e5e49bd45dd5e10c72f91088d9b470a
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 34d2fd0587110c6de9db334c44a4859a5de92d1cc124efb5f865c845d504816f3d3cd34d63776a709de396af8f9b2cf82a390cfb9f3e49e685ef8be4f469fb3e
|
7
|
+
data.tar.gz: '0587986407bb68a9ca97928ad65bcbb42fe101bd9d2e50a44b2126df38d67c04544932ec39009c6f1d6467736aa0a8404ac4aad04187dcb953285d4d5ee80a77'
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -88,7 +88,8 @@ Check out the [Prophet documentation](https://facebook.github.io/prophet/docs/qu
|
|
88
88
|
- [Multiplicative Seasonality](#multiplicative-seasonality)
|
89
89
|
- [Uncertainty Intervals](#uncertainty-intervals)
|
90
90
|
- [Non-Daily Data](#non-daily-data)
|
91
|
-
- [
|
91
|
+
- [Diagnostics](#diagnostics)
|
92
|
+
- [Additional Topics](#additional-topics)
|
92
93
|
|
93
94
|
## Advanced Quick Start
|
94
95
|
|
@@ -177,11 +178,24 @@ df = Rover.read_csv("example_wp_log_R.csv")
|
|
177
178
|
df["cap"] = 8.5
|
178
179
|
m = Prophet.new(growth: "logistic")
|
179
180
|
m.fit(df)
|
180
|
-
future = m.make_future_dataframe(periods:
|
181
|
+
future = m.make_future_dataframe(periods: 1826)
|
181
182
|
future["cap"] = 8.5
|
182
183
|
forecast = m.predict(future)
|
183
184
|
```
|
184
185
|
|
186
|
+
Saturating minimum
|
187
|
+
|
188
|
+
```ruby
|
189
|
+
df["y"] = 10 - df["y"]
|
190
|
+
df["cap"] = 6
|
191
|
+
df["floor"] = 1.5
|
192
|
+
future["cap"] = 6
|
193
|
+
future["floor"] = 1.5
|
194
|
+
m = Prophet.new(growth: "logistic")
|
195
|
+
m.fit(df)
|
196
|
+
forecast = m.predict(future)
|
197
|
+
```
|
198
|
+
|
185
199
|
## Trend Changepoints
|
186
200
|
|
187
201
|
[Explanation](https://facebook.github.io/prophet/docs/trend_changepoints.html)
|
@@ -308,9 +322,64 @@ future = m.make_future_dataframe(periods: 300, freq: "H")
|
|
308
322
|
forecast = m.predict(future)
|
309
323
|
```
|
310
324
|
|
311
|
-
##
|
325
|
+
## Diagnostics
|
326
|
+
|
327
|
+
[Explanation](http://facebook.github.io/prophet/docs/diagnostics.html)
|
328
|
+
|
329
|
+
Cross validation
|
330
|
+
|
331
|
+
```ruby
|
332
|
+
df_cv = Prophet::Diagnostics.cross_validation(m, initial: "730 days", period: "180 days", horizon: "365 days")
|
333
|
+
```
|
334
|
+
|
335
|
+
Custom cutoffs
|
336
|
+
|
337
|
+
```ruby
|
338
|
+
cutoffs = ["2013-02-15", "2013-08-15", "2014-02-15"].map { |v| Time.parse("#{v} 00:00:00 UTC") }
|
339
|
+
df_cv2 = Prophet::Diagnostics.cross_validation(m, cutoffs: cutoffs, horizon: "365 days")
|
340
|
+
```
|
341
|
+
|
342
|
+
Get performance metrics
|
312
343
|
|
313
|
-
|
344
|
+
```ruby
|
345
|
+
df_p = Prophet::Diagnostics.performance_metrics(df_cv)
|
346
|
+
```
|
347
|
+
|
348
|
+
Plot cross validation metrics
|
349
|
+
|
350
|
+
```ruby
|
351
|
+
Prophet::Plot.plot_cross_validation_metric(df_cv, metric: "mape")
|
352
|
+
```
|
353
|
+
|
354
|
+
Hyperparameter tuning
|
355
|
+
|
356
|
+
```ruby
|
357
|
+
param_grid = {
|
358
|
+
changepoint_prior_scale: [0.001, 0.01, 0.1, 0.5],
|
359
|
+
seasonality_prior_scale: [0.01, 0.1, 1.0, 10.0]
|
360
|
+
}
|
361
|
+
|
362
|
+
# Generate all combinations of parameters
|
363
|
+
all_params = param_grid.values[0].product(*param_grid.values[1..-1]).map { |v| param_grid.keys.zip(v).to_h }
|
364
|
+
rmses = [] # Store the RMSEs for each params here
|
365
|
+
|
366
|
+
# Use cross validation to evaluate all parameters
|
367
|
+
all_params.each do |params|
|
368
|
+
m = Prophet.new(**params).fit(df) # Fit model with given params
|
369
|
+
df_cv = Prophet::Diagnostics.cross_validation(m, cutoffs: cutoffs, horizon: "30 days")
|
370
|
+
df_p = Prophet::Diagnostics.performance_metrics(df_cv, rolling_window: 1)
|
371
|
+
rmses << df_p["rmse"][0]
|
372
|
+
end
|
373
|
+
|
374
|
+
# Find the best parameters
|
375
|
+
tuning_results = Rover::DataFrame.new(all_params)
|
376
|
+
tuning_results["rmse"] = rmses
|
377
|
+
p tuning_results
|
378
|
+
```
|
379
|
+
|
380
|
+
## Additional Topics
|
381
|
+
|
382
|
+
[Explanation](https://facebook.github.io/prophet/docs/additional_topics.html)
|
314
383
|
|
315
384
|
Save a model
|
316
385
|
|
@@ -326,6 +395,34 @@ m = Prophet.from_json(File.read("model.json"))
|
|
326
395
|
|
327
396
|
Uses the same format as Python, so models can be saved and loaded in either language
|
328
397
|
|
398
|
+
Flat trend
|
399
|
+
|
400
|
+
```ruby
|
401
|
+
m = Prophet.new(growth: "flat")
|
402
|
+
```
|
403
|
+
|
404
|
+
Updating fitted models
|
405
|
+
|
406
|
+
```ruby
|
407
|
+
def stan_init(m)
|
408
|
+
res = {}
|
409
|
+
["k", "m", "sigma_obs"].each do |pname|
|
410
|
+
res[pname] = m.params[pname][0, true][0]
|
411
|
+
end
|
412
|
+
["delta", "beta"].each do |pname|
|
413
|
+
res[pname] = m.params[pname][0, true]
|
414
|
+
end
|
415
|
+
res
|
416
|
+
end
|
417
|
+
|
418
|
+
df = Rover.read_csv("example_wp_log_peyton_manning.csv")
|
419
|
+
df1 = df[df["ds"] <= "2016-01-19"] # All data except the last day
|
420
|
+
m1 = Prophet.new.fit(df1) # A model fit to all data except the last day
|
421
|
+
|
422
|
+
m2 = Prophet.new.fit(df) # Adding the last day, fitting from scratch
|
423
|
+
m2 = Prophet.new.fit(df, init: stan_init(m1)) # Adding the last day, warm-starting from m1
|
424
|
+
```
|
425
|
+
|
329
426
|
## Resources
|
330
427
|
|
331
428
|
- [Forecasting at Scale](https://peerj.com/preprints/3190.pdf)
|
@@ -0,0 +1,349 @@
|
|
1
|
+
module Prophet
|
2
|
+
module Diagnostics
|
3
|
+
def self.generate_cutoffs(df, horizon, initial, period)
|
4
|
+
# Last cutoff is 'latest date in data - horizon' date
|
5
|
+
cutoff = df["ds"].max - horizon
|
6
|
+
if cutoff < df["ds"].min
|
7
|
+
raise Error, "Less data than horizon."
|
8
|
+
end
|
9
|
+
result = [cutoff]
|
10
|
+
while result[-1] >= df["ds"].min + initial
|
11
|
+
cutoff -= period
|
12
|
+
# If data does not exist in data range (cutoff, cutoff + horizon]
|
13
|
+
if !(((df["ds"] > cutoff) & (df["ds"] <= cutoff + horizon)).any?)
|
14
|
+
# Next cutoff point is 'last date before cutoff in data - horizon'
|
15
|
+
if cutoff > df["ds"].min
|
16
|
+
closest_date = df[df["ds"] <= cutoff].max["ds"]
|
17
|
+
cutoff = closest_date - horizon
|
18
|
+
end
|
19
|
+
# else no data left, leave cutoff as is, it will be dropped.
|
20
|
+
end
|
21
|
+
result << cutoff
|
22
|
+
end
|
23
|
+
result = result[0...-1]
|
24
|
+
if result.length == 0
|
25
|
+
raise Error, "Less data than horizon after initial window. Make horizon or initial shorter."
|
26
|
+
end
|
27
|
+
# logger.info("Making #{result.length} forecasts with cutoffs between #{result[-1]} and #{result[0]}")
|
28
|
+
result.reverse
|
29
|
+
end
|
30
|
+
|
31
|
+
def self.cross_validation(model, horizon:, period: nil, initial: nil, cutoffs: nil)
|
32
|
+
if model.history.nil?
|
33
|
+
raise Error, "Model has not been fit. Fitting the model provides contextual parameters for cross validation."
|
34
|
+
end
|
35
|
+
|
36
|
+
df = model.history.dup
|
37
|
+
horizon = timedelta(horizon)
|
38
|
+
|
39
|
+
predict_columns = ["ds", "yhat"]
|
40
|
+
if model.uncertainty_samples
|
41
|
+
predict_columns.concat(["yhat_lower", "yhat_upper"])
|
42
|
+
end
|
43
|
+
|
44
|
+
# Identify largest seasonality period
|
45
|
+
period_max = 0.0
|
46
|
+
model.seasonalities.each do |_, s|
|
47
|
+
period_max = [period_max, s[:period]].max
|
48
|
+
end
|
49
|
+
seasonality_dt = timedelta("#{period_max} days")
|
50
|
+
|
51
|
+
if cutoffs.nil?
|
52
|
+
# Set period
|
53
|
+
period = period.nil? ? 0.5 * horizon : timedelta(period)
|
54
|
+
|
55
|
+
# Set initial
|
56
|
+
initial = initial.nil? ? [3 * horizon, seasonality_dt].max : timedelta(initial)
|
57
|
+
|
58
|
+
# Compute Cutoffs
|
59
|
+
cutoffs = generate_cutoffs(df, horizon, initial, period)
|
60
|
+
else
|
61
|
+
# add validation of the cutoff to make sure that the min cutoff is strictly greater than the min date in the history
|
62
|
+
if cutoffs.min <= df["ds"].min
|
63
|
+
raise Error, "Minimum cutoff value is not strictly greater than min date in history"
|
64
|
+
end
|
65
|
+
# max value of cutoffs is <= (end date minus horizon)
|
66
|
+
end_date_minus_horizon = df["ds"].max - horizon
|
67
|
+
if cutoffs.max > end_date_minus_horizon
|
68
|
+
raise Error, "Maximum cutoff value is greater than end date minus horizon, no value for cross-validation remaining"
|
69
|
+
end
|
70
|
+
initial = cutoffs[0] - df["ds"].min
|
71
|
+
end
|
72
|
+
|
73
|
+
# Check if the initial window
|
74
|
+
# (that is, the amount of time between the start of the history and the first cutoff)
|
75
|
+
# is less than the maximum seasonality period
|
76
|
+
if initial < seasonality_dt
|
77
|
+
msg = "Seasonality has period of #{period_max} days "
|
78
|
+
msg += "which is larger than initial window. "
|
79
|
+
msg += "Consider increasing initial."
|
80
|
+
# logger.warn(msg)
|
81
|
+
end
|
82
|
+
|
83
|
+
predicts = cutoffs.map { |cutoff| single_cutoff_forecast(df, model, cutoff, horizon, predict_columns) }
|
84
|
+
|
85
|
+
# Combine all predicted DataFrame into one DataFrame
|
86
|
+
predicts.reduce(Rover::DataFrame.new) { |memo, v| memo.concat(v) }
|
87
|
+
end
|
88
|
+
|
89
|
+
def self.single_cutoff_forecast(df, model, cutoff, horizon, predict_columns)
|
90
|
+
# Generate new object with copying fitting options
|
91
|
+
m = prophet_copy(model, cutoff)
|
92
|
+
# Train model
|
93
|
+
history_c = df[df["ds"] <= cutoff]
|
94
|
+
if history_c.shape[0] < 2
|
95
|
+
raise Error, "Less than two datapoints before cutoff. Increase initial window."
|
96
|
+
end
|
97
|
+
m.fit(history_c, **model.fit_kwargs)
|
98
|
+
# Calculate yhat
|
99
|
+
index_predicted = (df["ds"] > cutoff) & (df["ds"] <= cutoff + horizon)
|
100
|
+
# Get the columns for the future dataframe
|
101
|
+
columns = ["ds"]
|
102
|
+
if m.growth == "logistic"
|
103
|
+
columns << "cap"
|
104
|
+
if m.logistic_floor
|
105
|
+
columns << "floor"
|
106
|
+
end
|
107
|
+
end
|
108
|
+
columns.concat(m.extra_regressors.keys)
|
109
|
+
columns.concat(m.seasonalities.map { |_, props| props[:condition_name] }.compact)
|
110
|
+
yhat = m.predict(df[index_predicted][columns])
|
111
|
+
# Merge yhat(predicts), y(df, original data) and cutoff
|
112
|
+
yhat[predict_columns].merge(df[index_predicted][["y"]]).merge(Rover::DataFrame.new({"cutoff" => [cutoff] * yhat.length}))
|
113
|
+
end
|
114
|
+
|
115
|
+
def self.prophet_copy(m, cutoff = nil)
|
116
|
+
if m.history.nil?
|
117
|
+
raise Error, "This is for copying a fitted Prophet object."
|
118
|
+
end
|
119
|
+
|
120
|
+
if m.specified_changepoints
|
121
|
+
changepoints = m.changepoints
|
122
|
+
if !cutoff.nil?
|
123
|
+
# Filter change points '< cutoff'
|
124
|
+
last_history_date = m.history["ds"][m.history["ds"] <= cutoff].max
|
125
|
+
changepoints = changepoints[changepoints < last_history_date]
|
126
|
+
end
|
127
|
+
else
|
128
|
+
changepoints = nil
|
129
|
+
end
|
130
|
+
|
131
|
+
# Auto seasonalities are set to False because they are already set in
|
132
|
+
# m.seasonalities.
|
133
|
+
m2 = m.class.new(
|
134
|
+
growth: m.growth,
|
135
|
+
n_changepoints: m.n_changepoints,
|
136
|
+
changepoint_range: m.changepoint_range,
|
137
|
+
changepoints: changepoints,
|
138
|
+
yearly_seasonality: false,
|
139
|
+
weekly_seasonality: false,
|
140
|
+
daily_seasonality: false,
|
141
|
+
holidays: m.holidays,
|
142
|
+
seasonality_mode: m.seasonality_mode,
|
143
|
+
seasonality_prior_scale: m.seasonality_prior_scale,
|
144
|
+
changepoint_prior_scale: m.changepoint_prior_scale,
|
145
|
+
holidays_prior_scale: m.holidays_prior_scale,
|
146
|
+
mcmc_samples: m.mcmc_samples,
|
147
|
+
interval_width: m.interval_width,
|
148
|
+
uncertainty_samples: m.uncertainty_samples
|
149
|
+
)
|
150
|
+
m2.extra_regressors = deepcopy(m.extra_regressors)
|
151
|
+
m2.seasonalities = deepcopy(m.seasonalities)
|
152
|
+
m2.country_holidays = deepcopy(m.country_holidays)
|
153
|
+
m2
|
154
|
+
end
|
155
|
+
|
156
|
+
def self.timedelta(value)
|
157
|
+
if value.is_a?(Numeric)
|
158
|
+
# ActiveSupport::Duration is a numeric
|
159
|
+
value
|
160
|
+
elsif (m = /\A(\d+(\.\d+)?) days\z/.match(value))
|
161
|
+
m[1].to_f * 86400
|
162
|
+
else
|
163
|
+
raise Error, "Unknown time delta"
|
164
|
+
end
|
165
|
+
end
|
166
|
+
|
167
|
+
def self.deepcopy(value)
|
168
|
+
if value.is_a?(Hash)
|
169
|
+
value.to_h { |k, v| [deepcopy(k), deepcopy(v)] }
|
170
|
+
elsif value.is_a?(Array)
|
171
|
+
value.map { |v| deepcopy(v) }
|
172
|
+
else
|
173
|
+
value.dup
|
174
|
+
end
|
175
|
+
end
|
176
|
+
|
177
|
+
def self.performance_metrics(df, metrics: nil, rolling_window: 0.1, monthly: false)
|
178
|
+
valid_metrics = ["mse", "rmse", "mae", "mape", "mdape", "smape", "coverage"]
|
179
|
+
if metrics.nil?
|
180
|
+
metrics = valid_metrics
|
181
|
+
end
|
182
|
+
if (df["yhat_lower"].nil? || df["yhat_upper"].nil?) && metrics.include?("coverage")
|
183
|
+
metrics.delete("coverage")
|
184
|
+
end
|
185
|
+
if metrics.uniq.length != metrics.length
|
186
|
+
raise ArgumentError, "Input metrics must be a list of unique values"
|
187
|
+
end
|
188
|
+
if !Set.new(metrics).subset?(Set.new(valid_metrics))
|
189
|
+
raise ArgumentError, "Valid values for metrics are: #{valid_metrics}"
|
190
|
+
end
|
191
|
+
df_m = df.dup
|
192
|
+
if monthly
|
193
|
+
raise Error, "Not implemented yet"
|
194
|
+
# df_m["horizon"] = df_m["ds"].dt.to_period("M").astype(int) - df_m["cutoff"].dt.to_period("M").astype(int)
|
195
|
+
else
|
196
|
+
df_m["horizon"] = df_m["ds"] - df_m["cutoff"]
|
197
|
+
end
|
198
|
+
df_m.sort_by! { |r| r["horizon"] }
|
199
|
+
if metrics.include?("mape") && df_m["y"].abs.min < 1e-8
|
200
|
+
# logger.info("Skipping MAPE because y close to 0")
|
201
|
+
metrics.delete("mape")
|
202
|
+
end
|
203
|
+
if metrics.length == 0
|
204
|
+
return nil
|
205
|
+
end
|
206
|
+
w = (rolling_window * df_m.shape[0]).to_i
|
207
|
+
if w >= 0
|
208
|
+
w = [w, 1].max
|
209
|
+
w = [w, df_m.shape[0]].min
|
210
|
+
end
|
211
|
+
# Compute all metrics
|
212
|
+
dfs = {}
|
213
|
+
metrics.each do |metric|
|
214
|
+
dfs[metric] = send(metric, df_m, w)
|
215
|
+
end
|
216
|
+
res = dfs[metrics[0]]
|
217
|
+
metrics.each do |metric|
|
218
|
+
res_m = dfs[metric]
|
219
|
+
res[metric] = res_m[metric]
|
220
|
+
end
|
221
|
+
res
|
222
|
+
end
|
223
|
+
|
224
|
+
def self.rolling_mean_by_h(x, h, w, name)
|
225
|
+
# Aggregate over h
|
226
|
+
df = Rover::DataFrame.new({"x" => x, "h" => h})
|
227
|
+
df2 = df.group("h").sum("x").inner_join(df.group("h").count).sort_by { |r| r["h"] }
|
228
|
+
xs = df2["sum_x"]
|
229
|
+
ns = df2["count"]
|
230
|
+
hs = df2["h"]
|
231
|
+
|
232
|
+
trailing_i = df2.length - 1
|
233
|
+
x_sum = 0
|
234
|
+
n_sum = 0
|
235
|
+
# We don't know output size but it is bounded by len(df2)
|
236
|
+
res_x = [nil] * df2.length
|
237
|
+
|
238
|
+
# Start from the right and work backwards
|
239
|
+
(df2.length - 1).downto(0) do |i|
|
240
|
+
x_sum += xs[i]
|
241
|
+
n_sum += ns[i]
|
242
|
+
while n_sum >= w
|
243
|
+
# Include points from the previous horizon. All of them if still
|
244
|
+
# less than w, otherwise weight the mean by the difference
|
245
|
+
excess_n = n_sum - w
|
246
|
+
excess_x = excess_n * xs[i] / ns[i]
|
247
|
+
res_x[trailing_i] = (x_sum - excess_x) / w
|
248
|
+
x_sum -= xs[trailing_i]
|
249
|
+
n_sum -= ns[trailing_i]
|
250
|
+
trailing_i -= 1
|
251
|
+
end
|
252
|
+
end
|
253
|
+
|
254
|
+
res_h = hs[(trailing_i + 1)..-1]
|
255
|
+
res_x = res_x[(trailing_i + 1)..-1]
|
256
|
+
|
257
|
+
Rover::DataFrame.new({"horizon" => res_h, name => res_x})
|
258
|
+
end
|
259
|
+
|
260
|
+
def self.rolling_median_by_h(x, h, w, name)
|
261
|
+
# Aggregate over h
|
262
|
+
df = Rover::DataFrame.new({"x" => x, "h" => h})
|
263
|
+
grouped = df.group("h")
|
264
|
+
df2 = grouped.count.sort_by { |r| r["h"] }
|
265
|
+
hs = df2["h"]
|
266
|
+
|
267
|
+
res_h = []
|
268
|
+
res_x = []
|
269
|
+
# Start from the right and work backwards
|
270
|
+
i = hs.length - 1
|
271
|
+
while i >= 0
|
272
|
+
h_i = hs[i]
|
273
|
+
xs = df[df["h"] == h_i]["x"].to_a
|
274
|
+
|
275
|
+
next_idx_to_add = (h == h_i).to_numo.cast_to(Numo::UInt8).argmax - 1
|
276
|
+
while xs.length < w && next_idx_to_add >= 0
|
277
|
+
# Include points from the previous horizon. All of them if still
|
278
|
+
# less than w, otherwise just enough to get to w.
|
279
|
+
xs << x[next_idx_to_add]
|
280
|
+
next_idx_to_add -= 1
|
281
|
+
end
|
282
|
+
if xs.length < w
|
283
|
+
# Ran out of points before getting enough.
|
284
|
+
break
|
285
|
+
end
|
286
|
+
res_h << hs[i]
|
287
|
+
res_x << Rover::Vector.new(xs).median
|
288
|
+
i -= 1
|
289
|
+
end
|
290
|
+
res_h.reverse!
|
291
|
+
res_x.reverse!
|
292
|
+
Rover::DataFrame.new({"horizon" => res_h, name => res_x})
|
293
|
+
end
|
294
|
+
|
295
|
+
def self.mse(df, w)
|
296
|
+
se = (df["y"] - df["yhat"]) ** 2
|
297
|
+
if w < 0
|
298
|
+
return Rover::DataFrame.new({"horizon" => df["horizon"], "mse" => se})
|
299
|
+
end
|
300
|
+
rolling_mean_by_h(se, df["horizon"], w, "mse")
|
301
|
+
end
|
302
|
+
|
303
|
+
def self.rmse(df, w)
|
304
|
+
res = mse(df, w)
|
305
|
+
res["rmse"] = res.delete("mse").map { |v| Math.sqrt(v) }
|
306
|
+
res
|
307
|
+
end
|
308
|
+
|
309
|
+
def self.mae(df, w)
|
310
|
+
ae = (df["y"] - df["yhat"]).abs
|
311
|
+
if w < 0
|
312
|
+
return Rover::DataFrame.new({"horizon" => df["horizon"], "mae" => ae})
|
313
|
+
end
|
314
|
+
rolling_mean_by_h(ae, df["horizon"], w, "mae")
|
315
|
+
end
|
316
|
+
|
317
|
+
def self.mape(df, w)
|
318
|
+
ape = ((df["y"] - df["yhat"]) / df["y"]).abs
|
319
|
+
if w < 0
|
320
|
+
return Rover::DataFrame.new({"horizon" => df["horizon"], "mape" => ape})
|
321
|
+
end
|
322
|
+
rolling_mean_by_h(ape, df["horizon"], w, "mape")
|
323
|
+
end
|
324
|
+
|
325
|
+
def self.mdape(df, w)
|
326
|
+
ape = ((df["y"] - df["yhat"]) / df["y"]).abs
|
327
|
+
if w < 0
|
328
|
+
return Rover::DataFrame.new({"horizon" => df["horizon"], "mdape" => ape})
|
329
|
+
end
|
330
|
+
rolling_median_by_h(ape, df["horizon"], w, "mdape")
|
331
|
+
end
|
332
|
+
|
333
|
+
def self.smape(df, w)
|
334
|
+
sape = (df["y"] - df["yhat"]).abs / ((df["y"].abs + df["yhat"].abs) / 2)
|
335
|
+
if w < 0
|
336
|
+
return Rover::DataFrame.new({"horizon" => df["horizon"], "smape" => sape})
|
337
|
+
end
|
338
|
+
rolling_mean_by_h(sape, df["horizon"], w, "smape")
|
339
|
+
end
|
340
|
+
|
341
|
+
def self.coverage(df, w)
|
342
|
+
is_covered = (df["y"] >= df["yhat_lower"]) & (df["y"] <= df["yhat_upper"])
|
343
|
+
if w < 0
|
344
|
+
return Rover::DataFrame.new({"horizon" => df["horizon"], "coverage" => is_covered})
|
345
|
+
end
|
346
|
+
rolling_mean_by_h(is_covered.to(:float), df["horizon"], w, "coverage")
|
347
|
+
end
|
348
|
+
end
|
349
|
+
end
|
data/lib/prophet/forecaster.rb
CHANGED
@@ -3,7 +3,14 @@ module Prophet
|
|
3
3
|
include Holidays
|
4
4
|
include Plot
|
5
5
|
|
6
|
-
attr_reader :logger, :params, :train_holiday_names
|
6
|
+
attr_reader :logger, :params, :train_holiday_names,
|
7
|
+
:history, :seasonalities, :specified_changepoints, :fit_kwargs,
|
8
|
+
:growth, :changepoints, :n_changepoints, :changepoint_range,
|
9
|
+
:holidays, :seasonality_mode, :seasonality_prior_scale,
|
10
|
+
:holidays_prior_scale, :changepoint_prior_scale, :mcmc_samples,
|
11
|
+
:interval_width, :uncertainty_samples
|
12
|
+
|
13
|
+
attr_accessor :extra_regressors, :seasonalities, :country_holidays
|
7
14
|
|
8
15
|
def initialize(
|
9
16
|
growth: "linear",
|
@@ -176,8 +183,10 @@ module Prophet
|
|
176
183
|
|
177
184
|
initialize_scales(initialize_scales, df)
|
178
185
|
|
179
|
-
if @logistic_floor
|
180
|
-
|
186
|
+
if @logistic_floor
|
187
|
+
unless df.include?("floor")
|
188
|
+
raise ArgumentError, "Expected column \"floor\"."
|
189
|
+
end
|
181
190
|
else
|
182
191
|
df["floor"] = 0
|
183
192
|
end
|
@@ -207,7 +216,12 @@ module Prophet
|
|
207
216
|
def initialize_scales(initialize_scales, df)
|
208
217
|
return unless initialize_scales
|
209
218
|
|
210
|
-
|
219
|
+
if @growth == "logistic" && df.include?("floor")
|
220
|
+
@logistic_floor = true
|
221
|
+
floor = df["floor"]
|
222
|
+
else
|
223
|
+
floor = 0.0
|
224
|
+
end
|
211
225
|
@y_scale = (df["y"] - floor).abs.max
|
212
226
|
@y_scale = 1 if @y_scale == 0
|
213
227
|
@start = df["ds"].min
|
data/lib/prophet/plot.rb
CHANGED
@@ -111,6 +111,61 @@ module Prophet
|
|
111
111
|
artists
|
112
112
|
end
|
113
113
|
|
114
|
+
def self.plot_cross_validation_metric(df_cv, metric:, rolling_window: 0.1, ax: nil, figsize: [10, 6], color: "b", point_color: "gray")
|
115
|
+
if ax.nil?
|
116
|
+
fig = plt.figure(facecolor: "w", figsize: figsize)
|
117
|
+
ax = fig.add_subplot(111)
|
118
|
+
else
|
119
|
+
fig = ax.get_figure
|
120
|
+
end
|
121
|
+
# Get the metric at the level of individual predictions, and with the rolling window.
|
122
|
+
df_none = Diagnostics.performance_metrics(df_cv, metrics: [metric], rolling_window: -1)
|
123
|
+
df_h = Diagnostics.performance_metrics(df_cv, metrics: [metric], rolling_window: rolling_window)
|
124
|
+
|
125
|
+
# Some work because matplotlib does not handle timedelta
|
126
|
+
# Target ~10 ticks.
|
127
|
+
tick_w = df_none["horizon"].max * 1e9 / 10.0
|
128
|
+
# Find the largest time resolution that has <1 unit per bin.
|
129
|
+
dts = ["D", "h", "m", "s", "ms", "us", "ns"]
|
130
|
+
dt_names = ["days", "hours", "minutes", "seconds", "milliseconds", "microseconds", "nanoseconds"]
|
131
|
+
dt_conversions = [
|
132
|
+
24 * 60 * 60 * 10 ** 9,
|
133
|
+
60 * 60 * 10 ** 9,
|
134
|
+
60 * 10 ** 9,
|
135
|
+
10 ** 9,
|
136
|
+
10 ** 6,
|
137
|
+
10 ** 3,
|
138
|
+
1.0
|
139
|
+
]
|
140
|
+
# TODO update
|
141
|
+
i = 0
|
142
|
+
# dts.each_with_index do |dt, i|
|
143
|
+
# if np.timedelta64(1, dt) < np.timedelta64(tick_w, "ns")
|
144
|
+
# break
|
145
|
+
# end
|
146
|
+
# end
|
147
|
+
|
148
|
+
x_plt = df_none["horizon"] * 1e9 / dt_conversions[i].to_f
|
149
|
+
x_plt_h = df_h["horizon"] * 1e9 / dt_conversions[i].to_f
|
150
|
+
|
151
|
+
ax.plot(x_plt.to_a, df_none[metric].to_a, ".", alpha: 0.1, c: point_color)
|
152
|
+
ax.plot(x_plt_h.to_a, df_h[metric].to_a, "-", c: color)
|
153
|
+
ax.grid(true)
|
154
|
+
|
155
|
+
ax.set_xlabel("Horizon (#{dt_names[i]})")
|
156
|
+
ax.set_ylabel(metric)
|
157
|
+
fig
|
158
|
+
end
|
159
|
+
|
160
|
+
def self.plt
|
161
|
+
begin
|
162
|
+
require "matplotlib/pyplot"
|
163
|
+
rescue LoadError
|
164
|
+
raise Error, "Install the matplotlib gem for plots"
|
165
|
+
end
|
166
|
+
Matplotlib::Pyplot
|
167
|
+
end
|
168
|
+
|
114
169
|
private
|
115
170
|
|
116
171
|
def plot_forecast_component(fcst, name, ax: nil, uncertainty: true, plot_cap: false, figsize: [10, 6])
|
@@ -263,12 +318,7 @@ module Prophet
|
|
263
318
|
end
|
264
319
|
|
265
320
|
def plt
|
266
|
-
|
267
|
-
require "matplotlib/pyplot"
|
268
|
-
rescue LoadError
|
269
|
-
raise Error, "Install the matplotlib gem for plots"
|
270
|
-
end
|
271
|
-
Matplotlib::Pyplot
|
321
|
+
Plot.plt
|
272
322
|
end
|
273
323
|
|
274
324
|
def dates
|
data/lib/prophet/stan_backend.rb
CHANGED
@@ -13,6 +13,11 @@ module Prophet
|
|
13
13
|
|
14
14
|
def fit(stan_init, stan_data, **kwargs)
|
15
15
|
stan_init, stan_data = prepare_data(stan_init, stan_data)
|
16
|
+
|
17
|
+
if !kwargs[:inits] && kwargs[:init]
|
18
|
+
kwargs[:inits] = prepare_data(kwargs.delete(:init), stan_data)[0]
|
19
|
+
end
|
20
|
+
|
16
21
|
kwargs[:algorithm] ||= stan_data["T"] < 100 ? "Newton" : "LBFGS"
|
17
22
|
iterations = 10000
|
18
23
|
|
@@ -49,6 +54,10 @@ module Prophet
|
|
49
54
|
def sampling(stan_init, stan_data, samples, **kwargs)
|
50
55
|
stan_init, stan_data = prepare_data(stan_init, stan_data)
|
51
56
|
|
57
|
+
if !kwargs[:inits] && kwargs[:init]
|
58
|
+
kwargs[:inits] = prepare_data(kwargs.delete(:init), stan_data)[0]
|
59
|
+
end
|
60
|
+
|
52
61
|
kwargs[:chains] ||= 4
|
53
62
|
kwargs[:warmup_iters] ||= samples / 2
|
54
63
|
|
@@ -128,7 +137,7 @@ module Prophet
|
|
128
137
|
stan_data["t_change"] = stan_data["t_change"].to_a
|
129
138
|
stan_data["s_a"] = stan_data["s_a"].to_a
|
130
139
|
stan_data["s_m"] = stan_data["s_m"].to_a
|
131
|
-
stan_data["X"] = stan_data["X"].to_numo.to_a
|
140
|
+
stan_data["X"] = stan_data["X"].respond_to?(:to_numo) ? stan_data["X"].to_numo.to_a : stan_data["X"].to_a
|
132
141
|
stan_init["delta"] = stan_init["delta"].to_a
|
133
142
|
stan_init["beta"] = stan_init["beta"].to_a
|
134
143
|
[stan_init, stan_data]
|
data/lib/prophet/version.rb
CHANGED
data/lib/prophet.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: prophet-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.4.
|
4
|
+
version: 0.4.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2022-07-
|
11
|
+
date: 2022-07-10 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: cmdstan
|
@@ -66,6 +66,7 @@ files:
|
|
66
66
|
- data-raw/generated_holidays.csv
|
67
67
|
- lib/prophet-rb.rb
|
68
68
|
- lib/prophet.rb
|
69
|
+
- lib/prophet/diagnostics.rb
|
69
70
|
- lib/prophet/forecaster.rb
|
70
71
|
- lib/prophet/holidays.rb
|
71
72
|
- lib/prophet/plot.rb
|