prophet-rb 0.4.1 → 0.5.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -179,7 +179,7 @@ module Prophet
179
179
  if metrics.nil?
180
180
  metrics = valid_metrics
181
181
  end
182
- if (df["yhat_lower"].nil? || df["yhat_upper"].nil?) && metrics.include?("coverage")
182
+ if (!df.include?("yhat_lower") || !df.include?("yhat_upper")) && metrics.include?("coverage")
183
183
  metrics.delete("coverage")
184
184
  end
185
185
  if metrics.uniq.length != metrics.length
@@ -44,7 +44,7 @@ module Prophet
44
44
  @yearly_seasonality = yearly_seasonality
45
45
  @weekly_seasonality = weekly_seasonality
46
46
  @daily_seasonality = daily_seasonality
47
- @holidays = holidays
47
+ @holidays = convert_df(holidays)
48
48
 
49
49
  @seasonality_mode = seasonality_mode
50
50
  @seasonality_prior_scale = seasonality_prior_scale.to_f
@@ -75,6 +75,7 @@ module Prophet
75
75
  validate_inputs
76
76
 
77
77
  @logger = ::Logger.new($stderr)
78
+ @logger.level = ::Logger::WARN
78
79
  @logger.formatter = proc do |severity, datetime, progname, msg|
79
80
  "[prophet] #{msg}\n"
80
81
  end
@@ -89,7 +90,7 @@ module Prophet
89
90
  raise ArgumentError, "Parameter \"changepoint_range\" must be in [0, 1]"
90
91
  end
91
92
  if @holidays
92
- if !@holidays.is_a?(Rover::DataFrame) && @holidays.include?("ds") && @holidays.include?("holiday")
93
+ if !(@holidays.is_a?(Rover::DataFrame) && @holidays.include?("ds") && @holidays.include?("holiday"))
93
94
  raise ArgumentError, "holidays must be a DataFrame with \"ds\" and \"holiday\" columns."
94
95
  end
95
96
  @holidays["ds"] = to_datetime(@holidays["ds"])
@@ -125,24 +126,24 @@ module Prophet
125
126
  "holidays", "zeros", "extra_regressors_additive", "yhat",
126
127
  "extra_regressors_multiplicative", "multiplicative_terms",
127
128
  ]
128
- rn_l = reserved_names.map { |n| n + "_lower" }
129
- rn_u = reserved_names.map { |n| n + "_upper" }
129
+ rn_l = reserved_names.map { |n| "#{n}_lower" }
130
+ rn_u = reserved_names.map { |n| "#{n}_upper" }
130
131
  reserved_names.concat(rn_l)
131
132
  reserved_names.concat(rn_u)
132
133
  reserved_names.concat(["ds", "y", "cap", "floor", "y_scaled", "cap_scaled"])
133
134
  if reserved_names.include?(name)
134
135
  raise ArgumentError, "Name #{name.inspect} is reserved."
135
136
  end
136
- if check_holidays && @holidays && @holidays["holiday"].uniq.include?(name)
137
+ if check_holidays && @holidays && @holidays["holiday"].uniq.to_a.include?(name)
137
138
  raise ArgumentError, "Name #{name.inspect} already used for a holiday."
138
139
  end
139
- if check_holidays && @country_holidays && get_holiday_names(@country_holidays).include?(name)
140
+ if check_holidays && @country_holidays && get_holiday_names(@country_holidays).to_a.include?(name)
140
141
  raise ArgumentError, "Name #{name.inspect} is a holiday name in #{@country_holidays.inspect}."
141
142
  end
142
143
  if check_seasonalities && @seasonalities[name]
143
144
  raise ArgumentError, "Name #{name.inspect} already used for a seasonality."
144
145
  end
145
- if check_regressors and @extra_regressors[name]
146
+ if check_regressors && @extra_regressors[name]
146
147
  raise ArgumentError, "Name #{name.inspect} already used for an added regressor."
147
148
  end
148
149
  end
@@ -167,7 +168,7 @@ module Prophet
167
168
  raise ArgumentError, "Found NaN in column #{name.inspect}"
168
169
  end
169
170
  end
170
- @seasonalities.values.each do |props|
171
+ @seasonalities.each_value do |props|
171
172
  condition_name = props[:condition_name]
172
173
  if condition_name
173
174
  if !df.include?(condition_name)
@@ -481,14 +482,14 @@ module Prophet
481
482
  end
482
483
  # Add totals additive and multiplicative components, and regressors
483
484
  ["additive", "multiplicative"].each do |mode|
484
- components = add_group_component(components, mode + "_terms", modes[mode])
485
+ components = add_group_component(components, "#{mode}_terms", modes[mode])
485
486
  regressors_by_mode = @extra_regressors.select { |r, props| props[:mode] == mode }
