prophet-rb 0.3.2 → 0.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +15 -0
- data/LICENSE.txt +1 -1
- data/README.md +158 -6
- data/data-raw/LICENSE-holidays.txt +20 -0
- data/data-raw/README.md +3 -0
- data/data-raw/generated_holidays.csv +29302 -61443
- data/lib/prophet/diagnostics.rb +349 -0
- data/lib/prophet/forecaster.rb +219 -15
- data/lib/prophet/holidays.rb +5 -2
- data/lib/prophet/plot.rb +60 -10
- data/lib/prophet/stan_backend.rb +10 -1
- data/lib/prophet/version.rb +1 -1
- data/lib/prophet.rb +5 -0
- data/stan/{unix/prophet.stan → prophet.stan} +8 -7
- data/vendor/aarch64-linux/bin/prophet +0 -0
- data/vendor/aarch64-linux/lib/libtbb.so.2 +0 -0
- data/vendor/aarch64-linux/lib/libtbbmalloc.so.2 +0 -0
- data/vendor/aarch64-linux/lib/libtbbmalloc_proxy.so.2 +0 -0
- data/vendor/aarch64-linux/licenses/sundials-license.txt +25 -63
- data/vendor/aarch64-linux/licenses/sundials-notice.txt +21 -0
- data/vendor/arm64-darwin/bin/prophet +0 -0
- data/vendor/arm64-darwin/lib/libtbb.dylib +0 -0
- data/vendor/arm64-darwin/lib/libtbbmalloc.dylib +0 -0
- data/vendor/arm64-darwin/licenses/sundials-license.txt +25 -63
- data/vendor/arm64-darwin/licenses/sundials-notice.txt +21 -0
- data/vendor/x86_64-darwin/bin/prophet +0 -0
- data/vendor/x86_64-darwin/lib/libtbb.dylib +0 -0
- data/vendor/x86_64-darwin/lib/libtbbmalloc.dylib +0 -0
- data/vendor/x86_64-darwin/licenses/sundials-license.txt +25 -63
- data/vendor/x86_64-darwin/licenses/sundials-notice.txt +21 -0
- data/vendor/x86_64-linux/bin/prophet +0 -0
- data/vendor/x86_64-linux/lib/libtbb.so.2 +0 -0
- data/vendor/x86_64-linux/lib/libtbbmalloc.so.2 +0 -0
- data/vendor/x86_64-linux/lib/libtbbmalloc_proxy.so.2 +0 -0
- data/vendor/x86_64-linux/licenses/sundials-license.txt +25 -63
- data/vendor/x86_64-linux/licenses/sundials-notice.txt +21 -0
- metadata +10 -4
- data/stan/win/prophet.stan +0 -175
| @@ -0,0 +1,349 @@ | |
| 1 | 
            +
            module Prophet
         | 
| 2 | 
            +
              module Diagnostics
         | 
| 3 | 
            +
                def self.generate_cutoffs(df, horizon, initial, period)
         | 
| 4 | 
            +
                  # Last cutoff is 'latest date in data - horizon' date
         | 
| 5 | 
            +
                  cutoff = df["ds"].max - horizon
         | 
| 6 | 
            +
                  if cutoff < df["ds"].min
         | 
| 7 | 
            +
                    raise Error, "Less data than horizon."
         | 
| 8 | 
            +
                  end
         | 
| 9 | 
            +
                  result = [cutoff]
         | 
| 10 | 
            +
                  while result[-1] >= df["ds"].min + initial
         | 
| 11 | 
            +
                    cutoff -= period
         | 
| 12 | 
            +
                    # If data does not exist in data range (cutoff, cutoff + horizon]
         | 
| 13 | 
            +
                    if !(((df["ds"] > cutoff) & (df["ds"] <= cutoff + horizon)).any?)
         | 
| 14 | 
            +
                      # Next cutoff point is 'last date before cutoff in data - horizon'
         | 
| 15 | 
            +
                      if cutoff > df["ds"].min
         | 
| 16 | 
            +
                        closest_date = df[df["ds"] <= cutoff].max["ds"]
         | 
| 17 | 
            +
                        cutoff = closest_date - horizon
         | 
| 18 | 
            +
                      end
         | 
| 19 | 
            +
                      # else no data left, leave cutoff as is, it will be dropped.
         | 
| 20 | 
            +
                    end
         | 
| 21 | 
            +
                    result << cutoff
         | 
| 22 | 
            +
                  end
         | 
| 23 | 
            +
                  result = result[0...-1]
         | 
| 24 | 
            +
                  if result.length == 0
         | 
| 25 | 
            +
                    raise Error, "Less data than horizon after initial window. Make horizon or initial shorter."
         | 
| 26 | 
            +
                  end
         | 
| 27 | 
            +
                  # logger.info("Making #{result.length} forecasts with cutoffs between #{result[-1]} and #{result[0]}")
         | 
| 28 | 
            +
                  result.reverse
         | 
| 29 | 
            +
                end
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                def self.cross_validation(model, horizon:, period: nil, initial: nil, cutoffs: nil)
         | 
| 32 | 
            +
                  if model.history.nil?
         | 
| 33 | 
            +
                    raise Error, "Model has not been fit. Fitting the model provides contextual parameters for cross validation."
         | 
| 34 | 
            +
                  end
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                  df = model.history.dup
         | 
| 37 | 
            +
                  horizon = timedelta(horizon)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                  predict_columns = ["ds", "yhat"]
         | 
| 40 | 
            +
                  if model.uncertainty_samples
         | 
| 41 | 
            +
                    predict_columns.concat(["yhat_lower", "yhat_upper"])
         | 
| 42 | 
            +
                  end
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                  # Identify largest seasonality period
         | 
| 45 | 
            +
                  period_max = 0.0
         | 
| 46 | 
            +
                  model.seasonalities.each do |_, s|
         | 
| 47 | 
            +
                    period_max = [period_max, s[:period]].max
         | 
| 48 | 
            +
                  end
         | 
| 49 | 
            +
                  seasonality_dt = timedelta("#{period_max} days")
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                  if cutoffs.nil?
         | 
| 52 | 
            +
                    # Set period
         | 
| 53 | 
            +
                    period = period.nil? ? 0.5 * horizon : timedelta(period)
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    # Set initial
         | 
