prophet-rb 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,27 @@
1
+ module Prophet
2
+ module Holidays
3
+ def get_holiday_names(country)
4
+ years = (1995..2045).to_a
5
+ make_holidays_df(years, country)["holiday"].uniq
6
+ end
7
+
8
+ def make_holidays_df(year_list, country)
9
+ holidays_df.where(holidays_df["country"].eq(country) & holidays_df["year"].in(year_list))["ds", "holiday"]
10
+ end
11
+
12
+ # TODO marshal on installation
13
+ def holidays_df
14
+ @holidays_df ||= begin
15
+ holidays = {"ds" => [], "holiday" => [], "country" => [], "year" => []}
16
+ holidays_file = File.expand_path("../../data-raw/generated_holidays.csv", __dir__)
17
+ CSV.foreach(holidays_file, headers: true, converters: [:date, :numeric]) do |row|
18
+ holidays["ds"] << row["ds"]
19
+ holidays["holiday"] << row["holiday"]
20
+ holidays["country"] << row["country"]
21
+ holidays["year"] << row["year"]
22
+ end
23
+ Daru::DataFrame.new(holidays)
24
+ end
25
+ end
26
+ end
27
+ end
@@ -0,0 +1,269 @@
1
+ module Prophet
2
+ module Plot
3
+ def plot(fcst, ax: nil, uncertainty: true, plot_cap: true, xlabel: "ds", ylabel: "y", figsize: [10, 6])
4
+ if ax.nil?
5
+ fig = plt.figure(facecolor: "w", figsize: figsize)
6
+ ax = fig.add_subplot(111)
7
+ else
8
+ fig = ax.get_figure
9
+ end
10
+ fcst_t = to_pydatetime(fcst["ds"])
11
+ ax.plot(to_pydatetime(@history["ds"]), @history["y"].map(&:to_f), "k.")
12
+ ax.plot(fcst_t, fcst["yhat"].map(&:to_f), ls: "-", c: "#0072B2")
13
+ if fcst.vectors.include?("cap") && plot_cap
14
+ ax.plot(fcst_t, fcst["cap"].map(&:to_f), ls: "--", c: "k")
15
+ end
16
+ if @logistic_floor && fcst.vectors.include?("floor") && plot_cap
17
+ ax.plot(fcst_t, fcst["floor"].map(&:to_f), ls: "--", c: "k")
18
+ end
19
+ if uncertainty && @uncertainty_samples
20
+ ax.fill_between(fcst_t, fcst["yhat_lower"].map(&:to_f), fcst["yhat_upper"].map(&:to_f), color: "#0072B2", alpha: 0.2)
21
+ end
22
+ # Specify formatting to workaround matplotlib issue #12925
23
+ locator = dates.AutoDateLocator.new(interval_multiples: false)
24
+ formatter = dates.AutoDateFormatter.new(locator)
25
+ ax.xaxis.set_major_locator(locator)
26
+ ax.xaxis.set_major_formatter(formatter)
27
+ ax.grid(true, which: "major", c: "gray", ls: "-", lw: 1, alpha: 0.2)
28
+ ax.set_xlabel(xlabel)
29
+ ax.set_ylabel(ylabel)
30
+ fig.tight_layout
31
+ fig
32
+ end
33
+
34
+ def plot_components(fcst, uncertainty: true, plot_cap: true, weekly_start: 0, yearly_start: 0, figsize: nil)
35
+ components = ["trend"]
36
+ if @train_holiday_names && fcst.vectors.include?("holidays")
37
+ components << "holidays"
38
+ end
39
+ # Plot weekly seasonality, if present
40
+ if @seasonalities["weekly"] && fcst.vectors.include?("weekly")
41
+ components << "weekly"
42
+ end
43
+ # Yearly if present
44
+ if @seasonalities["yearly"] && fcst.vectors.include?("yearly")
45
+ components << "yearly"
46
+ end
47
+ # Other seasonalities
48
+ components.concat(@seasonalities.keys.select { |name| fcst.vectors.include?(name) && !["weekly", "yearly"].include?(name) }.sort)
49
+ regressors = {"additive" => false, "multiplicative" => false}
50
+ @extra_regressors.each do |name, props|
51
+ regressors[props[:mode]] = true
52
+ end
53
+ ["additive", "multiplicative"].each do |mode|
54
+ if regressors[mode] && fcst.vectors.include?("extra_regressors_#{mode}")
55
+ components << "extra_regressors_#{mode}"
56
+ end
57
+ end
58
+ npanel = components.size
59
+
60
+ figsize = figsize || [9, 3 * npanel]
61
+ fig, axes = plt.subplots(npanel, 1, facecolor: "w", figsize: figsize)
62
+
63
+ if npanel == 1
64
+ axes = [axes]
65
+ end
66
+
67
+ multiplicative_axes = []
68
+
69
+ axes.tolist.zip(components) do |ax, plot_name|
70
+ if plot_name == "trend"
71
+ plot_forecast_component(fcst, "trend", ax: ax, uncertainty: uncertainty, plot_cap: plot_cap)
72
+ elsif @seasonalities[plot_name]
73
+ if plot_name == "weekly" || @seasonalities[plot_name][:period] == 7
74
+ plot_weekly(name: plot_name, ax: ax, uncertainty: uncertainty, weekly_start: weekly_start)
75
+ elsif plot_name == "yearly" || @seasonalities[plot_name][:period] == 365.25
76
+ plot_yearly(name: plot_name, ax: ax, uncertainty: uncertainty, yearly_start: yearly_start)
77
+ else
78
+ plot_seasonality(name: plot_name, ax: ax, uncertainty: uncertainty)
79
+ end
80
+ elsif ["holidays", "extra_regressors_additive", "extra_regressors_multiplicative"].include?(plot_name)
81
+ plot_forecast_component(fcst, plot_name, ax: ax, uncertainty: uncertainty, plot_cap: false)
82
+ end
83
+ if @component_modes["multiplicative"].include?(plot_name)
84
+ multiplicative_axes << ax
85
+ end
86
+ end
87
+
88
+ fig.tight_layout
89
+ # Reset multiplicative axes labels after tight_layout adjustment
90
+ multiplicative_axes.each do |ax|
91
+ ax = set_y_as_percent(ax)
92
+ end
93
+ fig
94
+ end
95
+
96
+ private
97
+
98
+ def plot_forecast_component(fcst, name, ax: nil, uncertainty: true, plot_cap: false, figsize: [10, 6])
99
+ artists = []
100
+ if !ax
101
+ fig = plt.figure(facecolor: "w", figsize: figsize)
102
+ ax = fig.add_subplot(111)
103
+ end
104
+ fcst_t = to_pydatetime(fcst["ds"])
105
+ artists += ax.plot(fcst_t, fcst[name].map(&:to_f), ls: "-", c: "#0072B2")
106
+ if fcst.vectors.include?("cap") && plot_cap
107
+ artists += ax.plot(fcst_t, fcst["cap"].map(&:to_f), ls: "--", c: "k")
108
+ end
109
+ if @logistic_floor && fcst.vectors.include?("floor") && plot_cap
110
+ ax.plot(fcst_t, fcst["floor"].map(&:to_f), ls: "--", c: "k")
111
+ end
112
+ if uncertainty && @uncertainty_samples
113
+ artists += [ax.fill_between(fcst_t, fcst[name + "_lower"].map(&:to_f), fcst[name + "_upper"].map(&:to_f), color: "#0072B2", alpha: 0.2)]
114
+ end
115
+ # Specify formatting to workaround matplotlib issue #12925
116
+ locator = dates.AutoDateLocator.new(interval_multiples: false)
117
+ formatter = dates.AutoDateFormatter.new(locator)
118
+ ax.xaxis.set_major_locator(locator)
119
+ ax.xaxis.set_major_formatter(formatter)
120
+ ax.grid(true, which: "major", c: "gray", ls: "-", lw: 1, alpha: 0.2)
121
+ ax.set_xlabel("ds")
122
+ ax.set_ylabel(name)
123
+ if @component_modes["multiplicative"].include?(name)
124
+ ax = set_y_as_percent(ax)
125
+ end
126
+ artists
127
+ end
128
+
129
+ def seasonality_plot_df(ds)
130
+ df_dict = {"ds" => ds, "cap" => [1.0] * ds.size, "floor" => [0.0] * ds.size}
131
+ @extra_regressors.each do |name|
132
+ df_dict[name] = [0.0] * ds.size
133
+ end
134
+ # Activate all conditional seasonality columns
135
+ @seasonalities.values.each do |props|
136
+ if props[:condition_name]
137
+ df_dict[props[:condition_name]] = [true] * ds.size
138
+ end
139
+ end
140
+ df = Daru::DataFrame.new(df_dict)
141
+ df = setup_dataframe(df)
142
+ df
143
+ end
144
+
145
+ def plot_weekly(ax: nil, uncertainty: true, weekly_start: 0, figsize: [10, 6], name: "weekly")
146
+ artists = []
147
+ if !ax
148
+ fig = plt.figure(facecolor: "w", figsize: figsize)
149
+ ax = fig.add_subplot(111)
150
+ end
151
+ # Compute weekly seasonality for a Sun-Sat sequence of dates.
152
+ start = Date.parse("2017-01-01")
153
+ days = 7.times.map { |i| start + i + weekly_start }
154
+ df_w = seasonality_plot_df(days)
155
+ seas = predict_seasonal_components(df_w)
156
+ days = days.map { |v| v.strftime("%A") }
157
+ artists += ax.plot(days.size.times.to_a, seas[name].map(&:to_f), ls: "-", c: "#0072B2")
158
+ if uncertainty && @uncertainty_samples
159
+ artists += [ax.fill_between(days.size.times.to_a, seas[name + "_lower"].map(&:to_f), seas[name + "_upper"].map(&:to_f), color: "#0072B2", alpha: 0.2)]
160
+ end
161
+ ax.grid(true, which: "major", c: "gray", ls: "-", lw: 1, alpha: 0.2)
162
+ ax.set_xticks(days.size.times.to_a)
163
+ ax.set_xticklabels(days)
164
+ ax.set_xlabel("Day of week")
165
+ ax.set_ylabel(name)
166
+ if @seasonalities[name]["mode"] == "multiplicative"
167
+ ax = set_y_as_percent(ax)
168
+ end
169
+ artists
170
+ end
171
+
172
+ def plot_yearly(ax: nil, uncertainty: true, yearly_start: 0, figsize: [10, 6], name: "yearly")
173
+ artists = []
174
+ if !ax
175
+ fig = plt.figure(facecolor: "w", figsize: figsize)
176
+ ax = fig.add_subplot(111)
177
+ end
178
+ # Compute yearly seasonality for a Jan 1 - Dec 31 sequence of dates.
179
+ start = Date.parse("2017-01-01")
180
+ days = 365.times.map { |i| start + i + yearly_start }
181
+ df_y = seasonality_plot_df(days)
182
+ seas = predict_seasonal_components(df_y)
183
+ artists += ax.plot(to_pydatetime(df_y["ds"]), seas[name].map(&:to_f), ls: "-", c: "#0072B2")
184
+ if uncertainty && @uncertainty_samples
185
+ artists += [ax.fill_between(to_pydatetime(df_y["ds"]), seas[name + "_lower"].map(&:to_f), seas[name + "_upper"].map(&:to_f), color: "#0072B2", alpha: 0.2)]
186
+ end
187
+ ax.grid(true, which: "major", c: "gray", ls: "-", lw: 1, alpha: 0.2)
188
+ months = dates.MonthLocator.new((1..12).to_a, bymonthday: 1, interval: 2)
189
+ ax.xaxis.set_major_formatter(ticker.FuncFormatter.new(lambda { |x, pos=nil| dates.num2date(x).strftime("%B %-e") }))
190
+ ax.xaxis.set_major_locator(months)
191
+ ax.set_xlabel("Day of year")
192
+ ax.set_ylabel(name)
193
+ if @seasonalities[name][:mode] == "multiplicative"
194
+ ax = set_y_as_percent(ax)
195
+ end
196
+ artists
197
+ end
198
+
199
+ def plot_seasonality(name:, ax: nil, uncertainty: true, figsize: [10, 6])
200
+ artists = []
201
+ if !ax
202
+ fig = plt.figure(facecolor: "w", figsize: figsize)
203
+ ax = fig.add_subplot(111)
204
+ end
205
+ # Compute seasonality from Jan 1 through a single period.
206
+ start = Time.utc(2017)
207
+ period = @seasonalities[name][:period]
208
+ finish = start + period * 86400
209
+ plot_points = 200
210
+ start = start.to_i
211
+ finish = finish.to_i
212
+ step = (finish - start) / (plot_points - 1).to_f
213
+ days = plot_points.times.map { |i| Time.at(start + i * step).utc }
214
+ df_y = seasonality_plot_df(days)
215
+ seas = predict_seasonal_components(df_y)
216
+ artists += ax.plot(to_pydatetime(df_y["ds"]), seas[name].map(&:to_f), ls: "-", c: "#0072B2")
217
+ if uncertainty && @uncertainty_samples
218
+ artists += [ax.fill_between(to_pydatetime(df_y["ds"]), seas[name + "_lower"].map(&:to_f), seas[name + "_upper"].map(&:to_f), color: "#0072B2", alpha: 0.2)]
219
+ end
220
+ ax.grid(true, which: "major", c: "gray", ls: "-", lw: 1, alpha: 0.2)
221
+ step = (finish - start) / (7 - 1).to_f
222
+ xticks = to_pydatetime(7.times.map { |i| Time.at(start + i * step).utc })
223
+ ax.set_xticks(xticks)
224
+ if period <= 2
225
+ fmt_str = "%T"
226
+ elsif period < 14
227
+ fmt_str = "%m/%d %R"
228
+ else
229
+ fmt_str = "%m/%d"
230
+ end
231
+ ax.xaxis.set_major_formatter(ticker.FuncFormatter.new(lambda { |x, pos=nil| dates.num2date(x).strftime(fmt_str) }))
232
+ ax.set_xlabel("ds")
233
+ ax.set_ylabel(name)
234
+ if @seasonalities[name][:mode] == "multiplicative"
235
+ ax = set_y_as_percent(ax)
236
+ end
237
+ artists
238
+ end
239
+
240
+ def set_y_as_percent(ax)
241
+ yticks = 100 * ax.get_yticks
242
+ yticklabels = yticks.tolist.map { |y| "%.4g%%" % y }
243
+ ax.set_yticklabels(yticklabels)
244
+ ax
245
+ end
246
+
247
+ def plt
248
+ begin
249
+ require "matplotlib/pyplot"
250
+ rescue LoadError
251
+ raise Error, "Install the matplotlib gem for plots"
252
+ end
253
+ Matplotlib::Pyplot
254
+ end
255
+
256
+ def dates
257
+ PyCall.import_module("matplotlib.dates")
258
+ end
259
+
260
+ def ticker
261
+ PyCall.import_module("matplotlib.ticker")
262
+ end
263
+
264
+ def to_pydatetime(v)
265
+ datetime = PyCall.import_module("datetime")
266
+ v.map { |v| datetime.datetime.utcfromtimestamp(v.to_i) }
267
+ end
268
+ end
269
+ end
@@ -0,0 +1,136 @@
1
+ module Prophet
2
+ class StanBackend
3
+ def initialize(logger)
4
+ @model = load_model
5
+ @logger = logger
6
+ end
7
+
8
+ def load_model
9
+ model_file = File.expand_path("../../stan_model/prophet_model.bin", __dir__)
10
+ CmdStan::Model.new(exe_file: model_file)
11
+ end
12
+
13
+ def fit(stan_init, stan_data, **kwargs)
14
+ stan_init, stan_data = prepare_data(stan_init, stan_data)
15
+ kwargs[:algorithm] ||= stan_data["T"] < 100 ? "Newton" : "LBFGS"
16
+ iterations = 10000
17
+
18
+ stan_fit = nil
19
+ begin
20
+ stan_fit = @model.optimize(
21
+ data: stan_data,
22
+ inits: stan_init,
23
+ iter: iterations,
24
+ **kwargs
25
+ )
26
+ rescue => e
27
+ if kwargs[:algorithm] != "Newton"
28
+ @logger.warn "Optimization terminated abnormally. Falling back to Newton."
29
+ kwargs[:algorithm] = "Newton"
30
+ stan_fit = @model.optimize(
31
+ data: stan_data,
32
+ inits: stan_init,
33
+ iter: iterations,
34
+ **kwargs
35
+ )
36
+ else
37
+ raise e
38
+ end
39
+ end
40
+
41
+ params = stan_to_numo(stan_fit.column_names, Numo::NArray.asarray(stan_fit.optimized_params.values))
42
+ params.each_key do |par|
43
+ params[par] = params[par].reshape(1, *params[par].shape)
44
+ end
45
+ params
46
+ end
47
+
48
+ def sampling(stan_init, stan_data, samples, **kwargs)
49
+ stan_init, stan_data = prepare_data(stan_init, stan_data)
50
+
51
+ kwargs[:chains] ||= 4
52
+ kwargs[:warmup_iters] ||= samples / 2
53
+
54
+ stan_fit = @model.sample(
55
+ data: stan_data,
56
+ inits: stan_init,
57
+ sampling_iters: samples,
58
+ **kwargs
59
+ )
60
+ res = Numo::NArray.asarray(stan_fit.sample)
61
+ samples, c, columns = res.shape
62
+ res = res.reshape(samples * c, columns)
63
+ params = stan_to_numo(stan_fit.column_names, res)
64
+
65
+ params.each_key do |par|
66
+ s = params[par].shape
67
+
68
+ if s[1] == 1
69
+ params[par] = params[par].reshape(s[0])
70
+ end
71
+
72
+ if ["delta", "beta"].include?(par) && s.size < 2
73
+ params[par] = params[par].reshape(-1, 1)
74
+ end
75
+ end
76
+
77
+ params
78
+ end
79
+
80
+ private
81
+
82
+ def stan_to_numo(column_names, data)
83
+ output = {}
84
+
85
+ prev = nil
86
+
87
+ start = 0
88
+ finish = 0
89
+
90
+ two_dims = data.shape.size > 1
91
+
92
+ column_names.each do |cname|
93
+ parsed = cname.split(".")
94
+
95
+ curr = parsed[0]
96
+ prev = curr if prev.nil?
97
+
98
+ if curr != prev
99
+ raise Error, "Found repeated column name" if output[prev]
100
+ if two_dims
101
+ output[prev] = Numo::NArray.asarray(data[true, start...finish])
102
+ else
103
+ output[prev] = Numo::NArray.asarray(data[start...finish])
104
+ end
105
+ prev = curr
106
+ start = finish
107
+ finish += 1
108
+ else
109
+ finish += 1
110
+ end
111
+ end
112
+
113
+ raise Error, "Found repeated column name" if output[prev]
114
+ if two_dims
115
+ output[prev] = Numo::NArray.asarray(data[true, start...finish])
116
+ else
117
+ output[prev] = Numo::NArray.asarray(data[start...finish])
118
+ end
119
+
120
+ output
121
+ end
122
+
123
+ def prepare_data(stan_init, stan_data)
124
+ stan_data["y"] = stan_data["y"].to_a
125
+ stan_data["t"] = stan_data["t"].to_a
126
+ stan_data["cap"] = stan_data["cap"].to_a
127
+ stan_data["t_change"] = stan_data["t_change"].to_a
128
+ stan_data["s_a"] = stan_data["s_a"].to_a
129
+ stan_data["s_m"] = stan_data["s_m"].to_a
130
+ stan_data["X"] = stan_data["X"].to_matrix.to_a
131
+ stan_init["delta"] = stan_init["delta"].to_a
132
+ stan_init["beta"] = stan_init["beta"].to_a
133
+ [stan_init, stan_data]
134
+ end
135
+ end
136
+ end
@@ -0,0 +1,3 @@
1
+ module Prophet
2
+ VERSION = "0.1.0"
3
+ end
@@ -0,0 +1,131 @@
1
+ // Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ // This source code is licensed under the MIT license found in the
4
+ // LICENSE file in the root directory of this source tree.
5
+
6
+ functions {
7
+ matrix get_changepoint_matrix(vector t, vector t_change, int T, int S) {
8
+ // Assumes t and t_change are sorted.
9
+ matrix[T, S] A;
10
+ row_vector[S] a_row;
11
+ int cp_idx;
12
+
13
+ // Start with an empty matrix.
14
+ A = rep_matrix(0, T, S);
15
+ a_row = rep_row_vector(0, S);
16
+ cp_idx = 1;
17
+
18
+ // Fill in each row of A.
19
+ for (i in 1:T) {
20
+ while ((cp_idx <= S) && (t[i] >= t_change[cp_idx])) {
21
+ a_row[cp_idx] = 1;
22
+ cp_idx = cp_idx + 1;
23
+ }
24
+ A[i] = a_row;
25
+ }
26
+ return A;
27
+ }
28
+
29
+ // Logistic trend functions
30
+
31
+ vector logistic_gamma(real k, real m, vector delta, vector t_change, int S) {
32
+ vector[S] gamma; // adjusted offsets, for piecewise continuity
33
+ vector[S + 1] k_s; // actual rate in each segment
34
+ real m_pr;
35
+
36
+ // Compute the rate in each segment
37
+ k_s = append_row(k, k + cumulative_sum(delta));
38
+
39
+ // Piecewise offsets
40
+ m_pr = m; // The offset in the previous segment
41
+ for (i in 1:S) {
42
+ gamma[i] = (t_change[i] - m_pr) * (1 - k_s[i] / k_s[i + 1]);
43
+ m_pr = m_pr + gamma[i]; // update for the next segment
44
+ }
45
+ return gamma;
46
+ }
47
+
48
+ vector logistic_trend(
49
+ real k,
50
+ real m,
51
+ vector delta,
52
+ vector t,
53
+ vector cap,
54
+ matrix A,
55
+ vector t_change,
56
+ int S
57
+ ) {
58
+ vector[S] gamma;
59
+
60
+ gamma = logistic_gamma(k, m, delta, t_change, S);
61
+ return cap .* inv_logit((k + A * delta) .* (t - (m + A * gamma)));
62
+ }
63
+
64
+ // Linear trend function
65
+
66
+ vector linear_trend(
67
+ real k,
68
+ real m,
69
+ vector delta,
70
+ vector t,
71
+ matrix A,
72
+ vector t_change
73
+ ) {
74
+ return (k + A * delta) .* t + (m + A * (-t_change .* delta));
75
+ }
76
+ }
77
+
78
+ data {
79
+ int T; // Number of time periods
80
+ int<lower=1> K; // Number of regressors
81
+ vector[T] t; // Time
82
+ vector[T] cap; // Capacities for logistic trend
83
+ vector[T] y; // Time series
84
+ int S; // Number of changepoints
85
+ vector[S] t_change; // Times of trend changepoints
86
+ matrix[T,K] X; // Regressors
87
+ vector[K] sigmas; // Scale on seasonality prior
88
+ real<lower=0> tau; // Scale on changepoints prior
89
+ int trend_indicator; // 0 for linear, 1 for logistic
90
+ vector[K] s_a; // Indicator of additive features
91
+ vector[K] s_m; // Indicator of multiplicative features
92
+ }
93
+
94
+ transformed data {
95
+ matrix[T, S] A;
96
+ A = get_changepoint_matrix(t, t_change, T, S);
97
+ }
98
+
99
+ parameters {
100
+ real k; // Base trend growth rate
101
+ real m; // Trend offset
102
+ vector[S] delta; // Trend rate adjustments
103
+ real<lower=0> sigma_obs; // Observation noise
104
+ vector[K] beta; // Regressor coefficients
105
+ }
106
+
107
+ model {
108
+ //priors
109
+ k ~ normal(0, 5);
110
+ m ~ normal(0, 5);
111
+ delta ~ double_exponential(0, tau);
112
+ sigma_obs ~ normal(0, 0.5);
113
+ beta ~ normal(0, sigmas);
114
+
115
+ // Likelihood
116
+ if (trend_indicator == 0) {
117
+ y ~ normal(
118
+ linear_trend(k, m, delta, t, A, t_change)
119
+ .* (1 + X * (beta .* s_m))
120
+ + X * (beta .* s_a),
121
+ sigma_obs
122
+ );
123
+ } else if (trend_indicator == 1) {
124
+ y ~ normal(
125
+ logistic_trend(k, m, delta, t, cap, A, t_change, S)
126
+ .* (1 + X * (beta .* s_m))
127
+ + X * (beta .* s_a),
128
+ sigma_obs
129
+ );
130
+ }
131
+ }