prophet-rb 0.3.1 → 0.4.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +17 -2
- data/LICENSE.txt +1 -1
- data/README.md +149 -2
- data/data-raw/LICENSE-holidays.txt +20 -0
- data/data-raw/README.md +3 -0
- data/data-raw/generated_holidays.csv +29302 -61443
- data/lib/prophet/diagnostics.rb +349 -0
- data/lib/prophet/forecaster.rb +214 -4
- data/lib/prophet/holidays.rb +6 -10
- data/lib/prophet/plot.rb +56 -6
- data/lib/prophet/stan_backend.rb +10 -1
- data/lib/prophet/version.rb +1 -1
- data/lib/prophet.rb +23 -7
- data/stan/{unix/prophet.stan → prophet.stan} +8 -7
- data/vendor/aarch64-linux/bin/prophet +0 -0
- data/vendor/aarch64-linux/lib/libtbb.so.2 +0 -0
- data/vendor/aarch64-linux/lib/libtbbmalloc.so.2 +0 -0
- data/vendor/aarch64-linux/lib/libtbbmalloc_proxy.so.2 +0 -0
- data/vendor/aarch64-linux/licenses/sundials-license.txt +25 -63
- data/vendor/aarch64-linux/licenses/sundials-notice.txt +21 -0
- data/vendor/arm64-darwin/bin/prophet +0 -0
- data/vendor/arm64-darwin/lib/libtbb.dylib +0 -0
- data/vendor/arm64-darwin/lib/libtbbmalloc.dylib +0 -0
- data/vendor/arm64-darwin/licenses/sundials-license.txt +25 -63
- data/vendor/arm64-darwin/licenses/sundials-notice.txt +21 -0
- data/vendor/x86_64-darwin/bin/prophet +0 -0
- data/vendor/x86_64-darwin/lib/libtbb.dylib +0 -0
- data/vendor/x86_64-darwin/lib/libtbbmalloc.dylib +0 -0
- data/vendor/x86_64-darwin/licenses/sundials-license.txt +25 -63
- data/vendor/x86_64-darwin/licenses/sundials-notice.txt +21 -0
- data/vendor/x86_64-linux/bin/prophet +0 -0
- data/vendor/x86_64-linux/lib/libtbb.so.2 +0 -0
- data/vendor/x86_64-linux/lib/libtbbmalloc.so.2 +0 -0
- data/vendor/x86_64-linux/lib/libtbbmalloc_proxy.so.2 +0 -0
- data/vendor/x86_64-linux/licenses/sundials-license.txt +25 -63
- data/vendor/x86_64-linux/licenses/sundials-notice.txt +21 -0
- metadata +10 -4
- data/stan/win/prophet.stan +0 -175
@@ -0,0 +1,349 @@
|
|
1
|
+
module Prophet
|
2
|
+
module Diagnostics
|
3
|
+
def self.generate_cutoffs(df, horizon, initial, period)
|
4
|
+
# Last cutoff is 'latest date in data - horizon' date
|
5
|
+
cutoff = df["ds"].max - horizon
|
6
|
+
if cutoff < df["ds"].min
|
7
|
+
raise Error, "Less data than horizon."
|
8
|
+
end
|
9
|
+
result = [cutoff]
|
10
|
+
while result[-1] >= df["ds"].min + initial
|
11
|
+
cutoff -= period
|
12
|
+
# If data does not exist in data range (cutoff, cutoff + horizon]
|
13
|
+
if !(((df["ds"] > cutoff) & (df["ds"] <= cutoff + horizon)).any?)
|
14
|
+
# Next cutoff point is 'last date before cutoff in data - horizon'
|
15
|
+
if cutoff > df["ds"].min
|
16
|
+
closest_date = df[df["ds"] <= cutoff].max["ds"]
|
17
|
+
cutoff = closest_date - horizon
|
18
|
+
end
|
19
|
+
# else no data left, leave cutoff as is, it will be dropped.
|
20
|
+
end
|
21
|
+
result << cutoff
|
22
|
+
end
|
23
|
+
result = result[0...-1]
|
24
|
+
if result.length == 0
|
25
|
+
raise Error, "Less data than horizon after initial window. Make horizon or initial shorter."
|
26
|
+
end
|
27
|
+
# logger.info("Making #{result.length} forecasts with cutoffs between #{result[-1]} and #{result[0]}")
|
28
|
+
result.reverse
|
29
|
+
end
|
30
|
+
|
31
|
+
def self.cross_validation(model, horizon:, period: nil, initial: nil, cutoffs: nil)
|
32
|
+
if model.history.nil?
|
33
|
+
raise Error, "Model has not been fit. Fitting the model provides contextual parameters for cross validation."
|
34
|
+
end
|
35
|
+
|
36
|
+
df = model.history.dup
|
37
|
+
horizon = timedelta(horizon)
|
38
|
+
|
39
|
+
predict_columns = ["ds", "yhat"]
|
40
|
+
if model.uncertainty_samples
|
41
|
+
predict_columns.concat(["yhat_lower", "yhat_upper"])
|
42
|
+
end
|
43
|
+
|
44
|
+
# Identify largest seasonality period
|
45
|
+
period_max = 0.0
|
46
|
+
model.seasonalities.each do |_, s|
|
47
|
+
period_max = [period_max, s[:period]].max
|
48
|
+
end
|
49
|
+
seasonality_dt = timedelta("#{period_max} days")
|
50
|
+
|
51
|
+
if cutoffs.nil?
|
52
|
+
# Set period
|
53
|
+
period = period.nil? ? 0.5 * horizon : timedelta(period)
|
54
|
+
|
55
|
+
# Set initial
|
56
|
+
initial = initial.nil? ? [3 * horizon, seasonality_dt].max : timedelta(initial)
|
57
|
+
|
58
|
+
# Compute Cutoffs
|
59
|
+
cutoffs = generate_cutoffs(df, horizon, initial, period)
|
60
|
+
else
|
61
|
+
# add validation of the cutoff to make sure that the min cutoff is strictly greater than the min date in the history
|
62
|
+
if cutoffs.min <= df["ds"].min
|
63
|
+
raise Error, "Minimum cutoff value is not strictly greater than min date in history"
|
64
|
+
end
|
65
|
+
# max value of cutoffs is <= (end date minus horizon)
|
66
|
+
end_date_minus_horizon = df["ds"].max - horizon
|
67
|
+
if cutoffs.max > end_date_minus_horizon
|
68
|
+
raise Error, "Maximum cutoff value is greater than end date minus horizon, no value for cross-validation remaining"
|
69
|
+
end
|
70
|
+
initial = cutoffs[0] - df["ds"].min
|
71
|
+
end
|
72
|
+
|
73
|
+
# Check if the initial window
|
74
|
+
# (that is, the amount of time between the start of the history and the first cutoff)
|
75
|
+
# is less than the maximum seasonality period
|
76
|
+
if initial < seasonality_dt
|
77
|
+
msg = "Seasonality has period of #{period_max} days "
|
78
|
+
msg += "which is larger than initial window. "
|
79
|
+
msg += "Consider increasing initial."
|
80
|
+
# logger.warn(msg)
|
81
|
+
end
|
82
|
+
|
83
|
+
predicts = cutoffs.map { |cutoff| single_cutoff_forecast(df, model, cutoff, horizon, predict_columns) }
|
84
|
+
|
85
|
+
# Combine all predicted DataFrame into one DataFrame
|
86
|
+
predicts.reduce(Rover::DataFrame.new) { |memo, v| memo.concat(v) }
|
87
|
+
end
|
88
|
+
|
89
|
+
def self.single_cutoff_forecast(df, model, cutoff, horizon, predict_columns)
|
90
|
+
# Generate new object with copying fitting options
|
91
|
+
m = prophet_copy(model, cutoff)
|
92
|
+
# Train model
|
93
|
+
history_c = df[df["ds"] <= cutoff]
|
94
|
+
if history_c.shape[0] < 2
|
95
|
+
raise Error, "Less than two datapoints before cutoff. Increase initial window."
|
96
|
+
end
|
97
|
+
m.fit(history_c, **model.fit_kwargs)
|
98
|
+
# Calculate yhat
|
99
|
+
index_predicted = (df["ds"] > cutoff) & (df["ds"] <= cutoff + horizon)
|
100
|
+
# Get the columns for the future dataframe
|
101
|
+
columns = ["ds"]
|
102
|
+
if m.growth == "logistic"
|
103
|
+
columns << "cap"
|
104
|
+
if m.logistic_floor
|
105
|
+
columns << "floor"
|
106
|
+
end
|
107
|
+
end
|
108
|
+
columns.concat(m.extra_regressors.keys)
|
109
|
+
columns.concat(m.seasonalities.map { |_, props| props[:condition_name] }.compact)
|
110
|
+
yhat = m.predict(df[index_predicted][columns])
|
111
|
+
# Merge yhat(predicts), y(df, original data) and cutoff
|
112
|
+
yhat[predict_columns].merge(df[index_predicted][["y"]]).merge(Rover::DataFrame.new({"cutoff" => [cutoff] * yhat.length}))
|
113
|
+
end
|
114
|
+
|
115
|
+
def self.prophet_copy(m, cutoff = nil)
|
116
|
+
if m.history.nil?
|
117
|
+
raise Error, "This is for copying a fitted Prophet object."
|
118
|
+
end
|
119
|
+
|
120
|
+
if m.specified_changepoints
|
121
|
+
changepoints = m.changepoints
|
122
|
+
if !cutoff.nil?
|
123
|
+
# Filter change points '< cutoff'
|
124
|
+
last_history_date = m.history["ds"][m.history["ds"] <= cutoff].max
|
125
|
+
changepoints = changepoints[changepoints < last_history_date]
|
126
|
+
end
|
127
|
+
else
|
128
|
+
changepoints = nil
|
129
|
+
end
|
130
|
+
|
131
|
+
# Auto seasonalities are set to False because they are already set in
|
132
|
+
# m.seasonalities.
|
133
|
+
m2 = m.class.new(
|
134
|
+
growth: m.growth,
|
135
|
+
n_changepoints: m.n_changepoints,
|
136
|
+
changepoint_range: m.changepoint_range,
|
137
|
+
changepoints: changepoints,
|
138
|
+
yearly_seasonality: false,
|
139
|
+
weekly_seasonality: false,
|
140
|
+
daily_seasonality: false,
|
141
|
+
holidays: m.holidays,
|
142
|
+
seasonality_mode: m.seasonality_mode,
|
143
|
+
seasonality_prior_scale: m.seasonality_prior_scale,
|
144
|
+
changepoint_prior_scale: m.changepoint_prior_scale,
|
145
|
+
holidays_prior_scale: m.holidays_prior_scale,
|
146
|
+
mcmc_samples: m.mcmc_samples,
|
147
|
+
interval_width: m.interval_width,
|
148
|
+
uncertainty_samples: m.uncertainty_samples
|
149
|
+
)
|
150
|
+
m2.extra_regressors = deepcopy(m.extra_regressors)
|
151
|
+
m2.seasonalities = deepcopy(m.seasonalities)
|
152
|
+
m2.country_holidays = deepcopy(m.country_holidays)
|
153
|
+
m2
|
154
|
+
end
|
155
|
+
|
156
|
+
def self.timedelta(value)
|
157
|
+
if value.is_a?(Numeric)
|
158
|
+
# ActiveSupport::Duration is a numeric
|
159
|
+
value
|
160
|
+
elsif (m = /\A(\d+(\.\d+)?) days\z/.match(value))
|
161
|
+
m[1].to_f * 86400
|
162
|
+
else
|
163
|
+
raise Error, "Unknown time delta"
|
164
|
+
end
|
165
|
+
end
|
166
|
+
|
167
|
+
def self.deepcopy(value)
|
168
|
+
if value.is_a?(Hash)
|
169
|
+
value.to_h { |k, v| [deepcopy(k), deepcopy(v)] }
|
170
|
+
elsif value.is_a?(Array)
|
171
|
+
value.map { |v| deepcopy(v) }
|
172
|
+
else
|
173
|
+
value.dup
|
174
|
+
end
|
175
|
+
end
|
176
|
+
|
177
|
+
def self.performance_metrics(df, metrics: nil, rolling_window: 0.1, monthly: false)
|
178
|
+
valid_metrics = ["mse", "rmse", "mae", "mape", "mdape", "smape", "coverage"]
|
179
|
+
if metrics.nil?
|
180
|
+
metrics = valid_metrics
|
181
|
+
end
|
182
|
+
if (df["yhat_lower"].nil? || df["yhat_upper"].nil?) && metrics.include?("coverage")
|
183
|
+
metrics.delete("coverage")
|
184
|
+
end
|
185
|
+
if metrics.uniq.length != metrics.length
|
186
|
+
raise ArgumentError, "Input metrics must be a list of unique values"
|
187
|
+
end
|
188
|
+
if !Set.new(metrics).subset?(Set.new(valid_metrics))
|
189
|
+
raise ArgumentError, "Valid values for metrics are: #{valid_metrics}"
|
190
|
+
end
|
191
|
+
df_m = df.dup
|
192
|
+
if monthly
|
193
|
+
raise Error, "Not implemented yet"
|
194
|
+
# df_m["horizon"] = df_m["ds"].dt.to_period("M").astype(int) - df_m["cutoff"].dt.to_period("M").astype(int)
|
195
|
+
else
|
196
|
+
df_m["horizon"] = df_m["ds"] - df_m["cutoff"]
|
197
|
+
end
|
198
|
+
df_m.sort_by! { |r| r["horizon"] }
|
199
|
+
if metrics.include?("mape") && df_m["y"].abs.min < 1e-8
|
200
|
+
# logger.info("Skipping MAPE because y close to 0")
|
201
|
+
metrics.delete("mape")
|
202
|
+
end
|
203
|
+
if metrics.length == 0
|
204
|
+
return nil
|
205
|
+
end
|
206
|
+
w = (rolling_window * df_m.shape[0]).to_i
|
207
|
+
if w >= 0
|
208
|
+
w = [w, 1].max
|
209
|
+
w = [w, df_m.shape[0]].min
|
210
|
+
end
|
211
|
+
# Compute all metrics
|
212
|
+
dfs = {}
|
213
|
+
metrics.each do |metric|
|
214
|
+
dfs[metric] = send(metric, df_m, w)
|
215
|
+
end
|
216
|
+
res = dfs[metrics[0]]
|
217
|
+
metrics.each do |metric|
|
218
|
+
res_m = dfs[metric]
|
219
|
+
res[metric] = res_m[metric]
|
220
|
+
end
|
221
|
+
res
|
222
|
+
end
|
223
|
+
|
224
|
+
def self.rolling_mean_by_h(x, h, w, name)
|
225
|
+
# Aggregate over h
|
226
|
+
df = Rover::DataFrame.new({"x" => x, "h" => h})
|
227
|
+
df2 = df.group("h").sum("x").inner_join(df.group("h").count).sort_by { |r| r["h"] }
|
228
|
+
xs = df2["sum_x"]
|
229
|
+
ns = df2["count"]
|
230
|
+
hs = df2["h"]
|
231
|
+
|
232
|
+
trailing_i = df2.length - 1
|
233
|
+
x_sum = 0
|
234
|
+
n_sum = 0
|
235
|
+
# We don't know output size but it is bounded by len(df2)
|
236
|
+
res_x = [nil] * df2.length
|
237
|
+
|
238
|
+
# Start from the right and work backwards
|
239
|
+
(df2.length - 1).downto(0) do |i|
|
240
|
+
x_sum += xs[i]
|
241
|
+
n_sum += ns[i]
|
242
|
+
while n_sum >= w
|
243
|
+
# Include points from the previous horizon. All of them if still
|
244
|
+
# less than w, otherwise weight the mean by the difference
|
245
|
+
excess_n = n_sum - w
|
246
|
+
excess_x = excess_n * xs[i] / ns[i]
|
247
|
+
res_x[trailing_i] = (x_sum - excess_x) / w
|
248
|
+
x_sum -= xs[trailing_i]
|
249
|
+
n_sum -= ns[trailing_i]
|
250
|
+
trailing_i -= 1
|
251
|
+
end
|
252
|
+
end
|
253
|
+
|
254
|
+
res_h = hs[(trailing_i + 1)..-1]
|
255
|
+
res_x = res_x[(trailing_i + 1)..-1]
|
256
|
+
|
257
|
+
Rover::DataFrame.new({"horizon" => res_h, name => res_x})
|
258
|
+
end
|
259
|
+
|
260
|
+
def self.rolling_median_by_h(x, h, w, name)
|
261
|
+
# Aggregate over h
|
262
|
+
df = Rover::DataFrame.new({"x" => x, "h" => h})
|
263
|
+
grouped = df.group("h")
|
264
|
+
df2 = grouped.count.sort_by { |r| r["h"] }
|
265
|
+
hs = df2["h"]
|
266
|
+
|
267
|
+
res_h = []
|
268
|
+
res_x = []
|
269
|
+
# Start from the right and work backwards
|
270
|
+
i = hs.length - 1
|
271
|
+
while i >= 0
|
272
|
+
h_i = hs[i]
|
273
|
+
xs = df[df["h"] == h_i]["x"].to_a
|
274
|
+
|
275
|
+
next_idx_to_add = (h == h_i).to_numo.cast_to(Numo::UInt8).argmax - 1
|
276
|
+
while xs.length < w && next_idx_to_add >= 0
|
277
|
+
# Include points from the previous horizon. All of them if still
|
278
|
+
# less than w, otherwise just enough to get to w.
|
279
|
+
xs << x[next_idx_to_add]
|
280
|
+
next_idx_to_add -= 1
|
281
|
+
end
|
282
|
+
if xs.length < w
|
283
|
+
# Ran out of points before getting enough.
|
284
|
+
break
|
285
|
+
end
|
286
|
+
res_h << hs[i]
|
287
|
+
res_x << Rover::Vector.new(xs).median
|
288
|
+
i -= 1
|
289
|
+
end
|
290
|
+
res_h.reverse!
|
291
|
+
res_x.reverse!
|
292
|
+
Rover::DataFrame.new({"horizon" => res_h, name => res_x})
|
293
|
+
end
|
294
|
+
|
295
|
+
def self.mse(df, w)
|
296
|
+
se = (df["y"] - df["yhat"]) ** 2
|
297
|
+
if w < 0
|
298
|
+
return Rover::DataFrame.new({"horizon" => df["horizon"], "mse" => se})
|
299
|
+
end
|
300
|
+
rolling_mean_by_h(se, df["horizon"], w, "mse")
|
301
|
+
end
|
302
|
+
|
303
|
+
def self.rmse(df, w)
|
304
|
+
res = mse(df, w)
|
305
|
+
res["rmse"] = res.delete("mse").map { |v| Math.sqrt(v) }
|
306
|
+
res
|
307
|
+
end
|
308
|
+
|
309
|
+
def self.mae(df, w)
|
310
|
+
ae = (df["y"] - df["yhat"]).abs
|
311
|
+
if w < 0
|
312
|
+
return Rover::DataFrame.new({"horizon" => df["horizon"], "mae" => ae})
|
313
|
+
end
|
314
|
+
rolling_mean_by_h(ae, df["horizon"], w, "mae")
|
315
|
+
end
|
316
|
+
|
317
|
+
def self.mape(df, w)
|
318
|
+
ape = ((df["y"] - df["yhat"]) / df["y"]).abs
|
319
|
+
if w < 0
|
320
|
+
return Rover::DataFrame.new({"horizon" => df["horizon"], "mape" => ape})
|
321
|
+
end
|
322
|
+
rolling_mean_by_h(ape, df["horizon"], w, "mape")
|
323
|
+
end
|
324
|
+
|
325
|
+
def self.mdape(df, w)
|
326
|
+
ape = ((df["y"] - df["yhat"]) / df["y"]).abs
|
327
|
+
if w < 0
|
328
|
+
return Rover::DataFrame.new({"horizon" => df["horizon"], "mdape" => ape})
|
329
|
+
end
|
330
|
+
rolling_median_by_h(ape, df["horizon"], w, "mdape")
|
331
|
+
end
|
332
|
+
|
333
|
+
def self.smape(df, w)
|
334
|
+
sape = (df["y"] - df["yhat"]).abs / ((df["y"].abs + df["yhat"].abs) / 2)
|
335
|
+
if w < 0
|
336
|
+
return Rover::DataFrame.new({"horizon" => df["horizon"], "smape" => sape})
|
337
|
+
end
|
338
|
+
rolling_mean_by_h(sape, df["horizon"], w, "smape")
|
339
|
+
end
|
340
|
+
|
341
|
+
def self.coverage(df, w)
|
342
|
+
is_covered = (df["y"] >= df["yhat_lower"]) & (df["y"] <= df["yhat_upper"])
|
343
|
+
if w < 0
|
344
|
+
return Rover::DataFrame.new({"horizon" => df["horizon"], "coverage" => is_covered})
|
345
|
+
end
|
346
|
+
rolling_mean_by_h(is_covered.to(:float), df["horizon"], w, "coverage")
|
347
|
+
end
|
348
|
+
end
|
349
|
+
end
|
data/lib/prophet/forecaster.rb
CHANGED
@@ -3,7 +3,14 @@ module Prophet
|
|
3
3
|
include Holidays
|
4
4
|
include Plot
|
5
5
|
|
6
|
-
attr_reader :logger, :params, :train_holiday_names
|
6
|
+
attr_reader :logger, :params, :train_holiday_names,
|
7
|
+
:history, :seasonalities, :specified_changepoints, :fit_kwargs,
|
8
|
+
:growth, :changepoints, :n_changepoints, :changepoint_range,
|
9
|
+
:holidays, :seasonality_mode, :seasonality_prior_scale,
|
10
|
+
:holidays_prior_scale, :changepoint_prior_scale, :mcmc_samples,
|
11
|
+
:interval_width, :uncertainty_samples
|
12
|
+
|
13
|
+
attr_accessor :extra_regressors, :seasonalities, :country_holidays
|
7
14
|
|
8
15
|
def initialize(
|
9
16
|
growth: "linear",
|
@@ -176,8 +183,10 @@ module Prophet
|
|
176
183
|
|
177
184
|
initialize_scales(initialize_scales, df)
|
178
185
|
|
179
|
-
if @logistic_floor
|
180
|
-
|
186
|
+
if @logistic_floor
|
187
|
+
unless df.include?("floor")
|
188
|
+
raise ArgumentError, "Expected column \"floor\"."
|
189
|
+
end
|
181
190
|
else
|
182
191
|
df["floor"] = 0
|
183
192
|
end
|
@@ -207,7 +216,12 @@ module Prophet
|
|
207
216
|
def initialize_scales(initialize_scales, df)
|
208
217
|
return unless initialize_scales
|
209
218
|
|
210
|
-
|
219
|
+
if @growth == "logistic" && df.include?("floor")
|
220
|
+
@logistic_floor = true
|
221
|
+
floor = df["floor"]
|
222
|
+
else
|
223
|
+
floor = 0.0
|
224
|
+
end
|
211
225
|
@y_scale = (df["y"] - floor).abs.max
|
212
226
|
@y_scale = 1 if @y_scale == 0
|
213
227
|
@start = df["ds"].min
|
@@ -386,6 +400,12 @@ module Prophet
|
|
386
400
|
|
387
401
|
def add_country_holidays(country_name)
|
388
402
|
raise Error, "Country holidays must be added prior to model fitting." if @history
|
403
|
+
|
404
|
+
# Fix for previously documented keyword argument
|
405
|
+
if country_name.is_a?(Hash) && country_name[:country_name]
|
406
|
+
country_name = country_name[:country_name]
|
407
|
+
end
|
408
|
+
|
389
409
|
# Validate names.
|
390
410
|
get_holiday_names(country_name).each do |name|
|
391
411
|
# Allow merging with existing holidays
|
@@ -965,6 +985,12 @@ module Prophet
|
|
965
985
|
Rover::DataFrame.new({"ds" => dates})
|
966
986
|
end
|
967
987
|
|
988
|
+
def to_json
|
989
|
+
require "json"
|
990
|
+
|
991
|
+
JSON.generate(as_json)
|
992
|
+
end
|
993
|
+
|
968
994
|
private
|
969
995
|
|
970
996
|
# Time is preferred over DateTime in Ruby docs
|
@@ -1011,5 +1037,189 @@ module Prophet
|
|
1011
1037
|
u = Numo::DFloat.new(size).rand(-0.5, 0.5)
|
1012
1038
|
loc - scale * u.sign * Numo::NMath.log(1 - 2 * u.abs)
|
1013
1039
|
end
|
1040
|
+
|
1041
|
+
SIMPLE_ATTRIBUTES = [
|
1042
|
+
"growth", "n_changepoints", "specified_changepoints", "changepoint_range",
|
1043
|
+
"yearly_seasonality", "weekly_seasonality", "daily_seasonality",
|
1044
|
+
"seasonality_mode", "seasonality_prior_scale", "changepoint_prior_scale",
|
1045
|
+
"holidays_prior_scale", "mcmc_samples", "interval_width", "uncertainty_samples",
|
1046
|
+
"y_scale", "logistic_floor", "country_holidays", "component_modes"
|
1047
|
+
]
|
1048
|
+
|
1049
|
+
PD_SERIES = ["changepoints", "history_dates", "train_holiday_names"]
|
1050
|
+
|
1051
|
+
PD_TIMESTAMP = ["start"]
|
1052
|
+
|
1053
|
+
PD_TIMEDELTA = ["t_scale"]
|
1054
|
+
|
1055
|
+
PD_DATAFRAME = ["holidays", "history", "train_component_cols"]
|
1056
|
+
|
1057
|
+
NP_ARRAY = ["changepoints_t"]
|
1058
|
+
|
1059
|
+
ORDEREDDICT = ["seasonalities", "extra_regressors"]
|
1060
|
+
|
1061
|
+
def as_json
|
1062
|
+
if @history.nil?
|
1063
|
+
raise Error, "This can only be used to serialize models that have already been fit."
|
1064
|
+
end
|
1065
|
+
|
1066
|
+
model_dict =
|
1067
|
+
SIMPLE_ATTRIBUTES.to_h do |attribute|
|
1068
|
+
[attribute, instance_variable_get("@#{attribute}")]
|
1069
|
+
end
|
1070
|
+
|
1071
|
+
# Handle attributes of non-core types
|
1072
|
+
PD_SERIES.each do |attribute|
|
1073
|
+
if instance_variable_get("@#{attribute}").nil?
|
1074
|
+
model_dict[attribute] = nil
|
1075
|
+
else
|
1076
|
+
v = instance_variable_get("@#{attribute}")
|
1077
|
+
d = {
|
1078
|
+
"name" => "ds",
|
1079
|
+
"index" => v.size.times.to_a,
|
1080
|
+
"data" => v.to_a.map { |v| v.iso8601(3) }
|
1081
|
+
}
|
1082
|
+
model_dict[attribute] = JSON.generate(d)
|
1083
|
+
end
|
1084
|
+
end
|
1085
|
+
PD_TIMESTAMP.each do |attribute|
|
1086
|
+
model_dict[attribute] = instance_variable_get("@#{attribute}").to_f
|
1087
|
+
end
|
1088
|
+
PD_TIMEDELTA.each do |attribute|
|
1089
|
+
model_dict[attribute] = instance_variable_get("@#{attribute}").to_f
|
1090
|
+
end
|
1091
|
+
PD_DATAFRAME.each do |attribute|
|
1092
|
+
if instance_variable_get("@#{attribute}").nil?
|
1093
|
+
model_dict[attribute] = nil
|
1094
|
+
else
|
1095
|
+
# use same format as Pandas
|
1096
|
+
v = instance_variable_get("@#{attribute}")
|
1097
|
+
|
1098
|
+
v = v.dup
|
1099
|
+
v["ds"] = v["ds"].map { |v| v.iso8601(3) } if v["ds"]
|
1100
|
+
v.delete("col")
|
1101
|
+
|
1102
|
+
fields =
|
1103
|
+
v.types.map do |k, t|
|
1104
|
+
type =
|
1105
|
+
case t
|
1106
|
+
when :object
|
1107
|
+
"datetime"
|
1108
|
+
when :int64
|
1109
|
+
"integer"
|
1110
|
+
else
|
1111
|
+
"number"
|
1112
|
+
end
|
1113
|
+
{"name" => k, "type" => type}
|
1114
|
+
end
|
1115
|
+
|
1116
|
+
d = {
|
1117
|
+
"schema" => {
|
1118
|
+
"fields" => fields,
|
1119
|
+
"pandas_version" => "0.20.0"
|
1120
|
+
},
|
1121
|
+
"data" => v.to_a
|
1122
|
+
}
|
1123
|
+
model_dict[attribute] = JSON.generate(d)
|
1124
|
+
end
|
1125
|
+
end
|
1126
|
+
NP_ARRAY.each do |attribute|
|
1127
|
+
model_dict[attribute] = instance_variable_get("@#{attribute}").to_a
|
1128
|
+
end
|
1129
|
+
ORDEREDDICT.each do |attribute|
|
1130
|
+
model_dict[attribute] = [
|
1131
|
+
instance_variable_get("@#{attribute}").keys,
|
1132
|
+
instance_variable_get("@#{attribute}").transform_keys(&:to_s)
|
1133
|
+
]
|
1134
|
+
end
|
1135
|
+
# Other attributes with special handling
|
1136
|
+
# fit_kwargs -> Transform any numpy types before serializing.
|
1137
|
+
# They do not need to be transformed back on deserializing.
|
1138
|
+
# TODO deep copy
|
1139
|
+
fit_kwargs = @fit_kwargs.to_h { |k, v| [k.to_s, v.dup] }
|
1140
|
+
if fit_kwargs.key?("init")
|
1141
|
+
fit_kwargs["init"].each do |k, v|
|
1142
|
+
if v.is_a?(Numo::NArray)
|
1143
|
+
fit_kwargs["init"][k] = v.to_a
|
1144
|
+
# elsif v.is_a?(Float)
|
1145
|
+
# fit_kwargs["init"][k] = v.to_f
|
1146
|
+
end
|
1147
|
+
end
|
1148
|
+
end
|
1149
|
+
model_dict["fit_kwargs"] = fit_kwargs
|
1150
|
+
|
1151
|
+
# Params (Dict[str, np.ndarray])
|
1152
|
+
model_dict["params"] = params.transform_values(&:to_a)
|
1153
|
+
# Attributes that are skipped: stan_fit, stan_backend
|
1154
|
+
# Returns 1.0 for Prophet 1.1
|
1155
|
+
model_dict["__prophet_version"] = "1.0"
|
1156
|
+
model_dict
|
1157
|
+
end
|
1158
|
+
|
1159
|
+
def self.from_json(model_json)
|
1160
|
+
require "json"
|
1161
|
+
|
1162
|
+
model_dict = JSON.parse(model_json)
|
1163
|
+
|
1164
|
+
# We will overwrite all attributes set in init anyway
|
1165
|
+
model = Prophet.new
|
1166
|
+
# Simple types
|
1167
|
+
SIMPLE_ATTRIBUTES.each do |attribute|
|
1168
|
+
model.instance_variable_set("@#{attribute}", model_dict.fetch(attribute))
|
1169
|
+
end
|
1170
|
+
PD_SERIES.each do |attribute|
|
1171
|
+
if model_dict[attribute].nil?
|
1172
|
+
model.instance_variable_set("@#{attribute}", nil)
|
1173
|
+
else
|
1174
|
+
d = JSON.parse(model_dict.fetch(attribute))
|
1175
|
+
s = Rover::Vector.new(d["data"])
|
1176
|
+
if d["name"] == "ds"
|
1177
|
+
s = s.map { |v| Time.parse(v).utc }
|
1178
|
+
end
|
1179
|
+
model.instance_variable_set("@#{attribute}", s)
|
1180
|
+
end
|
1181
|
+
end
|
1182
|
+
PD_TIMESTAMP.each do |attribute|
|
1183
|
+
model.instance_variable_set("@#{attribute}", Time.at(model_dict.fetch(attribute)))
|
1184
|
+
end
|
1185
|
+
PD_TIMEDELTA.each do |attribute|
|
1186
|
+
model.instance_variable_set("@#{attribute}", model_dict.fetch(attribute).to_f)
|
1187
|
+
end
|
1188
|
+
PD_DATAFRAME.each do |attribute|
|
1189
|
+
if model_dict[attribute].nil?
|
1190
|
+
model.instance_variable_set("@#{attribute}", nil)
|
1191
|
+
else
|
1192
|
+
d = JSON.parse(model_dict.fetch(attribute))
|
1193
|
+
df = Rover::DataFrame.new(d["data"])
|
1194
|
+
df["ds"] = df["ds"].map { |v| Time.parse(v).utc } if df["ds"]
|
1195
|
+
if attribute == "train_component_cols"
|
1196
|
+
# Special handling because of named index column
|
1197
|
+
# df.columns.name = 'component'
|
1198
|
+
# df.index.name = 'col'
|
1199
|
+
end
|
1200
|
+
model.instance_variable_set("@#{attribute}", df)
|
1201
|
+
end
|
1202
|
+
end
|
1203
|
+
NP_ARRAY.each do |attribute|
|
1204
|
+
model.instance_variable_set("@#{attribute}", Numo::NArray.cast(model_dict.fetch(attribute)))
|
1205
|
+
end
|
1206
|
+
ORDEREDDICT.each do |attribute|
|
1207
|
+
key_list, unordered_dict = model_dict.fetch(attribute)
|
1208
|
+
od = {}
|
1209
|
+
key_list.each do |key|
|
1210
|
+
od[key] = unordered_dict[key].transform_keys(&:to_sym)
|
1211
|
+
end
|
1212
|
+
model.instance_variable_set("@#{attribute}", od)
|
1213
|
+
end
|
1214
|
+
# Other attributes with special handling
|
1215
|
+
# fit_kwargs
|
1216
|
+
model.instance_variable_set(:@fit_kwargs, model_dict["fit_kwargs"].transform_keys(&:to_sym))
|
1217
|
+
# Params (Dict[str, np.ndarray])
|
1218
|
+
model.instance_variable_set(:@params, model_dict["params"].transform_values { |v| Numo::NArray.cast(v) })
|
1219
|
+
# Skipped attributes
|
1220
|
+
# model.stan_backend = nil
|
1221
|
+
model.instance_variable_set(:@stan_fit, nil)
|
1222
|
+
model
|
1223
|
+
end
|
1014
1224
|
end
|
1015
1225
|
end
|
data/lib/prophet/holidays.rb
CHANGED
@@ -2,25 +2,21 @@ module Prophet
|
|
2
2
|
module Holidays
|
3
3
|
def get_holiday_names(country)
|
4
4
|
years = (1995..2045).to_a
|
5
|
-
make_holidays_df(years, country)["holiday"].uniq
|
5
|
+
holiday_names = make_holidays_df(years, country)["holiday"].uniq
|
6
|
+
# TODO raise error in 0.4.0
|
7
|
+
logger.warn "Holidays in #{country} are not currently supported"
|
8
|
+
holiday_names
|
6
9
|
end
|
7
10
|
|
8
11
|
def make_holidays_df(year_list, country)
|
9
12
|
holidays_df[(holidays_df["country"] == country) & (holidays_df["year"].in?(year_list))][["ds", "holiday"]]
|
10
13
|
end
|
11
14
|
|
12
|
-
# TODO
|
15
|
+
# TODO improve performance
|
13
16
|
def holidays_df
|
14
17
|
@holidays_df ||= begin
|
15
|
-
holidays = {"ds" => [], "holiday" => [], "country" => [], "year" => []}
|
16
18
|
holidays_file = File.expand_path("../../data-raw/generated_holidays.csv", __dir__)
|
17
|
-
|
18
|
-
holidays["ds"] << row["ds"]
|
19
|
-
holidays["holiday"] << row["holiday"]
|
20
|
-
holidays["country"] << row["country"]
|
21
|
-
holidays["year"] << row["year"]
|
22
|
-
end
|
23
|
-
Rover::DataFrame.new(holidays)
|
19
|
+
Rover.read_csv(holidays_file, converters: [:date, :numeric])
|
24
20
|
end
|
25
21
|
end
|
26
22
|
end
|