prophet-rb 0.3.2 → 0.4.2

Sign up to get free protection for your applications and to get access to all the features.
Files changed (39) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +15 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +158 -6
  5. data/data-raw/LICENSE-holidays.txt +20 -0
  6. data/data-raw/README.md +3 -0
  7. data/data-raw/generated_holidays.csv +29302 -61443
  8. data/lib/prophet/diagnostics.rb +349 -0
  9. data/lib/prophet/forecaster.rb +219 -15
  10. data/lib/prophet/holidays.rb +5 -2
  11. data/lib/prophet/plot.rb +60 -10
  12. data/lib/prophet/stan_backend.rb +10 -1
  13. data/lib/prophet/version.rb +1 -1
  14. data/lib/prophet.rb +5 -0
  15. data/stan/{unix/prophet.stan → prophet.stan} +8 -7
  16. data/vendor/aarch64-linux/bin/prophet +0 -0
  17. data/vendor/aarch64-linux/lib/libtbb.so.2 +0 -0
  18. data/vendor/aarch64-linux/lib/libtbbmalloc.so.2 +0 -0
  19. data/vendor/aarch64-linux/lib/libtbbmalloc_proxy.so.2 +0 -0
  20. data/vendor/aarch64-linux/licenses/sundials-license.txt +25 -63
  21. data/vendor/aarch64-linux/licenses/sundials-notice.txt +21 -0
  22. data/vendor/arm64-darwin/bin/prophet +0 -0
  23. data/vendor/arm64-darwin/lib/libtbb.dylib +0 -0
  24. data/vendor/arm64-darwin/lib/libtbbmalloc.dylib +0 -0
  25. data/vendor/arm64-darwin/licenses/sundials-license.txt +25 -63
  26. data/vendor/arm64-darwin/licenses/sundials-notice.txt +21 -0
  27. data/vendor/x86_64-darwin/bin/prophet +0 -0
  28. data/vendor/x86_64-darwin/lib/libtbb.dylib +0 -0
  29. data/vendor/x86_64-darwin/lib/libtbbmalloc.dylib +0 -0
  30. data/vendor/x86_64-darwin/licenses/sundials-license.txt +25 -63
  31. data/vendor/x86_64-darwin/licenses/sundials-notice.txt +21 -0
  32. data/vendor/x86_64-linux/bin/prophet +0 -0
  33. data/vendor/x86_64-linux/lib/libtbb.so.2 +0 -0
  34. data/vendor/x86_64-linux/lib/libtbbmalloc.so.2 +0 -0
  35. data/vendor/x86_64-linux/lib/libtbbmalloc_proxy.so.2 +0 -0
  36. data/vendor/x86_64-linux/licenses/sundials-license.txt +25 -63
  37. data/vendor/x86_64-linux/licenses/sundials-notice.txt +21 -0
  38. metadata +10 -4
  39. data/stan/win/prophet.stan +0 -175
