prophet-rb 0.4.2 → 0.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -27,7 +27,8 @@ module Prophet
27
27
  changepoint_prior_scale: 0.05,
28
28
  mcmc_samples: 0,
29
29
  interval_width: 0.80,
30
- uncertainty_samples: 1000
30
+ uncertainty_samples: 1000,
31
+ scaling: "absmax"
31
32
  )
32
33
  @growth = growth
33
34
 
@@ -44,7 +45,7 @@ module Prophet
44
45
  @yearly_seasonality = yearly_seasonality
45
46
  @weekly_seasonality = weekly_seasonality
46
47
  @daily_seasonality = daily_seasonality
47
- @holidays = holidays
48
+ @holidays = convert_df(holidays)
48
49
 
49
50
  @seasonality_mode = seasonality_mode
50
51
  @seasonality_prior_scale = seasonality_prior_scale.to_f
@@ -54,6 +55,10 @@ module Prophet
54
55
  @mcmc_samples = mcmc_samples
55
56
  @interval_width = interval_width
56
57
  @uncertainty_samples = uncertainty_samples
58
+ if !["absmax", "minmax"].include?(scaling)
59
+ raise ArgumentError, "scaling must be one of \"absmax\" or \"minmax\""
60
+ end
61
+ @scaling = scaling
57
62
 
58
63
  # Set during fitting or by other methods
59
64
  @start = nil
@@ -75,6 +80,7 @@ module Prophet
75
80
  validate_inputs
76
81
 
77
82
  @logger = ::Logger.new($stderr)
83
+ @logger.level = ::Logger::WARN
78
84
  @logger.formatter = proc do |severity, datetime, progname, msg|
79
85
  "[prophet] #{msg}\n"
80
86
  end
@@ -133,10 +139,10 @@ module Prophet
133
139
  if reserved_names.include?(name)
134
140
  raise ArgumentError, "Name #{name.inspect} is reserved."
135
141
  end
136
- if check_holidays && @holidays && @holidays["holiday"].uniq.include?(name)
142
+ if check_holidays && @holidays && @holidays["holiday"].uniq.to_a.include?(name)
137
143
  raise ArgumentError, "Name #{name.inspect} already used for a holiday."
138
144
  end
139
- if check_holidays && @country_holidays && get_holiday_names(@country_holidays).include?(name)
145
+ if check_holidays && @country_holidays && get_holiday_names(@country_holidays).to_a.include?(name)
140
146
  raise ArgumentError, "Name #{name.inspect} is a holiday name in #{@country_holidays.inspect}."
141
147
  end
142
148
  if check_seasonalities && @seasonalities[name]
@@ -188,7 +194,11 @@ module Prophet
188
194
  raise ArgumentError, "Expected column \"floor\"."
189
195
  end
190
196
  else
191
- df["floor"] = 0
197
+ if @scaling == "absmax"
198
+ df["floor"] = 0
199
+ elsif @scaling == "minmax"
200
+ df["floor"] = @y_min
201
+ end
192
202
  end
193
203
 
194
204
  if @growth == "logistic"
@@ -218,11 +228,22 @@ module Prophet
218
228
 
219
229
  if @growth == "logistic" && df.include?("floor")
220
230
  @logistic_floor = true
221
- floor = df["floor"]
231
+ if @scaling == "absmax"
232
+ @y_min = (df["y"] - df["floor"]).abs.min.to_f
233
+ @y_scale = (df["y"] - df["floor"]).abs.max.to_f
234
+ elsif @scaling == "minmax"
235
+ @y_min = df["floor"].min
236
+ @y_scale = (df["cap"].max - @y_min).to_f
237
+ end
222
238
  else
223
- floor = 0.0
239
+ if @scaling == "absmax"
240
+ @y_min = 0.0
241
+ @y_scale = df["y"].abs.max.to_f
242
+ elsif @scaling == "minmax"
243
+ @y_min = df["y"].min
244
+ @y_scale = (df["y"].max - @y_min).to_f
245
+ end
224
246
  end
225
- @y_scale = (df["y"] - floor).abs.max
226
247
  @y_scale = 1 if @y_scale == 0
227
248
  @start = df["ds"].min
228
249
  @t_scale = df["ds"].max - @start
@@ -546,7 +567,7 @@ module Prophet
546
567
  days = 86400
547
568
 
548
569
  # Yearly seasonality
549
- yearly_disable = last - first < 370 * days
570
+ yearly_disable = last - first < 730 * days
550
571
  fourier_order = parse_seasonality_args("yearly", @yearly_seasonality, yearly_disable, 10)
551
572
  if fourier_order > 0
