prophet-rb 0.1.0 → 0.2.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +24 -0
- data/LICENSE.txt +1 -1
- data/README.md +131 -22
- data/lib/prophet.rb +64 -1
- data/lib/prophet/forecaster.rb +142 -137
- data/lib/prophet/holidays.rb +2 -2
- data/lib/prophet/plot.rb +48 -30
- data/lib/prophet/stan_backend.rb +1 -1
- data/lib/prophet/version.rb +1 -1
- metadata +37 -9
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 756e42a4e2c39e114610d2f41e2403d9b160e77da02c0483e43b4bf460dd828b
|
4
|
+
data.tar.gz: e7bce55f47410227131afcba9d2f90c9812dab093efd1c17465383b3a5e6dc73
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: ec014f32ff39abd49195d7e9a7f38b3b297bc7f6fc28da3d1be6fb19b6edbe816a681c4fbb35d5ea79fc2ced0aa948b0468cff9f6d9d8293590761711944b71b
|
7
|
+
data.tar.gz: 87ecc8706c73e7f1063c8d75cc5c687ebc3fddb92d89e60bb115d21c90a60fbb6172648a0d414baeb11f04572b37e06777d4100acb07b8559dea8e8a711741fb
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,27 @@
|
|
1
|
+
## 0.2.3 (2020-10-14)
|
2
|
+
|
3
|
+
- Added support for times to `forecast` method
|
4
|
+
|
5
|
+
## 0.2.2 (2020-07-26)
|
6
|
+
|
7
|
+
- Fixed error with constant series
|
8
|
+
- Fixed error with no changepoints
|
9
|
+
|
10
|
+
## 0.2.1 (2020-07-15)
|
11
|
+
|
12
|
+
- Added `forecast` method
|
13
|
+
|
14
|
+
## 0.2.0 (2020-05-13)
|
15
|
+
|
16
|
+
- Switched from Daru to Rover
|
17
|
+
|
18
|
+
## 0.1.1 (2020-04-10)
|
19
|
+
|
20
|
+
- Added `add_changepoints_to_plot`
|
21
|
+
- Fixed error with `changepoints` option
|
22
|
+
- Fixed error with `mcmc_samples` option
|
23
|
+
- Fixed error with additional regressors
|
24
|
+
|
1
25
|
## 0.1.0 (2020-04-09)
|
2
26
|
|
3
27
|
- First release
|
data/LICENSE.txt
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
MIT License
|
2
2
|
|
3
|
-
Copyright (c) 2020 Andrew Kane
|
4
3
|
Copyright (c) Facebook, Inc. and its affiliates.
|
4
|
+
Copyright (c) 2020 Andrew Kane
|
5
5
|
|
6
6
|
Permission is hereby granted, free of charge, to any person obtaining
|
7
7
|
a copy of this software and associated documentation files (the
|
data/README.md
CHANGED
@@ -10,6 +10,8 @@ Supports:
|
|
10
10
|
|
11
11
|
And gracefully handles missing data
|
12
12
|
|
13
|
+
[![Build Status](https://travis-ci.org/ankane/prophet.svg?branch=master)](https://travis-ci.org/ankane/prophet) [![Build status](https://ci.appveyor.com/api/projects/status/8ahmsvvhum4ivnmv/branch/master?svg=true)](https://ci.appveyor.com/project/ankane/prophet/branch/master)
|
14
|
+
|
13
15
|
## Installation
|
14
16
|
|
15
17
|
Add this line to your application’s Gemfile:
|
@@ -18,19 +20,47 @@ Add this line to your application’s Gemfile:
|
|
18
20
|
gem 'prophet-rb'
|
19
21
|
```
|
20
22
|
|
21
|
-
##
|
23
|
+
## Simple API
|
24
|
+
|
25
|
+
Get future predictions for a time series
|
26
|
+
|
27
|
+
```ruby
|
28
|
+
series = {
|
29
|
+
Date.parse("2020-01-01") => 100,
|
30
|
+
Date.parse("2020-01-02") => 150,
|
31
|
+
Date.parse("2020-01-03") => 136,
|
32
|
+
# ...
|
33
|
+
}
|
34
|
+
|
35
|
+
Prophet.forecast(series)
|
36
|
+
```
|
37
|
+
|
38
|
+
Specify the number of predictions to return
|
39
|
+
|
40
|
+
```ruby
|
41
|
+
Prophet.forecast(series, count: 3)
|
42
|
+
```
|
43
|
+
|
44
|
+
Works great with [Groupdate](https://github.com/ankane/groupdate)
|
45
|
+
|
46
|
+
```ruby
|
47
|
+
series = User.group_by_day(:created_at).count
|
48
|
+
Prophet.forecast(series)
|
49
|
+
```
|
50
|
+
|
51
|
+
## Advanced API
|
22
52
|
|
23
|
-
Check out the [Prophet documentation](https://facebook.github.io/prophet/docs/quick_start.html) for a great explanation of all of the features. The
|
53
|
+
Check out the [Prophet documentation](https://facebook.github.io/prophet/docs/quick_start.html) for a great explanation of all of the features. The advanced API follows the Python API and supports the same features. It uses [Rover](https://github.com/ankane/rover) for data frames.
|
24
54
|
|
25
|
-
## Quick Start
|
55
|
+
## Advanced Quick Start
|
26
56
|
|
27
57
|
[Explanation](https://facebook.github.io/prophet/docs/quick_start.html)
|
28
58
|
|
29
59
|
Create a data frame with `ds` and `y` columns - here’s [an example](examples/example_wp_log_peyton_manning.csv) you can use
|
30
60
|
|
31
61
|
```ruby
|
32
|
-
df =
|
33
|
-
df.head
|
62
|
+
df = Rover.read_csv("example_wp_log_peyton_manning.csv")
|
63
|
+
df.head
|
34
64
|
```
|
35
65
|
|
36
66
|
ds | y
|
@@ -52,7 +82,7 @@ Make a data frame with a `ds` column for future predictions
|
|
52
82
|
|
53
83
|
```ruby
|
54
84
|
future = m.make_future_dataframe(periods: 365)
|
55
|
-
future.tail
|
85
|
+
future.tail
|
56
86
|
```
|
57
87
|
|
58
88
|
ds |
|
@@ -67,7 +97,7 @@ Make predictions
|
|
67
97
|
|
68
98
|
```ruby
|
69
99
|
forecast = m.predict(future)
|
70
|
-
forecast["ds", "yhat", "yhat_lower", "yhat_upper"].tail
|
100
|
+
forecast[["ds", "yhat", "yhat_lower", "yhat_upper"]].tail
|
71
101
|
```
|
72
102
|
|
73
103
|
ds | yhat | yhat_lower | yhat_upper
|
@@ -88,7 +118,7 @@ Plot the forecast
|
|
88
118
|
m.plot(forecast).savefig("forecast.png")
|
89
119
|
```
|
90
120
|
|
91
|
-
![Forecast](https://blazer.dokkuapp.com/assets/prophet/forecast-
|
121
|
+
![Forecast](https://blazer.dokkuapp.com/assets/prophet/forecast-77cf453fda67d1b462c6c22aee3a02572203b71c4517fedecc1f438cd374a876.png)
|
92
122
|
|
93
123
|
Plot components
|
94
124
|
|
@@ -96,7 +126,46 @@ Plot components
|
|
96
126
|
m.plot_components(forecast).savefig("components.png")
|
97
127
|
```
|
98
128
|
|
99
|
-
![Components](https://blazer.dokkuapp.com/assets/prophet/components-
|
129
|
+
![Components](https://blazer.dokkuapp.com/assets/prophet/components-2cdd260e23bc89824ecca25f6bfe394deb5821d60b7e0e551469c90d204acd67.png)
|
130
|
+
|
131
|
+
## Saturating Forecasts
|
132
|
+
|
133
|
+
[Explanation](https://facebook.github.io/prophet/docs/saturating_forecasts.html)
|
134
|
+
|
135
|
+
Forecast logistic growth instead of linear
|
136
|
+
|
137
|
+
```ruby
|
138
|
+
df = Rover.read_csv("example_wp_log_R.csv")
|
139
|
+
df["cap"] = 8.5
|
140
|
+
m = Prophet.new(growth: "logistic")
|
141
|
+
m.fit(df)
|
142
|
+
future = m.make_future_dataframe(periods: 365)
|
143
|
+
future["cap"] = 8.5
|
144
|
+
forecast = m.predict(future)
|
145
|
+
```
|
146
|
+
|
147
|
+
## Trend Changepoints
|
148
|
+
|
149
|
+
[Explanation](https://facebook.github.io/prophet/docs/trend_changepoints.html)
|
150
|
+
|
151
|
+
Plot changepoints
|
152
|
+
|
153
|
+
```ruby
|
154
|
+
fig = m.plot(forecast)
|
155
|
+
m.add_changepoints_to_plot(fig.gca, forecast)
|
156
|
+
```
|
157
|
+
|
158
|
+
Adjust trend flexibility
|
159
|
+
|
160
|
+
```ruby
|
161
|
+
m = Prophet.new(changepoint_prior_scale: 0.5)
|
162
|
+
```
|
163
|
+
|
164
|
+
Specify the location of changepoints
|
165
|
+
|
166
|
+
```ruby
|
167
|
+
m = Prophet.new(changepoints: ["2014-01-01"])
|
168
|
+
```
|
100
169
|
|
101
170
|
## Holidays and Special Events
|
102
171
|
|
@@ -105,21 +174,21 @@ m.plot_components(forecast).savefig("components.png")
|
|
105
174
|
Create a data frame with `holiday` and `ds` columns. Include all occurrences in your past data and future occurrences you’d like to forecast.
|
106
175
|
|
107
176
|
```ruby
|
108
|
-
playoffs =
|
109
|
-
"holiday" =>
|
177
|
+
playoffs = Rover::DataFrame.new(
|
178
|
+
"holiday" => "playoff",
|
110
179
|
"ds" => ["2008-01-13", "2009-01-03", "2010-01-16",
|
111
180
|
"2010-01-24", "2010-02-07", "2011-01-08",
|
112
181
|
"2013-01-12", "2014-01-12", "2014-01-19",
|
113
182
|
"2014-02-02", "2015-01-11", "2016-01-17",
|
114
183
|
"2016-01-24", "2016-02-07"],
|
115
|
-
"lower_window" =>
|
116
|
-
"upper_window" =>
|
184
|
+
"lower_window" => 0,
|
185
|
+
"upper_window" => 1
|
117
186
|
)
|
118
|
-
superbowls =
|
119
|
-
"holiday" =>
|
187
|
+
superbowls = Rover::DataFrame.new(
|
188
|
+
"holiday" => "superbowl",
|
120
189
|
"ds" => ["2010-02-07", "2014-02-02", "2016-02-07"],
|
121
|
-
"lower_window" =>
|
122
|
-
"upper_window" =>
|
190
|
+
"lower_window" => 0,
|
191
|
+
"upper_window" => 1
|
123
192
|
)
|
124
193
|
holidays = playoffs.concat(superbowls)
|
125
194
|
|
@@ -141,7 +210,25 @@ Specify custom seasonalities
|
|
141
210
|
m = Prophet.new(weekly_seasonality: false)
|
142
211
|
m.add_seasonality(name: "monthly", period: 30.5, fourier_order: 5)
|
143
212
|
forecast = m.fit(df).predict(future)
|
144
|
-
|
213
|
+
```
|
214
|
+
|
215
|
+
Specify additional regressors
|
216
|
+
|
217
|
+
```ruby
|
218
|
+
nfl_sunday = lambda do |ds|
|
219
|
+
date = ds.respond_to?(:to_date) ? ds.to_date : Date.parse(ds)
|
220
|
+
date.wday == 0 && (date.month > 8 || date.month < 2) ? 1 : 0
|
221
|
+
end
|
222
|
+
|
223
|
+
df["nfl_sunday"] = df["ds"].map(&nfl_sunday)
|
224
|
+
|
225
|
+
m = Prophet.new
|
226
|
+
m.add_regressor("nfl_sunday")
|
227
|
+
m.fit(df)
|
228
|
+
|
229
|
+
future["nfl_sunday"] = future["ds"].map(&nfl_sunday)
|
230
|
+
|
231
|
+
forecast = m.predict(future)
|
145
232
|
```
|
146
233
|
|
147
234
|
## Multiplicative Seasonality
|
@@ -149,13 +236,27 @@ m.plot_components(forecast).savefig("components.png")
|
|
149
236
|
[Explanation](https://facebook.github.io/prophet/docs/multiplicative_seasonality.html)
|
150
237
|
|
151
238
|
```ruby
|
152
|
-
df =
|
239
|
+
df = Rover.read_csv("example_air_passengers.csv")
|
153
240
|
m = Prophet.new(seasonality_mode: "multiplicative")
|
154
241
|
m.fit(df)
|
155
242
|
future = m.make_future_dataframe(periods: 50, freq: "MS")
|
156
243
|
forecast = m.predict(future)
|
157
244
|
```
|
158
245
|
|
246
|
+
## Uncertainty Intervals
|
247
|
+
|
248
|
+
Specify the width of uncertainty intervals (80% by default)
|
249
|
+
|
250
|
+
```ruby
|
251
|
+
Prophet.new(interval_width: 0.95)
|
252
|
+
```
|
253
|
+
|
254
|
+
Get uncertainty in seasonality
|
255
|
+
|
256
|
+
```ruby
|
257
|
+
Prophet.new(mcmc_samples: 300)
|
258
|
+
```
|
259
|
+
|
159
260
|
## Non-Daily Data
|
160
261
|
|
161
262
|
[Explanation](https://facebook.github.io/prophet/docs/non-daily_data.html)
|
@@ -163,17 +264,25 @@ forecast = m.predict(future)
|
|
163
264
|
Sub-daily data
|
164
265
|
|
165
266
|
```ruby
|
166
|
-
df =
|
267
|
+
df = Rover.read_csv("example_yosemite_temps.csv")
|
167
268
|
m = Prophet.new(changepoint_prior_scale: 0.01).fit(df)
|
168
269
|
future = m.make_future_dataframe(periods: 300, freq: "H")
|
169
|
-
|
170
|
-
m.plot(fcst).savefig("forecast.png")
|
270
|
+
forecast = m.predict(future)
|
171
271
|
```
|
172
272
|
|
173
273
|
## Resources
|
174
274
|
|
175
275
|
- [Forecasting at Scale](https://peerj.com/preprints/3190.pdf)
|
176
276
|
|
277
|
+
## Upgrading
|
278
|
+
|
279
|
+
### 0.2.0
|
280
|
+
|
281
|
+
Prophet now uses [Rover](https://github.com/ankane/rover) instead of Daru. Two changes you may need to make are:
|
282
|
+
|
283
|
+
- `Rover.read_csv` instead of `Daru::DataFrame.from_csv`
|
284
|
+
- `df[["ds", "yhat"]]` instead of `df["ds", "yhat"]`
|
285
|
+
|
177
286
|
## Credits
|
178
287
|
|
179
288
|
This library was ported from the [Prophet Python library](https://github.com/facebook/prophet) and is available under the same license.
|
data/lib/prophet.rb
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
# dependencies
|
2
2
|
require "cmdstan"
|
3
|
-
require "
|
3
|
+
require "rover"
|
4
4
|
require "numo/narray"
|
5
5
|
|
6
6
|
# stdlib
|
@@ -20,4 +20,67 @@ module Prophet
|
|
20
20
|
def self.new(**kwargs)
|
21
21
|
Forecaster.new(**kwargs)
|
22
22
|
end
|
23
|
+
|
24
|
+
def self.forecast(series, count: 10)
|
25
|
+
raise ArgumentError, "Series must have at least 10 data points" if series.size < 10
|
26
|
+
|
27
|
+
# check type to determine output format
|
28
|
+
# check for before converting to time
|
29
|
+
keys = series.keys
|
30
|
+
dates = keys.all? { |k| k.is_a?(Date) }
|
31
|
+
time_zone = keys.first.time_zone if keys.first.respond_to?(:time_zone)
|
32
|
+
utc = keys.first.utc? if keys.first.respond_to?(:utc?)
|
33
|
+
times = keys.map(&:to_time)
|
34
|
+
|
35
|
+
day = times.all? { |t| t.hour == 0 && t.min == 0 && t.sec == 0 && t.nsec == 0 }
|
36
|
+
week = day && times.map { |k| k.wday }.uniq.size == 1
|
37
|
+
month = day && times.all? { |k| k.day == 1 }
|
38
|
+
quarter = month && times.all? { |k| k.month % 3 == 1 }
|
39
|
+
year = quarter && times.all? { |k| k.month == 1 }
|
40
|
+
|
41
|
+
freq =
|
42
|
+
if year
|
43
|
+
"YS"
|
44
|
+
elsif quarter
|
45
|
+
"QS"
|
46
|
+
elsif month
|
47
|
+
"MS"
|
48
|
+
elsif week
|
49
|
+
"W"
|
50
|
+
elsif day
|
51
|
+
"D"
|
52
|
+
else
|
53
|
+
diff = Rover::Vector.new(times).sort.diff.to_numo[1..-1]
|
54
|
+
min_diff = diff.min.to_i
|
55
|
+
|
56
|
+
# could be another common divisor
|
57
|
+
# but keep it simple for now
|
58
|
+
raise "Unknown frequency" unless (diff % min_diff).eq(0).all?
|
59
|
+
|
60
|
+
"#{min_diff}S"
|
61
|
+
end
|
62
|
+
|
63
|
+
# use series, not times, so dates are handled correctly
|
64
|
+
df = Rover::DataFrame.new({"ds" => series.keys, "y" => series.values})
|
65
|
+
|
66
|
+
m = Prophet.new
|
67
|
+
m.logger.level = ::Logger::FATAL # no logging
|
68
|
+
m.fit(df)
|
69
|
+
|
70
|
+
future = m.make_future_dataframe(periods: count, include_history: false, freq: freq)
|
71
|
+
forecast = m.predict(future)
|
72
|
+
result = forecast[["ds", "yhat"]].to_a
|
73
|
+
|
74
|
+
# use the same format as input
|
75
|
+
if dates
|
76
|
+
result.each { |v| v["ds"] = v["ds"].to_date }
|
77
|
+
elsif time_zone
|
78
|
+
result.each { |v| v["ds"] = v["ds"].in_time_zone(time_zone) }
|
79
|
+
elsif utc
|
80
|
+
result.each { |v| v["ds"] = v["ds"].utc }
|
81
|
+
else
|
82
|
+
result.each { |v| v["ds"] = v["ds"].localtime }
|
83
|
+
end
|
84
|
+
result.map { |v| [v["ds"], v["yhat"]] }.to_h
|
85
|
+
end
|
23
86
|
end
|
data/lib/prophet/forecaster.rb
CHANGED
@@ -82,12 +82,12 @@ module Prophet
|
|
82
82
|
raise ArgumentError, "Parameter \"changepoint_range\" must be in [0, 1]"
|
83
83
|
end
|
84
84
|
if @holidays
|
85
|
-
if !@holidays.is_a?(
|
85
|
+
if !@holidays.is_a?(Rover::DataFrame) && @holidays.include?("ds") && @holidays.include?("holiday")
|
86
86
|
raise ArgumentError, "holidays must be a DataFrame with \"ds\" and \"holiday\" columns."
|
87
87
|
end
|
88
88
|
@holidays["ds"] = to_datetime(@holidays["ds"])
|
89
|
-
has_lower = @holidays.
|
90
|
-
has_upper = @holidays.
|
89
|
+
has_lower = @holidays.include?("lower_window")
|
90
|
+
has_upper = @holidays.include?("upper_window")
|
91
91
|
if has_lower ^ has_upper # xor
|
92
92
|
raise ArgumentError, "Holidays must have both lower_window and upper_window, or neither"
|
93
93
|
end
|
@@ -141,7 +141,7 @@ module Prophet
|
|
141
141
|
end
|
142
142
|
|
143
143
|
def setup_dataframe(df, initialize_scales: false)
|
144
|
-
if df.
|
144
|
+
if df.include?("y")
|
145
145
|
df["y"] = df["y"].map(&:to_f)
|
146
146
|
raise ArgumentError "Found infinity in column y." unless df["y"].all?(&:finite?)
|
147
147
|
end
|
@@ -152,18 +152,18 @@ module Prophet
|
|
152
152
|
raise ArgumentError, "Found NaN in column ds." if df["ds"].any?(&:nil?)
|
153
153
|
|
154
154
|
@extra_regressors.each_key do |name|
|
155
|
-
if !df.
|
155
|
+
if !df.include?(name)
|
156
156
|
raise ArgumentError, "Regressor #{name.inspect} missing from dataframe"
|
157
157
|
end
|
158
158
|
df[name] = df[name].map(&:to_f)
|
159
|
-
if df[name].any?(&:nil)
|
159
|
+
if df[name].any?(&:nil?)
|
160
160
|
raise ArgumentError, "Found NaN in column #{name.inspect}"
|
161
161
|
end
|
162
162
|
end
|
163
163
|
@seasonalities.values.each do |props|
|
164
164
|
condition_name = props[:condition_name]
|
165
165
|
if condition_name
|
166
|
-
if !df.
|
166
|
+
if !df.include?(condition_name)
|
167
167
|
raise ArgumentError, "Condition #{condition_name.inspect} missing from dataframe"
|
168
168
|
end
|
169
169
|
if df.where(!df[condition_name].in([true, false])).any?
|
@@ -172,36 +172,33 @@ module Prophet
|
|
172
172
|
end
|
173
173
|
end
|
174
174
|
|
175
|
-
|
176
|
-
df.index.name = nil
|
177
|
-
end
|
178
|
-
df = df.sort(["ds"])
|
175
|
+
df = df.sort_by { |r| r["ds"] }
|
179
176
|
|
180
177
|
initialize_scales(initialize_scales, df)
|
181
178
|
|
182
|
-
if @logistic_floor && !df.
|
179
|
+
if @logistic_floor && !df.include?("floor")
|
183
180
|
raise ArgumentError, "Expected column \"floor\"."
|
184
181
|
else
|
185
182
|
df["floor"] = 0
|
186
183
|
end
|
187
184
|
|
188
185
|
if @growth == "logistic"
|
189
|
-
unless df.
|
186
|
+
unless df.include?("cap")
|
190
187
|
raise ArgumentError, "Capacities must be supplied for logistic growth in column \"cap\""
|
191
188
|
end
|
192
|
-
if df
|
189
|
+
if df[df["cap"] <= df["floor"]].size > 0
|
193
190
|
raise ArgumentError, "cap must be greater than floor (which defaults to 0)."
|
194
191
|
end
|
195
|
-
df["cap_scaled"] = (df["cap"] - df["floor"]) / @y_scale
|
192
|
+
df["cap_scaled"] = (df["cap"] - df["floor"]) / @y_scale.to_f
|
196
193
|
end
|
197
194
|
|
198
195
|
df["t"] = (df["ds"] - @start) / @t_scale.to_f
|
199
|
-
if df.
|
200
|
-
df["y_scaled"] = (df["y"] - df["floor"]) / @y_scale
|
196
|
+
if df.include?("y")
|
197
|
+
df["y_scaled"] = (df["y"] - df["floor"]) / @y_scale.to_f
|
201
198
|
end
|
202
199
|
|
203
200
|
@extra_regressors.each do |name, props|
|
204
|
-
df[name] = (
|
201
|
+
df[name] = (df[name] - props[:mu]) / props[:std].to_f
|
205
202
|
end
|
206
203
|
|
207
204
|
df
|
@@ -218,32 +215,40 @@ module Prophet
|
|
218
215
|
end
|
219
216
|
|
220
217
|
def set_changepoints
|
221
|
-
|
218
|
+
if @changepoints
|
219
|
+
if @changepoints.size > 0
|
220
|
+
too_low = @changepoints.min < @history["ds"].min
|
221
|
+
too_high = @changepoints.max > @history["ds"].max
|
222
|
+
if too_low || too_high
|
223
|
+
raise ArgumentError, "Changepoints must fall within training data."
|
224
|
+
end
|
225
|
+
end
|
226
|
+
else
|
227
|
+
hist_size = (@history.shape[0] * @changepoint_range).floor
|
222
228
|
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
229
|
+
if @n_changepoints + 1 > hist_size
|
230
|
+
@n_changepoints = hist_size - 1
|
231
|
+
logger.info "n_changepoints greater than number of observations. Using #{@n_changepoints}"
|
232
|
+
end
|
227
233
|
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
+
if @n_changepoints > 0
|
235
|
+
step = (hist_size - 1) / @n_changepoints.to_f
|
236
|
+
cp_indexes = (@n_changepoints + 1).times.map { |i| (i * step).round }
|
237
|
+
@changepoints = Rover::Vector.new(@history["ds"].to_a.values_at(*cp_indexes)).tail(-1)
|
238
|
+
else
|
239
|
+
@changepoints = []
|
240
|
+
end
|
234
241
|
end
|
235
242
|
|
236
243
|
if @changepoints.size > 0
|
237
|
-
@changepoints_t =
|
244
|
+
@changepoints_t = (@changepoints.map(&:to_i).sort.to_numo.cast_to(Numo::DFloat) - @start.to_i) / @t_scale.to_f
|
238
245
|
else
|
239
246
|
@changepoints_t = Numo::NArray.asarray([0])
|
240
247
|
end
|
241
248
|
end
|
242
249
|
|
243
250
|
def fourier_series(dates, period, series_order)
|
244
|
-
|
245
|
-
# uses to_datetime first so we get UTC
|
246
|
-
t = Numo::DFloat.asarray(dates.map { |v| v.to_i - start }) / (3600 * 24.0)
|
251
|
+
t = dates.map(&:to_i).to_numo / (3600 * 24.0)
|
247
252
|
|
248
253
|
# no need for column_stack
|
249
254
|
series_order.times.flat_map do |i|
|
@@ -255,11 +260,11 @@ module Prophet
|
|
255
260
|
|
256
261
|
def make_seasonality_features(dates, period, series_order, prefix)
|
257
262
|
features = fourier_series(dates, period, series_order)
|
258
|
-
|
263
|
+
Rover::DataFrame.new(features.map.with_index { |v, i| ["#{prefix}_delim_#{i + 1}", v] }.to_h)
|
259
264
|
end
|
260
265
|
|
261
266
|
def construct_holiday_dataframe(dates)
|
262
|
-
all_holidays =
|
267
|
+
all_holidays = Rover::DataFrame.new
|
263
268
|
if @holidays
|
264
269
|
all_holidays = @holidays.dup
|
265
270
|
end
|
@@ -271,12 +276,12 @@ module Prophet
|
|
271
276
|
# Drop future holidays not previously seen in training data
|
272
277
|
if @train_holiday_names
|
273
278
|
# Remove holiday names didn't show up in fit
|
274
|
-
all_holidays = all_holidays
|
279
|
+
all_holidays = all_holidays[all_holidays["holiday"].in?(@train_holiday_names)]
|
275
280
|
|
276
281
|
# Add holiday names in fit but not in predict with ds as NA
|
277
|
-
holidays_to_add =
|
278
|
-
"holiday" => @train_holiday_names
|
279
|
-
)
|
282
|
+
holidays_to_add = Rover::DataFrame.new({
|
283
|
+
"holiday" => @train_holiday_names[!@train_holiday_names.in?(all_holidays["holiday"])]
|
284
|
+
})
|
280
285
|
all_holidays = all_holidays.concat(holidays_to_add)
|
281
286
|
end
|
282
287
|
|
@@ -310,7 +315,7 @@ module Prophet
|
|
310
315
|
|
311
316
|
lw.upto(uw).each do |offset|
|
312
317
|
occurrence = dt ? dt + offset : nil
|
313
|
-
loc = occurrence ? row_index.index(occurrence) : nil
|
318
|
+
loc = occurrence ? row_index.to_a.index(occurrence) : nil
|
314
319
|
key = "#{row["holiday"]}_delim_#{offset >= 0 ? "+" : "-"}#{offset.abs}"
|
315
320
|
if loc
|
316
321
|
expanded_holidays[key][loc] = 1.0
|
@@ -319,14 +324,14 @@ module Prophet
|
|
319
324
|
end
|
320
325
|
end
|
321
326
|
end
|
322
|
-
holiday_features =
|
323
|
-
#
|
324
|
-
holiday_features = holiday_features[
|
325
|
-
prior_scale_list = holiday_features.
|
327
|
+
holiday_features = Rover::DataFrame.new(expanded_holidays)
|
328
|
+
# Make sure column order is consistent
|
329
|
+
holiday_features = holiday_features[holiday_features.vector_names.sort]
|
330
|
+
prior_scale_list = holiday_features.vector_names.map { |h| prior_scales[h.split("_delim_")[0]] }
|
326
331
|
holiday_names = prior_scales.keys
|
327
332
|
# Store holiday names used in fit
|
328
|
-
if
|
329
|
-
@train_holiday_names =
|
333
|
+
if @train_holiday_names.nil?
|
334
|
+
@train_holiday_names = Rover::Vector.new(holiday_names)
|
330
335
|
end
|
331
336
|
[holiday_features, prior_scale_list, holiday_names]
|
332
337
|
end
|
@@ -424,16 +429,16 @@ module Prophet
|
|
424
429
|
modes[@seasonality_mode].concat(holiday_names)
|
425
430
|
end
|
426
431
|
|
427
|
-
#
|
432
|
+
# Additional regressors
|
428
433
|
@extra_regressors.each do |name, props|
|
429
|
-
seasonal_features << df[name]
|
434
|
+
seasonal_features << Rover::DataFrame.new({name => df[name]})
|
430
435
|
prior_scales << props[:prior_scale]
|
431
436
|
modes[props[:mode]] << name
|
432
437
|
end
|
433
438
|
|
434
|
-
#
|
439
|
+
# Dummy to prevent empty X
|
435
440
|
if seasonal_features.size == 0
|
436
|
-
seasonal_features <<
|
441
|
+
seasonal_features << Rover::DataFrame.new({"zeros" => [0] * df.shape[0]})
|
437
442
|
prior_scales << 1.0
|
438
443
|
end
|
439
444
|
|
@@ -445,16 +450,16 @@ module Prophet
|
|
445
450
|
end
|
446
451
|
|
447
452
|
def regressor_column_matrix(seasonal_features, modes)
|
448
|
-
components =
|
453
|
+
components = Rover::DataFrame.new(
|
449
454
|
"col" => seasonal_features.shape[1].times.to_a,
|
450
|
-
"component" => seasonal_features.
|
455
|
+
"component" => seasonal_features.vector_names.map { |x| x.split("_delim_")[0] }
|
451
456
|
)
|
452
457
|
|
453
|
-
#
|
458
|
+
# Add total for holidays
|
454
459
|
if @train_holiday_names
|
455
460
|
components = add_group_component(components, "holidays", @train_holiday_names.uniq)
|
456
461
|
end
|
457
|
-
#
|
462
|
+
# Add totals additive and multiplicative components, and regressors
|
458
463
|
["additive", "multiplicative"].each do |mode|
|
459
464
|
components = add_group_component(components, mode + "_terms", modes[mode])
|
460
465
|
regressors_by_mode = @extra_regressors.select { |r, props| props[:mode] == mode }
|
@@ -465,20 +470,15 @@ module Prophet
|
|
465
470
|
modes[mode] << mode + "_terms"
|
466
471
|
modes[mode] << "extra_regressors_" + mode
|
467
472
|
end
|
468
|
-
#
|
473
|
+
# After all of the additive/multiplicative groups have been added,
|
469
474
|
modes[@seasonality_mode] << "holidays"
|
470
|
-
#
|
471
|
-
component_cols =
|
472
|
-
|
473
|
-
)
|
474
|
-
component_cols.each_vector do |v|
|
475
|
-
v.map! { |vi| vi.nil? ? 0 : vi }
|
476
|
-
end
|
477
|
-
component_cols.rename_vectors(:_id => "col")
|
475
|
+
# Convert to a binary matrix
|
476
|
+
component_cols = components["col"].crosstab(components["component"])
|
477
|
+
component_cols["col"] = component_cols.delete("_")
|
478
478
|
|
479
479
|
# Add columns for additive and multiplicative terms, if missing
|
480
480
|
["additive_terms", "multiplicative_terms"].each do |name|
|
481
|
-
component_cols[name] = 0 unless component_cols.
|
481
|
+
component_cols[name] = 0 unless component_cols.include?(name)
|
482
482
|
end
|
483
483
|
|
484
484
|
# TODO validation
|
@@ -487,10 +487,10 @@ module Prophet
|
|
487
487
|
end
|
488
488
|
|
489
489
|
def add_group_component(components, name, group)
|
490
|
-
new_comp = components
|
490
|
+
new_comp = components[components["component"].in?(group)].dup
|
491
491
|
group_cols = new_comp["col"].uniq
|
492
492
|
if group_cols.size > 0
|
493
|
-
new_comp =
|
493
|
+
new_comp = Rover::DataFrame.new({"col" => group_cols, "component" => name})
|
494
494
|
components = components.concat(new_comp)
|
495
495
|
end
|
496
496
|
components
|
@@ -566,8 +566,8 @@ module Prophet
|
|
566
566
|
end
|
567
567
|
|
568
568
|
def linear_growth_init(df)
|
569
|
-
i0 =
|
570
|
-
i1 = df
|
569
|
+
i0 = 0
|
570
|
+
i1 = df.size - 1
|
571
571
|
t = df["t"][i1] - df["t"][i0]
|
572
572
|
k = (df["y_scaled"][i1] - df["y_scaled"][i0]) / t
|
573
573
|
m = df["y_scaled"][i0] - k * df["t"][i0]
|
@@ -575,8 +575,8 @@ module Prophet
|
|
575
575
|
end
|
576
576
|
|
577
577
|
def logistic_growth_init(df)
|
578
|
-
i0 =
|
579
|
-
i1 = df
|
578
|
+
i0 = 0
|
579
|
+
i1 = df.size - 1
|
580
580
|
t = df["t"][i1] - df["t"][i0]
|
581
581
|
|
582
582
|
# Force valid values, in case y > cap or y < 0
|
@@ -605,8 +605,13 @@ module Prophet
|
|
605
605
|
def fit(df, **kwargs)
|
606
606
|
raise Error, "Prophet object can only be fit once" if @history
|
607
607
|
|
608
|
-
|
609
|
-
|
608
|
+
if defined?(Daru::DataFrame) && df.is_a?(Daru::DataFrame)
|
609
|
+
df = Rover::DataFrame.new(df.to_h)
|
610
|
+
end
|
611
|
+
raise ArgumentError, "Must be a data frame" unless df.is_a?(Rover::DataFrame)
|
612
|
+
|
613
|
+
history = df[!df["y"].missing]
|
614
|
+
raise Error, "Data has less than 2 non-nil rows" if history.size < 2
|
610
615
|
|
611
616
|
@history_dates = to_datetime(df["ds"]).sort
|
612
617
|
history = setup_dataframe(history, initialize_scales: true)
|
@@ -654,8 +659,8 @@ module Prophet
|
|
654
659
|
# Nothing to fit.
|
655
660
|
@params = stan_init
|
656
661
|
@params["sigma_obs"] = 1e-9
|
657
|
-
@params.each do |par|
|
658
|
-
@params[par] = Numo::NArray.asarray(@params[par])
|
662
|
+
@params.each do |par, _|
|
663
|
+
@params[par] = Numo::NArray.asarray([@params[par]])
|
659
664
|
end
|
660
665
|
elsif @mcmc_samples > 0
|
661
666
|
@params = @stan_backend.sampling(stan_init, dat, @mcmc_samples, **kwargs)
|
@@ -666,8 +671,10 @@ module Prophet
|
|
666
671
|
# If no changepoints were requested, replace delta with 0s
|
667
672
|
if @changepoints.size == 0
|
668
673
|
# Fold delta into the base rate k
|
669
|
-
|
670
|
-
|
674
|
+
# Numo doesn't support -1 with reshape
|
675
|
+
negative_one = @params["delta"].shape.inject(&:*)
|
676
|
+
@params["k"] = @params["k"] + @params["delta"].reshape(negative_one)
|
677
|
+
@params["delta"] = Numo::DFloat.zeros(@params["delta"].shape).reshape(negative_one, 1)
|
671
678
|
end
|
672
679
|
|
673
680
|
self
|
@@ -693,10 +700,10 @@ module Prophet
|
|
693
700
|
|
694
701
|
# Drop columns except ds, cap, floor, and trend
|
695
702
|
cols = ["ds", "trend"]
|
696
|
-
cols << "cap" if df.
|
703
|
+
cols << "cap" if df.include?("cap")
|
697
704
|
cols << "floor" if @logistic_floor
|
698
705
|
# Add in forecast components
|
699
|
-
df2 = df_concat_axis_one([df[
|
706
|
+
df2 = df_concat_axis_one([df[cols], intervals, seasonal_components])
|
700
707
|
df2["yhat"] = df2["trend"] * (df2["multiplicative_terms"] + 1) + df2["additive_terms"]
|
701
708
|
df2
|
702
709
|
end
|
@@ -731,8 +738,7 @@ module Prophet
|
|
731
738
|
k_t[indx] += deltas[s]
|
732
739
|
m_t[indx] += gammas[s]
|
733
740
|
end
|
734
|
-
|
735
|
-
df_values(cap) / (1 + Numo::NMath.exp(-k_t * (t - m_t)))
|
741
|
+
cap.to_numo / (1 + Numo::NMath.exp(-k_t * (t - m_t)))
|
736
742
|
end
|
737
743
|
|
738
744
|
def predict_trend(df)
|
@@ -758,10 +764,10 @@ module Prophet
|
|
758
764
|
upper_p = 100 * (1.0 + @interval_width) / 2
|
759
765
|
end
|
760
766
|
|
761
|
-
x =
|
767
|
+
x = seasonal_features.to_numo
|
762
768
|
data = {}
|
763
|
-
component_cols.
|
764
|
-
beta_c =
|
769
|
+
component_cols.vector_names.each do |component|
|
770
|
+
beta_c = @params["beta"] * component_cols[component].to_numo
|
765
771
|
|
766
772
|
comp = x.dot(beta_c.transpose)
|
767
773
|
if @component_modes["additive"].include?(component)
|
@@ -769,11 +775,11 @@ module Prophet
|
|
769
775
|
end
|
770
776
|
data[component] = comp.mean(axis: 1, nan: true)
|
771
777
|
if @uncertainty_samples
|
772
|
-
data[component + "_lower"] = percentile(
|
773
|
-
data[component + "_upper"] = percentile(
|
778
|
+
data[component + "_lower"] = comp.percentile(lower_p, axis: 1)
|
779
|
+
data[component + "_upper"] = comp.percentile(upper_p, axis: 1)
|
774
780
|
end
|
775
781
|
end
|
776
|
-
|
782
|
+
Rover::DataFrame.new(data)
|
777
783
|
end
|
778
784
|
|
779
785
|
def sample_posterior_predictive(df)
|
@@ -784,9 +790,9 @@ module Prophet
|
|
784
790
|
seasonal_features, _, component_cols, _ = make_all_seasonality_features(df)
|
785
791
|
|
786
792
|
# convert to Numo for performance
|
787
|
-
seasonal_features =
|
788
|
-
additive_terms =
|
789
|
-
multiplicative_terms =
|
793
|
+
seasonal_features = seasonal_features.to_numo
|
794
|
+
additive_terms = component_cols["additive_terms"].to_numo
|
795
|
+
multiplicative_terms = component_cols["multiplicative_terms"].to_numo
|
790
796
|
|
791
797
|
sim_values = {"yhat" => [], "trend" => []}
|
792
798
|
n_iterations.times do |i|
|
@@ -823,11 +829,11 @@ module Prophet
|
|
823
829
|
|
824
830
|
series = {}
|
825
831
|
["yhat", "trend"].each do |key|
|
826
|
-
series["#{key}_lower"] =
|
827
|
-
series["#{key}_upper"] =
|
832
|
+
series["#{key}_lower"] = sim_values[key].percentile(lower_p, axis: 1)
|
833
|
+
series["#{key}_upper"] = sim_values[key].percentile(upper_p, axis: 1)
|
828
834
|
end
|
829
835
|
|
830
|
-
|
836
|
+
Rover::DataFrame.new(series)
|
831
837
|
end
|
832
838
|
|
833
839
|
def sample_model(df, seasonal_features, iteration, s_a, s_m)
|
@@ -848,8 +854,8 @@ module Prophet
|
|
848
854
|
end
|
849
855
|
|
850
856
|
def sample_predictive_trend(df, iteration)
|
851
|
-
k = @params["k"][iteration
|
852
|
-
m = @params["m"][iteration
|
857
|
+
k = @params["k"][iteration]
|
858
|
+
m = @params["m"][iteration]
|
853
859
|
deltas = @params["delta"][iteration, true]
|
854
860
|
|
855
861
|
t = Numo::NArray.asarray(df["t"].to_a)
|
@@ -889,82 +895,81 @@ module Prophet
|
|
889
895
|
trend * @y_scale + Numo::NArray.asarray(df["floor"].to_a)
|
890
896
|
end
|
891
897
|
|
892
|
-
def percentile(a, percentile, axis:)
|
893
|
-
raise Error, "Axis must be 1" if axis != 1
|
894
|
-
|
895
|
-
sorted = a.sort(axis: axis)
|
896
|
-
x = percentile / 100.0 * (sorted.shape[axis] - 1)
|
897
|
-
r = x % 1
|
898
|
-
i = x.floor
|
899
|
-
# this should use axis, but we only need axis: 1
|
900
|
-
if i == sorted.shape[axis] - 1
|
901
|
-
sorted[true, -1]
|
902
|
-
else
|
903
|
-
sorted[true, i] + r * (sorted[true, i + 1] - sorted[true, i])
|
904
|
-
end
|
905
|
-
end
|
906
|
-
|
907
898
|
def make_future_dataframe(periods:, freq: "D", include_history: true)
|
908
899
|
raise Error, "Model has not been fit" unless @history_dates
|
909
900
|
last_date = @history_dates.max
|
901
|
+
# TODO add more freq
|
902
|
+
# https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#timeseries-offset-aliases
|
910
903
|
case freq
|
904
|
+
when /\A\d+S\z/
|
905
|
+
secs = freq.to_i
|
906
|
+
dates = (periods + 1).times.map { |i| last_date + i * secs }
|
907
|
+
when "H"
|
908
|
+
hour = 3600
|
909
|
+
dates = (periods + 1).times.map { |i| last_date + i * hour }
|
911
910
|
when "D"
|
912
911
|
# days have constant length with UTC (no DST or leap seconds)
|
913
|
-
|
914
|
-
|
915
|
-
|
912
|
+
day = 24 * 3600
|
913
|
+
dates = (periods + 1).times.map { |i| last_date + i * day }
|
914
|
+
when "W"
|
915
|
+
week = 7 * 24 * 3600
|
916
|
+
dates = (periods + 1).times.map { |i| last_date + i * week }
|
916
917
|
when "MS"
|
917
918
|
dates = [last_date]
|
919
|
+
# TODO reset day from last date, but keep time
|
918
920
|
periods.times do
|
919
921
|
dates << dates.last.to_datetime.next_month.to_time.utc
|
920
922
|
end
|
923
|
+
when "QS"
|
924
|
+
dates = [last_date]
|
925
|
+
# TODO reset day and month from last date, but keep time
|
926
|
+
periods.times do
|
927
|
+
dates << dates.last.to_datetime.next_month.next_month.next_month.to_time.utc
|
928
|
+
end
|
929
|
+
when "YS"
|
930
|
+
dates = [last_date]
|
931
|
+
# TODO reset day and month from last date, but keep time
|
932
|
+
periods.times do
|
933
|
+
dates << dates.last.to_datetime.next_year.to_time.utc
|
934
|
+
end
|
921
935
|
else
|
922
936
|
raise ArgumentError, "Unknown freq: #{freq}"
|
923
937
|
end
|
924
938
|
dates.select! { |d| d > last_date }
|
925
939
|
dates = dates.last(periods)
|
926
|
-
dates = @history_dates
|
927
|
-
|
940
|
+
dates = @history_dates.to_numo.concatenate(Numo::NArray.cast(dates)) if include_history
|
941
|
+
Rover::DataFrame.new({"ds" => dates})
|
928
942
|
end
|
929
943
|
|
930
944
|
private
|
931
945
|
|
932
|
-
# Time is
|
946
|
+
# Time is preferred over DateTime in Ruby docs
|
933
947
|
# use UTC to be consistent with Python
|
934
948
|
# and so days have equal length (no DST)
|
935
949
|
def to_datetime(vec)
|
936
950
|
return if vec.nil?
|
937
|
-
vec
|
938
|
-
|
939
|
-
|
940
|
-
|
941
|
-
|
942
|
-
|
943
|
-
|
944
|
-
|
951
|
+
vec =
|
952
|
+
vec.map do |v|
|
953
|
+
case v
|
954
|
+
when Time
|
955
|
+
v.utc
|
956
|
+
when Date
|
957
|
+
v.to_datetime.to_time.utc
|
958
|
+
else
|
959
|
+
DateTime.parse(v.to_s).to_time.utc
|
960
|
+
end
|
945
961
|
end
|
946
|
-
|
962
|
+
Rover::Vector.new(vec)
|
947
963
|
end
|
948
964
|
|
949
965
|
# okay to do in-place
|
950
966
|
def df_concat_axis_one(dfs)
|
951
967
|
dfs[1..-1].each do |df|
|
952
|
-
df
|
953
|
-
dfs[0][k] = v
|
954
|
-
end
|
968
|
+
dfs[0].merge!(df)
|
955
969
|
end
|
956
970
|
dfs[0]
|
957
971
|
end
|
958
972
|
|
959
|
-
def df_values(df)
|
960
|
-
if df.is_a?(Daru::Vector)
|
961
|
-
Numo::NArray.asarray(df.to_a)
|
962
|
-
else
|
963
|
-
# TODO make more performant
|
964
|
-
Numo::NArray.asarray(df.to_matrix.to_a)
|
965
|
-
end
|
966
|
-
end
|
967
|
-
|
968
973
|
# https://en.wikipedia.org/wiki/Poisson_distribution#Generating_Poisson-distributed_random_variables
|
969
974
|
def poisson(lam)
|
970
975
|
l = Math.exp(-lam)
|
@@ -979,7 +984,7 @@ module Prophet
|
|
979
984
|
|
980
985
|
# https://en.wikipedia.org/wiki/Laplace_distribution#Generating_values_from_the_Laplace_distribution
|
981
986
|
def laplace(loc, scale, size)
|
982
|
-
u = Numo::DFloat.new(size).rand
|
987
|
+
u = Numo::DFloat.new(size).rand(-0.5, 0.5)
|
983
988
|
loc - scale * u.sign * Numo::NMath.log(1 - 2 * u.abs)
|
984
989
|
end
|
985
990
|
end
|
data/lib/prophet/holidays.rb
CHANGED
@@ -6,7 +6,7 @@ module Prophet
|
|
6
6
|
end
|
7
7
|
|
8
8
|
def make_holidays_df(year_list, country)
|
9
|
-
holidays_df
|
9
|
+
holidays_df[(holidays_df["country"] == country) & (holidays_df["year"].in?(year_list))][["ds", "holiday"]]
|
10
10
|
end
|
11
11
|
|
12
12
|
# TODO marshal on installation
|
@@ -20,7 +20,7 @@ module Prophet
|
|
20
20
|
holidays["country"] << row["country"]
|
21
21
|
holidays["year"] << row["year"]
|
22
22
|
end
|
23
|
-
|
23
|
+
Rover::DataFrame.new(holidays)
|
24
24
|
end
|
25
25
|
end
|
26
26
|
end
|
data/lib/prophet/plot.rb
CHANGED
@@ -8,16 +8,16 @@ module Prophet
|
|
8
8
|
fig = ax.get_figure
|
9
9
|
end
|
10
10
|
fcst_t = to_pydatetime(fcst["ds"])
|
11
|
-
ax.plot(to_pydatetime(@history["ds"]), @history["y"].
|
12
|
-
ax.plot(fcst_t, fcst["yhat"].
|
13
|
-
if fcst.
|
14
|
-
ax.plot(fcst_t, fcst["cap"].
|
11
|
+
ax.plot(to_pydatetime(@history["ds"]), @history["y"].to_a, "k.")
|
12
|
+
ax.plot(fcst_t, fcst["yhat"].to_a, ls: "-", c: "#0072B2")
|
13
|
+
if fcst.include?("cap") && plot_cap
|
14
|
+
ax.plot(fcst_t, fcst["cap"].to_a, ls: "--", c: "k")
|
15
15
|
end
|
16
|
-
if @logistic_floor && fcst.
|
17
|
-
ax.plot(fcst_t, fcst["floor"].
|
16
|
+
if @logistic_floor && fcst.include?("floor") && plot_cap
|
17
|
+
ax.plot(fcst_t, fcst["floor"].to_a, ls: "--", c: "k")
|
18
18
|
end
|
19
19
|
if uncertainty && @uncertainty_samples
|
20
|
-
ax.fill_between(fcst_t, fcst["yhat_lower"].
|
20
|
+
ax.fill_between(fcst_t, fcst["yhat_lower"].to_a, fcst["yhat_upper"].to_a, color: "#0072B2", alpha: 0.2)
|
21
21
|
end
|
22
22
|
# Specify formatting to workaround matplotlib issue #12925
|
23
23
|
locator = dates.AutoDateLocator.new(interval_multiples: false)
|
@@ -33,25 +33,25 @@ module Prophet
|
|
33
33
|
|
34
34
|
def plot_components(fcst, uncertainty: true, plot_cap: true, weekly_start: 0, yearly_start: 0, figsize: nil)
|
35
35
|
components = ["trend"]
|
36
|
-
if @train_holiday_names && fcst.
|
36
|
+
if @train_holiday_names && fcst.include?("holidays")
|
37
37
|
components << "holidays"
|
38
38
|
end
|
39
39
|
# Plot weekly seasonality, if present
|
40
|
-
if @seasonalities["weekly"] && fcst.
|
40
|
+
if @seasonalities["weekly"] && fcst.include?("weekly")
|
41
41
|
components << "weekly"
|
42
42
|
end
|
43
43
|
# Yearly if present
|
44
|
-
if @seasonalities["yearly"] && fcst.
|
44
|
+
if @seasonalities["yearly"] && fcst.include?("yearly")
|
45
45
|
components << "yearly"
|
46
46
|
end
|
47
47
|
# Other seasonalities
|
48
|
-
components.concat(@seasonalities.keys.select { |name| fcst.
|
48
|
+
components.concat(@seasonalities.keys.select { |name| fcst.include?(name) && !["weekly", "yearly"].include?(name) }.sort)
|
49
49
|
regressors = {"additive" => false, "multiplicative" => false}
|
50
50
|
@extra_regressors.each do |name, props|
|
51
51
|
regressors[props[:mode]] = true
|
52
52
|
end
|
53
53
|
["additive", "multiplicative"].each do |mode|
|
54
|
-
if regressors[mode] && fcst.
|
54
|
+
if regressors[mode] && fcst.include?("extra_regressors_#{mode}")
|
55
55
|
components << "extra_regressors_#{mode}"
|
56
56
|
end
|
57
57
|
end
|
@@ -93,6 +93,24 @@ module Prophet
|
|
93
93
|
fig
|
94
94
|
end
|
95
95
|
|
96
|
+
# in Python, this is a separate method
|
97
|
+
def add_changepoints_to_plot(ax, fcst, threshold: 0.01, cp_color: "r", cp_linestyle: "--", trend: true)
|
98
|
+
artists = []
|
99
|
+
if trend
|
100
|
+
artists << ax.plot(to_pydatetime(fcst["ds"]), fcst["trend"].to_a, c: cp_color)
|
101
|
+
end
|
102
|
+
signif_changepoints =
|
103
|
+
if @changepoints.size > 0
|
104
|
+
(@params["delta"].mean(axis: 0, nan: true).abs >= threshold).mask(@changepoints.to_numo)
|
105
|
+
else
|
106
|
+
[]
|
107
|
+
end
|
108
|
+
to_pydatetime(signif_changepoints).each do |cp|
|
109
|
+
artists << ax.axvline(x: cp, c: cp_color, ls: cp_linestyle)
|
110
|
+
end
|
111
|
+
artists
|
112
|
+
end
|
113
|
+
|
96
114
|
private
|
97
115
|
|
98
116
|
def plot_forecast_component(fcst, name, ax: nil, uncertainty: true, plot_cap: false, figsize: [10, 6])
|
@@ -102,15 +120,15 @@ module Prophet
|
|
102
120
|
ax = fig.add_subplot(111)
|
103
121
|
end
|
104
122
|
fcst_t = to_pydatetime(fcst["ds"])
|
105
|
-
artists += ax.plot(fcst_t, fcst[name].
|
106
|
-
if fcst.
|
107
|
-
artists += ax.plot(fcst_t, fcst["cap"].
|
123
|
+
artists += ax.plot(fcst_t, fcst[name].to_a, ls: "-", c: "#0072B2")
|
124
|
+
if fcst.include?("cap") && plot_cap
|
125
|
+
artists += ax.plot(fcst_t, fcst["cap"].to_a, ls: "--", c: "k")
|
108
126
|
end
|
109
|
-
if @logistic_floor && fcst.
|
110
|
-
ax.plot(fcst_t, fcst["floor"].
|
127
|
+
if @logistic_floor && fcst.include?("floor") && plot_cap
|
128
|
+
ax.plot(fcst_t, fcst["floor"].to_a, ls: "--", c: "k")
|
111
129
|
end
|
112
130
|
if uncertainty && @uncertainty_samples
|
113
|
-
artists += [ax.fill_between(fcst_t, fcst[name + "_lower"].
|
131
|
+
artists += [ax.fill_between(fcst_t, fcst[name + "_lower"].to_a, fcst[name + "_upper"].to_a, color: "#0072B2", alpha: 0.2)]
|
114
132
|
end
|
115
133
|
# Specify formatting to workaround matplotlib issue #12925
|
116
134
|
locator = dates.AutoDateLocator.new(interval_multiples: false)
|
@@ -127,17 +145,17 @@ module Prophet
|
|
127
145
|
end
|
128
146
|
|
129
147
|
def seasonality_plot_df(ds)
|
130
|
-
df_dict = {"ds" => ds, "cap" =>
|
131
|
-
@extra_regressors.
|
132
|
-
df_dict[name] =
|
148
|
+
df_dict = {"ds" => ds, "cap" => 1.0, "floor" => 0.0}
|
149
|
+
@extra_regressors.each_key do |name|
|
150
|
+
df_dict[name] = 0.0
|
133
151
|
end
|
134
152
|
# Activate all conditional seasonality columns
|
135
153
|
@seasonalities.values.each do |props|
|
136
154
|
if props[:condition_name]
|
137
|
-
df_dict[props[:condition_name]] =
|
155
|
+
df_dict[props[:condition_name]] = true
|
138
156
|
end
|
139
157
|
end
|
140
|
-
df =
|
158
|
+
df = Rover::DataFrame.new(df_dict)
|
141
159
|
df = setup_dataframe(df)
|
142
160
|
df
|
143
161
|
end
|
@@ -154,9 +172,9 @@ module Prophet
|
|
154
172
|
df_w = seasonality_plot_df(days)
|
155
173
|
seas = predict_seasonal_components(df_w)
|
156
174
|
days = days.map { |v| v.strftime("%A") }
|
157
|
-
artists += ax.plot(days.size.times.to_a, seas[name].
|
175
|
+
artists += ax.plot(days.size.times.to_a, seas[name].to_a, ls: "-", c: "#0072B2")
|
158
176
|
if uncertainty && @uncertainty_samples
|
159
|
-
artists += [ax.fill_between(days.size.times.to_a, seas[name + "_lower"].
|
177
|
+
artists += [ax.fill_between(days.size.times.to_a, seas[name + "_lower"].to_a, seas[name + "_upper"].to_a, color: "#0072B2", alpha: 0.2)]
|
160
178
|
end
|
161
179
|
ax.grid(true, which: "major", c: "gray", ls: "-", lw: 1, alpha: 0.2)
|
162
180
|
ax.set_xticks(days.size.times.to_a)
|
@@ -180,9 +198,9 @@ module Prophet
|
|
180
198
|
days = 365.times.map { |i| start + i + yearly_start }
|
181
199
|
df_y = seasonality_plot_df(days)
|
182
200
|
seas = predict_seasonal_components(df_y)
|
183
|
-
artists += ax.plot(to_pydatetime(df_y["ds"]), seas[name].
|
201
|
+
artists += ax.plot(to_pydatetime(df_y["ds"]), seas[name].to_a, ls: "-", c: "#0072B2")
|
184
202
|
if uncertainty && @uncertainty_samples
|
185
|
-
artists += [ax.fill_between(to_pydatetime(df_y["ds"]), seas[name + "_lower"].
|
203
|
+
artists += [ax.fill_between(to_pydatetime(df_y["ds"]), seas[name + "_lower"].to_a, seas[name + "_upper"].to_a, color: "#0072B2", alpha: 0.2)]
|
186
204
|
end
|
187
205
|
ax.grid(true, which: "major", c: "gray", ls: "-", lw: 1, alpha: 0.2)
|
188
206
|
months = dates.MonthLocator.new((1..12).to_a, bymonthday: 1, interval: 2)
|
@@ -213,9 +231,9 @@ module Prophet
|
|
213
231
|
days = plot_points.times.map { |i| Time.at(start + i * step).utc }
|
214
232
|
df_y = seasonality_plot_df(days)
|
215
233
|
seas = predict_seasonal_components(df_y)
|
216
|
-
artists += ax.plot(to_pydatetime(df_y["ds"]), seas[name].
|
234
|
+
artists += ax.plot(to_pydatetime(df_y["ds"]), seas[name].to_a, ls: "-", c: "#0072B2")
|
217
235
|
if uncertainty && @uncertainty_samples
|
218
|
-
artists += [ax.fill_between(to_pydatetime(df_y["ds"]), seas[name + "_lower"].
|
236
|
+
artists += [ax.fill_between(to_pydatetime(df_y["ds"]), seas[name + "_lower"].to_a, seas[name + "_upper"].to_a, color: "#0072B2", alpha: 0.2)]
|
219
237
|
end
|
220
238
|
ax.grid(true, which: "major", c: "gray", ls: "-", lw: 1, alpha: 0.2)
|
221
239
|
step = (finish - start) / (7 - 1).to_f
|
@@ -263,7 +281,7 @@ module Prophet
|
|
263
281
|
|
264
282
|
def to_pydatetime(v)
|
265
283
|
datetime = PyCall.import_module("datetime")
|
266
|
-
v.map { |v| datetime.datetime.utcfromtimestamp(v.to_i) }
|
284
|
+
v.map { |v| datetime.datetime.utcfromtimestamp(v.to_i) }.to_a
|
267
285
|
end
|
268
286
|
end
|
269
287
|
end
|
data/lib/prophet/stan_backend.rb
CHANGED
@@ -127,7 +127,7 @@ module Prophet
|
|
127
127
|
stan_data["t_change"] = stan_data["t_change"].to_a
|
128
128
|
stan_data["s_a"] = stan_data["s_a"].to_a
|
129
129
|
stan_data["s_m"] = stan_data["s_m"].to_a
|
130
|
-
stan_data["X"] = stan_data["X"].
|
130
|
+
stan_data["X"] = stan_data["X"].to_numo.to_a
|
131
131
|
stan_init["delta"] = stan_init["delta"].to_a
|
132
132
|
stan_init["beta"] = stan_init["beta"].to_a
|
133
133
|
[stan_init, stan_data]
|
data/lib/prophet/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: prophet-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.2.3
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-
|
11
|
+
date: 2020-10-15 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: cmdstan
|
@@ -16,30 +16,30 @@ dependencies:
|
|
16
16
|
requirements:
|
17
17
|
- - ">="
|
18
18
|
- !ruby/object:Gem::Version
|
19
|
-
version:
|
19
|
+
version: 0.1.2
|
20
20
|
type: :runtime
|
21
21
|
prerelease: false
|
22
22
|
version_requirements: !ruby/object:Gem::Requirement
|
23
23
|
requirements:
|
24
24
|
- - ">="
|
25
25
|
- !ruby/object:Gem::Version
|
26
|
-
version:
|
26
|
+
version: 0.1.2
|
27
27
|
- !ruby/object:Gem::Dependency
|
28
|
-
name:
|
28
|
+
name: numo-narray
|
29
29
|
requirement: !ruby/object:Gem::Requirement
|
30
30
|
requirements:
|
31
31
|
- - ">="
|
32
32
|
- !ruby/object:Gem::Version
|
33
|
-
version:
|
33
|
+
version: 0.9.1.7
|
34
34
|
type: :runtime
|
35
35
|
prerelease: false
|
36
36
|
version_requirements: !ruby/object:Gem::Requirement
|
37
37
|
requirements:
|
38
38
|
- - ">="
|
39
39
|
- !ruby/object:Gem::Version
|
40
|
-
version:
|
40
|
+
version: 0.9.1.7
|
41
41
|
- !ruby/object:Gem::Dependency
|
42
|
-
name:
|
42
|
+
name: rover-df
|
43
43
|
requirement: !ruby/object:Gem::Requirement
|
44
44
|
requirements:
|
45
45
|
- - ">="
|
@@ -94,6 +94,20 @@ dependencies:
|
|
94
94
|
- - ">="
|
95
95
|
- !ruby/object:Gem::Version
|
96
96
|
version: '5'
|
97
|
+
- !ruby/object:Gem::Dependency
|
98
|
+
name: daru
|
99
|
+
requirement: !ruby/object:Gem::Requirement
|
100
|
+
requirements:
|
101
|
+
- - ">="
|
102
|
+
- !ruby/object:Gem::Version
|
103
|
+
version: '0'
|
104
|
+
type: :development
|
105
|
+
prerelease: false
|
106
|
+
version_requirements: !ruby/object:Gem::Requirement
|
107
|
+
requirements:
|
108
|
+
- - ">="
|
109
|
+
- !ruby/object:Gem::Version
|
110
|
+
version: '0'
|
97
111
|
- !ruby/object:Gem::Dependency
|
98
112
|
name: matplotlib
|
99
113
|
requirement: !ruby/object:Gem::Requirement
|
@@ -109,7 +123,21 @@ dependencies:
|
|
109
123
|
- !ruby/object:Gem::Version
|
110
124
|
version: '0'
|
111
125
|
- !ruby/object:Gem::Dependency
|
112
|
-
name:
|
126
|
+
name: activesupport
|
127
|
+
requirement: !ruby/object:Gem::Requirement
|
128
|
+
requirements:
|
129
|
+
- - ">="
|
130
|
+
- !ruby/object:Gem::Version
|
131
|
+
version: '0'
|
132
|
+
type: :development
|
133
|
+
prerelease: false
|
134
|
+
version_requirements: !ruby/object:Gem::Requirement
|
135
|
+
requirements:
|
136
|
+
- - ">="
|
137
|
+
- !ruby/object:Gem::Version
|
138
|
+
version: '0'
|
139
|
+
- !ruby/object:Gem::Dependency
|
140
|
+
name: tzinfo-data
|
113
141
|
requirement: !ruby/object:Gem::Requirement
|
114
142
|
requirements:
|
115
143
|
- - ">="
|