@@ -0,0 +1,349 @@
1
+ module Prophet
2
+ module Diagnostics
3
+ def self.generate_cutoffs(df, horizon, initial, period)
4
+ # Last cutoff is 'latest date in data - horizon' date
5
+ cutoff = df["ds"].max - horizon
6
+ if cutoff < df["ds"].min
7
+ raise Error, "Less data than horizon."
8
+ end
9
+ result = [cutoff]
10
+ while result[-1] >= df["ds"].min + initial
11
+ cutoff -= period
12
+ # If data does not exist in data range (cutoff, cutoff + horizon]
13
+ if !(((df["ds"] > cutoff) & (df["ds"] <= cutoff + horizon)).any?)
14
+ # Next cutoff point is 'last date before cutoff in data - horizon'
15
+ if cutoff > df["ds"].min
16
+ closest_date = df[df["ds"] <= cutoff].max["ds"]
17
+ cutoff = closest_date - horizon
18
+ end
19
+ # else no data left, leave cutoff as is, it will be dropped.
20
+ end
21
+ result << cutoff
22
+ end
23
+ result = result[0...-1]
24
+ if result.length == 0
25
+ raise Error, "Less data than horizon after initial window. Make horizon or initial shorter."
26
+ end
27
+ # logger.info("Making #{result.length} forecasts with cutoffs between #{result[-1]} and #{result[0]}")
28
+ result.reverse
29
+ end
30
+
31
+ def self.cross_validation(model, horizon:, period: nil, initial: nil, cutoffs: nil)
32
+ if model.history.nil?
33
+ raise Error, "Model has not been fit. Fitting the model provides contextual parameters for cross validation."
34
+ end
35
+
36
+ df = model.history.dup
37
+ horizon = timedelta(horizon)
38
+
39
+ predict_columns = ["ds", "yhat"]
40
+ if model.uncertainty_samples
41
+ predict_columns.concat(["yhat_lower", "yhat_upper"])
42
+ end
43
+
44
+ # Identify largest seasonality period
45
+ period_max = 0.0
46
+ model.seasonalities.each do |_, s|
47
+ period_max = [period_max, s[:period]].max
48
+ end
49
+ seasonality_dt = timedelta("#{period_max} days")
50
+
51
+ if cutoffs.nil?
52
+ # Set period
53
+ period = period.nil? ? 0.5 * horizon : timedelta(period)
54
+
55
+ # Set initial
56
+ initial = initial.nil? ? [3 * horizon, seasonality_dt].max : timedelta(initial)
57
+
58
+ # Compute Cutoffs
59
+ cutoffs = generate_cutoffs(df, horizon, initial, period)
60
+ else
61
+ # add validation of the cutoff to make sure that the min cutoff is strictly greater than the min date in the history
62
+ if cutoffs.min <= df["ds"].min
63
+ raise Error, "Minimum cutoff value is not strictly greater than min date in history"
64
+ end
65
+ # max value of cutoffs is <= (end date minus horizon)
66
+ end_date_minus_horizon = df["ds"].max - horizon
67
+ if cutoffs.max > end_date_minus_horizon
68
+ raise Error, "Maximum cutoff value is greater than end date minus horizon, no value for cross-validation remaining"
69
+ end
70
+ initial = cutoffs[0] - df["ds"].min
71
+ end
72
+
73
+ # Check if the initial window
74
+ # (that is, the amount of time between the start of the history and the first cutoff)
75
+ # is less than the maximum seasonality period
76
+ if initial < seasonality_dt
77
+ msg = "Seasonality has period of #{period_max} days "
78
+ msg += "which is larger than initial window. "
79
+ msg += "Consider increasing initial."
80
+ # logger.warn(msg)
81
+ end
82
+
83
+ predicts = cutoffs.map { |cutoff| single_cutoff_forecast(df, model, cutoff, horizon, predict_columns) }
84
+
85
+ # Combine all predicted DataFrame into one DataFrame
86
+ predicts.reduce(Rover::DataFrame.new) { |memo, v| memo.concat(v) }
87
+ end
88
+
89
+ def self.single_cutoff_forecast(df, model, cutoff, horizon, predict_columns)
90
+ # Generate new object with copying fitting options
91
+ m = prophet_copy(model, cutoff)
92
+ # Train model
93
+ history_c = df[df["ds"] <= cutoff]
94
+ if history_c.shape[0] < 2
95
+ raise Error, "Less than two datapoints before cutoff. Increase initial window."
96
+ end
97
+ m.fit(history_c, **model.fit_kwargs)
98
+ # Calculate yhat
99
+ index_predicted = (df["ds"] > cutoff) & (df["ds"] <= cutoff + horizon)
100
+ # Get the columns for the future dataframe
101
+ columns = ["ds"]
102
+ if m.growth == "logistic"
103
+ columns << "cap"
104
+ if m.logistic_floor
105
+ columns << "floor"
106
+ end
107
+ end
108
+ columns.concat(m.extra_regressors.keys)
109
+ columns.concat(m.seasonalities.map { |_, props| props[:condition_name] }.compact)
110
+ yhat = m.predict(df[index_predicted][columns])
111
+ # Merge yhat(predicts), y(df, original data) and cutoff
112
+ yhat[predict_columns].merge(df[index_predicted][["y"]]).merge(Rover::DataFrame.new({"cutoff" => [cutoff] * yhat.length}))
113
+ end
114
+
115
+ def self.prophet_copy(m, cutoff = nil)
116
+ if m.history.nil?
117
+ raise Error, "This is for copying a fitted Prophet object."
118
+ end
119
+
120
+ if m.specified_changepoints
121
+ changepoints = m.changepoints
122
+ if !cutoff.nil?
123
+ # Filter change points '< cutoff'
124
+ last_history_date = m.history["ds"][m.history["ds"] <= cutoff].max
125
+ changepoints = changepoints[changepoints < last_history_date]
126
+ end
127
+ else
128
+ changepoints = nil
129
+ end
130
+
131
+ # Auto seasonalities are set to False because they are already set in
132
+ # m.seasonalities.
133
+ m2 = m.class.new(
134
+ growth: m.growth,
135
+ n_changepoints: m.n_changepoints,
136
+ changepoint_range: m.changepoint_range,
137
+ changepoints: changepoints,
138
+ yearly_seasonality: false,
139
+ weekly_seasonality: false,
140
+ daily_seasonality: false,
141
+ holidays: m.holidays,
142
+ seasonality_mode: m.seasonality_mode,
143
+ seasonality_prior_scale: m.seasonality_prior_scale,
144
+ changepoint_prior_scale: m.changepoint_prior_scale,
145
+ holidays_prior_scale: m.holidays_prior_scale,
146
+ mcmc_samples: m.mcmc_samples,
147
+ interval_width: m.interval_width,
148
+ uncertainty_samples: m.uncertainty_samples
149
+ )
150
+ m2.extra_regressors = deepcopy(m.extra_regressors)
151
+ m2.seasonalities = deepcopy(m.seasonalities)
152
+ m2.country_holidays = deepcopy(m.country_holidays)
153
+ m2
154
+ end
155
+
156
+ def self.timedelta(value)
157
+ if value.is_a?(Numeric)
158
+ # ActiveSupport::Duration is a numeric
159
+ value
160
+ elsif (m = /\A(\d+(\.\d+)?) days\z/.match(value))
161
+ m[1].to_f * 86400
162
+ else
163
+ raise Error, "Unknown time delta"
164
+ end
165
+ end
166
+
167
+ def self.deepcopy(value)
168
+ if value.is_a?(Hash)
169
+ value.to_h { |k, v| [deepcopy(k), deepcopy(v)] }
170
+ elsif value.is_a?(Array)
171
+ value.map { |v| deepcopy(v) }
172
+ else
173
+ value.dup
174
+ end
175
+ end
176
+
177
+ def self.performance_metrics(df, metrics: nil, rolling_window: 0.1, monthly: false)
178
+ valid_metrics = ["mse", "rmse", "mae", "mape", "mdape", "smape", "coverage"]
179
+ if metrics.nil?
180
+ metrics = valid_metrics
181
+ end
182
+ if (!df.include?("yhat_lower") || !df.include?("yhat_upper")) && metrics.include?("coverage")
183
+ metrics.delete("coverage")
184
+ end
185
+ if metrics.uniq.length != metrics.length
186
+ raise ArgumentError, "Input metrics must be a list of unique values"
187
+ end
188
+ if !Set.new(metrics).subset?(Set.new(valid_metrics))
189
+ raise ArgumentError, "Valid values for metrics are: #{valid_metrics}"
190
+ end
191
+ df_m = df.dup
192
+ if monthly
193
+ raise Error, "Not implemented yet"
194
+ # df_m["horizon"] = df_m["ds"].dt.to_period("M").astype(int) - df_m["cutoff"].dt.to_period("M").astype(int)
195
+ else
196
+ df_m["horizon"] = df_m["ds"] - df_m["cutoff"]
197
+ end
198
+ df_m.sort_by! { |r| r["horizon"] }
199
+ if metrics.include?("mape") && df_m["y"].abs.min < 1e-8
200
+ # logger.info("Skipping MAPE because y close to 0")
201
+ metrics.delete("mape")
202
+ end
203
+ if metrics.length == 0
204
+ return nil
205
+ end
206
+ w = (rolling_window * df_m.shape[0]).to_i
207
+ if w >= 0
208
+ w = [w, 1].max
209
+ w = [w, df_m.shape[0]].min
210
+ end
211
+ # Compute all metrics
212
+ dfs = {}
213
+ metrics.each do |metric|
214
+ dfs[metric] = send(metric, df_m, w)
215
+ end
216
+ res = dfs[metrics[0]]
217
+ metrics.each do |metric|
218
+ res_m = dfs[metric]
219
+ res[metric] = res_m[metric]
220
+ end
221
+ res
222
+ end
223
+
224
+ def self.rolling_mean_by_h(x, h, w, name)
225
+ # Aggregate over h
226
+ df = Rover::DataFrame.new({"x" => x, "h" => h})
227
+ df2 = df.group("h").sum("x").inner_join(df.group("h").count).sort_by { |r| r["h"] }
228
+ xs = df2["sum_x"]
229
+ ns = df2["count"]
230
+ hs = df2["h"]
231
+
232
+ trailing_i = df2.length - 1
233
+ x_sum = 0
234
+ n_sum = 0
235
+ # We don't know output size but it is bounded by len(df2)
236
+ res_x = [nil] * df2.length
237
+
238
+ # Start from the right and work backwards
239
+ (df2.length - 1).downto(0) do |i|
240
+ x_sum += xs[i]
241
+ n_sum += ns[i]
242
+ while n_sum >= w
243
+ # Include points from the previous horizon. All of them if still
244
+ # less than w, otherwise weight the mean by the difference
245
+ excess_n = n_sum - w
246
+ excess_x = excess_n * xs[i] / ns[i]
247
+ res_x[trailing_i] = (x_sum - excess_x) / w
248
+ x_sum -= xs[trailing_i]
249
+ n_sum -= ns[trailing_i]
250
+ trailing_i -= 1
251
+ end
252
+ end
253
+
254
+ res_h = hs[(trailing_i + 1)..-1]
255
+ res_x = res_x[(trailing_i + 1)..-1]
256
+
257
+ Rover::DataFrame.new({"horizon" => res_h, name => res_x})
258
+ end
259
+
260
+ def self.rolling_median_by_h(x, h, w, name)
261
+ # Aggregate over h
262
+ df = Rover::DataFrame.new({"x" => x, "h" => h})
263
+ grouped = df.group("h")
264
+ df2 = grouped.count.sort_by { |r| r["h"] }
265
+ hs = df2["h"]
266
+
267
+ res_h = []
268
+ res_x = []
269
+ # Start from the right and work backwards
270
+ i = hs.length - 1
271
+ while i >= 0
272
+ h_i = hs[i]
273
+ xs = df[df["h"] == h_i]["x"].to_a
274
+
275
+ next_idx_to_add = (h == h_i).to_numo.cast_to(Numo::UInt8).argmax - 1
276
+ while xs.length < w && next_idx_to_add >= 0
277
+ # Include points from the previous horizon. All of them if still
278
+ # less than w, otherwise just enough to get to w.
279
+ xs << x[next_idx_to_add]
280
+ next_idx_to_add -= 1
281
+ end
282
+ if xs.length < w
283
+ # Ran out of points before getting enough.
284
+ break
285
+ end
286
+ res_h << hs[i]
287
+ res_x << Rover::Vector.new(xs).median
288
+ i -= 1
289
+ end
290
+ res_h.reverse!
291
+ res_x.reverse!
292
+ Rover::DataFrame.new({"horizon" => res_h, name => res_x})
293
+ end
294
+
295
+ def self.mse(df, w)
296
+ se = (df["y"] - df["yhat"]) ** 2
297
+ if w < 0
298
+ return Rover::DataFrame.new({"horizon" => df["horizon"], "mse" => se})
299
+ end
300
+ rolling_mean_by_h(se, df["horizon"], w, "mse")
301
+ end
302
+
303
+ def self.rmse(df, w)
304
+ res = mse(df, w)
305
+ res["rmse"] = res.delete("mse").map { |v| Math.sqrt(v) }
306
+ res
307
+ end
308
+
309
+ def self.mae(df, w)
310
+ ae = (df["y"] - df["yhat"]).abs
311
+ if w < 0
312
+ return Rover::DataFrame.new({"horizon" => df["horizon"], "mae" => ae})
313
+ end
314
+ rolling_mean_by_h(ae, df["horizon"], w, "mae")
315
+ end
316
+
317
+ def self.mape(df, w)
318
+ ape = ((df["y"] - df["yhat"]) / df["y"]).abs
319
+ if w < 0
320
+ return Rover::DataFrame.new({"horizon" => df["horizon"], "mape" => ape})
321
+ end
322
+ rolling_mean_by_h(ape, df["horizon"], w, "mape")
323
+ end
324
+
325
+ def self.mdape(df, w)
326
+ ape = ((df["y"] - df["yhat"]) / df["y"]).abs
327
+ if w < 0
328
+ return Rover::DataFrame.new({"horizon" => df["horizon"], "mdape" => ape})
329
+ end
330
+ rolling_median_by_h(ape, df["horizon"], w, "mdape")
331
+ end
332
+
333
+ def self.smape(df, w)
334
+ sape = (df["y"] - df["yhat"]).abs / ((df["y"].abs + df["yhat"].abs) / 2)
335
+ if w < 0
336
+ return Rover::DataFrame.new({"horizon" => df["horizon"], "smape" => sape})
337
+ end
338
+ rolling_mean_by_h(sape, df["horizon"], w, "smape")
339
+ end
340
+
341
+ def self.coverage(df, w)
342
+ is_covered = (df["y"] >= df["yhat_lower"]) & (df["y"] <= df["yhat_upper"])
343
+ if w < 0
344
+ return Rover::DataFrame.new({"horizon" => df["horizon"], "coverage" => is_covered})
345
+ end
346
+ rolling_mean_by_h(is_covered.to(:float), df["horizon"], w, "coverage")
347
+ end
348
+ end
349
+ end
@@ -3,7 +3,14 @@ module Prophet
3
3
  include Holidays
