prophet-rb 0.4.0 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 2bcbd43f3750bc8c70fe28f2b0b42e88464724601daa754a9d60d8e8e354c701
4
- data.tar.gz: c16f1ab7f2d48419543b54f326fc98ac1261c7e5ceaf7d3b0fd352b87a2fdcb3
3
+ metadata.gz: 69d58f060a9bda44b1ab8666ded81b3e61ff1b95b08ae40727fa7a2dfee57ee0
4
+ data.tar.gz: dc76685a8b45ca7cad79561f986af9c66e5e49bd45dd5e10c72f91088d9b470a
5
5
  SHA512:
6
- metadata.gz: e8b3cf363a665f063d7045b3c34c124021281e063e61ff99e7bd98ba774a0e40a2d379aeb53f521833724b11d5ff1c7934820b73095ad5384b70066fc20e5e36
7
- data.tar.gz: 177f8eab90b9be0e5112c3d158ef9b522309143c559a97ffa206b7cc63a83f0eb70e75d7e1f8ab9640e57e403bd85d1469ecdb9137b0be8acc03c7ff09a91447
6
+ metadata.gz: 34d2fd0587110c6de9db334c44a4859a5de92d1cc124efb5f865c845d504816f3d3cd34d63776a709de396af8f9b2cf82a390cfb9f3e49e685ef8be4f469fb3e
7
+ data.tar.gz: '0587986407bb68a9ca97928ad65bcbb42fe101bd9d2e50a44b2126df38d67c04544932ec39009c6f1d6467736aa0a8404ac4aad04187dcb953285d4d5ee80a77'
data/CHANGELOG.md CHANGED
@@ -1,3 +1,9 @@
1
+ ## 0.4.1 (2022-07-10)
2
+
3
+ - Added support for cross validation and performance metrics
4
+ - Added support for updating fitted models
5
+ - Added support for saturating minimum forecasts
6
+
1
7
  ## 0.4.0 (2022-07-07)
2
8
 
3
9
  - Added support for saving and loading models
