prophet-rb 0.4.1 → 0.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +13 -0
- data/LICENSE.txt +1 -1
- data/README.md +36 -9
- data/data-raw/LICENSE-holidays.txt +1 -1
- data/data-raw/README.md +1 -1
- data/data-raw/generated_holidays.csv +56365 -32386
- data/lib/prophet/diagnostics.rb +1 -1
- data/lib/prophet/forecaster.rb +32 -24
- data/lib/prophet/holidays.rb +3 -2
- data/lib/prophet/plot.rb +4 -4
- data/lib/prophet/stan_backend.rb +8 -6
- data/lib/prophet/version.rb +1 -1
- data/lib/prophet-rb.rb +1 -1
- data/lib/prophet.rb +10 -10
- metadata +4 -4
data/lib/prophet/diagnostics.rb
CHANGED
@@ -179,7 +179,7 @@ module Prophet
|
|
179
179
|
if metrics.nil?
|
180
180
|
metrics = valid_metrics
|
181
181
|
end
|
182
|
-
if (df
|
182
|
+
if (!df.include?("yhat_lower") || !df.include?("yhat_upper")) && metrics.include?("coverage")
|
183
183
|
metrics.delete("coverage")
|
184
184
|
end
|
185
185
|
if metrics.uniq.length != metrics.length
|
data/lib/prophet/forecaster.rb
CHANGED
@@ -44,7 +44,7 @@ module Prophet
|
|
44
44
|
@yearly_seasonality = yearly_seasonality
|
45
45
|
@weekly_seasonality = weekly_seasonality
|
46
46
|
@daily_seasonality = daily_seasonality
|
47
|
-
@holidays = holidays
|
47
|
+
@holidays = convert_df(holidays)
|
48
48
|
|
49
49
|
@seasonality_mode = seasonality_mode
|
50
50
|
@seasonality_prior_scale = seasonality_prior_scale.to_f
|
@@ -75,6 +75,7 @@ module Prophet
|
|
75
75
|
validate_inputs
|
76
76
|
|
77
77
|
@logger = ::Logger.new($stderr)
|
78
|
+
@logger.level = ::Logger::WARN
|
78
79
|
@logger.formatter = proc do |severity, datetime, progname, msg|
|
79
80
|
"[prophet] #{msg}\n"
|
80
81
|
end
|
@@ -89,7 +90,7 @@ module Prophet
|
|
89
90
|
raise ArgumentError, "Parameter \"changepoint_range\" must be in [0, 1]"
|
90
91
|
end
|
91
92
|
if @holidays
|
92
|
-
if
|
93
|
+
if !(@holidays.is_a?(Rover::DataFrame) && @holidays.include?("ds") && @holidays.include?("holiday"))
|
93
94
|
raise ArgumentError, "holidays must be a DataFrame with \"ds\" and \"holiday\" columns."
|
94
95
|
end
|
95
96
|
@holidays["ds"] = to_datetime(@holidays["ds"])
|
@@ -125,24 +126,24 @@ module Prophet
|
|
125
126
|
"holidays", "zeros", "extra_regressors_additive", "yhat",
|
126
127
|
"extra_regressors_multiplicative", "multiplicative_terms",
|
127
128
|
]
|
128
|
-
rn_l = reserved_names.map { |n| n
|
129
|
-
rn_u = reserved_names.map { |n| n
|
129
|
+
rn_l = reserved_names.map { |n| "#{n}_lower" }
|
130
|
+
rn_u = reserved_names.map { |n| "#{n}_upper" }
|
130
131
|
reserved_names.concat(rn_l)
|
131
132
|
reserved_names.concat(rn_u)
|
132
133
|
reserved_names.concat(["ds", "y", "cap", "floor", "y_scaled", "cap_scaled"])
|
133
134
|
if reserved_names.include?(name)
|
134
135
|
raise ArgumentError, "Name #{name.inspect} is reserved."
|
135
136
|
end
|
136
|
-
if check_holidays && @holidays && @holidays["holiday"].uniq.include?(name)
|
137
|
+
if check_holidays && @holidays && @holidays["holiday"].uniq.to_a.include?(name)
|
137
138
|
raise ArgumentError, "Name #{name.inspect} already used for a holiday."
|
138
139
|
end
|
139
|
-
if check_holidays && @country_holidays && get_holiday_names(@country_holidays).include?(name)
|
140
|
+
if check_holidays && @country_holidays && get_holiday_names(@country_holidays).to_a.include?(name)
|
140
141
|
raise ArgumentError, "Name #{name.inspect} is a holiday name in #{@country_holidays.inspect}."
|
141
142
|
end
|
142
143
|
if check_seasonalities && @seasonalities[name]
|
143
144
|
raise ArgumentError, "Name #{name.inspect} already used for a seasonality."
|
144
145
|
end
|
145
|
-
if check_regressors
|
146
|
+
if check_regressors && @extra_regressors[name]
|
146
147
|
raise ArgumentError, "Name #{name.inspect} already used for an added regressor."
|
147
148
|
end
|
148
149
|
end
|
@@ -167,7 +168,7 @@ module Prophet
|
|
167
168
|
raise ArgumentError, "Found NaN in column #{name.inspect}"
|
168
169
|
end
|
169
170
|
end
|
170
|
-
@seasonalities.
|
171
|
+
@seasonalities.each_value do |props|
|
171
172
|
condition_name = props[:condition_name]
|
172
173
|
if condition_name
|
173
174
|
if !df.include?(condition_name)
|
@@ -481,14 +482,14 @@ module Prophet
|
|
481
482
|
end
|
482
483
|
# Add totals additive and multiplicative components, and regressors
|
483
484
|
["additive", "multiplicative"].each do |mode|
|
484
|
-
components = add_group_component(components, mode
|
485
|
+
components = add_group_component(components, "#{mode}_terms", modes[mode])
|
485
486
|
regressors_by_mode = @extra_regressors.select { |r, props| props[:mode] == mode }
|
486
487
|
.map { |r, props| r }
|
487
|
-
components = add_group_component(components, "extra_regressors_"
|
488
|
+
components = add_group_component(components, "extra_regressors_#{mode}", regressors_by_mode)
|
488
489
|
|
489
490
|
# Add combination components to modes
|
490
|
-
modes[mode] << mode
|
491
|
-
modes[mode] << "extra_regressors_"
|
491
|
+
modes[mode] << "#{mode}_terms"
|
492
|
+
modes[mode] << "extra_regressors_#{mode}"
|
492
493
|
end
|
493
494
|
# After all of the additive/multiplicative groups have been added,
|
494
495
|
modes[@seasonality_mode] << "holidays"
|
@@ -631,9 +632,7 @@ module Prophet
|
|
631
632
|
def fit(df, **kwargs)
|
632
633
|
raise Error, "Prophet object can only be fit once" if @history
|
633
634
|
|
634
|
-
|
635
|
-
df = Rover::DataFrame.new(df.to_h)
|
636
|
-
end
|
635
|
+
df = convert_df(df)
|
637
636
|
raise ArgumentError, "Must be a data frame" unless df.is_a?(Rover::DataFrame)
|
638
637
|
|
639
638
|
unless df.include?("ds") && df.include?("y")
|
@@ -817,8 +816,8 @@ module Prophet
|
|
817
816
|
end
|
818
817
|
data[component] = comp.mean(axis: 1, nan: true)
|
819
818
|
if @uncertainty_samples
|
820
|
-
data[component
|
821
|
-
data[component
|
819
|
+
data["#{component}_lower"] = comp.percentile(lower_p, axis: 1)
|
820
|
+
data["#{component}_upper"] = comp.percentile(upper_p, axis: 1)
|
822
821
|
end
|
823
822
|
end
|
824
823
|
Rover::DataFrame.new(data)
|
@@ -993,6 +992,16 @@ module Prophet
|
|
993
992
|
|
994
993
|
private
|
995
994
|
|
995
|
+
def convert_df(df)
|
996
|
+
if defined?(Daru::DataFrame) && df.is_a?(Daru::DataFrame)
|
997
|
+
Rover::DataFrame.new(df.to_h)
|
998
|
+
elsif defined?(Polars::DataFrame) && df.is_a?(Polars::DataFrame)
|
999
|
+
Rover::DataFrame.new(df.to_h(as_series: false))
|
1000
|
+
else
|
1001
|
+
df
|
1002
|
+
end
|
1003
|
+
end
|
1004
|
+
|
996
1005
|
# Time is preferred over DateTime in Ruby docs
|
997
1006
|
# use UTC to be consistent with Python
|
998
1007
|
# and so days have equal length (no DST)
|
@@ -1077,7 +1086,7 @@ module Prophet
|
|
1077
1086
|
d = {
|
1078
1087
|
"name" => "ds",
|
1079
1088
|
"index" => v.size.times.to_a,
|
1080
|
-
"data" => v.to_a.map { |v| v.iso8601(3) }
|
1089
|
+
"data" => v.to_a.map { |v| v.iso8601(3).chomp("Z") }
|
1081
1090
|
}
|
1082
1091
|
model_dict[attribute] = JSON.generate(d)
|
1083
1092
|
end
|
@@ -1096,7 +1105,7 @@ module Prophet
|
|
1096
1105
|
v = instance_variable_get("@#{attribute}")
|
1097
1106
|
|
1098
1107
|
v = v.dup
|
1099
|
-
v["ds"] = v["ds"].map { |v| v.iso8601(3) } if v["ds"]
|
1108
|
+
v["ds"] = v["ds"].map { |v| v.iso8601(3).chomp("Z") } if v["ds"]
|
1100
1109
|
v.delete("col")
|
1101
1110
|
|
1102
1111
|
fields =
|
@@ -1116,7 +1125,7 @@ module Prophet
|
|
1116
1125
|
d = {
|
1117
1126
|
"schema" => {
|
1118
1127
|
"fields" => fields,
|
1119
|
-
"pandas_version" => "
|
1128
|
+
"pandas_version" => "1.4.0"
|
1120
1129
|
},
|
1121
1130
|
"data" => v.to_a
|
1122
1131
|
}
|
@@ -1151,8 +1160,7 @@ module Prophet
|
|
1151
1160
|
# Params (Dict[str, np.ndarray])
|
1152
1161
|
model_dict["params"] = params.transform_values(&:to_a)
|
1153
1162
|
# Attributes that are skipped: stan_fit, stan_backend
|
1154
|
-
|
1155
|
-
model_dict["__prophet_version"] = "1.0"
|
1163
|
+
model_dict["__prophet_version"] = "1.1.2"
|
1156
1164
|
model_dict
|
1157
1165
|
end
|
1158
1166
|
|
@@ -1174,7 +1182,7 @@ module Prophet
|
|
1174
1182
|
d = JSON.parse(model_dict.fetch(attribute))
|
1175
1183
|
s = Rover::Vector.new(d["data"])
|
1176
1184
|
if d["name"] == "ds"
|
1177
|
-
s = s.map { |v|
|
1185
|
+
s = s.map { |v| DateTime.parse(v).to_time.utc }
|
1178
1186
|
end
|
1179
1187
|
model.instance_variable_set("@#{attribute}", s)
|
1180
1188
|
end
|
@@ -1191,7 +1199,7 @@ module Prophet
|
|
1191
1199
|
else
|
1192
1200
|
d = JSON.parse(model_dict.fetch(attribute))
|
1193
1201
|
df = Rover::DataFrame.new(d["data"])
|
1194
|
-
df["ds"] = df["ds"].map { |v|
|
1202
|
+
df["ds"] = df["ds"].map { |v| DateTime.parse(v).to_time.utc } if df["ds"]
|
1195
1203
|
if attribute == "train_component_cols"
|
1196
1204
|
# Special handling because of named index column
|
1197
1205
|
# df.columns.name = 'component'
|
data/lib/prophet/holidays.rb
CHANGED
@@ -3,8 +3,9 @@ module Prophet
|
|
3
3
|
def get_holiday_names(country)
|
4
4
|
years = (1995..2045).to_a
|
5
5
|
holiday_names = make_holidays_df(years, country)["holiday"].uniq
|
6
|
-
|
7
|
-
|
6
|
+
if holiday_names.size == 0
|
7
|
+
raise ArgumentError, "Holidays in #{country} are not currently supported"
|
8
|
+
end
|
8
9
|
holiday_names
|
9
10
|
end
|
10
11
|
|
data/lib/prophet/plot.rb
CHANGED
@@ -183,7 +183,7 @@ module Prophet
|
|
183
183
|
ax.plot(fcst_t, fcst["floor"].to_a, ls: "--", c: "k")
|
184
184
|
end
|
185
185
|
if uncertainty && @uncertainty_samples
|
186
|
-
artists += [ax.fill_between(fcst_t, fcst[name
|
186
|
+
artists += [ax.fill_between(fcst_t, fcst["#{name}_lower"].to_a, fcst["#{name}_upper"].to_a, color: "#0072B2", alpha: 0.2)]
|
187
187
|
end
|
188
188
|
# Specify formatting to workaround matplotlib issue #12925
|
189
189
|
locator = dates.AutoDateLocator.new(interval_multiples: false)
|
@@ -229,7 +229,7 @@ module Prophet
|
|
229
229
|
days = days.map { |v| v.strftime("%A") }
|
230
230
|
artists += ax.plot(days.size.times.to_a, seas[name].to_a, ls: "-", c: "#0072B2")
|
231
231
|
if uncertainty && @uncertainty_samples
|
232
|
-
artists += [ax.fill_between(days.size.times.to_a, seas[name
|
232
|
+
artists += [ax.fill_between(days.size.times.to_a, seas["#{name}_lower"].to_a, seas["#{name}_upper"].to_a, color: "#0072B2", alpha: 0.2)]
|
233
233
|
end
|
234
234
|
ax.grid(true, which: "major", c: "gray", ls: "-", lw: 1, alpha: 0.2)
|
235
235
|
ax.set_xticks(days.size.times.to_a)
|
@@ -255,7 +255,7 @@ module Prophet
|
|
255
255
|
seas = predict_seasonal_components(df_y)
|
256
256
|
artists += ax.plot(to_pydatetime(df_y["ds"]), seas[name].to_a, ls: "-", c: "#0072B2")
|
257
257
|
if uncertainty && @uncertainty_samples
|
258
|
-
artists += [ax.fill_between(to_pydatetime(df_y["ds"]), seas[name
|
258
|
+
artists += [ax.fill_between(to_pydatetime(df_y["ds"]), seas["#{name}_lower"].to_a, seas["#{name}_upper"].to_a, color: "#0072B2", alpha: 0.2)]
|
259
259
|
end
|
260
260
|
ax.grid(true, which: "major", c: "gray", ls: "-", lw: 1, alpha: 0.2)
|
261
261
|
months = dates.MonthLocator.new((1..12).to_a, bymonthday: 1, interval: 2)
|
@@ -288,7 +288,7 @@ module Prophet
|
|
288
288
|
seas = predict_seasonal_components(df_y)
|
289
289
|
artists += ax.plot(to_pydatetime(df_y["ds"]), seas[name].to_a, ls: "-", c: "#0072B2")
|
290
290
|
if uncertainty && @uncertainty_samples
|
291
|
-
artists += [ax.fill_between(to_pydatetime(df_y["ds"]), seas[name
|
291
|
+
artists += [ax.fill_between(to_pydatetime(df_y["ds"]), seas["#{name}_lower"].to_a, seas["#{name}_upper"].to_a, color: "#0072B2", alpha: 0.2)]
|
292
292
|
end
|
293
293
|
ax.grid(true, which: "major", c: "gray", ls: "-", lw: 1, alpha: 0.2)
|
294
294
|
step = (finish - start) / (7 - 1).to_f
|
data/lib/prophet/stan_backend.rb
CHANGED
@@ -21,14 +21,16 @@ module Prophet
|
|
21
21
|
kwargs[:algorithm] ||= stan_data["T"] < 100 ? "Newton" : "LBFGS"
|
22
22
|
iterations = 10000
|
23
23
|
|
24
|
+
args = {
|
25
|
+
data: stan_data,
|
26
|
+
inits: stan_init,
|
27
|
+
iter: iterations
|
28
|
+
}
|
29
|
+
args.merge!(kwargs)
|
30
|
+
|
24
31
|
stan_fit = nil
|
25
32
|
begin
|
26
|
-
stan_fit = @model.optimize(
|
27
|
-
data: stan_data,
|
28
|
-
inits: stan_init,
|
29
|
-
iter: iterations,
|
30
|
-
**kwargs
|
31
|
-
)
|
33
|
+
stan_fit = @model.optimize(**args)
|
32
34
|
rescue => e
|
33
35
|
if kwargs[:algorithm] != "Newton"
|
34
36
|
@logger.warn "Optimization terminated abnormally. Falling back to Newton."
|
data/lib/prophet/version.rb
CHANGED
data/lib/prophet-rb.rb
CHANGED
@@ -1 +1 @@
|
|
1
|
-
|
1
|
+
require_relative "prophet"
|
data/lib/prophet.rb
CHANGED
@@ -1,19 +1,19 @@
|
|
1
1
|
# dependencies
|
2
2
|
require "cmdstan"
|
3
|
-
require "rover"
|
4
3
|
require "numo/narray"
|
4
|
+
require "rover"
|
5
5
|
|
6
6
|
# stdlib
|
7
7
|
require "logger"
|
8
8
|
require "set"
|
9
9
|
|
10
10
|
# modules
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
11
|
+
require_relative "prophet/diagnostics"
|
12
|
+
require_relative "prophet/holidays"
|
13
|
+
require_relative "prophet/plot"
|
14
|
+
require_relative "prophet/forecaster"
|
15
|
+
require_relative "prophet/stan_backend"
|
16
|
+
require_relative "prophet/version"
|
17
17
|
|
18
18
|
module Prophet
|
19
19
|
class Error < StandardError; end
|
@@ -68,7 +68,7 @@ module Prophet
|
|
68
68
|
df = Rover::DataFrame.new({"ds" => series.keys, "y" => series.values})
|
69
69
|
df["cap"] = cap if cap
|
70
70
|
|
71
|
-
m.logger.level = ::Logger::
|
71
|
+
m.logger.level = verbose ? ::Logger::INFO : ::Logger::FATAL
|
72
72
|
m.add_country_holidays(country_holidays) if country_holidays
|
73
73
|
m.fit(df)
|
74
74
|
|
@@ -87,7 +87,7 @@ module Prophet
|
|
87
87
|
else
|
88
88
|
result.each { |v| v["ds"] = v["ds"].localtime }
|
89
89
|
end
|
90
|
-
result.
|
90
|
+
result.to_h { |v| [v["ds"], v["yhat"]] }
|
91
91
|
end
|
92
92
|
|
93
93
|
# TODO better name for interval_width
|
@@ -97,7 +97,7 @@ module Prophet
|
|
97
97
|
df["cap"] = cap if cap
|
98
98
|
|
99
99
|
m = Prophet.new(interval_width: interval_width, **options)
|
100
|
-
m.logger.level = ::Logger::
|
100
|
+
m.logger.level = verbose ? ::Logger::INFO : ::Logger::FATAL
|
101
101
|
m.add_country_holidays(country_holidays) if country_holidays
|
102
102
|
m.fit(df)
|
103
103
|
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: prophet-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.5.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2023-09-05 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: cmdstan
|
@@ -159,14 +159,14 @@ required_ruby_version: !ruby/object:Gem::Requirement
|
|
159
159
|
requirements:
|
160
160
|
- - ">="
|
161
161
|
- !ruby/object:Gem::Version
|
162
|
-
version: '
|
162
|
+
version: '3'
|
163
163
|
required_rubygems_version: !ruby/object:Gem::Requirement
|
164
164
|
requirements:
|
165
165
|
- - ">="
|
166
166
|
- !ruby/object:Gem::Version
|
167
167
|
version: '0'
|
168
168
|
requirements: []
|
169
|
-
rubygems_version: 3.
|
169
|
+
rubygems_version: 3.4.10
|
170
170
|
signing_key:
|
171
171
|
specification_version: 4
|
172
172
|
summary: Time series forecasting for Ruby
|