4
4
  include Plot
5
5
 
6
- attr_reader :logger, :params, :train_holiday_names
6
+ attr_reader :logger, :params, :train_holiday_names,
7
+ :history, :seasonalities, :specified_changepoints, :fit_kwargs,
8
+ :growth, :changepoints, :n_changepoints, :changepoint_range,
9
+ :holidays, :seasonality_mode, :seasonality_prior_scale,
10
+ :holidays_prior_scale, :changepoint_prior_scale, :mcmc_samples,
11
+ :interval_width, :uncertainty_samples
12
+
13
+ attr_accessor :extra_regressors, :seasonalities, :country_holidays
7
14
 
8
15
  def initialize(
9
16
  growth: "linear",
@@ -82,7 +89,7 @@ module Prophet
82
89
  raise ArgumentError, "Parameter \"changepoint_range\" must be in [0, 1]"
83
90
  end
84
91
  if @holidays
85
- if !@holidays.is_a?(Rover::DataFrame) && @holidays.include?("ds") && @holidays.include?("holiday")
92
+ if !(@holidays.is_a?(Rover::DataFrame) && @holidays.include?("ds") && @holidays.include?("holiday"))
86
93
  raise ArgumentError, "holidays must be a DataFrame with \"ds\" and \"holiday\" columns."
87
94
  end
88
95
  @holidays["ds"] = to_datetime(@holidays["ds"])
@@ -118,8 +125,8 @@ module Prophet
118
125
  "holidays", "zeros", "extra_regressors_additive", "yhat",
119
126
  "extra_regressors_multiplicative", "multiplicative_terms",
120
127
  ]
