prophet-rb 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +23 -0
- data/README.md +202 -0
- data/data-raw/generated_holidays.csv +96474 -0
- data/ext/prophet/Makefile +5 -0
- data/ext/prophet/extconf.rb +18 -0
- data/lib/prophet-rb.rb +1 -0
- data/lib/prophet.rb +23 -0
- data/lib/prophet/forecaster.rb +986 -0
- data/lib/prophet/holidays.rb +27 -0
- data/lib/prophet/plot.rb +269 -0
- data/lib/prophet/stan_backend.rb +136 -0
- data/lib/prophet/version.rb +3 -0
- data/stan/unix/prophet.stan +131 -0
- data/stan/win/prophet.stan +162 -0
- metadata +170 -0
@@ -0,0 +1,27 @@
|
|
1
|
+
module Prophet
|
2
|
+
module Holidays
|
3
|
+
def get_holiday_names(country)
|
4
|
+
years = (1995..2045).to_a
|
5
|
+
make_holidays_df(years, country)["holiday"].uniq
|
6
|
+
end
|
7
|
+
|
8
|
+
def make_holidays_df(year_list, country)
|
9
|
+
holidays_df.where(holidays_df["country"].eq(country) & holidays_df["year"].in(year_list))["ds", "holiday"]
|
10
|
+
end
|
11
|
+
|
12
|
+
# TODO marshal on installation
|
13
|
+
def holidays_df
|
14
|
+
@holidays_df ||= begin
|
15
|
+
holidays = {"ds" => [], "holiday" => [], "country" => [], "year" => []}
|
16
|
+
holidays_file = File.expand_path("../../data-raw/generated_holidays.csv", __dir__)
|
17
|
+
CSV.foreach(holidays_file, headers: true, converters: [:date, :numeric]) do |row|
|
18
|
+
holidays["ds"] << row["ds"]
|
19
|
+
holidays["holiday"] << row["holiday"]
|
20
|
+
holidays["country"] << row["country"]
|
21
|
+
holidays["year"] << row["year"]
|
22
|
+
end
|
23
|
+
Daru::DataFrame.new(holidays)
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
data/lib/prophet/plot.rb
ADDED
@@ -0,0 +1,269 @@
|
|
1
|
+
module Prophet
|
2
|
+
module Plot
|
3
|
+
def plot(fcst, ax: nil, uncertainty: true, plot_cap: true, xlabel: "ds", ylabel: "y", figsize: [10, 6])
|
4
|
+
if ax.nil?
|
5
|
+
fig = plt.figure(facecolor: "w", figsize: figsize)
|
6
|
+
ax = fig.add_subplot(111)
|
7
|
+
else
|
8
|
+
fig = ax.get_figure
|
9
|
+
end
|
10
|
+
fcst_t = to_pydatetime(fcst["ds"])
|
11
|
+
ax.plot(to_pydatetime(@history["ds"]), @history["y"].map(&:to_f), "k.")
|
12
|
+
ax.plot(fcst_t, fcst["yhat"].map(&:to_f), ls: "-", c: "#0072B2")
|
13
|
+
if fcst.vectors.include?("cap") && plot_cap
|
14
|
+
ax.plot(fcst_t, fcst["cap"].map(&:to_f), ls: "--", c: "k")
|
15
|
+
end
|
16
|
+
if @logistic_floor && fcst.vectors.include?("floor") && plot_cap
|
17
|
+
ax.plot(fcst_t, fcst["floor"].map(&:to_f), ls: "--", c: "k")
|
18
|
+
end
|
19
|
+
if uncertainty && @uncertainty_samples
|
20
|
+
ax.fill_between(fcst_t, fcst["yhat_lower"].map(&:to_f), fcst["yhat_upper"].map(&:to_f), color: "#0072B2", alpha: 0.2)
|
21
|
+
end
|
22
|
+
# Specify formatting to workaround matplotlib issue #12925
|
23
|
+
locator = dates.AutoDateLocator.new(interval_multiples: false)
|
24
|
+
formatter = dates.AutoDateFormatter.new(locator)
|
25
|
+
ax.xaxis.set_major_locator(locator)
|
26
|
+
ax.xaxis.set_major_formatter(formatter)
|
27
|
+
ax.grid(true, which: "major", c: "gray", ls: "-", lw: 1, alpha: 0.2)
|
28
|
+
ax.set_xlabel(xlabel)
|
29
|
+
ax.set_ylabel(ylabel)
|
30
|
+
fig.tight_layout
|
31
|
+
fig
|
32
|
+
end
|
33
|
+
|
34
|
+
def plot_components(fcst, uncertainty: true, plot_cap: true, weekly_start: 0, yearly_start: 0, figsize: nil)
|
35
|
+
components = ["trend"]
|
36
|
+
if @train_holiday_names && fcst.vectors.include?("holidays")
|
37
|
+
components << "holidays"
|
38
|
+
end
|
39
|
+
# Plot weekly seasonality, if present
|
40
|
+
if @seasonalities["weekly"] && fcst.vectors.include?("weekly")
|
41
|
+
components << "weekly"
|
42
|
+
end
|
43
|
+
# Yearly if present
|
44
|
+
if @seasonalities["yearly"] && fcst.vectors.include?("yearly")
|
45
|
+
components << "yearly"
|
46
|
+
end
|
47
|
+
# Other seasonalities
|
48
|
+
components.concat(@seasonalities.keys.select { |name| fcst.vectors.include?(name) && !["weekly", "yearly"].include?(name) }.sort)
|
49
|
+
regressors = {"additive" => false, "multiplicative" => false}
|
50
|
+
@extra_regressors.each do |name, props|
|
51
|
+
regressors[props[:mode]] = true
|
52
|
+
end
|
53
|
+
["additive", "multiplicative"].each do |mode|
|
54
|
+
if regressors[mode] && fcst.vectors.include?("extra_regressors_#{mode}")
|
55
|
+
components << "extra_regressors_#{mode}"
|
56
|
+
end
|
57
|
+
end
|
58
|
+
npanel = components.size
|
59
|
+
|
60
|
+
figsize = figsize || [9, 3 * npanel]
|
61
|
+
fig, axes = plt.subplots(npanel, 1, facecolor: "w", figsize: figsize)
|
62
|
+
|
63
|
+
if npanel == 1
|
64
|
+
axes = [axes]
|
65
|
+
end
|
66
|
+
|
67
|
+
multiplicative_axes = []
|
68
|
+
|
69
|
+
axes.tolist.zip(components) do |ax, plot_name|
|
70
|
+
if plot_name == "trend"
|
71
|
+
plot_forecast_component(fcst, "trend", ax: ax, uncertainty: uncertainty, plot_cap: plot_cap)
|
72
|
+
elsif @seasonalities[plot_name]
|
73
|
+
if plot_name == "weekly" || @seasonalities[plot_name][:period] == 7
|
74
|
+
plot_weekly(name: plot_name, ax: ax, uncertainty: uncertainty, weekly_start: weekly_start)
|
75
|
+
elsif plot_name == "yearly" || @seasonalities[plot_name][:period] == 365.25
|
76
|
+
plot_yearly(name: plot_name, ax: ax, uncertainty: uncertainty, yearly_start: yearly_start)
|
77
|
+
else
|
78
|
+
plot_seasonality(name: plot_name, ax: ax, uncertainty: uncertainty)
|
79
|
+
end
|
80
|
+
elsif ["holidays", "extra_regressors_additive", "extra_regressors_multiplicative"].include?(plot_name)
|
81
|
+
plot_forecast_component(fcst, plot_name, ax: ax, uncertainty: uncertainty, plot_cap: false)
|
82
|
+
end
|
83
|
+
if @component_modes["multiplicative"].include?(plot_name)
|
84
|
+
multiplicative_axes << ax
|
85
|
+
end
|
86
|
+
end
|
87
|
+
|
88
|
+
fig.tight_layout
|
89
|
+
# Reset multiplicative axes labels after tight_layout adjustment
|
90
|
+
multiplicative_axes.each do |ax|
|
91
|
+
ax = set_y_as_percent(ax)
|
92
|
+
end
|
93
|
+
fig
|
94
|
+
end
|
95
|
+
|
96
|
+
private
|
97
|
+
|
98
|
+
def plot_forecast_component(fcst, name, ax: nil, uncertainty: true, plot_cap: false, figsize: [10, 6])
|
99
|
+
artists = []
|
100
|
+
if !ax
|
101
|
+
fig = plt.figure(facecolor: "w", figsize: figsize)
|
102
|
+
ax = fig.add_subplot(111)
|
103
|
+
end
|
104
|
+
fcst_t = to_pydatetime(fcst["ds"])
|
105
|
+
artists += ax.plot(fcst_t, fcst[name].map(&:to_f), ls: "-", c: "#0072B2")
|
106
|
+
if fcst.vectors.include?("cap") && plot_cap
|
107
|
+
artists += ax.plot(fcst_t, fcst["cap"].map(&:to_f), ls: "--", c: "k")
|
108
|
+
end
|
109
|
+
if @logistic_floor && fcst.vectors.include?("floor") && plot_cap
|
110
|
+
ax.plot(fcst_t, fcst["floor"].map(&:to_f), ls: "--", c: "k")
|
111
|
+
end
|
112
|
+
if uncertainty && @uncertainty_samples
|
113
|
+
artists += [ax.fill_between(fcst_t, fcst[name + "_lower"].map(&:to_f), fcst[name + "_upper"].map(&:to_f), color: "#0072B2", alpha: 0.2)]
|
114
|
+
end
|
115
|
+
# Specify formatting to workaround matplotlib issue #12925
|
116
|
+
locator = dates.AutoDateLocator.new(interval_multiples: false)
|
117
|
+
formatter = dates.AutoDateFormatter.new(locator)
|
118
|
+
ax.xaxis.set_major_locator(locator)
|
119
|
+
ax.xaxis.set_major_formatter(formatter)
|
120
|
+
ax.grid(true, which: "major", c: "gray", ls: "-", lw: 1, alpha: 0.2)
|
121
|
+
ax.set_xlabel("ds")
|
122
|
+
ax.set_ylabel(name)
|
123
|
+
if @component_modes["multiplicative"].include?(name)
|
124
|
+
ax = set_y_as_percent(ax)
|
125
|
+
end
|
126
|
+
artists
|
127
|
+
end
|
128
|
+
|
129
|
+
def seasonality_plot_df(ds)
|
130
|
+
df_dict = {"ds" => ds, "cap" => [1.0] * ds.size, "floor" => [0.0] * ds.size}
|
131
|
+
@extra_regressors.each do |name|
|
132
|
+
df_dict[name] = [0.0] * ds.size
|
133
|
+
end
|
134
|
+
# Activate all conditional seasonality columns
|
135
|
+
@seasonalities.values.each do |props|
|
136
|
+
if props[:condition_name]
|
137
|
+
df_dict[props[:condition_name]] = [true] * ds.size
|
138
|
+
end
|
139
|
+
end
|
140
|
+
df = Daru::DataFrame.new(df_dict)
|
141
|
+
df = setup_dataframe(df)
|
142
|
+
df
|
143
|
+
end
|
144
|
+
|
145
|
+
def plot_weekly(ax: nil, uncertainty: true, weekly_start: 0, figsize: [10, 6], name: "weekly")
|
146
|
+
artists = []
|
147
|
+
if !ax
|
148
|
+
fig = plt.figure(facecolor: "w", figsize: figsize)
|
149
|
+
ax = fig.add_subplot(111)
|
150
|
+
end
|
151
|
+
# Compute weekly seasonality for a Sun-Sat sequence of dates.
|
152
|
+
start = Date.parse("2017-01-01")
|
153
|
+
days = 7.times.map { |i| start + i + weekly_start }
|
154
|
+
df_w = seasonality_plot_df(days)
|
155
|
+
seas = predict_seasonal_components(df_w)
|
156
|
+
days = days.map { |v| v.strftime("%A") }
|
157
|
+
artists += ax.plot(days.size.times.to_a, seas[name].map(&:to_f), ls: "-", c: "#0072B2")
|
158
|
+
if uncertainty && @uncertainty_samples
|
159
|
+
artists += [ax.fill_between(days.size.times.to_a, seas[name + "_lower"].map(&:to_f), seas[name + "_upper"].map(&:to_f), color: "#0072B2", alpha: 0.2)]
|
160
|
+
end
|
161
|
+
ax.grid(true, which: "major", c: "gray", ls: "-", lw: 1, alpha: 0.2)
|
162
|
+
ax.set_xticks(days.size.times.to_a)
|
163
|
+
ax.set_xticklabels(days)
|
164
|
+
ax.set_xlabel("Day of week")
|
165
|
+
ax.set_ylabel(name)
|
166
|
+
if @seasonalities[name]["mode"] == "multiplicative"
|
167
|
+
ax = set_y_as_percent(ax)
|
168
|
+
end
|
169
|
+
artists
|
170
|
+
end
|
171
|
+
|
172
|
+
def plot_yearly(ax: nil, uncertainty: true, yearly_start: 0, figsize: [10, 6], name: "yearly")
|
173
|
+
artists = []
|
174
|
+
if !ax
|
175
|
+
fig = plt.figure(facecolor: "w", figsize: figsize)
|
176
|
+
ax = fig.add_subplot(111)
|
177
|
+
end
|
178
|
+
# Compute yearly seasonality for a Jan 1 - Dec 31 sequence of dates.
|
179
|
+
start = Date.parse("2017-01-01")
|
180
|
+
days = 365.times.map { |i| start + i + yearly_start }
|
181
|
+
df_y = seasonality_plot_df(days)
|
182
|
+
seas = predict_seasonal_components(df_y)
|
183
|
+
artists += ax.plot(to_pydatetime(df_y["ds"]), seas[name].map(&:to_f), ls: "-", c: "#0072B2")
|
184
|
+
if uncertainty && @uncertainty_samples
|
185
|
+
artists += [ax.fill_between(to_pydatetime(df_y["ds"]), seas[name + "_lower"].map(&:to_f), seas[name + "_upper"].map(&:to_f), color: "#0072B2", alpha: 0.2)]
|
186
|
+
end
|
187
|
+
ax.grid(true, which: "major", c: "gray", ls: "-", lw: 1, alpha: 0.2)
|
188
|
+
months = dates.MonthLocator.new((1..12).to_a, bymonthday: 1, interval: 2)
|
189
|
+
ax.xaxis.set_major_formatter(ticker.FuncFormatter.new(lambda { |x, pos=nil| dates.num2date(x).strftime("%B %-e") }))
|
190
|
+
ax.xaxis.set_major_locator(months)
|
191
|
+
ax.set_xlabel("Day of year")
|
192
|
+
ax.set_ylabel(name)
|
193
|
+
if @seasonalities[name][:mode] == "multiplicative"
|
194
|
+
ax = set_y_as_percent(ax)
|
195
|
+
end
|
196
|
+
artists
|
197
|
+
end
|
198
|
+
|
199
|
+
def plot_seasonality(name:, ax: nil, uncertainty: true, figsize: [10, 6])
|
200
|
+
artists = []
|
201
|
+
if !ax
|
202
|
+
fig = plt.figure(facecolor: "w", figsize: figsize)
|
203
|
+
ax = fig.add_subplot(111)
|
204
|
+
end
|
205
|
+
# Compute seasonality from Jan 1 through a single period.
|
206
|
+
start = Time.utc(2017)
|
207
|
+
period = @seasonalities[name][:period]
|
208
|
+
finish = start + period * 86400
|
209
|
+
plot_points = 200
|
210
|
+
start = start.to_i
|
211
|
+
finish = finish.to_i
|
212
|
+
step = (finish - start) / (plot_points - 1).to_f
|
213
|
+
days = plot_points.times.map { |i| Time.at(start + i * step).utc }
|
214
|
+
df_y = seasonality_plot_df(days)
|
215
|
+
seas = predict_seasonal_components(df_y)
|
216
|
+
artists += ax.plot(to_pydatetime(df_y["ds"]), seas[name].map(&:to_f), ls: "-", c: "#0072B2")
|
217
|
+
if uncertainty && @uncertainty_samples
|
218
|
+
artists += [ax.fill_between(to_pydatetime(df_y["ds"]), seas[name + "_lower"].map(&:to_f), seas[name + "_upper"].map(&:to_f), color: "#0072B2", alpha: 0.2)]
|
219
|
+
end
|
220
|
+
ax.grid(true, which: "major", c: "gray", ls: "-", lw: 1, alpha: 0.2)
|
221
|
+
step = (finish - start) / (7 - 1).to_f
|
222
|
+
xticks = to_pydatetime(7.times.map { |i| Time.at(start + i * step).utc })
|
223
|
+
ax.set_xticks(xticks)
|
224
|
+
if period <= 2
|
225
|
+
fmt_str = "%T"
|
226
|
+
elsif period < 14
|
227
|
+
fmt_str = "%m/%d %R"
|
228
|
+
else
|
229
|
+
fmt_str = "%m/%d"
|
230
|
+
end
|
231
|
+
ax.xaxis.set_major_formatter(ticker.FuncFormatter.new(lambda { |x, pos=nil| dates.num2date(x).strftime(fmt_str) }))
|
232
|
+
ax.set_xlabel("ds")
|
233
|
+
ax.set_ylabel(name)
|
234
|
+
if @seasonalities[name][:mode] == "multiplicative"
|
235
|
+
ax = set_y_as_percent(ax)
|
236
|
+
end
|
237
|
+
artists
|
238
|
+
end
|
239
|
+
|
240
|
+
def set_y_as_percent(ax)
|
241
|
+
yticks = 100 * ax.get_yticks
|
242
|
+
yticklabels = yticks.tolist.map { |y| "%.4g%%" % y }
|
243
|
+
ax.set_yticklabels(yticklabels)
|
244
|
+
ax
|
245
|
+
end
|
246
|
+
|
247
|
+
def plt
|
248
|
+
begin
|
249
|
+
require "matplotlib/pyplot"
|
250
|
+
rescue LoadError
|
251
|
+
raise Error, "Install the matplotlib gem for plots"
|
252
|
+
end
|
253
|
+
Matplotlib::Pyplot
|
254
|
+
end
|
255
|
+
|
256
|
+
def dates
|
257
|
+
PyCall.import_module("matplotlib.dates")
|
258
|
+
end
|
259
|
+
|
260
|
+
def ticker
|
261
|
+
PyCall.import_module("matplotlib.ticker")
|
262
|
+
end
|
263
|
+
|
264
|
+
def to_pydatetime(v)
|
265
|
+
datetime = PyCall.import_module("datetime")
|
266
|
+
v.map { |v| datetime.datetime.utcfromtimestamp(v.to_i) }
|
267
|
+
end
|
268
|
+
end
|
269
|
+
end
|
@@ -0,0 +1,136 @@
|
|
1
|
+
module Prophet
|
2
|
+
class StanBackend
|
3
|
+
def initialize(logger)
|
4
|
+
@model = load_model
|
5
|
+
@logger = logger
|
6
|
+
end
|
7
|
+
|
8
|
+
def load_model
|
9
|
+
model_file = File.expand_path("../../stan_model/prophet_model.bin", __dir__)
|
10
|
+
CmdStan::Model.new(exe_file: model_file)
|
11
|
+
end
|
12
|
+
|
13
|
+
def fit(stan_init, stan_data, **kwargs)
|
14
|
+
stan_init, stan_data = prepare_data(stan_init, stan_data)
|
15
|
+
kwargs[:algorithm] ||= stan_data["T"] < 100 ? "Newton" : "LBFGS"
|
16
|
+
iterations = 10000
|
17
|
+
|
18
|
+
stan_fit = nil
|
19
|
+
begin
|
20
|
+
stan_fit = @model.optimize(
|
21
|
+
data: stan_data,
|
22
|
+
inits: stan_init,
|
23
|
+
iter: iterations,
|
24
|
+
**kwargs
|
25
|
+
)
|
26
|
+
rescue => e
|
27
|
+
if kwargs[:algorithm] != "Newton"
|
28
|
+
@logger.warn "Optimization terminated abnormally. Falling back to Newton."
|
29
|
+
kwargs[:algorithm] = "Newton"
|
30
|
+
stan_fit = @model.optimize(
|
31
|
+
data: stan_data,
|
32
|
+
inits: stan_init,
|
33
|
+
iter: iterations,
|
34
|
+
**kwargs
|
35
|
+
)
|
36
|
+
else
|
37
|
+
raise e
|
38
|
+
end
|
39
|
+
end
|
40
|
+
|
41
|
+
params = stan_to_numo(stan_fit.column_names, Numo::NArray.asarray(stan_fit.optimized_params.values))
|
42
|
+
params.each_key do |par|
|
43
|
+
params[par] = params[par].reshape(1, *params[par].shape)
|
44
|
+
end
|
45
|
+
params
|
46
|
+
end
|
47
|
+
|
48
|
+
def sampling(stan_init, stan_data, samples, **kwargs)
|
49
|
+
stan_init, stan_data = prepare_data(stan_init, stan_data)
|
50
|
+
|
51
|
+
kwargs[:chains] ||= 4
|
52
|
+
kwargs[:warmup_iters] ||= samples / 2
|
53
|
+
|
54
|
+
stan_fit = @model.sample(
|
55
|
+
data: stan_data,
|
56
|
+
inits: stan_init,
|
57
|
+
sampling_iters: samples,
|
58
|
+
**kwargs
|
59
|
+
)
|
60
|
+
res = Numo::NArray.asarray(stan_fit.sample)
|
61
|
+
samples, c, columns = res.shape
|
62
|
+
res = res.reshape(samples * c, columns)
|
63
|
+
params = stan_to_numo(stan_fit.column_names, res)
|
64
|
+
|
65
|
+
params.each_key do |par|
|
66
|
+
s = params[par].shape
|
67
|
+
|
68
|
+
if s[1] == 1
|
69
|
+
params[par] = params[par].reshape(s[0])
|
70
|
+
end
|
71
|
+
|
72
|
+
if ["delta", "beta"].include?(par) && s.size < 2
|
73
|
+
params[par] = params[par].reshape(-1, 1)
|
74
|
+
end
|
75
|
+
end
|
76
|
+
|
77
|
+
params
|
78
|
+
end
|
79
|
+
|
80
|
+
private
|
81
|
+
|
82
|
+
def stan_to_numo(column_names, data)
|
83
|
+
output = {}
|
84
|
+
|
85
|
+
prev = nil
|
86
|
+
|
87
|
+
start = 0
|
88
|
+
finish = 0
|
89
|
+
|
90
|
+
two_dims = data.shape.size > 1
|
91
|
+
|
92
|
+
column_names.each do |cname|
|
93
|
+
parsed = cname.split(".")
|
94
|
+
|
95
|
+
curr = parsed[0]
|
96
|
+
prev = curr if prev.nil?
|
97
|
+
|
98
|
+
if curr != prev
|
99
|
+
raise Error, "Found repeated column name" if output[prev]
|
100
|
+
if two_dims
|
101
|
+
output[prev] = Numo::NArray.asarray(data[true, start...finish])
|
102
|
+
else
|
103
|
+
output[prev] = Numo::NArray.asarray(data[start...finish])
|
104
|
+
end
|
105
|
+
prev = curr
|
106
|
+
start = finish
|
107
|
+
finish += 1
|
108
|
+
else
|
109
|
+
finish += 1
|
110
|
+
end
|
111
|
+
end
|
112
|
+
|
113
|
+
raise Error, "Found repeated column name" if output[prev]
|
114
|
+
if two_dims
|
115
|
+
output[prev] = Numo::NArray.asarray(data[true, start...finish])
|
116
|
+
else
|
117
|
+
output[prev] = Numo::NArray.asarray(data[start...finish])
|
118
|
+
end
|
119
|
+
|
120
|
+
output
|
121
|
+
end
|
122
|
+
|
123
|
+
def prepare_data(stan_init, stan_data)
|
124
|
+
stan_data["y"] = stan_data["y"].to_a
|
125
|
+
stan_data["t"] = stan_data["t"].to_a
|
126
|
+
stan_data["cap"] = stan_data["cap"].to_a
|
127
|
+
stan_data["t_change"] = stan_data["t_change"].to_a
|
128
|
+
stan_data["s_a"] = stan_data["s_a"].to_a
|
129
|
+
stan_data["s_m"] = stan_data["s_m"].to_a
|
130
|
+
stan_data["X"] = stan_data["X"].to_matrix.to_a
|
131
|
+
stan_init["delta"] = stan_init["delta"].to_a
|
132
|
+
stan_init["beta"] = stan_init["beta"].to_a
|
133
|
+
[stan_init, stan_data]
|
134
|
+
end
|
135
|
+
end
|
136
|
+
end
|
@@ -0,0 +1,131 @@
|
|
1
|
+
// Copyright (c) Facebook, Inc. and its affiliates.
|
2
|
+
|
3
|
+
// This source code is licensed under the MIT license found in the
|
4
|
+
// LICENSE file in the root directory of this source tree.
|
5
|
+
|
6
|
+
functions {
|
7
|
+
matrix get_changepoint_matrix(vector t, vector t_change, int T, int S) {
|
8
|
+
// Assumes t and t_change are sorted.
|
9
|
+
matrix[T, S] A;
|
10
|
+
row_vector[S] a_row;
|
11
|
+
int cp_idx;
|
12
|
+
|
13
|
+
// Start with an empty matrix.
|
14
|
+
A = rep_matrix(0, T, S);
|
15
|
+
a_row = rep_row_vector(0, S);
|
16
|
+
cp_idx = 1;
|
17
|
+
|
18
|
+
// Fill in each row of A.
|
19
|
+
for (i in 1:T) {
|
20
|
+
while ((cp_idx <= S) && (t[i] >= t_change[cp_idx])) {
|
21
|
+
a_row[cp_idx] = 1;
|
22
|
+
cp_idx = cp_idx + 1;
|
23
|
+
}
|
24
|
+
A[i] = a_row;
|
25
|
+
}
|
26
|
+
return A;
|
27
|
+
}
|
28
|
+
|
29
|
+
// Logistic trend functions
|
30
|
+
|
31
|
+
vector logistic_gamma(real k, real m, vector delta, vector t_change, int S) {
|
32
|
+
vector[S] gamma; // adjusted offsets, for piecewise continuity
|
33
|
+
vector[S + 1] k_s; // actual rate in each segment
|
34
|
+
real m_pr;
|
35
|
+
|
36
|
+
// Compute the rate in each segment
|
37
|
+
k_s = append_row(k, k + cumulative_sum(delta));
|
38
|
+
|
39
|
+
// Piecewise offsets
|
40
|
+
m_pr = m; // The offset in the previous segment
|
41
|
+
for (i in 1:S) {
|
42
|
+
gamma[i] = (t_change[i] - m_pr) * (1 - k_s[i] / k_s[i + 1]);
|
43
|
+
m_pr = m_pr + gamma[i]; // update for the next segment
|
44
|
+
}
|
45
|
+
return gamma;
|
46
|
+
}
|
47
|
+
|
48
|
+
vector logistic_trend(
|
49
|
+
real k,
|
50
|
+
real m,
|
51
|
+
vector delta,
|
52
|
+
vector t,
|
53
|
+
vector cap,
|
54
|
+
matrix A,
|
55
|
+
vector t_change,
|
56
|
+
int S
|
57
|
+
) {
|
58
|
+
vector[S] gamma;
|
59
|
+
|
60
|
+
gamma = logistic_gamma(k, m, delta, t_change, S);
|
61
|
+
return cap .* inv_logit((k + A * delta) .* (t - (m + A * gamma)));
|
62
|
+
}
|
63
|
+
|
64
|
+
// Linear trend function
|
65
|
+
|
66
|
+
vector linear_trend(
|
67
|
+
real k,
|
68
|
+
real m,
|
69
|
+
vector delta,
|
70
|
+
vector t,
|
71
|
+
matrix A,
|
72
|
+
vector t_change
|
73
|
+
) {
|
74
|
+
return (k + A * delta) .* t + (m + A * (-t_change .* delta));
|
75
|
+
}
|
76
|
+
}
|
77
|
+
|
78
|
+
data {
|
79
|
+
int T; // Number of time periods
|
80
|
+
int<lower=1> K; // Number of regressors
|
81
|
+
vector[T] t; // Time
|
82
|
+
vector[T] cap; // Capacities for logistic trend
|
83
|
+
vector[T] y; // Time series
|
84
|
+
int S; // Number of changepoints
|
85
|
+
vector[S] t_change; // Times of trend changepoints
|
86
|
+
matrix[T,K] X; // Regressors
|
87
|
+
vector[K] sigmas; // Scale on seasonality prior
|
88
|
+
real<lower=0> tau; // Scale on changepoints prior
|
89
|
+
int trend_indicator; // 0 for linear, 1 for logistic
|
90
|
+
vector[K] s_a; // Indicator of additive features
|
91
|
+
vector[K] s_m; // Indicator of multiplicative features
|
92
|
+
}
|
93
|
+
|
94
|
+
transformed data {
|
95
|
+
matrix[T, S] A;
|
96
|
+
A = get_changepoint_matrix(t, t_change, T, S);
|
97
|
+
}
|
98
|
+
|
99
|
+
parameters {
|
100
|
+
real k; // Base trend growth rate
|
101
|
+
real m; // Trend offset
|
102
|
+
vector[S] delta; // Trend rate adjustments
|
103
|
+
real<lower=0> sigma_obs; // Observation noise
|
104
|
+
vector[K] beta; // Regressor coefficients
|
105
|
+
}
|
106
|
+
|
107
|
+
model {
|
108
|
+
//priors
|
109
|
+
k ~ normal(0, 5);
|
110
|
+
m ~ normal(0, 5);
|
111
|
+
delta ~ double_exponential(0, tau);
|
112
|
+
sigma_obs ~ normal(0, 0.5);
|
113
|
+
beta ~ normal(0, sigmas);
|
114
|
+
|
115
|
+
// Likelihood
|
116
|
+
if (trend_indicator == 0) {
|
117
|
+
y ~ normal(
|
118
|
+
linear_trend(k, m, delta, t, A, t_change)
|
119
|
+
.* (1 + X * (beta .* s_m))
|
120
|
+
+ X * (beta .* s_a),
|
121
|
+
sigma_obs
|
122
|
+
);
|
123
|
+
} else if (trend_indicator == 1) {
|
124
|
+
y ~ normal(
|
125
|
+
logistic_trend(k, m, delta, t, cap, A, t_change, S)
|
126
|
+
.* (1 + X * (beta .* s_m))
|
127
|
+
+ X * (beta .* s_a),
|
128
|
+
sigma_obs
|
129
|
+
);
|
130
|
+
}
|
131
|
+
}
|