prophet-rb 0.4.1 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -179,7 +179,7 @@ module Prophet
179
179
  if metrics.nil?
180
180
  metrics = valid_metrics
181
181
  end
182
- if (df["yhat_lower"].nil? || df["yhat_upper"].nil?) && metrics.include?("coverage")
182
+ if (!df.include?("yhat_lower") || !df.include?("yhat_upper")) && metrics.include?("coverage")
183
183
  metrics.delete("coverage")
184
184
  end
185
185
  if metrics.uniq.length != metrics.length
@@ -44,7 +44,7 @@ module Prophet
44
44
  @yearly_seasonality = yearly_seasonality
45
45
  @weekly_seasonality = weekly_seasonality
46
46
  @daily_seasonality = daily_seasonality
47
- @holidays = holidays
47
+ @holidays = convert_df(holidays)
48
48
 
49
49
  @seasonality_mode = seasonality_mode
50
50
  @seasonality_prior_scale = seasonality_prior_scale.to_f
@@ -75,6 +75,7 @@ module Prophet
75
75
  validate_inputs
76
76
 
77
77
  @logger = ::Logger.new($stderr)
78
+ @logger.level = ::Logger::WARN
78
79
  @logger.formatter = proc do |severity, datetime, progname, msg|
79
80
  "[prophet] #{msg}\n"
80
81
  end
@@ -89,7 +90,7 @@ module Prophet
89
90
  raise ArgumentError, "Parameter \"changepoint_range\" must be in [0, 1]"
90
91
  end
91
92
  if @holidays
92
- if !@holidays.is_a?(Rover::DataFrame) && @holidays.include?("ds") && @holidays.include?("holiday")
93
+ if !(@holidays.is_a?(Rover::DataFrame) && @holidays.include?("ds") && @holidays.include?("holiday"))
93
94
  raise ArgumentError, "holidays must be a DataFrame with \"ds\" and \"holiday\" columns."
94
95
  end
95
96
  @holidays["ds"] = to_datetime(@holidays["ds"])
@@ -125,24 +126,24 @@ module Prophet
125
126
  "holidays", "zeros", "extra_regressors_additive", "yhat",
126
127
  "extra_regressors_multiplicative", "multiplicative_terms",
127
128
  ]
128
- rn_l = reserved_names.map { |n| n + "_lower" }
129
- rn_u = reserved_names.map { |n| n + "_upper" }
129
+ rn_l = reserved_names.map { |n| "#{n}_lower" }
130
+ rn_u = reserved_names.map { |n| "#{n}_upper" }
130
131
  reserved_names.concat(rn_l)
131
132
  reserved_names.concat(rn_u)
132
133
  reserved_names.concat(["ds", "y", "cap", "floor", "y_scaled", "cap_scaled"])
133
134
  if reserved_names.include?(name)
134
135
  raise ArgumentError, "Name #{name.inspect} is reserved."
135
136
  end
136
- if check_holidays && @holidays && @holidays["holiday"].uniq.include?(name)
137
+ if check_holidays && @holidays && @holidays["holiday"].uniq.to_a.include?(name)
137
138
  raise ArgumentError, "Name #{name.inspect} already used for a holiday."
138
139
  end
139
- if check_holidays && @country_holidays && get_holiday_names(@country_holidays).include?(name)
140
+ if check_holidays && @country_holidays && get_holiday_names(@country_holidays).to_a.include?(name)
140
141
  raise ArgumentError, "Name #{name.inspect} is a holiday name in #{@country_holidays.inspect}."
141
142
  end
142
143
  if check_seasonalities && @seasonalities[name]
143
144
  raise ArgumentError, "Name #{name.inspect} already used for a seasonality."
144
145
  end
145
- if check_regressors and @extra_regressors[name]
146
+ if check_regressors && @extra_regressors[name]
146
147
  raise ArgumentError, "Name #{name.inspect} already used for an added regressor."
147
148
  end
148
149
  end
@@ -167,7 +168,7 @@ module Prophet
167
168
  raise ArgumentError, "Found NaN in column #{name.inspect}"
168
169
  end
169
170
  end
170
- @seasonalities.values.each do |props|
171
+ @seasonalities.each_value do |props|
171
172
  condition_name = props[:condition_name]
172
173
  if condition_name
173
174
  if !df.include?(condition_name)
@@ -481,14 +482,14 @@ module Prophet
481
482
  end
482
483
  # Add totals additive and multiplicative components, and regressors
483
484
  ["additive", "multiplicative"].each do |mode|
484
- components = add_group_component(components, mode + "_terms", modes[mode])
485
+ components = add_group_component(components, "#{mode}_terms", modes[mode])
485
486
  regressors_by_mode = @extra_regressors.select { |r, props| props[:mode] == mode }
