prophet-rb 0.4.2 → 0.5.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -27,7 +27,8 @@ module Prophet
27
27
  changepoint_prior_scale: 0.05,
28
28
  mcmc_samples: 0,
29
29
  interval_width: 0.80,
30
- uncertainty_samples: 1000
30
+ uncertainty_samples: 1000,
31
+ scaling: "absmax"
31
32
  )
32
33
  @growth = growth
33
34
 
@@ -44,7 +45,7 @@ module Prophet
44
45
  @yearly_seasonality = yearly_seasonality
45
46
  @weekly_seasonality = weekly_seasonality
46
47
  @daily_seasonality = daily_seasonality
47
- @holidays = holidays
48
+ @holidays = convert_df(holidays)
48
49
 
49
50
  @seasonality_mode = seasonality_mode
50
51
  @seasonality_prior_scale = seasonality_prior_scale.to_f
@@ -54,6 +55,10 @@ module Prophet
54
55
  @mcmc_samples = mcmc_samples
55
56
  @interval_width = interval_width
56
57
  @uncertainty_samples = uncertainty_samples
58
+ if !["absmax", "minmax"].include?(scaling)
59
+ raise ArgumentError, "scaling must be one of \"absmax\" or \"minmax\""
60
+ end
61
+ @scaling = scaling
57
62
 
58
63
  # Set during fitting or by other methods
59
64
  @start = nil
@@ -75,6 +80,7 @@ module Prophet
75
80
  validate_inputs
76
81
 
77
82
  @logger = ::Logger.new($stderr)
83
+ @logger.level = ::Logger::WARN
78
84
  @logger.formatter = proc do |severity, datetime, progname, msg|
79
85
  "[prophet] #{msg}\n"
80
86
  end
@@ -133,10 +139,10 @@ module Prophet
133
139
  if reserved_names.include?(name)
134
140
  raise ArgumentError, "Name #{name.inspect} is reserved."
135
141
  end
136
- if check_holidays && @holidays && @holidays["holiday"].uniq.include?(name)
142
+ if check_holidays && @holidays && @holidays["holiday"].uniq.to_a.include?(name)
137
143
  raise ArgumentError, "Name #{name.inspect} already used for a holiday."
138
144
  end
139
- if check_holidays && @country_holidays && get_holiday_names(@country_holidays).include?(name)
145
+ if check_holidays && @country_holidays && get_holiday_names(@country_holidays).to_a.include?(name)
140
146
  raise ArgumentError, "Name #{name.inspect} is a holiday name in #{@country_holidays.inspect}."
141
147
  end
142
148
  if check_seasonalities && @seasonalities[name]
@@ -188,7 +194,11 @@ module Prophet
188
194
  raise ArgumentError, "Expected column \"floor\"."
189
195
  end
190
196
  else
191
- df["floor"] = 0
197
+ if @scaling == "absmax"
198
+ df["floor"] = 0
199
+ elsif @scaling == "minmax"
200
+ df["floor"] = @y_min
201
+ end
192
202
  end
193
203
 
194
204
  if @growth == "logistic"
@@ -218,11 +228,22 @@ module Prophet
218
228
 
219
229
  if @growth == "logistic" && df.include?("floor")
220
230
  @logistic_floor = true
221
- floor = df["floor"]
231
+ if @scaling == "absmax"
232
+ @y_min = (df["y"] - df["floor"]).abs.min.to_f
233
+ @y_scale = (df["y"] - df["floor"]).abs.max.to_f
234
+ elsif @scaling == "minmax"
235
+ @y_min = df["floor"].min
236
+ @y_scale = (df["cap"].max - @y_min).to_f
237
+ end
222
238
  else
223
- floor = 0.0
239
+ if @scaling == "absmax"
240
+ @y_min = 0.0
241
+ @y_scale = df["y"].abs.max.to_f
242
+ elsif @scaling == "minmax"
243
+ @y_min = df["y"].min
244
+ @y_scale = (df["y"].max - @y_min).to_f
245
+ end
224
246
  end
225
- @y_scale = (df["y"] - floor).abs.max
226
247
  @y_scale = 1 if @y_scale == 0
227
248
  @start = df["ds"].min
228
249
  @t_scale = df["ds"].max - @start
@@ -546,7 +567,7 @@ module Prophet
546
567
  days = 86400
547
568
 
548
569
  # Yearly seasonality
549
- yearly_disable = last - first < 370 * days
570
+ yearly_disable = last - first < 730 * days
550
571
  fourier_order = parse_seasonality_args("yearly", @yearly_seasonality, yearly_disable, 10)
551
572
  if fourier_order > 0
