prophet-rb 0.4.2 → 0.5.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +15 -0
- data/LICENSE.txt +1 -1
- data/README.md +4 -4
- data/data-raw/LICENSE-holidays.txt +1 -1
- data/data-raw/README.md +1 -1
- data/data-raw/generated_holidays.csv +56365 -32386
- data/lib/prophet/forecaster.rb +55 -21
- data/lib/prophet/holidays.rb +1 -3
- data/lib/prophet/stan_backend.rb +8 -6
- data/lib/prophet/version.rb +1 -1
- data/lib/prophet-rb.rb +1 -1
- data/lib/prophet.rb +10 -10
- metadata +4 -4
data/lib/prophet/forecaster.rb
CHANGED
@@ -27,7 +27,8 @@ module Prophet
|
|
27
27
|
changepoint_prior_scale: 0.05,
|
28
28
|
mcmc_samples: 0,
|
29
29
|
interval_width: 0.80,
|
30
|
-
uncertainty_samples: 1000
|
30
|
+
uncertainty_samples: 1000,
|
31
|
+
scaling: "absmax"
|
31
32
|
)
|
32
33
|
@growth = growth
|
33
34
|
|
@@ -44,7 +45,7 @@ module Prophet
|
|
44
45
|
@yearly_seasonality = yearly_seasonality
|
45
46
|
@weekly_seasonality = weekly_seasonality
|
46
47
|
@daily_seasonality = daily_seasonality
|
47
|
-
@holidays = holidays
|
48
|
+
@holidays = convert_df(holidays)
|
48
49
|
|
49
50
|
@seasonality_mode = seasonality_mode
|
50
51
|
@seasonality_prior_scale = seasonality_prior_scale.to_f
|
@@ -54,6 +55,10 @@ module Prophet
|
|
54
55
|
@mcmc_samples = mcmc_samples
|
55
56
|
@interval_width = interval_width
|
56
57
|
@uncertainty_samples = uncertainty_samples
|
58
|
+
if !["absmax", "minmax"].include?(scaling)
|
59
|
+
raise ArgumentError, "scaling must be one of \"absmax\" or \"minmax\""
|
60
|
+
end
|
61
|
+
@scaling = scaling
|
57
62
|
|
58
63
|
# Set during fitting or by other methods
|
59
64
|
@start = nil
|
@@ -75,6 +80,7 @@ module Prophet
|
|
75
80
|
validate_inputs
|
76
81
|
|
77
82
|
@logger = ::Logger.new($stderr)
|
83
|
+
@logger.level = ::Logger::WARN
|
78
84
|
@logger.formatter = proc do |severity, datetime, progname, msg|
|
79
85
|
"[prophet] #{msg}\n"
|
80
86
|
end
|
@@ -133,10 +139,10 @@ module Prophet
|
|
133
139
|
if reserved_names.include?(name)
|
134
140
|
raise ArgumentError, "Name #{name.inspect} is reserved."
|
135
141
|
end
|
136
|
-
if check_holidays && @holidays && @holidays["holiday"].uniq.include?(name)
|
142
|
+
if check_holidays && @holidays && @holidays["holiday"].uniq.to_a.include?(name)
|
137
143
|
raise ArgumentError, "Name #{name.inspect} already used for a holiday."
|
138
144
|
end
|
139
|
-
if check_holidays && @country_holidays && get_holiday_names(@country_holidays).include?(name)
|
145
|
+
if check_holidays && @country_holidays && get_holiday_names(@country_holidays).to_a.include?(name)
|
140
146
|
raise ArgumentError, "Name #{name.inspect} is a holiday name in #{@country_holidays.inspect}."
|
141
147
|
end
|
142
148
|
if check_seasonalities && @seasonalities[name]
|
@@ -188,7 +194,11 @@ module Prophet
|
|
188
194
|
raise ArgumentError, "Expected column \"floor\"."
|
189
195
|
end
|
190
196
|
else
|
191
|
-
|
197
|
+
if @scaling == "absmax"
|
198
|
+
df["floor"] = 0
|
199
|
+
elsif @scaling == "minmax"
|
200
|
+
df["floor"] = @y_min
|
201
|
+
end
|
192
202
|
end
|
193
203
|
|
194
204
|
if @growth == "logistic"
|
@@ -218,11 +228,22 @@ module Prophet
|
|
218
228
|
|
219
229
|
if @growth == "logistic" && df.include?("floor")
|
220
230
|
@logistic_floor = true
|
221
|
-
|
231
|
+
if @scaling == "absmax"
|
232
|
+
@y_min = (df["y"] - df["floor"]).abs.min.to_f
|
233
|
+
@y_scale = (df["y"] - df["floor"]).abs.max.to_f
|
234
|
+
elsif @scaling == "minmax"
|
235
|
+
@y_min = df["floor"].min
|
236
|
+
@y_scale = (df["cap"].max - @y_min).to_f
|
237
|
+
end
|
222
238
|
else
|
223
|
-
|
239
|
+
if @scaling == "absmax"
|
240
|
+
@y_min = 0.0
|
241
|
+
@y_scale = df["y"].abs.max.to_f
|
242
|
+
elsif @scaling == "minmax"
|
243
|
+
@y_min = df["y"].min
|
244
|
+
@y_scale = (df["y"].max - @y_min).to_f
|
245
|
+
end
|
224
246
|
end
|
225
|
-
@y_scale = (df["y"] - floor).abs.max
|
226
247
|
@y_scale = 1 if @y_scale == 0
|
227
248
|
@start = df["ds"].min
|
228
249
|
@t_scale = df["ds"].max - @start
|
@@ -546,7 +567,7 @@ module Prophet
|
|
546
567
|
days = 86400
|
547
568
|
|
548
569
|
# Yearly seasonality
|
549
|
-
yearly_disable = last - first <
|
570
|
+
yearly_disable = last - first < 730 * days
|
550
571
|
fourier_order = parse_seasonality_args("yearly", @yearly_seasonality, yearly_disable, 10)
|
551
572
|
if fourier_order > 0
|
552
573
|
@seasonalities["yearly"] = {
|
@@ -631,9 +652,7 @@ module Prophet
|
|
631
652
|
def fit(df, **kwargs)
|
632
653
|
raise Error, "Prophet object can only be fit once" if @history
|
633
654
|
|
634
|
-
|
635
|
-
df = Rover::DataFrame.new(df.to_h)
|
636
|
-
end
|
655
|
+
df = convert_df(df)
|
637
656
|
raise ArgumentError, "Must be a data frame" unless df.is_a?(Rover::DataFrame)
|
638
657
|
|
639
658
|
unless df.include?("ds") && df.include?("y")
|
@@ -808,7 +827,7 @@ module Prophet
|
|
808
827
|
|
809
828
|
x = seasonal_features.to_numo
|
810
829
|
data = {}
|
811
|
-
component_cols.vector_names.each do |component|
|
830
|
+
(component_cols.vector_names - ["col"]).each do |component|
|
812
831
|
beta_c = @params["beta"] * component_cols[component].to_numo
|
813
832
|
|
814
833
|
comp = x.dot(beta_c.transpose)
|
@@ -993,6 +1012,16 @@ module Prophet
|
|
993
1012
|
|
994
1013
|
private
|
995
1014
|
|
1015
|
+
def convert_df(df)
|
1016
|
+
if defined?(Daru::DataFrame) && df.is_a?(Daru::DataFrame)
|
1017
|
+
Rover::DataFrame.new(df.to_h)
|
1018
|
+
elsif defined?(Polars::DataFrame) && df.is_a?(Polars::DataFrame)
|
1019
|
+
Rover::DataFrame.new(df.to_h(as_series: false))
|
1020
|
+
else
|
1021
|
+
df
|
1022
|
+
end
|
1023
|
+
end
|
1024
|
+
|
996
1025
|
# Time is preferred over DateTime in Ruby docs
|
997
1026
|
# use UTC to be consistent with Python
|
998
1027
|
# and so days have equal length (no DST)
|
@@ -1043,7 +1072,7 @@ module Prophet
|
|
1043
1072
|
"yearly_seasonality", "weekly_seasonality", "daily_seasonality",
|
1044
1073
|
"seasonality_mode", "seasonality_prior_scale", "changepoint_prior_scale",
|
1045
1074
|
"holidays_prior_scale", "mcmc_samples", "interval_width", "uncertainty_samples",
|
1046
|
-
"y_scale", "logistic_floor", "country_holidays", "component_modes"
|
1075
|
+
"y_scale", "y_min", "scaling", "logistic_floor", "country_holidays", "component_modes"
|
1047
1076
|
]
|
1048
1077
|
|
1049
1078
|
PD_SERIES = ["changepoints", "history_dates", "train_holiday_names"]
|
@@ -1077,7 +1106,7 @@ module Prophet
|
|
1077
1106
|
d = {
|
1078
1107
|
"name" => "ds",
|
1079
1108
|
"index" => v.size.times.to_a,
|
1080
|
-
"data" => v.to_a.map { |v| v.iso8601(3) }
|
1109
|
+
"data" => v.to_a.map { |v| v.iso8601(3).chomp("Z") }
|
1081
1110
|
}
|
1082
1111
|
model_dict[attribute] = JSON.generate(d)
|
1083
1112
|
end
|
@@ -1096,7 +1125,7 @@ module Prophet
|
|
1096
1125
|
v = instance_variable_get("@#{attribute}")
|
1097
1126
|
|
1098
1127
|
v = v.dup
|
1099
|
-
v["ds"] = v["ds"].map { |v| v.iso8601(3) } if v["ds"]
|
1128
|
+
v["ds"] = v["ds"].map { |v| v.iso8601(3).chomp("Z") } if v["ds"]
|
1100
1129
|
v.delete("col")
|
1101
1130
|
|
1102
1131
|
fields =
|
@@ -1116,7 +1145,7 @@ module Prophet
|
|
1116
1145
|
d = {
|
1117
1146
|
"schema" => {
|
1118
1147
|
"fields" => fields,
|
1119
|
-
"pandas_version" => "
|
1148
|
+
"pandas_version" => "1.4.0"
|
1120
1149
|
},
|
1121
1150
|
"data" => v.to_a
|
1122
1151
|
}
|
@@ -1151,8 +1180,7 @@ module Prophet
|
|
1151
1180
|
# Params (Dict[str, np.ndarray])
|
1152
1181
|
model_dict["params"] = params.transform_values(&:to_a)
|
1153
1182
|
# Attributes that are skipped: stan_fit, stan_backend
|
1154
|
-
|
1155
|
-
model_dict["__prophet_version"] = "1.0"
|
1183
|
+
model_dict["__prophet_version"] = "1.1.2"
|
1156
1184
|
model_dict
|
1157
1185
|
end
|
1158
1186
|
|
@@ -1161,6 +1189,12 @@ module Prophet
|
|
1161
1189
|
|
1162
1190
|
model_dict = JSON.parse(model_json)
|
1163
1191
|
|
1192
|
+
# handle_simple_attributes_backwards_compat
|
1193
|
+
if !model_dict["scaling"]
|
1194
|
+
model_dict["scaling"] = "absmax"
|
1195
|
+
model_dict["y_min"] = 0.0
|
1196
|
+
end
|
1197
|
+
|
1164
1198
|
# We will overwrite all attributes set in init anyway
|
1165
1199
|
model = Prophet.new
|
1166
1200
|
# Simple types
|
@@ -1174,7 +1208,7 @@ module Prophet
|
|
1174
1208
|
d = JSON.parse(model_dict.fetch(attribute))
|
1175
1209
|
s = Rover::Vector.new(d["data"])
|
1176
1210
|
if d["name"] == "ds"
|
1177
|
-
s = s.map { |v|
|
1211
|
+
s = s.map { |v| DateTime.parse(v).to_time.utc }
|
1178
1212
|
end
|
1179
1213
|
model.instance_variable_set("@#{attribute}", s)
|
1180
1214
|
end
|
@@ -1191,7 +1225,7 @@ module Prophet
|
|
1191
1225
|
else
|
1192
1226
|
d = JSON.parse(model_dict.fetch(attribute))
|
1193
1227
|
df = Rover::DataFrame.new(d["data"])
|
1194
|
-
df["ds"] = df["ds"].map { |v|
|
1228
|
+
df["ds"] = df["ds"].map { |v| DateTime.parse(v).to_time.utc } if df["ds"]
|
1195
1229
|
if attribute == "train_component_cols"
|
1196
1230
|
# Special handling because of named index column
|
1197
1231
|
# df.columns.name = 'component'
|
data/lib/prophet/holidays.rb
CHANGED
@@ -4,9 +4,7 @@ module Prophet
|
|
4
4
|
years = (1995..2045).to_a
|
5
5
|
holiday_names = make_holidays_df(years, country)["holiday"].uniq
|
6
6
|
if holiday_names.size == 0
|
7
|
-
#
|
8
|
-
# raise ArgumentError, "Holidays in #{country} are not currently supported"
|
9
|
-
logger.warn "Holidays in #{country} are not currently supported"
|
7
|
+
raise ArgumentError, "Holidays in #{country} are not currently supported"
|
10
8
|
end
|
11
9
|
holiday_names
|
12
10
|
end
|
data/lib/prophet/stan_backend.rb
CHANGED
@@ -21,14 +21,16 @@ module Prophet
|
|
21
21
|
kwargs[:algorithm] ||= stan_data["T"] < 100 ? "Newton" : "LBFGS"
|
22
22
|
iterations = 10000
|
23
23
|
|
24
|
+
args = {
|
25
|
+
data: stan_data,
|
26
|
+
inits: stan_init,
|
27
|
+
iter: iterations
|
28
|
+
}
|
29
|
+
args.merge!(kwargs)
|
30
|
+
|
24
31
|
stan_fit = nil
|
25
32
|
begin
|
26
|
-
stan_fit = @model.optimize(
|
27
|
-
data: stan_data,
|
28
|
-
inits: stan_init,
|
29
|
-
iter: iterations,
|
30
|
-
**kwargs
|
31
|
-
)
|
33
|
+
stan_fit = @model.optimize(**args)
|
32
34
|
rescue => e
|
33
35
|
if kwargs[:algorithm] != "Newton"
|
34
36
|
@logger.warn "Optimization terminated abnormally. Falling back to Newton."
|
data/lib/prophet/version.rb
CHANGED
data/lib/prophet-rb.rb
CHANGED
@@ -1 +1 @@
|
|
1
|
-
|
1
|
+
require_relative "prophet"
|
data/lib/prophet.rb
CHANGED
@@ -1,19 +1,19 @@
|
|
1
1
|
# dependencies
|
2
2
|
require "cmdstan"
|
3
|
-
require "rover"
|
4
3
|
require "numo/narray"
|
4
|
+
require "rover"
|
5
5
|
|
6
6
|
# stdlib
|
7
7
|
require "logger"
|
8
8
|
require "set"
|
9
9
|
|
10
10
|
# modules
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
11
|
+
require_relative "prophet/diagnostics"
|
12
|
+
require_relative "prophet/holidays"
|
13
|
+
require_relative "prophet/plot"
|
14
|
+
require_relative "prophet/forecaster"
|
15
|
+
require_relative "prophet/stan_backend"
|
16
|
+
require_relative "prophet/version"
|
17
17
|
|
18
18
|
module Prophet
|
19
19
|
class Error < StandardError; end
|
@@ -68,7 +68,7 @@ module Prophet
|
|
68
68
|
df = Rover::DataFrame.new({"ds" => series.keys, "y" => series.values})
|
69
69
|
df["cap"] = cap if cap
|
70
70
|
|
71
|
-
m.logger.level = ::Logger::
|
71
|
+
m.logger.level = verbose ? ::Logger::INFO : ::Logger::FATAL
|
72
72
|
m.add_country_holidays(country_holidays) if country_holidays
|
73
73
|
m.fit(df)
|
74
74
|
|
@@ -87,7 +87,7 @@ module Prophet
|
|
87
87
|
else
|
88
88
|
result.each { |v| v["ds"] = v["ds"].localtime }
|
89
89
|
end
|
90
|
-
result.
|
90
|
+
result.to_h { |v| [v["ds"], v["yhat"]] }
|
91
91
|
end
|
92
92
|
|
93
93
|
# TODO better name for interval_width
|
@@ -97,7 +97,7 @@ module Prophet
|
|
97
97
|
df["cap"] = cap if cap
|
98
98
|
|
99
99
|
m = Prophet.new(interval_width: interval_width, **options)
|
100
|
-
m.logger.level = ::Logger::
|
100
|
+
m.logger.level = verbose ? ::Logger::INFO : ::Logger::FATAL
|
101
101
|
m.add_country_holidays(country_holidays) if country_holidays
|
102
102
|
m.fit(df)
|
103
103
|
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: prophet-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.5.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2024-05-06 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: cmdstan
|
@@ -159,14 +159,14 @@ required_ruby_version: !ruby/object:Gem::Requirement
|
|
159
159
|
requirements:
|
160
160
|
- - ">="
|
161
161
|
- !ruby/object:Gem::Version
|
162
|
-
version: '
|
162
|
+
version: '3'
|
163
163
|
required_rubygems_version: !ruby/object:Gem::Requirement
|
164
164
|
requirements:
|
165
165
|
- - ">="
|
166
166
|
- !ruby/object:Gem::Version
|
167
167
|
version: '0'
|
168
168
|
requirements: []
|
169
|
-
rubygems_version: 3.
|
169
|
+
rubygems_version: 3.5.9
|
170
170
|
signing_key:
|
171
171
|
specification_version: 4
|
172
172
|
summary: Time series forecasting for Ruby
|