prophet-rb 0.3.1 → 0.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +17 -2
- data/LICENSE.txt +1 -1
- data/README.md +149 -2
- data/data-raw/LICENSE-holidays.txt +20 -0
- data/data-raw/README.md +3 -0
- data/data-raw/generated_holidays.csv +29302 -61443
- data/lib/prophet/diagnostics.rb +349 -0
- data/lib/prophet/forecaster.rb +214 -4
- data/lib/prophet/holidays.rb +6 -10
- data/lib/prophet/plot.rb +56 -6
- data/lib/prophet/stan_backend.rb +10 -1
- data/lib/prophet/version.rb +1 -1
- data/lib/prophet.rb +23 -7
- data/stan/{unix/prophet.stan → prophet.stan} +8 -7
- data/vendor/aarch64-linux/bin/prophet +0 -0
- data/vendor/aarch64-linux/lib/libtbb.so.2 +0 -0
- data/vendor/aarch64-linux/lib/libtbbmalloc.so.2 +0 -0
- data/vendor/aarch64-linux/lib/libtbbmalloc_proxy.so.2 +0 -0
- data/vendor/aarch64-linux/licenses/sundials-license.txt +25 -63
- data/vendor/aarch64-linux/licenses/sundials-notice.txt +21 -0
- data/vendor/arm64-darwin/bin/prophet +0 -0
- data/vendor/arm64-darwin/lib/libtbb.dylib +0 -0
- data/vendor/arm64-darwin/lib/libtbbmalloc.dylib +0 -0
- data/vendor/arm64-darwin/licenses/sundials-license.txt +25 -63
- data/vendor/arm64-darwin/licenses/sundials-notice.txt +21 -0
- data/vendor/x86_64-darwin/bin/prophet +0 -0
- data/vendor/x86_64-darwin/lib/libtbb.dylib +0 -0
- data/vendor/x86_64-darwin/lib/libtbbmalloc.dylib +0 -0
- data/vendor/x86_64-darwin/licenses/sundials-license.txt +25 -63
- data/vendor/x86_64-darwin/licenses/sundials-notice.txt +21 -0
- data/vendor/x86_64-linux/bin/prophet +0 -0
- data/vendor/x86_64-linux/lib/libtbb.so.2 +0 -0
- data/vendor/x86_64-linux/lib/libtbbmalloc.so.2 +0 -0
- data/vendor/x86_64-linux/lib/libtbbmalloc_proxy.so.2 +0 -0
- data/vendor/x86_64-linux/licenses/sundials-license.txt +25 -63
- data/vendor/x86_64-linux/licenses/sundials-notice.txt +21 -0
- metadata +10 -4
- data/stan/win/prophet.stan +0 -175
@@ -0,0 +1,349 @@
|
|
1
|
+
module Prophet
|
2
|
+
module Diagnostics
|
3
|
+
def self.generate_cutoffs(df, horizon, initial, period)
|
4
|
+
# Last cutoff is 'latest date in data - horizon' date
|
5
|
+
cutoff = df["ds"].max - horizon
|
6
|
+
if cutoff < df["ds"].min
|
7
|
+
raise Error, "Less data than horizon."
|
8
|
+
end
|
9
|
+
result = [cutoff]
|
10
|
+
while result[-1] >= df["ds"].min + initial
|
11
|
+
cutoff -= period
|
12
|
+
# If data does not exist in data range (cutoff, cutoff + horizon]
|
13
|
+
if !(((df["ds"] > cutoff) & (df["ds"] <= cutoff + horizon)).any?)
|
14
|
+
# Next cutoff point is 'last date before cutoff in data - horizon'
|
15
|
+
if cutoff > df["ds"].min
|
16
|
+
closest_date = df[df["ds"] <= cutoff].max["ds"]
|
17
|
+
cutoff = closest_date - horizon
|
18
|
+
end
|
19
|
+
# else no data left, leave cutoff as is, it will be dropped.
|
20
|
+
end
|
21
|
+
result << cutoff
|
22
|
+
end
|
23
|
+
result = result[0...-1]
|
24
|
+
if result.length == 0
|
25
|
+
raise Error, "Less data than horizon after initial window. Make horizon or initial shorter."
|
26
|
+
end
|
27
|
+
# logger.info("Making #{result.length} forecasts with cutoffs between #{result[-1]} and #{result[0]}")
|
28
|
+
result.reverse
|
29
|
+
end
|
30
|
+
|
31
|
+
def self.cross_validation(model, horizon:, period: nil, initial: nil, cutoffs: nil)
|
32
|
+
if model.history.nil?
|
33
|
+
raise Error, "Model has not been fit. Fitting the model provides contextual parameters for cross validation."
|
34
|
+
end
|
35
|
+
|
36
|
+
df = model.history.dup
|
37
|
+
horizon = timedelta(horizon)
|
38
|
+
|
39
|
+
predict_columns = ["ds", "yhat"]
|
40
|
+
if model.uncertainty_samples
|
41
|
+
predict_columns.concat(["yhat_lower", "yhat_upper"])
|
42
|
+
end
|
43
|
+
|
44
|
+
# Identify largest seasonality period
|
45
|
+
period_max = 0.0
|
46
|
+
model.seasonalities.each do |_, s|
|
47
|
+
period_max = [period_max, s[:period]].max
|
48
|
+
end
|
49
|
+
seasonality_dt = timedelta("#{period_max} days")
|
50
|
+
|
51
|
+
if cutoffs.nil?
|
52
|
+
# Set period
|
53
|
+
period = period.nil? ? 0.5 * horizon : timedelta(period)
|
54
|
+
|
55
|
+
# Set initial
|
56
|
+
initial = initial.nil? ? [3 * horizon, seasonality_dt].max : timedelta(initial)
|
57
|
+
|
58
|
+
# Compute Cutoffs
|
59
|
+
cutoffs = generate_cutoffs(df, horizon, initial, period)
|
60
|
+
else
|
61
|
+
# add validation of the cutoff to make sure that the min cutoff is strictly greater than the min date in the history
|
62
|
+
if cutoffs.min <= df["ds"].min
|
63
|
+
raise Error, "Minimum cutoff value is not strictly greater than min date in history"
|
64
|
+
end
|
65
|
+
# max value of cutoffs is <= (end date minus horizon)
|
66
|
+
end_date_minus_horizon = df["ds"].max - horizon
|
67
|
+
if cutoffs.max > end_date_minus_horizon
|
68
|
+
raise Error, "Maximum cutoff value is greater than end date minus horizon, no value for cross-validation remaining"
|
69
|
+
end
|
70
|
+
initial = cutoffs[0] - df["ds"].min
|
71
|
+
end
|
72
|
+
|
73
|
+
# Check if the initial window
|
74
|
+
# (that is, the amount of time between the start of the history and the first cutoff)
|
75
|
+
# is less than the maximum seasonality period
|
76
|
+
if initial < seasonality_dt
|
77
|
+
msg = "Seasonality has period of #{period_max} days "
|
78
|
+
msg += "which is larger than initial window. "
|
79
|
+
msg += "Consider increasing initial."
|
80
|
+
# logger.warn(msg)
|
81
|
+
end
|
82
|
+
|
83
|
+
predicts = cutoffs.map { |cutoff| single_cutoff_forecast(df, model, cutoff, horizon, predict_columns) }
|
84
|
+
|
85
|
+
# Combine all predicted DataFrame into one DataFrame
|
86
|
+
predicts.reduce(Rover::DataFrame.new) { |memo, v| memo.concat(v) }
|
87
|
+
end
|
88
|
+
|
89
|
+
def self.single_cutoff_forecast(df, model, cutoff, horizon, predict_columns)
|
90
|
+
# Generate new object with copying fitting options
|
91
|
+
m = prophet_copy(model, cutoff)
|
92
|
+
# Train model
|
93
|
+
history_c = df[df["ds"] <= cutoff]
|
94
|
+
if history_c.shape[0] < 2
|
95
|
+
raise Error, "Less than two datapoints before cutoff. Increase initial window."
|
96
|
+
end
|
97
|
+
m.fit(history_c, **model.fit_kwargs)
|
98
|
+
# Calculate yhat
|
99
|
+
index_predicted = (df["ds"] > cutoff) & (df["ds"] <= cutoff + horizon)
|
100
|
+
# Get the columns for the future dataframe
|
101
|
+
columns = ["ds"]
|
102
|
+
if m.growth == "logistic"
|
103
|
+
columns << "cap"
|
104
|
+
if m.logistic_floor
|
105
|
+
columns << "floor"
|
106
|
+
end
|
107
|
+
end
|
108
|
+
columns.concat(m.extra_regressors.keys)
|
109
|
+
columns.concat(m.seasonalities.map { |_, props| props[:condition_name] }.compact)
|
110
|
+
yhat = m.predict(df[index_predicted][columns])
|
111
|
+
# Merge yhat(predicts), y(df, original data) and cutoff
|
112
|
+
yhat[predict_columns].merge(df[index_predicted][["y"]]).merge(Rover::DataFrame.new({"cutoff" => [cutoff] * yhat.length}))
|
113
|
+
end
|
114
|
+
|
115
|
+
def self.prophet_copy(m, cutoff = nil)
|
116
|
+
if m.history.nil?
|
117
|
+
raise Error, "This is for copying a fitted Prophet object."
|
118
|
+
end
|
119
|
+
|
120
|
+
if m.specified_changepoints
|
121
|
+
changepoints = m.changepoints
|
122
|
+
if !cutoff.nil?
|
123
|
+
# Filter change points '< cutoff'
|
124
|
+
last_history_date = m.history["ds"][m.history["ds"] <= cutoff].max
|
125
|
+
changepoints = changepoints[changepoints < last_history_date]
|
126
|
+
end
|
127
|
+
else
|
128
|
+
changepoints = nil
|
129
|
+
end
|
130
|
+
|
131
|
+
# Auto seasonalities are set to False because they are already set in
|
132
|
+
# m.seasonalities.
|
133
|
+
m2 = m.class.new(
|
134
|
+
growth: m.growth,
|
135
|
+
n_changepoints: m.n_changepoints,
|
136
|
+
changepoint_range: m.changepoint_range,
|
137
|
+
changepoints: changepoints,
|
138
|
+
yearly_seasonality: false,
|
139
|
+
weekly_seasonality: false,
|
140
|
+
daily_seasonality: false,
|
141
|
+
holidays: m.holidays,
|
142
|
+
seasonality_mode: m.seasonality_mode,
|
143
|
+
seasonality_prior_scale: m.seasonality_prior_scale,
|
144
|
+
changepoint_prior_scale: m.changepoint_prior_scale,
|
145
|
+
holidays_prior_scale: m.holidays_prior_scale,
|
146
|
+
mcmc_samples: m.mcmc_samples,
|
147
|
+
interval_width: m.interval_width,
|
148
|
+
uncertainty_samples: m.uncertainty_samples
|
149
|
+
)
|
150
|
+
m2.extra_regressors = deepcopy(m.extra_regressors)
|
151
|
+
m2.seasonalities = deepcopy(m.seasonalities)
|
152
|
+
m2.country_holidays = deepcopy(m.country_holidays)
|
153
|
+
m2
|
154
|
+
end
|
155
|
+
|
156
|
+
def self.timedelta(value)
|
157
|
+
if value.is_a?(Numeric)
|
158
|
+
# ActiveSupport::Duration is a numeric
|
159
|
+
value
|
160
|
+
elsif (m = /\A(\d+(\.\d+)?) days\z/.match(value))
|
161
|
+
m[1].to_f * 86400
|
162
|
+
else
|
163
|
+
raise Error, "Unknown time delta"
|
164
|
+
end
|
165
|
+
end
|
166
|
+
|
167
|
+
def self.deepcopy(value)
|
168
|
+
if value.is_a?(Hash)
|
169
|
+
value.to_h { |k, v| [deepcopy(k), deepcopy(v)] }
|
170
|
+
elsif value.is_a?(Array)
|
171
|
+
value.map { |v| deepcopy(v) }
|
172
|
+
else
|
173
|
+
value.dup
|
174
|
+
end
|
175
|
+
end
|
176
|
+
|
177
|
+
def self.performance_metrics(df, metrics: nil, rolling_window: 0.1, monthly: false)
|
178
|
+
valid_metrics = ["mse", "rmse", "mae", "mape", "mdape", "smape", "coverage"]
|
179
|
+
if metrics.nil?
|
180
|
+
metrics = valid_metrics
|
181
|
+
end
|
182
|
+
if (df["yhat_lower"].nil? || df["yhat_upper"].nil?) && metrics.include?("coverage")
|
183
|
+
metrics.delete("coverage")
|
184
|
+
end
|
185
|
+
if metrics.uniq.length != metrics.length
|
186
|
+
raise ArgumentError, "Input metrics must be a list of unique values"
|
187
|
+
end
|
188
|
+
if !Set.new(metrics).subset?(Set.new(valid_metrics))
|
189
|
+
raise ArgumentError, "Valid values for metrics are: #{valid_metrics}"
|
190
|
+
end
|
191
|
+
df_m = df.dup
|
192
|
+
if monthly
|
193
|
+
raise Error, "Not implemented yet"
|
194
|
+
# df_m["horizon"] = df_m["ds"].dt.to_period("M").astype(int) - df_m["cutoff"].dt.to_period("M").astype(int)
|
195
|
+
else
|
196
|
+
df_m["horizon"] = df_m["ds"] - df_m["cutoff"]
|
197
|
+
end
|
198
|
+
df_m.sort_by! { |r| r["horizon"] }
|
199
|
+
if metrics.include?("mape") && df_m["y"].abs.min < 1e-8
|
200
|
+
# logger.info("Skipping MAPE because y close to 0")
|
201
|
+
metrics.delete("mape")
|
202
|
+
end
|
203
|
+
if metrics.length == 0
|
204
|
+
return nil
|
205
|
+
end
|
206
|
+
w = (rolling_window * df_m.shape[0]).to_i
|
207
|
+
if w >= 0
|
208
|
+
w = [w, 1].max
|
209
|
+
w = [w, df_m.shape[0]].min
|
210
|
+
end
|
211
|
+
# Compute all metrics
|
212
|
+
dfs = {}
|
213
|
+
metrics.each do |metric|
|
214
|
+
dfs[metric] = send(metric, df_m, w)
|
215
|
+
end
|
216
|
+
res = dfs[metrics[0]]
|
217
|
+
metrics.each do |metric|
|
218
|
+
res_m = dfs[metric]
|
219
|
+
res[metric] = res_m[metric]
|
220
|
+
end
|
221
|
+
res
|
222
|
+
end
|
223
|
+
|
224
|
+
def self.rolling_mean_by_h(x, h, w, name)
|
225
|
+
# Aggregate over h
|
226
|
+
df = Rover::DataFrame.new({"x" => x, "h" => h})
|
227
|
+
df2 = df.group("h").sum("x").inner_join(df.group("h").count).sort_by { |r| r["h"] }
|
228
|
+
xs = df2["sum_x"]
|
229
|
+
ns = df2["count"]
|
230
|
+
hs = df2["h"]
|
231
|
+
|
232
|
+
trailing_i = df2.length - 1
|
233
|
+
x_sum = 0
|
234
|
+
n_sum = 0
|
235
|
+
# We don't know output size but it is bounded by len(df2)
|
236
|
+
res_x = [nil] * df2.length
|
237
|
+
|
238
|
+
# Start from the right and work backwards
|
239
|
+
(df2.length - 1).downto(0) do |i|
|
240
|
+
x_sum += xs[i]
|
241
|
+
n_sum += ns[i]
|
242
|
+
while n_sum >= w
|
243
|
+
# Include points from the previous horizon. All of them if still
|
244
|
+
# less than w, otherwise weight the mean by the difference
|
245
|
+
excess_n = n_sum - w
|
246
|
+
excess_x = excess_n * xs[i] / ns[i]
|
247
|
+
res_x[trailing_i] = (x_sum - excess_x) / w
|
248
|
+
x_sum -= xs[trailing_i]
|
249
|
+
n_sum -= ns[trailing_i]
|
250
|
+
trailing_i -= 1
|
251
|
+
end
|
252
|
+
end
|
253
|
+
|
254
|
+
res_h = hs[(trailing_i + 1)..-1]
|
255
|
+
res_x = res_x[(trailing_i + 1)..-1]
|
256
|
+
|
257
|
+
Rover::DataFrame.new({"horizon" => res_h, name => res_x})
|
258
|
+
end
|
259
|
+
|
260
|
+
def self.rolling_median_by_h(x, h, w, name)
|
261
|
+
# Aggregate over h
|
262
|
+
df = Rover::DataFrame.new({"x" => x, "h" => h})
|
263
|
+
grouped = df.group("h")
|
264
|
+
df2 = grouped.count.sort_by { |r| r["h"] }
|
265
|
+
hs = df2["h"]
|
266
|
+
|
267
|
+
res_h = []
|
268
|
+
res_x = []
|
269
|
+
# Start from the right and work backwards
|
270
|
+
i = hs.length - 1
|
271
|
+
while i >= 0
|
272
|
+
h_i = hs[i]
|
273
|
+
xs = df[df["h"] == h_i]["x"].to_a
|
274
|
+
|
275
|
+
next_idx_to_add = (h == h_i).to_numo.cast_to(Numo::UInt8).argmax - 1
|
276
|
+
while xs.length < w && next_idx_to_add >= 0
|
277
|
+
# Include points from the previous horizon. All of them if still
|
278
|
+
# less than w, otherwise just enough to get to w.
|
279
|
+
xs << x[next_idx_to_add]
|
280
|
+
next_idx_to_add -= 1
|
281
|
+
end
|
282
|
+
if xs.length < w
|
283
|
+
# Ran out of points before getting enough.
|
284
|
+
break
|
285
|
+
end
|
286
|
+
res_h << hs[i]
|
287
|
+
res_x << Rover::Vector.new(xs).median
|
288
|
+
i -= 1
|
289
|
+
end
|
290
|
+
res_h.reverse!
|
291
|
+
res_x.reverse!
|
292
|
+
Rover::DataFrame.new({"horizon" => res_h, name => res_x})
|
293
|
+
end
|
294
|
+
|
295
|
+
def self.mse(df, w)
|
296
|
+
se = (df["y"] - df["yhat"]) ** 2
|
297
|
+
if w < 0
|
298
|
+
return Rover::DataFrame.new({"horizon" => df["horizon"], "mse" => se})
|
299
|
+
end
|
300
|
+
rolling_mean_by_h(se, df["horizon"], w, "mse")
|
301
|
+
end
|
302
|
+
|
303
|
+
def self.rmse(df, w)
|
304
|
+
res = mse(df, w)
|
305
|
+
res["rmse"] = res.delete("mse").map { |v| Math.sqrt(v) }
|
306
|
+
res
|
307
|
+
end
|
308
|
+
|
309
|
+
def self.mae(df, w)
|
310
|
+
ae = (df["y"] - df["yhat"]).abs
|
311
|
+
if w < 0
|
312
|
+
return Rover::DataFrame.new({"horizon" => df["horizon"], "mae" => ae})
|
313
|
+
end
|
314
|
+
rolling_mean_by_h(ae, df["horizon"], w, "mae")
|
315
|
+
end
|
316
|
+
|
317
|
+
def self.mape(df, w)
|
318
|
+
ape = ((df["y"] - df["yhat"]) / df["y"]).abs
|
319
|
+
if w < 0
|
320
|
+
return Rover::DataFrame.new({"horizon" => df["horizon"], "mape" => ape})
|
321
|
+
end
|
322
|
+
rolling_mean_by_h(ape, df["horizon"], w, "mape")
|
323
|
+
end
|
324
|
+
|
325
|
+
def self.mdape(df, w)
|
326
|
+
ape = ((df["y"] - df["yhat"]) / df["y"]).abs
|
327
|
+
if w < 0
|
328
|
+
return Rover::DataFrame.new({"horizon" => df["horizon"], "mdape" => ape})
|
329
|
+
end
|
330
|
+
rolling_median_by_h(ape, df["horizon"], w, "mdape")
|
331
|
+
end
|
332
|
+
|
333
|
+
def self.smape(df, w)
|
334
|
+
sape = (df["y"] - df["yhat"]).abs / ((df["y"].abs + df["yhat"].abs) / 2)
|
335
|
+
if w < 0
|
336
|
+
return Rover::DataFrame.new({"horizon" => df["horizon"], "smape" => sape})
|
337
|
+
end
|
338
|
+
rolling_mean_by_h(sape, df["horizon"], w, "smape")
|
339
|
+
end
|
340
|
+
|
341
|
+
def self.coverage(df, w)
|
342
|
+
is_covered = (df["y"] >= df["yhat_lower"]) & (df["y"] <= df["yhat_upper"])
|
343
|
+
if w < 0
|
344
|
+
return Rover::DataFrame.new({"horizon" => df["horizon"], "coverage" => is_covered})
|
345
|
+
end
|
346
|
+
rolling_mean_by_h(is_covered.to(:float), df["horizon"], w, "coverage")
|
347
|
+
end
|
348
|
+
end
|
349
|
+
end
|
data/lib/prophet/forecaster.rb
CHANGED
@@ -3,7 +3,14 @@ module Prophet
|
|
3
3
|
include Holidays
|
4
4
|
include Plot
|
5
5
|
|
6
|
-
attr_reader :logger, :params, :train_holiday_names
|
6
|
+
attr_reader :logger, :params, :train_holiday_names,
|
7
|
+
:history, :seasonalities, :specified_changepoints, :fit_kwargs,
|
8
|
+
:growth, :changepoints, :n_changepoints, :changepoint_range,
|
9
|
+
:holidays, :seasonality_mode, :seasonality_prior_scale,
|
10
|
+
:holidays_prior_scale, :changepoint_prior_scale, :mcmc_samples,
|
11
|
+
:interval_width, :uncertainty_samples
|
12
|
+
|
13
|
+
attr_accessor :extra_regressors, :seasonalities, :country_holidays
|
7
14
|
|
8
15
|
def initialize(
|
9
16
|
growth: "linear",
|
@@ -176,8 +183,10 @@ module Prophet
|
|
176
183
|
|
177
184
|
initialize_scales(initialize_scales, df)
|
178
185
|
|
179
|
-
if @logistic_floor
|
180
|
-
|
186
|
+
if @logistic_floor
|
187
|
+
unless df.include?("floor")
|
188
|
+
raise ArgumentError, "Expected column \"floor\"."
|
189
|
+
end
|
181
190
|
else
|
182
191
|
df["floor"] = 0
|
183
192
|
end
|
@@ -207,7 +216,12 @@ module Prophet
|
|
207
216
|
def initialize_scales(initialize_scales, df)
|
208
217
|
return unless initialize_scales
|
209
218
|
|
210
|
-
|
219
|
+
if @growth == "logistic" && df.include?("floor")
|
220
|
+
@logistic_floor = true
|
221
|
+
floor = df["floor"]
|
222
|
+
else
|
223
|
+
floor = 0.0
|
224
|
+
end
|
211
225
|
@y_scale = (df["y"] - floor).abs.max
|
212
226
|
@y_scale = 1 if @y_scale == 0
|
213
227
|
@start = df["ds"].min
|
@@ -386,6 +400,12 @@ module Prophet
|
|
386
400
|
|
387
401
|
def add_country_holidays(country_name)
|
388
402
|
raise Error, "Country holidays must be added prior to model fitting." if @history
|
403
|
+
|
404
|
+
# Fix for previously documented keyword argument
|
405
|
+
if country_name.is_a?(Hash) && country_name[:country_name]
|
406
|
+
country_name = country_name[:country_name]
|
407
|
+
end
|
408
|
+
|
389
409
|
# Validate names.
|
390
410
|
get_holiday_names(country_name).each do |name|
|
391
411
|
# Allow merging with existing holidays
|
@@ -965,6 +985,12 @@ module Prophet
|
|
965
985
|
Rover::DataFrame.new({"ds" => dates})
|
966
986
|
end
|
967
987
|
|
988
|
+
def to_json
|
989
|
+
require "json"
|
990
|
+
|
991
|
+
JSON.generate(as_json)
|
992
|
+
end
|
993
|
+
|
968
994
|
private
|
969
995
|
|
970
996
|
# Time is preferred over DateTime in Ruby docs
|
@@ -1011,5 +1037,189 @@ module Prophet
|
|
1011
1037
|
u = Numo::DFloat.new(size).rand(-0.5, 0.5)
|
1012
1038
|
loc - scale * u.sign * Numo::NMath.log(1 - 2 * u.abs)
|
1013
1039
|
end
|
1040
|
+
|
1041
|
+
SIMPLE_ATTRIBUTES = [
|
1042
|
+
"growth", "n_changepoints", "specified_changepoints", "changepoint_range",
|
1043
|
+
"yearly_seasonality", "weekly_seasonality", "daily_seasonality",
|
1044
|
+
"seasonality_mode", "seasonality_prior_scale", "changepoint_prior_scale",
|
1045
|
+
"holidays_prior_scale", "mcmc_samples", "interval_width", "uncertainty_samples",
|
1046
|
+
"y_scale", "logistic_floor", "country_holidays", "component_modes"
|
1047
|
+
]
|
1048
|
+
|
1049
|
+
PD_SERIES = ["changepoints", "history_dates", "train_holiday_names"]
|
1050
|
+
|
1051
|
+
PD_TIMESTAMP = ["start"]
|
1052
|
+
|
1053
|
+
PD_TIMEDELTA = ["t_scale"]
|
1054
|
+
|
1055
|
+
PD_DATAFRAME = ["holidays", "history", "train_component_cols"]
|
1056
|
+
|
1057
|
+
NP_ARRAY = ["changepoints_t"]
|
1058
|
+
|
1059
|
+
ORDEREDDICT = ["seasonalities", "extra_regressors"]
|
1060
|
+
|
1061
|
+
def as_json
|
1062
|
+
if @history.nil?
|
1063
|
+
raise Error, "This can only be used to serialize models that have already been fit."
|
1064
|
+
end
|
1065
|
+
|
1066
|
+
model_dict =
|
1067
|
+
SIMPLE_ATTRIBUTES.to_h do |attribute|
|
1068
|
+
[attribute, instance_variable_get("@#{attribute}")]
|
1069
|
+
end
|
1070
|
+
|
1071
|
+
# Handle attributes of non-core types
|
1072
|
+
PD_SERIES.each do |attribute|
|
1073
|
+
if instance_variable_get("@#{attribute}").nil?
|
1074
|
+
model_dict[attribute] = nil
|
1075
|
+
else
|
1076
|
+
v = instance_variable_get("@#{attribute}")
|
1077
|
+
d = {
|
1078
|
+
"name" => "ds",
|
1079
|
+
"index" => v.size.times.to_a,
|
1080
|
+
"data" => v.to_a.map { |v| v.iso8601(3) }
|
1081
|
+
}
|
1082
|
+
model_dict[attribute] = JSON.generate(d)
|
1083
|
+
end
|
1084
|
+
end
|
1085
|
+
PD_TIMESTAMP.each do |attribute|
|
1086
|
+
model_dict[attribute] = instance_variable_get("@#{attribute}").to_f
|
1087
|
+
end
|
1088
|
+
PD_TIMEDELTA.each do |attribute|
|
1089
|
+
model_dict[attribute] = instance_variable_get("@#{attribute}").to_f
|
1090
|
+
end
|
1091
|
+
PD_DATAFRAME.each do |attribute|
|
1092
|
+
if instance_variable_get("@#{attribute}").nil?
|
1093
|
+
model_dict[attribute] = nil
|
1094
|
+
else
|
1095
|
+
# use same format as Pandas
|
1096
|
+
v = instance_variable_get("@#{attribute}")
|
1097
|
+
|
1098
|
+
v = v.dup
|
1099
|
+
v["ds"] = v["ds"].map { |v| v.iso8601(3) } if v["ds"]
|
1100
|
+
v.delete("col")
|
1101
|
+
|
1102
|
+
fields =
|
1103
|
+
v.types.map do |k, t|
|
1104
|
+
type =
|
1105
|
+
case t
|
1106
|
+
when :object
|
1107
|
+
"datetime"
|
1108
|
+
when :int64
|
1109
|
+
"integer"
|
1110
|
+
else
|
1111
|
+
"number"
|
1112
|
+
end
|
1113
|
+
{"name" => k, "type" => type}
|
1114
|
+
end
|
1115
|
+
|
1116
|
+
d = {
|
1117
|
+
"schema" => {
|
1118
|
+
"fields" => fields,
|
1119
|
+
"pandas_version" => "0.20.0"
|
1120
|
+
},
|
1121
|
+
"data" => v.to_a
|
1122
|
+
}
|
1123
|
+
model_dict[attribute] = JSON.generate(d)
|
1124
|
+
end
|
1125
|
+
end
|
1126
|
+
NP_ARRAY.each do |attribute|
|
1127
|
+
model_dict[attribute] = instance_variable_get("@#{attribute}").to_a
|
1128
|
+
end
|
1129
|
+
ORDEREDDICT.each do |attribute|
|
1130
|
+
model_dict[attribute] = [
|
1131
|
+
instance_variable_get("@#{attribute}").keys,
|
1132
|
+
instance_variable_get("@#{attribute}").transform_keys(&:to_s)
|
1133
|
+
]
|
1134
|
+
end
|
1135
|
+
# Other attributes with special handling
|
1136
|
+
# fit_kwargs -> Transform any numpy types before serializing.
|
1137
|
+
# They do not need to be transformed back on deserializing.
|
1138
|
+
# TODO deep copy
|
1139
|
+
fit_kwargs = @fit_kwargs.to_h { |k, v| [k.to_s, v.dup] }
|
1140
|
+
if fit_kwargs.key?("init")
|
1141
|
+
fit_kwargs["init"].each do |k, v|
|
1142
|
+
if v.is_a?(Numo::NArray)
|
1143
|
+
fit_kwargs["init"][k] = v.to_a
|
1144
|
+
# elsif v.is_a?(Float)
|
1145
|
+
# fit_kwargs["init"][k] = v.to_f
|
1146
|
+
end
|
1147
|
+
end
|
1148
|
+
end
|
1149
|
+
model_dict["fit_kwargs"] = fit_kwargs
|
1150
|
+
|
1151
|
+
# Params (Dict[str, np.ndarray])
|
1152
|
+
model_dict["params"] = params.transform_values(&:to_a)
|
1153
|
+
# Attributes that are skipped: stan_fit, stan_backend
|
1154
|
+
# Returns 1.0 for Prophet 1.1
|
1155
|
+
model_dict["__prophet_version"] = "1.0"
|
1156
|
+
model_dict
|
1157
|
+
end
|
1158
|
+
|
1159
|
+
def self.from_json(model_json)
|
1160
|
+
require "json"
|
1161
|
+
|
1162
|
+
model_dict = JSON.parse(model_json)
|
1163
|
+
|
1164
|
+
# We will overwrite all attributes set in init anyway
|
1165
|
+
model = Prophet.new
|
1166
|
+
# Simple types
|
1167
|
+
SIMPLE_ATTRIBUTES.each do |attribute|
|
1168
|
+
model.instance_variable_set("@#{attribute}", model_dict.fetch(attribute))
|
1169
|
+
end
|
1170
|
+
PD_SERIES.each do |attribute|
|
1171
|
+
if model_dict[attribute].nil?
|
1172
|
+
model.instance_variable_set("@#{attribute}", nil)
|
1173
|
+
else
|
1174
|
+
d = JSON.parse(model_dict.fetch(attribute))
|
1175
|
+
s = Rover::Vector.new(d["data"])
|
1176
|
+
if d["name"] == "ds"
|
1177
|
+
s = s.map { |v| Time.parse(v).utc }
|
1178
|
+
end
|
1179
|
+
model.instance_variable_set("@#{attribute}", s)
|
1180
|
+
end
|
1181
|
+
end
|
1182
|
+
PD_TIMESTAMP.each do |attribute|
|
1183
|
+
model.instance_variable_set("@#{attribute}", Time.at(model_dict.fetch(attribute)))
|
1184
|
+
end
|
1185
|
+
PD_TIMEDELTA.each do |attribute|
|
1186
|
+
model.instance_variable_set("@#{attribute}", model_dict.fetch(attribute).to_f)
|
1187
|
+
end
|
1188
|
+
PD_DATAFRAME.each do |attribute|
|
1189
|
+
if model_dict[attribute].nil?
|
1190
|
+
model.instance_variable_set("@#{attribute}", nil)
|
1191
|
+
else
|
1192
|
+
d = JSON.parse(model_dict.fetch(attribute))
|
1193
|
+
df = Rover::DataFrame.new(d["data"])
|
1194
|
+
df["ds"] = df["ds"].map { |v| Time.parse(v).utc } if df["ds"]
|
1195
|
+
if attribute == "train_component_cols"
|
1196
|
+
# Special handling because of named index column
|
1197
|
+
# df.columns.name = 'component'
|
1198
|
+
# df.index.name = 'col'
|
1199
|
+
end
|
1200
|
+
model.instance_variable_set("@#{attribute}", df)
|
1201
|
+
end
|
1202
|
+
end
|
1203
|
+
NP_ARRAY.each do |attribute|
|
1204
|
+
model.instance_variable_set("@#{attribute}", Numo::NArray.cast(model_dict.fetch(attribute)))
|
1205
|
+
end
|
1206
|
+
ORDEREDDICT.each do |attribute|
|
1207
|
+
key_list, unordered_dict = model_dict.fetch(attribute)
|
1208
|
+
od = {}
|
1209
|
+
key_list.each do |key|
|
1210
|
+
od[key] = unordered_dict[key].transform_keys(&:to_sym)
|
1211
|
+
end
|
1212
|
+
model.instance_variable_set("@#{attribute}", od)
|
1213
|
+
end
|
1214
|
+
# Other attributes with special handling
|
1215
|
+
# fit_kwargs
|
1216
|
+
model.instance_variable_set(:@fit_kwargs, model_dict["fit_kwargs"].transform_keys(&:to_sym))
|
1217
|
+
# Params (Dict[str, np.ndarray])
|
1218
|
+
model.instance_variable_set(:@params, model_dict["params"].transform_values { |v| Numo::NArray.cast(v) })
|
1219
|
+
# Skipped attributes
|
1220
|
+
# model.stan_backend = nil
|
1221
|
+
model.instance_variable_set(:@stan_fit, nil)
|
1222
|
+
model
|
1223
|
+
end
|
1014
1224
|
end
|
1015
1225
|
end
|
data/lib/prophet/holidays.rb
CHANGED
@@ -2,25 +2,21 @@ module Prophet
|
|
2
2
|
module Holidays
|
3
3
|
def get_holiday_names(country)
|
4
4
|
years = (1995..2045).to_a
|
5
|
-
make_holidays_df(years, country)["holiday"].uniq
|
5
|
+
holiday_names = make_holidays_df(years, country)["holiday"].uniq
|
6
|
+
# TODO raise error in 0.4.0
|
7
|
+
logger.warn "Holidays in #{country} are not currently supported"
|
8
|
+
holiday_names
|
6
9
|
end
|
7
10
|
|
8
11
|
def make_holidays_df(year_list, country)
|
9
12
|
holidays_df[(holidays_df["country"] == country) & (holidays_df["year"].in?(year_list))][["ds", "holiday"]]
|
10
13
|
end
|
11
14
|
|
12
|
-
# TODO
|
15
|
+
# TODO improve performance
|
13
16
|
def holidays_df
|
14
17
|
@holidays_df ||= begin
|
15
|
-
holidays = {"ds" => [], "holiday" => [], "country" => [], "year" => []}
|
16
18
|
holidays_file = File.expand_path("../../data-raw/generated_holidays.csv", __dir__)
|
17
|
-
|
18
|
-
holidays["ds"] << row["ds"]
|
19
|
-
holidays["holiday"] << row["holiday"]
|
20
|
-
holidays["country"] << row["country"]
|
21
|
-
holidays["year"] << row["year"]
|
22
|
-
end
|
23
|
-
Rover::DataFrame.new(holidays)
|
19
|
+
Rover.read_csv(holidays_file, converters: [:date, :numeric])
|
24
20
|
end
|
25
21
|
end
|
26
22
|
end
|