prophet-rb 0.4.2 → 0.5.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -44,7 +44,7 @@ module Prophet
44
44
  @yearly_seasonality = yearly_seasonality
45
45
  @weekly_seasonality = weekly_seasonality
46
46
  @daily_seasonality = daily_seasonality
47
- @holidays = holidays
47
+ @holidays = convert_df(holidays)
48
48
 
49
49
  @seasonality_mode = seasonality_mode
50
50
  @seasonality_prior_scale = seasonality_prior_scale.to_f
@@ -75,6 +75,7 @@ module Prophet
75
75
  validate_inputs
76
76
 
77
77
  @logger = ::Logger.new($stderr)
78
+ @logger.level = ::Logger::WARN
78
79
  @logger.formatter = proc do |severity, datetime, progname, msg|
79
80
  "[prophet] #{msg}\n"
80
81
  end
@@ -133,10 +134,10 @@ module Prophet
133
134
  if reserved_names.include?(name)
134
135
  raise ArgumentError, "Name #{name.inspect} is reserved."
135
136
  end
136
- if check_holidays && @holidays && @holidays["holiday"].uniq.include?(name)
137
+ if check_holidays && @holidays && @holidays["holiday"].uniq.to_a.include?(name)
137
138
  raise ArgumentError, "Name #{name.inspect} already used for a holiday."
138
139
  end
139
- if check_holidays && @country_holidays && get_holiday_names(@country_holidays).include?(name)
140
+ if check_holidays && @country_holidays && get_holiday_names(@country_holidays).to_a.include?(name)
140
141
  raise ArgumentError, "Name #{name.inspect} is a holiday name in #{@country_holidays.inspect}."
141
142
  end
142
143
  if check_seasonalities && @seasonalities[name]
@@ -631,9 +632,7 @@ module Prophet
631
632
  def fit(df, **kwargs)
632
633
  raise Error, "Prophet object can only be fit once" if @history
633
634
 
634
- if defined?(Daru::DataFrame) && df.is_a?(Daru::DataFrame)
635
- df = Rover::DataFrame.new(df.to_h)
636
- end
635
+ df = convert_df(df)
637
636
  raise ArgumentError, "Must be a data frame" unless df.is_a?(Rover::DataFrame)
638
637
 
639
638
  unless df.include?("ds") && df.include?("y")
@@ -993,6 +992,16 @@ module Prophet
993
992
 
994
993
  private
995
994
 
995
+ def convert_df(df)
996
+ if defined?(Daru::DataFrame) && df.is_a?(Daru::DataFrame)
997
+ Rover::DataFrame.new(df.to_h)
998
+ elsif defined?(Polars::DataFrame) && df.is_a?(Polars::DataFrame)
999
+ Rover::DataFrame.new(df.to_h(as_series: false))
1000
+ else
1001
+ df
1002
+ end
1003
+ end
1004
+
996
1005
  # Time is preferred over DateTime in Ruby docs
997
1006
  # use UTC to be consistent with Python
998
1007
  # and so days have equal length (no DST)
@@ -1077,7 +1086,7 @@ module Prophet
1077
1086
  d = {
1078
1087
  "name" => "ds",
1079
1088
  "index" => v.size.times.to_a,
1080
- "data" => v.to_a.map { |v| v.iso8601(3) }
1089
+ "data" => v.to_a.map { |v| v.iso8601(3).chomp("Z") }
1081
1090
  }
1082
1091
  model_dict[attribute] = JSON.generate(d)
1083
1092
  end
@@ -1096,7 +1105,7 @@ module Prophet
1096
1105
  v = instance_variable_get("@#{attribute}")
1097
1106
 
1098
1107
  v = v.dup
1099
- v["ds"] = v["ds"].map { |v| v.iso8601(3) } if v["ds"]
1108
+ v["ds"] = v["ds"].map { |v| v.iso8601(3).chomp("Z") } if v["ds"]
1100
1109
  v.delete("col")
1101
1110
 
1102
1111
  fields =
@@ -1116,7 +1125,7 @@ module Prophet
1116
1125
  d = {
1117
1126
  "schema" => {
1118
1127
  "fields" => fields,
1119
- "pandas_version" => "0.20.0"
1128
+ "pandas_version" => "1.4.0"
1120
1129
  },
1121
1130
  "data" => v.to_a
1122
1131
  }
@@ -1151,8 +1160,7 @@ module Prophet
1151
1160
  # Params (Dict[str, np.ndarray])
1152
1161
  model_dict["params"] = params.transform_values(&:to_a)
1153
1162
  # Attributes that are skipped: stan_fit, stan_backend
1154
- # Returns 1.0 for Prophet 1.1
1155
- model_dict["__prophet_version"] = "1.0"
1163
+ model_dict["__prophet_version"] = "1.1.2"
1156
1164
  model_dict
1157
1165
  end
1158
1166
 
@@ -1174,7 +1182,7 @@ module Prophet
1174
1182
  d = JSON.parse(model_dict.fetch(attribute))
1175
1183
  s = Rover::Vector.new(d["data"])
1176
1184
  if d["name"] == "ds"
1177
- s = s.map { |v| Time.parse(v).utc }
1185
+ s = s.map { |v| DateTime.parse(v).to_time.utc }
1178
1186
  end
1179
1187
  model.instance_variable_set("@#{attribute}", s)
1180
1188
  end
