prophet-rb 0.3.1 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +17 -2
  3. data/LICENSE.txt +1 -1
  4. data/README.md +149 -2
  5. data/data-raw/LICENSE-holidays.txt +20 -0
  6. data/data-raw/README.md +3 -0
  7. data/data-raw/generated_holidays.csv +29302 -61443
  8. data/lib/prophet/diagnostics.rb +349 -0
  9. data/lib/prophet/forecaster.rb +214 -4
  10. data/lib/prophet/holidays.rb +6 -10
  11. data/lib/prophet/plot.rb +56 -6
  12. data/lib/prophet/stan_backend.rb +10 -1
  13. data/lib/prophet/version.rb +1 -1
  14. data/lib/prophet.rb +23 -7
  15. data/stan/{unix/prophet.stan → prophet.stan} +8 -7
  16. data/vendor/aarch64-linux/bin/prophet +0 -0
  17. data/vendor/aarch64-linux/lib/libtbb.so.2 +0 -0
  18. data/vendor/aarch64-linux/lib/libtbbmalloc.so.2 +0 -0
  19. data/vendor/aarch64-linux/lib/libtbbmalloc_proxy.so.2 +0 -0
  20. data/vendor/aarch64-linux/licenses/sundials-license.txt +25 -63
  21. data/vendor/aarch64-linux/licenses/sundials-notice.txt +21 -0
  22. data/vendor/arm64-darwin/bin/prophet +0 -0
  23. data/vendor/arm64-darwin/lib/libtbb.dylib +0 -0
  24. data/vendor/arm64-darwin/lib/libtbbmalloc.dylib +0 -0
  25. data/vendor/arm64-darwin/licenses/sundials-license.txt +25 -63
  26. data/vendor/arm64-darwin/licenses/sundials-notice.txt +21 -0
  27. data/vendor/x86_64-darwin/bin/prophet +0 -0
  28. data/vendor/x86_64-darwin/lib/libtbb.dylib +0 -0
  29. data/vendor/x86_64-darwin/lib/libtbbmalloc.dylib +0 -0
  30. data/vendor/x86_64-darwin/licenses/sundials-license.txt +25 -63
  31. data/vendor/x86_64-darwin/licenses/sundials-notice.txt +21 -0
  32. data/vendor/x86_64-linux/bin/prophet +0 -0
  33. data/vendor/x86_64-linux/lib/libtbb.so.2 +0 -0
  34. data/vendor/x86_64-linux/lib/libtbbmalloc.so.2 +0 -0
  35. data/vendor/x86_64-linux/lib/libtbbmalloc_proxy.so.2 +0 -0
  36. data/vendor/x86_64-linux/licenses/sundials-license.txt +25 -63
  37. data/vendor/x86_64-linux/licenses/sundials-notice.txt +21 -0
  38. metadata +10 -4
  39. data/stan/win/prophet.stan +0 -175