121
- rn_l = reserved_names.map { |n| n + "_lower" }
122
- rn_u = reserved_names.map { |n| n + "_upper" }
128
+ rn_l = reserved_names.map { |n| "#{n}_lower" }
129
+ rn_u = reserved_names.map { |n| "#{n}_upper" }
123
130
  reserved_names.concat(rn_l)
124
131
  reserved_names.concat(rn_u)
125
132
  reserved_names.concat(["ds", "y", "cap", "floor", "y_scaled", "cap_scaled"])
@@ -135,7 +142,7 @@ module Prophet
135
142
  if check_seasonalities && @seasonalities[name]
136
143
  raise ArgumentError, "Name #{name.inspect} already used for a seasonality."
137
144
  end
138
- if check_regressors and @extra_regressors[name]
145
+ if check_regressors && @extra_regressors[name]
139
146
  raise ArgumentError, "Name #{name.inspect} already used for an added regressor."
140
147
  end
141
148
  end
@@ -160,7 +167,7 @@ module Prophet
160
167
  raise ArgumentError, "Found NaN in column #{name.inspect}"
161
168
  end
162
169
  end
163
- @seasonalities.values.each do |props|
170
+ @seasonalities.each_value do |props|
164
171
  condition_name = props[:condition_name]
165
172
  if condition_name
