prophet-rb 0.4.2 → 0.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/LICENSE.txt +1 -1
- data/README.md +4 -4
- data/data-raw/LICENSE-holidays.txt +1 -1
- data/data-raw/README.md +1 -1
- data/data-raw/generated_holidays.csv +56365 -32386
- data/lib/prophet/forecaster.rb +21 -13
- data/lib/prophet/holidays.rb +1 -3
- data/lib/prophet/stan_backend.rb +8 -6
- data/lib/prophet/version.rb +1 -1
- data/lib/prophet-rb.rb +1 -1
- data/lib/prophet.rb +10 -10
- metadata +4 -4
data/lib/prophet/forecaster.rb
CHANGED
@@ -44,7 +44,7 @@ module Prophet
|
|
44
44
|
@yearly_seasonality = yearly_seasonality
|
45
45
|
@weekly_seasonality = weekly_seasonality
|
46
46
|
@daily_seasonality = daily_seasonality
|
47
|
-
@holidays = holidays
|
47
|
+
@holidays = convert_df(holidays)
|
48
48
|
|
49
49
|
@seasonality_mode = seasonality_mode
|
50
50
|
@seasonality_prior_scale = seasonality_prior_scale.to_f
|
@@ -75,6 +75,7 @@ module Prophet
|
|
75
75
|
validate_inputs
|
76
76
|
|
77
77
|
@logger = ::Logger.new($stderr)
|
78
|
+
@logger.level = ::Logger::WARN
|
78
79
|
@logger.formatter = proc do |severity, datetime, progname, msg|
|
79
80
|
"[prophet] #{msg}\n"
|
80
81
|
end
|
@@ -133,10 +134,10 @@ module Prophet
|
|
133
134
|
if reserved_names.include?(name)
|
134
135
|
raise ArgumentError, "Name #{name.inspect} is reserved."
|
135
136
|
end
|
136
|
-
if check_holidays && @holidays && @holidays["holiday"].uniq.include?(name)
|
137
|
+
if check_holidays && @holidays && @holidays["holiday"].uniq.to_a.include?(name)
|
137
138
|
raise ArgumentError, "Name #{name.inspect} already used for a holiday."
|
138
139
|
end
|
139
|
-
if check_holidays && @country_holidays && get_holiday_names(@country_holidays).include?(name)
|
140
|
+
if check_holidays && @country_holidays && get_holiday_names(@country_holidays).to_a.include?(name)
|
140
141
|
raise ArgumentError, "Name #{name.inspect} is a holiday name in #{@country_holidays.inspect}."
|
141
142
|
end
|
142
143
|
if check_seasonalities && @seasonalities[name]
|
@@ -631,9 +632,7 @@ module Prophet
|
|
631
632
|
def fit(df, **kwargs)
|
632
633
|
raise Error, "Prophet object can only be fit once" if @history
|
633
634
|
|
634
|
-
|
635
|
-
df = Rover::DataFrame.new(df.to_h)
|
636
|
-
end
|
635
|
+
df = convert_df(df)
|
637
636
|
raise ArgumentError, "Must be a data frame" unless df.is_a?(Rover::DataFrame)
|
638
637
|
|
639
638
|
unless df.include?("ds") && df.include?("y")
|
@@ -993,6 +992,16 @@ module Prophet
|
|
993
992
|
|
994
993
|
private
|
995
994
|
|
995
|
+
def convert_df(df)
|
996
|
+
if defined?(Daru::DataFrame) && df.is_a?(Daru::DataFrame)
|
997
|
+
Rover::DataFrame.new(df.to_h)
|
998
|
+
elsif defined?(Polars::DataFrame) && df.is_a?(Polars::DataFrame)
|
999
|
+
Rover::DataFrame.new(df.to_h(as_series: false))
|
1000
|
+
else
|
1001
|
+
df
|
1002
|
+
end
|
1003
|
+
end
|
1004
|
+
|
996
1005
|
# Time is preferred over DateTime in Ruby docs
|
997
1006
|
# use UTC to be consistent with Python
|
998
1007
|
# and so days have equal length (no DST)
|
@@ -1077,7 +1086,7 @@ module Prophet
|
|
1077
1086
|
d = {
|
1078
1087
|
"name" => "ds",
|
1079
1088
|
"index" => v.size.times.to_a,
|
1080
|
-
"data" => v.to_a.map { |v| v.iso8601(3) }
|
1089
|
+
"data" => v.to_a.map { |v| v.iso8601(3).chomp("Z") }
|
1081
1090
|
}
|
1082
1091
|
model_dict[attribute] = JSON.generate(d)
|
1083
1092
|
end
|
@@ -1096,7 +1105,7 @@ module Prophet
|
|
1096
1105
|
v = instance_variable_get("@#{attribute}")
|
1097
1106
|
|
1098
1107
|
v = v.dup
|
1099
|
-
v["ds"] = v["ds"].map { |v| v.iso8601(3) } if v["ds"]
|
1108
|
+
v["ds"] = v["ds"].map { |v| v.iso8601(3).chomp("Z") } if v["ds"]
|
1100
1109
|
v.delete("col")
|
1101
1110
|
|
1102
1111
|
fields =
|
@@ -1116,7 +1125,7 @@ module Prophet
|
|
1116
1125
|
d = {
|
1117
1126
|
"schema" => {
|
1118
1127
|
"fields" => fields,
|
1119
|
-
"pandas_version" => "
|
1128
|
+
"pandas_version" => "1.4.0"
|
1120
1129
|
},
|
1121
1130
|
"data" => v.to_a
|
1122
1131
|
}
|
@@ -1151,8 +1160,7 @@ module Prophet
|
|
1151
1160
|
# Params (Dict[str, np.ndarray])
|
1152
1161
|
model_dict["params"] = params.transform_values(&:to_a)
|
1153
1162
|
# Attributes that are skipped: stan_fit, stan_backend
|
1154
|
-
|
1155
|
-
model_dict["__prophet_version"] = "1.0"
|
1163
|
+
model_dict["__prophet_version"] = "1.1.2"
|
1156
1164
|
model_dict
|
1157
1165
|
end
|
1158
1166
|
|
@@ -1174,7 +1182,7 @@ module Prophet
|
|
1174
1182
|
d = JSON.parse(model_dict.fetch(attribute))
|
1175
1183
|
s = Rover::Vector.new(d["data"])
|
1176
1184
|
if d["name"] == "ds"
|
1177
|
-
s = s.map { |v|
|
1185
|
+
s = s.map { |v| DateTime.parse(v).to_time.utc }
|
1178
1186
|
end
|
1179
1187
|
model.instance_variable_set("@#{attribute}", s)
|
1180
1188
|
end
|
@@ -1191,7 +1199,7 @@ module Prophet
|
|
1191
1199
|
else
|
1192
1200
|
d = JSON.parse(model_dict.fetch(attribute))
|
1193
1201
|
df = Rover::DataFrame.new(d["data"])
|
1194
|
-
df["ds"] = df["ds"].map { |v|
|
1202
|
+
df["ds"] = df["ds"].map { |v| DateTime.parse(v).to_time.utc } if df["ds"]
|
1195
1203
|
if attribute == "train_component_cols"
|
1196
1204
|
# Special handling because of named index column
|
1197
1205
|
# df.columns.name = 'component'
|
data/lib/prophet/holidays.rb
CHANGED
@@ -4,9 +4,7 @@ module Prophet
|
|
4
4
|
years = (1995..2045).to_a
|
5
5
|
holiday_names = make_holidays_df(years, country)["holiday"].uniq
|
6
6
|
if holiday_names.size == 0
|
7
|
-
#
|
8
|
-
# raise ArgumentError, "Holidays in #{country} are not currently supported"
|
9
|
-
logger.warn "Holidays in #{country} are not currently supported"
|
7
|
+
raise ArgumentError, "Holidays in #{country} are not currently supported"
|
10
8
|
end
|
11
9
|
holiday_names
|
12
10
|
end
|
data/lib/prophet/stan_backend.rb
CHANGED
@@ -21,14 +21,16 @@ module Prophet
|
|
21
21
|
kwargs[:algorithm] ||= stan_data["T"] < 100 ? "Newton" : "LBFGS"
|
22
22
|
iterations = 10000
|
23
23
|
|
24
|
+
args = {
|
25
|
+
data: stan_data,
|
26
|
+
inits: stan_init,
|
27
|
+
iter: iterations
|
28
|
+
}
|
29
|
+
args.merge!(kwargs)
|
30
|
+
|
24
31
|
stan_fit = nil
|
25
32
|
begin
|
26
|
-
stan_fit = @model.optimize(
|
27
|
-
data: stan_data,
|
28
|
-
inits: stan_init,
|
29
|
-
iter: iterations,
|
30
|
-
**kwargs
|
31
|
-
)
|
33
|
+
stan_fit = @model.optimize(**args)
|
32
34
|
rescue => e
|
33
35
|
if kwargs[:algorithm] != "Newton"
|
34
36
|
@logger.warn "Optimization terminated abnormally. Falling back to Newton."
|
data/lib/prophet/version.rb
CHANGED
data/lib/prophet-rb.rb
CHANGED
@@ -1 +1 @@
|
|
1
|
-
|
1
|
+
require_relative "prophet"
|
data/lib/prophet.rb
CHANGED
@@ -1,19 +1,19 @@
|
|
1
1
|
# dependencies
|
2
2
|
require "cmdstan"
|
3
|
-
require "rover"
|
4
3
|
require "numo/narray"
|
4
|
+
require "rover"
|
5
5
|
|
6
6
|
# stdlib
|
7
7
|
require "logger"
|
8
8
|
require "set"
|
9
9
|
|
10
10
|
# modules
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
11
|
+
require_relative "prophet/diagnostics"
|
12
|
+
require_relative "prophet/holidays"
|
13
|
+
require_relative "prophet/plot"
|
14
|
+
require_relative "prophet/forecaster"
|
15
|
+
require_relative "prophet/stan_backend"
|
16
|
+
require_relative "prophet/version"
|
17
17
|
|
18
18
|
module Prophet
|
19
19
|
class Error < StandardError; end
|
@@ -68,7 +68,7 @@ module Prophet
|
|
68
68
|
df = Rover::DataFrame.new({"ds" => series.keys, "y" => series.values})
|
69
69
|
df["cap"] = cap if cap
|
70
70
|
|
71
|
-
m.logger.level = ::Logger::
|
71
|
+
m.logger.level = verbose ? ::Logger::INFO : ::Logger::FATAL
|
72
72
|
m.add_country_holidays(country_holidays) if country_holidays
|
73
73
|
m.fit(df)
|
74
74
|
|
@@ -87,7 +87,7 @@ module Prophet
|
|
87
87
|
else
|
88
88
|
result.each { |v| v["ds"] = v["ds"].localtime }
|
89
89
|
end
|
90
|
-
result.
|
90
|
+
result.to_h { |v| [v["ds"], v["yhat"]] }
|
91
91
|
end
|
92
92
|
|
93
93
|
# TODO better name for interval_width
|
@@ -97,7 +97,7 @@ module Prophet
|
|
97
97
|
df["cap"] = cap if cap
|
98
98
|
|
99
99
|
m = Prophet.new(interval_width: interval_width, **options)
|
100
|
-
m.logger.level = ::Logger::
|
100
|
+
m.logger.level = verbose ? ::Logger::INFO : ::Logger::FATAL
|
101
101
|
m.add_country_holidays(country_holidays) if country_holidays
|
102
102
|
m.fit(df)
|
103
103
|
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: prophet-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.5.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2023-09-05 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: cmdstan
|
@@ -159,14 +159,14 @@ required_ruby_version: !ruby/object:Gem::Requirement
|
|
159
159
|
requirements:
|
160
160
|
- - ">="
|
161
161
|
- !ruby/object:Gem::Version
|
162
|
-
version: '
|
162
|
+
version: '3'
|
163
163
|
required_rubygems_version: !ruby/object:Gem::Requirement
|
164
164
|
requirements:
|
165
165
|
- - ">="
|
166
166
|
- !ruby/object:Gem::Version
|
167
167
|
version: '0'
|
168
168
|
requirements: []
|
169
|
-
rubygems_version: 3.
|
169
|
+
rubygems_version: 3.4.10
|
170
170
|
signing_key:
|
171
171
|
specification_version: 4
|
172
172
|
summary: Time series forecasting for Ruby
|