@@ -0,0 +1,349 @@
1
+ module Prophet
2
+ module Diagnostics
3
+ def self.generate_cutoffs(df, horizon, initial, period)
4
+ # Last cutoff is 'latest date in data - horizon' date
5
+ cutoff = df["ds"].max - horizon
6
+ if cutoff < df["ds"].min
7
+ raise Error, "Less data than horizon."
8
+ end
9
+ result = [cutoff]
10
+ while result[-1] >= df["ds"].min + initial
11
+ cutoff -= period
12
+ # If data does not exist in data range (cutoff, cutoff + horizon]
13
+ if !(((df["ds"] > cutoff) & (df["ds"] <= cutoff + horizon)).any?)
14
+ # Next cutoff point is 'last date before cutoff in data - horizon'
15
+ if cutoff > df["ds"].min
16
+ closest_date = df[df["ds"] <= cutoff].max["ds"]
17
+ cutoff = closest_date - horizon
18
+ end
19
+ # else no data left, leave cutoff as is, it will be dropped.
20
+ end
21
+ result << cutoff
22
+ end
23
+ result = result[0...-1]
24
+ if result.length == 0
25
+ raise Error, "Less data than horizon after initial window. Make horizon or initial shorter."
26
+ end
27
+ # logger.info("Making #{result.length} forecasts with cutoffs between #{result[-1]} and #{result[0]}")
28
+ result.reverse
29
+ end
30
+
31
+ def self.cross_validation(model, horizon:, period: nil, initial: nil, cutoffs: nil)
32
+ if model.history.nil?
33
+ raise Error, "Model has not been fit. Fitting the model provides contextual parameters for cross validation."
34
+ end
35
+
36
+ df = model.history.dup
37
+ horizon = timedelta(horizon)
38
+
39
+ predict_columns = ["ds", "yhat"]
40
+ if model.uncertainty_samples
41
+ predict_columns.concat(["yhat_lower", "yhat_upper"])
42
+ end
43
+
44
+ # Identify largest seasonality period
45
+ period_max = 0.0
46
+ model.seasonalities.each do |_, s|
47
+ period_max = [period_max, s[:period]].max
48
+ end
49
+ seasonality_dt = timedelta("#{period_max} days")
50
+
51
+ if cutoffs.nil?
52
+ # Set period
53
+ period = period.nil? ? 0.5 * horizon : timedelta(period)
54
+
55
+ # Set initial
56
+ initial = initial.nil? ? [3 * horizon, seasonality_dt].max : timedelta(initial)
57
+
58
+ # Compute Cutoffs
59
+ cutoffs = generate_cutoffs(df, horizon, initial, period)
60
+ else
61
+ # add validation of the cutoff to make sure that the min cutoff is strictly greater than the min date in the history
62
+ if cutoffs.min <= df["ds"].min
63
+ raise Error, "Minimum cutoff value is not strictly greater than min date in history"
64
+ end
65
+ # max value of cutoffs is <= (end date minus horizon)
66
+ end_date_minus_horizon = df["ds"].max - horizon
67
+ if cutoffs.max > end_date_minus_horizon
68
+ raise Error, "Maximum cutoff value is greater than end date minus horizon, no value for cross-validation remaining"
69
+ end
70
+ initial = cutoffs[0] - df["ds"].min
71
+ end
72
+
73
+ # Check if the initial window
74
+ # (that is, the amount of time between the start of the history and the first cutoff)
75
+ # is less than the maximum seasonality period
76
+ if initial < seasonality_dt
77
+ msg = "Seasonality has period of #{period_max} days "
78
+ msg += "which is larger than initial window. "
79
+ msg += "Consider increasing initial."
80
+ # logger.warn(msg)
81
+ end
82
+
83
+ predicts = cutoffs.map { |cutoff| single_cutoff_forecast(df, model, cutoff, horizon, predict_columns) }
84
+
85
+ # Combine all predicted DataFrame into one DataFrame
86
+ predicts.reduce(Rover::DataFrame.new) { |memo, v| memo.concat(v) }
87
+ end
88
+
89
+ def self.single_cutoff_forecast(df, model, cutoff, horizon, predict_columns)
90
+ # Generate new object with copying fitting options
91
+ m = prophet_copy(model, cutoff)
92
+ # Train model
93
+ history_c = df[df["ds"] <= cutoff]
94
+ if history_c.shape[0] < 2
95
+ raise Error, "Less than two datapoints before cutoff. Increase initial window."
96
+ end
97
+ m.fit(history_c, **model.fit_kwargs)
98
+ # Calculate yhat
99
+ index_predicted = (df["ds"] > cutoff) & (df["ds"] <= cutoff + horizon)
100
+ # Get the columns for the future dataframe
101
+ columns = ["ds"]
102
+ if m.growth == "logistic"
103
+ columns << "cap"
104
+ if m.logistic_floor
105
+ columns << "floor"
106
+ end
107
+ end
108
+ columns.concat(m.extra_regressors.keys)
109
+ columns.concat(m.seasonalities.map { |_, props| props[:condition_name] }.compact)
110
+ yhat = m.predict(df[index_predicted][columns])
111
+ # Merge yhat(predicts), y(df, original data) and cutoff
112
+ yhat[predict_columns].merge(df[index_predicted][["y"]]).merge(Rover::DataFrame.new({"cutoff" => [cutoff] * yhat.length}))
113
+ end
114
+
115
+ def self.prophet_copy(m, cutoff = nil)
116
+ if m.history.nil?
117
+ raise Error, "This is for copying a fitted Prophet object."
118
+ end
119
+
120
+ if m.specified_changepoints
121
+ changepoints = m.changepoints
122
+ if !cutoff.nil?
123
+ # Filter change points '< cutoff'
124
+ last_history_date = m.history["ds"][m.history["ds"] <= cutoff].max
125
+ changepoints = changepoints[changepoints < last_history_date]
126
+ end
127
+ else
128
+ changepoints = nil
129
+ end
130
+
131
+ # Auto seasonalities are set to False because they are already set in
132
+ # m.seasonalities.
133
+ m2 = m.class.new(
134
+ growth: m.growth,
135
+ n_changepoints: m.n_changepoints,
136
+ changepoint_range: m.changepoint_range,
137
+ changepoints: changepoints,
138
+ yearly_seasonality: false,
139
+ weekly_seasonality: false,
140
+ daily_seasonality: false,
141
+ holidays: m.holidays,
142
+ seasonality_mode: m.seasonality_mode,
143
+ seasonality_prior_scale: m.seasonality_prior_scale,
144
+ changepoint_prior_scale: m.changepoint_prior_scale,
145
+ holidays_prior_scale: m.holidays_prior_scale,
146
+ mcmc_samples: m.mcmc_samples,
147
+ interval_width: m.interval_width,
148
+ uncertainty_samples: m.uncertainty_samples
149
+ )
150
+ m2.extra_regressors = deepcopy(m.extra_regressors)
151
+ m2.seasonalities = deepcopy(m.seasonalities)
152
+ m2.country_holidays = deepcopy(m.country_holidays)
153
+ m2
154
+ end
155
+
156
+ def self.timedelta(value)
157
+ if value.is_a?(Numeric)
158
+ # ActiveSupport::Duration is a numeric
159
+ value
160
+ elsif (m = /\A(\d+(\.\d+)?) days\z/.match(value))
161
+ m[1].to_f * 86400
162
+ else
163
+ raise Error, "Unknown time delta"
164
+ end
165
+ end
166
+
167
+ def self.deepcopy(value)
168
+ if value.is_a?(Hash)
169
+ value.to_h { |k, v| [deepcopy(k), deepcopy(v)] }
170
+ elsif value.is_a?(Array)
171
+ value.map { |v| deepcopy(v) }
172
+ else
173
+ value.dup
174
+ end
175
+ end
176
+
177
+ def self.performance_metrics(df, metrics: nil, rolling_window: 0.1, monthly: false)
178
+ valid_metrics = ["mse", "rmse", "mae", "mape", "mdape", "smape", "coverage"]
179
+ if metrics.nil?
180
+ metrics = valid_metrics
181
+ end
182
+ if (df["yhat_lower"].nil? || df["yhat_upper"].nil?) && metrics.include?("coverage")
183
+ metrics.delete("coverage")
184
+ end
185
+ if metrics.uniq.length != metrics.length
186
+ raise ArgumentError, "Input metrics must be a list of unique values"
187
+ end
188
+ if !Set.new(metrics).subset?(Set.new(valid_metrics))
189
+ raise ArgumentError, "Valid values for metrics are: #{valid_metrics}"
190
+ end
191
+ df_m = df.dup
192
+ if monthly
193
+ raise Error, "Not implemented yet"
194
+ # df_m["horizon"] = df_m["ds"].dt.to_period("M").astype(int) - df_m["cutoff"].dt.to_period("M").astype(int)
195
+ else
196
+ df_m["horizon"] = df_m["ds"] - df_m["cutoff"]
197
+ end
198
+ df_m.sort_by! { |r| r["horizon"] }
199
+ if metrics.include?("mape") && df_m["y"].abs.min < 1e-8
200
+ # logger.info("Skipping MAPE because y close to 0")
201
+ metrics.delete("mape")
202
+ end
203
+ if metrics.length == 0
204
+ return nil
205
+ end
206
+ w = (rolling_window * df_m.shape[0]).to_i
207
+ if w >= 0
208
+ w = [w, 1].max
209
+ w = [w, df_m.shape[0]].min
210
+ end
211
+ # Compute all metrics
212
+ dfs = {}
213
+ metrics.each do |metric|
214
+ dfs[metric] = send(metric, df_m, w)
215
+ end
216
+ res = dfs[metrics[0]]
217
+ metrics.each do |metric|
218
+ res_m = dfs[metric]
219
+ res[metric] = res_m[metric]
220
+ end
221
+ res
222
+ end
223
+
224
+ def self.rolling_mean_by_h(x, h, w, name)
225
+ # Aggregate over h
226
+ df = Rover::DataFrame.new({"x" => x, "h" => h})
227
+ df2 = df.group("h").sum("x").inner_join(df.group("h").count).sort_by { |r| r["h"] }
228
+ xs = df2["sum_x"]
229
+ ns = df2["count"]
230
+ hs = df2["h"]
231
+
232
+ trailing_i = df2.length - 1
233
+ x_sum = 0
234
+ n_sum = 0
235
+ # We don't know output size but it is bounded by len(df2)
236
+ res_x = [nil] * df2.length
237
+
238
+ # Start from the right and work backwards
239
+ (df2.length - 1).downto(0) do |i|
240
+ x_sum += xs[i]
241
+ n_sum += ns[i]
242
+ while n_sum >= w
243
+ # Include points from the previous horizon. All of them if still
244
+ # less than w, otherwise weight the mean by the difference
245
+ excess_n = n_sum - w
246
+ excess_x = excess_n * xs[i] / ns[i]
247
+ res_x[trailing_i] = (x_sum - excess_x) / w
248
+ x_sum -= xs[trailing_i]
249
+ n_sum -= ns[trailing_i]
250
+ trailing_i -= 1
251
+ end
252
+ end
253
+
254
+ res_h = hs[(trailing_i + 1)..-1]
255
+ res_x = res_x[(trailing_i + 1)..-1]
256
+
257
+ Rover::DataFrame.new({"horizon" => res_h, name => res_x})
258
+ end
259
+
260
+ def self.rolling_median_by_h(x, h, w, name)
261
+ # Aggregate over h
262
+ df = Rover::DataFrame.new({"x" => x, "h" => h})
263
+ grouped = df.group("h")
264
+ df2 = grouped.count.sort_by { |r| r["h"] }
265
+ hs = df2["h"]
266
+
267
+ res_h = []
268
+ res_x = []
269
+ # Start from the right and work backwards
270
+ i = hs.length - 1
271
+ while i >= 0
272
+ h_i = hs[i]
273
+ xs = df[df["h"] == h_i]["x"].to_a
274
+
275
+ next_idx_to_add = (h == h_i).to_numo.cast_to(Numo::UInt8).argmax - 1
276
+ while xs.length < w && next_idx_to_add >= 0
277
+ # Include points from the previous horizon. All of them if still
278
+ # less than w, otherwise just enough to get to w.
279
+ xs << x[next_idx_to_add]
280
+ next_idx_to_add -= 1
281
+ end
282
+ if xs.length < w
283
+ # Ran out of points before getting enough.
284
+ break
285
+ end
286
+ res_h << hs[i]
287
+ res_x << Rover::Vector.new(xs).median
288
+ i -= 1
289
+ end
290
+ res_h.reverse!
291
+ res_x.reverse!
292
+ Rover::DataFrame.new({"horizon" => res_h, name => res_x})
293
+ end
294
+
295
+ def self.mse(df, w)
296
+ se = (df["y"] - df["yhat"]) ** 2
297
+ if w < 0
298
+ return Rover::DataFrame.new({"horizon" => df["horizon"], "mse" => se})
299
+ end
300
+ rolling_mean_by_h(se, df["horizon"], w, "mse")
301
+ end
302
+
303
+ def self.rmse(df, w)
304
+ res = mse(df, w)
305
+ res["rmse"] = res.delete("mse").map { |v| Math.sqrt(v) }
306
+ res
307
+ end
308
+
309
+ def self.mae(df, w)
310
+ ae = (df["y"] - df["yhat"]).abs
311
+ if w < 0
312
+ return Rover::DataFrame.new({"horizon" => df["horizon"], "mae" => ae})
313
+ end
314
+ rolling_mean_by_h(ae, df["horizon"], w, "mae")
315
+ end
316
+
317
+ def self.mape(df, w)
318
+ ape = ((df["y"] - df["yhat"]) / df["y"]).abs
319
+ if w < 0
320
+ return Rover::DataFrame.new({"horizon" => df["horizon"], "mape" => ape})
321
+ end
322
+ rolling_mean_by_h(ape, df["horizon"], w, "mape")
323
+ end
324
+
325
+ def self.mdape(df, w)
326
+ ape = ((df["y"] - df["yhat"]) / df["y"]).abs
327
+ if w < 0
328
+ return Rover::DataFrame.new({"horizon" => df["horizon"], "mdape" => ape})
329
+ end
330
+ rolling_median_by_h(ape, df["horizon"], w, "mdape")
331
+ end
332
+
333
+ def self.smape(df, w)
334
+ sape = (df["y"] - df["yhat"]).abs / ((df["y"].abs + df["yhat"].abs) / 2)
335
+ if w < 0
336
+ return Rover::DataFrame.new({"horizon" => df["horizon"], "smape" => sape})
337
+ end
338
+ rolling_mean_by_h(sape, df["horizon"], w, "smape")
339
+ end
340
+
341
+ def self.coverage(df, w)
342
+ is_covered = (df["y"] >= df["yhat_lower"]) & (df["y"] <= df["yhat_upper"])
343
+ if w < 0
344
+ return Rover::DataFrame.new({"horizon" => df["horizon"], "coverage" => is_covered})
345
+ end
346
+ rolling_mean_by_h(is_covered.to(:float), df["horizon"], w, "coverage")
347
+ end
348
+ end
349
+ end
@@ -3,7 +3,14 @@ module Prophet
3
3
  include Holidays