166
173
  if !df.include?(condition_name)
@@ -176,8 +183,10 @@ module Prophet
176
183
 
177
184
  initialize_scales(initialize_scales, df)
178
185
 
179
- if @logistic_floor && !df.include?("floor")
180
- raise ArgumentError, "Expected column \"floor\"."
186
+ if @logistic_floor
187
+ unless df.include?("floor")
188
+ raise ArgumentError, "Expected column \"floor\"."
189
+ end
181
190
  else
182
191
  df["floor"] = 0
183
192
  end
@@ -207,7 +216,12 @@ module Prophet
207
216
  def initialize_scales(initialize_scales, df)
208
217
  return unless initialize_scales
209
218
 
210
- floor = 0
219
+ if @growth == "logistic" && df.include?("floor")
220
+ @logistic_floor = true
221
+ floor = df["floor"]
222
+ else
223
+ floor = 0.0
224
+ end
211
225
  @y_scale = (df["y"] - floor).abs.max
212
226
  @y_scale = 1 if @y_scale == 0
213
227
  @start = df["ds"].min
@@ -467,14 +481,14 @@ module Prophet
467
481
  end
468
482
  # Add totals additive and multiplicative components, and regressors
469
483
  ["additive", "multiplicative"].each do |mode|
470
- components = add_group_component(components, mode + "_terms", modes[mode])
484
+ components = add_group_component(components, "#{mode}_terms", modes[mode])
471
485
  regressors_by_mode = @extra_regressors.select { |r, props| props[:mode] == mode }