552
573
  @seasonalities["yearly"] = {
@@ -631,9 +652,7 @@ module Prophet
631
652
  def fit(df, **kwargs)
632
653
  raise Error, "Prophet object can only be fit once" if @history
633
654
 
634
- if defined?(Daru::DataFrame) && df.is_a?(Daru::DataFrame)
635
- df = Rover::DataFrame.new(df.to_h)
636
- end
655
+ df = convert_df(df)
637
656
  raise ArgumentError, "Must be a data frame" unless df.is_a?(Rover::DataFrame)
638
657
 
639
658
  unless df.include?("ds") && df.include?("y")
@@ -808,7 +827,7 @@ module Prophet
808
827
 
809
828
  x = seasonal_features.to_numo
810
829
  data = {}
811
- component_cols.vector_names.each do |component|
830
+ (component_cols.vector_names - ["col"]).each do |component|
812
831
  beta_c = @params["beta"] * component_cols[component].to_numo
813
832
 
814
833
  comp = x.dot(beta_c.transpose)
@@ -993,6 +1012,16 @@ module Prophet
993
1012
 
994
1013
  private
995
1014
 
1015
+ def convert_df(df)
1016
+ if defined?(Daru::DataFrame) && df.is_a?(Daru::DataFrame)
1017
+ Rover::DataFrame.new(df.to_h)
1018
+ elsif defined?(Polars::DataFrame) && df.is_a?(Polars::DataFrame)
1019
+ Rover::DataFrame.new(df.to_h(as_series: false))
1020
+ else
1021
+ df
1022
+ end
1023
+ end
1024
+
996
1025
  # Time is preferred over DateTime in Ruby docs
997
1026
  # use UTC to be consistent with Python
998
1027
  # and so days have equal length (no DST)
@@ -1043,7 +1072,7 @@ module Prophet
1043
1072
  "yearly_seasonality", "weekly_seasonality", "daily_seasonality",
1044
1073
  "seasonality_mode", "seasonality_prior_scale", "changepoint_prior_scale",
1045
1074
  "holidays_prior_scale", "mcmc_samples", "interval_width", "uncertainty_samples",
1046
- "y_scale", "logistic_floor", "country_holidays", "component_modes"
1075
+ "y_scale", "y_min", "scaling", "logistic_floor", "country_holidays", "component_modes"
1047
1076
  ]
1048
1077
 
1049
1078
  PD_SERIES = ["changepoints", "history_dates", "train_holiday_names"]
@@ -1077,7 +1106,7 @@ module Prophet
1077
1106
  d = {
1078
1107
  "name" => "ds",
1079
1108
  "index" => v.size.times.to_a,
1080
- "data" => v.to_a.map { |v| v.iso8601(3) }
1109
+ "data" => v.to_a.map { |v| v.iso8601(3).chomp("Z") }
1081
1110
  }
1082
1111
  model_dict[attribute] = JSON.generate(d)
1083
1112
  end
@@ -1096,7 +1125,7 @@ module Prophet
1096
1125
  v = instance_variable_get("@#{attribute}")
1097
1126
 
1098
1127
  v = v.dup
1099
- v["ds"] = v["ds"].map { |v| v.iso8601(3) } if v["ds"]
1128
+ v["ds"] = v["ds"].map { |v| v.iso8601(3).chomp("Z") } if v["ds"]
1100
1129
  v.delete("col")
1101
1130
 
1102
1131
  fields =
@@ -1116,7 +1145,7 @@ module Prophet
1116
1145
  d = {
1117
1146
  "schema" => {
1118
1147
  "fields" => fields,
1119
- "pandas_version" => "0.20.0"
1148
+ "pandas_version" => "1.4.0"
1120
1149
  },
1121
1150
  "data" => v.to_a
1122
1151
  }
@@ -1151,8 +1180,7 @@ module Prophet
1151
1180
  # Params (Dict[str, np.ndarray])
1152
1181
  model_dict["params"] = params.transform_values(&:to_a)
1153
1182
  # Attributes that are skipped: stan_fit, stan_backend
1154
- # Returns 1.0 for Prophet 1.1
1155
- model_dict["__prophet_version"] = "1.0"
1183
+ model_dict["__prophet_version"] = "1.1.2"
1156
1184
  model_dict
1157
1185
  end
1158
1186
 
@@ -1161,6 +1189,12 @@ module Prophet
1161
1189
 
1162
1190
  model_dict = JSON.parse(model_json)
1163
1191
 
1192
+ # handle_simple_attributes_backwards_compat
1193
+ if !model_dict["scaling"]
1194
+ model_dict["scaling"] = "absmax"
1195
+ model_dict["y_min"] = 0.0
1196
+ end
1197
+
1164
1198
  # We will overwrite all attributes set in init anyway
1165
1199
  model = Prophet.new
1166
1200
  # Simple types
@@ -1174,7 +1208,7 @@ module Prophet
1174
1208
  d = JSON.parse(model_dict.fetch(attribute))
1175
1209
  s = Rover::Vector.new(d["data"])
1176
1210
  if d["name"] == "ds"
1177
- s = s.map { |v| Time.parse(v).utc }
1211
+ s = s.map { |v| DateTime.parse(v).to_time.utc }
1178
1212
  end
1179
1213
  model.instance_variable_set("@#{attribute}", s)
1180
1214
  end
@@ -1191,7 +1225,7 @@ module Prophet
1191
1225
  else
1192
1226
  d = JSON.parse(model_dict.fetch(attribute))
1193
1227
  df = Rover::DataFrame.new(d["data"])