@@ -1191,7 +1199,7 @@ module Prophet
1191
1199
  else
1192
1200
  d = JSON.parse(model_dict.fetch(attribute))
1193
1201
  df = Rover::DataFrame.new(d["data"])
1194
- df["ds"] = df["ds"].map { |v| Time.parse(v).utc } if df["ds"]
1202
+ df["ds"] = df["ds"].map { |v| DateTime.parse(v).to_time.utc } if df["ds"]
1195
1203
  if attribute == "train_component_cols"
1196
1204
  # Special handling because of named index column
1197
1205
  # df.columns.name = 'component'
@@ -4,9 +4,7 @@ module Prophet
4
4
  years = (1995..2045).to_a
5
5
  holiday_names = make_holidays_df(years, country)["holiday"].uniq
6
6
  if holiday_names.size == 0
7
- # TODO raise error in 0.5.0
8
- # raise ArgumentError, "Holidays in #{country} are not currently supported"
9
- logger.warn "Holidays in #{country} are not currently supported"
7
+ raise ArgumentError, "Holidays in #{country} are not currently supported"
10
8
  end
11
9
  holiday_names
12
10
  end
@@ -21,14 +21,16 @@ module Prophet
21
21
  kwargs[:algorithm] ||= stan_data["T"] < 100 ? "Newton" : "LBFGS"
22
22
  iterations = 10000
23
23
 
24
+ args = {
25
+ data: stan_data,
26
+ inits: stan_init,
27
+ iter: iterations
28
+ }
29
+ args.merge!(kwargs)
30
+
24
31
  stan_fit = nil
25
32
  begin
26
- stan_fit = @model.optimize(
27
- data: stan_data,
28
- inits: stan_init,
29
- iter: iterations,
30
- **kwargs
31
- )
33
+ stan_fit = @model.optimize(**args)
32
34
  rescue => e
33
35
  if kwargs[:algorithm] != "Newton"
34
36
  @logger.warn "Optimization terminated abnormally. Falling back to Newton."
@@ -1,3 +1,3 @@
1
1
  module Prophet
2
- VERSION = "0.4.2"
2
+ VERSION = "0.5.0"
3
3
  end
data/lib/prophet-rb.rb CHANGED
@@ -1 +1 @@
1
- require "prophet"
1
+ require_relative "prophet"
data/lib/prophet.rb CHANGED
@@ -1,19 +1,19 @@
1
1
  # dependencies
2
2
  require "cmdstan"
3
- require "rover"
4
3
  require "numo/narray"
4
+ require "rover"
5
5
 
6
6
  # stdlib
7
7
  require "logger"
8
8
  require "set"
9
9
 
10
10
  # modules
11
- require "prophet/diagnostics"
12
- require "prophet/holidays"
13
- require "prophet/plot"
14
- require "prophet/forecaster"
15
- require "prophet/stan_backend"
16
- require "prophet/version"
11
+ require_relative "prophet/diagnostics"
12
+ require_relative "prophet/holidays"
13
+ require_relative "prophet/plot"
14
+ require_relative "prophet/forecaster"
15
+ require_relative "prophet/stan_backend"
16
+ require_relative "prophet/version"
17
17
 
18
18
  module Prophet
19
19
  class Error < StandardError; end
@@ -68,7 +68,7 @@ module Prophet
68
68
  df = Rover::DataFrame.new({"ds" => series.keys, "y" => series.values})
69
69
  df["cap"] = cap if cap
70
70
 
71
- m.logger.level = ::Logger::FATAL unless verbose
71
+ m.logger.level = verbose ? ::Logger::INFO : ::Logger::FATAL
72
72
  m.add_country_holidays(country_holidays) if country_holidays
73
73
  m.fit(df)
74
74
 
@@ -87,7 +87,7 @@ module Prophet
87
87
  else
88
88
  result.each { |v| v["ds"] = v["ds"].localtime }
89
89
  end
90
- result.map { |v| [v["ds"], v["yhat"]] }.to_h
90
+ result.to_h { |v| [v["ds"], v["yhat"]] }
91
91
  end
92
92
 
93
93
  # TODO better name for interval_width
@@ -97,7 +97,7 @@ module Prophet
97
97
  df["cap"] = cap if cap
98
98
 
99
99
  m = Prophet.new(interval_width: interval_width, **options)
100
- m.logger.level = ::Logger::FATAL unless verbose
100
+ m.logger.level = verbose ? ::Logger::INFO : ::Logger::FATAL
101
101
  m.add_country_holidays(country_holidays) if country_holidays
102
102
  m.fit(df)
103
103
 
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: prophet-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.4.2
4
+ version: 0.5.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2022-07-12 00:00:00.000000000 Z
11
+ date: 2023-09-05 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: cmdstan
@@ -159,14 +159,14 @@ required_ruby_version: !ruby/object:Gem::Requirement
159
159
  requirements:
160
160
  - - ">="
161
161
  - !ruby/object:Gem::Version
162
- version: '2.7'
162
+ version: '3'
163
163
  required_rubygems_version: !ruby/object:Gem::Requirement
164
164
  requirements:
165
165
  - - ">="
166
166
  - !ruby/object:Gem::Version
167
167
  version: '0'
168
168
  requirements: []
169
- rubygems_version: 3.3.7
169
+ rubygems_version: 3.4.10
170
170
  signing_key:
171
171
  specification_version: 4
172
172
  summary: Time series forecasting for Ruby