prophet-rb 0.4.1 → 0.5.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +13 -0
- data/LICENSE.txt +1 -1
- data/README.md +36 -9
- data/data-raw/LICENSE-holidays.txt +1 -1
- data/data-raw/README.md +1 -1
- data/data-raw/generated_holidays.csv +56365 -32386
- data/lib/prophet/diagnostics.rb +1 -1
- data/lib/prophet/forecaster.rb +32 -24
- data/lib/prophet/holidays.rb +3 -2
- data/lib/prophet/plot.rb +4 -4
- data/lib/prophet/stan_backend.rb +8 -6
- data/lib/prophet/version.rb +1 -1
- data/lib/prophet-rb.rb +1 -1
- data/lib/prophet.rb +10 -10
- metadata +4 -4
data/lib/prophet/diagnostics.rb
CHANGED
@@ -179,7 +179,7 @@ module Prophet
|
|
179
179
|
if metrics.nil?
|
180
180
|
metrics = valid_metrics
|
181
181
|
end
|
182
|
-
if (df
|
182
|
+
if (!df.include?("yhat_lower") || !df.include?("yhat_upper")) && metrics.include?("coverage")
|
183
183
|
metrics.delete("coverage")
|
184
184
|
end
|
185
185
|
if metrics.uniq.length != metrics.length
|
data/lib/prophet/forecaster.rb
CHANGED
@@ -44,7 +44,7 @@ module Prophet
|
|
44
44
|
@yearly_seasonality = yearly_seasonality
|
45
45
|
@weekly_seasonality = weekly_seasonality
|
46
46
|
@daily_seasonality = daily_seasonality
|
47
|
-
@holidays = holidays
|
47
|
+
@holidays = convert_df(holidays)
|
48
48
|
|
49
49
|
@seasonality_mode = seasonality_mode
|
50
50
|
@seasonality_prior_scale = seasonality_prior_scale.to_f
|
@@ -75,6 +75,7 @@ module Prophet
|
|
75
75
|
validate_inputs
|
76
76
|
|
77
77
|
@logger = ::Logger.new($stderr)
|
78
|
+
@logger.level = ::Logger::WARN
|
78
79
|
@logger.formatter = proc do |severity, datetime, progname, msg|
|
79
80
|
"[prophet] #{msg}\n"
|
80
81
|
end
|
@@ -89,7 +90,7 @@ module Prophet
|
|
89
90
|
raise ArgumentError, "Parameter \"changepoint_range\" must be in [0, 1]"
|
90
91
|
end
|
91
92
|
if @holidays
|
92
|
-
if
|
93
|
+
if !(@holidays.is_a?(Rover::DataFrame) && @holidays.include?("ds") && @holidays.include?("holiday"))
|
93
94
|
raise ArgumentError, "holidays must be a DataFrame with \"ds\" and \"holiday\" columns."
|
94
95
|
end
|
95
96
|
@holidays["ds"] = to_datetime(@holidays["ds"])
|
@@ -125,24 +126,24 @@ module Prophet
|
|
125
126
|
"holidays", "zeros", "extra_regressors_additive", "yhat",
|
126
127
|
"extra_regressors_multiplicative", "multiplicative_terms",
|
127
128
|
]
|
128
|
-
rn_l = reserved_names.map { |n| n
|
129
|
-
rn_u = reserved_names.map { |n| n
|
129
|
+
rn_l = reserved_names.map { |n| "#{n}_lower" }
|
130
|
+
rn_u = reserved_names.map { |n| "#{n}_upper" }
|
130
131
|
reserved_names.concat(rn_l)
|
131
132
|
reserved_names.concat(rn_u)
|
132
133
|
reserved_names.concat(["ds", "y", "cap", "floor", "y_scaled", "cap_scaled"])
|
133
134
|
if reserved_names.include?(name)
|
134
135
|
raise ArgumentError, "Name #{name.inspect} is reserved."
|
135
136
|
end
|
136
|
-
if check_holidays && @holidays && @holidays["holiday"].uniq.include?(name)
|
137
|
+
if check_holidays && @holidays && @holidays["holiday"].uniq.to_a.include?(name)
|
137
138
|
raise ArgumentError, "Name #{name.inspect} already used for a holiday."
|
138
139
|
end
|
139
|
-
if check_holidays && @country_holidays && get_holiday_names(@country_holidays).include?(name)
|
140
|
+
if check_holidays && @country_holidays && get_holiday_names(@country_holidays).to_a.include?(name)
|
140
141
|
raise ArgumentError, "Name #{name.inspect} is a holiday name in #{@country_holidays.inspect}."
|
141
142
|
end
|
142
143
|
if check_seasonalities && @seasonalities[name]
|
143
144
|
raise ArgumentError, "Name #{name.inspect} already used for a seasonality."
|
144
145
|
end
|
145
|
-
if check_regressors
|
146
|
+
if check_regressors && @extra_regressors[name]
|
146
147
|
raise ArgumentError, "Name #{name.inspect} already used for an added regressor."
|
147
148
|
end
|
148
149
|
end
|
@@ -167,7 +168,7 @@ module Prophet
|
|
167
168
|
raise ArgumentError, "Found NaN in column #{name.inspect}"
|
168
169
|
end
|
169
170
|
end
|
170
|
-
@seasonalities.
|
171
|
+
@seasonalities.each_value do |props|
|
171
172
|
condition_name = props[:condition_name]
|
172
173
|
if condition_name
|
173
174
|
if !df.include?(condition_name)
|
@@ -481,14 +482,14 @@ module Prophet
|
|
481
482
|
end
|
482
483
|
# Add totals additive and multiplicative components, and regressors
|
483
484
|
["additive", "multiplicative"].each do |mode|
|
484
|
-
components = add_group_component(components, mode
|
485
|
+
components = add_group_component(components, "#{mode}_terms", modes[mode])
|
485
486
|
regressors_by_mode = @extra_regressors.select { |r, props| props[:mode] == mode }
|
486
487
|
.map { |r, props| r }
|
487
|
-
components = add_group_component(components, "extra_regressors_"
|
488
|
+
components = add_group_component(components, "extra_regressors_#{mode}", regressors_by_mode)
|
488
489
|
|
489
490
|
# Add combination components to modes
|
490
|
-
modes[mode] << mode
|
491
|
-
modes[mode] << "extra_regressors_"
|
491
|
+
modes[mode] << "#{mode}_terms"
|
492
|
+
modes[mode] << "extra_regressors_#{mode}"
|
492
493
|
end
|
493
494
|
# After all of the additive/multiplicative groups have been added,
|
494
495
|
modes[@seasonality_mode] << "holidays"
|
@@ -631,9 +632,7 @@ module Prophet
|
|
631
632
|
def fit(df, **kwargs)
|
632
633
|
raise Error, "Prophet object can only be fit once" if @history
|
633
634
|
|
634
|
-
|
635
|
-
df = Rover::DataFrame.new(df.to_h)
|
636
|
-
end
|
635
|
+
df = convert_df(df)
|
637
636
|
raise ArgumentError, "Must be a data frame" unless df.is_a?(Rover::DataFrame)
|
638
637
|
|
639
638
|
unless df.include?("ds") && df.include?("y")
|
@@ -817,8 +816,8 @@ module Prophet
|
|
817
816
|
end
|
818
817
|
data[component] = comp.mean(axis: 1, nan: true)
|
819
818
|
if @uncertainty_samples
|
820
|
-
data[component
|
821
|
-
data[component
|
819
|
+
data["#{component}_lower"] = comp.percentile(lower_p, axis: 1)
|
820
|
+
data["#{component}_upper"] = comp.percentile(upper_p, axis: 1)
|
822
821
|
end
|
823
822
|
end
|
824
823
|
Rover::DataFrame.new(data)
|
@@ -993,6 +992,16 @@ module Prophet
|
|
993
992
|
|
994
993
|
private
|
995
994
|
|
995
|
+
def convert_df(df)
|
996
|
+
if defined?(Daru::DataFrame) && df.is_a?(Daru::DataFrame)
|
997
|
+
Rover::DataFrame.new(df.to_h)
|
998
|
+
elsif defined?(Polars::DataFrame) && df.is_a?(Polars::DataFrame)
|
999
|
+
Rover::DataFrame.new(df.to_h(as_series: false))
|
1000
|
+
else
|
1001
|
+
df
|
1002
|
+
end
|
1003
|
+
end
|
1004
|
+
|
996
1005
|
# Time is preferred over DateTime in Ruby docs
|
997
1006
|
# use UTC to be consistent with Python
|
998
1007
|
# and so days have equal length (no DST)
|
@@ -1077,7 +1086,7 @@ module Prophet
|
|
1077
1086
|
d = {
|
1078
1087
|
"name" => "ds",
|
1079
1088
|
"index" => v.size.times.to_a,
|
1080
|
-
"data" => v.to_a.map { |v| v.iso8601(3) }
|
1089
|
+
"data" => v.to_a.map { |v| v.iso8601(3).chomp("Z") }
|
1081
1090
|
}
|
1082
1091
|
model_dict[attribute] = JSON.generate(d)
|
1083
1092
|
end
|
@@ -1096,7 +1105,7 @@ module Prophet
|
|
1096
1105
|
v = instance_variable_get("@#{attribute}")
|
1097
1106
|
|
1098
1107
|
v = v.dup
|
1099
|
-
v["ds"] = v["ds"].map { |v| v.iso8601(3) } if v["ds"]
|
1108
|
+
v["ds"] = v["ds"].map { |v| v.iso8601(3).chomp("Z") } if v["ds"]
|
1100
1109
|
v.delete("col")
|
1101
1110
|
|
1102
1111
|
fields =
|
@@ -1116,7 +1125,7 @@ module Prophet
|
|
1116
1125
|
d = {
|
1117
1126
|
"schema" => {
|
1118
1127
|
"fields" => fields,
|
1119
|
-
"pandas_version" => "
|
1128
|
+
"pandas_version" => "1.4.0"
|
1120
1129
|
},
|
1121
1130
|
"data" => v.to_a
|
1122
1131
|
}
|
@@ -1151,8 +1160,7 @@ module Prophet
|
|
1151
1160
|
# Params (Dict[str, np.ndarray])
|
1152
1161
|
model_dict["params"] = params.transform_values(&:to_a)
|
1153
1162
|
# Attributes that are skipped: stan_fit, stan_backend
|
1154
|
-
|
1155
|
-
model_dict["__prophet_version"] = "1.0"
|
1163
|
+
model_dict["__prophet_version"] = "1.1.2"
|
1156
1164
|
model_dict
|
1157
1165
|
end
|
1158
1166
|
|
@@ -1174,7 +1182,7 @@ module Prophet
|
|
1174
1182
|
d = JSON.parse(model_dict.fetch(attribute))
|
1175
1183
|
s = Rover::Vector.new(d["data"])
|
1176
1184
|
if d["name"] == "ds"
|
1177
|
-
s = s.map { |v|
|
1185
|
+
s = s.map { |v| DateTime.parse(v).to_time.utc }
|
1178
1186
|
end
|
1179
1187
|
model.instance_variable_set("@#{attribute}", s)
|
1180
1188
|
end
|
@@ -1191,7 +1199,7 @@ module Prophet
|
|
1191
1199
|
else
|
1192
1200
|
d = JSON.parse(model_dict.fetch(attribute))
|
1193
1201
|
df = Rover::DataFrame.new(d["data"])
|
1194
|
-
df["ds"] = df["ds"].map { |v|
|
1202
|
+
df["ds"] = df["ds"].map { |v| DateTime.parse(v).to_time.utc } if df["ds"]
|
1195
1203
|
if attribute == "train_component_cols"
|
1196
1204
|
# Special handling because of named index column
|
1197
1205
|
# df.columns.name = 'component'
|
data/lib/prophet/holidays.rb
CHANGED
@@ -3,8 +3,9 @@ module Prophet
|
|
3
3
|
def get_holiday_names(country)
|
4
4
|
years = (1995..2045).to_a
|
5
5
|
holiday_names = make_holidays_df(years, country)["holiday"].uniq
|
6
|
-
|
7
|
-
|
6
|
+
if holiday_names.size == 0
|
7
|
+
raise ArgumentError, "Holidays in #{country} are not currently supported"
|
8
|
+
end
|
8
9
|
holiday_names
|
9
10
|
end
|
10
11
|
|
data/lib/prophet/plot.rb
CHANGED
@@ -183,7 +183,7 @@ module Prophet
|
|
183
183
|
ax.plot(fcst_t, fcst["floor"].to_a, ls: "--", c: "k")
|
184
184
|
end
|
185
185
|
if uncertainty && @uncertainty_samples
|
186
|
-
artists += [ax.fill_between(fcst_t, fcst[name
|
186
|
+
artists += [ax.fill_between(fcst_t, fcst["#{name}_lower"].to_a, fcst["#{name}_upper"].to_a, color: "#0072B2", alpha: 0.2)]
|
187
187
|
end
|
188
188
|
# Specify formatting to workaround matplotlib issue #12925
|
189
189
|
locator = dates.AutoDateLocator.new(interval_multiples: false)
|
@@ -229,7 +229,7 @@ module Prophet
|
|
229
229
|
days = days.map { |v| v.strftime("%A") }
|
230
230
|
artists += ax.plot(days.size.times.to_a, seas[name].to_a, ls: "-", c: "#0072B2")
|
231
231
|
if uncertainty && @uncertainty_samples
|
232
|
-
artists += [ax.fill_between(days.size.times.to_a, seas[name
|
232
|
+
artists += [ax.fill_between(days.size.times.to_a, seas["#{name}_lower"].to_a, seas["#{name}_upper"].to_a, color: "#0072B2", alpha: 0.2)]
|
233
233
|
end
|
234
234
|
ax.grid(true, which: "major", c: "gray", ls: "-", lw: 1, alpha: 0.2)
|
235
235
|
ax.set_xticks(days.size.times.to_a)
|
@@ -255,7 +255,7 @@ module Prophet
|
|
255
255
|
seas = predict_seasonal_components(df_y)
|
256
256
|
artists += ax.plot(to_pydatetime(df_y["ds"]), seas[name].to_a, ls: "-", c: "#0072B2")
|
257
257
|
if uncertainty && @uncertainty_samples
|
258
|
-
artists += [ax.fill_between(to_pydatetime(df_y["ds"]), seas[name
|
258
|
+
artists += [ax.fill_between(to_pydatetime(df_y["ds"]), seas["#{name}_lower"].to_a, seas["#{name}_upper"].to_a, color: "#0072B2", alpha: 0.2)]
|
259
259
|
end
|
260
260
|
ax.grid(true, which: "major", c: "gray", ls: "-", lw: 1, alpha: 0.2)
|
261
261
|
months = dates.MonthLocator.new((1..12).to_a, bymonthday: 1, interval: 2)
|
@@ -288,7 +288,7 @@ module Prophet
|
|
288
288
|
seas = predict_seasonal_components(df_y)
|
289
289
|
artists += ax.plot(to_pydatetime(df_y["ds"]), seas[name].to_a, ls: "-", c: "#0072B2")
|
290
290
|
if uncertainty && @uncertainty_samples
|
291
|
-
artists += [ax.fill_between(to_pydatetime(df_y["ds"]), seas[name
|
291
|
+
artists += [ax.fill_between(to_pydatetime(df_y["ds"]), seas["#{name}_lower"].to_a, seas["#{name}_upper"].to_a, color: "#0072B2", alpha: 0.2)]
|
292
292
|
end
|
293
293
|
ax.grid(true, which: "major", c: "gray", ls: "-", lw: 1, alpha: 0.2)
|
294
294
|
step = (finish - start) / (7 - 1).to_f
|
data/lib/prophet/stan_backend.rb
CHANGED
@@ -21,14 +21,16 @@ module Prophet
|
|
21
21
|
kwargs[:algorithm] ||= stan_data["T"] < 100 ? "Newton" : "LBFGS"
|
22
22
|
iterations = 10000
|
23
23
|
|
24
|
+
args = {
|
25
|
+
data: stan_data,
|
26
|
+
inits: stan_init,
|
27
|
+
iter: iterations
|
28
|
+
}
|
29
|
+
args.merge!(kwargs)
|
30
|
+
|
24
31
|
stan_fit = nil
|
25
32
|
begin
|
26
|
-
stan_fit = @model.optimize(
|
27
|
-
data: stan_data,
|
28
|
-
inits: stan_init,
|
29
|
-
iter: iterations,
|
30
|
-
**kwargs
|
31
|
-
)
|
33
|
+
stan_fit = @model.optimize(**args)
|
32
34
|
rescue => e
|
33
35
|
if kwargs[:algorithm] != "Newton"
|
34
36
|
@logger.warn "Optimization terminated abnormally. Falling back to Newton."
|
data/lib/prophet/version.rb
CHANGED
data/lib/prophet-rb.rb
CHANGED
@@ -1 +1 @@
|
|
1
|
-
|
1
|
+
require_relative "prophet"
|
data/lib/prophet.rb
CHANGED
@@ -1,19 +1,19 @@
|
|
1
1
|
# dependencies
|
2
2
|
require "cmdstan"
|
3
|
-
require "rover"
|
4
3
|
require "numo/narray"
|
4
|
+
require "rover"
|
5
5
|
|
6
6
|
# stdlib
|
7
7
|
require "logger"
|
8
8
|
require "set"
|
9
9
|
|
10
10
|
# modules
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
11
|
+
require_relative "prophet/diagnostics"
|
12
|
+
require_relative "prophet/holidays"
|
13
|
+
require_relative "prophet/plot"
|
14
|
+
require_relative "prophet/forecaster"
|
15
|
+
require_relative "prophet/stan_backend"
|
16
|
+
require_relative "prophet/version"
|
17
17
|
|
18
18
|
module Prophet
|
19
19
|
class Error < StandardError; end
|
@@ -68,7 +68,7 @@ module Prophet
|
|
68
68
|
df = Rover::DataFrame.new({"ds" => series.keys, "y" => series.values})
|
69
69
|
df["cap"] = cap if cap
|
70
70
|
|
71
|
-
m.logger.level = ::Logger::
|
71
|
+
m.logger.level = verbose ? ::Logger::INFO : ::Logger::FATAL
|
72
72
|
m.add_country_holidays(country_holidays) if country_holidays
|
73
73
|
m.fit(df)
|
74
74
|
|
@@ -87,7 +87,7 @@ module Prophet
|
|
87
87
|
else
|
88
88
|
result.each { |v| v["ds"] = v["ds"].localtime }
|
89
89
|
end
|
90
|
-
result.
|
90
|
+
result.to_h { |v| [v["ds"], v["yhat"]] }
|
91
91
|
end
|
92
92
|
|
93
93
|
# TODO better name for interval_width
|
@@ -97,7 +97,7 @@ module Prophet
|
|
97
97
|
df["cap"] = cap if cap
|
98
98
|
|
99
99
|
m = Prophet.new(interval_width: interval_width, **options)
|
100
|
-
m.logger.level = ::Logger::
|
100
|
+
m.logger.level = verbose ? ::Logger::INFO : ::Logger::FATAL
|
101
101
|
m.add_country_holidays(country_holidays) if country_holidays
|
102
102
|
m.fit(df)
|
103
103
|
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: prophet-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.5.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2023-09-05 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: cmdstan
|
@@ -159,14 +159,14 @@ required_ruby_version: !ruby/object:Gem::Requirement
|
|
159
159
|
requirements:
|
160
160
|
- - ">="
|
161
161
|
- !ruby/object:Gem::Version
|
162
|
-
version: '
|
162
|
+
version: '3'
|
163
163
|
required_rubygems_version: !ruby/object:Gem::Requirement
|
164
164
|
requirements:
|
165
165
|
- - ">="
|
166
166
|
- !ruby/object:Gem::Version
|
167
167
|
version: '0'
|
168
168
|
requirements: []
|
169
|
-
rubygems_version: 3.
|
169
|
+
rubygems_version: 3.4.10
|
170
170
|
signing_key:
|
171
171
|
specification_version: 4
|
172
172
|
summary: Time series forecasting for Ruby
|