prophet-rb 0.4.2 → 0.5.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +15 -0
- data/LICENSE.txt +1 -1
- data/README.md +4 -4
- data/data-raw/LICENSE-holidays.txt +1 -1
- data/data-raw/README.md +1 -1
- data/data-raw/generated_holidays.csv +56365 -32386
- data/lib/prophet/forecaster.rb +55 -21
- data/lib/prophet/holidays.rb +1 -3
- data/lib/prophet/stan_backend.rb +8 -6
- data/lib/prophet/version.rb +1 -1
- data/lib/prophet-rb.rb +1 -1
- data/lib/prophet.rb +10 -10
- metadata +4 -4
data/lib/prophet/forecaster.rb
CHANGED
@@ -27,7 +27,8 @@ module Prophet
|
|
27
27
|
changepoint_prior_scale: 0.05,
|
28
28
|
mcmc_samples: 0,
|
29
29
|
interval_width: 0.80,
|
30
|
-
uncertainty_samples: 1000
|
30
|
+
uncertainty_samples: 1000,
|
31
|
+
scaling: "absmax"
|
31
32
|
)
|
32
33
|
@growth = growth
|
33
34
|
|
@@ -44,7 +45,7 @@ module Prophet
|
|
44
45
|
@yearly_seasonality = yearly_seasonality
|
45
46
|
@weekly_seasonality = weekly_seasonality
|
46
47
|
@daily_seasonality = daily_seasonality
|
47
|
-
@holidays = holidays
|
48
|
+
@holidays = convert_df(holidays)
|
48
49
|
|
49
50
|
@seasonality_mode = seasonality_mode
|
50
51
|
@seasonality_prior_scale = seasonality_prior_scale.to_f
|
@@ -54,6 +55,10 @@ module Prophet
|
|
54
55
|
@mcmc_samples = mcmc_samples
|
55
56
|
@interval_width = interval_width
|
56
57
|
@uncertainty_samples = uncertainty_samples
|
58
|
+
if !["absmax", "minmax"].include?(scaling)
|
59
|
+
raise ArgumentError, "scaling must be one of \"absmax\" or \"minmax\""
|
60
|
+
end
|
61
|
+
@scaling = scaling
|
57
62
|
|
58
63
|
# Set during fitting or by other methods
|
59
64
|
@start = nil
|
@@ -75,6 +80,7 @@ module Prophet
|
|
75
80
|
validate_inputs
|
76
81
|
|
77
82
|
@logger = ::Logger.new($stderr)
|
83
|
+
@logger.level = ::Logger::WARN
|
78
84
|
@logger.formatter = proc do |severity, datetime, progname, msg|
|
79
85
|
"[prophet] #{msg}\n"
|
80
86
|
end
|
@@ -133,10 +139,10 @@ module Prophet
|
|
133
139
|
if reserved_names.include?(name)
|
134
140
|
raise ArgumentError, "Name #{name.inspect} is reserved."
|
135
141
|
end
|
136
|
-
if check_holidays && @holidays && @holidays["holiday"].uniq.include?(name)
|
142
|
+
if check_holidays && @holidays && @holidays["holiday"].uniq.to_a.include?(name)
|
137
143
|
raise ArgumentError, "Name #{name.inspect} already used for a holiday."
|
138
144
|
end
|
139
|
-
if check_holidays && @country_holidays && get_holiday_names(@country_holidays).include?(name)
|
145
|
+
if check_holidays && @country_holidays && get_holiday_names(@country_holidays).to_a.include?(name)
|
140
146
|
raise ArgumentError, "Name #{name.inspect} is a holiday name in #{@country_holidays.inspect}."
|
141
147
|
end
|
142
148
|
if check_seasonalities && @seasonalities[name]
|
@@ -188,7 +194,11 @@ module Prophet
|
|
188
194
|
raise ArgumentError, "Expected column \"floor\"."
|
189
195
|
end
|
190
196
|
else
|
191
|
-
|
197
|
+
if @scaling == "absmax"
|
198
|
+
df["floor"] = 0
|
199
|
+
elsif @scaling == "minmax"
|
200
|
+
df["floor"] = @y_min
|
201
|
+
end
|
192
202
|
end
|
193
203
|
|
194
204
|
if @growth == "logistic"
|
@@ -218,11 +228,22 @@ module Prophet
|
|
218
228
|
|
219
229
|
if @growth == "logistic" && df.include?("floor")
|
220
230
|
@logistic_floor = true
|
221
|
-
|
231
|
+
if @scaling == "absmax"
|
232
|
+
@y_min = (df["y"] - df["floor"]).abs.min.to_f
|
233
|
+
@y_scale = (df["y"] - df["floor"]).abs.max.to_f
|
234
|
+
elsif @scaling == "minmax"
|
235
|
+
@y_min = df["floor"].min
|
236
|
+
@y_scale = (df["cap"].max - @y_min).to_f
|
237
|
+
end
|
222
238
|
else
|
223
|
-
|
239
|
+
if @scaling == "absmax"
|
240
|
+
@y_min = 0.0
|
241
|
+
@y_scale = df["y"].abs.max.to_f
|
242
|
+
elsif @scaling == "minmax"
|
243
|
+
@y_min = df["y"].min
|
244
|
+
@y_scale = (df["y"].max - @y_min).to_f
|
245
|
+
end
|
224
246
|
end
|
225
|
-
@y_scale = (df["y"] - floor).abs.max
|
226
247
|
@y_scale = 1 if @y_scale == 0
|
227
248
|
@start = df["ds"].min
|
228
249
|
@t_scale = df["ds"].max - @start
|
@@ -546,7 +567,7 @@ module Prophet
|
|
546
567
|
days = 86400
|
547
568
|
|
548
569
|
# Yearly seasonality
|
549
|
-
yearly_disable = last - first <
|
570
|
+
yearly_disable = last - first < 730 * days
|
550
571
|
fourier_order = parse_seasonality_args("yearly", @yearly_seasonality, yearly_disable, 10)
|
551
572
|
if fourier_order > 0
|
552
573
|
@seasonalities["yearly"] = {
|
@@ -631,9 +652,7 @@ module Prophet
|
|
631
652
|
def fit(df, **kwargs)
|
632
653
|
raise Error, "Prophet object can only be fit once" if @history
|
633
654
|
|
634
|
-
|
635
|
-
df = Rover::DataFrame.new(df.to_h)
|
636
|
-
end
|
655
|
+
df = convert_df(df)
|
637
656
|
raise ArgumentError, "Must be a data frame" unless df.is_a?(Rover::DataFrame)
|
638
657
|
|
639
658
|
unless df.include?("ds") && df.include?("y")
|
@@ -808,7 +827,7 @@ module Prophet
|
|
808
827
|
|
809
828
|
x = seasonal_features.to_numo
|
810
829
|
data = {}
|
811
|
-
component_cols.vector_names.each do |component|
|
830
|
+
(component_cols.vector_names - ["col"]).each do |component|
|
812
831
|
beta_c = @params["beta"] * component_cols[component].to_numo
|
813
832
|
|
814
833
|
comp = x.dot(beta_c.transpose)
|
@@ -993,6 +1012,16 @@ module Prophet
|
|
993
1012
|
|
994
1013
|
private
|
995
1014
|
|
1015
|
+
def convert_df(df)
|
1016
|
+
if defined?(Daru::DataFrame) && df.is_a?(Daru::DataFrame)
|
1017
|
+
Rover::DataFrame.new(df.to_h)
|
1018
|
+
elsif defined?(Polars::DataFrame) && df.is_a?(Polars::DataFrame)
|
1019
|
+
Rover::DataFrame.new(df.to_h(as_series: false))
|
1020
|
+
else
|
1021
|
+
df
|
1022
|
+
end
|
1023
|
+
end
|
1024
|
+
|
996
1025
|
# Time is preferred over DateTime in Ruby docs
|
997
1026
|
# use UTC to be consistent with Python
|
998
1027
|
# and so days have equal length (no DST)
|
@@ -1043,7 +1072,7 @@ module Prophet
|
|
1043
1072
|
"yearly_seasonality", "weekly_seasonality", "daily_seasonality",
|
1044
1073
|
"seasonality_mode", "seasonality_prior_scale", "changepoint_prior_scale",
|
1045
1074
|
"holidays_prior_scale", "mcmc_samples", "interval_width", "uncertainty_samples",
|
1046
|
-
"y_scale", "logistic_floor", "country_holidays", "component_modes"
|
1075
|
+
"y_scale", "y_min", "scaling", "logistic_floor", "country_holidays", "component_modes"
|
1047
1076
|
]
|
1048
1077
|
|
1049
1078
|
PD_SERIES = ["changepoints", "history_dates", "train_holiday_names"]
|
@@ -1077,7 +1106,7 @@ module Prophet
|
|
1077
1106
|
d = {
|
1078
1107
|
"name" => "ds",
|
1079
1108
|
"index" => v.size.times.to_a,
|
1080
|
-
"data" => v.to_a.map { |v| v.iso8601(3) }
|
1109
|
+
"data" => v.to_a.map { |v| v.iso8601(3).chomp("Z") }
|
1081
1110
|
}
|
1082
1111
|
model_dict[attribute] = JSON.generate(d)
|
1083
1112
|
end
|
@@ -1096,7 +1125,7 @@ module Prophet
|
|
1096
1125
|
v = instance_variable_get("@#{attribute}")
|
1097
1126
|
|
1098
1127
|
v = v.dup
|
1099
|
-
v["ds"] = v["ds"].map { |v| v.iso8601(3) } if v["ds"]
|
1128
|
+
v["ds"] = v["ds"].map { |v| v.iso8601(3).chomp("Z") } if v["ds"]
|
1100
1129
|
v.delete("col")
|
1101
1130
|
|
1102
1131
|
fields =
|
@@ -1116,7 +1145,7 @@ module Prophet
|
|
1116
1145
|
d = {
|
1117
1146
|
"schema" => {
|
1118
1147
|
"fields" => fields,
|
1119
|
-
"pandas_version" => "
|
1148
|
+
"pandas_version" => "1.4.0"
|
1120
1149
|
},
|
1121
1150
|
"data" => v.to_a
|
1122
1151
|
}
|
@@ -1151,8 +1180,7 @@ module Prophet
|
|
1151
1180
|
# Params (Dict[str, np.ndarray])
|
1152
1181
|
model_dict["params"] = params.transform_values(&:to_a)
|
1153
1182
|
# Attributes that are skipped: stan_fit, stan_backend
|
1154
|
-
|
1155
|
-
model_dict["__prophet_version"] = "1.0"
|
1183
|
+
model_dict["__prophet_version"] = "1.1.2"
|
1156
1184
|
model_dict
|
1157
1185
|
end
|
1158
1186
|
|
@@ -1161,6 +1189,12 @@ module Prophet
|
|
1161
1189
|
|
1162
1190
|
model_dict = JSON.parse(model_json)
|
1163
1191
|
|
1192
|
+
# handle_simple_attributes_backwards_compat
|
1193
|
+
if !model_dict["scaling"]
|
1194
|
+
model_dict["scaling"] = "absmax"
|
1195
|
+
model_dict["y_min"] = 0.0
|
1196
|
+
end
|
1197
|
+
|
1164
1198
|
# We will overwrite all attributes set in init anyway
|
1165
1199
|
model = Prophet.new
|
1166
1200
|
# Simple types
|
@@ -1174,7 +1208,7 @@ module Prophet
|
|
1174
1208
|
d = JSON.parse(model_dict.fetch(attribute))
|
1175
1209
|
s = Rover::Vector.new(d["data"])
|
1176
1210
|
if d["name"] == "ds"
|
1177
|
-
s = s.map { |v|
|
1211
|
+
s = s.map { |v| DateTime.parse(v).to_time.utc }
|
1178
1212
|
end
|
1179
1213
|
model.instance_variable_set("@#{attribute}", s)
|
1180
1214
|
end
|
@@ -1191,7 +1225,7 @@ module Prophet
|
|
1191
1225
|
else
|
1192
1226
|
d = JSON.parse(model_dict.fetch(attribute))
|
1193
1227
|
df = Rover::DataFrame.new(d["data"])
|
1194
|
-
df["ds"] = df["ds"].map { |v|
|
1228
|
+
df["ds"] = df["ds"].map { |v| DateTime.parse(v).to_time.utc } if df["ds"]
|
1195
1229
|
if attribute == "train_component_cols"
|
1196
1230
|
# Special handling because of named index column
|
1197
1231
|
# df.columns.name = 'component'
|
data/lib/prophet/holidays.rb
CHANGED
@@ -4,9 +4,7 @@ module Prophet
|
|
4
4
|
years = (1995..2045).to_a
|
5
5
|
holiday_names = make_holidays_df(years, country)["holiday"].uniq
|
6
6
|
if holiday_names.size == 0
|
7
|
-
#
|
8
|
-
# raise ArgumentError, "Holidays in #{country} are not currently supported"
|
9
|
-
logger.warn "Holidays in #{country} are not currently supported"
|
7
|
+
raise ArgumentError, "Holidays in #{country} are not currently supported"
|
10
8
|
end
|
11
9
|
holiday_names
|
12
10
|
end
|
data/lib/prophet/stan_backend.rb
CHANGED
@@ -21,14 +21,16 @@ module Prophet
|
|
21
21
|
kwargs[:algorithm] ||= stan_data["T"] < 100 ? "Newton" : "LBFGS"
|
22
22
|
iterations = 10000
|
23
23
|
|
24
|
+
args = {
|
25
|
+
data: stan_data,
|
26
|
+
inits: stan_init,
|
27
|
+
iter: iterations
|
28
|
+
}
|
29
|
+
args.merge!(kwargs)
|
30
|
+
|
24
31
|
stan_fit = nil
|
25
32
|
begin
|
26
|
-
stan_fit = @model.optimize(
|
27
|
-
data: stan_data,
|
28
|
-
inits: stan_init,
|
29
|
-
iter: iterations,
|
30
|
-
**kwargs
|
31
|
-
)
|
33
|
+
stan_fit = @model.optimize(**args)
|
32
34
|
rescue => e
|
33
35
|
if kwargs[:algorithm] != "Newton"
|
34
36
|
@logger.warn "Optimization terminated abnormally. Falling back to Newton."
|
data/lib/prophet/version.rb
CHANGED
data/lib/prophet-rb.rb
CHANGED
@@ -1 +1 @@
|
|
1
|
-
|
1
|
+
require_relative "prophet"
|
data/lib/prophet.rb
CHANGED
@@ -1,19 +1,19 @@
|
|
1
1
|
# dependencies
|
2
2
|
require "cmdstan"
|
3
|
-
require "rover"
|
4
3
|
require "numo/narray"
|
4
|
+
require "rover"
|
5
5
|
|
6
6
|
# stdlib
|
7
7
|
require "logger"
|
8
8
|
require "set"
|
9
9
|
|
10
10
|
# modules
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
11
|
+
require_relative "prophet/diagnostics"
|
12
|
+
require_relative "prophet/holidays"
|
13
|
+
require_relative "prophet/plot"
|
14
|
+
require_relative "prophet/forecaster"
|
15
|
+
require_relative "prophet/stan_backend"
|
16
|
+
require_relative "prophet/version"
|
17
17
|
|
18
18
|
module Prophet
|
19
19
|
class Error < StandardError; end
|
@@ -68,7 +68,7 @@ module Prophet
|
|
68
68
|
df = Rover::DataFrame.new({"ds" => series.keys, "y" => series.values})
|
69
69
|
df["cap"] = cap if cap
|
70
70
|
|
71
|
-
m.logger.level = ::Logger::
|
71
|
+
m.logger.level = verbose ? ::Logger::INFO : ::Logger::FATAL
|
72
72
|
m.add_country_holidays(country_holidays) if country_holidays
|
73
73
|
m.fit(df)
|
74
74
|
|
@@ -87,7 +87,7 @@ module Prophet
|
|
87
87
|
else
|
88
88
|
result.each { |v| v["ds"] = v["ds"].localtime }
|
89
89
|
end
|
90
|
-
result.
|
90
|
+
result.to_h { |v| [v["ds"], v["yhat"]] }
|
91
91
|
end
|
92
92
|
|
93
93
|
# TODO better name for interval_width
|
@@ -97,7 +97,7 @@ module Prophet
|
|
97
97
|
df["cap"] = cap if cap
|
98
98
|
|
99
99
|
m = Prophet.new(interval_width: interval_width, **options)
|
100
|
-
m.logger.level = ::Logger::
|
100
|
+
m.logger.level = verbose ? ::Logger::INFO : ::Logger::FATAL
|
101
101
|
m.add_country_holidays(country_holidays) if country_holidays
|
102
102
|
m.fit(df)
|
103
103
|
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: prophet-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.5.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2024-05-06 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: cmdstan
|
@@ -159,14 +159,14 @@ required_ruby_version: !ruby/object:Gem::Requirement
|
|
159
159
|
requirements:
|
160
160
|
- - ">="
|
161
161
|
- !ruby/object:Gem::Version
|
162
|
-
version: '
|
162
|
+
version: '3'
|
163
163
|
required_rubygems_version: !ruby/object:Gem::Requirement
|
164
164
|
requirements:
|
165
165
|
- - ">="
|
166
166
|
- !ruby/object:Gem::Version
|
167
167
|
version: '0'
|
168
168
|
requirements: []
|
169
|
-
rubygems_version: 3.
|
169
|
+
rubygems_version: 3.5.9
|
170
170
|
signing_key:
|
171
171
|
specification_version: 4
|
172
172
|
summary: Time series forecasting for Ruby
|