| 56 | 
            +
                    initial = initial.nil? ? [3 * horizon, seasonality_dt].max : timedelta(initial)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    # Compute Cutoffs
         | 
| 59 | 
            +
                    cutoffs = generate_cutoffs(df, horizon, initial, period)
         | 
| 60 | 
            +
                  else
         | 
| 61 | 
            +
                    # add validation of the cutoff to make sure that the min cutoff is strictly greater than the min date in the history
         | 
| 62 | 
            +
                    if cutoffs.min <= df["ds"].min
         | 
| 63 | 
            +
                      raise Error, "Minimum cutoff value is not strictly greater than min date in history"
         | 
| 64 | 
            +
                    end
         | 
| 65 | 
            +
                    # max value of cutoffs is <= (end date minus horizon)
         | 
| 66 | 
            +
                    end_date_minus_horizon = df["ds"].max - horizon
         | 
| 67 | 
            +
                    if cutoffs.max > end_date_minus_horizon
         | 
| 68 | 
            +
                      raise Error, "Maximum cutoff value is greater than end date minus horizon, no value for cross-validation remaining"
         | 
| 69 | 
            +
                    end
         | 
| 70 | 
            +
                    initial = cutoffs[0] - df["ds"].min
         | 
| 71 | 
            +
                  end
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                  # Check if the initial window
         | 
| 74 | 
            +
                  # (that is, the amount of time between the start of the history and the first cutoff)
         | 
| 75 | 
            +
                  # is less than the maximum seasonality period
         | 
| 76 | 
            +
                  if initial < seasonality_dt
         | 
| 77 | 
            +
                    msg = "Seasonality has period of #{period_max} days "
         | 
| 78 | 
            +
                    msg += "which is larger than initial window. "
         | 
| 79 | 
            +
                    msg += "Consider increasing initial."
         | 
| 80 | 
            +
                    # logger.warn(msg)
         | 
| 81 | 
            +
                  end
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                  predicts = cutoffs.map { |cutoff| single_cutoff_forecast(df, model, cutoff, horizon, predict_columns) }
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                  # Combine all predicted DataFrame into one DataFrame
         | 
| 86 | 
            +
                  predicts.reduce(Rover::DataFrame.new) { |memo, v| memo.concat(v) }
         | 
| 87 | 
            +
                end
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                def self.single_cutoff_forecast(df, model, cutoff, horizon, predict_columns)
         | 
| 90 | 
            +
                  # Generate new object with copying fitting options
         | 
| 91 | 
            +
                  m = prophet_copy(model, cutoff)
         | 
| 92 | 
            +
                  # Train model
         | 
| 93 | 
            +
                  history_c = df[df["ds"] <= cutoff]
         | 
| 94 | 
            +
                  if history_c.shape[0] < 2
         | 
| 95 | 
            +
                    raise Error, "Less than two datapoints before cutoff. Increase initial window."
         | 
| 96 | 
            +
                  end
         | 
| 97 | 
            +
                  m.fit(history_c, **model.fit_kwargs)
         | 
| 98 | 
            +
                  # Calculate yhat
         | 
| 99 | 
            +
                  index_predicted = (df["ds"] > cutoff) & (df["ds"] <= cutoff + horizon)
         | 
| 100 | 
            +
                  # Get the columns for the future dataframe
         | 
| 101 | 
            +
                  columns = ["ds"]
         | 
| 102 | 
            +
                  if m.growth == "logistic"
         | 
| 103 | 
            +
                    columns << "cap"
         | 
| 104 | 
            +
                    if m.logistic_floor
         | 
| 105 | 
            +
                      columns << "floor"
         | 
| 106 | 
            +
                    end
         | 
| 107 | 
            +
                  end
         | 
| 108 | 
            +
                  columns.concat(m.extra_regressors.keys)
         | 
| 109 | 
            +
                  columns.concat(m.seasonalities.map { |_, props| props[:condition_name] }.compact)
         | 
| 110 | 
            +
                  yhat = m.predict(df[index_predicted][columns])
         | 
| 111 | 
            +
                  # Merge yhat(predicts), y(df, original data) and cutoff
         | 
| 112 | 
            +
                  yhat[predict_columns].merge(df[index_predicted][["y"]]).merge(Rover::DataFrame.new({"cutoff" => [cutoff] * yhat.length}))
         | 
| 113 | 
            +
                end
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                def self.prophet_copy(m, cutoff = nil)
         | 
| 116 | 
            +
                  if m.history.nil?
         | 
| 117 | 
            +
                    raise Error, "This is for copying a fitted Prophet object."
         | 
| 118 | 
            +
                  end
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                  if m.specified_changepoints
         | 
| 121 | 
            +
                    changepoints = m.changepoints
         | 
| 122 | 
            +
                    if !cutoff.nil?
         | 
| 123 | 
            +
                      # Filter change points '< cutoff'
         | 
| 124 | 
            +
                      last_history_date = m.history["ds"][m.history["ds"] <= cutoff].max
         | 
| 125 | 
            +
                      changepoints = changepoints[changepoints < last_history_date]
         | 
| 126 | 
            +
                    end
         | 
| 127 | 
            +
                  else
         | 
| 128 | 
            +
                    changepoints = nil
         | 
| 129 | 
            +
                  end
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                  # Auto seasonalities are set to False because they are already set in
         | 
| 132 | 
            +
                  # m.seasonalities.
         | 
| 133 | 
            +
                  m2 = m.class.new(
         | 
| 134 | 
            +
                    growth: m.growth,
         | 
| 135 | 
            +
                    n_changepoints: m.n_changepoints,
         | 
| 136 | 
            +
                    changepoint_range: m.changepoint_range,
         | 
| 137 | 
            +
                    changepoints: changepoints,
         | 
| 138 | 
            +
                    yearly_seasonality: false,
         | 
| 139 | 
            +
                    weekly_seasonality: false,
         | 
| 140 | 
            +
                    daily_seasonality: false,
         | 
| 141 | 
            +
                    holidays: m.holidays,
         | 
| 142 | 
            +
                    seasonality_mode: m.seasonality_mode,
         | 
| 143 | 
            +
                    seasonality_prior_scale: m.seasonality_prior_scale,
         | 
| 144 | 
            +
                    changepoint_prior_scale: m.changepoint_prior_scale,
         | 
| 145 | 
            +
                    holidays_prior_scale: m.holidays_prior_scale,
         | 
| 146 | 
            +
                    mcmc_samples: m.mcmc_samples,
         | 
| 147 | 
            +
                    interval_width: m.interval_width,
         | 
| 148 | 
            +
                    uncertainty_samples: m.uncertainty_samples
         | 
| 149 | 
            +
                  )
         | 