486
487
  .map { |r, props| r }
487
- components = add_group_component(components, "extra_regressors_" + mode, regressors_by_mode)
488
+ components = add_group_component(components, "extra_regressors_#{mode}", regressors_by_mode)
488
489
 
489
490
  # Add combination components to modes
490
- modes[mode] << mode + "_terms"
491
- modes[mode] << "extra_regressors_" + mode
491
+ modes[mode] << "#{mode}_terms"
492
+ modes[mode] << "extra_regressors_#{mode}"
492
493
  end
493
494
  # After all of the additive/multiplicative groups have been added,
494
495
  modes[@seasonality_mode] << "holidays"
@@ -631,9 +632,7 @@ module Prophet
631
632
  def fit(df, **kwargs)
632
633
  raise Error, "Prophet object can only be fit once" if @history
633
634
 
634
- if defined?(Daru::DataFrame) && df.is_a?(Daru::DataFrame)
635
- df = Rover::DataFrame.new(df.to_h)
636
- end
635
+ df = convert_df(df)
637
636
  raise ArgumentError, "Must be a data frame" unless df.is_a?(Rover::DataFrame)
638
637
 
639
638
  unless df.include?("ds") && df.include?("y")
@@ -817,8 +816,8 @@ module Prophet
817
816
  end
818
817
  data[component] = comp.mean(axis: 1, nan: true)
819
818
  if @uncertainty_samples
820
- data[component + "_lower"] = comp.percentile(lower_p, axis: 1)
821
- data[component + "_upper"] = comp.percentile(upper_p, axis: 1)
819
+ data["#{component}_lower"] = comp.percentile(lower_p, axis: 1)
820
+ data["#{component}_upper"] = comp.percentile(upper_p, axis: 1)
822
821
  end
823
822
  end
824
823
  Rover::DataFrame.new(data)
@@ -993,6 +992,16 @@ module Prophet
993
992
 
994
993
  private
995
994
 
995
+ def convert_df(df)
996
+ if defined?(Daru::DataFrame) && df.is_a?(Daru::DataFrame)
997
+ Rover::DataFrame.new(df.to_h)
998
+ elsif defined?(Polars::DataFrame) && df.is_a?(Polars::DataFrame)
999
+ Rover::DataFrame.new(df.to_h(as_series: false))
1000
+ else
1001
+ df
1002
+ end
1003
+ end
1004
+
996
1005
  # Time is preferred over DateTime in Ruby docs
997
1006
  # use UTC to be consistent with Python
998
1007
  # and so days have equal length (no DST)
@@ -1077,7 +1086,7 @@ module Prophet
1077
1086
  d = {
1078
1087
  "name" => "ds",
1079
1088
  "index" => v.size.times.to_a,
1080
- "data" => v.to_a.map { |v| v.iso8601(3) }
1089
+ "data" => v.to_a.map { |v| v.iso8601(3).chomp("Z") }
1081
1090
  }
1082
1091
  model_dict[attribute] = JSON.generate(d)
1083
1092
  end
@@ -1096,7 +1105,7 @@ module Prophet
1096
1105
  v = instance_variable_get("@#{attribute}")
1097
1106
 
1098
1107
  v = v.dup
1099
- v["ds"] = v["ds"].map { |v| v.iso8601(3) } if v["ds"]
1108
+ v["ds"] = v["ds"].map { |v| v.iso8601(3).chomp("Z") } if v["ds"]
1100
1109
  v.delete("col")
1101
1110
 
1102
1111
  fields =
@@ -1116,7 +1125,7 @@ module Prophet
1116
1125
  d = {
1117
1126
  "schema" => {
1118
1127
  "fields" => fields,
1119
- "pandas_version" => "0.20.0"
1128
+ "pandas_version" => "1.4.0"
1120
1129
  },
1121
1130
  "data" => v.to_a
1122
1131
  }
@@ -1151,8 +1160,7 @@ module Prophet
1151
1160
  # Params (Dict[str, np.ndarray])
1152
1161
  model_dict["params"] = params.transform_values(&:to_a)
1153
1162
  # Attributes that are skipped: stan_fit, stan_backend
1154
- # Returns 1.0 for Prophet 1.1
1155
- model_dict["__prophet_version"] = "1.0"
1163
+ model_dict["__prophet_version"] = "1.1.2"
1156
1164
  model_dict
1157
1165
  end
1158
1166
 
@@ -1174,7 +1182,7 @@ module Prophet
1174
1182
  d = JSON.parse(model_dict.fetch(attribute))
1175
1183
  s = Rover::Vector.new(d["data"])