486
487
  .map { |r, props| r }
487
- components = add_group_component(components, "extra_regressors_" + mode, regressors_by_mode)
488
+ components = add_group_component(components, "extra_regressors_#{mode}", regressors_by_mode)
488
489
 
489
490
  # Add combination components to modes
490
- modes[mode] << mode + "_terms"
491
- modes[mode] << "extra_regressors_" + mode
491
+ modes[mode] << "#{mode}_terms"
492
+ modes[mode] << "extra_regressors_#{mode}"
492
493
  end
493
494
  # After all of the additive/multiplicative groups have been added,
494
495
  modes[@seasonality_mode] << "holidays"
@@ -631,9 +632,7 @@ module Prophet
631
632
  def fit(df, **kwargs)
632
633
  raise Error, "Prophet object can only be fit once" if @history
633
634
 
634
- if defined?(Daru::DataFrame) && df.is_a?(Daru::DataFrame)
635
- df = Rover::DataFrame.new(df.to_h)
636
- end
635
+ df = convert_df(df)
637
636
  raise ArgumentError, "Must be a data frame" unless df.is_a?(Rover::DataFrame)
638
637
 
639
638
  unless df.include?("ds") && df.include?("y")
@@ -817,8 +816,8 @@ module Prophet
817
816
  end
818
817
  data[component] = comp.mean(axis: 1, nan: true)
819
818
  if @uncertainty_samples
820
- data[component + "_lower"] = comp.percentile(lower_p, axis: 1)
821
- data[component + "_upper"] = comp.percentile(upper_p, axis: 1)
819
+ data["#{component}_lower"] = comp.percentile(lower_p, axis: 1)
820
+ data["#{component}_upper"] = comp.percentile(upper_p, axis: 1)
822
821
  end
823
822
  end
824
823
  Rover::DataFrame.new(data)
@@ -993,6 +992,16 @@ module Prophet
993
992
 
994
993
  private
995
994
 
995
+ def convert_df(df)
996
+ if defined?(Daru::DataFrame) && df.is_a?(Daru::DataFrame)
997
+ Rover::DataFrame.new(df.to_h)
998
+ elsif defined?(Polars::DataFrame) && df.is_a?(Polars::DataFrame)
999
+ Rover::DataFrame.new(df.to_h(as_series: false))
1000
+ else
1001
+ df
1002
+ end
1003
+ end
1004
+
996
1005
  # Time is preferred over DateTime in Ruby docs
997
1006
  # use UTC to be consistent with Python
998
1007
  # and so days have equal length (no DST)
@@ -1077,7 +1086,7 @@ module Prophet
1077
1086
  d = {
1078
1087
  "name" => "ds",
1079
1088
  "index" => v.size.times.to_a,
1080
- "data" => v.to_a.map { |v| v.iso8601(3) }
1089
+ "data" => v.to_a.map { |v| v.iso8601(3).chomp("Z") }
1081
1090
  }
1082
1091
  model_dict[attribute] = JSON.generate(d)
1083
1092
  end
@@ -1096,7 +1105,7 @@ module Prophet
1096
1105
  v = instance_variable_get("@#{attribute}")
1097
1106
 
1098
1107
  v = v.dup
1099
- v["ds"] = v["ds"].map { |v| v.iso8601(3) } if v["ds"]
1108
+ v["ds"] = v["ds"].map { |v| v.iso8601(3).chomp("Z") } if v["ds"]
1100
1109
  v.delete("col")
1101
1110
 
1102
1111
  fields =
@@ -1116,7 +1125,7 @@ module Prophet
1116
1125
  d = {
1117
1126
  "schema" => {
1118
1127
  "fields" => fields,
1119
- "pandas_version" => "0.20.0"
1128
+ "pandas_version" => "1.4.0"
1120
1129
  },
1121
1130
  "data" => v.to_a
1122
1131
  }
@@ -1151,8 +1160,7 @@ module Prophet
1151
1160
  # Params (Dict[str, np.ndarray])
1152
1161
  model_dict["params"] = params.transform_values(&:to_a)
1153
1162
  # Attributes that are skipped: stan_fit, stan_backend
1154
- # Returns 1.0 for Prophet 1.1
1155
- model_dict["__prophet_version"] = "1.0"
1163
+ model_dict["__prophet_version"] = "1.1.2"
1156
1164
  model_dict
1157
1165
  end
1158
1166
 
@@ -1174,7 +1182,7 @@ module Prophet
1174
1182
  d = JSON.parse(model_dict.fetch(attribute))
1175
1183
  s = Rover::Vector.new(d["data"])