| 150 | 
            +
                  m2.extra_regressors = deepcopy(m.extra_regressors)
         | 
| 151 | 
            +
                  m2.seasonalities = deepcopy(m.seasonalities)
         | 
| 152 | 
            +
                  m2.country_holidays = deepcopy(m.country_holidays)
         | 
| 153 | 
            +
                  m2
         | 
| 154 | 
            +
                end
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                def self.timedelta(value)
         | 
| 157 | 
            +
                  if value.is_a?(Numeric)
         | 
| 158 | 
            +
                    # ActiveSupport::Duration is a numeric
         | 
| 159 | 
            +
                    value
         | 
| 160 | 
            +
                  elsif (m = /\A(\d+(\.\d+)?) days\z/.match(value))
         | 
| 161 | 
            +
                    m[1].to_f * 86400
         | 
| 162 | 
            +
                  else
         | 
| 163 | 
            +
                    raise Error, "Unknown time delta"
         | 
| 164 | 
            +
                  end
         | 
| 165 | 
            +
                end
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                def self.deepcopy(value)
         | 
| 168 | 
            +
                  if value.is_a?(Hash)
         | 
| 169 | 
            +
                    value.to_h { |k, v| [deepcopy(k), deepcopy(v)] }
         | 
| 170 | 
            +
                  elsif value.is_a?(Array)
         | 
| 171 | 
            +
                    value.map { |v| deepcopy(v) }
         | 
| 172 | 
            +
                  else
         | 
| 173 | 
            +
                    value.dup
         | 
| 174 | 
            +
                  end
         | 
| 175 | 
            +
                end
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                def self.performance_metrics(df, metrics: nil, rolling_window: 0.1, monthly: false)
         | 
| 178 | 
            +
                  valid_metrics = ["mse", "rmse", "mae", "mape", "mdape", "smape", "coverage"]
         | 
| 179 | 
            +
                  if metrics.nil?
         | 
| 180 | 
            +
                    metrics = valid_metrics
         | 
| 181 | 
            +
                  end
         | 
| 182 | 
            +
                  if (!df.include?("yhat_lower") || !df.include?("yhat_upper")) && metrics.include?("coverage")
         | 
| 183 | 
            +
                    metrics.delete("coverage")
         | 
| 184 | 
            +
                  end
         | 
| 185 | 
            +
                  if metrics.uniq.length != metrics.length
         | 
| 186 | 
            +
                    raise ArgumentError, "Input metrics must be a list of unique values"
         | 
| 187 | 
            +
                  end
         | 
| 188 | 
            +
                  if !Set.new(metrics).subset?(Set.new(valid_metrics))
         | 
| 189 | 
            +
                    raise ArgumentError, "Valid values for metrics are: #{valid_metrics}"
         | 
| 190 | 
            +
                  end
         | 
| 191 | 
            +
                  df_m = df.dup
         | 
| 192 | 
            +
                  if monthly
         | 
| 193 | 
            +
                    raise Error, "Not implemented yet"
         | 
| 194 | 
            +
                    # df_m["horizon"] = df_m["ds"].dt.to_period("M").astype(int) - df_m["cutoff"].dt.to_period("M").astype(int)
         | 
| 195 | 
            +
                  else
         | 
| 196 | 
            +
                    df_m["horizon"] = df_m["ds"] - df_m["cutoff"]
         | 
| 197 | 
            +
                  end
         | 
| 198 | 
            +
                  df_m.sort_by! { |r| r["horizon"] }
         | 
| 199 | 
            +
                  if metrics.include?("mape") && df_m["y"].abs.min < 1e-8
         | 
| 200 | 
            +
                    # logger.info("Skipping MAPE because y close to 0")
         | 
| 201 | 
            +
                    metrics.delete("mape")
         | 
| 202 | 
            +
                  end
         | 
| 203 | 
            +
                  if metrics.length == 0
         | 
| 204 | 
            +
                    return nil
         | 
| 205 | 
            +
                  end
         | 
| 206 | 
            +
                  w = (rolling_window * df_m.shape[0]).to_i
         | 
| 207 | 
            +
                  if w >= 0
         | 
| 208 | 
            +
                    w = [w, 1].max
         | 
| 209 | 
            +
                    w = [w, df_m.shape[0]].min
         | 
| 210 | 
            +
                  end
         | 
| 211 | 
            +
                  # Compute all metrics
         | 
| 212 | 
            +
                  dfs = {}
         | 
| 213 | 
            +
                  metrics.each do |metric|
         | 
| 214 | 
            +
                    dfs[metric] = send(metric, df_m, w)
         | 
| 215 | 
            +
                  end
         | 
| 216 | 
            +
                  res = dfs[metrics[0]]
         | 
| 217 | 
            +
                  metrics.each do |metric|
         | 
| 218 | 
            +
                    res_m = dfs[metric]
         | 
| 219 | 
            +
                    res[metric] = res_m[metric]
         | 
| 220 | 
            +
                  end
         | 
| 221 | 
            +
                  res
         | 
| 222 | 
            +
                end
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                def self.rolling_mean_by_h(x, h, w, name)
         | 
| 225 | 
            +
                  # Aggregate over h
         | 
| 226 | 
            +
                  df = Rover::DataFrame.new({"x" => x, "h" => h})
         | 
| 227 | 
            +
                  df2 = df.group("h").sum("x").inner_join(df.group("h").count).sort_by { |r| r["h"] }
         | 
| 228 | 
            +
                  xs = df2["sum_x"]
         | 
| 229 | 
            +
                  ns = df2["count"]
         | 
| 230 | 
            +
                  hs = df2["h"]
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                  trailing_i = df2.length - 1
         | 
| 233 | 
            +
                  x_sum = 0
         | 
| 234 | 
            +
                  n_sum = 0
         | 
| 235 | 
            +
                  # We don't know output size but it is bounded by len(df2)
         | 