1176
1184
  if d["name"] == "ds"
1177
- s = s.map { |v| Time.parse(v).utc }
1185
+ s = s.map { |v| DateTime.parse(v).to_time.utc }
1178
1186
  end
1179
1187
  model.instance_variable_set("@#{attribute}", s)
1180
1188
  end
@@ -1191,7 +1199,7 @@ module Prophet
1191
1199
  else
1192
1200
  d = JSON.parse(model_dict.fetch(attribute))
1193
1201
  df = Rover::DataFrame.new(d["data"])
1194
- df["ds"] = df["ds"].map { |v| Time.parse(v).utc } if df["ds"]
1202
+ df["ds"] = df["ds"].map { |v| DateTime.parse(v).to_time.utc } if df["ds"]
1195
1203
  if attribute == "train_component_cols"
1196
1204
  # Special handling because of named index column
1197
1205
  # df.columns.name = 'component'
@@ -3,8 +3,9 @@ module Prophet
3
3
  def get_holiday_names(country)
4
4
  years = (1995..2045).to_a
5
5
  holiday_names = make_holidays_df(years, country)["holiday"].uniq
6
- # TODO raise error in 0.4.0
7
- logger.warn "Holidays in #{country} are not currently supported"
6
+ if holiday_names.size == 0
7
+ raise ArgumentError, "Holidays in #{country} are not currently supported"
8
+ end
8
9
  holiday_names
9
10
  end
10
11
 
data/lib/prophet/plot.rb CHANGED
@@ -183,7 +183,7 @@ module Prophet
183
183
  ax.plot(fcst_t, fcst["floor"].to_a, ls: "--", c: "k")
184
184
  end
185
185
  if uncertainty && @uncertainty_samples
186
- artists += [ax.fill_between(fcst_t, fcst[name + "_lower"].to_a, fcst[name + "_upper"].to_a, color: "#0072B2", alpha: 0.2)]
186
+ artists += [ax.fill_between(fcst_t, fcst["#{name}_lower"].to_a, fcst["#{name}_upper"].to_a, color: "#0072B2", alpha: 0.2)]
187
187
  end
188
188
  # Specify formatting to workaround matplotlib issue #12925
189
189
  locator = dates.AutoDateLocator.new(interval_multiples: false)
@@ -229,7 +229,7 @@ module Prophet
229
229
  days = days.map { |v| v.strftime("%A") }
230
230
  artists += ax.plot(days.size.times.to_a, seas[name].to_a, ls: "-", c: "#0072B2")
231
231
  if uncertainty && @uncertainty_samples
232
- artists += [ax.fill_between(days.size.times.to_a, seas[name + "_lower"].to_a, seas[name + "_upper"].to_a, color: "#0072B2", alpha: 0.2)]
232
+ artists += [ax.fill_between(days.size.times.to_a, seas["#{name}_lower"].to_a, seas["#{name}_upper"].to_a, color: "#0072B2", alpha: 0.2)]
233
233
  end
234
234
  ax.grid(true, which: "major", c: "gray", ls: "-", lw: 1, alpha: 0.2)
235
235
  ax.set_xticks(days.size.times.to_a)
@@ -255,7 +255,7 @@ module Prophet
255
255
  seas = predict_seasonal_components(df_y)
256
256
  artists += ax.plot(to_pydatetime(df_y["ds"]), seas[name].to_a, ls: "-", c: "#0072B2")
257
257
  if uncertainty && @uncertainty_samples
258
- artists += [ax.fill_between(to_pydatetime(df_y["ds"]), seas[name + "_lower"].to_a, seas[name + "_upper"].to_a, color: "#0072B2", alpha: 0.2)]
258
+ artists += [ax.fill_between(to_pydatetime(df_y["ds"]), seas["#{name}_lower"].to_a, seas["#{name}_upper"].to_a, color: "#0072B2", alpha: 0.2)]
259
259
  end
260
260
  ax.grid(true, which: "major", c: "gray", ls: "-", lw: 1, alpha: 0.2)
261
261
  months = dates.MonthLocator.new((1..12).to_a, bymonthday: 1, interval: 2)
@@ -288,7 +288,7 @@ module Prophet
288
288
  seas = predict_seasonal_components(df_y)
289
289
  artists += ax.plot(to_pydatetime(df_y["ds"]), seas[name].to_a, ls: "-", c: "#0072B2")
290
290
  if uncertainty && @uncertainty_samples
291
- artists += [ax.fill_between(to_pydatetime(df_y["ds"]), seas[name + "_lower"].to_a, seas[name + "_upper"].to_a, color: "#0072B2", alpha: 0.2)]
291
+ artists += [ax.fill_between(to_pydatetime(df_y["ds"]), seas["#{name}_lower"].to_a, seas["#{name}_upper"].to_a, color: "#0072B2", alpha: 0.2)]
292
292
  end