472
486
  .map { |r, props| r }
473
- components = add_group_component(components, "extra_regressors_" + mode, regressors_by_mode)
487
+ components = add_group_component(components, "extra_regressors_#{mode}", regressors_by_mode)
474
488
 
475
489
  # Add combination components to modes
476
- modes[mode] << mode + "_terms"
477
- modes[mode] << "extra_regressors_" + mode
490
+ modes[mode] << "#{mode}_terms"
491
+ modes[mode] << "extra_regressors_#{mode}"
478
492
  end
479
493
  # After all of the additive/multiplicative groups have been added,
480
494
  modes[@seasonality_mode] << "holidays"
@@ -803,8 +817,8 @@ module Prophet
803
817
  end
804
818
  data[component] = comp.mean(axis: 1, nan: true)
805
819
  if @uncertainty_samples
806
- data[component + "_lower"] = comp.percentile(lower_p, axis: 1)
807
- data[component + "_upper"] = comp.percentile(upper_p, axis: 1)
820
+ data["#{component}_lower"] = comp.percentile(lower_p, axis: 1)
821
+ data["#{component}_upper"] = comp.percentile(upper_p, axis: 1)
808
822
  end
809
823
  end
810
824
  Rover::DataFrame.new(data)
@@ -971,6 +985,12 @@ module Prophet
971
985
  Rover::DataFrame.new({"ds" => dates})
972
986
  end
973
987
 
988
+ def to_json
989
+ require "json"
990
+
991
+ JSON.generate(as_json)
992
+ end
993
+
974
994
  private
975
995
 
976
996
  # Time is preferred over DateTime in Ruby docs
@@ -1017,5 +1037,189 @@ module Prophet
1017
1037
  u = Numo::DFloat.new(size).rand(-0.5, 0.5)
1018
1038
  loc - scale * u.sign * Numo::NMath.log(1 - 2 * u.abs)
1019
1039
  end