552
573
  @seasonalities["yearly"] = {
@@ -631,9 +652,7 @@ module Prophet
631
652
  def fit(df, **kwargs)
632
653
  raise Error, "Prophet object can only be fit once" if @history
633
654
 
634
- if defined?(Daru::DataFrame) && df.is_a?(Daru::DataFrame)
635
- df = Rover::DataFrame.new(df.to_h)
636
- end
655
+ df = convert_df(df)
637
656
  raise ArgumentError, "Must be a data frame" unless df.is_a?(Rover::DataFrame)
638
657
 
639
658
  unless df.include?("ds") && df.include?("y")
@@ -808,7 +827,7 @@ module Prophet
808
827
 
809
828
  x = seasonal_features.to_numo
810
829
  data = {}
811
- component_cols.vector_names.each do |component|
830
+ (component_cols.vector_names - ["col"]).each do |component|
812
831
  beta_c = @params["beta"] * component_cols[component].to_numo
813
832
 
814
833
  comp = x.dot(beta_c.transpose)
@@ -993,6 +1012,16 @@ module Prophet
993
1012
 
994
1013
  private
995
1014
 
1015
+ def convert_df(df)
1016
+ if defined?(Daru::DataFrame) && df.is_a?(Daru::DataFrame)
1017
+ Rover::DataFrame.new(df.to_h)
1018
+ elsif defined?(Polars::DataFrame) && df.is_a?(Polars::DataFrame)
1019
+ Rover::DataFrame.new(df.to_h(as_series: false))
1020
+ else
1021
+ df
1022
+ end
1023
+ end
1024
+
996
1025
  # Time is preferred over DateTime in Ruby docs
997
1026
  # use UTC to be consistent with Python
998
1027
  # and so days have equal length (no DST)
@@ -1043,7 +1072,7 @@ module Prophet
1043
1072
  "yearly_seasonality", "weekly_seasonality", "daily_seasonality",
1044
1073
  "seasonality_mode", "seasonality_prior_scale", "changepoint_prior_scale",
1045
1074
  "holidays_prior_scale", "mcmc_samples", "interval_width", "uncertainty_samples",
1046
- "y_scale", "logistic_floor", "country_holidays", "component_modes"
1075
+ "y_scale", "y_min", "scaling", "logistic_floor", "country_holidays", "component_modes"
1047
1076
  ]
1048
1077
 
1049
1078
  PD_SERIES = ["changepoints", "history_dates", "train_holiday_names"]
@@ -1077,7 +1106,7 @@ module Prophet
1077
1106
  d = {
1078
1107
  "name" => "ds",
1079
1108
  "index" => v.size.times.to_a,
1080
- "data" => v.to_a.map { |v| v.iso8601(3) }
1109
+ "data" => v.to_a.map { |v| v.iso8601(3).chomp("Z") }
1081
1110
  }
1082
1111
  model_dict[attribute] = JSON.generate(d)
1083
1112
  end
@@ -1096,7 +1125,7 @@ module Prophet
1096
1125
  v = instance_variable_get("@#{attribute}")
1097
1126
 
1098
1127
  v = v.dup
1099
- v["ds"] = v["ds"].map { |v| v.iso8601(3) } if v["ds"]
1128
+ v["ds"] = v["ds"].map { |v| v.iso8601(3).chomp("Z") } if v["ds"]
1100
1129
  v.delete("col")
1101
1130
 
1102
1131
  fields =
@@ -1116,7 +1145,7 @@ module Prophet
1116
1145
  d = {
1117
1146
  "schema" => {
1118
1147
  "fields" => fields,
1119
- "pandas_version" => "0.20.0"
1148
+ "pandas_version" => "1.4.0"
1120
1149
  },
1121
1150
  "data" => v.to_a
1122
1151
  }
@@ -1151,8 +1180,7 @@ module Prophet
1151
1180
  # Params (Dict[str, np.ndarray])
1152
1181
  model_dict["params"] = params.transform_values(&:to_a)
1153
1182
  # Attributes that are skipped: stan_fit, stan_backend
1154
- # Returns 1.0 for Prophet 1.1
1155
- model_dict["__prophet_version"] = "1.0"
1183
+ model_dict["__prophet_version"] = "1.1.2"
1156
1184
  model_dict
1157
1185
  end
1158
1186
 
@@ -1161,6 +1189,12 @@ module Prophet
1161
1189
 
1162
1190
  model_dict = JSON.parse(model_json)
1163
1191
 
1192
+ # handle_simple_attributes_backwards_compat
1193
+ if !model_dict["scaling"]
1194
+ model_dict["scaling"] = "absmax"
1195
+ model_dict["y_min"] = 0.0
1196
+ end
1197
+
1164
1198
  # We will overwrite all attributes set in init anyway
1165
1199
  model = Prophet.new
1166
1200
  # Simple types
@@ -1174,7 +1208,7 @@ module Prophet
1174
1208
  d = JSON.parse(model_dict.fetch(attribute))
1175
1209
  s = Rover::Vector.new(d["data"])
1176
1210
  if d["name"] == "ds"
1177
- s = s.map { |v| Time.parse(v).utc }
1211
+ s = s.map { |v| DateTime.parse(v).to_time.utc }
1178
1212
  end
1179
1213
  model.instance_variable_set("@#{attribute}", s)
1180
1214
  end
@@ -1191,7 +1225,7 @@ module Prophet
1191
1225
  else
1192
1226
  d = JSON.parse(model_dict.fetch(attribute))
1193
1227
  df = Rover::DataFrame.new(d["data"])