data/README.md CHANGED
@@ -88,7 +88,8 @@ Check out the [Prophet documentation](https://facebook.github.io/prophet/docs/qu
88
88
  - [Multiplicative Seasonality](#multiplicative-seasonality)
89
89
  - [Uncertainty Intervals](#uncertainty-intervals)
90
90
  - [Non-Daily Data](#non-daily-data)
91
- - [Saving Models](#saving-models)
91
+ - [Diagnostics](#diagnostics)
92
+ - [Additional Topics](#additional-topics)
92
93
 
93
94
  ## Advanced Quick Start
94
95
 
@@ -177,11 +178,24 @@ df = Rover.read_csv("example_wp_log_R.csv")
177
178
  df["cap"] = 8.5
178
179
  m = Prophet.new(growth: "logistic")
179
180
  m.fit(df)
180
- future = m.make_future_dataframe(periods: 365)
181
+ future = m.make_future_dataframe(periods: 1826)
181
182
  future["cap"] = 8.5
182
183
  forecast = m.predict(future)
183
184
  ```
184
185
 
186
+ Saturating minimum
187
+
188
+ ```ruby
189
+ df["y"] = 10 - df["y"]
190
+ df["cap"] = 6
191
+ df["floor"] = 1.5
192
+ future["cap"] = 6
193
+ future["floor"] = 1.5
194
+ m = Prophet.new(growth: "logistic")
195
+ m.fit(df)
196
+ forecast = m.predict(future)
197
+ ```
198
+
185
199
  ## Trend Changepoints
186
200
 
187
201
  [Explanation](https://facebook.github.io/prophet/docs/trend_changepoints.html)
@@ -308,9 +322,64 @@ future = m.make_future_dataframe(periods: 300, freq: "H")
308
322
  forecast = m.predict(future)
309
323
  ```
310
324
 
311
- ## Saving Models
325
+ ## Diagnostics
326
+
327
+ [Explanation](http://facebook.github.io/prophet/docs/diagnostics.html)
328
+
329
+ Cross validation
330
+
331
+ ```ruby
332
+ df_cv = Prophet::Diagnostics.cross_validation(m, initial: "730 days", period: "180 days", horizon: "365 days")
333
+ ```
334
+
335
+ Custom cutoffs
336
+
337
+ ```ruby
338
+ cutoffs = ["2013-02-15", "2013-08-15", "2014-02-15"].map { |v| Time.parse("#{v} 00:00:00 UTC") }
339
+ df_cv2 = Prophet::Diagnostics.cross_validation(m, cutoffs: cutoffs, horizon: "365 days")
340
+ ```
341
+
342
+ Get performance metrics
312
343
 
313
- [Explanation](https://facebook.github.io/prophet/docs/additional_topics.html#saving-models)
344
+ ```ruby
345
+ df_p = Prophet::Diagnostics.performance_metrics(df_cv)
346
+ ```
347
+
348
+ Plot cross validation metrics
349
+
350
+ ```ruby
351
+ Prophet::Plot.plot_cross_validation_metric(df_cv, metric: "mape")
352
+ ```
353
+
354
+ Hyperparameter tuning
355
+
356
+ ```ruby
357
+ param_grid = {
358
+ changepoint_prior_scale: [0.001, 0.01, 0.1, 0.5],
359
+ seasonality_prior_scale: [0.01, 0.1, 1.0, 10.0]
360
+ }
361
+
362
+ # Generate all combinations of parameters
363
+ all_params = param_grid.values[0].product(*param_grid.values[1..-1]).map { |v| param_grid.keys.zip(v).to_h }
364
+ rmses = [] # Store the RMSEs for each params here
365
+
366
+ # Use cross validation to evaluate all parameters
367
+ all_params.each do |params|
368
+ m = Prophet.new(**params).fit(df) # Fit model with given params
369
+ df_cv = Prophet::Diagnostics.cross_validation(m, cutoffs: cutoffs, horizon: "30 days")
370
+ df_p = Prophet::Diagnostics.performance_metrics(df_cv, rolling_window: 1)
371
+ rmses << df_p["rmse"][0]
372
+ end
373
+
374
+ # Find the best parameters
375
+ tuning_results = Rover::DataFrame.new(all_params)
376
+ tuning_results["rmse"] = rmses
377
+ p tuning_results
378
+ ```
379
+
380
+ ## Additional Topics
381
+
382
+ [Explanation](https://facebook.github.io/prophet/docs/additional_topics.html)
314
383
 
315
384
  Save a model
316
385
 
@@ -326,6 +395,34 @@ m = Prophet.from_json(File.read("model.json"))
326
395
 
327
396
  Uses the same format as Python, so models can be saved and loaded in either language
328
397
 
398
+ Flat trend
399
+
400
+ ```ruby
401
+ m = Prophet.new(growth: "flat")
402
+ ```
403
+
404
+ Updating fitted models
405
+
406
+ ```ruby
407
+ def stan_init(m)
408
+ res = {}
409
+ ["k", "m", "sigma_obs"].each do |pname|
410
+ res[pname] = m.params[pname][0, true][0]
411
+ end
412
+ ["delta", "beta"].each do |pname|
413
+ res[pname] = m.params[pname][0, true]
414
+ end
415
+ res
416
+ end
417
+
418
+ df = Rover.read_csv("example_wp_log_peyton_manning.csv")
419
+ df1 = df[df["ds"] <= "2016-01-19"] # All data except the last day
420
+ m1 = Prophet.new.fit(df1) # A model fit to all data except the last day
421
+
422
+ m2 = Prophet.new.fit(df) # Adding the last day, fitting from scratch
423
+ m2 = Prophet.new.fit(df, init: stan_init(m1)) # Adding the last day, warm-starting from m1
424
+ ```
425
+
329
426
  ## Resources
330
427
 
331
428
  - [Forecasting at Scale](https://peerj.com/preprints/3190.pdf)
@@ -0,0 +1,349 @@
1
+ module Prophet
2
+ module Diagnostics
3
+ def self.generate_cutoffs(df, horizon, initial, period)
4
+ # Last cutoff is 'latest date in data - horizon' date
5
+ cutoff = df["ds"].max - horizon
6
+ if cutoff < df["ds"].min
7
+ raise Error, "Less data than horizon."
8
+ end
9
+ result = [cutoff]
10
+ while result[-1] >= df["ds"].min + initial
11
+ cutoff -= period
12
+ # If data does not exist in data range (cutoff, cutoff + horizon]
13
+ if !(((df["ds"] > cutoff) & (df["ds"] <= cutoff + horizon)).any?)
14
+ # Next cutoff point is 'last date before cutoff in data - horizon'
15
+ if cutoff > df["ds"].min
16
+ closest_date = df[df["ds"] <= cutoff].max["ds"]
17
+ cutoff = closest_date - horizon
18
+ end
19
+ # else no data left, leave cutoff as is, it will be dropped.
20
+ end
21
+ result << cutoff
22
+ end
23
+ result = result[0...-1]
24
+ if result.length == 0
25
+ raise Error, "Less data than horizon after initial window. Make horizon or initial shorter."
26
+ end
27
+ # logger.info("Making #{result.length} forecasts with cutoffs between #{result[-1]} and #{result[0]}")
28
+ result.reverse
29
+ end
30
+
31
+ def self.cross_validation(model, horizon:, period: nil, initial: nil, cutoffs: nil)
32
+ if model.history.nil?
33
+ raise Error, "Model has not been fit. Fitting the model provides contextual parameters for cross validation."
34
+ end
35
+
36
+ df = model.history.dup
37
+ horizon = timedelta(horizon)
38
+
39
+ predict_columns = ["ds", "yhat"]
40
+ if model.uncertainty_samples
41
+ predict_columns.concat(["yhat_lower", "yhat_upper"])
42
+ end
43
+
44
+ # Identify largest seasonality period
45
+ period_max = 0.0
46
+ model.seasonalities.each do |_, s|
47
+ period_max = [period_max, s[:period]].max
48
+ end
49
+ seasonality_dt = timedelta("#{period_max} days")
50
+
51
+ if cutoffs.nil?
52
+ # Set period
53
+ period = period.nil? ? 0.5 * horizon : timedelta(period)
54
+
55
+ # Set initial
56
+ initial = initial.nil? ? [3 * horizon, seasonality_dt].max : timedelta(initial)
57
+
58
+ # Compute Cutoffs
59
+ cutoffs = generate_cutoffs(df, horizon, initial, period)
60
+ else
61
+ # add validation of the cutoff to make sure that the min cutoff is strictly greater than the min date in the history
62
+ if cutoffs.min <= df["ds"].min
63
+ raise Error, "Minimum cutoff value is not strictly greater than min date in history"
64
+ end
65
+ # max value of cutoffs is <= (end date minus horizon)
66
+ end_date_minus_horizon = df["ds"].max - horizon
67
+ if cutoffs.max > end_date_minus_horizon
68
+ raise Error, "Maximum cutoff value is greater than end date minus horizon, no value for cross-validation remaining"
69
+ end
70
+ initial = cutoffs[0] - df["ds"].min
71
+ end
72
+
73
+ # Check if the initial window
74
+ # (that is, the amount of time between the start of the history and the first cutoff)
75
+ # is less than the maximum seasonality period
76
+ if initial < seasonality_dt
77
+ msg = "Seasonality has period of #{period_max} days "
78
+ msg += "which is larger than initial window. "
79
+ msg += "Consider increasing initial."
80
+ # logger.warn(msg)
81
+ end
82
+
83
+ predicts = cutoffs.map { |cutoff| single_cutoff_forecast(df, model, cutoff, horizon, predict_columns) }
84
+
85
+ # Combine all predicted DataFrame into one DataFrame
86
+ predicts.reduce(Rover::DataFrame.new) { |memo, v| memo.concat(v) }
87
+ end
88
+
89
+ def self.single_cutoff_forecast(df, model, cutoff, horizon, predict_columns)
90
+ # Generate new object with copying fitting options
91
+ m = prophet_copy(model, cutoff)
92
+ # Train model
93
+ history_c = df[df["ds"] <= cutoff]
94
+ if history_c.shape[0] < 2
95
+ raise Error, "Less than two datapoints before cutoff. Increase initial window."
96
+ end
97
+ m.fit(history_c, **model.fit_kwargs)
98
+ # Calculate yhat
99
+ index_predicted = (df["ds"] > cutoff) & (df["ds"] <= cutoff + horizon)
100
+ # Get the columns for the future dataframe
101
+ columns = ["ds"]
102
+ if m.growth == "logistic"
103
+ columns << "cap"
104
+ if m.logistic_floor
105
+ columns << "floor"
106
+ end
107
+ end
108
+ columns.concat(m.extra_regressors.keys)
109
+ columns.concat(m.seasonalities.map { |_, props| props[:condition_name] }.compact)
110
+ yhat = m.predict(df[index_predicted][columns])
111
+ # Merge yhat(predicts), y(df, original data) and cutoff
112
+ yhat[predict_columns].merge(df[index_predicted][["y"]]).merge(Rover::DataFrame.new({"cutoff" => [cutoff] * yhat.length}))
113
+ end
114
+
115
+ def self.prophet_copy(m, cutoff = nil)
116
+ if m.history.nil?
117
+ raise Error, "This is for copying a fitted Prophet object."
118
+ end
119
+
120
+ if m.specified_changepoints
121
+ changepoints = m.changepoints
122
+ if !cutoff.nil?
123
+ # Filter change points '< cutoff'
124
+ last_history_date = m.history["ds"][m.history["ds"] <= cutoff].max
125
+ changepoints = changepoints[changepoints < last_history_date]
126
+ end
127
+ else
128
+ changepoints = nil
129
+ end
130
+
131
+ # Auto seasonalities are set to False because they are already set in
132
+ # m.seasonalities.
133
+ m2 = m.class.new(
134
+ growth: m.growth,
135
+ n_changepoints: m.n_changepoints,
136
+ changepoint_range: m.changepoint_range,
137
+ changepoints: changepoints,
138
+ yearly_seasonality: false,
139
+ weekly_seasonality: false,
140
+ daily_seasonality: false,
141
+ holidays: m.holidays,
142
+ seasonality_mode: m.seasonality_mode,
143
+ seasonality_prior_scale: m.seasonality_prior_scale,
144
+ changepoint_prior_scale: m.changepoint_prior_scale,
145
+ holidays_prior_scale: m.holidays_prior_scale,
146
+ mcmc_samples: m.mcmc_samples,
147
+ interval_width: m.interval_width,
148
+ uncertainty_samples: m.uncertainty_samples
149
+ )
150
+ m2.extra_regressors = deepcopy(m.extra_regressors)
151
+ m2.seasonalities = deepcopy(m.seasonalities)
152
+ m2.country_holidays = deepcopy(m.country_holidays)
153
+ m2
154
+ end
155
+
156
+ def self.timedelta(value)
157
+ if value.is_a?(Numeric)
158
+ # ActiveSupport::Duration is a numeric
159
+ value
160
+ elsif (m = /\A(\d+(\.\d+)?) days\z/.match(value))
161
+ m[1].to_f * 86400
162
+ else
163
+ raise Error, "Unknown time delta"
164
+ end
165
+ end
166
+
167
+ def self.deepcopy(value)
168
+ if value.is_a?(Hash)
169
+ value.to_h { |k, v| [deepcopy(k), deepcopy(v)] }
170
+ elsif value.is_a?(Array)
171
+ value.map { |v| deepcopy(v) }
172
+ else
173
+ value.dup
174
+ end
175
+ end
176
+
177
+ def self.performance_metrics(df, metrics: nil, rolling_window: 0.1, monthly: false)
178
+ valid_metrics = ["mse", "rmse", "mae", "mape", "mdape", "smape", "coverage"]
179
+ if metrics.nil?
180
+ metrics = valid_metrics
181
+ end
182
+ if (df["yhat_lower"].nil? || df["yhat_upper"].nil?) && metrics.include?("coverage")
183
+ metrics.delete("coverage")
184
+ end
185
+ if metrics.uniq.length != metrics.length
186
+ raise ArgumentError, "Input metrics must be a list of unique values"
187
+ end
188
+ if !Set.new(metrics).subset?(Set.new(valid_metrics))
189
+ raise ArgumentError, "Valid values for metrics are: #{valid_metrics}"
190
+ end
191
+ df_m = df.dup
192
+ if monthly
193
+ raise Error, "Not implemented yet"
194
+ # df_m["horizon"] = df_m["ds"].dt.to_period("M").astype(int) - df_m["cutoff"].dt.to_period("M").astype(int)
195
+ else
196
+ df_m["horizon"] = df_m["ds"] - df_m["cutoff"]
197
+ end
198
+ df_m.sort_by! { |r| r["horizon"] }
199
+ if metrics.include?("mape") && df_m["y"].abs.min < 1e-8
200
+ # logger.info("Skipping MAPE because y close to 0")
201
+ metrics.delete("mape")
202
+ end
203
+ if metrics.length == 0
204
+ return nil
205
+ end
206
+ w = (rolling_window * df_m.shape[0]).to_i
207
+ if w >= 0
208
+ w = [w, 1].max
209
+ w = [w, df_m.shape[0]].min
210
+ end
211
+ # Compute all metrics
212
+ dfs = {}
213
+ metrics.each do |metric|
214
+ dfs[metric] = send(metric, df_m, w)
215
+ end
216
+ res = dfs[metrics[0]]
217
+ metrics.each do |metric|
218
+ res_m = dfs[metric]
219
+ res[metric] = res_m[metric]
220
+ end
221
+ res
222
+ end
223
+
224
+ def self.rolling_mean_by_h(x, h, w, name)
225
+ # Aggregate over h
226
+ df = Rover::DataFrame.new({"x" => x, "h" => h})
227
+ df2 = df.group("h").sum("x").inner_join(df.group("h").count).sort_by { |r| r["h"] }
228
+ xs = df2["sum_x"]
229
+ ns = df2["count"]
230
+ hs = df2["h"]
231
+
232
+ trailing_i = df2.length - 1
233
+ x_sum = 0
234
+ n_sum = 0
235
+ # We don't know output size but it is bounded by len(df2)
236
+ res_x = [nil] * df2.length
237
+
238
+ # Start from the right and work backwards
239
+ (df2.length - 1).downto(0) do |i|
240
+ x_sum += xs[i]
241
+ n_sum += ns[i]
242
+ while n_sum >= w
243
+ # Include points from the previous horizon. All of them if still
244
+ # less than w, otherwise weight the mean by the difference
245
+ excess_n = n_sum - w
246
+ excess_x = excess_n * xs[i] / ns[i]
247
+ res_x[trailing_i] = (x_sum - excess_x) / w
248
+ x_sum -= xs[trailing_i]
249
+ n_sum -= ns[trailing_i]
250
+ trailing_i -= 1
251
+ end
252
+ end
253
+
254
+ res_h = hs[(trailing_i + 1)..-1]
255
+ res_x = res_x[(trailing_i + 1)..-1]
256
+
257
+ Rover::DataFrame.new({"horizon" => res_h, name => res_x})
258
+ end
259
+
260
+ def self.rolling_median_by_h(x, h, w, name)
261
+ # Aggregate over h
262
+ df = Rover::DataFrame.new({"x" => x, "h" => h})
263
+ grouped = df.group("h")
264
+ df2 = grouped.count.sort_by { |r| r["h"] }
265
+ hs = df2["h"]
266
+
267
+ res_h = []
268
+ res_x = []
269
+ # Start from the right and work backwards
270
+ i = hs.length - 1
271
+ while i >= 0
272
+ h_i = hs[i]
273
+ xs = df[df["h"] == h_i]["x"].to_a
274
+
275
+ next_idx_to_add = (h == h_i).to_numo.cast_to(Numo::UInt8).argmax - 1
276
+ while xs.length < w && next_idx_to_add >= 0
277
+ # Include points from the previous horizon. All of them if still
278
+ # less than w, otherwise just enough to get to w.
279
+ xs << x[next_idx_to_add]
280
+ next_idx_to_add -= 1
281
+ end
282
+ if xs.length < w
283
+ # Ran out of points before getting enough.
284
+ break
285
+ end
286
+ res_h << hs[i]
287
+ res_x << Rover::Vector.new(xs).median
288
+ i -= 1
289
+ end
290
+ res_h.reverse!
291
+ res_x.reverse!
292
+ Rover::DataFrame.new({"horizon" => res_h, name => res_x})
293
+ end
294
+
295
+ def self.mse(df, w)
296
+ se = (df["y"] - df["yhat"]) ** 2
297
+ if w < 0
298
+ return Rover::DataFrame.new({"horizon" => df["horizon"], "mse" => se})
299
+ end
300
+ rolling_mean_by_h(se, df["horizon"], w, "mse")
301
+ end
302
+
303
+ def self.rmse(df, w)
304
+ res = mse(df, w)
305
+ res["rmse"] = res.delete("mse").map { |v| Math.sqrt(v) }
306
+ res
307
+ end
308
+
309
+ def self.mae(df, w)
310
+ ae = (df["y"] - df["yhat"]).abs
311
+ if w < 0
312
+ return Rover::DataFrame.new({"horizon" => df["horizon"], "mae" => ae})
313
+ end
314
+ rolling_mean_by_h(ae, df["horizon"], w, "mae")
315
+ end
316
+
317
+ def self.mape(df, w)
318
+ ape = ((df["y"] - df["yhat"]) / df["y"]).abs
319
+ if w < 0
320
+ return Rover::DataFrame.new({"horizon" => df["horizon"], "mape" => ape})
321
+ end
322
+ rolling_mean_by_h(ape, df["horizon"], w, "mape")
323
+ end
324
+
325
+ def self.mdape(df, w)
326
+ ape = ((df["y"] - df["yhat"]) / df["y"]).abs
327
+ if w < 0
328
+ return Rover::DataFrame.new({"horizon" => df["horizon"], "mdape" => ape})
329
+ end
330
+ rolling_median_by_h(ape, df["horizon"], w, "mdape")
331
+ end
332
+
333
+ def self.smape(df, w)
334
+ sape = (df["y"] - df["yhat"]).abs / ((df["y"].abs + df["yhat"].abs) / 2)
335
+ if w < 0
336
+ return Rover::DataFrame.new({"horizon" => df["horizon"], "smape" => sape})
337
+ end
338
+ rolling_mean_by_h(sape, df["horizon"], w, "smape")
339
+ end
340
+
341
+ def self.coverage(df, w)
342
+ is_covered = (df["y"] >= df["yhat_lower"]) & (df["y"] <= df["yhat_upper"])
343
+ if w < 0
344
+ return Rover::DataFrame.new({"horizon" => df["horizon"], "coverage" => is_covered})
345
+ end
346
+ rolling_mean_by_h(is_covered.to(:float), df["horizon"], w, "coverage")
347
+ end
348
+ end
349
+ end
@@ -3,7 +3,14 @@ module Prophet
3
3
  include Holidays
4
4
  include Plot
5
5
 
6
- attr_reader :logger, :params, :train_holiday_names
6
+ attr_reader :logger, :params, :train_holiday_names,
7
+ :history, :seasonalities, :specified_changepoints, :fit_kwargs,
8
+ :growth, :changepoints, :n_changepoints, :changepoint_range,
9
+ :holidays, :seasonality_mode, :seasonality_prior_scale,
10
+ :holidays_prior_scale, :changepoint_prior_scale, :mcmc_samples,
11
+ :interval_width, :uncertainty_samples
12
+
13
+ attr_accessor :extra_regressors, :seasonalities, :country_holidays
7
14
 
8
15
  def initialize(
9
16
  growth: "linear",
@@ -176,8 +183,10 @@ module Prophet
176
183
 
177
184
  initialize_scales(initialize_scales, df)
178
185
 
179
- if @logistic_floor && !df.include?("floor")
180
- raise ArgumentError, "Expected column \"floor\"."
186
+ if @logistic_floor
187
+ unless df.include?("floor")
188
+ raise ArgumentError, "Expected column \"floor\"."
189
+ end
181
190
  else
182
191
  df["floor"] = 0
183
192
  end
@@ -207,7 +216,12 @@ module Prophet
207
216
  def initialize_scales(initialize_scales, df)
208
217
  return unless initialize_scales
209
218
 
210
- floor = 0
219
+ if @growth == "logistic" && df.include?("floor")
220
+ @logistic_floor = true
221
+ floor = df["floor"]
222
+ else
223
+ floor = 0.0
224
+ end
211
225
  @y_scale = (df["y"] - floor).abs.max
212
226
  @y_scale = 1 if @y_scale == 0
213
227
  @start = df["ds"].min
data/lib/prophet/plot.rb CHANGED
@@ -111,6 +111,61 @@ module Prophet
111
111
  artists
112
112
  end
113
113
 
114
+ def self.plot_cross_validation_metric(df_cv, metric:, rolling_window: 0.1, ax: nil, figsize: [10, 6], color: "b", point_color: "gray")
115
+ if ax.nil?
116
+ fig = plt.figure(facecolor: "w", figsize: figsize)
117
+ ax = fig.add_subplot(111)
118
+ else
119
+ fig = ax.get_figure
120
+ end
121
+ # Get the metric at the level of individual predictions, and with the rolling window.
122
+ df_none = Diagnostics.performance_metrics(df_cv, metrics: [metric], rolling_window: -1)
123
+ df_h = Diagnostics.performance_metrics(df_cv, metrics: [metric], rolling_window: rolling_window)
124
+
125
+ # Some work because matplotlib does not handle timedelta
126
+ # Target ~10 ticks.
127
+ tick_w = df_none["horizon"].max * 1e9 / 10.0
128
+ # Find the largest time resolution that has <1 unit per bin.
129
+ dts = ["D", "h", "m", "s", "ms", "us", "ns"]
130
+ dt_names = ["days", "hours", "minutes", "seconds", "milliseconds", "microseconds", "nanoseconds"]
131
+ dt_conversions = [
132
+ 24 * 60 * 60 * 10 ** 9,
133
+ 60 * 60 * 10 ** 9,
134
+ 60 * 10 ** 9,
135
+ 10 ** 9,
136
+ 10 ** 6,
137
+ 10 ** 3,
138
+ 1.0
139
+ ]
140
+ # TODO update
141
+ i = 0
142
+ # dts.each_with_index do |dt, i|
143
+ # if np.timedelta64(1, dt) < np.timedelta64(tick_w, "ns")
144
+ # break
145
+ # end
146
+ # end
147
+
148
+ x_plt = df_none["horizon"] * 1e9 / dt_conversions[i].to_f
149
+ x_plt_h = df_h["horizon"] * 1e9 / dt_conversions[i].to_f
150
+
151
+ ax.plot(x_plt.to_a, df_none[metric].to_a, ".", alpha: 0.1, c: point_color)
152
+ ax.plot(x_plt_h.to_a, df_h[metric].to_a, "-", c: color)
153
+ ax.grid(true)
154
+
155
+ ax.set_xlabel("Horizon (#{dt_names[i]})")
156
+ ax.set_ylabel(metric)
157
+ fig
158
+ end
159
+
160
+ def self.plt
161
+ begin
162
+ require "matplotlib/pyplot"
163
+ rescue LoadError
164
+ raise Error, "Install the matplotlib gem for plots"
165
+ end
166
+ Matplotlib::Pyplot
167
+ end
168
+
114
169
  private
115
170
 
116
171
  def plot_forecast_component(fcst, name, ax: nil, uncertainty: true, plot_cap: false, figsize: [10, 6])
@@ -263,12 +318,7 @@ module Prophet
263
318
  end
264
319
 
265
320
  def plt
266
- begin
267
- require "matplotlib/pyplot"
268
- rescue LoadError
269
- raise Error, "Install the matplotlib gem for plots"
270
- end
271
- Matplotlib::Pyplot
321
+ Plot.plt
272
322
  end
273
323
 
274
324
  def dates
@@ -13,6 +13,11 @@ module Prophet
13
13
 
14
14
  def fit(stan_init, stan_data, **kwargs)
15
15
  stan_init, stan_data = prepare_data(stan_init, stan_data)
16
+
17
+ if !kwargs[:inits] && kwargs[:init]
18
+ kwargs[:inits] = prepare_data(kwargs.delete(:init), stan_data)[0]
19
+ end
20
+
16
21
  kwargs[:algorithm] ||= stan_data["T"] < 100 ? "Newton" : "LBFGS"
17
22
  iterations = 10000
18
23
 
@@ -49,6 +54,10 @@ module Prophet
49
54
  def sampling(stan_init, stan_data, samples, **kwargs)
50
55
  stan_init, stan_data = prepare_data(stan_init, stan_data)
51
56
 
57
+ if !kwargs[:inits] && kwargs[:init]
58
+ kwargs[:inits] = prepare_data(kwargs.delete(:init), stan_data)[0]
59
+ end
60
+
52
61
  kwargs[:chains] ||= 4
53
62
  kwargs[:warmup_iters] ||= samples / 2
54
63
 
@@ -128,7 +137,7 @@ module Prophet
128
137
  stan_data["t_change"] = stan_data["t_change"].to_a
129
138
  stan_data["s_a"] = stan_data["s_a"].to_a
130
139
  stan_data["s_m"] = stan_data["s_m"].to_a
131
- stan_data["X"] = stan_data["X"].to_numo.to_a
140
+ stan_data["X"] = stan_data["X"].respond_to?(:to_numo) ? stan_data["X"].to_numo.to_a : stan_data["X"].to_a
132
141
  stan_init["delta"] = stan_init["delta"].to_a
133
142
  stan_init["beta"] = stan_init["beta"].to_a
134
143
  [stan_init, stan_data]
@@ -1,3 +1,3 @@
1
1
  module Prophet
2
- VERSION = "0.4.0"
2
+ VERSION = "0.4.1"
3
3
  end
data/lib/prophet.rb CHANGED
@@ -8,6 +8,7 @@ require "logger"
8
8
  require "set"
9
9
 
10
10
  # modules
11
+ require "prophet/diagnostics"
11
12
  require "prophet/holidays"
12
13
  require "prophet/plot"
13
14
  require "prophet/forecaster"
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: prophet-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.4.0
4
+ version: 0.4.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2022-07-07 00:00:00.000000000 Z
11
+ date: 2022-07-10 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: cmdstan
@@ -66,6 +66,7 @@ files:
66
66
  - data-raw/generated_holidays.csv
67
67
  - lib/prophet-rb.rb
68
68
  - lib/prophet.rb
69
+ - lib/prophet/diagnostics.rb
69
70
  - lib/prophet/forecaster.rb
70
71
  - lib/prophet/holidays.rb
71
72
  - lib/prophet/plot.rb