1176
1184
  if d["name"] == "ds"
1177
- s = s.map { |v| Time.parse(v).utc }
1185
+ s = s.map { |v| DateTime.parse(v).to_time.utc }
1178
1186
  end
1179
1187
  model.instance_variable_set("@#{attribute}", s)
1180
1188
  end
@@ -1191,7 +1199,7 @@ module Prophet
1191
1199
  else
1192
1200
  d = JSON.parse(model_dict.fetch(attribute))
1193
1201
  df = Rover::DataFrame.new(d["data"])
1194
- df["ds"] = df["ds"].map { |v| Time.parse(v).utc } if df["ds"]
1202
+ df["ds"] = df["ds"].map { |v| DateTime.parse(v).to_time.utc } if df["ds"]
1195
1203
  if attribute == "train_component_cols"
1196
1204
  # Special handling because of named index column
1197
1205
  # df.columns.name = 'component'
@@ -3,8 +3,9 @@ module Prophet
3
3
  def get_holiday_names(country)
4
4
  years = (1995..2045).to_a
5
5
  holiday_names = make_holidays_df(years, country)["holiday"].uniq
6
- # TODO raise error in 0.4.0
7
- logger.warn "Holidays in #{country} are not currently supported"
6
+ if holiday_names.size == 0
7
+ raise ArgumentError, "Holidays in #{country} are not currently supported"
8
+ end
8
9
  holiday_names
9
10
  end
10
11
 
data/lib/prophet/plot.rb CHANGED
@@ -183,7 +183,7 @@ module Prophet
183
183
  ax.plot(fcst_t, fcst["floor"].to_a, ls: "--", c: "k")
184
184
  end
185
185
  if uncertainty && @uncertainty_samples
186
- artists += [ax.fill_between(fcst_t, fcst[name + "_lower"].to_a, fcst[name + "_upper"].to_a, color: "#0072B2", alpha: 0.2)]
186
+ artists += [ax.fill_between(fcst_t, fcst["#{name}_lower"].to_a, fcst["#{name}_upper"].to_a, color: "#0072B2", alpha: 0.2)]
187
187
  end
188
188
  # Specify formatting to workaround matplotlib issue #12925
189
189
  locator = dates.AutoDateLocator.new(interval_multiples: false)
@@ -229,7 +229,7 @@ module Prophet
229
229
  days = days.map { |v| v.strftime("%A") }
230
230
  artists += ax.plot(days.size.times.to_a, seas[name].to_a, ls: "-", c: "#0072B2")
231
231
  if uncertainty && @uncertainty_samples
232
- artists += [ax.fill_between(days.size.times.to_a, seas[name + "_lower"].to_a, seas[name + "_upper"].to_a, color: "#0072B2", alpha: 0.2)]
232
+ artists += [ax.fill_between(days.size.times.to_a, seas["#{name}_lower"].to_a, seas["#{name}_upper"].to_a, color: "#0072B2", alpha: 0.2)]
233
233
  end
234
234
  ax.grid(true, which: "major", c: "gray", ls: "-", lw: 1, alpha: 0.2)
235
235
  ax.set_xticks(days.size.times.to_a)
@@ -255,7 +255,7 @@ module Prophet
255
255
  seas = predict_seasonal_components(df_y)
256
256
  artists += ax.plot(to_pydatetime(df_y["ds"]), seas[name].to_a, ls: "-", c: "#0072B2")
257
257
  if uncertainty && @uncertainty_samples
258
- artists += [ax.fill_between(to_pydatetime(df_y["ds"]), seas[name + "_lower"].to_a, seas[name + "_upper"].to_a, color: "#0072B2", alpha: 0.2)]
258
+ artists += [ax.fill_between(to_pydatetime(df_y["ds"]), seas["#{name}_lower"].to_a, seas["#{name}_upper"].to_a, color: "#0072B2", alpha: 0.2)]
259
259
  end
260
260
  ax.grid(true, which: "major", c: "gray", ls: "-", lw: 1, alpha: 0.2)
261
261
  months = dates.MonthLocator.new((1..12).to_a, bymonthday: 1, interval: 2)
@@ -288,7 +288,7 @@ module Prophet
288
288
  seas = predict_seasonal_components(df_y)
289
289
  artists += ax.plot(to_pydatetime(df_y["ds"]), seas[name].to_a, ls: "-", c: "#0072B2")
290
290
  if uncertainty && @uncertainty_samples
291
- artists += [ax.fill_between(to_pydatetime(df_y["ds"]), seas[name + "_lower"].to_a, seas[name + "_upper"].to_a, color: "#0072B2", alpha: 0.2)]
291
+ artists += [ax.fill_between(to_pydatetime(df_y["ds"]), seas["#{name}_lower"].to_a, seas["#{name}_upper"].to_a, color: "#0072B2", alpha: 0.2)]
292
292
  end