293
293
  ax.grid(true, which: "major", c: "gray", ls: "-", lw: 1, alpha: 0.2)
294
294
  step = (finish - start) / (7 - 1).to_f
@@ -21,14 +21,16 @@ module Prophet
21
21
  kwargs[:algorithm] ||= stan_data["T"] < 100 ? "Newton" : "LBFGS"
22
22
  iterations = 10000
23
23
 
24
+ args = {
25
+ data: stan_data,
26
+ inits: stan_init,
27
+ iter: iterations
28
+ }
29
+ args.merge!(kwargs)
30
+
24
31
  stan_fit = nil
25
32
  begin
26
- stan_fit = @model.optimize(
27
- data: stan_data,
28
- inits: stan_init,
29
- iter: iterations,
30
- **kwargs
31
- )
33
+ stan_fit = @model.optimize(**args)
32
34
  rescue => e
33
35
  if kwargs[:algorithm] != "Newton"
34
36
  @logger.warn "Optimization terminated abnormally. Falling back to Newton."
@@ -1,3 +1,3 @@
1
1
  module Prophet
2
- VERSION = "0.4.1"
2
+ VERSION = "0.5.0"
3
3
  end
data/lib/prophet-rb.rb CHANGED
@@ -1 +1 @@
1
- require "prophet"
1
+ require_relative "prophet"
data/lib/prophet.rb CHANGED
@@ -1,19 +1,19 @@
1
1
  # dependencies
2
2
  require "cmdstan"
3
- require "rover"
4
3
  require "numo/narray"
4
+ require "rover"
5
5
 
6
6
  # stdlib
7
7
  require "logger"
8
8
  require "set"
9
9
 
10
10
  # modules
11
- require "prophet/diagnostics"
12
- require "prophet/holidays"
13
- require "prophet/plot"
14
- require "prophet/forecaster"
15
- require "prophet/stan_backend"
16
- require "prophet/version"
11
+ require_relative "prophet/diagnostics"
12
+ require_relative "prophet/holidays"
13
+ require_relative "prophet/plot"
14
+ require_relative "prophet/forecaster"
15
+ require_relative "prophet/stan_backend"
16
+ require_relative "prophet/version"
17
17
 
18
18
  module Prophet
19
19
  class Error < StandardError; end
@@ -68,7 +68,7 @@ module Prophet
68
68
  df = Rover::DataFrame.new({"ds" => series.keys, "y" => series.values})
69
69
  df["cap"] = cap if cap
70
70
 
71
- m.logger.level = ::Logger::FATAL unless verbose
71
+ m.logger.level = verbose ? ::Logger::INFO : ::Logger::FATAL
72
72
  m.add_country_holidays(country_holidays) if country_holidays
73
73
  m.fit(df)
74
74
 
@@ -87,7 +87,7 @@ module Prophet
87
87
  else
88
88
  result.each { |v| v["ds"] = v["ds"].localtime }
89
89
  end
90
- result.map { |v| [v["ds"], v["yhat"]] }.to_h
90
+ result.to_h { |v| [v["ds"], v["yhat"]] }
91
91
  end
92
92
 
93
93
  # TODO better name for interval_width
@@ -97,7 +97,7 @@ module Prophet
97
97
  df["cap"] = cap if cap
98
98
 
99
99
  m = Prophet.new(interval_width: interval_width, **options)
100
- m.logger.level = ::Logger::FATAL unless verbose
100
+ m.logger.level = verbose ? ::Logger::INFO : ::Logger::FATAL
101
101
  m.add_country_holidays(country_holidays) if country_holidays
102
102
  m.fit(df)
103
103
 
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: prophet-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.4.1
4
+ version: 0.5.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2022-07-10 00:00:00.000000000 Z
11
+ date: 2023-09-05 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: cmdstan
@@ -159,14 +159,14 @@ required_ruby_version: !ruby/object:Gem::Requirement
159
159
  requirements:
160
160
  - - ">="
161
161
  - !ruby/object:Gem::Version
162
- version: '2.7'
162
+ version: '3'
163
163
  required_rubygems_version: !ruby/object:Gem::Requirement
164
164
  requirements:
165
165
  - - ">="
166
166
  - !ruby/object:Gem::Version
167
167
  version: '0'
168
168
  requirements: []
169
- rubygems_version: 3.3.7
169
+ rubygems_version: 3.4.10
170
170
  signing_key:
171
171
  specification_version: 4
172
172
  summary: Time series forecasting for Ruby