| 236 | 
            +
                  res_x = [nil] * df2.length
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                  # Start from the right and work backwards
         | 
| 239 | 
            +
                  (df2.length - 1).downto(0) do |i|
         | 
| 240 | 
            +
                    x_sum += xs[i]
         | 
| 241 | 
            +
                    n_sum += ns[i]
         | 
| 242 | 
            +
                    while n_sum >= w
         | 
| 243 | 
            +
                      # Include points from the previous horizon. All of them if still
         | 
| 244 | 
            +
                      # less than w, otherwise weight the mean by the difference
         | 
| 245 | 
            +
                      excess_n = n_sum - w
         | 
| 246 | 
            +
                      excess_x = excess_n * xs[i] / ns[i]
         | 
| 247 | 
            +
                      res_x[trailing_i] = (x_sum - excess_x) / w
         | 
| 248 | 
            +
                      x_sum -= xs[trailing_i]
         | 
| 249 | 
            +
                      n_sum -= ns[trailing_i]
         | 
| 250 | 
            +
                      trailing_i -= 1
         | 
| 251 | 
            +
                    end
         | 
| 252 | 
            +
                  end
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                  res_h = hs[(trailing_i + 1)..-1]
         | 
| 255 | 
            +
                  res_x = res_x[(trailing_i + 1)..-1]
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                  Rover::DataFrame.new({"horizon" => res_h, name => res_x})
         | 
| 258 | 
            +
                end
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                def self.rolling_median_by_h(x, h, w, name)
         | 
| 261 | 
            +
                  # Aggregate over h
         | 
| 262 | 
            +
                  df = Rover::DataFrame.new({"x" => x, "h" => h})
         | 
| 263 | 
            +
                  grouped = df.group("h")
         | 
| 264 | 
            +
                  df2 = grouped.count.sort_by { |r| r["h"] }
         | 
| 265 | 
            +
                  hs = df2["h"]
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                  res_h = []
         | 
| 268 | 
            +
                  res_x = []
         | 
| 269 | 
            +
                  # Start from the right and work backwards
         | 
| 270 | 
            +
                  i = hs.length - 1
         | 
| 271 | 
            +
                  while i >= 0
         | 
| 272 | 
            +
                    h_i = hs[i]
         | 
| 273 | 
            +
                    xs = df[df["h"] == h_i]["x"].to_a
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                    next_idx_to_add = (h == h_i).to_numo.cast_to(Numo::UInt8).argmax - 1
         | 
| 276 | 
            +
                    while xs.length < w && next_idx_to_add >= 0
         | 
| 277 | 
            +
                      # Include points from the previous horizon. All of them if still
         | 
| 278 | 
            +
                      # less than w, otherwise just enough to get to w.
         | 
| 279 | 
            +
                      xs << x[next_idx_to_add]
         | 
| 280 | 
            +
                      next_idx_to_add -= 1
         | 
| 281 | 
            +
                    end
         | 
| 282 | 
            +
                    if xs.length < w
         | 
| 283 | 
            +
                      # Ran out of points before getting enough.
         | 
| 284 | 
            +
                      break
         | 
| 285 | 
            +
                    end
         | 
| 286 | 
            +
                    res_h << hs[i]
         | 
| 287 | 
            +
                    res_x << Rover::Vector.new(xs).median
         | 
| 288 | 
            +
                    i -= 1
         | 
| 289 | 
            +
                  end
         | 
| 290 | 
            +
                  res_h.reverse!
         | 
| 291 | 
            +
                  res_x.reverse!
         | 
| 292 | 
            +
                  Rover::DataFrame.new({"horizon" => res_h, name => res_x})
         | 
| 293 | 
            +
                end
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                def self.mse(df, w)
         | 
| 296 | 
            +
                  se = (df["y"] - df["yhat"]) ** 2
         | 
| 297 | 
            +
                  if w < 0
         | 
| 298 | 
            +
                    return Rover::DataFrame.new({"horizon" => df["horizon"], "mse" => se})
         | 
| 299 | 
            +
                  end
         | 
| 300 | 
            +
                  rolling_mean_by_h(se, df["horizon"], w, "mse")
         | 
| 301 | 
            +
                end
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                def self.rmse(df, w)
         | 
| 304 | 
            +
                  res = mse(df, w)
         | 
| 305 | 
            +
                  res["rmse"] = res.delete("mse").map { |v| Math.sqrt(v) }
         | 
| 306 | 
            +
                  res
         | 
| 307 | 
            +
                end
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                def self.mae(df, w)
         | 
| 310 | 
            +
                  ae = (df["y"] - df["yhat"]).abs
         | 
| 311 | 
            +
                  if w < 0
         | 
| 312 | 
            +
                    return Rover::DataFrame.new({"horizon" => df["horizon"], "mae" => ae})
         | 
| 313 | 
            +
                  end
         | 
| 314 | 
            +
                  rolling_mean_by_h(ae, df["horizon"], w, "mae")
         | 
| 315 | 
            +
                end
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                def self.mape(df, w)
         | 
| 318 | 
            +
                  ape = ((df["y"] - df["yhat"]) / df["y"]).abs
         | 
| 319 | 
            +
                  if w < 0
         | 
| 320 | 
            +
                    return Rover::DataFrame.new({"horizon" => df["horizon"], "mape" => ape})
         | 
| 321 | 
            +
                  end
         | 
| 322 | 
            +
                  rolling_mean_by_h(ape, df["horizon"], w, "mape")
         | 
| 323 | 
            +
                end
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                def self.mdape(df, w)
         | 
| 326 | 
            +
                  ape = ((df["y"] - df["yhat"]) / df["y"]).abs
         | 
| 327 | 
            +
                  if w < 0
         | 
| 328 | 
            +
                    return Rover::DataFrame.new({"horizon" => df["horizon"], "mdape" => ape})
         | 
| 329 | 
            +
                  end
         | 
| 330 | 
            +
                  rolling_median_by_h(ape, df["horizon"], w, "mdape")
         | 
| 331 | 
            +
                end
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                def self.smape(df, w)
         | 
| 334 | 
            +
                  sape = (df["y"] - df["yhat"]).abs / ((df["y"].abs + df["yhat"].abs) / 2)
         | 
| 335 | 
            +
                  if w < 0
         | 
