prophet-rb 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,27 @@
1
+ module Prophet
2
+ module Holidays
3
+ def get_holiday_names(country)
4
+ years = (1995..2045).to_a
5
+ make_holidays_df(years, country)["holiday"].uniq
6
+ end
7
+
8
+ def make_holidays_df(year_list, country)
9
+ holidays_df.where(holidays_df["country"].eq(country) & holidays_df["year"].in(year_list))["ds", "holiday"]
10
+ end
11
+
12
+ # TODO marshal on installation
13
+ def holidays_df
14
+ @holidays_df ||= begin
15
+ holidays = {"ds" => [], "holiday" => [], "country" => [], "year" => []}
16
+ holidays_file = File.expand_path("../../data-raw/generated_holidays.csv", __dir__)
17
+ CSV.foreach(holidays_file, headers: true, converters: [:date, :numeric]) do |row|
18
+ holidays["ds"] << row["ds"]
19
+ holidays["holiday"] << row["holiday"]
20
+ holidays["country"] << row["country"]
21
+ holidays["year"] << row["year"]
22
+ end
23
+ Daru::DataFrame.new(holidays)
24
+ end
25
+ end
26
+ end
27
+ end
@@ -0,0 +1,269 @@
1
+ module Prophet
2
+ module Plot
3
+ def plot(fcst, ax: nil, uncertainty: true, plot_cap: true, xlabel: "ds", ylabel: "y", figsize: [10, 6])
4
+ if ax.nil?
5
+ fig = plt.figure(facecolor: "w", figsize: figsize)
6
+ ax = fig.add_subplot(111)
7
+ else
8
+ fig = ax.get_figure
9
+ end
10
+ fcst_t = to_pydatetime(fcst["ds"])
11
+ ax.plot(to_pydatetime(@history["ds"]), @history["y"].map(&:to_f), "k.")
12
+ ax.plot(fcst_t, fcst["yhat"].map(&:to_f), ls: "-", c: "#0072B2")
13
+ if fcst.vectors.include?("cap") && plot_cap
14
+ ax.plot(fcst_t, fcst["cap"].map(&:to_f), ls: "--", c: "k")
15
+ end
16
+ if @logistic_floor && fcst.vectors.include?("floor") && plot_cap
17
+ ax.plot(fcst_t, fcst["floor"].map(&:to_f), ls: "--", c: "k")
18
+ end
19
+ if uncertainty && @uncertainty_samples
20
+ ax.fill_between(fcst_t, fcst["yhat_lower"].map(&:to_f), fcst["yhat_upper"].map(&:to_f), color: "#0072B2", alpha: 0.2)
21
+ end
22
+ # Specify formatting to workaround matplotlib issue #12925
23
+ locator = dates.AutoDateLocator.new(interval_multiples: false)
24
+ formatter = dates.AutoDateFormatter.new(locator)
25
+ ax.xaxis.set_major_locator(locator)
26
+ ax.xaxis.set_major_formatter(formatter)
27
+ ax.grid(true, which: "major", c: "gray", ls: "-", lw: 1, alpha: 0.2)
28
+ ax.set_xlabel(xlabel)
29
+ ax.set_ylabel(ylabel)
30
+ fig.tight_layout
31
+ fig
32
+ end
33
+
34
+ def plot_components(fcst, uncertainty: true, plot_cap: true, weekly_start: 0, yearly_start: 0, figsize: nil)
35
+ components = ["trend"]
36
+ if @train_holiday_names && fcst.vectors.include?("holidays")
37
+ components << "holidays"
38
+ end
39
+ # Plot weekly seasonality, if present
40
+ if @seasonalities["weekly"] && fcst.vectors.include?("weekly")
41
+ components << "weekly"
42
+ end
43
+ # Yearly if present
44
+ if @seasonalities["yearly"] && fcst.vectors.include?("yearly")
45
+ components << "yearly"
46
+ end
47
+ # Other seasonalities
48
+ components.concat(@seasonalities.keys.select { |name| fcst.vectors.include?(name) && !["weekly", "yearly"].include?(name) }.sort)
49
+ regressors = {"additive" => false, "multiplicative" => false}
50
+ @extra_regressors.each do |name, props|
51
+ regressors[props[:mode]] = true
52
+ end
53
+ ["additive", "multiplicative"].each do |mode|
54
+ if regressors[mode] && fcst.vectors.include?("extra_regressors_#{mode}")
55
+ components << "extra_regressors_#{mode}"
56
+ end
57
+ end
58
+ npanel = components.size
59
+
60
+ figsize = figsize || [9, 3 * npanel]
61
+ fig, axes = plt.subplots(npanel, 1, facecolor: "w", figsize: figsize)
62
+
63
+ if npanel == 1
64
+ axes = [axes]
65
+ end
66
+
67
+ multiplicative_axes = []
68
+
69
+ axes.tolist.zip(components) do |ax, plot_name|
70
+ if plot_name == "trend"
71
+ plot_forecast_component(fcst, "trend", ax: ax, uncertainty: uncertainty, plot_cap: plot_cap)
72
+ elsif @seasonalities[plot_name]
73
+ if plot_name == "weekly" || @seasonalities[plot_name][:period] == 7
74
+ plot_weekly(name: plot_name, ax: ax, uncertainty: uncertainty, weekly_start: weekly_start)
75
+ elsif plot_name == "yearly" || @seasonalities[plot_name][:period] == 365.25
76
+ plot_yearly(name: plot_name, ax: ax, uncertainty: uncertainty, yearly_start: yearly_start)
77
+ else
78
+ plot_seasonality(name: plot_name, ax: ax, uncertainty: uncertainty)
79
+ end
80
+ elsif ["holidays", "extra_regressors_additive", "extra_regressors_multiplicative"].include?(plot_name)
81
+ plot_forecast_component(fcst, plot_name, ax: ax, uncertainty: uncertainty, plot_cap: false)
82
+ end
83
+ if @component_modes["multiplicative"].include?(plot_name)
84
+ multiplicative_axes << ax
85
+ end
86
+ end
87
+
88
+ fig.tight_layout
89
+ # Reset multiplicative axes labels after tight_layout adjustment
90
+ multiplicative_axes.each do |ax|
91
+ ax = set_y_as_percent(ax)
92
+ end
93
+ fig
94
+ end
95
+
96
+ private
97
+
98
+ def plot_forecast_component(fcst, name, ax: nil, uncertainty: true, plot_cap: false, figsize: [10, 6])
99
+ artists = []
100
+ if !ax
101
+ fig = plt.figure(facecolor: "w", figsize: figsize)
102
+ ax = fig.add_subplot(111)
103
+ end
104
+ fcst_t = to_pydatetime(fcst["ds"])
105
+ artists += ax.plot(fcst_t, fcst[name].map(&:to_f), ls: "-", c: "#0072B2")
106
+ if fcst.vectors.include?("cap") && plot_cap
107
+ artists += ax.plot(fcst_t, fcst["cap"].map(&:to_f), ls: "--", c: "k")
108
+ end
109
+ if @logistic_floor && fcst.vectors.include?("floor") && plot_cap
110
+ ax.plot(fcst_t, fcst["floor"].map(&:to_f), ls: "--", c: "k")
111
+ end
112
+ if uncertainty && @uncertainty_samples
113
+ artists += [ax.fill_between(fcst_t, fcst[name + "_lower"].map(&:to_f), fcst[name + "_upper"].map(&:to_f), color: "#0072B2", alpha: 0.2)]
114
+ end
115
+ # Specify formatting to workaround matplotlib issue #12925
116
+ locator = dates.AutoDateLocator.new(interval_multiples: false)
117
+ formatter = dates.AutoDateFormatter.new(locator)
118
+ ax.xaxis.set_major_locator(locator)
119
+ ax.xaxis.set_major_formatter(formatter)
120
+ ax.grid(true, which: "major", c: "gray", ls: "-", lw: 1, alpha: 0.2)
121
+ ax.set_xlabel("ds")
122
+ ax.set_ylabel(name)
123
+ if @component_modes["multiplicative"].include?(name)
124
+ ax = set_y_as_percent(ax)
125
+ end
126
+ artists
127
+ end
128
+
129
+ def seasonality_plot_df(ds)
130
+ df_dict = {"ds" => ds, "cap" => [1.0] * ds.size, "floor" => [0.0] * ds.size}
131
+ @extra_regressors.each do |name|
132
+ df_dict[name] = [0.0] * ds.size
133
+ end
134
+ # Activate all conditional seasonality columns
135
+ @seasonalities.values.each do |props|
136
+ if props[:condition_name]
137
+ df_dict[props[:condition_name]] = [true] * ds.size
138
+ end
139
+ end
140
+ df = Daru::DataFrame.new(df_dict)
141
+ df = setup_dataframe(df)
142
+ df
143
+ end
144
+
145
+ def plot_weekly(ax: nil, uncertainty: true, weekly_start: 0, figsize: [10, 6], name: "weekly")
146
+ artists = []
147
+ if !ax
148
+ fig = plt.figure(facecolor: "w", figsize: figsize)
149
+ ax = fig.add_subplot(111)
150
+ end
151
+ # Compute weekly seasonality for a Sun-Sat sequence of dates.
152
+ start = Date.parse("2017-01-01")
153
+ days = 7.times.map { |i| start + i + weekly_start }
154
+ df_w = seasonality_plot_df(days)
155
+ seas = predict_seasonal_components(df_w)
156
+ days = days.map { |v| v.strftime("%A") }
157
+ artists += ax.plot(days.size.times.to_a, seas[name].map(&:to_f), ls: "-", c: "#0072B2")
158
+ if uncertainty && @uncertainty_samples
159
+ artists += [ax.fill_between(days.size.times.to_a, seas[name + "_lower"].map(&:to_f), seas[name + "_upper"].map(&:to_f), color: "#0072B2", alpha: 0.2)]
160
+ end
161
+ ax.grid(true, which: "major", c: "gray", ls: "-", lw: 1, alpha: 0.2)
162
+ ax.set_xticks(days.size.times.to_a)
163
+ ax.set_xticklabels(days)
164
+ ax.set_xlabel("Day of week")
165
+ ax.set_ylabel(name)
166
+ if @seasonalities[name]["mode"] == "multiplicative"
167
+ ax = set_y_as_percent(ax)
168
+ end
169
+ artists
170
+ end
171
+
172
+ def plot_yearly(ax: nil, uncertainty: true, yearly_start: 0, figsize: [10, 6], name: "yearly")
173
+ artists = []
174
+ if !ax
175
+ fig = plt.figure(facecolor: "w", figsize: figsize)
176
+ ax = fig.add_subplot(111)
177
+ end
178
+ # Compute yearly seasonality for a Jan 1 - Dec 31 sequence of dates.
179
+ start = Date.parse("2017-01-01")
180
+ days = 365.times.map { |i| start + i + yearly_start }
181
+ df_y = seasonality_plot_df(days)
182
+ seas = predict_seasonal_components(df_y)
183
+ artists += ax.plot(to_pydatetime(df_y["ds"]), seas[name].map(&:to_f), ls: "-", c: "#0072B2")
184
+ if uncertainty && @uncertainty_samples
185
+ artists += [ax.fill_between(to_pydatetime(df_y["ds"]), seas[name + "_lower"].map(&:to_f), seas[name + "_upper"].map(&:to_f), color: "#0072B2", alpha: 0.2)]
186
+ end
187
+ ax.grid(true, which: "major", c: "gray", ls: "-", lw: 1, alpha: 0.2)
188
+ months = dates.MonthLocator.new((1..12).to_a, bymonthday: 1, interval: 2)
189
+ ax.xaxis.set_major_formatter(ticker.FuncFormatter.new(lambda { |x, pos=nil| dates.num2date(x).strftime("%B %-e") }))
190
+ ax.xaxis.set_major_locator(months)
191
+ ax.set_xlabel("Day of year")
192
+ ax.set_ylabel(name)
193
+ if @seasonalities[name][:mode] == "multiplicative"
194
+ ax = set_y_as_percent(ax)
195
+ end
196
+ artists
197
+ end
198
+
199
+ def plot_seasonality(name:, ax: nil, uncertainty: true, figsize: [10, 6])
200
+ artists = []
201
+ if !ax
202
+ fig = plt.figure(facecolor: "w", figsize: figsize)
203
+ ax = fig.add_subplot(111)
204
+ end
205
+ # Compute seasonality from Jan 1 through a single period.
206
+ start = Time.utc(2017)
207
+ period = @seasonalities[name][:period]
208
+ finish = start + period * 86400
209
+ plot_points = 200
210
+ start = start.to_i
211
+ finish = finish.to_i
212
+ step = (finish - start) / (plot_points - 1).to_f
213
+ days = plot_points.times.map { |i| Time.at(start + i * step).utc }
214
+ df_y = seasonality_plot_df(days)
215
+ seas = predict_seasonal_components(df_y)
216
+ artists += ax.plot(to_pydatetime(df_y["ds"]), seas[name].map(&:to_f), ls: "-", c: "#0072B2")
217
+ if uncertainty && @uncertainty_samples
218
+ artists += [ax.fill_between(to_pydatetime(df_y["ds"]), seas[name + "_lower"].map(&:to_f), seas[name + "_upper"].map(&:to_f), color: "#0072B2", alpha: 0.2)]
219
+ end
220
+ ax.grid(true, which: "major", c: "gray", ls: "-", lw: 1, alpha: 0.2)
221
+ step = (finish - start) / (7 - 1).to_f
222
+ xticks = to_pydatetime(7.times.map { |i| Time.at(start + i * step).utc })
223
+ ax.set_xticks(xticks)
224
+ if period <= 2
225
+ fmt_str = "%T"
226
+ elsif period < 14
227
+ fmt_str = "%m/%d %R"
228
+ else
229
+ fmt_str = "%m/%d"
230
+ end
231
+ ax.xaxis.set_major_formatter(ticker.FuncFormatter.new(lambda { |x, pos=nil| dates.num2date(x).strftime(fmt_str) }))
232
+ ax.set_xlabel("ds")
233
+ ax.set_ylabel(name)
234
+ if @seasonalities[name][:mode] == "multiplicative"
235
+ ax = set_y_as_percent(ax)
236
+ end
237
+ artists
238
+ end
239
+
240
+ def set_y_as_percent(ax)
241
+ yticks = 100 * ax.get_yticks
242
+ yticklabels = yticks.tolist.map { |y| "%.4g%%" % y }
243
+ ax.set_yticklabels(yticklabels)
244
+ ax
245
+ end
246
+
247
+ def plt
248
+ begin
249
+ require "matplotlib/pyplot"
250
+ rescue LoadError
251
+ raise Error, "Install the matplotlib gem for plots"
252
+ end
253
+ Matplotlib::Pyplot
254
+ end
255
+
256
+ def dates
257
+ PyCall.import_module("matplotlib.dates")
258
+ end
259
+
260
+ def ticker
261
+ PyCall.import_module("matplotlib.ticker")
262
+ end
263
+
264
+ def to_pydatetime(v)
265
+ datetime = PyCall.import_module("datetime")
266
+ v.map { |v| datetime.datetime.utcfromtimestamp(v.to_i) }
267
+ end
268
+ end
269
+ end
@@ -0,0 +1,136 @@
1
+ module Prophet
2
+ class StanBackend
3
+ def initialize(logger)
4
+ @model = load_model
5
+ @logger = logger
6
+ end
7
+
8
+ def load_model
9
+ model_file = File.expand_path("../../stan_model/prophet_model.bin", __dir__)
10
+ CmdStan::Model.new(exe_file: model_file)
11
+ end
12
+
13
+ def fit(stan_init, stan_data, **kwargs)
14
+ stan_init, stan_data = prepare_data(stan_init, stan_data)
15
+ kwargs[:algorithm] ||= stan_data["T"] < 100 ? "Newton" : "LBFGS"
16
+ iterations = 10000
17
+
18
+ stan_fit = nil
19
+ begin
20
+ stan_fit = @model.optimize(
21
+ data: stan_data,
22
+ inits: stan_init,
23
+ iter: iterations,
24
+ **kwargs
25
+ )
26
+ rescue => e
27
+ if kwargs[:algorithm] != "Newton"
28
+ @logger.warn "Optimization terminated abnormally. Falling back to Newton."
29
+ kwargs[:algorithm] = "Newton"
30
+ stan_fit = @model.optimize(
31
+ data: stan_data,
32
+ inits: stan_init,
33
+ iter: iterations,
34
+ **kwargs
35
+ )
36
+ else
37
+ raise e
38
+ end
39
+ end
40
+
41
+ params = stan_to_numo(stan_fit.column_names, Numo::NArray.asarray(stan_fit.optimized_params.values))
42
+ params.each_key do |par|
43
+ params[par] = params[par].reshape(1, *params[par].shape)
44
+ end
45
+ params
46
+ end
47
+
48
+ def sampling(stan_init, stan_data, samples, **kwargs)
49
+ stan_init, stan_data = prepare_data(stan_init, stan_data)
50
+
51
+ kwargs[:chains] ||= 4
52
+ kwargs[:warmup_iters] ||= samples / 2
53
+
54
+ stan_fit = @model.sample(
55
+ data: stan_data,
56
+ inits: stan_init,
57
+ sampling_iters: samples,
58
+ **kwargs
59
+ )
60
+ res = Numo::NArray.asarray(stan_fit.sample)
61
+ samples, c, columns = res.shape
62
+ res = res.reshape(samples * c, columns)
63
+ params = stan_to_numo(stan_fit.column_names, res)
64
+
65
+ params.each_key do |par|
66
+ s = params[par].shape
67
+
68
+ if s[1] == 1
69
+ params[par] = params[par].reshape(s[0])
70
+ end
71
+
72
+ if ["delta", "beta"].include?(par) && s.size < 2
73
+ params[par] = params[par].reshape(-1, 1)
74
+ end
75
+ end
76
+
77
+ params
78
+ end
79
+
80
+ private
81
+
82
+ def stan_to_numo(column_names, data)
83
+ output = {}
84
+
85
+ prev = nil
86
+
87
+ start = 0
88
+ finish = 0
89
+
90
+ two_dims = data.shape.size > 1
91
+
92
+ column_names.each do |cname|
93
+ parsed = cname.split(".")
94
+
95
+ curr = parsed[0]
96
+ prev = curr if prev.nil?
97
+
98
+ if curr != prev
99
+ raise Error, "Found repeated column name" if output[prev]
100
+ if two_dims
101
+ output[prev] = Numo::NArray.asarray(data[true, start...finish])
102
+ else
103
+ output[prev] = Numo::NArray.asarray(data[start...finish])
104
+ end
105
+ prev = curr
106
+ start = finish
107
+ finish += 1
108
+ else
109
+ finish += 1
110
+ end
111
+ end
112
+
113
+ raise Error, "Found repeated column name" if output[prev]
114
+ if two_dims
115
+ output[prev] = Numo::NArray.asarray(data[true, start...finish])
116
+ else
117
+ output[prev] = Numo::NArray.asarray(data[start...finish])
118
+ end
119
+
120
+ output
121
+ end
122
+
123
+ def prepare_data(stan_init, stan_data)
124
+ stan_data["y"] = stan_data["y"].to_a
125
+ stan_data["t"] = stan_data["t"].to_a
126
+ stan_data["cap"] = stan_data["cap"].to_a
127
+ stan_data["t_change"] = stan_data["t_change"].to_a
128
+ stan_data["s_a"] = stan_data["s_a"].to_a
129
+ stan_data["s_m"] = stan_data["s_m"].to_a
130
+ stan_data["X"] = stan_data["X"].to_matrix.to_a
131
+ stan_init["delta"] = stan_init["delta"].to_a
132
+ stan_init["beta"] = stan_init["beta"].to_a
133
+ [stan_init, stan_data]
134
+ end
135
+ end
136
+ end
@@ -0,0 +1,3 @@
1
+ module Prophet
2
+ VERSION = "0.1.0"
3
+ end
@@ -0,0 +1,131 @@
1
+ // Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ // This source code is licensed under the MIT license found in the
4
+ // LICENSE file in the root directory of this source tree.
5
+
6
+ functions {
7
+ matrix get_changepoint_matrix(vector t, vector t_change, int T, int S) {
8
+ // Assumes t and t_change are sorted.
9
+ matrix[T, S] A;
10
+ row_vector[S] a_row;
11
+ int cp_idx;
12
+
13
+ // Start with an empty matrix.
14
+ A = rep_matrix(0, T, S);
15
+ a_row = rep_row_vector(0, S);
16
+ cp_idx = 1;
17
+
18
+ // Fill in each row of A.
19
+ for (i in 1:T) {
20
+ while ((cp_idx <= S) && (t[i] >= t_change[cp_idx])) {
21
+ a_row[cp_idx] = 1;
22
+ cp_idx = cp_idx + 1;
23
+ }
24
+ A[i] = a_row;
25
+ }
26
+ return A;
27
+ }
28
+
29
+ // Logistic trend functions
30
+
31
+ vector logistic_gamma(real k, real m, vector delta, vector t_change, int S) {
32
+ vector[S] gamma; // adjusted offsets, for piecewise continuity
33
+ vector[S + 1] k_s; // actual rate in each segment
34
+ real m_pr;
35
+
36
+ // Compute the rate in each segment
37
+ k_s = append_row(k, k + cumulative_sum(delta));
38
+
39
+ // Piecewise offsets
40
+ m_pr = m; // The offset in the previous segment
41
+ for (i in 1:S) {
42
+ gamma[i] = (t_change[i] - m_pr) * (1 - k_s[i] / k_s[i + 1]);
43
+ m_pr = m_pr + gamma[i]; // update for the next segment
44
+ }
45
+ return gamma;
46
+ }
47
+
48
+ vector logistic_trend(
49
+ real k,
50
+ real m,
51
+ vector delta,
52
+ vector t,
53
+ vector cap,
54
+ matrix A,
55
+ vector t_change,
56
+ int S
57
+ ) {
58
+ vector[S] gamma;
59
+
60
+ gamma = logistic_gamma(k, m, delta, t_change, S);
61
+ return cap .* inv_logit((k + A * delta) .* (t - (m + A * gamma)));
62
+ }
63
+
64
+ // Linear trend function
65
+
66
+ vector linear_trend(
67
+ real k,
68
+ real m,
69
+ vector delta,
70
+ vector t,
71
+ matrix A,
72
+ vector t_change
73
+ ) {
74
+ return (k + A * delta) .* t + (m + A * (-t_change .* delta));
75
+ }
76
+ }
77
+
78
+ data {
79
+ int T; // Number of time periods
80
+ int<lower=1> K; // Number of regressors
81
+ vector[T] t; // Time
82
+ vector[T] cap; // Capacities for logistic trend
83
+ vector[T] y; // Time series
84
+ int S; // Number of changepoints
85
+ vector[S] t_change; // Times of trend changepoints
86
+ matrix[T,K] X; // Regressors
87
+ vector[K] sigmas; // Scale on seasonality prior
88
+ real<lower=0> tau; // Scale on changepoints prior
89
+ int trend_indicator; // 0 for linear, 1 for logistic
90
+ vector[K] s_a; // Indicator of additive features
91
+ vector[K] s_m; // Indicator of multiplicative features
92
+ }
93
+
94
+ transformed data {
95
+ matrix[T, S] A;
96
+ A = get_changepoint_matrix(t, t_change, T, S);
97
+ }
98
+
99
+ parameters {
100
+ real k; // Base trend growth rate
101
+ real m; // Trend offset
102
+ vector[S] delta; // Trend rate adjustments
103
+ real<lower=0> sigma_obs; // Observation noise
104
+ vector[K] beta; // Regressor coefficients
105
+ }
106
+
107
+ model {
108
+ //priors
109
+ k ~ normal(0, 5);
110
+ m ~ normal(0, 5);
111
+ delta ~ double_exponential(0, tau);
112
+ sigma_obs ~ normal(0, 0.5);
113
+ beta ~ normal(0, sigmas);
114
+
115
+ // Likelihood
116
+ if (trend_indicator == 0) {
117
+ y ~ normal(
118
+ linear_trend(k, m, delta, t, A, t_change)
119
+ .* (1 + X * (beta .* s_m))
120
+ + X * (beta .* s_a),
121
+ sigma_obs
122
+ );
123
+ } else if (trend_indicator == 1) {
124
+ y ~ normal(
125
+ logistic_trend(k, m, delta, t, cap, A, t_change, S)
126
+ .* (1 + X * (beta .* s_m))
127
+ + X * (beta .* s_a),
128
+ sigma_obs
129
+ );
130
+ }
131
+ }