prophet-rb 0.4.0 → 0.4.1

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 2bcbd43f3750bc8c70fe28f2b0b42e88464724601daa754a9d60d8e8e354c701
4
- data.tar.gz: c16f1ab7f2d48419543b54f326fc98ac1261c7e5ceaf7d3b0fd352b87a2fdcb3
3
+ metadata.gz: 69d58f060a9bda44b1ab8666ded81b3e61ff1b95b08ae40727fa7a2dfee57ee0
4
+ data.tar.gz: dc76685a8b45ca7cad79561f986af9c66e5e49bd45dd5e10c72f91088d9b470a
5
5
  SHA512:
6
- metadata.gz: e8b3cf363a665f063d7045b3c34c124021281e063e61ff99e7bd98ba774a0e40a2d379aeb53f521833724b11d5ff1c7934820b73095ad5384b70066fc20e5e36
7
- data.tar.gz: 177f8eab90b9be0e5112c3d158ef9b522309143c559a97ffa206b7cc63a83f0eb70e75d7e1f8ab9640e57e403bd85d1469ecdb9137b0be8acc03c7ff09a91447
6
+ metadata.gz: 34d2fd0587110c6de9db334c44a4859a5de92d1cc124efb5f865c845d504816f3d3cd34d63776a709de396af8f9b2cf82a390cfb9f3e49e685ef8be4f469fb3e
7
+ data.tar.gz: '0587986407bb68a9ca97928ad65bcbb42fe101bd9d2e50a44b2126df38d67c04544932ec39009c6f1d6467736aa0a8404ac4aad04187dcb953285d4d5ee80a77'
data/CHANGELOG.md CHANGED
@@ -1,3 +1,9 @@
1
+ ## 0.4.1 (2022-07-10)
2
+
3
+ - Added support for cross validation and performance metrics
4
+ - Added support for updating fitted models
5
+ - Added support for saturating minimum forecasts
6
+
1
7
  ## 0.4.0 (2022-07-07)
2
8
 
3
9
  - Added support for saving and loading models