| 336 | 
            +
                    return Rover::DataFrame.new({"horizon" => df["horizon"], "smape" => sape})
         | 
| 337 | 
            +
                  end
         | 
| 338 | 
            +
                  rolling_mean_by_h(sape, df["horizon"], w, "smape")
         | 
| 339 | 
            +
                end
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                def self.coverage(df, w)
         | 
| 342 | 
            +
                  is_covered = (df["y"] >= df["yhat_lower"]) & (df["y"] <= df["yhat_upper"])
         | 
| 343 | 
            +
                  if w < 0
         | 
| 344 | 
            +
                    return Rover::DataFrame.new({"horizon" => df["horizon"], "coverage" => is_covered})
         | 
| 345 | 
            +
                  end
         | 
| 346 | 
            +
                  rolling_mean_by_h(is_covered.to(:float), df["horizon"], w, "coverage")
         | 
| 347 | 
            +
                end
         | 
| 348 | 
            +
              end
         | 
| 349 | 
            +
            end
         | 
    
        data/lib/prophet/forecaster.rb
    CHANGED
    
    | @@ -3,7 +3,14 @@ module Prophet | |
| 3 3 | 
             
                include Holidays
         | 
| 4 4 | 
             
                include Plot
         | 
| 5 5 |  | 
| 6 | 
            -
                attr_reader :logger, :params, :train_holiday_names
         | 
| 6 | 
            +
                attr_reader :logger, :params, :train_holiday_names,
         | 
| 7 | 
            +
                  :history, :seasonalities, :specified_changepoints, :fit_kwargs,
         | 
| 8 | 
            +
                  :growth, :changepoints, :n_changepoints, :changepoint_range,
         | 
| 9 | 
            +
                  :holidays, :seasonality_mode, :seasonality_prior_scale,
         | 
| 10 | 
            +
                  :holidays_prior_scale, :changepoint_prior_scale, :mcmc_samples,
         | 
| 11 | 
            +
                  :interval_width, :uncertainty_samples
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                attr_accessor :extra_regressors, :seasonalities, :country_holidays
         | 
| 7 14 |  | 
| 8 15 | 
             
                def initialize(
         | 
| 9 16 | 
             
                  growth: "linear",
         | 
| @@ -82,7 +89,7 @@ module Prophet | |
| 82 89 | 
             
                    raise ArgumentError, "Parameter \"changepoint_range\" must be in [0, 1]"
         | 
| 83 90 | 
             
                  end
         | 
| 84 91 | 
             
                  if @holidays
         | 
| 85 | 
            -
                    if  | 
| 92 | 
            +
                    if !(@holidays.is_a?(Rover::DataFrame) && @holidays.include?("ds") && @holidays.include?("holiday"))
         | 
| 86 93 | 
             
                      raise ArgumentError, "holidays must be a DataFrame with \"ds\" and \"holiday\" columns."
         | 
| 87 94 | 
             
                    end
         | 
| 88 95 | 
             
                    @holidays["ds"] = to_datetime(@holidays["ds"])
         | 
| @@ -118,8 +125,8 @@ module Prophet | |
| 118 125 | 
             
                    "holidays", "zeros", "extra_regressors_additive", "yhat",
         | 
| 119 126 | 
             
                    "extra_regressors_multiplicative", "multiplicative_terms",
         | 
| 120 127 | 
             
                  ]
         | 