1194
- df["ds"] = df["ds"].map { |v| Time.parse(v).utc } if df["ds"]
1228
+ df["ds"] = df["ds"].map { |v| DateTime.parse(v).to_time.utc } if df["ds"]
1195
1229
  if attribute == "train_component_cols"
1196
1230
  # Special handling because of named index column
1197
1231
  # df.columns.name = 'component'
@@ -4,9 +4,7 @@ module Prophet
4
4
  years = (1995..2045).to_a
5
5
  holiday_names = make_holidays_df(years, country)["holiday"].uniq
6
6
  if holiday_names.size == 0
7
- # TODO raise error in 0.5.0
8
- # raise ArgumentError, "Holidays in #{country} are not currently supported"
9
- logger.warn "Holidays in #{country} are not currently supported"
7
+ raise ArgumentError, "Holidays in #{country} are not currently supported"
10
8
  end
11
9
  holiday_names
12
10
  end
@@ -21,14 +21,16 @@ module Prophet
21
21
  kwargs[:algorithm] ||= stan_data["T"] < 100 ? "Newton" : "LBFGS"
22
22
  iterations = 10000
23
23
 
24
+ args = {
25
+ data: stan_data,
26
+ inits: stan_init,
27
+ iter: iterations
28
+ }
29
+ args.merge!(kwargs)
30
+
24
31
  stan_fit = nil
25
32
  begin
26
- stan_fit = @model.optimize(
27
- data: stan_data,
28
- inits: stan_init,
29
- iter: iterations,
30
- **kwargs
31
- )
33
+ stan_fit = @model.optimize(**args)
32
34
  rescue => e
33
35
  if kwargs[:algorithm] != "Newton"
34
36
  @logger.warn "Optimization terminated abnormally. Falling back to Newton."
@@ -1,3 +1,3 @@
1
1
  module Prophet
2
- VERSION = "0.4.2"
2
+ VERSION = "0.5.1"
3
3
  end
data/lib/prophet-rb.rb CHANGED
@@ -1 +1 @@
1
- require "prophet"
1
+ require_relative "prophet"
data/lib/prophet.rb CHANGED
@@ -1,19 +1,19 @@
1
1
  # dependencies
2
2
  require "cmdstan"
3
- require "rover"
4
3
  require "numo/narray"
4
+ require "rover"
5
5
 
6
6
  # stdlib
7
7
  require "logger"
8
8
  require "set"
9
9
 
10
10
  # modules
11
- require "prophet/diagnostics"
12
- require "prophet/holidays"
13
- require "prophet/plot"
14
- require "prophet/forecaster"
15
- require "prophet/stan_backend"
16
- require "prophet/version"
11
+ require_relative "prophet/diagnostics"
12
+ require_relative "prophet/holidays"
13
+ require_relative "prophet/plot"
14
+ require_relative "prophet/forecaster"
15
+ require_relative "prophet/stan_backend"
16
+ require_relative "prophet/version"
17
17
 
18
18
  module Prophet
19
19
  class Error < StandardError; end
@@ -68,7 +68,7 @@ module Prophet
68
68
  df = Rover::DataFrame.new({"ds" => series.keys, "y" => series.values})
69
69
  df["cap"] = cap if cap
70
70
 
71
- m.logger.level = ::Logger::FATAL unless verbose
71
+ m.logger.level = verbose ? ::Logger::INFO : ::Logger::FATAL
72
72
  m.add_country_holidays(country_holidays) if country_holidays
73
73
  m.fit(df)
74
74
 
@@ -87,7 +87,7 @@ module Prophet
87
87
  else
88
88
  result.each { |v| v["ds"] = v["ds"].localtime }
89
89
  end
90
- result.map { |v| [v["ds"], v["yhat"]] }.to_h
90
+ result.to_h { |v| [v["ds"], v["yhat"]] }
91
91
  end
92
92
 
93
93
  # TODO better name for interval_width
@@ -97,7 +97,7 @@ module Prophet
97
97
  df["cap"] = cap if cap
98
98
 
99
99
  m = Prophet.new(interval_width: interval_width, **options)
100
- m.logger.level = ::Logger::FATAL unless verbose
100
+ m.logger.level = verbose ? ::Logger::INFO : ::Logger::FATAL
101
101
  m.add_country_holidays(country_holidays) if country_holidays
102
102
  m.fit(df)
103
103
 
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: prophet-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.4.2
4
+ version: 0.5.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2022-07-12 00:00:00.000000000 Z
11
+ date: 2024-05-06 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: cmdstan
@@ -159,14 +159,14 @@ required_ruby_version: !ruby/object:Gem::Requirement
159
159
  requirements:
160
160
  - - ">="
161
161
  - !ruby/object:Gem::Version
162
- version: '2.7'
162
+ version: '3'
163
163
  required_rubygems_version: !ruby/object:Gem::Requirement
164
164
  requirements:
165
165
  - - ">="
166
166
  - !ruby/object:Gem::Version
167
167
  version: '0'
168
168
  requirements: []
169
- rubygems_version: 3.3.7
169
+ rubygems_version: 3.5.9
170
170
  signing_key:
171
171
  specification_version: 4
172
172
  summary: Time series forecasting for Ruby