1040
+
1041
+ SIMPLE_ATTRIBUTES = [
1042
+ "growth", "n_changepoints", "specified_changepoints", "changepoint_range",
1043
+ "yearly_seasonality", "weekly_seasonality", "daily_seasonality",
1044
+ "seasonality_mode", "seasonality_prior_scale", "changepoint_prior_scale",
1045
+ "holidays_prior_scale", "mcmc_samples", "interval_width", "uncertainty_samples",
1046
+ "y_scale", "logistic_floor", "country_holidays", "component_modes"
1047
+ ]
1048
+
1049
+ PD_SERIES = ["changepoints", "history_dates", "train_holiday_names"]
1050
+
1051
+ PD_TIMESTAMP = ["start"]
1052
+
1053
+ PD_TIMEDELTA = ["t_scale"]
1054
+
1055
+ PD_DATAFRAME = ["holidays", "history", "train_component_cols"]
1056
+
1057
+ NP_ARRAY = ["changepoints_t"]
1058
+
1059
+ ORDEREDDICT = ["seasonalities", "extra_regressors"]
1060
+
1061
+ def as_json
1062
+ if @history.nil?
1063
+ raise Error, "This can only be used to serialize models that have already been fit."
1064
+ end
1065
+
1066
+ model_dict =
1067
+ SIMPLE_ATTRIBUTES.to_h do |attribute|
1068
+ [attribute, instance_variable_get("@#{attribute}")]
1069
+ end
1070
+
1071
+ # Handle attributes of non-core types
1072
+ PD_SERIES.each do |attribute|
1073
+ if instance_variable_get("@#{attribute}").nil?
1074
+ model_dict[attribute] = nil
1075
+ else
1076
+ v = instance_variable_get("@#{attribute}")
1077
+ d = {
1078
+ "name" => "ds",
1079
+ "index" => v.size.times.to_a,
1080
+ "data" => v.to_a.map { |v| v.iso8601(3) }
1081
+ }
1082
+ model_dict[attribute] = JSON.generate(d)
1083
+ end
1084
+ end
1085
+ PD_TIMESTAMP.each do |attribute|
1086
+ model_dict[attribute] = instance_variable_get("@#{attribute}").to_f
1087
+ end
1088
+ PD_TIMEDELTA.each do |attribute|
1089
+ model_dict[attribute] = instance_variable_get("@#{attribute}").to_f
1090
+ end
1091
+ PD_DATAFRAME.each do |attribute|
1092
+ if instance_variable_get("@#{attribute}").nil?
1093
+ model_dict[attribute] = nil
1094
+ else
1095
+ # use same format as Pandas
1096
+ v = instance_variable_get("@#{attribute}")
1097
+
1098
+ v = v.dup
1099
+ v["ds"] = v["ds"].map { |v| v.iso8601(3) } if v["ds"]
1100
+ v.delete("col")
1101
+
1102
+ fields =
1103
+ v.types.map do |k, t|
1104
+ type =
1105
+ case t
1106
+ when :object
1107
+ "datetime"
1108
+ when :int64
1109
+ "integer"
1110
+ else
1111
+ "number"
1112
+ end
1113
+ {"name" => k, "type" => type}
1114
+ end
1115
+
1116
+ d = {
1117
+ "schema" => {
1118
+ "fields" => fields,
1119
+ "pandas_version" => "0.20.0"
1120
+ },
1121
+ "data" => v.to_a
1122
+ }
1123
+ model_dict[attribute] = JSON.generate(d)
1124
+ end
1125
+ end
1126
+ NP_ARRAY.each do |attribute|
1127
+ model_dict[attribute] = instance_variable_get("@#{attribute}").to_a
1128
+ end
1129
+ ORDEREDDICT.each do |attribute|
1130
+ model_dict[attribute] = [
1131
+ instance_variable_get("@#{attribute}").keys,
1132
+ instance_variable_get("@#{attribute}").transform_keys(&:to_s)
1133
+ ]
1134
+ end
1135
+ # Other attributes with special handling
1136
+ # fit_kwargs -> Transform any numpy types before serializing.
1137
+ # They do not need to be transformed back on deserializing.
1138
+ # TODO deep copy
1139
+ fit_kwargs = @fit_kwargs.to_h { |k, v| [k.to_s, v.dup] }
1140
+ if fit_kwargs.key?("init")
1141
+ fit_kwargs["init"].each do |k, v|
1142
+ if v.is_a?(Numo::NArray)
1143
+ fit_kwargs["init"][k] = v.to_a
1144
+ # elsif v.is_a?(Float)
1145
+ # fit_kwargs["init"][k] = v.to_f
1146
+ end
1147
+ end
1148
+ end
1149
+ model_dict["fit_kwargs"] = fit_kwargs
1150
+
1151
+ # Params (Dict[str, np.ndarray])
1152
+ model_dict["params"] = params.transform_values(&:to_a)
1153
+ # Attributes that are skipped: stan_fit, stan_backend
1154
+ # Returns 1.0 for Prophet 1.1
1155
+ model_dict["__prophet_version"] = "1.0"
1156
+ model_dict
1157
+ end
1158
+
1159
+ def self.from_json(model_json)
1160
+ require "json"
1161
+
1162
+ model_dict = JSON.parse(model_json)
1163
+
1164
+ # We will overwrite all attributes set in init anyway
1165
+ model = Prophet.new
1166
+ # Simple types
1167
+ SIMPLE_ATTRIBUTES.each do |attribute|
1168
+ model.instance_variable_set("@#{attribute}", model_dict.fetch(attribute))
1169
+ end
1170
+ PD_SERIES.each do |attribute|
1171
+ if model_dict[attribute].nil?
1172
+ model.instance_variable_set("@#{attribute}", nil)
1173
+ else
1174
+ d = JSON.parse(model_dict.fetch(attribute))
1175
+ s = Rover::Vector.new(d["data"])
1176
+ if d["name"] == "ds"
1177
+ s = s.map { |v| Time.parse(v).utc }
1178
+ end
1179
+ model.instance_variable_set("@#{attribute}", s)
1180
+ end
1181
+ end
1182
+ PD_TIMESTAMP.each do |attribute|
1183
+ model.instance_variable_set("@#{attribute}", Time.at(model_dict.fetch(attribute)))
1184
+ end
1185
+ PD_TIMEDELTA.each do |attribute|
1186
+ model.instance_variable_set("@#{attribute}", model_dict.fetch(attribute).to_f)
1187
+ end
1188
+ PD_DATAFRAME.each do |attribute|
1189
+ if model_dict[attribute].nil?
1190
+ model.instance_variable_set("@#{attribute}", nil)
1191
+ else
1192
+ d = JSON.parse(model_dict.fetch(attribute))
1193
+ df = Rover::DataFrame.new(d["data"])
1194
+ df["ds"] = df["ds"].map { |v| Time.parse(v).utc } if df["ds"]
1195
+ if attribute == "train_component_cols"
1196
+ # Special handling because of named index column
1197
+ # df.columns.name = 'component'
1198
+ # df.index.name = 'col'
1199
+ end
1200
+ model.instance_variable_set("@#{attribute}", df)
1201
+ end
1202
+ end
1203
+ NP_ARRAY.each do |attribute|
1204
+ model.instance_variable_set("@#{attribute}", Numo::NArray.cast(model_dict.fetch(attribute)))
1205
+ end
1206
+ ORDEREDDICT.each do |attribute|
1207
+ key_list, unordered_dict = model_dict.fetch(attribute)
1208
+ od = {}
1209
+ key_list.each do |key|
1210
+ od[key] = unordered_dict[key].transform_keys(&:to_sym)
1211
+ end
1212
+ model.instance_variable_set("@#{attribute}", od)
1213
+ end
1214
+ # Other attributes with special handling
1215
+ # fit_kwargs
1216
+ model.instance_variable_set(:@fit_kwargs, model_dict["fit_kwargs"].transform_keys(&:to_sym))
1217
+ # Params (Dict[str, np.ndarray])
1218
+ model.instance_variable_set(:@params, model_dict["params"].transform_values { |v| Numo::NArray.cast(v) })
1219
+ # Skipped attributes
1220
+ # model.stan_backend = nil
1221
+ model.instance_variable_set(:@stan_fit, nil)
1222
+ model
1223
+ end
1020
1224
  end
1021
1225
  end
@@ -3,8 +3,11 @@ module Prophet
3
3
  def get_holiday_names(country)
4
4
  years = (1995..2045).to_a
5
5
  holiday_names = make_holidays_df(years, country)["holiday"].uniq
6
- # TODO raise error in 0.4.0
7
- logger.warn "Holidays in #{country} are not currently supported"
6
+ if holiday_names.size == 0
7
+ # TODO raise error in 0.5.0
8
+ # raise ArgumentError, "Holidays in #{country} are not currently supported"
9
+ logger.warn "Holidays in #{country} are not currently supported"
10
+ end
8
11
  holiday_names
9
12
  end
10
13