1194
- df["ds"] = df["ds"].map { |v| Time.parse(v).utc } if df["ds"]
1228
+ df["ds"] = df["ds"].map { |v| DateTime.parse(v).to_time.utc } if df["ds"]
1195
1229
  if attribute == "train_component_cols"
1196
1230
  # Special handling because of named index column
1197
1231
  # df.columns.name = 'component'
@@ -4,9 +4,7 @@ module Prophet
4
4
  years = (1995..2045).to_a
5
5
  holiday_names = make_holidays_df(years, country)["holiday"].uniq
6
6
  if holiday_names.size == 0
7
- # TODO raise error in 0.5.0
8
- # raise ArgumentError, "Holidays in #{country} are not currently supported"
9
- logger.warn "Holidays in #{country} are not currently supported"
7
+ raise ArgumentError, "Holidays in #{country} are not currently supported"
10
8
  end
11
9
  holiday_names
12
10
  end
@@ -21,14 +21,16 @@ module Prophet
21
21
  kwargs[:algorithm] ||= stan_data["T"] < 100 ? "Newton" : "LBFGS"
22
22
  iterations = 10000
23
23
 
24
+ args = {
25
+ data: stan_data,
26
+ inits: stan_init,
27
+ iter: iterations
28
+ }
29
+ args.merge!(kwargs)
30
+
24
31
  stan_fit = nil
25
32
  begin
26
- stan_fit = @model.optimize(
27
- data: stan_data,
28
- inits: stan_init,
29
- iter: iterations,
30
- **kwargs
31
- )
33
+ stan_fit = @model.optimize(**args)
32
34
  rescue => e
33
35
  if kwargs[:algorithm] != "Newton"
34
36
  @logger.warn "Optimization terminated abnormally. Falling back to Newton."
@@ -1,3 +1,3 @@
1
1
  module Prophet
2
- VERSION = "0.4.2"
2
+ VERSION = "0.5.1"
3
3
  end
data/lib/prophet-rb.rb CHANGED
@@ -1 +1 @@
1
- require "prophet"
1
+ require_relative "prophet"
data/lib/prophet.rb CHANGED
@@ -1,19 +1,19 @@
1
1
  # dependencies
2
2
  require "cmdstan"
3
- require "rover"
4
3
  require "numo/narray"
4
+ require "rover"
5
5
 
6
6
  # stdlib
7
7
  require "logger"
8
8
  require "set"
9
9
 
10
10
  # modules
11
- require "prophet/diagnostics"
12
- require "prophet/holidays"
13
- require "prophet/plot"
14
- require "prophet/forecaster"
15
- require "prophet/stan_backend"
16
- require "prophet/version"
11
+ require_relative "prophet/diagnostics"
12
+ require_relative "prophet/holidays"
13
+ require_relative "prophet/plot"
14
+ require_relative "prophet/forecaster"
15
+ require_relative "prophet/stan_backend"
16
+ require_relative "prophet/version"
17
17
 
18
18
  module Prophet
19
19
  class Error < StandardError; end
@@ -68,7 +68,7 @@ module Prophet
68
68
  df = Rover::DataFrame.new({"ds" => series.keys, "y" => series.values})
69
69
  df["cap"] = cap if cap
70
70
 
71
- m.logger.level = ::Logger::FATAL unless verbose
71
+ m.logger.level = verbose ? ::Logger::INFO : ::Logger::FATAL
72
72
  m.add_country_holidays(country_holidays) if country_holidays
73
73
  m.fit(df)
74
74
 
@@ -87,7 +87,7 @@ module Prophet
87
87
  else
88
88
  result.each { |v| v["ds"] = v["ds"].localtime }
89
89
  end
90
- result.map { |v| [v["ds"], v["yhat"]] }.to_h
90
+ result.to_h { |v| [v["ds"], v["yhat"]] }
91
91
  end
92
92
 
93
93
  # TODO better name for interval_width
@@ -97,7 +97,7 @@ module Prophet
97
97
  df["cap"] = cap if cap
98
98
 
99
99
  m = Prophet.new(interval_width: interval_width, **options)
100
- m.logger.level = ::Logger::FATAL unless verbose
100
+ m.logger.level = verbose ? ::Logger::INFO : ::Logger::FATAL
101
101
  m.add_country_holidays(country_holidays) if country_holidays
102
102
  m.fit(df)
103
103
 
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: prophet-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.4.2
4
+ version: 0.5.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2022-07-12 00:00:00.000000000 Z
11
+ date: 2024-05-06 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: cmdstan
@@ -159,14 +159,14 @@ required_ruby_version: !ruby/object:Gem::Requirement
159
159
  requirements:
160
160
  - - ">="
161
161
  - !ruby/object:Gem::Version
162
- version: '2.7'
162
+ version: '3'
163
163
  required_rubygems_version: !ruby/object:Gem::Requirement
164
164
  requirements:
165
165
  - - ">="
166
166
  - !ruby/object:Gem::Version
167
167
  version: '0'
168
168
  requirements: []
169
- rubygems_version: 3.3.7
169
+ rubygems_version: 3.5.9
170
170
  signing_key:
171
171
  specification_version: 4
172
172
  summary: Time series forecasting for Ruby