prophet-rb 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +23 -0
- data/README.md +202 -0
- data/data-raw/generated_holidays.csv +96474 -0
- data/ext/prophet/Makefile +5 -0
- data/ext/prophet/extconf.rb +18 -0
- data/lib/prophet-rb.rb +1 -0
- data/lib/prophet.rb +23 -0
- data/lib/prophet/forecaster.rb +986 -0
- data/lib/prophet/holidays.rb +27 -0
- data/lib/prophet/plot.rb +269 -0
- data/lib/prophet/stan_backend.rb +136 -0
- data/lib/prophet/version.rb +3 -0
- data/stan/unix/prophet.stan +131 -0
- data/stan/win/prophet.stan +162 -0
- metadata +170 -0
@@ -0,0 +1,27 @@
|
|
1
|
+
module Prophet
|
2
|
+
module Holidays
|
3
|
+
def get_holiday_names(country)
|
4
|
+
years = (1995..2045).to_a
|
5
|
+
make_holidays_df(years, country)["holiday"].uniq
|
6
|
+
end
|
7
|
+
|
8
|
+
def make_holidays_df(year_list, country)
|
9
|
+
holidays_df.where(holidays_df["country"].eq(country) & holidays_df["year"].in(year_list))["ds", "holiday"]
|
10
|
+
end
|
11
|
+
|
12
|
+
# TODO marshal on installation
|
13
|
+
def holidays_df
|
14
|
+
@holidays_df ||= begin
|
15
|
+
holidays = {"ds" => [], "holiday" => [], "country" => [], "year" => []}
|
16
|
+
holidays_file = File.expand_path("../../data-raw/generated_holidays.csv", __dir__)
|
17
|
+
CSV.foreach(holidays_file, headers: true, converters: [:date, :numeric]) do |row|
|
18
|
+
holidays["ds"] << row["ds"]
|
19
|
+
holidays["holiday"] << row["holiday"]
|
20
|
+
holidays["country"] << row["country"]
|
21
|
+
holidays["year"] << row["year"]
|
22
|
+
end
|
23
|
+
Daru::DataFrame.new(holidays)
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
data/lib/prophet/plot.rb
ADDED
@@ -0,0 +1,269 @@
|
|
1
|
+
module Prophet
|
2
|
+
module Plot
|
3
|
+
def plot(fcst, ax: nil, uncertainty: true, plot_cap: true, xlabel: "ds", ylabel: "y", figsize: [10, 6])
|
4
|
+
if ax.nil?
|
5
|
+
fig = plt.figure(facecolor: "w", figsize: figsize)
|
6
|
+
ax = fig.add_subplot(111)
|
7
|
+
else
|
8
|
+
fig = ax.get_figure
|
9
|
+
end
|
10
|
+
fcst_t = to_pydatetime(fcst["ds"])
|
11
|
+
ax.plot(to_pydatetime(@history["ds"]), @history["y"].map(&:to_f), "k.")
|
12
|
+
ax.plot(fcst_t, fcst["yhat"].map(&:to_f), ls: "-", c: "#0072B2")
|
13
|
+
if fcst.vectors.include?("cap") && plot_cap
|
14
|
+
ax.plot(fcst_t, fcst["cap"].map(&:to_f), ls: "--", c: "k")
|
15
|
+
end
|
16
|
+
if @logistic_floor && fcst.vectors.include?("floor") && plot_cap
|
17
|
+
ax.plot(fcst_t, fcst["floor"].map(&:to_f), ls: "--", c: "k")
|
18
|
+
end
|
19
|
+
if uncertainty && @uncertainty_samples
|
20
|
+
ax.fill_between(fcst_t, fcst["yhat_lower"].map(&:to_f), fcst["yhat_upper"].map(&:to_f), color: "#0072B2", alpha: 0.2)
|
21
|
+
end
|
22
|
+
# Specify formatting to workaround matplotlib issue #12925
|
23
|
+
locator = dates.AutoDateLocator.new(interval_multiples: false)
|
24
|
+
formatter = dates.AutoDateFormatter.new(locator)
|
25
|
+
ax.xaxis.set_major_locator(locator)
|
26
|
+
ax.xaxis.set_major_formatter(formatter)
|
27
|
+
ax.grid(true, which: "major", c: "gray", ls: "-", lw: 1, alpha: 0.2)
|
28
|
+
ax.set_xlabel(xlabel)
|
29
|
+
ax.set_ylabel(ylabel)
|
30
|
+
fig.tight_layout
|
31
|
+
fig
|
32
|
+
end
|
33
|
+
|
34
|
+
def plot_components(fcst, uncertainty: true, plot_cap: true, weekly_start: 0, yearly_start: 0, figsize: nil)
|
35
|
+
components = ["trend"]
|
36
|
+
if @train_holiday_names && fcst.vectors.include?("holidays")
|
37
|
+
components << "holidays"
|
38
|
+
end
|
39
|
+
# Plot weekly seasonality, if present
|
40
|
+
if @seasonalities["weekly"] && fcst.vectors.include?("weekly")
|
41
|
+
components << "weekly"
|
42
|
+
end
|
43
|
+
# Yearly if present
|
44
|
+
if @seasonalities["yearly"] && fcst.vectors.include?("yearly")
|
45
|
+
components << "yearly"
|
46
|
+
end
|
47
|
+
# Other seasonalities
|
48
|
+
components.concat(@seasonalities.keys.select { |name| fcst.vectors.include?(name) && !["weekly", "yearly"].include?(name) }.sort)
|
49
|
+
regressors = {"additive" => false, "multiplicative" => false}
|
50
|
+
@extra_regressors.each do |name, props|
|
51
|
+
regressors[props[:mode]] = true
|
52
|
+
end
|
53
|
+
["additive", "multiplicative"].each do |mode|
|
54
|
+
if regressors[mode] && fcst.vectors.include?("extra_regressors_#{mode}")
|
55
|
+
components << "extra_regressors_#{mode}"
|
56
|
+
end
|
57
|
+
end
|
58
|
+
npanel = components.size
|
59
|
+
|
60
|
+
figsize = figsize || [9, 3 * npanel]
|
61
|
+
fig, axes = plt.subplots(npanel, 1, facecolor: "w", figsize: figsize)
|
62
|
+
|
63
|
+
if npanel == 1
|
64
|
+
axes = [axes]
|
65
|
+
end
|
66
|
+
|
67
|
+
multiplicative_axes = []
|
68
|
+
|
69
|
+
axes.tolist.zip(components) do |ax, plot_name|
|
70
|
+
if plot_name == "trend"
|
71
|
+
plot_forecast_component(fcst, "trend", ax: ax, uncertainty: uncertainty, plot_cap: plot_cap)
|
72
|
+
elsif @seasonalities[plot_name]
|
73
|
+
if plot_name == "weekly" || @seasonalities[plot_name][:period] == 7
|
74
|
+
plot_weekly(name: plot_name, ax: ax, uncertainty: uncertainty, weekly_start: weekly_start)
|
75
|
+
elsif plot_name == "yearly" || @seasonalities[plot_name][:period] == 365.25
|
76
|
+
plot_yearly(name: plot_name, ax: ax, uncertainty: uncertainty, yearly_start: yearly_start)
|
77
|
+
else
|
78
|
+
plot_seasonality(name: plot_name, ax: ax, uncertainty: uncertainty)
|
79
|
+
end
|
80
|
+
elsif ["holidays", "extra_regressors_additive", "extra_regressors_multiplicative"].include?(plot_name)
|
81
|
+
plot_forecast_component(fcst, plot_name, ax: ax, uncertainty: uncertainty, plot_cap: false)
|
82
|
+
end
|
83
|
+
if @component_modes["multiplicative"].include?(plot_name)
|
84
|
+
multiplicative_axes << ax
|
85
|
+
end
|
86
|
+
end
|
87
|
+
|
88
|
+
fig.tight_layout
|
89
|
+
# Reset multiplicative axes labels after tight_layout adjustment
|
90
|
+
multiplicative_axes.each do |ax|
|
91
|
+
ax = set_y_as_percent(ax)
|
92
|
+
end
|
93
|
+
fig
|
94
|
+
end
|
95
|
+
|
96
|
+
private
|
97
|
+
|
98
|
+
def plot_forecast_component(fcst, name, ax: nil, uncertainty: true, plot_cap: false, figsize: [10, 6])
|
99
|
+
artists = []
|
100
|
+
if !ax
|
101
|
+
fig = plt.figure(facecolor: "w", figsize: figsize)
|
102
|
+
ax = fig.add_subplot(111)
|
103
|
+
end
|
104
|
+
fcst_t = to_pydatetime(fcst["ds"])
|
105
|
+
artists += ax.plot(fcst_t, fcst[name].map(&:to_f), ls: "-", c: "#0072B2")
|
106
|
+
if fcst.vectors.include?("cap") && plot_cap
|
107
|
+
artists += ax.plot(fcst_t, fcst["cap"].map(&:to_f), ls: "--", c: "k")
|
108
|
+
end
|
109
|
+
if @logistic_floor && fcst.vectors.include?("floor") && plot_cap
|
110
|
+
ax.plot(fcst_t, fcst["floor"].map(&:to_f), ls: "--", c: "k")
|
111
|
+
end
|
112
|
+
if uncertainty && @uncertainty_samples
|
113
|
+
artists += [ax.fill_between(fcst_t, fcst[name + "_lower"].map(&:to_f), fcst[name + "_upper"].map(&:to_f), color: "#0072B2", alpha: 0.2)]
|
114
|
+
end
|
115
|
+
# Specify formatting to workaround matplotlib issue #12925
|
116
|
+
locator = dates.AutoDateLocator.new(interval_multiples: false)
|
117
|
+
formatter = dates.AutoDateFormatter.new(locator)
|
118
|
+
ax.xaxis.set_major_locator(locator)
|
119
|
+
ax.xaxis.set_major_formatter(formatter)
|
120
|
+
ax.grid(true, which: "major", c: "gray", ls: "-", lw: 1, alpha: 0.2)
|
121
|
+
ax.set_xlabel("ds")
|
122
|
+
ax.set_ylabel(name)
|
123
|
+
if @component_modes["multiplicative"].include?(name)
|
124
|
+
ax = set_y_as_percent(ax)
|
125
|
+
end
|
126
|
+
artists
|
127
|
+
end
|
128
|
+
|
129
|
+
def seasonality_plot_df(ds)
|
130
|
+
df_dict = {"ds" => ds, "cap" => [1.0] * ds.size, "floor" => [0.0] * ds.size}
|
131
|
+
@extra_regressors.each do |name|
|
132
|
+
df_dict[name] = [0.0] * ds.size
|
133
|
+
end
|
134
|
+
# Activate all conditional seasonality columns
|
135
|
+
@seasonalities.values.each do |props|
|
136
|
+
if props[:condition_name]
|
137
|
+
df_dict[props[:condition_name]] = [true] * ds.size
|
138
|
+
end
|
139
|
+
end
|
140
|
+
df = Daru::DataFrame.new(df_dict)
|
141
|
+
df = setup_dataframe(df)
|
142
|
+
df
|
143
|
+
end
|
144
|
+
|
145
|
+
def plot_weekly(ax: nil, uncertainty: true, weekly_start: 0, figsize: [10, 6], name: "weekly")
|
146
|
+
artists = []
|
147
|
+
if !ax
|
148
|
+
fig = plt.figure(facecolor: "w", figsize: figsize)
|
149
|
+
ax = fig.add_subplot(111)
|
150
|
+
end
|
151
|
+
# Compute weekly seasonality for a Sun-Sat sequence of dates.
|
152
|
+
start = Date.parse("2017-01-01")
|
153
|
+
days = 7.times.map { |i| start + i + weekly_start }
|
154
|
+
df_w = seasonality_plot_df(days)
|
155
|
+
seas = predict_seasonal_components(df_w)
|
156
|
+
days = days.map { |v| v.strftime("%A") }
|
157
|
+
artists += ax.plot(days.size.times.to_a, seas[name].map(&:to_f), ls: "-", c: "#0072B2")
|
158
|
+
if uncertainty && @uncertainty_samples
|
159
|
+
artists += [ax.fill_between(days.size.times.to_a, seas[name + "_lower"].map(&:to_f), seas[name + "_upper"].map(&:to_f), color: "#0072B2", alpha: 0.2)]
|
160
|
+
end
|
161
|
+
ax.grid(true, which: "major", c: "gray", ls: "-", lw: 1, alpha: 0.2)
|
162
|
+
ax.set_xticks(days.size.times.to_a)
|
163
|
+
ax.set_xticklabels(days)
|
164
|
+
ax.set_xlabel("Day of week")
|
165
|
+
ax.set_ylabel(name)
|
166
|
+
if @seasonalities[name]["mode"] == "multiplicative"
|
167
|
+
ax = set_y_as_percent(ax)
|
168
|
+
end
|
169
|
+
artists
|
170
|
+
end
|
171
|
+
|
172
|
+
def plot_yearly(ax: nil, uncertainty: true, yearly_start: 0, figsize: [10, 6], name: "yearly")
|
173
|
+
artists = []
|
174
|
+
if !ax
|
175
|
+
fig = plt.figure(facecolor: "w", figsize: figsize)
|
176
|
+
ax = fig.add_subplot(111)
|
177
|
+
end
|
178
|
+
# Compute yearly seasonality for a Jan 1 - Dec 31 sequence of dates.
|
179
|
+
start = Date.parse("2017-01-01")
|
180
|
+
days = 365.times.map { |i| start + i + yearly_start }
|
181
|
+
df_y = seasonality_plot_df(days)
|
182
|
+
seas = predict_seasonal_components(df_y)
|
183
|
+
artists += ax.plot(to_pydatetime(df_y["ds"]), seas[name].map(&:to_f), ls: "-", c: "#0072B2")
|
184
|
+
if uncertainty && @uncertainty_samples
|
185
|
+
artists += [ax.fill_between(to_pydatetime(df_y["ds"]), seas[name + "_lower"].map(&:to_f), seas[name + "_upper"].map(&:to_f), color: "#0072B2", alpha: 0.2)]
|
186
|
+
end
|
187
|
+
ax.grid(true, which: "major", c: "gray", ls: "-", lw: 1, alpha: 0.2)
|
188
|
+
months = dates.MonthLocator.new((1..12).to_a, bymonthday: 1, interval: 2)
|
189
|
+
ax.xaxis.set_major_formatter(ticker.FuncFormatter.new(lambda { |x, pos=nil| dates.num2date(x).strftime("%B %-e") }))
|
190
|
+
ax.xaxis.set_major_locator(months)
|
191
|
+
ax.set_xlabel("Day of year")
|
192
|
+
ax.set_ylabel(name)
|
193
|
+
if @seasonalities[name][:mode] == "multiplicative"
|
194
|
+
ax = set_y_as_percent(ax)
|
195
|
+
end
|
196
|
+
artists
|
197
|
+
end
|
198
|
+
|
199
|
+
def plot_seasonality(name:, ax: nil, uncertainty: true, figsize: [10, 6])
|
200
|
+
artists = []
|
201
|
+
if !ax
|
202
|
+
fig = plt.figure(facecolor: "w", figsize: figsize)
|
203
|
+
ax = fig.add_subplot(111)
|
204
|
+
end
|
205
|
+
# Compute seasonality from Jan 1 through a single period.
|
206
|
+
start = Time.utc(2017)
|
207
|
+
period = @seasonalities[name][:period]
|
208
|
+
finish = start + period * 86400
|
209
|
+
plot_points = 200
|
210
|
+
start = start.to_i
|
211
|
+
finish = finish.to_i
|
212
|
+
step = (finish - start) / (plot_points - 1).to_f
|
213
|
+
days = plot_points.times.map { |i| Time.at(start + i * step).utc }
|
214
|
+
df_y = seasonality_plot_df(days)
|
215
|
+
seas = predict_seasonal_components(df_y)
|
216
|
+
artists += ax.plot(to_pydatetime(df_y["ds"]), seas[name].map(&:to_f), ls: "-", c: "#0072B2")
|
217
|
+
if uncertainty && @uncertainty_samples
|
218
|
+
artists += [ax.fill_between(to_pydatetime(df_y["ds"]), seas[name + "_lower"].map(&:to_f), seas[name + "_upper"].map(&:to_f), color: "#0072B2", alpha: 0.2)]
|
219
|
+
end
|
220
|
+
ax.grid(true, which: "major", c: "gray", ls: "-", lw: 1, alpha: 0.2)
|
221
|
+
step = (finish - start) / (7 - 1).to_f
|
222
|
+
xticks = to_pydatetime(7.times.map { |i| Time.at(start + i * step).utc })
|
223
|
+
ax.set_xticks(xticks)
|
224
|
+
if period <= 2
|
225
|
+
fmt_str = "%T"
|
226
|
+
elsif period < 14
|
227
|
+
fmt_str = "%m/%d %R"
|
228
|
+
else
|
229
|
+
fmt_str = "%m/%d"
|
230
|
+
end
|
231
|
+
ax.xaxis.set_major_formatter(ticker.FuncFormatter.new(lambda { |x, pos=nil| dates.num2date(x).strftime(fmt_str) }))
|
232
|
+
ax.set_xlabel("ds")
|
233
|
+
ax.set_ylabel(name)
|
234
|
+
if @seasonalities[name][:mode] == "multiplicative"
|
235
|
+
ax = set_y_as_percent(ax)
|
236
|
+
end
|
237
|
+
artists
|
238
|
+
end
|
239
|
+
|
240
|
+
def set_y_as_percent(ax)
|
241
|
+
yticks = 100 * ax.get_yticks
|
242
|
+
yticklabels = yticks.tolist.map { |y| "%.4g%%" % y }
|
243
|
+
ax.set_yticklabels(yticklabels)
|
244
|
+
ax
|
245
|
+
end
|
246
|
+
|
247
|
+
def plt
|
248
|
+
begin
|
249
|
+
require "matplotlib/pyplot"
|
250
|
+
rescue LoadError
|
251
|
+
raise Error, "Install the matplotlib gem for plots"
|
252
|
+
end
|
253
|
+
Matplotlib::Pyplot
|
254
|
+
end
|
255
|
+
|
256
|
+
def dates
|
257
|
+
PyCall.import_module("matplotlib.dates")
|
258
|
+
end
|
259
|
+
|
260
|
+
def ticker
|
261
|
+
PyCall.import_module("matplotlib.ticker")
|
262
|
+
end
|
263
|
+
|
264
|
+
def to_pydatetime(v)
|
265
|
+
datetime = PyCall.import_module("datetime")
|
266
|
+
v.map { |v| datetime.datetime.utcfromtimestamp(v.to_i) }
|
267
|
+
end
|
268
|
+
end
|
269
|
+
end
|
@@ -0,0 +1,136 @@
|
|
1
|
+
module Prophet
|
2
|
+
class StanBackend
|
3
|
+
def initialize(logger)
|
4
|
+
@model = load_model
|
5
|
+
@logger = logger
|
6
|
+
end
|
7
|
+
|
8
|
+
def load_model
|
9
|
+
model_file = File.expand_path("../../stan_model/prophet_model.bin", __dir__)
|
10
|
+
CmdStan::Model.new(exe_file: model_file)
|
11
|
+
end
|
12
|
+
|
13
|
+
def fit(stan_init, stan_data, **kwargs)
|
14
|
+
stan_init, stan_data = prepare_data(stan_init, stan_data)
|
15
|
+
kwargs[:algorithm] ||= stan_data["T"] < 100 ? "Newton" : "LBFGS"
|
16
|
+
iterations = 10000
|
17
|
+
|
18
|
+
stan_fit = nil
|
19
|
+
begin
|
20
|
+
stan_fit = @model.optimize(
|
21
|
+
data: stan_data,
|
22
|
+
inits: stan_init,
|
23
|
+
iter: iterations,
|
24
|
+
**kwargs
|
25
|
+
)
|
26
|
+
rescue => e
|
27
|
+
if kwargs[:algorithm] != "Newton"
|
28
|
+
@logger.warn "Optimization terminated abnormally. Falling back to Newton."
|
29
|
+
kwargs[:algorithm] = "Newton"
|
30
|
+
stan_fit = @model.optimize(
|
31
|
+
data: stan_data,
|
32
|
+
inits: stan_init,
|
33
|
+
iter: iterations,
|
34
|
+
**kwargs
|
35
|
+
)
|
36
|
+
else
|
37
|
+
raise e
|
38
|
+
end
|
39
|
+
end
|
40
|
+
|
41
|
+
params = stan_to_numo(stan_fit.column_names, Numo::NArray.asarray(stan_fit.optimized_params.values))
|
42
|
+
params.each_key do |par|
|
43
|
+
params[par] = params[par].reshape(1, *params[par].shape)
|
44
|
+
end
|
45
|
+
params
|
46
|
+
end
|
47
|
+
|
48
|
+
def sampling(stan_init, stan_data, samples, **kwargs)
|
49
|
+
stan_init, stan_data = prepare_data(stan_init, stan_data)
|
50
|
+
|
51
|
+
kwargs[:chains] ||= 4
|
52
|
+
kwargs[:warmup_iters] ||= samples / 2
|
53
|
+
|
54
|
+
stan_fit = @model.sample(
|
55
|
+
data: stan_data,
|
56
|
+
inits: stan_init,
|
57
|
+
sampling_iters: samples,
|
58
|
+
**kwargs
|
59
|
+
)
|
60
|
+
res = Numo::NArray.asarray(stan_fit.sample)
|
61
|
+
samples, c, columns = res.shape
|
62
|
+
res = res.reshape(samples * c, columns)
|
63
|
+
params = stan_to_numo(stan_fit.column_names, res)
|
64
|
+
|
65
|
+
params.each_key do |par|
|
66
|
+
s = params[par].shape
|
67
|
+
|
68
|
+
if s[1] == 1
|
69
|
+
params[par] = params[par].reshape(s[0])
|
70
|
+
end
|
71
|
+
|
72
|
+
if ["delta", "beta"].include?(par) && s.size < 2
|
73
|
+
params[par] = params[par].reshape(-1, 1)
|
74
|
+
end
|
75
|
+
end
|
76
|
+
|
77
|
+
params
|
78
|
+
end
|
79
|
+
|
80
|
+
private
|
81
|
+
|
82
|
+
def stan_to_numo(column_names, data)
|
83
|
+
output = {}
|
84
|
+
|
85
|
+
prev = nil
|
86
|
+
|
87
|
+
start = 0
|
88
|
+
finish = 0
|
89
|
+
|
90
|
+
two_dims = data.shape.size > 1
|
91
|
+
|
92
|
+
column_names.each do |cname|
|
93
|
+
parsed = cname.split(".")
|
94
|
+
|
95
|
+
curr = parsed[0]
|
96
|
+
prev = curr if prev.nil?
|
97
|
+
|
98
|
+
if curr != prev
|
99
|
+
raise Error, "Found repeated column name" if output[prev]
|
100
|
+
if two_dims
|
101
|
+
output[prev] = Numo::NArray.asarray(data[true, start...finish])
|
102
|
+
else
|
103
|
+
output[prev] = Numo::NArray.asarray(data[start...finish])
|
104
|
+
end
|
105
|
+
prev = curr
|
106
|
+
start = finish
|
107
|
+
finish += 1
|
108
|
+
else
|
109
|
+
finish += 1
|
110
|
+
end
|
111
|
+
end
|
112
|
+
|
113
|
+
raise Error, "Found repeated column name" if output[prev]
|
114
|
+
if two_dims
|
115
|
+
output[prev] = Numo::NArray.asarray(data[true, start...finish])
|
116
|
+
else
|
117
|
+
output[prev] = Numo::NArray.asarray(data[start...finish])
|
118
|
+
end
|
119
|
+
|
120
|
+
output
|
121
|
+
end
|
122
|
+
|
123
|
+
def prepare_data(stan_init, stan_data)
|
124
|
+
stan_data["y"] = stan_data["y"].to_a
|
125
|
+
stan_data["t"] = stan_data["t"].to_a
|
126
|
+
stan_data["cap"] = stan_data["cap"].to_a
|
127
|
+
stan_data["t_change"] = stan_data["t_change"].to_a
|
128
|
+
stan_data["s_a"] = stan_data["s_a"].to_a
|
129
|
+
stan_data["s_m"] = stan_data["s_m"].to_a
|
130
|
+
stan_data["X"] = stan_data["X"].to_matrix.to_a
|
131
|
+
stan_init["delta"] = stan_init["delta"].to_a
|
132
|
+
stan_init["beta"] = stan_init["beta"].to_a
|
133
|
+
[stan_init, stan_data]
|
134
|
+
end
|
135
|
+
end
|
136
|
+
end
|
@@ -0,0 +1,131 @@
|
|
1
|
+
// Copyright (c) Facebook, Inc. and its affiliates.
|
2
|
+
|
3
|
+
// This source code is licensed under the MIT license found in the
|
4
|
+
// LICENSE file in the root directory of this source tree.
|
5
|
+
|
6
|
+
functions {
|
7
|
+
matrix get_changepoint_matrix(vector t, vector t_change, int T, int S) {
|
8
|
+
// Assumes t and t_change are sorted.
|
9
|
+
matrix[T, S] A;
|
10
|
+
row_vector[S] a_row;
|
11
|
+
int cp_idx;
|
12
|
+
|
13
|
+
// Start with an empty matrix.
|
14
|
+
A = rep_matrix(0, T, S);
|
15
|
+
a_row = rep_row_vector(0, S);
|
16
|
+
cp_idx = 1;
|
17
|
+
|
18
|
+
// Fill in each row of A.
|
19
|
+
for (i in 1:T) {
|
20
|
+
while ((cp_idx <= S) && (t[i] >= t_change[cp_idx])) {
|
21
|
+
a_row[cp_idx] = 1;
|
22
|
+
cp_idx = cp_idx + 1;
|
23
|
+
}
|
24
|
+
A[i] = a_row;
|
25
|
+
}
|
26
|
+
return A;
|
27
|
+
}
|
28
|
+
|
29
|
+
// Logistic trend functions
|
30
|
+
|
31
|
+
vector logistic_gamma(real k, real m, vector delta, vector t_change, int S) {
|
32
|
+
vector[S] gamma; // adjusted offsets, for piecewise continuity
|
33
|
+
vector[S + 1] k_s; // actual rate in each segment
|
34
|
+
real m_pr;
|
35
|
+
|
36
|
+
// Compute the rate in each segment
|
37
|
+
k_s = append_row(k, k + cumulative_sum(delta));
|
38
|
+
|
39
|
+
// Piecewise offsets
|
40
|
+
m_pr = m; // The offset in the previous segment
|
41
|
+
for (i in 1:S) {
|
42
|
+
gamma[i] = (t_change[i] - m_pr) * (1 - k_s[i] / k_s[i + 1]);
|
43
|
+
m_pr = m_pr + gamma[i]; // update for the next segment
|
44
|
+
}
|
45
|
+
return gamma;
|
46
|
+
}
|
47
|
+
|
48
|
+
vector logistic_trend(
|
49
|
+
real k,
|
50
|
+
real m,
|
51
|
+
vector delta,
|
52
|
+
vector t,
|
53
|
+
vector cap,
|
54
|
+
matrix A,
|
55
|
+
vector t_change,
|
56
|
+
int S
|
57
|
+
) {
|
58
|
+
vector[S] gamma;
|
59
|
+
|
60
|
+
gamma = logistic_gamma(k, m, delta, t_change, S);
|
61
|
+
return cap .* inv_logit((k + A * delta) .* (t - (m + A * gamma)));
|
62
|
+
}
|
63
|
+
|
64
|
+
// Linear trend function
|
65
|
+
|
66
|
+
vector linear_trend(
|
67
|
+
real k,
|
68
|
+
real m,
|
69
|
+
vector delta,
|
70
|
+
vector t,
|
71
|
+
matrix A,
|
72
|
+
vector t_change
|
73
|
+
) {
|
74
|
+
return (k + A * delta) .* t + (m + A * (-t_change .* delta));
|
75
|
+
}
|
76
|
+
}
|
77
|
+
|
78
|
+
data {
|
79
|
+
int T; // Number of time periods
|
80
|
+
int<lower=1> K; // Number of regressors
|
81
|
+
vector[T] t; // Time
|
82
|
+
vector[T] cap; // Capacities for logistic trend
|
83
|
+
vector[T] y; // Time series
|
84
|
+
int S; // Number of changepoints
|
85
|
+
vector[S] t_change; // Times of trend changepoints
|
86
|
+
matrix[T,K] X; // Regressors
|
87
|
+
vector[K] sigmas; // Scale on seasonality prior
|
88
|
+
real<lower=0> tau; // Scale on changepoints prior
|
89
|
+
int trend_indicator; // 0 for linear, 1 for logistic
|
90
|
+
vector[K] s_a; // Indicator of additive features
|
91
|
+
vector[K] s_m; // Indicator of multiplicative features
|
92
|
+
}
|
93
|
+
|
94
|
+
transformed data {
|
95
|
+
matrix[T, S] A;
|
96
|
+
A = get_changepoint_matrix(t, t_change, T, S);
|
97
|
+
}
|
98
|
+
|
99
|
+
parameters {
|
100
|
+
real k; // Base trend growth rate
|
101
|
+
real m; // Trend offset
|
102
|
+
vector[S] delta; // Trend rate adjustments
|
103
|
+
real<lower=0> sigma_obs; // Observation noise
|
104
|
+
vector[K] beta; // Regressor coefficients
|
105
|
+
}
|
106
|
+
|
107
|
+
model {
|
108
|
+
//priors
|
109
|
+
k ~ normal(0, 5);
|
110
|
+
m ~ normal(0, 5);
|
111
|
+
delta ~ double_exponential(0, tau);
|
112
|
+
sigma_obs ~ normal(0, 0.5);
|
113
|
+
beta ~ normal(0, sigmas);
|
114
|
+
|
115
|
+
// Likelihood
|
116
|
+
if (trend_indicator == 0) {
|
117
|
+
y ~ normal(
|
118
|
+
linear_trend(k, m, delta, t, A, t_change)
|
119
|
+
.* (1 + X * (beta .* s_m))
|
120
|
+
+ X * (beta .* s_a),
|
121
|
+
sigma_obs
|
122
|
+
);
|
123
|
+
} else if (trend_indicator == 1) {
|
124
|
+
y ~ normal(
|
125
|
+
logistic_trend(k, m, delta, t, cap, A, t_change, S)
|
126
|
+
.* (1 + X * (beta .* s_m))
|
127
|
+
+ X * (beta .* s_a),
|
128
|
+
sigma_obs
|
129
|
+
);
|
130
|
+
}
|
131
|
+
}
|