data/README.md CHANGED
@@ -88,7 +88,8 @@ Check out the [Prophet documentation](https://facebook.github.io/prophet/docs/qu
88
88
  - [Multiplicative Seasonality](#multiplicative-seasonality)
89
89
  - [Uncertainty Intervals](#uncertainty-intervals)
90
90
  - [Non-Daily Data](#non-daily-data)
91
- - [Saving Models](#saving-models)
91
+ - [Diagnostics](#diagnostics)
92
+ - [Additional Topics](#additional-topics)
92
93
 
93
94
  ## Advanced Quick Start
94
95
 
@@ -177,11 +178,24 @@ df = Rover.read_csv("example_wp_log_R.csv")
177
178
  df["cap"] = 8.5
178
179
  m = Prophet.new(growth: "logistic")
179
180
  m.fit(df)
180
- future = m.make_future_dataframe(periods: 365)
181
+ future = m.make_future_dataframe(periods: 1826)
181
182
  future["cap"] = 8.5
182
183
  forecast = m.predict(future)
183
184
  ```
184
185
 
186
+ Saturating minimum
187
+
188
+ ```ruby
189
+ df["y"] = 10 - df["y"]
190
+ df["cap"] = 6
191
+ df["floor"] = 1.5
192
+ future["cap"] = 6
193
+ future["floor"] = 1.5
194
+ m = Prophet.new(growth: "logistic")
195
+ m.fit(df)
196
+ forecast = m.predict(future)
197
+ ```
198
+
185
199
  ## Trend Changepoints
186
200
 
187
201
  [Explanation](https://facebook.github.io/prophet/docs/trend_changepoints.html)
@@ -308,9 +322,64 @@ future = m.make_future_dataframe(periods: 300, freq: "H")
308
322
  forecast = m.predict(future)
309
323
  ```
310
324
 
311
- ## Saving Models
325
+ ## Diagnostics
326
+
327
+ [Explanation](http://facebook.github.io/prophet/docs/diagnostics.html)
328
+
329
+ Cross validation
330
+
331
+ ```ruby
332
+ df_cv = Prophet::Diagnostics.cross_validation(m, initial: "730 days", period: "180 days", horizon: "365 days")
333
+ ```
334
+
335
+ Custom cutoffs
336
+
337
+ ```ruby
338
+ cutoffs = ["2013-02-15", "2013-08-15", "2014-02-15"].map { |v| Time.parse("#{v} 00:00:00 UTC") }
339
+ df_cv2 = Prophet::Diagnostics.cross_validation(m, cutoffs: cutoffs, horizon: "365 days")
340
+ ```
341
+
342
+ Get performance metrics
312
343
 
313
- [Explanation](https://facebook.github.io/prophet/docs/additional_topics.html#saving-models)
344
+ ```ruby
345
+ df_p = Prophet::Diagnostics.performance_metrics(df_cv)
346
+ ```
347
+
348
+ Plot cross validation metrics
349
+
350
+ ```ruby
351
+ Prophet::Plot.plot_cross_validation_metric(df_cv, metric: "mape")
352
+ ```
353
+
354
+ Hyperparameter tuning
355
+
356
+ ```ruby
357
+ param_grid = {
358
+ changepoint_prior_scale: [0.001, 0.01, 0.1, 0.5],
359
+ seasonality_prior_scale: [0.01, 0.1, 1.0, 10.0]
360
+ }
361
+
362
+ # Generate all combinations of parameters
363
+ all_params = param_grid.values[0].product(*param_grid.values[1..-1]).map { |v| param_grid.keys.zip(v).to_h }
364
+ rmses = [] # Store the RMSEs for each params here
365
+
366
+ # Use cross validation to evaluate all parameters
367
+ all_params.each do |params|
368
+ m = Prophet.new(**params).fit(df) # Fit model with given params
369
+ df_cv = Prophet::Diagnostics.cross_validation(m, cutoffs: cutoffs, horizon: "30 days")
370
+ df_p = Prophet::Diagnostics.performance_metrics(df_cv, rolling_window: 1)
371
+ rmses << df_p["rmse"][0]
372
+ end
373
+
374
+ # Find the best parameters
375
+ tuning_results = Rover::DataFrame.new(all_params)
376
+ tuning_results["rmse"] = rmses
377
+ p tuning_results
378
+ ```
379
+
380
+ ## Additional Topics
381
+
382
+ [Explanation](https://facebook.github.io/prophet/docs/additional_topics.html)
314
383
 
315
384
  Save a model
316
385
 
@@ -326,6 +395,34 @@ m = Prophet.from_json(File.read("model.json"))
326
395
 
327
396
  Uses the same format as Python, so models can be saved and loaded in either language
328
397
 
398
+ Flat trend
399
+
400
+ ```ruby
401
+ m = Prophet.new(growth: "flat")
402
+ ```
403
+
404
+ Updating fitted models
405
+
406
+ ```ruby
407
+ def stan_init(m)
408
+ res = {}
409
+ ["k", "m", "sigma_obs"].each do |pname|
410
+ res[pname] = m.params[pname][0, true][0]
411
+ end
412
+ ["delta", "beta"].each do |pname|
413
+ res[pname] = m.params[pname][0, true]
414
+ end
415
+ res
416
+ end
417
+
418
+ df = Rover.read_csv("example_wp_log_peyton_manning.csv")
419
+ df1 = df[df["ds"] <= "2016-01-19"] # All data except the last day
420
+ m1 = Prophet.new.fit(df1) # A model fit to all data except the last day
421
+
422
+ m2 = Prophet.new.fit(df) # Adding the last day, fitting from scratch
423
+ m2 = Prophet.new.fit(df, init: stan_init(m1)) # Adding the last day, warm-starting from m1
424
+ ```
425
+
329
426
  ## Resources
330
427
 
331
428
  - [Forecasting at Scale](https://peerj.com/preprints/3190.pdf)
@@ -0,0 +1,349 @@
1
+ module Prophet
2
+ module Diagnostics
3
+ def self.generate_cutoffs(df, horizon, initial, period)
4
+ # Last cutoff is 'latest date in data - horizon' date
5
+ cutoff = df["ds"].max - horizon
6
+ if cutoff < df["ds"].min
7
+ raise Error, "Less data than horizon."
8
+ end
9
+ result = [cutoff]
10
+ while result[-1] >= df["ds"].min + initial
11
+ cutoff -= period
12
+ # If data does not exist in data range (cutoff, cutoff + horizon]
13
+ if !(((df["ds"] > cutoff) & (df["ds"] <= cutoff + horizon)).any?)
14
+ # Next cutoff point is 'last date before cutoff in data - horizon'
15
+ if cutoff > df["ds"].min
16
+ closest_date = df[df["ds"] <= cutoff].max["ds"]
17
+ cutoff = closest_date - horizon
18
+ end
19
+ # else no data left, leave cutoff as is, it will be dropped.
20
+ end
21
+ result << cutoff
22
+ end
23
+ result = result[0...-1]
24
+ if result.length == 0
25
+ raise Error, "Less data than horizon after initial window. Make horizon or initial shorter."
26
+ end
27
+ # logger.info("Making #{result.length} forecasts with cutoffs between #{result[-1]} and #{result[0]}")
28
+ result.reverse
29
+ end
30
+
31
+ def self.cross_validation(model, horizon:, period: nil, initial: nil, cutoffs: nil)
32
+ if model.history.nil?
33
+ raise Error, "Model has not been fit. Fitting the model provides contextual parameters for cross validation."
34
+ end
35
+
36
+ df = model.history.dup
37
+ horizon = timedelta(horizon)
38
+
39
+ predict_columns = ["ds", "yhat"]
40
+ if model.uncertainty_samples
41
+ predict_columns.concat(["yhat_lower", "yhat_upper"])
42
+ end
43
+
44
+ # Identify largest seasonality period
45
+ period_max = 0.0
46
+ model.seasonalities.each do |_, s|
47
+ period_max = [period_max, s[:period]].max
48
+ end
49
+ seasonality_dt = timedelta("#{period_max} days")
50
+
51
+ if cutoffs.nil?
52
+ # Set period
53
+ period = period.nil? ? 0.5 * horizon : timedelta(period)
54
+
55
+ # Set initial
56
+ initial = initial.nil? ? [3 * horizon, seasonality_dt].max : timedelta(initial)
57
+
58
+ # Compute Cutoffs
59
+ cutoffs = generate_cutoffs(df, horizon, initial, period)
60
+ else
61
+ # add validation of the cutoff to make sure that the min cutoff is strictly greater than the min date in the history
62
+ if cutoffs.min <= df["ds"].min
63
+ raise Error, "Minimum cutoff value is not strictly greater than min date in history"
64
+ end
65
+ # max value of cutoffs is <= (end date minus horizon)
66
+ end_date_minus_horizon = df["ds"].max - horizon
67
+ if cutoffs.max > end_date_minus_horizon
68
+ raise Error, "Maximum cutoff value is greater than end date minus horizon, no value for cross-validation remaining"
69
+ end
70
+ initial = cutoffs[0] - df["ds"].min
71
+ end
72
+
73
+ # Check if the initial window
74
+ # (that is, the amount of time between the start of the history and the first cutoff)
75
+ # is less than the maximum seasonality period
76
+ if initial < seasonality_dt
77
+ msg = "Seasonality has period of #{period_max} days "
78
+ msg += "which is larger than initial window. "
79
+ msg += "Consider increasing initial."
80
+ # logger.warn(msg)
81
+ end
82
+
83
+ predicts = cutoffs.map { |cutoff| single_cutoff_forecast(df, model, cutoff, horizon, predict_columns) }
84
+
85
+ # Combine all predicted DataFrame into one DataFrame
86
+ predicts.reduce(Rover::DataFrame.new) { |memo, v| memo.concat(v) }
87
+ end
88
+
89
+ def self.single_cutoff_forecast(df, model, cutoff, horizon, predict_columns)
90
+ # Generate new object with copying fitting options
91
+ m = prophet_copy(model, cutoff)
92
+ # Train model
93
+ history_c = df[df["ds"] <= cutoff]
94
+ if history_c.shape[0] < 2
95
+ raise Error, "Less than two datapoints before cutoff. Increase initial window."
96
+ end
97
+ m.fit(history_c, **model.fit_kwargs)
98
+ # Calculate yhat
99
+ index_predicted = (df["ds"] > cutoff) & (df["ds"] <= cutoff + horizon)
100
+ # Get the columns for the future dataframe
101
+ columns = ["ds"]
102
+ if m.growth == "logistic"
103
+ columns << "cap"
104
+ if m.logistic_floor
105
+ columns << "floor"
106
+ end
107
+ end
108
+ columns.concat(m.extra_regressors.keys)
109
+ columns.concat(m.seasonalities.map { |_, props| props[:condition_name] }.compact)
110
+ yhat = m.predict(df[index_predicted][columns])
111
+ # Merge yhat(predicts), y(df, original data) and cutoff
112
+ yhat[predict_columns].merge(df[index_predicted][["y"]]).merge(Rover::DataFrame.new({"cutoff" => [cutoff] * yhat.length}))
113
+ end
114
+
115
+ def self.prophet_copy(m, cutoff = nil)
116
+ if m.history.nil?
117
+ raise Error, "This is for copying a fitted Prophet object."
118
+ end
119
+
120
+ if m.specified_changepoints
121
+ changepoints = m.changepoints
122
+ if !cutoff.nil?
123
+ # Filter change points '< cutoff'
124
+ last_history_date = m.history["ds"][m.history["ds"] <= cutoff].max
125
+ changepoints = changepoints[changepoints < last_history_date]
126
+ end
127
+ else
128
+ changepoints = nil
129
+ end
130
+
131
+ # Auto seasonalities are set to False because they are already set in
132
+ # m.seasonalities.
133
+ m2 = m.class.new(
134
+ growth: m.growth,
135
+ n_changepoints: m.n_changepoints,
136
+ changepoint_range: m.changepoint_range,
137
+ changepoints: changepoints,
138
+ yearly_seasonality: false,
139
+ weekly_seasonality: false,
140
+ daily_seasonality: false,
141
+ holidays: m.holidays,
142
+ seasonality_mode: m.seasonality_mode,
143
+ seasonality_prior_scale: m.seasonality_prior_scale,
144
+ changepoint_prior_scale: m.changepoint_prior_scale,
145
+ holidays_prior_scale: m.holidays_prior_scale,
146
+ mcmc_samples: m.mcmc_samples,
147
+ interval_width: m.interval_width,
148
+ uncertainty_samples: m.uncertainty_samples
149
+ )
150
+ m2.extra_regressors = deepcopy(m.extra_regressors)
151
+ m2.seasonalities = deepcopy(m.seasonalities)
152
+ m2.country_holidays = deepcopy(m.country_holidays)
153
+ m2
154
+ end
155
+
156
+ def self.timedelta(value)
157
+ if value.is_a?(Numeric)
158
+ # ActiveSupport::Duration is a numeric
159
+ value
160
+ elsif (m = /\A(\d+(\.\d+)?) days\z/.match(value))
161
+ m[1].to_f * 86400
162
+ else
163
+ raise Error, "Unknown time delta"
164
+ end
165
+ end
166
+
167
+ def self.deepcopy(value)
168
+ if value.is_a?(Hash)
169
+ value.to_h { |k, v| [deepcopy(k), deepcopy(v)] }
170
+ elsif value.is_a?(Array)
171
+ value.map { |v| deepcopy(v) }
172
+ else
173
+ value.dup
174
+ end
175
+ end
176
+
177
+ def self.performance_metrics(df, metrics: nil, rolling_window: 0.1, monthly: false)
178
+ valid_metrics = ["mse", "rmse", "mae", "mape", "mdape", "smape", "coverage"]
179
+ if metrics.nil?
180
+ metrics = valid_metrics
181
+ end
182
+ if (df["yhat_lower"].nil? || df["yhat_upper"].nil?) && metrics.include?("coverage")
183
+ metrics.delete("coverage")
184
+ end
185
+ if metrics.uniq.length != metrics.length
186
+ raise ArgumentError, "Input metrics must be a list of unique values"
187
+ end
188
+ if !Set.new(metrics).subset?(Set.new(valid_metrics))
189
+ raise ArgumentError, "Valid values for metrics are: #{valid_metrics}"
190
+ end
191
+ df_m = df.dup
192
+ if monthly
193
+ raise Error, "Not implemented yet"
194
+ # df_m["horizon"] = df_m["ds"].dt.to_period("M").astype(int) - df_m["cutoff"].dt.to_period("M").astype(int)
195
+ else
196
+ df_m["horizon"] = df_m["ds"] - df_m["cutoff"]
197
+ end
198
+ df_m.sort_by! { |r| r["horizon"] }
199
+ if metrics.include?("mape") && df_m["y"].abs.min < 1e-8
200
+ # logger.info("Skipping MAPE because y close to 0")
201
+ metrics.delete("mape")
202
+ end
203
+ if metrics.length == 0
204
+ return nil
205
+ end
206
+ w = (rolling_window * df_m.shape[0]).to_i
207
+ if w >= 0
208
+ w = [w, 1].max
209
+ w = [w, df_m.shape[0]].min
210
+ end
211
+ # Compute all metrics
212
+ dfs = {}
213
+ metrics.each do |metric|
214
+ dfs[metric] = send(metric, df_m, w)
215
+ end
216
+ res = dfs[metrics[0]]
217
+ metrics.each do |metric|
218
+ res_m = dfs[metric]
219
+ res[metric] = res_m[metric]
220
+ end
221
+ res
222
+ end
223
+
224
+ def self.rolling_mean_by_h(x, h, w, name)
225
+ # Aggregate over h
226
+ df = Rover::DataFrame.new({"x" => x, "h" => h})
227
+ df2 = df.group("h").sum("x").inner_join(df.group("h").count).sort_by { |r| r["h"] }
228
+ xs = df2["sum_x"]
229
+ ns = df2["count"]
230
+ hs = df2["h"]
231
+
232
+ trailing_i = df2.length - 1
233
+ x_sum = 0
234
+ n_sum = 0
235
+ # We don't know output size but it is bounded by len(df2)
236
+ res_x = [nil] * df2.length
237
+
238
+ # Start from the right and work backwards
239
+ (df2.length - 1).downto(0) do |i|
240
+ x_sum += xs[i]
241
+ n_sum += ns[i]
242
+ while n_sum >= w
243
+ # Include points from the previous horizon. All of them if still
244
+ # less than w, otherwise weight the mean by the difference
245
+ excess_n = n_sum - w
246
+ excess_x = excess_n * xs[i] / ns[i]
247
+ res_x[trailing_i] = (x_sum - excess_x) / w
248
+ x_sum -= xs[trailing_i]
249
+ n_sum -= ns[trailing_i]
250
+ trailing_i -= 1
251
+ end
252
+ end
253
+
254
+ res_h = hs[(trailing_i + 1)..-1]
255
+ res_x = res_x[(trailing_i + 1)..-1]
256
+
257
+ Rover::DataFrame.new({"horizon" => res_h, name => res_x})
258
+ end
259
+
260
+ def self.rolling_median_by_h(x, h, w, name)
261
+ # Aggregate over h
262
+ df = Rover::DataFrame.new({"x" => x, "h" => h})
263
+ grouped = df.group("h")
264
+ df2 = grouped.count.sort_by { |r| r["h"] }
265
+ hs = df2["h"]
266
+
267
+ res_h = []
268
+ res_x = []
269
+ # Start from the right and work backwards
270
+ i = hs.length - 1
271
+ while i >= 0
272
+ h_i = hs[i]
273
+ xs = df[df["h"] == h_i]["x"].to_a
274
+
275
+ next_idx_to_add = (h == h_i).to_numo.cast_to(Numo::UInt8).argmax - 1
276
+ while xs.length < w && next_idx_to_add >= 0
277
+ # Include points from the previous horizon. All of them if still
278
+ # less than w, otherwise just enough to get to w.
279
+ xs << x[next_idx_to_add]
280
+ next_idx_to_add -= 1
281
+ end
282
+ if xs.length < w
283
+ # Ran out of points before getting enough.
284
+ break
285
+ end
286
+ res_h << hs[i]
287
+ res_x << Rover::Vector.new(xs).median
288
+ i -= 1
289
+ end
290
+ res_h.reverse!
291
+ res_x.reverse!
292
+ Rover::DataFrame.new({"horizon" => res_h, name => res_x})
293
+ end
294
+
295
+ def self.mse(df, w)
296
+ se = (df["y"] - df["yhat"]) ** 2
297
+ if w < 0
298
+ return Rover::DataFrame.new({"horizon" => df["horizon"], "mse" => se})
299
+ end
300
+ rolling_mean_by_h(se, df["horizon"], w, "mse")
301
+ end
302
+
303
+ def self.rmse(df, w)
304
+ res = mse(df, w)
305
+ res["rmse"] = res.delete("mse").map { |v| Math.sqrt(v) }
306
+ res
307
+ end
308
+
309
+ def self.mae(df, w)
310
+ ae = (df["y"] - df["yhat"]).abs
311
+ if w < 0
312
+ return Rover::DataFrame.new({"horizon" => df["horizon"], "mae" => ae})
313
+ end
314
+ rolling_mean_by_h(ae, df["horizon"], w, "mae")
315
+ end
316
+
317
+ def self.mape(df, w)
318
+ ape = ((df["y"] - df["yhat"]) / df["y"]).abs
319
+ if w < 0
320
+ return Rover::DataFrame.new({"horizon" => df["horizon"], "mape" => ape})
321
+ end
322
+ rolling_mean_by_h(ape, df["horizon"], w, "mape")
323
+ end
324
+
325
+ def self.mdape(df, w)
326
+ ape = ((df["y"] - df["yhat"]) / df["y"]).abs
327
+ if w < 0
328
+ return Rover::DataFrame.new({"horizon" => df["horizon"], "mdape" => ape})
329
+ end
330
+ rolling_median_by_h(ape, df["horizon"], w, "mdape")
331
+ end
332
+
333
+ def self.smape(df, w)
334
+ sape = (df["y"] - df["yhat"]).abs / ((df["y"].abs + df["yhat"].abs) / 2)
335
+ if w < 0
336
+ return Rover::DataFrame.new({"horizon" => df["horizon"], "smape" => sape})
337
+ end
338
+ rolling_mean_by_h(sape, df["horizon"], w, "smape")
339
+ end
340
+
341
+ def self.coverage(df, w)
342
+ is_covered = (df["y"] >= df["yhat_lower"]) & (df["y"] <= df["yhat_upper"])
343
+ if w < 0
344
+ return Rover::DataFrame.new({"horizon" => df["horizon"], "coverage" => is_covered})
345
+ end
346
+ rolling_mean_by_h(is_covered.to(:float), df["horizon"], w, "coverage")
347
+ end
348
+ end
349
+ end
@@ -3,7 +3,14 @@ module Prophet
3
3
  include Holidays
4
4
  include Plot
5
5
 
6
- attr_reader :logger, :params, :train_holiday_names
6
+ attr_reader :logger, :params, :train_holiday_names,
7
+ :history, :seasonalities, :specified_changepoints, :fit_kwargs,
8
+ :growth, :changepoints, :n_changepoints, :changepoint_range,
9
+ :holidays, :seasonality_mode, :seasonality_prior_scale,
10
+ :holidays_prior_scale, :changepoint_prior_scale, :mcmc_samples,
11
+ :interval_width, :uncertainty_samples
12
+
13
+ attr_accessor :extra_regressors, :seasonalities, :country_holidays
7
14
 
8
15
  def initialize(
9
16
  growth: "linear",
@@ -176,8 +183,10 @@ module Prophet
176
183
 
177
184
  initialize_scales(initialize_scales, df)
178
185
 
179
- if @logistic_floor && !df.include?("floor")
180
- raise ArgumentError, "Expected column \"floor\"."
186
+ if @logistic_floor
187
+ unless df.include?("floor")
188
+ raise ArgumentError, "Expected column \"floor\"."
189
+ end
181
190
  else
182
191
  df["floor"] = 0
183
192
  end
@@ -207,7 +216,12 @@ module Prophet
207
216
  def initialize_scales(initialize_scales, df)
208
217
  return unless initialize_scales
209
218
 
210
- floor = 0
219
+ if @growth == "logistic" && df.include?("floor")
220
+ @logistic_floor = true
221
+ floor = df["floor"]
222
+ else
223
+ floor = 0.0
224
+ end
211
225
  @y_scale = (df["y"] - floor).abs.max
212
226
  @y_scale = 1 if @y_scale == 0
213
227
  @start = df["ds"].min
data/lib/prophet/plot.rb CHANGED
@@ -111,6 +111,61 @@ module Prophet
111
111
  artists
112
112
  end
113
113
 
114
+ def self.plot_cross_validation_metric(df_cv, metric:, rolling_window: 0.1, ax: nil, figsize: [10, 6], color: "b", point_color: "gray")
115
+ if ax.nil?
116
+ fig = plt.figure(facecolor: "w", figsize: figsize)
117
+ ax = fig.add_subplot(111)
118
+ else
119
+ fig = ax.get_figure
120
+ end
121
+ # Get the metric at the level of individual predictions, and with the rolling window.
122
+ df_none = Diagnostics.performance_metrics(df_cv, metrics: [metric], rolling_window: -1)
123
+ df_h = Diagnostics.performance_metrics(df_cv, metrics: [metric], rolling_window: rolling_window)
124
+
125
+ # Some work because matplotlib does not handle timedelta
126
+ # Target ~10 ticks.
127
+ tick_w = df_none["horizon"].max * 1e9 / 10.0
128
+ # Find the largest time resolution that has <1 unit per bin.
129
+ dts = ["D", "h", "m", "s", "ms", "us", "ns"]
130
+ dt_names = ["days", "hours", "minutes", "seconds", "milliseconds", "microseconds", "nanoseconds"]
131
+ dt_conversions = [
132
+ 24 * 60 * 60 * 10 ** 9,
133
+ 60 * 60 * 10 ** 9,
134
+ 60 * 10 ** 9,
135
+ 10 ** 9,
136
+ 10 ** 6,
137
+ 10 ** 3,
138
+ 1.0
139
+ ]
140
+ # TODO update
141
+ i = 0
142
+ # dts.each_with_index do |dt, i|
143
+ # if np.timedelta64(1, dt) < np.timedelta64(tick_w, "ns")
144
+ # break
145
+ # end
146
+ # end
147
+
148
+ x_plt = df_none["horizon"] * 1e9 / dt_conversions[i].to_f
149
+ x_plt_h = df_h["horizon"] * 1e9 / dt_conversions[i].to_f
150
+
151
+ ax.plot(x_plt.to_a, df_none[metric].to_a, ".", alpha: 0.1, c: point_color)
152
+ ax.plot(x_plt_h.to_a, df_h[metric].to_a, "-", c: color)
153
+ ax.grid(true)
154
+
155
+ ax.set_xlabel("Horizon (#{dt_names[i]})")
156
+ ax.set_ylabel(metric)
157
+ fig
158
+ end
159
+
160
+ def self.plt
161
+ begin
162
+ require "matplotlib/pyplot"
163
+ rescue LoadError
164
+ raise Error, "Install the matplotlib gem for plots"
165
+ end
166
+ Matplotlib::Pyplot
167
+ end
168
+
114
169
  private
115
170
 
116
171
  def plot_forecast_component(fcst, name, ax: nil, uncertainty: true, plot_cap: false, figsize: [10, 6])
@@ -263,12 +318,7 @@ module Prophet
263
318
  end
264
319
 
265
320
  def plt
266
- begin
267
- require "matplotlib/pyplot"
268
- rescue LoadError
269
- raise Error, "Install the matplotlib gem for plots"
270
- end
271
- Matplotlib::Pyplot
321
+ Plot.plt
272
322
  end
273
323
 
274
324
  def dates
@@ -13,6 +13,11 @@ module Prophet
13
13
 
14
14
  def fit(stan_init, stan_data, **kwargs)
15
15
  stan_init, stan_data = prepare_data(stan_init, stan_data)
16
+
17
+ if !kwargs[:inits] && kwargs[:init]
18
+ kwargs[:inits] = prepare_data(kwargs.delete(:init), stan_data)[0]
19
+ end
20
+
16
21
  kwargs[:algorithm] ||= stan_data["T"] < 100 ? "Newton" : "LBFGS"
17
22
  iterations = 10000
18
23
 
@@ -49,6 +54,10 @@ module Prophet
49
54
  def sampling(stan_init, stan_data, samples, **kwargs)
50
55
  stan_init, stan_data = prepare_data(stan_init, stan_data)
51
56
 
57
+ if !kwargs[:inits] && kwargs[:init]
58
+ kwargs[:inits] = prepare_data(kwargs.delete(:init), stan_data)[0]
59
+ end
60
+
52
61
  kwargs[:chains] ||= 4
53
62
  kwargs[:warmup_iters] ||= samples / 2
54
63
 
@@ -128,7 +137,7 @@ module Prophet
128
137
  stan_data["t_change"] = stan_data["t_change"].to_a
129
138
  stan_data["s_a"] = stan_data["s_a"].to_a
130
139
  stan_data["s_m"] = stan_data["s_m"].to_a
131
- stan_data["X"] = stan_data["X"].to_numo.to_a
140
+ stan_data["X"] = stan_data["X"].respond_to?(:to_numo) ? stan_data["X"].to_numo.to_a : stan_data["X"].to_a
132
141
  stan_init["delta"] = stan_init["delta"].to_a
133
142
  stan_init["beta"] = stan_init["beta"].to_a
134
143
  [stan_init, stan_data]
@@ -1,3 +1,3 @@
1
1
  module Prophet
2
- VERSION = "0.4.0"
2
+ VERSION = "0.4.1"
3
3
  end
data/lib/prophet.rb CHANGED
@@ -8,6 +8,7 @@ require "logger"
8
8
  require "set"
9
9
 
10
10
  # modules
11
+ require "prophet/diagnostics"
11
12
  require "prophet/holidays"
12
13
  require "prophet/plot"
13
14
  require "prophet/forecaster"
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: prophet-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.4.0
4
+ version: 0.4.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2022-07-07 00:00:00.000000000 Z
11
+ date: 2022-07-10 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: cmdstan
@@ -66,6 +66,7 @@ files:
66
66
  - data-raw/generated_holidays.csv
67
67
  - lib/prophet-rb.rb
68
68
  - lib/prophet.rb
69
+ - lib/prophet/diagnostics.rb
69
70
  - lib/prophet/forecaster.rb
70
71
  - lib/prophet/holidays.rb
71
72
  - lib/prophet/plot.rb