| 121 | 
            -
                  rn_l = reserved_names.map { |n| n | 
| 122 | 
            -
                  rn_u = reserved_names.map { |n| n | 
| 128 | 
            +
                  rn_l = reserved_names.map { |n| "#{n}_lower" }
         | 
| 129 | 
            +
                  rn_u = reserved_names.map { |n| "#{n}_upper" }
         | 
| 123 130 | 
             
                  reserved_names.concat(rn_l)
         | 
| 124 131 | 
             
                  reserved_names.concat(rn_u)
         | 
| 125 132 | 
             
                  reserved_names.concat(["ds", "y", "cap", "floor", "y_scaled", "cap_scaled"])
         | 
| @@ -135,7 +142,7 @@ module Prophet | |
| 135 142 | 
             
                  if check_seasonalities && @seasonalities[name]
         | 
| 136 143 | 
             
                    raise ArgumentError, "Name #{name.inspect} already used for a seasonality."
         | 
| 137 144 | 
             
                  end
         | 
| 138 | 
            -
                  if check_regressors  | 
| 145 | 
            +
                  if check_regressors && @extra_regressors[name]
         | 
| 139 146 | 
             
                    raise ArgumentError, "Name #{name.inspect} already used for an added regressor."
         | 
| 140 147 | 
             
                  end
         | 
| 141 148 | 
             
                end
         | 
| @@ -160,7 +167,7 @@ module Prophet | |
| 160 167 | 
             
                      raise ArgumentError, "Found NaN in column #{name.inspect}"
         | 
| 161 168 | 
             
                    end
         | 
| 162 169 | 
             
                  end
         | 
| 163 | 
            -
                  @seasonalities. | 
| 170 | 
            +
                  @seasonalities.each_value do |props|
         | 
| 164 171 | 
             
                    condition_name = props[:condition_name]
         | 
| 165 172 | 
             
                    if condition_name
         | 
| 166 173 | 
             
                      if !df.include?(condition_name)
         | 
| @@ -176,8 +183,10 @@ module Prophet | |
| 176 183 |  | 
| 177 184 | 
             
                  initialize_scales(initialize_scales, df)
         | 
| 178 185 |  | 
| 179 | 
            -
                  if @logistic_floor | 
| 180 | 
            -
                     | 
| 186 | 
            +
                  if @logistic_floor
         | 
| 187 | 
            +
                    unless df.include?("floor")
         | 
| 188 | 
            +
                      raise ArgumentError, "Expected column \"floor\"."
         | 
| 189 | 
            +
                    end
         | 
| 181 190 | 
             
                  else
         | 
| 182 191 | 
             
                    df["floor"] = 0
         | 
| 183 192 | 
             
                  end
         | 
| @@ -207,7 +216,12 @@ module Prophet | |
| 207 216 | 
             
                def initialize_scales(initialize_scales, df)
         | 
| 208 217 | 
             
                  return unless initialize_scales
         | 
| 209 218 |  | 
| 210 | 
            -
                   | 
| 219 | 
            +
                  if @growth == "logistic" && df.include?("floor")
         | 
| 220 | 
            +
                    @logistic_floor = true
         | 
| 221 | 
            +
                    floor = df["floor"]
         | 
| 222 | 
            +
                  else
         | 
| 223 | 
            +
                    floor = 0.0
         | 
| 224 | 
            +
                  end
         | 
| 211 225 | 
             
                  @y_scale = (df["y"] - floor).abs.max
         | 
| 212 226 | 
             
                  @y_scale = 1 if @y_scale == 0
         | 
| 213 227 | 
             
                  @start = df["ds"].min
         | 
| @@ -467,14 +481,14 @@ module Prophet | |
| 467 481 | 
             
                  end
         | 
| 468 482 | 
             
                  # Add totals additive and multiplicative components, and regressors
         | 
| 469 483 | 
             
                  ["additive", "multiplicative"].each do |mode|
         | 
| 470 | 
            -
                    components = add_group_component(components, mode | 
| 484 | 
            +
                    components = add_group_component(components, "#{mode}_terms", modes[mode])
         | 
| 471 485 | 
             
                    regressors_by_mode = @extra_regressors.select { |r, props| props[:mode] == mode }
         | 
| 472 486 | 
             
                      .map { |r, props| r }
         | 
| 473 | 
            -
                    components = add_group_component(components, "extra_regressors_" | 
| 487 | 
            +
                    components = add_group_component(components, "extra_regressors_#{mode}", regressors_by_mode)
         | 
| 474 488 |  | 
| 475 489 | 
             
                    # Add combination components to modes
         | 
| 476 | 
            -
                    modes[mode] << mode | 
| 477 | 
            -
                    modes[mode] << "extra_regressors_" | 
| 490 | 
            +
                    modes[mode] << "#{mode}_terms"
         | 
| 491 | 
            +
                    modes[mode] << "extra_regressors_#{mode}"
         | 
| 478 492 | 
             
                  end
         | 
| 479 493 | 
             
                  # After all of the additive/multiplicative groups have been added,
         | 
| 480 494 | 
             
                  modes[@seasonality_mode] << "holidays"
         | 
| @@ -803,8 +817,8 @@ module Prophet | |
| 803 817 | 
             
                    end
         | 
| 804 818 | 
             
                    data[component] = comp.mean(axis: 1, nan: true)
         | 
| 805 819 | 
             
                    if @uncertainty_samples
         | 
| 806 | 
            -
                      data[component | 
| 807 | 
            -
                      data[component | 
| 820 | 
            +
                      data["#{component}_lower"] = comp.percentile(lower_p, axis: 1)
         | 
| 821 | 
            +
                      data["#{component}_upper"] = comp.percentile(upper_p, axis: 1)
         | 
| 808 822 | 
             
                    end
         | 
| 809 823 | 
             
                  end
         | 
| 810 824 | 
             
                  Rover::DataFrame.new(data)
         | 
| @@ -971,6 +985,12 @@ module Prophet | |
| 971 985 | 
             
                  Rover::DataFrame.new({"ds" => dates})
         | 
| 972 986 | 
             
                end
         | 
| 973 987 |  | 
| 988 | 
            +
                def to_json
         | 
| 989 | 
            +
                  require "json"
         | 
| 990 | 
            +
             | 
| 991 | 
            +
                  JSON.generate(as_json)
         | 
| 992 | 
            +
                end
         | 
| 993 | 
            +
             | 
| 974 994 | 
             
                private
         | 
| 975 995 |  | 
| 976 996 | 
             
                # Time is preferred over DateTime in Ruby docs
         | 
| @@ -1017,5 +1037,189 @@ module Prophet | |
| 1017 1037 | 
             
                  u = Numo::DFloat.new(size).rand(-0.5, 0.5)
         | 
| 1018 1038 | 
             
                  loc - scale * u.sign * Numo::NMath.log(1 - 2 * u.abs)
         | 
| 1019 1039 | 
             
                end
         | 
| 1040 | 
            +
             | 
| 1041 | 
            +
                SIMPLE_ATTRIBUTES = [
         | 
| 1042 | 
            +
                  "growth", "n_changepoints", "specified_changepoints", "changepoint_range",
         | 
| 1043 | 
            +
                  "yearly_seasonality", "weekly_seasonality", "daily_seasonality",
         | 
| 1044 | 
            +
                  "seasonality_mode", "seasonality_prior_scale", "changepoint_prior_scale",
         | 
| 1045 | 
            +
                  "holidays_prior_scale", "mcmc_samples", "interval_width", "uncertainty_samples",
         | 
| 1046 | 
            +
                  "y_scale", "logistic_floor", "country_holidays", "component_modes"
         | 
| 1047 | 
            +
                ]
         | 
| 1048 | 
            +
             | 
| 1049 | 
            +
                PD_SERIES = ["changepoints", "history_dates", "train_holiday_names"]
         | 
| 1050 | 
            +
             | 
| 1051 | 
            +
                PD_TIMESTAMP = ["start"]
         | 
| 1052 | 
            +
             | 
| 1053 | 
            +
                PD_TIMEDELTA = ["t_scale"]
         | 
| 1054 | 
            +
             | 
| 1055 | 
            +
                PD_DATAFRAME = ["holidays", "history", "train_component_cols"]
         | 
| 1056 | 
            +
             | 
| 1057 | 
            +
                NP_ARRAY = ["changepoints_t"]
         | 
| 1058 | 
            +
             | 
| 1059 | 
            +
                ORDEREDDICT = ["seasonalities", "extra_regressors"]
         | 
| 1060 | 
            +
             | 
| 1061 | 
            +
                def as_json
         | 
| 1062 | 
            +
                  if @history.nil?
         | 
| 1063 | 
            +
                    raise Error, "This can only be used to serialize models that have already been fit."
         | 
| 1064 | 
            +
                  end
         | 
| 1065 | 
            +
             | 
| 1066 | 
            +
                  model_dict =
         | 
| 1067 | 
            +
                    SIMPLE_ATTRIBUTES.to_h do |attribute|
         | 
| 1068 | 
            +
                      [attribute, instance_variable_get("@#{attribute}")]
         | 
| 1069 | 
            +
                    end
         | 
| 1070 | 
            +
             | 
| 1071 | 
            +
                  # Handle attributes of non-core types
         | 
| 1072 | 
            +
                  PD_SERIES.each do |attribute|
         | 
| 1073 | 
            +
                    if instance_variable_get("@#{attribute}").nil?
         | 
| 1074 | 
            +
                      model_dict[attribute] = nil
         | 
| 1075 | 
            +
                    else
         | 
| 1076 | 
            +
                      v = instance_variable_get("@#{attribute}")
         | 
| 1077 | 
            +
                      d = {
         | 
| 1078 | 
            +
                        "name" => "ds",
         | 
| 1079 | 
            +
                        "index" => v.size.times.to_a,
         | 
| 1080 | 
            +
                        "data" => v.to_a.map { |v| v.iso8601(3) }
         | 
| 1081 | 
            +
                      }
         | 
| 1082 | 
            +
                      model_dict[attribute] = JSON.generate(d)
         | 
| 1083 | 
            +
                    end
         | 
| 1084 | 
            +
                  end
         | 
| 1085 | 
            +
                  PD_TIMESTAMP.each do |attribute|
         | 
| 1086 | 
            +
                    model_dict[attribute] = instance_variable_get("@#{attribute}").to_f
         | 
| 1087 | 
            +
                  end
         | 
| 1088 | 
            +
                  PD_TIMEDELTA.each do |attribute|
         | 
| 1089 | 
            +
                    model_dict[attribute] = instance_variable_get("@#{attribute}").to_f
         | 
| 1090 | 
            +
                  end
         | 
| 1091 | 
            +
                  PD_DATAFRAME.each do |attribute|
         | 
| 1092 | 
            +
                    if instance_variable_get("@#{attribute}").nil?
         | 
| 1093 | 
            +
                      model_dict[attribute] = nil
         | 
| 1094 | 
            +
                    else
         | 
| 1095 | 
            +
                      # use same format as Pandas
         | 
| 1096 | 
            +
                      v = instance_variable_get("@#{attribute}")
         | 
| 1097 | 
            +
             | 
| 1098 | 
            +
                      v = v.dup
         | 
| 1099 | 
            +
                      v["ds"] = v["ds"].map { |v| v.iso8601(3) } if v["ds"]
         | 
| 1100 | 
            +
                      v.delete("col")
         | 
| 1101 | 
            +
             | 
| 1102 | 
            +
                      fields =
         | 
| 1103 | 
            +
                        v.types.map do |k, t|
         | 
| 1104 | 
            +
                          type =
         | 
| 1105 | 
            +
                            case t
         | 
| 1106 | 
            +
                            when :object
         | 
| 1107 | 
            +
                              "datetime"
         | 
| 1108 | 
            +
                            when :int64
         | 
| 1109 | 
            +
                              "integer"
         | 
| 1110 | 
            +
                            else
         | 
| 1111 | 
            +
                              "number"
         | 
| 1112 | 
            +
                            end
         | 
| 1113 | 
            +
                          {"name" => k, "type" => type}
         | 
| 1114 | 
            +
                        end
         | 
| 1115 | 
            +
             | 
| 1116 | 
            +
                      d = {
         | 
| 1117 | 
            +
                        "schema" => {
         | 
| 1118 | 
            +
                          "fields" => fields,
         | 
| 1119 | 
            +
                          "pandas_version" => "0.20.0"
         | 
| 1120 | 
            +
                        },
         | 
| 1121 | 
            +
                        "data" => v.to_a
         | 
| 1122 | 
            +
                      }
         | 
| 1123 | 
            +
                      model_dict[attribute] = JSON.generate(d)
         | 
| 1124 | 
            +
                    end
         | 
| 1125 | 
            +
                  end
         | 
| 1126 | 
            +
                  NP_ARRAY.each do |attribute|
         | 
| 1127 | 
            +
                    model_dict[attribute] = instance_variable_get("@#{attribute}").to_a
         | 
| 1128 | 
            +
                  end
         | 
| 1129 | 
            +
                  ORDEREDDICT.each do |attribute|
         | 
| 1130 | 
            +
                    model_dict[attribute] = [
         | 
| 1131 | 
            +
                      instance_variable_get("@#{attribute}").keys,
         | 
| 1132 | 
            +
                      instance_variable_get("@#{attribute}").transform_keys(&:to_s)
         | 
| 1133 | 
            +
                    ]
         | 
| 1134 | 
            +
                  end
         | 
| 1135 | 
            +
                  # Other attributes with special handling
         | 
| 1136 | 
            +
                  # fit_kwargs -> Transform any numpy types before serializing.
         | 
| 1137 | 
            +
                  # They do not need to be transformed back on deserializing.
         | 
| 1138 | 
            +
                  # TODO deep copy
         | 
| 1139 | 
            +
                  fit_kwargs = @fit_kwargs.to_h { |k, v| [k.to_s, v.dup] }
         | 
| 1140 | 
            +
                  if fit_kwargs.key?("init")
         | 
| 1141 | 
            +
                    fit_kwargs["init"].each do |k, v|
         | 
| 1142 | 
            +
                      if v.is_a?(Numo::NArray)
         | 
| 1143 | 
            +
                        fit_kwargs["init"][k] = v.to_a
         | 
| 1144 | 
            +
                      # elsif v.is_a?(Float)
         | 
| 1145 | 
            +
                      #   fit_kwargs["init"][k] = v.to_f
         | 
| 1146 | 
            +
                      end
         | 
| 1147 | 
            +
                    end
         | 
| 1148 | 
            +
                  end
         | 
| 1149 | 
            +
                  model_dict["fit_kwargs"] = fit_kwargs
         | 
| 1150 | 
            +
             | 
| 1151 | 
            +
                  # Params (Dict[str, np.ndarray])
         | 
| 1152 | 
            +
                  model_dict["params"] = params.transform_values(&:to_a)
         | 
| 1153 | 
            +
                  # Attributes that are skipped: stan_fit, stan_backend
         | 
| 1154 | 
            +
                  # Returns 1.0 for Prophet 1.1
         | 
| 1155 | 
            +
                  model_dict["__prophet_version"] = "1.0"
         | 
| 1156 | 
            +
                  model_dict
         | 
| 1157 | 
            +
                end
         | 
| 1158 | 
            +
             | 
| 1159 | 
            +
                def self.from_json(model_json)
         | 
| 1160 | 
            +
                  require "json"
         | 
| 1161 | 
            +
             | 
| 1162 | 
            +
                  model_dict = JSON.parse(model_json)
         | 
| 1163 | 
            +
             | 
| 1164 | 
            +
                  # We will overwrite all attributes set in init anyway
         | 
| 1165 | 
            +
                  model = Prophet.new
         | 
| 1166 | 
            +
                  # Simple types
         | 
| 1167 | 
            +
                  SIMPLE_ATTRIBUTES.each do |attribute|
         | 
| 1168 | 
            +
                    model.instance_variable_set("@#{attribute}", model_dict.fetch(attribute))
         | 
| 1169 | 
            +
                  end
         | 
| 1170 | 
            +
                  PD_SERIES.each do |attribute|
         | 
| 1171 | 
            +
                    if model_dict[attribute].nil?
         | 
| 1172 | 
            +
                      model.instance_variable_set("@#{attribute}", nil)
         | 
| 1173 | 
            +
                    else
         | 
| 1174 | 
            +
                      d = JSON.parse(model_dict.fetch(attribute))
         | 
| 1175 | 
            +
                      s = Rover::Vector.new(d["data"])
         | 
| 1176 | 
            +
                      if d["name"] == "ds"
         | 
| 1177 | 
            +
                        s = s.map { |v| Time.parse(v).utc }
         | 
| 1178 | 
            +
                      end
         | 
| 1179 | 
            +
                      model.instance_variable_set("@#{attribute}", s)
         | 
| 1180 | 
            +
                    end
         | 
| 1181 | 
            +
                  end
         | 
| 1182 | 
            +
                  PD_TIMESTAMP.each do |attribute|
         | 
| 1183 | 
            +
                    model.instance_variable_set("@#{attribute}", Time.at(model_dict.fetch(attribute)))
         | 
| 1184 | 
            +
                  end
         | 
| 1185 | 
            +
                  PD_TIMEDELTA.each do |attribute|
         | 
| 1186 | 
            +
                    model.instance_variable_set("@#{attribute}", model_dict.fetch(attribute).to_f)
         | 
| 1187 | 
            +
                  end
         | 
| 1188 | 
            +
                  PD_DATAFRAME.each do |attribute|
         | 
| 1189 | 
            +
                    if model_dict[attribute].nil?
         | 
| 1190 | 
            +
                      model.instance_variable_set("@#{attribute}", nil)
         | 
| 1191 | 
            +
                    else
         | 
| 1192 | 
            +
                      d = JSON.parse(model_dict.fetch(attribute))
         | 
| 1193 | 
            +
                      df = Rover::DataFrame.new(d["data"])
         | 
| 1194 | 
            +
                      df["ds"] = df["ds"].map { |v| Time.parse(v).utc } if df["ds"]
         | 
| 1195 | 
            +
                      if attribute == "train_component_cols"
         | 
| 1196 | 
            +
                        # Special handling because of named index column
         | 
| 1197 | 
            +
                        # df.columns.name = 'component'
         | 
| 1198 | 
            +
                        # df.index.name = 'col'
         | 
| 1199 | 
            +
                      end
         | 
| 1200 | 
            +
                      model.instance_variable_set("@#{attribute}", df)
         | 
| 1201 | 
            +
                    end
         | 
| 1202 | 
            +
                  end
         | 
| 1203 | 
            +
                  NP_ARRAY.each do |attribute|
         | 
| 1204 | 
            +
                    model.instance_variable_set("@#{attribute}", Numo::NArray.cast(model_dict.fetch(attribute)))
         | 
| 1205 | 
            +
                  end
         | 
| 1206 | 
            +
                  ORDEREDDICT.each do |attribute|
         | 
| 1207 | 
            +
                    key_list, unordered_dict = model_dict.fetch(attribute)
         | 
| 1208 | 
            +
                    od = {}
         | 
| 1209 | 
            +
                    key_list.each do |key|
         | 
| 1210 | 
            +
                      od[key] = unordered_dict[key].transform_keys(&:to_sym)
         | 
| 1211 | 
            +
                    end
         | 
| 1212 | 
            +
                    model.instance_variable_set("@#{attribute}", od)
         | 
| 1213 | 
            +
                  end
         | 
| 1214 | 
            +
                  # Other attributes with special handling
         | 
| 1215 | 
            +
                  # fit_kwargs
         | 
| 1216 | 
            +
                  model.instance_variable_set(:@fit_kwargs, model_dict["fit_kwargs"].transform_keys(&:to_sym))
         | 
| 1217 | 
            +
                  # Params (Dict[str, np.ndarray])
         | 
| 1218 | 
            +
                  model.instance_variable_set(:@params, model_dict["params"].transform_values { |v| Numo::NArray.cast(v) })
         | 
| 1219 | 
            +
                  # Skipped attributes
         | 
| 1220 | 
            +
                  # model.stan_backend = nil
         | 
| 1221 | 
            +
                  model.instance_variable_set(:@stan_fit, nil)
         | 
| 1222 | 
            +
                  model
         | 
| 1223 | 
            +
                end
         | 
| 1020 1224 | 
             
              end
         | 
| 1021 1225 | 
             
            end
         | 
    
        data/lib/prophet/holidays.rb
    CHANGED
    
    | @@ -3,8 +3,11 @@ module Prophet | |
| 3 3 | 
             
                def get_holiday_names(country)
         | 
| 4 4 | 
             
                  years = (1995..2045).to_a
         | 
| 5 5 | 
             
                  holiday_names = make_holidays_df(years, country)["holiday"].uniq
         | 
| 6 | 
            -
                   | 
| 7 | 
            -
             | 
| 6 | 
            +
                  if holiday_names.size == 0
         | 
| 7 | 
            +
                    # TODO raise error in 0.5.0
         | 
| 8 | 
            +
                    # raise ArgumentError, "Holidays in #{country} are not currently supported"
         | 
| 9 | 
            +
                    logger.warn "Holidays in #{country} are not currently supported"
         | 
| 10 | 
            +
                  end
         | 
| 8 11 | 
             
                  holiday_names
         | 
| 9 12 | 
             
                end
         | 
| 10 13 |  |