4
4
  include Plot
5
5
 
6
- attr_reader :logger, :params, :train_holiday_names
6
+ attr_reader :logger, :params, :train_holiday_names,
7
+ :history, :seasonalities, :specified_changepoints, :fit_kwargs,
8
+ :growth, :changepoints, :n_changepoints, :changepoint_range,
9
+ :holidays, :seasonality_mode, :seasonality_prior_scale,
10
+ :holidays_prior_scale, :changepoint_prior_scale, :mcmc_samples,
11
+ :interval_width, :uncertainty_samples
12
+
13
+ attr_accessor :extra_regressors, :seasonalities, :country_holidays
7
14
 
8
15
  def initialize(
9
16
  growth: "linear",
@@ -176,8 +183,10 @@ module Prophet
176
183
 
177
184
  initialize_scales(initialize_scales, df)
178
185
 
179
- if @logistic_floor && !df.include?("floor")
180
- raise ArgumentError, "Expected column \"floor\"."
186
+ if @logistic_floor
187
+ unless df.include?("floor")
188
+ raise ArgumentError, "Expected column \"floor\"."
189
+ end
181
190
  else
182
191
  df["floor"] = 0
183
192
  end
@@ -207,7 +216,12 @@ module Prophet
207
216
  def initialize_scales(initialize_scales, df)
208
217
  return unless initialize_scales
209
218
 
210
- floor = 0
219
+ if @growth == "logistic" && df.include?("floor")
220
+ @logistic_floor = true
221
+ floor = df["floor"]
222
+ else
223
+ floor = 0.0
224
+ end
211
225
  @y_scale = (df["y"] - floor).abs.max
212
226
  @y_scale = 1 if @y_scale == 0
213
227
  @start = df["ds"].min
@@ -386,6 +400,12 @@ module Prophet
386
400
 
387
401
  def add_country_holidays(country_name)
388
402
  raise Error, "Country holidays must be added prior to model fitting." if @history
403
+
404
+ # Fix for previously documented keyword argument
405
+ if country_name.is_a?(Hash) && country_name[:country_name]
406
+ country_name = country_name[:country_name]
407
+ end
408
+
389
409
  # Validate names.
390
410
  get_holiday_names(country_name).each do |name|
391
411
  # Allow merging with existing holidays
@@ -965,6 +985,12 @@ module Prophet
965
985
  Rover::DataFrame.new({"ds" => dates})
966
986
  end
967
987
 
988
+ def to_json
989
+ require "json"
990
+
991
+ JSON.generate(as_json)
992
+ end
993
+
968
994
  private
969
995
 
970
996
  # Time is preferred over DateTime in Ruby docs
@@ -1011,5 +1037,189 @@ module Prophet
1011
1037
  u = Numo::DFloat.new(size).rand(-0.5, 0.5)
1012
1038
  loc - scale * u.sign * Numo::NMath.log(1 - 2 * u.abs)
1013
1039
  end
1040
+
1041
+ SIMPLE_ATTRIBUTES = [
1042
+ "growth", "n_changepoints", "specified_changepoints", "changepoint_range",
1043
+ "yearly_seasonality", "weekly_seasonality", "daily_seasonality",
1044
+ "seasonality_mode", "seasonality_prior_scale", "changepoint_prior_scale",
1045
+ "holidays_prior_scale", "mcmc_samples", "interval_width", "uncertainty_samples",
1046
+ "y_scale", "logistic_floor", "country_holidays", "component_modes"
1047
+ ]
1048
+
1049
+ PD_SERIES = ["changepoints", "history_dates", "train_holiday_names"]
1050
+
1051
+ PD_TIMESTAMP = ["start"]
1052
+
1053
+ PD_TIMEDELTA = ["t_scale"]
1054
+
1055
+ PD_DATAFRAME = ["holidays", "history", "train_component_cols"]
1056
+
1057
+ NP_ARRAY = ["changepoints_t"]
1058
+
1059
+ ORDEREDDICT = ["seasonalities", "extra_regressors"]
1060
+
1061
+ def as_json
1062
+ if @history.nil?
1063
+ raise Error, "This can only be used to serialize models that have already been fit."
1064
+ end
1065
+
1066
+ model_dict =
1067
+ SIMPLE_ATTRIBUTES.to_h do |attribute|
1068
+ [attribute, instance_variable_get("@#{attribute}")]
1069
+ end
1070
+
1071
+ # Handle attributes of non-core types
1072
+ PD_SERIES.each do |attribute|
1073
+ if instance_variable_get("@#{attribute}").nil?
1074
+ model_dict[attribute] = nil
1075
+ else
1076
+ v = instance_variable_get("@#{attribute}")
1077
+ d = {
1078
+ "name" => "ds",
1079
+ "index" => v.size.times.to_a,
1080
+ "data" => v.to_a.map { |v| v.iso8601(3) }
1081
+ }
1082
+ model_dict[attribute] = JSON.generate(d)
1083
+ end
1084
+ end
1085
+ PD_TIMESTAMP.each do |attribute|
1086
+ model_dict[attribute] = instance_variable_get("@#{attribute}").to_f
1087
+ end
1088
+ PD_TIMEDELTA.each do |attribute|
1089
+ model_dict[attribute] = instance_variable_get("@#{attribute}").to_f
1090
+ end
1091
+ PD_DATAFRAME.each do |attribute|
1092
+ if instance_variable_get("@#{attribute}").nil?
1093
+ model_dict[attribute] = nil
1094
+ else
1095
+ # use same format as Pandas
1096
+ v = instance_variable_get("@#{attribute}")
1097
+
1098
+ v = v.dup
1099
+ v["ds"] = v["ds"].map { |v| v.iso8601(3) } if v["ds"]
1100
+ v.delete("col")
1101
+
1102
+ fields =
1103
+ v.types.map do |k, t|
1104
+ type =
1105
+ case t
1106
+ when :object
1107
+ "datetime"
1108
+ when :int64
1109
+ "integer"
1110
+ else
1111
+ "number"
1112
+ end
1113
+ {"name" => k, "type" => type}
1114
+ end
1115
+
1116
+ d = {
1117
+ "schema" => {
1118
+ "fields" => fields,
1119
+ "pandas_version" => "0.20.0"
1120
+ },
1121
+ "data" => v.to_a
1122
+ }
1123
+ model_dict[attribute] = JSON.generate(d)
1124
+ end
1125
+ end
1126
+ NP_ARRAY.each do |attribute|
1127
+ model_dict[attribute] = instance_variable_get("@#{attribute}").to_a
1128
+ end
1129
+ ORDEREDDICT.each do |attribute|
1130
+ model_dict[attribute] = [
1131
+ instance_variable_get("@#{attribute}").keys,
1132
+ instance_variable_get("@#{attribute}").transform_keys(&:to_s)
1133
+ ]
1134
+ end
1135
+ # Other attributes with special handling
1136
+ # fit_kwargs -> Transform any numpy types before serializing.
1137
+ # They do not need to be transformed back on deserializing.
1138
+ # TODO deep copy
1139
+ fit_kwargs = @fit_kwargs.to_h { |k, v| [k.to_s, v.dup] }
1140
+ if fit_kwargs.key?("init")
1141
+ fit_kwargs["init"].each do |k, v|
1142
+ if v.is_a?(Numo::NArray)
1143
+ fit_kwargs["init"][k] = v.to_a
1144
+ # elsif v.is_a?(Float)
1145
+ # fit_kwargs["init"][k] = v.to_f
1146
+ end
1147
+ end
1148
+ end
1149
+ model_dict["fit_kwargs"] = fit_kwargs
1150
+
1151
+ # Params (Dict[str, np.ndarray])
1152
+ model_dict["params"] = params.transform_values(&:to_a)
1153
+ # Attributes that are skipped: stan_fit, stan_backend
1154
+ # Returns 1.0 for Prophet 1.1
1155
+ model_dict["__prophet_version"] = "1.0"
1156
+ model_dict
1157
+ end
1158
+
1159
+ def self.from_json(model_json)
1160
+ require "json"
1161
+
1162
+ model_dict = JSON.parse(model_json)
1163
+
1164
+ # We will overwrite all attributes set in init anyway
1165
+ model = Prophet.new
1166
+ # Simple types
1167
+ SIMPLE_ATTRIBUTES.each do |attribute|
1168
+ model.instance_variable_set("@#{attribute}", model_dict.fetch(attribute))
1169
+ end
1170
+ PD_SERIES.each do |attribute|
1171
+ if model_dict[attribute].nil?
1172
+ model.instance_variable_set("@#{attribute}", nil)
1173
+ else
1174
+ d = JSON.parse(model_dict.fetch(attribute))
1175
+ s = Rover::Vector.new(d["data"])
1176
+ if d["name"] == "ds"
1177
+ s = s.map { |v| Time.parse(v).utc }
1178
+ end
1179
+ model.instance_variable_set("@#{attribute}", s)
1180
+ end
1181
+ end
1182
+ PD_TIMESTAMP.each do |attribute|
1183
+ model.instance_variable_set("@#{attribute}", Time.at(model_dict.fetch(attribute)))
1184
+ end
1185
+ PD_TIMEDELTA.each do |attribute|
1186
+ model.instance_variable_set("@#{attribute}", model_dict.fetch(attribute).to_f)
1187
+ end
1188
+ PD_DATAFRAME.each do |attribute|
1189
+ if model_dict[attribute].nil?
1190
+ model.instance_variable_set("@#{attribute}", nil)
1191
+ else
1192
+ d = JSON.parse(model_dict.fetch(attribute))
1193
+ df = Rover::DataFrame.new(d["data"])
1194
+ df["ds"] = df["ds"].map { |v| Time.parse(v).utc } if df["ds"]
1195
+ if attribute == "train_component_cols"
1196
+ # Special handling because of named index column
1197
+ # df.columns.name = 'component'
1198
+ # df.index.name = 'col'
1199
+ end
1200
+ model.instance_variable_set("@#{attribute}", df)
1201
+ end
1202
+ end
1203
+ NP_ARRAY.each do |attribute|
1204
+ model.instance_variable_set("@#{attribute}", Numo::NArray.cast(model_dict.fetch(attribute)))
1205
+ end
1206
+ ORDEREDDICT.each do |attribute|
1207
+ key_list, unordered_dict = model_dict.fetch(attribute)
1208
+ od = {}
1209
+ key_list.each do |key|
1210
+ od[key] = unordered_dict[key].transform_keys(&:to_sym)
1211
+ end
1212
+ model.instance_variable_set("@#{attribute}", od)
1213
+ end
1214
+ # Other attributes with special handling
1215
+ # fit_kwargs
1216
+ model.instance_variable_set(:@fit_kwargs, model_dict["fit_kwargs"].transform_keys(&:to_sym))
1217
+ # Params (Dict[str, np.ndarray])
1218
+ model.instance_variable_set(:@params, model_dict["params"].transform_values { |v| Numo::NArray.cast(v) })
1219
+ # Skipped attributes
1220
+ # model.stan_backend = nil
1221
+ model.instance_variable_set(:@stan_fit, nil)
1222
+ model
1223
+ end
1014
1224
  end
1015
1225
  end
@@ -2,25 +2,21 @@ module Prophet
2
2
  module Holidays
3
3
  def get_holiday_names(country)
4
4
  years = (1995..2045).to_a
5
- make_holidays_df(years, country)["holiday"].uniq
5
+ holiday_names = make_holidays_df(years, country)["holiday"].uniq
6
+ # TODO raise error in 0.4.0
7
+ logger.warn "Holidays in #{country} are not currently supported"
8
+ holiday_names
6
9
  end
7
10
 
8
11
  def make_holidays_df(year_list, country)
9
12
  holidays_df[(holidays_df["country"] == country) & (holidays_df["year"].in?(year_list))][["ds", "holiday"]]
10
13
  end
11
14
 
12
- # TODO marshal on installation
15
+ # TODO improve performance
13
16
  def holidays_df
14
17
  @holidays_df ||= begin
15
- holidays = {"ds" => [], "holiday" => [], "country" => [], "year" => []}
16
18
  holidays_file = File.expand_path("../../data-raw/generated_holidays.csv", __dir__)
17
- CSV.foreach(holidays_file, headers: true, converters: [:date, :numeric]) do |row|
18
- holidays["ds"] << row["ds"]
19
- holidays["holiday"] << row["holiday"]
20
- holidays["country"] << row["country"]
21
- holidays["year"] << row["year"]
22
- end
23
- Rover::DataFrame.new(holidays)
19
+ Rover.read_csv(holidays_file, converters: [:date, :numeric])
24
20
  end
25
21
  end
26
22
  end