293
293
  ax.grid(true, which: "major", c: "gray", ls: "-", lw: 1, alpha: 0.2)
294
294
  step = (finish - start) / (7 - 1).to_f
@@ -21,14 +21,16 @@ module Prophet
21
21
  kwargs[:algorithm] ||= stan_data["T"] < 100 ? "Newton" : "LBFGS"
22
22
  iterations = 10000
23
23
 
24
+ args = {
25
+ data: stan_data,
26
+ inits: stan_init,
27
+ iter: iterations
28
+ }
29
+ args.merge!(kwargs)
30
+
24
31
  stan_fit = nil
25
32
  begin
26
- stan_fit = @model.optimize(
27
- data: stan_data,
28
- inits: stan_init,
29
- iter: iterations,
30
- **kwargs
31
- )
33
+ stan_fit = @model.optimize(**args)
32
34
  rescue => e
33
35
  if kwargs[:algorithm] != "Newton"
34
36
  @logger.warn "Optimization terminated abnormally. Falling back to Newton."
@@ -1,3 +1,3 @@
1
1
  module Prophet
2
- VERSION = "0.4.1"
2
+ VERSION = "0.5.0"
3
3
  end
data/lib/prophet-rb.rb CHANGED
@@ -1 +1 @@
1
- require "prophet"
1
+ require_relative "prophet"
data/lib/prophet.rb CHANGED
@@ -1,19 +1,19 @@
1
1
  # dependencies
2
2
  require "cmdstan"
3
- require "rover"
4
3
  require "numo/narray"
4
+ require "rover"
5
5
 
6
6
  # stdlib
7
7
  require "logger"
8
8
  require "set"
9
9
 
10
10
  # modules
11
- require "prophet/diagnostics"
12
- require "prophet/holidays"
13
- require "prophet/plot"
14
- require "prophet/forecaster"
15
- require "prophet/stan_backend"
16
- require "prophet/version"
11
+ require_relative "prophet/diagnostics"
12
+ require_relative "prophet/holidays"
13
+ require_relative "prophet/plot"
14
+ require_relative "prophet/forecaster"
15
+ require_relative "prophet/stan_backend"
16
+ require_relative "prophet/version"
17
17
 
18
18
  module Prophet
19
19
  class Error < StandardError; end
@@ -68,7 +68,7 @@ module Prophet
68
68
  df = Rover::DataFrame.new({"ds" => series.keys, "y" => series.values})
69
69
  df["cap"] = cap if cap
70
70
 
71
- m.logger.level = ::Logger::FATAL unless verbose
71
+ m.logger.level = verbose ? ::Logger::INFO : ::Logger::FATAL
72
72
  m.add_country_holidays(country_holidays) if country_holidays
73
73
  m.fit(df)
74
74
 
@@ -87,7 +87,7 @@ module Prophet
87
87
  else
88
88
  result.each { |v| v["ds"] = v["ds"].localtime }
89
89
  end
90
- result.map { |v| [v["ds"], v["yhat"]] }.to_h
90
+ result.to_h { |v| [v["ds"], v["yhat"]] }
91
91
  end
92
92
 
93
93
  # TODO better name for interval_width
@@ -97,7 +97,7 @@ module Prophet
97
97
  df["cap"] = cap if cap
98
98
 
99
99
  m = Prophet.new(interval_width: interval_width, **options)
100
- m.logger.level = ::Logger::FATAL unless verbose
100
+ m.logger.level = verbose ? ::Logger::INFO : ::Logger::FATAL
101
101
  m.add_country_holidays(country_holidays) if country_holidays
102
102
  m.fit(df)
103
103
 
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: prophet-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.4.1
4
+ version: 0.5.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2022-07-10 00:00:00.000000000 Z
11
+ date: 2023-09-05 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: cmdstan
@@ -159,14 +159,14 @@ required_ruby_version: !ruby/object:Gem::Requirement
159
159
  requirements:
160
160
  - - ">="
161
161
  - !ruby/object:Gem::Version
162
- version: '2.7'
162
+ version: '3'
163
163
  required_rubygems_version: !ruby/object:Gem::Requirement
164
164
  requirements:
165
165
  - - ">="
166
166
  - !ruby/object:Gem::Version
167
167
  version: '0'
168
168
  requirements: []
169
- rubygems_version: 3.3.7
169
+ rubygems_version: 3.4.10
170
170
  signing_key:
171
171
  specification_version: 4
172
172
  summary: Time series forecasting for Ruby